From a3cc4a4a691d519233a9f9a1ac947692f2574e39 Mon Sep 17 00:00:00 2001 From: Meiyim Date: Mon, 15 Mar 2021 11:31:43 +0800 Subject: [PATCH] [NPU] Support npu op table_lookup_v2 and table_lookup_v2_grad (#31399) * [npu] support npu kernel `table_lookup_v2` * clean up * +python test * +cmake * clean up * remove int8 kernel + python unitest for fp16 * clean up --- paddle/fluid/operators/CMakeLists.txt | 3 + .../fluid/operators/lookup_table_v2_op_npu.cc | 80 ++++++++++ .../operators/lookup_table_v2_op_npu_test.cc | 142 ++++++++++++++++++ .../npu/test_lookup_table_v2_op_npu.py | 142 ++++++++++++++++++ 4 files changed, 367 insertions(+) create mode 100644 paddle/fluid/operators/lookup_table_v2_op_npu.cc create mode 100644 paddle/fluid/operators/lookup_table_v2_op_npu_test.cc create mode 100644 python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index ad4e1cd55f..4797b0e715 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -156,6 +156,9 @@ cc_library(tensor_formatter SRCS tensor_formatter.cc DEPS ${OP_HEADER_DEPS}) if (WITH_PYTHON) cc_library(py_func_op SRCS py_func_op.cc DEPS op_registry python pybind) endif() +if (WITH_ASCEND_CL) + cc_test(lookup_table_v2_op_npu_test SRCS lookup_table_v2_op_npu_test.cc DEPS op_registry lookup_table_v2_op scope device_context enforce executor compare_op) +endif() if (WITH_ASCEND_CL) cc_test(range_op_npu_test SRCS range_op_npu_test.cc DEPS op_registry range_op scope device_context enforce executor) diff --git a/paddle/fluid/operators/lookup_table_v2_op_npu.cc b/paddle/fluid/operators/lookup_table_v2_op_npu.cc new file mode 100644 index 0000000000..e7cc048ed3 --- /dev/null +++ b/paddle/fluid/operators/lookup_table_v2_op_npu.cc @@ -0,0 +1,80 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +template +class LookupTableV2NPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *ids_t = ctx.Input("Ids"); // int tensor + auto *output_t = ctx.Output("Out"); // float tensor + auto *table_t = ctx.Input("W"); + auto *table_var = ctx.InputVar("W"); + PADDLE_ENFORCE_EQ( + table_var->IsType(), true, + platform::errors::InvalidArgument("npu only accept LoDTensor")); + output_t->mutable_data(ctx.GetPlace()); + framework::NPUAttributeMap attr_input = {{"validate_indices", false}}; + + auto runner = + NpuOpRunner("Gather", {*table_t, *ids_t}, {*output_t}, attr_input); + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } +}; + +template +class LookupTableV2GradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *ids_t = ctx.Input("Ids"); + auto *output_grad_t = + ctx.Input(framework::GradVarName("Out")); + auto *table_t = ctx.Input("W"); + auto *table_grad_t = + ctx.Output(framework::GradVarName("W")); + framework::NPUAttributeMap attr_input = {{"use_locking", true}}; + + auto runner = NpuOpRunner("ScatterAdd", {*table_t, *ids_t, *output_grad_t}, + {*table_grad_t}, attr_input); + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_NPU_KERNEL( + lookup_table_v2, + ops::LookupTableV2NPUKernel, + ops::LookupTableV2NPUKernel); + +REGISTER_OP_NPU_KERNEL( + lookup_table_v2_grad, ops::LookupTableV2GradNPUKernel, + ops::LookupTableV2GradNPUKernel); diff --git a/paddle/fluid/operators/lookup_table_v2_op_npu_test.cc b/paddle/fluid/operators/lookup_table_v2_op_npu_test.cc new file mode 100644 index 0000000000..f37915834b --- /dev/null +++ b/paddle/fluid/operators/lookup_table_v2_op_npu_test.cc @@ -0,0 +1,142 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef _WIN32 +#include +#endif + +#include +#include +#include +#include +#include // NOLINT +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/operators/dropout_op.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/string/printf.h" + +namespace f = paddle::framework; +namespace p = paddle::platform; +namespace m = paddle::operators::math; + +USE_OP(lookup_table_v2); +USE_OP_DEVICE_KERNEL(lookup_table_v2, NPU); + +template +void Compare(f::Scope* scope, const p::DeviceContext& ctx) { + // init + auto ids = scope->Var("Ids"); + auto out = scope->Var("Out"); + auto w = scope->Var("W"); + + auto ids_t = ids->GetMutable(); + auto out_t = out->GetMutable(); + auto w_t = w->GetMutable(); + int bsz = 10; + int dim = 32; + int seqlen = 8; + int vocab_size = 100; + TensorFromVector(std::vector(bsz * seqlen, 3), ctx, ids_t); + std::vector val(vocab_size * dim, 10.); + TensorFromVector(val, ctx, w_t); + ids_t->Resize({bsz, seqlen}); + w_t->Resize({vocab_size, dim}); + out_t->Resize({bsz, seqlen, dim}); + ctx.Wait(); + + auto place = ctx.GetPlace(); + out_t->mutable_data(place); + f::AttributeMap attrs = {{}}; + auto op = f::OpRegistry::CreateOp("lookup_table_v2", + {{"W", {"W"}}, {"Ids", {"Ids"}}}, + {{"Out", {"Out"}}}, attrs); + op->Run(*scope, place); + std::vector out_v; + TensorToVector(*out_t, ctx, &out_v); + ctx.Wait(); + EXPECT_EQ(out_t->numel(), bsz * seqlen * dim); + T res = std::accumulate(out_v.begin(), out_v.end(), 0.); + float eps = 1.e-6; + EXPECT_LT(fabs(res - bsz * seqlen * dim * 10.), eps); +} + +template +void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx) { + // init + auto w = scope->Var("W"); + auto ids = scope->Var("Ids"); + auto out = scope->Var("DOut"); + auto dw = scope->Var("DW"); + + auto w_t = w->GetMutable(); + auto ids_t = ids->GetMutable(); + auto out_t = out->GetMutable(); + auto dw_t = dw->GetMutable(); + + int bsz = 2; + int dim = 2; + int seqlen = 2; + int vocab_size = 4; + + std::vector val_int(bsz * seqlen, 3); + std::vector val(vocab_size * dim, 0.); + std::vector val_out(bsz * seqlen * dim, 1.); + + TensorFromVector(val_int, ctx, ids_t); + TensorFromVector(val, ctx, w_t); + TensorFromVector(val, ctx, dw_t); + TensorFromVector(val_out, ctx, out_t); + + w_t->Resize({vocab_size, dim}); + ids_t->Resize({bsz, seqlen}); + out_t->Resize({bsz, seqlen, dim}); + dw_t->Resize({vocab_size, dim}); + + ctx.Wait(); + + auto place = ctx.GetPlace(); + out_t->mutable_data(place); + w_t->mutable_data(place); + dw_t->mutable_data(place); + f::AttributeMap attrs = {{}}; + auto op = f::OpRegistry::CreateOp( + "lookup_table_v2_grad", + {{"Ids", {"Ids"}}, {"W", {"W"}}, {"Out@GRAD", {"DOut"}}}, + {{"W@GRAD", {"DW"}}}, attrs); + op->Run(*scope, place); + ctx.Wait(); + std::vector w_v; + TensorToVector(*dw_t, ctx, &w_v); + ctx.Wait(); + EXPECT_EQ(dw_t->numel(), vocab_size * dim); + T res = std::accumulate(w_v.begin(), w_v.end(), 0.); + float eps = 1.e-6; + EXPECT_LT(fabs(res - bsz * seqlen * dim), eps); +} + +TEST(lookup_table_v2, NPU_fp32) { + f::Scope scope; + p::NPUDeviceContext ctx(p::NPUPlace(0)); + Compare(&scope, ctx); +} + +TEST(lookup_table_v2_grad, NPU_fp32) { + f::Scope scope; + p::NPUDeviceContext ctx(p::NPUPlace(0)); + CompareGrad(&scope, ctx); +} diff --git a/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py new file mode 100644 index 0000000000..99016e5d62 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py @@ -0,0 +1,142 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid + +paddle.enable_static() +SEED = 2021 + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestLookupTableV2(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "lookup_table_v2" + self.place = paddle.NPUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + bsz=2 + seqlen=2 + vocab=3 + dim=2 + w = np.ones([vocab, dim]).astype(self.dtype) + x = np.random.randint(0, vocab, size=(bsz, seqlen)).astype(np.int64) + out = np.ones([bsz, seqlen, dim]).astype(self.dtype) + + self.inputs = {'W': OpTest.np_dtype_to_fluid_dtype(w), 'Ids': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = { + 'is_sparse': False, + 'is_distributed': False, + 'remote_prefetch':False, + 'padding_idx': -1 + } + self.outputs = {'Out': out} + + def set_npu(self): + self.__class__.use_npu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place, check_dygraph=False) + + # TODO(ascendrc): Add grad test + # def test_check_grad(self): + # if self.dtype == np.float16: + # return + # self.check_grad(['X'], 'Out') + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestLookupTableV2FP16(TestLookupTableV2): + no_need_check_grad = True + def init_dtype(self): + self.dtype = np.float16 + +#@unittest.skipIf(not paddle.is_compiled_with_npu(), +# "core is not compiled with NPU") +#class TestLookupTableV2Int8(TestLookupTableV2): +# def init_dtype(self): +# self.dtype = np.int8 +# +#@unittest.skipIf(not paddle.is_compiled_with_npu(), +# "core is not compiled with NPU") +#class TestLookupTableV2UInt8(TestLookupTableV2): +# def init_dtype(self): +# self.dtype = np.uint8 + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestLookupTableV2Net(unittest.TestCase): + def _test(self, run_npu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + bsz=3 + seqlen=2 + vocab=3 + dim=2 + + ids_np = np.random.randint(0, vocab, size=(bsz, seqlen)).astype('int64') + + with paddle.static.program_guard(main_prog, startup_prog): + emb = paddle.nn.Embedding(vocab, dim) + ids = paddle.static.data(name="ids", shape=[bsz, seqlen], dtype='int64') + res = emb(ids) + loss = res.sum() + + if run_npu: + place = paddle.NPUPlace(0) + else: + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + for epoch in range(1): + loss_res, w = exe.run( + main_prog, + feed={"ids": ids_np}, + fetch_list=[loss, emb.weight]) + if epoch % 10 == 0: + print(w) + print("Epoch {} | Loss: {}".format(epoch, loss)) + + return loss_res + + def test_npu(self): + cpu_loss = self._test(False) + npu_loss = self._test(True) + self.assertTrue(np.allclose(npu_loss, cpu_loss)) + + + +if __name__ == '__main__': + unittest.main() + -- GitLab