提交 8925295a 编写于 作者: D dangqingqing

follow comments.

上级 b89d15a3
...@@ -30,11 +30,14 @@ namespace rnn { ...@@ -30,11 +30,14 @@ namespace rnn {
void SegmentInputs(std::vector<std::shared_ptr<Scope>>& step_scopes, void SegmentInputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
const std::vector<Link>& inlinks, const std::vector<Link>& inlinks,
const size_t seq_len, const size_t seq_len,
bool infer_shape) { bool infer_shape_mode) {
PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided."); PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided.");
for (size_t i = 0; i < inlinks.size(); ++i) { for (size_t i = 0; i < inlinks.size(); ++i) {
Tensor* input = auto input_var = step_scopes[0]->GetVariable(inlinks[i].external);
step_scopes[0]->GetVariable(inlinks[i].external)->GetMutable<Tensor>(); PADDLE_ENFORCE(input_var != nullptr,
"input link [%s] is not in scope.",
inlinks[i].external);
Tensor* input = input_var->GetMutable<Tensor>();
DDim dims = input->dims(); DDim dims = input->dims();
PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len, PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len,
"all the inlinks must have same length"); "all the inlinks must have same length");
...@@ -43,7 +46,7 @@ void SegmentInputs(std::vector<std::shared_ptr<Scope>>& step_scopes, ...@@ -43,7 +46,7 @@ void SegmentInputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
Tensor* step_input = step_scopes[j] Tensor* step_input = step_scopes[j]
->CreateVariable(inlinks[i].internal) ->CreateVariable(inlinks[i].internal)
->GetMutable<Tensor>(); ->GetMutable<Tensor>();
if (!infer_shape) { if (!infer_shape_mode) {
*step_input = input->Slice<float>(j, j + 1); *step_input = input->Slice<float>(j, j + 1);
} }
step_input->Resize(step_dims); step_input->Resize(step_dims);
...@@ -54,12 +57,14 @@ void SegmentInputs(std::vector<std::shared_ptr<Scope>>& step_scopes, ...@@ -54,12 +57,14 @@ void SegmentInputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes, void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
const std::vector<Link>& outlinks, const std::vector<Link>& outlinks,
const size_t seq_len, const size_t seq_len,
bool infer_shape) { bool infer_shape_mode) {
for (size_t i = 0; i < outlinks.size(); i++) { for (size_t i = 0; i < outlinks.size(); i++) {
PADDLE_ENFORCE(step_scopes[0]->HasVariable(outlinks[i].external),
"output link [%s] is not in scope.",
outlinks[i].external);
Tensor* output = Tensor* output =
step_scopes[0]->GetVariable(outlinks[i].external)->GetMutable<Tensor>(); step_scopes[0]->GetVariable(outlinks[i].external)->GetMutable<Tensor>();
if (infer_shape_mode) {
if (infer_shape) {
DDim step_dims = step_scopes[0] DDim step_dims = step_scopes[0]
->GetVariable(outlinks[i].internal) ->GetVariable(outlinks[i].internal)
->GetMutable<Tensor>() ->GetMutable<Tensor>()
...@@ -69,8 +74,6 @@ void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes, ...@@ -69,8 +74,6 @@ void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
output->Resize(make_ddim(dims_vec)); output->Resize(make_ddim(dims_vec));
} else { } else {
output->mutable_data<float>(platform::CPUPlace()); output->mutable_data<float>(platform::CPUPlace());
}
for (size_t j = 0; j < seq_len; j++) { for (size_t j = 0; j < seq_len; j++) {
Tensor* step_output = step_scopes[j] Tensor* step_output = step_scopes[j]
->GetVariable(outlinks[i].internal) ->GetVariable(outlinks[i].internal)
...@@ -81,13 +84,14 @@ void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes, ...@@ -81,13 +84,14 @@ void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
.CopyFrom<float>(*step_output, platform::CPUPlace()); .CopyFrom<float>(*step_output, platform::CPUPlace());
} }
} }
}
} }
void LinkMemories(std::vector<std::shared_ptr<Scope>>& scopes, void LinkMemories(std::vector<std::shared_ptr<Scope>>& scopes,
const std::vector<rnn::MemoryAttr>& memories, const std::vector<rnn::MemoryAttr>& memories,
const size_t step_id, const size_t step_id,
const int offset, const int offset,
bool infer_shape) { bool infer_shape_mode) {
PADDLE_ENFORCE(step_id < scopes.size(), PADDLE_ENFORCE(step_id < scopes.size(),
"step [%d] is out of range of step scopes' size [%d]", "step [%d] is out of range of step scopes' size [%d]",
step_id, step_id,
...@@ -107,7 +111,7 @@ void LinkMemories(std::vector<std::shared_ptr<Scope>>& scopes, ...@@ -107,7 +111,7 @@ void LinkMemories(std::vector<std::shared_ptr<Scope>>& scopes,
auto mem = scope->GetVariable(attr.pre_var)->GetMutable<Tensor>(); auto mem = scope->GetVariable(attr.pre_var)->GetMutable<Tensor>();
// maybe share variable is better? // maybe share variable is better?
auto linked_mem = linked_scope->GetVariable(attr.var)->GetMutable<Tensor>(); auto linked_mem = linked_scope->GetVariable(attr.var)->GetMutable<Tensor>();
if (infer_shape) { if (infer_shape_mode) {
mem->Resize(linked_mem->dims()); mem->Resize(linked_mem->dims());
} else { } else {
mem->ShareDataWith<float>(*linked_mem); mem->ShareDataWith<float>(*linked_mem);
...@@ -179,43 +183,39 @@ void RecurrentAlgorithm::InferShape(const std::shared_ptr<Scope>& scope) const { ...@@ -179,43 +183,39 @@ void RecurrentAlgorithm::InferShape(const std::shared_ptr<Scope>& scope) const {
->GetMutable<Tensor>() ->GetMutable<Tensor>()
->dims()[0]; ->dims()[0];
CreateScopes(scope); CreateScopes(scope);
auto step_scopes = GetStepScopes(scope); auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, true); rnn::SegmentInputs(
step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/);
InitMemories(step_scopes[0], true); InitMemories(step_scopes[0], true /*infer_shape_mode*/);
PADDLE_ENFORCE(scope->HasVariable(arg_->step_net),
"stepnet [%s] is not in scope.",
arg_->step_net);
Variable* net = scope->GetVariable(arg_->step_net); Variable* net = scope->GetVariable(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net"); PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (size_t i = 0; i < seq_len_; i++) { for (size_t i = 0; i < seq_len_; i++) {
if (i > 0) { if (i > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, i, -1, true); rnn::LinkMemories(
step_scopes, arg_->memories, i, -1, true /*infer_shape_mode*/);
} }
net->GetMutable<NetOp>()->InferShape(step_scopes[i]); net->GetMutable<NetOp>()->InferShape(step_scopes[i]);
} }
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, true); rnn::ConcatOutputs(
step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/);
} }
void RecurrentAlgorithm::Run(const std::shared_ptr<Scope>& scope, void RecurrentAlgorithm::Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const { const platform::DeviceContext& dev_ctx) const {
auto step_scopes = GetStepScopes(scope); auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, false); step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/);
InitMemories(step_scopes[0], false /*infer_shape_mode*/);
InitMemories(step_scopes[0], false);
Variable* net = scope->GetVariable(arg_->step_net); Variable* net = scope->GetVariable(arg_->step_net);
for (size_t step_id = 0; step_id < seq_len_; step_id++) { for (size_t step_id = 0; step_id < seq_len_; step_id++) {
if (step_id > 0) { if (step_id > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1, false); rnn::LinkMemories(
step_scopes, arg_->memories, step_id, -1, false /*infer_shape_mode*/);
} }
net->GetMutable<NetOp>()->Run(step_scopes[step_id], dev_ctx); net->GetMutable<NetOp>()->Run(step_scopes[step_id], dev_ctx);
} }
rnn::ConcatOutputs(
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, false); step_scopes, arg_->outlinks, seq_len_, false /*infer_shape_mode*/);
} }
void RecurrentAlgorithm::CreateScopes(std::shared_ptr<Scope> scope) const { void RecurrentAlgorithm::CreateScopes(std::shared_ptr<Scope> scope) const {
...@@ -227,7 +227,6 @@ void RecurrentAlgorithm::CreateScopes(std::shared_ptr<Scope> scope) const { ...@@ -227,7 +227,6 @@ void RecurrentAlgorithm::CreateScopes(std::shared_ptr<Scope> scope) const {
if (seq_len_ > step_scopes->size()) { if (seq_len_ > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < seq_len_; ++i) { for (size_t i = step_scopes->size(); i < seq_len_; ++i) {
std::shared_ptr<Scope> step_scope = std::make_shared<Scope>(scope); std::shared_ptr<Scope> step_scope = std::make_shared<Scope>(scope);
// Now all variables in scope must be created outside of op. // Now all variables in scope must be created outside of op.
auto net_op = scope->GetVariable(arg_->step_net)->GetMutable<NetOp>(); auto net_op = scope->GetVariable(arg_->step_net)->GetMutable<NetOp>();
for (auto& input : net_op->inputs_) { for (auto& input : net_op->inputs_) {
...@@ -237,14 +236,13 @@ void RecurrentAlgorithm::CreateScopes(std::shared_ptr<Scope> scope) const { ...@@ -237,14 +236,13 @@ void RecurrentAlgorithm::CreateScopes(std::shared_ptr<Scope> scope) const {
for (auto& output : net_op->outputs_) { for (auto& output : net_op->outputs_) {
step_scope->CreateVariable(output); step_scope->CreateVariable(output);
} }
step_scopes->push_back(std::make_shared<Scope>(step_scope)); step_scopes->push_back(std::make_shared<Scope>(step_scope));
} }
} }
} }
void RecurrentAlgorithm::InitMemories(std::shared_ptr<Scope> step_scope, void RecurrentAlgorithm::InitMemories(std::shared_ptr<Scope> step_scope,
bool infer_shape) const { bool infer_shape_mode) const {
for (auto& attr : arg_->memories) { for (auto& attr : arg_->memories) {
Tensor* pre_mem = Tensor* pre_mem =
step_scope->CreateVariable(attr.pre_var)->GetMutable<Tensor>(); step_scope->CreateVariable(attr.pre_var)->GetMutable<Tensor>();
...@@ -254,7 +252,7 @@ void RecurrentAlgorithm::InitMemories(std::shared_ptr<Scope> step_scope, ...@@ -254,7 +252,7 @@ void RecurrentAlgorithm::InitMemories(std::shared_ptr<Scope> step_scope,
attr.boot_var); attr.boot_var);
Tensor* boot_mem = Tensor* boot_mem =
step_scope->GetVariable(attr.boot_var)->GetMutable<Tensor>(); step_scope->GetVariable(attr.boot_var)->GetMutable<Tensor>();
if (infer_shape) { if (infer_shape_mode) {
pre_mem->Resize(boot_mem->dims()); pre_mem->Resize(boot_mem->dims());
} else { } else {
pre_mem->ShareDataWith<float>(*boot_mem); pre_mem->ShareDataWith<float>(*boot_mem);
...@@ -320,23 +318,23 @@ void RecurrentGradientAlgorithm::Run( ...@@ -320,23 +318,23 @@ void RecurrentGradientAlgorithm::Run(
const std::shared_ptr<Scope>& scope, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const { const platform::DeviceContext& dev_ctx) const {
auto step_scopes = GetStepScopes(scope); auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, false); rnn::SegmentInputs(
PADDLE_ENFORCE(scope->HasVariable(arg_->step_net), step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/);
"step net is not in scope.");
Variable* net = scope->GetVariable(arg_->step_net); Variable* net = scope->GetVariable(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) { if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1, false); rnn::LinkMemories(
step_scopes, arg_->memories, step_id, 1, false /*infer_shape_mode*/);
} }
net->GetMutable<NetOp>()->Run(step_scopes[step_id], dev_ctx); net->GetMutable<NetOp>()->Run(step_scopes[step_id], dev_ctx);
} }
LinkBootMemoryGradients(step_scopes[0], false); LinkBootMemoryGradients(step_scopes[0], false);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, false); rnn::ConcatOutputs(
step_scopes, arg_->outlinks, seq_len_, false /*infer_shape_mode*/);
} }
void RecurrentGradientAlgorithm::LinkBootMemoryGradients( void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
std::shared_ptr<Scope> step_scope, bool infer_shape) const { std::shared_ptr<Scope> step_scope, bool infer_shape_mode) const {
for (auto& attr : arg_->memories) { for (auto& attr : arg_->memories) {
Tensor* mem_grad = Tensor* mem_grad =
step_scope->CreateVariable(attr.var)->GetMutable<Tensor>(); step_scope->CreateVariable(attr.var)->GetMutable<Tensor>();
...@@ -346,7 +344,7 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients( ...@@ -346,7 +344,7 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
attr.boot_var); attr.boot_var);
Tensor* boot_mem_grad = Tensor* boot_mem_grad =
step_scope->CreateVariable(attr.boot_var)->GetMutable<Tensor>(); step_scope->CreateVariable(attr.boot_var)->GetMutable<Tensor>();
if (infer_shape) { if (infer_shape_mode) {
boot_mem_grad->Resize(mem_grad->dims()); boot_mem_grad->Resize(mem_grad->dims());
} else { } else {
boot_mem_grad->ShareDataWith<float>(*mem_grad); boot_mem_grad->ShareDataWith<float>(*mem_grad);
...@@ -360,21 +358,20 @@ void RecurrentGradientAlgorithm::InferShape( ...@@ -360,21 +358,20 @@ void RecurrentGradientAlgorithm::InferShape(
->GetMutable<Tensor>() ->GetMutable<Tensor>()
->dims()[0]; ->dims()[0];
auto step_scopes = GetStepScopes(scope); auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, true); rnn::SegmentInputs(
step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/);
PADDLE_ENFORCE(scope->HasVariable(arg_->step_net),
"step net is not in scope.");
Variable* net = scope->GetVariable(arg_->step_net); Variable* net = scope->GetVariable(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net"); PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) { if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1, true); rnn::LinkMemories(
step_scopes, arg_->memories, step_id, 1, true /*infer_shape_mode*/);
} }
net->GetMutable<NetOp>()->InferShape(step_scopes[step_id]); net->GetMutable<NetOp>()->InferShape(step_scopes[step_id]);
} }
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, true); rnn::ConcatOutputs(
LinkBootMemoryGradients(step_scopes[0], true); step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/);
LinkBootMemoryGradients(step_scopes[0], true /*infer_shape_mode*/);
} }
void RecurrentGradientOp::Init() { void RecurrentGradientOp::Init() {
......
...@@ -73,7 +73,7 @@ struct ArgumentName { ...@@ -73,7 +73,7 @@ struct ArgumentName {
void SegmentInputs(std::vector<std::shared_ptr<Scope>>& step_scopes, void SegmentInputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
const std::vector<Link>& inlinks, const std::vector<Link>& inlinks,
const size_t seq_len, const size_t seq_len,
bool infer_shape); bool infer_shape_mode);
/** /**
* Process outputs of step nets and merge to variables. * Process outputs of step nets and merge to variables.
...@@ -81,13 +81,13 @@ void SegmentInputs(std::vector<std::shared_ptr<Scope>>& step_scopes, ...@@ -81,13 +81,13 @@ void SegmentInputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes, void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
const std::vector<Link>& outlinks, const std::vector<Link>& outlinks,
const size_t seq_len, const size_t seq_len,
bool infer_shape); bool infer_shape_mode);
void LinkMemories(std::vector<std::shared_ptr<Scope>>& step_scopes, void LinkMemories(std::vector<std::shared_ptr<Scope>>& step_scopes,
const std::vector<MemoryAttr>& memories, const std::vector<MemoryAttr>& memories,
const size_t step_id, const size_t step_id,
const int offset, const int offset,
bool infer_shape); bool infer_shape_mode);
void InitArgument(const ArgumentName& name, Argument* arg); void InitArgument(const ArgumentName& name, Argument* arg);
...@@ -128,7 +128,8 @@ protected: ...@@ -128,7 +128,8 @@ protected:
->GetMutable<std::vector<std::shared_ptr<Scope>>>(); ->GetMutable<std::vector<std::shared_ptr<Scope>>>();
} }
void InitMemories(std::shared_ptr<Scope> step_scopes, bool infer_shape) const; void InitMemories(std::shared_ptr<Scope> step_scopes,
bool infer_shape_mode) const;
private: private:
std::unique_ptr<rnn::Argument> arg_; std::unique_ptr<rnn::Argument> arg_;
...@@ -153,7 +154,7 @@ public: ...@@ -153,7 +154,7 @@ public:
const platform::DeviceContext& dev_ctx) const; const platform::DeviceContext& dev_ctx) const;
void LinkBootMemoryGradients(std::shared_ptr<Scope> step_scopes, void LinkBootMemoryGradients(std::shared_ptr<Scope> step_scopes,
bool infer_shape) const; bool infer_shape_mode) const;
/** /**
* InferShape must be called before Run. * InferShape must be called before Run.
......
...@@ -298,7 +298,10 @@ protected: ...@@ -298,7 +298,10 @@ protected:
std::vector<std::shared_ptr<Scope>>* step_scopes = std::vector<std::shared_ptr<Scope>>* step_scopes =
scope_->GetVariable("step_scopes") scope_->GetVariable("step_scopes")
->GetMutable<std::vector<std::shared_ptr<Scope>>>(); ->GetMutable<std::vector<std::shared_ptr<Scope>>>();
rnn::SegmentInputs(*step_scopes, std::vector<rnn::Link>{inlink}, 10, true); rnn::SegmentInputs(*step_scopes,
std::vector<rnn::Link>{inlink},
10,
true /*infer_shape_mode*/);
} }
void LinkeMemories() { void LinkeMemories() {
...@@ -313,7 +316,8 @@ protected: ...@@ -313,7 +316,8 @@ protected:
scope_->GetVariable("step_scopes") scope_->GetVariable("step_scopes")
->GetMutable<std::vector<std::shared_ptr<Scope>>>(); ->GetMutable<std::vector<std::shared_ptr<Scope>>>();
for (int i = 1; i < 10; ++i) { for (int i = 1; i < 10; ++i) {
rnn::LinkMemories(*step_scopes, memories, i, -1, true); rnn::LinkMemories(
*step_scopes, memories, i, -1, true /*infer_shape_mode*/);
} }
} }
...@@ -343,7 +347,7 @@ TEST(RecurrentOp, LinkMemories) { ...@@ -343,7 +347,7 @@ TEST(RecurrentOp, LinkMemories) {
auto tensor = scope->CreateVariable("h")->GetMutable<Tensor>(); auto tensor = scope->CreateVariable("h")->GetMutable<Tensor>();
float* data = tensor->mutable_data<float>(make_ddim({15, 20}), CPUPlace()); float* data = tensor->mutable_data<float>(make_ddim({15, 20}), CPUPlace());
for (int j = 0; j < 15 * 20; ++j) { for (int j = 0; j < 15 * 20; ++j) {
data[i] = rand() * (1. / (double)RAND_MAX); data[j] = rand() * (1. / (double)RAND_MAX);
} }
step_scopes.push_back(scope); step_scopes.push_back(scope);
} }
...@@ -357,7 +361,7 @@ TEST(RecurrentOp, LinkMemories) { ...@@ -357,7 +361,7 @@ TEST(RecurrentOp, LinkMemories) {
memories.push_back(mem_attr); memories.push_back(mem_attr);
for (int i = 1; i < len; ++i) { for (int i = 1; i < len; ++i) {
rnn::LinkMemories(step_scopes, memories, i, -1, false); rnn::LinkMemories(step_scopes, memories, i, -1, false /*infer_shape_mode*/);
} }
// check // check
for (int i = 0; i < len - 1; ++i) { for (int i = 0; i < len - 1; ++i) {
...@@ -373,7 +377,7 @@ TEST(RecurrentOp, LinkMemories) { ...@@ -373,7 +377,7 @@ TEST(RecurrentOp, LinkMemories) {
} }
for (int i = len - 2; i >= 0; --i) { for (int i = len - 2; i >= 0; --i) {
rnn::LinkMemories(step_scopes, memories, i, 1, false); rnn::LinkMemories(step_scopes, memories, i, 1, false /*infer_shape_mode*/);
} }
// check // check
for (int i = len - 2; i >= 0; --i) { for (int i = len - 2; i >= 0; --i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册