未验证 提交 01425309 编写于 作者: Y Yang Yu

Rename shrink_state -> shrink_rnn_memory

Follow comments
上级 b4dddb29
...@@ -18,9 +18,9 @@ ...@@ -18,9 +18,9 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class ShrinkStateOp : public ArrayOp { class ShrinkRNNMemoryOp : public ArrayOp {
public: public:
ShrinkStateOp(const std::string &type, ShrinkRNNMemoryOp(const std::string &type,
const framework::VariableNameMap &inputs, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
...@@ -36,18 +36,12 @@ class ShrinkStateOp : public ArrayOp { ...@@ -36,18 +36,12 @@ class ShrinkStateOp : public ArrayOp {
PADDLE_ENFORCE(rank_table_var != nullptr, "RankTable must be set"); PADDLE_ENFORCE(rank_table_var != nullptr, "RankTable must be set");
auto &rank_table = rank_table_var->Get<framework::LoDRankTable>(); auto &rank_table = rank_table_var->Get<framework::LoDRankTable>();
int dst_num_rows = 0;
{
auto &rank_items = rank_table.items(); auto &rank_items = rank_table.items();
for (auto &rank_item : rank_items) { int dst_num_rows =
if (offset < rank_item.length) { std::lower_bound(rank_items.begin(), rank_items.end(), offset,
++dst_num_rows; [](const framework::LoDRankTable::TableItem &a,
} else { size_t b) { return a.length > b; }) -
break; rank_items.begin();
}
}
}
auto *out_var = scope.FindVar(Output("Out")); auto *out_var = scope.FindVar(Output("Out"));
PADDLE_ENFORCE(out_var != nullptr, "Output Out must be set"); PADDLE_ENFORCE(out_var != nullptr, "Output Out must be set");
...@@ -58,9 +52,9 @@ class ShrinkStateOp : public ArrayOp { ...@@ -58,9 +52,9 @@ class ShrinkStateOp : public ArrayOp {
} }
}; };
class ShrinkStateOpProtoMaker : public framework::OpProtoAndCheckerMaker { class ShrinkRNNMemoryOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ShrinkStateOpProtoMaker(framework::OpProto *proto, ShrinkRNNMemoryOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", ""); AddInput("X", "");
...@@ -71,7 +65,7 @@ class ShrinkStateOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -71,7 +65,7 @@ class ShrinkStateOpProtoMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
class ShrinkStateOpInferShape : public framework::InferShapeBase { class ShrinkRNNMemoryInferShape : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *context) const override { void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X")); PADDLE_ENFORCE(context->HasInput("X"));
...@@ -81,9 +75,9 @@ class ShrinkStateOpInferShape : public framework::InferShapeBase { ...@@ -81,9 +75,9 @@ class ShrinkStateOpInferShape : public framework::InferShapeBase {
} }
}; };
class ShrinkStateGradOp : public ArrayOp { class ShrinkRNNMemoryGradOp : public ArrayOp {
public: public:
ShrinkStateGradOp(const std::string &type, ShrinkRNNMemoryGradOp(const std::string &type,
const framework::VariableNameMap &inputs, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
...@@ -92,8 +86,7 @@ class ShrinkStateGradOp : public ArrayOp { ...@@ -92,8 +86,7 @@ class ShrinkStateGradOp : public ArrayOp {
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::DeviceContext &dev_ctx) const override {
auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out"))); auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out")));
auto dx_name = Output(framework::GradVarName("X")); auto *dx_var = scope.FindVar(Output(framework::GradVarName("X")));
auto *dx_var = scope.FindVar(dx_name);
PADDLE_ENFORCE(dx_var != nullptr, "Input Gradient should not be nullptr"); PADDLE_ENFORCE(dx_var != nullptr, "Input Gradient should not be nullptr");
auto *x_var = scope.FindVar(Input("X")); auto *x_var = scope.FindVar(Input("X"));
PADDLE_ENFORCE(x_var != nullptr); PADDLE_ENFORCE(x_var != nullptr);
...@@ -110,7 +103,7 @@ class ShrinkStateGradOp : public ArrayOp { ...@@ -110,7 +103,7 @@ class ShrinkStateGradOp : public ArrayOp {
auto height = dout_tensor.dims()[0]; auto height = dout_tensor.dims()[0];
dx_tensor.Slice(0, static_cast<int>(height)) dx_tensor.Slice(0, static_cast<int>(height))
.CopyFrom(dout_tensor, dout_tensor.place(), dev_ctx); .CopyFrom(dout_tensor, dout_tensor.place(), dev_ctx);
if (height < dout_tensor.dims()[0]) { if (dx_tensor.dims()[0] < height) {
auto rest_tensor = dx_tensor.Slice( auto rest_tensor = dx_tensor.Slice(
static_cast<int>(height), static_cast<int>(dout_tensor.dims()[0])); static_cast<int>(height), static_cast<int>(dout_tensor.dims()[0]));
math::set_constant(dev_ctx, &rest_tensor, 0.0f); math::set_constant(dev_ctx, &rest_tensor, 0.0f);
...@@ -119,7 +112,7 @@ class ShrinkStateGradOp : public ArrayOp { ...@@ -119,7 +112,7 @@ class ShrinkStateGradOp : public ArrayOp {
} }
}; };
class ShrikStateGradInferShape : public framework::InferShapeBase { class ShrinkRNNMemoryGradInferShape : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *context) const override { void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X")); PADDLE_ENFORCE(context->HasInput("X"));
...@@ -129,14 +122,14 @@ class ShrikStateGradInferShape : public framework::InferShapeBase { ...@@ -129,14 +122,14 @@ class ShrikStateGradInferShape : public framework::InferShapeBase {
} }
}; };
class ShrinkStateGradOpMaker : public framework::SingleGradOpDescMaker { class ShrinkRNNGradOpMaker : public framework::SingleGradOpDescMaker {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
std::unique_ptr<framework::OpDescBind> Apply() const override { std::unique_ptr<framework::OpDescBind> Apply() const override {
auto *op = new framework::OpDescBind(); auto *op = new framework::OpDescBind();
op->SetType("shrink_state_grad"); op->SetType("shrink_rnn_memory_grad");
op->SetInput("X", Input("X")); op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
...@@ -149,8 +142,8 @@ class ShrinkStateGradOpMaker : public framework::SingleGradOpDescMaker { ...@@ -149,8 +142,8 @@ class ShrinkStateGradOpMaker : public framework::SingleGradOpDescMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(shrink_state, ops::ShrinkStateOp, REGISTER_OPERATOR(shrink_rnn_memory, ops::ShrinkRNNMemoryOp,
ops::ShrinkStateOpInferShape, ops::ShrinkStateOpProtoMaker, ops::ShrinkRNNMemoryInferShape,
ops::ShrinkStateGradOpMaker); ops::ShrinkRNNMemoryOpProtoMaker, ops::ShrinkRNNGradOpMaker);
REGISTER_OPERATOR(shrink_state_grad, ops::ShrinkStateGradOp, REGISTER_OPERATOR(shrink_rnn_memory_grad, ops::ShrinkRNNMemoryGradOp,
ops::ShrikStateGradInferShape); ops::ShrinkRNNMemoryGradInferShape);
...@@ -85,7 +85,6 @@ class WriteToArrayInferVarType : public framework::VarTypeInference { ...@@ -85,7 +85,6 @@ class WriteToArrayInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDescBind &op_desc, void operator()(const framework::OpDescBind &op_desc,
framework::BlockDescBind *block) const override { framework::BlockDescBind *block) const override {
VLOG(10) << "I am here?";
for (auto &out_var : op_desc.OutputArgumentNames()) { for (auto &out_var : op_desc.OutputArgumentNames()) {
VLOG(10) << "Set Variable " << out_var << " as LOD_TENSOR_ARRAY"; VLOG(10) << "Set Variable " << out_var << " as LOD_TENSOR_ARRAY";
block->Var(out_var)->SetType(framework::VarDesc::LOD_TENSOR_ARRAY); block->Var(out_var)->SetType(framework::VarDesc::LOD_TENSOR_ARRAY);
......
...@@ -844,7 +844,7 @@ def shrink_memory(x, i, table, main_program=None): ...@@ -844,7 +844,7 @@ def shrink_memory(x, i, table, main_program=None):
helper = LayerHelper('shrink_memory', **locals()) helper = LayerHelper('shrink_memory', **locals())
out = helper.create_tmp_variable(dtype=x.data_type) out = helper.create_tmp_variable(dtype=x.data_type)
helper.append_op( helper.append_op(
type='shrink_state', type='shrink_rnn_memory',
inputs={'X': [x], inputs={'X': [x],
'I': [i], 'I': [i],
'RankTable': [table]}, 'RankTable': [table]},
......
...@@ -7,8 +7,8 @@ from paddle.v2.framework.framework import g_main_program ...@@ -7,8 +7,8 @@ from paddle.v2.framework.framework import g_main_program
import numpy import numpy
class TestShrinkState(unittest.TestCase): class TestShrinkRNNMemory(unittest.TestCase):
def test_shrink_state(self): def test_shrink_rnn_memory(self):
x = layers.data('x', shape=[100], data_type='float32') x = layers.data('x', shape=[100], data_type='float32')
x.stop_gradient = False x.stop_gradient = False
table = layers.lod_rank_table(x=x) table = layers.lod_rank_table(x=x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册