diff --git a/include/net/mctp.h b/include/net/mctp.h index f2d98f6993c0e5426e744e95500af0a06bf2c069..0a460ba185b87b26aebca4f9f72729f6226c7708 100644 --- a/include/net/mctp.h +++ b/include/net/mctp.h @@ -84,9 +84,21 @@ struct mctp_sock { * updates to either list are performed under the netns_mctp->keys * lock. * - * - there is a single destruction path for a mctp_sk_key - through socket - * unhash (see mctp_sk_unhash). This performs the list removal under - * keys_lock. + * - a key may have a sk_buff attached as part of an in-progress message + * reassembly (->reasm_head). The reassembly context is protected by + * reasm_lock, which may be acquired with the keys lock (above) held, if + * necessary. Consequently, keys lock *cannot* be acquired with the + * reasm_lock held. + * + * - there are two destruction paths for a mctp_sk_key: + * + * - through socket unhash (see mctp_sk_unhash). This performs the list + * removal under keys_lock. + * + * - where a key is established to receive a reply message: after receiving + * the (complete) reply, or during reassembly errors. Here, we clean up + * the reassembly context (marking reasm_dead, to prevent another from + * starting), and remove the socket from the netns & socket lists. */ struct mctp_sk_key { mctp_eid_t peer_addr; @@ -102,6 +114,13 @@ struct mctp_sk_key { /* per-socket list */ struct hlist_node sklist; + /* incoming fragment reassembly context */ + spinlock_t reasm_lock; + struct sk_buff *reasm_head; + struct sk_buff **reasm_tailp; + bool reasm_dead; + u8 last_seq; + struct rcu_head rcu; }; diff --git a/net/mctp/af_mctp.c b/net/mctp/af_mctp.c index 52bd7f2b78db3a7ca2242bbd5c1321e0330da8a6..9ca836df19d03b1d94023a539357671282d882dd 100644 --- a/net/mctp/af_mctp.c +++ b/net/mctp/af_mctp.c @@ -263,6 +263,14 @@ static void mctp_sk_unhash(struct sock *sk) hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) { hlist_del_rcu(&key->sklist); hlist_del_rcu(&key->hlist); + + spin_lock(&key->reasm_lock); + if (key->reasm_head) + kfree_skb(key->reasm_head); + key->reasm_head = NULL; + key->reasm_dead = true; + spin_unlock(&key->reasm_lock); + kfree_rcu(key, rcu); } spin_unlock_irqrestore(&net->mctp.keys_lock, flags); diff --git a/net/mctp/route.c b/net/mctp/route.c index cc9891672eaa58fbe36ce81fbc097862014d7af8..160220e6f241a3a6ff42c7bbe253397fef4df195 100644 --- a/net/mctp/route.c +++ b/net/mctp/route.c @@ -23,6 +23,8 @@ #include #include +static const unsigned int mctp_message_maxlen = 64 * 1024; + /* route output callbacks */ static int mctp_route_discard(struct mctp_route *route, struct sk_buff *skb) { @@ -105,14 +107,125 @@ static struct mctp_sk_key *mctp_lookup_key(struct net *net, struct sk_buff *skb, return ret; } +static struct mctp_sk_key *mctp_key_alloc(struct mctp_sock *msk, + mctp_eid_t local, mctp_eid_t peer, + u8 tag, gfp_t gfp) +{ + struct mctp_sk_key *key; + + key = kzalloc(sizeof(*key), gfp); + if (!key) + return NULL; + + key->peer_addr = peer; + key->local_addr = local; + key->tag = tag; + key->sk = &msk->sk; + spin_lock_init(&key->reasm_lock); + + return key; +} + +static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk) +{ + struct net *net = sock_net(&msk->sk); + struct mctp_sk_key *tmp; + unsigned long flags; + int rc = 0; + + spin_lock_irqsave(&net->mctp.keys_lock, flags); + + hlist_for_each_entry(tmp, &net->mctp.keys, hlist) { + if (mctp_key_match(tmp, key->local_addr, key->peer_addr, + key->tag)) { + rc = -EEXIST; + break; + } + } + + if (!rc) { + hlist_add_head(&key->hlist, &net->mctp.keys); + hlist_add_head(&key->sklist, &msk->keys); + } + + spin_unlock_irqrestore(&net->mctp.keys_lock, flags); + + return rc; +} + +/* Must be called with key->reasm_lock, which it will release. Will schedule + * the key for an RCU free. + */ +static void __mctp_key_unlock_drop(struct mctp_sk_key *key, struct net *net, + unsigned long flags) + __releases(&key->reasm_lock) +{ + struct sk_buff *skb; + + skb = key->reasm_head; + key->reasm_head = NULL; + key->reasm_dead = true; + spin_unlock_irqrestore(&key->reasm_lock, flags); + + spin_lock_irqsave(&net->mctp.keys_lock, flags); + hlist_del_rcu(&key->hlist); + hlist_del_rcu(&key->sklist); + spin_unlock_irqrestore(&net->mctp.keys_lock, flags); + kfree_rcu(key, rcu); + + if (skb) + kfree_skb(skb); +} + +static int mctp_frag_queue(struct mctp_sk_key *key, struct sk_buff *skb) +{ + struct mctp_hdr *hdr = mctp_hdr(skb); + u8 exp_seq, this_seq; + + this_seq = (hdr->flags_seq_tag >> MCTP_HDR_SEQ_SHIFT) + & MCTP_HDR_SEQ_MASK; + + if (!key->reasm_head) { + key->reasm_head = skb; + key->reasm_tailp = &(skb_shinfo(skb)->frag_list); + key->last_seq = this_seq; + return 0; + } + + exp_seq = (key->last_seq + 1) & MCTP_HDR_SEQ_MASK; + + if (this_seq != exp_seq) + return -EINVAL; + + if (key->reasm_head->len + skb->len > mctp_message_maxlen) + return -EINVAL; + + skb->next = NULL; + skb->sk = NULL; + *key->reasm_tailp = skb; + key->reasm_tailp = &skb->next; + + key->last_seq = this_seq; + + key->reasm_head->data_len += skb->len; + key->reasm_head->len += skb->len; + key->reasm_head->truesize += skb->truesize; + + return 0; +} + static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) { struct net *net = dev_net(skb->dev); struct mctp_sk_key *key; struct mctp_sock *msk; struct mctp_hdr *mh; + unsigned long f; + u8 tag, flags; + int rc; msk = NULL; + rc = -EINVAL; /* we may be receiving a locally-routed packet; drop source sk * accounting @@ -121,50 +234,144 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) /* ensure we have enough data for a header and a type */ if (skb->len < sizeof(struct mctp_hdr) + 1) - goto drop; + goto out; /* grab header, advance data ptr */ mh = mctp_hdr(skb); skb_pull(skb, sizeof(struct mctp_hdr)); if (mh->ver != 1) - goto drop; + goto out; - /* TODO: reassembly */ - if ((mh->flags_seq_tag & (MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM)) - != (MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM)) - goto drop; + flags = mh->flags_seq_tag & (MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM); + tag = mh->flags_seq_tag & (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO); rcu_read_lock(); - /* 1. lookup socket matching (src,dest,tag) */ + + /* lookup socket / reasm context, exactly matching (src,dest,tag) */ key = mctp_lookup_key(net, skb, mh->src); - /* 2. lookup socket macthing (BCAST,dest,tag) */ - if (!key) - key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY); + if (flags & MCTP_HDR_FLAG_SOM) { + if (key) { + msk = container_of(key->sk, struct mctp_sock, sk); + } else { + /* first response to a broadcast? do a more general + * key lookup to find the socket, but don't use this + * key for reassembly - we'll create a more specific + * one for future packets if required (ie, !EOM). + */ + key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY); + if (key) { + msk = container_of(key->sk, + struct mctp_sock, sk); + key = NULL; + } + } - /* 3. SOM? -> lookup bound socket, conditionally (!EOM) create - * mapping for future (1)/(2). - */ - if (key) - msk = container_of(key->sk, struct mctp_sock, sk); - else if (!msk && (mh->flags_seq_tag & MCTP_HDR_FLAG_SOM)) - msk = mctp_lookup_bind(net, skb); + if (!key && !msk && (tag & MCTP_HDR_FLAG_TO)) + msk = mctp_lookup_bind(net, skb); - if (!msk) - goto unlock_drop; + if (!msk) { + rc = -ENOENT; + goto out_unlock; + } - sock_queue_rcv_skb(&msk->sk, skb); + /* single-packet message? deliver to socket, clean up any + * pending key. + */ + if (flags & MCTP_HDR_FLAG_EOM) { + sock_queue_rcv_skb(&msk->sk, skb); + if (key) { + spin_lock_irqsave(&key->reasm_lock, f); + /* we've hit a pending reassembly; not much we + * can do but drop it + */ + __mctp_key_unlock_drop(key, net, f); + } + rc = 0; + goto out_unlock; + } - rcu_read_unlock(); + /* broadcast response or a bind() - create a key for further + * packets for this message + */ + if (!key) { + key = mctp_key_alloc(msk, mh->dest, mh->src, + tag, GFP_ATOMIC); + if (!key) { + rc = -ENOMEM; + goto out_unlock; + } - return 0; + /* we can queue without the reasm lock here, as the + * key isn't observable yet + */ + mctp_frag_queue(key, skb); + + /* if the key_add fails, we've raced with another + * SOM packet with the same src, dest and tag. There's + * no way to distinguish future packets, so all we + * can do is drop; we'll free the skb on exit from + * this function. + */ + rc = mctp_key_add(key, msk); + if (rc) + kfree(key); + + } else { + /* existing key: start reassembly */ + spin_lock_irqsave(&key->reasm_lock, f); + + if (key->reasm_head || key->reasm_dead) { + /* duplicate start? drop everything */ + __mctp_key_unlock_drop(key, net, f); + rc = -EEXIST; + } else { + rc = mctp_frag_queue(key, skb); + spin_unlock_irqrestore(&key->reasm_lock, f); + } + } + + } else if (key) { + /* this packet continues a previous message; reassemble + * using the message-specific key + */ + + spin_lock_irqsave(&key->reasm_lock, f); + + /* we need to be continuing an existing reassembly... */ + if (!key->reasm_head) + rc = -EINVAL; + else + rc = mctp_frag_queue(key, skb); + + /* end of message? deliver to socket, and we're done with + * the reassembly/response key + */ + if (!rc && flags & MCTP_HDR_FLAG_EOM) { + sock_queue_rcv_skb(key->sk, key->reasm_head); + key->reasm_head = NULL; + __mctp_key_unlock_drop(key, net, f); + } else { + spin_unlock_irqrestore(&key->reasm_lock, f); + } + + } else { + /* not a start, no matching key */ + rc = -ENOENT; + } -unlock_drop: +out_unlock: rcu_read_unlock(); -drop: - kfree_skb(skb); - return 0; +out: + if (rc) + kfree_skb(skb); + return rc; +} + +static unsigned int mctp_route_mtu(struct mctp_route *rt) +{ + return rt->mtu ?: READ_ONCE(rt->dev->dev->mtu); } static int mctp_route_output(struct mctp_route *route, struct sk_buff *skb) @@ -228,8 +435,6 @@ static void mctp_reserve_tag(struct net *net, struct mctp_sk_key *key, lockdep_assert_held(&mns->keys_lock); - key->sk = &msk->sk; - /* we hold the net->key_lock here, allowing updates to both * then net and sk */ @@ -251,11 +456,9 @@ static int mctp_alloc_local_tag(struct mctp_sock *msk, u8 tagbits; /* be optimistic, alloc now */ - key = kzalloc(sizeof(*key), GFP_KERNEL); + key = mctp_key_alloc(msk, saddr, daddr, 0, GFP_KERNEL); if (!key) return -ENOMEM; - key->local_addr = saddr; - key->peer_addr = daddr; /* 8 possible tag values */ tagbits = 0xff; @@ -340,6 +543,86 @@ int mctp_do_route(struct mctp_route *rt, struct sk_buff *skb) return rc; } +static int mctp_do_fragment_route(struct mctp_route *rt, struct sk_buff *skb, + unsigned int mtu, u8 tag) +{ + const unsigned int hlen = sizeof(struct mctp_hdr); + struct mctp_hdr *hdr, *hdr2; + unsigned int pos, size; + struct sk_buff *skb2; + int rc; + u8 seq; + + hdr = mctp_hdr(skb); + seq = 0; + rc = 0; + + if (mtu < hlen + 1) { + kfree_skb(skb); + return -EMSGSIZE; + } + + /* we've got the header */ + skb_pull(skb, hlen); + + for (pos = 0; pos < skb->len;) { + /* size of message payload */ + size = min(mtu - hlen, skb->len - pos); + + skb2 = alloc_skb(MCTP_HEADER_MAXLEN + hlen + size, GFP_KERNEL); + if (!skb2) { + rc = -ENOMEM; + break; + } + + /* generic skb copy */ + skb2->protocol = skb->protocol; + skb2->priority = skb->priority; + skb2->dev = skb->dev; + memcpy(skb2->cb, skb->cb, sizeof(skb2->cb)); + + if (skb->sk) + skb_set_owner_w(skb2, skb->sk); + + /* establish packet */ + skb_reserve(skb2, MCTP_HEADER_MAXLEN); + skb_reset_network_header(skb2); + skb_put(skb2, hlen + size); + skb2->transport_header = skb2->network_header + hlen; + + /* copy header fields, calculate SOM/EOM flags & seq */ + hdr2 = mctp_hdr(skb2); + hdr2->ver = hdr->ver; + hdr2->dest = hdr->dest; + hdr2->src = hdr->src; + hdr2->flags_seq_tag = tag & + (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO); + + if (pos == 0) + hdr2->flags_seq_tag |= MCTP_HDR_FLAG_SOM; + + if (pos + size == skb->len) + hdr2->flags_seq_tag |= MCTP_HDR_FLAG_EOM; + + hdr2->flags_seq_tag |= seq << MCTP_HDR_SEQ_SHIFT; + + /* copy message payload */ + skb_copy_bits(skb, pos, skb_transport_header(skb2), size); + + /* do route, but don't drop the rt reference */ + rc = rt->output(rt, skb2); + if (rc) + break; + + seq = (seq + 1) & MCTP_HDR_SEQ_MASK; + pos += size; + } + + mctp_route_release(rt); + consume_skb(skb); + return rc; +} + int mctp_local_output(struct sock *sk, struct mctp_route *rt, struct sk_buff *skb, mctp_eid_t daddr, u8 req_tag) { @@ -347,6 +630,7 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt, struct mctp_skb_cb *cb = mctp_cb(skb); struct mctp_hdr *hdr; unsigned long flags; + unsigned int mtu; mctp_eid_t saddr; int rc; u8 tag; @@ -376,26 +660,32 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt, tag = req_tag; } - /* TODO: we have the route MTU here; packetise */ + skb->protocol = htons(ETH_P_MCTP); + skb->priority = 0; skb_reset_transport_header(skb); skb_push(skb, sizeof(struct mctp_hdr)); skb_reset_network_header(skb); + skb->dev = rt->dev->dev; + + /* cb->net will have been set on initial ingress */ + cb->src = saddr; + + /* set up common header fields */ hdr = mctp_hdr(skb); hdr->ver = 1; hdr->dest = daddr; hdr->src = saddr; - hdr->flags_seq_tag = MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM | /* TODO */ - tag; - skb->dev = rt->dev->dev; - skb->protocol = htons(ETH_P_MCTP); - skb->priority = 0; + mtu = mctp_route_mtu(rt); - /* cb->net will have been set on initial ingress */ - cb->src = saddr; - - return mctp_do_route(rt, skb); + if (skb->len + sizeof(struct mctp_hdr) <= mtu) { + hdr->flags_seq_tag = MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM | + tag; + return mctp_do_route(rt, skb); + } else { + return mctp_do_fragment_route(rt, skb, mtu, tag); + } } /* route management */