diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_main.c b/drivers/net/ethernet/mellanox/mlx5/core/en_main.c index 54bae797b3381410d632d1168bc877b8dee961a3..491cff9cf368e75f6f9c5aae2b9ec6d665817e60 100644 --- a/drivers/net/ethernet/mellanox/mlx5/core/en_main.c +++ b/drivers/net/ethernet/mellanox/mlx5/core/en_main.c @@ -3144,11 +3144,21 @@ static int mlx5e_xdp_set(struct net_device *netdev, struct bpf_prog *prog) if (was_opened && reset) mlx5e_close_locked(netdev); + if (was_opened && !reset) { + /* num_channels is invariant here, so we can take the + * batched reference right upfront. + */ + prog = bpf_prog_add(prog, priv->params.num_channels); + if (IS_ERR(prog)) { + err = PTR_ERR(prog); + goto unlock; + } + } - /* exchange programs */ + /* exchange programs, extra prog reference we got from caller + * as long as we don't fail from this point onwards. + */ old_prog = xchg(&priv->xdp_prog, prog); - if (prog) - bpf_prog_add(prog, 1); if (old_prog) bpf_prog_put(old_prog); @@ -3164,7 +3174,6 @@ static int mlx5e_xdp_set(struct net_device *netdev, struct bpf_prog *prog) /* exchanging programs w/o reset, we update ref counts on behalf * of the channels RQs here. */ - bpf_prog_add(prog, priv->params.num_channels); for (i = 0; i < priv->params.num_channels; i++) { struct mlx5e_channel *c = priv->channel[i];