diff --git a/cmake/configure.cmake b/cmake/configure.cmake index c1c93e17fd82ea048ba27b127b1527d9a8c9da41..db8f5ab0456792f903093b9cf20e2541f00add5c 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -24,6 +24,10 @@ if(WITH_DOUBLE) add_definitions(-DPADDLE_TYPE_DOUBLE) endif(WITH_DOUBLE) +if(WITH_TESTING) + add_definitions(-DPADDLE_WITH_TESTING) +endif(WITH_TESTING) + if(NOT WITH_TIMER) add_definitions(-DPADDLE_DISABLE_TIMER) endif(NOT WITH_TIMER) diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 1e9ace99876e00b9f733da3e66f37dc0dc2f8cad..15f80b57206c90f689acfdcac60a0d9011025fc0 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -142,9 +142,9 @@ class OperatorBase { // Macro for define a clone method. // If you are writing an kernel operator, `Clone` will be defined when you // register it. i.e. `Clone` method is not needed to define by yourself. -#define DEFINE_OP_CLONE_METHOD(cls) \ - std::unique_ptr Clone() const final { \ - return std::unique_ptr(new cls(*this)); \ +#define DEFINE_OP_CLONE_METHOD(cls) \ + std::unique_ptr<::paddle::framework::OperatorBase> Clone() const final { \ + return std::unique_ptr<::paddle::framework::OperatorBase>(new cls(*this)); \ } // Macro for define a default constructor for Operator. diff --git a/paddle/framework/tensor_array.h b/paddle/framework/tensor_array.h index 94a14c2df492b175cf6a643800937878e95c5f37..293da04997304be41810446cb3e866d545805f83 100644 --- a/paddle/framework/tensor_array.h +++ b/paddle/framework/tensor_array.h @@ -87,12 +87,12 @@ class TensorArray { LoDTensor Stack() const; /* - * Unpacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors. + * Unstacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors. */ void Unstack(const LoDTensor &source) const; /* - * Unpacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors, + * Unstacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors, * with memory of tensors shared. */ void UnstackShared(const LoDTensor &source) const; diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index d132c1813e6871add95c5017d563da35a7912fe4..7dae8fe2f99f9ec1233d0a0f6180cc9f287fc150 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -133,3 +133,4 @@ cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory) +cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc DEPS dynamic_recurrent_op recurrent_op tensor_array) diff --git a/paddle/operators/dynamic_recurrent_op.cc b/paddle/operators/dynamic_recurrent_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b919aef8fb62e5b2331c2d842556e0642ea6b095 --- /dev/null +++ b/paddle/operators/dynamic_recurrent_op.cc @@ -0,0 +1,276 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve . + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/dynamic_recurrent_op.h" + +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::Scope; +using framework::TensorArray; +using framework::LoDTensor; +using framework::Variable; + +namespace detail { + +inline void CreateVariables(Scope& scope, + const std::vector& var_names) { + for (const auto& name : var_names) { + scope.NewVar(name); + } +} + +} // namespace detail + +class DynamicRecurrentOpProtoAndCheckerMaker + : public framework::OpProtoAndCheckerMaker { + public: + DynamicRecurrentOpProtoAndCheckerMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + const auto& name = DynamicRecurrentOp::kArgName; + // inputs and outputs stored in proto + AddInput(name.inlinks, + "the inputs that need to be segmented for each step.") + .AsDuplicable(); + AddInput(name.boot_memories, "variables to initialize memories.") + .AsDuplicable(); + + AddOutput(name.outlinks, "the outputs that need to concated for all steps.") + .AsDuplicable(); + AddOutput(name.step_scopes, "step scopes"); + + // Attributes stored in AttributeMap + AddAttr>(name.pre_memories, + "names of pre-memories"); + AddAttr>(name.memories, "names of memories"); + + AddComment("This is a RNN operator for varience-length sequences."); + } +}; + +void DynamicRecurrentOp::Run(const Scope& scope, + const platform::DeviceContext& dev_ctx) const { + cache_.Init(kArgName, *this, scope, &arg_); + SplitInputs(); + CreateScopes(); + WriteStepInputs(); + InitStates(); + + // call stepnet in all the time steps + for (size_t step = 0; step < cache_.num_steps; step++) { + auto& step_scope = cache_.GetScope(step); + stepnet_->Run(step_scope, dev_ctx); + } + + WriteStepOutputs(); + ConcatOutputs(); +} + +void DynamicRecurrentOp::SplitInputs() const { + // TODO(superjom) make level a config + // TODO(superjom) check all the inputs has the same LoD + int level = 0; + const auto& inlinks = cache_.inlinks; + for (const auto& item : inlinks) { + const auto& var = item.second; + const auto& tensor = var->Get(); + TensorArray& ta = step_inputs_[item.first]; + dy_seq_metas_[item.first] = + ta.Unpack(tensor, level, true /*length_descend*/); + + if (cache_.num_steps) { + PADDLE_ENFORCE_EQ(ta.size(), cache_.num_steps, + "inputs should have the same steps"); + } else { + cache_.num_steps = ta.size(); + } + } +} + +void DynamicRecurrentOp::WriteStepInputs() const { + for (const auto& item : cache_.inlinks) { + auto ta_it = step_inputs_.find(item.first); + PADDLE_ENFORCE(ta_it != step_inputs_.end(), + "step_inputs_ not compatible with memory set"); + TensorArray& ta = ta_it->second; + for (size_t step = 0; step < ta.size(); step++) { + auto tensor = ta.Read(step); + auto& step_scope = cache_.GetScope(step); + Variable* var = step_scope.FindVar(item.first); + if (var == nullptr) { + var = step_scope.NewVar(item.first); + } + var->GetMutable()->ShareDataWith(tensor); + } + } +} + +void DynamicRecurrentOp::WriteStepOutputs() const { + for (size_t step = 0; step < cache_.scopes->size(); step++) { + auto& scope = cache_.GetScope(step); + for (auto& item : step_outputs_) { + auto* var = scope.FindVar(item.first); + if (var == nullptr) { + var = scope.NewVar(item.first); + } + auto* tensor = var->GetMutable(); + item.second.WriteShared(step, *tensor); + } + } +} + +void DynamicRecurrentOp::CreateScopes() const { + PADDLE_ENFORCE_GT(cache_.num_steps, 0); + // resize scopes + size_t num_scopes_need_create = cache_.num_steps - cache_.scopes->size(); + for (size_t i = 0; i < num_scopes_need_create; i++) { + cache_.scopes->emplace_back(&cache_.scope->NewScope()); + } + + // init temporary inputs + PADDLE_ENFORCE_NOT_NULL(stepnet_, "stepnet should be set first"); + std::vector memories; + std::vector pre_memories; + std::transform(arg_.memories.begin(), arg_.memories.end(), + std::back_inserter(memories), + [](const rnn::MemoryAttr& m) { return m.var; }); + std::transform(arg_.memories.begin(), arg_.memories.end(), + std::back_inserter(pre_memories), + [](const rnn::MemoryAttr& m) { return m.pre_var; }); + + for (size_t step = 0; step < cache_.num_steps; step++) { + auto& scope = cache_.GetScope(step); + detail::CreateVariables(scope, arg_.inlinks); + detail::CreateVariables(scope, arg_.outlinks); + detail::CreateVariables(scope, memories); + detail::CreateVariables(scope, pre_memories); + } +} + +void DynamicRecurrentOp::ConcatOutputs() const { + // TODO(superjom) transform this to a config + int level = 0; + // TODO(superjom) pass in some lod + // just a placeholder + framework::LoD lod; + for (auto& item : step_outputs_) { + auto tensor = item.second.Pack(level, dy_seq_metas_[item.first], lod); + auto& output = cache_.outlinks[item.first]->Get(); + const_cast(&output)->ShareDataWith(tensor); + } +} + +void DynamicRecurrentOp::InitStates() const { + // init the first state + // TODO(superjom) parepare the scenerio that boot state not exists + for (auto memory : arg_.memories) { + auto* boot_state_var = cache_.scope->FindVar(memory.boot_var); + PADDLE_ENFORCE_NOT_NULL(boot_state_var); + auto& boot_state = boot_state_var->Get(); + const auto& dims = boot_state.dims(); + + for (size_t step = 0; step < cache_.num_steps; step++) { + auto& cur_scope = cache_.GetScope(step); + // link pre-state to boot_state + // init state and pre-state + auto* pre_state = cur_scope.FindVar(memory.pre_var); + PADDLE_ENFORCE_NOT_NULL(pre_state); + pre_state->GetMutable(); + + auto* state = cur_scope.FindVar(memory.var); + PADDLE_ENFORCE_NOT_NULL(state); + state->GetMutable()->Resize(dims); + state->GetMutable()->mutable_data( + platform::CPUPlace()); + + if (step == 0) { + auto* pre_state_tensor = pre_state->GetMutable(); + pre_state_tensor->Resize(boot_state.dims()); + pre_state_tensor->ShareDataWith(boot_state); + } else { + auto& pre_scope = cache_.GetScope(step - 1); + auto* state_pre = pre_scope.FindVar(memory.var); + PADDLE_ENFORCE_NOT_NULL(state_pre); + pre_state->GetMutable()->ShareDataWith( + *state_pre->GetMutable()); + } + } + } +} + +void DynamicRecurrentOp::ArgCache::Init( + const rnn::ArgumentName& name, const paddle::framework::OperatorBase& op, + const paddle::framework::Scope& scope, rnn::Argument* arg) { + this->scope = &scope; + InitArgument(name, op, arg); + CacheScopes(scope, *arg); + CacheInlinks(scope, arg->inlinks); + CacheOutlinks(scope, arg->outlinks); +} + +void DynamicRecurrentOp::ArgCache::InitArgument(const rnn::ArgumentName& name, + const OperatorBase& op, + rnn::Argument* arg) { + rnn::InitArgument(name, arg, op, false /*is_grad*/); +} + +void DynamicRecurrentOp::ArgCache::CacheScopes(const Scope& scope, + const rnn::Argument& arg) { + auto scopes_var = scope.FindVar(arg.step_scopes); + PADDLE_ENFORCE(scopes_var != nullptr, + "the step_scopes output argument [%s] should be created first " + "by framework.", + arg.step_scopes); + this->scopes = scopes_var->GetMutable>(); +} + +void DynamicRecurrentOp::ArgCache::CacheInlinks( + const Scope& scope, const std::vector& names) { + for (auto name : names) { + auto* var = GetVariable(scope, name); + inlinks[name] = var; + } +} + +void DynamicRecurrentOp::ArgCache::CacheOutlinks( + const Scope& scope, const std::vector& names) { + for (auto name : names) { + auto* var = GetVariable(scope, name); + outlinks[name] = var; + } +} + +Variable* DynamicRecurrentOp::ArgCache::GetVariable(const Scope& scope, + const std::string& name) { + auto* var = scope.FindVar(name); + PADDLE_ENFORCE_NOT_NULL(var, "variable [%s] not exist in scope", name); + return var; +} + +const rnn::ArgumentName DynamicRecurrentOp::kArgName{ + "step_net", "step_scopes", "inlinks", "outlinks", + "memories", "pre_memories", "boot_memories"}; + +void DynamicRecurrentGradientOp::Run( + const Scope& scope, const platform::DeviceContext& dev_ctx) const {} + +} // namespace operators +} // namespace paddle + +REGISTER_OP_WITHOUT_GRADIENT( + dynamic_recurrent, paddle::operators::DynamicRecurrentOp, + paddle::operators::DynamicRecurrentOpProtoAndCheckerMaker); diff --git a/paddle/operators/dynamic_recurrent_op.h b/paddle/operators/dynamic_recurrent_op.h new file mode 100644 index 0000000000000000000000000000000000000000..6a2970f27fd5bcb25e924dbc567e254159b55a3e --- /dev/null +++ b/paddle/operators/dynamic_recurrent_op.h @@ -0,0 +1,158 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#ifdef PADDLE_WITH_TESTING +#include "gtest/gtest.h" +#endif + +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/operator.h" +#include "paddle/framework/tensor_array.h" +#include "paddle/framework/variable.h" +#include "paddle/operators/rnn/recurrent_op_utils.h" + +namespace paddle { +namespace operators { + +class DynamicRecurrentOp : public framework::OperatorBase { + public: + static const rnn::ArgumentName kArgName; + using value_type = float; + + DynamicRecurrentOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + DynamicRecurrentOp(const DynamicRecurrentOp& o) + : framework::OperatorBase( + static_cast(o)) { + // TODO(yuyang18): Implement copy ctor well. + PADDLE_THROW("Not implemented"); + } + + void Run(const framework::Scope& scope, + const platform::DeviceContext& dev_ctx) const override; + + /* + * Split the inputs(LoDTensors) to segments for each time step. + */ + void SplitInputs() const; + + /* + * Create step-scopes to store temporary outputs in each time steps. + */ + void CreateScopes() const; + + /* + * Link TensorArray steps to the corresponding variables located in + * step-scopes. + */ + void WriteStepInputs() const; + + /* + * Write output of each step to the corresponding TensorArray. + */ + void WriteStepOutputs() const; + + /* + * Initialize the states, each state will have a corresponding pre-state, + * which share the memory with the state in the previous time state. The + * pre-state in the first time step will be initialized with an zero tensor or + * a tensor in parent scope if is provided. + */ + void InitStates() const; + + /* + * Concatenate outputs in each time step and generate a LoDTensor. + */ + void ConcatOutputs() const; + + /* + * set a stepnet that is created according to a RecurrentOp's stepnet. + */ + void SetStepNet(std::unique_ptr net) { + PADDLE_ENFORCE_NOT_NULL(net); + stepnet_ = std::move(net); + } + const OperatorBase& GetStepNet() const { return *stepnet_; } + + protected: + struct ArgCache { + framework::Scope const* scope; + std::vector* scopes; + std::map inlinks; + std::map outlinks; + + size_t num_steps{0}; + + void Init(const rnn::ArgumentName& name, const OperatorBase& op, + const framework::Scope& scope, rnn::Argument* arg); + + framework::Scope& GetScope(size_t index) { + PADDLE_ENFORCE_LT(index, num_steps); + return *scopes->at(index); + } + + private: + void InitArgument(const rnn::ArgumentName& name, const OperatorBase& op, + rnn::Argument* arg); + void CacheScopes(const framework::Scope& scope, const rnn::Argument& arg); + void CacheInlinks(const framework::Scope& scope, + const std::vector& names); + void CacheOutlinks(const framework::Scope& scope, + const std::vector& names); + framework::Variable* GetVariable(const framework::Scope& scope, + const std::string& name); + }; + + private: + std::unique_ptr stepnet_; + mutable framework::TensorArray states_; + mutable std::map step_inputs_; + mutable std::map step_outputs_; + mutable std::map> + dy_seq_metas_; + mutable rnn::Argument arg_; + mutable ArgCache cache_; + +#ifdef PADDLE_WITH_TESTING + friend class DynamicRecurrentOpTestHelper; + FRIEND_TEST(DynamicRecurrentOpTestHelper, SplitInputs); + FRIEND_TEST(DynamicRecurrentOpTestHelper, CreateCache); + FRIEND_TEST(DynamicRecurrentOpTestHelper, CreateScopes); + FRIEND_TEST(DynamicRecurrentOpTestHelper, WriteStepInputs); + FRIEND_TEST(DynamicRecurrentOpTestHelper, WriteStepOutputs); + FRIEND_TEST(DynamicRecurrentOpTestHelper, InitStates); + FRIEND_TEST(DynamicRecurrentOpTestHelper, ConcatOutputs); +#endif +}; + +class DynamicRecurrentGradientOp : public framework::OperatorBase { + public: + DynamicRecurrentGradientOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void Run(const framework::Scope& scope, + const platform::DeviceContext& dev_ctx) const override; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/dynamic_recurrent_op_test.cc b/paddle/operators/dynamic_recurrent_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..675a7890f3fa6bb7ab9dbbdb04894b2557214a8a --- /dev/null +++ b/paddle/operators/dynamic_recurrent_op_test.cc @@ -0,0 +1,222 @@ +#include "paddle/operators/dynamic_recurrent_op.h" + +#include + +#include "paddle/framework/ddim.h" +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/op_desc.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/net_op.h" + +namespace paddle { +namespace operators { + +using framework::Scope; +using framework::TensorArray; +using framework::LoDTensor; +using framework::Variable; + +class TestOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + DEFINE_OP_CLONE_METHOD(TestOp); + void Run(const Scope& scope, + const platform::DeviceContext& dev_ctx) const override {} +}; + +void OpDescNewVar(const std::string& param_name, + std::initializer_list arguments, + paddle::framework::OpDesc::Var* var) { + var->set_parameter(param_name); + for (auto& arg_name : arguments) { + var->add_arguments(arg_name); + } +} + +// create a LoD tensor in scope with specific dims +LoDTensor* CreateVar(Scope& scope, std::string name, framework::DDim dims, + const platform::Place& place) { + auto* var = scope.NewVar(name); + auto* tensor = var->GetMutable(); + tensor->Resize(dims); + tensor->mutable_data(place); + return tensor; +} + +class DynamicRecurrentOpTestHelper : public ::testing::Test { + protected: + const rnn::ArgumentName argname = DynamicRecurrentOp::kArgName; + + virtual void SetUp() override { + CreateGlobalVariables(); + + auto op_desc = CreateOpDesc(); + op = paddle::framework::OpRegistry::CreateOp(op_desc); + dop = dynamic_cast(op.get()); + InitCacheManually(); + InitStepNet(); + } + + framework::OpDesc CreateOpDesc() { + // create op + paddle::framework::OpDesc op_desc; + op_desc.set_type("dynamic_recurrent"); + + OpDescNewVar(argname.inlinks, {"in0"}, op_desc.add_inputs()); + OpDescNewVar(argname.boot_memories, {"boot_mem"}, op_desc.add_inputs()); + OpDescNewVar(argname.step_scopes, {"step_scopes"}, op_desc.add_outputs()); + OpDescNewVar(argname.outlinks, {"out0"}, op_desc.add_outputs()); + + // set pre-memories + auto pre_memories = op_desc.mutable_attrs()->Add(); + pre_memories->set_name(argname.pre_memories); + pre_memories->set_type(paddle::framework::AttrType::STRINGS); + auto pre_memories_item = pre_memories->add_strings(); + *pre_memories_item = "mem@pre"; + + // set memories + auto memories = op_desc.mutable_attrs()->Add(); + memories->set_name(argname.memories); + memories->set_type(paddle::framework::AttrType::STRINGS); + auto memories_item = memories->add_strings(); + *memories_item = "mem"; + return op_desc; + } + + void CreateGlobalVariables() { + platform::CPUPlace place; + scope.NewVar("step_scopes"); + CreateVar(scope, "boot_mem", framework::make_ddim({10, 20}), place); + // auto* out0 = + CreateVar(scope, "out0", framework::make_ddim({10, 20}), place); + auto* in0 = CreateVar(scope, "in0", framework::make_ddim({10, 8}), place); + // 10 instanes with 4 sentences, length is 4, 3, 2, 1 respectively. + framework::LoD in0_lod(1); + for (int x : std::vector{0, 4, 7, 9, 10}) { + in0_lod[0].push_back(x); + } + in0->set_lod(in0_lod); + in0->Resize(framework::make_ddim({10, 8})); + // set the content, each sentence content is seqid.batchid + // the seqid starts from 0 + int start = 0; + for (size_t seqid = 0; seqid < in0_lod.size() - 1; seqid++) { + for (size_t batchid = 0; + batchid < in0_lod[0][seqid + 1] - in0_lod[0][seqid]; batchid++) { + float v = seqid + batchid * 0.1; + + for (size_t dim = 0; dim < 8; dim++) { + in0->data()[start * 8 + dim] = v; + } + start++; + } + } + } + + void InitCacheManually() { + dop->cache_.Init(DynamicRecurrentOp::kArgName, *dop, scope, &dop->arg_); + } + + void InitStepNet() { + std::unique_ptr stepnet{new NetOp}; + dynamic_cast(stepnet.get()) + ->AppendOp(std::unique_ptr(new TestOp( + "test", {{"inlinks", {"in0"}}, {"boot_memories", {"boot_mem"}}}, + {{"outlinks", {"out0"}}, {"step_scopes", {"step_scopes"}}}, {}))); + dop->SetStepNet(std::move(stepnet)); + } + + protected: + DynamicRecurrentOp* dop; + std::unique_ptr op; + paddle::platform::CPUDeviceContext device_context; + paddle::framework::Scope scope; +}; + +TEST_F(DynamicRecurrentOpTestHelper, CreateCache) { + const rnn::Argument& arg = dop->arg_; + ASSERT_EQ(arg.inlinks.size(), 1UL); + ASSERT_EQ(arg.outlinks.size(), 1UL); +} + +TEST_F(DynamicRecurrentOpTestHelper, SplitInputs) { + dop->SplitInputs(); + auto& in0_ta = dop->step_inputs_["in0"]; + ASSERT_EQ(in0_ta.size(), 4UL); + + const auto& batch0 = in0_ta.Read(0); + const auto& batch1 = in0_ta.Read(1); + const auto& batch2 = in0_ta.Read(2); + const auto& batch3 = in0_ta.Read(3); + EXPECT_EQ(batch0.dims()[0], 4); + EXPECT_EQ(batch1.dims()[0], 3); + EXPECT_EQ(batch2.dims()[0], 2); + EXPECT_EQ(batch3.dims()[0], 1); +} + +TEST_F(DynamicRecurrentOpTestHelper, CreateScopes) { + dop->SplitInputs(); + dop->CreateScopes(); + ASSERT_EQ(dop->cache_.num_steps, 4UL); + ASSERT_EQ(dop->cache_.scopes->size(), 4UL); +} + +TEST_F(DynamicRecurrentOpTestHelper, WriteStepInputs) { + dop->SplitInputs(); + dop->CreateScopes(); + dop->WriteStepInputs(); + + for (size_t step = 0; step < dop->cache_.num_steps; step++) { + auto& scope = dop->cache_.GetScope(step); + for (auto name : std::vector({"in0"})) { + ASSERT_TRUE(scope.FindVar(name) != nullptr); + } + } +} + +TEST_F(DynamicRecurrentOpTestHelper, WriteStepOutputs) { + dop->SplitInputs(); + dop->CreateScopes(); + dop->WriteStepInputs(); + dop->WriteStepOutputs(); + + for (size_t step = 0; step < dop->cache_.num_steps; step++) { + auto& scope = dop->cache_.GetScope(step); + for (auto name : std::vector({"out0"})) { + ASSERT_TRUE(scope.FindVar(name)); + } + } +} + +TEST_F(DynamicRecurrentOpTestHelper, ConcatOutputs) { + // Let's leave this test to python unittest. +} + +TEST_F(DynamicRecurrentOpTestHelper, InitStates) { + dop->SplitInputs(); + dop->CreateScopes(); + dop->WriteStepInputs(); + dop->WriteStepOutputs(); + dop->InitStates(); + + for (size_t step = 0; step < dop->cache_.num_steps; step++) { + auto& scope = dop->cache_.GetScope(step); + auto state = scope.FindVar("mem"); + ASSERT_TRUE(state != nullptr); + + auto* pre_state = scope.FindVar("mem@pre"); + ASSERT_TRUE(pre_state != nullptr); + + auto* boot_state = scope.FindVar("boot_mem"); + ASSERT_TRUE(boot_state != nullptr); + + if (step == 0) { + // check pre_state is a reference of boot_state + ASSERT_EQ(boot_state->Get().data(), + pre_state->Get().data()); + } + } +} + +} // operators +} // namespace paddle