diff --git a/paddle/operators/reorder_lod_tensor_by_rank_op.cc b/paddle/operators/reorder_lod_tensor_by_rank_op.cc index 384047428d086ed9294ebefff47fee1939de0ad0..369bd4391c9884c7a521641d2096898960992022 100644 --- a/paddle/operators/reorder_lod_tensor_by_rank_op.cc +++ b/paddle/operators/reorder_lod_tensor_by_rank_op.cc @@ -19,21 +19,28 @@ namespace paddle { namespace operators { -class ReorderLoDTensorProtoMaker : public framework::OpProtoAndCheckerMaker { +class ReorderLoDTensorByRankTableOpProtoMaker + : public framework::OpProtoAndCheckerMaker { public: - ReorderLoDTensorProtoMaker(OpProto *proto, OpAttrChecker *op_checker) + ReorderLoDTensorByRankTableOpProtoMaker(OpProto *proto, + OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "(LoDTensor) the input lod tensor need to be reordered."); AddInput("RankTable", "(LoDRankTable) the rank table that input need follow"); AddOutput("Out", "(LoDTensor) reordered lod tensor"); - AddComment(R"DOC(ReorderLoDTensorLoDRankTable + AddComment(R"DOC(ReorderLoDTensorByRankTable Reorder the input X by the rank of `RankTable`. If `RankTable` is ordered by index [3, 0, 2, 1]. Input X will reorder its sequence, the third sequence of X will be the first sequence of Output. NOTE: The RankTable does not need to be calculated by X. + +For example: +The X = [Seq0, Seq1, Seq2, Seq3]. The indices of RankTable are [3, 0, 2, 1]. + +The Out = [Seq3, Seq0, Seq2, Seq1] with correct LoD information. )DOC"); } }; @@ -146,8 +153,9 @@ class ReorderLoDTensorByRankTableOp : public ReorderLoDTensorByRankTableBase { size_t out_offset = 0; out->mutable_lod()->clear(); for (auto &item : rank_table.items()) { - out_offset = this->CopyTensorAndLod(dev_ctx, absolute_table[item.index], - x, out, out_offset); + PADDLE_ENFORCE_LT(item.index, absolute_table.size()); + out_offset = CopyTensorAndLod(dev_ctx, absolute_table[item.index], x, out, + out_offset); } } }; @@ -220,6 +228,7 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(reorder_lod_tensor_by_rank, ops::ReorderLoDTensorByRankTableOp, ops::ReorderLodTensorByRankGradOpMaker, - ops::ReorderLoDTensorProtoMaker, ops::IdentityInferShape); + ops::ReorderLoDTensorByRankTableOpProtoMaker, + ops::IdentityInferShape); REGISTER_OPERATOR(reorder_lod_tensor_by_rank_grad, ops::ReorderLoDTensorByRankGradOp, ops::IdentityInferShape); diff --git a/python/paddle/v2/fluid/layers/control_flow.py b/python/paddle/v2/fluid/layers/control_flow.py index d66c834654759d3a391eb7bdb5581b7a32a89af1..f49cabfee87242c5ae352a5445a586301c9d9b47 100644 --- a/python/paddle/v2/fluid/layers/control_flow.py +++ b/python/paddle/v2/fluid/layers/control_flow.py @@ -3,6 +3,7 @@ from ..framework import Program, Variable, Operator from .. import core from tensor import assign, fill_constant import contextlib +from ..registry import autodoc __all__ = [ 'split_lod_tensor', 'merge_lod_tensor', 'BlockGuard', 'StaticRNNGuard', @@ -983,16 +984,8 @@ class DynamicRNN(object): method)) +@autodoc def reorder_lod_tensor_by_rank(x, rank_table): - """ - - Args: - x(Variable): - rank_table(Variable): - - Returns: - - """ helper = LayerHelper('reorder_lod_tensor_by_rank', **locals()) helper.is_instance('x', Variable) helper.is_instance('rank_table', Variable) diff --git a/python/paddle/v2/fluid/registry.py b/python/paddle/v2/fluid/registry.py index 6f5dd365ded628ad49800f0a04f208ec49cca4c5..7aa82906114b355277185211134bb791e5dc43f9 100644 --- a/python/paddle/v2/fluid/registry.py +++ b/python/paddle/v2/fluid/registry.py @@ -8,7 +8,7 @@ import proto.framework_pb2 as framework_pb2 from framework import OpProtoHolder, Variable, Program, Operator from paddle.v2.fluid.layer_helper import LayerHelper, unique_name -__all__ = ['deprecated', 'register_layer'] +__all__ = ['deprecated', 'register_layer', 'autodoc'] def _convert_(name): @@ -175,12 +175,18 @@ def deprecated(func_or_class): """ Wrap func with deprecated warning """ - warnings.simplefilter('always', DeprecationWarning) #turn off filter + warnings.simplefilter('always', DeprecationWarning) # turn off filter warnings.warn( "Call to deprecated function {}.".format(func.__name__), category=DeprecationWarning, stacklevel=2) - warnings.simplefilter('default', DeprecationWarning) #reset filter + warnings.simplefilter('default', DeprecationWarning) # reset filter return func(*args, **kwargs) return func_wrapper + + +def autodoc(func): + func.__doc__ = _generate_doc_string_(OpProtoHolder.instance().get_op_proto( + func.__name__)) + return func