未验证 提交 01425309 编写于 作者: Y Yang Yu

Rename shrink_state -> shrink_rnn_memory

Follow comments
上级 b4dddb29
......@@ -18,9 +18,9 @@
namespace paddle {
namespace operators {
class ShrinkStateOp : public ArrayOp {
class ShrinkRNNMemoryOp : public ArrayOp {
public:
ShrinkStateOp(const std::string &type,
ShrinkRNNMemoryOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
......@@ -36,18 +36,12 @@ class ShrinkStateOp : public ArrayOp {
PADDLE_ENFORCE(rank_table_var != nullptr, "RankTable must be set");
auto &rank_table = rank_table_var->Get<framework::LoDRankTable>();
int dst_num_rows = 0;
{
auto &rank_items = rank_table.items();
for (auto &rank_item : rank_items) {
if (offset < rank_item.length) {
++dst_num_rows;
} else {
break;
}
}
}
int dst_num_rows =
std::lower_bound(rank_items.begin(), rank_items.end(), offset,
[](const framework::LoDRankTable::TableItem &a,
size_t b) { return a.length > b; }) -
rank_items.begin();
auto *out_var = scope.FindVar(Output("Out"));
PADDLE_ENFORCE(out_var != nullptr, "Output Out must be set");
......@@ -58,9 +52,9 @@ class ShrinkStateOp : public ArrayOp {
}
};
class ShrinkStateOpProtoMaker : public framework::OpProtoAndCheckerMaker {
class ShrinkRNNMemoryOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ShrinkStateOpProtoMaker(framework::OpProto *proto,
ShrinkRNNMemoryOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "");
......@@ -71,7 +65,7 @@ class ShrinkStateOpProtoMaker : public framework::OpProtoAndCheckerMaker {
}
};
class ShrinkStateOpInferShape : public framework::InferShapeBase {
class ShrinkRNNMemoryInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X"));
......@@ -81,9 +75,9 @@ class ShrinkStateOpInferShape : public framework::InferShapeBase {
}
};
class ShrinkStateGradOp : public ArrayOp {
class ShrinkRNNMemoryGradOp : public ArrayOp {
public:
ShrinkStateGradOp(const std::string &type,
ShrinkRNNMemoryGradOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
......@@ -92,8 +86,7 @@ class ShrinkStateGradOp : public ArrayOp {
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out")));
auto dx_name = Output(framework::GradVarName("X"));
auto *dx_var = scope.FindVar(dx_name);
auto *dx_var = scope.FindVar(Output(framework::GradVarName("X")));
PADDLE_ENFORCE(dx_var != nullptr, "Input Gradient should not be nullptr");
auto *x_var = scope.FindVar(Input("X"));
PADDLE_ENFORCE(x_var != nullptr);
......@@ -110,7 +103,7 @@ class ShrinkStateGradOp : public ArrayOp {
auto height = dout_tensor.dims()[0];
dx_tensor.Slice(0, static_cast<int>(height))
.CopyFrom(dout_tensor, dout_tensor.place(), dev_ctx);
if (height < dout_tensor.dims()[0]) {
if (dx_tensor.dims()[0] < height) {
auto rest_tensor = dx_tensor.Slice(
static_cast<int>(height), static_cast<int>(dout_tensor.dims()[0]));
math::set_constant(dev_ctx, &rest_tensor, 0.0f);
......@@ -119,7 +112,7 @@ class ShrinkStateGradOp : public ArrayOp {
}
};
class ShrikStateGradInferShape : public framework::InferShapeBase {
class ShrinkRNNMemoryGradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X"));
......@@ -129,14 +122,14 @@ class ShrikStateGradInferShape : public framework::InferShapeBase {
}
};
class ShrinkStateGradOpMaker : public framework::SingleGradOpDescMaker {
class ShrinkRNNGradOpMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto *op = new framework::OpDescBind();
op->SetType("shrink_state_grad");
op->SetType("shrink_rnn_memory_grad");
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
......@@ -149,8 +142,8 @@ class ShrinkStateGradOpMaker : public framework::SingleGradOpDescMaker {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(shrink_state, ops::ShrinkStateOp,
ops::ShrinkStateOpInferShape, ops::ShrinkStateOpProtoMaker,
ops::ShrinkStateGradOpMaker);
REGISTER_OPERATOR(shrink_state_grad, ops::ShrinkStateGradOp,
ops::ShrikStateGradInferShape);
REGISTER_OPERATOR(shrink_rnn_memory, ops::ShrinkRNNMemoryOp,
ops::ShrinkRNNMemoryInferShape,
ops::ShrinkRNNMemoryOpProtoMaker, ops::ShrinkRNNGradOpMaker);
REGISTER_OPERATOR(shrink_rnn_memory_grad, ops::ShrinkRNNMemoryGradOp,
ops::ShrinkRNNMemoryGradInferShape);
......@@ -85,7 +85,6 @@ class WriteToArrayInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDescBind &op_desc,
framework::BlockDescBind *block) const override {
VLOG(10) << "I am here?";
for (auto &out_var : op_desc.OutputArgumentNames()) {
VLOG(10) << "Set Variable " << out_var << " as LOD_TENSOR_ARRAY";
block->Var(out_var)->SetType(framework::VarDesc::LOD_TENSOR_ARRAY);
......
......@@ -844,7 +844,7 @@ def shrink_memory(x, i, table, main_program=None):
helper = LayerHelper('shrink_memory', **locals())
out = helper.create_tmp_variable(dtype=x.data_type)
helper.append_op(
type='shrink_state',
type='shrink_rnn_memory',
inputs={'X': [x],
'I': [i],
'RankTable': [table]},
......
......@@ -7,8 +7,8 @@ from paddle.v2.framework.framework import g_main_program
import numpy
class TestShrinkState(unittest.TestCase):
def test_shrink_state(self):
class TestShrinkRNNMemory(unittest.TestCase):
def test_shrink_rnn_memory(self):
x = layers.data('x', shape=[100], data_type='float32')
x.stop_gradient = False
table = layers.lod_rank_table(x=x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册