未验证 提交 e80acff3 编写于 作者: J jakpiase 提交者: GitHub

added fix for matmul and support for 6 rank tensor (#35740)

上级 bd79ae09
...@@ -655,6 +655,24 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -655,6 +655,24 @@ class MatMulOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument("reshape_out supported rank is 3, " platform::errors::InvalidArgument("reshape_out supported rank is 3, "
"received %d", "received %d",
reshape_out_size)); reshape_out_size));
auto it = std::find(reshape_out.begin(), reshape_out.end(), -1);
// if "-1" is present then one of reshape dims must be infered
if (it != reshape_out.end()) {
int index = std::distance(reshape_out.begin(), it);
auto ddim_out_vec = framework::vectorize(ddim_out);
int ddim_out_product =
std::accumulate(ddim_out_vec.begin(), ddim_out_vec.end(), 1,
std::multiplies<int>());
int reshape_out_product = std::accumulate(
reshape_out.begin(), reshape_out.end(), -1, std::multiplies<int>());
reshape_out[index] = ddim_out_product / reshape_out_product;
}
framework::DDim shape_out = framework::DDim shape_out =
ddim_out.transpose(transpose_out).reshape(reshape_out); ddim_out.transpose(transpose_out).reshape(reshape_out);
context->SetOutputDim("Out", shape_out); context->SetOutputDim("Out", shape_out);
......
...@@ -358,6 +358,8 @@ inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size, ...@@ -358,6 +358,8 @@ inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size,
} else if (data_format == MKLDNNMemoryFormat::nhwc) { } else if (data_format == MKLDNNMemoryFormat::nhwc) {
return MKLDNNMemoryFormat::ndhwc; return MKLDNNMemoryFormat::ndhwc;
} }
} else if (dims_size == 6) {
return MKLDNNMemoryFormat::abcdef;
} }
return data_format; return data_format;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册