未验证 提交 92577b65 编写于 作者: J jods 提交者: GitHub

Failing test case for #3945 (#3946)

* Failing test case for #3945

* asking right questions

* - remove unnecessary SELECT * wrapper

- fix DB2 CTE location

* remove recursive flag for CTEs without dependencies

* hack the hack

* fix recursive cte detection

* monkeypatching at its best

* fix test baselines generation

---------
Co-authored-by: NMaceWindu <MaceWindu@users.noreply.github.com>
上级 d05a6455
......@@ -15,6 +15,8 @@ namespace LinqToDB.DataProvider.DB2
abstract partial class DB2SqlBuilderBase : BasicSqlBuilder<DB2Options>
{
public override bool CteFirst => false;
protected DB2SqlBuilderBase(IDataProvider? provider, MappingSchema mappingSchema, DataOptions dataOptions, ISqlOptimizer sqlOptimizer, SqlProviderFlags sqlProviderFlags)
: base(provider, mappingSchema, dataOptions, sqlOptimizer, sqlProviderFlags)
{
......
......@@ -18,6 +18,8 @@ namespace LinqToDB.DataProvider.Firebird
public partial class FirebirdSqlBuilder : BasicSqlBuilder<FirebirdOptions>
{
public override bool CteFirst => false;
public FirebirdSqlBuilder(IDataProvider provider, MappingSchema mappingSchema, DataOptions dataOptions, ISqlOptimizer sqlOptimizer, SqlProviderFlags sqlProviderFlags)
: base(provider, mappingSchema, dataOptions, sqlOptimizer, sqlProviderFlags)
{
......@@ -339,18 +341,6 @@ protected override void BuildDeleteQuery(SqlDeleteStatement deleteStatement)
}
}
protected override void BuildInsertQuery(SqlStatement statement, SqlInsertClause insertClause, bool addAlias)
{
if (statement is SqlStatementWithQueryBase withQuery && withQuery.With?.Clauses.Count > 0)
{
BuildInsertQuery2(statement, insertClause, addAlias);
}
else
{
base.BuildInsertQuery(statement, insertClause, addAlias);
}
}
protected override void BuildCreateTableCommand(SqlTable table)
{
string command;
......
......@@ -12,6 +12,8 @@ namespace LinqToDB.DataProvider.Oracle
abstract partial class OracleSqlBuilderBase : BasicSqlBuilder<OracleOptions>
{
public override bool CteFirst => false;
protected OracleSqlBuilderBase(IDataProvider? provider, MappingSchema mappingSchema, DataOptions dataOptions, ISqlOptimizer sqlOptimizer, SqlProviderFlags sqlProviderFlags)
: base(provider, mappingSchema, dataOptions, sqlOptimizer, sqlProviderFlags)
{
......@@ -150,18 +152,6 @@ protected override void BuildDeleteQuery(SqlDeleteStatement deleteStatement)
}
}
protected override void BuildInsertQuery(SqlStatement statement, SqlInsertClause insertClause, bool addAlias)
{
if (statement is SqlStatementWithQueryBase withQuery && withQuery.With?.Clauses.Count > 0)
{
BuildInsertQuery2(statement, insertClause, addAlias);
}
else
{
base.BuildInsertQuery(statement, insertClause, addAlias);
}
}
protected sealed override bool IsReserved(string word)
{
// TODO: now we use static 11g list
......
......@@ -65,6 +65,14 @@ protected BasicSqlBuilder(BasicSqlBuilder parentBuilder)
public virtual bool IsNestedJoinSupported => true;
public virtual bool IsNestedJoinParenthesisRequired => false;
/// <summary>
/// Identifies CTE clause location:
/// <list type="bullet">
/// <item><c>CteFirst = true</c> (default): WITH clause goes first in query</item>
/// <item><c>CteFirst = false</c>: WITH clause goes before SELECT</item>
/// </list>
/// </summary>
public virtual bool CteFirst => true;
/// <summary>
/// True if it is needed to wrap join condition with ()
......@@ -390,6 +398,12 @@ protected virtual void BuildCteBody(SelectQuery selectQuery)
protected virtual void BuildInsertQuery(SqlStatement statement, SqlInsertClause insertClause, bool addAlias)
{
if (!CteFirst && statement is SqlStatementWithQueryBase withQuery && withQuery.With?.Clauses.Count > 0)
{
BuildInsertQuery2(statement, insertClause, addAlias);
return;
}
BuildStep = Step.Tag; BuildTag(statement);
BuildStep = Step.WithClause; BuildWithClause(statement.GetWithClause());
BuildStep = Step.InsertClause; BuildInsertClause(statement, insertClause, addAlias);
......@@ -420,11 +434,6 @@ protected void BuildInsertQuery2(SqlStatement statement, SqlInsertClause insertC
BuildStep = Step.Tag; BuildTag(statement);
BuildStep = Step.InsertClause; BuildInsertClause(statement, insertClause, addAlias);
AppendIndent().AppendLine("SELECT * FROM");
AppendIndent().AppendLine(OpenParens);
++Indent;
BuildStep = Step.WithClause; BuildWithClause(statement.GetWithClause());
if (statement.QueryType == QueryType.Insert && statement.SelectQuery!.From.Tables.Count != 0)
......@@ -443,10 +452,6 @@ protected void BuildInsertQuery2(SqlStatement statement, SqlInsertClause insertC
BuildGetIdentity(insertClause);
else
BuildOutputSubclause(statement.GetOutputClause());
--Indent;
AppendIndent().AppendLine(")");
}
protected virtual void BuildMultiInsertQuery(SqlMultiInsertStatement statement)
......
......@@ -313,8 +313,7 @@ static void RegisterDependency(CteClause cteClause, Dictionary<CteClause, HashSe
}
});
// self-reference is allowed, so we do not need to add dependency
dependsOn.Remove(cteClause);
foundCte.Add(cteClause, dependsOn);
foreach (var clause in dependsOn)
......@@ -372,6 +371,17 @@ void FinalizeCte(SqlStatement statement)
if (!SqlProviderFlags.IsCommonTableExpressionsSupported)
throw new LinqToDBException("DataProvider do not supports Common Table Expressions.");
// basic detection of non-recursive CTEs
// for more complex cases we will need dependency cycles detection
foreach (var kvp in cteHolder.WriteableValue)
{
if (kvp.Value.Count == 0)
kvp.Key.IsRecursive = false;
// remove self-reference for topo-sort
kvp.Value.Remove(kvp.Key);
}
var ordered = TopoSorting.TopoSort(cteHolder.WriteableValue.Keys, cteHolder, static (cteHolder, i) => cteHolder.WriteableValue![i]).ToList();
Utils.MakeUniqueNames(ordered, null, static (n, a) => !ReservedWords.IsReserved(n), static c => c.Name, static (c, n, a) => c.Name = n,
......
......@@ -155,7 +155,7 @@ static QueryData GetQueryData(IQueryElement? root, SelectQuery selectQuery, Hash
{
var t = FindField(field, table);
if (t != null)
if (t != null && !FindField(q.Select.Columns, field))
{
var n = q.Select.Columns.Count;
var idx = q.Select.Add(field);
......@@ -172,18 +172,27 @@ static QueryData GetQueryData(IQueryElement? root, SelectQuery selectQuery, Hash
return null;
}
static bool FindField(List<SqlColumn> columns, SqlField field)
{
foreach (var column in columns)
if (column.Expression != field && column.Expression.Find(field, static (field, e) => field == e) != null)
return true;
return false;
}
static void ResolveFields(QueryData data)
{
if (data.Queries.Count == 0)
return;
var dic = new Dictionary<ISqlExpression,ISqlExpression>();
Dictionary<ISqlExpression,ISqlExpression>? dic = null;
foreach (var sqlExpression in data.Fields)
{
var field = (SqlField)sqlExpression;
if (dic.ContainsKey(field))
if (dic?.ContainsKey(field) == true)
continue;
var found = false;
......@@ -201,11 +210,11 @@ static void ResolveFields(QueryData data)
var expr = GetColumn(data, field);
if (expr != null)
dic.Add(field, expr);
(dic ??= new()).Add(field, expr);
}
}
if (dic.Count > 0)
if (dic != null)
data.Query.VisitParentFirst((dic, data), static (context, e) =>
{
ISqlExpression? ex;
......
......@@ -1384,5 +1384,22 @@ public void Issue2264([CteContextSource] string context)
query.ToArray();
}
[Test]
public void Issue3945([CteContextSource] string context)
{
using var db = GetDataContext(context);
using var tb = db.CreateLocalTable<TestFolder>();
var cte = db.GetCte<TestFolder>("CTE", cte => tb.Where(c => c.ParentId != null));
var join = from child in cte
join parent in tb on child.ParentId equals parent.Id
select new TestFolder
{
Id = TestData.Guid1,
Label = parent.Label + "/" + child.Label,
};
join.Insert(tb, x => x);
}
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册