From ed9ceb9f989b9444062f2071d7cd4790588e039e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Gallus?= Date: Tue, 26 Nov 2019 03:05:52 +0100 Subject: [PATCH] Refactor MKL-DNN ElementwiseMul (#21061) * Refactor MKL-DNN ElementwiseMul remove manual fallback, remove format attrs test=develop * Refine PADDLE_ENFORCEs in eltwise_mul_op.h test=develop * Make ElementwiseMulOp inherit from ElementwiseOp * Change type of simd_width to int test=develop * Remove Constructor extensions in ElementwiseOp and ElementwiseMulOp test=develop * Restore attributes test=develop * Fix test coverage for mkldnn eltwise mul test=develop * Conform to new is_run_common_broadcast API test=develop * Add UT for AreDimsAndFormatCorrect test=develop --- .../operators/elementwise/CMakeLists.txt | 1 + .../elementwise/elementwise_mul_op.cc | 2 +- .../elementwise/elementwise_mul_op.h | 45 +++++ .../operators/elementwise/elementwise_op.h | 15 +- .../mkldnn/elementwise_mul_mkldnn_op.cc | 165 ++++------------ .../test_elementwise_mul_op_dim.cc | 115 +++++++++++ .../mkldnn/test_elementwise_mul_mkldnn_op.py | 179 +----------------- 7 files changed, 205 insertions(+), 317 deletions(-) create mode 100644 paddle/fluid/operators/elementwise/test_elementwise_mul_op_dim.cc diff --git a/paddle/fluid/operators/elementwise/CMakeLists.txt b/paddle/fluid/operators/elementwise/CMakeLists.txt index 94886066ca5..6347cac91d5 100644 --- a/paddle/fluid/operators/elementwise/CMakeLists.txt +++ b/paddle/fluid/operators/elementwise/CMakeLists.txt @@ -4,3 +4,4 @@ register_operators() cc_test(test_elementwise_add_op_inplace SRCS test_elementwise_add_op_inplace.cc DEPS op_registry elementwise_add_op scope device_context enforce executor) cc_test(test_elementwise_div_grad_grad SRCS test_elementwise_div_grad_grad.cc DEPS op_registry elementwise_div_op scope device_context enforce executor) cc_test(test_elementwise_add_grad_grad SRCS test_elementwise_add_grad_grad.cc DEPS op_registry elementwise_add_op scope device_context enforce executor) +cc_test(test_elementwise_mul_op_correct_dims SRCS test_elementwise_mul_op_dim.cc DEPS op_registry elementwise_mul_op scope device_context enforce executor) diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc index aaa6cfe0346..6c1870f953f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc @@ -117,7 +117,7 @@ class ElementwiseMulDoubleGradMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(elementwise_mul, ops::ElementwiseOp, +REGISTER_OPERATOR(elementwise_mul, ops::ElementwiseMulOp, ops::ElementwiseMulOpMaker, ops::ElementwiseOpInferVarType, ops::ElementwiseMulOpGradMaker, ops::ElementwiseMulOpGradMaker); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index 502da88cf04..e41ee0b7417 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -13,14 +13,59 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/cpu_info.h" namespace paddle { namespace operators { +class ElementwiseMulOp : public ElementwiseOp { + public: + using Tensor = framework::Tensor; + using ElementwiseOp::ElementwiseOp; + +#ifdef PADDLE_WITH_MKLDNN + static bool AreDimsAndFormatCorrect(const framework::ExecutionContext& ctx, + int simd_width, + mkldnn::memory::format x_format) { + using Tensor = framework::Tensor; + using paddle::framework::vectorize; + using mkldnn::memory; + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto x_dims = vectorize(x->dims()); + const bool are_dims_divisable = !(x_dims[1] % simd_width); + const bool is_x_format_correct = x->format() == x_format; + const bool is_y_format_correct = vectorize(y->dims()).size() == 2; + return are_dims_divisable && is_x_format_correct && is_y_format_correct; + } +#endif + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + +#ifdef PADDLE_WITH_MKLDNN + using mkldnn::memory; + if (platform::CanMKLDNNBeUsed(ctx)) { + bool can_use_avx512_kernel = + platform::MayIUse(platform::avx512f) && + AreDimsAndFormatCorrect(ctx, 16, memory::format::nChw16c); + if (can_use_avx512_kernel) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + template void default_elementwise_mul(const framework::ExecutionContext& ctx, const framework::Tensor* x, diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index 120ee9336d6..4e25947d7d5 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -130,21 +130,10 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { .EqualGreaterThan(-1); AddAttr("use_mkldnn", "(bool, default false). Used by MKLDNN.") .SetDefault(false); - AddAttr( - "x_data_format", - "(string, default NCHW) Only used in mkldnn" - "An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". " - "Defaults to \"\". Specify the data format of the output data, " - "the input will be transformed automatically. ") + AddAttr("x_data_format", "This parameter is no longer used.") .SetDefault(""); - AddAttr( - "y_data_format", - "(string, default \"\") Only used in mkldnn" - "An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". " - "Defaults to \"\". Specify the data format of the output data, " - "the input will be transformed automatically. ") + AddAttr("y_data_format", "This parameter is no longer used.") .SetDefault(""); - AddOpComment(); } diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc index 6a3ecd493a8..bc74a67f3e5 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc @@ -32,38 +32,28 @@ using framework::DataLayout; using mkldnn::memory; using platform::StringToMKLDNNFormat; -static void UpdateDataFormat(const framework::ExecutionContext& ctx, - framework::Tensor* tensor, const char* attribute) { - if (ctx.op().HasAttr(attribute)) { - auto format_as_string = ctx.Attr(attribute); - auto format = StringToMKLDNNFormat(&format_as_string); - if (format != MKLDNNMemoryFormat::any) { - tensor->set_format(format); +template +static void ComputeBroadcastedMultiply(const T* x_data, const T* y_data, + T* z_data, int64_t n, int64_t c, + int64_t h, int64_t w, int simd_width, + void (*multiply)(const T*, const T*, T*, + int, int)) { + const int64_t C = c / simd_width; +#pragma omp parallel for collapse(2) + for (int ni = 0; ni < n; ni++) { + for (int ci = 0; ci < C; ci++) { + auto ptr_x = + x_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + + auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; + auto ptr_z = + z_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + + multiply(ptr_x, ptr_y, ptr_z, h, w); } } } -template -static void ReorderInput(framework::Tensor* tensor, - const platform::Place& place, - const mkldnn::engine& engine, bool isFourDim) { - using platform::to_void_cast; - auto dims = paddle::framework::vectorize(tensor->dims()); - framework::Tensor out_tensor; - out_tensor.Resize(tensor->dims()); - out_tensor.set_format(isFourDim ? MKLDNNMemoryFormat::nchw - : MKLDNNMemoryFormat::nc); - out_tensor.set_layout(tensor->layout()); - mkldnn::memory input_memory = { - {{dims, platform::MKLDNNGetDataType(), tensor->format()}, engine}, - to_void_cast(tensor->data())}; - mkldnn::memory output_memory = { - {{dims, platform::MKLDNNGetDataType(), out_tensor.format()}, engine}, - to_void_cast(out_tensor.mutable_data(place))}; - platform::Reorder(input_memory, output_memory); - tensor->ShareDataWith(out_tensor); -} - template class ElementwiseMulMKLDNNKernel : public framework::OpKernel { public: @@ -82,105 +72,26 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { auto y_dims_untrimmed = y->dims(); auto x_int_dims = paddle::framework::vectorize(x_dims); - UpdateDataFormat(ctx, const_cast(x), "x_data_format"); - UpdateDataFormat(ctx, const_cast(y), "y_data_format"); - - const bool is_avx512_enabled = platform::MayIUse(platform::avx512f); - const bool are_dims_divisable = !(x_int_dims[1] % 16); - const bool is_x_format_correct = x->format() == MKLDNNMemoryFormat::nChw16c; - const bool is_y_format_correct = y->format() == MKLDNNMemoryFormat::nc; - if (is_x_format_correct && is_y_format_correct && are_dims_divisable && - is_avx512_enabled) { - int pre, n, post, is_run_common_broadcast; - get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post, - &is_run_common_broadcast); - - if (post == 1) { - PADDLE_THROW("Not implemented when post is 1"); - } else { - // Just check whether it works for RE-Resnext. - PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions"); - - int n = x_dims[0]; - int c = x_dims[1]; - int h = x_dims[2]; - int w = x_dims[3]; - - PADDLE_ENFORCE(y_dims_untrimmed[0] == n && y_dims_untrimmed[1] == c, - "Y should be in nc format"); - - constexpr int simd_width = 16; - int C = c / simd_width; - - auto multiply = jit::KernelFuncs, - platform::CPUPlace>::Cache() - .At(0); -#pragma omp parallel for collapse(2) - for (int ni = 0; ni < n; ni++) { - for (int ci = 0; ci < C; ci++) { - auto ptr_x = - x_data + ni * C * h * w * simd_width + ci * h * w * simd_width; - - auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; - auto ptr_z = - z_data + ni * C * h * w * simd_width + ci * h * w * simd_width; - - multiply(ptr_x, ptr_y, ptr_z, h, w); - } - } - } - - z->set_layout(DataLayout::kMKLDNN); - z->set_format(x->format()); - } else { - // Fallback to naive version: - const bool are_inputs_in_same_format = x->format() == y->format(); - const bool is_x_nchw = x->format() == MKLDNNMemoryFormat::nchw; - const bool is_x_nc = x->format() == MKLDNNMemoryFormat::nc; - const bool is_x_x = x->format() == MKLDNNMemoryFormat::x; - const bool is_y_nchw = y->format() == MKLDNNMemoryFormat::nchw; - const bool is_y_nc = y->format() == MKLDNNMemoryFormat::nc; - const bool is_y_x = y->format() == MKLDNNMemoryFormat::x; - if (!are_inputs_in_same_format) { - using platform::MKLDNNDeviceContext; - auto& dev_ctx = ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); - if (!(is_x_nchw || is_x_nc || is_x_x)) - ReorderInput(const_cast(x), ctx.GetPlace(), mkldnn_engine, - x->dims().size() == 4); - if (!(is_y_nchw || is_y_nc || is_y_x)) - ReorderInput(const_cast(y), ctx.GetPlace(), mkldnn_engine, - y->dims().size() == 4); - } - - auto mul_func = [](T a, T b) -> T { return a * b; }; - - TransformFunctor - functor( - x, y, z, - ctx.template device_context(), - mul_func); - - axis = (axis == -1 ? x_dims.size() - y_dims_untrimmed.size() : axis); - PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), - "Axis should be in range [0, x_dims)"); - - auto y_dims = trim_trailing_singular_dims(y_dims_untrimmed); - axis = (y_dims.size() == 0) ? x_dims.size() : axis; - - int pre, n, post, is_run_common_broadcast; - get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post, - &is_run_common_broadcast); - - if (post == 1) { - functor.RunRowWise(n, pre); - } else { - functor.RunMidWise(n, pre, post); - } - z->set_layout(DataLayout::kMKLDNN); - z->set_format(x->format()); - } + int pre, num, post, is_run_common_broadcast; + get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &num, &post, + &is_run_common_broadcast); + + if (post == 1) PADDLE_THROW("Not implemented when post is 1"); + + const int64_t n = x_dims[0]; + const int64_t c = x_dims[1]; + const int64_t h = x_dims[2]; + const int64_t w = x_dims[3]; + + const int simd_width = 16; + auto multiply = + jit::KernelFuncs, platform::CPUPlace>::Cache() + .At(0); + ComputeBroadcastedMultiply(x_data, y_data, z_data, n, c, h, w, simd_width, + multiply); + + z->set_layout(DataLayout::kMKLDNN); + z->set_format(x->format()); } }; } // namespace operators diff --git a/paddle/fluid/operators/elementwise/test_elementwise_mul_op_dim.cc b/paddle/fluid/operators/elementwise/test_elementwise_mul_op_dim.cc new file mode 100644 index 00000000000..4477aa0a3f8 --- /dev/null +++ b/paddle/fluid/operators/elementwise/test_elementwise_mul_op_dim.cc @@ -0,0 +1,115 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" + +USE_OP(elementwise_mul); + +namespace paddle { +namespace operators { +#ifdef PADDLE_WITH_MKLDNN + +using framework::Scope; +using framework::LoDTensor; +using framework::OpRegistry; +using framework::OperatorBase; +using framework::RuntimeContext; +using framework::ExecutionContext; + +struct TestData { + int64_t channel_num; + MKLDNNMemoryFormat format; + framework::DDim y_dims; + bool supposed_to_fail; +}; + +void MainTest(const TestData& test_data) { + auto place = platform::CPUPlace(); + Scope scope; + + auto* x = scope.Var("x")->GetMutable(); + auto* y = scope.Var("y")->GetMutable(); + scope.Var("out")->GetMutable(); + + x->Resize({1, test_data.channel_num, 3, 3}); + y->Resize(test_data.y_dims); + + x->set_format(test_data.format); + y->set_format(MKLDNNMemoryFormat::nc); + + std::unique_ptr op = OpRegistry::CreateOp( + "elementwise_mul", {{"X", {"x"}}, {"Y", {"y"}}}, {{"Out", {"out"}}}, {}); + + auto& pool = platform::DeviceContextPool::Instance(); + auto* dev_ctx = dynamic_cast(pool.Get(place)); + + RuntimeContext runtime_ctx = + RuntimeContext(op->Inputs(), op->Outputs(), scope); + ExecutionContext ctx = + ExecutionContext(*op, scope, *dev_ctx, runtime_ctx, nullptr); + bool result = ElementwiseMulOp::AreDimsAndFormatCorrect( + ctx, 16, MKLDNNMemoryFormat::nChw16c); + if (test_data.supposed_to_fail) + ASSERT_FALSE(result); + else + ASSERT_TRUE(result); +} + +// Checks if AreDimsAndFormatCorrect returns true when supplied with expected +// data +TEST(ElementwiseMulOpTester, correct_dims) { + TestData test_data; + test_data.channel_num = 16; + test_data.format = MKLDNNMemoryFormat::nChw16c; + test_data.y_dims = {1, test_data.channel_num}; + test_data.supposed_to_fail = false; + MainTest(test_data); +} + +// Checks if AreDimsAndFormatCorrect fails when channel_num is not divisable by +// 16 +TEST(ElementwiseMulOpTester, incorrect_channel_num) { + TestData test_data; + test_data.channel_num = 17; + test_data.format = MKLDNNMemoryFormat::nChw16c; + test_data.y_dims = {1, test_data.channel_num}; + test_data.supposed_to_fail = true; + MainTest(test_data); +} + +// Checks if AreDimsAndFormatCorrect fails when x format is different from +// nchw16c +TEST(ElementwiseMulOpTester, incorrect_format) { + TestData test_data; + test_data.channel_num = 16; + test_data.format = MKLDNNMemoryFormat::nchw; + test_data.y_dims = {1, test_data.channel_num}; + test_data.supposed_to_fail = true; + MainTest(test_data); +} + +// Checks if AreDimsAndFormatCorrect fails when y input is not 2-dimensional +TEST(ElementwiseMulOpTester, incorrect_y_dims) { + TestData test_data; + test_data.channel_num = 16; + test_data.format = MKLDNNMemoryFormat::nChw16c; + test_data.y_dims = {1, test_data.channel_num, 1}; + test_data.supposed_to_fail = true; + MainTest(test_data); +} +#endif +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_mkldnn_op.py index 043c544f26a..f8dd1011af2 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_mkldnn_op.py @@ -40,7 +40,7 @@ class TestElementwiseMulMKLDNNOp_Integrated_With_Convs(ElementwiseMulOp): self.filter_size2 = [1, 16, 2, 2] self.dilations = [1, 1] self.use_cudnn = False - self.data_format = "NCHW" + self.data_format = "ANYLAYOUT" self.input = np.random.random(self.input_size).astype(self.dtype) self.filter = np.random.random(self.filter_size).astype(self.dtype) self.filter2 = np.random.random(self.filter_size2).astype(self.dtype) @@ -97,7 +97,8 @@ class TestElementwiseMulMKLDNNOp_Integrated_With_Convs(ElementwiseMulOp): 'groups': self.groups, 'dilations': self.dilations, 'use_cudnn': self.use_cudnn, - 'use_mkldnn': self.use_mkldnn + 'use_mkldnn': self.use_mkldnn, + 'data_format': self.data_format }) elementwise_mul_op = block.append_op( type="elementwise_mul", @@ -152,179 +153,5 @@ class TestElementwiseMulMKLDNNOp_Integrated_With_Convs(ElementwiseMulOp): pass -class TestElementwiseMulMKLDNNOp_FallbackNCHW(ElementwiseMulOp): - def init_input_output(self): - self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype) - self.y = np.random.rand(1, 16).astype(self.dtype) - - self.out = self.x * self.y.reshape(1, 16, 1, 1) - - def init_kernel_type(self): - self.use_mkldnn = True - - def init_axis(self): - self.axis = 0 - - def test_check_grad_normal(self): - pass - - def test_check_grad_ingore_x(self): - pass - - def test_check_grad_ingore_y(self): - pass - - -class TestElementwiseMulMKLDNNOp_FallbackNCHW16C(ElementwiseMulOp): - def init_input_output(self): - x = np.random.rand(1, 16, 2, 2).astype(self.dtype) - self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) - y = np.random.rand(1, 16, 2, 2).astype(self.dtype) - self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) - - self.out = x * y - - def setUp(self): - super(TestElementwiseMulMKLDNNOp_FallbackNCHW16C, self).setUp() - self.attrs["x_data_format"] = "nchw16c" - self.attrs["y_data_format"] = "nchw16c" - self._cpu_only = True - - def init_kernel_type(self): - self.use_mkldnn = True - - def init_axis(self): - self.axis = 0 - - def test_check_grad_normal(self): - pass - - def test_check_grad_ingore_x(self): - pass - - def test_check_grad_ingore_y(self): - pass - - -class TestElementwiseMulMKLDNNOp_FallbackNoReorders(ElementwiseMulOp): - def init_input_output(self): - x = np.random.rand(1, 16, 2, 2).astype(self.dtype) - self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) - y = np.random.rand(1, 16, 2, 2).astype(self.dtype) - self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) - - self.out = x * y - - def setUp(self): - super(TestElementwiseMulMKLDNNOp_FallbackNoReorders, self).setUp() - self.attrs["x_data_format"] = "nchw16c" - self.attrs["y_data_format"] = "nchw16c" - self._cpu_only = True - - def init_kernel_type(self): - self.use_mkldnn = True - - def init_axis(self): - self.axis = 0 - - def test_check_grad_normal(self): - pass - - def test_check_grad_ingore_x(self): - pass - - def test_check_grad_ingore_y(self): - pass - - -class TestElementwiseMulMKLDNNOp_FallbackWithReorder1(ElementwiseMulOp): - def init_input_output(self): - self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype) - y = np.random.rand(1, 16, 2, 2).astype(self.dtype) - self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) - - self.out = self.x * y - - def setUp(self): - super(TestElementwiseMulMKLDNNOp_FallbackWithReorder1, self).setUp() - self.attrs["x_data_format"] = "nchw" - self.attrs["y_data_format"] = "nchw16c" - self._cpu_only = True - - def init_kernel_type(self): - self.use_mkldnn = True - - def init_axis(self): - self.axis = 0 - - def test_check_grad_normal(self): - pass - - def test_check_grad_ingore_x(self): - pass - - def test_check_grad_ingore_y(self): - pass - - -class TestElementwiseMulMKLDNNOp_FallbackWithReorder2(ElementwiseMulOp): - def init_input_output(self): - self.y = np.random.rand(1, 16, 2, 2).astype(self.dtype) - x = np.random.rand(1, 16, 2, 2).astype(self.dtype) - self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) - - self.out = x * self.y - - def setUp(self): - super(TestElementwiseMulMKLDNNOp_FallbackWithReorder2, self).setUp() - self.attrs["x_data_format"] = "nchw16c" - self.attrs["y_data_format"] = "nchw" - self._cpu_only = True - - def init_kernel_type(self): - self.use_mkldnn = True - - def init_axis(self): - self.axis = 0 - - def test_check_grad_normal(self): - pass - - def test_check_grad_ingore_x(self): - pass - - def test_check_grad_ingore_y(self): - pass - - -class TestElementwiseMulMKLDNNOp_FallbackNoReorders2(ElementwiseMulOp): - def init_input_output(self): - self.x = np.random.rand(1, 16).astype(self.dtype) - self.y = np.random.rand(1, 16).astype(self.dtype) - - self.out = self.x * self.y - - def setUp(self): - super(TestElementwiseMulMKLDNNOp_FallbackNoReorders2, self).setUp() - self.attrs["x_data_format"] = "nc" - self.attrs["y_data_format"] = "nc" - self._cpu_only = True - - def init_kernel_type(self): - self.use_mkldnn = True - - def init_axis(self): - self.axis = 0 - - def test_check_grad_normal(self): - pass - - def test_check_grad_ingore_x(self): - pass - - def test_check_grad_ingore_y(self): - pass - - if __name__ == '__main__': unittest.main() -- GitLab