From 01425309292983205a5fff9658799a0c3efcf6b9 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Tue, 7 Nov 2017 20:13:16 -0800 Subject: [PATCH] Rename shrink_state -> shrink_rnn_memory Follow comments --- ...nk_state_op.cc => shrink_rnn_memory_op.cc} | 67 +++++++++---------- .../operators/tensor_array_read_write_op.cc | 1 - python/paddle/v2/framework/layers.py | 2 +- ...ink_state.py => test_shrink_rnn_memory.py} | 4 +- 4 files changed, 33 insertions(+), 41 deletions(-) rename paddle/operators/{shrink_state_op.cc => shrink_rnn_memory_op.cc} (73%) rename python/paddle/v2/framework/tests/{test_shrink_state.py => test_shrink_rnn_memory.py} (95%) diff --git a/paddle/operators/shrink_state_op.cc b/paddle/operators/shrink_rnn_memory_op.cc similarity index 73% rename from paddle/operators/shrink_state_op.cc rename to paddle/operators/shrink_rnn_memory_op.cc index 5aaecf0aa..65bccc0c8 100644 --- a/paddle/operators/shrink_state_op.cc +++ b/paddle/operators/shrink_rnn_memory_op.cc @@ -18,12 +18,12 @@ namespace paddle { namespace operators { -class ShrinkStateOp : public ArrayOp { +class ShrinkRNNMemoryOp : public ArrayOp { public: - ShrinkStateOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) + ShrinkRNNMemoryOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) : ArrayOp(type, inputs, outputs, attrs) {} void Run(const framework::Scope &scope, @@ -36,18 +36,12 @@ class ShrinkStateOp : public ArrayOp { PADDLE_ENFORCE(rank_table_var != nullptr, "RankTable must be set"); auto &rank_table = rank_table_var->Get(); - int dst_num_rows = 0; - - { - auto &rank_items = rank_table.items(); - for (auto &rank_item : rank_items) { - if (offset < rank_item.length) { - ++dst_num_rows; - } else { - break; - } - } - } + auto &rank_items = rank_table.items(); + int dst_num_rows = + std::lower_bound(rank_items.begin(), rank_items.end(), offset, + [](const framework::LoDRankTable::TableItem &a, + size_t b) { return a.length > b; }) - + rank_items.begin(); auto *out_var = scope.FindVar(Output("Out")); PADDLE_ENFORCE(out_var != nullptr, "Output Out must be set"); @@ -58,10 +52,10 @@ class ShrinkStateOp : public ArrayOp { } }; -class ShrinkStateOpProtoMaker : public framework::OpProtoAndCheckerMaker { +class ShrinkRNNMemoryOpProtoMaker : public framework::OpProtoAndCheckerMaker { public: - ShrinkStateOpProtoMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + ShrinkRNNMemoryOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", ""); AddInput("RankTable", ""); @@ -71,7 +65,7 @@ class ShrinkStateOpProtoMaker : public framework::OpProtoAndCheckerMaker { } }; -class ShrinkStateOpInferShape : public framework::InferShapeBase { +class ShrinkRNNMemoryInferShape : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *context) const override { PADDLE_ENFORCE(context->HasInput("X")); @@ -81,19 +75,18 @@ class ShrinkStateOpInferShape : public framework::InferShapeBase { } }; -class ShrinkStateGradOp : public ArrayOp { +class ShrinkRNNMemoryGradOp : public ArrayOp { public: - ShrinkStateGradOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) + ShrinkRNNMemoryGradOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) : ArrayOp(type, inputs, outputs, attrs) {} void Run(const framework::Scope &scope, const platform::DeviceContext &dev_ctx) const override { auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out"))); - auto dx_name = Output(framework::GradVarName("X")); - auto *dx_var = scope.FindVar(dx_name); + auto *dx_var = scope.FindVar(Output(framework::GradVarName("X"))); PADDLE_ENFORCE(dx_var != nullptr, "Input Gradient should not be nullptr"); auto *x_var = scope.FindVar(Input("X")); PADDLE_ENFORCE(x_var != nullptr); @@ -110,7 +103,7 @@ class ShrinkStateGradOp : public ArrayOp { auto height = dout_tensor.dims()[0]; dx_tensor.Slice(0, static_cast(height)) .CopyFrom(dout_tensor, dout_tensor.place(), dev_ctx); - if (height < dout_tensor.dims()[0]) { + if (dx_tensor.dims()[0] < height) { auto rest_tensor = dx_tensor.Slice( static_cast(height), static_cast(dout_tensor.dims()[0])); math::set_constant(dev_ctx, &rest_tensor, 0.0f); @@ -119,7 +112,7 @@ class ShrinkStateGradOp : public ArrayOp { } }; -class ShrikStateGradInferShape : public framework::InferShapeBase { +class ShrinkRNNMemoryGradInferShape : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *context) const override { PADDLE_ENFORCE(context->HasInput("X")); @@ -129,14 +122,14 @@ class ShrikStateGradInferShape : public framework::InferShapeBase { } }; -class ShrinkStateGradOpMaker : public framework::SingleGradOpDescMaker { +class ShrinkRNNGradOpMaker : public framework::SingleGradOpDescMaker { public: using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; protected: std::unique_ptr Apply() const override { auto *op = new framework::OpDescBind(); - op->SetType("shrink_state_grad"); + op->SetType("shrink_rnn_memory_grad"); op->SetInput("X", Input("X")); op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetOutput(framework::GradVarName("X"), InputGrad("X")); @@ -149,8 +142,8 @@ class ShrinkStateGradOpMaker : public framework::SingleGradOpDescMaker { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(shrink_state, ops::ShrinkStateOp, - ops::ShrinkStateOpInferShape, ops::ShrinkStateOpProtoMaker, - ops::ShrinkStateGradOpMaker); -REGISTER_OPERATOR(shrink_state_grad, ops::ShrinkStateGradOp, - ops::ShrikStateGradInferShape); +REGISTER_OPERATOR(shrink_rnn_memory, ops::ShrinkRNNMemoryOp, + ops::ShrinkRNNMemoryInferShape, + ops::ShrinkRNNMemoryOpProtoMaker, ops::ShrinkRNNGradOpMaker); +REGISTER_OPERATOR(shrink_rnn_memory_grad, ops::ShrinkRNNMemoryGradOp, + ops::ShrinkRNNMemoryGradInferShape); diff --git a/paddle/operators/tensor_array_read_write_op.cc b/paddle/operators/tensor_array_read_write_op.cc index 87b6b6929..eaf635274 100644 --- a/paddle/operators/tensor_array_read_write_op.cc +++ b/paddle/operators/tensor_array_read_write_op.cc @@ -85,7 +85,6 @@ class WriteToArrayInferVarType : public framework::VarTypeInference { public: void operator()(const framework::OpDescBind &op_desc, framework::BlockDescBind *block) const override { - VLOG(10) << "I am here?"; for (auto &out_var : op_desc.OutputArgumentNames()) { VLOG(10) << "Set Variable " << out_var << " as LOD_TENSOR_ARRAY"; block->Var(out_var)->SetType(framework::VarDesc::LOD_TENSOR_ARRAY); diff --git a/python/paddle/v2/framework/layers.py b/python/paddle/v2/framework/layers.py index 8fc34501c..4504cf736 100644 --- a/python/paddle/v2/framework/layers.py +++ b/python/paddle/v2/framework/layers.py @@ -844,7 +844,7 @@ def shrink_memory(x, i, table, main_program=None): helper = LayerHelper('shrink_memory', **locals()) out = helper.create_tmp_variable(dtype=x.data_type) helper.append_op( - type='shrink_state', + type='shrink_rnn_memory', inputs={'X': [x], 'I': [i], 'RankTable': [table]}, diff --git a/python/paddle/v2/framework/tests/test_shrink_state.py b/python/paddle/v2/framework/tests/test_shrink_rnn_memory.py similarity index 95% rename from python/paddle/v2/framework/tests/test_shrink_state.py rename to python/paddle/v2/framework/tests/test_shrink_rnn_memory.py index 2601c769e..2090455b9 100644 --- a/python/paddle/v2/framework/tests/test_shrink_state.py +++ b/python/paddle/v2/framework/tests/test_shrink_rnn_memory.py @@ -7,8 +7,8 @@ from paddle.v2.framework.framework import g_main_program import numpy -class TestShrinkState(unittest.TestCase): - def test_shrink_state(self): +class TestShrinkRNNMemory(unittest.TestCase): + def test_shrink_rnn_memory(self): x = layers.data('x', shape=[100], data_type='float32') x.stop_gradient = False table = layers.lod_rank_table(x=x) -- GitLab