From 89bc3fd841631a26b38fe424107688728493527f Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Fri, 19 Jul 2019 20:55:01 +0800 Subject: [PATCH] Support memory eager deletion on recurrent OP (#17710) Test PaddingRNN on V100 GPU device. Test configuration: large model, padding mode (which is the mode using recurrentOp), one GPU. GPU memory (MiB): 6414 (this PR) vs 6837 (without this PR) Speed (steps/s): 10.28 (this PR) vs 9.89 (without this PR) --- paddle/fluid/framework/CMakeLists.txt | 2 +- paddle/fluid/framework/executor.cc | 3 + .../ir/memory_optimize_pass/CMakeLists.txt | 4 +- .../eager_deletion_pass.cc | 5 + .../recurrent_op_eager_deletion_pass.cc | 76 ++ .../recurrent_op_eager_deletion_pass.h | 43 + .../operators/controlflow/CMakeLists.txt | 4 +- .../fluid/operators/controlflow/op_variant.cc | 72 ++ .../fluid/operators/controlflow/op_variant.h | 69 ++ .../controlflow/recurrent_op_helper.cc | 265 +++++ .../controlflow/recurrent_op_helper.h | 52 + .../operators/controlflow/while_op_helper.cc | 101 +- paddle/fluid/operators/recurrent_op.cc | 947 ++++++++---------- paddle/fluid/operators/recurrent_op.h | 226 +++++ .../fluid/operators/rnn_memory_helper_op.cc | 1 + paddle/fluid/string/string_helper.h | 10 +- .../test_eager_deletion_padding_rnn.py | 657 ++++++++++++ .../test_eager_deletion_recurrent_op.py | 683 +++++++++++++ 18 files changed, 2605 insertions(+), 615 deletions(-) create mode 100644 paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc create mode 100644 paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.h create mode 100644 paddle/fluid/operators/controlflow/op_variant.cc create mode 100644 paddle/fluid/operators/controlflow/op_variant.h create mode 100644 paddle/fluid/operators/controlflow/recurrent_op_helper.cc create mode 100644 paddle/fluid/operators/controlflow/recurrent_op_helper.h create mode 100644 paddle/fluid/operators/recurrent_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_eager_deletion_padding_rnn.py create mode 100644 python/paddle/fluid/tests/unittests/test_eager_deletion_recurrent_op.py diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 79c7bde040..5dc6e74b8f 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -193,7 +193,7 @@ else() cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) endif() -target_link_libraries(executor while_op_helper executor_gc_helper) +target_link_libraries(executor while_op_helper executor_gc_helper recurrent_op_helper) cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index e36871e8d8..cfab2f5f4c 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -30,6 +30,7 @@ limitations under the License. */ #include "paddle/fluid/framework/trainer_factory.h" #include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h" #include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/platform/place.h" @@ -410,6 +411,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, if (gc && ctx->prog_.Size() > 1) { operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(ctx->block_id_, ctx->ops_); + operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( + ctx->block_id_, ctx->ops_); } } diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt index 615245e725..070ea9aad0 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt +++ b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt @@ -1,5 +1,6 @@ cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base) cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle) +cc_library(recurrent_op_eager_deletion_pass SRCS recurrent_op_eager_deletion_pass.cc DEPS recurrent_op_helper graph_helper pass computation_op_handle) cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle) cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper) @@ -14,7 +15,8 @@ cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_ cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry) -cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass reference_count_pass_helper) +cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle + eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper) cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper) cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle multi_devices_helper graph pass) diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc index bbef21908d..452255a699 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc @@ -272,6 +272,10 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { auto while_op_eager_deletion_pass = ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass"); while_op_eager_deletion_pass->Apply(graph); + + auto recurrent_op_eager_deletion_pass = + ir::PassRegistry::Instance().Get("recurrent_op_eager_deletion_pass"); + recurrent_op_eager_deletion_pass->Apply(graph); } } // namespace ir @@ -285,3 +289,4 @@ REGISTER_PASS(eager_deletion_pass, paddle::framework::ir::EagerDeletionPass) .RequirePassAttr(paddle::framework::ir::kGarbageCollector); USE_PASS(while_op_eager_deletion_pass); +USE_PASS(recurrent_op_eager_deletion_pass); diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc new file mode 100644 index 0000000000..40e07ce8b6 --- /dev/null +++ b/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc @@ -0,0 +1,76 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.h" + +#include +#include + +#include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/string/string_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +using paddle::operators::OpVariant; +using paddle::operators::OpVariantSet; +using paddle::operators::OpAndGradOpPair; + +void RecurrentOpEagerDeletionPass::ApplyImpl(Graph *graph) const { + // Find all recurrent_op and recurrent_grad_op in graph + // Note the graph only contains ops and block 0 + std::unordered_map target_ops = + DeviceIdToRecurrentAndRecurrentGradOp(*graph); + + for (auto &entry : target_ops) { + // Prepare safe eager deletion on different devices because the garbage + // collection may be different across devices + OpAndGradOpPair &op_pair = entry.second; + PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(&op_pair); + } +} + +// Returns a std::unordered_map mapping from the device id to recurrent op and +// grad op pair +std::unordered_map +RecurrentOpEagerDeletionPass::DeviceIdToRecurrentAndRecurrentGradOp( + const Graph &graph) const { + std::unordered_map ret; + std::vector all_ops = + FilterByNodeWrapper(graph); + + for (auto *op : all_ops) { + auto compute_op = dynamic_cast(op); + if (compute_op == nullptr) continue; + + if (compute_op->Name() == "recurrent") { + // GetScopeIdx() returns device/place id + ret[compute_op->GetScopeIdx()].first.emplace(compute_op->GetOp()); + } else if (compute_op->Name() == "recurrent_grad") { + // GetScopeIdx() returns device/place id + ret[compute_op->GetScopeIdx()].second.emplace(compute_op->GetOp()); + } + } + return ret; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(recurrent_op_eager_deletion_pass, + paddle::framework::ir::RecurrentOpEagerDeletionPass); diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.h b/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.h new file mode 100644 index 0000000000..9c39a9faf2 --- /dev/null +++ b/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.h @@ -0,0 +1,43 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/operators/controlflow/op_variant.h" +#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +// Pass class set skip eager deletion vars for recurrent ops +class RecurrentOpEagerDeletionPass : public Pass { + protected: + void ApplyImpl(Graph *graph) const override; + + private: + // Returns a std::unordered_map mapping from the device id to recurrent op and + // grad op pair + std::unordered_map + DeviceIdToRecurrentAndRecurrentGradOp(const Graph &graph) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/operators/controlflow/CMakeLists.txt b/paddle/fluid/operators/controlflow/CMakeLists.txt index 7aa1c44eaa..f7281a2d1a 100644 --- a/paddle/fluid/operators/controlflow/CMakeLists.txt +++ b/paddle/fluid/operators/controlflow/CMakeLists.txt @@ -1,5 +1,7 @@ include(operators) register_operators(DEPS naive_executor) -cc_library(while_op_helper SRCS while_op_helper.cc DEPS operator) +cc_library(op_variant SRCS op_variant.cc DEPS operator proto_desc) +cc_library(recurrent_op_helper SRCS recurrent_op_helper.cc DEPS operator op_variant recurrent_op) +cc_library(while_op_helper SRCS while_op_helper.cc DEPS operator op_variant) file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") diff --git a/paddle/fluid/operators/controlflow/op_variant.cc b/paddle/fluid/operators/controlflow/op_variant.cc new file mode 100644 index 0000000000..d6eea8c4c8 --- /dev/null +++ b/paddle/fluid/operators/controlflow/op_variant.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/controlflow/op_variant.h" + +namespace paddle { +namespace operators { + +struct InputsVisitor + : public boost::static_visitor { + template + const framework::VariableNameMap *operator()(const OpType *op) const { + return &(op->Inputs()); + } +}; + +struct OutputsVisitor + : public boost::static_visitor { + template + const framework::VariableNameMap *operator()(const OpType *op) const { + return &(op->Outputs()); + } +}; + +struct AttributeMapVisitor + : public boost::static_visitor { + const framework::AttributeMap *operator()(const framework::OpDesc *op) const { + return &(op->GetAttrMap()); + } + + const framework::AttributeMap *operator()( + const framework::OperatorBase *op) const { + return &(op->Attrs()); + } +}; + +struct RawPointerVisitor : public boost::static_visitor { + template + const void *operator()(const OpType *op) const { + return op; + } +}; + +const framework::VariableNameMap &OpVariant::Inputs() const { + return *boost::apply_visitor(InputsVisitor(), op_); +} + +const framework::VariableNameMap &OpVariant::Outputs() const { + return *boost::apply_visitor(OutputsVisitor(), op_); +} + +const framework::AttributeMap &OpVariant::Attrs() const { + return *boost::apply_visitor(AttributeMapVisitor(), op_); +} + +const void *OpVariant::RawPointer() const { + return boost::apply_visitor(RawPointerVisitor(), op_); +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/controlflow/op_variant.h b/paddle/fluid/operators/controlflow/op_variant.h new file mode 100644 index 0000000000..26c70589f2 --- /dev/null +++ b/paddle/fluid/operators/controlflow/op_variant.h @@ -0,0 +1,69 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/platform/variant.h" + +namespace paddle { +namespace operators { + +// OpVariant is a wrapper class of OpDesc and OperatorBase pointer +// So that API would be the same. +class OpVariant { + public: + OpVariant(const framework::OperatorBase *op) : op_(op) {} // NOLINT + + OpVariant(const framework::OpDesc *op) : op_(op) {} // NOLINT + + const framework::VariableNameMap &Inputs() const; + + const framework::VariableNameMap &Outputs() const; + + const framework::AttributeMap &Attrs() const; + + const void *RawPointer() const; + + template + const AttrType &Attr(const std::string &name) const { + auto &attrs = Attrs(); + auto it = attrs.find(name); + PADDLE_ENFORCE(it != attrs.end(), "Cannot find attribute %s", name); + return boost::get(it->second); + } + + bool operator==(const OpVariant &other) const { + return RawPointer() == other.RawPointer(); + } + + int which() const { return static_cast(op_.which()); } + + struct Hasher { + size_t operator()(const OpVariant &op) const { + return reinterpret_cast(op.RawPointer()); + } + }; + + private: + const boost::variant + op_; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/controlflow/recurrent_op_helper.cc b/paddle/fluid/operators/controlflow/recurrent_op_helper.cc new file mode 100644 index 0000000000..6925086679 --- /dev/null +++ b/paddle/fluid/operators/controlflow/recurrent_op_helper.cc @@ -0,0 +1,265 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h" + +#include +#include +#include +#include + +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/operators/recurrent_op.h" + +namespace paddle { +namespace operators { + +static bool IsMatchedRecurrentOpAndRecurrentGradOp(const OpVariant &fwd_op, + const OpVariant &grad_op) { + return fwd_op.Inputs().at(RecurrentBase::kInputs) == + grad_op.Inputs().at(RecurrentBase::kInputs) && + fwd_op.Outputs().at(RecurrentBase::kOutputs) == + grad_op.Inputs().at(RecurrentBase::kOutputs); +} + +// Returns whether the variable is skippable in forward recurrent op +// The variable is skippable in recurrent_op when the variable used in +// recurrent_grad is not from grad_block. +static bool IsSkippableVar(const std::string &name, + framework::BlockDesc *grad_block) { + return name != framework::kEmptyVarName && !grad_block->HasVar(name); +} + +static void ClearSkipVars(const OpVariant &op) { + auto &attrs = const_cast(op.Attrs()); + std::vector &attr_skip_vars = + boost::get>( + attrs[RecurrentBase::kSkipEagerDeletionVars]); + attr_skip_vars.clear(); +} + +// Add skip vars into op's attribute +template +static void AddSkipVars(const OpVariant &op, const Container &skip_vars) { + auto &attrs = const_cast(op.Attrs()); + VLOG(2) << "Prepare to add " << skip_vars.size() + << " skip var(s): " << paddle::string::join_strings(skip_vars, ' '); + std::vector &attr_skip_vars = + boost::get>( + attrs[RecurrentBase::kSkipEagerDeletionVars]); + attr_skip_vars.insert(attr_skip_vars.end(), skip_vars.cbegin(), + skip_vars.cend()); +} + +// Find all ops and grad ops with given type name. The ops and grad ops +// may locate in different blocks so we should traverse all blocks in the +// program and find them out +static void FindAllOpAndGradOp(OpAndGradOpPair *op_and_grad_op, + const std::string &type_name, + const std::string &backward_type_name) { + OpVariantSet &ops = op_and_grad_op->first; + OpVariantSet &grad_ops = op_and_grad_op->second; + + PADDLE_ENFORCE_GE(ops.size(), grad_ops.size(), + "There are extra grad ops in the graph or program"); + + if (ops.empty()) return; + + const auto *program = + ops.begin() + ->Attr(RecurrentBase::kStepBlock) + ->Program(); + for (size_t i = 1; i < program->Size(); ++i) { + auto &block = program->Block(i); + for (size_t j = 0; j < block.OpSize(); ++j) { + auto *op = block.Op(j); + if (op->Type() == type_name) { + ops.emplace(op); + } else if (op->Type() == backward_type_name) { + grad_ops.emplace(op); + } + } + } + + PADDLE_ENFORCE_GE(ops.size(), grad_ops.size(), + "There are extra grad ops in the graph or program"); +} + +// Returns GradVarName of input var names +static std::vector GradVarLists( + const std::vector &var_names) { + std::vector retv; + retv.reserve(var_names.size()); + std::transform(var_names.begin(), var_names.end(), std::back_inserter(retv), + framework::GradVarName); + return retv; +} + +// Add memory vars in recurrent op as skip vars. +static void AddOpMemVarsAsSkip(const OpVariant &op, bool set_grad_mem_vars) { + bool has_state = op.Attr(RecurrentBase::kHasStates); + if (has_state) { + std::unordered_set skip_vars; + + auto &mem_vars = op.Attr>(RecurrentBase::kStates); + skip_vars.insert(mem_vars.begin(), mem_vars.end()); + + auto &pre_mem_vars = + op.Attr>(RecurrentBase::kExStates); + skip_vars.insert(pre_mem_vars.begin(), pre_mem_vars.end()); + + if (set_grad_mem_vars) { + auto mem_grad_vars = GradVarLists(mem_vars); + skip_vars.insert(mem_grad_vars.begin(), mem_grad_vars.end()); + auto pre_mem_grad_vars = GradVarLists(pre_mem_vars); + skip_vars.insert(pre_mem_grad_vars.begin(), pre_mem_grad_vars.end()); + } + AddSkipVars(op, skip_vars); + } +} + +// Set outputs and memory vars of the input forward op as skip vars +static void SetRecurrentForwardOpOnlySkipVarAttr(const OpVariant &fwd_op) { + ClearSkipVars(fwd_op); + + AddOpMemVarsAsSkip(fwd_op, /* set_grad_mem_vars = */ false); + auto &output_vars = fwd_op.Outputs().at(RecurrentBase::kOutputs); + AddSkipVars(fwd_op, output_vars); +} + +// Set skip vars of matched recurrent op and recurrent_grad op +static void SetRecurrentOpAndRecurrentGradOpSkipVarAttr( + const OpVariant &fwd_op, const OpVariant &bwd_op) { + // Find all skippable variables in forward recurrent_op + ClearSkipVars(fwd_op); + AddOpMemVarsAsSkip(fwd_op, /* set_grad_mem_vars = */ false); + + auto *grad_block = + bwd_op.Attr(RecurrentBase::kStepBlock); + std::unordered_set fwd_skip_vars; + for (auto *op_desc : grad_block->AllOps()) { + for (auto &in_arg_name : op_desc->InputArgumentNames()) { + if (IsSkippableVar(in_arg_name, grad_block)) { + fwd_skip_vars.insert(in_arg_name); + } + } + for (auto &out_arg_name : op_desc->OutputArgumentNames()) { + if (IsSkippableVar(out_arg_name, grad_block)) { + fwd_skip_vars.insert(out_arg_name); + } + } + } + AddSkipVars(fwd_op, fwd_skip_vars); + + // Find all skippable variables in recurrent_grad_op + // The skippable variables are those which would be used across time steps + ClearSkipVars(bwd_op); + AddOpMemVarsAsSkip(bwd_op, /* set_grad_mem_vars = */ true); + std::unordered_set bwd_skip_vars; + + auto &fwd_input = fwd_op.Inputs().at(RecurrentBase::kInputs); + auto &in_grads = + bwd_op.Outputs().at(framework::GradVarName(RecurrentBase::kInputs)); + + PADDLE_ENFORCE_EQ( + fwd_input.size(), in_grads.size(), + "Backward input gradient number does not match forward input number."); + for (size_t i = 0; i < in_grads.size(); ++i) { + if (in_grads[i] == framework::kEmptyVarName) { + continue; + } + bwd_skip_vars.insert(in_grads[i]); + bwd_skip_vars.insert(framework::GradVarName(fwd_input[i])); + } + + auto &fwd_param = fwd_op.Inputs().at(RecurrentBase::kParameters); + auto ¶m_grads = + bwd_op.Outputs().at(framework::GradVarName(RecurrentBase::kParameters)); + PADDLE_ENFORCE_EQ(fwd_param.size(), param_grads.size(), + "Backward parameter gradient number does not match forward " + "parameter number."); + for (size_t i = 0; i < fwd_param.size(); ++i) { + if (param_grads[i] == framework::kEmptyVarName) { + continue; + } + bwd_skip_vars.insert(param_grads[i]); + bwd_skip_vars.insert(framework::GradVarName(fwd_param[i])); + } + + AddSkipVars(bwd_op, bwd_skip_vars); +} + +void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( + int block_id, + const std::vector> + &all_ops) { + // If block_id is not 0, returns + // This is because all recurrent_ops and recurrent_grad_ops in the whole + // program would be processed when block_id is 0 (i.e. when Executor::Run() + // or ParallelExecutor constructs). + + // What's more, all recurrent_ops and recurrent_grad_ops must be processed + // when block_id is zero. If not, recurrent_op may run first and erase + // variables + // used in recurrent_grad_op, and in this moment, recurrent_grad_ops may be + // not constructed yet. + if (block_id != 0) return; + + OpAndGradOpPair op_pair; + for (auto &op : all_ops) { + if (op->Type() == "recurrent") { + op_pair.first.emplace(op.get()); + } else if (op->Type() == "recurrent_grad") { + op_pair.second.emplace(op.get()); + } + } + PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(&op_pair); +} + +void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( + OpAndGradOpPair *op_pair) { + // Find all ops and grad ops at all blocks + FindAllOpAndGradOp(op_pair, "recurrent", "recurrent_grad"); + + OpVariantSet &recurrent_ops = op_pair->first; + OpVariantSet &recurrent_grad_ops = op_pair->second; + + VLOG(2) << "Found recurrent op num: " << recurrent_ops.size() + << ", recurrent grad op num: " << recurrent_grad_ops.size(); + + if (recurrent_ops.empty()) { + return; + } + + for (auto &bwd_op : recurrent_grad_ops) { + const OpVariant *matched_fwd_op = nullptr; + for (auto &fwd_op : recurrent_ops) { + if (IsMatchedRecurrentOpAndRecurrentGradOp(fwd_op, bwd_op)) { + PADDLE_ENFORCE(matched_fwd_op == nullptr, + "Found multiple matched recurrent op"); + matched_fwd_op = &fwd_op; + } + } + PADDLE_ENFORCE_NOT_NULL(matched_fwd_op, "Cannot find matched forward op"); + SetRecurrentOpAndRecurrentGradOpSkipVarAttr(*matched_fwd_op, bwd_op); + recurrent_ops.erase(*matched_fwd_op); + } + + for (auto &fwd_op : recurrent_ops) { + SetRecurrentForwardOpOnlySkipVarAttr(fwd_op); + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/controlflow/recurrent_op_helper.h b/paddle/fluid/operators/controlflow/recurrent_op_helper.h new file mode 100644 index 0000000000..b1e6e662c0 --- /dev/null +++ b/paddle/fluid/operators/controlflow/recurrent_op_helper.h @@ -0,0 +1,52 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/controlflow/op_variant.h" +#include "paddle/fluid/operators/recurrent_op.h" +#include "paddle/fluid/platform/variant.h" +#include "paddle/fluid/string/string_helper.h" + +namespace paddle { +namespace operators { + +using OpVariantSet = std::unordered_set; +using OpAndGradOpPair = std::pair; + +// Set vars to skip eager deletion on input recurrent and recurrent_grad for +// preparing safe eager deletion. Input contains all recurrent and +// recurrent_grad ops at block 0 and the function will find all recurrent and +// recurrent_grad ops across blocks. +void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( + OpAndGradOpPair *op_pair); + +// Set vars to skip eager deletion on input recurrent and recurrent_grad for +// preparing safe eager deletion. The input block_id must be 0 and caller can +// input all ops in the block. The function will find all recurrent and +// recurrent_grad ops across blocks. +void PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( + int block_id, + const std::vector> + &all_ops); + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/controlflow/while_op_helper.cc b/paddle/fluid/operators/controlflow/while_op_helper.cc index 2cbd94a061..009bc5796c 100644 --- a/paddle/fluid/operators/controlflow/while_op_helper.cc +++ b/paddle/fluid/operators/controlflow/while_op_helper.cc @@ -13,109 +13,18 @@ // limitations under the License. #include "paddle/fluid/operators/controlflow/while_op_helper.h" + #include #include #include + #include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/operators/controlflow/op_variant.h" +#include "paddle/fluid/string/string_helper.h" namespace paddle { namespace operators { -// OpVariant is a wrapper class of OpDesc and OperatorBase -// So that API would be the same. -class OpVariant { - struct InputsVisitor - : public boost::static_visitor { - template - const framework::VariableNameMap *operator()(const OpType *op) const { - return &(op->Inputs()); - } - }; - - struct OutputsVisitor - : public boost::static_visitor { - template - const framework::VariableNameMap *operator()(const OpType *op) const { - return &(op->Outputs()); - } - }; - - struct AttributeMapVisitor - : public boost::static_visitor { - const framework::AttributeMap *operator()( - const framework::OpDesc *op) const { - return &(op->GetAttrMap()); - } - - const framework::AttributeMap *operator()( - const framework::OperatorBase *op) const { - return &(op->Attrs()); - } - }; - - struct RawPointerVisitor : public boost::static_visitor { - template - const void *operator()(const OpType *op) const { - return op; - } - }; - - public: - OpVariant(const framework::OperatorBase *op) : op_(op) {} // NOLINT - - OpVariant(const framework::OpDesc *op) : op_(op) {} // NOLINT - - const framework::VariableNameMap &Inputs() const { - return *boost::apply_visitor(InputsVisitor(), op_); - } - - const framework::VariableNameMap &Outputs() const { - return *boost::apply_visitor(OutputsVisitor(), op_); - } - - const framework::AttributeMap &Attrs() const { - return *boost::apply_visitor(AttributeMapVisitor(), op_); - } - - template - const AttrType &Attr(const std::string &name) const { - auto &attrs = Attrs(); - auto it = attrs.find(name); - PADDLE_ENFORCE(it != attrs.end(), "Cannot find attribute %s", name); - return boost::get(it->second); - } - - bool operator==(const OpVariant &other) const { - return RawPointer() == other.RawPointer(); - } - - const void *RawPointer() const { - return boost::apply_visitor(RawPointerVisitor(), op_); - } - - int which() const { return static_cast(op_.which()); } - - struct Hasher { - size_t operator()(const OpVariant &op) const { - return reinterpret_cast(op.RawPointer()); - } - }; - - private: - const boost::variant - op_; -}; - -static std::string GetDebugString(const std::vector &names) { - if (names.empty()) return ""; - std::string ret = names[0]; - for (size_t i = 1; i < names.size(); ++i) { - ret += (" " + names[i]); - } - return ret; -} - // Set skip variables of while_op and while_grad_op // These variables should be skipped when eager deletion enables. // It is because: @@ -124,7 +33,7 @@ static std::string GetDebugString(const std::vector &names) { static void SetSkipVars(const OpVariant &op, std::vector attr) { auto &attrs = const_cast(op.Attrs()); VLOG(2) << "Prepare to skip " << attr.size() - << " var(s): " << GetDebugString(attr); + << " var(s): " << string::join_strings(attr, ' '); attrs[kSkipEagerDeletionVars] = std::move(attr); } diff --git a/paddle/fluid/operators/recurrent_op.cc b/paddle/fluid/operators/recurrent_op.cc index b3bb1abf4d..d26a85fb93 100644 --- a/paddle/fluid/operators/recurrent_op.cc +++ b/paddle/fluid/operators/recurrent_op.cc @@ -12,31 +12,34 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include "paddle/fluid/framework/executor.h" -#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/recurrent_op.h" + +#include +#include "paddle/fluid/string/string_helper.h" namespace paddle { namespace operators { -constexpr char kInputs[] = "inputs"; -constexpr char kInitialStates[] = "initial_states"; -constexpr char kParameters[] = "parameters"; -constexpr char kOutputs[] = "outputs"; -constexpr char kStepScopes[] = "step_scopes"; -constexpr char kHasStates[] = "has_states"; -constexpr char kExStates[] = "ex_states"; -constexpr char kStates[] = "states"; -constexpr char kStepBlock[] = "sub_block"; -constexpr char kReverse[] = "reverse"; -constexpr char kIsTrain[] = "is_train"; -#define GRAD_SUFFIX "@GRAD" -constexpr char kInputGrads[] = "inputs" GRAD_SUFFIX; -constexpr char kOutputGrads[] = "outputs" GRAD_SUFFIX; -constexpr char kParamGrads[] = "parameters" GRAD_SUFFIX; -constexpr char kInitStateGrads[] = "initial_states" GRAD_SUFFIX; using StepScopeVar = std::vector; +const char RecurrentBase::kInputs[] = "inputs"; +const char RecurrentBase::kInitialStates[] = "initial_states"; +const char RecurrentBase::kParameters[] = "parameters"; +const char RecurrentBase::kOutputs[] = "outputs"; +const char RecurrentBase::kStepScopes[] = "step_scopes"; +const char RecurrentBase::kHasStates[] = "has_states"; +const char RecurrentBase::kExStates[] = "ex_states"; +const char RecurrentBase::kStates[] = "states"; +const char RecurrentBase::kStepBlock[] = "sub_block"; +const char RecurrentBase::kReverse[] = "reverse"; +const char RecurrentBase::kIsTrain[] = "is_train"; +const char RecurrentBase::kSkipEagerDeletionVars[] = "skip_eager_deletion_vars"; +#define GRAD_SUFFIX "@GRAD" +const char RecurrentBase::kInputGrads[] = "inputs" GRAD_SUFFIX; +const char RecurrentBase::kOutputGrads[] = "outputs" GRAD_SUFFIX; +const char RecurrentBase::kParamGrads[] = "parameters" GRAD_SUFFIX; +const char RecurrentBase::kInitStateGrads[] = "initial_states" GRAD_SUFFIX; + static void ClearStepScopes(const platform::DeviceContext &dev_ctx, framework::Scope *parent_scope, StepScopeVar *step_scopes) { @@ -65,534 +68,440 @@ static void ClearStepScopes(const platform::DeviceContext &dev_ctx, // reversely access scopes // else // access scopes from begin to end. -class StepScopes { - public: - StepScopes(const platform::DeviceContext &dev_ctx, - const framework::Scope &parent, StepScopeVar *scopes, - bool is_train, size_t seq_len, bool is_backward = false) - : counter_(is_backward ? seq_len - 1 : 0UL), - scopes_(scopes), - is_train_(is_train), - is_backward_(is_backward) { - size_t num_step_scopes = is_train ? seq_len : 2; - PADDLE_ENFORCE(is_train || !is_backward, - "Cannot backward when is not training"); - if (!is_backward_) { - ClearStepScopes(dev_ctx, const_cast(&parent), scopes); - scopes->reserve(static_cast(num_step_scopes)); - for (size_t i = 0; i < num_step_scopes; ++i) { - scopes->emplace_back(&parent.NewScope()); - } +StepScopes::StepScopes(const platform::DeviceContext &dev_ctx, + const framework::Scope &parent, StepScopeVar *scopes, + bool is_train, size_t seq_len, bool is_backward) + : counter_(is_backward ? seq_len - 1 : 0UL), + scopes_(scopes), + is_train_(is_train), + is_backward_(is_backward) { + size_t num_step_scopes = is_train ? seq_len : 2; + PADDLE_ENFORCE(is_train || !is_backward, + "Cannot backward when is not training"); + if (!is_backward_) { + ClearStepScopes(dev_ctx, const_cast(&parent), scopes); + scopes->reserve(static_cast(num_step_scopes)); + for (size_t i = 0; i < num_step_scopes; ++i) { + scopes->emplace_back(&parent.NewScope()); } } +} + +framework::Scope &StepScopes::CurScope() { return GetScope(counter_); } - framework::Scope &CurScope() { return GetScope(counter_); } +framework::Scope &StepScopes::ExScope() { + auto &scope = GetScope(is_backward_ ? counter_ + 1 : counter_ - 1); + return scope; +} - framework::Scope &ExScope() { - auto &scope = GetScope(is_backward_ ? counter_ + 1 : counter_ - 1); - return scope; +void StepScopes::Next() { + if (is_backward_) { + --counter_; + } else { + ++counter_; } +} - void Next() { - if (is_backward_) { - --counter_; - } else { - ++counter_; - } +framework::Scope &StepScopes::GetScope(size_t scope_id) const { + if (!is_train_) { + scope_id %= 2; } + PADDLE_ENFORCE_LT(scope_id, scopes_->size()); + return *(*scopes_)[scope_id]; +} - private: - framework::Scope &GetScope(size_t scope_id) const { - if (!is_train_) { - scope_id %= 2; +RecurrentBase::RecurrentBase(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + +// Get SequenceLength from Scope +// The sequence length is got from input tensor. The input tensor's +// dimension should be [SEQ_LEN, ..., ...]. The first of the tensor's shape +// is SEQ_LEN. The second of the tensor's shape could be the batch size or +// nested sequence length. +int64_t RecurrentBase::GetSequenceLength(const framework::Scope &scope) const { + // Dim format SEQ_LEN, BATCH_SIZE, ... + int64_t seq_len = -1; + auto &all_inputs = Inputs(kInputs); + PADDLE_ENFORCE(!all_inputs.empty()); + for (auto &iname : all_inputs) { + auto *var = scope.FindVar(iname); + PADDLE_ENFORCE(var != nullptr); + PADDLE_ENFORCE(var->IsType()); + auto &dim = var->Get().dims(); + if (seq_len == -1) { + seq_len = dim[0]; + } else { + PADDLE_ENFORCE_EQ(seq_len, dim[0]); } - PADDLE_ENFORCE_LT(scope_id, scopes_->size()); - return *(*scopes_)[scope_id]; } + return seq_len; +} - size_t counter_; - StepScopeVar *scopes_; - bool is_train_; - bool is_backward_; -}; +// for src_tensor, dst_tensor in zip(map(src_scope.FindVar, src_vars), +// map(dst_scope.Var, dst_vars)): +// dst_tensor.ShareDataWith(src_tensor) +void RecurrentBase::LinkTensor(const framework::Scope &src_scope, + const std::vector &src_vars, + framework::Scope *dst_scope, + const std::vector &dst_vars) { + LinkTensorWithCallback( + src_scope, src_vars, dst_scope, dst_vars, + [&](const framework::Tensor &src, framework::Tensor *dst) { + dst->ShareDataWith(src); + }); +} -// Base class for RecurrentOp/RecurrentGradOp -// Some common protected functions for RecurrentOp/RecurrentGradOp -class RecurrentBase : public framework::OperatorBase { - public: - RecurrentBase(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorBase(type, inputs, outputs, attrs) {} +// (seq_len, shape) -> return [seq_len] + list(shape) +framework::DDim RecurrentBase::PrependDims(size_t seq_len, + const framework::DDim &src) { + auto dims = framework::vectorize(src); + dims.insert(dims.begin(), static_cast(seq_len)); + return framework::make_ddim(dims); +} - protected: - // Get SequenceLength from Scope - // The sequence length is got from input tensor. The input tensor's - // dimension should be [SEQ_LEN, ..., ...]. The first of the tensor's shape - // is SEQ_LEN. The second of the tensor's shape could be the batch size or - // nested sequence length. - int64_t GetSequenceLength(const framework::Scope &scope) const { - // Dim format SEQ_LEN, BATCH_SIZE, ... - int64_t seq_len = -1; - auto &all_inputs = Inputs(kInputs); - PADDLE_ENFORCE(!all_inputs.empty()); - for (auto &iname : all_inputs) { - auto *var = scope.FindVar(iname); - PADDLE_ENFORCE(var != nullptr); - PADDLE_ENFORCE(var->IsType()); - auto &dim = var->Get().dims(); - if (seq_len == -1) { - seq_len = dim[0]; - } else { - PADDLE_ENFORCE_EQ(seq_len, dim[0]); - } - } - return seq_len; - } +RecurrentOp::RecurrentOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : RecurrentBase(type, inputs, outputs, attrs) {} - // for src_tensor, dst_tensor in zip(map(src_scope.FindVar, src_vars), - // map(dst_scope.Var, dst_vars)): - // dst_tensor.ShareDataWith(src_tensor) - static void LinkTensor(const framework::Scope &src_scope, - const std::vector &src_vars, - framework::Scope *dst_scope, - const std::vector &dst_vars) { - LinkTensorWithCallback( - src_scope, src_vars, dst_scope, dst_vars, - [&](const framework::Tensor &src, framework::Tensor *dst) { - dst->ShareDataWith(src); - }); - } +void RecurrentOp::RunImpl(const framework::Scope &scope, + const platform::Place &place) const { + bool has_state = Attr(kHasStates); + auto seq_len = static_cast(this->GetSequenceLength(scope)); - // for src_tensor, dst_tensor in zip(map(src_scope.FindVar, src_vars), - // map(dst_scope.Var, dst_vars)): - // callback(src_tensor, &dst_tensor) - template - static void LinkTensorWithCallback(const framework::Scope &src_scope, - const std::vector &src_vars, - framework::Scope *dst_scope, - const std::vector &dst_vars, - Callback callback, - bool is_backward = false) { - PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size()); - for (size_t i = 0; i < dst_vars.size(); ++i) { - VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i]; - AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback, - is_backward); - } - } + // get device context from pool + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); - // for src_tensor, dst_tensor in zip(map(src_scope.FindVar, src_vars), - // map(dst_scope.FindVar, dst_vars)): - // callback(src_tensor, &dst_tensor) - template - static void LinkTensorWithCallback(const framework::Scope &src_scope, - const std::vector &src_vars, - const framework::Scope &dst_scope, - const std::vector &dst_vars, - Callback callback, - bool is_backward = false) { - PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size()); - for (size_t i = 0; i < dst_vars.size(); ++i) { - VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i]; - AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback, - is_backward); - } - } + VLOG(3) << "Static RNN input sequence length = " << seq_len; + StepScopes scopes = CreateStepScopes(dev_ctx, scope, seq_len); + auto reverse = Attr(kReverse); - // (seq_len, shape) -> return [seq_len] + list(shape) - static framework::DDim PrependDims(size_t seq_len, - const framework::DDim &src) { - auto dims = framework::vectorize(src); - dims.insert(dims.begin(), static_cast(seq_len)); - return framework::make_ddim(dims); - } + framework::Executor executor(place); + auto *block = Attr(kStepBlock); - private: - template - static void AccessTensor(const framework::Scope &src_scope, - const std::string &src_var_name, - framework::Scope *dst_scope, - const std::string &dst_var_name, Callback callback, - bool is_backward = false) { - auto *src_var = src_scope.FindVar(src_var_name); - if (is_backward && src_var == nullptr) { - return; - } - PADDLE_ENFORCE(src_var != nullptr, "%s is not found.", src_var_name); - auto &src_tensor = src_var->Get(); + auto *program = block->Program(); + auto ctx = executor.Prepare( + *program, block->ID(), Attr>( + kSkipEagerDeletionVars) /*skip_ref_cnt_vars*/); - auto *dst_var = dst_scope->Var(dst_var_name); - auto *dst_tensor = dst_var->GetMutable(); - callback(src_tensor, dst_tensor); - } + for (size_t i = 0; i < seq_len; ++i) { + size_t seq_offset = reverse ? seq_len - i - 1 : i; + VLOG(3) << "Recurrent operate at the time step " << seq_offset; - template - static void AccessTensor(const framework::Scope &src_scope, - const std::string &src_var_name, - const framework::Scope &dst_scope, - const std::string &dst_var_name, Callback callback, - bool is_backward = false) { - auto *dst_var = dst_scope.FindVar(dst_var_name); - if (is_backward && dst_var == nullptr) { - return; - } - auto *src_var = src_scope.FindVar(src_var_name); - PADDLE_ENFORCE(src_var != nullptr, "%s is not found.", src_var_name); - auto &src_tensor = src_var->Get(); - PADDLE_ENFORCE(dst_var != nullptr, "%s is not found.", dst_var_name); - auto *dst_tensor = dst_var->GetMutable(); - callback(src_tensor, dst_tensor); - } -}; + auto &cur_scope = scopes.CurScope(); -class RecurrentOp : public RecurrentBase { - public: - RecurrentOp(const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : RecurrentBase(type, inputs, outputs, attrs) {} - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - bool has_state = Attr(kHasStates); - auto seq_len = static_cast(this->GetSequenceLength(scope)); - - // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(place); - - VLOG(3) << "Static RNN input sequence length = " << seq_len; - StepScopes scopes = CreateStepScopes(dev_ctx, scope, seq_len); - auto reverse = Attr(kReverse); - - framework::Executor executor(place); - auto *block = Attr(kStepBlock); - - auto *program = block->Program(); - auto ctx = executor.Prepare( - *program, block->ID(), std::vector() /*skip_ref_cnt_vars*/, - true /*force_disable_gc*/); - - for (size_t i = 0; i < seq_len; ++i) { - size_t seq_offset = reverse ? seq_len - i - 1 : i; - VLOG(3) << "Recurrent operate at the time step " << seq_offset; - - auto &cur_scope = scopes.CurScope(); - - // Link outside::input --> inside::input - // inside::input = outside::input[seq_offset: seq_offset+1] - LinkTensorWithCallback( - scope, Inputs(kInputs), &cur_scope, Inputs(kInputs), - [&seq_offset](const framework::Tensor &outside, - framework::Tensor *inside) { - inside->ShareDataWith(outside.Slice(seq_offset, seq_offset + 1)); - auto dims = framework::vectorize(inside->dims()); - dims.erase(dims.begin()); - inside->Resize(framework::make_ddim(dims)); - }); - - if (has_state) { - if (i == 0) { - // Link initial states --> ex_states - LinkTensor(scope, Inputs(kInitialStates), &cur_scope, - Attr>(kExStates)); - } else { - auto &ex_scope = scopes.ExScope(); - // Link ex_scope::state --> cur_scope::ex_state - LinkTensor(ex_scope, Attr>(kStates), - &cur_scope, Attr>(kExStates)); - } - } + // Link outside::input --> inside::input + // inside::input = outside::input[seq_offset: seq_offset+1] + LinkTensorWithCallback( + scope, Inputs(kInputs), &cur_scope, Inputs(kInputs), + [&seq_offset](const framework::Tensor &outside, + framework::Tensor *inside) { + inside->ShareDataWith(outside.Slice(seq_offset, seq_offset + 1)); + auto dims = framework::vectorize(inside->dims()); + dims.erase(dims.begin()); + inside->Resize(framework::make_ddim(dims)); + }); - // Every inputs are linked now, execute! - executor.RunPreparedContext(ctx.get(), &cur_scope, - false /*create_local_scope*/, - true /*create_vars*/, true /* keep_kids */); - - // Copy inside::output -> outside::output - // outside::output[seq_offset: seq_offset + 1] = inside::output - this->LinkTensorWithCallback( - cur_scope, Outputs(kOutputs), scope, Outputs(kOutputs), - [&](const framework::LoDTensor &src_tensor, - framework::LoDTensor *dst_tensor) { - if (i == 0) { // create output tensor at begin - dst_tensor->Resize(PrependDims(seq_len, src_tensor.dims())); - dst_tensor->mutable_data(place, src_tensor.type()); - } - - auto dst_out = dst_tensor->Slice(seq_offset, seq_offset + 1); - // Explicit copy output since the local RNN scope can be destroyed - // early. - framework::TensorCopy(src_tensor, place, dev_ctx, &dst_out); - }); - - scopes.Next(); + if (has_state) { + if (i == 0) { + // Link initial states --> ex_states + LinkTensor(scope, Inputs(kInitialStates), &cur_scope, + Attr>(kExStates)); + } else { + auto &ex_scope = scopes.ExScope(); + // Link ex_scope::state --> cur_scope::ex_state + LinkTensor(ex_scope, Attr>(kStates), + &cur_scope, Attr>(kExStates)); + } } - } - private: - StepScopes CreateStepScopes(const platform::DeviceContext &dev_ctx, - const framework::Scope &scope, - size_t seq_len) const { - auto *var = scope.FindVar(Output(kStepScopes)); - PADDLE_ENFORCE(var != nullptr); - return StepScopes(dev_ctx, scope, var->GetMutable(), - Attr(kIsTrain), seq_len); + // Every inputs are linked now, execute! + executor.RunPreparedContext(ctx.get(), &cur_scope, + false /*create_local_scope*/, + true /*create_vars*/, true /* keep_kids */); + + // Copy inside::output -> outside::output + // outside::output[seq_offset: seq_offset + 1] = inside::output + this->LinkTensorWithCallback( + cur_scope, Outputs(kOutputs), scope, Outputs(kOutputs), + [&](const framework::LoDTensor &src_tensor, + framework::LoDTensor *dst_tensor) { + if (i == 0) { // create output tensor at begin + dst_tensor->Resize(PrependDims(seq_len, src_tensor.dims())); + dst_tensor->mutable_data(place, src_tensor.type()); + } + + auto dst_out = dst_tensor->Slice(seq_offset, seq_offset + 1); + // Explicit copy output since the local RNN scope can be destroyed + // early. + framework::TensorCopy(src_tensor, place, dev_ctx, &dst_out); + }); + + scopes.Next(); } -}; +} -class RecurrentGradOp : public RecurrentBase { - public: - RecurrentGradOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : RecurrentBase(type, inputs, outputs, attrs) {} - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - bool has_state = Attr(kHasStates); - const size_t seq_len = static_cast(GetSequenceLength(scope)); - - // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(place); - - StepScopes scopes = CreateStepScopes(dev_ctx, scope, seq_len); - auto reverse = Attr(kReverse); - - framework::Executor executor(place); - auto *block = Attr(kStepBlock); - auto *program = block->Program(); - auto ctx = executor.Prepare( - *program, block->ID(), std::vector() /*skip_ref_cnt_vars*/, - true /*force_disable_gc*/); - - for (size_t step_id = 0; step_id < seq_len; ++step_id) { - size_t seq_offset = reverse ? step_id : seq_len - step_id - 1; - VLOG(3) << "Recurrent backward operate at the time step " << seq_offset; - auto &cur_scope = scopes.CurScope(); - - // Link outside::output_grads --> inside::output_grads - // inside::output_grad = outside::output_grad[seq_offset:seq_offset+1] - LinkTensorWithCallback( - scope, Inputs(kOutputGrads), &cur_scope, Inputs(kOutputGrads), - [&](const framework::Tensor &outside, framework::Tensor *inside) { - inside->ShareDataWith(outside.Slice(seq_offset, seq_offset + 1)); - auto dims = framework::vectorize(inside->dims()); - dims.erase(dims.begin()); - inside->Resize(framework::make_ddim(dims)); - }, - true /*is_backward*/); - auto og_set = List2Set(Inputs(kOutputGrads)); - - if (VLOG_IS_ON(10)) { - std::ostringstream sout; - std::copy(og_set.begin(), og_set.end(), - std::ostream_iterator(sout, ",")); - VLOG(10) << " RNN output gradients = [" << sout.str() << "]"; - } +StepScopes RecurrentOp::CreateStepScopes(const platform::DeviceContext &dev_ctx, + const framework::Scope &scope, + size_t seq_len) const { + auto *var = scope.FindVar(Output(kStepScopes)); + PADDLE_ENFORCE(var != nullptr); + return StepScopes(dev_ctx, scope, var->GetMutable(), + Attr(kIsTrain), seq_len); +} - if (has_state) { - // Link states - // if cur_scope::cur_state_grad in out_grads: - // cur_scope::cur_state_grad += ex_scope::ex_state_grad - // else: - // ex_scope::ex_state_grad --> cur_scope::cur_state_grad - if (step_id != 0) { // not at beginning - auto &ex_scope = scopes.ExScope(); - auto ex_state_grads = - GradVarLists(Attr>(kExStates)); - auto cur_state_grads = - GradVarLists(Attr>(kStates)); - - PADDLE_ENFORCE_EQ(ex_state_grads.size(), cur_state_grads.size()); - for (size_t i = 0; i < ex_state_grads.size(); ++i) { - auto &cur_grad = cur_state_grads[i]; - auto &ex_grad = ex_state_grads[i]; - auto &ex_tensor = - ex_scope.FindVar(ex_grad)->Get(); - - VLOG(10) << " RNN link " << cur_grad << " from " << ex_grad; - auto *cur_grad_var = cur_scope.Var(cur_grad); - auto cur_grad_tensor = - cur_grad_var->GetMutable(); - framework::TensorCopy(ex_tensor, place, dev_ctx, cur_grad_tensor); - } +RecurrentGradOp::RecurrentGradOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : RecurrentBase(type, inputs, outputs, attrs) {} + +void RecurrentGradOp::RunImpl(const framework::Scope &scope, + const platform::Place &place) const { + bool has_state = Attr(kHasStates); + const size_t seq_len = static_cast(GetSequenceLength(scope)); + + // get device context from pool + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); + + StepScopes scopes = CreateStepScopes(dev_ctx, scope, seq_len); + auto reverse = Attr(kReverse); + + framework::Executor executor(place); + auto *block = Attr(kStepBlock); + auto *program = block->Program(); + auto ctx = executor.Prepare( + *program, block->ID(), Attr>( + kSkipEagerDeletionVars) /*skip_ref_cnt_vars*/); + + for (size_t step_id = 0; step_id < seq_len; ++step_id) { + size_t seq_offset = reverse ? step_id : seq_len - step_id - 1; + VLOG(3) << "Recurrent backward operate at the time step " << seq_offset; + auto &cur_scope = scopes.CurScope(); + + // Link outside::output_grads --> inside::output_grads + // inside::output_grad = outside::output_grad[seq_offset:seq_offset+1] + LinkTensorWithCallback( + scope, Inputs(kOutputGrads), &cur_scope, Inputs(kOutputGrads), + [&](const framework::Tensor &outside, framework::Tensor *inside) { + inside->ShareDataWith(outside.Slice(seq_offset, seq_offset + 1)); + auto dims = framework::vectorize(inside->dims()); + dims.erase(dims.begin()); + inside->Resize(framework::make_ddim(dims)); + }, + true /*is_backward*/); + auto og_set = List2Set(Inputs(kOutputGrads)); + + if (VLOG_IS_ON(10)) { + std::ostringstream sout; + std::copy(og_set.begin(), og_set.end(), + std::ostream_iterator(sout, ",")); + VLOG(10) << " RNN output gradients = [" << sout.str() << "]"; + } + + if (has_state) { + // Link states + // if cur_scope::cur_state_grad in out_grads: + // cur_scope::cur_state_grad += ex_scope::ex_state_grad + // else: + // ex_scope::ex_state_grad --> cur_scope::cur_state_grad + if (step_id != 0) { // not at beginning + auto &ex_scope = scopes.ExScope(); + auto ex_state_grads = + GradVarLists(Attr>(kExStates)); + auto cur_state_grads = + GradVarLists(Attr>(kStates)); + + PADDLE_ENFORCE_EQ(ex_state_grads.size(), cur_state_grads.size()); + for (size_t i = 0; i < ex_state_grads.size(); ++i) { + auto &cur_grad = cur_state_grads[i]; + auto &ex_grad = ex_state_grads[i]; + auto &ex_tensor = + ex_scope.FindVar(ex_grad)->Get(); + + VLOG(10) << " RNN link " << cur_grad << " from " << ex_grad; + auto *cur_grad_var = cur_scope.Var(cur_grad); + auto cur_grad_tensor = + cur_grad_var->GetMutable(); + framework::TensorCopy(ex_tensor, place, dev_ctx, cur_grad_tensor); } } + } - VLOG(5) << "Recurrent memory linking finished "; - // Run step block with cur_scope - executor.RunPreparedContext(ctx.get(), &cur_scope, - false /*create_local_scope*/, - true /*create_vars*/, true /* keep_kids */); + VLOG(5) << "Recurrent memory linking finished "; + // Run step block with cur_scope + executor.RunPreparedContext(ctx.get(), &cur_scope, + false /*create_local_scope*/, + true /*create_vars*/, true /* keep_kids */); - VLOG(5) << "executor.Run finished "; + VLOG(5) << "executor.Run finished "; - auto local_var_names = LocalVarNames(cur_scope); + auto local_var_names = LocalVarNames(cur_scope); - // Accumulate params - // if (step == 0): - // outside::param_grad = 0.0 - // outside::param_grad += inside::param_grad - { - auto &pg_names = Outputs(kParamGrads); - auto &p_names = Inputs(kParameters); - PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size()); + // Accumulate params + // if (step == 0): + // outside::param_grad = 0.0 + // outside::param_grad += inside::param_grad + { + auto &pg_names = Outputs(kParamGrads); + auto &p_names = Inputs(kParameters); + PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size()); - for (size_t param_id = 0; param_id < pg_names.size(); ++param_id) { - auto inside_grad_name = framework::GradVarName(p_names[param_id]); + for (size_t param_id = 0; param_id < pg_names.size(); ++param_id) { + auto inside_grad_name = framework::GradVarName(p_names[param_id]); - // If does not compute gradient of that variable inside rnn, just - // continue - if (local_var_names.find(inside_grad_name) == local_var_names.end()) { - continue; - } + // If does not compute gradient of that variable inside rnn, just + // continue + if (local_var_names.find(inside_grad_name) == local_var_names.end()) { + continue; + } - // zero gradient variable in step 0 - if (step_id == 0) { - auto &inside_tensor = cur_scope.FindVar(inside_grad_name) - ->Get(); - framework::AttributeMap attrs; - attrs["dtype"] = inside_tensor.type(); - attrs["shape"] = framework::vectorize2int(inside_tensor.dims()); - attrs["value"] = 0.0f; - - auto zero_op = framework::OpRegistry::CreateOp( - "fill_constant", framework::VariableNameMap{}, - {{"Out", {pg_names[param_id]}}}, attrs); - zero_op->Run(scope, place); - } + // zero gradient variable in step 0 + if (step_id == 0) { + auto &inside_tensor = + cur_scope.FindVar(inside_grad_name)->Get(); + framework::AttributeMap attrs; + attrs["dtype"] = inside_tensor.type(); + attrs["shape"] = framework::vectorize2int(inside_tensor.dims()); + attrs["value"] = 0.0f; + + auto zero_op = framework::OpRegistry::CreateOp( + "fill_constant", framework::VariableNameMap{}, + {{"Out", {pg_names[param_id]}}}, attrs); + zero_op->Run(scope, place); + } - auto new_inside_name = cur_scope.Rename(inside_grad_name); + auto new_inside_name = cur_scope.Rename(inside_grad_name); - // sum gradient - auto sum_op = framework::OpRegistry::CreateOp( - "sum", {{"X", {pg_names[param_id], new_inside_name}}}, - {{"Out", {pg_names[param_id]}}}, - framework::AttributeMap{{"use_mkldnn", {false}}}); - sum_op->Run(cur_scope, place); + // sum gradient + auto sum_op = framework::OpRegistry::CreateOp( + "sum", {{"X", {pg_names[param_id], new_inside_name}}}, + {{"Out", {pg_names[param_id]}}}, + framework::AttributeMap{{"use_mkldnn", {false}}}); + sum_op->Run(cur_scope, place); - cur_scope.Rename(new_inside_name, inside_grad_name); - } + cur_scope.Rename(new_inside_name, inside_grad_name); } - VLOG(5) << "Accumulate Parameter finished "; - - // Copy input gradient from inside to outside - // outside::input_grad[seq_offset: seq_offset + 1] = inside::input_grad - LinkTensorWithCallback( - cur_scope, GradVarLists(Inputs(kInputs)), scope, Outputs(kInputGrads), - [&](const framework::LoDTensor &inside, - framework::LoDTensor *outside) { - if (inside.memory_size() == 0) { // IG is not created. - return; - } - if (step_id == 0) { // alloc memory - outside->Resize(PrependDims(seq_len, inside.dims())); + } + VLOG(5) << "Accumulate Parameter finished "; + + // Copy input gradient from inside to outside + // outside::input_grad[seq_offset: seq_offset + 1] = inside::input_grad + LinkTensorWithCallback( + cur_scope, GradVarLists(Inputs(kInputs)), scope, Outputs(kInputGrads), + [&](const framework::LoDTensor &inside, framework::LoDTensor *outside) { + if (inside.memory_size() == 0) { // IG is not created. + return; + } + if (step_id == 0) { // alloc memory + outside->Resize(PrependDims(seq_len, inside.dims())); + outside->mutable_data(place, inside.type()); + } + + auto dst = outside->Slice(seq_offset, seq_offset + 1); + framework::TensorCopy(inside, place, dev_ctx, &dst); + }, + true /*is_backward*/); + VLOG(5) << "Link outside gradient finished "; + + if (has_state) { + if (step_id + 1 == seq_len) { // at_end + // copy initialize states gradient from inside to outside + LinkTensorWithCallback( + cur_scope, GradVarLists(Attr>(kExStates)), + scope, Outputs(kInitStateGrads), + [&](const framework::LoDTensor &inside, + framework::LoDTensor *outside) { + outside->Resize(inside.dims()); outside->mutable_data(place, inside.type()); - } - - auto dst = outside->Slice(seq_offset, seq_offset + 1); - framework::TensorCopy(inside, place, dev_ctx, &dst); - }, - true /*is_backward*/); - VLOG(5) << "Link outside gradient finished "; - - if (has_state) { - if (step_id + 1 == seq_len) { // at_end - // copy initialize states gradient from inside to outside - LinkTensorWithCallback( - cur_scope, - GradVarLists(Attr>(kExStates)), scope, - Outputs(kInitStateGrads), - [&](const framework::LoDTensor &inside, - framework::LoDTensor *outside) { - outside->Resize(inside.dims()); - outside->mutable_data(place, inside.type()); - framework::TensorCopy(inside, place, dev_ctx, outside); - }, - true /*is_backward*/); - VLOG(5) << "Link initialize state gradient finished "; - } + framework::TensorCopy(inside, place, dev_ctx, outside); + }, + true /*is_backward*/); + VLOG(5) << "Link initialize state gradient finished "; } - scopes.Next(); } - // Delete the scope of StepScopes - auto *var = scope.FindVar(Input(kStepScopes)); - PADDLE_ENFORCE(var != nullptr); - auto *step_scopes = var->GetMutable(); - ClearStepScopes(dev_ctx, const_cast(&scope), - step_scopes); + scopes.Next(); } + // Delete the scope of StepScopes + auto *var = scope.FindVar(Input(kStepScopes)); + PADDLE_ENFORCE(var != nullptr); + auto *step_scopes = var->GetMutable(); + ClearStepScopes(dev_ctx, const_cast(&scope), step_scopes); +} - private: - StepScopes CreateStepScopes(const platform::DeviceContext &dev_ctx, - const framework::Scope &scope, - size_t seq_len) const { - auto *var = scope.FindVar(Input(kStepScopes)); - PADDLE_ENFORCE(var != nullptr); - return StepScopes(dev_ctx, scope, var->GetMutable(), - Attr(kIsTrain), seq_len, true /*is_backward*/); - } +StepScopes RecurrentGradOp::CreateStepScopes( + const platform::DeviceContext &dev_ctx, const framework::Scope &scope, + size_t seq_len) const { + auto *var = scope.FindVar(Input(kStepScopes)); + PADDLE_ENFORCE(var != nullptr); + return StepScopes(dev_ctx, scope, var->GetMutable(), + Attr(kIsTrain), seq_len, true /*is_backward*/); +} - std::unordered_set List2Set( - const std::vector &list) const { - std::unordered_set local_var_name_set; - local_var_name_set.reserve(list.size()); - for (auto &each : list) { - local_var_name_set.insert(each); - } - return local_var_name_set; +std::unordered_set RecurrentGradOp::List2Set( + const std::vector &list) const { + std::unordered_set local_var_name_set; + local_var_name_set.reserve(list.size()); + for (auto &each : list) { + local_var_name_set.insert(each); } + return local_var_name_set; +} - std::unordered_set LocalVarNames( - const framework::Scope &scope) const { - return this->List2Set(scope.LocalVarNames()); - } - static std::vector GradVarLists( - const std::vector &var_names) { - std::vector retv; - retv.reserve(var_names.size()); - std::transform(var_names.begin(), var_names.end(), std::back_inserter(retv), - framework::GradVarName); - return retv; - } -}; +std::unordered_set RecurrentGradOp::LocalVarNames( + const framework::Scope &scope) const { + return this->List2Set(scope.LocalVarNames()); +} +std::vector RecurrentGradOp::GradVarLists( + const std::vector &var_names) { + std::vector retv; + retv.reserve(var_names.size()); + std::transform(var_names.begin(), var_names.end(), std::back_inserter(retv), + framework::GradVarName); + return retv; +} class RecurrentOpProtoMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput(kInputs, "rnn inputs").AsDuplicable(); - AddInput(kInitialStates, "rnn initial states").AsDuplicable(); - AddInput(kParameters, + AddInput(RecurrentBase::kInputs, "rnn inputs").AsDuplicable(); + AddInput(RecurrentBase::kInitialStates, "rnn initial states") + .AsDuplicable(); + AddInput(RecurrentBase::kParameters, "Parameters are used by step block as its input. However, the " "input is not a sequence tensor. Every time step, each operator " "in step block just use the parameter directly.") .AsDuplicable(); - AddOutput(kOutputs, + AddOutput(RecurrentBase::kOutputs, "The output sequence of RNN. The sequence length must be same.") .AsDuplicable(); - AddOutput(kStepScopes, + AddOutput(RecurrentBase::kStepScopes, "StepScopes contain all local variables in each time step."); - AddAttr(kHasStates, "Whether has states.").SetDefault(false); - AddAttr>(kExStates, - string::Sprintf( - R"DOC(The ex-state variable names. + AddAttr(RecurrentBase::kHasStates, "Whether has states.") + .SetDefault(false); + AddAttr>( + RecurrentBase::kExStates, + string::Sprintf( + R"DOC(The ex-state variable names. The ex-state means the state value in the ex-timestep or the previous time step [%s, %s, %s] must be the same order)DOC", - kExStates, kStates, kInitStateGrads)); + RecurrentBase::kExStates, RecurrentBase::kStates, + RecurrentBase::kInitStateGrads)); AddAttr>( - kStates, + RecurrentBase::kStates, string::Sprintf( "The state variable names. [%s, %s, %s] must be the same order", - kExStates, kStates, kInitStateGrads)); - AddAttr(kStepBlock, "The step block inside RNN"); - AddAttr(kReverse, R"DOC(Calculate RNN reversely or not. + RecurrentBase::kExStates, RecurrentBase::kStates, + RecurrentBase::kInitStateGrads)); + AddAttr(RecurrentBase::kStepBlock, + "The step block inside RNN"); + AddAttr(RecurrentBase::kReverse, R"DOC(Calculate RNN reversely or not. By default reverse=False Assume the input data is [A, B, C, D] @@ -617,7 +526,12 @@ if reverse is True v v v v o o o o )DOC").SetDefault(false); - AddAttr(kIsTrain, "").SetDefault(true); + AddAttr(RecurrentBase::kIsTrain, "").SetDefault(true); + AddAttr>(RecurrentBase::kSkipEagerDeletionVars, + "Vars that would skip eager deletion." + "Users should not set this manually.") + .SetDefault(std::vector()); + AddComment(R"DOC( Static Length Recurrent Operator. @@ -643,7 +557,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker { } for (auto &output_param : this->OutputNames()) { - if (output_param == kStepScopes) { + if (output_param == RecurrentBase::kStepScopes) { grad->SetInput(output_param, this->Output(output_param)); grad->SetInput(framework::GradVarName(output_param), this->Output(output_param)); @@ -654,7 +568,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker { } } grad->SetAttrMap(this->Attrs()); - grad->SetBlockAttr(kStepBlock, grad_block_[0]); + grad->SetBlockAttr(RecurrentBase::kStepBlock, grad_block_[0]); return std::unique_ptr(grad); } @@ -663,46 +577,55 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker { class RecurrentGradOpShapeInference : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *ctx) const override { - std::vector output{kOutputs}; + std::vector output{RecurrentBase::kOutputs}; // In some case the kInitialStates is empty. // If the kInitialStates is empty, all the states should be empty. - if (!ctx->HasInputs(kInitialStates)) { + if (!ctx->HasInputs(RecurrentBase::kInitialStates)) { PADDLE_ENFORCE_EQ( - ctx->Attrs().Get>(kExStates).size(), 0, - "The Attr(%s) should be empty.", kExStates); + ctx->Attrs() + .Get>(RecurrentBase::kExStates) + .size(), + 0, "The Attr(%s) should be empty.", RecurrentBase::kExStates); PADDLE_ENFORCE_EQ( - ctx->Attrs().Get>(kStates).size(), 0, - "The Attr(%s) should be empty.", kStates); + ctx->Attrs() + .Get>(RecurrentBase::kStates) + .size(), + 0, "The Attr(%s) should be empty.", RecurrentBase::kStates); } - PADDLE_ENFORCE(ctx->HasInputs(kInputs), - "The input(%s) should not be empty.", kInputs); - PADDLE_ENFORCE(ctx->HasInputs(kOutputs), - "The input(%s) should not be empty.", kOutputs); + PADDLE_ENFORCE(ctx->HasInputs(RecurrentBase::kInputs), + "The input(%s) should not be empty.", + RecurrentBase::kInputs); + PADDLE_ENFORCE(ctx->HasInputs(RecurrentBase::kOutputs), + "The input(%s) should not be empty.", + RecurrentBase::kOutputs); // In some case the kInitialStates is empty. - if (ctx->HasInputs(kInitialStates)) { - PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(kInitialStates)), + if (ctx->HasInputs(RecurrentBase::kInitialStates)) { + PADDLE_ENFORCE(ctx->HasOutputs( + framework::GradVarName(RecurrentBase::kInitialStates)), "The output of(%s) should not be empty.", - framework::GradVarName(kInitialStates)); - ctx->SetOutputsDim(framework::GradVarName(kInitialStates), - ctx->GetInputsDim(kInitialStates)); + framework::GradVarName(RecurrentBase::kInitialStates)); + ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kInitialStates), + ctx->GetInputsDim(RecurrentBase::kInitialStates)); } - PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(kInputs)), - "The output of(%s) should not be empty.", - framework::GradVarName(kInputs)); - ctx->SetOutputsDim(framework::GradVarName(kInputs), - ctx->GetInputsDim(kInputs)); + PADDLE_ENFORCE( + ctx->HasOutputs(framework::GradVarName(RecurrentBase::kInputs)), + "The output of(%s) should not be empty.", + framework::GradVarName(RecurrentBase::kInputs)); + ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kInputs), + ctx->GetInputsDim(RecurrentBase::kInputs)); // In some case the kParameters is empty. - if (ctx->HasInputs(kParameters)) { - PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(kParameters)), - "The output of(%s) should not be empty.", - framework::GradVarName(kParameters)); - ctx->SetOutputsDim(framework::GradVarName(kParameters), - ctx->GetInputsDim(kParameters)); + if (ctx->HasInputs(RecurrentBase::kParameters)) { + PADDLE_ENFORCE( + ctx->HasOutputs(framework::GradVarName(RecurrentBase::kParameters)), + "The output of(%s) should not be empty.", + framework::GradVarName(RecurrentBase::kParameters)); + ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kParameters), + ctx->GetInputsDim(RecurrentBase::kParameters)); } } }; diff --git a/paddle/fluid/operators/recurrent_op.h b/paddle/fluid/operators/recurrent_op.h new file mode 100644 index 0000000000..8da0fcacee --- /dev/null +++ b/paddle/fluid/operators/recurrent_op.h @@ -0,0 +1,226 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +// StepScopes manages scopes inside RNN. +// StepScopes::CurScope() get the current scope +// StepScopes::ExScope() get the ex-scope, or scope in previous time step. +// StepScopes::Next() move to next time step. +// +// if is_train = False, then +// there are two scopes for the RNN and just support forward. +// else +// the len(scopes) == seq_len +// +// if is_backward = True, then +// reversely access scopes +// else +// access scopes from begin to end. +class StepScopes { + public: + StepScopes(const platform::DeviceContext &dev_ctx, + const framework::Scope &parent, + std::vector *scopes, bool is_train, + size_t seq_len, bool is_backward = false); + + framework::Scope &CurScope(); + + framework::Scope &ExScope(); + + void Next(); + + private: + framework::Scope &GetScope(size_t scope_id) const; + + size_t counter_; + std::vector *scopes_; + bool is_train_; + bool is_backward_; +}; + +// Base class for RecurrentOp/RecurrentGradOp +// Some common protected functions for RecurrentOp/RecurrentGradOp +class RecurrentBase : public framework::OperatorBase { + public: + static const char kInputs[]; + static const char kInitialStates[]; + static const char kParameters[]; + static const char kOutputs[]; + static const char kStepScopes[]; + static const char kHasStates[]; + static const char kExStates[]; + static const char kStates[]; + static const char kStepBlock[]; + static const char kReverse[]; + static const char kIsTrain[]; + static const char kSkipEagerDeletionVars[]; + static const char kInputGrads[]; + static const char kOutputGrads[]; + static const char kParamGrads[]; + static const char kInitStateGrads[]; + + RecurrentBase(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs); + + protected: + // Get SequenceLength from Scope + // The sequence length is got from input tensor. The input tensor's + // dimension should be [SEQ_LEN, ..., ...]. The first of the tensor's shape + // is SEQ_LEN. The second of the tensor's shape could be the batch size or + // nested sequence length. + int64_t GetSequenceLength(const framework::Scope &scope) const; + + // for src_tensor, dst_tensor in zip(map(src_scope.FindVar, src_vars), + // map(dst_scope.Var, dst_vars)): + // dst_tensor.ShareDataWith(src_tensor) + static void LinkTensor(const framework::Scope &src_scope, + const std::vector &src_vars, + framework::Scope *dst_scope, + const std::vector &dst_vars); + + // for src_tensor, dst_tensor in zip(map(src_scope.FindVar, src_vars), + // map(dst_scope.Var, dst_vars)): + // callback(src_tensor, &dst_tensor) + template + static void LinkTensorWithCallback(const framework::Scope &src_scope, + const std::vector &src_vars, + framework::Scope *dst_scope, + const std::vector &dst_vars, + Callback callback, + bool is_backward = false) { + PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size()); + for (size_t i = 0; i < dst_vars.size(); ++i) { + VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i]; + AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback, + is_backward); + } + } + + // for src_tensor, dst_tensor in zip(map(src_scope.FindVar, src_vars), + // map(dst_scope.FindVar, dst_vars)): + // callback(src_tensor, &dst_tensor) + template + static void LinkTensorWithCallback(const framework::Scope &src_scope, + const std::vector &src_vars, + const framework::Scope &dst_scope, + const std::vector &dst_vars, + Callback callback, + bool is_backward = false) { + PADDLE_ENFORCE_EQ(src_vars.size(), dst_vars.size()); + for (size_t i = 0; i < dst_vars.size(); ++i) { + VLOG(10) << "Link " << src_vars[i] << " to " << dst_vars[i]; + AccessTensor(src_scope, src_vars[i], dst_scope, dst_vars[i], callback, + is_backward); + } + } + + // (seq_len, shape) -> return [seq_len] + list(shape) + static framework::DDim PrependDims(size_t seq_len, + const framework::DDim &src); + + private: + template + static void AccessTensor(const framework::Scope &src_scope, + const std::string &src_var_name, + framework::Scope *dst_scope, + const std::string &dst_var_name, Callback callback, + bool is_backward = false) { + auto *src_var = src_scope.FindVar(src_var_name); + if (is_backward && src_var == nullptr) { + return; + } + PADDLE_ENFORCE(src_var != nullptr, "%s is not found.", src_var_name); + auto &src_tensor = src_var->Get(); + + auto *dst_var = dst_scope->Var(dst_var_name); + auto *dst_tensor = dst_var->GetMutable(); + callback(src_tensor, dst_tensor); + } + + template + static void AccessTensor(const framework::Scope &src_scope, + const std::string &src_var_name, + const framework::Scope &dst_scope, + const std::string &dst_var_name, Callback callback, + bool is_backward = false) { + auto *dst_var = dst_scope.FindVar(dst_var_name); + if (is_backward && dst_var == nullptr) { + return; + } + auto *src_var = src_scope.FindVar(src_var_name); + PADDLE_ENFORCE(src_var != nullptr, "%s is not found.", src_var_name); + auto &src_tensor = src_var->Get(); + PADDLE_ENFORCE(dst_var != nullptr, "%s is not found.", dst_var_name); + auto *dst_tensor = dst_var->GetMutable(); + callback(src_tensor, dst_tensor); + } +}; + +class RecurrentOp : public RecurrentBase { + public: + RecurrentOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs); + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override; + + private: + StepScopes CreateStepScopes(const platform::DeviceContext &dev_ctx, + const framework::Scope &scope, + size_t seq_len) const; +}; + +class RecurrentGradOp : public RecurrentBase { + public: + RecurrentGradOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs); + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override; + + StepScopes CreateStepScopes(const platform::DeviceContext &dev_ctx, + const framework::Scope &scope, + size_t seq_len) const; + + std::unordered_set List2Set( + const std::vector &list) const; + + std::unordered_set LocalVarNames( + const framework::Scope &scope) const; + + static std::vector GradVarLists( + const std::vector &var_names); +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/rnn_memory_helper_op.cc b/paddle/fluid/operators/rnn_memory_helper_op.cc index b00cc07dea..9f652480a2 100644 --- a/paddle/fluid/operators/rnn_memory_helper_op.cc +++ b/paddle/fluid/operators/rnn_memory_helper_op.cc @@ -91,6 +91,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase { auto in_grad_var_name = Output(framework::GradVarName("X")); auto *in_grad_var = scope.FindVar(in_grad_var_name); + PADDLE_ENFORCE(in_grad_var != nullptr, "Cannot find in_grad_var in scope, name is %s", in_grad_var_name); diff --git a/paddle/fluid/string/string_helper.h b/paddle/fluid/string/string_helper.h index e2ded402b1..cc09088c7e 100644 --- a/paddle/fluid/string/string_helper.h +++ b/paddle/fluid/string/string_helper.h @@ -119,16 +119,18 @@ std::vector split_string(const std::string& str) { return list; } -template -std::string join_strings(const std::vector& strs, char delim) { +template +std::string join_strings(const Container& strs, char delim) { std::string str; - for (size_t i = 0; i < strs.size(); i++) { + int i = 0; + for (auto& elem : strs) { if (i > 0) { str += delim; } - str += boost::lexical_cast(strs[i]); + str += boost::lexical_cast(elem); + ++i; } return str; diff --git a/python/paddle/fluid/tests/unittests/test_eager_deletion_padding_rnn.py b/python/paddle/fluid/tests/unittests/test_eager_deletion_padding_rnn.py new file mode 100644 index 0000000000..d16e7a95a6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_eager_deletion_padding_rnn.py @@ -0,0 +1,657 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle.fluid.layers as layers +import time + +from paddle.fluid import ParamAttr +from paddle.fluid.contrib.layers import basic_lstm +from paddle.fluid.executor import Executor +from paddle.fluid.layers.control_flow import StaticRNN as PaddingRNN + + +class RnnConfig(object): + def __init__(self, model_type, rnn_model): + self.model_type = model_type + self.rnn_model = rnn_model + + self.vocab_size = 10000 + if self.model_type == "test": + self.num_layers = 1 + self.batch_size = 2 + self.hidden_size = 10 + self.num_steps = 3 + self.init_scale = 0.1 + self.max_grad_norm = 5.0 + self.epoch_start_decay = 1 + self.max_epoch = 1 + self.dropout = 0.0 + self.lr_decay = 0.5 + self.base_learning_rate = 1.0 + elif self.model_type == "small": + self.num_layers = 2 + self.batch_size = 20 + self.hidden_size = 200 + self.num_steps = 20 + self.init_scale = 0.1 + self.max_grad_norm = 5.0 + self.epoch_start_decay = 4 + self.max_epoch = 13 + self.dropout = 0.0 + self.lr_decay = 0.5 + self.base_learning_rate = 1.0 + elif self.model_type == "medium": + self.num_layers = 2 + self.batch_size = 20 + self.hidden_size = 650 + self.num_steps = 35 + self.init_scale = 0.05 + self.max_grad_norm = 5.0 + self.epoch_start_decay = 6 + self.max_epoch = 39 + self.dropout = 0.5 + self.lr_decay = 0.8 + self.base_learning_rate = 1.0 + elif self.model_type == "large": + self.num_layers = 2 + self.batch_size = 20 + self.hidden_size = 1500 + self.num_steps = 35 + self.init_scale = 0.04 + self.max_grad_norm = 10.0 + self.epoch_start_decay = 14 + self.max_epoch = 55 + self.dropout = 0.65 + self.lr_decay = 1.0 / 1.15 + self.base_learning_rate = 1.0 + else: + raise ValueError('Unsupported model_type.') + + if rnn_model not in ('static', 'padding', 'cudnn', 'basic_lstm'): + raise ValueError('Unsupported rnn_model.') + + self.batch_size = 12 + self.max_epoch = 3 + self.random_seed = 123 + + +# Fake data reader for test +class Reader(object): + def get_data_iter(self, rnn_config): + for i in range(rnn_config.max_epoch): + x = np.zeros( + shape=(rnn_config.batch_size, rnn_config.num_steps), + dtype='int64') + y = np.ones( + shape=(rnn_config.batch_size, rnn_config.num_steps), + dtype='int64') + yield (x, y) + + +# Model from PaddleNLP/models/language_model/lm_model.py in Paddle Models repo +def lm_model(hidden_size, + vocab_size, + batch_size, + num_layers=2, + num_steps=20, + init_scale=0.1, + dropout=None, + rnn_model='static', + use_py_reader=False): + def padding_rnn(input_embedding, len=3, init_hidden=None, init_cell=None): + weight_1_arr = [] + weight_2_arr = [] + bias_arr = [] + hidden_array = [] + cell_array = [] + mask_array = [] + for i in range(num_layers): + weight_1 = layers.create_parameter( + [hidden_size * 2, hidden_size * 4], + dtype="float32", + name="fc_weight1_" + str(i), + default_initializer=fluid.initializer.UniformInitializer( + low=-init_scale, high=init_scale)) + weight_1_arr.append(weight_1) + bias_1 = layers.create_parameter( + [hidden_size * 4], + dtype="float32", + name="fc_bias1_" + str(i), + default_initializer=fluid.initializer.Constant(0.0)) + bias_arr.append(bias_1) + + pre_hidden = layers.slice( + init_hidden, axes=[0], starts=[i], ends=[i + 1]) + pre_cell = layers.slice( + init_cell, axes=[0], starts=[i], ends=[i + 1]) + pre_hidden = layers.reshape(pre_hidden, shape=[-1, hidden_size]) + pre_cell = layers.reshape(pre_cell, shape=[-1, hidden_size]) + hidden_array.append(pre_hidden) + cell_array.append(pre_cell) + + input_embedding = layers.transpose(input_embedding, perm=[1, 0, 2]) + rnn = PaddingRNN() + + with rnn.step(): + input = rnn.step_input(input_embedding) + for k in range(num_layers): + pre_hidden = rnn.memory(init=hidden_array[k]) + pre_cell = rnn.memory(init=cell_array[k]) + weight_1 = weight_1_arr[k] + bias = bias_arr[k] + + nn = layers.concat([input, pre_hidden], 1) + gate_input = layers.matmul(x=nn, y=weight_1) + + gate_input = layers.elementwise_add(gate_input, bias) + i = layers.slice( + gate_input, axes=[1], starts=[0], ends=[hidden_size]) + j = layers.slice( + gate_input, + axes=[1], + starts=[hidden_size], + ends=[hidden_size * 2]) + f = layers.slice( + gate_input, + axes=[1], + starts=[hidden_size * 2], + ends=[hidden_size * 3]) + o = layers.slice( + gate_input, + axes=[1], + starts=[hidden_size * 3], + ends=[hidden_size * 4]) + + c = pre_cell * layers.sigmoid(f) + layers.sigmoid( + i) * layers.tanh(j) + m = layers.tanh(c) * layers.sigmoid(o) + + rnn.update_memory(pre_hidden, m) + rnn.update_memory(pre_cell, c) + + rnn.step_output(m) + rnn.step_output(c) + + input = m + + if dropout != None and dropout > 0.0: + input = layers.dropout( + input, + dropout_prob=dropout, + dropout_implementation='upscale_in_train') + + rnn.step_output(input) + rnnout = rnn() + + last_hidden_array = [] + last_cell_array = [] + real_res = rnnout[-1] + for i in range(num_layers): + m = rnnout[i * 2] + c = rnnout[i * 2 + 1] + m.stop_gradient = True + c.stop_gradient = True + last_h = layers.slice( + m, axes=[0], starts=[num_steps - 1], ends=[num_steps]) + last_hidden_array.append(last_h) + last_c = layers.slice( + c, axes=[0], starts=[num_steps - 1], ends=[num_steps]) + last_cell_array.append(last_c) + real_res = layers.transpose(x=real_res, perm=[1, 0, 2]) + last_hidden = layers.concat(last_hidden_array, 0) + last_cell = layers.concat(last_cell_array, 0) + + return real_res, last_hidden, last_cell + + def encoder_static(input_embedding, len=3, init_hidden=None, + init_cell=None): + + weight_1_arr = [] + weight_2_arr = [] + bias_arr = [] + hidden_array = [] + cell_array = [] + mask_array = [] + for i in range(num_layers): + weight_1 = layers.create_parameter( + [hidden_size * 2, hidden_size * 4], + dtype="float32", + name="fc_weight1_" + str(i), + default_initializer=fluid.initializer.UniformInitializer( + low=-init_scale, high=init_scale)) + weight_1_arr.append(weight_1) + bias_1 = layers.create_parameter( + [hidden_size * 4], + dtype="float32", + name="fc_bias1_" + str(i), + default_initializer=fluid.initializer.Constant(0.0)) + bias_arr.append(bias_1) + + pre_hidden = layers.slice( + init_hidden, axes=[0], starts=[i], ends=[i + 1]) + pre_cell = layers.slice( + init_cell, axes=[0], starts=[i], ends=[i + 1]) + pre_hidden = layers.reshape( + pre_hidden, shape=[-1, hidden_size], inplace=True) + pre_cell = layers.reshape( + pre_cell, shape=[-1, hidden_size], inplace=True) + hidden_array.append(pre_hidden) + cell_array.append(pre_cell) + + res = [] + sliced_inputs = layers.split( + input_embedding, num_or_sections=len, dim=1) + + for index in range(len): + input = sliced_inputs[index] + input = layers.reshape(input, shape=[-1, hidden_size], inplace=True) + for k in range(num_layers): + pre_hidden = hidden_array[k] + pre_cell = cell_array[k] + weight_1 = weight_1_arr[k] + bias = bias_arr[k] + + nn = layers.concat([input, pre_hidden], 1) + gate_input = layers.matmul(x=nn, y=weight_1) + + gate_input = layers.elementwise_add(gate_input, bias) + i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1) + + try: + from paddle.fluid.contrib.layers import fused_elemwise_activation + # fluid.contrib.layers.fused_elemwise_activation can do a fused + # operation, like: + # 1) x + sigmoid(y); x + tanh(y) + # 2) tanh(x + y) + # Now the unary operation supported in this fused op is limit, and + # we will extent this operation to support more unary operations and + # do this kind of fusion automitically in future version of paddle.fluid. + # layers.sigmoid(i) * layers.tanh(j) + tmp0 = fused_elemwise_activation( + x=layers.tanh(j), + y=i, + functor_list=['elementwise_mul', 'sigmoid'], + save_intermediate_out=False) + # pre_cell * layers.sigmoid(f) + tmp1 = fused_elemwise_activation( + x=pre_cell, + y=f, + functor_list=['elementwise_mul', 'sigmoid'], + save_intermediate_out=False) + c = tmp0 + tmp1 + # layers.tanh(c) * layers.sigmoid(o) + m = fused_elemwise_activation( + x=layers.tanh(c), + y=o, + functor_list=['elementwise_mul', 'sigmoid'], + save_intermediate_out=False) + except ImportError: + c = pre_cell * layers.sigmoid(f) + layers.sigmoid( + i) * layers.tanh(j) + m = layers.tanh(c) * layers.sigmoid(o) + + hidden_array[k] = m + cell_array[k] = c + input = m + + if dropout != None and dropout > 0.0: + input = layers.dropout( + input, + dropout_prob=dropout, + dropout_implementation='upscale_in_train') + + res.append(input) + + last_hidden = layers.concat(hidden_array, 1) + last_hidden = layers.reshape( + last_hidden, shape=[-1, num_layers, hidden_size], inplace=True) + last_hidden = layers.transpose(x=last_hidden, perm=[1, 0, 2]) + + last_cell = layers.concat(cell_array, 1) + last_cell = layers.reshape( + last_cell, shape=[-1, num_layers, hidden_size]) + last_cell = layers.transpose(x=last_cell, perm=[1, 0, 2]) + + real_res = layers.concat(res, 0) + real_res = layers.reshape( + real_res, shape=[len, -1, hidden_size], inplace=True) + real_res = layers.transpose(x=real_res, perm=[1, 0, 2]) + + return real_res, last_hidden, last_cell + + batch_size_each = batch_size + if use_py_reader: + feed_shapes = [[batch_size_each, num_steps, 1], + [batch_size_each * num_steps, 1]] + py_reader = fluid.layers.py_reader( + capacity=16, shapes=feed_shapes, dtypes=['int64', 'int64']) + x, y = fluid.layers.read_file(py_reader) + else: + x = layers.data( + name="x", + shape=[batch_size_each, num_steps, 1], + dtype='int64', + append_batch_size=False) + y = layers.data( + name="y", + shape=[batch_size_each * num_steps, 1], + dtype='int64', + append_batch_size=False) + + init_hidden = layers.data( + name="init_hidden", + shape=[num_layers, batch_size_each, hidden_size], + dtype='float32', + append_batch_size=False) + init_cell = layers.data( + name="init_cell", + shape=[num_layers, batch_size_each, hidden_size], + dtype='float32', + append_batch_size=False) + + init_cell.persistable = True + init_hidden.persistable = True + + init_hidden_reshape = layers.reshape( + init_hidden, shape=[num_layers, -1, hidden_size]) + init_cell_reshape = layers.reshape( + init_cell, shape=[num_layers, -1, hidden_size]) + + x_emb = layers.embedding( + input=x, + size=[vocab_size, hidden_size], + dtype='float32', + is_sparse=False, + param_attr=fluid.ParamAttr( + name='embedding_para', + initializer=fluid.initializer.UniformInitializer( + low=-init_scale, high=init_scale))) + + x_emb = layers.reshape( + x_emb, shape=[-1, num_steps, hidden_size], inplace=True) + if dropout != None and dropout > 0.0: + x_emb = layers.dropout( + x_emb, + dropout_prob=dropout, + dropout_implementation='upscale_in_train') + + if rnn_model == "padding": + rnn_out, last_hidden, last_cell = padding_rnn( + x_emb, + len=num_steps, + init_hidden=init_hidden_reshape, + init_cell=init_cell_reshape) + elif rnn_model == "static": + rnn_out, last_hidden, last_cell = encoder_static( + x_emb, + len=num_steps, + init_hidden=init_hidden_reshape, + init_cell=init_cell_reshape) + elif rnn_model == "cudnn": + x_emb = layers.transpose(x_emb, perm=[1, 0, 2]) + rnn_out, last_hidden, last_cell = layers.lstm( + x_emb, + init_hidden_reshape, + init_cell_reshape, + num_steps, + hidden_size, + num_layers, + is_bidirec=False, + default_initializer=fluid.initializer.UniformInitializer( + low=-init_scale, high=init_scale)) + rnn_out = layers.transpose(rnn_out, perm=[1, 0, 2]) + elif rnn_model == "basic_lstm": + rnn_out, last_hidden, last_cell = basic_lstm( x_emb, init_hidden, init_cell, hidden_size, \ + num_layers=num_layers, batch_first=True, dropout_prob=dropout, \ + param_attr = ParamAttr( initializer=fluid.initializer.UniformInitializer(low=-init_scale, high=init_scale) ), \ + bias_attr = ParamAttr( initializer = fluid.initializer.Constant(0.0) ), \ + forget_bias = 0.0) + else: + print("type not support") + return + + rnn_out = layers.reshape( + rnn_out, shape=[-1, num_steps, hidden_size], inplace=True) + + softmax_weight = layers.create_parameter( + [hidden_size, vocab_size], + dtype="float32", + name="softmax_weight", + default_initializer=fluid.initializer.UniformInitializer( + low=-init_scale, high=init_scale)) + softmax_bias = layers.create_parameter( + [vocab_size], + dtype="float32", + name='softmax_bias', + default_initializer=fluid.initializer.UniformInitializer( + low=-init_scale, high=init_scale)) + + projection = layers.matmul(rnn_out, softmax_weight) + projection = layers.elementwise_add(projection, softmax_bias) + projection = layers.reshape( + projection, shape=[-1, vocab_size], inplace=True) + + loss = layers.softmax_with_cross_entropy( + logits=projection, label=y, soft_label=False) + + loss = layers.reshape(loss, shape=[-1, num_steps], inplace=True) + loss = layers.reduce_mean(loss, dim=[0]) + loss = layers.reduce_sum(loss) + + loss.persistable = True + last_cell.persistable = True + last_hidden.persistable = True + + # This will feed last_hidden, last_cell to init_hidden, init_cell, which + # can be used directly in next batch. This can avoid the fetching of + # last_hidden and last_cell and feeding of init_hidden and init_cell in + # each training step. + layers.assign(input=last_cell, output=init_cell) + layers.assign(input=last_hidden, output=init_hidden) + + feeding_list = ['x', 'y', 'init_hidden', 'init_cell'] + if use_py_reader: + return loss, last_hidden, last_cell, feeding_list, py_reader + else: + return loss, last_hidden, last_cell, feeding_list + + +class EagerDeletionPaddingRnnTest(unittest.TestCase): + def setUp(self): + self.reader = Reader() + + def prepare_program(self, config): + self.main_program = fluid.Program() + self.startup_program = fluid.Program() + self.startup_program.random_seed = config.random_seed + with fluid.program_guard(self.main_program, self.startup_program): + with fluid.unique_name.guard(): + res_vars = lm_model( + config.hidden_size, + config.vocab_size, + config.batch_size, + num_layers=config.num_layers, + num_steps=config.num_steps, + init_scale=config.init_scale, + dropout=config.dropout, + rnn_model=config.rnn_model, + use_py_reader=False) + self.loss, self.last_hidden, self.last_cell, self.feed_order = res_vars + + fluid.clip.set_gradient_clip( + clip=fluid.clip.GradientClipByGlobalNorm( + clip_norm=config.max_grad_norm)) + + self.learning_rate = fluid.layers.create_global_var( + name="learning_rate", + shape=[1], + value=1.0, + dtype='float32', + persistable=True) + + optimizer = fluid.optimizer.SGD( + learning_rate=self.learning_rate) + optimizer.minimize(self.loss) + self.exe = Executor(fluid.CPUPlace()) + self.exe.run(self.startup_program) + + self.device_count = 1 + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.num_threads = self.device_count + exec_strategy.num_iteration_per_drop_scope = 100 + + build_strategy = fluid.BuildStrategy() + build_strategy.enable_inplace = True + build_strategy.memory_optimize = False + build_strategy.fuse_all_optimizer_ops = True + + self.train_program = fluid.compiler.CompiledProgram( + self.main_program).with_data_parallel( + loss_name=self.loss.name, + build_strategy=build_strategy, + exec_strategy=exec_strategy) + + def generate_init_data(self): + init_hidden = np.zeros( + (self.config.num_layers, self.config.batch_size, + self.config.hidden_size), + dtype='float32') + init_cell = np.zeros( + (self.config.num_layers, self.config.batch_size, + self.config.hidden_size), + dtype='float32') + return init_hidden, init_cell + + def generate_new_lr(self, epoch_id=0, device_count=1): + new_lr = self.config.base_learning_rate * (self.config.lr_decay**max( + epoch_id + 1 - self.config.epoch_start_decay, 0.0)) + lr = np.ones((self.device_count), dtype='float32') * new_lr + return lr + + def prepare_input(self, + batch, + init_hidden=None, + init_cell=None, + epoch_id=0, + with_lr=True, + device_count=1): + x, y = batch + x = x.reshape((-1, self.config.num_steps, 1)) + y = y.reshape((-1, 1)) + + res = {} + res['x'] = x + res['y'] = y + if init_hidden is not None: + res['init_hidden'] = init_hidden + if init_cell is not None: + res['init_cell'] = init_cell + if with_lr: + res['learning_rate'] = self.generate_new_lr(epoch_id, device_count) + return res + + def train_an_epoch(self, epoch_id, batch_times): + train_data_iter = self.reader.get_data_iter(self.config) + + total_loss = 0 + iters = 0 + + init_hidden, init_cell = self.generate_init_data() + ppl = np.zeros(shape=(0)) + for batch_id, batch in enumerate(train_data_iter): + input_data_feed = self.prepare_input( + batch, + init_hidden=init_hidden, + init_cell=init_cell, + epoch_id=epoch_id, + with_lr=True, + device_count=self.device_count) + + batch_start_time = time.time() + fetch_outs = self.exe.run(self.train_program, + feed=input_data_feed, + fetch_list=[ + self.loss.name, "learning_rate", + self.last_hidden.name, + self.last_cell.name + ], + use_program_cache=True) + batch_time = time.time() - batch_start_time + batch_times.append(batch_time) + + cost_train = np.array(fetch_outs[0]) + lr = np.array(fetch_outs[1]) + init_hidden = np.array(fetch_outs[2]) + init_cell = np.array(fetch_outs[3]) + + total_loss += cost_train + iters += self.config.num_steps + + batch_ppl = np.exp(total_loss / iters) + ppl = np.append(ppl, batch_ppl) + return ppl + + def train(self, config): + self.config = config + self.prepare_program(config) + total_time = 0.0 + ppl = np.zeros(shape=(0, config.batch_size)) + for epoch_id in range(config.max_epoch): + batch_times = [] + epoch_start_time = time.time() + train_ppl = self.train_an_epoch(epoch_id, batch_times) + epoch_time = time.time() - epoch_start_time + total_time += epoch_time + ppl = np.append(ppl, train_ppl) + return ppl + + def compare_padding_static_mode(self): + ''' + Test that train ppl of padding mode is same to that of static mode + ''' + config = RnnConfig('test', 'padding') + with fluid.scope_guard(fluid.Scope()): + padding_rnn_ppl = self.train(config) + config = RnnConfig('test', 'static') + with fluid.scope_guard(fluid.Scope()): + static_rnn_ppl = self.train(config) + self.assertTrue( + np.isclose( + padding_rnn_ppl, static_rnn_ppl, rtol=0.001).all()) + + def test_padding_mode_no_eager_deletion(self): + ''' + Test that train ppl of padding mode is same to that of static mode without eager deletion + ''' + fluid.core._set_eager_deletion_mode(-1.0, 1.0, True) + self.compare_padding_static_mode() + + def test_padding_mode_eager_deletion(self): + ''' + Test that train ppl of padding mode is same to that of static mode under eager deletion + ''' + fluid.core._set_eager_deletion_mode(0.0, 1.0, True) + self.compare_padding_static_mode() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_eager_deletion_recurrent_op.py b/python/paddle/fluid/tests/unittests/test_eager_deletion_recurrent_op.py new file mode 100644 index 0000000000..8cd7371159 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_eager_deletion_recurrent_op.py @@ -0,0 +1,683 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import paddle.fluid as fluid +import paddle.fluid.compiler as compiler +import paddle.fluid.core as core +import paddle.fluid.layers as layers +import unittest + +from paddle.fluid.framework import Program, grad_var_name +from paddle.fluid.executor import Executor +from paddle.fluid.backward import append_backward + +fluid.core._set_eager_deletion_mode(0.0, 1.0, True) + + +class PyRNNBase(object): + def __init__(self, input_shape, output_shape): + self.x = np.ones(shape=input_shape).astype("float32") + self.y = np.zeros(shape=output_shape).astype("float32") + + def step(self, step_id, x): + raise NotImplementedError + + def forward(self): + for step_id in range(self.x.shape[0]): + self.step(step_id, self.x[step_id]) + return np.array([np.mean(self.y)]) + + def segment_inputs(self): + return [self.x[i] for i in range(self.x.shape[0])] + + +class PySimpleRNN1(PyRNNBase): + def __init__(self, input_shape, output_shape): + super(PySimpleRNN1, self).__init__(input_shape, output_shape) + + seq_len, batch_size, input_dim = input_shape + self.h_boot = np.random.normal(size=(batch_size, + input_dim)).astype("float32") + + self.scale = 1.0 / 2.0 + men_dim = (seq_len, batch_size, input_dim) + self.mems = np.zeros(shape=men_dim).astype("float32") + + def step(self, step_id, x): + if step_id == 0: + pre_mem = self.h_boot + else: + pre_mem = self.mems[step_id - 1] + self.mems[step_id] = (pre_mem + x) * self.scale + self.y[step_id] = self.mems[step_id] + + +class PySimpleRNN2(PyRNNBase): + def __init__(self, input_shape, output_shape): + super(PySimpleRNN2, self).__init__(input_shape, output_shape) + + seq_len, batch_size, input_dim = input_shape + self.W = np.random.normal(size=(input_dim, input_dim)).astype("float32") + self.U = np.random.normal(size=(input_dim, input_dim)).astype("float32") + self.h_boot = np.ones(shape=(batch_size, input_dim)).astype("float32") + + men_dim = (seq_len, batch_size, input_dim) + self.mems = np.zeros(shape=men_dim).astype("float32") + + def step(self, step_id, x): + if step_id > 0: + pre_mem = self.mems[step_id - 1] + else: + pre_mem = self.h_boot + xW = np.matmul(x, self.W).astype("float32") + hU = np.matmul(pre_mem, self.U).astype("float32") + + def py_sigmoid(x): + return 1. / (1. + np.exp(-x)) + + self.mems[step_id] = py_sigmoid(xW + hU) + self.y[step_id] = self.mems[step_id] + + +def create_tensor(np_data, place): + tensor = core.LoDTensor() + tensor.set(np_data, place) + return tensor + + +class EagerDeletionRecurrentOpTest1(unittest.TestCase): + ''' + Test RNNOp + equation: + h_t = ( x_t + h_{t-1} ) / scale + vars: + - x + memories: + - h + outputs: + - h + ''' + + input_dim = 2 + batch_size = 1 + sent_len = 1 + + def setup_program(self): + self.main_program = Program() + self.startup_program = Program() + self.place = core.CPUPlace() + + def setUp(self): + self.setup_program() + self.data_field = {"x", "h_boot"} + + self.input_shape = (self.sent_len, self.batch_size, self.input_dim) + self.output_shape = (self.sent_len, self.batch_size, self.input_dim) + self.py_rnn = PySimpleRNN1(self.input_shape, self.output_shape) + + with fluid.program_guard(self.main_program, self.startup_program): + self.output = layers.mean(self.create_rnn_op()) + + def create_rnn_op(self): + x = layers.data( + shape=[self.sent_len, self.batch_size, self.input_dim], + dtype='float32', + name='x', + append_batch_size=False) + x.stop_gradient = False + h_boot = layers.data( + shape=[self.input_dim], dtype='float32', name='h_boot') + h_boot.stop_gradient = False + + rnn = layers.StaticRNN() + with rnn.step(): + h_pre = rnn.memory(init=h_boot) + x_t = rnn.step_input(x) + + h = layers.scale( + x=layers.elementwise_add( + x=h_pre, y=x_t), + scale=self.py_rnn.scale) + + rnn.update_memory(h_pre, h) + rnn.output(h) + + return rnn() + + def forward(self): + self.feed_map = { + x: create_tensor(getattr(self.py_rnn, x), self.place) + for x in self.data_field + } + exe = Executor(self.place) + out = exe.run(self.main_program, + feed=self.feed_map, + fetch_list=[self.output]) + + return out[0] + + def backward(self): + self.feed_map = { + x: create_tensor(getattr(self.py_rnn, x), self.place) + for x in self.data_field + } + fetch_list = [ + self.main_program.global_block().var(grad_var_name(x)) + for x in self.data_field + ] + + exe = Executor(self.place) + return exe.run(self.main_program, + feed=self.feed_map, + fetch_list=fetch_list, + return_numpy=False) + + def test_backward(self, rtol=0.1): + self.check_forward() + + with fluid.program_guard(self.main_program, self.startup_program): + append_backward(self.output) + + ana_grad = [np.array(x) for x in self.backward()] + + num_grad = self.get_numerical_gradient() + for idx, name in enumerate(self.data_field): + self.assertEqual(num_grad[idx].shape, ana_grad[idx].shape) + self.assertTrue( + np.isclose( + num_grad[idx], ana_grad[idx], rtol=rtol).all(), + "num_grad (" + name + ") has diff at " + str(self.place) + + "\nExpect " + str(num_grad[idx]) + "\n" + "But Got" + + str(ana_grad[idx]) + " in class " + self.__class__.__name__) + + def check_forward(self): + pd_output = self.forward() + py_output = self.py_rnn.forward() + self.assertEqual(pd_output.shape, py_output.shape) + self.assertTrue(np.isclose(pd_output, py_output, rtol=0.1).all()) + + def get_numerical_gradient(self, delta=0.005): + dloss_dout = 1.0 + feed_list = [getattr(self.py_rnn, x) for x in self.data_field] + grad_list = [np.zeros_like(x) for x in feed_list] + for feed, grad in zip(feed_list, grad_list): + for f, g in np.nditer([feed, grad], op_flags=['readwrite']): + o = float(f) + f[...] = o + delta + y_pos = self.forward() + + f[...] = o - delta + y_neg = self.forward() + + f[...] = o + dout_dfeed = (y_pos - y_neg) / (delta * 2) + g[...] = dout_dfeed[0] + + return grad_list + + +class EagerDeletionRecurrentOpTest2(EagerDeletionRecurrentOpTest1): + ''' + Test RNNOp + equation: + h_t = \sigma (W x_t + U h_{t-1}) + weights: + - W + - U + vars: + - x + memories: + - h + outputs: + - h + ''' + + input_dim = 2 + batch_size = 10 + sent_len = 2 + + def setUp(self): + self.setup_program() + + self.data_field = {"x", "h_boot", "W", "U"} + + self.input_shape = (self.sent_len, self.batch_size, self.input_dim) + self.output_shape = (self.sent_len, self.batch_size, self.input_dim) + self.py_rnn = PySimpleRNN2(self.input_shape, self.output_shape) + + with fluid.program_guard(self.main_program, self.startup_program): + self.output = layers.mean(self.create_rnn_op()) + + def create_rnn_op(self): + x = layers.data( + shape=[self.sent_len, self.batch_size, self.input_dim], + dtype='float32', + name='x', + append_batch_size=False) + x.stop_gradient = False + h_boot = layers.data( + shape=[self.input_dim], dtype='float32', name='h_boot') + h_boot.stop_gradient = False + + rnn = layers.StaticRNN() + with rnn.step(): + h_pre = rnn.memory(init=h_boot) + x_t = rnn.step_input(x) + + temp_l = layers.fc(input=x_t, + size=self.input_dim, + param_attr='W', + bias_attr=False) + temp_r = layers.fc(input=h_pre, + size=self.input_dim, + param_attr='U', + bias_attr=False) + + h = layers.sigmoid(x=layers.elementwise_add(x=temp_l, y=temp_r)) + + rnn.update_memory(h_pre, h) + rnn.output(h) + + return rnn() + + def test_backward(self): + super(EagerDeletionRecurrentOpTest2, self).test_backward(rtol=0.2) + + +class EagerDeletionRecurrentOpMultipleMemoryTest(EagerDeletionRecurrentOpTest1): + ''' + Test RNNOp with two memories + equation: + h_1 = h_pre_1 + h_2 = h_pre_2 + y = h_1 + h_2 + vars: + - x + memories: + - h_1, h_2 + outputs: + - y + ''' + + class PySimpleRNN3(PyRNNBase): + def __init__(self, input_shape, output_shape): + super(EagerDeletionRecurrentOpMultipleMemoryTest.PySimpleRNN3, + self).__init__(input_shape, output_shape) + + seq_len, batch_size, input_dim = input_shape + self.h_boot1 = np.random.normal(size=(batch_size, + input_dim)).astype("float32") + self.h_boot2 = np.random.normal(size=(batch_size, + input_dim)).astype("float32") + + men_dim = (seq_len, batch_size, input_dim) + self.mems1 = np.zeros(shape=men_dim).astype("float32") + self.mems2 = np.zeros(shape=men_dim).astype("float32") + + def step(self, step_id, x): + if step_id == 0: + pre_mem1 = self.h_boot1 + pre_mem2 = self.h_boot2 + else: + pre_mem1 = self.mems1[step_id - 1] + pre_mem2 = self.mems2[step_id - 1] + self.mems1[step_id] = pre_mem1 + self.mems2[step_id] = pre_mem2 + self.y[step_id] = self.mems1[step_id] + self.mems2[step_id] + x + + input_dim = 1 + batch_size = 1 + sent_len = 2 + + def setUp(self): + self.setup_program() + + self.data_field = {"x", "h_boot1", "h_boot2"} + + self.input_shape = (self.sent_len, self.batch_size, self.input_dim) + self.output_shape = (self.sent_len, self.batch_size, self.input_dim) + self.py_rnn = EagerDeletionRecurrentOpMultipleMemoryTest.PySimpleRNN3( + self.input_shape, self.output_shape) + + with fluid.program_guard(self.main_program, self.startup_program): + self.output = layers.mean(self.create_rnn_op()) + + def create_rnn_op(self): + x = layers.data( + shape=[self.sent_len, self.batch_size, self.input_dim], + dtype='float32', + name='x', + append_batch_size=False) + x.stop_gradient = False + h_boot1 = layers.data( + shape=[self.batch_size, self.input_dim], + dtype='float32', + name='h_boot1', + append_batch_size=False) + h_boot1.stop_gradient = False + h_boot2 = layers.data( + shape=[self.batch_size, self.input_dim], + dtype='float32', + name='h_boot2', + append_batch_size=False) + h_boot2.stop_gradient = False + + rnn = layers.StaticRNN() + with rnn.step(): + h_pre1 = rnn.memory(init=h_boot1) + h_pre2 = rnn.memory(init=h_boot2) + x_t = rnn.step_input(x) + + mem1 = layers.scale(x=h_pre1, scale=1.0) + mem2 = layers.scale(x=h_pre2, scale=1.0) + out = layers.sums(input=[mem1, x_t, mem2]) + + rnn.update_memory(h_pre1, mem1) + rnn.update_memory(h_pre2, mem2) + rnn.output(out) + + return rnn() + + +class EagerDeletionRecurrentOpNoMemBootTest(EagerDeletionRecurrentOpTest1): + ''' + Test RNNOp without memory boot + equation: + mem = x + mem_pre + y = mem + vars: + - x + memories: + - mem + outputs: + - y + ''' + + class PySimpleRNN4(PyRNNBase): + def __init__(self, input_shape, output_shape): + super(EagerDeletionRecurrentOpNoMemBootTest.PySimpleRNN4, + self).__init__(input_shape, output_shape) + men_dim = input_shape + self.mems = np.zeros(shape=men_dim).astype("float32") + + def step(self, step_id, x): + if step_id == 0: + pre_mem = np.zeros_like(x) + else: + pre_mem = self.mems[step_id - 1] + self.mems[step_id] = pre_mem + x + self.y[step_id] = self.mems[step_id] + + input_dim = 1 + batch_size = 1 + sent_len = 2 + + def setUp(self): + self.setup_program() + + self.data_field = {"x"} + + self.input_shape = (self.sent_len, self.batch_size, self.input_dim) + self.output_shape = (self.sent_len, self.batch_size, self.input_dim) + self.py_rnn = EagerDeletionRecurrentOpNoMemBootTest.PySimpleRNN4( + self.input_shape, self.output_shape) + + with fluid.program_guard(self.main_program, self.startup_program): + self.output = layers.mean(self.create_rnn_op()) + + def create_rnn_op(self): + x = layers.data( + shape=[self.sent_len, self.batch_size, self.input_dim], + dtype='float32', + name='x', + append_batch_size=False) + x.stop_gradient = False + + rnn = layers.StaticRNN() + with rnn.step(): + mem_pre = rnn.memory(shape=[-1, self.input_dim], batch_ref=x) + x_t = rnn.step_input(x) + mem = layers.elementwise_add(x=mem_pre, y=x_t) + rnn.update_memory(mem_pre, mem) + rnn.output(mem) + + return rnn() + + +class EagerDeletionTwoRecurrentOpsTest(EagerDeletionRecurrentOpTest1): + ''' + Test RNNOp with two recurrent ops + equation: + first_rnn: + mem_inside = x + mem_pre_inside + first_inside_out = mem_inside + second_rnn: + mem = x + reduce_sum(rnn_inside_out) + y = mem + mem_pre + vars: + - x + memories: + - mem_inside + - mem + outputs: + - y + ''' + + class PySimpleRNN5(PyRNNBase): + def __init__(self, input_shape, output_shape): + super(EagerDeletionTwoRecurrentOpsTest.PySimpleRNN5, + self).__init__(input_shape, output_shape) + self.mem_0 = np.zeros(shape=input_shape).astype("float32") + self.mem_1 = np.zeros(shape=input_shape).astype("float32") + self.rnn_0_output = np.zeros(shape=input_shape).astype("float32") + + def step(self, step_id, x): + # First Rnn + for step in range(self.x.shape[0]): + x_t = self.x[step] + pre_mem = np.zeros_like(x_t) if step == 0 else self.mem_0[step - + 1] + self.mem_0[step] = x_t + pre_mem + self.rnn_0_output[step] = self.mem_0[step] + # Second RNN + pre_mem = np.zeros_like(x) if step_id == 0 else self.mem_1[step_id - + 1] + # print(np.sum(self.rnn_0_output)) + self.mem_1[step_id] = x + np.sum(self.rnn_0_output) + self.y[step_id] = self.mem_1[step_id] + pre_mem + + input_dim = 1 + batch_size = 1 + sent_len = 1 + + def setUp(self): + self.setup_program() + + self.data_field = {"x"} + + self.input_shape = (self.sent_len, self.batch_size, self.input_dim) + self.output_shape = (self.sent_len, self.batch_size, self.input_dim) + self.py_rnn = EagerDeletionTwoRecurrentOpsTest.PySimpleRNN5( + self.input_shape, self.output_shape) + + with fluid.program_guard(self.main_program, self.startup_program): + self.output = layers.mean(self.create_rnn_op()) + + def create_rnn_op(self): + x = layers.data( + shape=[self.sent_len, self.batch_size, self.input_dim], + dtype='float32', + name='x', + append_batch_size=False) + x.stop_gradient = False + + rnn_0 = layers.StaticRNN() + with rnn_0.step(): + x_t = rnn_0.step_input(x) + mem_pre = rnn_0.memory(shape=[-1, self.input_dim], batch_ref=x) + mem = layers.elementwise_add(x=mem_pre, y=x_t) + rnn_0.update_memory(mem_pre, mem) + rnn_0.output(mem) + + rnn_1 = layers.StaticRNN() + with rnn_1.step(): + mem_pre = rnn_1.memory(shape=[-1, self.input_dim], batch_ref=x) + x_t = rnn_1.step_input(x) + last_rnn_output = rnn_0() + last_rnn_sum = fluid.layers.reduce_sum(last_rnn_output) + mem = layers.elementwise_add(x=x_t, y=last_rnn_sum) + y = layers.elementwise_add(x=mem_pre, y=mem) + rnn_1.update_memory(mem_pre, mem) + rnn_1.output(y) + return rnn_1() + + +class EagerDeletionRecurrentOpParallelExecutorTest( + EagerDeletionRecurrentOpTest1): + ''' + Test RNNOp with ParallelExecutor + equation: + h_t = ( x_t + h_{t-1} ) / scale + vars: + - x + memories: + - h + outputs: + - h + ''' + + def forward(self): + self.feed_map = { + x: create_tensor(getattr(self.py_rnn, x), self.place) + for x in self.data_field + } + + build_strategy = fluid.BuildStrategy() + build_strategy.enable_inplace = True + exec_strategy = fluid.ExecutionStrategy() + parallel_exe = fluid.ParallelExecutor( + use_cuda=False, + main_program=self.main_program, + build_strategy=build_strategy, + exec_strategy=exec_strategy) + out = parallel_exe.run(feed=self.feed_map, fetch_list=[self.output]) + return out[0] + + def backward(self): + self.feed_map = { + x: create_tensor(getattr(self.py_rnn, x), self.place) + for x in self.data_field + } + fetch_list = [ + self.main_program.global_block().var(grad_var_name(x)) + for x in self.data_field + ] + + build_strategy = fluid.BuildStrategy() + build_strategy.enable_inplace = True + exec_strategy = fluid.ExecutionStrategy() + parallel_exe = fluid.ParallelExecutor( + use_cuda=False, + main_program=self.main_program, + build_strategy=build_strategy, + exec_strategy=exec_strategy) + return parallel_exe.run(feed=self.feed_map, + fetch_list=fetch_list, + return_numpy=False) + + +class EagerDeletionFarwardOnlyRnnAndBackwardRnnTest( + EagerDeletionRecurrentOpTest1): + ''' + Test one forward only RNN and one backward RNN in one program + ''' + + def setUp(self): + self.setup_program() + self.data_field = {"x", "h_boot"} + + self.input_shape = (self.sent_len, self.batch_size, self.input_dim) + self.output_shape = (self.sent_len, self.batch_size, self.input_dim) + self.py_rnn = PySimpleRNN1(self.input_shape, self.output_shape) + + with fluid.program_guard(self.main_program, self.startup_program): + x = layers.data( + shape=[self.sent_len, self.batch_size, self.input_dim], + dtype='float32', + name='x', + append_batch_size=False) + x.stop_gradient = False + h_boot = layers.data( + shape=[self.input_dim], dtype='float32', name='h_boot') + h_boot.stop_gradient = False + + forward_only_rnn = layers.StaticRNN() + with forward_only_rnn.step(): + h_pre = forward_only_rnn.memory(init=h_boot) + x_t = forward_only_rnn.step_input(x) + + h = layers.scale( + x=layers.elementwise_add( + x=h_pre, y=x_t), + scale=self.py_rnn.scale) + + forward_only_rnn.update_memory(h_pre, h) + forward_only_rnn.output(h) + forward_only_output = forward_only_rnn() + forward_only_output.stop_gradient = True + self.forward_only_output = layers.mean(forward_only_output) + + rnn = layers.StaticRNN() + with rnn.step(): + h_pre = rnn.memory(init=h_boot) + x_t = rnn.step_input(x) + + h = layers.scale( + x=layers.elementwise_add( + x=h_pre, y=x_t), + scale=self.py_rnn.scale) + + rnn.update_memory(h_pre, h) + rnn.output(h) + + self.output = layers.mean(rnn()) + + def forward_two_rnn(self): + self.feed_map = { + x: create_tensor(getattr(self.py_rnn, x), self.place) + for x in self.data_field + } + exe = Executor(self.place) + out = exe.run(self.main_program, + feed=self.feed_map, + fetch_list=[self.forward_only_output, self.output]) + + return out[0], out[1] + + def check_forward(self): + forward_only_output, pd_output = self.forward_two_rnn() + py_output = self.py_rnn.forward() + self.assertEqual(forward_only_output.shape, py_output.shape) + self.assertEqual(pd_output.shape, py_output.shape) + self.assertTrue( + np.isclose( + forward_only_output, py_output, rtol=0.1).all) + self.assertTrue(np.isclose(pd_output, py_output, rtol=0.1).all()) + + +if __name__ == '__main__': + unittest.main() -- GitLab