From aee0d3ec5f7aaeeb35939b152b91c24e8a166920 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Fri, 28 Jul 2017 22:13:44 -0500 Subject: [PATCH] RecurrentOp implementation (#2890) * add rnn op interfaces * add Run * rename state -> memory * change state -> memory * make compilable * add .cc * init test * add op fake implementation * add CreateStepNet and CreateScopes implementation. * add TODO list * init memory attributes. * add LinkMemories * add PlainNet fake implementation * Use std::shared_ptr in the OpRunContext. * add test * disable mutable_data * finist segmentInput function * enable mutable_data with a trick * RNNOp test. * enable LinkMemories with mutable_data * update SegmentInput function with comments * finish ConcatOutput function * reformat inputs and attributes boot_memories * Refine unit test. * Refine unit test. * modify inlinks. * add OpDesc to Net * fix bug and update unit test. * move step scopes from inputs to outputs * fix merge conflict, update SegmentInput function * add RecurrentOpProtoAndCheckerMaker. * clean the codes * Abstract GetStepScopes and GetMaxSeqLen function * refine LinkMemories * Refine code and add some comments. * add backward core * update for develop branch. * add forward core * add forward algorithm * Add RecurrentGradientAlgorithm implenmention. * use CopyFrom and Slice function in RecurrentOp * add unit test for LinkMemories. * fix unit test. * use the latest tensor.h, solve conflict * add maker * move SegmentInput and ConcatOutput to details nameplace * unit test for RecurrentGradientAlgorithm. * apply OperatorBase * apply net operator. * move memorys to attributes * add RecurrentGradientOp * open test unit test in recurrent_network_op_test. * revert some files. * add RecurrentArgument and Link struct to simplify member variable. * rename. * move recurrent_op from framework to operators * add RecurrentGradientOp Init * fix name * fix Link.interal/external name * use namespace operators instead of framework * clean the code * use the latest add_op and mul_op, don't test backward now * Remove ScopePtr and OperatorPtr * add get_net to pybind * add test_recurrent_op.py * add random into gen_tensor * update to develop branch and refine some code. * add some comments. --- paddle/operators/CMakeLists.txt | 5 + paddle/operators/recurrent_network_op.cc | 418 ++++++++++++++++++ paddle/operators/recurrent_network_op.h | 216 +++++++++ paddle/operators/recurrent_network_op_test.cc | 400 +++++++++++++++++ paddle/pybind/CMakeLists.txt | 2 +- paddle/pybind/pybind.cc | 6 + .../v2/framework/tests/test_recurrent_op.py | 92 ++++ 7 files changed, 1138 insertions(+), 1 deletion(-) create mode 100644 paddle/operators/recurrent_network_op.cc create mode 100644 paddle/operators/recurrent_network_op.h create mode 100644 paddle/operators/recurrent_network_op_test.cc create mode 100644 python/paddle/v2/framework/tests/test_recurrent_op.py diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 0a14dc2114..5085e1b925 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -54,3 +54,8 @@ op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op softmax_op net) op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) + +op_library(recurrent_network_op SRCS recurrent_network_op.cc DEPS op_desc +tensor op_registry operator net) +cc_test(recurrent_network_op_test SRCS recurrent_network_op_test.cc DEPS +recurrent_network_op gtest mul_op add_op) diff --git a/paddle/operators/recurrent_network_op.cc b/paddle/operators/recurrent_network_op.cc new file mode 100644 index 0000000000..0a86d4b9fb --- /dev/null +++ b/paddle/operators/recurrent_network_op.cc @@ -0,0 +1,418 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/recurrent_network_op.h" + +#include +#include +#include + +#include "paddle/framework/net.h" +#include "paddle/framework/op_registry.h" +#include "paddle/platform/enforce.h" + +namespace paddle { +namespace operators { + +namespace rnn { + +void SegmentInputs(std::vector>& step_scopes, + const std::vector& inlinks, + const size_t seq_len) { + PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided."); + for (size_t i = 0; i < inlinks.size(); ++i) { + Tensor* input = + step_scopes[0]->GetVariable(inlinks[i].external)->GetMutable(); + DDim dims = input->dims(); + PADDLE_ENFORCE(static_cast(dims[0]) == seq_len, + "all the inlinks must have same length"); + DDim step_dims = slice_ddim(dims, 1, dims.size()); + for (size_t j = 0; j < seq_len; j++) { + Tensor* step_input = step_scopes[j] + ->CreateVariable(inlinks[i].internal) + ->GetMutable(); + *step_input = input->Slice(j, j + 1); + step_input->Resize(step_dims); + } + } +} + +void ConcatOutputs(std::vector>& step_scopes, + const std::vector& outlinks, + const size_t seq_len) { + for (size_t i = 0; i < outlinks.size(); i++) { + Tensor* output = + step_scopes[0]->GetVariable(outlinks[i].external)->GetMutable(); + + // TODO(qingiqng) remove following code after adding + // InferShape in RecurrentGradientOp + DDim step_dims = step_scopes[0] + ->GetVariable(outlinks[i].internal) + ->GetMutable() + ->dims(); + std::vector dims_vec = vectorize(step_dims); + dims_vec.insert(dims_vec.begin(), seq_len); + output->mutable_data(make_ddim(dims_vec), platform::CPUPlace()); + + for (size_t j = 0; j < seq_len; j++) { + Tensor* step_output = step_scopes[j] + ->GetVariable(outlinks[i].internal) + ->GetMutable(); + // TODO data type and platform::DeviceContext() should set correctly + (output->Slice(j, j + 1)) + .CopyFrom(*step_output, platform::CPUDeviceContext()); + } + } +} + +void LinkMemories(std::vector>& scopes, + const std::vector& memories, + size_t step_id, + int offset) { + PADDLE_ENFORCE(step_id < scopes.size(), + "step [%d] is out of range of step scopes' size [%d]", + step_id, + scopes.size()); + PADDLE_ENFORCE(static_cast(step_id) + offset >= 0, + "offset [%d] must be large than -[%d]", + offset, + step_id); + PADDLE_ENFORCE(step_id + offset < scopes.size(), + "offset [%d] is out of range, it must be less than (%d - %d)", + offset, + scopes.size(), + step_id); + std::shared_ptr scope = scopes[step_id]; + std::shared_ptr linked_scope = scopes[step_id + offset]; + for (auto& attr : memories) { + auto mem = scope->CreateVariable(attr.pre_var)->GetMutable(); + // maybe share variable is better? + auto linked_mem = linked_scope->GetVariable(attr.var)->GetMutable(); + mem->ShareDataWith(*linked_mem); + + // TODO(qingqing) remove following code + // the memory of current step should be allocated in step net + auto m = scope->CreateVariable(attr.var)->GetMutable(); + // for unit test, as addOp and mulOp are null currently, if not + // mutable_data, mem.data() in output will be error. We will + // remove this line after merge the correct addOp and mulOp. + m->mutable_data(mem->dims(), platform::CPUPlace()); + } +} + +void InitArgument(const ArgumentName& name, + Argument* arg, + const OperatorBase& op) { + arg->step_net = op.Input(name.step_net); + arg->step_scopes = op.Output(name.step_scopes); + + auto inlinks = op.Inputs(name.inlinks); + auto inlink_alias = op.GetAttr>(name.inlink_alias); + PADDLE_ENFORCE(inlinks.size() == inlink_alias.size(), + "the size of inlinks and inlink_alias don't match:%d,%d", + inlinks.size(), + inlink_alias.size()); + for (size_t i = 0; i < inlinks.size(); ++i) { + rnn::Link link; + link.external = inlinks[i]; + link.internal = inlink_alias[i]; + (arg->inlinks).push_back(link); + } + + auto outlinks = op.Outputs(name.outlinks); + auto outlink_alias = op.GetAttr>(name.outlink_alias); + PADDLE_ENFORCE(outlinks.size() == outlink_alias.size(), + "the size of outlinks and outlink_alias don't match:%d,%d", + outlinks.size(), + outlink_alias.size()); + for (size_t i = 0; i < outlinks.size(); ++i) { + rnn::Link link; + link.external = outlinks[i]; + link.internal = outlink_alias[i]; + (arg->outlinks).push_back(link); + } + + auto boot_memories = op.Inputs(name.boot_memories); + + // attributes + auto memories = op.GetAttr>(name.memories); + auto pre_memories = op.GetAttr>(name.pre_memories); + + PADDLE_ENFORCE(memories.size() == boot_memories.size(), + "the size of memories, boot_memories don't match:%d,%d", + memories.size(), + boot_memories.size()); + PADDLE_ENFORCE(pre_memories.size() == boot_memories.size(), + "the size of pre_memories, boot_memories don't match:%d,%d", + pre_memories.size(), + boot_memories.size()); + PADDLE_ENFORCE(memories.size() > 0, "more than 1 memories should be set"); + + for (size_t i = 0; i < memories.size(); ++i) { + rnn::MemoryAttr mem_attr; + mem_attr.var = memories[i]; + mem_attr.pre_var = pre_memories[i]; + mem_attr.boot_var = boot_memories[i]; + (arg->memories).push_back(mem_attr); + } +} + +} // namespace rnn + +void RecurrentAlgorithm::InferShape(const std::shared_ptr& scope) const { + seq_len_ = scope->GetVariable((arg_->inlinks[0]).external) + ->GetMutable() + ->dims()[0]; + CreateScopes(scope); + auto step_scopes = GetStepScopes(scope); + + // SegmentInputs is called in InferShape. The input must hold memory in + // SegmentInputs. But the other op only set dimension for the output in + // InferShape. That's a problem. Wether the RNN op needs InferShape or not? + // Wether the following functions (SegmentInputs, InitMemories, ...) need + // to rewrite for RNN op? + rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); + + InitMemories(step_scopes[0]); + + PADDLE_ENFORCE(scope->HasVariable(arg_->step_net), + "stepnet [%s] is not in scope.", + arg_->step_net); + Variable* net = scope->GetVariable(arg_->step_net); + PADDLE_ENFORCE(net != nullptr, "failed to get step net"); + // If the InferShape is called in OperatorBase's run function, + // the rnn op only needs to do InferShape for the first time step + for (size_t i = 0; i < seq_len_; i++) { + if (i > 0) { + rnn::LinkMemories(step_scopes, arg_->memories, i, -1); + } + net->GetMutable()->InferShape(step_scopes[i]); + } + + auto outlinks = arg_->outlinks; + for (size_t i = 0; i < outlinks.size(); i++) { + DDim step_dims = step_scopes[0] + ->GetVariable(outlinks[i].internal) + ->GetMutable() + ->dims(); + std::vector dims_vec = vectorize(step_dims); + // now only support fixed length + dims_vec.insert(dims_vec.begin(), seq_len_); + Tensor* output = + step_scopes[0]->GetVariable(outlinks[i].external)->GetMutable(); + output->Resize(make_ddim(dims_vec)); + } +} + +void RecurrentAlgorithm::Run(const std::shared_ptr& scope, + const platform::DeviceContext& dev_ctx) const { + auto step_scopes = GetStepScopes(scope); + + Variable* net = scope->GetVariable(arg_->step_net); + for (size_t step_id = 0; step_id < seq_len_; step_id++) { + // the link memory is done in InferShape + // maybe remove following code after testing + if (step_id > 0) { + rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1); + } + net->GetMutable()->Run(step_scopes[step_id], dev_ctx); + } + + rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_); +} + +void RecurrentAlgorithm::CreateScopes(std::shared_ptr scope) const { + // TODO(xxx) Only two scopes are needed for inference, this case will be + // supported later. + auto step_scopes = scope->GetVariable(arg_->step_scopes) + ->GetMutable>>(); + + if (seq_len_ > step_scopes->size()) { + for (size_t i = step_scopes->size(); i < seq_len_; ++i) { + std::shared_ptr step_scope = std::make_shared(scope); + + // Now all variables in scope must be created outside of op. + auto net_op = scope->GetVariable(arg_->step_net)->GetMutable(); + for (auto& input : net_op->inputs_) { + step_scope->CreateVariable(input); + } + for (auto& output : net_op->outputs_) { + step_scope->CreateVariable(output); + } + + step_scopes->push_back(std::make_shared(step_scope)); + } + } +} + +void RecurrentAlgorithm::InitMemories(std::shared_ptr step_scope) const { + for (auto& attr : arg_->memories) { + Tensor* pre_mem = + step_scope->CreateVariable(attr.pre_var)->GetMutable(); + PADDLE_ENFORCE(step_scope->HasVariable(attr.boot_var), + "memory [%s]'s boot variable [%s] not exists", + attr.var, + attr.boot_var); + Tensor* boot_mem = + step_scope->GetVariable(attr.boot_var)->GetMutable(); + pre_mem->ShareDataWith(*boot_mem); + + // TODO(qingqing) remove following code + // the memory of current step should be allocated in step net + // here for unit test + auto cur_step_mem = + step_scope->CreateVariable(attr.var)->GetMutable(); + cur_step_mem->mutable_data(boot_mem->dims(), platform::CPUPlace()); + } +} + +const rnn::ArgumentName RecurrentOp::kArgName{"step_net", + "step_scopes", + "inlinks", + "outlinks", + "inlink_alias", + "outlink_alias", + "memories", + "pre_memories", + "boot_memories"}; + +const rnn::ArgumentName RecurrentGradientOp::kArgName{"step_net", + "step_scopes", + "outlink@grad", + "inlink@grad", + "inlink_alias", + "outlink_alias", + "memories", + "pre_memories", + "boot_memories@grad"}; + +void RecurrentOp::Init() { + OperatorBase::Init(); + std::unique_ptr arg(new rnn::Argument()); + rnn::InitArgument(kArgName, arg.get(), *this); + alg_.Init(std::move(arg)); +} + +class RecurrentAlgorithmProtoAndCheckerMaker : public OpProtoAndCheckerMaker { +public: + RecurrentAlgorithmProtoAndCheckerMaker(OpProto* proto, + OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + const auto& name = RecurrentOp::kArgName; + // inputs and outputs stored in proto + AddInputs(name.inlinks, + "the input that need to be segmented for each step."); + AddInputs(name.boot_memories, "variables to initialize memories."); + AddInput(name.step_net, "network shared by all steps."); + + AddOutputs(name.outlinks, + "the output that need to concated for all steps."); + AddOutput(name.step_scopes, "step scopes"); + + // Attributes stored in AttributeMap + AddAttr>(name.inlink_alias, "alias of inlinks"); + AddAttr>(name.outlink_alias, "alias of outlinks"); + AddAttr>(name.pre_memories, + "names of pre-memories"); + AddAttr>(name.memories, "names of memories"); + + AddComment("This is a recurrent group operator."); + } +}; + +void RecurrentGradientAlgorithm::Run( + const std::shared_ptr& scope, + const platform::DeviceContext& dev_ctx) const { + auto step_scopes = GetStepScopes(scope); + rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); + PADDLE_ENFORCE(scope->HasVariable(arg_->step_net), + "step net is not in scope."); + Variable* net = scope->GetVariable(arg_->step_net); + PADDLE_ENFORCE(net != nullptr, "failed to get step net"); + for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { + if (static_cast(step_id) != seq_len_ - 1) { + rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1); + } + net->GetMutable()->Run(step_scopes[step_id], dev_ctx); + } + LinkBootMemoryGradients(step_scopes[0]); + rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_); +} + +void RecurrentGradientAlgorithm::LinkBootMemoryGradients( + std::shared_ptr step_scope) const { + for (auto& attr : arg_->memories) { + Tensor* mem_grad = + step_scope->CreateVariable(attr.var)->GetMutable(); + PADDLE_ENFORCE(mem_grad != nullptr, + "boot_tensor should be retrieved before"); + PADDLE_ENFORCE(step_scope->HasVariable(attr.boot_var), + "memory [%s]'s boot variable [%s] not exists", + attr.var, + attr.boot_var); + Tensor* boot_mem_grad = + step_scope->CreateVariable(attr.boot_var)->GetMutable(); + boot_mem_grad->ShareDataWith(*mem_grad); + } +} + +void RecurrentGradientAlgorithm::InferShape( + const std::shared_ptr& scope) const { + seq_len_ = scope->GetVariable((arg_->inlinks[0]).external) + ->GetMutable() + ->dims()[0]; + auto step_scopes = GetStepScopes(scope); + rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); + + PADDLE_ENFORCE(scope->HasVariable(arg_->step_net), + "step net is not in scope."); + Variable* net = scope->GetVariable(arg_->step_net); + PADDLE_ENFORCE(net != nullptr, "failed to get step net"); + + for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { + if (static_cast(step_id) != seq_len_ - 1) { + rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1); + } + net->GetMutable()->InferShape(step_scopes[step_id]); + } + + auto outlinks = arg_->outlinks; + for (size_t i = 0; i < outlinks.size(); i++) { + DDim step_dims = step_scopes[0] + ->GetVariable(outlinks[i].internal) + ->GetMutable() + ->dims(); + std::vector dims_vec = vectorize(step_dims); + // now only support fixed length + dims_vec.insert(dims_vec.begin(), seq_len_); + Tensor* output = + step_scopes[0]->GetVariable(outlinks[i].external)->GetMutable(); + output->Resize(make_ddim(dims_vec)); + } + LinkBootMemoryGradients(step_scopes[0]); +} + +void RecurrentGradientOp::Init() { + OperatorBase::Init(); + std::unique_ptr arg(new rnn::Argument()); + rnn::InitArgument(kArgName, arg.get(), *this); + alg_.Init(std::move(arg)); +} + +} // namespace operators +} // namespace paddle + +REGISTER_OP(recurrent_op, + paddle::operators::RecurrentOp, + paddle::operators::RecurrentAlgorithmProtoAndCheckerMaker); diff --git a/paddle/operators/recurrent_network_op.h b/paddle/operators/recurrent_network_op.h new file mode 100644 index 0000000000..8946c8ce38 --- /dev/null +++ b/paddle/operators/recurrent_network_op.h @@ -0,0 +1,216 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +using namespace paddle::framework; + +namespace rnn { + +/** + * Memory of a RNN (same as the role of `Momory` in PaddlePaddle). + * + * Memory attributes cached by this op, dims will be infered from + * boot memories in father scope. Other attributes are copied from Op's proto + * attributes. + */ +struct MemoryAttr { + // name of current state variable + std::string var; + // name of previous step's state variable + std::string pre_var; + // name of the variables to init this memory (same role of `boot_layer` in + // PaddlePaddle), which is store in father's scope. + std::string boot_var; +}; + +struct Link { + // input or output links name. + std::string internal; + // alias to avoid duplicate keys in scopes. + std::string external; +}; + +struct Argument { + std::string step_net; + std::string step_scopes; + std::vector inlinks; + std::vector outlinks; + std::vector memories; +}; + +struct ArgumentName { + std::string step_net; + std::string step_scopes; + std::string inlinks; + std::string outlinks; + std::string inlink_alias; // the alias of inlinks in step net. + std::string outlink_alias; // the alias of outlinks in step net. + std::string memories; // the memory name + std::string pre_memories; // the previous memory name + std::string boot_memories; // the boot memory name +}; + +/** + * Prepare inputs for each step net. + */ +void SegmentInputs(std::vector>& step_scopes, + const std::vector& inlinks, + const size_t seq_len); + +/** + * Process outputs of step nets and merge to variables. + */ +void ConcatOutputs(std::vector>& step_scopes, + const std::vector& outlinks, + const size_t seq_len); + +void LinkMemories(std::vector>& step_scopes, + const std::vector& memories, + size_t step_id, + int offset); + +void InitArgument(const ArgumentName& name, Argument* arg); + +}; // namespace rnn + +// The sequence format in RecurrentOp is Tensor now. +// TODO: +// 1. No-padding computing for sequences with indifinite length in one batch. +// 2. Hierarchical RNN for sequence with sub-sequence. +// 3. Internal Memory. +// 4. More Complex RNN architecture, such as Gated Feedback RNN. +// Refer to: https://arxiv.org/pdf/1502.02367.pdf + +class RecurrentAlgorithm { +public: + void Run(const std::shared_ptr& scope, + const platform::DeviceContext& dev_ctx) const; + + void Init(std::unique_ptr arg) { arg_ = std::move(arg); } + + /** + * InferShape must be called before Run. + */ + void InferShape(const std::shared_ptr& scope) const; + +protected: + /* + * The step scopes will be stored in the father scope as a variable. + * + * NOTE the scopes are reused in both the forward and backward, so just + * create once and expand its size if more steps need. + */ + void CreateScopes(std::shared_ptr scope) const; + + inline const std::vector>& GetStepScopes( + std::shared_ptr scope) const { + return *(scope->GetVariable(arg_->step_scopes)) + ->GetMutable>>(); + } + + void InitMemories(std::shared_ptr step_scopes) const; + +private: + std::unique_ptr arg_; + mutable size_t seq_len_; +}; + +class RecurrentGradientAlgorithm { + /** + * RNN's backward alogorithm. + * + * To accelerate the development of RecurrentGradientOp, we decouple RNN's + * algorithm and `OperatorBase`'s implementation, the former contains the core + * implementation of a RNN, and will keep stable even if the framework changes + * a + * lot, and the latter is a wrapper acts like an dapter for it to make RNN an + * operator. + */ +public: + void Init(std::unique_ptr arg) { arg_ = std::move(arg); } + + void Run(const std::shared_ptr& scope, + const platform::DeviceContext& dev_ctx) const; + + void LinkBootMemoryGradients(std::shared_ptr step_scopes) const; + + /** + * InferShape must be called before Run. + */ + void InferShape(const std::shared_ptr& scope) const; + +protected: + inline const std::vector>& GetStepScopes( + std::shared_ptr scope) const { + return *(scope->GetVariable(arg_->step_scopes)) + ->GetMutable>>(); + } + +private: + std::unique_ptr arg_; + mutable size_t seq_len_; +}; + +class RecurrentOp final : public OperatorBase { +public: + void Init() override; + + /** + * InferShape must be called before Run. + */ + virtual void InferShape(const std::shared_ptr& scope) const override { + alg_.InferShape(scope); + } + + virtual void Run(const std::shared_ptr& scope, + const platform::DeviceContext& dev_ctx) const override { + alg_.Run(scope, dev_ctx); + } + + static const rnn::ArgumentName kArgName; + +private: + RecurrentAlgorithm alg_; +}; + +class RecurrentGradientOp final : public OperatorBase { +public: + void Init() override; + + /** + * InferShape must be called before Run. + */ + virtual void InferShape(const std::shared_ptr& scope) const override { + alg_.InferShape(scope); + } + + virtual void Run(const std::shared_ptr& scope, + const platform::DeviceContext& dev_ctx) const override { + alg_.Run(scope, dev_ctx); + } + + static const rnn::ArgumentName kArgName; + +private: + RecurrentGradientAlgorithm alg_; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/recurrent_network_op_test.cc b/paddle/operators/recurrent_network_op_test.cc new file mode 100644 index 0000000000..6784ac6001 --- /dev/null +++ b/paddle/operators/recurrent_network_op_test.cc @@ -0,0 +1,400 @@ +/* + Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +#include +#include + +#include "paddle/framework/net.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" +#include "paddle/framework/tensor.h" +#include "paddle/operators/recurrent_network_op.h" + +namespace paddle { +namespace operators { + +class RecurrentOpTest : public ::testing::Test { +protected: + virtual void SetUp() override { + CreateGlobalVariables(); + CreateStepNet(); + CreateRNNOp(); + } + + virtual void TearDown() override {} + + void CreateGlobalVariables() { + scope_ = std::make_shared(); + // create input, and init content + LOG(INFO) << "create global variable x"; + for (auto inlink : std::vector{"x", "x0", "x1", "h"}) { + Variable* x = scope_->CreateVariable(inlink); + DDim dims = make_ddim(std::vector{ + 10 /*sent size*/, 20 /*batch size*/, 30 /*input dim*/}); + x->GetMutable()->mutable_data(dims, platform::CPUPlace()); + } + // create output alias just for test + for (auto inlink : std::vector{"h@alias"}) { + Variable* x = scope_->CreateVariable(inlink); + DDim dims = + make_ddim(std::vector{20 /*batch size*/, 30 /*input dim*/}); + x->GetMutable()->mutable_data(dims, platform::CPUPlace()); + } + + LOG(INFO) << "create global variable w"; + Variable* w = scope_->CreateVariable("rnn/w"); + w->GetMutable()->mutable_data( + make_ddim(std::vector{30, 30}), platform::CPUPlace()); + + for (auto boot : std::vector{"x_boot", "h_boot"}) { + LOG(INFO) << "create global variable " << boot; + Variable* h_boot = scope_->CreateVariable(boot); + h_boot->GetMutable()->mutable_data( + make_ddim(std::vector{20 /*batch size*/, 30 /*input dim*/}), + platform::CPUPlace()); + } + + LOG(INFO) << "create variable step_scopes"; + scope_->CreateVariable("step_scopes"); + + LOG(INFO) << "create variable h"; + scope_->CreateVariable("h"); + } + + void CreateRNNOp() { + OpDesc op_desc; + + op_desc.set_type("recurrent_op"); + // inlinks 0 + op_desc.add_inputs("x"); + op_desc.add_inputs("x0"); + op_desc.add_inputs("x1"); + // boot_memories 3 + op_desc.add_inputs("x_boot"); + op_desc.add_inputs("h_boot"); + // step net 5 + op_desc.add_inputs("step_net"); + // outlinks 6 + op_desc.add_outputs("h"); + // step scopes 7 + op_desc.add_outputs("step_scopes"); + + auto _input_format = std::vector{ + 0, // in_link + 3, // memories + 5 // step_net + }; + auto input_format = op_desc.add_attrs(); + input_format->set_name("input_format"); + input_format->set_type(paddle::framework::AttrType::INTS); + for (auto i : _input_format) { + input_format->add_ints(i); + } + + auto output_format = op_desc.add_attrs(); + output_format->set_name("output_format"); + output_format->set_type(paddle::framework::AttrType::INTS); + for (auto i : std::vector{0, 1, 2}) { + output_format->add_ints(i); + } + + auto inlink_alias = op_desc.add_attrs(); + inlink_alias->set_name("inlink_alias"); + inlink_alias->set_type(paddle::framework::AttrType::STRINGS); + + auto outlink_alias = op_desc.add_attrs(); + outlink_alias->set_name("outlink_alias"); + outlink_alias->set_type(paddle::framework::AttrType::STRINGS); + + auto pre_memories = op_desc.add_attrs(); + pre_memories->set_name("pre_memories"); + pre_memories->set_type(paddle::framework::AttrType::STRINGS); + + auto memories = op_desc.add_attrs(); + memories->set_name("memories"); + memories->set_type(paddle::framework::AttrType::STRINGS); + + // create inlink_alias + for (const auto& item : + std::vector{"x@alias", "x0@alias", "x1@alias"}) { + inlink_alias->add_strings(item); + } + // pre memories + for (const auto& item : + std::vector{"rnn/x@pre", "rnn/h@pre"}) { + pre_memories->add_strings(item); + } + // memories + for (const auto& item : std::vector{"rnn/x", "rnn/h"}) { + memories->add_strings(item); + } + // output alias + for (const auto& item : std::vector{"h@alias"}) { + outlink_alias->add_strings(item); + } + + rnn_op_ = OpRegistry::CreateOp(op_desc); + + LOG(INFO) << "rnn_op finish init"; + } + + void CreateStepNet() { + LOG(INFO) << "create variable step_net"; + Variable* var = scope_->CreateVariable("step_net"); + auto net = var->GetMutable(); + // rnn/s is net's input or output? + net->inputs_ = {"rnn/h@pre", "rnn/w", "rnn/x"}; + net->inputs_ = {"rnn/s", "rnn/h"}; + net->AddOp( + OpRegistry::CreateOp("mul", {"rnn/h@pre", "rnn/w"}, {"rnn/s"}, {})); + + net->AddOp( + OpRegistry::CreateOp("add_two", {"rnn/x", "rnn/s"}, {"rnn/h"}, {})); + net->CompleteAddOp(); + } + + // father scope + std::shared_ptr scope_; + std::shared_ptr rnn_op_; +}; + +TEST_F(RecurrentOpTest, Run) { + platform::CPUDeviceContext ctx; + rnn_op_->InferShape(scope_); + rnn_op_->Run(scope_, ctx); +} + +class RecurrentGradientAlgorithmTest : public ::testing::Test { +protected: + virtual void SetUp() override { + CreateGlobalVariables(); + CreateStepScopes(); + CreateStepNet(); + CreateRNNGradientAlgorithm(); + + // segment inputs + SegmentInputs(); + // link forward memories + LinkeMemories(); + } + + virtual void TearDown() override {} + + void CreateGlobalVariables() { + scope_ = std::make_shared(); + // inputs: x + LOG(INFO) << "create global variable x"; + Variable* x = scope_->CreateVariable("x"); + DDim dims = + make_ddim({10 /*sent size*/, 20 /*batch size*/, 30 /*input dim*/}); + x->GetMutable()->mutable_data(dims, platform::CPUPlace()); + // inputs: h_boot + LOG(INFO) << "create global variable h_boot"; + Variable* h_boot = scope_->CreateVariable("h_boot"); + h_boot->GetMutable()->mutable_data( + make_ddim({20 /*batch size*/, 30 /*input dim*/}), platform::CPUPlace()); + // inputs: w + LOG(INFO) << "create global variable w"; + Variable* w = scope_->CreateVariable("rnn/w"); + w->GetMutable()->mutable_data(make_ddim({30, 30}), + platform::CPUPlace()); + // inputs: h_grad + LOG(INFO) << "create variable h_grad"; + Variable* dh = scope_->CreateVariable("h_grad"); + dh->GetMutable()->mutable_data(make_ddim({10, 20, 30}), + platform::CPUPlace()); + // inputs: step_scopes + LOG(INFO) << "create variable step_scopes"; + scope_->CreateVariable("step_scopes"); + // inputs: step_net + LOG(INFO) << "create variable step_net"; + scope_->CreateVariable("step_net"); + // outputs: w_grad + LOG(INFO) << "create global variable w_grad"; + scope_->CreateVariable("rnn/w_grad"); + // outputs: x_grad + LOG(INFO) << "create global variable x_grad"; + scope_->CreateVariable("x_grad"); + // outputs: h_boot_grad + LOG(INFO) << "create global variable h_boot_grad"; + scope_->CreateVariable("h_boot_grad"); + } + + void CreateStepScopes() { + std::vector>* step_scopes = + scope_->GetVariable("step_scopes") + ->GetMutable>>(); + for (int i = 0; i < 10; ++i) { + auto scope = std::make_shared(scope_); + auto pre_t = scope->CreateVariable("rnn/pre_h")->GetMutable(); + pre_t->mutable_data(make_ddim({20, 30}), platform::CPUPlace()); + auto tensor = scope->CreateVariable("rnn/h")->GetMutable(); + tensor->mutable_data(make_ddim({20, 30}), platform::CPUPlace()); + + // for unit test of ConcatOutputs + auto xg = scope->CreateVariable("rnn/x_grad")->GetMutable(); + xg->mutable_data(make_ddim({20, 30}), platform::CPUPlace()); + + step_scopes->push_back(scope); + } + + // last time step + auto g = (*step_scopes)[9] + ->CreateVariable("rnn/h_pre_grad") + ->GetMutable(); + g->mutable_data(make_ddim({20, 30}), platform::CPUPlace()); + } + + void CreateRNNGradientAlgorithm() { + std::unique_ptr arg(new rnn::Argument()); + arg->step_net = "step_net"; + arg->step_scopes = "step_scopes"; + rnn::Link inlink; + inlink.external = "h_grad"; + inlink.internal = "rnn/h_grad"; + arg->inlinks = std::vector{inlink}; + + rnn::Link outlink; + outlink.external = "x_grad"; + outlink.internal = "rnn/x_grad"; + arg->outlinks = std::vector{outlink}; + + rnn::MemoryAttr mem_attr; + mem_attr.pre_var = "rnn/h_pre_grad"; + mem_attr.var = "rnn/h_grad"; + mem_attr.boot_var = "h_boot_grad"; + arg->memories = std::vector{mem_attr}; + + rnn_grad_algo_.Init(std::move(arg)); + } + + void CreateStepNet() { + LOG(INFO) << "create variable step_net"; + Variable* var = scope_->CreateVariable("step_net"); + auto net = var->GetMutable(); + net->AddOp(OpRegistry::CreateOp("mul", + {"rnn/h_pre", "rnn/w", "rnn/s_grad"}, + {"rnn/h_pre_grad", "rnn/w_grad"}, + {})); + + net->AddOp(OpRegistry::CreateOp( + "add_two", {"rnn/h_grad"}, {"rnn/x_grad", "rnn/s_grad"}, {})); + net->CompleteAddOp(); + } + + void SegmentInputs() { + LOG(INFO) << "segment inputs"; + std::vector inlinks = {"x"}; + std::vector inlinks_alias = {"rnn/x"}; + + rnn::Link inlink; + inlink.external = "x"; + inlink.internal = "rnn/x"; + std::vector>* step_scopes = + scope_->GetVariable("step_scopes") + ->GetMutable>>(); + rnn::SegmentInputs(*step_scopes, std::vector{inlink}, 10); + } + + void LinkeMemories() { + LOG(INFO) << "link memories"; + rnn::MemoryAttr mem_attr; + mem_attr.pre_var = "rnn/h_pre"; + mem_attr.var = "rnn/h"; + mem_attr.boot_var = "boot_h"; + std::vector memories; + memories.push_back(mem_attr); + std::vector>* step_scopes = + scope_->GetVariable("step_scopes") + ->GetMutable>>(); + for (int i = 1; i < 10; ++i) { + rnn::LinkMemories(*step_scopes, memories, i, -1); + } + } + + std::shared_ptr scope_; + RecurrentGradientAlgorithm rnn_grad_algo_; +}; + +// TEST_F(RecurrentGradientAlgorithmTest, Run) { +// platform::CPUDeviceContext ctx; +// rnn_grad_algo_.Run(scope_, ctx); +// } + +} // namespace operators +} // namespace paddle + +TEST(RecurrentOp, LinkMemories) { + using namespace paddle::framework; + using namespace paddle::platform; + using namespace paddle::operators; + + // create and init step scopes + int len = 10; + std::vector> step_scopes; + for (int i = 0; i < len; ++i) { + auto scope = std::make_shared(); + scope->CreateVariable("pre_h"); + auto tensor = scope->CreateVariable("h")->GetMutable(); + float* data = tensor->mutable_data(make_ddim({15, 20}), CPUPlace()); + for (int i = 0; i < 15 * 20; ++i) { + data[i] = rand() * (1. / (double)RAND_MAX); + } + step_scopes.push_back(scope); + } + + // create MemoryAttr + rnn::MemoryAttr mem_attr; + mem_attr.pre_var = "pre_h"; + mem_attr.var = "h"; + mem_attr.boot_var = "boot_h"; + std::vector memories; + memories.push_back(mem_attr); + + for (int i = 1; i < len; ++i) { + rnn::LinkMemories(step_scopes, memories, i, -1); + } + // check + for (int i = 0; i < len - 1; ++i) { + const float* a = + step_scopes[i]->GetVariable("h")->GetMutable()->data(); + const float* b = step_scopes[i + 1] + ->GetVariable("pre_h") + ->GetMutable() + ->data(); + for (size_t i = 0; i < 15 * 20; ++i) { + ASSERT_FLOAT_EQ(a[i], b[i]); + } + } + + for (int i = len - 2; i >= 0; --i) { + rnn::LinkMemories(step_scopes, memories, i, 1); + } + // check + for (int i = len - 2; i >= 0; --i) { + const float* a = step_scopes[i] + ->GetVariable("pre_h") + ->GetMutable() + ->data(); + const float* b = step_scopes[i + 1] + ->GetVariable("h") + ->GetMutable() + ->data(); + for (size_t i = 0; i < 15 * 20; ++i) { + ASSERT_FLOAT_EQ(a[i], b[i]); + } + } +} + +USE_OP(add_two); +USE_OP(mul); diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index fd1a142b40..7d0e68a8f3 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1,2 +1,2 @@ cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python - add_op fc_op sgd_op cross_entropy_op) + add_op fc_op sgd_op cross_entropy_op recurrent_network_op) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index ccefcd2511..08a8bd0d8b 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -36,6 +36,7 @@ USE_OP(mul); USE_OP(sigmoid); USE_OP(softmax); USE_OP(rowwise_add); +USE_OP_WITHOUT_KERNEL(recurrent_op); template void ExposeOperator(ClassType& m) { @@ -94,6 +95,11 @@ All parameter, weight, gradient are variables in Paddle. [](pd::Variable& self) -> pd::Tensor* { return self.GetMutable(); }, + py::return_value_policy::reference) + .def("get_net", + [](pd::Variable& self) -> pd::NetOp* { + return self.GetMutable(); + }, py::return_value_policy::reference); py::class_>(m, "Scope") diff --git a/python/paddle/v2/framework/tests/test_recurrent_op.py b/python/paddle/v2/framework/tests/test_recurrent_op.py new file mode 100644 index 0000000000..0457e3f16a --- /dev/null +++ b/python/paddle/v2/framework/tests/test_recurrent_op.py @@ -0,0 +1,92 @@ +import paddle.v2.framework.core as core +import unittest +import numpy as np +import paddle.v2.framework.create_op_creation_methods as creation + +ops = creation.op_creations + + +def create_tensor(scope, name, shape): + tensor = scope.create_var(name).get_tensor() + tensor.set_dims(shape) + tensor.alloc_float() + tensor.set(np.random.random(shape)) + return tensor + + +class TestRNN(unittest.TestCase): + ''' + Test RNNOp + + equation: + h_t = \sigma (W x_t + U h_{t-1}) + weights: + - W + - U + vars: + - x + memories: + - h + outputs: + - h + ''' + + def init(self): + input_dim = 30 + batch_size = 50 + weight_dim = 15 + + self.scope = core.Scope(None) + + # create vars + create_tensor(self.scope, "x", [batch_size, input_dim]) + create_tensor(self.scope, "W", [input_dim, weight_dim]) + create_tensor(self.scope, "U", [weight_dim, weight_dim]) + create_tensor(self.scope, "h_boot", [batch_size, weight_dim]) + + x_alias = "x@alias" + y_alias = "y@alias" + memory = "h@alias" + prememory = "h@pre" + output = "rnn_out" + output_alias = "rnn_out@alias" + + # create step net + stepnet_var = self.scope.create_var("stepnet") + stepnet = stepnet_var.get_net() + # stepnet = core.Net.create() + x_fc_op = ops.fc(X=x_alias, W="W", Y="Wx") + h_fc_op = ops.fc(X=prememory, W="U", Y="Uh") + sum_op = ops.add_two(X="Wx", Y="Uh", Out="sum") + sig_op = ops.sigmoid(X="sum", Y=memory) + stepnet.add_op(x_fc_op) + stepnet.add_op(h_fc_op) + stepnet.add_op(sum_op) + stepnet.add_op(sig_op) + stepnet.complete_add_op(True) + + # create RNNOp + rnnop = ops.recurrent_op( + # inputs + inlinks=["x"], + boot_memories=["h_boot"], + step_net="stepnet", + # outputs + outlinks=[output], + step_scopes="step_scopes", + # attributes + inlink_alias=["x@alias"], + outlink_alias=[output_alias], + pre_memories=[prememory], + memories=[memory]) + + ctx = core.DeviceContext.cpu_context() + rnnop.infer_shape(self.scope) + rnnop.run(self.scope, ctx) + + def test_recurrent(self): + self.init() + + +if __name__ == '__main__': + unittest.main() -- GitLab