未验证 提交 6d3a68cb 编写于 作者: C cambriconhsq 提交者: GitHub

[MLU]add mlu kernel for sqrt op (#43326)

上级 8045fcfd
...@@ -108,6 +108,43 @@ class ActivationGradMLUKernelV3 : public framework::OpKernel<T> { ...@@ -108,6 +108,43 @@ class ActivationGradMLUKernelV3 : public framework::OpKernel<T> {
} }
}; };
// For sqrt
template <typename T>
class SqrtMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
auto place = ctx.GetPlace();
out->mutable_data<T>(place);
MLUCnnlTensorDesc input_desc(*x);
MLUCnnlTensorDesc output_desc(*out);
cnnlComputationPreference_t prefer = CNNL_COMPUTATION_FAST;
MLUCnnl::Sqrt(ctx, prefer, input_desc.get(), GetBasePtr(x),
output_desc.get(), GetBasePtr(out));
}
};
template <typename T>
class SqrtGradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto place = ctx.GetPlace();
dx->mutable_data<T>(place);
MLUCnnlTensorDesc data_desc(*out);
MLUCnnl::SqrtGrad(ctx, data_desc.get(), GetBasePtr(out), GetBasePtr(dout),
GetBasePtr(dx));
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -170,3 +207,9 @@ REGISTER_OP_MLU_KERNEL( ...@@ -170,3 +207,9 @@ REGISTER_OP_MLU_KERNEL(
ops::ActivationGradMLUKernelV1<CNNL_ACTIVATION_LEAKYRELU, float>, ops::ActivationGradMLUKernelV1<CNNL_ACTIVATION_LEAKYRELU, float>,
ops::ActivationGradMLUKernelV1<CNNL_ACTIVATION_LEAKYRELU, ops::ActivationGradMLUKernelV1<CNNL_ACTIVATION_LEAKYRELU,
paddle::platform::float16>); paddle::platform::float16>);
// sqrt
REGISTER_OP_MLU_KERNEL(sqrt, ops::SqrtMLUKernel<float>,
ops::SqrtMLUKernel<paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(sqrt_grad, ops::SqrtGradMLUKernel<float>,
ops::SqrtGradMLUKernel<paddle::platform::float16>);
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append('..')
from op_test import OpTest
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
import paddle
import paddle.nn.functional as F
paddle.enable_static()
np.random.seed(10)
class TestSqrt(OpTest):
def setUp(self):
self.op_type = "sqrt"
self.dtype = 'float32'
self.set_mlu()
self.python_api = paddle.sqrt
np.random.seed(1023)
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
out = np.sqrt(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out', check_eager=False)
def test_check_output(self):
self.check_output_with_place(self.place)
class TestSqrtHalf(OpTest):
def setUp(self):
self.op_type = "sqrt"
self.dtype = 'float16'
self.set_mlu()
self.python_api = paddle.sqrt
np.random.seed(1023)
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
out = np.sqrt(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'],
'Out',
check_eager=False,
max_relative_error=0.85)
def test_check_output(self):
self.check_output_with_place(self.place)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册