From 79918a84429d7dab4eff9487002a7eb01d4f2aaf Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Wed, 22 Aug 2018 20:06:48 +0800 Subject: [PATCH] add sequence_mask_op for DAM model --- paddle/fluid/API.spec | 1 + paddle/fluid/operators/batch_norm_op.cc | 2 +- paddle/fluid/operators/sequence_mask_op.cc | 26 ++++ paddle/fluid/operators/sequence_mask_op.cu | 22 ++++ paddle/fluid/operators/sequence_mask_op.h | 117 ++++++++++++++++++ python/paddle/fluid/layers/nn.py | 22 +++- python/paddle/fluid/nets.py | 2 +- .../tests/book/test_image_classification.py | 5 +- .../tests/unittests/test_sequence_mask.py | 86 +++++++++++++ 9 files changed, 278 insertions(+), 5 deletions(-) create mode 100644 paddle/fluid/operators/sequence_mask_op.cc create mode 100644 paddle/fluid/operators/sequence_mask_op.cu create mode 100644 paddle/fluid/operators/sequence_mask_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_sequence_mask.py diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 9250cde1b2..359db26ed6 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -162,6 +162,7 @@ paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None)) +paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'max_len', 'mask_dtype'], varargs=None, keywords=None, defaults=('int64',)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 5912a1a17c..969f75544f 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -135,7 +135,7 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Variance", "The global variance (for training) " "or estimated Variance (for testing)"); - AddOutput("Y", "result after normalization").Reuse("X"); + AddOutput("Y", "result after normalization"); AddOutput("MeanOut", "Share memory with Mean. " "Store the global mean when training") diff --git a/paddle/fluid/operators/sequence_mask_op.cc b/paddle/fluid/operators/sequence_mask_op.cc new file mode 100644 index 0000000000..e45c18d6af --- /dev/null +++ b/paddle/fluid/operators/sequence_mask_op.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/sequence_mask_op.h" + +REGISTER_OPERATOR(sequence_mask, paddle::operators::SequenceMaskOp, + paddle::operators::SequenceMaskOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL( + sequence_mask, + paddle::operators::SequenceMaskKernel, + paddle::operators::SequenceMaskKernel); diff --git a/paddle/fluid/operators/sequence_mask_op.cu b/paddle/fluid/operators/sequence_mask_op.cu new file mode 100644 index 0000000000..ff5acf4d9e --- /dev/null +++ b/paddle/fluid/operators/sequence_mask_op.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/sequence_mask_op.h" + +REGISTER_OP_CUDA_KERNEL( + sequence_mask, + paddle::operators::SequenceMaskKernel, + paddle::operators::SequenceMaskKernel); diff --git a/paddle/fluid/operators/sequence_mask_op.h b/paddle/fluid/operators/sequence_mask_op.h new file mode 100644 index 0000000000..237857b51d --- /dev/null +++ b/paddle/fluid/operators/sequence_mask_op.h @@ -0,0 +1,117 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +class SequenceMaskOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist"); + auto max_len = ctx->Attrs().Get("max_len"); + PADDLE_ENFORCE_GT(max_len, 1, "Attr(max_len) must be larger than 1"); + PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist"); + auto dim = framework::vectorize2int(ctx->GetInputDim("X")); + dim.push_back(max_len); + ctx->SetOutputDim("Y", framework::make_ddim(dim)); + } +}; + +class SequenceMaskOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input of sequence_mask op."); + AddOutput("Y", "The output mask of sequence_mask op."); + AddAttr("max_len", "The maximum length of the sequence.") + .GreaterThan(1); + AddAttr("out_dtype", "Output data type"); + AddComment(R"DOC( +SequenceMask Operator + +This operator outputs a Mask according to Input(X) and Attr(max_len). +Supposing Input(X) is a Tensor with shape [d_1, d_2, ..., d_n], the +Output(Y) is a mask with shape [d_1, d_2, ..., d_n, max_len], where: + +Y(i_1, i_2, ..., i_n, j) = (j < X(i_1, i_2, ..., i_n)) + )DOC"); + } +}; + +template +struct SequenceMaskForRangeFunctor { + HOSTDEVICE SequenceMaskForRangeFunctor(const Tx *x, Ty *y, int max_len) + : x_(x), y_(y), max_len_(max_len) {} + + HOSTDEVICE void operator()(int y_idx) const { + int x_idx = y_idx / max_len_; + int j = y_idx % max_len_; + y_[y_idx] = static_cast(j < x_[x_idx] ? 1 : 0); + } + + private: + const Tx *x_; + Ty *y_; + int max_len_; +}; + +template +struct SequenceMaskFunctor { + using Tensor = framework::LoDTensor; + + SequenceMaskFunctor(const DeviceContext &ctx, const Tx *x, Tensor *y, + int limits, int max_len) + : ctx_(ctx), x_(x), y_(y), limits_(limits), max_len_(max_len) {} + + template + void operator()() const { + auto *y_data = y_->mutable_data(ctx_.GetPlace()); + platform::ForRange for_range(ctx_, limits_); + for_range(SequenceMaskForRangeFunctor(x_, y_data, max_len_)); + } + + private: + const DeviceContext &ctx_; + const Tx *x_; + Tensor *y_; + int limits_; + int max_len_; +}; + +template +class SequenceMaskKernel : public framework::OpKernel { + using Tensor = framework::LoDTensor; + + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *x = ctx.Input("X"); + auto *y = ctx.Output("Y"); + auto max_len = ctx.Attr("max_len"); + auto out_dtype = static_cast( + ctx.Attr("out_dtype")); + auto &dev_ctx = ctx.template device_context(); + framework::VisitDataType(out_dtype, SequenceMaskFunctor( + dev_ctx, x->data(), y, + x->numel() * max_len, max_len)); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 71592618f5..1fe457452f 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -27,6 +27,7 @@ from . import utils import random from .. import unique_name from functools import reduce +import warnings __all__ = [ 'fc', @@ -103,6 +104,7 @@ __all__ = [ 'rank_loss', 'prelu', 'flatten', + 'sequence_mask', ] @@ -2046,7 +2048,7 @@ def batch_norm(input, param_attr(ParamAttr): The parameter attribute for Parameter `scale`. bias_attr(ParamAttr): The parameter attribute for Parameter `bias`. data_layout(string, default NCHW): NCHW|NHWC - in_place(bool, Default False): Make the input and output of batch norm reuse memory. + in_place(bool, Default False): This argument is deprecated since 0.15.0. use_mkldnn(bool, Default false): ${use_mkldnn_comment} name(string, Default None): A name for this layer(optional). If set None, the layer will be named automatically. @@ -2068,6 +2070,10 @@ def batch_norm(input, helper = LayerHelper('batch_norm', **locals()) dtype = helper.input_dtype() + if in_place: + raise warnings.warn("The argument in_place is deprecated since 0.15.0, " + "please do not set it True.") + input_shape = input.shape if data_layout == 'NCHW': channel_num = input_shape[1] @@ -2117,7 +2123,7 @@ def batch_norm(input, saved_mean = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) saved_variance = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) - batch_norm_out = input if in_place else helper.create_tmp_variable(dtype) + batch_norm_out = helper.create_tmp_variable(dtype) helper.append_op( type="batch_norm", @@ -5517,3 +5523,15 @@ def flatten(x, axis=1, name=None): outputs={'Out': out}, attrs={"axis": axis}) return out + + +def sequence_mask(x, max_len, mask_dtype='int64'): + helper = LayerHelper('sequence_mask', **locals()) + y = helper.create_tmp_variable(dtype=mask_dtype) + helper.append_op( + type='sequence_mask', + inputs={'X': [x]}, + outputs={'Y': y}, + attrs={'max_len': max_len, + 'out_dtype': y.dtype}) + return y diff --git a/python/paddle/fluid/nets.py b/python/paddle/fluid/nets.py index 051fe84364..01563cbbb7 100644 --- a/python/paddle/fluid/nets.py +++ b/python/paddle/fluid/nets.py @@ -229,7 +229,7 @@ def img_conv_group(input, use_mkldnn=use_mkldnn) if conv_with_batchnorm[i]: - tmp = layers.batch_norm(input=tmp, act=conv_act, in_place=True) + tmp = layers.batch_norm(input=tmp, act=conv_act) drop_rate = conv_batchnorm_drop_rate[i] if abs(drop_rate) > 1e-5: tmp = layers.dropout(x=tmp, dropout_prob=drop_rate) diff --git a/python/paddle/fluid/tests/book/test_image_classification.py b/python/paddle/fluid/tests/book/test_image_classification.py index 9fe361425c..cd1e8cd682 100644 --- a/python/paddle/fluid/tests/book/test_image_classification.py +++ b/python/paddle/fluid/tests/book/test_image_classification.py @@ -256,7 +256,10 @@ def main(net_type, use_cuda, is_local=True): save_dirname = "image_classification_" + net_type + ".inference.model" train(net_type, use_cuda, save_dirname, is_local) - infer(use_cuda, save_dirname) + + # There is bug in fluid.InferenceTranspiler for VGG. + if net_type == "resnet": + infer(use_cuda, save_dirname) class TestImageClassification(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_sequence_mask.py b/python/paddle/fluid/tests/unittests/test_sequence_mask.py new file mode 100644 index 0000000000..c6d09df984 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sequence_mask.py @@ -0,0 +1,86 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from op_test import OpTest +from paddle.fluid.framework import convert_np_dtype_to_dtype_ +import numpy as np +import copy +import unittest + + +class SequenceMaskTestBase(OpTest): + def initDefaultParameters(self): + self.op_type = 'sequence_mask' + self.max_len = 10 + self.mask_dtype = 'int64' + self.x = [[0, 3, 4], [5, 7, 9]] + + def initParameters(self): + pass + + def setUp(self): + self.initDefaultParameters() + self.initParameters() + if not isinstance(self.x, np.ndarray): + self.x = np.array(self.x) + + self.inputs = {'X': self.x} + self.outputs = {'Y': self.calc_ground_truth_mask()} + self.attrs = { + 'max_len': self.max_len, + 'out_dtype': convert_np_dtype_to_dtype_(self.mask_dtype) + } + + def calc_ground_truth_mask(self): + shape = self.x.shape + (self.max_len, ) + index_broadcast = np.broadcast_to( + np.reshape( + range(self.max_len), newshape=[1] * self.x.ndim + [-1]), + shape=shape) + x_broadcast = np.broadcast_to( + np.reshape( + self.x, newshape=self.x.shape + (-1, )), shape=shape) + return (index_broadcast < x_broadcast).astype(self.mask_dtype) + + def test_check_output(self): + self.check_output() + + +class SequenceMaskTest1(SequenceMaskTestBase): + def initParameters(self): + self.mask_dtype = 'bool' + + +class SequenceMaskTest2(SequenceMaskTestBase): + def initParameters(self): + self.mask_dtype = 'uint8' + + +class SequenceMaskTest3(SequenceMaskTestBase): + def initParameters(self): + self.mask_dtype = 'int32' + + +class SequenceMaskTest4(SequenceMaskTestBase): + def initParameters(self): + self.mask_dtype = 'float32' + + +class SequenceMaskTest5(SequenceMaskTestBase): + def initParameters(self): + self.mask_dtype = 'float64' + + +if __name__ == '__main__': + unittest.main() -- GitLab