EF 7/Core 中的 AddOrUpdate 发生了什么?

What happened to AddOrUpdate in EF 7 / Core?

我正在使用 EntityFramework.Core 7.0.0-rc1-final.

编写种子方法

DbSet的AddOrUpdate方法怎么了?

正在等待实施。查看问题 #629 & #4526

更新:根据下面的评论(未验证)- 此功能最终定于在 .NET Core 2.1 中发布!

您可以使用我创建的这个扩展方法来修补我们的代码库以迁移到 EF Core:

   public static void AddOrUpdate<T>(this DbSet<T> dbSet, T data) where T : class
        {
            var t = typeof(T);
            PropertyInfo keyField = null;
            foreach (var propt in t.GetProperties())
            {
                var keyAttr = propt.GetCustomAttribute<KeyAttribute>();
                if (keyAttr != null)
                {
                    keyField = propt;
                    break; // assume no composite keys
                }
            }
            if (keyField == null)
            {
                throw new Exception($"{t.FullName} does not have a KeyAttribute field. Unable to exec AddOrUpdate call.");
            }
            var keyVal = keyField.GetValue(data);
            var dbVal = dbSet.Find(keyVal);
            if (dbVal != null)
            {
                dbSet.Update(data);
                return;
            }
            dbSet.Add(data);
        }

我想这就是你想要的。

public static class DbSetExtension
{
    public static void AddOrUpdate<T>(this DbSet<T> dbSet, T data) where T : class
    {
        var context = dbSet.GetContext();
        var ids = context.Model.FindEntityType(typeof(T)).FindPrimaryKey().Properties.Select(x => x.Name);

        var t = typeof(T);
        List<PropertyInfo> keyFields = new List<PropertyInfo>();

        foreach (var propt in t.GetProperties())
        {
            var keyAttr = ids.Contains(propt.Name);
            if (keyAttr)
            {
                keyFields.Add(propt);
            }
        }
        if (keyFields.Count <= 0)
        {
            throw new Exception($"{t.FullName} does not have a KeyAttribute field. Unable to exec AddOrUpdate call.");
        }
        var entities = dbSet.AsNoTracking().ToList();
        foreach (var keyField in keyFields)
        {
            var keyVal = keyField.GetValue(data);
            entities = entities.Where(p => p.GetType().GetProperty(keyField.Name).GetValue(p).Equals(keyVal)).ToList();
        }
        var dbVal = entities.FirstOrDefault();
        if (dbVal != null)
        {
            context.Entry(dbVal).CurrentValues.SetValues(data);
            context.Entry(dbVal).State = EntityState.Modified;
            return;
        }
        dbSet.Add(data);
    }

    public static void AddOrUpdate<T>(this DbSet<T> dbSet, Expression<Func<T, object>> key, T data) where T : class
    {
        var context = dbSet.GetContext();
        var ids = context.Model.FindEntityType(typeof(T)).FindPrimaryKey().Properties.Select(x => x.Name);
        var t = typeof(T);
        var keyObject = key.Compile()(data);
        PropertyInfo[] keyFields = keyObject.GetType().GetProperties().Select(p=>t.GetProperty(p.Name)).ToArray();
        if (keyFields == null)
        {
            throw new Exception($"{t.FullName} does not have a KeyAttribute field. Unable to exec AddOrUpdate call.");
        }
        var keyVals = keyFields.Select(p => p.GetValue(data));
        var entities = dbSet.AsNoTracking().ToList();
        int i = 0;
        foreach (var keyVal in keyVals)
        {
            entities = entities.Where(p => p.GetType().GetProperty(keyFields[i].Name).GetValue(p).Equals(keyVal)).ToList();
            i++;
        }
        if (entities.Any())
        {
            var dbVal = entities.FirstOrDefault();
            var keyAttrs =
                data.GetType().GetProperties().Where(p => ids.Contains(p.Name)).ToList();
            if (keyAttrs.Any())
            {
                foreach (var keyAttr in keyAttrs)
                {
                    keyAttr.SetValue(data,
                        dbVal.GetType()
                            .GetProperties()
                            .FirstOrDefault(p => p.Name == keyAttr.Name)
                            .GetValue(dbVal));
                }
                context.Entry(dbVal).CurrentValues.SetValues(data);
                context.Entry(dbVal).State = EntityState.Modified;
                return;
            }                
        }
        dbSet.Add(data);
    }
}

public static class HackyDbSetGetContextTrick
{
    public static DbContext GetContext<TEntity>(this DbSet<TEntity> dbSet)
        where TEntity : class
    {
        return (DbContext)dbSet
            .GetType().GetTypeInfo()
            .GetField("_context", BindingFlags.NonPublic | BindingFlags.Instance)
            .GetValue(dbSet);
    }
}

我认为,如果假设基本实体 class 是合法选项,则此解决方案是解决此问题的更简单方法。简单性来自您实现 DomainEntityBase 的域实体,这减轻了其他建议解决方案中的许多复杂性。

public static class DbContextExtensions
{
    public static void AddOrUpdate<T>(this DbSet<T> dbSet, IEnumerable<T> records) 
        where T : DomainEntityBase
    {
        foreach (var data in records)
        {
            var exists = dbSet.AsNoTracking().Any(x => x.Id == data.Id);
            if (exists)
            {
                dbSet.Update(data);
                continue;
            }
            dbSet.Add(data);
        }
    }
}

public class DomainEntityBase
{
    [Key]
    public Guid Id { get; set; }
}

我找到了一个不错的解决方案,它允许您指定应该匹配的 属性。但是,它不需要单个实体,而是每次调用中的一个列表。它可能会为您提供一些提示,让您了解如何实现像老版本一样工作的更好版本。

https://github.com/aspnet/MusicStore/blob/7787e963dd0b7293ff95b28dcae92407231e0300/samples/MusicStore/Models/SampleData.cs#L48

(代码不是我的)

有一个扩展方法Upsert

context.Upsert(new Role { Name = "Employee", NormalizedName = "employee" })
       .On(r => new { r.Name })
       .Run();

On Github

None 的答案使用 Entity Framework Core (2.0) 对我有用所以这里是对我有用的解决方案:

public static class DbSetExtensions
{

    public static void AddOrUpdate<T>(this DbSet<T> dbSet, Expression<Func<T, object>> identifierExpression, params T[] entities) where T : class
    {
        foreach (var entity in entities)
            AddOrUpdate(dbSet, identifierExpression, entity);
    }


    public static void AddOrUpdate<T>(this DbSet<T> dbSet, Expression<Func<T, object>> identifierExpression, T entity) where T : class
    {
        if (identifierExpression == null)
            throw new ArgumentNullException(nameof(identifierExpression));
        if (entity == null)
            throw new ArgumentNullException(nameof(entity));

        var keyObject = identifierExpression.Compile()(entity);
        var parameter = Expression.Parameter(typeof(T), "p");

        var lambda = Expression.Lambda<Func<T, bool>>(
            Expression.Equal(
                ReplaceParameter(identifierExpression.Body, parameter),
                Expression.Constant(keyObject)),
            parameter);

        var item = dbSet.FirstOrDefault(lambda.Compile());
        if (item == null)
        {
            // easy case
            dbSet.Add(entity);
        }
        else
        {
            // get Key fields, using KeyAttribute if possible otherwise convention
            var dataType = typeof(T);
            var keyFields = dataType.GetProperties().Where(p => p.GetCustomAttribute<KeyAttribute>() != null).ToList();
            if (!keyFields.Any())
            {
                string idName = dataType.Name + "Id";
                keyFields = dataType.GetProperties().Where(p => 
                    string.Equals(p.Name, "Id", StringComparison.OrdinalIgnoreCase) || 
                    string.Equals(p.Name, idName, StringComparison.OrdinalIgnoreCase)).ToList();
            }

            // update all non key and non collection properties
            foreach (var p in typeof(T).GetProperties().Where(p => p.GetSetMethod() != null && p.GetGetMethod() != null))
            {
                // ignore collections
                if (p.PropertyType != typeof(string) && p.PropertyType.GetInterface(nameof(System.Collections.IEnumerable)) != null)
                    continue;

                // ignore ID fields
                if (keyFields.Any(x => x.Name == p.Name))
                    continue;

                var existingValue = p.GetValue(entity);
                if (!Equals(p.GetValue(item), existingValue))
                {
                    p.SetValue(item, existingValue);
                }
            }

            // also update key values on incoming data item where appropriate
            foreach (var idField in keyFields.Where(p => p.GetSetMethod() != null && p.GetGetMethod() != null))
            {
                var existingValue = idField.GetValue(item);
                if (!Equals(idField.GetValue(entity), existingValue))
                {
                    idField.SetValue(entity, existingValue);
                }
            }
        }
    }


    private static Expression ReplaceParameter(Expression oldExpression, ParameterExpression newParameter)
    {
        switch (oldExpression.NodeType)
        {
            case ExpressionType.MemberAccess:
                var m = (MemberExpression)oldExpression;
                return Expression.MakeMemberAccess(newParameter, m.Member);
            case ExpressionType.New:
                var newExpression = (NewExpression)oldExpression;
                var arguments = new List<Expression>();
                foreach (var a in newExpression.Arguments)
                    arguments.Add(ReplaceParameter(a, newParameter));
                var returnValue = Expression.New(newExpression.Constructor, arguments.ToArray());
                return returnValue;
            default:
                throw new NotSupportedException("Unknown expression type for AddOrUpdate: " + oldExpression.NodeType);
        }
    }
}

如果您有更复杂的 identifierExpression,您可能需要更新 ReplaceParameter() 方法。简单的 属性 访问器将适用于此实现。例如:

context.Projects.AddOrUpdate(x => x.Name, new Project { ... })
context.Projects.AddOrUpdate(x => new { x.Name, x.Description }, new Project { ... })

然后context.SaveChanges()将数据提交到数据库

我从 的回答开始,修改了两处:

  1. 我使用流利的 api 进行键指定,所以我正在寻找实体的主键而不是实体的属性
  2. 我已打开更改跟踪,但收到其他人提到的有关 EF 已在跟踪它的错误。这会在已跟踪的实体上进行查找,并将传入实体的值复制到它,然后更新原始实体

    public TEntity AddOrUpdate(TEntity entity)
    {
        var entityEntry = Context.Entry(entity);
    
        var primaryKeyName = entityEntry.Context.Model.FindEntityType(typeof(TEntity)).FindPrimaryKey().Properties
            .Select(x => x.Name).Single();
    
        var primaryKeyField = entity.GetType().GetProperty(primaryKeyName);
    
        var t = typeof(TEntity);
        if (primaryKeyField == null)
        {
            throw new Exception($"{t.FullName} does not have a primary key specified. Unable to exec AddOrUpdate call.");
        }
        var keyVal = primaryKeyField.GetValue(entity);
        var dbVal = DbSet.Find(keyVal);
    
        if (dbVal != null)
        {
            Context.Entry(dbVal).CurrentValues.SetValues(entity);
            DbSet.Update(dbVal);
    
            entity = dbVal;
        }
        else
        {
            DbSet.Add(entity);
        }
    
        return entity;
    }
    

到目前为止,我已经能够从中获得不错的里程,没有任何问题。

我在 EFCore 2.1 上使用它

下面的 MS 文档文章 Disconnected entities 说,只要数据库中的主键列具有自动生成的 (例如身份)值。

引用文章:

If it is known whether or not an insert or update is needed, then either Add or Update can be used appropriately.

However, if the entity uses auto-generated key values, then the Update method can be used for both cases.

The Update method normally marks the entity for update, not insert. However, if the entity has a auto-generated key, and no key value has been set, then the entity is instead automatically marked for insert.

This behavior was introduced in EF Core 2.0. For earlier releases it is always necessary to explicitly choose either Add or Update.

If the entity is not using auto-generated keys, then the application must decide whether the entity should be inserted or updated.

我已经在测试项目中对此进行了尝试,可以确认更新适用于使用自动生成的密钥在 EF Core 2.2 中添加和更新实体。

上面链接的 Disconnected entities 文章还包括自制 InsertOrUpdate 方法的示例代码,适用于早期版本的 EF Core 或者如果实体没有自动生成的密钥.示例代码特定于特定实体 class,需要修改以使其通用化。

这是我基于此线程中其他解决方案的解决方案。

  • 支持复合键
  • 支持 shadow-property 键。
  • 保持在 EF Core 领域内并且不使用反射。
  • 将 _appDb 更改为您的上下文。

        public object PrimaryKeyValues<TEntity>(TEntity entity)
        {
            var properties = _appDb.Model.FindEntityType(typeof(TEntity)).FindPrimaryKey().Properties;

            var entry = _appDb.Entry(entity);

            var values = properties?.Select(p => entry.Property(p.Name).CurrentValue);

            if (values?.Count() == 1)
                return values.Single();

            return values?.ToArray();
        }


        public async Task<TEntity> AddOrUpdateAsync<TEntity>(TEntity entity) where TEntity : class
        {
            var pkValue = PrimaryKeyValues(entity);

            if (pkValue == null)
            {
                throw new Exception($"{typeof(TEntity).FullName} does not have a primary key specified. Unable to exec AddOrUpdateAsync call.");
            }

            if ((await _appDb.FindAsync(typeof(TEntity), pkValue)) is TEntity dbEntry)
            {
                _appDb.Entry(dbEntry).CurrentValues.SetValues(entity);
                _appDb.Update(dbEntry);

                entity = dbEntry;
            }
            else
            {
                _appDb.Add(entity);
            }

            return entity;
        }

更新 - 添加或更新范围

完整的解决方案。 不支持影子属性键

DbContextExtensions.cs

        // FIND ALL
        // ===============================================================
        /// <summary>
        /// Tries to get all entities by their primary keys. Return all/partial/empty array of database entities.
        /// </summary>
        /// <typeparam name="TEntity"></typeparam>
        /// <param name="dbContext"></param>
        /// <param name="args"></param>
        /// <returns></returns>
        public static async Task<TEntity[]> FindAllAsync<TEntity>(this DbContext dbContext, IEnumerable<TEntity> args) where TEntity : class
        {
            return await Task.Run(() => { 
                var dbParameter = Expression.Parameter(typeof(TEntity), typeof(TEntity).Name);

                var properties = dbContext.Model.FindEntityType(typeof(TEntity)).FindPrimaryKey()?.Properties;

                if (properties == null)
                    throw new ArgumentException($"{typeof(TEntity).FullName} does not have a primary key specified.");

                if (args == null)
                    throw new ArgumentNullException($"Entities to find argument cannot be null");

                if (!args.Any())
                    return Enumerable.Empty<TEntity>().ToArray();

                var aggregatedExpression = args.Select(entity =>
                {
                    var entry = dbContext.Entry(entity);

                    return properties.Select(p =>
                    {
                        var dbProp = dbParameter.Type.GetProperty(p.Name); 
                        var left = Expression.Property(dbParameter, dbProp); 

                        var argValue = entry.Property(p.Name).CurrentValue;
                        var right = Expression.Constant(argValue);

                        return Expression.Equal(left, right);
                    })
                    .Aggregate((acc, next) => Expression.And(acc, next));
                })
                .Aggregate((acc, next) => Expression.OrElse(acc, next));

                var whereMethod = typeof(Enumerable).GetMethods().First(m => m.Name == "Where" && m.GetParameters().Length == 2);
                MethodInfo genericWhereMethod = whereMethod.MakeGenericMethod(typeof(TEntity));

                var whereLambda = Expression.Lambda(aggregatedExpression, dbParameter);

                var set = dbContext.Set<TEntity>();
                var func = whereLambda.Compile();

                var result = genericWhereMethod.Invoke(null, new object[] { set, func}) as IEnumerable<TEntity>;

                return result.ToArray();
            });
        }

        // ADD OR UPDATE - RANGE - ASYNC
        // ===============================================================
        /// <summary>
        /// Foreach entity in a range, adds it when it doesn't exist otherwise updates it. Bases decision on Pk.
        /// </summary>
        /// <typeparam name="TEntity"></typeparam>
        /// <param name="dbContext"></param>
        /// <param name="entities"></param>
        /// <returns></returns>
        public static async Task<(int AddedCount, int UpdatedCount)> AddOrUpdateRangeAsync<TEntity>(this DbContext dbContext, IEnumerable<TEntity> entities) where TEntity : class
        {
            var existingEntitiesHashes = (await dbContext.FindAllAsync(entities)).Select(x =>
            {
                dbContext.Entry(x).State = EntityState.Detached;
                return dbContext.PrimaryKeyHash(x);
            });

            var (True, False) = entities.DivideOn(x => existingEntitiesHashes.Contains(dbContext.PrimaryKeyHash(x)));

            dbContext.UpdateRange(True);
            dbContext.AddRange(False);

            return (AddedCount: False.Count(), UpdatedCount: True.Count());
        }


        // ADD OR UPDATE - ASYNC
        // ===============================================================
        /// <summary>
        /// Adds when not existing otherwise updates an entity. Bases decision on Pk.
        /// </summary>
        /// <typeparam name="TEntity"></typeparam>
        /// <param name="dbContext"></param>
        /// <param name="entity"></param>
        /// <returns></returns>
        public static async Task AddOrUpdateAsync<TEntity>(this DbContext dbContext, TEntity entity) where TEntity : class
            => await dbContext.AddOrUpdateRangeAsync(new TEntity[] { entity });

        // PK HASH
        // ===============================================================
        /// <summary>
        /// Returns the compounded hash string of all primary keys of the entity
        /// </summary>
        /// <typeparam name="TTarget"></typeparam>
        /// <param name="dbContext"></param>
        /// <param name="entity"></param>
        /// <returns></returns>
        public static string PrimaryKeyHash<TTarget>(this DbContext dbContext, TTarget entity)
        {
            var properties = dbContext.Model.FindEntityType(typeof(TTarget)).FindPrimaryKey().Properties;

            var entry = dbContext.Entry(entity);

            return properties.Select(p => Crypto.HashGUID(entry.Property(p.Name).CurrentValue))
                             .Aggregate(string.Empty, (acc, next) => acc += next);
        }

Crypto.cs

    public class Crypto
    {
        /// <summary>
        /// RETURNS A HASH AS A GUID BASED ON OBJECT.TOSTRING()
        /// </summary>
        /// <param name="obj"></param>
        /// <returns></returns>
        public static string HashGUID(object obj)
        {
            string text = string.Empty;
            MD5CryptoServiceProvider md5CryptoServiceProvider = new MD5CryptoServiceProvider();
            byte[] bytes = new UTF8Encoding().GetBytes(obj.ToString());
            byte[] array = md5CryptoServiceProvider.ComputeHash(bytes);
            for (int i = 0; i < array.Length; i++)
            {
                text += Convert.ToString(array[i], 16).PadLeft(2, '0');
            }
            md5CryptoServiceProvider.Clear();
            return text.PadLeft(32, '0');
        }
    }

IEnumerableExtensions.cs

        /// <summary>
        /// Divides into two based on predicate
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="source"></param>
        /// <param name="predicate"></param>
        /// <returns></returns>
        public static (IEnumerable<T> True, IEnumerable<T> False) DivideOn<T>(this IEnumerable<T> source, Func<T, bool> predicate)
            => (source.Where(x => predicate(x)), source.Where(x => !predicate(x)));

如果用到请点赞(ノ◕ヮ◕)ノ✲゚。⋆

我不明白为什么人们试图在其他答案中找到主键。 只需在调用方法时传递它,就像在 EF 6 AddOrUpdate 方法中所做的那样。

public static TEntity AddOrUpdate<TEntity>(this DbSet<TEntity> dbSet, DbContext context, Func<TEntity, object> identifier, TEntity entity) where TEntity : class
{
    TEntity result = dbSet.Find(identifier.Invoke(entity));
    if (result != null)
    {
        context.Entry(result).CurrentValues.SetValues(entity);
        dbSet.Update(result);
        return result;
    }
    else
    {
        dbSet.Add(entity);
        return entity;
    }
}

以后像这样使用它:

dbContext.MyModels.AddOrUpdate(dbContext, model => model.Id, new MyModel() { Id = 3 });

干净且高效。