@@ -165,6 +165,24 @@ binder::expression_vector LogicalHashJoin::getJoinNodeIDs(
165165 return result;
166166}
167167
168+ class JoinNodeIDUniquenessAnalyzer {
169+ public:
170+ bool isUnique (LogicalOperator* op, const binder::Expression& joinNodeID) {
171+ switch (op->getOperatorType ()) {
172+ case LogicalOperatorType::FILTER:
173+ case LogicalOperatorType::FLATTEN:
174+ case LogicalOperatorType::LIMIT:
175+ case LogicalOperatorType::PROJECTION:
176+ case LogicalOperatorType::SEMI_MASKER:
177+ return isUnique (op->getChild (0 ).get (), joinNodeID);
178+ case LogicalOperatorType::SCAN_NODE_TABLE:
179+ return *op->constCast <LogicalScanNodeTable>().getNodeID () == joinNodeID;
180+ default :
181+ return false ;
182+ }
183+ }
184+ };
185+
168186bool LogicalHashJoin::requireFlatProbeKeys () {
169187 // Flatten for multiple join keys.
170188 if (joinConditions.size () > 1 ) {
@@ -179,35 +197,8 @@ bool LogicalHashJoin::requireFlatProbeKeys() {
179197 if (probeKey->dataType .getLogicalTypeID () != LogicalTypeID::INTERNAL_ID) {
180198 return true ;
181199 }
182- return !isJoinKeyUniqueOnBuildSide (*buildKey);
183- }
184-
185- bool LogicalHashJoin::isJoinKeyUniqueOnBuildSide (const binder::Expression& joinNodeID) {
186- auto buildSchema = children[1 ]->getSchema ();
187- auto numGroupsInScope = buildSchema->getGroupsPosInScope ().size ();
188- bool hasProjectedOutGroups = buildSchema->getNumGroups () > numGroupsInScope;
189- if (numGroupsInScope > 1 || hasProjectedOutGroups) {
190- return false ;
191- }
192- // Now there is a single factorization group, we need to further make sure joinNodeID comes from
193- // ScanNodeID operator. Because if joinNodeID comes from a ColExtend we cannot guarantee the
194- // reverse mapping is still many-to-one. We look for the most simple pattern where build plan is
195- // linear.
196- auto op = children[1 ].get ();
197- while (op->getNumChildren () != 0 ) {
198- if (op->getNumChildren () > 1 ) {
199- return false ;
200- }
201- op = op->getChild (0 ).get ();
202- }
203- if (op->getOperatorType () != LogicalOperatorType::SCAN_NODE_TABLE) {
204- return false ;
205- }
206- auto scan = ku_dynamic_cast<LogicalScanNodeTable*>(op);
207- if (scan->getNodeID ()->getUniqueName () != joinNodeID.getUniqueName ()) {
208- return false ;
209- }
210- return true ;
200+ auto anaylzer = JoinNodeIDUniquenessAnalyzer ();
201+ return !anaylzer.isUnique (children[1 ].get (), *buildKey);
211202}
212203
213204} // namespace planner
0 commit comments