From e928274c0d90a44d83137ab8f5b9e38feeffe1ee Mon Sep 17 00:00:00 2001 From: Qi Li Date: Tue, 7 Sep 2021 15:51:47 +0800 Subject: [PATCH] [NPU] log_softmax_grad, test=develop (#35484) * [NPU] log_softmax_grad, test=develop * remove debug files, test=develop * update lookup_table_v2 for CANN 5.0.x, test=develop --- paddle/fluid/operators/log_softmax_op_npu.cc | 53 +++++++++--- .../fluid/operators/lookup_table_v2_op_npu.cc | 5 -- .../unittests/npu/test_log_softmax_op_npu.py | 80 ++++++++++++++++++- 3 files changed, 117 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/operators/log_softmax_op_npu.cc b/paddle/fluid/operators/log_softmax_op_npu.cc index d955bef6ce2..a2c3a1b323a 100644 --- a/paddle/fluid/operators/log_softmax_op_npu.cc +++ b/paddle/fluid/operators/log_softmax_op_npu.cc @@ -14,9 +14,13 @@ #include "paddle/fluid/operators/log_softmax_op.h" #include "paddle/fluid/operators/npu_op_runner.h" + namespace paddle { namespace operators { -template + +using NPUDeviceContext = platform::NPUDeviceContext; + +template class LogSoftmaxNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -24,22 +28,47 @@ class LogSoftmaxNPUKernel : public framework::OpKernel { auto* Out = ctx.Output("Out"); const int rank = X->dims().size(); const int axis = CanonicalAxis(ctx.Attr("axis"), rank); - std::vector axes; - axes.push_back(axis); - framework::NPUAttributeMap attr_input = {{"axes", axes}}; Out->mutable_data(ctx.GetPlace()); - const auto& runner = NpuOpRunner("LogSoftmaxV2", {*X}, {*Out}, attr_input); - auto stream = - ctx.template device_context() - .stream(); - runner.Run(stream); + + if (X->numel() != 0) { + auto stream = ctx.template device_context().stream(); + const auto& runner = NpuOpRunner("LogSoftmaxV2", {*X}, {*Out}, + {{"axes", std::vector{axis}}}); + runner.Run(stream); + } } }; + +template +class LogSoftmaxGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* Out = ctx.Input("Out"); + auto* dOut = ctx.Input(framework::GradVarName("Out")); + auto* dX = ctx.Output(framework::GradVarName("X")); + const int rank = dOut->dims().size(); + const int axis = CanonicalAxis(ctx.Attr("axis"), rank); + + // allocate memory on device. + dX->mutable_data(ctx.GetPlace()); + + if (dOut->numel() != 0) { + auto stream = ctx.template device_context().stream(); + const auto& runner = NpuOpRunner("LogSoftmaxGrad", {*dOut, *Out}, {*dX}, + {{"axis", std::vector{axis}}}); + runner.Run(stream); + } + } +}; + } // namespace operators } // namespace paddle + namespace ops = paddle::operators; namespace plat = paddle::platform; -REGISTER_OP_NPU_KERNEL( - log_softmax, - ops::LogSoftmaxNPUKernel); +REGISTER_OP_NPU_KERNEL(log_softmax, ops::LogSoftmaxNPUKernel, + ops::LogSoftmaxNPUKernel); + +REGISTER_OP_NPU_KERNEL(log_softmax_grad, ops::LogSoftmaxGradNPUKernel, + ops::LogSoftmaxGradNPUKernel); diff --git a/paddle/fluid/operators/lookup_table_v2_op_npu.cc b/paddle/fluid/operators/lookup_table_v2_op_npu.cc index c65fa634070..387cd92b69f 100644 --- a/paddle/fluid/operators/lookup_table_v2_op_npu.cc +++ b/paddle/fluid/operators/lookup_table_v2_op_npu.cc @@ -29,11 +29,6 @@ class LookupTableV2NPUKernel : public framework::OpKernel { auto *output_t = ctx.Output("Out"); // float tensor auto *table_t = ctx.Input("W"); - // It seems cann 20.1 accepts int64, but cann 20.2+ not. - PADDLE_ENFORCE_EQ(ids_t->type(), framework::proto::VarType::INT32, - platform::errors::Unimplemented( - "The index of LookupTableV2 should be int32.")); - auto *table_var = ctx.InputVar("W"); PADDLE_ENFORCE_EQ( table_var->IsType(), true, diff --git a/python/paddle/fluid/tests/unittests/npu/test_log_softmax_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_log_softmax_op_npu.py index e8b680d1ddc..f6baefec7f2 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_log_softmax_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_log_softmax_op_npu.py @@ -22,9 +22,10 @@ import paddle import paddle.fluid as fluid from paddle.fluid import core import paddle.nn.functional as F + from test_log_softmax import ref_log_softmax, ref_log_softmax_grad + paddle.enable_static() -np.random.seed(10) class TestLogSoftmaxNPUOp(OpTest): @@ -55,10 +56,16 @@ class TestLogSoftmaxNPUOp(OpTest): pass def test_check_output(self): - self.check_output_with_place(self.place) + if self.dtype == np.float16: + self.check_output_with_place(self.place, atol=1e-2) + else: + self.check_output_with_place(self.place) def test_check_grad(self): - pass + if self.dtype == np.float16: + return + self.check_grad_with_place( + self.place, ['X'], ['Out'], user_defined_grads=[self.x_grad]) def test_class(op_type, typename): @@ -88,8 +95,73 @@ def test_class2(op_type, typename): globals()[cls_name] = TestLogSoftmaxAxis -for _typename in {'float32'}: +for _typename in {np.float32, np.float16}: test_class("logsoftmax", _typename) test_class2("logsoftmax", _typename) + + +class TestNNLogSoftmaxAPI(unittest.TestCase): + def setUp(self): + self.x_shape = [2, 3, 4, 5] + self.x = np.random.uniform(-1., 1., self.x_shape).astype(np.float32) + self.place = paddle.NPUPlace(0) \ + if paddle.fluid.core.is_compiled_with_npu() \ + else paddle.CPUPlace() + + def check_api(self, axis=-1): + ref_out = np.apply_along_axis(ref_log_softmax, axis, self.x) + + logsoftmax = paddle.nn.LogSoftmax(axis) + # test static api + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data(name='x', shape=self.x_shape) + y = logsoftmax(x) + exe = paddle.static.Executor(self.place) + out = exe.run(feed={'x': self.x}, fetch_list=[y]) + self.assertTrue(np.allclose(out[0], ref_out)) + + # test dygrapg api + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x) + y = logsoftmax(x) + self.assertTrue(np.allclose(y.numpy(), ref_out)) + paddle.enable_static() + + def test_check_api(self): + for axis in [-1, 1]: + self.check_api(axis) + + +class TestNNFunctionalLogSoftmaxAPI(unittest.TestCase): + def setUp(self): + self.x_shape = [2, 3, 4, 5] + self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32) + self.place = paddle.NPUPlace(0) \ + if paddle.fluid.core.is_compiled_with_npu() \ + else paddle.CPUPlace() + + def check_api(self, axis=-1, dtype=None): + x = self.x.copy() + if dtype is not None: + x = x.astype(dtype) + ref_out = np.apply_along_axis(ref_log_softmax, axis, x) + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data(name='x', shape=self.x_shape) + y = F.log_softmax(x, axis, dtype) + exe = paddle.static.Executor(self.place) + out = exe.run(feed={'x': self.x}, fetch_list=[y]) + self.assertTrue(np.allclose(out[0], ref_out)) + + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x) + y = F.log_softmax(x, axis, dtype) + self.assertTrue(np.allclose(y.numpy(), ref_out), True) + paddle.enable_static() + + def test_check_api(self): + for axis in [-1, 1]: + self.check_api(axis) + + if __name__ == '__main__': unittest.main() -- GitLab