提交 f6595811 编写于 作者: Y Yibing Liu

Get sequence length in sequence_pad op & fix sequence_mask op

上级 a39eba77
...@@ -131,7 +131,9 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, ...@@ -131,7 +131,9 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
AddAttr<std::string>(OpNamescopeAttrName(), "Operator name with namesope.") AddAttr<std::string>(OpNamescopeAttrName(), "Operator name with namesope.")
.SetDefault(""); .SetDefault("");
AddAttr<std::vector<std::string>>(OpCreationCallstackAttrName(),
"Callstack for Op Creatation.")
.SetDefault({});
Validate(); Validate();
} }
......
...@@ -40,6 +40,7 @@ class OpProtoAndCheckerMaker { ...@@ -40,6 +40,7 @@ class OpProtoAndCheckerMaker {
static const char *OpRoleAttrName() { return "op_role"; } static const char *OpRoleAttrName() { return "op_role"; }
static const char *OpRoleVarAttrName() { return "op_role_var"; } static const char *OpRoleVarAttrName() { return "op_role_var"; }
static const char *OpNamescopeAttrName() { return "op_namescope"; } static const char *OpNamescopeAttrName() { return "op_namescope"; }
static const char *OpCreationCallstackAttrName() { return "op_callstack"; }
void operator()(proto::OpProto *proto, OpAttrChecker *attr_checker); void operator()(proto::OpProto *proto, OpAttrChecker *attr_checker);
......
...@@ -11,15 +11,17 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,15 +11,17 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <gflags/gflags.h> #include "paddle/fluid/framework/operator.h"
#include <glog/logging.h>
#include <algorithm> #include <algorithm>
#include <sstream>
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "paddle/fluid/framework/data_transform.h" #include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -137,19 +139,48 @@ static LoD GetLoD(const Scope& scope, const std::string& name) { ...@@ -137,19 +139,48 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
} }
void OperatorBase::Run(const Scope& scope, const platform::Place& place) { void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
VLOG(4) << place << " " << DebugStringEx(&scope); try {
if (platform::is_gpu_place(place)) { if (VLOG_IS_ON(4)) {
VLOG(4) << place << " " << DebugStringEx(&scope);
}
if (platform::is_gpu_place(place)) {
#ifndef PADDLE_WITH_CUDA #ifndef PADDLE_WITH_CUDA
PADDLE_THROW("Cannot run operator on place %s", place); PADDLE_THROW("Cannot run operator on place %s", place);
#else #else
auto dev_id = boost::get<platform::CUDAPlace>(place).device; auto dev_id = boost::get<platform::CUDAPlace>(place).device;
platform::SetDeviceId(dev_id); platform::SetDeviceId(dev_id);
#endif #endif
}
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::RecordEvent record_event(Type(), pool.Get(place));
RunImpl(scope, place);
if (VLOG_IS_ON(3)) {
VLOG(3) << place << " " << DebugStringEx(&scope);
}
} catch (platform::EnforceNotMet exception) {
if (Attrs().count("sub_block") != 0) {
throw exception;
}
auto& callstack = Attr<std::vector<std::string>>(
OpProtoAndCheckerMaker::OpCreationCallstackAttrName());
if (callstack.empty()) {
throw exception;
}
std::ostringstream sout;
sout << "Invoke operator " << Type() << " error.\n";
sout << "Python Callstacks: \n";
for (auto& line : callstack) {
sout << line;
}
sout << "C++ Callstacks: \n";
sout << exception.err_str_;
exception.err_str_ = sout.str();
throw exception;
} catch (...) {
std::rethrow_exception(std::current_exception());
} }
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::RecordEvent record_event(Type(), pool.Get(place));
RunImpl(scope, place);
VLOG(3) << place << " " << DebugStringEx(&scope);
} }
bool OperatorBase::HasInputs(const std::string& name) const { bool OperatorBase::HasInputs(const std::string& name) const {
...@@ -177,7 +208,7 @@ const std::vector<std::string>& OperatorBase::Inputs( ...@@ -177,7 +208,7 @@ const std::vector<std::string>& OperatorBase::Inputs(
} }
bool OperatorBase::HasOutputs(const std::string& name) const { bool OperatorBase::HasOutputs(const std::string& name) const {
if (outputs_.find(name) != outputs_.end()) { if (outputs_.end() != outputs_.find(name)) {
return true; return true;
} else { } else {
return false; return false;
......
...@@ -23,4 +23,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -23,4 +23,8 @@ REGISTER_OP_CPU_KERNEL(
paddle::operators::SequenceMaskKernel<paddle::platform::CPUDeviceContext, paddle::operators::SequenceMaskKernel<paddle::platform::CPUDeviceContext,
int>, int>,
paddle::operators::SequenceMaskKernel<paddle::platform::CPUDeviceContext, paddle::operators::SequenceMaskKernel<paddle::platform::CPUDeviceContext,
int64_t>); int64_t>,
paddle::operators::SequenceMaskKernel<paddle::platform::CPUDeviceContext,
float>,
paddle::operators::SequenceMaskKernel<paddle::platform::CPUDeviceContext,
double>);
...@@ -19,4 +19,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -19,4 +19,8 @@ REGISTER_OP_CUDA_KERNEL(
paddle::operators::SequenceMaskKernel<paddle::platform::CUDADeviceContext, paddle::operators::SequenceMaskKernel<paddle::platform::CUDADeviceContext,
int>, int>,
paddle::operators::SequenceMaskKernel<paddle::platform::CUDADeviceContext, paddle::operators::SequenceMaskKernel<paddle::platform::CUDADeviceContext,
int64_t>); int64_t>,
paddle::operators::SequenceMaskKernel<paddle::platform::CUDADeviceContext,
float>,
paddle::operators::SequenceMaskKernel<paddle::platform::CUDADeviceContext,
double>);
...@@ -29,10 +29,12 @@ class SequencePadOp : public framework::OperatorWithKernel { ...@@ -29,10 +29,12 @@ class SequencePadOp : public framework::OperatorWithKernel {
"Input(PadValue) of SequencePadOp should not be null."); "Input(PadValue) of SequencePadOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequencePadOp should not be null."); "Output(Out) of SequencePadOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Length"),
"Output(Length) of SequencePadOp should not be null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(x_dims.size(), 2, PADDLE_ENFORCE_GE(x_dims.size(), 2,
"The rank of Input(x) can't be less than 2."); "The rank of Input(X) can't be less than 2.");
auto time_step_dims = framework::slice_ddim(x_dims, 1, x_dims.size()); auto time_step_dims = framework::slice_ddim(x_dims, 1, x_dims.size());
auto pad_value_dims = ctx->GetInputDim("PadValue"); auto pad_value_dims = ctx->GetInputDim("PadValue");
PADDLE_ENFORCE(pad_value_dims == framework::make_ddim({1}) || PADDLE_ENFORCE(pad_value_dims == framework::make_ddim({1}) ||
...@@ -41,8 +43,8 @@ class SequencePadOp : public framework::OperatorWithKernel { ...@@ -41,8 +43,8 @@ class SequencePadOp : public framework::OperatorWithKernel {
"shape equals to time steps in sequences"); "shape equals to time steps in sequences");
int out_dim_0 = -1; int out_dim_0 = -1;
int out_dim_1 = -1;
int padded_length = ctx->Attrs().Get<int>("padded_length");
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
// run time // run time
framework::Variable* x_var = framework::Variable* x_var =
...@@ -58,7 +60,6 @@ class SequencePadOp : public framework::OperatorWithKernel { ...@@ -58,7 +60,6 @@ class SequencePadOp : public framework::OperatorWithKernel {
int seq_num = x_lod_0.size() - 1; int seq_num = x_lod_0.size() - 1;
int max_seq_len = math::MaximumSequenceLength(x_lod_0); int max_seq_len = math::MaximumSequenceLength(x_lod_0);
int padded_length = ctx->Attrs().Get<int>("padded_length");
if (padded_length == -1) { if (padded_length == -1) {
padded_length = max_seq_len; padded_length = max_seq_len;
} }
...@@ -66,19 +67,30 @@ class SequencePadOp : public framework::OperatorWithKernel { ...@@ -66,19 +67,30 @@ class SequencePadOp : public framework::OperatorWithKernel {
"The Attr(padded_length) must be -1 or an int greater " "The Attr(padded_length) must be -1 or an int greater "
"than the length of the longest original sequence."); "than the length of the longest original sequence.");
out_dim_0 = seq_num; out_dim_0 = seq_num;
out_dim_1 = padded_length;
} else { } else {
// compile time // compile time
if (padded_length == -1) {
padded_length = 1;
}
framework::VarDesc* x_desc = framework::VarDesc* x_desc =
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("X")[0]); boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("X")[0]);
PADDLE_ENFORCE_GE(x_desc->GetLoDLevel(), 1); PADDLE_ENFORCE_GE(x_desc->GetLoDLevel(), 1);
} }
std::vector<int> out_dims_vec{out_dim_0, out_dim_1}; std::vector<int> out_dims_vec{out_dim_0, padded_length};
std::vector<int> len_dims_vec{out_dim_0, 1};
auto time_step_dims_vec = framework::vectorize2int(time_step_dims); auto time_step_dims_vec = framework::vectorize2int(time_step_dims);
out_dims_vec.insert(out_dims_vec.end(), time_step_dims_vec.begin(), out_dims_vec.insert(out_dims_vec.end(), time_step_dims_vec.begin(),
time_step_dims_vec.end()); time_step_dims_vec.end());
ctx->SetOutputDim("Out", framework::make_ddim(out_dims_vec)); ctx->SetOutputDim("Out", framework::make_ddim(out_dims_vec));
ctx->SetOutputDim("Length", framework::make_ddim(len_dims_vec));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X"));
return framework::OpKernelType(data_type, ctx.device_context());
} }
}; };
...@@ -96,6 +108,10 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -96,6 +108,10 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput( AddOutput(
"Out", "Out",
"(LoDTensor) The output vairable, which contains padded sequences."); "(LoDTensor) The output vairable, which contains padded sequences.");
AddOutput(
"Length",
"(LoDTensor) The output vairable, which contains the actual length of "
"sequences before padding.");
AddAttr<int>( AddAttr<int>(
"padded_length", "padded_length",
"The length of padded sequences. It can be setted to -1 or " "The length of padded sequences. It can be setted to -1 or "
...@@ -125,6 +141,7 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -125,6 +141,7 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker {
then we get LoDTensor: then we get LoDTensor:
Out.data = [[a, b, 0, 0], Out.data = [[a, b, 0, 0],
[c, d, e, 0]] [c, d, e, 0]]
Length.data = [[2], [3]]
Case 2: Case 2:
...@@ -138,7 +155,8 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -138,7 +155,8 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker {
then we get LoDTensor: then we get LoDTensor:
Out.data = [[[a1, a2], [b1, b2], [0, 0]], Out.data = [[[a1, a2], [b1, b2], [0, 0]],
[[c1, c2], [d1, d2], [e1, e2]]] [[c1, c2], [d1, d2], [e1, e2]]]
Length.data = [[2], [3]]
Case 3: Case 3:
Given a 1-level LoDTensor input(X): Given a 1-level LoDTensor input(X):
...@@ -151,6 +169,7 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -151,6 +169,7 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker {
then we get LoDTensor: then we get LoDTensor:
Out.data = [[[a1, a2], [b1, b2], [p1, p2]], Out.data = [[[a1, a2], [b1, b2], [p1, p2]],
[[c1, c2], [d1, d2], [e1, e2]]] [[c1, c2], [d1, d2], [e1, e2]]]
Length.data = [[2], [3]]
)DOC"); )DOC");
} }
...@@ -171,6 +190,13 @@ class SequencePadGradOp : public framework::OperatorWithKernel { ...@@ -171,6 +190,13 @@ class SequencePadGradOp : public framework::OperatorWithKernel {
ctx->ShareLoD("X", /*->*/ framework::GradVarName("X")); ctx->ShareLoD("X", /*->*/ framework::GradVarName("X"));
} }
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X"));
return framework::OpKernelType(data_type, ctx.device_context());
}
}; };
} // namespace operators } // namespace operators
......
...@@ -32,6 +32,7 @@ class SequencePadOpKernel : public framework::OpKernel<T> { ...@@ -32,6 +32,7 @@ class SequencePadOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const auto* x = ctx.Input<LoDTensor>("X"); const auto* x = ctx.Input<LoDTensor>("X");
auto* out = ctx.Output<LoDTensor>("Out"); auto* out = ctx.Output<LoDTensor>("Out");
auto* len_t = ctx.Output<LoDTensor>("Length");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
const auto* pad_value = ctx.Input<LoDTensor>("PadValue"); const auto* pad_value = ctx.Input<LoDTensor>("PadValue");
...@@ -41,6 +42,15 @@ class SequencePadOpKernel : public framework::OpKernel<T> { ...@@ -41,6 +42,15 @@ class SequencePadOpKernel : public framework::OpKernel<T> {
math::PaddingLoDTensorFunctor<DeviceContext, T>()( math::PaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *x, out, *pad_value, ctx.template device_context<DeviceContext>(), *x, out, *pad_value,
padded_length, 0, false, math::kBatchLengthWidth); padded_length, 0, false, math::kBatchLengthWidth);
LoDTensor seq_len;
seq_len.Resize(len_t->dims());
int64_t* len_data = seq_len.mutable_data<int64_t>(platform::CPUPlace());
for (size_t i = 1; i < x->lod()[0].size(); ++i) {
len_data[i - 1] = x->lod()[0][i] - x->lod()[0][i - 1];
}
framework::TensorCopy(seq_len, ctx.GetPlace(),
ctx.template device_context<DeviceContext>(), len_t);
} }
}; };
......
...@@ -30,6 +30,8 @@ class TopkOp : public framework::OperatorWithKernel { ...@@ -30,6 +30,8 @@ class TopkOp : public framework::OperatorWithKernel {
"Output(Indices) of TopkOp should not be null."); "Output(Indices) of TopkOp should not be null.");
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(input_dims.size(), 2,
"Rank of TopK op's input must be 2.");
const int k = static_cast<int>(ctx->Attrs().Get<int>("k")); const int k = static_cast<int>(ctx->Attrs().Get<int>("k"));
PADDLE_ENFORCE_GE(k, 1, "k must >= 1"); PADDLE_ENFORCE_GE(k, 1, "k must >= 1");
......
...@@ -46,6 +46,9 @@ void BindConstValue(pybind11::module* m) { ...@@ -46,6 +46,9 @@ void BindConstValue(pybind11::module* m) {
op_proto_and_checker_maker.def( op_proto_and_checker_maker.def(
"kOpNameScopeAttrName", "kOpNameScopeAttrName",
framework::OpProtoAndCheckerMaker::OpNamescopeAttrName); framework::OpProtoAndCheckerMaker::OpNamescopeAttrName);
op_proto_and_checker_maker.def(
"kOpCreationCallstackAttrName",
framework::OpProtoAndCheckerMaker::OpCreationCallstackAttrName);
} }
} // namespace pybind } // namespace pybind
......
...@@ -18,6 +18,7 @@ import collections ...@@ -18,6 +18,7 @@ import collections
import contextlib import contextlib
import re import re
import six import six
import traceback
import numpy as np import numpy as np
...@@ -572,6 +573,10 @@ class Operator(object): ...@@ -572,6 +573,10 @@ class Operator(object):
if role_var_name in op_attrs and len(op_attrs[role_var_name]) == 0: if role_var_name in op_attrs and len(op_attrs[role_var_name]) == 0:
del op_attrs[role_var_name] del op_attrs[role_var_name]
callstack_var_name = op_maker.kOpCreationCallstackAttrName()
op_attrs[callstack_var_name] = list(
reversed(traceback.format_stack()))[1:]
if len(self.desc.type()) != 0: if len(self.desc.type()) != 0:
return return
if type is None: if type is None:
......
...@@ -2680,7 +2680,8 @@ def sequence_pad(x, pad_value, maxlen=None): ...@@ -2680,7 +2680,8 @@ def sequence_pad(x, pad_value, maxlen=None):
longest original sequence." longest original sequence."
Returns: Returns:
Variable: The padded sequence batch. All sequences has the same length. Variable: The padded sequence batch and the original lengths before
padding. All sequences has the same length.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -2696,15 +2697,21 @@ def sequence_pad(x, pad_value, maxlen=None): ...@@ -2696,15 +2697,21 @@ def sequence_pad(x, pad_value, maxlen=None):
helper = LayerHelper('sequence_pad', input=x, **locals()) helper = LayerHelper('sequence_pad', input=x, **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype) out = helper.create_tmp_variable(dtype)
length = helper.create_tmp_variable(dtype)
pad_value.stop_gradient = True
length.stop_gradient = True
if maxlen is None: if maxlen is None:
maxlen = -1 maxlen = -1
helper.append_op( helper.append_op(
type='sequence_pad', type='sequence_pad',
inputs={'X': x, inputs={'X': x,
'PadValue': pad_value}, 'PadValue': pad_value},
outputs={'Out': out}, outputs={'Out': out,
'Length': length},
attrs={'padded_length': maxlen}) attrs={'padded_length': maxlen})
return out return out, length
def beam_search(pre_ids, def beam_search(pre_ids,
...@@ -5913,7 +5920,7 @@ def sequence_mask(x, maxlen=None, dtype='int64', name=None): ...@@ -5913,7 +5920,7 @@ def sequence_mask(x, maxlen=None, dtype='int64', name=None):
inputs={'X': [x]}, inputs={'X': [x]},
outputs={'Y': out}, outputs={'Y': out},
attrs={ attrs={
'max_len': maxlen if maxlen is not None else -1, 'maxlen': maxlen if maxlen is not None else -1,
'out_dtype': out.dtype 'out_dtype': out.dtype
}) })
return out return out
......
...@@ -69,7 +69,7 @@ class TestOperator(unittest.TestCase): ...@@ -69,7 +69,7 @@ class TestOperator(unittest.TestCase):
set(mul_op.attr_names), set(mul_op.attr_names),
set([ set([
"x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var", "x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var",
"op_namescope" "op_namescope", "op_callstack"
])) ]))
self.assertEqual(mul_op.has_attr("x_num_col_dims"), True) self.assertEqual(mul_op.has_attr("x_num_col_dims"), True)
self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT) self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT)
......
...@@ -62,7 +62,8 @@ class TestSequencePadOp(OpTest): ...@@ -62,7 +62,8 @@ class TestSequencePadOp(OpTest):
start_idx = end_idx start_idx = end_idx
out_data = np.array(padded_sequences) out_data = np.array(padded_sequences)
self.outputs = {'Out': out_data} length = np.array(self.x_len_lod[0])
self.outputs = {'Out': out_data, 'Length': length}
def setUp(self): def setUp(self):
self.op_type = 'sequence_pad' self.op_type = 'sequence_pad'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册