未验证 提交 ed478a3e 编写于 作者: Z zhulei 提交者: GitHub

[NPU] Add p_norm_grad (#36497)

上级 7eab0fa6
......@@ -81,6 +81,122 @@ class PnormNPUKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class PnormGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Out");
auto* dy = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto place = ctx.GetPlace();
dx->mutable_data<T>(place);
auto xdim = x->dims();
float porder = ctx.Attr<float>("porder");
bool keepdim = ctx.Attr<bool>("keepdim");
int axis = ctx.Attr<int>("axis");
axis = axis < 0 ? xdim.size() + axis : axis;
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
Tensor y_share(y->type());
Tensor dy_share(dy->type());
y_share.ShareDataWith(*y);
dy_share.ShareDataWith(*dy);
auto ydim = xdim;
if (!keepdim) {
ydim[axis] = 1;
} else {
ydim = y->dims();
}
y_share.Resize(ydim);
dy_share.Resize(ydim);
if (porder == 0) {
FillNpuTensorWithConstant(dx, static_cast<T>(0));
dx->Resize(xdim);
} else if (porder == INFINITY || porder == -INFINITY) {
Tensor x_abs;
x_abs.mutable_data<T>(xdim, place);
const auto& r_abs = NpuOpRunner("Abs", {*x}, {x_abs}, {});
r_abs.Run(stream);
Tensor t_cond;
t_cond.mutable_data<bool>(xdim, place);
const auto& r_equal =
NpuOpRunner("Equal", {x_abs, y_share}, {t_cond}, {});
r_equal.Run(stream);
Tensor t_zero;
t_zero.mutable_data<T>({1}, place);
FillNpuTensorWithConstant(&t_zero, static_cast<T>(0));
Tensor x_sign;
x_sign.mutable_data<T>(xdim, place);
const auto& r_sign = NpuOpRunner("Sign", {*x}, {x_sign}, {});
r_sign.Run(stream);
const auto& r_mul = NpuOpRunner("Mul", {x_sign, dy_share}, {*dx}, {});
r_mul.Run(stream);
const auto& r_sel =
NpuOpRunner("SelectV2", {t_cond, *dx, t_zero}, {*dx}, {});
r_sel.Run(stream);
} else {
Tensor x_abs;
x_abs.mutable_data<T>(xdim, place);
const auto& r_abs = NpuOpRunner("Abs", {*x}, {x_abs}, {});
r_abs.Run(stream);
Tensor x_sign;
x_sign.mutable_data<T>(xdim, place);
const auto& r_sign = NpuOpRunner("Sign", {*x}, {x_sign}, {});
r_sign.Run(stream);
Tensor y_pow;
y_pow.mutable_data<T>(ydim, place);
if (porder >= 1) {
const auto& r_pow1 = NpuOpRunner(
"Power", {x_abs}, {x_abs},
{{"power", (porder - 1)}, {"scale", 1.0f}, {"shift", 0.0f}});
r_pow1.Run(stream);
const auto& r_pow2 = NpuOpRunner(
"Power", {y_share}, {y_pow},
{{"power", (porder - 1)}, {"scale", 1.0f}, {"shift", 0.0f}});
r_pow2.Run(stream);
const auto& r_div = NpuOpRunner("DivNoNan", {x_abs, y_pow}, {*dx}, {});
r_div.Run(stream);
} else {
const auto& r_pow1 = NpuOpRunner(
"Power", {x_abs}, {x_abs},
{{"power", (1 - porder)}, {"scale", 1.0f}, {"shift", 0.0f}});
r_pow1.Run(stream);
const auto& r_pow2 = NpuOpRunner(
"Power", {y_share}, {y_pow},
{{"power", (1 - porder)}, {"scale", 1.0f}, {"shift", 0.0f}});
r_pow2.Run(stream);
const auto& r_div = NpuOpRunner("DivNoNan", {y_pow, x_abs}, {*dx}, {});
r_div.Run(stream);
}
const auto& r_mul1 = NpuOpRunner("Mul", {*dx, x_sign}, {*dx}, {});
r_mul1.Run(stream);
const auto& r_mul2 = NpuOpRunner("Mul", {*dx, dy_share}, {*dx}, {});
r_mul2.Run(stream);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -90,3 +206,7 @@ namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(
p_norm, ops::PnormNPUKernel<plat::NPUDeviceContext, float>,
ops::PnormNPUKernel<plat::NPUDeviceContext, plat::float16>);
REGISTER_OP_NPU_KERNEL(
p_norm_grad, ops::PnormGradNPUKernel<plat::NPUDeviceContext, float>,
ops::PnormGradNPUKernel<plat::NPUDeviceContext, plat::float16>);
......@@ -27,7 +27,6 @@ paddle.enable_static()
class TestPnormOp(OpTest):
def set_npu(self):
self.__class__.use_npu = True
self.__class__.no_need_check_grad = True
def setUp(self):
self.set_npu()
......@@ -51,6 +50,12 @@ class TestPnormOp(OpTest):
else:
self.check_output_with_place(paddle.NPUPlace(0))
def test_check_grad(self):
if self.dtype == "float16":
return
self.check_grad_with_place(
paddle.NPUPlace(0), ['X'], 'Out', user_defined_grads=self.gradient)
def init_test_case(self):
self.shape = [2, 3, 4, 5]
self.axis = 1
......@@ -131,6 +136,16 @@ class TestPnormOp5(TestPnormOp3):
self.init_dtype()
class TestPnormOp6(TestPnormOp3):
def init_test_case(self):
self.shape = [2, 3, 4, 5]
self.axis = 1
self.epsilon = 1e-12
self.porder = 0.5
self.keepdim = False
self.init_dtype()
class TestPnormOpfp16(TestPnormOp):
def init_dtype(self):
self.dtype = "float16"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册