Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datajunction-server/datajunction_server/api/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,7 @@ async def add_reference_dimension_link(
),
)
await session.commit()
await session.refresh(target_column)
return JSONResponse(
status_code=201,
content={
Expand Down
36 changes: 36 additions & 0 deletions datajunction-server/datajunction_server/construction/build_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,30 @@ async def dimension_join_path(
return join_path

await refresh_if_needed(session, current_link.dimension, ["current"])

# Check the reference links on this dimension node
await refresh_if_needed(session, current_link.dimension.current, ["columns"])
for col in current_link.dimension.current.columns:
if col.dimension:
# Check if it matches the reference link dimension attribute
if f"{col.dimension.name}.{col.dimension_column}" == dimension:
return join_path
# Check if it matches any of the reference link dimension's linked attributes
await refresh_if_needed(session, col.dimension, ["current"])
await refresh_if_needed(
session,
col.dimension.current,
["dimension_links"],
)
for link in col.dimension.current.dimension_links:
if (
link.foreign_keys.get(
f"{col.dimension.name}.{col.dimension_column}",
)
== dimension
):
return join_path

await refresh_if_needed(
session,
current_link.dimension.current,
Expand Down Expand Up @@ -1410,7 +1434,19 @@ def build_dimension_attribute(
if dimension_attr.name in link.foreign_keys_reversed
else None
)
reference_links = {
col.name: f"{col.dimension.name}.{col.dimension_column}"
for col in link.dimension.current.columns
if col.dimension
}
for col in node_query.select.projection:
if reference_links.get(col.alias_or_name.name) == full_column_name: # type: ignore
return ast.Column(
name=ast.Name(col.alias_or_name.name), # type: ignore
alias=ast.Name(alias) if alias else None,
_table=node_query,
_type=col.type, # type: ignore
)
if col.alias_or_name.name == dimension_attr.column_name or ( # type: ignore
foreign_key_column_name
and col.alias_or_name.identifier() == foreign_key_column_name # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion datajunction-server/datajunction_server/models/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,8 +591,8 @@ class ColumnOutput(BaseModel):
type: str
attributes: Optional[List[AttributeOutput]]
dimension: Optional[NodeNameOutput]
dimension_column: Optional[str]
partition: Optional[PartitionOutput]
# order: Optional[int]

class Config: # pylint: disable=missing-class-docstring, too-few-public-methods
"""
Expand Down
6 changes: 3 additions & 3 deletions datajunction-server/datajunction_server/sql/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,13 @@ async def get_dimensions_dag( # pylint: disable=too-many-locals
)
.join(
graph_branches,
(current_rev.id == graph_branches.c.node_revision_id)
& (is_(graph_branches.c.dimension_column, None)),
(current_rev.id == graph_branches.c.node_revision_id),
# & (is_(graph_branches.c.dimension_column, None)),
)
.join(
next_node,
(next_node.id == graph_branches.c.dimension_id)
& (is_(graph_branches.c.dimension_column, None))
# & (is_(graph_branches.c.dimension_column, None))
& (is_(next_node.deactivated_at, None)),
)
.join(
Expand Down
74 changes: 74 additions & 0 deletions datajunction-server/tests/api/dimension_links_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,27 @@ async def _link_events_to_users_without_role() -> Response:
return _link_events_to_users_without_role


@pytest.fixture
def reference_link_users_date(
dimensions_link_client: AsyncClient, # pylint: disable=redefined-outer-name
):
"""
Create a reference link between users and date
"""

async def _reference_link_users_date() -> Response:
response = await dimensions_link_client.post(
"/nodes/default.users/columns/snapshot_date/link",
params={
"dimension_node": "default.date",
"dimension_column": "dateint",
},
)
return response

return _reference_link_users_date


@pytest.fixture
def link_events_to_users_with_role_direct(
dimensions_link_client: AsyncClient, # pylint: disable=redefined-outer-name
Expand Down Expand Up @@ -964,6 +985,59 @@ async def test_measures_sql_with_reference_dimension_links(
assert response_data[0]["errors"] == []


@pytest.mark.asyncio
async def test_measures_sql_with_ref_link_on_dim_node(
dimensions_link_client: AsyncClient, # pylint: disable=redefined-outer-name
link_events_to_users_without_role, # pylint: disable=redefined-outer-name
reference_link_users_date, # pylint: disable=redefined-outer-name
):
"""
Verify that measures SQL can be retrieved for dimension attributes that come from a
reference dimension link from one dim node to another dim node.
"""
await link_events_to_users_without_role()
await reference_link_users_date()

response = await dimensions_link_client.get(
"/sql/measures/v2",
params={
"metrics": ["default.elapsed_secs"],
"dimensions": [
"default.date.dateint",
],
},
)
response_data = response.json()
expected_sql = """
WITH default_DOT_events AS (
SELECT
default_DOT_events_table.user_id,
default_DOT_events_table.event_start_date,
default_DOT_events_table.event_end_date,
default_DOT_events_table.elapsed_secs,
default_DOT_events_table.user_registration_country
FROM examples.events AS default_DOT_events_table
),
default_DOT_users AS (
SELECT
default_DOT_users_table.user_id,
default_DOT_users_table.snapshot_date,
default_DOT_users_table.registration_country,
default_DOT_users_table.residence_country,
default_DOT_users_table.account_type
FROM examples.users AS default_DOT_users_table
)
SELECT
default_DOT_events.elapsed_secs default_DOT_events_DOT_elapsed_secs,
default_DOT_users.snapshot_date default_DOT_date_DOT_dateint
FROM default_DOT_events
LEFT JOIN default_DOT_users
ON default_DOT_events.user_id = default_DOT_users.user_id
AND default_DOT_events.event_start_date = default_DOT_users.snapshot_date
"""
assert str(parse(response_data[0]["sql"])) == str(parse(expected_sql))


@pytest.mark.asyncio
async def test_dimension_link_cross_join(
dimensions_link_client: AsyncClient, # pylint: disable=redefined-outer-name
Expand Down
10 changes: 10 additions & 0 deletions datajunction-server/tests/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2258,6 +2258,16 @@
"primary_key": ["country_code"],
},
),
(
"/nodes/dimension/",
{
"description": "Date dimension",
"query": """SELECT 1 AS dateint""",
"mode": "published",
"name": "default.date",
"primary_key": ["dateint"],
},
),
(
"/nodes/metric/",
{
Expand Down
Loading