提交 843ed8e3 编写于 作者: Y Yan Chunwei 提交者: GitHub

dynamic recurrent op forward c++ implentation (#4597)

上级 7506e481
...@@ -24,6 +24,10 @@ if(WITH_DOUBLE) ...@@ -24,6 +24,10 @@ if(WITH_DOUBLE)
add_definitions(-DPADDLE_TYPE_DOUBLE) add_definitions(-DPADDLE_TYPE_DOUBLE)
endif(WITH_DOUBLE) endif(WITH_DOUBLE)
if(WITH_TESTING)
add_definitions(-DPADDLE_WITH_TESTING)
endif(WITH_TESTING)
if(NOT WITH_TIMER) if(NOT WITH_TIMER)
add_definitions(-DPADDLE_DISABLE_TIMER) add_definitions(-DPADDLE_DISABLE_TIMER)
endif(NOT WITH_TIMER) endif(NOT WITH_TIMER)
......
...@@ -142,9 +142,9 @@ class OperatorBase { ...@@ -142,9 +142,9 @@ class OperatorBase {
// Macro for define a clone method. // Macro for define a clone method.
// If you are writing an kernel operator, `Clone` will be defined when you // If you are writing an kernel operator, `Clone` will be defined when you
// register it. i.e. `Clone` method is not needed to define by yourself. // register it. i.e. `Clone` method is not needed to define by yourself.
#define DEFINE_OP_CLONE_METHOD(cls) \ #define DEFINE_OP_CLONE_METHOD(cls) \
std::unique_ptr<OperatorBase> Clone() const final { \ std::unique_ptr<::paddle::framework::OperatorBase> Clone() const final { \
return std::unique_ptr<OperatorBase>(new cls(*this)); \ return std::unique_ptr<::paddle::framework::OperatorBase>(new cls(*this)); \
} }
// Macro for define a default constructor for Operator. // Macro for define a default constructor for Operator.
......
...@@ -87,12 +87,12 @@ class TensorArray { ...@@ -87,12 +87,12 @@ class TensorArray {
LoDTensor Stack() const; LoDTensor Stack() const;
/* /*
* Unpacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors. * Unstacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors.
*/ */
void Unstack(const LoDTensor &source) const; void Unstack(const LoDTensor &source) const;
/* /*
* Unpacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors, * Unstacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors,
* with memory of tensors shared. * with memory of tensors shared.
*/ */
void UnstackShared(const LoDTensor &source) const; void UnstackShared(const LoDTensor &source) const;
......
...@@ -133,3 +133,4 @@ cc_test(gather_test SRCS gather_test.cc DEPS tensor) ...@@ -133,3 +133,4 @@ cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory) cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory)
cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc DEPS dynamic_recurrent_op recurrent_op tensor_array)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve .
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/dynamic_recurrent_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using framework::Scope;
using framework::TensorArray;
using framework::LoDTensor;
using framework::Variable;
namespace detail {
inline void CreateVariables(Scope& scope,
const std::vector<std::string>& var_names) {
for (const auto& name : var_names) {
scope.NewVar(name);
}
}
} // namespace detail
class DynamicRecurrentOpProtoAndCheckerMaker
: public framework::OpProtoAndCheckerMaker {
public:
DynamicRecurrentOpProtoAndCheckerMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
const auto& name = DynamicRecurrentOp::kArgName;
// inputs and outputs stored in proto
AddInput(name.inlinks,
"the inputs that need to be segmented for each step.")
.AsDuplicable();
AddInput(name.boot_memories, "variables to initialize memories.")
.AsDuplicable();
AddOutput(name.outlinks, "the outputs that need to concated for all steps.")
.AsDuplicable();
AddOutput(name.step_scopes, "step scopes");
// Attributes stored in AttributeMap
AddAttr<std::vector<std::string>>(name.pre_memories,
"names of pre-memories");
AddAttr<std::vector<std::string>>(name.memories, "names of memories");
AddComment("This is a RNN operator for varience-length sequences.");
}
};
void DynamicRecurrentOp::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const {
cache_.Init(kArgName, *this, scope, &arg_);
SplitInputs();
CreateScopes();
WriteStepInputs();
InitStates();
// call stepnet in all the time steps
for (size_t step = 0; step < cache_.num_steps; step++) {
auto& step_scope = cache_.GetScope(step);
stepnet_->Run(step_scope, dev_ctx);
}
WriteStepOutputs();
ConcatOutputs();
}
void DynamicRecurrentOp::SplitInputs() const {
// TODO(superjom) make level a config
// TODO(superjom) check all the inputs has the same LoD
int level = 0;
const auto& inlinks = cache_.inlinks;
for (const auto& item : inlinks) {
const auto& var = item.second;
const auto& tensor = var->Get<LoDTensor>();
TensorArray& ta = step_inputs_[item.first];
dy_seq_metas_[item.first] =
ta.Unpack(tensor, level, true /*length_descend*/);
if (cache_.num_steps) {
PADDLE_ENFORCE_EQ(ta.size(), cache_.num_steps,
"inputs should have the same steps");
} else {
cache_.num_steps = ta.size();
}
}
}
void DynamicRecurrentOp::WriteStepInputs() const {
for (const auto& item : cache_.inlinks) {
auto ta_it = step_inputs_.find(item.first);
PADDLE_ENFORCE(ta_it != step_inputs_.end(),
"step_inputs_ not compatible with memory set");
TensorArray& ta = ta_it->second;
for (size_t step = 0; step < ta.size(); step++) {
auto tensor = ta.Read(step);
auto& step_scope = cache_.GetScope(step);
Variable* var = step_scope.FindVar(item.first);
if (var == nullptr) {
var = step_scope.NewVar(item.first);
}
var->GetMutable<LoDTensor>()->ShareDataWith<value_type>(tensor);
}
}
}
void DynamicRecurrentOp::WriteStepOutputs() const {
for (size_t step = 0; step < cache_.scopes->size(); step++) {
auto& scope = cache_.GetScope(step);
for (auto& item : step_outputs_) {
auto* var = scope.FindVar(item.first);
if (var == nullptr) {
var = scope.NewVar(item.first);
}
auto* tensor = var->GetMutable<LoDTensor>();
item.second.WriteShared(step, *tensor);
}
}
}
void DynamicRecurrentOp::CreateScopes() const {
PADDLE_ENFORCE_GT(cache_.num_steps, 0);
// resize scopes
size_t num_scopes_need_create = cache_.num_steps - cache_.scopes->size();
for (size_t i = 0; i < num_scopes_need_create; i++) {
cache_.scopes->emplace_back(&cache_.scope->NewScope());
}
// init temporary inputs
PADDLE_ENFORCE_NOT_NULL(stepnet_, "stepnet should be set first");
std::vector<std::string> memories;
std::vector<std::string> pre_memories;
std::transform(arg_.memories.begin(), arg_.memories.end(),
std::back_inserter(memories),
[](const rnn::MemoryAttr& m) { return m.var; });
std::transform(arg_.memories.begin(), arg_.memories.end(),
std::back_inserter(pre_memories),
[](const rnn::MemoryAttr& m) { return m.pre_var; });
for (size_t step = 0; step < cache_.num_steps; step++) {
auto& scope = cache_.GetScope(step);
detail::CreateVariables(scope, arg_.inlinks);
detail::CreateVariables(scope, arg_.outlinks);
detail::CreateVariables(scope, memories);
detail::CreateVariables(scope, pre_memories);
}
}
void DynamicRecurrentOp::ConcatOutputs() const {
// TODO(superjom) transform this to a config
int level = 0;
// TODO(superjom) pass in some lod
// just a placeholder
framework::LoD lod;
for (auto& item : step_outputs_) {
auto tensor = item.second.Pack(level, dy_seq_metas_[item.first], lod);
auto& output = cache_.outlinks[item.first]->Get<LoDTensor>();
const_cast<LoDTensor*>(&output)->ShareDataWith<value_type>(tensor);
}
}
void DynamicRecurrentOp::InitStates() const {
// init the first state
// TODO(superjom) parepare the scenerio that boot state not exists
for (auto memory : arg_.memories) {
auto* boot_state_var = cache_.scope->FindVar(memory.boot_var);
PADDLE_ENFORCE_NOT_NULL(boot_state_var);
auto& boot_state = boot_state_var->Get<LoDTensor>();
const auto& dims = boot_state.dims();
for (size_t step = 0; step < cache_.num_steps; step++) {
auto& cur_scope = cache_.GetScope(step);
// link pre-state to boot_state
// init state and pre-state
auto* pre_state = cur_scope.FindVar(memory.pre_var);
PADDLE_ENFORCE_NOT_NULL(pre_state);
pre_state->GetMutable<LoDTensor>();
auto* state = cur_scope.FindVar(memory.var);
PADDLE_ENFORCE_NOT_NULL(state);
state->GetMutable<LoDTensor>()->Resize(dims);
state->GetMutable<LoDTensor>()->mutable_data<value_type>(
platform::CPUPlace());
if (step == 0) {
auto* pre_state_tensor = pre_state->GetMutable<LoDTensor>();
pre_state_tensor->Resize(boot_state.dims());
pre_state_tensor->ShareDataWith<value_type>(boot_state);
} else {
auto& pre_scope = cache_.GetScope(step - 1);
auto* state_pre = pre_scope.FindVar(memory.var);
PADDLE_ENFORCE_NOT_NULL(state_pre);
pre_state->GetMutable<LoDTensor>()->ShareDataWith<value_type>(
*state_pre->GetMutable<LoDTensor>());
}
}
}
}
void DynamicRecurrentOp::ArgCache::Init(
const rnn::ArgumentName& name, const paddle::framework::OperatorBase& op,
const paddle::framework::Scope& scope, rnn::Argument* arg) {
this->scope = &scope;
InitArgument(name, op, arg);
CacheScopes(scope, *arg);
CacheInlinks(scope, arg->inlinks);
CacheOutlinks(scope, arg->outlinks);
}
void DynamicRecurrentOp::ArgCache::InitArgument(const rnn::ArgumentName& name,
const OperatorBase& op,
rnn::Argument* arg) {
rnn::InitArgument(name, arg, op, false /*is_grad*/);
}
void DynamicRecurrentOp::ArgCache::CacheScopes(const Scope& scope,
const rnn::Argument& arg) {
auto scopes_var = scope.FindVar(arg.step_scopes);
PADDLE_ENFORCE(scopes_var != nullptr,
"the step_scopes output argument [%s] should be created first "
"by framework.",
arg.step_scopes);
this->scopes = scopes_var->GetMutable<std::vector<Scope*>>();
}
void DynamicRecurrentOp::ArgCache::CacheInlinks(
const Scope& scope, const std::vector<std::string>& names) {
for (auto name : names) {
auto* var = GetVariable(scope, name);
inlinks[name] = var;
}
}
void DynamicRecurrentOp::ArgCache::CacheOutlinks(
const Scope& scope, const std::vector<std::string>& names) {
for (auto name : names) {
auto* var = GetVariable(scope, name);
outlinks[name] = var;
}
}
Variable* DynamicRecurrentOp::ArgCache::GetVariable(const Scope& scope,
const std::string& name) {
auto* var = scope.FindVar(name);
PADDLE_ENFORCE_NOT_NULL(var, "variable [%s] not exist in scope", name);
return var;
}
const rnn::ArgumentName DynamicRecurrentOp::kArgName{
"step_net", "step_scopes", "inlinks", "outlinks",
"memories", "pre_memories", "boot_memories"};
void DynamicRecurrentGradientOp::Run(
const Scope& scope, const platform::DeviceContext& dev_ctx) const {}
} // namespace operators
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(
dynamic_recurrent, paddle::operators::DynamicRecurrentOp,
paddle::operators::DynamicRecurrentOpProtoAndCheckerMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_TESTING
#include "gtest/gtest.h"
#endif
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor_array.h"
#include "paddle/framework/variable.h"
#include "paddle/operators/rnn/recurrent_op_utils.h"
namespace paddle {
namespace operators {
class DynamicRecurrentOp : public framework::OperatorBase {
public:
static const rnn::ArgumentName kArgName;
using value_type = float;
DynamicRecurrentOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
DynamicRecurrentOp(const DynamicRecurrentOp& o)
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
// TODO(yuyang18): Implement copy ctor well.
PADDLE_THROW("Not implemented");
}
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override;
/*
* Split the inputs(LoDTensors) to segments for each time step.
*/
void SplitInputs() const;
/*
* Create step-scopes to store temporary outputs in each time steps.
*/
void CreateScopes() const;
/*
* Link TensorArray steps to the corresponding variables located in
* step-scopes.
*/
void WriteStepInputs() const;
/*
* Write output of each step to the corresponding TensorArray.
*/
void WriteStepOutputs() const;
/*
* Initialize the states, each state will have a corresponding pre-state,
* which share the memory with the state in the previous time state. The
* pre-state in the first time step will be initialized with an zero tensor or
* a tensor in parent scope if is provided.
*/
void InitStates() const;
/*
* Concatenate outputs in each time step and generate a LoDTensor.
*/
void ConcatOutputs() const;
/*
* set a stepnet that is created according to a RecurrentOp's stepnet.
*/
void SetStepNet(std::unique_ptr<OperatorBase> net) {
PADDLE_ENFORCE_NOT_NULL(net);
stepnet_ = std::move(net);
}
const OperatorBase& GetStepNet() const { return *stepnet_; }
protected:
struct ArgCache {
framework::Scope const* scope;
std::vector<framework::Scope*>* scopes;
std::map<std::string, framework::Variable*> inlinks;
std::map<std::string, framework::Variable*> outlinks;
size_t num_steps{0};
void Init(const rnn::ArgumentName& name, const OperatorBase& op,
const framework::Scope& scope, rnn::Argument* arg);
framework::Scope& GetScope(size_t index) {
PADDLE_ENFORCE_LT(index, num_steps);
return *scopes->at(index);
}
private:
void InitArgument(const rnn::ArgumentName& name, const OperatorBase& op,
rnn::Argument* arg);
void CacheScopes(const framework::Scope& scope, const rnn::Argument& arg);
void CacheInlinks(const framework::Scope& scope,
const std::vector<std::string>& names);
void CacheOutlinks(const framework::Scope& scope,
const std::vector<std::string>& names);
framework::Variable* GetVariable(const framework::Scope& scope,
const std::string& name);
};
private:
std::unique_ptr<OperatorBase> stepnet_;
mutable framework::TensorArray states_;
mutable std::map<std::string, framework::TensorArray> step_inputs_;
mutable std::map<std::string, framework::TensorArray> step_outputs_;
mutable std::map<std::string, std::vector<framework::DySeqMeta>>
dy_seq_metas_;
mutable rnn::Argument arg_;
mutable ArgCache cache_;
#ifdef PADDLE_WITH_TESTING
friend class DynamicRecurrentOpTestHelper;
FRIEND_TEST(DynamicRecurrentOpTestHelper, SplitInputs);
FRIEND_TEST(DynamicRecurrentOpTestHelper, CreateCache);
FRIEND_TEST(DynamicRecurrentOpTestHelper, CreateScopes);
FRIEND_TEST(DynamicRecurrentOpTestHelper, WriteStepInputs);
FRIEND_TEST(DynamicRecurrentOpTestHelper, WriteStepOutputs);
FRIEND_TEST(DynamicRecurrentOpTestHelper, InitStates);
FRIEND_TEST(DynamicRecurrentOpTestHelper, ConcatOutputs);
#endif
};
class DynamicRecurrentGradientOp : public framework::OperatorBase {
public:
DynamicRecurrentGradientOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override;
};
} // namespace operators
} // namespace paddle
#include "paddle/operators/dynamic_recurrent_op.h"
#include <gtest/gtest.h>
#include "paddle/framework/ddim.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_desc.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
namespace paddle {
namespace operators {
using framework::Scope;
using framework::TensorArray;
using framework::LoDTensor;
using framework::Variable;
class TestOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
DEFINE_OP_CLONE_METHOD(TestOp);
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
};
void OpDescNewVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
var->add_arguments(arg_name);
}
}
// create a LoD tensor in scope with specific dims
LoDTensor* CreateVar(Scope& scope, std::string name, framework::DDim dims,
const platform::Place& place) {
auto* var = scope.NewVar(name);
auto* tensor = var->GetMutable<LoDTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(place);
return tensor;
}
class DynamicRecurrentOpTestHelper : public ::testing::Test {
protected:
const rnn::ArgumentName argname = DynamicRecurrentOp::kArgName;
virtual void SetUp() override {
CreateGlobalVariables();
auto op_desc = CreateOpDesc();
op = paddle::framework::OpRegistry::CreateOp(op_desc);
dop = dynamic_cast<DynamicRecurrentOp*>(op.get());
InitCacheManually();
InitStepNet();
}
framework::OpDesc CreateOpDesc() {
// create op
paddle::framework::OpDesc op_desc;
op_desc.set_type("dynamic_recurrent");
OpDescNewVar(argname.inlinks, {"in0"}, op_desc.add_inputs());
OpDescNewVar(argname.boot_memories, {"boot_mem"}, op_desc.add_inputs());
OpDescNewVar(argname.step_scopes, {"step_scopes"}, op_desc.add_outputs());
OpDescNewVar(argname.outlinks, {"out0"}, op_desc.add_outputs());
// set pre-memories
auto pre_memories = op_desc.mutable_attrs()->Add();
pre_memories->set_name(argname.pre_memories);
pre_memories->set_type(paddle::framework::AttrType::STRINGS);
auto pre_memories_item = pre_memories->add_strings();
*pre_memories_item = "mem@pre";
// set memories
auto memories = op_desc.mutable_attrs()->Add();
memories->set_name(argname.memories);
memories->set_type(paddle::framework::AttrType::STRINGS);
auto memories_item = memories->add_strings();
*memories_item = "mem";
return op_desc;
}
void CreateGlobalVariables() {
platform::CPUPlace place;
scope.NewVar("step_scopes");
CreateVar(scope, "boot_mem", framework::make_ddim({10, 20}), place);
// auto* out0 =
CreateVar(scope, "out0", framework::make_ddim({10, 20}), place);
auto* in0 = CreateVar(scope, "in0", framework::make_ddim({10, 8}), place);
// 10 instanes with 4 sentences, length is 4, 3, 2, 1 respectively.
framework::LoD in0_lod(1);
for (int x : std::vector<int>{0, 4, 7, 9, 10}) {
in0_lod[0].push_back(x);
}
in0->set_lod(in0_lod);
in0->Resize(framework::make_ddim({10, 8}));
// set the content, each sentence content is seqid.batchid
// the seqid starts from 0
int start = 0;
for (size_t seqid = 0; seqid < in0_lod.size() - 1; seqid++) {
for (size_t batchid = 0;
batchid < in0_lod[0][seqid + 1] - in0_lod[0][seqid]; batchid++) {
float v = seqid + batchid * 0.1;
for (size_t dim = 0; dim < 8; dim++) {
in0->data<float>()[start * 8 + dim] = v;
}
start++;
}
}
}
void InitCacheManually() {
dop->cache_.Init(DynamicRecurrentOp::kArgName, *dop, scope, &dop->arg_);
}
void InitStepNet() {
std::unique_ptr<framework::OperatorBase> stepnet{new NetOp};
dynamic_cast<NetOp*>(stepnet.get())
->AppendOp(std::unique_ptr<TestOp>(new TestOp(
"test", {{"inlinks", {"in0"}}, {"boot_memories", {"boot_mem"}}},
{{"outlinks", {"out0"}}, {"step_scopes", {"step_scopes"}}}, {})));
dop->SetStepNet(std::move(stepnet));
}
protected:
DynamicRecurrentOp* dop;
std::unique_ptr<framework::OperatorBase> op;
paddle::platform::CPUDeviceContext device_context;
paddle::framework::Scope scope;
};
TEST_F(DynamicRecurrentOpTestHelper, CreateCache) {
const rnn::Argument& arg = dop->arg_;
ASSERT_EQ(arg.inlinks.size(), 1UL);
ASSERT_EQ(arg.outlinks.size(), 1UL);
}
TEST_F(DynamicRecurrentOpTestHelper, SplitInputs) {
dop->SplitInputs();
auto& in0_ta = dop->step_inputs_["in0"];
ASSERT_EQ(in0_ta.size(), 4UL);
const auto& batch0 = in0_ta.Read(0);
const auto& batch1 = in0_ta.Read(1);
const auto& batch2 = in0_ta.Read(2);
const auto& batch3 = in0_ta.Read(3);
EXPECT_EQ(batch0.dims()[0], 4);
EXPECT_EQ(batch1.dims()[0], 3);
EXPECT_EQ(batch2.dims()[0], 2);
EXPECT_EQ(batch3.dims()[0], 1);
}
TEST_F(DynamicRecurrentOpTestHelper, CreateScopes) {
dop->SplitInputs();
dop->CreateScopes();
ASSERT_EQ(dop->cache_.num_steps, 4UL);
ASSERT_EQ(dop->cache_.scopes->size(), 4UL);
}
TEST_F(DynamicRecurrentOpTestHelper, WriteStepInputs) {
dop->SplitInputs();
dop->CreateScopes();
dop->WriteStepInputs();
for (size_t step = 0; step < dop->cache_.num_steps; step++) {
auto& scope = dop->cache_.GetScope(step);
for (auto name : std::vector<std::string>({"in0"})) {
ASSERT_TRUE(scope.FindVar(name) != nullptr);
}
}
}
TEST_F(DynamicRecurrentOpTestHelper, WriteStepOutputs) {
dop->SplitInputs();
dop->CreateScopes();
dop->WriteStepInputs();
dop->WriteStepOutputs();
for (size_t step = 0; step < dop->cache_.num_steps; step++) {
auto& scope = dop->cache_.GetScope(step);
for (auto name : std::vector<std::string>({"out0"})) {
ASSERT_TRUE(scope.FindVar(name));
}
}
}
TEST_F(DynamicRecurrentOpTestHelper, ConcatOutputs) {
// Let's leave this test to python unittest.
}
TEST_F(DynamicRecurrentOpTestHelper, InitStates) {
dop->SplitInputs();
dop->CreateScopes();
dop->WriteStepInputs();
dop->WriteStepOutputs();
dop->InitStates();
for (size_t step = 0; step < dop->cache_.num_steps; step++) {
auto& scope = dop->cache_.GetScope(step);
auto state = scope.FindVar("mem");
ASSERT_TRUE(state != nullptr);
auto* pre_state = scope.FindVar("mem@pre");
ASSERT_TRUE(pre_state != nullptr);
auto* boot_state = scope.FindVar("boot_mem");
ASSERT_TRUE(boot_state != nullptr);
if (step == 0) {
// check pre_state is a reference of boot_state
ASSERT_EQ(boot_state->Get<LoDTensor>().data<float>(),
pre_state->Get<LoDTensor>().data<float>());
}
}
}
} // operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册