提交 600c0b69 编写于 作者: H Herman van Hovell 提交者: Reynold Xin

[SPARK-13713][SQL] Migrate parser from ANTLR3 to ANTLR4

### What changes were proposed in this pull request?
The current ANTLR3 parser is quite complex to maintain and suffers from code blow-ups. This PR introduces a new parser that is based on ANTLR4.

This parser is based on the [Presto's SQL parser](https://github.com/facebook/presto/blob/master/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4). The current implementation can parse and create Catalyst and SQL plans. Large parts of the HiveQl DDL and some of the DML functionality is currently missing, the plan is to add this in follow-up PRs.

This PR is a work in progress, and work needs to be done in the following area's:

- [x] Error handling should be improved.
- [x] Documentation should be improved.
- [x] Multi-Insert needs to be tested.
- [ ] Naming and package locations.

### How was this patch tested?

Catalyst and SQL unit tests.

Author: Herman van Hovell <hvanhovell@questtec.nl>

Closes #11557 from hvanhovell/ngParser.
上级 1528ff4c
......@@ -238,6 +238,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
(BSD 3 Clause) netlib core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core)
(BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.2.7 - https://github.com/jpmml/jpmml-model)
(BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/)
(BSD License) ANTLR 4.5.2-1 (org.antlr:antlr4:4.5.2-1 - http://wwww.antlr.org/)
(BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org)
(BSD licence) ANTLR StringTemplate (org.antlr:stringtemplate:3.2.1 - http://www.stringtemplate.org)
(BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org)
......
......@@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.jar
antlr-runtime-3.5.2.jar
antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
arpack_combined_all-0.1.jar
......
......@@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.1.jar
antlr-runtime-3.5.2.jar
antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
arpack_combined_all-0.1.jar
......
......@@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.1.jar
antlr-runtime-3.5.2.jar
antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
arpack_combined_all-0.1.jar
......
......@@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.1.jar
antlr-runtime-3.5.2.jar
antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
apacheds-i18n-2.0.0-M15.jar
......
......@@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.1.jar
antlr-runtime-3.5.2.jar
antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
apacheds-i18n-2.0.0-M15.jar
......
......@@ -178,6 +178,7 @@
<jsr305.version>1.3.9</jsr305.version>
<libthrift.version>0.9.2</libthrift.version>
<antlr.version>3.5.2</antlr.version>
<antlr4.version>4.5.2-1</antlr4.version>
<test.java.home>${java.home}</test.java.home>
<test.exclude.tags></test.exclude.tags>
......@@ -1759,6 +1760,11 @@
<artifactId>antlr-runtime</artifactId>
<version>${antlr.version}</version>
</dependency>
<dependency>
<groupId>org.antlr</groupId>
<artifactId>antlr4-runtime</artifactId>
<version>${antlr4.version}</version>
</dependency>
</dependencies>
</dependencyManagement>
......
......@@ -25,6 +25,7 @@ import sbt._
import sbt.Classpaths.publishTask
import sbt.Keys._
import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion
import com.simplytyped.Antlr4Plugin._
import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys}
import com.typesafe.tools.mima.plugin.MimaKeys
......@@ -401,7 +402,10 @@ object OldDeps {
}
object Catalyst {
lazy val settings = Seq(
lazy val settings = antlr4Settings ++ Seq(
antlr4PackageName in Antlr4 := Some("org.apache.spark.sql.catalyst.parser.ng"),
antlr4GenListener in Antlr4 := true,
antlr4GenVisitor in Antlr4 := true,
// ANTLR code-generation step.
//
// This has been heavily inspired by com.github.stefri.sbt-antlr (0.5.3). It fixes a number of
......@@ -414,7 +418,7 @@ object Catalyst {
"SparkSqlLexer.g",
"SparkSqlParser.g")
val sourceDir = (sourceDirectory in Compile).value / "antlr3"
val targetDir = (sourceManaged in Compile).value
val targetDir = (sourceManaged in Compile).value / "antlr3"
// Create default ANTLR Tool.
val antlr = new org.antlr.Tool
......
......@@ -23,3 +23,9 @@ libraryDependencies += "org.ow2.asm" % "asm" % "5.0.3"
libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.0.3"
libraryDependencies += "org.antlr" % "antlr" % "3.5.2"
// TODO I am not sure we want such a dep.
resolvers += "simplytyped" at "http://simplytyped.github.io/repo/releases"
addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.7.10")
......@@ -51,7 +51,7 @@ from pyspark.sql.types import UserDefinedType, _infer_type
from pyspark.tests import ReusedPySparkTestCase
from pyspark.sql.functions import UserDefinedFunction, sha2
from pyspark.sql.window import Window
from pyspark.sql.utils import AnalysisException, IllegalArgumentException
from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException
class UTCOffsetTimezone(datetime.tzinfo):
......@@ -1130,7 +1130,9 @@ class SQLTests(ReusedPySparkTestCase):
def test_capture_analysis_exception(self):
self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc"))
self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b"))
self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("abc"))
def test_capture_parse_exception(self):
self.assertRaises(ParseException, lambda: self.sqlCtx.sql("abc"))
def test_capture_illegalargument_exception(self):
self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks",
......
......@@ -33,6 +33,12 @@ class AnalysisException(CapturedException):
"""
class ParseException(CapturedException):
"""
Failed to parse a SQL command.
"""
class IllegalArgumentException(CapturedException):
"""
Passed an illegal or inappropriate argument.
......@@ -49,6 +55,8 @@ def capture_sql_exception(f):
e.java_exception.getStackTrace()))
if s.startswith('org.apache.spark.sql.AnalysisException: '):
raise AnalysisException(s.split(': ', 1)[1], stackTrace)
if s.startswith('org.apache.spark.sql.catalyst.parser.ng.ParseException: '):
raise ParseException(s.split(': ', 1)[1], stackTrace)
if s.startswith('java.lang.IllegalArgumentException: '):
raise IllegalArgumentException(s.split(': ', 1)[1], stackTrace)
raise
......
......@@ -75,6 +75,10 @@
<groupId>org.antlr</groupId>
<artifactId>antlr-runtime</artifactId>
</dependency>
<dependency>
<groupId>org.antlr</groupId>
<artifactId>antlr4-runtime</artifactId>
</dependency>
<dependency>
<groupId>commons-codec</groupId>
<artifactId>commons-codec</artifactId>
......
......@@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp}
import scala.language.implicitConversions
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
......@@ -161,6 +161,10 @@ package object dsl {
def lower(e: Expression): Expression = Lower(e)
def sqrt(e: Expression): Expression = Sqrt(e)
def abs(e: Expression): Expression = Abs(e)
def star(names: String*): Expression = names match {
case Seq() => UnresolvedStar(None)
case target => UnresolvedStar(Option(target))
}
implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name }
// TODO more implicit class for literal?
......@@ -231,6 +235,12 @@ package object dsl {
AttributeReference(s, structType, nullable = true)()
def struct(attrs: AttributeReference*): AttributeReference =
struct(StructType.fromAttributes(attrs))
/** Create a function. */
def function(exprs: Expression*): UnresolvedFunction =
UnresolvedFunction(s, exprs, isDistinct = false)
def distinctFunction(exprs: Expression*): UnresolvedFunction =
UnresolvedFunction(s, exprs, isDistinct = true)
}
implicit class DslAttribute(a: AttributeReference) {
......@@ -243,8 +253,20 @@ package object dsl {
object expressions extends ExpressionConversions // scalastyle:ignore
object plans { // scalastyle:ignore
def table(ref: String): LogicalPlan =
UnresolvedRelation(TableIdentifier(ref), None)
def table(db: String, ref: String): LogicalPlan =
UnresolvedRelation(TableIdentifier(ref, Option(db)), None)
implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) {
def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan)
def select(exprs: Expression*): LogicalPlan = {
val namedExpressions = exprs.map {
case e: NamedExpression => e
case e => UnresolvedAlias(e)
}
Project(namedExpressions, logicalPlan)
}
def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan)
......@@ -296,6 +318,14 @@ package object dsl {
analysis.UnresolvedRelation(TableIdentifier(tableName)),
Map.empty, logicalPlan, overwrite, false)
def as(alias: String): LogicalPlan = logicalPlan match {
case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias))
case plan => SubqueryAlias(alias, plan)
}
def distribute(exprs: Expression*): LogicalPlan =
RepartitionByExpression(exprs, logicalPlan)
def analyze: LogicalPlan =
EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(logicalPlan))
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.parser.ng
import org.antlr.v4.runtime._
import org.antlr.v4.runtime.atn.PredictionMode
import org.antlr.v4.runtime.misc.ParseCancellationException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.types.DataType
/**
* Base SQL parsing infrastructure.
*/
abstract class AbstractSqlParser extends ParserInterface with Logging {
/** Creates/Resolves DataType for a given SQL string. */
def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
// TODO add this to the parser interface.
astBuilder.visitSingleDataType(parser.singleDataType())
}
/** Creates Expression for a given SQL string. */
override def parseExpression(sqlText: String): Expression = parse(sqlText) { parser =>
astBuilder.visitSingleExpression(parser.singleExpression())
}
/** Creates TableIdentifier for a given SQL string. */
override def parseTableIdentifier(sqlText: String): TableIdentifier = parse(sqlText) { parser =>
astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier())
}
/** Creates LogicalPlan for a given SQL string. */
override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
astBuilder.visitSingleStatement(parser.singleStatement()) match {
case plan: LogicalPlan => plan
case _ => nativeCommand(sqlText)
}
}
/** Get the builder (visitor) which converts a ParseTree into a AST. */
protected def astBuilder: AstBuilder
/** Create a native command, or fail when this is not supported. */
protected def nativeCommand(sqlText: String): LogicalPlan = {
val position = Origin(None, None)
throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position)
}
protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = {
logInfo(s"Parsing command: $command")
val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command))
lexer.removeErrorListeners()
lexer.addErrorListener(ParseErrorListener)
val tokenStream = new CommonTokenStream(lexer)
val parser = new SqlBaseParser(tokenStream)
parser.addParseListener(PostProcessor)
parser.removeErrorListeners()
parser.addErrorListener(ParseErrorListener)
try {
try {
// first, try parsing with potentially faster SLL mode
parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
toResult(parser)
}
catch {
case e: ParseCancellationException =>
// if we fail, parse with LL mode
tokenStream.reset() // rewind input stream
parser.reset()
// Try Again.
parser.getInterpreter.setPredictionMode(PredictionMode.LL)
toResult(parser)
}
}
catch {
case e: ParseException if e.command.isDefined =>
throw e
case e: ParseException =>
throw e.withCommand(command)
case e: AnalysisException =>
val position = Origin(e.line, e.startPosition)
throw new ParseException(Option(command), e.message, position, position)
}
}
}
/**
* Concrete SQL parser for Catalyst-only SQL statements.
*/
object CatalystSqlParser extends AbstractSqlParser {
val astBuilder = new AstBuilder
}
/**
* This string stream provides the lexer with upper case characters only. This greatly simplifies
* lexing the stream, while we can maintain the original command.
*
* This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver.ANTLRNoCaseStringStream
*
* The comment below (taken from the original class) describes the rationale for doing this:
*
* This class provides and implementation for a case insensitive token checker for the lexical
* analysis part of antlr. By converting the token stream into upper case at the time when lexical
* rules are checked, this class ensures that the lexical rules need to just match the token with
* upper case letters as opposed to combination of upper case and lower case characters. This is
* purely used for matching lexical rules. The actual token text is stored in the same way as the
* user input without actually converting it into an upper case. The token values are generated by
* the consume() function of the super class ANTLRStringStream. The LA() function is the lookahead
* function and is purely used for matching lexical rules. This also means that the grammar will
* only accept capitalized tokens in case it is run from other tools like antlrworks which do not
* have the ANTLRNoCaseStringStream implementation.
*/
private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRInputStream(input) {
override def LA(i: Int): Int = {
val la = super.LA(i)
if (la == 0 || la == IntStream.EOF) la
else Character.toUpperCase(la)
}
}
/**
* The ParseErrorListener converts parse errors into AnalysisExceptions.
*/
case object ParseErrorListener extends BaseErrorListener {
override def syntaxError(
recognizer: Recognizer[_, _],
offendingSymbol: scala.Any,
line: Int,
charPositionInLine: Int,
msg: String,
e: RecognitionException): Unit = {
val position = Origin(Some(line), Some(charPositionInLine))
throw new ParseException(None, msg, position, position)
}
}
/**
* A [[ParseException]] is an [[AnalysisException]] that is thrown during the parse process. It
* contains fields and an extended error message that make reporting and diagnosing errors easier.
*/
class ParseException(
val command: Option[String],
message: String,
val start: Origin,
val stop: Origin) extends AnalysisException(message, start.line, start.startPosition) {
def this(message: String, ctx: ParserRuleContext) = {
this(Option(ParserUtils.command(ctx)),
message,
ParserUtils.position(ctx.getStart),
ParserUtils.position(ctx.getStop))
}
override def getMessage: String = {
val builder = new StringBuilder
builder ++= "\n" ++= message
start match {
case Origin(Some(l), Some(p)) =>
builder ++= s"(line $l, pos $p)\n"
command.foreach { cmd =>
val (above, below) = cmd.split("\n").splitAt(l)
builder ++= "\n== SQL ==\n"
above.foreach(builder ++= _ += '\n')
builder ++= (0 until p).map(_ => "-").mkString("") ++= "^^^\n"
below.foreach(builder ++= _ += '\n')
}
case _ =>
command.foreach { cmd =>
builder ++= "\n== SQL ==\n" ++= cmd
}
}
builder.toString
}
def withCommand(cmd: String): ParseException = {
new ParseException(Option(cmd), message, start, stop)
}
}
/**
* The post-processor validates & cleans-up the parse tree during the parse process.
*/
case object PostProcessor extends SqlBaseBaseListener {
/** Remove the back ticks from an Identifier. */
override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = {
replaceTokenByIdentifier(ctx, 1) { token =>
// Remove the double back ticks in the string.
token.setText(token.getText.replace("``", "`"))
token
}
}
/** Treat non-reserved keywords as Identifiers. */
override def exitNonReserved(ctx: NonReservedContext): Unit = {
replaceTokenByIdentifier(ctx, 0)(identity)
}
private def replaceTokenByIdentifier(
ctx: ParserRuleContext,
stripMargins: Int)(
f: CommonToken => CommonToken = identity): Unit = {
val parent = ctx.getParent
parent.removeLastChild()
val token = ctx.getChild(0).getPayload.asInstanceOf[Token]
parent.addChild(f(new CommonToken(
new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream),
SqlBaseParser.IDENTIFIER,
token.getChannel,
token.getStartIndex + stripMargins,
token.getStopIndex - stripMargins)))
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.parser.ng
import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token}
import org.antlr.v4.runtime.misc.Interval
import org.antlr.v4.runtime.tree.TerminalNode
import org.apache.spark.sql.catalyst.parser.ParseUtils.unescapeSQLString
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
/**
* A collection of utility methods for use during the parsing process.
*/
object ParserUtils {
/** Get the command which created the token. */
def command(ctx: ParserRuleContext): String = {
command(ctx.getStart.getInputStream)
}
/** Get the command which created the token. */
def command(stream: CharStream): String = {
stream.getText(Interval.of(0, stream.size()))
}
/** Get the code that creates the given node. */
def source(ctx: ParserRuleContext): String = {
val stream = ctx.getStart.getInputStream
stream.getText(Interval.of(ctx.getStart.getStartIndex, ctx.getStop.getStopIndex))
}
/** Get all the text which comes after the given rule. */
def remainder(ctx: ParserRuleContext): String = remainder(ctx.getStop)
/** Get all the text which comes after the given token. */
def remainder(token: Token): String = {
val stream = token.getInputStream
val interval = Interval.of(token.getStopIndex + 1, stream.size())
stream.getText(interval)
}
/** Convert a string token into a string. */
def string(token: Token): String = unescapeSQLString(token.getText)
/** Convert a string node into a string. */
def string(node: TerminalNode): String = unescapeSQLString(node.getText)
/** Get the origin (line and position) of the token. */
def position(token: Token): Origin = {
Origin(Option(token.getLine), Option(token.getCharPositionInLine))
}
/** Assert if a condition holds. If it doesn't throw a parse exception. */
def assert(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = {
if (!f) {
throw new ParseException(message, ctx)
}
}
/**
* Register the origin of the context. Any TreeNode created in the closure will be assigned the
* registered origin. This method restores the previously set origin after completion of the
* closure.
*/
def withOrigin[T](ctx: ParserRuleContext)(f: => T): T = {
val current = CurrentOrigin.get
CurrentOrigin.set(position(ctx.getStart))
try {
f
} finally {
CurrentOrigin.set(current)
}
}
/** Some syntactic sugar which makes it easier to work with optional clauses for LogicalPlans. */
implicit class EnhancedLogicalPlan(val plan: LogicalPlan) extends AnyVal {
/**
* Create a plan using the block of code when the given context exists. Otherwise return the
* original plan.
*/
def optional(ctx: AnyRef)(f: => LogicalPlan): LogicalPlan = {
if (ctx != null) {
f
} else {
plan
}
}
/**
* Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the
* passed function. The original plan is returned when the context does not exist.
*/
def optionalMap[C <: ParserRuleContext](
ctx: C)(
f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = {
if (ctx != null) {
f(ctx, plan)
} else {
plan
}
}
}
}
......@@ -21,15 +21,20 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.ng.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.unsafe.types.CalendarInterval
class CatalystQlSuite extends PlanTest {
val parser = new CatalystQl()
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
val star = UnresolvedAlias(UnresolvedStar(None))
test("test case insensitive") {
val result = Project(UnresolvedAlias(Literal(1)):: Nil, OneRowRelation)
val result = OneRowRelation.select(1)
assert(result === parser.parsePlan("seLect 1"))
assert(result === parser.parsePlan("select 1"))
assert(result === parser.parsePlan("SELECT 1"))
......@@ -37,52 +42,31 @@ class CatalystQlSuite extends PlanTest {
test("test NOT operator with comparison operations") {
val parsed = parser.parsePlan("SELECT NOT TRUE > TRUE")
val expected = Project(
UnresolvedAlias(
Not(
GreaterThan(Literal(true), Literal(true)))
) :: Nil,
OneRowRelation)
val expected = OneRowRelation.select(Not(GreaterThan(true, true)))
comparePlans(parsed, expected)
}
test("test Union Distinct operator") {
val parsed1 = parser.parsePlan("SELECT * FROM t0 UNION SELECT * FROM t1")
val parsed2 = parser.parsePlan("SELECT * FROM t0 UNION DISTINCT SELECT * FROM t1")
val expected =
Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
SubqueryAlias("u_1",
Distinct(
Union(
Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
UnresolvedRelation(TableIdentifier("t0"), None)),
Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
UnresolvedRelation(TableIdentifier("t1"), None))))))
val parsed1 = parser.parsePlan(
"SELECT * FROM t0 UNION SELECT * FROM t1")
val parsed2 = parser.parsePlan(
"SELECT * FROM t0 UNION DISTINCT SELECT * FROM t1")
val expected = Distinct(Union(table("t0").select(star), table("t1").select(star)))
.as("u_1").select(star)
comparePlans(parsed1, expected)
comparePlans(parsed2, expected)
}
test("test Union All operator") {
val parsed = parser.parsePlan("SELECT * FROM t0 UNION ALL SELECT * FROM t1")
val expected =
Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
SubqueryAlias("u_1",
Union(
Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
UnresolvedRelation(TableIdentifier("t0"), None)),
Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
UnresolvedRelation(TableIdentifier("t1"), None)))))
val expected = Union(table("t0").select(star), table("t1").select(star)).as("u_1").select(star)
comparePlans(parsed, expected)
}
test("support hive interval literal") {
def checkInterval(sql: String, result: CalendarInterval): Unit = {
val parsed = parser.parsePlan(sql)
val expected = Project(
UnresolvedAlias(
Literal(result)
) :: Nil,
OneRowRelation)
val expected = OneRowRelation.select(Literal(result))
comparePlans(parsed, expected)
}
......@@ -129,11 +113,7 @@ class CatalystQlSuite extends PlanTest {
test("support scientific notation") {
def assertRight(input: String, output: Double): Unit = {
val parsed = parser.parsePlan("SELECT " + input)
val expected = Project(
UnresolvedAlias(
Literal(output)
) :: Nil,
OneRowRelation)
val expected = OneRowRelation.select(Literal(output))
comparePlans(parsed, expected)
}
......
......@@ -18,19 +18,24 @@
package org.apache.spark.sql.catalyst.parser
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.parser.ng.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.types._
class DataTypeParserSuite extends SparkFunSuite {
abstract class AbstractDataTypeParserSuite extends SparkFunSuite {
def parse(sql: String): DataType
def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = {
test(s"parse ${dataTypeString.replace("\n", "")}") {
assert(DataTypeParser.parse(dataTypeString) === expectedDataType)
assert(parse(dataTypeString) === expectedDataType)
}
}
def intercept(sql: String)
def unsupported(dataTypeString: String): Unit = {
test(s"$dataTypeString is not supported") {
intercept[DataTypeException](DataTypeParser.parse(dataTypeString))
intercept(dataTypeString)
}
}
......@@ -97,13 +102,6 @@ class DataTypeParserSuite extends SparkFunSuite {
StructField("arrAy", ArrayType(DoubleType, true), true) ::
StructField("anotherArray", ArrayType(StringType, true), true) :: Nil)
)
// A column name can be a reserved word in our DDL parser and SqlParser.
checkDataType(
"Struct<TABLE: string, CASE:boolean>",
StructType(
StructField("TABLE", StringType, true) ::
StructField("CASE", BooleanType, true) :: Nil)
)
// Use backticks to quote column names having special characters.
checkDataType(
"struct<`x+y`:int, `!@#$%^&*()`:string, `1_2.345<>:\"`:varchar(20)>",
......@@ -118,6 +116,43 @@ class DataTypeParserSuite extends SparkFunSuite {
unsupported("it is not a data type")
unsupported("struct<x+y: int, 1.1:timestamp>")
unsupported("struct<x: int")
}
class DataTypeParserSuite extends AbstractDataTypeParserSuite {
override def intercept(sql: String): Unit =
intercept[DataTypeException](DataTypeParser.parse(sql))
override def parse(sql: String): DataType =
DataTypeParser.parse(sql)
// A column name can be a reserved word in our DDL parser and SqlParser.
checkDataType(
"Struct<TABLE: string, CASE:boolean>",
StructType(
StructField("TABLE", StringType, true) ::
StructField("CASE", BooleanType, true) :: Nil)
)
unsupported("struct<x int, y string>")
unsupported("struct<`x``y` int>")
}
class CatalystQlDataTypeParserSuite extends AbstractDataTypeParserSuite {
override def intercept(sql: String): Unit =
intercept[ParseException](CatalystSqlParser.parseDataType(sql))
override def parse(sql: String): DataType =
CatalystSqlParser.parseDataType(sql)
// A column name can be a reserved word in our DDL parser and SqlParser.
unsupported("Struct<TABLE: string, CASE:boolean>")
checkDataType(
"struct<x int, y string>",
(new StructType).add("x", IntegerType).add("y", StringType))
checkDataType(
"struct<`x``y` int>",
(new StructType).add("x`y", IntegerType))
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.parser.ng
import org.apache.spark.SparkFunSuite
/**
* Test various parser errors.
*/
class ErrorParserSuite extends SparkFunSuite {
def intercept(sql: String, line: Int, startPosition: Int, messages: String*): Unit = {
val e = intercept[ParseException](CatalystSqlParser.parsePlan(sql))
// Check position.
assert(e.line.isDefined)
assert(e.line.get === line)
assert(e.startPosition.isDefined)
assert(e.startPosition.get === startPosition)
// Check messages.
val error = e.getMessage
messages.foreach { message =>
assert(error.contains(message))
}
}
test("no viable input") {
intercept("select from tbl", 1, 7, "no viable alternative at input", "-------^^^")
intercept("select\nfrom tbl", 2, 0, "no viable alternative at input", "^^^")
intercept("select ((r + 1) ", 1, 16, "no viable alternative at input", "----------------^^^")
}
test("extraneous input") {
intercept("select 1 1", 1, 9, "extraneous input '1' expecting", "---------^^^")
intercept("select *\nfrom r as q t", 2, 12, "extraneous input", "------------^^^")
}
test("mismatched input") {
intercept("select * from r order by q from t", 1, 27,
"mismatched input",
"---------------------------^^^")
intercept("select *\nfrom r\norder by q\nfrom t", 4, 0, "mismatched input", "^^^")
}
test("semantic errors") {
intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0,
"Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported",
"^^^")
intercept("select * from r where a in (select * from t)", 1, 24,
"IN with a Sub-query is currently not supported",
"------------------------^^^")
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.parser.ng
import java.sql.{Date, Timestamp}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
/**
* Test basic expression parsing. If a type of expression is supported it should be tested here.
*
* Please note that some of the expressions test don't have to be sound expressions, only their
* structure needs to be valid. Unsound expressions should be caught by the Analyzer or
* CheckAnalysis classes.
*/
class ExpressionParserSuite extends PlanTest {
import CatalystSqlParser._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
def assertEqual(sqlCommand: String, e: Expression): Unit = {
compareExpressions(parseExpression(sqlCommand), e)
}
def intercept(sqlCommand: String, messages: String*): Unit = {
val e = intercept[ParseException](parseExpression(sqlCommand))
messages.foreach { message =>
assert(e.message.contains(message))
}
}
test("star expressions") {
// Global Star
assertEqual("*", UnresolvedStar(None))
// Targeted Star
assertEqual("a.b.*", UnresolvedStar(Option(Seq("a", "b"))))
}
// NamedExpression (Alias/Multialias)
test("named expressions") {
// No Alias
val r0 = 'a
assertEqual("a", r0)
// Single Alias.
val r1 = 'a as "b"
assertEqual("a as b", r1)
assertEqual("a b", r1)
// Multi-Alias
assertEqual("a as (b, c)", MultiAlias('a, Seq("b", "c")))
assertEqual("a() (b, c)", MultiAlias('a.function(), Seq("b", "c")))
// Numeric literals without a space between the literal qualifier and the alias, should not be
// interpreted as such. An unresolved reference should be returned instead.
// TODO add the JIRA-ticket number.
assertEqual("1SL", Symbol("1SL"))
// Aliased star is allowed.
assertEqual("a.* b", UnresolvedStar(Option(Seq("a"))) as 'b)
}
test("binary logical expressions") {
// And
assertEqual("a and b", 'a && 'b)
// Or
assertEqual("a or b", 'a || 'b)
// Combination And/Or check precedence
assertEqual("a and b or c and d", ('a && 'b) || ('c && 'd))
assertEqual("a or b or c and d", 'a || 'b || ('c && 'd))
// Multiple AND/OR get converted into a balanced tree
assertEqual("a or b or c or d or e or f", (('a || 'b) || 'c) || (('d || 'e) || 'f))
assertEqual("a and b and c and d and e and f", (('a && 'b) && 'c) && (('d && 'e) && 'f))
}
test("long binary logical expressions") {
def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = {
val sql = (1 to 1000).map(x => s"$x == $x").mkString(op)
val e = parseExpression(sql)
assert(e.collect { case _: EqualTo => true }.size === 1000)
assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999)
}
testVeryBinaryExpression(" AND ", classOf[And])
testVeryBinaryExpression(" OR ", classOf[Or])
}
test("not expressions") {
assertEqual("not a", !'a)
assertEqual("!a", !'a)
assertEqual("not true > true", Not(GreaterThan(true, true)))
}
test("exists expression") {
intercept("exists (select 1 from b where b.x = a.x)", "EXISTS clauses are not supported")
}
test("comparison expressions") {
assertEqual("a = b", 'a === 'b)
assertEqual("a == b", 'a === 'b)
assertEqual("a <=> b", 'a <=> 'b)
assertEqual("a <> b", 'a =!= 'b)
assertEqual("a != b", 'a =!= 'b)
assertEqual("a < b", 'a < 'b)
assertEqual("a <= b", 'a <= 'b)
assertEqual("a > b", 'a > 'b)
assertEqual("a >= b", 'a >= 'b)
}
test("between expressions") {
assertEqual("a between b and c", 'a >= 'b && 'a <= 'c)
assertEqual("a not between b and c", !('a >= 'b && 'a <= 'c))
}
test("in expressions") {
assertEqual("a in (b, c, d)", 'a in ('b, 'c, 'd))
assertEqual("a not in (b, c, d)", !('a in ('b, 'c, 'd)))
}
test("in sub-query") {
intercept("a in (select b from c)", "IN with a Sub-query is currently not supported")
}
test("like expressions") {
assertEqual("a like 'pattern%'", 'a like "pattern%")
assertEqual("a not like 'pattern%'", !('a like "pattern%"))
assertEqual("a rlike 'pattern%'", 'a rlike "pattern%")
assertEqual("a not rlike 'pattern%'", !('a rlike "pattern%"))
assertEqual("a regexp 'pattern%'", 'a rlike "pattern%")
assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%"))
}
test("is null expressions") {
assertEqual("a is null", 'a.isNull)
assertEqual("a is not null", 'a.isNotNull)
assertEqual("a = b is null", ('a === 'b).isNull)
assertEqual("a = b is not null", ('a === 'b).isNotNull)
}
test("binary arithmetic expressions") {
// Simple operations
assertEqual("a * b", 'a * 'b)
assertEqual("a / b", 'a / 'b)
assertEqual("a DIV b", ('a / 'b).cast(LongType))
assertEqual("a % b", 'a % 'b)
assertEqual("a + b", 'a + 'b)
assertEqual("a - b", 'a - 'b)
assertEqual("a & b", 'a & 'b)
assertEqual("a ^ b", 'a ^ 'b)
assertEqual("a | b", 'a | 'b)
// Check precedences
assertEqual(
"a * t | b ^ c & d - e + f % g DIV h / i * k",
'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g / 'h).cast(LongType) / 'i * 'k)))))
}
test("unary arithmetic expressions") {
assertEqual("+a", 'a)
assertEqual("-a", -'a)
assertEqual("~a", ~'a)
assertEqual("-+~~a", -(~(~'a)))
}
test("cast expressions") {
// Note that DataType parsing is tested elsewhere.
assertEqual("cast(a as int)", 'a.cast(IntegerType))
assertEqual("cast(a as timestamp)", 'a.cast(TimestampType))
assertEqual("cast(a as array<int>)", 'a.cast(ArrayType(IntegerType)))
assertEqual("cast(cast(a as int) as long)", 'a.cast(IntegerType).cast(LongType))
}
test("function expressions") {
assertEqual("foo()", 'foo.function())
assertEqual("foo.bar()", Symbol("foo.bar").function())
assertEqual("foo(*)", 'foo.function(star()))
assertEqual("count(*)", 'count.function(1))
assertEqual("foo(a, b)", 'foo.function('a, 'b))
assertEqual("foo(all a, b)", 'foo.function('a, 'b))
assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b))
assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b))
assertEqual("`select`(all a, b)", 'select.function('a, 'b))
}
test("window function expressions") {
val func = 'foo.function(star())
def windowed(
partitioning: Seq[Expression] = Seq.empty,
ordering: Seq[SortOrder] = Seq.empty,
frame: WindowFrame = UnspecifiedFrame): Expression = {
WindowExpression(func, WindowSpecDefinition(partitioning, ordering, frame))
}
// Basic window testing.
assertEqual("foo(*) over w1", UnresolvedWindowExpression(func, WindowSpecReference("w1")))
assertEqual("foo(*) over ()", windowed())
assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b)))
assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b)))
assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b)))
assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc )))
assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc )))
assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc)))
assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc)))
// Test use of expressions in window functions.
assertEqual(
"sum(product + 1) over (partition by ((product) + (1)) order by 2)",
WindowExpression('sum.function('product + 1),
WindowSpecDefinition(Seq('product + 1), Seq(Literal(2).asc), UnspecifiedFrame)))
assertEqual(
"sum(product + 1) over (partition by ((product / 2) + 1) order by 2)",
WindowExpression('sum.function('product + 1),
WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame)))
// Range/Row
val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame))
val boundaries = Seq(
("10 preceding", ValuePreceding(10), CurrentRow),
("3 + 1 following", ValueFollowing(4), CurrentRow), // Will fail during analysis
("unbounded preceding", UnboundedPreceding, CurrentRow),
("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis
("between unbounded preceding and current row", UnboundedPreceding, CurrentRow),
("between unbounded preceding and unbounded following",
UnboundedPreceding, UnboundedFollowing),
("between 10 preceding and current row", ValuePreceding(10), CurrentRow),
("between current row and 5 following", CurrentRow, ValueFollowing(5)),
("between 10 preceding and 5 following", ValuePreceding(10), ValueFollowing(5))
)
frameTypes.foreach {
case (frameTypeSql, frameType) =>
boundaries.foreach {
case (boundarySql, begin, end) =>
val query = s"foo(*) over (partition by a order by b $frameTypeSql $boundarySql)"
val expr = windowed(Seq('a), Seq('b.asc), SpecifiedWindowFrame(frameType, begin, end))
assertEqual(query, expr)
}
}
// We cannot use non integer constants.
intercept("foo(*) over (partition by a order by b rows 10.0 preceding)",
"Frame bound value must be a constant integer.")
// We cannot use an arbitrary expression.
intercept("foo(*) over (partition by a order by b rows exp(b) preceding)",
"Frame bound value must be a constant integer.")
}
test("row constructor") {
// Note that '(a)' will be interpreted as a nested expression.
assertEqual("(a, b)", CreateStruct(Seq('a, 'b)))
assertEqual("(a, b, c)", CreateStruct(Seq('a, 'b, 'c)))
}
test("scalar sub-query") {
assertEqual(
"(select max(val) from tbl) > current",
ScalarSubquery(table("tbl").select('max.function('val))) > 'current)
assertEqual(
"a = (select b from s)",
'a === ScalarSubquery(table("s").select('b)))
}
test("case when") {
assertEqual("case a when 1 then b when 2 then c else d end",
CaseKeyWhen('a, Seq(1, 'b, 2, 'c, 'd)))
assertEqual("case when a = 1 then b when a = 2 then c else d end",
CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd))
}
test("dereference") {
assertEqual("a.b", UnresolvedAttribute("a.b"))
assertEqual("`select`.b", UnresolvedAttribute("select.b"))
assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis.
assertEqual("struct(a, b).b", 'struct.function('a, 'b).getField("b"))
}
test("reference") {
// Regular
assertEqual("a", 'a)
// Starting with a digit.
assertEqual("1a", Symbol("1a"))
// Quoted using a keyword.
assertEqual("`select`", 'select)
// Unquoted using an unreserved keyword.
assertEqual("columns", 'columns)
}
test("subscript") {
assertEqual("a[b]", 'a.getItem('b))
assertEqual("a[1 + 1]", 'a.getItem(Literal(1) + 1))
assertEqual("`c`.a[b]", UnresolvedAttribute("c.a").getItem('b))
}
test("parenthesis") {
assertEqual("(a)", 'a)
assertEqual("r * (a + b)", 'r * ('a + 'b))
}
test("type constructors") {
// Dates.
assertEqual("dAte '2016-03-11'", Literal(Date.valueOf("2016-03-11")))
intercept[IllegalArgumentException] {
parseExpression("DAtE 'mar 11 2016'")
}
// Timestamps.
assertEqual("tImEstAmp '2016-03-11 20:54:00.000'",
Literal(Timestamp.valueOf("2016-03-11 20:54:00.000")))
intercept[IllegalArgumentException] {
parseExpression("timestamP '2016-33-11 20:54:00.000'")
}
// Unsupported datatype.
intercept("GEO '(10,-6)'", "Literals of type 'GEO' are currently not supported.")
}
test("literals") {
// NULL
assertEqual("null", Literal(null))
// Boolean
assertEqual("trUe", Literal(true))
assertEqual("False", Literal(false))
// Integral should have the narrowest possible type
assertEqual("787324", Literal(787324))
assertEqual("7873247234798249234", Literal(7873247234798249234L))
assertEqual("78732472347982492793712334",
Literal(BigDecimal("78732472347982492793712334").underlying()))
// Decimal
assertEqual("7873247234798249279371.2334",
Literal(BigDecimal("7873247234798249279371.2334").underlying()))
// Scientific Decimal
assertEqual("9.0e1", 90d)
assertEqual(".9e+2", 90d)
assertEqual("0.9e+2", 90d)
assertEqual("900e-1", 90d)
assertEqual("900.0E-1", 90d)
assertEqual("9.e+1", 90d)
intercept(".e3")
// Tiny Int Literal
assertEqual("10Y", Literal(10.toByte))
intercept("-1000Y")
// Small Int Literal
assertEqual("10S", Literal(10.toShort))
intercept("40000S")
// Long Int Literal
assertEqual("10L", Literal(10L))
intercept("78732472347982492793712334L")
// Double Literal
assertEqual("10.0D", Literal(10.0D))
// TODO we need to figure out if we should throw an exception here!
assertEqual("1E309", Literal(Double.PositiveInfinity))
}
test("strings") {
// Single Strings.
assertEqual("\"hello\"", "hello")
assertEqual("'hello'", "hello")
// Multi-Strings.
assertEqual("\"hello\" 'world'", "helloworld")
assertEqual("'hello' \" \" 'world'", "hello world")
// 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a
// regular '%'; to get the correct result you need to add another escaped '\'.
// TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method?
assertEqual("'pattern%'", "pattern%")
assertEqual("'no-pattern\\%'", "no-pattern\\%")
assertEqual("'pattern\\\\%'", "pattern\\%")
assertEqual("'pattern\\\\\\%'", "pattern\\\\%")
// Escaped characters.
// See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html
assertEqual("'\\0'", "\u0000") // ASCII NUL (X'00')
assertEqual("'\\''", "\'") // Single quote
assertEqual("'\\\"'", "\"") // Double quote
assertEqual("'\\b'", "\b") // Backspace
assertEqual("'\\n'", "\n") // Newline
assertEqual("'\\r'", "\r") // Carriage return
assertEqual("'\\t'", "\t") // Tab character
assertEqual("'\\Z'", "\u001A") // ASCII 26 - CTRL + Z (EOF on windows)
// Octals
assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!")
// Unicode
assertEqual("'\\u0087\\u0111\\u0114\\u0108\\u0100\\u0032\\u0058\\u0041'", "World :)")
}
test("intervals") {
def intervalLiteral(u: String, s: String): Literal = {
Literal(CalendarInterval.fromSingleUnitString(u, s))
}
// Empty interval statement
intercept("interval", "at least one time unit should be given for interval literal")
// Single Intervals.
val units = Seq(
"year",
"month",
"week",
"day",
"hour",
"minute",
"second",
"millisecond",
"microsecond")
val forms = Seq("", "s")
val values = Seq("0", "10", "-7", "21")
units.foreach { unit =>
forms.foreach { form =>
values.foreach { value =>
val expected = intervalLiteral(unit, value)
assertEqual(s"interval $value $unit$form", expected)
assertEqual(s"interval '$value' $unit$form", expected)
}
}
}
// Hive nanosecond notation.
assertEqual("interval 13.123456789 seconds", intervalLiteral("second", "13.123456789"))
assertEqual("interval -13.123456789 second", intervalLiteral("second", "-13.123456789"))
// Non Existing unit
intercept("interval 10 nanoseconds", "No interval can be constructed")
// Year-Month intervals.
val yearMonthValues = Seq("123-10", "496-0", "-2-3", "-123-0")
yearMonthValues.foreach { value =>
val result = Literal(CalendarInterval.fromYearMonthString(value))
assertEqual(s"interval '$value' year to month", result)
}
// Day-Time intervals.
val datTimeValues = Seq(
"99 11:22:33.123456789",
"-99 11:22:33.123456789",
"10 9:8:7.123456789",
"1 0:0:0",
"-1 0:0:0",
"1 0:0:1")
datTimeValues.foreach { value =>
val result = Literal(CalendarInterval.fromDayTimeString(value))
assertEqual(s"interval '$value' day to second", result)
}
// Unknown FROM TO intervals
intercept("interval 10 month to second", "Intervals FROM month TO second are not supported.")
// Composed intervals.
assertEqual(
"interval 3 months 22 seconds 1 millisecond",
Literal(new CalendarInterval(3, 22001000L)))
assertEqual(
"interval 3 years '-1-10' year to month 3 weeks '1 0:0:2' day to second",
Literal(new CalendarInterval(14,
22 * CalendarInterval.MICROS_PER_DAY + 2 * CalendarInterval.MICROS_PER_SECOND)))
}
test("composed expressions") {
assertEqual("1 + r.r As q", (Literal(1) + UnresolvedAttribute("r.r")).as("q"))
assertEqual("1 - f('o', o(bar))", Literal(1) - 'f.function("o", 'o.function('bar)))
intercept("1 - f('o', o(bar)) hello * world", "mismatched input '*'")
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.parser.ng
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.{BooleanType, IntegerType}
class PlanParserSuite extends PlanTest {
import CatalystSqlParser._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = {
comparePlans(parsePlan(sqlCommand), plan)
}
def intercept(sqlCommand: String, messages: String*): Unit = {
val e = intercept[ParseException](parsePlan(sqlCommand))
messages.foreach { message =>
assert(e.message.contains(message))
}
}
test("case insensitive") {
val plan = table("a").select(star())
assertEqual("sELEct * FroM a", plan)
assertEqual("select * fRoM a", plan)
assertEqual("SELECT * FROM a", plan)
}
test("show functions") {
assertEqual("show functions", ShowFunctions(None, None))
assertEqual("show functions foo", ShowFunctions(None, Some("foo")))
assertEqual("show functions foo.bar", ShowFunctions(Some("foo"), Some("bar")))
assertEqual("show functions 'foo\\\\.*'", ShowFunctions(None, Some("foo\\.*")))
intercept("show functions foo.bar.baz", "SHOW FUNCTIONS unsupported name")
}
test("describe function") {
assertEqual("describe function bar", DescribeFunction("bar", isExtended = false))
assertEqual("describe function extended bar", DescribeFunction("bar", isExtended = true))
assertEqual("describe function foo.bar", DescribeFunction("foo.bar", isExtended = false))
assertEqual("describe function extended f.bar", DescribeFunction("f.bar", isExtended = true))
}
test("set operations") {
val a = table("a").select(star())
val b = table("b").select(star())
assertEqual("select * from a union select * from b", Distinct(a.union(b)))
assertEqual("select * from a union distinct select * from b", Distinct(a.union(b)))
assertEqual("select * from a union all select * from b", a.union(b))
assertEqual("select * from a except select * from b", a.except(b))
intercept("select * from a except all select * from b", "EXCEPT ALL is not supported.")
assertEqual("select * from a except distinct select * from b", a.except(b))
assertEqual("select * from a intersect select * from b", a.intersect(b))
intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.")
assertEqual("select * from a intersect distinct select * from b", a.intersect(b))
}
test("common table expressions") {
def cte(plan: LogicalPlan, namedPlans: (String, LogicalPlan)*): With = {
val ctes = namedPlans.map {
case (name, cte) =>
name -> SubqueryAlias(name, cte)
}.toMap
With(plan, ctes)
}
assertEqual(
"with cte1 as (select * from a) select * from cte1",
cte(table("cte1").select(star()), "cte1" -> table("a").select(star())))
assertEqual(
"with cte1 (select 1) select * from cte1",
cte(table("cte1").select(star()), "cte1" -> OneRowRelation.select(1)))
assertEqual(
"with cte1 (select 1), cte2 as (select * from cte1) select * from cte2",
cte(table("cte2").select(star()),
"cte1" -> OneRowRelation.select(1),
"cte2" -> table("cte1").select(star())))
intercept(
"with cte1 (select 1), cte1 as (select 1 from cte1) select * from cte1",
"Name 'cte1' is used for multiple common table expressions")
}
test("simple select query") {
assertEqual("select 1", OneRowRelation.select(1))
assertEqual("select a, b", OneRowRelation.select('a, 'b))
assertEqual("select a, b from db.c", table("db", "c").select('a, 'b))
assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b))
assertEqual(
"select a, b from db.c having x < 1",
table("db", "c").select('a, 'b).where(('x < 1).cast(BooleanType)))
assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b)))
assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b))
}
test("reverse select query") {
assertEqual("from a", table("a"))
assertEqual("from a select b, c", table("a").select('b, 'c))
assertEqual(
"from db.a select b, c where d < 1", table("db", "a").where('d < 1).select('b, 'c))
assertEqual("from a select distinct b, c", Distinct(table("a").select('b, 'c)))
assertEqual(
"from (from a union all from b) c select *",
table("a").union(table("b")).as("c").select(star()))
}
test("transform query spec") {
val p = ScriptTransformation(Seq('a, 'b), "func", Seq.empty, table("e"), null)
assertEqual("select transform(a, b) using 'func' from e where f < 10",
p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string)))
assertEqual("map a, b using 'func' as c, d from e",
p.copy(output = Seq('c.string, 'd.string)))
assertEqual("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e",
p.copy(output = Seq('c.int, 'd.decimal(10, 0))))
}
test("multi select query") {
assertEqual(
"from a select * select * where s < 10",
table("a").select(star()).union(table("a").where('s < 10).select(star())))
intercept(
"from a select * select * from x where a.s < 10",
"Multi-Insert queries cannot have a FROM clause in their individual SELECT statements")
assertEqual(
"from a insert into tbl1 select * insert into tbl2 select * where s < 10",
table("a").select(star()).insertInto("tbl1").union(
table("a").where('s < 10).select(star()).insertInto("tbl2")))
}
test("query organization") {
// Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows
val baseSql = "select * from t"
val basePlan = table("t").select(star())
val ws = Map("w1" -> WindowSpecDefinition(Seq.empty, Seq.empty, UnspecifiedFrame))
val limitWindowClauses = Seq(
("", (p: LogicalPlan) => p),
(" limit 10", (p: LogicalPlan) => p.limit(10)),
(" window w1 as ()", (p: LogicalPlan) => WithWindowDefinition(ws, p)),
(" window w1 as () limit 10", (p: LogicalPlan) => WithWindowDefinition(ws, p).limit(10))
)
val orderSortDistrClusterClauses = Seq(
("", basePlan),
(" order by a, b desc", basePlan.orderBy('a.asc, 'b.desc)),
(" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)),
(" distribute by a, b", basePlan.distribute('a, 'b)),
(" distribute by a sort by b", basePlan.distribute('a).sortBy('b.asc)),
(" cluster by a, b", basePlan.distribute('a, 'b).sortBy('a.asc, 'b.asc))
)
orderSortDistrClusterClauses.foreach {
case (s1, p1) =>
limitWindowClauses.foreach {
case (s2, pf2) =>
assertEqual(baseSql + s1 + s2, pf2(p1))
}
}
val msg = "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported"
intercept(s"$baseSql order by a sort by a", msg)
intercept(s"$baseSql cluster by a distribute by a", msg)
intercept(s"$baseSql order by a cluster by a", msg)
intercept(s"$baseSql order by a distribute by a", msg)
}
test("insert into") {
val sql = "select * from t"
val plan = table("t").select(star())
def insert(
partition: Map[String, Option[String]],
overwrite: Boolean = false,
ifNotExists: Boolean = false): LogicalPlan =
InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists)
// Single inserts
assertEqual(s"insert overwrite table s $sql",
insert(Map.empty, overwrite = true))
assertEqual(s"insert overwrite table s if not exists $sql",
insert(Map.empty, overwrite = true, ifNotExists = true))
assertEqual(s"insert into s $sql",
insert(Map.empty))
assertEqual(s"insert into table s partition (c = 'd', e = 1) $sql",
insert(Map("c" -> Option("d"), "e" -> Option("1"))))
assertEqual(s"insert overwrite table s partition (c = 'd', x) if not exists $sql",
insert(Map("c" -> Option("d"), "x" -> None), overwrite = true, ifNotExists = true))
// Multi insert
val plan2 = table("t").where('x > 5).select(star())
assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5",
InsertIntoTable(
table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false).union(
InsertIntoTable(
table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false)))
}
test("aggregation") {
val sql = "select a, b, sum(c) as c from d group by a, b"
// Normal
assertEqual(sql, table("d").groupBy('a, 'b)('a, 'b, 'sum.function('c).as("c")))
// Cube
assertEqual(s"$sql with cube",
table("d").groupBy(Cube(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c")))
// Rollup
assertEqual(s"$sql with rollup",
table("d").groupBy(Rollup(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c")))
// Grouping Sets
assertEqual(s"$sql grouping sets((a, b), (a), ())",
GroupingSets(Seq(0, 1, 3), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c"))))
intercept(s"$sql grouping sets((a, b), (c), ())",
"c doesn't show up in the GROUP BY list")
}
test("limit") {
val sql = "select * from t"
val plan = table("t").select(star())
assertEqual(s"$sql limit 10", plan.limit(10))
assertEqual(s"$sql limit cast(9 / 4 as int)", plan.limit(Cast(Literal(9) / 4, IntegerType)))
}
test("window spec") {
// Note that WindowSpecs are testing in the ExpressionParserSuite
val sql = "select * from t"
val plan = table("t").select(star())
val spec = WindowSpecDefinition(Seq('a, 'b), Seq('c.asc),
SpecifiedWindowFrame(RowFrame, ValuePreceding(1), ValueFollowing(1)))
// Test window resolution.
val ws1 = Map("w1" -> spec, "w2" -> spec, "w3" -> spec)
assertEqual(
s"""$sql
|window w1 as (partition by a, b order by c rows between 1 preceding and 1 following),
| w2 as w1,
| w3 as w1""".stripMargin,
WithWindowDefinition(ws1, plan))
// Fail with no reference.
intercept(s"$sql window w2 as w1", "Cannot resolve window reference 'w1'")
// Fail when resolved reference is not a window spec.
intercept(
s"""$sql
|window w1 as (partition by a, b order by c rows between 1 preceding and 1 following),
| w2 as w1,
| w3 as w2""".stripMargin,
"Window reference 'w2' is not a window specification"
)
}
test("lateral view") {
// Single lateral view
assertEqual(
"select * from t lateral view explode(x) expl as x",
table("t")
.generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x"))
.select(star()))
// Multiple lateral views
assertEqual(
"""select *
|from t
|lateral view explode(x) expl
|lateral view outer json_tuple(x, y) jtup q, z""".stripMargin,
table("t")
.generate(Explode('x), join = true, outer = false, Some("expl"), Seq.empty)
.generate(JsonTuple(Seq('x, 'y)), join = true, outer = true, Some("jtup"), Seq("q", "z"))
.select(star()))
// Multi-Insert lateral views.
val from = table("t1").generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x"))
assertEqual(
"""from t1
|lateral view explode(x) expl as x
|insert into t2
|select *
|lateral view json_tuple(x, y) jtup q, z
|insert into t3
|select *
|where s < 10
""".stripMargin,
Union(from
.generate(JsonTuple(Seq('x, 'y)), join = true, outer = false, Some("jtup"), Seq("q", "z"))
.select(star())
.insertInto("t2"),
from.where('s < 10).select(star()).insertInto("t3")))
// Unsupported generator.
intercept(
"select * from t lateral view posexplode(x) posexpl as x, y",
"Generator function 'posexplode' is not supported")
}
test("joins") {
// Test single joins.
val testUnconditionalJoin = (sql: String, jt: JoinType) => {
assertEqual(
s"select * from t as tt $sql u",
table("t").as("tt").join(table("u"), jt, None).select(star()))
}
val testConditionalJoin = (sql: String, jt: JoinType) => {
assertEqual(
s"select * from t $sql u as uu on a = b",
table("t").join(table("u").as("uu"), jt, Option('a === 'b)).select(star()))
}
val testNaturalJoin = (sql: String, jt: JoinType) => {
assertEqual(
s"select * from t tt natural $sql u as uu",
table("t").as("tt").join(table("u").as("uu"), NaturalJoin(jt), None).select(star()))
}
val testUsingJoin = (sql: String, jt: JoinType) => {
assertEqual(
s"select * from t $sql u using(a, b)",
table("t").join(table("u"), UsingJoin(jt, Seq('a.attr, 'b.attr)), None).select(star()))
}
val testAll = Seq(testUnconditionalJoin, testConditionalJoin, testNaturalJoin, testUsingJoin)
def test(sql: String, jt: JoinType, tests: Seq[(String, JoinType) => Unit]): Unit = {
tests.foreach(_(sql, jt))
}
test("cross join", Inner, Seq(testUnconditionalJoin))
test(",", Inner, Seq(testUnconditionalJoin))
test("join", Inner, testAll)
test("inner join", Inner, testAll)
test("left join", LeftOuter, testAll)
test("left outer join", LeftOuter, testAll)
test("right join", RightOuter, testAll)
test("right outer join", RightOuter, testAll)
test("full join", FullOuter, testAll)
test("full outer join", FullOuter, testAll)
// Test multiple consecutive joins
assertEqual(
"select * from a join b join c right join d",
table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star()))
}
test("sampled relations") {
val sql = "select * from t"
assertEqual(s"$sql tablesample(100 rows)",
table("t").limit(100).select(star()))
assertEqual(s"$sql tablesample(43 percent) as x",
Sample(0, .43d, withReplacement = false, 10L, table("t").as("x"))(true).select(star()))
assertEqual(s"$sql tablesample(bucket 4 out of 10) as x",
Sample(0, .4d, withReplacement = false, 10L, table("t").as("x"))(true).select(star()))
intercept(s"$sql tablesample(bucket 4 out of 10 on x) as x",
"TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported")
intercept(s"$sql tablesample(bucket 11 out of 10) as x",
s"Sampling fraction (${11.0/10.0}) must be on interval [0, 1]")
}
test("sub-query") {
val plan = table("t0").select('id)
assertEqual("select id from (t0)", plan)
assertEqual("select id from ((((((t0))))))", plan)
assertEqual(
"(select * from t1) union distinct (select * from t2)",
Distinct(table("t1").select(star()).union(table("t2").select(star()))))
assertEqual(
"select * from ((select * from t1) union (select * from t2)) t",
Distinct(
table("t1").select(star()).union(table("t2").select(star()))).as("t").select(star()))
assertEqual(
"""select id
|from (((select id from t0)
| union all
| (select id from t0))
| union all
| (select id from t0)) as u_1
""".stripMargin,
plan.union(plan).union(plan).as("u_1").select('id))
}
test("scalar sub-query") {
assertEqual(
"select (select max(b) from s) ss from t",
table("t").select(ScalarSubquery(table("s").select('max.function('b))).as("ss")))
assertEqual(
"select * from t where a = (select b from s)",
table("t").where('a === ScalarSubquery(table("s").select('b))).select(star()))
assertEqual(
"select g from t group by g having a > (select b from s)",
table("t")
.groupBy('g)('g)
.where(('a > ScalarSubquery(table("s").select('b))).cast(BooleanType)))
}
test("table reference") {
assertEqual("table t", table("t"))
assertEqual("table d.t", table("d", "t"))
}
test("inline table") {
assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows(
Seq('col1.int),
Seq(1, 2, 3, 4).map(x => Row(x))))
assertEqual(
"values (1, 'a'), (2, 'b'), (3, 'c') as tbl(a, b)",
LocalRelation.fromExternalRows(
Seq('a.int, 'b.string),
Seq((1, "a"), (2, "b"), (3, "c")).map(x => Row(x._1, x._2))).as("tbl"))
intercept("values (a, 'a'), (b, 'b')",
"All expressions in an inline table must be constants.")
intercept("values (1, 'a'), (2, 'b') as tbl(a, b, c)",
"Number of aliases must match the number of fields in an inline table.")
intercept[ArrayIndexOutOfBoundsException](parsePlan("values (1, 'a'), (2, 'b', 5Y)"))
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.parser.ng
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.TableIdentifier
class TableIdentifierParserSuite extends SparkFunSuite {
import CatalystSqlParser._
test("table identifier") {
// Regular names.
assert(TableIdentifier("q") === parseTableIdentifier("q"))
assert(TableIdentifier("q", Option("d")) === parseTableIdentifier("d.q"))
// Illegal names.
intercept[ParseException](parseTableIdentifier(""))
intercept[ParseException](parseTableIdentifier("d.q.g"))
// SQL Keywords.
val keywords = Seq("select", "from", "where", "left", "right")
keywords.foreach { keyword =>
intercept[ParseException](parseTableIdentifier(keyword))
assert(TableIdentifier(keyword) === parseTableIdentifier(s"`$keyword`"))
assert(TableIdentifier(keyword, Option("db")) === parseTableIdentifier(s"db.`$keyword`"))
}
}
}
......@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample}
import org.apache.spark.sql.catalyst.util._
/**
......@@ -32,6 +32,8 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
*/
protected def normalizeExprIds(plan: LogicalPlan) = {
plan transformAllExpressions {
case s: ScalarSubquery =>
ScalarSubquery(s.query, ExprId(0))
case a: AttributeReference =>
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
case a: Alias =>
......@@ -40,21 +42,25 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
}
/**
* Normalizes the filter conditions that appear in the plan. For instance,
* ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2)
* etc., will all now be equivalent.
* Normalizes plans:
* - Filter the filter conditions that appear in a plan. For instance,
* ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2)
* etc., will all now be equivalent.
* - Sample the seed will replaced by 0L.
*/
private def normalizeFilters(plan: LogicalPlan) = {
private def normalizePlan(plan: LogicalPlan): LogicalPlan = {
plan transform {
case filter @ Filter(condition: Expression, child: LogicalPlan) =>
Filter(splitConjunctivePredicates(condition).sortBy(_.hashCode()).reduce(And), child)
case sample: Sample =>
sample.copy(seed = 0L)(true)
}
}
/** Fails the test if the two plans do not match */
protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) {
val normalized1 = normalizeFilters(normalizeExprIds(plan1))
val normalized2 = normalizeFilters(normalizeExprIds(plan2))
val normalized1 = normalizePlan(normalizeExprIds(plan1))
val normalized2 = normalizePlan(normalizeExprIds(plan2))
if (normalized1 != normalized2) {
fail(
s"""
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution
import scala.collection.JavaConverters._
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.parser.ng.{AbstractSqlParser, AstBuilder}
import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation}
import org.apache.spark.sql.execution.command.{DescribeCommand => _, _}
import org.apache.spark.sql.execution.datasources._
/**
* Concrete parser for Spark SQL statements.
*/
object SparkSqlParser extends AbstractSqlParser{
val astBuilder = new SparkSqlAstBuilder
}
/**
* Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier.
*/
class SparkSqlAstBuilder extends AstBuilder {
import org.apache.spark.sql.catalyst.parser.ng.ParserUtils._
/**
* Create a [[SetCommand]] logical plan.
*
* Note that we assume that everything after the SET keyword is assumed to be a part of the
* key-value pair. The split between key and value is made by searching for the first `=`
* character in the raw string.
*/
override def visitSetConfiguration(ctx: SetConfigurationContext): LogicalPlan = withOrigin(ctx) {
// Construct the command.
val raw = remainder(ctx.SET.getSymbol)
val keyValueSeparatorIndex = raw.indexOf('=')
if (keyValueSeparatorIndex >= 0) {
val key = raw.substring(0, keyValueSeparatorIndex).trim
val value = raw.substring(keyValueSeparatorIndex + 1).trim
SetCommand(Some(key -> Option(value)))
} else if (raw.nonEmpty) {
SetCommand(Some(raw.trim -> None))
} else {
SetCommand(None)
}
}
/**
* Create a [[SetDatabaseCommand]] logical plan.
*/
override def visitUse(ctx: UseContext): LogicalPlan = withOrigin(ctx) {
SetDatabaseCommand(ctx.db.getText)
}
/**
* Create a [[ShowTablesCommand]] logical plan.
*/
override def visitShowTables(ctx: ShowTablesContext): LogicalPlan = withOrigin(ctx) {
if (ctx.LIKE != null) {
logWarning("SHOW TABLES LIKE option is ignored.")
}
ShowTablesCommand(Option(ctx.db).map(_.getText))
}
/**
* Create a [[RefreshTable]] logical plan.
*/
override def visitRefreshTable(ctx: RefreshTableContext): LogicalPlan = withOrigin(ctx) {
RefreshTable(visitTableIdentifier(ctx.tableIdentifier))
}
/**
* Create a [[CacheTableCommand]] logical plan.
*/
override def visitCacheTable(ctx: CacheTableContext): LogicalPlan = withOrigin(ctx) {
val query = Option(ctx.query).map(plan)
CacheTableCommand(ctx.identifier.getText, query, ctx.LAZY != null)
}
/**
* Create an [[UncacheTableCommand]] logical plan.
*/
override def visitUncacheTable(ctx: UncacheTableContext): LogicalPlan = withOrigin(ctx) {
UncacheTableCommand(ctx.identifier.getText)
}
/**
* Create a [[ClearCacheCommand]] logical plan.
*/
override def visitClearCache(ctx: ClearCacheContext): LogicalPlan = withOrigin(ctx) {
ClearCacheCommand
}
/**
* Create an [[ExplainCommand]] logical plan.
*/
override def visitExplain(ctx: ExplainContext): LogicalPlan = withOrigin(ctx) {
val options = ctx.explainOption.asScala
if (options.exists(_.FORMATTED != null)) {
logWarning("EXPLAIN FORMATTED option is ignored.")
}
if (options.exists(_.LOGICAL != null)) {
logWarning("EXPLAIN LOGICAL option is ignored.")
}
// Create the explain comment.
val statement = plan(ctx.statement)
if (isExplainableStatement(statement)) {
ExplainCommand(statement, extended = options.exists(_.EXTENDED != null))
} else {
ExplainCommand(OneRowRelation)
}
}
/**
* Determine if a plan should be explained at all.
*/
protected def isExplainableStatement(plan: LogicalPlan): Boolean = plan match {
case _: datasources.DescribeCommand => false
case _ => true
}
/**
* Create a [[DescribeCommand]] logical plan.
*/
override def visitDescribeTable(ctx: DescribeTableContext): LogicalPlan = withOrigin(ctx) {
// FORMATTED and columns are not supported. Return null and let the parser decide what to do
// with this (create an exception or pass it on to a different system).
if (ctx.describeColName != null || ctx.FORMATTED != null || ctx.partitionSpec != null) {
null
} else {
datasources.DescribeCommand(
visitTableIdentifier(ctx.tableIdentifier),
ctx.EXTENDED != null)
}
}
/** Type to keep track of a table header. */
type TableHeader = (TableIdentifier, Boolean, Boolean, Boolean)
/**
* Validate a create table statement and return the [[TableIdentifier]].
*/
override def visitCreateTableHeader(
ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) {
val temporary = ctx.TEMPORARY != null
val ifNotExists = ctx.EXISTS != null
assert(!temporary || !ifNotExists,
"a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.",
ctx)
(visitTableIdentifier(ctx.tableIdentifier), temporary, ifNotExists, ctx.EXTERNAL != null)
}
/**
* Create a [[CreateTableUsing]] or a [[CreateTableUsingAsSelect]] logical plan.
*
* TODO add bucketing and partitioning.
*/
override def visitCreateTableUsing(ctx: CreateTableUsingContext): LogicalPlan = withOrigin(ctx) {
val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader)
if (external) {
logWarning("EXTERNAL option is not supported.")
}
val options = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty)
val provider = ctx.tableProvider.qualifiedName.getText
if (ctx.query != null) {
// Get the backing query.
val query = plan(ctx.query)
// Determine the storage mode.
val mode = if (ifNotExists) {
SaveMode.Ignore
} else if (temp) {
SaveMode.Overwrite
} else {
SaveMode.ErrorIfExists
}
CreateTableUsingAsSelect(table, provider, temp, Array.empty, None, mode, options, query)
} else {
val struct = Option(ctx.colTypeList).map(createStructType)
CreateTableUsing(table, struct, provider, temp, options, ifNotExists, managedIfNoPath = false)
}
}
/**
* Convert a table property list into a key-value map.
*/
override def visitTablePropertyList(
ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) {
ctx.tableProperty.asScala.map { property =>
// A key can either be a String or a collection of dot separated elements. We need to treat
// these differently.
val key = if (property.key.STRING != null) {
string(property.key.STRING)
} else {
property.key.getText
}
val value = Option(property.value).map(string).orNull
key -> value
}.toMap
}
}
......@@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.parser.CatalystQl
import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
......@@ -1172,8 +1172,7 @@ object functions {
* @group normal_funcs
*/
def expr(expr: String): Column = {
val parser = SQLContext.getActive().map(_.sessionState.sqlParser).getOrElse(new CatalystQl())
Column(parser.parseExpression(expr))
Column(SparkSqlParser.parseExpression(expr))
}
//////////////////////////////////////////////////////////////////////////////////////////////
......
......@@ -81,7 +81,7 @@ private[sql] class SessionState(ctx: SQLContext) {
/**
* Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
*/
lazy val sqlParser: ParserInterface = new SparkQl(conf)
lazy val sqlParser: ParserInterface = SparkSqlParser
/**
* Planner that converts optimized logical plans to physical plans.
......
......@@ -329,8 +329,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
}
test("full outer join") {
upperCaseData.where('N <= 4).registerTempTable("left")
upperCaseData.where('N >= 3).registerTempTable("right")
upperCaseData.where('N <= 4).registerTempTable("`left`")
upperCaseData.where('N >= 3).registerTempTable("`right`")
val left = UnresolvedRelation(TableIdentifier("left"), None)
val right = UnresolvedRelation(TableIdentifier("right"), None)
......
......@@ -1656,7 +1656,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
val e2 = intercept[AnalysisException] {
sql("select interval 23 nanosecond")
}
assert(e2.message.contains("cannot recognize input near"))
assert(e2.message.contains("No interval can be constructed"))
}
test("SPARK-8945: add and subtract expressions for interval type") {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册