diff --git a/include/net/seg6_local.h b/include/net/seg6_local.h index 661fd5b4d3e0b9ea40079d161131d6471d7335aa..08359e2d8b357683bd8fd6a3ea82290b5c24a26b 100644 --- a/include/net/seg6_local.h +++ b/include/net/seg6_local.h @@ -21,10 +21,12 @@ extern int seg6_lookup_nexthop(struct sk_buff *skb, struct in6_addr *nhaddr, u32 tbl_id); +extern bool seg6_bpf_has_valid_srh(struct sk_buff *skb); struct seg6_bpf_srh_state { - bool valid; + struct ipv6_sr_hdr *srh; u16 hdrlen; + bool valid; }; DECLARE_PER_CPU(struct seg6_bpf_srh_state, seg6_bpf_srh_states); diff --git a/net/core/filter.c b/net/core/filter.c index 104d560946da8cbbc76f2d411e68842ea215a51a..7df1a0f1d1e16e663459a5929de4370519e8b606 100644 --- a/net/core/filter.c +++ b/net/core/filter.c @@ -4542,26 +4542,28 @@ BPF_CALL_4(bpf_lwt_seg6_store_bytes, struct sk_buff *, skb, u32, offset, { struct seg6_bpf_srh_state *srh_state = this_cpu_ptr(&seg6_bpf_srh_states); + struct ipv6_sr_hdr *srh = srh_state->srh; void *srh_tlvs, *srh_end, *ptr; - struct ipv6_sr_hdr *srh; int srhoff = 0; - if (ipv6_find_hdr(skb, &srhoff, IPPROTO_ROUTING, NULL, NULL) < 0) + if (srh == NULL) return -EINVAL; - srh = (struct ipv6_sr_hdr *)(skb->data + srhoff); srh_tlvs = (void *)((char *)srh + ((srh->first_segment + 1) << 4)); srh_end = (void *)((char *)srh + sizeof(*srh) + srh_state->hdrlen); ptr = skb->data + offset; if (ptr >= srh_tlvs && ptr + len <= srh_end) - srh_state->valid = 0; + srh_state->valid = false; else if (ptr < (void *)&srh->flags || ptr + len > (void *)&srh->segments) return -EFAULT; if (unlikely(bpf_try_make_writable(skb, offset + len))) return -EFAULT; + if (ipv6_find_hdr(skb, &srhoff, IPPROTO_ROUTING, NULL, NULL) < 0) + return -EINVAL; + srh_state->srh = (struct ipv6_sr_hdr *)(skb->data + srhoff); memcpy(skb->data + offset, from, len); return 0; @@ -4577,52 +4579,78 @@ static const struct bpf_func_proto bpf_lwt_seg6_store_bytes_proto = { .arg4_type = ARG_CONST_SIZE }; -BPF_CALL_4(bpf_lwt_seg6_action, struct sk_buff *, skb, - u32, action, void *, param, u32, param_len) +static void bpf_update_srh_state(struct sk_buff *skb) { struct seg6_bpf_srh_state *srh_state = this_cpu_ptr(&seg6_bpf_srh_states); - struct ipv6_sr_hdr *srh; int srhoff = 0; - int err; - - if (ipv6_find_hdr(skb, &srhoff, IPPROTO_ROUTING, NULL, NULL) < 0) - return -EINVAL; - srh = (struct ipv6_sr_hdr *)(skb->data + srhoff); - - if (!srh_state->valid) { - if (unlikely((srh_state->hdrlen & 7) != 0)) - return -EBADMSG; - - srh->hdrlen = (u8)(srh_state->hdrlen >> 3); - if (unlikely(!seg6_validate_srh(srh, (srh->hdrlen + 1) << 3))) - return -EBADMSG; - srh_state->valid = 1; + if (ipv6_find_hdr(skb, &srhoff, IPPROTO_ROUTING, NULL, NULL) < 0) { + srh_state->srh = NULL; + } else { + srh_state->srh = (struct ipv6_sr_hdr *)(skb->data + srhoff); + srh_state->hdrlen = srh_state->srh->hdrlen << 3; + srh_state->valid = true; } +} + +BPF_CALL_4(bpf_lwt_seg6_action, struct sk_buff *, skb, + u32, action, void *, param, u32, param_len) +{ + struct seg6_bpf_srh_state *srh_state = + this_cpu_ptr(&seg6_bpf_srh_states); + int hdroff = 0; + int err; switch (action) { case SEG6_LOCAL_ACTION_END_X: + if (!seg6_bpf_has_valid_srh(skb)) + return -EBADMSG; if (param_len != sizeof(struct in6_addr)) return -EINVAL; return seg6_lookup_nexthop(skb, (struct in6_addr *)param, 0); case SEG6_LOCAL_ACTION_END_T: + if (!seg6_bpf_has_valid_srh(skb)) + return -EBADMSG; + if (param_len != sizeof(int)) + return -EINVAL; + return seg6_lookup_nexthop(skb, NULL, *(int *)param); + case SEG6_LOCAL_ACTION_END_DT6: + if (!seg6_bpf_has_valid_srh(skb)) + return -EBADMSG; if (param_len != sizeof(int)) return -EINVAL; + + if (ipv6_find_hdr(skb, &hdroff, IPPROTO_IPV6, NULL, NULL) < 0) + return -EBADMSG; + if (!pskb_pull(skb, hdroff)) + return -EBADMSG; + + skb_postpull_rcsum(skb, skb_network_header(skb), hdroff); + skb_reset_network_header(skb); + skb_reset_transport_header(skb); + skb->encapsulation = 0; + + bpf_compute_data_pointers(skb); + bpf_update_srh_state(skb); return seg6_lookup_nexthop(skb, NULL, *(int *)param); case SEG6_LOCAL_ACTION_END_B6: + if (srh_state->srh && !seg6_bpf_has_valid_srh(skb)) + return -EBADMSG; err = bpf_push_seg6_encap(skb, BPF_LWT_ENCAP_SEG6_INLINE, param, param_len); if (!err) - srh_state->hdrlen = - ((struct ipv6_sr_hdr *)param)->hdrlen << 3; + bpf_update_srh_state(skb); + return err; case SEG6_LOCAL_ACTION_END_B6_ENCAP: + if (srh_state->srh && !seg6_bpf_has_valid_srh(skb)) + return -EBADMSG; err = bpf_push_seg6_encap(skb, BPF_LWT_ENCAP_SEG6, param, param_len); if (!err) - srh_state->hdrlen = - ((struct ipv6_sr_hdr *)param)->hdrlen << 3; + bpf_update_srh_state(skb); + return err; default: return -EINVAL; @@ -4644,15 +4672,14 @@ BPF_CALL_3(bpf_lwt_seg6_adjust_srh, struct sk_buff *, skb, u32, offset, { struct seg6_bpf_srh_state *srh_state = this_cpu_ptr(&seg6_bpf_srh_states); + struct ipv6_sr_hdr *srh = srh_state->srh; void *srh_end, *srh_tlvs, *ptr; - struct ipv6_sr_hdr *srh; struct ipv6hdr *hdr; int srhoff = 0; int ret; - if (ipv6_find_hdr(skb, &srhoff, IPPROTO_ROUTING, NULL, NULL) < 0) + if (unlikely(srh == NULL)) return -EINVAL; - srh = (struct ipv6_sr_hdr *)(skb->data + srhoff); srh_tlvs = (void *)((unsigned char *)srh + sizeof(*srh) + ((srh->first_segment + 1) << 4)); @@ -4682,8 +4709,11 @@ BPF_CALL_3(bpf_lwt_seg6_adjust_srh, struct sk_buff *, skb, u32, offset, hdr = (struct ipv6hdr *)skb->data; hdr->payload_len = htons(skb->len - sizeof(struct ipv6hdr)); + if (ipv6_find_hdr(skb, &srhoff, IPPROTO_ROUTING, NULL, NULL) < 0) + return -EINVAL; + srh_state->srh = (struct ipv6_sr_hdr *)(skb->data + srhoff); srh_state->hdrlen += len; - srh_state->valid = 0; + srh_state->valid = false; return 0; } diff --git a/net/ipv6/seg6_local.c b/net/ipv6/seg6_local.c index e1025b493a185ee8eaba0941c05bc293b025343f..60325dbfe88b0ca8ffb08cc89666642c52739480 100644 --- a/net/ipv6/seg6_local.c +++ b/net/ipv6/seg6_local.c @@ -459,36 +459,57 @@ static int input_action_end_b6_encap(struct sk_buff *skb, DEFINE_PER_CPU(struct seg6_bpf_srh_state, seg6_bpf_srh_states); +bool seg6_bpf_has_valid_srh(struct sk_buff *skb) +{ + struct seg6_bpf_srh_state *srh_state = + this_cpu_ptr(&seg6_bpf_srh_states); + struct ipv6_sr_hdr *srh = srh_state->srh; + + if (unlikely(srh == NULL)) + return false; + + if (unlikely(!srh_state->valid)) { + if ((srh_state->hdrlen & 7) != 0) + return false; + + srh->hdrlen = (u8)(srh_state->hdrlen >> 3); + if (!seg6_validate_srh(srh, (srh->hdrlen + 1) << 3)) + return false; + + srh_state->valid = true; + } + + return true; +} + static int input_action_end_bpf(struct sk_buff *skb, struct seg6_local_lwt *slwt) { struct seg6_bpf_srh_state *srh_state = this_cpu_ptr(&seg6_bpf_srh_states); - struct seg6_bpf_srh_state local_srh_state; struct ipv6_sr_hdr *srh; - int srhoff = 0; int ret; srh = get_and_validate_srh(skb); - if (!srh) - goto drop; + if (!srh) { + kfree_skb(skb); + return -EINVAL; + } advance_nextseg(srh, &ipv6_hdr(skb)->daddr); /* preempt_disable is needed to protect the per-CPU buffer srh_state, * which is also accessed by the bpf_lwt_seg6_* helpers */ preempt_disable(); + srh_state->srh = srh; srh_state->hdrlen = srh->hdrlen << 3; - srh_state->valid = 1; + srh_state->valid = true; rcu_read_lock(); bpf_compute_data_pointers(skb); ret = bpf_prog_run_save_cb(slwt->bpf.prog, skb); rcu_read_unlock(); - local_srh_state = *srh_state; - preempt_enable(); - switch (ret) { case BPF_OK: case BPF_REDIRECT: @@ -500,24 +521,17 @@ static int input_action_end_bpf(struct sk_buff *skb, goto drop; } - if (unlikely((local_srh_state.hdrlen & 7) != 0)) - goto drop; - - if (ipv6_find_hdr(skb, &srhoff, IPPROTO_ROUTING, NULL, NULL) < 0) - goto drop; - srh = (struct ipv6_sr_hdr *)(skb->data + srhoff); - srh->hdrlen = (u8)(local_srh_state.hdrlen >> 3); - - if (!local_srh_state.valid && - unlikely(!seg6_validate_srh(srh, (srh->hdrlen + 1) << 3))) + if (srh_state->srh && !seg6_bpf_has_valid_srh(skb)) goto drop; + preempt_enable(); if (ret != BPF_REDIRECT) seg6_lookup_nexthop(skb, NULL, 0); return dst_input(skb); drop: + preempt_enable(); kfree_skb(skb); return -EINVAL; }