未验证 提交 1b71a718 编写于 作者: J Jackwaterveg 提交者: GitHub

[NPU] Add square grad (#34889)

* test=develop

* test=develop
上级 40f62737
......@@ -386,6 +386,35 @@ class SquareNPUKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class SquareGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto factor = static_cast<float>(2.0);
auto place = ctx.GetPlace();
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
// Step 1: Compute x_muls_factor = factor * x
Tensor x_muls_factor(x->type());
x_muls_factor.mutable_data<T>(x->dims(), place);
const auto& runner_muls_1 =
NpuOpRunner("Muls", {*x}, {x_muls_factor}, {{"value", factor}});
runner_muls_1.Run(stream);
// Step 2: Compute dx = dout * factor * x
dx->mutable_data<T>(place);
const auto& runner_mul_2 =
NpuOpRunner("Mul", {*dout, x_muls_factor}, {*dx}, {});
runner_mul_2.Run(stream);
}
};
template <typename DeviceContext, typename T>
class SigmoidNPUKernel : public framework::OpKernel<T> {
public:
......@@ -869,6 +898,12 @@ REGISTER_OP_NPU_KERNEL(
paddle::platform::float16>,
ops::SquareNPUKernel<paddle::platform::NPUDeviceContext, int>);
REGISTER_OP_NPU_KERNEL(
square_grad,
ops::SquareGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::SquareNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(
sigmoid, ops::SigmoidNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::SigmoidNPUKernel<paddle::platform::NPUDeviceContext,
......
......@@ -50,12 +50,10 @@ class TestSquare(OpTest):
def test_check_output(self):
self.check_output_with_place(self.place)
# TODO(ascendrc): Add grad test
# def test_check_grad(self):
# if self.dtype == np.float16:
# return
# self.check_grad(['X'], 'Out')
#
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(self.place, ['X'], 'Out')
class TestSquareFp16(OpTest):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册