diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index ddf214a4b30ac9369013b3aeacdf2a1282f41411..968bbdb1a5f036110b17ce19a24c98f15b422eb8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -76,6 +76,4 @@ case class ScalarSubquery( override def withNewPlan(plan: LogicalPlan): ScalarSubquery = ScalarSubquery(plan, exprId) override def toString: String = s"subquery#${exprId.id}" - - // TODO: support sql() } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 13a78c609e0142e323028dc25f3a48d3b8f8fec5..9a14ccff57f83a1482793ba94c8cce8f8f9eed98 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -24,13 +24,22 @@ import scala.util.control.NonFatal import org.apache.spark.Logging import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, NonSQLExpression, - SortOrder} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.types.{DataType, NullType} + +/** + * A place holder for generated SQL for subquery expression. + */ +case class SubqueryHolder(query: String) extends LeafExpression with Unevaluable { + override def dataType: DataType = NullType + override def nullable: Boolean = true + override def sql: String = s"($query)" +} /** * A builder class used to convert a resolved logical plan into a SQL query string. Note that this @@ -46,7 +55,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi def toSQL: String = { val canonicalizedPlan = Canonicalizer.execute(logicalPlan) try { - canonicalizedPlan.transformAllExpressions { + val replaced = canonicalizedPlan.transformAllExpressions { + case e: SubqueryExpression => + SubqueryHolder(new SQLBuilder(e.query, sqlContext).toSQL) case e: NonSQLExpression => throw new UnsupportedOperationException( s"Expression $e doesn't have a SQL representation" @@ -54,14 +65,14 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case e => e } - val generatedSQL = toSQL(canonicalizedPlan, true) + val generatedSQL = toSQL(replaced, true) logDebug( s"""Built SQL query string successfully from given logical plan: | |# Original logical plan: |${logicalPlan.treeString} |# Canonicalized logical plan: - |${canonicalizedPlan.treeString} + |${replaced.treeString} |# Generated SQL: |$generatedSQL """.stripMargin) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala index d68c602a887f7a83df73feb15c04faed1c58906e..72765f05e7e4952888af9c5ec5cfb6674a7ddec5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala @@ -268,4 +268,9 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkSqlGeneration("SELECT input_file_name()") checkSqlGeneration("SELECT monotonically_increasing_id()") } + + test("subquery") { + checkSqlGeneration("SELECT 1 + (SELECT 2)") + checkSqlGeneration("SELECT 1 + (SELECT 2 + (SELECT 3 as a))") + } }