提交 9189567a 编写于 作者: Y Yang Yu

Follow comments

上级 8823a12e
...@@ -19,21 +19,28 @@ ...@@ -19,21 +19,28 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class ReorderLoDTensorProtoMaker : public framework::OpProtoAndCheckerMaker { class ReorderLoDTensorByRankTableOpProtoMaker
: public framework::OpProtoAndCheckerMaker {
public: public:
ReorderLoDTensorProtoMaker(OpProto *proto, OpAttrChecker *op_checker) ReorderLoDTensorByRankTableOpProtoMaker(OpProto *proto,
OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) the input lod tensor need to be reordered."); AddInput("X", "(LoDTensor) the input lod tensor need to be reordered.");
AddInput("RankTable", AddInput("RankTable",
"(LoDRankTable) the rank table that input need follow"); "(LoDRankTable) the rank table that input need follow");
AddOutput("Out", "(LoDTensor) reordered lod tensor"); AddOutput("Out", "(LoDTensor) reordered lod tensor");
AddComment(R"DOC(ReorderLoDTensorLoDRankTable AddComment(R"DOC(ReorderLoDTensorByRankTable
Reorder the input X by the rank of `RankTable`. If `RankTable` is ordered by Reorder the input X by the rank of `RankTable`. If `RankTable` is ordered by
index [3, 0, 2, 1]. Input X will reorder its sequence, the third sequence of index [3, 0, 2, 1]. Input X will reorder its sequence, the third sequence of
X will be the first sequence of Output. X will be the first sequence of Output.
NOTE: The RankTable does not need to be calculated by X. NOTE: The RankTable does not need to be calculated by X.
For example:
The X = [Seq0, Seq1, Seq2, Seq3]. The indices of RankTable are [3, 0, 2, 1].
The Out = [Seq3, Seq0, Seq2, Seq1] with correct LoD information.
)DOC"); )DOC");
} }
}; };
...@@ -146,8 +153,9 @@ class ReorderLoDTensorByRankTableOp : public ReorderLoDTensorByRankTableBase { ...@@ -146,8 +153,9 @@ class ReorderLoDTensorByRankTableOp : public ReorderLoDTensorByRankTableBase {
size_t out_offset = 0; size_t out_offset = 0;
out->mutable_lod()->clear(); out->mutable_lod()->clear();
for (auto &item : rank_table.items()) { for (auto &item : rank_table.items()) {
out_offset = this->CopyTensorAndLod(dev_ctx, absolute_table[item.index], PADDLE_ENFORCE_LT(item.index, absolute_table.size());
x, out, out_offset); out_offset = CopyTensorAndLod(dev_ctx, absolute_table[item.index], x, out,
out_offset);
} }
} }
}; };
...@@ -220,6 +228,7 @@ namespace ops = paddle::operators; ...@@ -220,6 +228,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(reorder_lod_tensor_by_rank, REGISTER_OPERATOR(reorder_lod_tensor_by_rank,
ops::ReorderLoDTensorByRankTableOp, ops::ReorderLoDTensorByRankTableOp,
ops::ReorderLodTensorByRankGradOpMaker, ops::ReorderLodTensorByRankGradOpMaker,
ops::ReorderLoDTensorProtoMaker, ops::IdentityInferShape); ops::ReorderLoDTensorByRankTableOpProtoMaker,
ops::IdentityInferShape);
REGISTER_OPERATOR(reorder_lod_tensor_by_rank_grad, REGISTER_OPERATOR(reorder_lod_tensor_by_rank_grad,
ops::ReorderLoDTensorByRankGradOp, ops::IdentityInferShape); ops::ReorderLoDTensorByRankGradOp, ops::IdentityInferShape);
...@@ -3,6 +3,7 @@ from ..framework import Program, Variable, Operator ...@@ -3,6 +3,7 @@ from ..framework import Program, Variable, Operator
from .. import core from .. import core
from tensor import assign, fill_constant from tensor import assign, fill_constant
import contextlib import contextlib
from ..registry import autodoc
__all__ = [ __all__ = [
'split_lod_tensor', 'merge_lod_tensor', 'BlockGuard', 'StaticRNNGuard', 'split_lod_tensor', 'merge_lod_tensor', 'BlockGuard', 'StaticRNNGuard',
...@@ -983,16 +984,8 @@ class DynamicRNN(object): ...@@ -983,16 +984,8 @@ class DynamicRNN(object):
method)) method))
@autodoc
def reorder_lod_tensor_by_rank(x, rank_table): def reorder_lod_tensor_by_rank(x, rank_table):
"""
Args:
x(Variable):
rank_table(Variable):
Returns:
"""
helper = LayerHelper('reorder_lod_tensor_by_rank', **locals()) helper = LayerHelper('reorder_lod_tensor_by_rank', **locals())
helper.is_instance('x', Variable) helper.is_instance('x', Variable)
helper.is_instance('rank_table', Variable) helper.is_instance('rank_table', Variable)
......
...@@ -8,7 +8,7 @@ import proto.framework_pb2 as framework_pb2 ...@@ -8,7 +8,7 @@ import proto.framework_pb2 as framework_pb2
from framework import OpProtoHolder, Variable, Program, Operator from framework import OpProtoHolder, Variable, Program, Operator
from paddle.v2.fluid.layer_helper import LayerHelper, unique_name from paddle.v2.fluid.layer_helper import LayerHelper, unique_name
__all__ = ['deprecated', 'register_layer'] __all__ = ['deprecated', 'register_layer', 'autodoc']
def _convert_(name): def _convert_(name):
...@@ -175,12 +175,18 @@ def deprecated(func_or_class): ...@@ -175,12 +175,18 @@ def deprecated(func_or_class):
""" """
Wrap func with deprecated warning Wrap func with deprecated warning
""" """
warnings.simplefilter('always', DeprecationWarning) #turn off filter warnings.simplefilter('always', DeprecationWarning) # turn off filter
warnings.warn( warnings.warn(
"Call to deprecated function {}.".format(func.__name__), "Call to deprecated function {}.".format(func.__name__),
category=DeprecationWarning, category=DeprecationWarning,
stacklevel=2) stacklevel=2)
warnings.simplefilter('default', DeprecationWarning) #reset filter warnings.simplefilter('default', DeprecationWarning) # reset filter
return func(*args, **kwargs) return func(*args, **kwargs)
return func_wrapper return func_wrapper
def autodoc(func):
func.__doc__ = _generate_doc_string_(OpProtoHolder.instance().get_op_proto(
func.__name__))
return func
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册