提交 b23982a2 编写于 作者: Y Yang Yu

Add ReorderLoDTensorByRank

It is useful to reorder RNN memory block.
上级 0f1c685c
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <paddle/framework/lod_rank_table.h>
#include "paddle/framework/op_registry.h"
#include "paddle/operators/detail/safe_ref.h"
namespace paddle {
namespace operators {
class ReorderLoDTensorProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ReorderLoDTensorProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) the input lod tensor need to be reordered.");
AddInput("RankTable",
"(LoDRankTable) the rank table that input need follow");
AddOutput("Out", "(LoDTensor) reordered lod tensor");
AddComment(R"DOC(ReorderLoDTensorLoDRankTable
Reorder the input X by the rank of `RankTable`. If `RankTable` is ordered by
index [3, 0, 2, 1]. Input X will reorder its sequence, the third sequence of
X will be the first sequence of Output.
NOTE: The RankTable does not need to be calculated by X.
)DOC");
}
};
class ReorderLoDTensorByRankTableBase : public framework::OperatorBase {
public:
ReorderLoDTensorByRankTableBase(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto &x =
detail::Ref(scope.FindVar(Input("X")),
"Cannot find input lod tensor variable %s", Input("X"))
.Get<framework::LoDTensor>();
auto &rank_table = detail::Ref(scope.FindVar(Input("RankTable")),
"Cannot find input rank table variable %s",
Input("RankTable"))
.Get<framework::LoDRankTable>();
auto &out =
*detail::Ref(scope.FindVar(Output("Out")),
"Cannot find output lod tensor variable %s", Output("Out"))
.GetMutable<framework::LoDTensor>();
out.Resize(x.dims());
out.mutable_data(x.place(), x.type());
this->process(dev_ctx, x, rank_table, &out);
}
protected:
virtual void process(const platform::DeviceContext &dev_ctx,
const framework::LoDTensor &x,
const framework::LoDRankTable &rank_table,
framework::LoDTensor *out) const = 0;
struct AbsoluteRankTableItem {
size_t offset; // the absolute/accumulated offset.
size_t length; // the length
framework::LoD lod;
};
std::vector<AbsoluteRankTableItem> GetAbsoluteOffsetAndLengthByLoDRankTable(
const framework::LoDTensor &x) const {
std::vector<AbsoluteRankTableItem> absolute_table;
size_t level = 0;
size_t size = x.lod()[level].size();
for (size_t i = 0; i < size - 1; ++i) {
auto lod_offset =
framework::GetSubLoDAndAbsoluteOffset(x.lod(), i, i + 1, level);
auto &offset = lod_offset.second;
absolute_table.emplace_back();
absolute_table.back().length = offset.second - offset.first;
absolute_table.back().offset = offset.first;
absolute_table.back().lod = lod_offset.first;
}
return absolute_table;
}
size_t CopyTensorAndLod(const platform::DeviceContext &dev_ctx,
const AbsoluteRankTableItem &item,
const framework::LoDTensor &x,
framework::LoDTensor *out, size_t out_offset) const {
auto &out_lod = *out->mutable_lod();
auto len = item.length;
auto x_offset = item.offset;
if (out_lod.empty()) {
for (size_t i = 0; i < item.lod.size(); ++i) {
out_lod.push_back(std::vector<size_t>({0}));
}
}
for (size_t i = 0; i < out_lod.size(); ++i) {
auto &out_v = out_lod[i];
auto &new_lod_v = item.lod[i];
for (auto &detail : new_lod_v) {
out_v.push_back(out_v.back() + detail);
}
}
auto x_sliced = x.Slice(x_offset, x_offset + len);
auto out_sliced = out->Slice(out_offset, out_offset + len);
framework::CopyFrom(x_sliced, out_sliced.place(), dev_ctx, &out_sliced);
out_offset += len;
return out_offset;
}
};
class ReorderLoDTensorByRankTableOp : public ReorderLoDTensorByRankTableBase {
public:
ReorderLoDTensorByRankTableOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: ReorderLoDTensorByRankTableBase(type, inputs, outputs, attrs) {}
protected:
void process(const platform::DeviceContext &dev_ctx,
const framework::LoDTensor &x,
const framework::LoDRankTable &rank_table,
framework::LoDTensor *out) const override {
auto absolute_table = GetAbsoluteOffsetAndLengthByLoDRankTable(x);
size_t out_offset = 0;
out->mutable_lod()->clear();
for (auto &item : rank_table.items()) {
out_offset = this->CopyTensorAndLod(dev_ctx, absolute_table[item.index],
x, out, out_offset);
}
}
};
class IdentityInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
context->SetOutputDim("Out", context->GetInputDim("X"));
}
};
class ReorderLodTensorByRankGradOpMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto *grad_op = new framework::OpDescBind();
grad_op->SetType("reorder_lod_tensor_by_rank_grad");
grad_op->SetInput("X", OutputGrad("Out"));
grad_op->SetOutput("Out", InputGrad("X"));
grad_op->SetInput("RankTable", Input("RankTable"));
return std::unique_ptr<framework::OpDescBind>(grad_op);
}
};
class ReorderLoDTensorByRankGradOp : public ReorderLoDTensorByRankTableBase {
public:
ReorderLoDTensorByRankGradOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: ReorderLoDTensorByRankTableBase(type, inputs, outputs, attrs) {}
protected:
void process(const platform::DeviceContext &dev_ctx,
const framework::LoDTensor &x,
const framework::LoDRankTable &rank_table,
framework::LoDTensor *out) const override {
auto absolute_table = GetAbsoluteOffsetAndLengthByLoDRankTable(x);
// offsets = enumerate([item.index for item in rank_table.items()])
std::vector<std::pair<size_t, size_t>> offsets;
offsets.reserve(rank_table.items().size());
for (size_t i = 0; i < rank_table.items().size(); ++i) {
offsets.push_back({i, rank_table.items()[i].index});
}
// offsets.sort(key=lambda x: x[1])
std::sort(
offsets.begin(), offsets.end(),
[](const std::pair<size_t, size_t> &a,
const std::pair<size_t, size_t> &b) { return a.second < b.second; });
// Copy TensorAndLod
size_t out_offset = 0;
for (auto &offset : offsets) {
out_offset = this->CopyTensorAndLod(dev_ctx, absolute_table[offset.first],
x, out, out_offset);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(reorder_lod_tensor_by_rank,
ops::ReorderLoDTensorByRankTableOp,
ops::ReorderLodTensorByRankGradOpMaker,
ops::ReorderLoDTensorProtoMaker, ops::IdentityInferShape);
REGISTER_OPERATOR(reorder_lod_tensor_by_rank_grad,
ops::ReorderLoDTensorByRankGradOp, ops::IdentityInferShape);
......@@ -389,6 +389,9 @@ class Operator(object):
% (in_proto.name, len(in_args)))
in_arg_names = []
for arg in in_args:
if isinstance(arg, basestring):
in_arg_names.append(arg)
else:
in_arg_names.append(arg.name)
self.desc.set_input(in_proto.name, in_arg_names)
else:
......
......@@ -194,3 +194,9 @@ class LayerHelper(object):
else:
# For integer and boolean types, initialize with all zeros
return Constant()
def is_instance(self, param_name, cls):
param = self.kwargs.get(param_name, None)
if not isinstance(param, cls):
raise TypeError("The input {0} parameter of method {1} must be {2}",
param_name, self.layer_type, cls.__name__)
......@@ -10,7 +10,7 @@ __all__ = [
'max_sequence_len', 'topk', 'lod_tensor_to_array', 'array_to_lod_tensor',
'increment', 'array_write', 'create_array', 'less_than', 'array_read',
'shrink_memory', 'array_length', 'IfElse', 'DynamicRNN', 'ConditionalBlock',
'StaticRNN'
'StaticRNN', 'reorder_lod_tensor_by_rank'
]
......@@ -963,3 +963,26 @@ class DynamicRNN(object):
if self.status != DynamicRNN.IN_RNN:
raise ValueError("{0} can only be invoked inside rnn block.".format(
method))
def reorder_lod_tensor_by_rank(x, rank_table):
"""
Args:
x(Variable):
rank_table(Variable):
Returns:
"""
helper = LayerHelper('reorder_lod_tensor_by_rank', **locals())
helper.is_instance('x', Variable)
helper.is_instance('rank_table', Variable)
out = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(
type='reorder_lod_tensor_by_rank',
inputs={'X': [x],
'RankTable': [rank_table]},
outputs={'Out': [out]})
return out
import unittest
import paddle.v2.fluid as fluid
import numpy
class TestReorderLoDTensor(unittest.TestCase):
def test_reorder(self):
dat = fluid.layers.data(name='input', shape=[1], lod_level=2)
dat.stop_gradient = False
rank_dat = fluid.layers.data(name='ref', shape=[1], lod_level=1)
table = fluid.layers.lod_rank_table(rank_dat)
new_dat = fluid.layers.reorder_lod_tensor_by_rank(
x=dat, rank_table=table)
loss = fluid.layers.mean(x=new_dat)
fluid.backward.append_backward_ops(loss=loss)
cpu = fluid.CPUPlace()
exe = fluid.Executor(cpu)
exe.run(fluid.default_startup_program())
ref = fluid.Tensor()
ref_lod = [0, 3, 4, 7, 8, 14]
ref.set_lod([ref_lod])
ref.set(numpy.random.random(size=[14, 1]).astype('float32'), cpu)
input = fluid.Tensor()
lod_level_0 = numpy.random.randint(low=1, high=5, size=5)
lod_level_0 = [0] + numpy.cumsum(lod_level_0).tolist()
lod_level_1 = numpy.random.randint(low=1, high=5, size=lod_level_0[-1])
lod_level_1 = [0] + numpy.cumsum(lod_level_1).tolist()
input.set_lod([lod_level_0, lod_level_1])
input.set(
numpy.random.random(size=[lod_level_1[-1], 1]).astype('float32'),
cpu)
ig = exe.run(fluid.default_main_program(),
feed={'input': input,
'ref': ref},
fetch_list=['input@GRAD'],
return_numpy=False)[0]
self.assertAlmostEqual(numpy.array(ig).sum(), 1.0, delta=0.001)
self.assertEqual(input.lod(), ig.lod())
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册