未验证 提交 cdf2579d 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #14053 from jczaja/prv-seqpool-max

Max Sequence pool optimization
...@@ -64,7 +64,7 @@ paddle.fluid.layers.chunk_eval ArgSpec(args=['input', 'label', 'chunk_scheme', ' ...@@ -64,7 +64,7 @@ paddle.fluid.layers.chunk_eval ArgSpec(args=['input', 'label', 'chunk_scheme', '
paddle.fluid.layers.sequence_conv ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'bias_attr', 'param_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(3, 1, None, None, None, None, None)) paddle.fluid.layers.sequence_conv ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'bias_attr', 'param_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(3, 1, None, None, None, None, None))
paddle.fluid.layers.conv2d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None)) paddle.fluid.layers.conv2d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None))
paddle.fluid.layers.conv3d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None)) paddle.fluid.layers.conv3d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None))
paddle.fluid.layers.sequence_pool ArgSpec(args=['input', 'pool_type'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.sequence_pool ArgSpec(args=['input', 'pool_type', 'is_test'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None)) paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None))
paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None)) paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None))
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None)) paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None))
......
...@@ -31,7 +31,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -31,7 +31,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T> template <typename T, bool is_test>
class MaxSeqPoolFunctor { class MaxSeqPoolFunctor {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
...@@ -70,7 +70,41 @@ class MaxSeqPoolFunctor { ...@@ -70,7 +70,41 @@ class MaxSeqPoolFunctor {
} }
} }
}; };
// Instantisation of Max Sequence Pooling for test phase eg. no need to fill
// index buffer
template <typename T>
class MaxSeqPoolFunctor<T, true> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& input, framework::Tensor* output,
framework::Tensor* index) {
auto in_dims = input.dims();
auto out_dims = output->dims();
PADDLE_ENFORCE_GT(in_dims.size(), 1);
PADDLE_ENFORCE_GT(out_dims.size(), 1);
for (int64_t i = 1; i < in_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i]);
}
auto starts = input.lod()[0];
const T* in_data = input.data<T>();
T* out_data = output->data<T>();
int64_t num_seq = out_dims[0];
int64_t dim = output->numel() / num_seq;
for (int64_t i = 0; i < num_seq; ++i) {
std::memcpy(&out_data[i * dim], &in_data[starts[i] * dim],
dim * sizeof(T));
for (size_t j = starts[i] + 1; j < starts[i + 1]; ++j) {
for (int64_t k = 0; k < dim; ++k) {
if (in_data[j * dim + k] > out_data[i * dim + k]) {
out_data[i * dim + k] = in_data[j * dim + k];
}
}
}
}
}
};
template <typename T> template <typename T>
class MaxSeqPoolGradFunctor { class MaxSeqPoolGradFunctor {
public: public:
...@@ -188,11 +222,16 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> { ...@@ -188,11 +222,16 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
/* max pool has index output */ /* max pool has index output */
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const std::string pooltype, const framework::LoDTensor& input, const std::string pooltype, const framework::LoDTensor& input,
framework::Tensor* output, framework::Tensor* output, bool is_test,
framework::Tensor* index = nullptr) { framework::Tensor* index = nullptr) {
if (pooltype == "MAX") { if (pooltype == "MAX") {
math::MaxSeqPoolFunctor<T> max_pool; if (is_test) {
max_pool(context, input, output, index); math::MaxSeqPoolFunctor<T, true> max_pool;
max_pool(context, input, output, index);
} else {
math::MaxSeqPoolFunctor<T, false> max_pool;
max_pool(context, input, output, index);
}
return; return;
} }
if (pooltype == "LAST") { if (pooltype == "LAST") {
...@@ -200,6 +239,7 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> { ...@@ -200,6 +239,7 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
last_pool(context, input, output); last_pool(context, input, output);
return; return;
} }
if (pooltype == "FIRST") { if (pooltype == "FIRST") {
math::FirstSeqPoolFunctor<T> first_pool; math::FirstSeqPoolFunctor<T> first_pool;
first_pool(context, input, output); first_pool(context, input, output);
......
...@@ -133,7 +133,7 @@ class SequencePoolFunctor<platform::CUDADeviceContext, T> { ...@@ -133,7 +133,7 @@ class SequencePoolFunctor<platform::CUDADeviceContext, T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const std::string pooltype, const framework::LoDTensor& input, const std::string pooltype, const framework::LoDTensor& input,
framework::Tensor* output, framework::Tensor* output, bool is_test,
framework::Tensor* index = nullptr) { framework::Tensor* index = nullptr) {
auto& lod = input.lod()[0]; auto& lod = input.lod()[0];
const size_t item_dim = output->numel() / output->dims()[0]; const size_t item_dim = output->numel() / output->dims()[0];
......
...@@ -28,7 +28,7 @@ class SequencePoolFunctor { ...@@ -28,7 +28,7 @@ class SequencePoolFunctor {
/* max pool has index output */ /* max pool has index output */
void operator()(const DeviceContext& context, const std::string pooltype, void operator()(const DeviceContext& context, const std::string pooltype,
const framework::LoDTensor& input, framework::Tensor* output, const framework::LoDTensor& input, framework::Tensor* output,
framework::Tensor* index = nullptr); bool is_test = false, framework::Tensor* index = nullptr);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
......
...@@ -47,6 +47,7 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -47,6 +47,7 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor<int>) This tensor is used for the sequence max-pooling " "(Tensor<int>) This tensor is used for the sequence max-pooling "
"to record the max indexes.") "to record the max indexes.")
.AsIntermediate(); .AsIntermediate();
AddAttr<bool>("is_test", "").SetDefault(false);
AddAttr<std::string>( AddAttr<std::string>(
"pooltype", "pooltype",
"(string, default 'AVERAGE') the pooling pooltype of SequencePoolOp.") "(string, default 'AVERAGE') the pooling pooltype of SequencePoolOp.")
......
...@@ -32,10 +32,6 @@ class SequencePoolKernel : public framework::OpKernel<T> { ...@@ -32,10 +32,6 @@ class SequencePoolKernel : public framework::OpKernel<T> {
auto* in = context.Input<LoDTensor>("X"); auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<Tensor>("Out"); auto* out = context.Output<Tensor>("Out");
std::string pooltype = context.Attr<std::string>("pooltype"); std::string pooltype = context.Attr<std::string>("pooltype");
Tensor* index = nullptr;
if (pooltype == "MAX") {
index = context.Output<Tensor>("MaxIndex");
}
auto dims = in->dims(); auto dims = in->dims();
auto lod = in->lod(); auto lod = in->lod();
...@@ -48,13 +44,22 @@ class SequencePoolKernel : public framework::OpKernel<T> { ...@@ -48,13 +44,22 @@ class SequencePoolKernel : public framework::OpKernel<T> {
dims[0] = lod[0].size() - 1; dims[0] = lod[0].size() - 1;
out->Resize({dims}); out->Resize({dims});
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
if (pooltype == "MAX") { Tensor* index = nullptr;
const bool is_test = context.Attr<bool>("is_test");
// Do not create index buffer for inference (is_test) mode
// TODO(jczaja): Skip index buffer creation for other devices eg. GPU
if (pooltype == "MAX" &&
(is_test == false ||
platform::is_cpu_place(context.GetPlace()) == false)) {
index = context.Output<Tensor>("MaxIndex");
index->Resize({dims}); index->Resize({dims});
index->mutable_data<int>(context.GetPlace()); index->mutable_data<int>(context.GetPlace());
} }
math::SequencePoolFunctor<DeviceContext, T> pool; math::SequencePoolFunctor<DeviceContext, T> pool;
pool(context.template device_context<DeviceContext>(), pooltype, *in, out, pool(context.template device_context<DeviceContext>(), pooltype, *in, out,
index); is_test, index);
} }
}; };
......
...@@ -1825,7 +1825,7 @@ def conv3d(input, ...@@ -1825,7 +1825,7 @@ def conv3d(input,
return helper.append_activation(pre_act) return helper.append_activation(pre_act)
def sequence_pool(input, pool_type): def sequence_pool(input, pool_type, is_test=False):
""" """
This function add the operator for sequence pooling. This function add the operator for sequence pooling.
It pools features of all time-steps of each instance, and is applied It pools features of all time-steps of each instance, and is applied
...@@ -1862,6 +1862,7 @@ def sequence_pool(input, pool_type): ...@@ -1862,6 +1862,7 @@ def sequence_pool(input, pool_type):
input(variable): The input variable which is a LoDTensor. input(variable): The input variable which is a LoDTensor.
pool_type (string): The pooling type of sequence_pool. pool_type (string): The pooling type of sequence_pool.
It supports average, sum, sqrt and max. It supports average, sum, sqrt and max.
is_test(bool, Default False): Used distinguish training from scoring mode.
Returns: Returns:
The sequence pooling variable which is a Tensor. The sequence pooling variable which is a Tensor.
...@@ -1889,7 +1890,8 @@ def sequence_pool(input, pool_type): ...@@ -1889,7 +1890,8 @@ def sequence_pool(input, pool_type):
inputs={"X": input}, inputs={"X": input},
outputs={"Out": pool_out, outputs={"Out": pool_out,
"MaxIndex": max_index}, "MaxIndex": max_index},
attrs={"pooltype": pool_type.upper()}) attrs={"pooltype": pool_type.upper(),
"is_test": is_test})
# when pool_type is max, variable max_index is initialized, # when pool_type is max, variable max_index is initialized,
# so we stop the gradient explicitly here # so we stop the gradient explicitly here
......
...@@ -184,6 +184,20 @@ class TestSeqMaxPool2D(TestSeqAvgPool2D): ...@@ -184,6 +184,20 @@ class TestSeqMaxPool2D(TestSeqAvgPool2D):
out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 11)) out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 11))
class TestSeqMaxPool2DInference(TestSeqMaxPool2D):
def compute(self, x, offset, out):
self.attrs = {'pooltype': "MAX", 'is_test': True}
for i in range(len(offset[0]) - 1):
sub_x = np.reshape(x[offset[0][i]:offset[0][i + 1], :],
(-1, 3 * 11))
out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 11))
def test_check_grad(self):
"""Grad computation does not apply to Sequence MAX
Pool executed when is_test is true """
return
class TestSeqLastPool2D(TestSeqAvgPool2D): class TestSeqLastPool2D(TestSeqAvgPool2D):
def compute(self, x, offset, out): def compute(self, x, offset, out):
self.attrs = {'pooltype': "LAST"} self.attrs = {'pooltype': "LAST"}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册