Skip to content

Commit 167274d

Browse files
committed
refactor: extract core logic from CollectQDevS3Files with callback functions
- Extract collectS3FilesCore function with callback interfaces - Add FindFileMetaFunc and SaveFileMetaFunc for better naming - Add comprehensive unit tests for core logic - Replace Chinese comments with English - All tests passing (30/30)
1 parent 1d4ac63 commit 167274d

File tree

2 files changed

+211
-33
lines changed

2 files changed

+211
-33
lines changed

backend/plugins/q_dev/tasks/s3_file_collector.go

Lines changed: 100 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -41,79 +41,87 @@ func isCSVFile(key string) bool {
4141
return strings.HasSuffix(key, ".csv")
4242
}
4343

44-
45-
46-
var _ plugin.SubTaskEntryPoint = CollectQDevS3Files
47-
48-
// CollectQDevS3Files 收集S3文件元数据
49-
func CollectQDevS3Files(taskCtx plugin.SubTaskContext) errors.Error {
50-
data := taskCtx.GetData().(*QDevTaskData)
51-
db := taskCtx.GetDal()
52-
53-
// 列出指定前缀下的所有对象
44+
// S3ListObjectsFunc defines the callback for listing S3 objects
45+
type S3ListObjectsFunc func(input *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error)
46+
47+
// FindFileMetaFunc defines the callback for finding existing file metadata
48+
type FindFileMetaFunc func(connectionId uint64, s3Path string) (*models.QDevS3FileMeta, error)
49+
50+
// SaveFileMetaFunc defines the callback for saving file metadata
51+
type SaveFileMetaFunc func(fileMeta *models.QDevS3FileMeta) error
52+
53+
// ProgressFunc defines the callback for progress tracking
54+
type ProgressFunc func(increment int)
55+
56+
// LogFunc defines the callback for logging
57+
type LogFunc func(format string, args ...interface{})
58+
59+
// collectS3FilesCore contains the core logic for collecting S3 files
60+
func collectS3FilesCore(
61+
bucket, prefix string,
62+
connectionId uint64,
63+
listObjects S3ListObjectsFunc,
64+
findFileMeta FindFileMetaFunc,
65+
saveFileMeta SaveFileMetaFunc,
66+
progress ProgressFunc,
67+
logDebug LogFunc,
68+
) error {
69+
// List all objects under the specified prefix
5470
var continuationToken *string
55-
prefix := normalizeS3Prefix(data.Options.S3Prefix)
56-
57-
taskCtx.SetProgress(0, -1)
71+
normalizedPrefix := normalizeS3Prefix(prefix)
5872
csvFilesFound := 0
5973

6074
for {
6175
input := &s3.ListObjectsV2Input{
62-
Bucket: aws.String(data.S3Client.Bucket),
63-
Prefix: aws.String(prefix),
76+
Bucket: aws.String(bucket),
77+
Prefix: aws.String(normalizedPrefix),
6478
ContinuationToken: continuationToken,
6579
}
6680

67-
result, err := data.S3Client.S3.ListObjectsV2(input)
81+
result, err := listObjects(input)
6882
if err != nil {
69-
return errors.Convert(err)
83+
return err
7084
}
7185

72-
// 处理每个CSV文件
86+
// Process each CSV file
7387
for _, object := range result.Contents {
7488
// Only process CSV files
7589
if !isCSVFile(*object.Key) {
76-
taskCtx.GetLogger().Debug("Skipping non-CSV file: %s", *object.Key)
90+
logDebug("Skipping non-CSV file: %s", *object.Key)
7791
continue
7892
}
7993

8094
csvFilesFound++
8195

8296
// Check if this file already exists in our database
83-
existingFile := &models.QDevS3FileMeta{}
84-
err = db.First(existingFile, dal.Where("connection_id = ? AND s3_path = ?",
85-
data.Options.ConnectionId, *object.Key))
86-
97+
existingFile, err := findFileMeta(connectionId, *object.Key)
8798
if err == nil {
8899
// File already exists in database, skip it if it's already processed
89100
if existingFile.Processed {
90-
taskCtx.GetLogger().Debug("Skipping already processed file: %s", *object.Key)
101+
logDebug("Skipping already processed file: %s", *object.Key)
91102
continue
92103
}
93104
// Otherwise, we'll keep the existing record (which is still marked as unprocessed)
94-
taskCtx.GetLogger().Debug("Found existing unprocessed file: %s", *object.Key)
105+
logDebug("Found existing unprocessed file: %s", *object.Key)
95106
continue
96-
} else if !db.IsErrorNotFound(err) {
97-
return errors.Default.Wrap(err, "failed to query existing file metadata")
98107
}
99108

100109
// This is a new file, save its metadata
101110
fileMeta := &models.QDevS3FileMeta{
102-
ConnectionId: data.Options.ConnectionId,
111+
ConnectionId: connectionId,
103112
FileName: *object.Key,
104113
S3Path: *object.Key,
105114
Processed: false,
106115
}
107116

108-
err = db.Create(fileMeta)
109-
if err != nil {
110-
return errors.Default.Wrap(err, "failed to create file metadata")
117+
if err := saveFileMeta(fileMeta); err != nil {
118+
return err
111119
}
112120

113-
taskCtx.IncProgress(1)
121+
progress(1)
114122
}
115123

116-
// 如果没有更多对象,退出循环
124+
// If there are no more objects, exit the loop
117125
if !*result.IsTruncated {
118126
break
119127
}
@@ -129,6 +137,65 @@ func CollectQDevS3Files(taskCtx plugin.SubTaskContext) errors.Error {
129137
return nil
130138
}
131139

140+
141+
142+
var _ plugin.SubTaskEntryPoint = CollectQDevS3Files
143+
144+
// CollectQDevS3Files 收集S3文件元数据
145+
func CollectQDevS3Files(taskCtx plugin.SubTaskContext) errors.Error {
146+
data := taskCtx.GetData().(*QDevTaskData)
147+
db := taskCtx.GetDal()
148+
149+
taskCtx.SetProgress(0, -1)
150+
151+
// Define callback functions
152+
listObjects := func(input *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error) {
153+
return data.S3Client.S3.ListObjectsV2(input)
154+
}
155+
156+
findFileMeta := func(connectionId uint64, s3Path string) (*models.QDevS3FileMeta, error) {
157+
existingFile := &models.QDevS3FileMeta{}
158+
err := db.First(existingFile, dal.Where("connection_id = ? AND s3_path = ?", connectionId, s3Path))
159+
if err != nil {
160+
if db.IsErrorNotFound(err) {
161+
return nil, err
162+
}
163+
return nil, errors.Default.Wrap(err, "failed to query existing file metadata")
164+
}
165+
return existingFile, nil
166+
}
167+
168+
saveFileMeta := func(fileMeta *models.QDevS3FileMeta) error {
169+
err := db.Create(fileMeta)
170+
if err != nil {
171+
return errors.Default.Wrap(err, "failed to create file metadata")
172+
}
173+
return nil
174+
}
175+
176+
progress := func(increment int) {
177+
taskCtx.IncProgress(increment)
178+
}
179+
180+
logDebug := func(format string, args ...interface{}) {
181+
taskCtx.GetLogger().Debug(format, args...)
182+
}
183+
184+
// Call the core function
185+
err := collectS3FilesCore(
186+
data.S3Client.Bucket,
187+
data.Options.S3Prefix,
188+
data.Options.ConnectionId,
189+
listObjects,
190+
findFileMeta,
191+
saveFileMeta,
192+
progress,
193+
logDebug,
194+
)
195+
196+
return errors.Convert(err)
197+
}
198+
132199
var CollectQDevS3FilesMeta = plugin.SubTaskMeta{
133200
Name: "collectQDevS3Files",
134201
EntryPoint: CollectQDevS3Files,

backend/plugins/q_dev/tasks/s3_file_collector_test.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@ limitations under the License.
1818
package tasks
1919

2020
import (
21+
"errors"
2122
"testing"
2223

24+
"github.com/apache/incubator-devlake/plugins/q_dev/models"
25+
"github.com/aws/aws-sdk-go/aws"
26+
"github.com/aws/aws-sdk-go/service/s3"
2327
"github.com/stretchr/testify/assert"
2428
)
2529

@@ -60,3 +64,110 @@ func TestIsCSVFile(t *testing.T) {
6064
assert.Equal(t, test.expected, result)
6165
}
6266
}
67+
68+
func TestCollectS3FilesCore_Success(t *testing.T) {
69+
// Mock functions
70+
listObjects := func(input *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error) {
71+
return &s3.ListObjectsV2Output{
72+
Contents: []*s3.Object{
73+
{Key: aws.String("file1.csv")},
74+
{Key: aws.String("file2.txt")},
75+
{Key: aws.String("data.csv")},
76+
},
77+
IsTruncated: aws.Bool(false),
78+
}, nil
79+
}
80+
81+
findFileMeta := func(connectionId uint64, s3Path string) (*models.QDevS3FileMeta, error) {
82+
return nil, errors.New("not found")
83+
}
84+
85+
createdFiles := []string{}
86+
saveFileMeta := func(fileMeta *models.QDevS3FileMeta) error {
87+
createdFiles = append(createdFiles, fileMeta.S3Path)
88+
return nil
89+
}
90+
91+
progressCount := 0
92+
progress := func(increment int) {
93+
progressCount += increment
94+
}
95+
96+
logMessages := []string{}
97+
logDebug := func(format string, args ...interface{}) {
98+
logMessages = append(logMessages, format)
99+
}
100+
101+
err := collectS3FilesCore("bucket", "prefix", 1, listObjects, findFileMeta, saveFileMeta, progress, logDebug)
102+
103+
assert.NoError(t, err)
104+
assert.Equal(t, 2, len(createdFiles))
105+
assert.Contains(t, createdFiles, "file1.csv")
106+
assert.Contains(t, createdFiles, "data.csv")
107+
assert.Equal(t, 2, progressCount)
108+
assert.Contains(t, logMessages, "Skipping non-CSV file: %s")
109+
}
110+
111+
func TestCollectS3FilesCore_NoCSVFiles(t *testing.T) {
112+
listObjects := func(input *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error) {
113+
return &s3.ListObjectsV2Output{
114+
Contents: []*s3.Object{
115+
{Key: aws.String("file1.txt")},
116+
{Key: aws.String("file2.json")},
117+
},
118+
IsTruncated: aws.Bool(false),
119+
}, nil
120+
}
121+
122+
findFileMeta := func(connectionId uint64, s3Path string) (*models.QDevS3FileMeta, error) {
123+
return nil, errors.New("not found")
124+
}
125+
126+
saveFileMeta := func(fileMeta *models.QDevS3FileMeta) error {
127+
return nil
128+
}
129+
130+
progress := func(increment int) {}
131+
logDebug := func(format string, args ...interface{}) {}
132+
133+
err := collectS3FilesCore("bucket", "prefix", 1, listObjects, findFileMeta, saveFileMeta, progress, logDebug)
134+
135+
assert.Error(t, err)
136+
assert.Contains(t, err.Error(), "no CSV files found")
137+
}
138+
139+
func TestCollectS3FilesCore_SkipProcessedFiles(t *testing.T) {
140+
listObjects := func(input *s3.ListObjectsV2Input) (*s3.ListObjectsV2Output, error) {
141+
return &s3.ListObjectsV2Output{
142+
Contents: []*s3.Object{
143+
{Key: aws.String("processed.csv")},
144+
{Key: aws.String("unprocessed.csv")},
145+
},
146+
IsTruncated: aws.Bool(false),
147+
}, nil
148+
}
149+
150+
findFileMeta := func(connectionId uint64, s3Path string) (*models.QDevS3FileMeta, error) {
151+
if s3Path == "processed.csv" {
152+
return &models.QDevS3FileMeta{Processed: true}, nil
153+
}
154+
if s3Path == "unprocessed.csv" {
155+
return &models.QDevS3FileMeta{Processed: false}, nil
156+
}
157+
return nil, errors.New("not found")
158+
}
159+
160+
createdFiles := []string{}
161+
saveFileMeta := func(fileMeta *models.QDevS3FileMeta) error {
162+
createdFiles = append(createdFiles, fileMeta.S3Path)
163+
return nil
164+
}
165+
166+
progress := func(increment int) {}
167+
logDebug := func(format string, args ...interface{}) {}
168+
169+
err := collectS3FilesCore("bucket", "prefix", 1, listObjects, findFileMeta, saveFileMeta, progress, logDebug)
170+
171+
assert.NoError(t, err)
172+
assert.Equal(t, 0, len(createdFiles)) // No new files should be created
173+
}

0 commit comments

Comments
 (0)