diff --git a/AUTHORS.md b/AUTHORS.md index 4060f75613ac4dadf353ff53a73fd0647a8052be..54a1097b50f7a09062f8987e62db6b5f5e89e0b7 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -42,6 +42,7 @@ | QiJune | Jun Qi | | qingqing01 | Qing-Qing Dang | | reyoung | Yang Yu | +| Sand3r- | Michal Gallus | | Superjom | Chun-Wei Yan | | tensor-tang | Jian Tang | | tianbingsz | Tian-Bing Xu | diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 40b0130b265471a1288d966c4cbcd4f0e1bdb9f1..6918e030bf859bc8a55baed9d944e16217b0efb6 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -100,6 +100,7 @@ class OperatorBase { const std::string& Type() const { return type_; } + bool HasAttr(const std::string& name) const { return attrs_.count(name); } template inline const T& Attr(const std::string& name) const { PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap", diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..10290a4aeff6b6a023fb28961d12728aff891e83 --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc @@ -0,0 +1,201 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" + +#include "paddle/fluid/platform/mkldnn_helper.h" + +#include "paddle/fluid/operators/math/jit_kernel.h" +#include "xbyak.h" +#include "xbyak_util.h" + +namespace paddle { +namespace operators { + +using framework::DataLayout; +using mkldnn::memory; + +static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) { + std::transform(format.begin(), format.end(), format.begin(), ::tolower); + + if (!format.compare("nchw")) { + return memory::format::nchw; + } else if (!format.compare("nchw16c")) { + return memory::format::nChw16c; + } else if (!format.compare("nchw8c")) { + return memory::format::nChw8c; + } else if (!format.compare("nhwc")) { + return memory::format::nhwc; + } else { + return memory::format::any; + } +} + +static void UpdateDataFormat(const framework::ExecutionContext& ctx, + framework::Tensor* tensor, const char* attribute) { + if (ctx.op().HasAttr(attribute)) { + auto format_as_string = ctx.Attr(attribute); + auto format = StringToMKLDNNFormat(format_as_string); + if (format != memory::format::any) { + tensor->set_format(format); + } + } +} + +template +static void ReorderInput(framework::Tensor* tensor, + const platform::Place& place, + const mkldnn::engine& engine, bool isFourDim) { + using platform::to_void_cast; + auto dims = paddle::framework::vectorize2int(tensor->dims()); + framework::Tensor out_tensor; + out_tensor.Resize(tensor->dims()); + out_tensor.set_format(isFourDim ? memory::format::nchw : memory::format::nc); + out_tensor.set_layout(tensor->layout()); + mkldnn::memory input_memory = { + {{dims, platform::MKLDNNGetDataType(), tensor->format()}, engine}, + to_void_cast(tensor->data())}; + mkldnn::memory output_memory = { + {{dims, platform::MKLDNNGetDataType(), out_tensor.format()}, engine}, + to_void_cast(out_tensor.mutable_data(place))}; + platform::Reorder(input_memory, output_memory); + tensor->ShareDataWith(out_tensor); +} + +template +class ElementwiseMulMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using Tensor = framework::Tensor; + + int axis = ctx.Attr("axis"); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + const T* x_data = x->data(); + const T* y_data = y->data(); + T* z_data = z->mutable_data(ctx.GetPlace()); + + auto x_dims = x->dims(); + auto y_dims_untrimmed = y->dims(); + auto x_int_dims = paddle::framework::vectorize2int(x_dims); + + UpdateDataFormat(ctx, (Tensor*)x, "x_data_format"); + UpdateDataFormat(ctx, (Tensor*)y, "y_data_format"); + + Xbyak::util::Cpu cpu; + const bool is_avx512_enabled = cpu.has(Xbyak::util::Cpu::tAVX512F); + const bool are_dims_divisable = !(x_int_dims[1] % 16); + const bool is_x_format_correct = x->format() == memory::format::nChw16c; + const bool is_y_format_correct = y->format() == memory::format::nc; + if (is_x_format_correct && is_y_format_correct && are_dims_divisable && + is_avx512_enabled) { + int pre, n, post; + get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post); + + if (post == 1) { + PADDLE_THROW("Not implemented when post is 1"); + } else { + // Just check whether it works for RE-Resnext. + PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions"); + + int n = x_dims[0]; + int c = x_dims[1]; + int h = x_dims[2]; + int w = x_dims[3]; + + PADDLE_ENFORCE(y_dims_untrimmed[0] == n && y_dims_untrimmed[1] == c, + "Y should be in nc format"); + + constexpr int simd_width = 16; + int C = c / simd_width; + + const auto& multiply = + math::jitkernel::KernelPool::Instance() + .template Get>(n); + +#pragma omp parallel for collapse(2) + for (int ni = 0; ni < n; ni++) { + for (int ci = 0; ci < C; ci++) { + auto ptr_x = + x_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + + auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; + auto ptr_z = + z_data + ni * C * h * w * simd_width + ci * h * w * simd_width; + + multiply->Compute(ptr_x, ptr_y, ptr_z, h, w); + } + } + } + + z->set_layout(DataLayout::kMKLDNN); + z->set_format(x->format()); + } else { + // Fallback to naive version: + const bool are_inputs_in_same_format = x->format() == y->format(); + const bool is_x_nchw = x->format() == memory::format::nchw; + const bool is_x_nc = x->format() == memory::format::nc; + const bool is_y_nchw = y->format() == memory::format::nchw; + const bool is_y_nc = y->format() == memory::format::nc; + if (!are_inputs_in_same_format) { + using platform::MKLDNNDeviceContext; + auto& dev_ctx = ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + if (!(is_x_nchw || is_x_nc)) + ReorderInput((Tensor*)x, ctx.GetPlace(), mkldnn_engine, + x->dims().size() == 4); + if (!(is_y_nchw || is_y_nc)) + ReorderInput((Tensor*)y, ctx.GetPlace(), mkldnn_engine, + y->dims().size() == 4); + } + + auto mul_func = [](T a, T b) -> T { return a * b; }; + + TransformFunctor + functor( + x, y, z, + ctx.template device_context(), + mul_func); + + axis = (axis == -1 ? x_dims.size() - y_dims_untrimmed.size() : axis); + PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), + "Axis should be in range [0, x_dims)"); + + auto y_dims = trim_trailing_singular_dims(y_dims_untrimmed); + axis = (y_dims.size() == 0) ? x_dims.size() : axis; + + int pre, n, post; + get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post); + + if (post == 1) { + functor.RunRowWise(n, pre); + } else { + functor.RunMidWise(n, pre, post); + } + z->set_layout(DataLayout::kMKLDNN); + z->set_format(x->format()); + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_KERNEL(elementwise_mul, MKLDNN, ::paddle::platform::CPUPlace, + ops::ElementwiseMulMKLDNNKernel) diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index f01f67692e1e5dd040971cb0dd1dd793648da97a..85a7817be9b3a82d40853b417d78a7fdf67f6c1f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -97,6 +97,20 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { .EqualGreaterThan(-1); AddAttr("use_mkldnn", "(bool, default false). Used by MKLDNN.") .SetDefault(false); + AddAttr( + "x_data_format", + "(string, default NCHW) Only used in mkldnn" + "An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". " + "Defaults to \"\". Specify the data format of the output data, " + "the input will be transformed automatically. ") + .SetDefault(""); + AddAttr( + "y_data_format", + "(string, default \"\") Only used in mkldnn" + "An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". " + "Defaults to \"\". Specify the data format of the output data, " + "the input will be transformed automatically. ") + .SetDefault(""); AddComment(string::Sprintf(R"DOC( Elementwise %s Operator diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 65f83ff4846601d1575daa994772cd869d526f56..64ef55de7cf73fea4538cc0d8fa6d316ddaff2f8 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -322,6 +322,42 @@ class VActJitCode : public JitCode { ymm_t ymm_dst = ymm_t(1); }; +#ifdef PADDLE_WITH_MKLDNN +struct EltwiseMulnChw16cNC : public Xbyak::CodeGenerator { + explicit EltwiseMulnChw16cNC(size_t code_size = 256 * 1024) + : Xbyak::CodeGenerator(code_size) { + // RDI is ptr x_input + // RSI is ptr y_input + // RDX is ptr output + // RCX is height + // r8 is width + + push(rbx); + + xor_(rax, rax); + xor_(r10, r10); + vmovups(zmm3, ptr[rsi]); + + L("h_loop"); + xor_(rbx, rbx); + L("w_loop"); + vmovups(zmm2, ptr[rdi + rax]); + vmulps(zmm1, zmm2, zmm3); + vmovups(ptr[rdx + rax], zmm1); + add(rax, 64); + inc(rbx); + cmp(r8, rbx); + jnz("w_loop"); + inc(r10); + cmp(r10, rcx); + jnz("h_loop"); + + pop(rbx); + ret(); + } +}; +#endif + } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 7e163c1349e73d8fe5e436b98c9a8f67e6439506..82d808f415c3b4ed2688d034aad13610ae2ab0f4 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -95,6 +95,15 @@ class VAddBiasKernel : public Kernel { void (*Compute)(const T *, const T *, T *, int); }; +#ifdef PADDLE_WITH_MKLDNN +template +class EltwiseMulnChw16cNCKernel : public Kernel { + public: + // nChw16c = nChw16c .* NC + void (*Compute)(const float *, const float *, float *, int, int); +}; +#endif + template class VActKernel : public Kernel { public: diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index 36a50f20434f313e93bfa3dd2c9d46963024caf7..a143b51439f55d1f80d7936dfad46e31bd19f0cb 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -226,6 +226,44 @@ bool VAddKernelImpl::useMKL(int d) { } #endif +#ifdef PADDLE_WITH_MKLDNN +/* EltwiseMul for nChw16c & NC inputs JitKernel */ +template +class EltwiseMulnChw16cNCKernelImpl + : public math::jitkernel::EltwiseMulnChw16cNCKernel { + public: + JITKERNEL_DECLARE_STATIC_FUNC; + explicit EltwiseMulnChw16cNCKernelImpl(int d) + : EltwiseMulnChw16cNCKernel() { + using mul_func_t = void (*)(const float*, const float*, float*, int, int); +#ifdef PADDLE_WITH_XBYAK + if (useJIT(d)) { + // roughly estimate the size of code + size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8; + sz = sz > 4096 ? sz : 4096; + jitcode_.reset(new gen::EltwiseMulnChw16cNC(sz)); + this->Compute = (mul_func_t)jitcode_->getCode(); + return; + } +#endif + PADDLE_THROW( + "This kernel shouldn't be used in Non-Xbyak, Non-MKL-DNN " + "environemnt"); + } + +#ifdef PADDLE_WITH_XBYAK + + private: + std::unique_ptr jitcode_{nullptr}; +}; + +template <> +bool EltwiseMulnChw16cNCKernelImpl::useJIT(int d) { + return true; +} +#endif +#endif + /* VAddRelu JitKernel */ template class VAddReluKernelImpl : public VAddReluKernel { @@ -394,6 +432,9 @@ REGISTER_JITKERNEL(vscal, VScalKernel); REGISTER_JITKERNEL(vaddbias, VAddBiasKernel); REGISTER_JITKERNEL(vrelu, VReluKernel); REGISTER_JITKERNEL(videntity, VIdentityKernel); +#ifdef PADDLE_WITH_MKLDNN +REGISTER_JITKERNEL(eltwise_mul_nchw16c, EltwiseMulnChw16cNCKernel); +#endif } // namespace jitkernel } // namespace math diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 690c4cf0ad6b2c741689e419223cfa6b6e1e5cf3..c195a28e452fbe073a9afb5d650f538176f688fd 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -362,7 +362,9 @@ class OpTest(unittest.TestCase): else: return [] places = [fluid.CPUPlace()] - if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type): + cpu_only = self._cpu_only if hasattr(self, '_cpu_only') else False + if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type)\ + and not cpu_only: places.append(core.CUDAPlace(0)) return places diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..536e9a1c58ec4a8b1b5a7c1d3a5fe737b38d24ab --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py @@ -0,0 +1,263 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid.core as core +from paddle.fluid.op import Operator +from test_elementwise_mul_op import * + + +class TestElementwiseMulMKLDNNOp_BroadcastNCHW16c(ElementwiseMulOp): + def init_input_output(self): + x = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + self.y = np.random.rand(1, 16).astype(self.dtype) + + self.out = x * self.y.reshape(1, 16, 1, 1) + self.out = self.out.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + + def setUp(self): + super(TestElementwiseMulMKLDNNOp_BroadcastNCHW16c, self).setUp() + self.attrs["x_data_format"] = "nchw16c" + self.attrs["y_data_format"] = "nc" + self._cpu_only = True + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +@unittest.skip( + "Not implemented yet.") # TODO(mgallus): enable when implemented. +class TestElementwiseMulMKLDNNOp_BroadcastNCHW8c(ElementwiseMulOp): + def init_input_output(self): + x = np.random.rand(1, 8, 2, 2).astype(self.dtype) + self.x = x.transpose(0, 2, 3, 1).reshape(1, 8, 2, 2) + self.y = np.random.rand(1, 8).astype(self.dtype) + + self.out = x * self.y.reshape(1, 8, 1, 1) + self.out = self.out.transpose(0, 2, 3, 1).reshape(1, 8, 2, 2) + + def setUp(self): + super(TestElementwiseMulMKLDNNOp_BroadcastNCHW8c, self).setUp() + self.attrs["x_data_format"] = "nchw8c" + self.attrs["y_data_format"] = "nc" + self._cpu_only = True + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +class TestElementwiseMulMKLDNNOp_FallbackNCHW(ElementwiseMulOp): + def init_input_output(self): + self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.y = np.random.rand(1, 16).astype(self.dtype) + + self.out = self.x * self.y.reshape(1, 16, 1, 1) + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +class TestElementwiseMulMKLDNNOp_FallbackNCHW16C(ElementwiseMulOp): + def init_input_output(self): + x = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + y = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + + self.out = self.x * self.y + + def setUp(self): + super(TestElementwiseMulMKLDNNOp_FallbackNCHW16C, self).setUp() + self.attrs["x_data_format"] = "nchw16c" + self.attrs["y_data_format"] = "nchw16c" + self._cpu_only = True + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +class TestElementwiseMulMKLDNNOp_FallbackNoReorders(ElementwiseMulOp): + def init_input_output(self): + x = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + y = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + + self.out = self.x * self.y + + def setUp(self): + super(TestElementwiseMulMKLDNNOp_FallbackNoReorders, self).setUp() + self.attrs["x_data_format"] = "nchw16c" + self.attrs["y_data_format"] = "nchw16c" + self._cpu_only = True + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +class TestElementwiseMulMKLDNNOp_FallbackWithReorder1(ElementwiseMulOp): + def init_input_output(self): + self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype) + y = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + + self.out = self.x * y + + def setUp(self): + super(TestElementwiseMulMKLDNNOp_FallbackWithReorder1, self).setUp() + self.attrs["x_data_format"] = "nchw" + self.attrs["y_data_format"] = "nchw16c" + self._cpu_only = True + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +class TestElementwiseMulMKLDNNOp_FallbackWithReorder2(ElementwiseMulOp): + def init_input_output(self): + self.y = np.random.rand(1, 16, 2, 2).astype(self.dtype) + x = np.random.rand(1, 16, 2, 2).astype(self.dtype) + self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + + self.out = x * self.y + + def setUp(self): + super(TestElementwiseMulMKLDNNOp_FallbackWithReorder2, self).setUp() + self.attrs["x_data_format"] = "nchw16c" + self.attrs["y_data_format"] = "nchw" + self._cpu_only = True + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +class TestElementwiseMulMKLDNNOp_FallbackNoReorders2(ElementwiseMulOp): + def init_input_output(self): + self.x = np.random.rand(1, 16).astype(self.dtype) + self.y = np.random.rand(1, 16).astype(self.dtype) + + self.out = self.x * self.y + + def setUp(self): + super(TestElementwiseMulMKLDNNOp_FallbackNoReorders2, self).setUp() + self.attrs["x_data_format"] = "nc" + self.attrs["y_data_format"] = "nc" + self._cpu_only = True + + def init_kernel_type(self): + self.use_mkldnn = True + + def init_axis(self): + self.axis = 0 + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py index 53409e436c0739bce63a3a8f90591e0ca6836859..57ba34f833f824d13e0b82caea789f7f57622bc9 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py @@ -21,13 +21,24 @@ from paddle.fluid.op import Operator class ElementwiseMulOp(OpTest): + def init_kernel_type(self): + self.use_mkldnn = False + def setUp(self): self.op_type = "elementwise_mul" + self.dtype = np.float32 + self.axis = -1 + self.init_dtype() + self.init_input_output() + self.init_kernel_type() + self.init_axis() + self.inputs = { - 'X': np.random.uniform(0.1, 1, [13, 17]).astype("float64"), - 'Y': np.random.uniform(0.1, 1, [13, 17]).astype("float64") + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) } - self.outputs = {'Out': np.multiply(self.inputs['X'], self.inputs['Y'])} + self.outputs = {'Out': self.out} + self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn} def test_check_output(self): self.check_output() @@ -41,6 +52,17 @@ class ElementwiseMulOp(OpTest): def test_check_grad_ingore_y(self): self.check_grad(['X'], 'Out', no_grad_set=set('Y')) + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.out = np.multiply(self.x, self.y) + + def init_dtype(self): + pass + + def init_axis(self): + pass + class TestElementwiseMulOp_scalar(ElementwiseMulOp): def setUp(self): @@ -63,17 +85,13 @@ class TestElementwiseMulOp_Vector(ElementwiseMulOp): class TestElementwiseMulOp_broadcast_0(ElementwiseMulOp): - def setUp(self): - self.op_type = "elementwise_mul" - self.inputs = { - 'X': np.random.rand(2, 3, 4).astype(np.float64), - 'Y': np.random.rand(2).astype(np.float64) - } + def init_input_output(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(2).astype(self.dtype) + self.out = self.x * self.y.reshape(2, 1, 1) - self.attrs = {'axis': 0} - self.outputs = { - 'Out': self.inputs['X'] * self.inputs['Y'].reshape(2, 1, 1) - } + def init_axis(self): + self.axis = 0 class TestElementwiseMulOp_broadcast_1(ElementwiseMulOp):