未验证 提交 853af66f 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] support cann 20.3 (#32044)

* fix compile problem on cann 20.3

* fix ut

* fix test_mul

* fix check_finite_and_scale

* fix lookup_table_v2_grad

* fix cmake

* support print op
上级 78959a39
......@@ -21,6 +21,11 @@ else()
set(ASCEND_DIR /usr/local/Ascend)
endif()
if(EXISTS ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/include/graph/ascend_string.h)
# It means CANN 20.2 +
add_definitions(-DPADDLE_WITH_ASCEND_STRING)
endif()
if(WITH_ASCEND)
set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64)
set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common)
......@@ -43,9 +48,7 @@ if(WITH_ASCEND)
set(atlas_acl_lib ${ATLAS_RUNTIME_DIR}/libascendcl.so)
INCLUDE_DIRECTORIES(${ATLAS_RUNTIME_INC_DIR})
if(EXISTS ${ATLAS_RUNTIME_INC_DIR}/graph/ascend_string.h)
add_definitions(-DPADDLE_WITH_ASCEND_STRING)
endif()
ADD_LIBRARY(ascend_ge SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ascend_ge PROPERTY IMPORTED_LOCATION ${atlas_ge_runner_lib})
......
......@@ -159,7 +159,6 @@ endif()
if (WITH_ASCEND_CL)
cc_test(range_op_npu_test SRCS range_op_npu_test.cc DEPS op_registry range_op scope device_context enforce executor)
cc_test(lookup_table_v2_op_npu_test SRCS lookup_table_v2_op_npu_test.cc DEPS op_registry lookup_table_v2_op scope device_context enforce executor compare_op)
cc_test(expand_op_npu_test SRCS expand_op_npu_test.cc DEPS op_registry expand_op scope device_context enforce executor compare_op)
endif()
......
......@@ -61,7 +61,6 @@ class CheckFiniteAndUnscaleNPUKernel : public framework::OpKernel<T> {
size_t x_size = xs.size();
for (size_t i = 0; i < x_size; ++i) {
found_inf_data = true;
const auto* x = xs[i];
auto* out = outs[i];
out->mutable_data<T>(ctx.GetPlace());
......@@ -77,6 +76,8 @@ class CheckFiniteAndUnscaleNPUKernel : public framework::OpKernel<T> {
NpuOpRunner("CheckNumerics", {*x}, {check_xout},
{{"message", std::string("check_nan_and_inf")}});
runner_checknumerics.Run(stream);
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.Wait();
} catch (platform::EnforceNotMet& exception) {
LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!";
found_inf_data = true;
......
......@@ -110,10 +110,10 @@ void Compare(f::Scope *scope, const p::DeviceContext &ctx) {
// out found_inf
Tensor found_inf_tensor;
found_inf_tensor.Resize({1});
bool *is_finite_data =
bool *found_inf_data =
found_inf_tensor.mutable_data<bool>(paddle::platform::CPUPlace());
f::TensorCopy(*found_inf, place, &found_inf_tensor);
EXPECT_FALSE(*is_finite_data);
EXPECT_TRUE(*found_inf_data);
ctx.Wait();
}
......
......@@ -28,6 +28,12 @@ class LookupTableV2NPUKernel : public framework::OpKernel<T> {
auto *ids_t = ctx.Input<framework::LoDTensor>("Ids"); // int tensor
auto *output_t = ctx.Output<framework::LoDTensor>("Out"); // float tensor
auto *table_t = ctx.Input<framework::LoDTensor>("W");
// It seems cann 20.1 accepts int64, but cann 20.2+ not.
PADDLE_ENFORCE_EQ(ids_t->type(), framework::proto::VarType::INT32,
platform::errors::Unimplemented(
"The index of LookupTableV2 should be int32."));
auto *table_var = ctx.InputVar("W");
PADDLE_ENFORCE_EQ(
table_var->IsType<framework::LoDTensor>(), true,
......@@ -49,28 +55,26 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *ids_t = ctx.Input<framework::LoDTensor>("Ids");
auto *output_grad_t =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *table_grad_t =
ctx.Output<framework::LoDTensor>(framework::GradVarName("W"));
table_grad_t->mutable_data<T>(ctx.GetPlace());
auto *p = table_grad_t->mutable_data<T>(ctx.GetPlace());
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
// step2: ZerosLike x in device
Tensor zeroslike_w(table_grad_t->type());
zeroslike_w.Resize(table_grad_t->dims());
auto p = zeroslike_w.mutable_data<T>(ctx.GetPlace());
platform::NPUMemsetAsync(static_cast<void *>(p), 0,
zeroslike_w.numel() * sizeof(T), stream);
table_grad_t->numel() * sizeof(T), stream);
table_grad_t->mutable_data<T>(ctx.GetPlace());
// NOTE(zhiqiu): It seems in cann 20.1, the first input and output
// can be different tensor, but in cann 20.2+, it does inplace operation.
// Thus, the first input and output should be same tensor.
auto runner_scatter =
NpuOpRunner("ScatterAdd", {zeroslike_w, *ids_t, *output_grad_t},
{*table_grad_t}, {});
NpuOpRunner("ScatterAdd", {*table_grad_t, *ids_t, *output_grad_t},
{*table_grad_t}, {{"use_locking", true}});
runner_scatter.Run(stream);
}
};
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef _WIN32
#include <unistd.h>
#endif
#include <cmath>
#include <iostream>
#include <numeric>
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/string/printf.h"
namespace f = paddle::framework;
namespace p = paddle::platform;
namespace m = paddle::operators::math;
USE_OP(lookup_table_v2);
USE_OP_DEVICE_KERNEL(lookup_table_v2, NPU);
template <typename T>
void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
// init
auto ids = scope->Var("Ids");
auto out = scope->Var("Out");
auto w = scope->Var("W");
auto ids_t = ids->GetMutable<f::LoDTensor>();
auto out_t = out->GetMutable<f::LoDTensor>();
auto w_t = w->GetMutable<f::LoDTensor>();
int bsz = 10;
int dim = 32;
int seqlen = 8;
int vocab_size = 100;
TensorFromVector(std::vector<int64_t>(bsz * seqlen, 3), ctx, ids_t);
std::vector<T> val(vocab_size * dim, 10.);
TensorFromVector(val, ctx, w_t);
ids_t->Resize({bsz, seqlen});
w_t->Resize({vocab_size, dim});
out_t->Resize({bsz, seqlen, dim});
ctx.Wait();
auto place = ctx.GetPlace();
out_t->mutable_data<T>(place);
f::AttributeMap attrs = {{}};
auto op = f::OpRegistry::CreateOp("lookup_table_v2",
{{"W", {"W"}}, {"Ids", {"Ids"}}},
{{"Out", {"Out"}}}, attrs);
op->Run(*scope, place);
std::vector<T> out_v;
TensorToVector(*out_t, ctx, &out_v);
ctx.Wait();
EXPECT_EQ(out_t->numel(), bsz * seqlen * dim);
T res = std::accumulate(out_v.begin(), out_v.end(), 0.);
float eps = 1.e-6;
EXPECT_LT(fabs(res - bsz * seqlen * dim * 10.), eps);
}
template <typename T>
void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx) {
// init
auto w = scope->Var("W");
auto ids = scope->Var("Ids");
auto out = scope->Var("DOut");
auto dw = scope->Var("DW");
auto w_t = w->GetMutable<f::LoDTensor>();
auto ids_t = ids->GetMutable<f::LoDTensor>();
auto out_t = out->GetMutable<f::LoDTensor>();
auto dw_t = dw->GetMutable<f::LoDTensor>();
int bsz = 2;
int dim = 2;
int seqlen = 2;
int vocab_size = 4;
std::vector<int64_t> val_int(bsz * seqlen, 3);
std::vector<T> val(vocab_size * dim, 0.);
std::vector<T> val_out(bsz * seqlen * dim, 1.);
TensorFromVector(val_int, ctx, ids_t);
TensorFromVector(val, ctx, w_t);
TensorFromVector(val, ctx, dw_t);
TensorFromVector(val_out, ctx, out_t);
w_t->Resize({vocab_size, dim});
ids_t->Resize({bsz, seqlen});
out_t->Resize({bsz, seqlen, dim});
dw_t->Resize({vocab_size, dim});
ctx.Wait();
auto place = ctx.GetPlace();
out_t->mutable_data<T>(place);
w_t->mutable_data<T>(place);
dw_t->mutable_data<T>(place);
f::AttributeMap attrs = {{}};
auto op = f::OpRegistry::CreateOp(
"lookup_table_v2_grad",
{{"Ids", {"Ids"}}, {"W", {"W"}}, {"Out@GRAD", {"DOut"}}},
{{"W@GRAD", {"DW"}}}, attrs);
op->Run(*scope, place);
ctx.Wait();
std::vector<T> w_v;
TensorToVector(*dw_t, ctx, &w_v);
ctx.Wait();
EXPECT_EQ(dw_t->numel(), vocab_size * dim);
T res = std::accumulate(w_v.begin(), w_v.end(), 0.);
float eps = 1.e-6;
EXPECT_LT(fabs(res - bsz * seqlen * dim), eps);
}
TEST(lookup_table_v2, NPU_fp32) {
f::Scope scope;
p::NPUDeviceContext ctx(p::NPUPlace(0));
Compare<float>(&scope, ctx);
}
TEST(lookup_table_v2_grad, NPU_fp32) {
f::Scope scope;
p::NPUDeviceContext ctx(p::NPUPlace(0));
CompareGrad<float>(&scope, ctx);
}
......@@ -125,6 +125,11 @@ void TensorFormatter::FormatData(const framework::LoDTensor& print_tensor,
framework::LoDTensor cpu_tensor;
platform::CPUPlace cpu_place;
TensorCopy(print_tensor, cpu_place, &cpu_tensor);
#ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(print_tensor.place())) {
platform::DeviceContextPool::Instance().Get(print_tensor.place())->Wait();
}
#endif
data = cpu_tensor.data<T>();
}
......
......@@ -23,7 +23,17 @@ limitations under the License. */
namespace paddle {
namespace platform {
// For ACL 20.1
#ifdef PADDLE_WITH_ASCEND_STRING
// For CANN 20.2+
// ACL_AICORE_ARITHMETIC_UTILIZATION = 0, record arithmetic stats
// ACL_AICORE_PIPE_UTILIZATION = 1, record pipeline
// ACL_AICORE_MEMORY_BANDWIDTH = 2, record memory
// ACL_AICORE_L0B_AND_WIDTH = 3, recore internal memory
// ACL_AICORE_RESOURCE_CONFLICT_RATIO = 5, record pipeline ratio
constexpr aclprofAicoreMetrics default_metrics =
ACL_AICORE_ARITHMETIC_UTILIZATION;
#else
// For CANN 20.1
// ACL_AICORE_ARITHMATIC_THROUGHPUT = 0, record arithmetic stats
// ACL_AICORE_PIPELINE = 1, record pipeline
// ACL_AICORE_SYNCHRONIZATION = 2, record sync
......@@ -32,6 +42,7 @@ namespace platform {
// ACL_AICORE_STALL = 5, record pipeline ratio
constexpr aclprofAicoreMetrics default_metrics =
ACL_AICORE_ARITHMATIC_THROUGHPUT;
#endif
// ACL_PROF_ACL_API, record ACL API stats
// ACL_PROF_TASK_TIME, record AI core stats
......
......@@ -14,6 +14,8 @@
import unittest
import numpy as np
import sys
sys.path.append("..")
from op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.fluid as fluid
......
......@@ -41,7 +41,7 @@ class TestLookupTableV2(OpTest):
vocab = 10
dim = 20
w = np.ones([vocab, dim]).astype(self.dtype)
x = np.random.randint(0, vocab, size=(bsz, seqlen)).astype(np.int64)
x = np.random.randint(0, vocab, size=(bsz, seqlen)).astype(np.int32)
out = np.ones([bsz, seqlen, dim]).astype(self.dtype)
self.inputs = {
......
......@@ -248,8 +248,9 @@ class TestMulNet3_2(unittest.TestCase):
cpu_pred, cpu_loss = self._test(False)
npu_pred, npu_loss = self._test(True)
self.assertTrue(np.allclose(npu_pred, cpu_pred))
self.assertTrue(np.allclose(npu_loss, cpu_loss))
self.assertTrue(np.allclose(
npu_pred, cpu_pred, atol=1e-5)) # atol needed on cann 20.3
self.assertTrue(np.allclose(npu_loss, cpu_loss, atol=1e-5))
@unittest.skipIf(not paddle.is_compiled_with_npu(),
......
......@@ -16,6 +16,8 @@ from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
from op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.fluid.core as core
......
......@@ -36,7 +36,7 @@ class TestAssign(OpTest):
self.op_type = "assign"
self.init_dtype()
x = np.rand.random([3,3])
x = np.random.random([3, 3]).astype(self.dtype)
self.inputs = {'X': x}
self.attrs = {}
......@@ -46,7 +46,7 @@ class TestAssign(OpTest):
self.__class__.use_npu = True
def init_dtype(self):
self.dtype = np.int64
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False)
......@@ -54,4 +54,3 @@ class TestAssign(OpTest):
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册