未验证 提交 338f9e05 编写于 作者: R ronnywang 提交者: GitHub

add sequence_mask_op_npu and tests (#34455)

上级 46808afc
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sequence_ops/sequence_mask_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class SequenceMaskNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Output<Tensor>("Y");
int maxlen = ctx.Attr<int>("maxlen");
if (ctx.HasInput("MaxLenTensor")) {
auto max_len_tensor = ctx.Input<Tensor>("MaxLenTensor");
PADDLE_ENFORCE_NOT_NULL(max_len_tensor,
platform::errors::InvalidArgument(
"Input(MaxLenTensor) should not be NULL."
"But received Input(MaxLenTensor) is NULL"));
framework::Tensor temp;
TensorCopySync(*max_len_tensor, platform::CPUPlace(), &temp);
maxlen = *temp.data<int32_t>();
PADDLE_ENFORCE_GT(
maxlen, 0,
platform::errors::InvalidArgument(
"Input(MaxLenTensor) value should be greater than 0. But "
"received Input(MaxLenTensor) value = %d.",
maxlen));
}
if (maxlen < 0) {
auto x_numel = x->numel();
std::vector<T> x_vec;
framework::TensorToVector(*x, dev_ctx, &x_vec);
auto x_data = x_vec.data();
maxlen = static_cast<int>(*std::max_element(x_data, x_data + x_numel));
}
auto y_dim = framework::vectorize<int>(x->dims());
y_dim.push_back(maxlen);
Tensor cast_x;
cast_x.mutable_data<int32_t>(x->dims(), ctx.GetPlace());
const auto& cast1_runner =
NpuOpRunner("Cast", {*x}, {cast_x},
{{"dst_type", ConvertToNpuDtype(cast_x.type())}});
cast1_runner.Run(dev_ctx.stream());
Tensor tmp;
tmp.mutable_data<int32_t>(framework::make_ddim({maxlen}), ctx.GetPlace());
NpuOpRunner range_runner;
range_runner.SetType("Range");
range_runner.AddInput(std::vector<int32_t>({0}));
range_runner.AddInput(std::vector<int32_t>({maxlen}));
range_runner.AddInput(std::vector<int32_t>({1}));
range_runner.AddOutput(tmp);
range_runner.Run(dev_ctx.stream());
Tensor expand_tmp;
expand_tmp.mutable_data<int32_t>(framework::make_ddim(y_dim),
ctx.GetPlace());
const auto& expand_runner =
NpuOpRunner("ExpandD", {tmp}, {expand_tmp}, {{"shape", y_dim}});
expand_runner.Run(dev_ctx.stream());
auto x_dims = framework::vectorize<int>(x->dims());
x_dims.push_back(1);
cast_x.Resize(framework::make_ddim({x_dims}));
Tensor x_tmp;
x_tmp.mutable_data<int32_t>(framework::make_ddim(y_dim), ctx.GetPlace());
const auto& tile_runner =
NpuOpRunner("TileWithAxis", {cast_x}, {x_tmp},
{{"axis", x->dims().size()}, {"tiles", maxlen}});
tile_runner.Run(dev_ctx.stream());
Tensor y_tmp;
y_tmp.mutable_data<uint8_t>(framework::make_ddim(y_dim), ctx.GetPlace());
const auto& less_runner =
NpuOpRunner("Less", {expand_tmp, x_tmp}, {y_tmp}, {});
less_runner.Run(dev_ctx.stream());
y->Resize(framework::make_ddim(y_dim));
auto out_dtype = static_cast<framework::proto::VarType::Type>(
ctx.Attr<int>("out_dtype"));
if (out_dtype == framework::proto::VarType::INT32) {
y->mutable_data<int32_t>(ctx.GetPlace());
} else if (out_dtype == framework::proto::VarType::INT64) {
y->mutable_data<int64_t>(ctx.GetPlace());
} else if (out_dtype == framework::proto::VarType::FP32) {
y->mutable_data<float>(ctx.GetPlace());
} else if (out_dtype == framework::proto::VarType::FP64) {
y->mutable_data<double>(ctx.GetPlace());
} else if (out_dtype == framework::proto::VarType::BOOL) {
y->mutable_data<bool>(ctx.GetPlace());
} else if (out_dtype == framework::proto::VarType::UINT8) {
y->mutable_data<uint8_t>(ctx.GetPlace());
} else {
PADDLE_ENFORCE(false,
platform::errors::InvalidArgument(
"out_dtype only supporing int32, int64, fp32, fp64, "
"bool, uint8, but receive out_dtype is %d",
out_dtype));
}
const auto& cast2_runner = NpuOpRunner(
"Cast", {y_tmp}, {*y}, {{"dst_type", ConvertToNpuDtype(out_dtype)}});
cast2_runner.Run(dev_ctx.stream());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(
sequence_mask, ops::SequenceMaskNPUKernel<plat::NPUDeviceContext, int32_t>,
ops::SequenceMaskNPUKernel<plat::NPUDeviceContext, int64_t>,
ops::SequenceMaskNPUKernel<plat::NPUDeviceContext, float>,
ops::SequenceMaskNPUKernel<plat::NPUDeviceContext, double>);
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.framework import convert_np_dtype_to_dtype_, Program, program_guard
paddle.enable_static()
class SequenceMaskTestBase(OpTest):
def set_npu(self):
self.__class__.use_npu = True
def initDefaultParameters(self):
self.op_type = 'sequence_mask'
self.maxlen = 10
self.mask_dtype = 'int64'
self.x = [[0, 3, 4], [5, 7, 9]]
def initParameters(self):
pass
def setUp(self):
self.set_npu()
self.initDefaultParameters()
self.initParameters()
if not isinstance(self.x, np.ndarray):
self.x = np.array(self.x)
self.inputs = {'X': self.x}
self.outputs = {'Y': self.calc_ground_truth_mask()}
self.attrs = {
'maxlen': self.maxlen,
'out_dtype': convert_np_dtype_to_dtype_(self.mask_dtype)
}
def calc_ground_truth_mask(self):
maxlen = np.max(self.x) if self.maxlen < 0 else self.maxlen
shape = self.x.shape + (maxlen, )
index_broadcast = np.broadcast_to(
np.reshape(
range(maxlen), newshape=[1] * self.x.ndim + [-1]),
shape=shape)
x_broadcast = np.broadcast_to(
np.reshape(
self.x, newshape=self.x.shape + (-1, )), shape=shape)
return (index_broadcast < x_broadcast).astype(self.mask_dtype)
def test_check_output(self):
self.check_output_with_place(paddle.NPUPlace(0))
class SequenceMaskTest1(SequenceMaskTestBase):
def initParameters(self):
self.mask_dtype = 'bool'
class SequenceMaskTest2(SequenceMaskTestBase):
def initParameters(self):
self.mask_dtype = 'uint8'
class SequenceMaskTest3(SequenceMaskTestBase):
def initParameters(self):
self.mask_dtype = 'int32'
class SequenceMaskTest4(SequenceMaskTestBase):
def initParameters(self):
self.mask_dtype = 'float32'
class SequenceMaskTest5(SequenceMaskTestBase):
def initParameters(self):
self.mask_dtype = 'float64'
class SequenceMaskTest6(SequenceMaskTestBase):
def initParameters(self):
self.maxlen = -1
class SequenceMaskTestBase_tensor_attr(OpTest):
def set_npu(self):
self.__class__.use_npu = True
def initDefaultParameters(self):
self.op_type = 'sequence_mask'
self.maxlen = 10
self.maxlen_tensor = np.ones((1), 'int32') * 10
self.mask_dtype = 'int64'
self.x = [[0, 3, 4], [5, 7, 9]]
def initParameters(self):
pass
def setUp(self):
self.set_npu()
self.initDefaultParameters()
self.initParameters()
if not isinstance(self.x, np.ndarray):
self.x = np.array(self.x)
self.inputs = {'X': self.x, 'MaxLenTensor': self.maxlen_tensor}
self.outputs = {'Y': self.calc_ground_truth_mask()}
self.attrs = {'out_dtype': convert_np_dtype_to_dtype_(self.mask_dtype)}
def calc_ground_truth_mask(self):
maxlen = np.max(self.x) if self.maxlen < 0 else self.maxlen
shape = self.x.shape + (maxlen, )
index_broadcast = np.broadcast_to(
np.reshape(
range(maxlen), newshape=[1] * self.x.ndim + [-1]),
shape=shape)
x_broadcast = np.broadcast_to(
np.reshape(
self.x, newshape=self.x.shape + (-1, )), shape=shape)
return (index_broadcast < x_broadcast).astype(self.mask_dtype)
def test_check_output(self):
self.check_output()
class SequenceMaskTest1_tensor_attr(SequenceMaskTestBase_tensor_attr):
def initParameters(self):
self.mask_dtype = 'bool'
class SequenceMaskTest2_tensor_attr(SequenceMaskTestBase_tensor_attr):
def initParameters(self):
self.mask_dtype = 'uint8'
class SequenceMaskTest3_tensor_attr(SequenceMaskTestBase_tensor_attr):
def initParameters(self):
self.mask_dtype = 'int32'
class SequenceMaskTest4_tensor_attr(SequenceMaskTestBase_tensor_attr):
def initParameters(self):
self.mask_dtype = 'float32'
class SequenceMaskTest5_tensor_attr(SequenceMaskTestBase_tensor_attr):
def initParameters(self):
self.mask_dtype = 'float64'
class TestSequenceMaskOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
input_data = np.random.uniform(1, 5, [4]).astype("float32")
def test_Variable():
# the input must be Variable
fluid.layers.sequence_mask(input_data, maxlen=4)
self.assertRaises(TypeError, test_Variable)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册