diff --git a/Dapper.Contrib.Tests/Tests.cs b/Dapper.Contrib.Tests/Tests.cs index 8d3108788077877f62dee3f70fe2af5309638ec2..687de7e5c036a11ab3a6e555cfb6d3f39f618a28 100644 --- a/Dapper.Contrib.Tests/Tests.cs +++ b/Dapper.Contrib.Tests/Tests.cs @@ -197,13 +197,37 @@ public void InsertGetUpdate() connection.Query("select * from Users").Count().IsEqualTo(0); connection.Update(notrackedUser).IsEqualTo(false); //returns false, user not found - - //insert with custom sqladapter - connection.Insert(new User { Name = "Adam", Age = 10 }, sqlAdapter: new SqlServerAdapter()).IsMoreThan(0); } } - + public void InsertWithCustomDbType() + { + SqlMapperExtensions.GetDatabaseType = (conn) => "SQLiteConnection"; + + bool sqliteCodeCalled = false; + using (var connection = GetOpenConnection()) + { + connection.DeleteAll(); + connection.Get(3).IsNull(); + try + { + var id = connection.Insert(new User { Name = "Adam", Age = 10 }); + } + catch (SqlCeException ex) + { + sqliteCodeCalled = ex.Message.IndexOf("last_insert_rowid", StringComparison.InvariantCultureIgnoreCase) >= 0; + } + catch (Exception) + { + } + } + SqlMapperExtensions.GetDatabaseType = null; + + if (!sqliteCodeCalled) + { + throw new Exception("Was expecting sqlite code to be called"); + } + } public void GetAll() { diff --git a/Dapper.Contrib/SqlMapperExtensions.cs b/Dapper.Contrib/SqlMapperExtensions.cs index 1887c078e94b3c3a60b5f5c9c84f0cd5d7fced9e..76a6768b7eac946f9f864d88fdf7dfa4060546ab 100644 --- a/Dapper.Contrib/SqlMapperExtensions.cs +++ b/Dapper.Contrib/SqlMapperExtensions.cs @@ -23,6 +23,7 @@ public interface IProxy //must be kept public { bool IsDirty { get; set; } } + public delegate string GetDatabaseTypeDelegate(IDbConnection connection); private static readonly ConcurrentDictionary> KeyProperties = new ConcurrentDictionary>(); private static readonly ConcurrentDictionary> TypeProperties = new ConcurrentDictionary>(); @@ -170,7 +171,7 @@ private static bool IsWriteable(PropertyInfo pi) public static IEnumerable GetAll(this IDbConnection connection, IDbTransaction transaction = null, int? commandTimeout = null) where T : class { var type = typeof(T); - var cacheType = typeof (List); + var cacheType = typeof(List); string sql; if (!GetQueries.TryGetValue(cacheType.TypeHandle, out sql)) @@ -184,7 +185,7 @@ private static bool IsWriteable(PropertyInfo pi) var name = GetTableName(type); // TODO: query information schema and only select fields that are both in information schema and underlying class / interface - sql = "select * from " + name ; + sql = "select * from " + name; GetQueries[cacheType.TypeHandle] = sql; } @@ -226,17 +227,17 @@ private static string GetTableName(Type type) return name; } + /// /// Inserts an entity into table "Ts" and returns identity id or number if inserted rows if inserting a list. /// /// Open SqlConnection /// Entity to insert, can be list of entities /// Identity of inserted entity, or number of inserted rows if inserting a list - public static long Insert(this IDbConnection connection, T entityToInsert, IDbTransaction transaction = null, - int? commandTimeout = null, ISqlAdapter sqlAdapter = null) where T : class + public static long Insert(this IDbConnection connection, T entityToInsert, IDbTransaction transaction = null, int? commandTimeout = null) where T : class { var isList = false; - + var type = typeof(T); if (type.IsArray || type.IsGenericType) @@ -271,12 +272,11 @@ private static string GetTableName(Type type) if (!isList) //single entity { - if(sqlAdapter == null) - sqlAdapter = GetFormatter(connection); - return sqlAdapter.Insert(connection, transaction, commandTimeout, name, sbColumnList.ToString(), + var adapter = GetFormatter(connection); + return adapter.Insert(connection, transaction, commandTimeout, name, sbColumnList.ToString(), sbParameterList.ToString(), keyProperties, entityToInsert); } - + //insert list of entities var cmd = String.Format("insert into {0} ({1}) values ({2})", name, sbColumnList, sbParameterList); return connection.Execute(cmd, entityToInsert, transaction, commandTimeout); @@ -386,11 +386,29 @@ private static string GetTableName(Type type) return deleted > 0; } + /// + /// Specifies a custom callback that detects the database type instead of relying on the default strategy (the name of the connection type object). + /// Please note that this callback is global and will be used by all the calls that require a database specific adapter. + /// + public static GetDatabaseTypeDelegate GetDatabaseType; + private static ISqlAdapter GetFormatter(IDbConnection connection) { - var name = connection.GetType().Name.ToLower(); - return !AdapterDictionary.ContainsKey(name) ? - new SqlServerAdapter() : + string name; + var getDatabaseType = GetDatabaseType; + if (getDatabaseType != null) + { + name = getDatabaseType(connection); + if (name != null) + name = name.ToLower(); + } + else + { + name = connection.GetType().Name.ToLower(); + } + + return !AdapterDictionary.ContainsKey(name) ? + new SqlServerAdapter() : AdapterDictionary[name]; } @@ -552,8 +570,8 @@ public TableAttribute(string tableName) Name = tableName; } -// ReSharper disable once MemberCanBePrivate.Global -// ReSharper disable once UnusedAutoPropertyAccessor.Global + // ReSharper disable once MemberCanBePrivate.Global + // ReSharper disable once UnusedAutoPropertyAccessor.Global public string Name { get; set; } } @@ -590,7 +608,7 @@ public int Insert(IDbConnection connection, IDbTransaction transaction, int? com { var cmd = String.Format("insert into {0} ({1}) values ({2})", tableName, columnList, parameterList); - connection.Execute(cmd, entityToInsert, transaction, commandTimeout); + connection.Execute(cmd, entityToInsert, transaction, commandTimeout); //NOTE: would prefer to use IDENT_CURRENT('tablename') or IDENT_SCOPE but these are not available on SQLCE var r = connection.Query("select @@IDENTITY id", transaction: transaction, commandTimeout: commandTimeout);