diff --git a/include/net/sctp/structs.h b/include/net/sctp/structs.h index 20e72129be1ce0063eeafcbaadcee1f37e0c614c..344da04e2e305c479ae04791b1ca33c23836ed6e 100644 --- a/include/net/sctp/structs.h +++ b/include/net/sctp/structs.h @@ -955,7 +955,7 @@ void sctp_transport_route(struct sctp_transport *, union sctp_addr *, void sctp_transport_pmtu(struct sctp_transport *, struct sock *sk); void sctp_transport_free(struct sctp_transport *); void sctp_transport_reset_timers(struct sctp_transport *); -void sctp_transport_hold(struct sctp_transport *); +int sctp_transport_hold(struct sctp_transport *); void sctp_transport_put(struct sctp_transport *); void sctp_transport_update_rto(struct sctp_transport *, __u32); void sctp_transport_raise_cwnd(struct sctp_transport *, __u32, __u32); diff --git a/net/sctp/input.c b/net/sctp/input.c index bf61dfb8e09eea1034c7b9ddd40583752033230c..49d2cc751386f3208c67eed17ec9c3f15b801f50 100644 --- a/net/sctp/input.c +++ b/net/sctp/input.c @@ -935,15 +935,22 @@ static struct sctp_association *__sctp_lookup_association( struct sctp_transport **pt) { struct sctp_transport *t; + struct sctp_association *asoc = NULL; + rcu_read_lock(); t = sctp_addrs_lookup_transport(net, local, peer); - if (!t || t->dead) - return NULL; + if (!t || !sctp_transport_hold(t)) + goto out; - sctp_association_hold(t->asoc); + asoc = t->asoc; + sctp_association_hold(asoc); *pt = t; - return t->asoc; + sctp_transport_put(t); + +out: + rcu_read_unlock(); + return asoc; } /* Look up an association. protected by RCU read lock */ @@ -955,9 +962,7 @@ struct sctp_association *sctp_lookup_association(struct net *net, { struct sctp_association *asoc; - rcu_read_lock(); asoc = __sctp_lookup_association(net, laddr, paddr, transportp); - rcu_read_unlock(); return asoc; } diff --git a/net/sctp/transport.c b/net/sctp/transport.c index aab9e3f29755a58eb7584aa763b4b1a46d2e6608..69f3799570a7e60a6a691d7f0443b664b4d54cc4 100644 --- a/net/sctp/transport.c +++ b/net/sctp/transport.c @@ -296,9 +296,9 @@ void sctp_transport_route(struct sctp_transport *transport, } /* Hold a reference to a transport. */ -void sctp_transport_hold(struct sctp_transport *transport) +int sctp_transport_hold(struct sctp_transport *transport) { - atomic_inc(&transport->refcnt); + return atomic_add_unless(&transport->refcnt, 1, 0); } /* Release a reference to a transport and clean up