diff --git a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc index bd9bee8873250da1cefba7ef8903a61447a438da..40f64800a0b81a161805857cb3e0a3855f386720 100644 --- a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc @@ -134,6 +134,15 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { EnforceLayouts(multi_input); Tensor* output = ctx.Output("Out"); int concat_axis = ctx.Attr("axis"); + const int rank = multi_input[0]->dims().size(); + PADDLE_ENFORCE_EQ( + concat_axis >= -rank && concat_axis < rank, true, + platform::errors::InvalidArgument( + "The axis is expected to be in range of [%d, %d), but got %d", + -rank, rank, concat_axis)); + if (concat_axis < 0) { + concat_axis = concat_axis + rank; + } auto& dev_ctx = ctx.template device_context(); auto place = GetCpuPlace(ctx);