未验证 提交 19b02d95 编写于 作者: A Aganlengzi 提交者: GitHub

[NPU] modifications for model ernie-1.0 (#36642)

* [NPU] modifications for model ernie-1.0

* rollback 503003 and change cast to dtype
上级 2dd0a46a
......@@ -21,6 +21,38 @@ namespace operators {
using Tensor = framework::Tensor;
static void CumsumImp(const Tensor& input, Tensor* output,
const framework::NPUAttributeMap& attr_input,
const framework::ExecutionContext& ctx) {
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
if (input.type() == framework::proto::VarType::INT64) {
Tensor tmp_input;
tmp_input.mutable_data<float>(input.dims(), ctx.GetPlace());
auto dst_acl_dtype = ConvertToNpuDtype(tmp_input.type());
const auto& cast_runner_1 =
NpuOpRunner("Cast", {input}, {tmp_input},
{{"dst_type", static_cast<int>(dst_acl_dtype)}});
cast_runner_1.Run(stream);
Tensor tmp_output;
tmp_output.mutable_data<float>(output->dims(), ctx.GetPlace());
const auto& runner =
NpuOpRunner("CumsumD", {tmp_input}, {tmp_output}, attr_input);
runner.Run(stream);
dst_acl_dtype = ConvertToNpuDtype(output->type());
const auto& cast_runner_2 =
NpuOpRunner("Cast", {tmp_output}, {*output},
{{"dst_type", static_cast<int>(dst_acl_dtype)}});
cast_runner_2.Run(stream);
} else {
const auto& runner = NpuOpRunner("CumsumD", {input}, {*output}, attr_input);
runner.Run(stream);
}
}
template <typename DeviceContext, typename T>
class CumSumNPUKernel : public framework::OpKernel<T> {
public:
......@@ -36,10 +68,6 @@ class CumSumNPUKernel : public framework::OpKernel<T> {
framework::NPUAttributeMap attr_input = {
{"axis", axis}, {"exclusive", exclusive}, {"reverse", reverse}};
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
bool flatten = ctx.Attr<bool>("flatten");
if (flatten) {
PADDLE_ENFORCE_EQ(
......@@ -53,11 +81,9 @@ class CumSumNPUKernel : public framework::OpKernel<T> {
new_x.Resize(framework::make_ddim({x->numel()}));
const auto& runner = NpuOpRunner("CumsumD", {new_x}, {*out}, attr_input);
runner.Run(stream);
CumsumImp(new_x, out, attr_input, ctx);
} else {
const auto& runner = NpuOpRunner("CumsumD", {*x}, {*out}, attr_input);
runner.Run(stream);
CumsumImp(*x, out, attr_input, ctx);
}
}
};
......@@ -69,5 +95,8 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(
cumsum, ops::CumSumNPUKernel<plat::NPUDeviceContext, int>,
#ifdef PADDLE_WITH_ASCEND_INT64
ops::CumSumNPUKernel<plat::NPUDeviceContext, int64_t>,
#endif
ops::CumSumNPUKernel<plat::NPUDeviceContext, float>,
ops::CumSumNPUKernel<plat::NPUDeviceContext, plat::float16>);
......@@ -167,10 +167,16 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(elementwise_sub, ops::ElementwiseSubNPUKernel<int>,
#ifdef PADDLE_WITH_ASCEND_INT64
ops::ElementwiseSubNPUKernel<int64_t>,
#endif
ops::ElementwiseSubNPUKernel<float>,
ops::ElementwiseSubNPUKernel<plat::float16>);
REGISTER_OP_NPU_KERNEL(elementwise_sub_grad,
ops::ElementwiseSubGradNPUKernel<int>,
#ifdef PADDLE_WITH_ASCEND_INT64
ops::ElementwiseSubGradNPUKernel<int64_t>,
#endif
ops::ElementwiseSubGradNPUKernel<float>,
ops::ElementwiseSubGradNPUKernel<plat::float16>);
......@@ -21,6 +21,9 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
constexpr int64_t kNoPadding = -1;
template <typename DeviceContext, typename T>
class LookupTableV2NPUKernel : public framework::OpKernel<T> {
public:
......@@ -35,16 +38,52 @@ class LookupTableV2NPUKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument("npu only accept LoDTensor"));
output_t->mutable_data<T>(ctx.GetPlace());
NpuOpRunner runner;
runner.SetType("GatherV2")
.AddInput(*table_t)
.AddInput(*ids_t)
.AddInput(std::vector<int32_t>{0})
int64_t padding_idx = ctx.Attr<int64_t>("padding_idx");
if (padding_idx == kNoPadding) {
NpuOpRunner runner;
runner.SetType("GatherV2")
.AddInput(*table_t)
.AddInput(*ids_t)
.AddInput(std::vector<int32_t>{0})
#if (CANN_VERSION_CODE >= 503003)
.AddAttrs({{"batch_dims", 0}})
#endif
.AddOutput(*output_t);
runner.Run();
} else {
Tensor tmp_table_t(table_t->type());
tmp_table_t.mutable_data<T>(table_t->dims(), ctx.GetPlace());
Tensor index;
index.mutable_data<int32_t>({1, 1}, ctx.GetPlace());
FillNpuTensorWithConstant<int32_t>(&index,
static_cast<int32_t>(padding_idx));
auto updata_dim = framework::make_ddim({1, table_t->dims()[1]});
Tensor update;
update.mutable_data<T>(updata_dim, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&update, static_cast<T>(0));
update.Resize(updata_dim);
NpuOpRunner update_runner;
update_runner.SetType("TensorScatterUpdate")
.AddInput(*table_t)
.AddInput(index)
.AddInput(update)
.AddOutput(tmp_table_t);
update_runner.Run();
NpuOpRunner runner;
runner.SetType("GatherV2")
.AddInput(tmp_table_t)
.AddInput(*ids_t)
.AddInput(std::vector<int32_t>{0})
#if (CANN_VERSION_CODE >= 503003)
.AddAttrs({{"batch_dims", 0}})
.AddAttrs({{"batch_dims", 0}})
#endif
.AddOutput(*output_t);
runner.Run();
.AddOutput(*output_t);
runner.Run();
}
}
};
......
......@@ -249,5 +249,45 @@ class TestNPUCumSumWithFlatten2(TestNPUCumSumOp1):
self.outputs = {'Out': self.inputs['X'].cumsum()}
#----------------Cumsum Int64----------------
class TestNPUCumSumOpInt64(TestNPUCumSumOp1):
def init_testcase(self):
self.attrs = {'axis': -1, 'reverse': True}
self.inputs = {
'X': np.random.randint(
1, 10000, size=(5, 6, 10)).astype(self.dtype)
}
self.outputs = {
'Out': np.flip(
np.flip(
self.inputs['X'], axis=2).cumsum(axis=2), axis=2)
}
def create_test_int64(parent):
class TestCumSumInt64(parent):
def init_dtype(self):
self.dtype = np.int64
cls_name = "{0}_{1}".format(parent.__name__, "Int64")
TestCumSumInt64.__name__ = cls_name
globals()[cls_name] = TestCumSumInt64
create_test_int64(TestNPUCumSumOp1)
create_test_int64(TestNPUCumSumOp2)
create_test_int64(TestNPUCumSumOp3)
create_test_int64(TestNPUCumSumOp4)
create_test_int64(TestNPUCumSumOp5)
create_test_int64(TestNPUCumSumOp7)
create_test_int64(TestNPUCumSumExclusive1)
create_test_int64(TestNPUCumSumExclusive2)
create_test_int64(TestNPUCumSumExclusive3)
create_test_int64(TestNPUCumSumExclusive4)
create_test_int64(TestNPUCumSumExclusive5)
create_test_int64(TestNPUCumSumReverseExclusive)
create_test_int64(TestNPUCumSumWithFlatten1)
create_test_int64(TestNPUCumSumWithFlatten2)
if __name__ == '__main__':
unittest.main()
......@@ -95,6 +95,11 @@ class TestElementwiseSubOpInt32(TestElementwiseSubOp):
self.dtype = np.int32
class TestElementwiseSubOpInt64(TestElementwiseSubOp):
def init_dtype(self):
self.dtype = np.int64
class TestSubtractAPI(unittest.TestCase):
def test_name(self):
with paddle.static.program_guard(paddle.static.Program()):
......
......@@ -33,14 +33,15 @@ class TestLookupTableV2(OpTest):
self.place = paddle.NPUPlace(0)
self.init_dtype()
self.init_dim()
self.init_dims()
self.init_padding_idx()
np.random.seed(SEED)
bsz = 6
seqlen = 8
vocab = 10
w = np.ones([vocab, self.dim]).astype(self.dtype)
x = np.random.randint(0, vocab, size=(bsz, seqlen)).astype(np.int32)
out = np.ones([bsz, seqlen, self.dim]).astype(self.dtype)
w = np.random.random([self.vocab, self.dim]).astype(self.dtype)
x = np.random.randint(
0, self.vocab, size=(self.bsz, self.seqlen)).astype(np.int32)
out = w[x]
if self.padding_idx != -1:
out[np.squeeze(x == self.padding_idx)] = np.zeros(self.dim)
self.inputs = {
'W': OpTest.np_dtype_to_fluid_dtype(w),
......@@ -50,7 +51,7 @@ class TestLookupTableV2(OpTest):
'is_sparse': False,
'is_distributed': False,
'remote_prefetch': False,
'padding_idx': -1
'padding_idx': self.padding_idx
}
self.outputs = {'Out': out}
......@@ -60,10 +61,16 @@ class TestLookupTableV2(OpTest):
def init_dtype(self):
self.dtype = np.float32
def init_dim(self):
def init_dims(self):
self.bsz = 6
self.seqlen = 8
self.vocab = 10
# embedding_dim is not multiple of 32
self.dim = 20
def init_padding_idx(self):
self.padding_idx = -1
def test_check_output(self):
self.check_output_with_place(self.place)
......@@ -85,7 +92,10 @@ class TestLookupTableV2FP16(TestLookupTableV2):
class TestLookupTableV2Dim32(TestLookupTableV2):
def init_dim(self):
def init_dims(self):
self.bsz = 6
self.seqlen = 8
self.vocab = 10
# embedding_dim is multiple of 32
self.dim = 64
......@@ -96,7 +106,10 @@ class TestLookupTableV2Dim32FP16(TestLookupTableV2):
def init_dtype(self):
self.dtype = np.float16
def init_dim(self):
def init_dims(self):
self.bsz = 6
self.seqlen = 8
self.vocab = 10
self.dim = 64
def set_npu(self):
......@@ -104,5 +117,10 @@ class TestLookupTableV2Dim32FP16(TestLookupTableV2):
self.__class__.no_need_check_grad = True
class TestLookupTableV2WithPadding(TestLookupTableV2):
def init_padding_idx(self):
self.padding_idx = np.random.randint(0, self.vocab)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
paddle.enable_static()
SEED = 2021
def reference_matmul(X, Y, transpose_X=False, transpose_Y=False, scale=1.0):
"""Reference forward implementation using np.matmul."""
# np.matmul does not support the transpose flags, so we manually
# transpose X and Y appropriately.
if transpose_X:
if X.ndim == 1:
X = X.reshape((X.size, ))
elif X.ndim == 2:
X = X.T
else:
dim = [i for i in range(len(X.shape))]
dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1]
X = np.transpose(X, tuple(dim))
if transpose_Y:
if Y.ndim == 1:
Y = Y.reshape((Y.size, ))
else:
dim = [i for i in range(len(Y.shape))]
dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1]
Y = np.transpose(Y, tuple(dim))
Out = np.matmul(X, Y)
if not Out.shape:
# We do not support 0-dimensional Tensors (scalars). So where
# np.matmul outputs a scalar, we must convert to a Tensor of
# shape (1, ) instead.
# Everywhere else, we are compatible with np.matmul.
Out = np.array([Out], dtype="float64")
if abs(scale - 1.0) > 1e-09:
Out = Out * scale
return Out
class TestMatMulOp(OpTest):
"""
basic case
"""
def setUp(self):
self.set_npu()
self.op_type = "matmul"
self.init_dtype()
self.init_alpha()
self.config()
X = np.random.random(self.x_shape).astype(self.dtype)
Y = np.random.random(self.y_shape).astype(self.dtype)
# -0.1 ~ 0.1
X = -0.1 + 0.2 * X
Y = -0.1 + 0.2 * Y
Out = reference_matmul(X, Y, self.transpose_X, self.transpose_Y,
self.alpha)
Out = Out.astype(self.dtype)
self.inputs = {'X': X, 'Y': Y}
self.attrs = {
'transpose_X': self.transpose_X,
'transpose_Y': self.transpose_Y,
'alpha': self.alpha
}
self.outputs = {'Out': Out}
def set_npu(self):
self.__class__.use_npu = True
self.place = paddle.NPUPlace(0)
def config(self):
self.x_shape = (100, )
self.y_shape = (100, )
self.transpose_X = False
self.transpose_Y = False
def init_alpha(self):
self.alpha = 1.0
def init_dtype(self):
self.dtype = "float32"
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-7)
def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['X', 'Y'], 'Out')
class TestMatMulOp1(TestMatMulOp):
"""
case x_ndim == 1, y_ndim != 1
"""
def config(self):
self.x_shape = (100, )
self.y_shape = (1, 3, 2, 100)
self.transpose_X = False
self.transpose_Y = True
class TestMatMulOp2(TestMatMulOp):
"""
case x_ndim != 1, y_ndim == 1
"""
def config(self):
self.x_shape = (1, 2, 100, 1)
self.y_shape = (100, )
self.transpose_X = True
self.transpose_Y = False
class TestMatMulOp3(TestMatMulOp):
"""
case [M, K] x [K, N] = [M, N]
"""
def config(self):
self.x_shape = (2, 100)
self.y_shape = (100, 2)
self.transpose_X = False
self.transpose_Y = False
class TestMatMulOp4(TestMatMulOp):
"""
case [M, K] x [K, N] = [M, N]
"""
def config(self):
self.x_shape = (2, 100)
self.y_shape = (2, 100)
self.transpose_X = False
self.transpose_Y = True
class TestMatMulOp5(TestMatMulOp):
"""
case [M, K] x [K, N] = [M, N]
"""
def config(self):
self.x_shape = (100, 2)
self.y_shape = (100, 2)
self.transpose_X = True
self.transpose_Y = False
class TestMatMulOp6(TestMatMulOp):
"""
case [B, M, K] x [K, N] = [B, M, N]
"""
def config(self):
self.x_shape = (2, 2, 25)
self.y_shape = (25, 4)
self.transpose_X = False
self.transpose_Y = False
class TestMatMulOp7(TestMatMulOp):
"""
case [B, M, K] x [K, N] = [B, M, N]
"""
def config(self):
self.x_shape = (1, 2, 25)
self.y_shape = (4, 25)
self.transpose_X = False
self.transpose_Y = True
class TestMatMulOp8(TestMatMulOp):
"""
case [B, M, K] x [K, N] = [B, M, N]
"""
def config(self):
self.x_shape = (1, 25, 4)
self.y_shape = (25, 4)
self.transpose_X = True
self.transpose_Y = False
class TestMatMulOp9(TestMatMulOp):
"""
case [B, M, K] x [B, K, N] = [B, M, N]
"""
def config(self):
self.x_shape = (2, 5, 10)
self.y_shape = (2, 10, 5)
self.transpose_X = False
self.transpose_Y = False
class TestMatMulOp10(TestMatMulOp):
"""
case [B, M, K] x [B, K, N] = [B, M, N]
"""
def config(self):
self.x_shape = (2, 10, 5)
self.y_shape = (2, 10, 5)
self.transpose_X = True
self.transpose_Y = False
class TestMatMulOp11(TestMatMulOp):
"""
case [B, M, K] x [B, K, N] = [B, M, N]
"""
def config(self):
self.x_shape = (2, 5, 10)
self.y_shape = (2, 5, 10)
self.transpose_X = False
self.transpose_Y = True
class TestMatMulOp12(TestMatMulOp):
"""
case to check the gradient for special case
"""
def config(self):
self.x_shape = (100)
self.y_shape = (1, 2, 2, 100, 2)
self.transpose_X = False
self.transpose_Y = False
class TestMatMulOp13(TestMatMulOp):
"""
case to check the gradient for special case
"""
def config(self):
self.x_shape = (2, 1, 100)
self.y_shape = (100)
self.transpose_X = False
self.transpose_Y = False
#--------------------test matmul alpha--------------------
def create_test_alpha_class(parent):
class TestMatMulOpAlphaCase(parent):
def init_alpha(self):
self.alpha = 0.125
cls_name = "{0}_{1}".format(parent.__name__, "Alpha")
TestMatMulOpAlphaCase.__name__ = cls_name
globals()[cls_name] = TestMatMulOpAlphaCase
create_test_alpha_class(TestMatMulOp)
create_test_alpha_class(TestMatMulOp1)
create_test_alpha_class(TestMatMulOp2)
create_test_alpha_class(TestMatMulOp3)
create_test_alpha_class(TestMatMulOp4)
create_test_alpha_class(TestMatMulOp5)
create_test_alpha_class(TestMatMulOp6)
create_test_alpha_class(TestMatMulOp9)
create_test_alpha_class(TestMatMulOp10)
create_test_alpha_class(TestMatMulOp11)
create_test_alpha_class(TestMatMulOp12)
create_test_alpha_class(TestMatMulOp13)
#--------------------test matmul fp16--------------------
def create_test_fp16_class(parent, atol=0.001, max_relative_error=2.5):
class TestMatMulOpFp16Case(parent):
def init_kernel_type(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output_with_place(self.place, atol=atol)
def test_check_grad(self):
self.check_grad_with_place(
self.place, ['X', 'Y'],
'Out',
max_relative_error=max_relative_error)
cls_name = "{0}_{1}".format(parent.__name__, "Fp16")
TestMatMulOpFp16Case.__name__ = cls_name
globals()[cls_name] = TestMatMulOpFp16Case
create_test_fp16_class(TestMatMulOp)
create_test_fp16_class(TestMatMulOp1)
create_test_fp16_class(TestMatMulOp2)
create_test_fp16_class(TestMatMulOp3)
create_test_fp16_class(TestMatMulOp4)
create_test_fp16_class(TestMatMulOp5)
create_test_fp16_class(TestMatMulOp6)
create_test_fp16_class(TestMatMulOp9)
create_test_fp16_class(TestMatMulOp10)
create_test_fp16_class(TestMatMulOp11)
create_test_fp16_class(TestMatMulOp12)
create_test_fp16_class(TestMatMulOp13)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册