未验证 提交 4d842050 编写于 作者: P pangyoki 提交者: GitHub

[NPU] change ScatterAdd to EmbeddingDenseGrad in lookup_table NPU op (#33866)

* change ScatterAdd to EmbeddingDenseGrad in lookup_table NPU op

* EmbeddingDenseGrad only supports dim 32

* fix shape error
上级 871edade
...@@ -65,17 +65,31 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> { ...@@ -65,17 +65,31 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> {
ctx.template device_context<paddle::platform::NPUDeviceContext>() ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream(); .stream();
const auto &runner_zeros = int embedding_dim = table_grad_t->dims()[1];
NpuOpRunner("ZerosLike", {*table_grad_t}, {*table_grad_t});
runner_zeros.Run(stream); if (embedding_dim % 32 == 0) {
// NOTE(pangyoki): The embedding_dim of Tensor used in
// NOTE(zhiqiu): It seems in cann 20.1, the first input and output // EmbeddingDenseGrad must be an integer multiple of 32.
// can be different tensor, but in cann 20.2+, it does inplace operation. int num_weights = table_grad_t->dims()[0];
// Thus, the first input and output should be same tensor. const auto &runner =
const auto &runner_scatter = NpuOpRunner("EmbeddingDenseGrad", {*output_grad_t, *ids_t},
NpuOpRunner("ScatterAdd", {*table_grad_t, *ids_t, *output_grad_t}, {*table_grad_t}, {{"num_weights", num_weights},
{*table_grad_t}, {{"use_locking", true}}); {"padding_idx", -1},
runner_scatter.Run(stream); {"scale_grad_by_freq", false}});
runner.Run(stream);
} else {
const auto &runner_zeros =
NpuOpRunner("ZerosLike", {*table_grad_t}, {*table_grad_t});
runner_zeros.Run(stream);
// NOTE(zhiqiu): It seems in cann 20.1, the first input and output
// can be different tensor, but in cann 20.2+, it does inplace operation.
// Thus, the first input and output should be same tensor.
const auto &runner_scatter =
NpuOpRunner("ScatterAdd", {*table_grad_t, *ids_t, *output_grad_t},
{*table_grad_t}, {{"use_locking", true}});
runner_scatter.Run(stream);
}
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -35,14 +35,14 @@ class TestLookupTableV2(OpTest): ...@@ -35,14 +35,14 @@ class TestLookupTableV2(OpTest):
self.place = paddle.NPUPlace(0) self.place = paddle.NPUPlace(0)
self.init_dtype() self.init_dtype()
self.init_dim()
np.random.seed(SEED) np.random.seed(SEED)
bsz = 6 bsz = 6
seqlen = 8 seqlen = 8
vocab = 10 vocab = 10
dim = 20 w = np.ones([vocab, self.dim]).astype(self.dtype)
w = np.ones([vocab, dim]).astype(self.dtype)
x = np.random.randint(0, vocab, size=(bsz, seqlen)).astype(np.int32) x = np.random.randint(0, vocab, size=(bsz, seqlen)).astype(np.int32)
out = np.ones([bsz, seqlen, dim]).astype(self.dtype) out = np.ones([bsz, seqlen, self.dim]).astype(self.dtype)
self.inputs = { self.inputs = {
'W': OpTest.np_dtype_to_fluid_dtype(w), 'W': OpTest.np_dtype_to_fluid_dtype(w),
...@@ -62,6 +62,10 @@ class TestLookupTableV2(OpTest): ...@@ -62,6 +62,10 @@ class TestLookupTableV2(OpTest):
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float32
def init_dim(self):
# embedding_dim is not multiple of 32
self.dim = 20
def test_check_output(self): def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False) self.check_output_with_place(self.place, check_dygraph=False)
...@@ -85,5 +89,29 @@ class TestLookupTableV2FP16(TestLookupTableV2): ...@@ -85,5 +89,29 @@ class TestLookupTableV2FP16(TestLookupTableV2):
self.__class__.no_need_check_grad = True self.__class__.no_need_check_grad = True
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestLookupTableV2Dim32(TestLookupTableV2):
def init_dim(self):
# embedding_dim is multiple of 32
self.dim = 64
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestLookupTableV2Dim32FP16(TestLookupTableV2):
no_need_check_grad = True
def init_dtype(self):
self.dtype = np.float16
def init_dim(self):
self.dim = 64
def set_npu(self):
self.__class__.use_npu = True
self.__class__.no_need_check_grad = True
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册