未验证 提交 cf408949 编写于 作者: F furnace 提交者: GitHub

[NPU] Add norm_grad kernel (#35237)

* [NPU] fix for test_norm_op_npu

* [NPU] add norm_grad

* [NPU] add CheckAxis for axis

* [NPU] delete debug codes

* norm can not use L2Normalize, norm_grad can use L2NormalizeGrad

* [NPU] delete useless codes

* [NPU] optimize norm_grad OpMaker

* Update python import path
上级 e928274c
......@@ -88,7 +88,11 @@ class NormOpGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs());
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
#ifndef PADDLE_WITH_ASCEND_CL
op->SetInput("Norm", this->Output("Norm"));
#else
op->SetInput("Out", this->Output("Out"));
#endif
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
......
......@@ -15,24 +15,26 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class NormNPUKernel : public framework::OpKernel<T> {
private:
void CheckAxis(int axis, int rank) const {
using DDim = framework::DDim;
using Tensor = framework::Tensor;
void CheckAxis(int axis, int rank) {
// check the axis is in [-rank, rank-1]
if (axis <= rank - 1 && axis >= -rank) return;
PADDLE_THROW(platform::errors::InvalidArgument(
"axis in norm operator must between (%d) and (%d)"
"but got (%d).",
-rank, rank - 1, axis));
}
}
template <typename DeviceContext, typename T>
class NormNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
void Compute(const framework::ExecutionContext &ctx) const override {
VLOG(4) << "Launch Norm Op Kernel on NPU." << std::endl;
auto* in_x = ctx.Input<framework::Tensor>("X");
auto* out_y = ctx.Output<framework::Tensor>("Out");
auto* out_norm = ctx.Output<framework::Tensor>("Norm");
auto *in_x = ctx.Input<framework::Tensor>("X");
auto *out_y = ctx.Output<framework::Tensor>("Out");
auto *out_norm = ctx.Output<framework::Tensor>("Norm");
out_y->mutable_data<T>(ctx.GetPlace());
out_norm->mutable_data<T>(ctx.GetPlace());
auto xdim = in_x->dims();
......@@ -46,7 +48,7 @@ class NormNPUKernel : public framework::OpKernel<T> {
attr_input_norm["p"] = 2;
attr_input_norm["keepdim"] = true;
attr_input_norm["epsilon"] = eps;
const auto& runner =
const auto &runner =
NpuOpRunner("LpNorm", {*in_x}, {*out_norm}, attr_input_norm);
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
......@@ -56,12 +58,48 @@ class NormNPUKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class NormGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
float epsilon = ctx.Attr<float>("epsilon");
int axis = ctx.Attr<int>("axis");
auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Input<framework::Tensor>("Out");
auto *dy = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto xdim = x->dims();
CheckAxis(axis, xdim.size());
auto place = ctx.GetPlace();
dx->mutable_data<T>(place);
framework::NPUAttributeMap attr_input_norm;
attr_input_norm["dim"] = std::vector<int>({axis});
attr_input_norm["eps"] = epsilon;
const auto &runner =
NpuOpRunner("L2NormalizeGrad", {*x, *y, *dy}, {*dx}, attr_input_norm);
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(
norm, ops::NormNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::NormNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>)
REGISTER_OP_NPU_KERNEL(
norm_grad, ops::NormGradNPUKernel<plat::NPUDeviceContext, float>,
ops::NormGradNPUKernel<plat::NPUDeviceContext, plat::float16>);
......@@ -20,26 +20,18 @@ import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from op_test import OpTest, skip_check_grad_ci
from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci
from paddle.fluid.tests.unittests.test_norm_op import l2_norm
SEED = 2021
def l2_norm(x, axis, epsilon):
x2 = x**2
s = np.sum(x2, axis=axis, keepdims=True)
r = np.sqrt(s) + epsilon
y = x / np.broadcast_to(r, x.shape)
return y, r
class TestNorm(OpTest):
class TestNPUNormOp(OpTest):
def setUp(self):
paddle.enable_static()
self.set_npu()
self.place = paddle.NPUPlace(0)
self.op_type = "norm"
self.init_dtype()
self.init_test_case()
x = np.random.random(self.shape).astype(self.dtype)
y, norm = l2_norm(x, self.axis, self.epsilon)
......@@ -52,6 +44,8 @@ class TestNorm(OpTest):
def init_dtype(self):
self.dtype = np.float32
def init_test_case(self):
self.axis = 1
self.epsilon = 1e-10
self.shape = (2, 3, 4, 5)
......@@ -59,29 +53,50 @@ class TestNorm(OpTest):
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
if self.dtype == np.float16:
return
class TestNormOp2(TestNorm):
self.check_grad_with_place(
self.place, ['X'], 'Out', max_relative_error=0.006)
class TestNPUNormOp2(TestNPUNormOp):
def init_test_case(self):
self.shape = [5, 3, 9, 7]
self.axis = 0
self.epsilon = 1e-8
self.dtype = np.float32
class TestNormOp3(TestNorm):
class TestNPUNormOp3(TestNPUNormOp):
def init_test_case(self):
self.shape = [5, 3, 2, 7]
self.axis = -1
self.epsilon = 1e-8
self.dtype = np.float32
class TestNormOp4(TestNorm):
@skip_check_grad_ci(reason="'check_grad' on large inputs is too slow, " +
"however it is desirable to cover the forward pass")
class TestNPUNormOp4(TestNPUNormOp):
def init_test_case(self):
self.shape = [128, 1024, 14, 14]
self.axis = 2
self.epsilon = 1e-8
self.dtype = np.float32
def test_check_grad(self):
pass
@skip_check_grad_ci(reason="'check_grad' on large inputs is too slow, " +
"however it is desirable to cover the forward pass")
class TestNPUNormOp5(TestNPUNormOp):
def init_test_case(self):
self.shape = [2048, 2048]
self.axis = 1
self.epsilon = 1e-8
def test_check_grad(self):
pass
class API_NormTest(unittest.TestCase):
......@@ -96,13 +111,15 @@ class API_NormTest(unittest.TestCase):
self.assertRaises(TypeError, test_norm_x_type)
class TestNormFP16(TestNorm):
class TestNPUNormOpFP16(TestNPUNormOp):
def set_npu(self):
self.__class__.use_npu = True
self.__class__.no_need_check_grad = True
def init_dtype(self):
self.dtype = np.float16
def init_test_case(self):
self.axis = -1
self.epsilon = 1e-10
self.shape = (2, 3, 100)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册