Skip to content
151 changes: 118 additions & 33 deletions backend/plugins/q_dev/tasks/s3_file_collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,87 +28,172 @@ import (
"github.com/aws/aws-sdk-go/service/s3"
)

var _ plugin.SubTaskEntryPoint = CollectQDevS3Files

// CollectQDevS3Files 收集S3文件元数据
func CollectQDevS3Files(taskCtx plugin.SubTaskContext) errors.Error {
data := taskCtx.GetData().(*QDevTaskData)
db := taskCtx.GetDal()

// 列出指定前缀下的所有对象
var continuationToken *string
prefix := data.Options.S3Prefix
// normalizeS3Prefix ensures the prefix ends with "/" if it's not empty
func normalizeS3Prefix(prefix string) string {
if prefix != "" && !strings.HasSuffix(prefix, "/") {
prefix = prefix + "/"
return prefix + "/"
}
return prefix
}

taskCtx.SetProgress(0, -1)
// isCSVFile checks if the given S3 object key represents a CSV file
func isCSVFile(key string) bool {
return strings.HasSuffix(key, ".csv")
}

// S3ListObjectsFunc defines the callback for listing S3 objects
type S3ListObjectsFunc func(input *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error)

// FindFileMetaFunc defines the callback for finding existing file metadata
type FindFileMetaFunc func(connectionId uint64, s3Path string) (*models.QDevS3FileMeta, error)

// SaveFileMetaFunc defines the callback for saving file metadata
type SaveFileMetaFunc func(fileMeta *models.QDevS3FileMeta) error

// ProgressFunc defines the callback for progress tracking
type ProgressFunc func(increment int)

// LogFunc defines the callback for logging
type LogFunc func(format string, args ...interface{})

// collectS3FilesCore contains the core logic for collecting S3 files
func collectS3FilesCore(
bucket, prefix string,
connectionId uint64,
listObjects S3ListObjectsFunc,
findFileMeta FindFileMetaFunc,
saveFileMeta SaveFileMetaFunc,
progress ProgressFunc,
logDebug LogFunc,
) error {
// List all objects under the specified prefix
var continuationToken *string
normalizedPrefix := normalizeS3Prefix(prefix)
csvFilesFound := 0

for {
input := &s3.ListObjectsV2Input{
Bucket: aws.String(data.S3Client.Bucket),
Prefix: aws.String(prefix),
Bucket: aws.String(bucket),
Prefix: aws.String(normalizedPrefix),
ContinuationToken: continuationToken,
}

result, err := data.S3Client.S3.ListObjectsV2(input)
result, err := listObjects(input)
if err != nil {
return errors.Convert(err)
return err
}

// 处理每个CSV文件
// Process each CSV file
for _, object := range result.Contents {
// Only process CSV files
if !strings.HasSuffix(*object.Key, ".csv") {
taskCtx.GetLogger().Debug("Skipping non-CSV file: %s", *object.Key)
if !isCSVFile(*object.Key) {
logDebug("Skipping non-CSV file: %s", *object.Key)
continue
}

// Check if this file already exists in our database
existingFile := &models.QDevS3FileMeta{}
err = db.First(existingFile, dal.Where("connection_id = ? AND s3_path = ?",
data.Options.ConnectionId, *object.Key))
csvFilesFound++

// Check if this file already exists in our database
existingFile, err := findFileMeta(connectionId, *object.Key)
if err == nil {
// File already exists in database, skip it if it's already processed
if existingFile.Processed {
taskCtx.GetLogger().Debug("Skipping already processed file: %s", *object.Key)
logDebug("Skipping already processed file: %s", *object.Key)
continue
}
// Otherwise, we'll keep the existing record (which is still marked as unprocessed)
taskCtx.GetLogger().Debug("Found existing unprocessed file: %s", *object.Key)
logDebug("Found existing unprocessed file: %s", *object.Key)
continue
} else if !db.IsErrorNotFound(err) {
return errors.Default.Wrap(err, "failed to query existing file metadata")
}

// This is a new file, save its metadata
fileMeta := &models.QDevS3FileMeta{
ConnectionId: data.Options.ConnectionId,
ConnectionId: connectionId,
FileName: *object.Key,
S3Path: *object.Key,
Processed: false,
}

err = db.Create(fileMeta)
if err != nil {
return errors.Default.Wrap(err, "failed to create file metadata")
if err := saveFileMeta(fileMeta); err != nil {
return err
}

taskCtx.IncProgress(1)
progress(1)
}

// 如果没有更多对象,退出循环
// If there are no more objects, exit the loop
if !*result.IsTruncated {
break
}

continuationToken = result.NextContinuationToken
}

// Check if no CSV files were found
if csvFilesFound == 0 {
return errors.BadInput.New("no CSV files found in S3 path. Please verify the S3 bucket and prefix configuration")
}

return nil
}

var _ plugin.SubTaskEntryPoint = CollectQDevS3Files

// CollectQDevS3Files 收集S3文件元数据
func CollectQDevS3Files(taskCtx plugin.SubTaskContext) errors.Error {
data := taskCtx.GetData().(*QDevTaskData)
db := taskCtx.GetDal()

taskCtx.SetProgress(0, -1)

// Define callback functions
listObjects := func(input *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error) {
return data.S3Client.S3.ListObjectsV2(input)
}

findFileMeta := func(connectionId uint64, s3Path string) (*models.QDevS3FileMeta, error) {
existingFile := &models.QDevS3FileMeta{}
err := db.First(existingFile, dal.Where("connection_id = ? AND s3_path = ?", connectionId, s3Path))
if err != nil {
if db.IsErrorNotFound(err) {
return nil, err
}
return nil, errors.Default.Wrap(err, "failed to query existing file metadata")
}
return existingFile, nil
}

saveFileMeta := func(fileMeta *models.QDevS3FileMeta) error {
err := db.Create(fileMeta)
if err != nil {
return errors.Default.Wrap(err, "failed to create file metadata")
}
return nil
}

progress := func(increment int) {
taskCtx.IncProgress(increment)
}

logDebug := func(format string, args ...interface{}) {
taskCtx.GetLogger().Debug(format, args...)
}

// Call the core function
err := collectS3FilesCore(
data.S3Client.Bucket,
data.Options.S3Prefix,
data.Options.ConnectionId,
listObjects,
findFileMeta,
saveFileMeta,
progress,
logDebug,
)

return errors.Convert(err)
}

var CollectQDevS3FilesMeta = plugin.SubTaskMeta{
Name: "collectQDevS3Files",
EntryPoint: CollectQDevS3Files,
Expand Down
173 changes: 173 additions & 0 deletions backend/plugins/q_dev/tasks/s3_file_collector_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/*
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package tasks

import (
"errors"
"testing"

"github.com/apache/incubator-devlake/plugins/q_dev/models"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
)

func TestNormalizeS3Prefix(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"", ""},
{"prefix", "prefix/"},
{"prefix/", "prefix/"},
{"path/to/folder", "path/to/folder/"},
{"path/to/folder/", "path/to/folder/"},
}

for _, test := range tests {
result := normalizeS3Prefix(test.input)
assert.Equal(t, test.expected, result)
}
}

func TestIsCSVFile(t *testing.T) {
tests := []struct {
key string
expected bool
}{
{"file.csv", true},
{"data.CSV", false},
{"report.csv", true},
{"document.txt", false},
{"path/to/file.csv", true},
{"file.csv.backup", false},
{"", false},
}

for _, test := range tests {
result := isCSVFile(test.key)
assert.Equal(t, test.expected, result)
}
}

func TestCollectS3FilesCore_Success(t *testing.T) {
// Mock functions
listObjects := func(input *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error) {
return &s3.ListObjectsV2Output{
Contents: []*s3.Object{
{Key: aws.String("file1.csv")},
{Key: aws.String("file2.txt")},
{Key: aws.String("data.csv")},
},
IsTruncated: aws.Bool(false),
}, nil
}

findFileMeta := func(connectionId uint64, s3Path string) (*models.QDevS3FileMeta, error) {
return nil, errors.New("not found")
}

createdFiles := []string{}
saveFileMeta := func(fileMeta *models.QDevS3FileMeta) error {
createdFiles = append(createdFiles, fileMeta.S3Path)
return nil
}

progressCount := 0
progress := func(increment int) {
progressCount += increment
}

logMessages := []string{}
logDebug := func(format string, args ...interface{}) {
logMessages = append(logMessages, format)
}

err := collectS3FilesCore("bucket", "prefix", 1, listObjects, findFileMeta, saveFileMeta, progress, logDebug)

assert.NoError(t, err)
assert.Equal(t, 2, len(createdFiles))
assert.Contains(t, createdFiles, "file1.csv")
assert.Contains(t, createdFiles, "data.csv")
assert.Equal(t, 2, progressCount)
assert.Contains(t, logMessages, "Skipping non-CSV file: %s")
}

func TestCollectS3FilesCore_NoCSVFiles(t *testing.T) {
listObjects := func(input *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error) {
return &s3.ListObjectsV2Output{
Contents: []*s3.Object{
{Key: aws.String("file1.txt")},
{Key: aws.String("file2.json")},
},
IsTruncated: aws.Bool(false),
}, nil
}

findFileMeta := func(connectionId uint64, s3Path string) (*models.QDevS3FileMeta, error) {
return nil, errors.New("not found")
}

saveFileMeta := func(fileMeta *models.QDevS3FileMeta) error {
return nil
}

progress := func(increment int) {}
logDebug := func(format string, args ...interface{}) {}

err := collectS3FilesCore("bucket", "prefix", 1, listObjects, findFileMeta, saveFileMeta, progress, logDebug)

assert.Error(t, err)
assert.Contains(t, err.Error(), "no CSV files found")
}

func TestCollectS3FilesCore_SkipProcessedFiles(t *testing.T) {
listObjects := func(input *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error) {
return &s3.ListObjectsV2Output{
Contents: []*s3.Object{
{Key: aws.String("processed.csv")},
{Key: aws.String("unprocessed.csv")},
},
IsTruncated: aws.Bool(false),
}, nil
}

findFileMeta := func(connectionId uint64, s3Path string) (*models.QDevS3FileMeta, error) {
if s3Path == "processed.csv" {
return &models.QDevS3FileMeta{Processed: true}, nil
}
if s3Path == "unprocessed.csv" {
return &models.QDevS3FileMeta{Processed: false}, nil
}
return nil, errors.New("not found")
}

createdFiles := []string{}
saveFileMeta := func(fileMeta *models.QDevS3FileMeta) error {
createdFiles = append(createdFiles, fileMeta.S3Path)
return nil
}

progress := func(increment int) {}
logDebug := func(format string, args ...interface{}) {}

err := collectS3FilesCore("bucket", "prefix", 1, listObjects, findFileMeta, saveFileMeta, progress, logDebug)

assert.NoError(t, err)
assert.Equal(t, 0, len(createdFiles)) // No new files should be created
}