diff --git a/net/ipv4/gre.c b/net/ipv4/gre.c index b2e805af9b87a03675d7bac1a7e210124757e566..1e294d510ac3cd06d6baf48608a31ff58555b4cd 100644 --- a/net/ipv4/gre.c +++ b/net/ipv4/gre.c @@ -26,46 +26,32 @@ static const struct gre_protocol __rcu *gre_proto[GREPROTO_MAX] __read_mostly; -static DEFINE_SPINLOCK(gre_proto_lock); int gre_add_protocol(const struct gre_protocol *proto, u8 version) { if (version >= GREPROTO_MAX) - goto err_out; - - spin_lock(&gre_proto_lock); - if (gre_proto[version]) - goto err_out_unlock; - - RCU_INIT_POINTER(gre_proto[version], proto); - spin_unlock(&gre_proto_lock); - return 0; + return -EINVAL; -err_out_unlock: - spin_unlock(&gre_proto_lock); -err_out: - return -1; + return (cmpxchg((const struct gre_protocol **)&gre_proto[version], NULL, proto) == NULL) ? + 0 : -EBUSY; } EXPORT_SYMBOL_GPL(gre_add_protocol); int gre_del_protocol(const struct gre_protocol *proto, u8 version) { + int ret; + if (version >= GREPROTO_MAX) - goto err_out; - - spin_lock(&gre_proto_lock); - if (rcu_dereference_protected(gre_proto[version], - lockdep_is_held(&gre_proto_lock)) != proto) - goto err_out_unlock; - RCU_INIT_POINTER(gre_proto[version], NULL); - spin_unlock(&gre_proto_lock); + return -EINVAL; + + ret = (cmpxchg((const struct gre_protocol **)&gre_proto[version], proto, NULL) == proto) ? + 0 : -EBUSY; + + if (ret) + return ret; + synchronize_rcu(); return 0; - -err_out_unlock: - spin_unlock(&gre_proto_lock); -err_out: - return -1; } EXPORT_SYMBOL_GPL(gre_del_protocol);