未验证 提交 bf481550 编写于 作者: J Jacek Czaja 提交者: GitHub

[Ready to merge] oneDNN NHWC matmul & elementwise kernels fixes (#42506)

* - fix to crash

- more fixes

- added diagnostic

- matmul output fixes.

- compilation fix

- stop rotating too small shapes

* - Added enabling of matmul_V2 onednn test
上级 ae4d1ec1
...@@ -103,11 +103,12 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -103,11 +103,12 @@ class ElementwiseOp : public framework::OperatorWithKernel {
std::vector<int> out_dims_array(max_dim); std::vector<int> out_dims_array(max_dim);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// (jczaja): Broadcasting of dims has to be done on Paddle shapes (NHWC) // (jczaja): Broadcasting of dims has to be done on Paddle shapes (NHWC)
// if model is using NHWC. // if model is using NHWC and any of shapes in at least 3D
bool should_rotate = bool should_rotate =
ctx->IsRunMKLDNNKernel() && ctx->IsRunMKLDNNKernel() &&
(platform::MKLDNNDeviceContext::tls().get_cur_paddle_data_layout() == (platform::MKLDNNDeviceContext::tls().get_cur_paddle_data_layout() ==
framework::DataLayout::kNHWC); framework::DataLayout::kNHWC) &&
(x_dims.size() >= 3 || y_dims.size() >= 3);
if (should_rotate) { if (should_rotate) {
// Pick bigger shape and rotate this one // Pick bigger shape and rotate this one
bool x_over_y = (x_dims.size() > y_dims.size()); bool x_over_y = (x_dims.size() > y_dims.size());
......
...@@ -585,6 +585,19 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -585,6 +585,19 @@ class MatMulOp : public framework::OperatorWithKernel {
auto dim_x = GetDimForInput(*context, "X"); auto dim_x = GetDimForInput(*context, "X");
auto dim_y = GetDimForInput(*context, "Y"); auto dim_y = GetDimForInput(*context, "Y");
#ifdef PADDLE_WITH_MKLDNN
// (jczaja): For NHWC execution output shape needs
// to be computed like instead x*y we are to do y*x
bool channelwise_onednn =
context->IsRunMKLDNNKernel() &&
(platform::MKLDNNDeviceContext::tls().get_cur_paddle_data_layout() ==
framework::DataLayout::kNHWC);
if (channelwise_onednn) {
std::swap(dim_x, dim_y);
}
#endif
auto mat_dim_x = phi::funcs::CreateMatrixDescriptor( auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(
RowMatrixFromVector(dim_x), 0, RowMatrixFromVector(dim_x), 0,
context->Attrs().Get<bool>("transpose_X")); context->Attrs().Get<bool>("transpose_X"));
...@@ -770,6 +783,21 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -770,6 +783,21 @@ class MatMulOp : public framework::OperatorWithKernel {
framework::TransToProtoVarType(tensor.dtype()), tensor.place(), framework::TransToProtoVarType(tensor.dtype()), tensor.place(),
tensor.layout()); tensor.layout());
} else { } else {
#ifdef PADDLE_WITH_MKLDNN
// When matmul is first oneDNN op in a chain (there was some non oneDNN op
// previously)
// then we also need to rotate shape NHWC -> NCWH
if ((expected_kernel_type.data_layout_ ==
framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN) &&
paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() ==
framework::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(),
framework::DataLayout::kNHWC);
}
#endif
return framework::OpKernelType(expected_kernel_type.data_type_, return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout()); tensor.place(), tensor.layout());
} }
......
...@@ -274,6 +274,22 @@ class MatMulV2Op : public framework::OperatorWithKernel { ...@@ -274,6 +274,22 @@ class MatMulV2Op : public framework::OperatorWithKernel {
framework::TransToProtoVarType(tensor.dtype()), tensor.place(), framework::TransToProtoVarType(tensor.dtype()), tensor.place(),
tensor.layout()); tensor.layout());
} else { } else {
#ifdef PADDLE_WITH_MKLDNN
// When matmul_v2 is first oneDNN op in a chain (there was some non oneDNN
// op
// previously)
// then we also need to rotate shape NHWC -> NCWH
if ((expected_kernel_type.data_layout_ ==
framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN) &&
paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() ==
framework::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(),
framework::DataLayout::kNHWC);
}
#endif
return framework::OpKernelType(expected_kernel_type.data_type_, return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout()); tensor.place(), tensor.layout());
} }
......
...@@ -78,13 +78,6 @@ tf_pd<Type> MKLDNNBwdPrimitiveDesc(const Engine& e, const Primitive& p, ...@@ -78,13 +78,6 @@ tf_pd<Type> MKLDNNBwdPrimitiveDesc(const Engine& e, const Primitive& p,
inline void MatchShapeToLayout(framework::Tensor* tensor_in, inline void MatchShapeToLayout(framework::Tensor* tensor_in,
framework::DataLayout from, framework::DataLayout from,
framework::DataLayout to) { framework::DataLayout to) {
// In these data layouts, channel dimension is either on 2nd position: nChw or
// at last nhwC, so for dim==2 these layouts are the same and nothing should
// be done. Similarly for dim==1 when you have just one possible combination.
if (tensor_in->dims().size() < 3) {
return;
}
auto print_dims = [](const std::vector<int>& dims) { auto print_dims = [](const std::vector<int>& dims) {
std::ostringstream oss; std::ostringstream oss;
...@@ -101,6 +94,15 @@ inline void MatchShapeToLayout(framework::Tensor* tensor_in, ...@@ -101,6 +94,15 @@ inline void MatchShapeToLayout(framework::Tensor* tensor_in,
return oss.str(); return oss.str();
}; };
// In these data layouts, channel dimension is either on 2nd position: nChw or
// at last nhwC, so for dim==2 these layouts are the same and nothing should
// be done. Similarly for dim==1 when you have just one possible combination.
if (tensor_in->dims().size() < 3) {
VLOG(3) << "Keeping kMKLDNN/kNHWC/kNDHWC output_shape"
<< print_dims(phi::vectorize<int>(tensor_in->dims()));
return;
}
switch (from) { switch (from) {
case framework::DataLayout::kMKLDNN: case framework::DataLayout::kMKLDNN:
if ((to == framework::DataLayout::kNHWC) || if ((to == framework::DataLayout::kNHWC) ||
...@@ -571,6 +573,12 @@ inline void RegisterModelLayout( ...@@ -571,6 +573,12 @@ inline void RegisterModelLayout(
std::vector<std::unique_ptr<framework::OperatorBase>>& ops, std::vector<std::unique_ptr<framework::OperatorBase>>& ops,
const platform::Place& place) { const platform::Place& place) {
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
// If there is already registered NHWC then quit this call
// not to overwrite setting with analysis of internal "while" op block
if (platform::MKLDNNDeviceContext::tls().get_cur_paddle_data_layout() ==
framework::DataLayout::kNHWC)
return;
VLOG(4) << "RegisterModelLayout for mkldnn"; VLOG(4) << "RegisterModelLayout for mkldnn";
auto check_attrib = [](std::unique_ptr<framework::OperatorBase>& op, auto check_attrib = [](std::unique_ptr<framework::OperatorBase>& op,
const std::string& attrib_name) -> bool { const std::string& attrib_name) -> bool {
......
...@@ -67,6 +67,8 @@ class TestMatMulV2VectorXVectorOneDNNOp(OpTest): ...@@ -67,6 +67,8 @@ class TestMatMulV2VectorXVectorOneDNNOp(OpTest):
self.y_shape = (100, ) self.y_shape = (100, )
self.trans_x = False self.trans_x = False
self.trans_y = False self.trans_y = False
self._cpu_only = True
self.use_mkldnn = True
def set_inputs(self, x, y): def set_inputs(self, x, y):
self.inputs = {'X': x, 'Y': y} self.inputs = {'X': x, 'Y': y}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册