mcast_snoop.c 5.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
/* Copyright (C) 2010: YOSHIFUJI Hideaki <yoshfuji@linux-ipv6.org>
 * Copyright (C) 2015: Linus Lüssing <linus.luessing@c0d3.blue>
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of version 2 of the GNU General Public
 * License as published by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, see <http://www.gnu.org/licenses/>.
 *
 *
 * Based on the MLD support added to br_multicast.c by YOSHIFUJI Hideaki.
 */

#include <linux/skbuff.h>
#include <net/ipv6.h>
#include <net/mld.h>
#include <net/addrconf.h>
#include <net/ip6_checksum.h>

static int ipv6_mc_check_ip6hdr(struct sk_buff *skb)
{
	const struct ipv6hdr *ip6h;
	unsigned int len;
	unsigned int offset = skb_network_offset(skb) + sizeof(*ip6h);

	if (!pskb_may_pull(skb, offset))
		return -EINVAL;

	ip6h = ipv6_hdr(skb);

	if (ip6h->version != 6)
		return -EINVAL;

	len = offset + ntohs(ip6h->payload_len);
	if (skb->len < len || len <= offset)
		return -EINVAL;

	return 0;
}

static int ipv6_mc_check_exthdrs(struct sk_buff *skb)
{
	const struct ipv6hdr *ip6h;
50
	int offset;
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
	u8 nexthdr;
	__be16 frag_off;

	ip6h = ipv6_hdr(skb);

	if (ip6h->nexthdr != IPPROTO_HOPOPTS)
		return -ENOMSG;

	nexthdr = ip6h->nexthdr;
	offset = skb_network_offset(skb) + sizeof(*ip6h);
	offset = ipv6_skip_exthdr(skb, offset, &nexthdr, &frag_off);

	if (offset < 0)
		return -EINVAL;

	if (nexthdr != IPPROTO_ICMPV6)
		return -ENOMSG;

	skb_set_transport_header(skb, offset);

	return 0;
}

static int ipv6_mc_check_mld_reportv2(struct sk_buff *skb)
{
	unsigned int len = skb_transport_offset(skb);

	len += sizeof(struct mld2_report);

	return pskb_may_pull(skb, len) ? 0 : -EINVAL;
}

static int ipv6_mc_check_mld_query(struct sk_buff *skb)
{
	struct mld_msg *mld;
	unsigned int len = skb_transport_offset(skb);

	/* RFC2710+RFC3810 (MLDv1+MLDv2) require link-local source addresses */
	if (!(ipv6_addr_type(&ipv6_hdr(skb)->saddr) & IPV6_ADDR_LINKLOCAL))
		return -EINVAL;

	len += sizeof(struct mld_msg);
	if (skb->len < len)
		return -EINVAL;

	/* MLDv1? */
	if (skb->len != len) {
		/* or MLDv2? */
		len += sizeof(struct mld2_query) - sizeof(struct mld_msg);
		if (skb->len < len || !pskb_may_pull(skb, len))
			return -EINVAL;
	}

	mld = (struct mld_msg *)skb_transport_header(skb);

	/* RFC2710+RFC3810 (MLDv1+MLDv2) require the multicast link layer
	 * all-nodes destination address (ff02::1) for general queries
	 */
	if (ipv6_addr_any(&mld->mld_mca) &&
	    !ipv6_addr_is_ll_all_nodes(&ipv6_hdr(skb)->daddr))
		return -EINVAL;

	return 0;
}

static int ipv6_mc_check_mld_msg(struct sk_buff *skb)
{
	struct mld_msg *mld = (struct mld_msg *)skb_transport_header(skb);

	switch (mld->mld_type) {
	case ICMPV6_MGM_REDUCTION:
	case ICMPV6_MGM_REPORT:
		/* fall through */
		return 0;
	case ICMPV6_MLD2_REPORT:
		return ipv6_mc_check_mld_reportv2(skb);
	case ICMPV6_MGM_QUERY:
		return ipv6_mc_check_mld_query(skb);
	default:
		return -ENOMSG;
	}
}

static inline __sum16 ipv6_mc_validate_checksum(struct sk_buff *skb)
{
	return skb_checksum_validate(skb, IPPROTO_ICMPV6, ip6_compute_pseudo);
}

static int __ipv6_mc_check_mld(struct sk_buff *skb,
			       struct sk_buff **skb_trimmed)

{
	struct sk_buff *skb_chk = NULL;
	unsigned int transport_len;
	unsigned int len = skb_transport_offset(skb) + sizeof(struct mld_msg);
146
	int ret = -EINVAL;
147 148 149 150 151 152 153

	transport_len = ntohs(ipv6_hdr(skb)->payload_len);
	transport_len -= skb_transport_offset(skb) - sizeof(struct ipv6hdr);

	skb_chk = skb_checksum_trimmed(skb, transport_len,
				       ipv6_mc_validate_checksum);
	if (!skb_chk)
154
		goto err;
155

156 157
	if (!pskb_may_pull(skb_chk, len))
		goto err;
158 159

	ret = ipv6_mc_check_mld_msg(skb_chk);
160 161
	if (ret)
		goto err;
162 163 164

	if (skb_trimmed)
		*skb_trimmed = skb_chk;
165 166
	/* free now unneeded clone */
	else if (skb_chk != skb)
167 168
		kfree_skb(skb_chk);

169 170 171 172 173 174 175
	ret = 0;

err:
	if (ret && skb_chk && skb_chk != skb)
		kfree_skb(skb_chk);

	return ret;
176 177 178 179 180 181 182 183
}

/**
 * ipv6_mc_check_mld - checks whether this is a sane MLD packet
 * @skb: the skb to validate
 * @skb_trimmed: to store an skb pointer trimmed to IPv6 packet tail (optional)
 *
 * Checks whether an IPv6 packet is a valid MLD packet. If so sets
184
 * skb transport header accordingly and returns zero.
185 186 187 188 189 190 191 192 193 194 195 196 197 198
 *
 * -EINVAL: A broken packet was detected, i.e. it violates some internet
 *  standard
 * -ENOMSG: IP header validation succeeded but it is not an MLD packet.
 * -ENOMEM: A memory allocation failure happened.
 *
 * Optionally, an skb pointer might be provided via skb_trimmed (or set it
 * to NULL): After parsing an MLD packet successfully it will point to
 * an skb which has its tail aligned to the IP packet end. This might
 * either be the originally provided skb or a trimmed, cloned version if
 * the skb frame had data beyond the IP packet. A cloned skb allows us
 * to leave the original skb and its full frame unchanged (which might be
 * desirable for layer 2 frame jugglers).
 *
199 200
 * Caller needs to set the skb network header and free any returned skb if it
 * differs from the provided skb.
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
 */
int ipv6_mc_check_mld(struct sk_buff *skb, struct sk_buff **skb_trimmed)
{
	int ret;

	ret = ipv6_mc_check_ip6hdr(skb);
	if (ret < 0)
		return ret;

	ret = ipv6_mc_check_exthdrs(skb);
	if (ret < 0)
		return ret;

	return __ipv6_mc_check_mld(skb, skb_trimmed);
}
EXPORT_SYMBOL(ipv6_mc_check_mld);