未验证 提交 dbe24977 编写于 作者: W Wilber 提交者: GitHub

fix mkldnn concat bug. test=develop (#24722)

上级 b9260b36
......@@ -134,6 +134,15 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
EnforceLayouts(multi_input);
Tensor* output = ctx.Output<Tensor>("Out");
int concat_axis = ctx.Attr<int>("axis");
const int rank = multi_input[0]->dims().size();
PADDLE_ENFORCE_EQ(
concat_axis >= -rank && concat_axis < rank, true,
platform::errors::InvalidArgument(
"The axis is expected to be in range of [%d, %d), but got %d",
-rank, rank, concat_axis));
if (concat_axis < 0) {
concat_axis = concat_axis + rank;
}
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
auto place = GetCpuPlace(ctx);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册