提交 017182c6 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #3124 from qingqing01/rnn_infershape

Refine InferShape for recurrent_network_op
...@@ -60,10 +60,5 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) ...@@ -60,10 +60,5 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
op_library(fc_op op_library(fc_op
SRCS fc_op.cc SRCS fc_op.cc
DEPS mul_op rowwise_add_op sigmoid_op softmax_op net) DEPS mul_op rowwise_add_op sigmoid_op softmax_op net)
op_library(recurrent_op SRCS recurrent_op.cc DEPS op_desc tensor op_registry operator net)
op_library(recurrent_network_op cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op)
SRCS recurrent_network_op.cc
DEPS op_desc tensor net)
cc_test(recurrent_network_op_test
SRCS recurrent_network_op_test.cc
DEPS recurrent_network_op mul_op add_op)
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/recurrent_network_op.h" #include "paddle/operators/recurrent_op.h"
#include <glog/logging.h> #include <glog/logging.h>
#include <cstring> #include <cstring>
...@@ -29,11 +29,15 @@ namespace rnn { ...@@ -29,11 +29,15 @@ namespace rnn {
void SegmentInputs(const std::vector<Scope*>& step_scopes, void SegmentInputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& inlinks, const std::vector<Link>& inlinks,
const size_t seq_len) { const size_t seq_len,
bool infer_shape_mode) {
PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided."); PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided.");
for (size_t i = 0; i < inlinks.size(); ++i) { for (size_t i = 0; i < inlinks.size(); ++i) {
Tensor* input = auto input_var = step_scopes[0]->FindVar(inlinks[i].external);
step_scopes[0]->FindVar(inlinks[i].external)->GetMutable<Tensor>(); PADDLE_ENFORCE(input_var != nullptr,
"input link [%s] is not in scope.",
inlinks[i].external);
Tensor* input = input_var->GetMutable<Tensor>();
DDim dims = input->dims(); DDim dims = input->dims();
PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len, PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len,
"all the inlinks must have same length"); "all the inlinks must have same length");
...@@ -41,7 +45,9 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes, ...@@ -41,7 +45,9 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
for (size_t j = 0; j < seq_len; j++) { for (size_t j = 0; j < seq_len; j++) {
Tensor* step_input = Tensor* step_input =
step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable<Tensor>(); step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable<Tensor>();
*step_input = input->Slice<float>(j, j + 1); if (!infer_shape_mode) {
*step_input = input->Slice<float>(j, j + 1);
}
step_input->Resize(step_dims); step_input->Resize(step_dims);
} }
} }
...@@ -49,36 +55,41 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes, ...@@ -49,36 +55,41 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
void ConcatOutputs(const std::vector<Scope*>& step_scopes, void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& outlinks, const std::vector<Link>& outlinks,
const size_t seq_len) { const size_t seq_len,
bool infer_shape_mode) {
for (size_t i = 0; i < outlinks.size(); i++) { for (size_t i = 0; i < outlinks.size(); i++) {
Tensor* output = auto output_var = step_scopes[0]->FindVar(outlinks[i].external);
step_scopes[0]->FindVar(outlinks[i].external)->GetMutable<Tensor>(); PADDLE_ENFORCE(output_var != nullptr,
"output link [%s] is not in scope.",
// TODO(qingiqng) remove following code after adding outlinks[i].external);
// InferShape in RecurrentGradientOp Tensor* output = output_var->GetMutable<Tensor>();
DDim step_dims = step_scopes[0] if (infer_shape_mode) {
->FindVar(outlinks[i].internal) DDim step_dims = step_scopes[0]
->GetMutable<Tensor>() ->FindVar(outlinks[i].internal)
->dims(); ->GetMutable<Tensor>()
std::vector<int> dims_vec = vectorize(step_dims); ->dims();
dims_vec.insert(dims_vec.begin(), seq_len); std::vector<int> dims_vec = vectorize(step_dims);
output->mutable_data<float>(make_ddim(dims_vec), platform::CPUPlace()); dims_vec.insert(dims_vec.begin(), seq_len);
output->Resize(make_ddim(dims_vec));
for (size_t j = 0; j < seq_len; j++) { } else {
Tensor* step_output = output->mutable_data<float>(platform::CPUPlace());
step_scopes[j]->FindVar(outlinks[i].internal)->GetMutable<Tensor>(); for (size_t j = 0; j < seq_len; j++) {
// TODO(luotao02) data type and platform::DeviceContext() should set Tensor* step_output =
// correctly step_scopes[j]->FindVar(outlinks[i].internal)->GetMutable<Tensor>();
(output->Slice<float>(j, j + 1)) // TODO(luotao02) data type and platform::DeviceContext() should set
.CopyFrom<float>(*step_output, platform::CPUPlace()); // correctly
(output->Slice<float>(j, j + 1))
.CopyFrom<float>(*step_output, platform::CPUPlace());
}
} }
} }
} }
void LinkMemories(const std::vector<Scope*>& scopes, void LinkMemories(const std::vector<Scope*>& scopes,
const std::vector<rnn::MemoryAttr>& memories, const std::vector<rnn::MemoryAttr>& memories,
size_t step_id, const size_t step_id,
int offset) { const int offset,
bool infer_shape_mode) {
PADDLE_ENFORCE(step_id < scopes.size(), PADDLE_ENFORCE(step_id < scopes.size(),
"step [%d] is out of range of step scopes' size [%d]", "step [%d] is out of range of step scopes' size [%d]",
step_id, step_id,
...@@ -95,18 +106,13 @@ void LinkMemories(const std::vector<Scope*>& scopes, ...@@ -95,18 +106,13 @@ void LinkMemories(const std::vector<Scope*>& scopes,
auto scope = scopes[step_id]; auto scope = scopes[step_id];
auto linked_scope = scopes[step_id + offset]; auto linked_scope = scopes[step_id + offset];
for (auto& attr : memories) { for (auto& attr : memories) {
auto mem = scope->NewVar(attr.pre_var)->GetMutable<Tensor>(); auto mem = scope->FindVar(attr.pre_var)->GetMutable<Tensor>();
// maybe share variable is better?
auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable<Tensor>(); auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable<Tensor>();
mem->ShareDataWith<float>(*linked_mem); if (infer_shape_mode) {
mem->Resize(linked_mem->dims());
// TODO(qingqing) remove following code } else {
// the memory of current step should be allocated in step net mem->ShareDataWith<float>(*linked_mem);
auto m = scope->NewVar(attr.var)->GetMutable<Tensor>(); }
// for unit test, as addOp and mulOp are null currently, if not
// mutable_data, mem.data() in output will be error. We will
// remove this line after merge the correct addOp and mulOp.
m->mutable_data<float>(mem->dims(), platform::CPUPlace());
} }
} }
...@@ -175,60 +181,39 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const { ...@@ -175,60 +181,39 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const {
->dims()[0]; ->dims()[0];
CreateScopes(scope); CreateScopes(scope);
auto step_scopes = GetStepScopes(scope); auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(
// SegmentInputs is called in InferShape. The input must hold memory in step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/);
// SegmentInputs. But the other op only set dimension for the output in InitMemories(step_scopes[0], true /*infer_shape_mode*/);
// InferShape. That's a problem. Wether the RNN op needs InferShape or not?
// Wether the following functions (SegmentInputs, InitMemories, ...) need
// to rewrite for RNN op?
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_);
InitMemories(step_scopes[0]);
PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr,
"stepnet [%s] is not in scope.",
arg_->step_net);
Variable* net = scope.FindVar(arg_->step_net); Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net"); PADDLE_ENFORCE(net != nullptr, "failed to get step net");
// If the InferShape is called in OperatorBase's run function,
// the rnn op only needs to do InferShape for the first time step
for (size_t i = 0; i < seq_len_; i++) { for (size_t i = 0; i < seq_len_; i++) {
if (i > 0) { if (i > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, i, -1); rnn::LinkMemories(
step_scopes, arg_->memories, i, -1, true /*infer_shape_mode*/);
} }
net->GetMutable<NetOp>()->InferShape(*step_scopes[i]); net->GetMutable<NetOp>()->InferShape(*step_scopes[i]);
} }
rnn::ConcatOutputs(
auto outlinks = arg_->outlinks; step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/);
for (size_t i = 0; i < outlinks.size(); i++) {
DDim step_dims = step_scopes[0]
->FindVar(outlinks[i].internal)
->GetMutable<Tensor>()
->dims();
std::vector<int> dims_vec = vectorize(step_dims);
// now only support fixed length
dims_vec.insert(dims_vec.begin(), seq_len_);
Tensor* output =
step_scopes[0]->FindVar(outlinks[i].external)->GetMutable<Tensor>();
output->Resize(make_ddim(dims_vec));
}
} }
void RecurrentAlgorithm::Run(const Scope& scope, void RecurrentAlgorithm::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const { const platform::DeviceContext& dev_ctx) const {
auto step_scopes = GetStepScopes(scope); auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(
step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/);
InitMemories(step_scopes[0], false /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net); Variable* net = scope.FindVar(arg_->step_net);
for (size_t step_id = 0; step_id < seq_len_; step_id++) { for (size_t step_id = 0; step_id < seq_len_; step_id++) {
// the link memory is done in InferShape
// maybe remove following code after testing
if (step_id > 0) { if (step_id > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1); rnn::LinkMemories(
step_scopes, arg_->memories, step_id, -1, false /*infer_shape_mode*/);
} }
net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx); net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx);
} }
rnn::ConcatOutputs(
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_); step_scopes, arg_->outlinks, seq_len_, false /*infer_shape_mode*/);
} }
void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
...@@ -244,18 +229,19 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { ...@@ -244,18 +229,19 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
// Now all variables in scope must be created outside of op. // Now all variables in scope must be created outside of op.
auto net_op = scope.FindVar(arg_->step_net)->GetMutable<NetOp>(); auto net_op = scope.FindVar(arg_->step_net)->GetMutable<NetOp>();
for (auto& input : net_op->inputs_) { for (auto& input : net_op->inputs_) {
// the weight are located in parent scope
if (!step_scope.FindVar(input)) step_scope.NewVar(input); if (!step_scope.FindVar(input)) step_scope.NewVar(input);
} }
for (auto& output : net_op->outputs_) { for (auto& output : net_op->outputs_) {
step_scope.NewVar(output); step_scope.NewVar(output);
} }
step_scopes->emplace_back(&step_scope); step_scopes->emplace_back(&step_scope);
} }
} }
} }
void RecurrentAlgorithm::InitMemories(Scope* step_scope) const { void RecurrentAlgorithm::InitMemories(Scope* step_scope,
bool infer_shape_mode) const {
for (auto& attr : arg_->memories) { for (auto& attr : arg_->memories) {
Tensor* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable<Tensor>(); Tensor* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable<Tensor>();
PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr, PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,
...@@ -263,13 +249,11 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope) const { ...@@ -263,13 +249,11 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope) const {
attr.var, attr.var,
attr.boot_var); attr.boot_var);
Tensor* boot_mem = step_scope->FindVar(attr.boot_var)->GetMutable<Tensor>(); Tensor* boot_mem = step_scope->FindVar(attr.boot_var)->GetMutable<Tensor>();
pre_mem->ShareDataWith<float>(*boot_mem); if (infer_shape_mode) {
pre_mem->Resize(boot_mem->dims());
// TODO(qingqing) remove following code } else {
// the memory of current step should be allocated in step net pre_mem->ShareDataWith<float>(*boot_mem);
// here for unit test }
auto cur_step_mem = step_scope->NewVar(attr.var)->GetMutable<Tensor>();
cur_step_mem->mutable_data<float>(boot_mem->dims(), platform::CPUPlace());
} }
} }
...@@ -307,13 +291,14 @@ public: ...@@ -307,13 +291,14 @@ public:
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
const auto& name = RecurrentOp::kArgName; const auto& name = RecurrentOp::kArgName;
// inputs and outputs stored in proto // inputs and outputs stored in proto
AddInput(name.inlinks, "the input that need to be segmented for each step.") AddInput(name.inlinks,
"the inputs that need to be segmented for each step.")
.SetMultiple(); .SetMultiple();
AddInput(name.boot_memories, "variables to initialize memories.") AddInput(name.boot_memories, "variables to initialize memories.")
.SetMultiple(); .SetMultiple();
AddInput(name.step_net, "network shared by all steps."); AddInput(name.step_net, "network shared by all steps.");
AddOutput(name.outlinks, "the output that need to concated for all steps.") AddOutput(name.outlinks, "the outputs that need to concated for all steps.")
.SetMultiple(); .SetMultiple();
AddOutput(name.step_scopes, "step scopes"); AddOutput(name.step_scopes, "step scopes");
...@@ -331,34 +316,39 @@ public: ...@@ -331,34 +316,39 @@ public:
void RecurrentGradientAlgorithm::Run( void RecurrentGradientAlgorithm::Run(
const Scope& scope, const platform::DeviceContext& dev_ctx) const { const Scope& scope, const platform::DeviceContext& dev_ctx) const {
auto step_scopes = GetStepScopes(scope); auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); rnn::SegmentInputs(
PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr, step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/);
"step net is not in scope.");
Variable* net = scope.FindVar(arg_->step_net); Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net"); PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) { if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1); rnn::LinkMemories(
step_scopes, arg_->memories, step_id, 1, false /*infer_shape_mode*/);
} }
net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx); net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx);
} }
LinkBootMemoryGradients(step_scopes[0]); LinkBootMemoryGradients(step_scopes[0], false);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_); rnn::ConcatOutputs(
step_scopes, arg_->outlinks, seq_len_, false /*infer_shape_mode*/);
} }
void RecurrentGradientAlgorithm::LinkBootMemoryGradients( void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
Scope* step_scope) const { Scope* step_scope, bool infer_shape_mode) const {
for (auto& attr : arg_->memories) { for (auto& attr : arg_->memories) {
Tensor* mem_grad = step_scope->NewVar(attr.var)->GetMutable<Tensor>(); PADDLE_ENFORCE(step_scope->FindVar(attr.var) != nullptr,
PADDLE_ENFORCE(mem_grad != nullptr, "memory variable [%s] does not exists",
"boot_tensor should be retrieved before"); attr.var);
PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr, PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,
"memory [%s]'s boot variable [%s] not exists", "boot variable [%s] does not exists",
attr.var,
attr.boot_var); attr.boot_var);
Tensor* mem_grad = step_scope->NewVar(attr.var)->GetMutable<Tensor>();
Tensor* boot_mem_grad = Tensor* boot_mem_grad =
step_scope->NewVar(attr.boot_var)->GetMutable<Tensor>(); step_scope->NewVar(attr.boot_var)->GetMutable<Tensor>();
boot_mem_grad->ShareDataWith<float>(*mem_grad); if (infer_shape_mode) {
boot_mem_grad->Resize(mem_grad->dims());
} else {
boot_mem_grad->ShareDataWith<float>(*mem_grad);
}
} }
} }
...@@ -367,34 +357,20 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const { ...@@ -367,34 +357,20 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
->GetMutable<Tensor>() ->GetMutable<Tensor>()
->dims()[0]; ->dims()[0];
auto step_scopes = GetStepScopes(scope); auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); rnn::SegmentInputs(
step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/);
PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr,
"step net is not in scope.");
Variable* net = scope.FindVar(arg_->step_net); Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net"); PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) { if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1); rnn::LinkMemories(
step_scopes, arg_->memories, step_id, 1, true /*infer_shape_mode*/);
} }
net->GetMutable<NetOp>()->InferShape(*step_scopes[step_id]); net->GetMutable<NetOp>()->InferShape(*step_scopes[step_id]);
} }
rnn::ConcatOutputs(
auto outlinks = arg_->outlinks; step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/);
for (size_t i = 0; i < outlinks.size(); i++) { LinkBootMemoryGradients(step_scopes[0], true /*infer_shape_mode*/);
DDim step_dims = step_scopes[0]
->FindVar(outlinks[i].internal)
->GetMutable<Tensor>()
->dims();
std::vector<int> dims_vec = vectorize(step_dims);
// now only support fixed length
dims_vec.insert(dims_vec.begin(), seq_len_);
Tensor* output =
step_scopes[0]->FindVar(outlinks[i].external)->GetMutable<Tensor>();
output->Resize(make_ddim(dims_vec));
}
LinkBootMemoryGradients(step_scopes[0]);
} }
void RecurrentGradientOp::Init() { void RecurrentGradientOp::Init() {
......
...@@ -72,19 +72,22 @@ struct ArgumentName { ...@@ -72,19 +72,22 @@ struct ArgumentName {
*/ */
void SegmentInputs(const std::vector<Scope*>& step_scopes, void SegmentInputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& inlinks, const std::vector<Link>& inlinks,
const size_t seq_len); const size_t seq_len,
bool infer_shape_mode);
/** /**
* Process outputs of step nets and merge to variables. * Process outputs of step nets and merge to variables.
*/ */
void ConcatOutputs(const std::vector<Scope*>& step_scopes, void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& outlinks, const std::vector<Link>& outlinks,
const size_t seq_len); const size_t seq_len,
bool infer_shape_mode);
void LinkMemories(const std::vector<Scope*>& step_scopes, void LinkMemories(const std::vector<Scope*>& step_scopes,
const std::vector<MemoryAttr>& memories, const std::vector<MemoryAttr>& memories,
size_t step_id, const size_t step_id,
int offset); const int offset,
bool infer_shape_mode);
void InitArgument(const ArgumentName& name, Argument* arg); void InitArgument(const ArgumentName& name, Argument* arg);
...@@ -122,7 +125,7 @@ protected: ...@@ -122,7 +125,7 @@ protected:
return *scope.FindVar(arg_->step_scopes)->GetMutable<std::vector<Scope*>>(); return *scope.FindVar(arg_->step_scopes)->GetMutable<std::vector<Scope*>>();
} }
void InitMemories(Scope* step_scopes) const; void InitMemories(Scope* step_scopes, bool infer_shape_mode) const;
private: private:
std::unique_ptr<rnn::Argument> arg_; std::unique_ptr<rnn::Argument> arg_;
...@@ -145,7 +148,7 @@ public: ...@@ -145,7 +148,7 @@ public:
void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const; void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const;
void LinkBootMemoryGradients(Scope* step_scopes) const; void LinkBootMemoryGradients(Scope* step_scopes, bool infer_shape_mode) const;
/** /**
* InferShape must be called before Run. * InferShape must be called before Run.
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/operators/recurrent_network_op.h" #include "paddle/operators/recurrent_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -55,7 +55,7 @@ protected: ...@@ -55,7 +55,7 @@ protected:
w->GetMutable<Tensor>()->mutable_data<float>( w->GetMutable<Tensor>()->mutable_data<float>(
make_ddim(std::vector<int>{30, 30}), platform::CPUPlace()); make_ddim(std::vector<int>{30, 30}), platform::CPUPlace());
for (auto boot : std::vector<std::string>{"x_boot", "h_boot"}) { for (auto boot : std::vector<std::string>{"h_boot"}) {
LOG(INFO) << "create global variable " << boot; LOG(INFO) << "create global variable " << boot;
Variable* h_boot = scope_.NewVar(boot); Variable* h_boot = scope_.NewVar(boot);
h_boot->GetMutable<Tensor>()->mutable_data<float>( h_boot->GetMutable<Tensor>()->mutable_data<float>(
...@@ -79,7 +79,6 @@ protected: ...@@ -79,7 +79,6 @@ protected:
op_desc.add_inputs("x0"); op_desc.add_inputs("x0");
op_desc.add_inputs("x1"); op_desc.add_inputs("x1");
// boot_memories 3 // boot_memories 3
op_desc.add_inputs("x_boot");
op_desc.add_inputs("h_boot"); op_desc.add_inputs("h_boot");
// step net 5 // step net 5
op_desc.add_inputs("step_net"); op_desc.add_inputs("step_net");
...@@ -91,7 +90,7 @@ protected: ...@@ -91,7 +90,7 @@ protected:
auto _input_format = std::vector<int>{ auto _input_format = std::vector<int>{
0, // in_link 0, // in_link
3, // memories 3, // memories
5 // step_net 4 // step_net
}; };
auto input_format = op_desc.add_attrs(); auto input_format = op_desc.add_attrs();
input_format->set_name("input_format"); input_format->set_name("input_format");
...@@ -129,12 +128,11 @@ protected: ...@@ -129,12 +128,11 @@ protected:
inlink_alias->add_strings(item); inlink_alias->add_strings(item);
} }
// pre memories // pre memories
for (const auto& item : for (const auto& item : std::vector<std::string>{"rnn/h@pre"}) {
std::vector<std::string>{"rnn/x@pre", "rnn/h@pre"}) {
pre_memories->add_strings(item); pre_memories->add_strings(item);
} }
// memories // memories
for (const auto& item : std::vector<std::string>{"rnn/x", "rnn/h"}) { for (const auto& item : std::vector<std::string>{"rnn/h"}) {
memories->add_strings(item); memories->add_strings(item);
} }
// output alias // output alias
...@@ -151,14 +149,11 @@ protected: ...@@ -151,14 +149,11 @@ protected:
LOG(INFO) << "create variable step_net"; LOG(INFO) << "create variable step_net";
Variable* var = scope_.NewVar("step_net"); Variable* var = scope_.NewVar("step_net");
auto net = var->GetMutable<NetOp>(); auto net = var->GetMutable<NetOp>();
// rnn/s is net's input or output?
net->inputs_ = {"rnn/h@pre", "rnn/w", "rnn/x"};
net->inputs_ = {"rnn/s", "rnn/h"};
net->AddOp( net->AddOp(
OpRegistry::CreateOp("mul", {"rnn/h@pre", "rnn/w"}, {"rnn/s"}, {})); OpRegistry::CreateOp("mul", {"rnn/h@pre", "rnn/w"}, {"rnn/s"}, {}));
net->AddOp( net->AddOp(
OpRegistry::CreateOp("add_two", {"rnn/x", "rnn/s"}, {"rnn/h"}, {})); OpRegistry::CreateOp("add_two", {"x@alias", "rnn/s"}, {"rnn/h"}, {}));
net->CompleteAddOp(); net->CompleteAddOp();
} }
...@@ -297,7 +292,10 @@ protected: ...@@ -297,7 +292,10 @@ protected:
inlink.internal = "rnn/x"; inlink.internal = "rnn/x";
auto step_scopes = auto step_scopes =
scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>(); scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>();
rnn::SegmentInputs(*step_scopes, std::vector<rnn::Link>{inlink}, 10); rnn::SegmentInputs(*step_scopes,
std::vector<rnn::Link>{inlink},
10,
true /*infer_shape_mode*/);
} }
void LinkeMemories() { void LinkeMemories() {
...@@ -311,7 +309,8 @@ protected: ...@@ -311,7 +309,8 @@ protected:
auto step_scopes = auto step_scopes =
scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>(); scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>();
for (int i = 1; i < 10; ++i) { for (int i = 1; i < 10; ++i) {
rnn::LinkMemories(*step_scopes, memories, i, -1); rnn::LinkMemories(
*step_scopes, memories, i, -1, true /*infer_shape_mode*/);
} }
} }
...@@ -333,14 +332,14 @@ TEST(RecurrentOp, LinkMemories) { ...@@ -333,14 +332,14 @@ TEST(RecurrentOp, LinkMemories) {
using namespace paddle::operators; using namespace paddle::operators;
// create and init step scopes // create and init step scopes
int len = 10; size_t len = 10;
std::vector<Scope*> step_scopes; std::vector<Scope*> step_scopes;
for (int i = 0; i < len; ++i) { for (size_t i = 0; i < len; ++i) {
auto scope = new Scope(); auto scope = new Scope();
scope->NewVar("pre_h"); scope->NewVar("pre_h");
auto tensor = scope->NewVar("h")->GetMutable<Tensor>(); auto tensor = scope->NewVar("h")->GetMutable<Tensor>();
float* data = tensor->mutable_data<float>({15, 20}, CPUPlace()); float* data = tensor->mutable_data<float>({15, 20}, CPUPlace());
for (int j = 0; j < 15 * 20; ++j) { for (size_t j = 0; j < 15 * 20; ++j) {
data[j] = rand() * (1. / (double)RAND_MAX); data[j] = rand() * (1. / (double)RAND_MAX);
} }
step_scopes.push_back(scope); step_scopes.push_back(scope);
...@@ -354,24 +353,24 @@ TEST(RecurrentOp, LinkMemories) { ...@@ -354,24 +353,24 @@ TEST(RecurrentOp, LinkMemories) {
std::vector<rnn::MemoryAttr> memories; std::vector<rnn::MemoryAttr> memories;
memories.push_back(mem_attr); memories.push_back(mem_attr);
for (int i = 1; i < len; ++i) { for (size_t i = 1; i < len; ++i) {
rnn::LinkMemories(step_scopes, memories, i, -1); rnn::LinkMemories(step_scopes, memories, i, -1, false /*infer_shape_mode*/);
} }
// check // check
for (int i = 0; i < len - 1; ++i) { for (size_t i = 0; i < len - 1; ++i) {
const float* a = const float* a =
step_scopes[i]->FindVar("h")->GetMutable<Tensor>()->data<float>(); step_scopes[i]->FindVar("h")->GetMutable<Tensor>()->data<float>();
const float* b = step_scopes[i + 1] const float* b = step_scopes[i + 1]
->FindVar("pre_h") ->FindVar("pre_h")
->GetMutable<Tensor>() ->GetMutable<Tensor>()
->data<float>(); ->data<float>();
for (size_t i = 0; i < 15 * 20; ++i) { for (size_t j = 0; j < 15 * 20; ++j) {
ASSERT_FLOAT_EQ(a[i], b[i]); ASSERT_FLOAT_EQ(a[j], b[j]);
} }
} }
for (int i = len - 2; i >= 0; --i) { for (int i = len - 2; i >= 0; --i) {
rnn::LinkMemories(step_scopes, memories, i, 1); rnn::LinkMemories(step_scopes, memories, i, 1, false /*infer_shape_mode*/);
} }
// check // check
for (int i = len - 2; i >= 0; --i) { for (int i = len - 2; i >= 0; --i) {
...@@ -379,8 +378,8 @@ TEST(RecurrentOp, LinkMemories) { ...@@ -379,8 +378,8 @@ TEST(RecurrentOp, LinkMemories) {
step_scopes[i]->FindVar("pre_h")->GetMutable<Tensor>()->data<float>(); step_scopes[i]->FindVar("pre_h")->GetMutable<Tensor>()->data<float>();
const float* b = const float* b =
step_scopes[i + 1]->FindVar("h")->GetMutable<Tensor>()->data<float>(); step_scopes[i + 1]->FindVar("h")->GetMutable<Tensor>()->data<float>();
for (size_t i = 0; i < 15 * 20; ++i) { for (size_t j = 0; j < 15 * 20; ++j) {
ASSERT_FLOAT_EQ(a[i], b[i]); ASSERT_FLOAT_EQ(a[j], b[j]);
} }
} }
...@@ -391,9 +390,3 @@ TEST(RecurrentOp, LinkMemories) { ...@@ -391,9 +390,3 @@ TEST(RecurrentOp, LinkMemories) {
USE_OP(add_two); USE_OP(add_two);
USE_OP(mul); USE_OP(mul);
// int main() {
// //! TODO(yuyang18): Temporary disable this unit-test because implementation
// //! error.
// return 0;
//}
\ No newline at end of file
...@@ -6,4 +6,4 @@ cc_library(paddle_pybind SHARED ...@@ -6,4 +6,4 @@ cc_library(paddle_pybind SHARED
add_op add_op
mean_op mean_op
cross_entropy_op cross_entropy_op
recurrent_network_op) recurrent_op)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册