diff --git a/paddle/fluid/operators/activation_op_mlu.cc b/paddle/fluid/operators/activation_op_mlu.cc index f66b75fd1f3197230f9c4d304c49a92d18bbcad0..90d0a72074b81da90dcadc71e4a468c7d8f3db94 100644 --- a/paddle/fluid/operators/activation_op_mlu.cc +++ b/paddle/fluid/operators/activation_op_mlu.cc @@ -108,6 +108,43 @@ class ActivationGradMLUKernelV3 : public framework::OpKernel { } }; +// For sqrt +template +class SqrtMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + auto place = ctx.GetPlace(); + + out->mutable_data(place); + + MLUCnnlTensorDesc input_desc(*x); + MLUCnnlTensorDesc output_desc(*out); + + cnnlComputationPreference_t prefer = CNNL_COMPUTATION_FAST; + MLUCnnl::Sqrt(ctx, prefer, input_desc.get(), GetBasePtr(x), + output_desc.get(), GetBasePtr(out)); + } +}; + +template +class SqrtGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* out = ctx.Input("Out"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto place = ctx.GetPlace(); + + dx->mutable_data(place); + + MLUCnnlTensorDesc data_desc(*out); + MLUCnnl::SqrtGrad(ctx, data_desc.get(), GetBasePtr(out), GetBasePtr(dout), + GetBasePtr(dx)); + } +}; + } // namespace operators } // namespace paddle @@ -170,3 +207,9 @@ REGISTER_OP_MLU_KERNEL( ops::ActivationGradMLUKernelV1, ops::ActivationGradMLUKernelV1); + +// sqrt +REGISTER_OP_MLU_KERNEL(sqrt, ops::SqrtMLUKernel, + ops::SqrtMLUKernel); +REGISTER_OP_MLU_KERNEL(sqrt_grad, ops::SqrtGradMLUKernel, + ops::SqrtGradMLUKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_sqrt_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_sqrt_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..a7bdc162acdb173d428a57076902c46fde2ac195 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_sqrt_op_mlu.py @@ -0,0 +1,89 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys + +sys.path.append('..') +from op_test import OpTest +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard +import paddle +import paddle.nn.functional as F + +paddle.enable_static() +np.random.seed(10) + + +class TestSqrt(OpTest): + + def setUp(self): + self.op_type = "sqrt" + self.dtype = 'float32' + self.set_mlu() + self.python_api = paddle.sqrt + + np.random.seed(1023) + x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) + out = np.sqrt(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out', check_eager=False) + + def test_check_output(self): + self.check_output_with_place(self.place) + + +class TestSqrtHalf(OpTest): + + def setUp(self): + self.op_type = "sqrt" + self.dtype = 'float16' + self.set_mlu() + self.python_api = paddle.sqrt + + np.random.seed(1023) + x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) + out = np.sqrt(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], + 'Out', + check_eager=False, + max_relative_error=0.85) + + def test_check_output(self): + self.check_output_with_place(self.place) + + +if __name__ == "__main__": + unittest.main()