diff --git a/backend/plugins/q_dev/tasks/s3_file_collector.go b/backend/plugins/q_dev/tasks/s3_file_collector.go index 37abf184563..67a44cd33c7 100644 --- a/backend/plugins/q_dev/tasks/s3_file_collector.go +++ b/backend/plugins/q_dev/tasks/s3_file_collector.go @@ -28,77 +28,100 @@ import ( "github.com/aws/aws-sdk-go/service/s3" ) -var _ plugin.SubTaskEntryPoint = CollectQDevS3Files - -// CollectQDevS3Files 收集S3文件元数据 -func CollectQDevS3Files(taskCtx plugin.SubTaskContext) errors.Error { - data := taskCtx.GetData().(*QDevTaskData) - db := taskCtx.GetDal() - - // 列出指定前缀下的所有对象 - var continuationToken *string - prefix := data.Options.S3Prefix +// normalizeS3Prefix ensures the prefix ends with "/" if it's not empty +func normalizeS3Prefix(prefix string) string { if prefix != "" && !strings.HasSuffix(prefix, "/") { - prefix = prefix + "/" + return prefix + "/" } + return prefix +} - taskCtx.SetProgress(0, -1) +// isCSVFile checks if the given S3 object key represents a CSV file +func isCSVFile(key string) bool { + return strings.HasSuffix(key, ".csv") +} + +// S3ListObjectsFunc defines the callback for listing S3 objects +type S3ListObjectsFunc func(input *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error) + +// FindFileMetaFunc defines the callback for finding existing file metadata +type FindFileMetaFunc func(connectionId uint64, s3Path string) (*models.QDevS3FileMeta, error) + +// SaveFileMetaFunc defines the callback for saving file metadata +type SaveFileMetaFunc func(fileMeta *models.QDevS3FileMeta) error + +// ProgressFunc defines the callback for progress tracking +type ProgressFunc func(increment int) + +// LogFunc defines the callback for logging +type LogFunc func(format string, args ...interface{}) + +// collectS3FilesCore contains the core logic for collecting S3 files +func collectS3FilesCore( + bucket, prefix string, + connectionId uint64, + listObjects S3ListObjectsFunc, + findFileMeta FindFileMetaFunc, + saveFileMeta SaveFileMetaFunc, + progress ProgressFunc, + logDebug LogFunc, +) error { + // List all objects under the specified prefix + var continuationToken *string + normalizedPrefix := normalizeS3Prefix(prefix) + csvFilesFound := 0 for { input := &s3.ListObjectsV2Input{ - Bucket: aws.String(data.S3Client.Bucket), - Prefix: aws.String(prefix), + Bucket: aws.String(bucket), + Prefix: aws.String(normalizedPrefix), ContinuationToken: continuationToken, } - result, err := data.S3Client.S3.ListObjectsV2(input) + result, err := listObjects(input) if err != nil { - return errors.Convert(err) + return err } - // 处理每个CSV文件 + // Process each CSV file for _, object := range result.Contents { // Only process CSV files - if !strings.HasSuffix(*object.Key, ".csv") { - taskCtx.GetLogger().Debug("Skipping non-CSV file: %s", *object.Key) + if !isCSVFile(*object.Key) { + logDebug("Skipping non-CSV file: %s", *object.Key) continue } - // Check if this file already exists in our database - existingFile := &models.QDevS3FileMeta{} - err = db.First(existingFile, dal.Where("connection_id = ? AND s3_path = ?", - data.Options.ConnectionId, *object.Key)) + csvFilesFound++ + // Check if this file already exists in our database + existingFile, err := findFileMeta(connectionId, *object.Key) if err == nil { // File already exists in database, skip it if it's already processed if existingFile.Processed { - taskCtx.GetLogger().Debug("Skipping already processed file: %s", *object.Key) + logDebug("Skipping already processed file: %s", *object.Key) continue } // Otherwise, we'll keep the existing record (which is still marked as unprocessed) - taskCtx.GetLogger().Debug("Found existing unprocessed file: %s", *object.Key) + logDebug("Found existing unprocessed file: %s", *object.Key) continue - } else if !db.IsErrorNotFound(err) { - return errors.Default.Wrap(err, "failed to query existing file metadata") } // This is a new file, save its metadata fileMeta := &models.QDevS3FileMeta{ - ConnectionId: data.Options.ConnectionId, + ConnectionId: connectionId, FileName: *object.Key, S3Path: *object.Key, Processed: false, } - err = db.Create(fileMeta) - if err != nil { - return errors.Default.Wrap(err, "failed to create file metadata") + if err := saveFileMeta(fileMeta); err != nil { + return err } - taskCtx.IncProgress(1) + progress(1) } - // 如果没有更多对象,退出循环 + // If there are no more objects, exit the loop if !*result.IsTruncated { break } @@ -106,9 +129,71 @@ func CollectQDevS3Files(taskCtx plugin.SubTaskContext) errors.Error { continuationToken = result.NextContinuationToken } + // Check if no CSV files were found + if csvFilesFound == 0 { + return errors.BadInput.New("no CSV files found in S3 path. Please verify the S3 bucket and prefix configuration") + } + return nil } +var _ plugin.SubTaskEntryPoint = CollectQDevS3Files + +// CollectQDevS3Files 收集S3文件元数据 +func CollectQDevS3Files(taskCtx plugin.SubTaskContext) errors.Error { + data := taskCtx.GetData().(*QDevTaskData) + db := taskCtx.GetDal() + + taskCtx.SetProgress(0, -1) + + // Define callback functions + listObjects := func(input *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error) { + return data.S3Client.S3.ListObjectsV2(input) + } + + findFileMeta := func(connectionId uint64, s3Path string) (*models.QDevS3FileMeta, error) { + existingFile := &models.QDevS3FileMeta{} + err := db.First(existingFile, dal.Where("connection_id = ? AND s3_path = ?", connectionId, s3Path)) + if err != nil { + if db.IsErrorNotFound(err) { + return nil, err + } + return nil, errors.Default.Wrap(err, "failed to query existing file metadata") + } + return existingFile, nil + } + + saveFileMeta := func(fileMeta *models.QDevS3FileMeta) error { + err := db.Create(fileMeta) + if err != nil { + return errors.Default.Wrap(err, "failed to create file metadata") + } + return nil + } + + progress := func(increment int) { + taskCtx.IncProgress(increment) + } + + logDebug := func(format string, args ...interface{}) { + taskCtx.GetLogger().Debug(format, args...) + } + + // Call the core function + err := collectS3FilesCore( + data.S3Client.Bucket, + data.Options.S3Prefix, + data.Options.ConnectionId, + listObjects, + findFileMeta, + saveFileMeta, + progress, + logDebug, + ) + + return errors.Convert(err) +} + var CollectQDevS3FilesMeta = plugin.SubTaskMeta{ Name: "collectQDevS3Files", EntryPoint: CollectQDevS3Files, diff --git a/backend/plugins/q_dev/tasks/s3_file_collector_test.go b/backend/plugins/q_dev/tasks/s3_file_collector_test.go new file mode 100644 index 00000000000..bd808afd4b4 --- /dev/null +++ b/backend/plugins/q_dev/tasks/s3_file_collector_test.go @@ -0,0 +1,173 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one or more +contributor license agreements. See the NOTICE file distributed with +this work for additional information regarding copyright ownership. +The ASF licenses this file to You under the Apache License, Version 2.0 +(the "License"); you may not use this file except in compliance with +the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tasks + +import ( + "errors" + "testing" + + "github.com/apache/incubator-devlake/plugins/q_dev/models" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/stretchr/testify/assert" +) + +func TestNormalizeS3Prefix(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"", ""}, + {"prefix", "prefix/"}, + {"prefix/", "prefix/"}, + {"path/to/folder", "path/to/folder/"}, + {"path/to/folder/", "path/to/folder/"}, + } + + for _, test := range tests { + result := normalizeS3Prefix(test.input) + assert.Equal(t, test.expected, result) + } +} + +func TestIsCSVFile(t *testing.T) { + tests := []struct { + key string + expected bool + }{ + {"file.csv", true}, + {"data.CSV", false}, + {"report.csv", true}, + {"document.txt", false}, + {"path/to/file.csv", true}, + {"file.csv.backup", false}, + {"", false}, + } + + for _, test := range tests { + result := isCSVFile(test.key) + assert.Equal(t, test.expected, result) + } +} + +func TestCollectS3FilesCore_Success(t *testing.T) { + // Mock functions + listObjects := func(input *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error) { + return &s3.ListObjectsV2Output{ + Contents: []*s3.Object{ + {Key: aws.String("file1.csv")}, + {Key: aws.String("file2.txt")}, + {Key: aws.String("data.csv")}, + }, + IsTruncated: aws.Bool(false), + }, nil + } + + findFileMeta := func(connectionId uint64, s3Path string) (*models.QDevS3FileMeta, error) { + return nil, errors.New("not found") + } + + createdFiles := []string{} + saveFileMeta := func(fileMeta *models.QDevS3FileMeta) error { + createdFiles = append(createdFiles, fileMeta.S3Path) + return nil + } + + progressCount := 0 + progress := func(increment int) { + progressCount += increment + } + + logMessages := []string{} + logDebug := func(format string, args ...interface{}) { + logMessages = append(logMessages, format) + } + + err := collectS3FilesCore("bucket", "prefix", 1, listObjects, findFileMeta, saveFileMeta, progress, logDebug) + + assert.NoError(t, err) + assert.Equal(t, 2, len(createdFiles)) + assert.Contains(t, createdFiles, "file1.csv") + assert.Contains(t, createdFiles, "data.csv") + assert.Equal(t, 2, progressCount) + assert.Contains(t, logMessages, "Skipping non-CSV file: %s") +} + +func TestCollectS3FilesCore_NoCSVFiles(t *testing.T) { + listObjects := func(input *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error) { + return &s3.ListObjectsV2Output{ + Contents: []*s3.Object{ + {Key: aws.String("file1.txt")}, + {Key: aws.String("file2.json")}, + }, + IsTruncated: aws.Bool(false), + }, nil + } + + findFileMeta := func(connectionId uint64, s3Path string) (*models.QDevS3FileMeta, error) { + return nil, errors.New("not found") + } + + saveFileMeta := func(fileMeta *models.QDevS3FileMeta) error { + return nil + } + + progress := func(increment int) {} + logDebug := func(format string, args ...interface{}) {} + + err := collectS3FilesCore("bucket", "prefix", 1, listObjects, findFileMeta, saveFileMeta, progress, logDebug) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "no CSV files found") +} + +func TestCollectS3FilesCore_SkipProcessedFiles(t *testing.T) { + listObjects := func(input *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error) { + return &s3.ListObjectsV2Output{ + Contents: []*s3.Object{ + {Key: aws.String("processed.csv")}, + {Key: aws.String("unprocessed.csv")}, + }, + IsTruncated: aws.Bool(false), + }, nil + } + + findFileMeta := func(connectionId uint64, s3Path string) (*models.QDevS3FileMeta, error) { + if s3Path == "processed.csv" { + return &models.QDevS3FileMeta{Processed: true}, nil + } + if s3Path == "unprocessed.csv" { + return &models.QDevS3FileMeta{Processed: false}, nil + } + return nil, errors.New("not found") + } + + createdFiles := []string{} + saveFileMeta := func(fileMeta *models.QDevS3FileMeta) error { + createdFiles = append(createdFiles, fileMeta.S3Path) + return nil + } + + progress := func(increment int) {} + logDebug := func(format string, args ...interface{}) {} + + err := collectS3FilesCore("bucket", "prefix", 1, listObjects, findFileMeta, saveFileMeta, progress, logDebug) + + assert.NoError(t, err) + assert.Equal(t, 0, len(createdFiles)) // No new files should be created +}