From 853af66fc73eaa068ca0b6a0579599c1a0d1f613 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 7 Apr 2021 15:55:11 +0800 Subject: [PATCH] [NPU] support cann 20.3 (#32044) * fix compile problem on cann 20.3 * fix ut * fix test_mul * fix check_finite_and_scale * fix lookup_table_v2_grad * fix cmake * support print op --- cmake/external/ascend.cmake | 9 +- paddle/fluid/operators/CMakeLists.txt | 1 - .../amp/check_finite_and_unscale_op_npu.cc | 3 +- .../check_finite_and_unscale_op_npu_test.cc | 4 +- .../fluid/operators/lookup_table_v2_op_npu.cc | 24 +-- .../operators/lookup_table_v2_op_npu_test.cc | 142 ------------------ paddle/fluid/operators/tensor_formatter.cc | 5 + paddle/fluid/platform/npu_profiler.h | 13 +- .../test_amp_check_finite_and_scale_op_npu.py | 2 + .../npu/test_lookup_table_v2_op_npu.py | 2 +- .../tests/unittests/npu/test_mul_op_npu.py | 5 +- .../unittests/npu/test_reduce_any_op_npu.py | 2 + .../tests/unittests/test_assign_op_npu.py | 5 +- 13 files changed, 51 insertions(+), 166 deletions(-) delete mode 100644 paddle/fluid/operators/lookup_table_v2_op_npu_test.cc diff --git a/cmake/external/ascend.cmake b/cmake/external/ascend.cmake index f46c5bf7ac0..ed98ca60e4d 100644 --- a/cmake/external/ascend.cmake +++ b/cmake/external/ascend.cmake @@ -21,6 +21,11 @@ else() set(ASCEND_DIR /usr/local/Ascend) endif() +if(EXISTS ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/include/graph/ascend_string.h) + # It means CANN 20.2 + + add_definitions(-DPADDLE_WITH_ASCEND_STRING) +endif() + if(WITH_ASCEND) set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64) set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common) @@ -43,9 +48,7 @@ if(WITH_ASCEND) set(atlas_acl_lib ${ATLAS_RUNTIME_DIR}/libascendcl.so) INCLUDE_DIRECTORIES(${ATLAS_RUNTIME_INC_DIR}) - if(EXISTS ${ATLAS_RUNTIME_INC_DIR}/graph/ascend_string.h) - add_definitions(-DPADDLE_WITH_ASCEND_STRING) - endif() + ADD_LIBRARY(ascend_ge SHARED IMPORTED GLOBAL) SET_PROPERTY(TARGET ascend_ge PROPERTY IMPORTED_LOCATION ${atlas_ge_runner_lib}) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 6fe18f24794..2d3550f8f06 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -159,7 +159,6 @@ endif() if (WITH_ASCEND_CL) cc_test(range_op_npu_test SRCS range_op_npu_test.cc DEPS op_registry range_op scope device_context enforce executor) - cc_test(lookup_table_v2_op_npu_test SRCS lookup_table_v2_op_npu_test.cc DEPS op_registry lookup_table_v2_op scope device_context enforce executor compare_op) cc_test(expand_op_npu_test SRCS expand_op_npu_test.cc DEPS op_registry expand_op scope device_context enforce executor compare_op) endif() diff --git a/paddle/fluid/operators/amp/check_finite_and_unscale_op_npu.cc b/paddle/fluid/operators/amp/check_finite_and_unscale_op_npu.cc index 46f9f7ff089..3db45805025 100644 --- a/paddle/fluid/operators/amp/check_finite_and_unscale_op_npu.cc +++ b/paddle/fluid/operators/amp/check_finite_and_unscale_op_npu.cc @@ -61,7 +61,6 @@ class CheckFiniteAndUnscaleNPUKernel : public framework::OpKernel { size_t x_size = xs.size(); for (size_t i = 0; i < x_size; ++i) { - found_inf_data = true; const auto* x = xs[i]; auto* out = outs[i]; out->mutable_data(ctx.GetPlace()); @@ -77,6 +76,8 @@ class CheckFiniteAndUnscaleNPUKernel : public framework::OpKernel { NpuOpRunner("CheckNumerics", {*x}, {check_xout}, {{"message", std::string("check_nan_and_inf")}}); runner_checknumerics.Run(stream); + ctx.template device_context() + .Wait(); } catch (platform::EnforceNotMet& exception) { LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!"; found_inf_data = true; diff --git a/paddle/fluid/operators/amp/check_finite_and_unscale_op_npu_test.cc b/paddle/fluid/operators/amp/check_finite_and_unscale_op_npu_test.cc index 99e81a4757d..1ed188b1593 100644 --- a/paddle/fluid/operators/amp/check_finite_and_unscale_op_npu_test.cc +++ b/paddle/fluid/operators/amp/check_finite_and_unscale_op_npu_test.cc @@ -110,10 +110,10 @@ void Compare(f::Scope *scope, const p::DeviceContext &ctx) { // out found_inf Tensor found_inf_tensor; found_inf_tensor.Resize({1}); - bool *is_finite_data = + bool *found_inf_data = found_inf_tensor.mutable_data(paddle::platform::CPUPlace()); f::TensorCopy(*found_inf, place, &found_inf_tensor); - EXPECT_FALSE(*is_finite_data); + EXPECT_TRUE(*found_inf_data); ctx.Wait(); } diff --git a/paddle/fluid/operators/lookup_table_v2_op_npu.cc b/paddle/fluid/operators/lookup_table_v2_op_npu.cc index 4516aa38fb3..320b498156b 100644 --- a/paddle/fluid/operators/lookup_table_v2_op_npu.cc +++ b/paddle/fluid/operators/lookup_table_v2_op_npu.cc @@ -28,6 +28,12 @@ class LookupTableV2NPUKernel : public framework::OpKernel { auto *ids_t = ctx.Input("Ids"); // int tensor auto *output_t = ctx.Output("Out"); // float tensor auto *table_t = ctx.Input("W"); + + // It seems cann 20.1 accepts int64, but cann 20.2+ not. + PADDLE_ENFORCE_EQ(ids_t->type(), framework::proto::VarType::INT32, + platform::errors::Unimplemented( + "The index of LookupTableV2 should be int32.")); + auto *table_var = ctx.InputVar("W"); PADDLE_ENFORCE_EQ( table_var->IsType(), true, @@ -49,28 +55,26 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { auto *ids_t = ctx.Input("Ids"); + auto *output_grad_t = ctx.Input(framework::GradVarName("Out")); auto *table_grad_t = ctx.Output(framework::GradVarName("W")); - table_grad_t->mutable_data(ctx.GetPlace()); + auto *p = table_grad_t->mutable_data(ctx.GetPlace()); auto stream = ctx.template device_context() .stream(); - // step2: ZerosLike x in device - Tensor zeroslike_w(table_grad_t->type()); - zeroslike_w.Resize(table_grad_t->dims()); - auto p = zeroslike_w.mutable_data(ctx.GetPlace()); - platform::NPUMemsetAsync(static_cast(p), 0, - zeroslike_w.numel() * sizeof(T), stream); + table_grad_t->numel() * sizeof(T), stream); - table_grad_t->mutable_data(ctx.GetPlace()); + // NOTE(zhiqiu): It seems in cann 20.1, the first input and output + // can be different tensor, but in cann 20.2+, it does inplace operation. + // Thus, the first input and output should be same tensor. auto runner_scatter = - NpuOpRunner("ScatterAdd", {zeroslike_w, *ids_t, *output_grad_t}, - {*table_grad_t}, {}); + NpuOpRunner("ScatterAdd", {*table_grad_t, *ids_t, *output_grad_t}, + {*table_grad_t}, {{"use_locking", true}}); runner_scatter.Run(stream); } }; diff --git a/paddle/fluid/operators/lookup_table_v2_op_npu_test.cc b/paddle/fluid/operators/lookup_table_v2_op_npu_test.cc deleted file mode 100644 index f37915834bd..00000000000 --- a/paddle/fluid/operators/lookup_table_v2_op_npu_test.cc +++ /dev/null @@ -1,142 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifndef _WIN32 -#include -#endif - -#include -#include -#include -#include -#include // NOLINT -#include - -#include "gtest/gtest.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" -#include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/string/printf.h" - -namespace f = paddle::framework; -namespace p = paddle::platform; -namespace m = paddle::operators::math; - -USE_OP(lookup_table_v2); -USE_OP_DEVICE_KERNEL(lookup_table_v2, NPU); - -template -void Compare(f::Scope* scope, const p::DeviceContext& ctx) { - // init - auto ids = scope->Var("Ids"); - auto out = scope->Var("Out"); - auto w = scope->Var("W"); - - auto ids_t = ids->GetMutable(); - auto out_t = out->GetMutable(); - auto w_t = w->GetMutable(); - int bsz = 10; - int dim = 32; - int seqlen = 8; - int vocab_size = 100; - TensorFromVector(std::vector(bsz * seqlen, 3), ctx, ids_t); - std::vector val(vocab_size * dim, 10.); - TensorFromVector(val, ctx, w_t); - ids_t->Resize({bsz, seqlen}); - w_t->Resize({vocab_size, dim}); - out_t->Resize({bsz, seqlen, dim}); - ctx.Wait(); - - auto place = ctx.GetPlace(); - out_t->mutable_data(place); - f::AttributeMap attrs = {{}}; - auto op = f::OpRegistry::CreateOp("lookup_table_v2", - {{"W", {"W"}}, {"Ids", {"Ids"}}}, - {{"Out", {"Out"}}}, attrs); - op->Run(*scope, place); - std::vector out_v; - TensorToVector(*out_t, ctx, &out_v); - ctx.Wait(); - EXPECT_EQ(out_t->numel(), bsz * seqlen * dim); - T res = std::accumulate(out_v.begin(), out_v.end(), 0.); - float eps = 1.e-6; - EXPECT_LT(fabs(res - bsz * seqlen * dim * 10.), eps); -} - -template -void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx) { - // init - auto w = scope->Var("W"); - auto ids = scope->Var("Ids"); - auto out = scope->Var("DOut"); - auto dw = scope->Var("DW"); - - auto w_t = w->GetMutable(); - auto ids_t = ids->GetMutable(); - auto out_t = out->GetMutable(); - auto dw_t = dw->GetMutable(); - - int bsz = 2; - int dim = 2; - int seqlen = 2; - int vocab_size = 4; - - std::vector val_int(bsz * seqlen, 3); - std::vector val(vocab_size * dim, 0.); - std::vector val_out(bsz * seqlen * dim, 1.); - - TensorFromVector(val_int, ctx, ids_t); - TensorFromVector(val, ctx, w_t); - TensorFromVector(val, ctx, dw_t); - TensorFromVector(val_out, ctx, out_t); - - w_t->Resize({vocab_size, dim}); - ids_t->Resize({bsz, seqlen}); - out_t->Resize({bsz, seqlen, dim}); - dw_t->Resize({vocab_size, dim}); - - ctx.Wait(); - - auto place = ctx.GetPlace(); - out_t->mutable_data(place); - w_t->mutable_data(place); - dw_t->mutable_data(place); - f::AttributeMap attrs = {{}}; - auto op = f::OpRegistry::CreateOp( - "lookup_table_v2_grad", - {{"Ids", {"Ids"}}, {"W", {"W"}}, {"Out@GRAD", {"DOut"}}}, - {{"W@GRAD", {"DW"}}}, attrs); - op->Run(*scope, place); - ctx.Wait(); - std::vector w_v; - TensorToVector(*dw_t, ctx, &w_v); - ctx.Wait(); - EXPECT_EQ(dw_t->numel(), vocab_size * dim); - T res = std::accumulate(w_v.begin(), w_v.end(), 0.); - float eps = 1.e-6; - EXPECT_LT(fabs(res - bsz * seqlen * dim), eps); -} - -TEST(lookup_table_v2, NPU_fp32) { - f::Scope scope; - p::NPUDeviceContext ctx(p::NPUPlace(0)); - Compare(&scope, ctx); -} - -TEST(lookup_table_v2_grad, NPU_fp32) { - f::Scope scope; - p::NPUDeviceContext ctx(p::NPUPlace(0)); - CompareGrad(&scope, ctx); -} diff --git a/paddle/fluid/operators/tensor_formatter.cc b/paddle/fluid/operators/tensor_formatter.cc index e4fa4a96a5c..5bce5719d7c 100644 --- a/paddle/fluid/operators/tensor_formatter.cc +++ b/paddle/fluid/operators/tensor_formatter.cc @@ -125,6 +125,11 @@ void TensorFormatter::FormatData(const framework::LoDTensor& print_tensor, framework::LoDTensor cpu_tensor; platform::CPUPlace cpu_place; TensorCopy(print_tensor, cpu_place, &cpu_tensor); +#ifdef PADDLE_WITH_ASCEND_CL + if (platform::is_npu_place(print_tensor.place())) { + platform::DeviceContextPool::Instance().Get(print_tensor.place())->Wait(); + } +#endif data = cpu_tensor.data(); } diff --git a/paddle/fluid/platform/npu_profiler.h b/paddle/fluid/platform/npu_profiler.h index 05325aaf9ba..a7b674d0d0c 100644 --- a/paddle/fluid/platform/npu_profiler.h +++ b/paddle/fluid/platform/npu_profiler.h @@ -23,7 +23,17 @@ limitations under the License. */ namespace paddle { namespace platform { -// For ACL 20.1 +#ifdef PADDLE_WITH_ASCEND_STRING +// For CANN 20.2+ +// ACL_AICORE_ARITHMETIC_UTILIZATION = 0, record arithmetic stats +// ACL_AICORE_PIPE_UTILIZATION = 1, record pipeline +// ACL_AICORE_MEMORY_BANDWIDTH = 2, record memory +// ACL_AICORE_L0B_AND_WIDTH = 3, recore internal memory +// ACL_AICORE_RESOURCE_CONFLICT_RATIO = 5, record pipeline ratio +constexpr aclprofAicoreMetrics default_metrics = + ACL_AICORE_ARITHMETIC_UTILIZATION; +#else +// For CANN 20.1 // ACL_AICORE_ARITHMATIC_THROUGHPUT = 0, record arithmetic stats // ACL_AICORE_PIPELINE = 1, record pipeline // ACL_AICORE_SYNCHRONIZATION = 2, record sync @@ -32,6 +42,7 @@ namespace platform { // ACL_AICORE_STALL = 5, record pipeline ratio constexpr aclprofAicoreMetrics default_metrics = ACL_AICORE_ARITHMATIC_THROUGHPUT; +#endif // ACL_PROF_ACL_API, record ACL API stats // ACL_PROF_TASK_TIME, record AI core stats diff --git a/python/paddle/fluid/tests/unittests/npu/test_amp_check_finite_and_scale_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_amp_check_finite_and_scale_op_npu.py index 4cda0ceeccf..ac80ea4c62c 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_amp_check_finite_and_scale_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_amp_check_finite_and_scale_op_npu.py @@ -14,6 +14,8 @@ import unittest import numpy as np +import sys +sys.path.append("..") from op_test import OpTest, skip_check_grad_ci import paddle import paddle.fluid as fluid diff --git a/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py index 400ddd9d4aa..2463ddb7137 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py @@ -41,7 +41,7 @@ class TestLookupTableV2(OpTest): vocab = 10 dim = 20 w = np.ones([vocab, dim]).astype(self.dtype) - x = np.random.randint(0, vocab, size=(bsz, seqlen)).astype(np.int64) + x = np.random.randint(0, vocab, size=(bsz, seqlen)).astype(np.int32) out = np.ones([bsz, seqlen, dim]).astype(self.dtype) self.inputs = { diff --git a/python/paddle/fluid/tests/unittests/npu/test_mul_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_mul_op_npu.py index e65a3dac739..4fcfd33b32f 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_mul_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_mul_op_npu.py @@ -248,8 +248,9 @@ class TestMulNet3_2(unittest.TestCase): cpu_pred, cpu_loss = self._test(False) npu_pred, npu_loss = self._test(True) - self.assertTrue(np.allclose(npu_pred, cpu_pred)) - self.assertTrue(np.allclose(npu_loss, cpu_loss)) + self.assertTrue(np.allclose( + npu_pred, cpu_pred, atol=1e-5)) # atol needed on cann 20.3 + self.assertTrue(np.allclose(npu_loss, cpu_loss, atol=1e-5)) @unittest.skipIf(not paddle.is_compiled_with_npu(), diff --git a/python/paddle/fluid/tests/unittests/npu/test_reduce_any_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_reduce_any_op_npu.py index 087256b2980..583a648224d 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_reduce_any_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_reduce_any_op_npu.py @@ -16,6 +16,8 @@ from __future__ import print_function import unittest import numpy as np +import sys +sys.path.append("..") from op_test import OpTest, skip_check_grad_ci import paddle import paddle.fluid.core as core diff --git a/python/paddle/fluid/tests/unittests/test_assign_op_npu.py b/python/paddle/fluid/tests/unittests/test_assign_op_npu.py index 44515ce2e5b..ed21549b7e0 100644 --- a/python/paddle/fluid/tests/unittests/test_assign_op_npu.py +++ b/python/paddle/fluid/tests/unittests/test_assign_op_npu.py @@ -36,7 +36,7 @@ class TestAssign(OpTest): self.op_type = "assign" self.init_dtype() - x = np.rand.random([3,3]) + x = np.random.random([3, 3]).astype(self.dtype) self.inputs = {'X': x} self.attrs = {} @@ -46,7 +46,7 @@ class TestAssign(OpTest): self.__class__.use_npu = True def init_dtype(self): - self.dtype = np.int64 + self.dtype = np.float32 def test_check_output(self): self.check_output_with_place(self.place, check_dygraph=False) @@ -54,4 +54,3 @@ class TestAssign(OpTest): if __name__ == '__main__': unittest.main() - -- GitLab