未验证 提交 4d842050 编写于 作者: P pangyoki 提交者: GitHub

[NPU] change ScatterAdd to EmbeddingDenseGrad in lookup_table NPU op (#33866)

* change ScatterAdd to EmbeddingDenseGrad in lookup_table NPU op

* EmbeddingDenseGrad only supports dim 32

* fix shape error
上级 871edade
......@@ -65,6 +65,19 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> {
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
int embedding_dim = table_grad_t->dims()[1];
if (embedding_dim % 32 == 0) {
// NOTE(pangyoki): The embedding_dim of Tensor used in
// EmbeddingDenseGrad must be an integer multiple of 32.
int num_weights = table_grad_t->dims()[0];
const auto &runner =
NpuOpRunner("EmbeddingDenseGrad", {*output_grad_t, *ids_t},
{*table_grad_t}, {{"num_weights", num_weights},
{"padding_idx", -1},
{"scale_grad_by_freq", false}});
runner.Run(stream);
} else {
const auto &runner_zeros =
NpuOpRunner("ZerosLike", {*table_grad_t}, {*table_grad_t});
runner_zeros.Run(stream);
......@@ -77,6 +90,7 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> {
{*table_grad_t}, {{"use_locking", true}});
runner_scatter.Run(stream);
}
}
};
} // namespace operators
} // namespace paddle
......
......@@ -35,14 +35,14 @@ class TestLookupTableV2(OpTest):
self.place = paddle.NPUPlace(0)
self.init_dtype()
self.init_dim()
np.random.seed(SEED)
bsz = 6
seqlen = 8
vocab = 10
dim = 20
w = np.ones([vocab, dim]).astype(self.dtype)
w = np.ones([vocab, self.dim]).astype(self.dtype)
x = np.random.randint(0, vocab, size=(bsz, seqlen)).astype(np.int32)
out = np.ones([bsz, seqlen, dim]).astype(self.dtype)
out = np.ones([bsz, seqlen, self.dim]).astype(self.dtype)
self.inputs = {
'W': OpTest.np_dtype_to_fluid_dtype(w),
......@@ -62,6 +62,10 @@ class TestLookupTableV2(OpTest):
def init_dtype(self):
self.dtype = np.float32
def init_dim(self):
# embedding_dim is not multiple of 32
self.dim = 20
def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False)
......@@ -85,5 +89,29 @@ class TestLookupTableV2FP16(TestLookupTableV2):
self.__class__.no_need_check_grad = True
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestLookupTableV2Dim32(TestLookupTableV2):
def init_dim(self):
# embedding_dim is multiple of 32
self.dim = 64
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestLookupTableV2Dim32FP16(TestLookupTableV2):
no_need_check_grad = True
def init_dtype(self):
self.dtype = np.float16
def init_dim(self):
self.dim = 64
def set_npu(self):
self.__class__.use_npu = True
self.__class__.no_need_check_grad = True
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册