From 5439f07dd787ec79048aa37cd734cbf3b42624bb Mon Sep 17 00:00:00 2001 From: qipengh Date: Thu, 21 Apr 2022 21:31:38 +0800 Subject: [PATCH] [MLU]:add elementwise_div op (#41810) --- .../elementwise/elementwise_div_op_mlu.cc | 141 ++++++++++ .../mlu/test_elementwise_div_op_mlu.py | 253 ++++++++++++++++++ 2 files changed, 394 insertions(+) create mode 100644 paddle/fluid/operators/elementwise/elementwise_div_op_mlu.cc create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_elementwise_div_op_mlu.py diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op_mlu.cc b/paddle/fluid/operators/elementwise/elementwise_div_op_mlu.cc new file mode 100644 index 0000000000..1a7d757a27 --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_div_op_mlu.cc @@ -0,0 +1,141 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/operators/elementwise/elementwise_div_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_mlu.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class ElementwiseDivMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + MLUBinaryOp(ctx); + } +}; + +template +class ElementwiseDivGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* out = ctx.Input("Out"); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + int axis = ctx.Attr("axis"); + + const auto& x_dims = x->dims(); + const auto& y_dims = y->dims(); + axis = (axis < 0 ? (std::abs(x_dims.size() - y_dims.size()) + axis + 1) + : axis); + int max_dim = std::max(x_dims.size(), y_dims.size()); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector out_dims_array(max_dim); + GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), + y_dims_array.data(), out_dims_array.data(), max_dim, + axis); + + MLUCnnlTensorDesc x_desc(max_dim, x_dims_array.data(), ToCnnlDataType()); + MLUCnnlTensorDesc y_desc(max_dim, y_dims_array.data(), ToCnnlDataType()); + MLUCnnlTensorDesc dout_desc(*dout); + MLUCnnlOpTensorDesc mul_op_desc(CNNL_OP_TENSOR_MUL, ToCnnlDataType(), + CNNL_NOT_PROPAGATE_NAN); + + // compute dout/y == 1/y * dout + Tensor dout_div_y(dout->dtype()); + dout_div_y.Resize(dout->dims()); + dout_div_y.mutable_data(ctx.GetPlace()); + MLUBinary
(ctx, CNNL_COMPUTATION_HIGH_PRECISION, dout_desc.get(), + GetBasePtr(dout), y_desc.get(), GetBasePtr(y), + dout_desc.get(), GetBasePtr(&dout_div_y)); + + if (dx) { + // compute dx = dout/y = 1/y * dout + if (dx->dims() != dout->dims()) { + dx->mutable_data(ctx.GetPlace()); + + std::vector reduce_axes; + GetReduceAxes(axis, dout_div_y.dims(), dx->dims(), &reduce_axes); + MLUCnnlReduceDesc reduction_desc( + reduce_axes, CNNL_REDUCE_ADD, ToCnnlDataType(), + CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES); + MLUCnnlTensorDesc dx_desc(*dx); + MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduction_desc.get(), + nullptr, dout_desc.get(), GetBasePtr(&dout_div_y), 0, + nullptr, nullptr, dx_desc.get(), GetBasePtr(dx)); + } else { + dx->ShareDataWith(dout_div_y); + } + } + + if (dy) { + // compute dy = -out * (dout/y) = -out/y * dout + Tensor neg_out(out->type()); + neg_out.mutable_data(out->dims(), ctx.GetPlace()); + + MLUCnnlTensorDesc out_desc(*out); + MLUUnary(ctx, CNNL_COMPUTATION_HIGH_PRECISION, out_desc.get(), + GetBasePtr(out), out_desc.get(), GetBasePtr(&neg_out)); + + Tensor dy_temp(y->dtype()); + dy_temp.Resize(dout->dims()); + dy_temp.mutable_data(ctx.GetPlace()); + + MLUCnnl::OpTensor(ctx, mul_op_desc.get(), dout_desc.get(), + GetBasePtr(&neg_out), dout_desc.get(), + GetBasePtr(&dout_div_y), dout_desc.get(), + GetBasePtr(&dy_temp), ToCnnlDataType()); + + if (dy->dims() != dout->dims()) { + dy->mutable_data(ctx.GetPlace()); + + std::vector reduce_axes; + GetReduceAxes(axis, dy_temp.dims(), dy->dims(), &reduce_axes); + MLUCnnlReduceDesc reduction_desc( + reduce_axes, CNNL_REDUCE_ADD, ToCnnlDataType(), + CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES); + MLUCnnlTensorDesc dy_desc(*dy); + MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduction_desc.get(), + nullptr, dout_desc.get(), GetBasePtr(&dy_temp), 0, + nullptr, nullptr, dy_desc.get(), GetBasePtr(dy)); + } else { + dy->ShareDataWith(dy_temp); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(elementwise_div, ops::ElementwiseDivMLUKernel, + ops::ElementwiseDivMLUKernel, + ops::ElementwiseDivMLUKernel); + +REGISTER_OP_MLU_KERNEL(elementwise_div_grad, + ops::ElementwiseDivGradMLUKernel, + ops::ElementwiseDivGradMLUKernel, + ops::ElementwiseDivGradMLUKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_elementwise_div_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_elementwise_div_op_mlu.py new file mode 100644 index 0000000000..8fdac75c4c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_elementwise_div_op_mlu.py @@ -0,0 +1,253 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest, skip_check_grad_ci +import paddle +import paddle.fluid as fluid +from paddle.fluid.core import ops + +paddle.enable_static() +SEED = 2022 + + +class TestElementwiseDiv(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_div" + + self.init_dtype() + np.random.seed(SEED) + x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + out = np.divide(x, y) + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(x), + 'Y': OpTest.np_dtype_to_fluid_dtype(y) + } + self.attrs = {} + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + self.check_grad_with_place( + self.place, ['X', 'Y'], 'Out', max_relative_error=0.05) + + def test_check_grad_ingore_x(self): + self.check_grad_with_place( + self.place, ['Y'], + 'Out', + max_relative_error=0.05, + no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + self.check_grad_with_place( + self.place, ['X'], + 'Out', + max_relative_error=0.05, + no_grad_set=set("Y")) + + +class TestElementwiseDivFp16(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_div" + + self.init_dtype() + np.random.seed(SEED) + x = np.random.uniform(1, 2, [3, 4]).astype(self.dtype) + y = np.random.uniform(1, 2, [3, 4]).astype(self.dtype) + out = np.divide(x, y) + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(x), + 'Y': OpTest.np_dtype_to_fluid_dtype(y) + } + self.attrs = {} + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + self.__class__.no_need_check_grad = True + self.place = paddle.device.MLUPlace(0) + + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-5) + + +@skip_check_grad_ci( + reason="[skip shape check] Use y_shape(1) to test broadcast.") +class TestTestElementwiseDiv_scalar(TestElementwiseDiv): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_div" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [20, 3, 4]).astype(np.float32), + 'Y': np.random.uniform(0.1, 1, [1]).astype(np.float32) + } + self.outputs = {'Out': self.inputs['X'] / self.inputs['Y']} + + +class TestTestElementwiseDiv_Vector(TestElementwiseDiv): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_div" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [100]).astype("float32"), + 'Y': np.random.uniform(0.1, 1, [100]).astype("float32") + } + self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])} + + +class TestTestElementwiseDiv_broadcast_0(TestElementwiseDiv): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_div" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [100, 3, 4]).astype("float32"), + 'Y': np.random.uniform(0.1, 1, [100]).astype("float32") + } + + self.attrs = {'axis': 0} + self.outputs = { + 'Out': + np.divide(self.inputs['X'], self.inputs['Y'].reshape(100, 1, 1)) + } + + +class TestTestElementwiseDiv_broadcast_1(TestElementwiseDiv): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_div" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [2, 100, 4]).astype("float32"), + 'Y': np.random.uniform(0.1, 1, [100]).astype("float32") + } + + self.attrs = {'axis': 1} + self.outputs = { + 'Out': + np.divide(self.inputs['X'], self.inputs['Y'].reshape(1, 100, 1)) + } + + +class TestTestElementwiseDiv_broadcast_2(TestElementwiseDiv): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_div" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [2, 3, 100]).astype("float32"), + 'Y': np.random.uniform(0.1, 1, [100]).astype("float32") + } + + self.outputs = { + 'Out': + np.divide(self.inputs['X'], self.inputs['Y'].reshape(1, 1, 100)) + } + + +class TestTestElementwiseDiv_broadcast_3(TestElementwiseDiv): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_div" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [2, 10, 12, 5]).astype("float32"), + 'Y': np.random.uniform(0.1, 1, [10, 12]).astype("float32") + } + + self.attrs = {'axis': 1} + self.outputs = { + 'Out': + np.divide(self.inputs['X'], self.inputs['Y'].reshape(1, 10, 12, 1)) + } + + +class TestTestElementwiseDiv_broadcast_4(TestElementwiseDiv): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_div" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [2, 3, 50]).astype("float32"), + 'Y': np.random.uniform(0.1, 1, [2, 1, 50]).astype("float32") + } + self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])} + + +class TestTestElementwiseDiv_broadcast_5(TestElementwiseDiv): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_div" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [2, 3, 4, 20]).astype("float32"), + 'Y': np.random.uniform(0.1, 1, [2, 3, 1, 20]).astype("float32") + } + self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])} + + +class TestTestElementwiseDiv_commonuse_1(TestElementwiseDiv): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_div" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [2, 3, 100]).astype("float32"), + 'Y': np.random.uniform(0.1, 1, [1, 1, 100]).astype("float32"), + } + self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])} + + +class TestTestElementwiseDiv_commonuse_2(TestElementwiseDiv): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_div" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [30, 3, 1, 5]).astype("float32"), + 'Y': np.random.uniform(0.1, 1, [30, 1, 4, 1]).astype("float32"), + } + self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])} + + +class TestTestElementwiseDiv_xsize_lessthan_ysize(TestElementwiseDiv): + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_div" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [10, 12]).astype("float32"), + 'Y': np.random.uniform(0.1, 1, [2, 3, 10, 12]).astype("float32"), + } + + self.attrs = {'axis': 2} + + self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])} + + +if __name__ == '__main__': + unittest.main() -- GitLab