From dbe2497768841015a49808f4bf2336dc41dd64ea Mon Sep 17 00:00:00 2001 From: Wilber Date: Tue, 26 May 2020 11:36:52 +0800 Subject: [PATCH] fix mkldnn concat bug. test=develop (#24722) --- paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc index bd9bee88732..40f64800a0b 100644 --- a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc @@ -134,6 +134,15 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { EnforceLayouts(multi_input); Tensor* output = ctx.Output("Out"); int concat_axis = ctx.Attr("axis"); + const int rank = multi_input[0]->dims().size(); + PADDLE_ENFORCE_EQ( + concat_axis >= -rank && concat_axis < rank, true, + platform::errors::InvalidArgument( + "The axis is expected to be in range of [%d, %d), but got %d", + -rank, rank, concat_axis)); + if (concat_axis < 0) { + concat_axis = concat_axis + rank; + } auto& dev_ctx = ctx.template device_context(); auto place = GetCpuPlace(ctx); -- GitLab