提交 ed9ceb9f 编写于 作者: M Michał Gallus 提交者: Tao Luo

Refactor MKL-DNN ElementwiseMul (#21061)

* Refactor MKL-DNN ElementwiseMul

remove manual fallback, remove format attrs
test=develop

* Refine PADDLE_ENFORCEs in eltwise_mul_op.h

test=develop

* Make ElementwiseMulOp inherit from ElementwiseOp

* Change type of simd_width to int

test=develop

* Remove Constructor extensions in ElementwiseOp and ElementwiseMulOp

test=develop

* Restore attributes

test=develop

* Fix test coverage for mkldnn eltwise mul

test=develop

* Conform to new is_run_common_broadcast API

test=develop

* Add UT for AreDimsAndFormatCorrect

test=develop
上级 0a93635b
...@@ -4,3 +4,4 @@ register_operators() ...@@ -4,3 +4,4 @@ register_operators()
cc_test(test_elementwise_add_op_inplace SRCS test_elementwise_add_op_inplace.cc DEPS op_registry elementwise_add_op scope device_context enforce executor) cc_test(test_elementwise_add_op_inplace SRCS test_elementwise_add_op_inplace.cc DEPS op_registry elementwise_add_op scope device_context enforce executor)
cc_test(test_elementwise_div_grad_grad SRCS test_elementwise_div_grad_grad.cc DEPS op_registry elementwise_div_op scope device_context enforce executor) cc_test(test_elementwise_div_grad_grad SRCS test_elementwise_div_grad_grad.cc DEPS op_registry elementwise_div_op scope device_context enforce executor)
cc_test(test_elementwise_add_grad_grad SRCS test_elementwise_add_grad_grad.cc DEPS op_registry elementwise_add_op scope device_context enforce executor) cc_test(test_elementwise_add_grad_grad SRCS test_elementwise_add_grad_grad.cc DEPS op_registry elementwise_add_op scope device_context enforce executor)
cc_test(test_elementwise_mul_op_correct_dims SRCS test_elementwise_mul_op_dim.cc DEPS op_registry elementwise_mul_op scope device_context enforce executor)
...@@ -117,7 +117,7 @@ class ElementwiseMulDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -117,7 +117,7 @@ class ElementwiseMulDoubleGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_mul, ops::ElementwiseOp, REGISTER_OPERATOR(elementwise_mul, ops::ElementwiseMulOp,
ops::ElementwiseMulOpMaker, ops::ElementwiseOpInferVarType, ops::ElementwiseMulOpMaker, ops::ElementwiseOpInferVarType,
ops::ElementwiseMulOpGradMaker<paddle::framework::OpDesc>, ops::ElementwiseMulOpGradMaker<paddle::framework::OpDesc>,
ops::ElementwiseMulOpGradMaker<paddle::imperative::OpBase>); ops::ElementwiseMulOpGradMaker<paddle::imperative::OpBase>);
......
...@@ -13,14 +13,59 @@ See the License for the specific language governing permissions and ...@@ -13,14 +13,59 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class ElementwiseMulOp : public ElementwiseOp {
public:
using Tensor = framework::Tensor;
using ElementwiseOp::ElementwiseOp;
#ifdef PADDLE_WITH_MKLDNN
static bool AreDimsAndFormatCorrect(const framework::ExecutionContext& ctx,
int simd_width,
mkldnn::memory::format x_format) {
using Tensor = framework::Tensor;
using paddle::framework::vectorize;
using mkldnn::memory;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto x_dims = vectorize(x->dims());
const bool are_dims_divisable = !(x_dims[1] % simd_width);
const bool is_x_format_correct = x->format() == x_format;
const bool is_y_format_correct = vectorize(y->dims()).size() == 2;
return are_dims_divisable && is_x_format_correct && is_y_format_correct;
}
#endif
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
using mkldnn::memory;
if (platform::CanMKLDNNBeUsed(ctx)) {
bool can_use_avx512_kernel =
platform::MayIUse(platform::avx512f) &&
AreDimsAndFormatCorrect(ctx, 16, memory::format::nChw16c);
if (can_use_avx512_kernel) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void default_elementwise_mul(const framework::ExecutionContext& ctx, void default_elementwise_mul(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* x,
......
...@@ -130,21 +130,10 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -130,21 +130,10 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
.EqualGreaterThan(-1); .EqualGreaterThan(-1);
AddAttr<bool>("use_mkldnn", "(bool, default false). Used by MKLDNN.") AddAttr<bool>("use_mkldnn", "(bool, default false). Used by MKLDNN.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::string>( AddAttr<std::string>("x_data_format", "This parameter is no longer used.")
"x_data_format",
"(string, default NCHW) Only used in mkldnn"
"An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". "
"Defaults to \"\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault(""); .SetDefault("");
AddAttr<std::string>( AddAttr<std::string>("y_data_format", "This parameter is no longer used.")
"y_data_format",
"(string, default \"\") Only used in mkldnn"
"An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". "
"Defaults to \"\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault(""); .SetDefault("");
AddOpComment(); AddOpComment();
} }
......
...@@ -32,38 +32,28 @@ using framework::DataLayout; ...@@ -32,38 +32,28 @@ using framework::DataLayout;
using mkldnn::memory; using mkldnn::memory;
using platform::StringToMKLDNNFormat; using platform::StringToMKLDNNFormat;
static void UpdateDataFormat(const framework::ExecutionContext& ctx, template <typename T>
framework::Tensor* tensor, const char* attribute) { static void ComputeBroadcastedMultiply(const T* x_data, const T* y_data,
if (ctx.op().HasAttr(attribute)) { T* z_data, int64_t n, int64_t c,
auto format_as_string = ctx.Attr<std::string>(attribute); int64_t h, int64_t w, int simd_width,
auto format = StringToMKLDNNFormat(&format_as_string); void (*multiply)(const T*, const T*, T*,
if (format != MKLDNNMemoryFormat::any) { int, int)) {
tensor->set_format(format); const int64_t C = c / simd_width;
#pragma omp parallel for collapse(2)
for (int ni = 0; ni < n; ni++) {
for (int ci = 0; ci < C; ci++) {
auto ptr_x =
x_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
auto ptr_y = y_data + ni * C * simd_width + ci * simd_width;
auto ptr_z =
z_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
multiply(ptr_x, ptr_y, ptr_z, h, w);
} }
} }
} }
template <typename T>
static void ReorderInput(framework::Tensor* tensor,
const platform::Place& place,
const mkldnn::engine& engine, bool isFourDim) {
using platform::to_void_cast;
auto dims = paddle::framework::vectorize<int>(tensor->dims());
framework::Tensor out_tensor;
out_tensor.Resize(tensor->dims());
out_tensor.set_format(isFourDim ? MKLDNNMemoryFormat::nchw
: MKLDNNMemoryFormat::nc);
out_tensor.set_layout(tensor->layout());
mkldnn::memory input_memory = {
{{dims, platform::MKLDNNGetDataType<T>(), tensor->format()}, engine},
to_void_cast<T>(tensor->data<T>())};
mkldnn::memory output_memory = {
{{dims, platform::MKLDNNGetDataType<T>(), out_tensor.format()}, engine},
to_void_cast<T>(out_tensor.mutable_data<T>(place))};
platform::Reorder(input_memory, output_memory);
tensor->ShareDataWith(out_tensor);
}
template <typename T> template <typename T>
class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
public: public:
...@@ -82,105 +72,26 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -82,105 +72,26 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
auto y_dims_untrimmed = y->dims(); auto y_dims_untrimmed = y->dims();
auto x_int_dims = paddle::framework::vectorize<int>(x_dims); auto x_int_dims = paddle::framework::vectorize<int>(x_dims);
UpdateDataFormat(ctx, const_cast<Tensor*>(x), "x_data_format"); int pre, num, post, is_run_common_broadcast;
UpdateDataFormat(ctx, const_cast<Tensor*>(y), "y_data_format"); get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &num, &post,
const bool is_avx512_enabled = platform::MayIUse(platform::avx512f);
const bool are_dims_divisable = !(x_int_dims[1] % 16);
const bool is_x_format_correct = x->format() == MKLDNNMemoryFormat::nChw16c;
const bool is_y_format_correct = y->format() == MKLDNNMemoryFormat::nc;
if (is_x_format_correct && is_y_format_correct && are_dims_divisable &&
is_avx512_enabled) {
int pre, n, post, is_run_common_broadcast;
get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post,
&is_run_common_broadcast); &is_run_common_broadcast);
if (post == 1) { if (post == 1) PADDLE_THROW("Not implemented when post is 1");
PADDLE_THROW("Not implemented when post is 1");
} else {
// Just check whether it works for RE-Resnext.
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions");
int n = x_dims[0];
int c = x_dims[1];
int h = x_dims[2];
int w = x_dims[3];
PADDLE_ENFORCE(y_dims_untrimmed[0] == n && y_dims_untrimmed[1] == c,
"Y should be in nc format");
constexpr int simd_width = 16; const int64_t n = x_dims[0];
int C = c / simd_width; const int64_t c = x_dims[1];
const int64_t h = x_dims[2];
const int64_t w = x_dims[3];
auto multiply = jit::KernelFuncs<jit::NCHW16CMulNCTuple<T>, const int simd_width = 16;
platform::CPUPlace>::Cache() auto multiply =
jit::KernelFuncs<jit::NCHW16CMulNCTuple<T>, platform::CPUPlace>::Cache()
.At(0); .At(0);
#pragma omp parallel for collapse(2) ComputeBroadcastedMultiply(x_data, y_data, z_data, n, c, h, w, simd_width,
for (int ni = 0; ni < n; ni++) { multiply);
for (int ci = 0; ci < C; ci++) {
auto ptr_x =
x_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
auto ptr_y = y_data + ni * C * simd_width + ci * simd_width;
auto ptr_z =
z_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
multiply(ptr_x, ptr_y, ptr_z, h, w);
}
}
}
z->set_layout(DataLayout::kMKLDNN); z->set_layout(DataLayout::kMKLDNN);
z->set_format(x->format()); z->set_format(x->format());
} else {
// Fallback to naive version:
const bool are_inputs_in_same_format = x->format() == y->format();
const bool is_x_nchw = x->format() == MKLDNNMemoryFormat::nchw;
const bool is_x_nc = x->format() == MKLDNNMemoryFormat::nc;
const bool is_x_x = x->format() == MKLDNNMemoryFormat::x;
const bool is_y_nchw = y->format() == MKLDNNMemoryFormat::nchw;
const bool is_y_nc = y->format() == MKLDNNMemoryFormat::nc;
const bool is_y_x = y->format() == MKLDNNMemoryFormat::x;
if (!are_inputs_in_same_format) {
using platform::MKLDNNDeviceContext;
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
if (!(is_x_nchw || is_x_nc || is_x_x))
ReorderInput<T>(const_cast<Tensor*>(x), ctx.GetPlace(), mkldnn_engine,
x->dims().size() == 4);
if (!(is_y_nchw || is_y_nc || is_y_x))
ReorderInput<T>(const_cast<Tensor*>(y), ctx.GetPlace(), mkldnn_engine,
y->dims().size() == 4);
}
auto mul_func = [](T a, T b) -> T { return a * b; };
TransformFunctor<decltype(mul_func), T,
paddle::platform::CPUDeviceContext, T>
functor(
x, y, z,
ctx.template device_context<paddle::platform::CPUDeviceContext>(),
mul_func);
axis = (axis == -1 ? x_dims.size() - y_dims_untrimmed.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
auto y_dims = trim_trailing_singular_dims(y_dims_untrimmed);
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post, is_run_common_broadcast;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post,
&is_run_common_broadcast);
if (post == 1) {
functor.RunRowWise(n, pre);
} else {
functor.RunMidWise(n, pre, post);
}
z->set_layout(DataLayout::kMKLDNN);
z->set_format(x->format());
}
} }
}; };
} // namespace operators } // namespace operators
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
USE_OP(elementwise_mul);
namespace paddle {
namespace operators {
#ifdef PADDLE_WITH_MKLDNN
using framework::Scope;
using framework::LoDTensor;
using framework::OpRegistry;
using framework::OperatorBase;
using framework::RuntimeContext;
using framework::ExecutionContext;
struct TestData {
int64_t channel_num;
MKLDNNMemoryFormat format;
framework::DDim y_dims;
bool supposed_to_fail;
};
void MainTest(const TestData& test_data) {
auto place = platform::CPUPlace();
Scope scope;
auto* x = scope.Var("x")->GetMutable<LoDTensor>();
auto* y = scope.Var("y")->GetMutable<LoDTensor>();
scope.Var("out")->GetMutable<LoDTensor>();
x->Resize({1, test_data.channel_num, 3, 3});
y->Resize(test_data.y_dims);
x->set_format(test_data.format);
y->set_format(MKLDNNMemoryFormat::nc);
std::unique_ptr<OperatorBase> op = OpRegistry::CreateOp(
"elementwise_mul", {{"X", {"x"}}, {"Y", {"y"}}}, {{"Out", {"out"}}}, {});
auto& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext*>(pool.Get(place));
RuntimeContext runtime_ctx =
RuntimeContext(op->Inputs(), op->Outputs(), scope);
ExecutionContext ctx =
ExecutionContext(*op, scope, *dev_ctx, runtime_ctx, nullptr);
bool result = ElementwiseMulOp::AreDimsAndFormatCorrect(
ctx, 16, MKLDNNMemoryFormat::nChw16c);
if (test_data.supposed_to_fail)
ASSERT_FALSE(result);
else
ASSERT_TRUE(result);
}
// Checks if AreDimsAndFormatCorrect returns true when supplied with expected
// data
TEST(ElementwiseMulOpTester, correct_dims) {
TestData test_data;
test_data.channel_num = 16;
test_data.format = MKLDNNMemoryFormat::nChw16c;
test_data.y_dims = {1, test_data.channel_num};
test_data.supposed_to_fail = false;
MainTest(test_data);
}
// Checks if AreDimsAndFormatCorrect fails when channel_num is not divisable by
// 16
TEST(ElementwiseMulOpTester, incorrect_channel_num) {
TestData test_data;
test_data.channel_num = 17;
test_data.format = MKLDNNMemoryFormat::nChw16c;
test_data.y_dims = {1, test_data.channel_num};
test_data.supposed_to_fail = true;
MainTest(test_data);
}
// Checks if AreDimsAndFormatCorrect fails when x format is different from
// nchw16c
TEST(ElementwiseMulOpTester, incorrect_format) {
TestData test_data;
test_data.channel_num = 16;
test_data.format = MKLDNNMemoryFormat::nchw;
test_data.y_dims = {1, test_data.channel_num};
test_data.supposed_to_fail = true;
MainTest(test_data);
}
// Checks if AreDimsAndFormatCorrect fails when y input is not 2-dimensional
TEST(ElementwiseMulOpTester, incorrect_y_dims) {
TestData test_data;
test_data.channel_num = 16;
test_data.format = MKLDNNMemoryFormat::nChw16c;
test_data.y_dims = {1, test_data.channel_num, 1};
test_data.supposed_to_fail = true;
MainTest(test_data);
}
#endif
} // namespace operators
} // namespace paddle
...@@ -40,7 +40,7 @@ class TestElementwiseMulMKLDNNOp_Integrated_With_Convs(ElementwiseMulOp): ...@@ -40,7 +40,7 @@ class TestElementwiseMulMKLDNNOp_Integrated_With_Convs(ElementwiseMulOp):
self.filter_size2 = [1, 16, 2, 2] self.filter_size2 = [1, 16, 2, 2]
self.dilations = [1, 1] self.dilations = [1, 1]
self.use_cudnn = False self.use_cudnn = False
self.data_format = "NCHW" self.data_format = "ANYLAYOUT"
self.input = np.random.random(self.input_size).astype(self.dtype) self.input = np.random.random(self.input_size).astype(self.dtype)
self.filter = np.random.random(self.filter_size).astype(self.dtype) self.filter = np.random.random(self.filter_size).astype(self.dtype)
self.filter2 = np.random.random(self.filter_size2).astype(self.dtype) self.filter2 = np.random.random(self.filter_size2).astype(self.dtype)
...@@ -97,7 +97,8 @@ class TestElementwiseMulMKLDNNOp_Integrated_With_Convs(ElementwiseMulOp): ...@@ -97,7 +97,8 @@ class TestElementwiseMulMKLDNNOp_Integrated_With_Convs(ElementwiseMulOp):
'groups': self.groups, 'groups': self.groups,
'dilations': self.dilations, 'dilations': self.dilations,
'use_cudnn': self.use_cudnn, 'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn 'use_mkldnn': self.use_mkldnn,
'data_format': self.data_format
}) })
elementwise_mul_op = block.append_op( elementwise_mul_op = block.append_op(
type="elementwise_mul", type="elementwise_mul",
...@@ -152,179 +153,5 @@ class TestElementwiseMulMKLDNNOp_Integrated_With_Convs(ElementwiseMulOp): ...@@ -152,179 +153,5 @@ class TestElementwiseMulMKLDNNOp_Integrated_With_Convs(ElementwiseMulOp):
pass pass
class TestElementwiseMulMKLDNNOp_FallbackNCHW(ElementwiseMulOp):
def init_input_output(self):
self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
self.y = np.random.rand(1, 16).astype(self.dtype)
self.out = self.x * self.y.reshape(1, 16, 1, 1)
def init_kernel_type(self):
self.use_mkldnn = True
def init_axis(self):
self.axis = 0
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
class TestElementwiseMulMKLDNNOp_FallbackNCHW16C(ElementwiseMulOp):
def init_input_output(self):
x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
y = np.random.rand(1, 16, 2, 2).astype(self.dtype)
self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
self.out = x * y
def setUp(self):
super(TestElementwiseMulMKLDNNOp_FallbackNCHW16C, self).setUp()
self.attrs["x_data_format"] = "nchw16c"
self.attrs["y_data_format"] = "nchw16c"
self._cpu_only = True
def init_kernel_type(self):
self.use_mkldnn = True
def init_axis(self):
self.axis = 0
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
class TestElementwiseMulMKLDNNOp_FallbackNoReorders(ElementwiseMulOp):
def init_input_output(self):
x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
y = np.random.rand(1, 16, 2, 2).astype(self.dtype)
self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
self.out = x * y
def setUp(self):
super(TestElementwiseMulMKLDNNOp_FallbackNoReorders, self).setUp()
self.attrs["x_data_format"] = "nchw16c"
self.attrs["y_data_format"] = "nchw16c"
self._cpu_only = True
def init_kernel_type(self):
self.use_mkldnn = True
def init_axis(self):
self.axis = 0
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
class TestElementwiseMulMKLDNNOp_FallbackWithReorder1(ElementwiseMulOp):
def init_input_output(self):
self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
y = np.random.rand(1, 16, 2, 2).astype(self.dtype)
self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
self.out = self.x * y
def setUp(self):
super(TestElementwiseMulMKLDNNOp_FallbackWithReorder1, self).setUp()
self.attrs["x_data_format"] = "nchw"
self.attrs["y_data_format"] = "nchw16c"
self._cpu_only = True
def init_kernel_type(self):
self.use_mkldnn = True
def init_axis(self):
self.axis = 0
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
class TestElementwiseMulMKLDNNOp_FallbackWithReorder2(ElementwiseMulOp):
def init_input_output(self):
self.y = np.random.rand(1, 16, 2, 2).astype(self.dtype)
x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
self.out = x * self.y
def setUp(self):
super(TestElementwiseMulMKLDNNOp_FallbackWithReorder2, self).setUp()
self.attrs["x_data_format"] = "nchw16c"
self.attrs["y_data_format"] = "nchw"
self._cpu_only = True
def init_kernel_type(self):
self.use_mkldnn = True
def init_axis(self):
self.axis = 0
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
class TestElementwiseMulMKLDNNOp_FallbackNoReorders2(ElementwiseMulOp):
def init_input_output(self):
self.x = np.random.rand(1, 16).astype(self.dtype)
self.y = np.random.rand(1, 16).astype(self.dtype)
self.out = self.x * self.y
def setUp(self):
super(TestElementwiseMulMKLDNNOp_FallbackNoReorders2, self).setUp()
self.attrs["x_data_format"] = "nc"
self.attrs["y_data_format"] = "nc"
self._cpu_only = True
def init_kernel_type(self):
self.use_mkldnn = True
def init_axis(self):
self.axis = 0
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册