From 4c8254e3bf426a044ef51d661193bd9a720dc204 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Wed, 27 Mar 2019 10:53:01 +0000 Subject: [PATCH] revert some loop op revision test=develop --- .../operators/controlflow/CMakeLists.txt | 2 +- .../fluid/operators/controlflow/while_op.cc | 21 +- .../operators/controlflow/while_op_helper.cc | 291 ++++++++++++++++++ .../operators/controlflow/while_op_helper.h | 43 +++ paddle/fluid/operators/interpolate_op.cc | 4 + paddle/fluid/operators/lstm_op.cc | 1 + paddle/fluid/operators/margin_rank_loss_op.cc | 1 + paddle/fluid/operators/mean_op.cc | 3 + paddle/fluid/operators/multiplex_op.cc | 1 + paddle/fluid/operators/recurrent_op.cc | 52 ++-- paddle/fluid/operators/scatter_op.cc | 1 + 11 files changed, 380 insertions(+), 40 deletions(-) create mode 100644 paddle/fluid/operators/controlflow/while_op_helper.cc create mode 100644 paddle/fluid/operators/controlflow/while_op_helper.h diff --git a/paddle/fluid/operators/controlflow/CMakeLists.txt b/paddle/fluid/operators/controlflow/CMakeLists.txt index 4782e9d5f..7aa1c44ea 100644 --- a/paddle/fluid/operators/controlflow/CMakeLists.txt +++ b/paddle/fluid/operators/controlflow/CMakeLists.txt @@ -1,5 +1,5 @@ include(operators) register_operators(DEPS naive_executor) -cc_library(loop_op_helper SRCS loop_op_helper.cc DEPS operator) +cc_library(while_op_helper SRCS while_op_helper.cc DEPS operator) file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index a07a732d8..b32192088 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -18,21 +18,28 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/var_type.h" -#include "paddle/fluid/operators/controlflow/loop_op_helper.h" +#include "paddle/fluid/operators/controlflow/while_op_helper.h" #include "paddle/fluid/operators/detail/safe_ref.h" namespace paddle { namespace operators { -static constexpr char kCondition[] = "Condition"; -static constexpr char kStepScopes[] = "StepScopes"; -static constexpr char kX[] = "X"; -static constexpr char kXGRAD[] = "X@GRAD"; -static constexpr char kOutputs[] = "Out"; - using StepScopeVar = std::vector; using LoDTensor = framework::LoDTensor; +namespace { // NOLINT +static std::string GetSkipEagerDeletionVarsDebugString( + const std::vector &vars) { + std::string str = "Skip " + std::to_string(vars.size()) + + " var(s) in eager deletion mode: "; + for (auto &var : vars) { + str.append(var); + str.push_back(' '); + } + return str; +} +} // NOLINT + class WhileOp : public framework::OperatorBase { public: WhileOp(const std::string &type, const framework::VariableNameMap &inputs, diff --git a/paddle/fluid/operators/controlflow/while_op_helper.cc b/paddle/fluid/operators/controlflow/while_op_helper.cc new file mode 100644 index 000000000..2cbd94a06 --- /dev/null +++ b/paddle/fluid/operators/controlflow/while_op_helper.cc @@ -0,0 +1,291 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/controlflow/while_op_helper.h" +#include +#include +#include +#include "paddle/fluid/framework/program_desc.h" + +namespace paddle { +namespace operators { + +// OpVariant is a wrapper class of OpDesc and OperatorBase +// So that API would be the same. +class OpVariant { + struct InputsVisitor + : public boost::static_visitor { + template + const framework::VariableNameMap *operator()(const OpType *op) const { + return &(op->Inputs()); + } + }; + + struct OutputsVisitor + : public boost::static_visitor { + template + const framework::VariableNameMap *operator()(const OpType *op) const { + return &(op->Outputs()); + } + }; + + struct AttributeMapVisitor + : public boost::static_visitor { + const framework::AttributeMap *operator()( + const framework::OpDesc *op) const { + return &(op->GetAttrMap()); + } + + const framework::AttributeMap *operator()( + const framework::OperatorBase *op) const { + return &(op->Attrs()); + } + }; + + struct RawPointerVisitor : public boost::static_visitor { + template + const void *operator()(const OpType *op) const { + return op; + } + }; + + public: + OpVariant(const framework::OperatorBase *op) : op_(op) {} // NOLINT + + OpVariant(const framework::OpDesc *op) : op_(op) {} // NOLINT + + const framework::VariableNameMap &Inputs() const { + return *boost::apply_visitor(InputsVisitor(), op_); + } + + const framework::VariableNameMap &Outputs() const { + return *boost::apply_visitor(OutputsVisitor(), op_); + } + + const framework::AttributeMap &Attrs() const { + return *boost::apply_visitor(AttributeMapVisitor(), op_); + } + + template + const AttrType &Attr(const std::string &name) const { + auto &attrs = Attrs(); + auto it = attrs.find(name); + PADDLE_ENFORCE(it != attrs.end(), "Cannot find attribute %s", name); + return boost::get(it->second); + } + + bool operator==(const OpVariant &other) const { + return RawPointer() == other.RawPointer(); + } + + const void *RawPointer() const { + return boost::apply_visitor(RawPointerVisitor(), op_); + } + + int which() const { return static_cast(op_.which()); } + + struct Hasher { + size_t operator()(const OpVariant &op) const { + return reinterpret_cast(op.RawPointer()); + } + }; + + private: + const boost::variant + op_; +}; + +static std::string GetDebugString(const std::vector &names) { + if (names.empty()) return ""; + std::string ret = names[0]; + for (size_t i = 1; i < names.size(); ++i) { + ret += (" " + names[i]); + } + return ret; +} + +// Set skip variables of while_op and while_grad_op +// These variables should be skipped when eager deletion enables. +// It is because: +// 1. while_grad_op needs some variables defined in while_op. +// 2. while_grad_op needs variables from the previous time step. +static void SetSkipVars(const OpVariant &op, std::vector attr) { + auto &attrs = const_cast(op.Attrs()); + VLOG(2) << "Prepare to skip " << attr.size() + << " var(s): " << GetDebugString(attr); + attrs[kSkipEagerDeletionVars] = std::move(attr); +} + +// Check whether the forward while_op and while_grad_op match +// The program may have many while_ops. +static bool IsMatchedWhileOpAndWhileGradOp(const OpVariant &fwd_op, + const OpVariant &grad_op) { + return fwd_op.Inputs().at(kX) == grad_op.Inputs().at(kX) && + fwd_op.Outputs().at(kOutputs) == grad_op.Inputs().at(kOutputs); +} + +// Test whether the variable is skippable in forward while_op +// The variable is skippable in while_op when the variable used in while_grad +// is not from grad_block. +static bool IsSkippableVar(const std::string &name, + framework::BlockDesc *grad_block) { + return name != framework::kEmptyVarName && !grad_block->HasVar(name); +} + +static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op, + const OpVariant &bwd_op) { + auto *grad_block = bwd_op.Attr(kStepBlock); + + // Find all skippable variables in forward while_op + std::unordered_set forward_skip_vars; + for (auto *op_desc : grad_block->AllOps()) { + for (auto &in_arg_name : op_desc->InputArgumentNames()) { + if (IsSkippableVar(in_arg_name, grad_block)) { + forward_skip_vars.insert(in_arg_name); + } + } + + for (auto &out_arg_name : op_desc->OutputArgumentNames()) { + if (IsSkippableVar(out_arg_name, grad_block)) { + forward_skip_vars.insert(out_arg_name); + } + } + } + + SetSkipVars(fwd_op, std::vector(forward_skip_vars.begin(), + forward_skip_vars.end())); + + // Find all skippable variables in while_grad_op + // The skipped variables are those which would be used across time steps. + auto &fwd_input = fwd_op.Inputs().at(kX); + auto &in_grads = bwd_op.Outputs().at(framework::GradVarName(kX)); + PADDLE_ENFORCE_EQ( + fwd_input.size(), in_grads.size(), + "Backward input gradient number does not match forward input number."); + + std::unordered_set backward_skip_vars; + for (size_t i = 0; i < in_grads.size(); ++i) { + if (in_grads[i] == framework::kEmptyVarName) { + continue; + } + backward_skip_vars.insert(in_grads[i]); + backward_skip_vars.insert(framework::GradVarName(fwd_input[i])); + } + + SetSkipVars(bwd_op, std::vector(backward_skip_vars.begin(), + backward_skip_vars.end())); +} + +// Find all while_ops and while_grad_ops in the graph or program +// The while_grad_op and while_op may located in different blocks +// So we should traverse all blocks in the program and find them out. +static void FindAllWhileAndWhileGradOp(std::vector *while_ops, + std::vector *while_grad_ops) { + PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size()); + + if (while_ops->empty()) return; + + const auto *program = + while_ops->front().Attr(kStepBlock)->Program(); + for (size_t i = 1; i < program->Size(); ++i) { + auto &block = program->Block(i); + for (size_t j = 0; j < block.OpSize(); ++j) { + auto *op = block.Op(j); + if (op->Type() == "while") { + while_ops->emplace_back(op); + } else if (op->Type() == "while_grad") { + while_grad_ops->emplace_back(op); + } + } + } + + PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size(), + "There are extra while_grad ops in the graph or program"); +} + +static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl( + std::vector *while_ops, std::vector *while_grad_ops) { + FindAllWhileAndWhileGradOp(while_ops, while_grad_ops); + + VLOG(2) << "Found while op num: " << while_ops->size() + << ", while grad op num: " << while_grad_ops->size(); + + if (while_grad_ops->empty()) { + return; + } + + std::unordered_set while_op_set( + while_ops->begin(), while_ops->end()); + + for (auto &bwd_op : *while_grad_ops) { + const OpVariant *matched_fwd_op = nullptr; + for (auto &fwd_op : while_op_set) { + if (IsMatchedWhileOpAndWhileGradOp(fwd_op, bwd_op)) { + PADDLE_ENFORCE(matched_fwd_op == nullptr, + "Found multiple matched while ops"); + matched_fwd_op = &fwd_op; + } + } + PADDLE_ENFORCE_NOT_NULL(matched_fwd_op, + "Cannot find matched forward while op."); + ModifyWhileOpAndWhileGradOpAttr(*matched_fwd_op, bwd_op); + while_op_set.erase(*matched_fwd_op); + } +} + +void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( + int block_id, + const std::vector> &all_ops) { + // If block_id is not 0, returns + // This is because all while_ops and while_grad_ops in the whole program + // would be processed when block_id is 0 (i.e. when Executor::Run() or + // ParallelExecutor constructs). + + // What's more, all while_ops and while_grad_ops must be processed when + // block_id is zero. If not, while_op may run first and erase variables + // used in while_grad_op, and in this moment, while_grad_ops may be not + // constructed yet. + if (block_id != 0) return; + + std::vector fwd_ops, bwd_ops; + for (auto &op : all_ops) { + if (op->Type() == "while") { + fwd_ops.emplace_back(op.get()); + } else if (op->Type() == "while_grad") { + bwd_ops.emplace_back(op.get()); + } + } + PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops); +} + +void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( + const std::vector &while_ops, + const std::vector &while_grad_ops) { + std::vector fwd_ops, bwd_ops; + fwd_ops.reserve(while_ops.size()); + for (auto *op : while_ops) { + fwd_ops.emplace_back(op); + } + + bwd_ops.reserve(while_grad_ops.size()); + for (auto *op : while_grad_ops) { + bwd_ops.emplace_back(op); + } + + PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops); +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/controlflow/while_op_helper.h b/paddle/fluid/operators/controlflow/while_op_helper.h new file mode 100644 index 000000000..456ba8642 --- /dev/null +++ b/paddle/fluid/operators/controlflow/while_op_helper.h @@ -0,0 +1,43 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/variant.h" + +namespace paddle { +namespace operators { + +static constexpr char kStepBlock[] = "sub_block"; +static constexpr char kCondition[] = "Condition"; +static constexpr char kStepScopes[] = "StepScopes"; +static constexpr char kX[] = "X"; +static constexpr char kXGRAD[] = "X@GRAD"; +static constexpr char kOutputs[] = "Out"; +static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars"; + +void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( + int block_id, + const std::vector> &all_ops); + +void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( + const std::vector &while_ops, + const std::vector &while_grad_ops); + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index cfded65f0..edee8c08d 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -10,6 +10,7 @@ limitations under the License. */ #include "paddle/fluid/operators/interpolate_op.h" +#include #include #include #include "paddle/fluid/framework/op_registry.h" @@ -209,6 +210,9 @@ class InterpolateGradDescMaker : public framework::SingleGradOpDescMaker { std::unique_ptr op(new framework::OpDesc()); op->SetType(ForwardOp().Type() + "_grad"); op->SetInput("X", Input("X")); + if (ForwardOp().Inputs().count("OutSize") > 0) { + op->SetInput("OutSize", Input("OutSize")); + } op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetAttrMap(Attrs()); diff --git a/paddle/fluid/operators/lstm_op.cc b/paddle/fluid/operators/lstm_op.cc index 30c3945cb..52e4e8be2 100644 --- a/paddle/fluid/operators/lstm_op.cc +++ b/paddle/fluid/operators/lstm_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/lstm_op.h" +#include #include namespace paddle { diff --git a/paddle/fluid/operators/margin_rank_loss_op.cc b/paddle/fluid/operators/margin_rank_loss_op.cc index b3d9733a9..fca353255 100644 --- a/paddle/fluid/operators/margin_rank_loss_op.cc +++ b/paddle/fluid/operators/margin_rank_loss_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/margin_rank_loss_op.h" +#include namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/mean_op.cc b/paddle/fluid/operators/mean_op.cc index 26d86afed..2b2f84507 100644 --- a/paddle/fluid/operators/mean_op.cc +++ b/paddle/fluid/operators/mean_op.cc @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/mean_op.h" +#include #include +#include + namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/multiplex_op.cc b/paddle/fluid/operators/multiplex_op.cc index b3d0423b7..7cb213e89 100644 --- a/paddle/fluid/operators/multiplex_op.cc +++ b/paddle/fluid/operators/multiplex_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/multiplex_op.h" +#include #include namespace paddle { diff --git a/paddle/fluid/operators/recurrent_op.cc b/paddle/fluid/operators/recurrent_op.cc index 45c87bb08..2898a62dd 100644 --- a/paddle/fluid/operators/recurrent_op.cc +++ b/paddle/fluid/operators/recurrent_op.cc @@ -15,24 +15,24 @@ limitations under the License. */ #include #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/controlflow/loop_op_helper.h" namespace paddle { namespace operators { - -using recurrent::kInputs; -using recurrent::kInitialStates; -using recurrent::kParameters; -using recurrent::kOutputs; -using recurrent::kStepScopes; -using recurrent::kExStates; -using recurrent::kStates; -using recurrent::kReverse; -using recurrent::kIsTrain; -using recurrent::kInputGrads; -using recurrent::kOutputGrads; -using recurrent::kParamGrads; -using recurrent::kInitStateGrads; +constexpr char kInputs[] = "inputs"; +constexpr char kInitialStates[] = "initial_states"; +constexpr char kParameters[] = "parameters"; +constexpr char kOutputs[] = "outputs"; +constexpr char kStepScopes[] = "step_scopes"; +constexpr char kExStates[] = "ex_states"; +constexpr char kStates[] = "states"; +constexpr char kStepBlock[] = "sub_block"; +constexpr char kReverse[] = "reverse"; +constexpr char kIsTrain[] = "is_train"; +#define GRAD_SUFFIX "@GRAD" +constexpr char kInputGrads[] = "inputs" GRAD_SUFFIX; +constexpr char kOutputGrads[] = "outputs" GRAD_SUFFIX; +constexpr char kParamGrads[] = "parameters" GRAD_SUFFIX; +constexpr char kInitStateGrads[] = "initial_states" GRAD_SUFFIX; using StepScopeVar = std::vector; @@ -249,9 +249,6 @@ class RecurrentOp : public RecurrentBase { framework::Executor executor(place); auto *block = Attr(kStepBlock); - auto &keep_vars = Attr>(kSkipEagerDeletionVars); - VLOG(2) << GetSkipEagerDeletionVarsDebugString(keep_vars); - auto *program = block->Program(); for (size_t i = 0; i < seq_len; ++i) { @@ -286,7 +283,8 @@ class RecurrentOp : public RecurrentBase { // Every inputs are linked now, execute! executor.Run(*program, &cur_scope, block->ID(), false /*create_local_scope*/, true /*create_vars*/, - keep_vars); + std::vector() /*skip_ref_cnt_vars*/, + true /*force_disable_gc*/); // get device context from pool platform::DeviceContextPool &pool = @@ -343,9 +341,6 @@ class RecurrentGradOp : public RecurrentBase { auto *block = Attr(kStepBlock); auto *program = block->Program(); - auto &keep_vars = Attr>(kSkipEagerDeletionVars); - - VLOG(2) << GetSkipEagerDeletionVarsDebugString(keep_vars); // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); @@ -406,7 +401,8 @@ class RecurrentGradOp : public RecurrentBase { // Run step block with cur_scope executor.Run(*program, &cur_scope, block->ID(), false /*create_local_scope*/, true /*create_vars*/, - keep_vars); + std::vector() /*skip_ref_cnt_vars*/, + true /*force_disable_gc*/); VLOG(5) << "executor.Run finished "; @@ -583,10 +579,6 @@ if reverse is True o o o o )DOC").SetDefault(false); AddAttr(kIsTrain, "").SetDefault(true); - AddAttr>(kSkipEagerDeletionVars, - "Skip vars that would " - "be used in backward ops") - .SetDefault(std::vector()); AddComment(R"DOC( Static Length Recurrent Operator. @@ -622,11 +614,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker { this->OutputGrad(output_param)); } } - - auto attrs = this->Attrs(); - attrs.insert({kSkipEagerDeletionVars, std::vector()}); - grad->SetAttrMap(attrs); - + grad->SetAttrMap(this->Attrs()); grad->SetBlockAttr(kStepBlock, grad_block_[0]); return std::unique_ptr(grad); diff --git a/paddle/fluid/operators/scatter_op.cc b/paddle/fluid/operators/scatter_op.cc index 1c2670750..8e0e3bd60 100644 --- a/paddle/fluid/operators/scatter_op.cc +++ b/paddle/fluid/operators/scatter_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/scatter_op.h" +#include #include "paddle/fluid/framework/ddim.h" namespace paddle { -- GitLab