diff --git a/paddle/fluid/operators/elementwise/elementwise_mlu.h b/paddle/fluid/operators/elementwise/elementwise_mlu.h index d5c85e9f71cc19114b73e9b3d3eae16b440e9aca..50085f531a99dcab286b8c93731bbd49c479ddd1 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mlu.h +++ b/paddle/fluid/operators/elementwise/elementwise_mlu.h @@ -122,6 +122,7 @@ enum BINARY_FUNCTOR { DIVNONAN, MAXIMUM, MINIMUM, + POW, }; template @@ -171,6 +172,18 @@ inline void MLUBinary(const framework::ExecutionContext& ctx, MLUCnnl::Minimum(ctx, in1_desc, in1, in2_desc, in2, out_desc, out); } +template <> +inline void MLUBinary(const framework::ExecutionContext& ctx, + cnnlComputationPreference_t prefer, + const cnnlTensorDescriptor_t x_desc, + const void* x, + const cnnlTensorDescriptor_t y_desc, + const void* y, + const cnnlTensorDescriptor_t out_desc, + void* out) { + MLUCnnl::Pow(ctx, prefer, x_desc, x, y_desc, y, out_desc, out); +} + template void MLUBinaryOp(const framework::ExecutionContext& ctx) { auto* x = ctx.Input("X"); diff --git a/paddle/fluid/operators/elementwise/elementwise_pow_op_mlu.cc b/paddle/fluid/operators/elementwise/elementwise_pow_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..431122641ec3d5afdc28af291a551055c21c26eb --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_pow_op_mlu.cc @@ -0,0 +1,214 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/elementwise/elementwise_mlu.h" +#include "paddle/fluid/operators/elementwise/elementwise_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class ElementwisePowMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + MLUBinaryOp(ctx); + } +}; + +template +class ElementwisePowGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + int axis = ctx.Attr("axis"); + auto place = ctx.GetPlace(); + + auto x_dims = x->dims(); + auto y_dims = y->dims(); + axis = + (axis < 0 ? std::abs(x_dims.size() - y_dims.size()) + axis + 1 : axis); + + int max_dim = std::max(x_dims.size(), y_dims.size()); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector out_dims_array(max_dim); + GetBroadcastDimsArrays(x_dims, + y_dims, + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + axis); + cnnlDataType_t data_type = ToCnnlDataType(); + MLUCnnlTensorDesc x_desc(max_dim, x_dims_array.data(), data_type); + MLUCnnlTensorDesc y_desc(max_dim, y_dims_array.data(), data_type); + MLUCnnlTensorDesc out_desc(max_dim, out_dims_array.data(), data_type); + + auto dout_dims = dout->dims(); + if (dx) { + // dx = dout * y * pow(x, y - 1); + Tensor one_dx(y->type()); + one_dx.mutable_data(phi::make_ddim(y_dims_array), place); + FillMLUTensorWithHostValue(ctx, static_cast(1), &one_dx); + + Tensor sub_dx(y->type()); + sub_dx.mutable_data(phi::make_ddim(y_dims_array), place); + MLUCnnlOpTensorDesc op_tensor_desc( + CNNL_OP_TENSOR_SUB, data_type, CNNL_NOT_PROPAGATE_NAN); + MLUCnnl::OpTensor(ctx, + op_tensor_desc.get(), + y_desc.get(), + GetBasePtr(y), + y_desc.get(), + GetBasePtr(&one_dx), + y_desc.get(), + GetBasePtr(&sub_dx), + data_type); + + Tensor tmp_dx(x->type()); + tmp_dx.mutable_data(phi::make_ddim(out_dims_array), place); + MLUCnnl::Pow(ctx, + CNNL_COMPUTATION_HIGH_PRECISION, + x_desc.get(), + GetBasePtr(x), + y_desc.get(), + GetBasePtr(&sub_dx), + out_desc.get(), + GetBasePtr(&tmp_dx)); + + MLUCnnl::MulAx(ctx, + y_desc.get(), + GetBasePtr(y), + out_desc.get(), + GetBasePtr(&tmp_dx)); + MLUCnnl::MulAx(ctx, + out_desc.get(), + GetBasePtr(dout), + out_desc.get(), + GetBasePtr(&tmp_dx)); + + if (x_dims != dout_dims) { + dx->mutable_data(place); + std::vector reduce_axes; + GetReduceAxes(axis, dout_dims, x_dims, &reduce_axes); + if (!reduce_axes.empty()) { + MLUCnnlReduceDesc reduction_desc(reduce_axes, + CNNL_REDUCE_ADD, + data_type, + CNNL_NOT_PROPAGATE_NAN, + CNNL_REDUCE_NO_INDICES, + CNNL_32BIT_INDICES); + MLUCnnlTensorDesc dx_desc(*dx); + MLUCnnl::Reduce(ctx, + true /*need_workspace*/, + reduction_desc.get(), + nullptr, + out_desc.get(), + GetBasePtr(&tmp_dx), + 0, + nullptr, + nullptr, + dx_desc.get(), + GetBasePtr(dx)); + } + } else { + dx->ShareDataWith(tmp_dx); + } + } + if (dy) { + // dy = dout * log(x) * pow(x, y) + Tensor tmp_dy(y->type()); + tmp_dy.mutable_data(phi::make_ddim(out_dims_array), place); + MLUCnnl::Pow(ctx, + CNNL_COMPUTATION_HIGH_PRECISION, + x_desc.get(), + GetBasePtr(x), + y_desc.get(), + GetBasePtr(y), + out_desc.get(), + GetBasePtr(&tmp_dy)); + + Tensor log_x(x->type()); + log_x.mutable_data(x->dims(), place); + MLUCnnl::Log(ctx, + CNNL_COMPUTATION_HIGH_PRECISION, + CNNL_LOG_E, + x_desc.get(), + GetBasePtr(x), + x_desc.get(), + GetBasePtr(&log_x)); + MLUCnnl::MulAx(ctx, + x_desc.get(), + GetBasePtr(&log_x), + out_desc.get(), + GetBasePtr(&tmp_dy)); + MLUCnnl::MulAx(ctx, + out_desc.get(), + GetBasePtr(dout), + out_desc.get(), + GetBasePtr(&tmp_dy)); + + if (y_dims != dout_dims) { + dy->mutable_data(place); + std::vector reduce_axes; + GetReduceAxes(axis, dout_dims, y_dims, &reduce_axes); + if (!reduce_axes.empty()) { + MLUCnnlReduceDesc reduction_desc(reduce_axes, + CNNL_REDUCE_ADD, + data_type, + CNNL_NOT_PROPAGATE_NAN, + CNNL_REDUCE_NO_INDICES, + CNNL_32BIT_INDICES); + MLUCnnlTensorDesc dy_desc(*dy); + MLUCnnl::Reduce(ctx, + true /*need_workspace*/, + reduction_desc.get(), + nullptr, + out_desc.get(), + GetBasePtr(&tmp_dy), + 0, + nullptr, + nullptr, + dy_desc.get(), + GetBasePtr(dy)); + } + } else { + dy->ShareDataWith(tmp_dy); + } + } + if (!dx && !dy) { + PADDLE_THROW(platform::errors::Unavailable( + "Not support all outputs to be empty.")); + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(elementwise_pow, + ops::ElementwisePowMLUKernel, + ops::ElementwisePowMLUKernel); + +REGISTER_OP_MLU_KERNEL(elementwise_pow_grad, + ops::ElementwisePowGradMLUKernel, + ops::ElementwisePowGradMLUKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_elementwise_pow_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_elementwise_pow_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..7e04aed19c6923c5ac73e8ca9683fe507bb56d06 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_elementwise_pow_op_mlu.py @@ -0,0 +1,256 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import paddle.fluid as fluid +import paddle + +import numpy as np +import unittest +import sys + +sys.path.append("..") +from op_test import OpTest + +paddle.enable_static() +SEED = 2022 + + +def ComputeGrad(x, y, out, axis): + grad = 1 / out.size + shape_x = x.shape + shape_y = y.shape + shape_out = out.shape + reduce_axes_x = [] + reduce_axes_y = [] + + if shape_x != shape_out: + if len(shape_x) < len(shape_out): + src_axis = axis + else: + src_axis = 0 + + for ax in range(len(shape_out)): + if (ax < src_axis or ax >= src_axis + len(shape_x)) or ( + shape_out[ax] > 1 and shape_x[ax - src_axis] == 1): + reduce_axes_x.append(ax) + + if shape_y != shape_out: + if len(shape_y) < len(shape_out): + src_axis = axis + else: + src_axis = 0 + + for ax in range(len(shape_out)): + if (ax < src_axis or ax >= src_axis + len(shape_y)) or ( + shape_out[ax] > 1 and shape_y[ax - src_axis] == 1): + reduce_axes_y.append(ax) + + if len(reduce_axes_x) > 0: + for i in reduce_axes_x: + x = np.expand_dims(x, axis=i) + + if len(reduce_axes_y) > 0: + for i in reduce_axes_y: + y = np.expand_dims(y, axis=i) + + dx = y * np.power(x, y - 1) * grad + dy = np.log(x) * np.power(x, y) * grad + + if len(reduce_axes_x) > 0: + for i, element in enumerate(reduce_axes_x): + dx = np.add.reduce(dx, element - i) + + if len(reduce_axes_y) > 0: + for i, element in enumerate(reduce_axes_y): + dy = np.add.reduce(dy, element - i) + + return dx, dy + + +class TestElementwisePow(OpTest): + + def setUp(self): + self.set_mlu() + self.op_type = "elementwise_pow" + + self.init_dtype() + self.init_input_output() + self.init_axis() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': self.axis} + self.outputs = {'Out': self.out} + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place) + + def init_axis(self): + self.axis = -1 + + def init_input_output(self): + np.random.seed(SEED) + self.x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + self.y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + self.out = np.power(self.x, self.y) + + def test_check_grad_normal(self): + dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place(self.place, ['X', 'Y'], + 'Out', + user_defined_grads=[dx, dy]) + + def test_check_grad_ingore_x(self): + _, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place(self.place, ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[dy]) + + def test_check_grad_ingore_y(self): + dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place(self.place, ['X'], + 'Out', + no_grad_set=set("Y"), + user_defined_grads=[dx]) + + +class TestElementwisePowFp16(TestElementwisePow): + + def init_input_output(self): + np.random.seed(SEED) + self.x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + self.y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + self.out = np.power(self.x, self.y) + + def set_mlu(self): + self.__class__.use_mlu = True + # self.__class__.no_need_check_grad = True + self.place = paddle.device.MLUPlace(0) + + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-5) + + +class TestElementwisePowOp_broadcast_0(TestElementwisePow): + + def init_axis(self): + self.axis = 1 + + def init_input_output(self): + np.random.seed(SEED) + self.x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + self.y = np.random.uniform(1, 2, [1, 11, 17]).astype(self.dtype) + self.out = np.power(self.x, self.y) + + def test_check_grad_normal(self): + dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place(self.place, ['X', 'Y'], + 'Out', + user_defined_grads=[dx, dy]) + + def test_check_grad_ingore_x(self): + _, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place(self.place, ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[dy]) + + def test_check_grad_ingore_y(self): + dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place(self.place, ['X'], + 'Out', + no_grad_set=set("Y"), + user_defined_grads=[dx]) + + +class TestElementwisePowOp_broadcast_1(TestElementwisePow): + + def init_axis(self): + self.axis = 1 + + def init_input_output(self): + np.random.seed(SEED) + self.x = np.random.uniform(1, 2, [2, 100, 1]).astype(self.dtype) + self.y = np.random.uniform(1, 2, [100]).astype(self.dtype) + self.out = np.power(self.x, self.y.reshape(1, 100, 1)) + + def test_check_grad_normal(self): + dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place(self.place, ['X', 'Y'], + 'Out', + user_defined_grads=[dx, dy]) + + def test_check_grad_ingore_x(self): + _, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place(self.place, ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[dy]) + + def test_check_grad_ingore_y(self): + dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place(self.place, ['X'], + 'Out', + no_grad_set=set("Y"), + user_defined_grads=[dx]) + + +class TestElementwisePowOp_broadcast_2(TestElementwisePow): + + def init_axis(self): + self.axis = 0 + + def init_input_output(self): + np.random.seed(SEED) + self.x = np.random.uniform(0.1, 1, [100, 3, 1]).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, [100]).astype(self.dtype) + self.out = np.power(self.x, self.y.reshape(100, 1, 1)) + + def test_check_grad_normal(self): + dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place(self.place, ['X', 'Y'], + 'Out', + user_defined_grads=[dx, dy]) + + def test_check_grad_ingore_x(self): + _, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place(self.place, ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[dy]) + + def test_check_grad_ingore_y(self): + dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place(self.place, ['X'], + 'Out', + no_grad_set=set("Y"), + user_defined_grads=[dx]) + + +if __name__ == '__main__': + unittest.main()