From f6595811a17952de335da307715a9cf073bc1b7d Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Fri, 14 Sep 2018 08:13:57 +0000 Subject: [PATCH] Get sequence length in sequence_pad op & fix sequence_mask op --- paddle/fluid/framework/op_proto_maker.cc | 4 +- paddle/fluid/framework/op_proto_maker.h | 1 + paddle/fluid/framework/operator.cc | 61 ++++++++++++++----- paddle/fluid/operators/sequence_mask_op.cc | 6 +- paddle/fluid/operators/sequence_mask_op.cu | 6 +- paddle/fluid/operators/sequence_pad_op.cc | 38 ++++++++++-- paddle/fluid/operators/sequence_pad_op.h | 10 +++ paddle/fluid/operators/top_k_op.cc | 2 + paddle/fluid/pybind/const_value.cc | 3 + python/paddle/fluid/framework.py | 5 ++ python/paddle/fluid/layers/nn.py | 15 +++-- .../tests/unittests/test_operator_desc.py | 2 +- .../tests/unittests/test_sequence_pad_op.py | 3 +- 13 files changed, 126 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/framework/op_proto_maker.cc b/paddle/fluid/framework/op_proto_maker.cc index 4fa047bf3e..f6a9a71095 100644 --- a/paddle/fluid/framework/op_proto_maker.cc +++ b/paddle/fluid/framework/op_proto_maker.cc @@ -131,7 +131,9 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, AddAttr(OpNamescopeAttrName(), "Operator name with namesope.") .SetDefault(""); - + AddAttr>(OpCreationCallstackAttrName(), + "Callstack for Op Creatation.") + .SetDefault({}); Validate(); } diff --git a/paddle/fluid/framework/op_proto_maker.h b/paddle/fluid/framework/op_proto_maker.h index 18827385ad..c8386263bd 100644 --- a/paddle/fluid/framework/op_proto_maker.h +++ b/paddle/fluid/framework/op_proto_maker.h @@ -40,6 +40,7 @@ class OpProtoAndCheckerMaker { static const char *OpRoleAttrName() { return "op_role"; } static const char *OpRoleVarAttrName() { return "op_role_var"; } static const char *OpNamescopeAttrName() { return "op_namescope"; } + static const char *OpCreationCallstackAttrName() { return "op_callstack"; } void operator()(proto::OpProto *proto, OpAttrChecker *attr_checker); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index d58d6e4f3e..ee119aa362 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -11,15 +11,17 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include - +#include "paddle/fluid/framework/operator.h" #include - +#include +#include +#include +#include "gflags/gflags.h" +#include "glog/logging.h" #include "paddle/fluid/framework/data_transform.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/platform/profiler.h" @@ -137,19 +139,48 @@ static LoD GetLoD(const Scope& scope, const std::string& name) { } void OperatorBase::Run(const Scope& scope, const platform::Place& place) { - VLOG(4) << place << " " << DebugStringEx(&scope); - if (platform::is_gpu_place(place)) { + try { + if (VLOG_IS_ON(4)) { + VLOG(4) << place << " " << DebugStringEx(&scope); + } + if (platform::is_gpu_place(place)) { #ifndef PADDLE_WITH_CUDA - PADDLE_THROW("Cannot run operator on place %s", place); + PADDLE_THROW("Cannot run operator on place %s", place); #else - auto dev_id = boost::get(place).device; - platform::SetDeviceId(dev_id); + auto dev_id = boost::get(place).device; + platform::SetDeviceId(dev_id); #endif + } + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + platform::RecordEvent record_event(Type(), pool.Get(place)); + RunImpl(scope, place); + if (VLOG_IS_ON(3)) { + VLOG(3) << place << " " << DebugStringEx(&scope); + } + } catch (platform::EnforceNotMet exception) { + if (Attrs().count("sub_block") != 0) { + throw exception; + } + + auto& callstack = Attr>( + OpProtoAndCheckerMaker::OpCreationCallstackAttrName()); + + if (callstack.empty()) { + throw exception; + } + std::ostringstream sout; + sout << "Invoke operator " << Type() << " error.\n"; + sout << "Python Callstacks: \n"; + for (auto& line : callstack) { + sout << line; + } + sout << "C++ Callstacks: \n"; + sout << exception.err_str_; + exception.err_str_ = sout.str(); + throw exception; + } catch (...) { + std::rethrow_exception(std::current_exception()); } - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - platform::RecordEvent record_event(Type(), pool.Get(place)); - RunImpl(scope, place); - VLOG(3) << place << " " << DebugStringEx(&scope); } bool OperatorBase::HasInputs(const std::string& name) const { @@ -177,7 +208,7 @@ const std::vector& OperatorBase::Inputs( } bool OperatorBase::HasOutputs(const std::string& name) const { - if (outputs_.find(name) != outputs_.end()) { + if (outputs_.end() != outputs_.find(name)) { return true; } else { return false; diff --git a/paddle/fluid/operators/sequence_mask_op.cc b/paddle/fluid/operators/sequence_mask_op.cc index e45c18d6af..798211f481 100644 --- a/paddle/fluid/operators/sequence_mask_op.cc +++ b/paddle/fluid/operators/sequence_mask_op.cc @@ -23,4 +23,8 @@ REGISTER_OP_CPU_KERNEL( paddle::operators::SequenceMaskKernel, paddle::operators::SequenceMaskKernel); + int64_t>, + paddle::operators::SequenceMaskKernel, + paddle::operators::SequenceMaskKernel); diff --git a/paddle/fluid/operators/sequence_mask_op.cu b/paddle/fluid/operators/sequence_mask_op.cu index ff5acf4d9e..2ad2377457 100644 --- a/paddle/fluid/operators/sequence_mask_op.cu +++ b/paddle/fluid/operators/sequence_mask_op.cu @@ -19,4 +19,8 @@ REGISTER_OP_CUDA_KERNEL( paddle::operators::SequenceMaskKernel, paddle::operators::SequenceMaskKernel); + int64_t>, + paddle::operators::SequenceMaskKernel, + paddle::operators::SequenceMaskKernel); diff --git a/paddle/fluid/operators/sequence_pad_op.cc b/paddle/fluid/operators/sequence_pad_op.cc index 44d73aa407..4583b26256 100644 --- a/paddle/fluid/operators/sequence_pad_op.cc +++ b/paddle/fluid/operators/sequence_pad_op.cc @@ -29,10 +29,12 @@ class SequencePadOp : public framework::OperatorWithKernel { "Input(PadValue) of SequencePadOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of SequencePadOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Length"), + "Output(Length) of SequencePadOp should not be null."); auto x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE_GE(x_dims.size(), 2, - "The rank of Input(x) can't be less than 2."); + "The rank of Input(X) can't be less than 2."); auto time_step_dims = framework::slice_ddim(x_dims, 1, x_dims.size()); auto pad_value_dims = ctx->GetInputDim("PadValue"); PADDLE_ENFORCE(pad_value_dims == framework::make_ddim({1}) || @@ -41,8 +43,8 @@ class SequencePadOp : public framework::OperatorWithKernel { "shape equals to time steps in sequences"); int out_dim_0 = -1; - int out_dim_1 = -1; + int padded_length = ctx->Attrs().Get("padded_length"); if (ctx->IsRuntime()) { // run time framework::Variable* x_var = @@ -58,7 +60,6 @@ class SequencePadOp : public framework::OperatorWithKernel { int seq_num = x_lod_0.size() - 1; int max_seq_len = math::MaximumSequenceLength(x_lod_0); - int padded_length = ctx->Attrs().Get("padded_length"); if (padded_length == -1) { padded_length = max_seq_len; } @@ -66,19 +67,30 @@ class SequencePadOp : public framework::OperatorWithKernel { "The Attr(padded_length) must be -1 or an int greater " "than the length of the longest original sequence."); out_dim_0 = seq_num; - out_dim_1 = padded_length; } else { // compile time + if (padded_length == -1) { + padded_length = 1; + } framework::VarDesc* x_desc = boost::get(ctx->GetInputVarPtrs("X")[0]); PADDLE_ENFORCE_GE(x_desc->GetLoDLevel(), 1); } - std::vector out_dims_vec{out_dim_0, out_dim_1}; + std::vector out_dims_vec{out_dim_0, padded_length}; + std::vector len_dims_vec{out_dim_0, 1}; auto time_step_dims_vec = framework::vectorize2int(time_step_dims); out_dims_vec.insert(out_dims_vec.end(), time_step_dims_vec.begin(), time_step_dims_vec.end()); ctx->SetOutputDim("Out", framework::make_ddim(out_dims_vec)); + ctx->SetOutputDim("Length", framework::make_ddim(len_dims_vec)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); + return framework::OpKernelType(data_type, ctx.device_context()); } }; @@ -96,6 +108,10 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput( "Out", "(LoDTensor) The output vairable, which contains padded sequences."); + AddOutput( + "Length", + "(LoDTensor) The output vairable, which contains the actual length of " + "sequences before padding."); AddAttr( "padded_length", "The length of padded sequences. It can be setted to -1 or " @@ -125,6 +141,7 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker { then we get LoDTensor: Out.data = [[a, b, 0, 0], [c, d, e, 0]] + Length.data = [[2], [3]] Case 2: @@ -138,7 +155,8 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker { then we get LoDTensor: Out.data = [[[a1, a2], [b1, b2], [0, 0]], [[c1, c2], [d1, d2], [e1, e2]]] - + Length.data = [[2], [3]] + Case 3: Given a 1-level LoDTensor input(X): @@ -151,6 +169,7 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker { then we get LoDTensor: Out.data = [[[a1, a2], [b1, b2], [p1, p2]], [[c1, c2], [d1, d2], [e1, e2]]] + Length.data = [[2], [3]] )DOC"); } @@ -171,6 +190,13 @@ class SequencePadGradOp : public framework::OperatorWithKernel { ctx->ShareLoD("X", /*->*/ framework::GradVarName("X")); } } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); + return framework::OpKernelType(data_type, ctx.device_context()); + } }; } // namespace operators diff --git a/paddle/fluid/operators/sequence_pad_op.h b/paddle/fluid/operators/sequence_pad_op.h index 5fc9da69d7..840bd39a7f 100644 --- a/paddle/fluid/operators/sequence_pad_op.h +++ b/paddle/fluid/operators/sequence_pad_op.h @@ -32,6 +32,7 @@ class SequencePadOpKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { const auto* x = ctx.Input("X"); auto* out = ctx.Output("Out"); + auto* len_t = ctx.Output("Length"); out->mutable_data(ctx.GetPlace()); const auto* pad_value = ctx.Input("PadValue"); @@ -41,6 +42,15 @@ class SequencePadOpKernel : public framework::OpKernel { math::PaddingLoDTensorFunctor()( ctx.template device_context(), *x, out, *pad_value, padded_length, 0, false, math::kBatchLengthWidth); + + LoDTensor seq_len; + seq_len.Resize(len_t->dims()); + int64_t* len_data = seq_len.mutable_data(platform::CPUPlace()); + for (size_t i = 1; i < x->lod()[0].size(); ++i) { + len_data[i - 1] = x->lod()[0][i] - x->lod()[0][i - 1]; + } + framework::TensorCopy(seq_len, ctx.GetPlace(), + ctx.template device_context(), len_t); } }; diff --git a/paddle/fluid/operators/top_k_op.cc b/paddle/fluid/operators/top_k_op.cc index 4a8ac441cf..92a0697e27 100644 --- a/paddle/fluid/operators/top_k_op.cc +++ b/paddle/fluid/operators/top_k_op.cc @@ -30,6 +30,8 @@ class TopkOp : public framework::OperatorWithKernel { "Output(Indices) of TopkOp should not be null."); auto input_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(input_dims.size(), 2, + "Rank of TopK op's input must be 2."); const int k = static_cast(ctx->Attrs().Get("k")); PADDLE_ENFORCE_GE(k, 1, "k must >= 1"); diff --git a/paddle/fluid/pybind/const_value.cc b/paddle/fluid/pybind/const_value.cc index f577068d1f..b4af7ed826 100644 --- a/paddle/fluid/pybind/const_value.cc +++ b/paddle/fluid/pybind/const_value.cc @@ -46,6 +46,9 @@ void BindConstValue(pybind11::module* m) { op_proto_and_checker_maker.def( "kOpNameScopeAttrName", framework::OpProtoAndCheckerMaker::OpNamescopeAttrName); + op_proto_and_checker_maker.def( + "kOpCreationCallstackAttrName", + framework::OpProtoAndCheckerMaker::OpCreationCallstackAttrName); } } // namespace pybind diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index b0e0d27ff7..633d2334cf 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -18,6 +18,7 @@ import collections import contextlib import re import six +import traceback import numpy as np @@ -572,6 +573,10 @@ class Operator(object): if role_var_name in op_attrs and len(op_attrs[role_var_name]) == 0: del op_attrs[role_var_name] + callstack_var_name = op_maker.kOpCreationCallstackAttrName() + op_attrs[callstack_var_name] = list( + reversed(traceback.format_stack()))[1:] + if len(self.desc.type()) != 0: return if type is None: diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 8408e6d2a1..8af1f7b1bd 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -2680,7 +2680,8 @@ def sequence_pad(x, pad_value, maxlen=None): longest original sequence." Returns: - Variable: The padded sequence batch. All sequences has the same length. + Variable: The padded sequence batch and the original lengths before + padding. All sequences has the same length. Examples: .. code-block:: python @@ -2696,15 +2697,21 @@ def sequence_pad(x, pad_value, maxlen=None): helper = LayerHelper('sequence_pad', input=x, **locals()) dtype = helper.input_dtype() out = helper.create_tmp_variable(dtype) + length = helper.create_tmp_variable(dtype) + + pad_value.stop_gradient = True + length.stop_gradient = True + if maxlen is None: maxlen = -1 helper.append_op( type='sequence_pad', inputs={'X': x, 'PadValue': pad_value}, - outputs={'Out': out}, + outputs={'Out': out, + 'Length': length}, attrs={'padded_length': maxlen}) - return out + return out, length def beam_search(pre_ids, @@ -5913,7 +5920,7 @@ def sequence_mask(x, maxlen=None, dtype='int64', name=None): inputs={'X': [x]}, outputs={'Y': out}, attrs={ - 'max_len': maxlen if maxlen is not None else -1, + 'maxlen': maxlen if maxlen is not None else -1, 'out_dtype': out.dtype }) return out diff --git a/python/paddle/fluid/tests/unittests/test_operator_desc.py b/python/paddle/fluid/tests/unittests/test_operator_desc.py index cac132e6e0..e0c9b28b3a 100644 --- a/python/paddle/fluid/tests/unittests/test_operator_desc.py +++ b/python/paddle/fluid/tests/unittests/test_operator_desc.py @@ -69,7 +69,7 @@ class TestOperator(unittest.TestCase): set(mul_op.attr_names), set([ "x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var", - "op_namescope" + "op_namescope", "op_callstack" ])) self.assertEqual(mul_op.has_attr("x_num_col_dims"), True) self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT) diff --git a/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py b/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py index 471515c817..3ac7371aa7 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py @@ -62,7 +62,8 @@ class TestSequencePadOp(OpTest): start_idx = end_idx out_data = np.array(padded_sequences) - self.outputs = {'Out': out_data} + length = np.array(self.x_len_lod[0]) + self.outputs = {'Out': out_data, 'Length': length} def setUp(self): self.op_type = 'sequence_pad' -- GitLab