未验证 提交 fead5631 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] fix bug of lookup_table_v2_grad (#31834)

上级 149f76e6
......@@ -44,6 +44,11 @@ static void DataCopy(const framework::LoDTensor &src_item,
TensorCopySync(src_item, platform::CPUPlace(), dst_item);
}
#else
#ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(src_item.place())) {
platform::DeviceContextPool::Instance().Get(src_item.place())->Wait();
}
#endif
TensorCopySync(src_item, platform::CPUPlace(), dst_item);
#endif
} else {
......
......@@ -51,18 +51,27 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> {
auto *ids_t = ctx.Input<framework::LoDTensor>("Ids");
auto *output_grad_t =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *table_t = ctx.Input<framework::LoDTensor>("W");
auto *table_grad_t =
ctx.Output<framework::LoDTensor>(framework::GradVarName("W"));
table_grad_t->mutable_data<T>(ctx.GetPlace());
framework::NPUAttributeMap attr_input = {{"use_locking", true}};
auto runner = NpuOpRunner("ScatterAdd", {*table_t, *ids_t, *output_grad_t},
{*table_grad_t}, attr_input);
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
// step2: ZerosLike x in device
Tensor zeroslike_w(table_grad_t->type());
zeroslike_w.Resize(table_grad_t->dims());
auto p = zeroslike_w.mutable_data<T>(ctx.GetPlace());
platform::NPUMemsetAsync(static_cast<void *>(p), 0,
zeroslike_w.numel() * sizeof(T), stream);
table_grad_t->mutable_data<T>(ctx.GetPlace());
auto runner_scatter =
NpuOpRunner("ScatterAdd", {zeroslike_w, *ids_t, *output_grad_t},
{*table_grad_t}, {});
runner_scatter.Run(stream);
}
};
} // namespace operators
......
......@@ -36,19 +36,22 @@ class TestLookupTableV2(OpTest):
self.init_dtype()
np.random.seed(SEED)
bsz=2
seqlen=2
vocab=3
dim=2
bsz = 6
seqlen = 8
vocab = 10
dim = 20
w = np.ones([vocab, dim]).astype(self.dtype)
x = np.random.randint(0, vocab, size=(bsz, seqlen)).astype(np.int64)
out = np.ones([bsz, seqlen, dim]).astype(self.dtype)
self.inputs = {'W': OpTest.np_dtype_to_fluid_dtype(w), 'Ids': OpTest.np_dtype_to_fluid_dtype(x)}
self.inputs = {
'W': OpTest.np_dtype_to_fluid_dtype(w),
'Ids': OpTest.np_dtype_to_fluid_dtype(x)
}
self.attrs = {
'is_sparse': False,
'is_distributed': False,
'remote_prefetch':False,
'remote_prefetch': False,
'padding_idx': -1
}
self.outputs = {'Out': out}
......@@ -62,81 +65,25 @@ class TestLookupTableV2(OpTest):
def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False)
# TODO(ascendrc): Add grad test
# def test_check_grad(self):
# if self.dtype == np.float16:
# return
# self.check_grad(['X'], 'Out')
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
self.place, ['W'], 'Out', check_dygraph=False)
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestLookupTableV2FP16(TestLookupTableV2):
no_need_check_grad = True
def init_dtype(self):
self.dtype = np.float16
#@unittest.skipIf(not paddle.is_compiled_with_npu(),
# "core is not compiled with NPU")
#class TestLookupTableV2Int8(TestLookupTableV2):
# def init_dtype(self):
# self.dtype = np.int8
#
#@unittest.skipIf(not paddle.is_compiled_with_npu(),
# "core is not compiled with NPU")
#class TestLookupTableV2UInt8(TestLookupTableV2):
# def init_dtype(self):
# self.dtype = np.uint8
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestLookupTableV2Net(unittest.TestCase):
def _test(self, run_npu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)
bsz=3
seqlen=2
vocab=3
dim=2
ids_np = np.random.randint(0, vocab, size=(bsz, seqlen)).astype('int64')
with paddle.static.program_guard(main_prog, startup_prog):
emb = paddle.nn.Embedding(vocab, dim)
ids = paddle.static.data(name="ids", shape=[bsz, seqlen], dtype='int64')
res = emb(ids)
loss = res.sum()
if run_npu:
place = paddle.NPUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
for epoch in range(1):
loss_res, w = exe.run(
main_prog,
feed={"ids": ids_np},
fetch_list=[loss, emb.weight])
if epoch % 10 == 0:
print(w)
print("Epoch {} | Loss: {}".format(epoch, loss))
return loss_res
def test_npu(self):
cpu_loss = self._test(False)
npu_loss = self._test(True)
self.assertTrue(np.allclose(npu_loss, cpu_loss))
def set_npu(self):
self.__class__.use_npu = True
self.__class__.no_need_check_grad = True
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册