提交 9189567a 编写于 作者: Y Yang Yu

Follow comments

上级 8823a12e
......@@ -19,21 +19,28 @@
namespace paddle {
namespace operators {
class ReorderLoDTensorProtoMaker : public framework::OpProtoAndCheckerMaker {
class ReorderLoDTensorByRankTableOpProtoMaker
: public framework::OpProtoAndCheckerMaker {
public:
ReorderLoDTensorProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
ReorderLoDTensorByRankTableOpProtoMaker(OpProto *proto,
OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) the input lod tensor need to be reordered.");
AddInput("RankTable",
"(LoDRankTable) the rank table that input need follow");
AddOutput("Out", "(LoDTensor) reordered lod tensor");
AddComment(R"DOC(ReorderLoDTensorLoDRankTable
AddComment(R"DOC(ReorderLoDTensorByRankTable
Reorder the input X by the rank of `RankTable`. If `RankTable` is ordered by
index [3, 0, 2, 1]. Input X will reorder its sequence, the third sequence of
X will be the first sequence of Output.
NOTE: The RankTable does not need to be calculated by X.
For example:
The X = [Seq0, Seq1, Seq2, Seq3]. The indices of RankTable are [3, 0, 2, 1].
The Out = [Seq3, Seq0, Seq2, Seq1] with correct LoD information.
)DOC");
}
};
......@@ -146,8 +153,9 @@ class ReorderLoDTensorByRankTableOp : public ReorderLoDTensorByRankTableBase {
size_t out_offset = 0;
out->mutable_lod()->clear();
for (auto &item : rank_table.items()) {
out_offset = this->CopyTensorAndLod(dev_ctx, absolute_table[item.index],
x, out, out_offset);
PADDLE_ENFORCE_LT(item.index, absolute_table.size());
out_offset = CopyTensorAndLod(dev_ctx, absolute_table[item.index], x, out,
out_offset);
}
}
};
......@@ -220,6 +228,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(reorder_lod_tensor_by_rank,
ops::ReorderLoDTensorByRankTableOp,
ops::ReorderLodTensorByRankGradOpMaker,
ops::ReorderLoDTensorProtoMaker, ops::IdentityInferShape);
ops::ReorderLoDTensorByRankTableOpProtoMaker,
ops::IdentityInferShape);
REGISTER_OPERATOR(reorder_lod_tensor_by_rank_grad,
ops::ReorderLoDTensorByRankGradOp, ops::IdentityInferShape);
......@@ -3,6 +3,7 @@ from ..framework import Program, Variable, Operator
from .. import core
from tensor import assign, fill_constant
import contextlib
from ..registry import autodoc
__all__ = [
'split_lod_tensor', 'merge_lod_tensor', 'BlockGuard', 'StaticRNNGuard',
......@@ -983,16 +984,8 @@ class DynamicRNN(object):
method))
@autodoc
def reorder_lod_tensor_by_rank(x, rank_table):
"""
Args:
x(Variable):
rank_table(Variable):
Returns:
"""
helper = LayerHelper('reorder_lod_tensor_by_rank', **locals())
helper.is_instance('x', Variable)
helper.is_instance('rank_table', Variable)
......
......@@ -8,7 +8,7 @@ import proto.framework_pb2 as framework_pb2
from framework import OpProtoHolder, Variable, Program, Operator
from paddle.v2.fluid.layer_helper import LayerHelper, unique_name
__all__ = ['deprecated', 'register_layer']
__all__ = ['deprecated', 'register_layer', 'autodoc']
def _convert_(name):
......@@ -175,12 +175,18 @@ def deprecated(func_or_class):
"""
Wrap func with deprecated warning
"""
warnings.simplefilter('always', DeprecationWarning) #turn off filter
warnings.simplefilter('always', DeprecationWarning) # turn off filter
warnings.warn(
"Call to deprecated function {}.".format(func.__name__),
category=DeprecationWarning,
stacklevel=2)
warnings.simplefilter('default', DeprecationWarning) #reset filter
warnings.simplefilter('default', DeprecationWarning) # reset filter
return func(*args, **kwargs)
return func_wrapper
def autodoc(func):
func.__doc__ = _generate_doc_string_(OpProtoHolder.instance().get_op_proto(
func.__name__))
return func
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册