From c791df09cf49219d2cc4916a2dfb7722eb4ed8d6 Mon Sep 17 00:00:00 2001 From: Jack Zhou Date: Wed, 14 Oct 2020 17:31:15 +0800 Subject: [PATCH] Add elementwise XPU OP kernel for KUNLUN core, including (but still cannot process common broadcast Add elementwise XPU OP kernel for KUNLUN core, including (but still cannot process common broadcast --- .../elementwise/elementwise_div_op_xpu.cc | 43 +++++ .../elementwise/elementwise_max_op_xpu.cc | 45 +++++ .../elementwise/elementwise_mul_op_xpu.cc | 40 +++++ .../elementwise/elementwise_sub_op_xpu.cc | 49 ++++++ .../operators/elementwise/elementwise_xpu.h | 165 ++++++++++++++++-- .../fluid/tests/unittests/xpu/elementwise.py | 100 +++++++++++ .../xpu/test_elementwise_div_op_xpu.py | 138 +++++++++++++++ .../xpu/test_elementwise_max_op_xpu.py | 129 ++++++++++++++ .../xpu/test_elementwise_mul_op_xpu.py | 153 ++++++++++++++++ .../xpu/test_elementwise_sub_op_xpu.py | 128 ++++++++++++++ 10 files changed, 975 insertions(+), 15 deletions(-) create mode 100644 paddle/fluid/operators/elementwise/elementwise_div_op_xpu.cc create mode 100644 paddle/fluid/operators/elementwise/elementwise_max_op_xpu.cc create mode 100644 paddle/fluid/operators/elementwise/elementwise_mul_op_xpu.cc create mode 100644 paddle/fluid/operators/elementwise/elementwise_sub_op_xpu.cc create mode 100644 python/paddle/fluid/tests/unittests/xpu/elementwise.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_elementwise_div_op_xpu.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_elementwise_max_op_xpu.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_elementwise_mul_op_xpu.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_elementwise_sub_op_xpu.py diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op_xpu.cc b/paddle/fluid/operators/elementwise/elementwise_div_op_xpu.cc new file mode 100644 index 0000000000..6cc4276680 --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_div_op_xpu.cc @@ -0,0 +1,43 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/operators/elementwise/elementwise_div_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_xpu.h" +namespace paddle { +namespace operators { + +template +struct XPUDivFunctor { + int operator()(xpu::Context* ctx, const T* x, const T* y, T* z, int len) { + return xpu::elementwise_div(ctx, x, y, z, len); + } +}; + +template +class ElementwiseDivXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + XPUElementwise>(ctx); + } +}; + +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL( + elementwise_div, + ops::ElementwiseDivXPUKernel); +#endif diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op_xpu.cc b/paddle/fluid/operators/elementwise/elementwise_max_op_xpu.cc new file mode 100644 index 0000000000..232cfa0239 --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_max_op_xpu.cc @@ -0,0 +1,45 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU + +#include "paddle/fluid/operators/elementwise/elementwise_max_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_xpu.h" +namespace paddle { +namespace operators { + +template +struct XPUMaxFunctor { + int operator()(xpu::Context* ctx, const T* x, const T* y, T* z, int len) { + return xpu::elementwise_max(ctx, x, y, z, len); + } +}; + +template +class ElementwiseMaxXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + XPUElementwise>(ctx); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL( + elementwise_max, + ops::ElementwiseMaxXPUKernel); +#endif diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op_xpu.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op_xpu.cc new file mode 100644 index 0000000000..d9a6ca844a --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op_xpu.cc @@ -0,0 +1,40 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_xpu.h" +namespace paddle { +namespace operators { +template +class ElementwiseMulXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + XPUElementwise>(ctx); + } +}; +DEFINE_XPU_GRAD_KERNEL(Mul, mul, true); +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL( + elementwise_mul, + ops::ElementwiseMulXPUKernel); +REGISTER_OP_XPU_KERNEL(elementwise_mul_grad, + ops::ElementwiseMulGradXPUKernel< + paddle::platform::XPUDeviceContext, float>); + +#endif diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op_xpu.cc b/paddle/fluid/operators/elementwise/elementwise_sub_op_xpu.cc new file mode 100644 index 0000000000..4e205fe492 --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op_xpu.cc @@ -0,0 +1,49 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_xpu.h" +namespace paddle { +namespace operators { + +template +struct XPUSubFunctor { + int operator()(xpu::Context* ctx, const T* x, const T* y, T* z, int len) { + return xpu::elementwise_sub(ctx, x, y, z, len); + } +}; + +template +class ElementwiseSubXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + XPUElementwise>(ctx); + } +}; + +DEFINE_XPU_GRAD_KERNEL(Sub, sub, false); +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL( + elementwise_sub, + ops::ElementwiseSubXPUKernel); +REGISTER_OP_XPU_KERNEL(elementwise_sub_grad, + ops::ElementwiseSubGradXPUKernel< + paddle::platform::XPUDeviceContext, float>); + +#endif diff --git a/paddle/fluid/operators/elementwise/elementwise_xpu.h b/paddle/fluid/operators/elementwise/elementwise_xpu.h index 53c4332e91..9145ab856d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_xpu.h +++ b/paddle/fluid/operators/elementwise/elementwise_xpu.h @@ -13,9 +13,131 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once #ifdef PADDLE_WITH_XPU +#include #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/place.h" +#define XPU_MALLOC(addr, num_bytes) \ + PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast(addr), num_bytes), \ + XPU_SUCCESS, \ + platform::errors::ResourceExhausted( \ + "\n\nOut of memory error on XPU, Cannot" \ + "allocate %s memory on XPU. \n\nPlease " \ + "check whether there is any other process " \ + "using XPU.\n", \ + string::HumanReadableSize(num_bytes))) + +#define DEFINE_XPU_GRAD_KERNEL(kernel_type, kernel_name, use_x_y_data) \ + template \ + class Elementwise##kernel_type##GradXPUKernel \ + : public ElemwiseGradKernel { \ + public: \ + void Compute(const framework::ExecutionContext& ctx) const override { \ + ElemwiseGradKernel::Compute(ctx); \ + using Tensor = framework::Tensor; \ + auto* dout = ctx.Input(framework::GradVarName("Out")); \ + auto* dx = ctx.Output(framework::GradVarName("X")); \ + auto* dy = ctx.Output(framework::GradVarName("Y")); \ + auto dx_dims = dout->dims(); \ + auto dy_dims_untrimed = dout->dims(); \ + T* dx_data = NULL; \ + T* dy_data = NULL; \ + const T* y_data = nullptr; \ + const T* x_data = nullptr; \ + T* y_broadcast = nullptr; \ + if (use_x_y_data) { \ + auto* x = ctx.Input("X"); \ + auto* y = ctx.Input("Y"); \ + y_data = y->data(); \ + x_data = x->data(); \ + } else { \ + x_data = dout->data(); \ + y_data = dout->data(); \ + } \ + int axis = ctx.Attr("axis"); \ + PADDLE_ENFORCE_GE( \ + dx_dims.size(), dy_dims_untrimed.size(), \ + platform::errors::InvalidArgument( \ + "Rank of first input must >= rank of second input.")); \ + if (dx != nullptr) { \ + dx->mutable_data(ctx.GetPlace()); \ + dx_dims = dx->dims(); \ + dx_data = dx->data(); \ + } \ + if (dy != nullptr) { \ + dy->mutable_data(ctx.GetPlace()); \ + dy_dims_untrimed = dy->dims(); \ + dy_data = dy->data(); \ + } \ + int pre, n, post, is_run_common_broadcast; \ + if (dx_dims == dy_dims_untrimed) { \ + pre = post = 1; \ + n = dout->numel(); \ + } else { \ + axis = (axis == -1 ? dx_dims.size() - dy_dims_untrimed.size() : axis); \ + PADDLE_ENFORCE_EQ(axis >= 0 && axis < dx_dims.size(), true, \ + platform::errors::InvalidArgument( \ + "Axis should be in range [0, dx_dims)")); \ + auto dy_dims = trim_trailing_singular_dims(dy_dims_untrimed); \ + axis = (dy_dims.size() == 0) ? dx_dims.size() : axis; \ + get_mid_dims(dx_dims, dy_dims, axis, &pre, &n, &post, \ + &is_run_common_broadcast); \ + } \ + int len = pre * n * post; \ + auto& dev_ctx = \ + ctx.template device_context(); \ + if (dx == nullptr) { \ + XPU_MALLOC(&dx_data, len * sizeof(float)); \ + } \ + if (dy == nullptr) { \ + XPU_MALLOC(&dy_data, len * sizeof(float)); \ + } else { \ + if (len != n) { \ + XPU_MALLOC(&dy_data, len * sizeof(float)); \ + } \ + } \ + if (use_x_y_data) { \ + if (len != n) { \ + XPU_MALLOC(&y_broadcast, len * sizeof(float)); \ + int res = \ + xpu::broadcast_ew(dev_ctx.x_context(), y_data, y_broadcast, pre, \ + n, post, xpu::ElementwiseOp::ASSIGN); \ + PADDLE_ENFORCE_EQ(res, xpu::Error_t::SUCCESS, \ + platform::errors::Fatal("XPU kernel error!")); \ + y_data = y_broadcast; \ + } \ + } \ + int res = xpu::elementwise_##kernel_name##_grad( \ + dev_ctx.x_context(), x_data, y_data, dout->data() /*out*/, \ + dout->data(), dx_data, dy_data, len); \ + PADDLE_ENFORCE_EQ(res, xpu::Error_t::SUCCESS, \ + platform::errors::Fatal("XPU kernel error!")); \ + if ((dy != nullptr) && (len != n)) { \ + int res = xpu::reduce_ew(dev_ctx.x_context(), dy_data, dy->data(), \ + pre, n, post, xpu::ElementwiseOp::ASSIGN); \ + PADDLE_ENFORCE_EQ(res, xpu::Error_t::SUCCESS, \ + platform::errors::Fatal("XPU kernel error!")); \ + dev_ctx.Wait(); \ + xpu_free(dy_data); \ + } \ + if ((len != n || dx == nullptr || dy == nullptr) && \ + !(dy != nullptr && len != n)) { \ + dev_ctx.Wait(); \ + } \ + if (dx == nullptr) { \ + xpu_free(dx_data); \ + } \ + if (dy == nullptr) { \ + xpu_free(dy_data); \ + } \ + if (use_x_y_data) { \ + if (len != n) { \ + xpu_free(y_broadcast); \ + } \ + } \ + } \ + } + namespace paddle { namespace operators { @@ -35,13 +157,16 @@ struct XPUMulFunctor { template void XPUElementwise(const framework::ExecutionContext& ctx) { - PADDLE_ENFORCE(platform::is_xpu_place(ctx.GetPlace()), - "This kernel only runs on XPU device."); + PADDLE_ENFORCE_EQ(platform::is_xpu_place(ctx.GetPlace()), true, + platform::errors::PreconditionNotMet( + "This kernel only runs on XPU device.")); auto x_var = ctx.InputVar("X"); PADDLE_ENFORCE_NE(x_var, nullptr, platform::errors::Fatal("Cannot get input Variable X")); - PADDLE_ENFORCE(x_var->IsType(), - "XPU only support LoDTensor"); + PADDLE_ENFORCE_EQ( + x_var->IsType(), true, + platform::errors::InvalidArgument( + "XPU only support LoDTensor, Input(X) is not LoDTensor")); auto x = x_var->Get(); auto* y = ctx.Input("Y"); @@ -52,14 +177,21 @@ void XPUElementwise(const framework::ExecutionContext& ctx) { auto x_dims = x.dims(); auto y_dims_untrimed = y->dims(); PADDLE_ENFORCE_GE(x_dims.size(), y_dims_untrimed.size(), - "Rank of first input must >= rank of second input."); + platform::errors::InvalidArgument( + "Rank of first input must >= rank of second input.")); axis = (axis == -1 ? x_dims.size() - y_dims_untrimed.size() : axis); - PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), - "Axis should be in range [0, x_dims)"); + PADDLE_ENFORCE_EQ( + axis >= 0 && axis < x_dims.size(), true, + platform::errors::InvalidArgument("Axis should be in range [0, x_dims)")); auto y_dims = trim_trailing_singular_dims(y_dims_untrimed); axis = (y_dims.size() == 0) ? x_dims.size() : axis; int pre, n, post, is_common_broadcast; get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post, &is_common_broadcast); + + PADDLE_ENFORCE_NE(is_common_broadcast, 1, + platform::errors::Unimplemented( + "X's shape should be equal to Y's shape.")); + int len = pre * n * post; const T* x_data = x.data(); @@ -74,15 +206,17 @@ void XPUElementwise(const framework::ExecutionContext& ctx) { if (std::is_same>::value) { int res = xpu::matrix_vector_add(dev_ctx.x_context(), x_data, y_data, z_data, pre, n); - PADDLE_ENFORCE(res == xpu::Error_t::SUCCESS, "XPU kernel error! res = %d", - res); + PADDLE_ENFORCE_EQ( + res, xpu::Error_t::SUCCESS, + platform::errors::Fatal("XPU kernel error! res = %d", res)); return; } if (std::is_same>::value) { int res = xpu::matrix_vector_mul(dev_ctx.x_context(), x_data, y_data, z_data, pre, n); - PADDLE_ENFORCE(res == xpu::Error_t::SUCCESS, "XPU kernel error! res = %d", - res); + PADDLE_ENFORCE_EQ( + res, xpu::Error_t::SUCCESS, + platform::errors::Fatal("XPU kernel error! res = %d", res)); return; } } @@ -92,15 +226,16 @@ void XPUElementwise(const framework::ExecutionContext& ctx) { len * sizeof(T)) == XPU_SUCCESS); int res = xpu::broadcast_ew(dev_ctx.x_context(), y_data, y_broadcast, pre, n, post, xpu::ElementwiseOp::ASSIGN); - PADDLE_ENFORCE(res == xpu::Error_t::SUCCESS, "XPU kernel error! res = %d", - res); + PADDLE_ENFORCE_EQ( + res, xpu::Error_t::SUCCESS, + platform::errors::Fatal("XPU kernel error! res = %d", res)); y_data = y_broadcast; } Functor functor; int res = functor(dev_ctx.x_context(), x_data, y_data, z_data, len); - PADDLE_ENFORCE(res == xpu::Error_t::SUCCESS, "XPU kernel error! res = %d", - res); + PADDLE_ENFORCE_EQ(res, xpu::Error_t::SUCCESS, + platform::errors::Fatal("XPU kernel error! res = %d", res)); if (pre != 1 || post != 1) { dev_ctx.Wait(); diff --git a/python/paddle/fluid/tests/unittests/xpu/elementwise.py b/python/paddle/fluid/tests/unittests/xpu/elementwise.py new file mode 100644 index 0000000000..64a7abb54a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/elementwise.py @@ -0,0 +1,100 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import paddle +import paddle.fluid as fluid +paddle.enable_static() + + +class TestXPUElementwiseOpBase(object): + def setUp(self, op_type): + self.op_type = op_type + self.attrs = {'use_xpu': True} + self.is_common_broadcast = False + self.is_x_size_less_than_y = False + self.grad_implemented = False + self.y_grad_implemented = True + self.dtype = np.float32 + self.__class__.op_type = self.op_type + self.__class__.use_xpu = True + self.__class__.dtype = self.dtype + + def net(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = fluid.layers.data( + name='X', shape=self.inputs['X'].shape, dtype=self.dtype) + y = fluid.layers.data( + name='Y', shape=self.inputs['Y'].shape, dtype=self.dtype) + op = getattr(fluid.layers, self.op_type) + z = op(x, y) + exe = fluid.Executor(place) + z_value = exe.run(feed=self.inputs, fetch_list=[z.name]) + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + if not self.is_common_broadcast and not self.is_x_size_less_than_y: + self.check_output_with_place(place, atol=1e-3) + else: + with self.assertRaises(BaseException): + self.net(place) + + def _check_grad_xpu_helper(self, + inputs_to_check, + output_names, + no_grad_set=None, + max_relative_error=0.05): + if self.grad_implemented and not self.is_common_broadcast \ + and not self.is_x_size_less_than_y: + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, + inputs_to_check, + output_names, + no_grad_set=no_grad_set, + max_relative_error=max_relative_error) + + def test_check_grad_normal(self): + self._check_grad_xpu_helper(['X', 'Y'], 'Out') + + def test_check_grad_ingore_x(self): + self._check_grad_xpu_helper(['Y'], 'Out', set("X")) + + def test_check_grad_ingore_y(self): + if self.y_grad_implemented: + self._check_grad_xpu_helper(['X'], 'Out', set("Y")) + + def init_axis(self): + self.axis = -1 + + def make_input(self, x_shape=[13, 17], y_shape=[13, 17]): + self.inputs = { + 'X': np.random.uniform(0.1, 1, x_shape).astype(self.dtype), + 'Y': np.random.uniform(0.1, 1, y_shape).astype(self.dtype) + } + + def reshape_input(self, x_shape=None, y_shape=None): + if x_shape is None: + x = self.inputs['X'] + else: + x = self.inputs['X'].reshape(x_shape) + if y_shape is None: + y = self.inputs['Y'] + else: + y = self.inputs['Y'].reshape(y_shape) + return x, y + + def make_output(self, x_shape=None, y_shape=None): + pass diff --git a/python/paddle/fluid/tests/unittests/xpu/test_elementwise_div_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_elementwise_div_op_xpu.py new file mode 100644 index 0000000000..cb6e412cb0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_elementwise_div_op_xpu.py @@ -0,0 +1,138 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("..") +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +from op_test import OpTest, skip_check_grad_ci +from elementwise import TestXPUElementwiseOpBase +paddle.enable_static() + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPUElementwiseDivOp(OpTest, TestXPUElementwiseOpBase): + def setUp(self): + TestXPUElementwiseOpBase.setUp(self, "elementwise_div") + self.make_input() + self.make_output() + + def make_output(self, x_shape=None, y_shape=None): + x, y = self.reshape_input(x_shape, y_shape) + self.outputs = {'Out': np.divide(x, y)} + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseDivOp_scalar(TestXPUElementwiseDivOp): + def setUp(self): + super(TestElementwiseDivOp_scalar, self).setUp() + self.grad_implemented = False + self.make_input([20, 3, 4], [1]) + self.make_output() + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseDivOp_Vector(TestXPUElementwiseDivOp): + def setUp(self): + super(TestElementwiseDivOp_Vector, self).setUp() + self.make_input([100, ], [100, ]) + self.make_output() + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseDivOp_broadcast_0(TestXPUElementwiseDivOp): + def setUp(self): + super(TestElementwiseDivOp_broadcast_0, self).setUp() + self.attrs['axis'] = 0 + self.make_input([100, 3, 4], [100, ]) + self.make_output(y_shape=[100, 1, 1]) + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseDivOp_broadcast_1(TestXPUElementwiseDivOp): + def setUp(self): + super(TestElementwiseDivOp_broadcast_1, self).setUp() + self.attrs['axis'] = 1 + self.make_input([2, 100, 4], [100, ]) + self.make_output(y_shape=[1, 100, 1]) + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseDivOp_broadcast_2(TestXPUElementwiseDivOp): + def setUp(self): + super(TestElementwiseDivOp_broadcast_2, self).setUp() + self.make_input([2, 3, 100], [100, ]) + self.make_output(y_shape=[1, 1, 100]) + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseDivOp_broadcast_3(TestXPUElementwiseDivOp): + def setUp(self): + super(TestElementwiseDivOp_broadcast_3, self).setUp() + self.attrs['axis'] = 1 + self.make_input([2, 10, 12, 5], [10, 12]) + self.make_output(y_shape=[1, 10, 12, 1]) + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseDivOp_broadcast_4(TestXPUElementwiseDivOp): + def setUp(self): + super(TestElementwiseDivOp_broadcast_4, self).setUp() + self.is_common_broadcast = True + self.make_input([2, 3, 50], [2, 1, 50]) + self.make_output() + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseDivOp_broadcast_5(TestXPUElementwiseDivOp): + def setUp(self): + super(TestElementwiseDivOp_broadcast_5, self).setUp() + self.is_common_broadcast = True + self.make_input([2, 3, 4, 20], [2, 3, 1, 20]) + self.make_output() + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseDivOp_commonuse_1(TestXPUElementwiseDivOp): + def setUp(self): + super(TestElementwiseDivOp_commonuse_1, self).setUp() + self.is_common_broadcast = True + self.make_input([2, 3, 100], [1, 1, 100]) + self.make_output() + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseDivOp_xsize_lessthan_ysize(TestXPUElementwiseDivOp): + def setUp(self): + super(TestElementwiseDivOp_xsize_lessthan_ysize, self).setUp() + self.is_x_size_less_than_y = True + self.attrs['axis'] = 2 + self.make_input([10, 12], [2, 3, 10, 12]) + self.make_output(x_shape=[1, 1, 10, 12]) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_elementwise_max_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_elementwise_max_op_xpu.py new file mode 100644 index 0000000000..340c5895c1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_elementwise_max_op_xpu.py @@ -0,0 +1,129 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("..") +import unittest +import numpy as np +from op_test import OpTest, skip_check_grad_ci +import paddle +from elementwise import TestXPUElementwiseOpBase +paddle.enable_static() + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPUElementwiseOp(OpTest, TestXPUElementwiseOpBase): + def setUp(self): + TestXPUElementwiseOpBase.setUp(self, "elementwise_max") + self.make_input() + self.make_output() + + def make_input(self, x_shape=[13, 17], y_shape=[13, 17], idx_list=None): + x = np.random.random(x_shape).astype(self.dtype) + sgn = np.random.choice([-1, 1], y_shape).astype(self.dtype) + if idx_list is None: + y = x + sgn * np.random.uniform(0.1, 1, y_shape).astype(self.dtype) + else: + x_temp = x + for idx in idx_list: + x_temp = np.take(x_temp, [0], axis=idx) + sgn = sgn.reshape(x_temp.shape) + y = x_temp + sgn * np.random.uniform(0.1, 1, x_temp.shape) + y = y.reshape(y_shape).astype(self.dtype) + + self.inputs = {'X': x, 'Y': y} + + def make_output(self, x_shape=None, y_shape=None): + x, y = self.reshape_input(x_shape, y_shape) + self.outputs = {'Out': np.maximum(x, y)} + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseMaxOp_scalar(TestXPUElementwiseOp): + def setUp(self): + super(TestElementwiseMaxOp_scalar, self).setUp() + self.make_input([2, 3, 20], [1]) + self.make_output() + self.grad_implemented = False + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseMaxOp_Vector(TestXPUElementwiseOp): + def setUp(self): + super(TestElementwiseMaxOp_Vector, self).setUp() + self.make_input([100, ], [100, ]) + self.make_output() + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseMaxOp_broadcast_0(TestXPUElementwiseOp): + def setUp(self): + super(TestElementwiseMaxOp_broadcast_0, self).setUp() + self.attrs['axis'] = 0 + self.make_input([100, 5, 2], [100, ], [1, 2]) + self.make_output(y_shape=[100, 1, 1]) + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseMaxOp_broadcast_1(TestXPUElementwiseOp): + def setUp(self): + super(TestElementwiseMaxOp_broadcast_1, self).setUp() + self.attrs['axis'] = 1 + self.make_input([2, 100, 3], [100, ], [0, 2]) + self.make_output(y_shape=[1, 100, 1]) + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseMaxOp_broadcast_2(TestXPUElementwiseOp): + def setUp(self): + super(TestElementwiseMaxOp_broadcast_2, self).setUp() + self.make_input([1, 3, 100], [100, ], [0, 1]) + self.make_output(y_shape=[1, 1, 100]) + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseMaxOp_broadcast_3(TestXPUElementwiseOp): + def setUp(self): + super(TestElementwiseMaxOp_broadcast_3, self).setUp() + self.attrs['axis'] = 1 + self.make_input([2, 50, 2, 1], [50, 2], [0, 3]) + self.make_output(y_shape=[1, 50, 2, 1]) + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseMaxOp_broadcast_4(TestXPUElementwiseOp): + def setUp(self): + super(TestElementwiseMaxOp_broadcast_4, self).setUp() + self.make_input([2, 3, 4, 5], [2, 3, 1, 5]) + self.make_output() + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseMaxOp_broadcast_5(TestXPUElementwiseOp): + def setUp(self): + super(TestElementwiseMaxOp_broadcast_5, self).setUp() + self.make_input([2, 3, 100], [1, 1, 100]) + self.make_output() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_elementwise_mul_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_elementwise_mul_op_xpu.py new file mode 100644 index 0000000000..3fa9c6d84e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_elementwise_mul_op_xpu.py @@ -0,0 +1,153 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("..") +import unittest +import numpy as np +from op_test import OpTest, skip_check_grad_ci +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard +import paddle +from elementwise import TestXPUElementwiseOpBase +paddle.enable_static() + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPUElementwiseMulOp(OpTest, TestXPUElementwiseOpBase): + def init_kernel_type(self): + self.use_mkldnn = False + + def setUp(self): + TestXPUElementwiseOpBase.setUp(self, "elementwise_mul") + self.init_kernel_type() + self.init_axis() + self.attrs['axis'] = self.axis + self.attrs['use_mkldnn'] = self.use_mkldnn + self.grad_implemented = True + self.make_input() + self.make_output() + + def make_output(self, x_shape=None, y_shape=None): + x, y = self.reshape_input(x_shape, y_shape) + self.outputs = {'Out': np.multiply(x, y)} + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPUElementwiseMulOp_scalar(TestXPUElementwiseMulOp): + def setUp(self): + super(TestXPUElementwiseMulOp_scalar, self).setUp() + self.make_input((10, 3, 4), (1, )) + self.make_output() + self.grad_implemented = False + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPUElementwiseMulOp_Vector(TestXPUElementwiseMulOp): + def setUp(self): + super(TestXPUElementwiseMulOp_Vector, self).setUp() + self.make_input((100, ), (100, )) + self.make_output() + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPUElementwiseMulOp_broadcast_0(TestXPUElementwiseMulOp): + def setUp(self): + super(TestXPUElementwiseMulOp_broadcast_0, self).setUp() + self.make_input((100, 2, 3), (100, )) + self.make_output(y_shape=(100, 1, 1)) + self.y_grad_implemented = False + + def init_axis(self): + self.axis = 0 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseMulOp_broadcast_1(TestXPUElementwiseMulOp): + def setUp(self): + super(TestElementwiseMulOp_broadcast_1, self).setUp() + self.attrs['axis'] = 1 + self.y_grad_implemented = False + self.make_input((2, 100, 3), (100, )) + self.make_output(y_shape=(1, 100, 1)) + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseMulOp_broadcast_2(TestXPUElementwiseMulOp): + def setUp(self): + super(TestElementwiseMulOp_broadcast_2, self).setUp() + self.y_grad_implemented = False + self.make_input((2, 3, 100), (100, )) + self.make_output(y_shape=(1, 1, 100)) + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseMulOp_broadcast_3(TestXPUElementwiseMulOp): + def setUp(self): + super(TestElementwiseMulOp_broadcast_3, self).setUp() + self.attrs['axis'] = 1 + self.y_grad_implemented = False + self.make_input((2, 10, 12, 3), (10, 12)) + self.make_output(y_shape=(1, 10, 12, 1)) + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseMulOp_broadcast_4(TestXPUElementwiseMulOp): + def setUp(self): + super(TestElementwiseMulOp_broadcast_4, self).setUp() + self.is_common_broadcast = True + self.make_input((10, 2, 11), (10, 1, 11)) + self.make_output() + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseMulOp_broadcast_5(TestXPUElementwiseMulOp): + def setUp(self): + super(TestElementwiseMulOp_broadcast_5, self).setUp() + self.is_common_broadcast = True + self.make_input((10, 4, 2, 3), (10, 4, 1, 3)) + self.make_output() + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPUElementwiseMulOp_commonuse_1(TestXPUElementwiseMulOp): + def setUp(self): + super(TestXPUElementwiseMulOp_commonuse_1, self).setUp() + self.is_common_broadcast = True + self.make_input((2, 3, 100), (1, 1, 100)) + self.make_output() + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPUElementwiseMulOp_xsize_lessthan_ysize(TestXPUElementwiseMulOp): + def setUp(self): + super(TestXPUElementwiseMulOp_xsize_lessthan_ysize, self).setUp() + self.attrs['axis'] = 2 + self.is_x_size_less_than_y = True + self.make_input((10, 10), (2, 2, 10, 10)) + self.make_output(x_shape=(1, 1, 10, 10)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_elementwise_sub_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_elementwise_sub_op_xpu.py new file mode 100644 index 0000000000..22aa07be95 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_elementwise_sub_op_xpu.py @@ -0,0 +1,128 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +import numpy as np +import sys +sys.path.append("..") +from op_test import OpTest, skip_check_grad_ci +import paddle +from elementwise import TestXPUElementwiseOpBase +paddle.enable_static() + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPUElementwiseSubOp(OpTest, TestXPUElementwiseOpBase): + def setUp(self): + TestXPUElementwiseOpBase.setUp(self, "elementwise_sub") + self.make_input() + self.make_output() + self.grad_implemented = True + + def make_output(self, x_shape=None, y_shape=None): + x, y = self.reshape_input(x_shape, y_shape) + self.outputs = {'Out': x - y} + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseSubOp_scalar(TestXPUElementwiseSubOp): + def setUp(self): + super(TestElementwiseSubOp_scalar, self).setUp() + self.grad_implemented = False + self.make_input((10, 3, 4), (1, )) + self.make_output() + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseSubOp_Vector(TestXPUElementwiseSubOp): + def setUp(self): + super(TestElementwiseSubOp_Vector, self).setUp() + self.make_input((100, ), (100, )) + self.make_output() + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseSubOp_broadcast_0(TestXPUElementwiseSubOp): + def setUp(self): + super(TestElementwiseSubOp_broadcast_0, self).setUp() + self.attrs['axis'] = 0 + self.make_input((100, 3, 2), (100, )) + self.make_output(y_shape=(100, 1, 1)) + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseSubOp_broadcast_1(TestXPUElementwiseSubOp): + def setUp(self): + super(TestElementwiseSubOp_broadcast_1, self).setUp() + self.attrs['axis'] = 1 + self.make_input((2, 100, 3), (100, )) + self.make_output(y_shape=(1, 100, 1)) + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseSubOp_broadcast_2(TestXPUElementwiseSubOp): + def setUp(self): + super(TestElementwiseSubOp_broadcast_2, self).setUp() + self.make_input((2, 3, 100), (100, )) + self.make_output(y_shape=(1, 1, 100)) + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseSubOp_broadcast_3(TestXPUElementwiseSubOp): + def setUp(self): + super(TestElementwiseSubOp_broadcast_3, self).setUp() + self.attrs['axis'] = 1 + self.make_input((2, 10, 12, 3), (10, 12)) + self.make_output(y_shape=(1, 10, 12, 1)) + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseSubOp_broadcast_4(TestXPUElementwiseSubOp): + def setUp(self): + super(TestElementwiseSubOp_broadcast_4, self).setUp() + self.is_common_broadcast = True + self.make_input((2, 5, 3, 12), (2, 5, 1, 12)) + self.make_output() + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseSubOp_commonuse_1(TestXPUElementwiseSubOp): + def setUp(self): + super(TestElementwiseSubOp_commonuse_1, self).setUp() + self.is_common_broadcast = True + self.make_input((2, 3, 100), (1, 1, 100)) + self.make_output() + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestElementwiseSubOp_xsize_lessthan_ysize(TestXPUElementwiseSubOp): + def setUp(self): + super(TestElementwiseSubOp_xsize_lessthan_ysize, self).setUp() + self.attrs['axis'] = 2 + self.is_x_size_less_than_y = True + self.make_input((10, 12), (2, 3, 10, 12)) + self.make_output(x_shape=(1, 1, 10, 12)) + + +if __name__ == '__main__': + unittest.main() -- GitLab