未验证 提交 153e030b 编写于 作者: F fuyou765 提交者: GitHub

[MLU]add mlu kernel for reciprocal and reciprocal grad op (#43855)

上级 bcf57274
......@@ -399,11 +399,81 @@ class HardSigmoidGradMLUKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class ReciprocalMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
auto place = ctx.GetPlace();
out->mutable_data<T>(place);
MLUCnnlTensorDesc x_desc(*x);
MLUCnnlTensorDesc out_desc(*out);
MLUCnnl::Reciprocal(
ctx, x_desc.get(), GetBasePtr(x), out_desc.get(), GetBasePtr(out));
}
};
template <typename DeviceContext, typename T>
class ReciprocalGradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto place = ctx.GetPlace();
dx->mutable_data<T>(place);
Tensor square_out;
square_out.Resize(out->dims());
square_out.mutable_data<T>(place);
MLUCnnlTensorDesc out_desc(*out);
MLUCnnlTensorDesc dout_desc(*dout);
MLUCnnlTensorDesc dx_desc(*dx);
MLUCnnlTensorDesc square_out_desc(square_out);
MLUCnnl::Square(ctx,
out_desc.get(),
GetBasePtr(out),
square_out_desc.get(),
GetBasePtr(&square_out));
cnnlOpTensorDesc_t op_tensor_op = CNNL_OP_TENSOR_MUL;
cnnlDataType_t op_tensor_comp_type = CNNL_DTYPE_FLOAT;
cnnlNanPropagation_t op_tensor_nan_opt = CNNL_NOT_PROPAGATE_NAN;
MLUCnnlOpTensorDesc op_tensor_desc(
op_tensor_op, op_tensor_comp_type, op_tensor_nan_opt);
float alpha1_float = -1;
float alpha2_float = 1;
float beta_float = 0;
MLUCnnl::OpTensor(ctx,
op_tensor_desc.get(),
dout_desc.get(),
GetBasePtr(dout),
square_out_desc.get(),
GetBasePtr(&square_out),
dx_desc.get(),
GetBasePtr(dx),
op_tensor_comp_type,
alpha1_float,
alpha2_float,
beta_float);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
// reciprocal
REGISTER_OP_MLU_KERNEL(
reciprocal,
ops::ReciprocalMLUKernel<paddle::platform::MLUDeviceContext, float>,
ops::ReciprocalMLUKernel<paddle::platform::MLUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(
reciprocal_grad,
ops::ReciprocalGradMLUKernel<paddle::platform::MLUDeviceContext, float>,
ops::ReciprocalGradMLUKernel<paddle::platform::MLUDeviceContext,
paddle::platform::float16>);
// relu
REGISTER_OP_MLU_KERNEL(
relu,
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function, division
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest, skip_check_grad_ci
import paddle
paddle.enable_static()
class TestMLUReciprocal(OpTest):
def setUp(self):
self.op_type = "reciprocal"
self.set_mlu()
self.init_dtype()
np.random.seed(1024)
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
out = np.reciprocal(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'],
'Out',
max_relative_error=0.01)
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.MLUPlace(0)
def init_dtype(self):
self.dtype = np.float32
class TestMLUReciprocalFp16(TestMLUReciprocal):
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.MLUPlace(0)
def init_dtype(self):
self.dtype = np.float16
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册