Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions flower/api/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,8 @@ def get(self):
:query state: filter tasks by state
:query received_start: filter tasks by received date (must be greater than) format %Y-%m-%d %H:%M
:query received_end: filter tasks by received date (must be less than) format %Y-%m-%d %H:%M
:query only_fields: returns only selected fields for tasks (comma-separated)
:query except_fields: returns all but selected fields for tasks (comma-separated)
:reqheader Authorization: optional OAuth token to authenticate
:statuscode 200: no error
:statuscode 401: unauthorized request
Expand All @@ -507,6 +509,8 @@ def get(self):
received_end = self.get_argument('received_end', None)
sort_by = self.get_argument('sort_by', None)
search = self.get_argument('search', None)
only_fields = self.get_argument('only_fields', None)
except_fields = self.get_argument('except_fields', None)

limit = limit and int(limit)
offset = max(offset, 0)
Expand All @@ -522,7 +526,7 @@ def get(self):
received_end=received_end,
search=search
):
task = tasks.as_dict(task)
task = tasks.as_dict(task, only_fields=only_fields, except_fields=except_fields)
worker = task.pop('worker', None)
if worker is not None:
task['worker'] = worker.hostname
Comment on lines +529 to 532
Copy link

Copilot AI Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The worker hostname transformation on lines 530-532 happens after filtering, which means if 'worker' is included in only_fields or excluded from except_fields, the transformation may not work as expected. The worker field should be transformed before applying field filtering.

Suggested change
task = tasks.as_dict(task, only_fields=only_fields, except_fields=except_fields)
worker = task.pop('worker', None)
if worker is not None:
task['worker'] = worker.hostname
if hasattr(task, 'worker') and task.worker is not None:
task.worker = task.worker.hostname
task = tasks.as_dict(task, only_fields=only_fields, except_fields=except_fields)

Copilot uses AI. Check for mistakes.
Expand Down Expand Up @@ -621,18 +625,22 @@ def get(self, taskid):
"worker": "celery@worker1"
}

:query only_fields: returns only selected fields for task (comma-separated)
:query except_fields: returns all but selected fields for task (comma-separated)
:reqheader Authorization: optional OAuth token to authenticate
:statuscode 200: no error
:statuscode 401: unauthorized request
:statuscode 404: unknown task
"""
only_fields = self.get_argument('only_fields', None)
except_fields = self.get_argument('except_fields', None)

task = tasks.get_task_by_id(self.application.events, taskid)
if not task:
raise HTTPError(404, f"Unknown task '{taskid}'")

response = task.as_dict()
if task.worker is not None:
response = tasks.as_dict(task, only_fields=only_fields, except_fields=except_fields)
if task.worker is not None and 'worker' in response:
response['worker'] = task.worker.hostname

Comment on lines +642 to 645
Copy link

Copilot AI Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition checks if 'worker' exists in the filtered response, but the worker hostname assignment happens after filtering. If 'worker' is excluded by filtering, this condition will prevent the assignment, but if 'worker' is included, the assignment will overwrite the filtered value. Consider applying the worker hostname transformation before filtering.

Suggested change
response = tasks.as_dict(task, only_fields=only_fields, except_fields=except_fields)
if task.worker is not None and 'worker' in response:
response['worker'] = task.worker.hostname
# Ensure the worker field is set to the hostname before filtering
if task.worker is not None:
task.worker = task.worker.hostname
response = tasks.as_dict(task, only_fields=only_fields, except_fields=except_fields)

Copilot uses AI. Check for mistakes.
self.write(response)
33 changes: 31 additions & 2 deletions flower/utils/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,34 @@ def get_task_by_id(events, task_id):
return events.state.tasks.get(task_id)


def as_dict(task):
return task.as_dict()
def filter_dict(task_dict, only_fields=None, except_fields=None):
"""
Filter a dictionary based on only_fields or except_fields parameters.

Args:
task_dict (dict): The dictionary to filter
only_fields (str or list): Fields to include (excludes all others)
except_fields (str or list): Fields to exclude (includes all others)

Returns:
dict: The filtered dictionary
"""
Comment on lines +79 to +80
Copy link

Copilot AI Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function doesn't validate mutual exclusivity of only_fields and except_fields parameters. If both are provided, only_fields takes precedence silently, which could lead to unexpected behavior. Consider adding validation to reject requests with both parameters or documenting this precedence clearly.

Suggested change
dict: The filtered dictionary
"""
dict: The filtered dictionary
Raises:
ValueError: If both only_fields and except_fields are provided.
"""
if only_fields and except_fields:
raise ValueError("Cannot specify both only_fields and except_fields. Please provide only one.")

Copilot uses AI. Check for mistakes.
if only_fields:
# Convert comma-separated string to list if necessary
if isinstance(only_fields, str):
only_fields = [field.strip() for field in only_fields.split(',')]
# Keep only the specified fields
return {k: v for k, v in task_dict.items() if k in only_fields}
elif except_fields:
# Convert comma-separated string to list if necessary
if isinstance(except_fields, str):
except_fields = [field.strip() for field in except_fields.split(',')]
# Remove the specified fields
return {k: v for k, v in task_dict.items() if k not in except_fields}

return task_dict


def as_dict(task, only_fields=None, except_fields=None):
task_dict = task.as_dict()
return filter_dict(task_dict, only_fields, except_fields)
122 changes: 120 additions & 2 deletions tests/unit/api/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from celery.result import AsyncResult

from flower.events import EventsState
from flower.utils.tasks import filter_dict
from tests.unit.utils import task_succeeded_events

from . import BaseApiTestCase
Expand Down Expand Up @@ -93,7 +94,28 @@ class MockTasks:
@staticmethod
def get_task_by_id(events, task_id):
from celery.events.state import Task
return Task()
task = Task()
# Set some test data on the task
task.name = 'test_task'
task.state = 'SUCCESS'
return task

@staticmethod
def as_dict(task, only_fields=None, except_fields=None):
# Create a mock dictionary with test data
task_dict = {
'name': task.name,
'state': task.state,
'worker': 'test_worker',
'received': 1234567890,
'started': 1234567891,
'succeeded': 1234567892,
'timestamp': 1234567892,
'runtime': 2.0
}

# Use the filter_dict function to handle field filtering
return filter_dict(task_dict, only_fields, except_fields)


class TaskTests(BaseApiTestCase):
Expand All @@ -106,7 +128,68 @@ def get_app(self, capp=None):

@patch('flower.api.tasks.tasks', new=MockTasks)
def test_task_info(self):
self.get('/api/task/info/123')
# Make the request
r = self.get('/api/task/info/123')

# Parse the response
task = json.loads(r.body.decode("utf-8"))

# Assert the response status code
self.assertEqual(200, r.code)

# Assert the task data
self.assertEqual('test_task', task['name'])
self.assertEqual('SUCCESS', task['state'])
self.assertEqual('test_worker', task['worker'])
self.assertEqual(1234567890, task['received'])
self.assertEqual(1234567891, task['started'])
self.assertEqual(1234567892, task['succeeded'])
self.assertEqual(1234567892, task['timestamp'])
self.assertEqual(2.0, task['runtime'])

def test_task_info_field_selection(self):
state = EventsState()
state.get_or_create_worker('worker1')
events = [Event('worker-online', hostname='worker1')]
events += task_succeeded_events(worker='worker1', name='task1',
id='123')

for i, e in enumerate(events):
e['clock'] = i
e['local_received'] = time.time()
state.event(e)
self.app.events.state = state

# Test only_fields parameter
params = dict(only_fields='name,state')

r = self.get('/api/task/info/123?' + '&'.join(
map(lambda x: '%s=%s' % x, params.items())))

task = json.loads(r.body.decode("utf-8"))

self.assertEqual(200, r.code)
# Check that only the specified fields are returned
self.assertEqual(2, len(task))
self.assertIn('name', task)
self.assertIn('state', task)
self.assertNotIn('worker', task)
self.assertNotIn('received', task)

# Test except_fields parameter
params = dict(except_fields='worker,received')

r = self.get('/api/task/info/123?' + '&'.join(
map(lambda x: '%s=%s' % x, params.items())))

task = json.loads(r.body.decode("utf-8"))

self.assertEqual(200, r.code)
# Check that the specified fields are not returned
self.assertIn('name', task)
self.assertIn('state', task)
self.assertNotIn('worker', task)
self.assertNotIn('received', task)

def test_tasks_pagination(self):
state = EventsState()
Expand Down Expand Up @@ -216,3 +299,38 @@ def test_tasks_pagination(self):
self.assertEqual(1, len(table))
firstFetchedTaskName = table[list(table)[0]]['name']
self.assertEqual("task1", firstFetchedTaskName)

# Test only_fields parameter
params = dict(limit=4, offset=0, sort_by='name', only_fields='name,state')

r = self.get('/api/tasks?' + '&'.join(
map(lambda x: '%s=%s' % x, params.items())))

table = json.loads(r.body.decode("utf-8"), object_pairs_hook=OrderedDict)

self.assertEqual(200, r.code)
self.assertEqual(4, len(table))
# Check that only the specified fields are returned
task = table[list(table)[0]]
self.assertEqual(2, len(task))
self.assertIn('name', task)
self.assertIn('state', task)
self.assertNotIn('worker', task)
self.assertNotIn('received', task)

# Test except_fields parameter
params = dict(limit=4, offset=0, sort_by='name', except_fields='worker,received')

r = self.get('/api/tasks?' + '&'.join(
map(lambda x: '%s=%s' % x, params.items())))

table = json.loads(r.body.decode("utf-8"), object_pairs_hook=OrderedDict)

self.assertEqual(200, r.code)
self.assertEqual(4, len(table))
# Check that the specified fields are not returned
task = table[list(table)[0]]
self.assertIn('name', task)
self.assertIn('state', task)
self.assertNotIn('worker', task)
self.assertNotIn('received', task)