Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 153 additions & 21 deletions flower/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import logging
from urllib.parse import quote

from concurrent.futures import ThreadPoolExecutor

Expand Down Expand Up @@ -37,21 +38,21 @@ class Flower(tornado.web.Application):

def __init__(self, options=None, capp=None, events=None,
io_loop=None, **kwargs):

handlers = default_handlers
if options is not None and options.url_prefix:
handlers = [rewrite_handler(h, options.url_prefix) for h in handlers]
kwargs.update(handlers=handlers)

super().__init__(**kwargs)

self.options = options or default_options
self.io_loop = io_loop or ioloop.IOLoop.instance()
self.ssl_options = kwargs.get('ssl_options', None)

self.capp = capp or celery.Celery()
self.capp.loader.import_default_modules()

self.executor = self.pool_executor_cls(max_workers=self.max_workers)
self.io_loop.set_default_executor(self.executor)

self.inspector = Inspector(self.io_loop, self.capp, self.options.inspect_timeout / 1000.0)

self.events = events or Events(
Expand All @@ -63,33 +64,85 @@ def __init__(self, options=None, capp=None, events=None,
io_loop=self.io_loop,
max_workers_in_memory=self.options.max_workers,
max_tasks_in_memory=self.options.max_tasks)
self.started = False

def start(self):
self._http_server = None
self._executor = None

def _start_executor(self):
if self._executor is None:
logging.debug("Starting executor...")
ctx = self.pool_executor_cls(max_workers=self.max_workers)
self._executor = ctx.__enter__() # pylint: disable=unnecessary-dunder-call
self.io_loop.set_default_executor(self._executor)

def _stop_executor(self):
if self._executor is not None:
logging.debug("Stop executor...")
self._executor.__exit__(None, None, None)
self._executor = None

def _start_events(self):
self.events.start()

def _stop_events(self):
self.events.stop()

def _start_http_server(self):
logging.debug("Starting HTTP server...")
if not self.options.unix_socket:
self.listen(self.options.port, address=self.options.address,
ssl_options=self.ssl_options,
xheaders=self.options.xheaders)
http_server = self.listen(
self.options.port,
address=self.options.address,
ssl_options=self.ssl_options,
xheaders=self.options.xheaders
)
else:
from tornado.netutil import bind_unix_socket
server = HTTPServer(self)
socket = bind_unix_socket(self.options.unix_socket, mode=0o777)
server.add_socket(socket)

self.started = True
self.update_workers()
http_server = HTTPServer(self)
socket = bind_unix_socket(self.options.unix_socket, mode=0o777)
http_server.add_socket(socket)
self._http_server = http_server

def _stop_http_server(self):
logging.debug("Stopping HTTP server...")
self.io_loop.run_sync(
self._http_server.close_all_connections, timeout=5
)
self._http_server.stop()
self._http_server = None

def start_server(self):
if self._http_server is not None:
logging.debug("Flower server already started.")
return
logging.debug("Starting Flower server...")
self._start_executor()
self._start_events()
self._start_http_server()
logging.debug("Flower server started.")

def stop_server(self):
if self._http_server is None:
logging.debug("Flower server already stopped.")
return
logging.debug("Stopping Flower server...")
self._stop_events()
self._stop_http_server()
self._stop_executor()
logging.debug("Flower server stopped.")

def serve_forever(self):
if not self._http_server:
raise RuntimeError("The server is not running")
logging.debug("Starting event loop...")
self.io_loop.start()

def stop(self):
if self.started:
self.events.stop()
logging.debug("Stopping executors...")
self.executor.shutdown(wait=False)
logging.debug("Stopping event loop...")
self.io_loop.stop()
self.started = False
def shutdown(self):
if self._http_server:
raise RuntimeError("The server is still running")
logging.debug("Stopping event loop...")
self.io_loop.stop()

@property
def transport(self):
Expand All @@ -101,3 +154,82 @@ def workers(self):

def update_workers(self, workername=None):
return self.inspector.inspect(workername)

def _get_scheme(self):
if self.options.unix_socket:
return "http+unix"
if self.ssl_options:
return "https"
return "http"

def _get_socket(self):
sockets = getattr(self._http_server, "_sockets", None) # pylint: disable=protected-access
if sockets:
return list(sockets.values())[0]
return None

def _get_domain(self):
if self.options.unix_socket:
raise RuntimeError("UNIX socket")

sock = self._get_socket()
if sock is not None:
return sock.getsockname()[0]

return self.options.address or "0.0.0.0"

def _get_port(self):
if self.options.unix_socket:
raise RuntimeError("UNIX socket")

sock = self._get_socket()
if sock is not None:
return sock.getsockname()[1]

return self.options.port

def _get_authority(self):
if self.options.unix_socket:
return quote(self.options.unix_socket)

return f"{self._get_domain()}:{self._get_port()}"

def _get_url_path(self, path=None):
path = path or ""
if not self.options.url_prefix:
return path

prefix = self.options.url_prefix.strip("/")
return f"/{prefix}{path}"

def get_url(self, path=None):
path = self._get_url_path(path)
return f"{self._get_scheme()}://{self._get_authority()}{path}"

#
# For backward compatibility
#

def start(self):
self.start_server()
self.update_workers()
self.serve_forever()

def stop(self):
self.stop_server()
self.shutdown()

@property
def started(self):
return self._http_server is not None

@started.setter
def started(self, value):
if value:
self.start_server()
else:
self.stop_server()

@property
def executor(self):
return self._executor
36 changes: 18 additions & 18 deletions flower/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,17 @@ def flower(ctx, tornado_argv):
atexit.register(flower_app.stop)
signal.signal(signal.SIGTERM, sigterm_handler)

if not ctx.obj.quiet:
print_banner(app, 'ssl_options' in settings)
try:
flower_app.start_server()
finally:
# Print the banner even when server failed to start
if not ctx.obj.quiet:
print_banner(flower_app)

flower_app.update_workers()

try:
flower_app.start()
flower_app.serve_forever()
except (KeyboardInterrupt, SystemExit):
pass

Expand Down Expand Up @@ -158,24 +164,18 @@ def is_flower_envvar(name):
name[len(ENV_VAR_PREFIX):].lower() in default_options


def print_banner(app, ssl):
if not options.unix_socket:
if options.url_prefix:
prefix_str = f'/{options.url_prefix}/'
else:
prefix_str = ''

logger.info(
"Visit me at http%s://%s:%s%s", 's' if ssl else '',
options.address or '0.0.0.0', options.port,
prefix_str
)
def print_banner(flower_app):
if not flower_app.options.unix_socket:
url = flower_app.get_url()
logger.info("Visit me at %s", url)
Comment on lines +169 to +170
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the fix. We get the URL from the Flower app. Before we got it from the Celery app before the server started.

else:
logger.info("Visit me via unix socket file: %s", options.unix_socket)
unix_socket = flower_app.options.unix_socket
logger.info("Visit me via unix socket file: %s", unix_socket)

logger.info('Broker: %s', app.connection().as_uri())
capp = flower_app.capp
logger.info('Broker: %s', capp.connection().as_uri())
logger.info(
'Registered tasks: \n%s',
pformat(sorted(app.tasks.keys()))
pformat(sorted(capp.tasks.keys()))
)
logger.debug('Settings: %s', pformat(settings))
57 changes: 45 additions & 12 deletions tests/unit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,60 @@

import celery
import tornado.testing
from tornado.ioloop import IOLoop
from tornado.options import options
from tornado.httpclient import AsyncHTTPClient, HTTPResponse

from flower import command # noqa: F401 side effect - define options
from flower.app import Flower
from flower.events import Events
from flower.urls import handlers, settings


class AsyncHTTPTestCase(tornado.testing.AsyncHTTPTestCase):
class AsyncHTTPTestCase(tornado.testing.AsyncTestCase):

def _get_celery_app(self):
return celery.Celery()
def setUp(self) -> None:
super().setUp()
self._http_client = AsyncHTTPClient()
self._capp = celery.Celery()
self._start_flower()

def get_app(self, capp=None):
if not capp:
capp = self._get_celery_app()
events = Events(capp, IOLoop.current())
app = Flower(capp=capp, events=events,
options=options, handlers=handlers, **settings)
return app
def _start_flower(self):
self._app = Flower(
capp=self._capp,
io_loop=self.io_loop,
options=options,
handlers=handlers,
**settings
)
self._app.start_server()

def _stop_flower(self):
self._app.stop_server()

def _restart_flower(self, reset_celery_app=False):
self._stop_flower()
if reset_celery_app:
self._capp = celery.Celery()
self._start_flower()

def tearDown(self) -> None:
self._http_client.close()
self._app.stop_server()
del self._http_client
del self._app
super().tearDown()

def fetch(
self, path: str, raise_error: bool = False, **kwargs
) -> HTTPResponse:
url = self._app.get_url(path)

def fetch():
return self._http_client.fetch(url, raise_error=raise_error, **kwargs)

return self.io_loop.run_sync(
fetch,
timeout=tornado.testing.get_async_test_timeout(),
)

def get(self, url, **kwargs):
return self.fetch(url, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/api/test_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ def test_unknown_worker(self):

class WorkerControlTests(BaseApiTestCase):
def setUp(self):
BaseApiTestCase.setUp(self)
super().setUp()
self.is_worker = ControlHandler.is_worker
ControlHandler.is_worker = lambda *args: True

def tearDown(self):
BaseApiTestCase.tearDown(self)
super().tearDown()
ControlHandler.is_worker = self.is_worker

def test_shutdown(self):
Expand Down
8 changes: 1 addition & 7 deletions tests/unit/api/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,6 @@ def get_task_by_id(events, task_id):


class TaskTests(BaseApiTestCase):
def setUp(self):
self.app = super().get_app()
super().setUp()

def get_app(self, capp=None):
return self.app

@patch('flower.api.tasks.tasks', new=MockTasks)
def test_task_info(self):
Expand All @@ -127,7 +121,7 @@ def test_tasks_pagination(self):
e['clock'] = i
e['local_received'] = time.time()
state.event(e)
self.app.events.state = state
self._app.events.state = state

# Test limit 4 and offset 0
params = dict(limit=4, offset=0, sort_by='name')
Expand Down
Loading
Loading