diff --git a/net/netlink/genetlink.c b/net/netlink/genetlink.c index c785080e94012e08d3740f6d5ada22684d27b7bc..a98c94594508c30d28514f9c89bf636c38184750 100644 --- a/net/netlink/genetlink.c +++ b/net/netlink/genetlink.c @@ -468,6 +468,45 @@ static void genl_dumpit_info_free(const struct genl_dumpit_info *info) kfree(info); } +static struct nlattr ** +genl_family_rcv_msg_attrs_parse(const struct genl_family *family, + struct nlmsghdr *nlh, + struct netlink_ext_ack *extack, + const struct genl_ops *ops, + int hdrlen, + enum genl_validate_flags no_strict_flag) +{ + enum netlink_validation validate = ops->validate & no_strict_flag ? + NL_VALIDATE_LIBERAL : + NL_VALIDATE_STRICT; + struct nlattr **attrbuf; + int err; + + if (family->maxattr && family->parallel_ops) { + attrbuf = kmalloc_array(family->maxattr + 1, + sizeof(struct nlattr *), GFP_KERNEL); + if (!attrbuf) + return ERR_PTR(-ENOMEM); + } else { + attrbuf = family->attrbuf; + } + + err = __nlmsg_parse(nlh, hdrlen, attrbuf, family->maxattr, + family->policy, validate, extack); + if (err && family->maxattr && family->parallel_ops) { + kfree(attrbuf); + return ERR_PTR(err); + } + return attrbuf; +} + +static void genl_family_rcv_msg_attrs_free(const struct genl_family *family, + struct nlattr **attrbuf) +{ + if (family->maxattr && family->parallel_ops) + kfree(attrbuf); +} + static int genl_lock_start(struct netlink_callback *cb) { const struct genl_ops *ops = genl_dumpit_info(cb)->ops; @@ -599,26 +638,11 @@ static int genl_family_rcv_msg_doit(const struct genl_family *family, if (!ops->doit) return -EOPNOTSUPP; - if (family->maxattr && family->parallel_ops) { - attrbuf = kmalloc_array(family->maxattr + 1, - sizeof(struct nlattr *), - GFP_KERNEL); - if (attrbuf == NULL) - return -ENOMEM; - } else - attrbuf = family->attrbuf; - - if (attrbuf) { - enum netlink_validation validate = NL_VALIDATE_STRICT; - - if (ops->validate & GENL_DONT_VALIDATE_STRICT) - validate = NL_VALIDATE_LIBERAL; - - err = __nlmsg_parse(nlh, hdrlen, attrbuf, family->maxattr, - family->policy, validate, extack); - if (err < 0) - goto out; - } + attrbuf = genl_family_rcv_msg_attrs_parse(family, nlh, extack, + ops, hdrlen, + GENL_DONT_VALIDATE_STRICT); + if (IS_ERR(attrbuf)) + return PTR_ERR(attrbuf); info.snd_seq = nlh->nlmsg_seq; info.snd_portid = NETLINK_CB(skb).portid; @@ -642,8 +666,7 @@ static int genl_family_rcv_msg_doit(const struct genl_family *family, family->post_doit(ops, skb, &info); out: - if (family->parallel_ops) - kfree(attrbuf); + genl_family_rcv_msg_attrs_free(family, attrbuf); return err; }