提交 4c8254e3 编写于 作者: S sneaxiy

revert some loop op revision

test=develop
上级 16f09947
include(operators) include(operators)
register_operators(DEPS naive_executor) register_operators(DEPS naive_executor)
cc_library(loop_op_helper SRCS loop_op_helper.cc DEPS operator) cc_library(while_op_helper SRCS while_op_helper.cc DEPS operator)
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n")
...@@ -18,21 +18,28 @@ ...@@ -18,21 +18,28 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/controlflow/loop_op_helper.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static constexpr char kCondition[] = "Condition";
static constexpr char kStepScopes[] = "StepScopes";
static constexpr char kX[] = "X";
static constexpr char kXGRAD[] = "X@GRAD";
static constexpr char kOutputs[] = "Out";
using StepScopeVar = std::vector<framework::Scope *>; using StepScopeVar = std::vector<framework::Scope *>;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
namespace { // NOLINT
static std::string GetSkipEagerDeletionVarsDebugString(
const std::vector<std::string> &vars) {
std::string str = "Skip " + std::to_string(vars.size()) +
" var(s) in eager deletion mode: ";
for (auto &var : vars) {
str.append(var);
str.push_back(' ');
}
return str;
}
} // NOLINT
class WhileOp : public framework::OperatorBase { class WhileOp : public framework::OperatorBase {
public: public:
WhileOp(const std::string &type, const framework::VariableNameMap &inputs, WhileOp(const std::string &type, const framework::VariableNameMap &inputs,
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include <string>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace operators {
// OpVariant is a wrapper class of OpDesc and OperatorBase
// So that API would be the same.
class OpVariant {
struct InputsVisitor
: public boost::static_visitor<const framework::VariableNameMap *> {
template <typename OpType>
const framework::VariableNameMap *operator()(const OpType *op) const {
return &(op->Inputs());
}
};
struct OutputsVisitor
: public boost::static_visitor<const framework::VariableNameMap *> {
template <typename OpType>
const framework::VariableNameMap *operator()(const OpType *op) const {
return &(op->Outputs());
}
};
struct AttributeMapVisitor
: public boost::static_visitor<const framework::AttributeMap *> {
const framework::AttributeMap *operator()(
const framework::OpDesc *op) const {
return &(op->GetAttrMap());
}
const framework::AttributeMap *operator()(
const framework::OperatorBase *op) const {
return &(op->Attrs());
}
};
struct RawPointerVisitor : public boost::static_visitor<const void *> {
template <typename OpType>
const void *operator()(const OpType *op) const {
return op;
}
};
public:
OpVariant(const framework::OperatorBase *op) : op_(op) {} // NOLINT
OpVariant(const framework::OpDesc *op) : op_(op) {} // NOLINT
const framework::VariableNameMap &Inputs() const {
return *boost::apply_visitor(InputsVisitor(), op_);
}
const framework::VariableNameMap &Outputs() const {
return *boost::apply_visitor(OutputsVisitor(), op_);
}
const framework::AttributeMap &Attrs() const {
return *boost::apply_visitor(AttributeMapVisitor(), op_);
}
template <typename AttrType>
const AttrType &Attr(const std::string &name) const {
auto &attrs = Attrs();
auto it = attrs.find(name);
PADDLE_ENFORCE(it != attrs.end(), "Cannot find attribute %s", name);
return boost::get<AttrType>(it->second);
}
bool operator==(const OpVariant &other) const {
return RawPointer() == other.RawPointer();
}
const void *RawPointer() const {
return boost::apply_visitor(RawPointerVisitor(), op_);
}
int which() const { return static_cast<int>(op_.which()); }
struct Hasher {
size_t operator()(const OpVariant &op) const {
return reinterpret_cast<size_t>(op.RawPointer());
}
};
private:
const boost::variant<const framework::OperatorBase *,
const framework::OpDesc *>
op_;
};
static std::string GetDebugString(const std::vector<std::string> &names) {
if (names.empty()) return "";
std::string ret = names[0];
for (size_t i = 1; i < names.size(); ++i) {
ret += (" " + names[i]);
}
return ret;
}
// Set skip variables of while_op and while_grad_op
// These variables should be skipped when eager deletion enables.
// It is because:
// 1. while_grad_op needs some variables defined in while_op.
// 2. while_grad_op needs variables from the previous time step.
static void SetSkipVars(const OpVariant &op, std::vector<std::string> attr) {
auto &attrs = const_cast<framework::AttributeMap &>(op.Attrs());
VLOG(2) << "Prepare to skip " << attr.size()
<< " var(s): " << GetDebugString(attr);
attrs[kSkipEagerDeletionVars] = std::move(attr);
}
// Check whether the forward while_op and while_grad_op match
// The program may have many while_ops.
static bool IsMatchedWhileOpAndWhileGradOp(const OpVariant &fwd_op,
const OpVariant &grad_op) {
return fwd_op.Inputs().at(kX) == grad_op.Inputs().at(kX) &&
fwd_op.Outputs().at(kOutputs) == grad_op.Inputs().at(kOutputs);
}
// Test whether the variable is skippable in forward while_op
// The variable is skippable in while_op when the variable used in while_grad
// is not from grad_block.
static bool IsSkippableVar(const std::string &name,
framework::BlockDesc *grad_block) {
return name != framework::kEmptyVarName && !grad_block->HasVar(name);
}
static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op,
const OpVariant &bwd_op) {
auto *grad_block = bwd_op.Attr<framework::BlockDesc *>(kStepBlock);
// Find all skippable variables in forward while_op
std::unordered_set<std::string> forward_skip_vars;
for (auto *op_desc : grad_block->AllOps()) {
for (auto &in_arg_name : op_desc->InputArgumentNames()) {
if (IsSkippableVar(in_arg_name, grad_block)) {
forward_skip_vars.insert(in_arg_name);
}
}
for (auto &out_arg_name : op_desc->OutputArgumentNames()) {
if (IsSkippableVar(out_arg_name, grad_block)) {
forward_skip_vars.insert(out_arg_name);
}
}
}
SetSkipVars(fwd_op, std::vector<std::string>(forward_skip_vars.begin(),
forward_skip_vars.end()));
// Find all skippable variables in while_grad_op
// The skipped variables are those which would be used across time steps.
auto &fwd_input = fwd_op.Inputs().at(kX);
auto &in_grads = bwd_op.Outputs().at(framework::GradVarName(kX));
PADDLE_ENFORCE_EQ(
fwd_input.size(), in_grads.size(),
"Backward input gradient number does not match forward input number.");
std::unordered_set<std::string> backward_skip_vars;
for (size_t i = 0; i < in_grads.size(); ++i) {
if (in_grads[i] == framework::kEmptyVarName) {
continue;
}
backward_skip_vars.insert(in_grads[i]);
backward_skip_vars.insert(framework::GradVarName(fwd_input[i]));
}
SetSkipVars(bwd_op, std::vector<std::string>(backward_skip_vars.begin(),
backward_skip_vars.end()));
}
// Find all while_ops and while_grad_ops in the graph or program
// The while_grad_op and while_op may located in different blocks
// So we should traverse all blocks in the program and find them out.
static void FindAllWhileAndWhileGradOp(std::vector<OpVariant> *while_ops,
std::vector<OpVariant> *while_grad_ops) {
PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size());
if (while_ops->empty()) return;
const auto *program =
while_ops->front().Attr<framework::BlockDesc *>(kStepBlock)->Program();
for (size_t i = 1; i < program->Size(); ++i) {
auto &block = program->Block(i);
for (size_t j = 0; j < block.OpSize(); ++j) {
auto *op = block.Op(j);
if (op->Type() == "while") {
while_ops->emplace_back(op);
} else if (op->Type() == "while_grad") {
while_grad_ops->emplace_back(op);
}
}
}
PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size(),
"There are extra while_grad ops in the graph or program");
}
static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(
std::vector<OpVariant> *while_ops, std::vector<OpVariant> *while_grad_ops) {
FindAllWhileAndWhileGradOp(while_ops, while_grad_ops);
VLOG(2) << "Found while op num: " << while_ops->size()
<< ", while grad op num: " << while_grad_ops->size();
if (while_grad_ops->empty()) {
return;
}
std::unordered_set<OpVariant, OpVariant::Hasher> while_op_set(
while_ops->begin(), while_ops->end());
for (auto &bwd_op : *while_grad_ops) {
const OpVariant *matched_fwd_op = nullptr;
for (auto &fwd_op : while_op_set) {
if (IsMatchedWhileOpAndWhileGradOp(fwd_op, bwd_op)) {
PADDLE_ENFORCE(matched_fwd_op == nullptr,
"Found multiple matched while ops");
matched_fwd_op = &fwd_op;
}
}
PADDLE_ENFORCE_NOT_NULL(matched_fwd_op,
"Cannot find matched forward while op.");
ModifyWhileOpAndWhileGradOpAttr(*matched_fwd_op, bwd_op);
while_op_set.erase(*matched_fwd_op);
}
}
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
int block_id,
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops) {
// If block_id is not 0, returns
// This is because all while_ops and while_grad_ops in the whole program
// would be processed when block_id is 0 (i.e. when Executor::Run() or
// ParallelExecutor constructs).
// What's more, all while_ops and while_grad_ops must be processed when
// block_id is zero. If not, while_op may run first and erase variables
// used in while_grad_op, and in this moment, while_grad_ops may be not
// constructed yet.
if (block_id != 0) return;
std::vector<OpVariant> fwd_ops, bwd_ops;
for (auto &op : all_ops) {
if (op->Type() == "while") {
fwd_ops.emplace_back(op.get());
} else if (op->Type() == "while_grad") {
bwd_ops.emplace_back(op.get());
}
}
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops);
}
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
const std::vector<framework::OperatorBase *> &while_ops,
const std::vector<framework::OperatorBase *> &while_grad_ops) {
std::vector<OpVariant> fwd_ops, bwd_ops;
fwd_ops.reserve(while_ops.size());
for (auto *op : while_ops) {
fwd_ops.emplace_back(op);
}
bwd_ops.reserve(while_grad_ops.size());
for (auto *op : while_grad_ops) {
bwd_ops.emplace_back(op);
}
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops);
}
} // namespace operators
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle {
namespace operators {
static constexpr char kStepBlock[] = "sub_block";
static constexpr char kCondition[] = "Condition";
static constexpr char kStepScopes[] = "StepScopes";
static constexpr char kX[] = "X";
static constexpr char kXGRAD[] = "X@GRAD";
static constexpr char kOutputs[] = "Out";
static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
int block_id,
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops);
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
const std::vector<framework::OperatorBase *> &while_ops,
const std::vector<framework::OperatorBase *> &while_grad_ops);
} // namespace operators
} // namespace paddle
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/interpolate_op.h" #include "paddle/fluid/operators/interpolate_op.h"
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -209,6 +210,9 @@ class InterpolateGradDescMaker : public framework::SingleGradOpDescMaker { ...@@ -209,6 +210,9 @@ class InterpolateGradDescMaker : public framework::SingleGradOpDescMaker {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType(ForwardOp().Type() + "_grad"); op->SetType(ForwardOp().Type() + "_grad");
op->SetInput("X", Input("X")); op->SetInput("X", Input("X"));
if (ForwardOp().Inputs().count("OutSize") > 0) {
op->SetInput("OutSize", Input("OutSize"));
}
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs()); op->SetAttrMap(Attrs());
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/lstm_op.h" #include "paddle/fluid/operators/lstm_op.h"
#include <memory>
#include <string> #include <string>
namespace paddle { namespace paddle {
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/margin_rank_loss_op.h" #include "paddle/fluid/operators/margin_rank_loss_op.h"
#include <memory>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/mean_op.h" #include "paddle/fluid/operators/mean_op.h"
#include <memory>
#include <string> #include <string>
#include <unordered_map>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/multiplex_op.h" #include "paddle/fluid/operators/multiplex_op.h"
#include <memory>
#include <vector> #include <vector>
namespace paddle { namespace paddle {
......
...@@ -15,24 +15,24 @@ limitations under the License. */ ...@@ -15,24 +15,24 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/controlflow/loop_op_helper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
constexpr char kInputs[] = "inputs";
using recurrent::kInputs; constexpr char kInitialStates[] = "initial_states";
using recurrent::kInitialStates; constexpr char kParameters[] = "parameters";
using recurrent::kParameters; constexpr char kOutputs[] = "outputs";
using recurrent::kOutputs; constexpr char kStepScopes[] = "step_scopes";
using recurrent::kStepScopes; constexpr char kExStates[] = "ex_states";
using recurrent::kExStates; constexpr char kStates[] = "states";
using recurrent::kStates; constexpr char kStepBlock[] = "sub_block";
using recurrent::kReverse; constexpr char kReverse[] = "reverse";
using recurrent::kIsTrain; constexpr char kIsTrain[] = "is_train";
using recurrent::kInputGrads; #define GRAD_SUFFIX "@GRAD"
using recurrent::kOutputGrads; constexpr char kInputGrads[] = "inputs" GRAD_SUFFIX;
using recurrent::kParamGrads; constexpr char kOutputGrads[] = "outputs" GRAD_SUFFIX;
using recurrent::kInitStateGrads; constexpr char kParamGrads[] = "parameters" GRAD_SUFFIX;
constexpr char kInitStateGrads[] = "initial_states" GRAD_SUFFIX;
using StepScopeVar = std::vector<framework::Scope *>; using StepScopeVar = std::vector<framework::Scope *>;
...@@ -249,9 +249,6 @@ class RecurrentOp : public RecurrentBase { ...@@ -249,9 +249,6 @@ class RecurrentOp : public RecurrentBase {
framework::Executor executor(place); framework::Executor executor(place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock); auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto &keep_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
VLOG(2) << GetSkipEagerDeletionVarsDebugString(keep_vars);
auto *program = block->Program(); auto *program = block->Program();
for (size_t i = 0; i < seq_len; ++i) { for (size_t i = 0; i < seq_len; ++i) {
...@@ -286,7 +283,8 @@ class RecurrentOp : public RecurrentBase { ...@@ -286,7 +283,8 @@ class RecurrentOp : public RecurrentBase {
// Every inputs are linked now, execute! // Every inputs are linked now, execute!
executor.Run(*program, &cur_scope, block->ID(), executor.Run(*program, &cur_scope, block->ID(),
false /*create_local_scope*/, true /*create_vars*/, false /*create_local_scope*/, true /*create_vars*/,
keep_vars); std::vector<std::string>() /*skip_ref_cnt_vars*/,
true /*force_disable_gc*/);
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool &pool =
...@@ -343,9 +341,6 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -343,9 +341,6 @@ class RecurrentGradOp : public RecurrentBase {
auto *block = Attr<framework::BlockDesc *>(kStepBlock); auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program(); auto *program = block->Program();
auto &keep_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
VLOG(2) << GetSkipEagerDeletionVarsDebugString(keep_vars);
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
...@@ -406,7 +401,8 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -406,7 +401,8 @@ class RecurrentGradOp : public RecurrentBase {
// Run step block with cur_scope // Run step block with cur_scope
executor.Run(*program, &cur_scope, block->ID(), executor.Run(*program, &cur_scope, block->ID(),
false /*create_local_scope*/, true /*create_vars*/, false /*create_local_scope*/, true /*create_vars*/,
keep_vars); std::vector<std::string>() /*skip_ref_cnt_vars*/,
true /*force_disable_gc*/);
VLOG(5) << "executor.Run finished "; VLOG(5) << "executor.Run finished ";
...@@ -583,10 +579,6 @@ if reverse is True ...@@ -583,10 +579,6 @@ if reverse is True
o o o o o o o o
)DOC").SetDefault(false); )DOC").SetDefault(false);
AddAttr<bool>(kIsTrain, "").SetDefault(true); AddAttr<bool>(kIsTrain, "").SetDefault(true);
AddAttr<std::vector<std::string>>(kSkipEagerDeletionVars,
"Skip vars that would "
"be used in backward ops")
.SetDefault(std::vector<std::string>());
AddComment(R"DOC( AddComment(R"DOC(
Static Length Recurrent Operator. Static Length Recurrent Operator.
...@@ -622,11 +614,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -622,11 +614,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker {
this->OutputGrad(output_param)); this->OutputGrad(output_param));
} }
} }
grad->SetAttrMap(this->Attrs());
auto attrs = this->Attrs();
attrs.insert({kSkipEagerDeletionVars, std::vector<std::string>()});
grad->SetAttrMap(attrs);
grad->SetBlockAttr(kStepBlock, grad_block_[0]); grad->SetBlockAttr(kStepBlock, grad_block_[0]);
return std::unique_ptr<framework::OpDesc>(grad); return std::unique_ptr<framework::OpDesc>(grad);
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/scatter_op.h" #include "paddle/fluid/operators/scatter_op.h"
#include <memory>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
namespace paddle { namespace paddle {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册