act_csum.c 16.8 KB
Newer Older
1
// SPDX-License-Identifier: GPL-2.0-or-later
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
/*
 * Checksum updating actions
 *
 * Copyright (c) 2010 Gregoire Baron <baronchon@n7mm.org>
 */

#include <linux/types.h>
#include <linux/init.h>
#include <linux/kernel.h>
#include <linux/module.h>
#include <linux/spinlock.h>

#include <linux/netlink.h>
#include <net/netlink.h>
#include <linux/rtnetlink.h>

#include <linux/skbuff.h>

#include <net/ip.h>
#include <net/ipv6.h>
#include <net/icmp.h>
#include <linux/icmpv6.h>
#include <linux/igmp.h>
#include <net/tcp.h>
#include <net/udp.h>
27
#include <net/ip6_checksum.h>
28
#include <net/sctp/checksum.h>
29 30

#include <net/act_api.h>
31
#include <net/pkt_cls.h>
32 33 34 35 36 37 38 39

#include <linux/tc_act/tc_csum.h>
#include <net/tc_act/tc_csum.h>

static const struct nla_policy csum_policy[TCA_CSUM_MAX + 1] = {
	[TCA_CSUM_PARMS] = { .len = sizeof(struct tc_csum), },
};

40
static unsigned int csum_net_id;
41
static struct tc_action_ops act_csum_ops;
42 43

static int tcf_csum_init(struct net *net, struct nlattr *nla,
44
			 struct nlattr *est, struct tc_action **a, int ovr,
45
			 int bind, bool rtnl_held, struct tcf_proto *tp,
46
			 struct netlink_ext_ack *extack)
47
{
48
	struct tc_action_net *tn = net_generic(net, csum_net_id);
49
	struct tcf_csum_params *params_new;
50
	struct nlattr *tb[TCA_CSUM_MAX + 1];
51
	struct tcf_chain *goto_ch = NULL;
52 53 54
	struct tc_csum *parm;
	struct tcf_csum *p;
	int ret = 0, err;
55
	u32 index;
56 57 58 59

	if (nla == NULL)
		return -EINVAL;

60 61
	err = nla_parse_nested_deprecated(tb, TCA_CSUM_MAX, nla, csum_policy,
					  NULL);
62 63 64 65 66 67
	if (err < 0)
		return err;

	if (tb[TCA_CSUM_PARMS] == NULL)
		return -EINVAL;
	parm = nla_data(tb[TCA_CSUM_PARMS]);
68 69
	index = parm->index;
	err = tcf_idr_check_alloc(tn, &index, a, bind);
70
	if (!err) {
71
		ret = tcf_idr_create(tn, index, est, a,
72
				     &act_csum_ops, bind, true);
73
		if (ret) {
74
			tcf_idr_cleanup(tn, index);
75
			return ret;
76
		}
77
		ret = ACT_P_CREATED;
78
	} else if (err > 0) {
79 80
		if (bind)/* dont override defaults */
			return 0;
81 82
		if (!ovr) {
			tcf_idr_release(*a, bind);
83
			return -EEXIST;
84
		}
85 86
	} else {
		return err;
87 88
	}

89 90 91 92
	err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, extack);
	if (err < 0)
		goto release_idr;

93
	p = to_tcf_csum(*a);
94 95 96

	params_new = kzalloc(sizeof(*params_new), GFP_KERNEL);
	if (unlikely(!params_new)) {
97 98
		err = -ENOMEM;
		goto put_chain;
99
	}
100
	params_new->update_flags = parm->update_flags;
101

102
	spin_lock_bh(&p->tcf_lock);
103
	goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch);
104 105
	rcu_swap_protected(p->params, params_new,
			   lockdep_is_held(&p->tcf_lock));
106
	spin_unlock_bh(&p->tcf_lock);
107

108 109
	if (goto_ch)
		tcf_chain_put_by_act(goto_ch);
110 111
	if (params_new)
		kfree_rcu(params_new, rcu);
112 113

	if (ret == ACT_P_CREATED)
114
		tcf_idr_insert(tn, *a);
115 116

	return ret;
117 118 119 120 121 122
put_chain:
	if (goto_ch)
		tcf_chain_put_by_act(goto_ch);
release_idr:
	tcf_idr_release(*a, bind);
	return err;
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
}

/**
 * tcf_csum_skb_nextlayer - Get next layer pointer
 * @skb: sk_buff to use
 * @ihl: previous summed headers length
 * @ipl: complete packet length
 * @jhl: next header length
 *
 * Check the expected next layer availability in the specified sk_buff.
 * Return the next layer pointer if pass, NULL otherwise.
 */
static void *tcf_csum_skb_nextlayer(struct sk_buff *skb,
				    unsigned int ihl, unsigned int ipl,
				    unsigned int jhl)
{
	int ntkoff = skb_network_offset(skb);
	int hl = ihl + jhl;

	if (!pskb_may_pull(skb, ipl + ntkoff) || (ipl < hl) ||
143
	    skb_try_make_writable(skb, hl + ntkoff))
144 145 146 147 148
		return NULL;
	else
		return (void *)(skb_network_header(skb) + ihl);
}

J
Jamal Hadi Salim 已提交
149 150
static int tcf_csum_ipv4_icmp(struct sk_buff *skb, unsigned int ihl,
			      unsigned int ipl)
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
{
	struct icmphdr *icmph;

	icmph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*icmph));
	if (icmph == NULL)
		return 0;

	icmph->checksum = 0;
	skb->csum = csum_partial(icmph, ipl - ihl, 0);
	icmph->checksum = csum_fold(skb->csum);

	skb->ip_summed = CHECKSUM_NONE;

	return 1;
}

static int tcf_csum_ipv4_igmp(struct sk_buff *skb,
			      unsigned int ihl, unsigned int ipl)
{
	struct igmphdr *igmph;

	igmph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*igmph));
	if (igmph == NULL)
		return 0;

	igmph->csum = 0;
	skb->csum = csum_partial(igmph, ipl - ihl, 0);
	igmph->csum = csum_fold(skb->csum);

	skb->ip_summed = CHECKSUM_NONE;

	return 1;
}

J
Jamal Hadi Salim 已提交
185 186
static int tcf_csum_ipv6_icmp(struct sk_buff *skb, unsigned int ihl,
			      unsigned int ipl)
187 188
{
	struct icmp6hdr *icmp6h;
189
	const struct ipv6hdr *ip6h;
190 191 192 193 194

	icmp6h = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*icmp6h));
	if (icmp6h == NULL)
		return 0;

195
	ip6h = ipv6_hdr(skb);
196 197 198 199 200 201 202 203 204 205 206
	icmp6h->icmp6_cksum = 0;
	skb->csum = csum_partial(icmp6h, ipl - ihl, 0);
	icmp6h->icmp6_cksum = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr,
					      ipl - ihl, IPPROTO_ICMPV6,
					      skb->csum);

	skb->ip_summed = CHECKSUM_NONE;

	return 1;
}

J
Jamal Hadi Salim 已提交
207 208
static int tcf_csum_ipv4_tcp(struct sk_buff *skb, unsigned int ihl,
			     unsigned int ipl)
209 210
{
	struct tcphdr *tcph;
211
	const struct iphdr *iph;
212

213 214 215
	if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_TCPV4)
		return 1;

216 217 218 219
	tcph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*tcph));
	if (tcph == NULL)
		return 0;

220
	iph = ip_hdr(skb);
221 222 223 224 225 226 227 228 229 230
	tcph->check = 0;
	skb->csum = csum_partial(tcph, ipl - ihl, 0);
	tcph->check = tcp_v4_check(ipl - ihl,
				   iph->saddr, iph->daddr, skb->csum);

	skb->ip_summed = CHECKSUM_NONE;

	return 1;
}

J
Jamal Hadi Salim 已提交
231 232
static int tcf_csum_ipv6_tcp(struct sk_buff *skb, unsigned int ihl,
			     unsigned int ipl)
233 234
{
	struct tcphdr *tcph;
235
	const struct ipv6hdr *ip6h;
236

237 238 239
	if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_TCPV6)
		return 1;

240 241 242 243
	tcph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*tcph));
	if (tcph == NULL)
		return 0;

244
	ip6h = ipv6_hdr(skb);
245 246 247 248 249 250 251 252 253 254 255
	tcph->check = 0;
	skb->csum = csum_partial(tcph, ipl - ihl, 0);
	tcph->check = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr,
				      ipl - ihl, IPPROTO_TCP,
				      skb->csum);

	skb->ip_summed = CHECKSUM_NONE;

	return 1;
}

J
Jamal Hadi Salim 已提交
256 257
static int tcf_csum_ipv4_udp(struct sk_buff *skb, unsigned int ihl,
			     unsigned int ipl, int udplite)
258 259
{
	struct udphdr *udph;
260
	const struct iphdr *iph;
261 262
	u16 ul;

263 264 265
	if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_UDP)
		return 1;

266 267 268
	/*
	 * Support both UDP and UDPLITE checksum algorithms, Don't use
	 * udph->len to get the real length without any protocol check,
269 270 271 272 273 274 275 276
	 * UDPLITE uses udph->len for another thing,
	 * Use iph->tot_len, or just ipl.
	 */

	udph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*udph));
	if (udph == NULL)
		return 0;

277
	iph = ip_hdr(skb);
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
	ul = ntohs(udph->len);

	if (udplite || udph->check) {

		udph->check = 0;

		if (udplite) {
			if (ul == 0)
				skb->csum = csum_partial(udph, ipl - ihl, 0);
			else if ((ul >= sizeof(*udph)) && (ul <= ipl - ihl))
				skb->csum = csum_partial(udph, ul, 0);
			else
				goto ignore_obscure_skb;
		} else {
			if (ul != ipl - ihl)
				goto ignore_obscure_skb;

			skb->csum = csum_partial(udph, ul, 0);
		}

		udph->check = csum_tcpudp_magic(iph->saddr, iph->daddr,
						ul, iph->protocol,
						skb->csum);

		if (!udph->check)
			udph->check = CSUM_MANGLED_0;
	}

	skb->ip_summed = CHECKSUM_NONE;

ignore_obscure_skb:
	return 1;
}

J
Jamal Hadi Salim 已提交
312 313
static int tcf_csum_ipv6_udp(struct sk_buff *skb, unsigned int ihl,
			     unsigned int ipl, int udplite)
314 315
{
	struct udphdr *udph;
316
	const struct ipv6hdr *ip6h;
317 318
	u16 ul;

319 320 321
	if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_UDP)
		return 1;

322 323 324
	/*
	 * Support both UDP and UDPLITE checksum algorithms, Don't use
	 * udph->len to get the real length without any protocol check,
325 326 327 328 329 330 331 332
	 * UDPLITE uses udph->len for another thing,
	 * Use ip6h->payload_len + sizeof(*ip6h) ... , or just ipl.
	 */

	udph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*udph));
	if (udph == NULL)
		return 0;

333
	ip6h = ipv6_hdr(skb);
334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366
	ul = ntohs(udph->len);

	udph->check = 0;

	if (udplite) {
		if (ul == 0)
			skb->csum = csum_partial(udph, ipl - ihl, 0);

		else if ((ul >= sizeof(*udph)) && (ul <= ipl - ihl))
			skb->csum = csum_partial(udph, ul, 0);

		else
			goto ignore_obscure_skb;
	} else {
		if (ul != ipl - ihl)
			goto ignore_obscure_skb;

		skb->csum = csum_partial(udph, ul, 0);
	}

	udph->check = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr, ul,
				      udplite ? IPPROTO_UDPLITE : IPPROTO_UDP,
				      skb->csum);

	if (!udph->check)
		udph->check = CSUM_MANGLED_0;

	skb->ip_summed = CHECKSUM_NONE;

ignore_obscure_skb:
	return 1;
}

367 368 369 370 371
static int tcf_csum_sctp(struct sk_buff *skb, unsigned int ihl,
			 unsigned int ipl)
{
	struct sctphdr *sctph;

372
	if (skb_is_gso(skb) && skb_is_gso_sctp(skb))
373 374 375 376 377 378 379 380 381
		return 1;

	sctph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*sctph));
	if (!sctph)
		return 0;

	sctph->checksum = sctp_compute_cksum(skb,
					     skb_network_offset(skb) + ihl);
	skb->ip_summed = CHECKSUM_NONE;
382
	skb->csum_not_inet = 0;
383 384 385 386

	return 1;
}

387 388
static int tcf_csum_ipv4(struct sk_buff *skb, u32 update_flags)
{
389
	const struct iphdr *iph;
390 391 392 393 394 395 396 397 398 399 400 401
	int ntkoff;

	ntkoff = skb_network_offset(skb);

	if (!pskb_may_pull(skb, sizeof(*iph) + ntkoff))
		goto fail;

	iph = ip_hdr(skb);

	switch (iph->frag_off & htons(IP_OFFSET) ? 0 : iph->protocol) {
	case IPPROTO_ICMP:
		if (update_flags & TCA_CSUM_UPDATE_FLAG_ICMP)
402 403
			if (!tcf_csum_ipv4_icmp(skb, iph->ihl * 4,
						ntohs(iph->tot_len)))
404 405 406 407
				goto fail;
		break;
	case IPPROTO_IGMP:
		if (update_flags & TCA_CSUM_UPDATE_FLAG_IGMP)
408 409
			if (!tcf_csum_ipv4_igmp(skb, iph->ihl * 4,
						ntohs(iph->tot_len)))
410 411 412 413
				goto fail;
		break;
	case IPPROTO_TCP:
		if (update_flags & TCA_CSUM_UPDATE_FLAG_TCP)
414
			if (!tcf_csum_ipv4_tcp(skb, iph->ihl * 4,
415
					       ntohs(iph->tot_len)))
416 417 418 419
				goto fail;
		break;
	case IPPROTO_UDP:
		if (update_flags & TCA_CSUM_UPDATE_FLAG_UDP)
420
			if (!tcf_csum_ipv4_udp(skb, iph->ihl * 4,
421
					       ntohs(iph->tot_len), 0))
422 423 424 425
				goto fail;
		break;
	case IPPROTO_UDPLITE:
		if (update_flags & TCA_CSUM_UPDATE_FLAG_UDPLITE)
426
			if (!tcf_csum_ipv4_udp(skb, iph->ihl * 4,
427
					       ntohs(iph->tot_len), 1))
428 429
				goto fail;
		break;
430 431 432 433 434
	case IPPROTO_SCTP:
		if ((update_flags & TCA_CSUM_UPDATE_FLAG_SCTP) &&
		    !tcf_csum_sctp(skb, iph->ihl * 4, ntohs(iph->tot_len)))
			goto fail;
		break;
435 436 437
	}

	if (update_flags & TCA_CSUM_UPDATE_FLAG_IPV4HDR) {
438
		if (skb_try_make_writable(skb, sizeof(*iph) + ntkoff))
439 440
			goto fail;

441
		ip_send_check(ip_hdr(skb));
442 443 444 445 446 447 448 449
	}

	return 1;

fail:
	return 0;
}

J
Jamal Hadi Salim 已提交
450 451
static int tcf_csum_ipv6_hopopts(struct ipv6_opt_hdr *ip6xh, unsigned int ixhl,
				 unsigned int *pl)
452 453 454 455 456 457 458 459
{
	int off, len, optlen;
	unsigned char *xh = (void *)ip6xh;

	off = sizeof(*ip6xh);
	len = ixhl - off;

	while (len > 1) {
460
		switch (xh[off]) {
461
		case IPV6_TLV_PAD1:
462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519
			optlen = 1;
			break;
		case IPV6_TLV_JUMBO:
			optlen = xh[off + 1] + 2;
			if (optlen != 6 || len < 6 || (off & 3) != 2)
				/* wrong jumbo option length/alignment */
				return 0;
			*pl = ntohl(*(__be32 *)(xh + off + 2));
			goto done;
		default:
			optlen = xh[off + 1] + 2;
			if (optlen > len)
				/* ignore obscure options */
				goto done;
			break;
		}
		off += optlen;
		len -= optlen;
	}

done:
	return 1;
}

static int tcf_csum_ipv6(struct sk_buff *skb, u32 update_flags)
{
	struct ipv6hdr *ip6h;
	struct ipv6_opt_hdr *ip6xh;
	unsigned int hl, ixhl;
	unsigned int pl;
	int ntkoff;
	u8 nexthdr;

	ntkoff = skb_network_offset(skb);

	hl = sizeof(*ip6h);

	if (!pskb_may_pull(skb, hl + ntkoff))
		goto fail;

	ip6h = ipv6_hdr(skb);

	pl = ntohs(ip6h->payload_len);
	nexthdr = ip6h->nexthdr;

	do {
		switch (nexthdr) {
		case NEXTHDR_FRAGMENT:
			goto ignore_skb;
		case NEXTHDR_ROUTING:
		case NEXTHDR_HOP:
		case NEXTHDR_DEST:
			if (!pskb_may_pull(skb, hl + sizeof(*ip6xh) + ntkoff))
				goto fail;
			ip6xh = (void *)(skb_network_header(skb) + hl);
			ixhl = ipv6_optlen(ip6xh);
			if (!pskb_may_pull(skb, hl + ixhl + ntkoff))
				goto fail;
520
			ip6xh = (void *)(skb_network_header(skb) + hl);
521 522 523 524 525 526 527 528
			if ((nexthdr == NEXTHDR_HOP) &&
			    !(tcf_csum_ipv6_hopopts(ip6xh, ixhl, &pl)))
				goto fail;
			nexthdr = ip6xh->nexthdr;
			hl += ixhl;
			break;
		case IPPROTO_ICMPV6:
			if (update_flags & TCA_CSUM_UPDATE_FLAG_ICMP)
529
				if (!tcf_csum_ipv6_icmp(skb,
530 531 532 533 534
							hl, pl + sizeof(*ip6h)))
					goto fail;
			goto done;
		case IPPROTO_TCP:
			if (update_flags & TCA_CSUM_UPDATE_FLAG_TCP)
535
				if (!tcf_csum_ipv6_tcp(skb,
536 537 538 539 540
						       hl, pl + sizeof(*ip6h)))
					goto fail;
			goto done;
		case IPPROTO_UDP:
			if (update_flags & TCA_CSUM_UPDATE_FLAG_UDP)
541
				if (!tcf_csum_ipv6_udp(skb, hl,
542
						       pl + sizeof(*ip6h), 0))
543 544 545 546
					goto fail;
			goto done;
		case IPPROTO_UDPLITE:
			if (update_flags & TCA_CSUM_UPDATE_FLAG_UDPLITE)
547
				if (!tcf_csum_ipv6_udp(skb, hl,
548
						       pl + sizeof(*ip6h), 1))
549 550
					goto fail;
			goto done;
551 552 553 554 555
		case IPPROTO_SCTP:
			if ((update_flags & TCA_CSUM_UPDATE_FLAG_SCTP) &&
			    !tcf_csum_sctp(skb, hl, pl + sizeof(*ip6h)))
				goto fail;
			goto done;
556 557 558 559 560 561 562 563 564 565 566 567 568
		default:
			goto ignore_skb;
		}
	} while (pskb_may_pull(skb, hl + 1 + ntkoff));

done:
ignore_skb:
	return 1;

fail:
	return 0;
}

569 570
static int tcf_csum_act(struct sk_buff *skb, const struct tc_action *a,
			struct tcf_result *res)
571
{
572
	struct tcf_csum *p = to_tcf_csum(a);
573 574
	bool orig_vlan_tag_present = false;
	unsigned int vlan_hdr_count = 0;
575
	struct tcf_csum_params *params;
576
	u32 update_flags;
577
	__be16 protocol;
578 579
	int action;

580
	params = rcu_dereference_bh(p->params);
581

582
	tcf_lastuse_update(&p->tcf_tm);
583
	bstats_cpu_update(this_cpu_ptr(p->common.cpu_bstats), skb);
584

585
	action = READ_ONCE(p->tcf_action);
586
	if (unlikely(action == TC_ACT_SHOT))
587
		goto drop;
588

589
	update_flags = params->update_flags;
590 591 592
	protocol = tc_skb_protocol(skb);
again:
	switch (protocol) {
593 594 595 596 597 598 599 600
	case cpu_to_be16(ETH_P_IP):
		if (!tcf_csum_ipv4(skb, update_flags))
			goto drop;
		break;
	case cpu_to_be16(ETH_P_IPV6):
		if (!tcf_csum_ipv6(skb, update_flags))
			goto drop;
		break;
601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621
	case cpu_to_be16(ETH_P_8021AD): /* fall through */
	case cpu_to_be16(ETH_P_8021Q):
		if (skb_vlan_tag_present(skb) && !orig_vlan_tag_present) {
			protocol = skb->protocol;
			orig_vlan_tag_present = true;
		} else {
			struct vlan_hdr *vlan = (struct vlan_hdr *)skb->data;

			protocol = vlan->h_vlan_encapsulated_proto;
			skb_pull(skb, VLAN_HLEN);
			skb_reset_network_header(skb);
			vlan_hdr_count++;
		}
		goto again;
	}

out:
	/* Restore the skb for the pulled VLAN tags */
	while (vlan_hdr_count--) {
		skb_push(skb, VLAN_HLEN);
		skb_reset_network_header(skb);
622 623 624 625 626
	}

	return action;

drop:
627
	qstats_drop_inc(this_cpu_ptr(p->common.cpu_qstats));
628 629
	action = TC_ACT_SHOT;
	goto out;
630 631
}

J
Jamal Hadi Salim 已提交
632 633
static int tcf_csum_dump(struct sk_buff *skb, struct tc_action *a, int bind,
			 int ref)
634 635
{
	unsigned char *b = skb_tail_pointer(skb);
636
	struct tcf_csum *p = to_tcf_csum(a);
637
	struct tcf_csum_params *params;
638 639
	struct tc_csum opt = {
		.index   = p->tcf_index,
640 641
		.refcnt  = refcount_read(&p->tcf_refcnt) - ref,
		.bindcnt = atomic_read(&p->tcf_bindcnt) - bind,
642 643 644
	};
	struct tcf_t t;

645
	spin_lock_bh(&p->tcf_lock);
646 647 648
	params = rcu_dereference_protected(p->params,
					   lockdep_is_held(&p->tcf_lock));
	opt.action = p->tcf_action;
649 650
	opt.update_flags = params->update_flags;

651 652
	if (nla_put(skb, TCA_CSUM_PARMS, sizeof(opt), &opt))
		goto nla_put_failure;
653 654

	tcf_tm_dump(&t, &p->tcf_tm);
655
	if (nla_put_64bit(skb, TCA_CSUM_TM, sizeof(t), &t, TCA_CSUM_PAD))
656
		goto nla_put_failure;
657
	spin_unlock_bh(&p->tcf_lock);
658 659 660 661

	return skb->len;

nla_put_failure:
662
	spin_unlock_bh(&p->tcf_lock);
663 664 665 666
	nlmsg_trim(skb, b);
	return -1;
}

667 668 669 670 671 672
static void tcf_csum_cleanup(struct tc_action *a)
{
	struct tcf_csum *p = to_tcf_csum(a);
	struct tcf_csum_params *params;

	params = rcu_dereference_protected(p->params, 1);
673 674
	if (params)
		kfree_rcu(params, rcu);
675 676
}

677 678
static int tcf_csum_walker(struct net *net, struct sk_buff *skb,
			   struct netlink_callback *cb, int type,
679 680
			   const struct tc_action_ops *ops,
			   struct netlink_ext_ack *extack)
681 682 683
{
	struct tc_action_net *tn = net_generic(net, csum_net_id);

684
	return tcf_generic_walker(tn, skb, cb, type, ops, extack);
685 686
}

687
static int tcf_csum_search(struct net *net, struct tc_action **a, u32 index)
688 689 690
{
	struct tc_action_net *tn = net_generic(net, csum_net_id);

691
	return tcf_idr_search(tn, a, index);
692 693
}

694 695 696 697 698
static size_t tcf_csum_get_fill_size(const struct tc_action *act)
{
	return nla_total_size(sizeof(struct tc_csum));
}

699
static struct tc_action_ops act_csum_ops = {
700
	.kind		= "csum",
701
	.id		= TCA_ID_CSUM,
702
	.owner		= THIS_MODULE,
703
	.act		= tcf_csum_act,
704 705
	.dump		= tcf_csum_dump,
	.init		= tcf_csum_init,
706
	.cleanup	= tcf_csum_cleanup,
707 708
	.walk		= tcf_csum_walker,
	.lookup		= tcf_csum_search,
709
	.get_fill_size  = tcf_csum_get_fill_size,
710
	.size		= sizeof(struct tcf_csum),
711 712 713 714 715 716
};

static __net_init int csum_init_net(struct net *net)
{
	struct tc_action_net *tn = net_generic(net, csum_net_id);

717
	return tc_action_net_init(tn, &act_csum_ops);
718 719
}

720
static void __net_exit csum_exit_net(struct list_head *net_list)
721
{
722
	tc_action_net_exit(net_list, csum_net_id);
723 724 725 726
}

static struct pernet_operations csum_net_ops = {
	.init = csum_init_net,
727
	.exit_batch = csum_exit_net,
728 729
	.id   = &csum_net_id,
	.size = sizeof(struct tc_action_net),
730 731 732 733 734 735 736
};

MODULE_DESCRIPTION("Checksum updating actions");
MODULE_LICENSE("GPL");

static int __init csum_init_module(void)
{
737
	return tcf_register_action(&act_csum_ops, &csum_net_ops);
738 739 740 741
}

static void __exit csum_cleanup_module(void)
{
742
	tcf_unregister_action(&act_csum_ops, &csum_net_ops);
743 744 745 746
}

module_init(csum_init_module);
module_exit(csum_cleanup_module);