提交 e895c98f 编写于 作者: S sneaxiy

add support to max_len is None

上级 64464cb1
......@@ -162,7 +162,7 @@ paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs
paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'max_len', 'mask_dtype'], varargs=None, keywords=None, defaults=('int64',))
paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'maxlen', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 'int64', None))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
......
......@@ -14,6 +14,14 @@
#pragma once
#ifdef __NVCC__
#include <thrust/device_ptr.h>
#include <thrust/functional.h>
#include <thrust/reduce.h>
#else
#include <algorithm>
#endif
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
......@@ -26,50 +34,60 @@ class SequenceMaskOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist");
auto max_len = ctx->Attrs().Get<int>("max_len");
PADDLE_ENFORCE_GT(max_len, 1, "Attr(max_len) must be larger than 1");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist");
auto dim = framework::vectorize2int(ctx->GetInputDim("X"));
dim.push_back(max_len);
ctx->SetOutputDim("Y", framework::make_ddim(dim));
auto maxlen = ctx->Attrs().Get<int>("maxlen");
if (maxlen > 0) { // We can only infershape when maxlen > 0
auto dim = framework::vectorize2int(ctx->GetInputDim("X"));
dim.push_back(maxlen);
ctx->SetOutputDim("Y", framework::make_ddim(dim));
}
}
};
class SequenceMaskOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input of sequence_mask op.");
AddInput("X", "The input tensor of sequence_mask op.");
AddOutput("Y", "The output mask of sequence_mask op.");
AddAttr<int>("max_len", "The maximum length of the sequence.")
.GreaterThan(1);
AddAttr<int>("maxlen",
"The maximum length of the sequence. If maxlen < 0, maxlen "
"= max(Input(X)).")
.SetDefault(-1)
.AddCustomChecker([](int &v) {
PADDLE_ENFORCE(v < 0 || v >= 1,
"Attr(maxlen) must be less than 0 or larger than 1");
});
AddAttr<int>("out_dtype", "Output data type");
AddComment(R"DOC(
SequenceMask Operator
This operator outputs a Mask according to Input(X) and Attr(max_len).
This operator outputs a Mask according to Input(X) and Attr(maxlen).
Supposing Input(X) is a Tensor with shape [d_1, d_2, ..., d_n], the
Output(Y) is a mask with shape [d_1, d_2, ..., d_n, max_len], where:
Output(Y) is a mask with shape [d_1, d_2, ..., d_n, maxlen], where:
Y(i_1, i_2, ..., i_n, j) = (j < X(i_1, i_2, ..., i_n))
If maxlen < 0, maxlen = max(X)
)DOC");
}
};
template <typename Tx, typename Ty>
struct SequenceMaskForRangeFunctor {
HOSTDEVICE SequenceMaskForRangeFunctor(const Tx *x, Ty *y, int max_len)
: x_(x), y_(y), max_len_(max_len) {}
HOSTDEVICE SequenceMaskForRangeFunctor(const Tx *x, Ty *y, int maxlen)
: x_(x), y_(y), maxlen_(maxlen) {}
HOSTDEVICE void operator()(int y_idx) const {
int x_idx = y_idx / max_len_;
int j = y_idx % max_len_;
int x_idx = y_idx / maxlen_;
int j = y_idx % maxlen_;
y_[y_idx] = static_cast<Ty>(j < x_[x_idx] ? 1 : 0);
}
private:
const Tx *x_;
Ty *y_;
int max_len_;
int maxlen_;
};
template <typename DeviceContext, typename Tx>
......@@ -77,14 +95,14 @@ struct SequenceMaskFunctor {
using Tensor = framework::LoDTensor;
SequenceMaskFunctor(const DeviceContext &ctx, const Tx *x, Tensor *y,
int limits, int max_len)
: ctx_(ctx), x_(x), y_(y), limits_(limits), max_len_(max_len) {}
int limits, int maxlen)
: ctx_(ctx), x_(x), y_(y), limits_(limits), maxlen_(maxlen) {}
template <typename Ty>
void operator()() const {
auto *y_data = y_->mutable_data<Ty>(ctx_.GetPlace());
platform::ForRange<DeviceContext> for_range(ctx_, limits_);
for_range(SequenceMaskForRangeFunctor<Tx, Ty>(x_, y_data, max_len_));
for_range(SequenceMaskForRangeFunctor<Tx, Ty>(x_, y_data, maxlen_));
}
private:
......@@ -92,7 +110,7 @@ struct SequenceMaskFunctor {
const Tx *x_;
Tensor *y_;
int limits_;
int max_len_;
int maxlen_;
};
template <typename DeviceContext, typename Tx>
......@@ -103,13 +121,32 @@ class SequenceMaskKernel : public framework::OpKernel<Tx> {
void Compute(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Output<Tensor>("Y");
auto max_len = ctx.Attr<int>("max_len");
auto maxlen = ctx.Attr<int>("maxlen");
auto *x_data = x->data<Tx>();
auto x_numel = x->numel();
if (maxlen < 0) {
#ifdef __NVCC__
VLOG(10)
<< "SequenceMaskOp on GPU may be slow when maxlen is not provided.";
maxlen = static_cast<int>(
thrust::reduce(thrust::device_pointer_cast(x_data),
thrust::device_pointer_cast(x_data) + x_numel,
static_cast<Tx>(0), thrust::maximum<Tx>()));
#else
maxlen = static_cast<int>(*std::max_element(x_data, x_data + x_numel));
#endif
auto y_dim = framework::vectorize2int(x->dims());
y_dim.push_back(maxlen);
y->Resize(framework::make_ddim(y_dim));
}
auto out_dtype = static_cast<framework::proto::VarType::Type>(
ctx.Attr<int>("out_dtype"));
auto &dev_ctx = ctx.template device_context<DeviceContext>();
framework::VisitDataType(out_dtype, SequenceMaskFunctor<DeviceContext, Tx>(
dev_ctx, x->data<Tx>(), y,
x->numel() * max_len, max_len));
framework::VisitDataType(out_dtype,
SequenceMaskFunctor<DeviceContext, Tx>(
dev_ctx, x_data, y, x_numel * maxlen, maxlen));
}
};
......
......@@ -5525,13 +5525,46 @@ def flatten(x, axis=1, name=None):
return out
def sequence_mask(x, max_len, mask_dtype='int64'):
def sequence_mask(x, maxlen=None, dtype='int64', name=None):
"""
**SequenceMask Layer**
This layer outputs a mask according to the input :code:`x` and
:code:`maxlen` with data type of :code:`dtype`.
Supposing :code:`x` is a Tensor with shape [d_1, d_2, ..., d_n], the
:code:`y` is a mask with shape [d_1, d_2, ..., d_n, maxlen], where:
.. math::
y(i_1, i_2,..., i_n, j) = (j < x(i_1, i_2,..., i_n))
Args:
x (Variable): Input tensor of sequence_mask layer,
whose elements are integers less than :code:`maxlen`.
maxlen (int|None): Maximum length of the sequence. If :code:`maxlen`
is None, it would be replace with :math:`max(x)`.
dtype (np.dtype|core.VarDesc.VarType|str): Data type of the output.
name (str|None): A name for this layer(optional). If set None, the
layer will be named automatically.
Returns:
Variable: The output sequence mask.
"""
helper = LayerHelper('sequence_mask', **locals())
y = helper.create_tmp_variable(dtype=mask_dtype)
if name is None:
out = helper.create_tmp_variable(dtype=dtype)
else:
out = helper.create_tmp_variable(dtype=dtype, name=name)
helper.append_op(
type='sequence_mask',
inputs={'X': [x]},
outputs={'Y': y},
attrs={'max_len': max_len,
'out_dtype': y.dtype})
return y
outputs={'Y': out},
attrs={
'max_len': maxlen if maxlen is not None else -1,
'out_dtype': out.dtype
})
return out
......@@ -13,7 +13,9 @@
# limitations under the License.
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid.framework import convert_np_dtype_to_dtype_
import paddle.fluid.core as core
import numpy as np
import copy
import unittest
......@@ -22,7 +24,7 @@ import unittest
class SequenceMaskTestBase(OpTest):
def initDefaultParameters(self):
self.op_type = 'sequence_mask'
self.max_len = 10
self.maxlen = 10
self.mask_dtype = 'int64'
self.x = [[0, 3, 4], [5, 7, 9]]
......@@ -38,15 +40,16 @@ class SequenceMaskTestBase(OpTest):
self.inputs = {'X': self.x}
self.outputs = {'Y': self.calc_ground_truth_mask()}
self.attrs = {
'max_len': self.max_len,
'maxlen': self.maxlen,
'out_dtype': convert_np_dtype_to_dtype_(self.mask_dtype)
}
def calc_ground_truth_mask(self):
shape = self.x.shape + (self.max_len, )
maxlen = np.max(self.x) if self.maxlen < 0 else self.maxlen
shape = self.x.shape + (maxlen, )
index_broadcast = np.broadcast_to(
np.reshape(
range(self.max_len), newshape=[1] * self.x.ndim + [-1]),
range(maxlen), newshape=[1] * self.x.ndim + [-1]),
shape=shape)
x_broadcast = np.broadcast_to(
np.reshape(
......@@ -82,5 +85,10 @@ class SequenceMaskTest5(SequenceMaskTestBase):
self.mask_dtype = 'float64'
class SequenceMaskTest6(SequenceMaskTestBase):
def initParameters(self):
self.maxlen = -1
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册