From fead5631564b4363c44ff6ec49a80a34eb5e4939 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 24 Mar 2021 15:26:55 +0800 Subject: [PATCH] [NPU] fix bug of lookup_table_v2_grad (#31834) --- .../fluid/operators/controlflow/fetch_op.cc | 5 + .../fluid/operators/lookup_table_v2_op_npu.cc | 19 +++- .../npu/test_lookup_table_v2_op_npu.py | 91 ++++--------------- 3 files changed, 38 insertions(+), 77 deletions(-) diff --git a/paddle/fluid/operators/controlflow/fetch_op.cc b/paddle/fluid/operators/controlflow/fetch_op.cc index d86b6b4842..fdd1b776bd 100644 --- a/paddle/fluid/operators/controlflow/fetch_op.cc +++ b/paddle/fluid/operators/controlflow/fetch_op.cc @@ -44,6 +44,11 @@ static void DataCopy(const framework::LoDTensor &src_item, TensorCopySync(src_item, platform::CPUPlace(), dst_item); } #else +#ifdef PADDLE_WITH_ASCEND_CL + if (platform::is_npu_place(src_item.place())) { + platform::DeviceContextPool::Instance().Get(src_item.place())->Wait(); + } +#endif TensorCopySync(src_item, platform::CPUPlace(), dst_item); #endif } else { diff --git a/paddle/fluid/operators/lookup_table_v2_op_npu.cc b/paddle/fluid/operators/lookup_table_v2_op_npu.cc index 8ab4d70fd3..fab2d7f7aa 100644 --- a/paddle/fluid/operators/lookup_table_v2_op_npu.cc +++ b/paddle/fluid/operators/lookup_table_v2_op_npu.cc @@ -51,18 +51,27 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel { auto *ids_t = ctx.Input("Ids"); auto *output_grad_t = ctx.Input(framework::GradVarName("Out")); - auto *table_t = ctx.Input("W"); auto *table_grad_t = ctx.Output(framework::GradVarName("W")); table_grad_t->mutable_data(ctx.GetPlace()); - framework::NPUAttributeMap attr_input = {{"use_locking", true}}; - auto runner = NpuOpRunner("ScatterAdd", {*table_t, *ids_t, *output_grad_t}, - {*table_grad_t}, attr_input); auto stream = ctx.template device_context() .stream(); - runner.Run(stream); + + // step2: ZerosLike x in device + Tensor zeroslike_w(table_grad_t->type()); + zeroslike_w.Resize(table_grad_t->dims()); + auto p = zeroslike_w.mutable_data(ctx.GetPlace()); + + platform::NPUMemsetAsync(static_cast(p), 0, + zeroslike_w.numel() * sizeof(T), stream); + + table_grad_t->mutable_data(ctx.GetPlace()); + auto runner_scatter = + NpuOpRunner("ScatterAdd", {zeroslike_w, *ids_t, *output_grad_t}, + {*table_grad_t}, {}); + runner_scatter.Run(stream); } }; } // namespace operators diff --git a/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py index 99016e5d62..400ddd9d4a 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py @@ -36,19 +36,22 @@ class TestLookupTableV2(OpTest): self.init_dtype() np.random.seed(SEED) - bsz=2 - seqlen=2 - vocab=3 - dim=2 + bsz = 6 + seqlen = 8 + vocab = 10 + dim = 20 w = np.ones([vocab, dim]).astype(self.dtype) x = np.random.randint(0, vocab, size=(bsz, seqlen)).astype(np.int64) out = np.ones([bsz, seqlen, dim]).astype(self.dtype) - self.inputs = {'W': OpTest.np_dtype_to_fluid_dtype(w), 'Ids': OpTest.np_dtype_to_fluid_dtype(x)} + self.inputs = { + 'W': OpTest.np_dtype_to_fluid_dtype(w), + 'Ids': OpTest.np_dtype_to_fluid_dtype(x) + } self.attrs = { 'is_sparse': False, 'is_distributed': False, - 'remote_prefetch':False, + 'remote_prefetch': False, 'padding_idx': -1 } self.outputs = {'Out': out} @@ -62,81 +65,25 @@ class TestLookupTableV2(OpTest): def test_check_output(self): self.check_output_with_place(self.place, check_dygraph=False) - # TODO(ascendrc): Add grad test - # def test_check_grad(self): - # if self.dtype == np.float16: - # return - # self.check_grad(['X'], 'Out') + def test_check_grad(self): + if self.dtype == np.float16: + return + self.check_grad_with_place( + self.place, ['W'], 'Out', check_dygraph=False) + @unittest.skipIf(not paddle.is_compiled_with_npu(), "core is not compiled with NPU") class TestLookupTableV2FP16(TestLookupTableV2): no_need_check_grad = True + def init_dtype(self): self.dtype = np.float16 - -#@unittest.skipIf(not paddle.is_compiled_with_npu(), -# "core is not compiled with NPU") -#class TestLookupTableV2Int8(TestLookupTableV2): -# def init_dtype(self): -# self.dtype = np.int8 -# -#@unittest.skipIf(not paddle.is_compiled_with_npu(), -# "core is not compiled with NPU") -#class TestLookupTableV2UInt8(TestLookupTableV2): -# def init_dtype(self): -# self.dtype = np.uint8 - - -@unittest.skipIf(not paddle.is_compiled_with_npu(), - "core is not compiled with NPU") -class TestLookupTableV2Net(unittest.TestCase): - def _test(self, run_npu=True): - main_prog = paddle.static.Program() - startup_prog = paddle.static.Program() - main_prog.random_seed = SEED - startup_prog.random_seed = SEED - np.random.seed(SEED) - - bsz=3 - seqlen=2 - vocab=3 - dim=2 - - ids_np = np.random.randint(0, vocab, size=(bsz, seqlen)).astype('int64') - - with paddle.static.program_guard(main_prog, startup_prog): - emb = paddle.nn.Embedding(vocab, dim) - ids = paddle.static.data(name="ids", shape=[bsz, seqlen], dtype='int64') - res = emb(ids) - loss = res.sum() - - if run_npu: - place = paddle.NPUPlace(0) - else: - place = paddle.CPUPlace() - - exe = paddle.static.Executor(place) - exe.run(startup_prog) - - for epoch in range(1): - loss_res, w = exe.run( - main_prog, - feed={"ids": ids_np}, - fetch_list=[loss, emb.weight]) - if epoch % 10 == 0: - print(w) - print("Epoch {} | Loss: {}".format(epoch, loss)) - - return loss_res - - def test_npu(self): - cpu_loss = self._test(False) - npu_loss = self._test(True) - self.assertTrue(np.allclose(npu_loss, cpu_loss)) + def set_npu(self): + self.__class__.use_npu = True + self.__class__.no_need_check_grad = True if __name__ == '__main__': unittest.main() - -- GitLab