未验证 提交 a802da65 编写于 作者: Z Zeng Jinle 提交者: GitHub

Feature/mem opt pass refactor (#18735)

* first version memory optimize pass, test=develop

* remove move_tensor_sharing_pass, test=develop

* refine code comments, add unittests, test=develop

* turn off memory_optimize by default, test=develop

* follow huihuang's comments, test=develop

* follow chengduoZH's comments, test=develop

* fix grammar error, add const qualifier, fix pass_test exception message, test=develop

* follow chengduoZH's comments 2nd, test=develop
上级 c5f47c21
...@@ -3,7 +3,10 @@ cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context ...@@ -3,7 +3,10 @@ cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(share_tensor_buffer_functor SRCS share_tensor_buffer_functor.cc DEPS framework_proto scope place operator op_registry)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(share_tensor_buffer_op_handle SRCS share_tensor_buffer_op_handle.cc DEPS op_handle_base scope computation_op_handle share_tensor_buffer_functor)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper) cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper)
...@@ -59,12 +62,7 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d ...@@ -59,12 +62,7 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper) cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
cc_library(share_tensor_buffer_op_handle SRCS share_tensor_buffer_op_handle.cc DEPS op_handle_base scope) set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass buffer_shared_inplace_op_pass buffer_shared_cross_op_memory_reuse_pass)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass buffer_shared_inplace_op_pass)
if (WITH_GPU)
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
endif()
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h" #include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h" #include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
......
...@@ -108,11 +108,6 @@ struct BuildStrategy { ...@@ -108,11 +108,6 @@ struct BuildStrategy {
// FLAGS_use_mkldnn=false // FLAGS_use_mkldnn=false
std::unordered_set<std::string> mkldnn_enabled_op_types_; std::unordered_set<std::string> mkldnn_enabled_op_types_;
// FIXME(liuwei1031) disable memory_optimzie and enable_inplace in 1.4
// to open them by default, we need to solve the fetch variable issue
// TODO(liuwei1031): memory_optimize depends on kStaleProgramOpDescs,
// it is not appropriate, because kStaleProgramOpDescs will be removed in the
// near future.
bool memory_optimize_{false}; bool memory_optimize_{false};
// Turn on inplace by default. // Turn on inplace by default.
......
...@@ -108,6 +108,8 @@ class OpHandleBase { ...@@ -108,6 +108,8 @@ class OpHandleBase {
ir::Node *Node() { return node_; } ir::Node *Node() { return node_; }
const ir::Node *Node() const { return node_; }
void SetLocalExecScopes( void SetLocalExecScopes(
const std::unordered_map<Scope *, Scope *> &scope_map); const std::unordered_map<Scope *, Scope *> &scope_map);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/share_tensor_buffer_functor.h"
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace details {
// TODO(zjl): support SelectedRows
static inline const Tensor &GetTensorFromVar(const Variable *var) {
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>();
} else {
PADDLE_THROW("Variable must be type of LoDTensor");
}
}
static inline Tensor *GetMutableTensorFromVar(Variable *var) {
if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>();
} else {
PADDLE_THROW("Variable must be type of LoDTensor");
}
}
ShareTensorBufferFunctor::ShareTensorBufferFunctor(
Scope *scope, size_t scope_idx, const std::string &op_type,
const std::vector<const ir::MemOptVarInfo *> &in_var_infos,
const std::vector<std::string> &out_var_names)
: scope_(scope),
scope_idx_(scope_idx),
op_type_(op_type),
in_var_infos_(in_var_infos),
out_var_names_(out_var_names) {
PADDLE_ENFORCE_EQ(in_var_infos_.size(), out_var_names_.size());
for (size_t i = 0; i < in_var_infos_.size(); ++i) {
AddReuseVarPair(in_var_infos_[i], out_var_names_[i]);
}
}
std::unordered_map<std::string, std::string>
ShareTensorBufferFunctor::ReusedVars() const {
std::unordered_map<std::string, std::string> result;
for (size_t i = 0; i < in_var_infos_.size(); ++i) {
result.insert({in_var_infos_[i]->Name(), out_var_names_[i]});
}
return result;
}
void ShareTensorBufferFunctor::AddReuseVarPair(
const ir::MemOptVarInfo *in_var_info, const std::string &out_var_name) {
PADDLE_ENFORCE_NOT_NULL(in_var_info, "in_var_info cannot be nullptr");
PADDLE_ENFORCE_NE(in_var_info->Name(), out_var_name,
"in/out cannot have same name: %s", out_var_name);
in_var_infos_.emplace_back(in_var_info);
out_var_names_.emplace_back(out_var_name);
}
void ShareTensorBufferFunctor::CallOnce() {
PADDLE_ENFORCE(in_out_vars_.empty(), "in_out_vars_ must be initialized here");
for (size_t i = 0; i < in_var_infos_.size(); ++i) {
auto *in_var = exec_scope_->FindVar(in_var_infos_[i]->Name());
auto *out_var = exec_scope_->FindVar(out_var_names_[i]);
PADDLE_ENFORCE_NOT_NULL(in_var);
PADDLE_ENFORCE_NOT_NULL(out_var);
PADDLE_ENFORCE_NE(in_var, out_var);
in_out_vars_.emplace_back(in_var, out_var);
}
}
void ShareTensorBufferFunctor::operator()(Scope *exec_scope) {
if (!exec_scope_) {
PADDLE_ENFORCE_NOT_NULL(exec_scope);
exec_scope_ = exec_scope;
CallOnce();
} else {
PADDLE_ENFORCE(exec_scope_ == exec_scope, "Scope must be the same");
}
for (size_t i = 0; i < in_var_infos_.size(); ++i) {
const auto &in_tensor = GetTensorFromVar(in_out_vars_[i].first);
auto *out_tensor = GetMutableTensorFromVar(in_out_vars_[i].second);
auto *in_var_info = in_var_infos_[i];
if (UNLIKELY(in_var_info->IsSkipped())) {
// If in_var is inplaced in the previous batch and we want to fetch
// in_var in the current batch, we have to reset memory of out_var
// to avoid wrong calculation result.
if (in_tensor.Holder() == out_tensor->Holder()) {
VLOG(1) << "Clear " << out_var_names_[i]
<< " because you may want to fetch an inplaced variable "
<< in_var_info->Name()
<< " in previous batch: " << in_var_info->Name() << " -> "
<< out_var_names_[i];
out_tensor->clear();
}
} else {
out_tensor->ShareBufferWith(in_tensor);
VLOG(2) << "Share tensor buffer when running " << op_type_ << " : "
<< in_var_info->Name() << " -> " << out_var_names_[i];
}
}
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace framework {
namespace details {
// NOTE(paddle-dev): ShareTensorBufferFunctor is responsible for
// performing memory reuse in run-time. ShareTensorBufferOpHandle
// is only a wrapper of ShareTensorBufferFunctor.
// Once we find the run-time memory reuse strategy is time-consuming in
// scheduling, we should need a pass to move ShareTensorBufferFunctor into
// each ComputationOpHandle. ShareTensorBufferFunctor is preserved for
// this probable movement.
class ShareTensorBufferFunctor {
public:
ShareTensorBufferFunctor(
Scope *scope, size_t scope_idx, const std::string &op_type,
const std::vector<const ir::MemOptVarInfo *> &in_var_infos,
const std::vector<std::string> &out_var_names);
void AddReuseVarPair(const ir::MemOptVarInfo *in_var_info,
const std::string &out_var_name);
void operator()(Scope *exec_scope);
std::unordered_map<std::string, std::string> ReusedVars() const;
size_t GetScopeIdx() const { return scope_idx_; }
Scope *GetScope() { return scope_; }
private:
void CallOnce();
private:
Scope *scope_;
Scope *exec_scope_{nullptr};
size_t scope_idx_;
std::string op_type_;
std::vector<const ir::MemOptVarInfo *> in_var_infos_;
std::vector<std::string> out_var_names_;
std::vector<std::pair<const Variable *, Variable *>> in_out_vars_;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -25,55 +25,42 @@ namespace paddle { ...@@ -25,55 +25,42 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
// TODO(zjl): support SelectedRows ComputationOpHandle *GetUniquePendingComputationOpHandle(
static inline const Tensor &GetTensorFromVar(const Variable *var) { ShareTensorBufferOpHandle *share_tensor_op) {
if (var->IsType<LoDTensor>()) { ComputationOpHandle *result_op = nullptr;
return var->Get<LoDTensor>(); for (ir::Node *out_var : share_tensor_op->Node()->outputs) {
} else { for (ir::Node *pending_op : out_var->outputs) {
PADDLE_THROW("Variable must be type of LoDTensor"); auto &op = pending_op->Wrapper<OpHandleBase>();
} auto *compute_op = dynamic_cast<ComputationOpHandle *>(&op);
} PADDLE_ENFORCE_NOT_NULL(compute_op);
static inline Tensor *GetMutableTensorFromVar(Variable *var) { if (result_op == nullptr) {
if (var->IsType<LoDTensor>()) { result_op = compute_op;
return var->GetMutable<LoDTensor>(); } else {
} else { PADDLE_ENFORCE_EQ(result_op, compute_op);
PADDLE_THROW("Variable must be type of LoDTensor"); }
}
} }
PADDLE_ENFORCE_NOT_NULL(result_op);
return result_op;
} }
ShareTensorBufferOpHandle::ShareTensorBufferOpHandle( ShareTensorBufferOpHandle::ShareTensorBufferOpHandle(
ir::Node *node, Scope *scope, size_t scope_idx, const std::string &op_type, ir::Node *node, Scope *scope, size_t scope_idx, const std::string &op_type,
const std::vector<ir::MemOptVarInfo *> &in_var_infos, const std::vector<const ir::MemOptVarInfo *> &in_var_infos,
const std::vector<std::string> &out_var_names) const std::vector<std::string> &out_var_names)
: OpHandleBase(node), : OpHandleBase(node),
scope_(scope), functor_(scope, scope_idx, op_type, in_var_infos, out_var_names) {}
scope_idx_(scope_idx),
op_type_(op_type),
in_var_infos_(in_var_infos),
out_var_names_(out_var_names) {
PADDLE_ENFORCE_EQ(in_var_infos_.size(), out_var_names_.size());
for (size_t i = 0; i < in_var_infos_.size(); ++i) {
Add(in_var_infos_[i], out_var_names_[i]);
}
}
std::unordered_set<std::string> ShareTensorBufferOpHandle::ReusedVarSet() std::unordered_map<std::string, std::string>
const { ShareTensorBufferOpHandle::ReusedVars() const {
std::unordered_set<std::string> result; return functor_.ReusedVars();
for (auto &in_var_info : in_var_infos_) {
result.insert(in_var_info->Name());
}
return result;
} }
void ShareTensorBufferOpHandle::Add(ir::MemOptVarInfo *in_var_info, void ShareTensorBufferOpHandle::AddReuseVarPair(
const std::string &out_var_name) { const ir::MemOptVarInfo *in_var_info, const std::string &out_var_name) {
PADDLE_ENFORCE_NOT_NULL(in_var_info, "in_var_info cannot be nullptr"); functor_.AddReuseVarPair(in_var_info, out_var_name);
PADDLE_ENFORCE_NE(in_var_info->Name(), out_var_name,
"in/out cannot have same name: %s", out_var_name);
in_var_infos_.emplace_back(in_var_info);
out_var_names_.emplace_back(out_var_name);
} }
void ShareTensorBufferOpHandle::InitCUDA() { void ShareTensorBufferOpHandle::InitCUDA() {
...@@ -84,49 +71,7 @@ void ShareTensorBufferOpHandle::InitCUDA() { ...@@ -84,49 +71,7 @@ void ShareTensorBufferOpHandle::InitCUDA() {
#endif #endif
} }
void ShareTensorBufferOpHandle::CallOnce() { void ShareTensorBufferOpHandle::RunImpl() { functor_(local_exec_scopes_[0]); }
PADDLE_ENFORCE(in_out_vars_.empty(), "in_out_vars_ must be initialized here");
Scope *exec_scope = local_exec_scopes_[0];
for (size_t i = 0; i < in_var_infos_.size(); ++i) {
auto *in_var = exec_scope->FindVar(in_var_infos_[i]->Name());
auto *out_var = exec_scope->FindVar(out_var_names_[i]);
PADDLE_ENFORCE_NOT_NULL(in_var);
PADDLE_ENFORCE_NOT_NULL(out_var);
PADDLE_ENFORCE_NE(in_var, out_var);
in_out_vars_.emplace_back(in_var, out_var);
}
}
void ShareTensorBufferOpHandle::RunImpl() {
if (in_var_infos_.size() != in_out_vars_.size()) {
CallOnce();
}
for (size_t i = 0; i < in_var_infos_.size(); ++i) {
const auto &in_tensor = GetTensorFromVar(in_out_vars_[i].first);
auto *out_tensor = GetMutableTensorFromVar(in_out_vars_[i].second);
auto *in_var_info = in_var_infos_[i];
if (UNLIKELY(in_var_info->IsSkipped())) {
// If in_var is inplaced in the previous batch and we want to fetch
// in_var in the current batch, we have to reset memory of out_var
// to avoid wrong calculation result.
if (in_tensor.Holder() == out_tensor->Holder()) {
VLOG(1) << "Clear " << out_var_names_[i]
<< " because you may want to fetch an inplaced variable "
<< in_var_info->Name()
<< " in previous batch: " << in_var_info->Name() << " -> "
<< out_var_names_[i];
out_tensor->clear();
}
} else {
out_tensor->ShareBufferWith(in_tensor);
VLOG(2) << "Share tensor buffer when running " << op_type_ << " : "
<< in_var_info->Name() << " -> " << out_var_names_[i];
}
}
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -14,22 +14,15 @@ ...@@ -14,22 +14,15 @@
#pragma once #pragma once
#include <string> #include <string>
#include <unordered_set> #include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/share_tensor_buffer_functor.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class Variable;
class Scope;
class Tensor;
namespace ir {
class MemOptVarInfo;
} // namespace ir
namespace details { namespace details {
class ShareTensorBufferOpHandle : public OpHandleBase { class ShareTensorBufferOpHandle : public OpHandleBase {
...@@ -37,16 +30,19 @@ class ShareTensorBufferOpHandle : public OpHandleBase { ...@@ -37,16 +30,19 @@ class ShareTensorBufferOpHandle : public OpHandleBase {
ShareTensorBufferOpHandle( ShareTensorBufferOpHandle(
ir::Node *node, Scope *scope, size_t scope_idx, ir::Node *node, Scope *scope, size_t scope_idx,
const std::string &op_type, const std::string &op_type,
const std::vector<ir::MemOptVarInfo *> &in_vars_infos, const std::vector<const ir::MemOptVarInfo *> &in_vars_infos,
const std::vector<std::string> &out_var_names); const std::vector<std::string> &out_var_names);
std::unordered_set<std::string> ReusedVarSet() const; std::unordered_map<std::string, std::string> ReusedVars() const;
Priority GetPriority() const override { return Priority::kHighest; } Priority GetPriority() const override { return Priority::kHighest; }
size_t GetScopeIdx() const { return scope_idx_; } size_t GetScopeIdx() const { return functor_.GetScopeIdx(); }
void AddReuseVarPair(const ir::MemOptVarInfo *in_var_info,
const std::string &out_var_name);
void Add(ir::MemOptVarInfo *in_var_info, const std::string &ou_var_name); const ShareTensorBufferFunctor &Functor() const { return functor_; }
protected: protected:
std::string Name() const override { return "buffer_share"; } std::string Name() const override { return "buffer_share"; }
...@@ -55,20 +51,17 @@ class ShareTensorBufferOpHandle : public OpHandleBase { ...@@ -55,20 +51,17 @@ class ShareTensorBufferOpHandle : public OpHandleBase {
void InitCUDA() override; void InitCUDA() override;
std::vector<Scope *> GetLocalScopes() override { return {scope_}; } std::vector<Scope *> GetLocalScopes() override {
return {functor_.GetScope()};
}
private: private:
void CallOnce(); ShareTensorBufferFunctor functor_;
Scope *scope_;
size_t scope_idx_;
std::string op_type_;
std::vector<ir::MemOptVarInfo *> in_var_infos_;
std::vector<std::string> out_var_names_;
std::vector<std::pair<const Variable *, Variable *>> in_out_vars_;
}; };
ComputationOpHandle *GetUniquePendingComputationOpHandle(
ShareTensorBufferOpHandle *share_tensor_op);
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <utility> #include <utility>
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
...@@ -74,12 +75,16 @@ struct VarHandleBase { ...@@ -74,12 +75,16 @@ struct VarHandleBase {
OpHandleBase* GeneratedOp() { return generated_op_; } OpHandleBase* GeneratedOp() { return generated_op_; }
const OpHandleBase* GeneratedOp() const { return generated_op_; }
const std::unordered_set<OpHandleBase*>& PendingOps() const { const std::unordered_set<OpHandleBase*>& PendingOps() const {
return pending_ops_; return pending_ops_;
} }
ir::Node* Node() { return node_; } ir::Node* Node() { return node_; }
const ir::Node* Node() const { return node_; }
protected: protected:
// The operator who generate this variable. nullptr if the variable // The operator who generate this variable. nullptr if the variable
// is a root node. // is a root node.
...@@ -96,6 +101,9 @@ struct VarHandleBase { ...@@ -96,6 +101,9 @@ struct VarHandleBase {
// //
// NOTE: runtime variables have place. // NOTE: runtime variables have place.
struct VarHandle : public VarHandleBase { struct VarHandle : public VarHandleBase {
DISABLE_COPY_AND_ASSIGN(VarHandle);
public:
virtual ~VarHandle(); virtual ~VarHandle();
std::string DebugString() const override; std::string DebugString() const override;
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/pass_builder.h" #include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
......
...@@ -22,3 +22,4 @@ cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pas ...@@ -22,3 +22,4 @@ cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pas
cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle multi_devices_helper graph pass) cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle multi_devices_helper graph pass)
cc_library(buffer_shared_inplace_op_pass SRCS buffer_shared_inplace_op_pass.cc DEPS memory_reuse_pass) cc_library(buffer_shared_inplace_op_pass SRCS buffer_shared_inplace_op_pass.cc DEPS memory_reuse_pass)
cc_library(buffer_shared_cross_op_memory_reuse_pass SRCS buffer_shared_cross_op_memory_reuse_pass.cc DEPS memory_reuse_pass)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/share_tensor_buffer_op_handle.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
using OpHandleBase = details::OpHandleBase;
using ComputationOpHandle = details::ComputationOpHandle;
using VarHandle = details::VarHandle;
using VarHandleBase = details::VarHandleBase;
using DummyVarHandle = details::DummyVarHandle;
enum NodeDependency { kSame = 0, kNoDep = 1, kBefore = 2, kAfter = 3 };
static NodeDependency ReverseNodeDependency(NodeDependency dep) {
return dep == NodeDependency::kBefore
? NodeDependency::kAfter
: (dep == NodeDependency::kAfter ? NodeDependency::kBefore : dep);
}
class BufferSharedCrossOpMemoryReusePass : public MemoryReusePass {
protected:
std::string ReuseType() const override { return "cross_op_memory_reuse"; }
void Run(Graph *graph) const override;
private:
void RunOnScopeIdx(size_t idx) const;
// Toposort ops. Different strategies can be used in the future.
std::vector<OpHandleBase *> SortOp(const OpGraphView &graph_view) const;
// Build the initial dependency matrix, and initializing all fields,
// including `ops_`, `op_to_idx_`, `deps_`
void BuildOpDependencyMap() const;
// Get op index inside `ops_`, used to find dependency inside `deps_`
size_t OpIndex(const ComputationOpHandle *op) const;
size_t ResolveDependencyBetween(
ComputationOpHandle *op,
const std::unordered_set<ComputationOpHandle *> &prev_ops) const;
// Get dependency relationship between op1 and op2
// Notice: GetOpDep(op1, op2) == ReverseNodeDependency(GetOpDep(op2, op1))
NodeDependency GetOpDep(const ComputationOpHandle *op1,
const ComputationOpHandle *op2) const;
void SetOpDep(const ComputationOpHandle *op1, const ComputationOpHandle *op2,
NodeDependency dep) const;
private:
mutable Graph *graph_;
// All ops in the graph, grouped by scope index
mutable std::vector<std::vector<ComputationOpHandle *>> ops_;
// Index of each op in `ops_`, grouped by scope index.
// Index of each op is the index inside `deps_`.
mutable std::vector<std::unordered_map<const ComputationOpHandle *, size_t>>
op_to_idx_;
// Dependency matrix of between any 2 ops
// If deps_[scope_idx][i][j] is equal to:
// 1. kSame, Op(i) and Op(j) are the same ops, only when i == j.
// 2. kNoDep, Op(i) and Op(j) have no dependency between each other.
// 3. kBefore, Op(i) is the preceding op of Op(j).
// 4. kAfter, Op(i) is the pending op of Op(j).
mutable std::vector<std::vector<std::vector<NodeDependency>>> deps_;
};
void BufferSharedCrossOpMemoryReusePass::Run(Graph *graph) const {
graph_ = graph;
BuildOpDependencyMap();
for (size_t i = 0; i < ScopeNum(); ++i) {
RunOnScopeIdx(i);
}
}
// Note(zjl): The reason why I separate SortOp from BuildOpDependencyMap()
// is that we can use different sorting strategies in the future to
// evaluate the effects of different sorting strategies.
// Currently, I use BFS, but we can use other kinds of sorting strategy
// in the future, as long as the new strategy reaches higher memory reuse
// ratio.
std::vector<OpHandleBase *> BufferSharedCrossOpMemoryReusePass::SortOp(
const OpGraphView &graph_view) const {
std::vector<OpHandleBase *> sorted_ops;
sorted_ops.reserve(graph_view.OpNumber());
graph_view.BreadthFirstVisit(
[&](OpHandleBase *cur_op) { sorted_ops.emplace_back(cur_op); });
PADDLE_ENFORCE_EQ(sorted_ops.size(), graph_view.OpNumber(),
"There are unvisited ops");
return sorted_ops;
}
/**
* Try to reuse unlived vars.
*
* What we do is: transverse all outputs of each op, and find a suitable
* unused var, and then reuse its memory as output.
*
* How to determine unused vars?
*
* Case 1: unlived vars after all preceding ops run. In this case, no extra
* edge would be added to the graph.
*
* Case 2: unlived vars after all preceding ops and all no-dep ops run. In
* this case, the reused var is from no-dep ops, so that we have to add
* extra edge to resolve data hazard.
*
*
* If Case 2 occurs, what we should do to resolve data hazard?
*
* - Step 1: add a dep var between reused_op and share_tensor_buffer_op,
* that is: reused_op -> dep_var -> share_tensor_buffer_op.
*
* - Step 2: Update deps_, all preceding ops of reused_op should be
* preceding ops of op.
*/
void BufferSharedCrossOpMemoryReusePass::RunOnScopeIdx(size_t idx) const {
auto &ops = ops_[idx];
auto &last_live_ops_of_vars =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars)[idx];
// Build a reverse map of `last_live_ops_of_vars`,
// i.e., VarHandle -> last lived ops of VarHandle
std::unordered_map<VarHandle *, std::unordered_set<ComputationOpHandle *>>
var_to_ops;
for (auto &pair : last_live_ops_of_vars) {
for (auto *op : pair.second.ops()) {
var_to_ops[pair.second.var()].insert(op);
}
}
// Deep copy of `var_to_ops`, used to get last lived ops of each unlived var
auto original_var_to_ops = var_to_ops;
// Memory size of VarHandle -> list<VarHandle>
std::map<int64_t, std::list<VarHandle *>> unlived_var_pool;
size_t reuse_num = 0;
for (auto *op : ops) {
// Transverse all output args of op, find whether there is unlived var
// can be reused.
auto out_args = op->Node()->Op()->OutputArgumentNames();
for (auto &out_arg : out_args) {
auto out_nodes = this->FindNodesByName(out_arg, op->Node()->outputs);
// If out_arg is kEmptyVarName, it may not be found in output nodes.
if (out_nodes.size() != 1) {
continue;
}
auto *out_node = *(out_nodes.begin());
auto *out_var =
dynamic_cast<VarHandle *>(&(out_node->Wrapper<VarHandleBase>()));
PADDLE_ENFORCE_NOT_NULL(out_var);
// If out_arg is not reusable, skip it
if (!IsOutVarReusable(*out_var)) {
continue;
}
auto mem_size = GetMemorySize(*out_var);
// Special case: if memory size of out_var is 0, skip it
if (mem_size == 0) {
continue;
}
// Find a suitable unlived var from `unlived_var_pool`
// Here, we use `find`, but we can perform `lower_bound` if
// it is better in the future.
auto iter = unlived_var_pool.find(std::abs(mem_size));
if (iter == unlived_var_pool.end()) {
continue;
}
// Obtain candidate_vars that can be reused.
auto &candidate_vars = iter->second;
for (auto var_iter = candidate_vars.begin();
var_iter != candidate_vars.end(); ++var_iter) {
bool success = this->TryReuseVar(*var_iter, out_var);
if (!success) continue;
// If memory reuse is successful, we should do some post-processing.
++reuse_num;
auto &prev_ops = original_var_to_ops.at(*var_iter);
// Add extra dependencies between `op` and last lived ops of reused var
// (i.e. prev_ops) if needed.
// All `prev_ops` must be preceding ops of op to avoid data hazard.
size_t new_added_dep_num = ResolveDependencyBetween(op, prev_ops);
VLOG(3) << "Variable can be reused between: " << (*var_iter)->Name()
<< " -> " << out_var->Name() << " when running op "
<< op->Name() << ", add extra dependency " << new_added_dep_num
<< "/" << prev_ops.size();
// erase reused var from ``original_var_to_ops`
original_var_to_ops.erase(*var_iter);
// erase reused var from `candidate_vars`
candidate_vars.erase(var_iter);
if (candidate_vars.empty()) {
// erase reused var from `unlived_var_pool` if there is no other vars
// which has same size with reused var.
unlived_var_pool.erase(iter);
}
break;
}
}
// After all output args have been transversed, we should check whether
// there is new unlived var after `op` runs.
for (auto op_iter = var_to_ops.begin(); op_iter != var_to_ops.end();) {
// erase op from `var_to_ops` first
op_iter->second.erase(op);
if (op_iter->second.empty()) {
// there is a unlived var, since all lived ops have run
VarHandle *unlived_var = op_iter->first;
var_to_ops.erase(op_iter++);
if (IsInVarReusable(*unlived_var)) {
auto mem_size = GetMemorySize(*unlived_var);
if (mem_size != 0) {
unlived_var_pool[std::abs(mem_size)].push_front(unlived_var);
}
}
} else {
++op_iter;
}
}
}
VLOG(4) << "Reuse " << reuse_num << " variable(s) in Scope " << idx;
}
size_t BufferSharedCrossOpMemoryReusePass::ResolveDependencyBetween(
ComputationOpHandle *op,
const std::unordered_set<ComputationOpHandle *> &prev_ops) const {
size_t new_added_dep_num = 0;
size_t op_idx = OpIndex(op);
auto &deps = deps_[op->GetScopeIdx()];
for (auto *prev_op : prev_ops) {
auto op_dep = GetOpDep(prev_op, op);
if (op_dep == NodeDependency::kBefore) continue;
PADDLE_ENFORCE_EQ(op_dep, NodeDependency::kNoDep,
"The graph has circle, this may be a bug");
auto iter =
std::find_if(prev_op->Outputs().begin(), prev_op->Outputs().end(),
[](VarHandleBase *var) {
return dynamic_cast<DummyVarHandle *>(var) != nullptr;
});
if (iter != prev_op->Outputs().end()) {
op->AddInput(*iter);
} else {
auto *dep_var = new DummyVarHandle(graph_->CreateControlDepVar());
graph_->Get<details::GraphDepVars>(details::kGraphDepVars)
.emplace(dep_var);
prev_op->AddOutput(dep_var);
op->AddInput(dep_var);
}
// All preceding ops of `prev_op` should be preceding ops of `op`
size_t prev_op_idx = OpIndex(prev_op);
for (size_t i = 0; i < deps[prev_op_idx].size(); ++i) {
if (deps[prev_op_idx][i] != NodeDependency::kAfter) {
continue;
}
deps[i][op_idx] = NodeDependency::kBefore;
deps[op_idx][i] = NodeDependency::kAfter;
}
// All pending ops of `op` should be pending ops of `prev_op`.
for (size_t i = 0; i < deps[op_idx].size(); ++i) {
if (deps[op_idx][i] != NodeDependency::kBefore) {
continue;
}
deps[i][prev_op_idx] = NodeDependency::kAfter;
deps[prev_op_idx][i] = NodeDependency::kBefore;
}
// `prev_op` is one of preceding op of `op`
SetOpDep(prev_op, op, NodeDependency::kBefore);
++new_added_dep_num;
}
return new_added_dep_num;
}
void BufferSharedCrossOpMemoryReusePass::BuildOpDependencyMap() const {
PADDLE_ENFORCE(ops_.empty(), "ops_ must be initialized here");
PADDLE_ENFORCE(op_to_idx_.empty(), "op_to_idx_ must be initialized here");
PADDLE_ENFORCE(deps_.empty(), "deps_ must be initialized here");
// Toposort ops
OpGraphView graph_view(ir::FilterByNodeWrapper<OpHandleBase>(*graph_));
auto ops = SortOp(graph_view);
size_t scope_num = this->ScopeNum();
size_t op_num = ops.size();
// A map to record all preceding ops of each op
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
preceding_ops;
// BFS to fill `preceding_ops`
graph_view.BreadthFirstVisit([&](OpHandleBase *cur_op) {
// All preceding ops of cur_op should be:
// - preceding ops of cur_op, that is connected to cur_op directely
// - all preceding ops of `direct preceding ops of cur_op`
auto &all_preceding_ops_of_cur_op = preceding_ops[cur_op];
for (auto &preceding_op : graph_view.PrecedingOps(cur_op)) {
all_preceding_ops_of_cur_op.insert(preceding_op);
auto &prev_preceding_ops = preceding_ops[preceding_op];
all_preceding_ops_of_cur_op.insert(prev_preceding_ops.begin(),
prev_preceding_ops.end());
}
});
PADDLE_ENFORCE_EQ(preceding_ops.size(), op_num);
// Find out ComputationOpHandles only
ops_.resize(scope_num);
op_to_idx_.resize(scope_num);
for (auto *op : ops) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op == nullptr) continue;
size_t scope_idx = compute_op->GetScopeIdx();
ops_[scope_idx].emplace_back(compute_op);
op_to_idx_[scope_idx].emplace(compute_op, op_to_idx_[scope_idx].size());
}
// Fill deps_ according to `preceding_ops`
deps_.resize(scope_num);
for (size_t i = 0; i < deps_.size(); ++i) {
deps_[i].resize(ops_[i].size());
for (auto &item : deps_[i]) {
item.assign(ops_[i].size(), NodeDependency::kNoDep);
}
}
for (auto &ops_on_each_device : ops_) {
for (auto *op : ops_on_each_device) {
SetOpDep(op, op, NodeDependency::kSame);
for (auto *preceding_op : preceding_ops[op]) {
auto *compute_preceding_op =
dynamic_cast<ComputationOpHandle *>(preceding_op);
if (compute_preceding_op != nullptr &&
compute_preceding_op->GetScopeIdx() == op->GetScopeIdx()) {
SetOpDep(compute_preceding_op, op, NodeDependency::kBefore);
}
}
}
}
}
size_t BufferSharedCrossOpMemoryReusePass::OpIndex(
const ComputationOpHandle *op) const {
auto iter = op_to_idx_[op->GetScopeIdx()].find(op);
PADDLE_ENFORCE(iter != op_to_idx_[op->GetScopeIdx()].end());
return iter->second;
}
NodeDependency BufferSharedCrossOpMemoryReusePass::GetOpDep(
const ComputationOpHandle *op1, const ComputationOpHandle *op2) const {
PADDLE_ENFORCE_EQ(op1->GetScopeIdx(), op2->GetScopeIdx());
return deps_[op1->GetScopeIdx()][OpIndex(op1)][OpIndex(op2)];
}
void BufferSharedCrossOpMemoryReusePass::SetOpDep(
const ComputationOpHandle *op1, const ComputationOpHandle *op2,
NodeDependency dep) const {
PADDLE_ENFORCE_EQ(op1->GetScopeIdx(), op2->GetScopeIdx());
if (op1 == op2) {
PADDLE_ENFORCE(dep == NodeDependency::kSame);
auto idx = OpIndex(op1);
deps_[op1->GetScopeIdx()][idx][idx] = NodeDependency::kSame;
} else {
auto idx1 = OpIndex(op1);
auto idx2 = OpIndex(op2);
PADDLE_ENFORCE(dep != NodeDependency::kSame && idx1 != idx2);
deps_[op1->GetScopeIdx()][idx1][idx2] = dep;
deps_[op1->GetScopeIdx()][idx2][idx1] = ReverseNodeDependency(dep);
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(buffer_shared_cross_op_memory_reuse_pass,
paddle::framework::ir::BufferSharedCrossOpMemoryReusePass)
.RequirePassAttr(paddle::framework::ir::kMemOptVarInfoMapList)
.RequirePassAttr(paddle::framework::ir::kLastLiveOpsOfVars)
.RequirePassAttr(paddle::framework::ir::kUseCuda);
...@@ -50,11 +50,11 @@ void BufferSharedInplaceOpPass::Run(Graph *graph) const { ...@@ -50,11 +50,11 @@ void BufferSharedInplaceOpPass::Run(Graph *graph) const {
for (auto &pair : each_scope_ops) { for (auto &pair : each_scope_ops) {
// If variable has more than 1 last lived ops, this variable cannot // If variable has more than 1 last lived ops, this variable cannot
// be inplaced. // be inplaced.
if (pair.second.size() != 1) { if (pair.second.ops().size() != 1) {
continue; continue;
} }
auto *op = *(pair.second.begin()); auto *op = *(pair.second.ops().begin());
const std::string &op_type = op->GetOp()->Type(); const std::string &op_type = op->GetOp()->Type();
const framework::OpDesc *op_desc = op->Node()->Op(); const framework::OpDesc *op_desc = op->Node()->Op();
PADDLE_ENFORCE_NOT_NULL(op_desc); PADDLE_ENFORCE_NOT_NULL(op_desc);
...@@ -141,7 +141,7 @@ void BufferSharedInplaceOpPass::Run(Graph *graph) const { ...@@ -141,7 +141,7 @@ void BufferSharedInplaceOpPass::Run(Graph *graph) const {
<< out_var_handle_ptr->Name() << out_var_handle_ptr->Name()
<< ". Debug String is: " << op->GetOp()->DebugString(); << ". Debug String is: " << op->GetOp()->DebugString();
} else { } else {
VLOG(4) << "Inplace failed in op " << op_type << ": " VLOG(3) << "Inplace failed in op " << op_type << ": "
<< in_var_handle_ptr->Name() << " -> " << in_var_handle_ptr->Name() << " -> "
<< out_var_handle_ptr->Name(); << out_var_handle_ptr->Name();
} }
......
...@@ -205,7 +205,7 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { ...@@ -205,7 +205,7 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
for (auto &var_ops_map : last_live_ops) { for (auto &var_ops_map : last_live_ops) {
for (auto &var_ops_pair : var_ops_map) { for (auto &var_ops_pair : var_ops_map) {
const std::string &var_name = var_ops_pair.first; const std::string &var_name = var_ops_pair.first;
for (auto *op : var_ops_pair.second) { for (auto *op : var_ops_pair.second.ops()) {
op_vars_map[op].insert(var_name); op_vars_map[op].insert(var_name);
} }
} }
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_pass.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
......
...@@ -36,8 +36,6 @@ namespace ir { ...@@ -36,8 +36,6 @@ namespace ir {
constexpr char kMemOptSkipVars[] = "@MEM_OPT_SKIP_VARS@"; constexpr char kMemOptSkipVars[] = "@MEM_OPT_SKIP_VARS@";
typedef std::unordered_set<std::string> MemOptSkipVars; typedef std::unordered_set<std::string> MemOptSkipVars;
constexpr char kUseCuda[] = "use_cuda";
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph); std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph);
// NOTE(dzh): A ordered set for node reuse in memory optimize. // NOTE(dzh): A ordered set for node reuse in memory optimize.
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h"
#include <functional>
#include <map> #include <map>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -23,37 +24,16 @@ namespace paddle { ...@@ -23,37 +24,16 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
// Each ShareTensorBufferOpHandle should only have one pending
// ComputationOpHandle
static details::ComputationOpHandle *GetUniquePendingComputationOpHandle(
details::ShareTensorBufferOpHandle *share_tensor_op) {
details::ComputationOpHandle *result_op = nullptr;
for (Node *out_var : share_tensor_op->Node()->outputs) {
for (Node *pending_op : out_var->outputs) {
auto &op = pending_op->Wrapper<details::OpHandleBase>();
auto *compute_op = dynamic_cast<details::ComputationOpHandle *>(&op);
PADDLE_ENFORCE_NOT_NULL(compute_op);
if (result_op == nullptr) {
result_op = compute_op;
} else {
PADDLE_ENFORCE_EQ(result_op, compute_op);
}
}
}
PADDLE_ENFORCE_NOT_NULL(result_op);
return result_op;
}
void MemoryReusePass::ApplyImpl(Graph *graph) const { void MemoryReusePass::ApplyImpl(Graph *graph) const {
graph_ = graph; graph_ = graph;
use_cuda_ = Get<bool>(kUseCuda);
all_vars_ = &(graph_->Get<details::GraphVars>(details::kGraphVars)); all_vars_ = &(graph_->Get<details::GraphVars>(details::kGraphVars));
var_infos_ = &(Get<MemOptVarInfoMapList>(kMemOptVarInfoMapList)); var_infos_ = &(Get<MemOptVarInfoMapList>(kMemOptVarInfoMapList));
last_live_ops_of_vars_ = last_live_ops_of_vars_ =
&(Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars)); &(Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars));
reused_var_names_.resize(all_vars_->size()); reused_in_var_names_.resize(all_vars_->size());
reused_out_var_names_.resize(all_vars_->size());
var_descs_.resize(all_vars_->size()); var_descs_.resize(all_vars_->size());
// Collect the existing ShareTensorBufferOpHandles. // Collect the existing ShareTensorBufferOpHandles.
...@@ -82,7 +62,7 @@ bool MemoryReusePass::TryReuseVar(details::VarHandle *in_var, ...@@ -82,7 +62,7 @@ bool MemoryReusePass::TryReuseVar(details::VarHandle *in_var,
auto *op = auto *op =
dynamic_cast<details::ComputationOpHandle *>(out_var->GeneratedOp()); dynamic_cast<details::ComputationOpHandle *>(out_var->GeneratedOp());
PADDLE_ENFORCE_NOT_NULL(op); PADDLE_ENFORCE_NOT_NULL(op);
if (IsVarsReusable(in_var, out_var)) { if (IsVarPairReusable(*in_var, *out_var)) {
AddReuseVar(op, in_var, out_var); AddReuseVar(op, in_var, out_var);
return true; return true;
} else { } else {
...@@ -101,28 +81,37 @@ std::unordered_set<Node *> MemoryReusePass::FindNodesByName( ...@@ -101,28 +81,37 @@ std::unordered_set<Node *> MemoryReusePass::FindNodesByName(
return ret; return ret;
} }
VarDesc *MemoryReusePass::GetVarDesc(details::VarHandle *var) const { VarDesc *MemoryReusePass::GetVarDesc(const details::VarHandle &var) const {
auto iter = var_descs_[var->scope_idx()].find(var->Name()); const auto var_name = var.Name();
if (iter == var_descs_[var->scope_idx()].end()) { size_t scope_idx = var.scope_idx();
PADDLE_ENFORCE((*all_vars_)[var->scope_idx()].count(var->Name()), auto iter = var_descs_[scope_idx].find(var_name);
"Variable %s not found", var->Name()); if (iter == var_descs_[scope_idx].end()) {
auto *desc = PADDLE_ENFORCE((*all_vars_)[scope_idx].count(var_name),
TryGetLatestVarDesc((*all_vars_)[var->scope_idx()].at(var->Name())); "Variable %s not found", var_name);
auto *desc = TryGetLatestVarDesc((*all_vars_)[scope_idx].at(var_name));
PADDLE_ENFORCE_NOT_NULL(desc); PADDLE_ENFORCE_NOT_NULL(desc);
var_descs_[var->scope_idx()].emplace(var->Name(), desc); var_descs_[scope_idx].emplace(var_name, desc);
return desc; return desc;
} else { } else {
return iter->second; return iter->second;
} }
} }
int64_t MemoryReusePass::GetMemorySize(const details::VarHandle &var) const {
auto *var_desc = GetVarDesc(var);
auto shapes = var_desc->GetShape();
return std::accumulate(shapes.begin(), shapes.end(), static_cast<int64_t>(1),
std::multiplies<int64_t>());
}
void MemoryReusePass::CollectShareTensorBufferOpHandles() const { void MemoryReusePass::CollectShareTensorBufferOpHandles() const {
auto all_ops = FilterByNodeWrapper<details::OpHandleBase>(*graph_); auto all_ops = FilterByNodeWrapper<details::OpHandleBase>(*graph_);
for (auto *op : all_ops) { for (auto *op : all_ops) {
auto *share_buffer_op = auto *share_buffer_op =
dynamic_cast<details::ShareTensorBufferOpHandle *>(op); dynamic_cast<details::ShareTensorBufferOpHandle *>(op);
if (share_buffer_op != nullptr) { if (share_buffer_op != nullptr) {
auto *compute_op = GetUniquePendingComputationOpHandle(share_buffer_op); auto *compute_op =
details::GetUniquePendingComputationOpHandle(share_buffer_op);
PADDLE_ENFORCE(ops_.count(compute_op) == 0); PADDLE_ENFORCE(ops_.count(compute_op) == 0);
ops_.emplace(compute_op, share_buffer_op); ops_.emplace(compute_op, share_buffer_op);
} }
...@@ -131,14 +120,28 @@ void MemoryReusePass::CollectShareTensorBufferOpHandles() const { ...@@ -131,14 +120,28 @@ void MemoryReusePass::CollectShareTensorBufferOpHandles() const {
void MemoryReusePass::CollectReusedVars() const { void MemoryReusePass::CollectReusedVars() const {
for (auto &pair : ops_) { for (auto &pair : ops_) {
auto reused_vars = pair.second->ReusedVarSet(); auto reused_vars = pair.second->ReusedVars();
reused_var_names_[pair.first->GetScopeIdx()].insert(reused_vars.begin(), for (auto &reused_var_pair : reused_vars) {
reused_vars.end()); reused_in_var_names_[pair.first->GetScopeIdx()].insert(
reused_var_pair.first);
reused_out_var_names_[pair.first->GetScopeIdx()].insert(
reused_var_pair.second);
}
} }
} }
bool MemoryReusePass::IsVarAlreadyReused(details::VarHandle *var) const { bool MemoryReusePass::IsInVarAlreadyReused(
return reused_var_names_[var->scope_idx()].count(var->Name()) > 0; const details::VarHandle &in_var) const {
const auto var_name = in_var.Name();
size_t scope_idx = in_var.scope_idx();
return reused_in_var_names_[scope_idx].count(var_name) > 0;
}
bool MemoryReusePass::IsOutVarAlreadyReused(
const details::VarHandle &out_var) const {
const auto var_name = out_var.Name();
size_t scope_idx = out_var.scope_idx();
return reused_out_var_names_[scope_idx].count(var_name) > 0;
} }
details::ShareTensorBufferOpHandle * details::ShareTensorBufferOpHandle *
...@@ -171,57 +174,118 @@ MemoryReusePass::InsertShareTensorBufferOpHandleToGraph( ...@@ -171,57 +174,118 @@ MemoryReusePass::InsertShareTensorBufferOpHandleToGraph(
return buffer_share_op; return buffer_share_op;
} }
bool MemoryReusePass::IsVarsReusable(details::VarHandle *in_var, /**
details::VarHandle *out_var) const { * Input var is reusable only when:
const auto in_name = in_var->Name(); * - it is not an empty var.
const auto out_name = out_var->Name(); * - it has not been reused. If an input var is reused twice or more,
* the calculation result may be wrong.
* - it is not a persistable var.
* - it is LoDTensor. We can support SelectedRows in the future.
*/
bool MemoryReusePass::IsInVarReusable(const details::VarHandle &in_var) const {
if (in_var.Name() == kEmptyVarName) {
return false;
}
if (IsInVarAlreadyReused(in_var)) {
return false;
}
const VarDesc *in_var_desc = GetVarDesc(in_var);
if (in_name == out_name) { if (in_var_desc->Persistable()) {
return false; return false;
} }
if (in_name == kEmptyVarName || out_name == kEmptyVarName) { if (in_var_desc->GetType() != proto::VarType::LOD_TENSOR) {
return false; return false;
} }
if (IsVarAlreadyReused(in_var)) { return true;
}
/**
* Output var is reusable only when:
* - it is not an empty var.
* - it is the first version var. Otherwise, the var may be overwritten
* in the second batch, which results in wrong calculation result.
* It is critical especially when
* ExecutionStrategy::num_iteration_per_drop_scope_ > 1.
* - it has not reused other var's memory. It is not necessary to do memory
* reuse twice for the same var.
* - it is not a persistable var.
* - it is LoDTensor. We can support SelectedRows in the future.
* - it does not occur in inputs of the generated op. It would happen when
* op has the same var as both input and output.
*/
bool MemoryReusePass::IsOutVarReusable(
const details::VarHandle &out_var) const {
PADDLE_ENFORCE_NOT_NULL(dynamic_cast<const details::ComputationOpHandle *>(
out_var.GeneratedOp()));
const auto out_name = out_var.Name();
if (out_name == kEmptyVarName) {
return false; return false;
} }
// out_var must be the first version!!! // out_var must be the first version!!!
auto out_var_iter = (*all_vars_)[out_var->scope_idx()].find(out_name); auto out_var_iter = (*all_vars_)[out_var.scope_idx()].find(out_name);
PADDLE_ENFORCE(out_var_iter != (*all_vars_)[out_var->scope_idx()].end() && PADDLE_ENFORCE(out_var_iter != (*all_vars_)[out_var.scope_idx()].end() &&
!out_var_iter->second.empty(), !out_var_iter->second.empty(),
"Cannot find variable %s", out_name); "Cannot find variable %s", out_name);
if (out_var_iter->second[0] != out_var) { if (out_var_iter->second[0] != &out_var) {
return false; return false;
} }
const VarDesc *in_var_desc = GetVarDesc(in_var); if (IsOutVarAlreadyReused(out_var)) {
const VarDesc *out_var_desc = GetVarDesc(out_var); return false;
}
if (in_var_desc->Persistable() || out_var_desc->Persistable()) { const VarDesc *out_var_desc = GetVarDesc(out_var);
if (out_var_desc->Persistable()) {
return false; return false;
} }
if (in_var_desc->GetType() != proto::VarType::LOD_TENSOR || if (out_var_desc->GetType() != proto::VarType::LOD_TENSOR) {
out_var_desc->GetType() != proto::VarType::LOD_TENSOR) {
return false; return false;
} }
if (!FindNodesByName(in_name, out_var->GeneratedOp()->Node()->outputs) // If out_name occurs in input of the generated op, it cannot reuse others.
if (!FindNodesByName(out_name, out_var.GeneratedOp()->Node()->inputs)
.empty()) { .empty()) {
return false; return false;
} }
if (!FindNodesByName(out_name, out_var->GeneratedOp()->Node()->inputs) return true;
.empty()) { }
/**
* Input-Output pair can be reused only when:
* - they are not the same var.
* - they are both reusable.
* - input var does not occur in output of op.
* - input var does not occur in input of op for multiple times.
*/
bool MemoryReusePass::IsVarPairReusable(
const details::VarHandle &in_var, const details::VarHandle &out_var) const {
auto *op =
dynamic_cast<const details::ComputationOpHandle *>(out_var.GeneratedOp());
PADDLE_ENFORCE_NOT_NULL(op);
const auto in_name = in_var.Name();
if (in_name == out_var.Name()) {
return false; return false;
} }
auto all_input_args = if (!IsInVarReusable(in_var) || !IsOutVarReusable(out_var)) {
out_var->GeneratedOp()->Node()->Op()->InputArgumentNames(); return false;
}
if (!FindNodesByName(in_name, op->Node()->outputs).empty()) {
return false;
}
auto all_input_args = op->Node()->Op()->InputArgumentNames();
if (std::count(all_input_args.begin(), all_input_args.end(), in_name) > 1) { if (std::count(all_input_args.begin(), all_input_args.end(), in_name) > 1) {
return false; return false;
} }
...@@ -249,10 +313,11 @@ void MemoryReusePass::AddReuseVar(details::ComputationOpHandle *op, ...@@ -249,10 +313,11 @@ void MemoryReusePass::AddReuseVar(details::ComputationOpHandle *op,
share_buffer_op->AddInput(in_var); share_buffer_op->AddInput(in_var);
} }
share_buffer_op->Add( share_buffer_op->AddReuseVarPair(
(*var_infos_)[op->GetScopeIdx()].at(in_var->Name()).get(), (*var_infos_)[op->GetScopeIdx()].at(in_var->Name()).get(),
out_var->Name()); out_var->Name());
reused_var_names_[op->GetScopeIdx()].insert(in_var->Name()); reused_in_var_names_[op->GetScopeIdx()].insert(in_var->Name());
reused_out_var_names_[op->GetScopeIdx()].insert(out_var->Name());
UpdateLastLiveOpOfVar(op, in_var, out_var); UpdateLastLiveOpOfVar(op, in_var, out_var);
} }
...@@ -265,14 +330,21 @@ void MemoryReusePass::UpdateLastLiveOpOfVar(details::ComputationOpHandle *op, ...@@ -265,14 +330,21 @@ void MemoryReusePass::UpdateLastLiveOpOfVar(details::ComputationOpHandle *op,
size_t scope_idx = op->GetScopeIdx(); size_t scope_idx = op->GetScopeIdx();
auto out_var_op_iter = auto out_var_op_iter =
(*last_live_ops_of_vars_)[scope_idx].find(out_var->Name()); (*last_live_ops_of_vars_)[scope_idx].find(out_var->Name());
PADDLE_ENFORCE(out_var_op_iter != (*last_live_ops_of_vars_)[scope_idx].end(),
"Cannot find variable %s", out_var->Name()); // In Reduce mode, some output variable(gradient of parameter) does not have
PADDLE_ENFORCE(!out_var_op_iter->second.empty()); // last live ops
details::ComputationOpHandle *last_live_op_of_in_var = nullptr;
auto &last_live_ops_of_in_var = if (out_var_op_iter == (*last_live_ops_of_vars_)[scope_idx].end()) {
(*last_live_ops_of_vars_)[scope_idx][in_var->Name()]; last_live_op_of_in_var = op;
last_live_ops_of_in_var.clear(); } else {
last_live_ops_of_in_var.insert(*(out_var_op_iter->second.begin())); PADDLE_ENFORCE(!out_var_op_iter->second.ops().empty());
last_live_op_of_in_var = *(out_var_op_iter->second.ops().begin());
}
auto *last_live_ops_of_in_var =
(*last_live_ops_of_vars_)[scope_idx][in_var->Name()].mutable_ops();
last_live_ops_of_in_var->clear();
last_live_ops_of_in_var->insert(last_live_op_of_in_var);
auto in_var_info_iter = (*var_infos_)[scope_idx].find(in_var->Name()); auto in_var_info_iter = (*var_infos_)[scope_idx].find(in_var->Name());
PADDLE_ENFORCE(in_var_info_iter != (*var_infos_)[scope_idx].end(), PADDLE_ENFORCE(in_var_info_iter != (*var_infos_)[scope_idx].end(),
......
...@@ -81,18 +81,26 @@ class MemoryReusePass : public Pass { ...@@ -81,18 +81,26 @@ class MemoryReusePass : public Pass {
bool TryReuseVar(details::VarHandle *in_var, bool TryReuseVar(details::VarHandle *in_var,
details::VarHandle *out_var) const; details::VarHandle *out_var) const;
std::unordered_set<ir::Node *> FindNodesByName( bool IsInVarReusable(const details::VarHandle &in_var) const;
const std::string &name, const std::vector<ir::Node *> &nodes) const;
bool IsOutVarReusable(const details::VarHandle &out_var) const;
std::unordered_set<Node *> FindNodesByName(
const std::string &name, const std::vector<Node *> &nodes) const;
size_t ScopeNum() const { return all_vars_->size(); } size_t ScopeNum() const { return all_vars_->size(); }
int64_t GetMemorySize(const details::VarHandle &var) const;
private: private:
VarDesc *GetVarDesc(details::VarHandle *var) const; VarDesc *GetVarDesc(const details::VarHandle &var) const;
bool IsVarsReusable(details::VarHandle *in_var, bool IsVarPairReusable(const details::VarHandle &in_var,
details::VarHandle *out_var) const; const details::VarHandle &out_var) const;
bool IsVarAlreadyReused(details::VarHandle *var) const; bool IsInVarAlreadyReused(const details::VarHandle &in_var) const;
bool IsOutVarAlreadyReused(const details::VarHandle &out_var) const;
details::ShareTensorBufferOpHandle *InsertShareTensorBufferOpHandleToGraph( details::ShareTensorBufferOpHandle *InsertShareTensorBufferOpHandleToGraph(
details::ComputationOpHandle *op) const; details::ComputationOpHandle *op) const;
...@@ -110,15 +118,19 @@ class MemoryReusePass : public Pass { ...@@ -110,15 +118,19 @@ class MemoryReusePass : public Pass {
private: private:
mutable Graph *graph_; mutable Graph *graph_;
mutable bool use_cuda_;
mutable details::GraphVars *all_vars_; mutable details::GraphVars *all_vars_;
mutable MemOptVarInfoMapList *var_infos_; mutable MemOptVarInfoMapList *var_infos_;
mutable std::vector<LastLiveOpsOfVars> *last_live_ops_of_vars_; mutable std::vector<LastLiveOpsOfVars> *last_live_ops_of_vars_;
mutable std::unordered_map<details::ComputationOpHandle *, mutable std::unordered_map<details::ComputationOpHandle *,
details::ShareTensorBufferOpHandle *> details::ShareTensorBufferOpHandle *>
ops_; ops_;
mutable std::vector<std::unordered_set<std::string>> reused_var_names_; mutable std::vector<std::unordered_set<std::string>> reused_in_var_names_;
mutable std::vector<std::unordered_set<std::string>> reused_out_var_names_;
mutable std::vector<std::unordered_map<std::string, VarDesc *>> var_descs_; mutable std::vector<std::unordered_map<std::string, VarDesc *>> var_descs_;
}; };
......
...@@ -66,6 +66,24 @@ const std::unordered_set<details::OpHandleBase *> &OpGraphView::PendingOps( ...@@ -66,6 +66,24 @@ const std::unordered_set<details::OpHandleBase *> &OpGraphView::PendingOps(
return pending_ops_.at(op); return pending_ops_.at(op);
} }
const std::unordered_set<details::OpHandleBase *> &OpGraphView::PrecedingOps(
details::OpHandleBase *op) const {
EnforceHasOp(op);
return preceding_ops_.at(op);
}
std::unordered_map<details::OpHandleBase *, size_t>
OpGraphView::GetPrecedingDepNum() const {
std::unordered_map<details::OpHandleBase *, size_t> result;
result.reserve(preceding_ops_.size());
for (auto &pair : preceding_ops_) {
result.emplace(pair.first, pair.second.size());
}
return result;
}
size_t OpGraphView::OpNumber() const { return preceding_ops_.size(); }
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -33,13 +33,24 @@ class OpGraphView { ...@@ -33,13 +33,24 @@ class OpGraphView {
const std::unordered_set<details::OpHandleBase *> &PendingOps( const std::unordered_set<details::OpHandleBase *> &PendingOps(
details::OpHandleBase *op) const; details::OpHandleBase *op) const;
const std::unordered_set<details::OpHandleBase *> &PrecedingOps(
details::OpHandleBase *op) const;
std::unordered_map<details::OpHandleBase *, size_t> GetPrecedingDepNum()
const;
bool HasOp(details::OpHandleBase *op) const; bool HasOp(details::OpHandleBase *op) const;
size_t OpNumber() const;
// Use a visitor to visit all pending ops of op // Use a visitor to visit all pending ops of op
// Stop when callback returns false // Stop when callback returns false
template <typename Callback> template <typename Callback>
bool VisitAllPendingOps(details::OpHandleBase *op, Callback &&callback) const; bool VisitAllPendingOps(details::OpHandleBase *op, Callback &&callback) const;
template <typename Callback>
void BreadthFirstVisit(Callback &&callback) const;
private: private:
void Build(const std::vector<details::OpHandleBase *> &ops); void Build(const std::vector<details::OpHandleBase *> &ops);
void EnforceHasOp(details::OpHandleBase *op) const; void EnforceHasOp(details::OpHandleBase *op) const;
...@@ -75,6 +86,52 @@ bool OpGraphView::VisitAllPendingOps(details::OpHandleBase *op, ...@@ -75,6 +86,52 @@ bool OpGraphView::VisitAllPendingOps(details::OpHandleBase *op,
return true; return true;
} }
template <typename Callback>
void OpGraphView::BreadthFirstVisit(Callback &&callback) const {
auto op_deps = GetPrecedingDepNum();
size_t op_num = op_deps.size();
std::unordered_set<details::OpHandleBase *> visited_ops;
std::queue<details::OpHandleBase *> ready_ops;
size_t num_calls = 0;
for (auto iter = op_deps.begin(); iter != op_deps.end();) {
if (iter->second != 0) {
++iter;
continue;
}
visited_ops.insert(iter->first);
ready_ops.push(iter->first);
callback(iter->first);
++num_calls;
op_deps.erase(iter++);
}
while (!ready_ops.empty()) {
auto *cur_op = ready_ops.front();
ready_ops.pop();
auto &pending_ops = PendingOps(cur_op);
for (auto *pending_op : pending_ops) {
if (visited_ops.count(pending_op) > 0) {
continue;
}
if (--op_deps.at(pending_op) == 0) {
visited_ops.insert(pending_op);
op_deps.erase(pending_op);
ready_ops.push(pending_op);
callback(pending_op);
++num_calls;
}
}
}
PADDLE_ENFORCE_EQ(num_calls, op_num, "There are unvisited ops");
PADDLE_ENFORCE_EQ(visited_ops.size(), op_num, "There are unvisited ops");
PADDLE_ENFORCE(op_deps.empty(), "There are unvisited ops");
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -346,6 +346,8 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { ...@@ -346,6 +346,8 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
// Seldomly, some vars may have no pending or preceding computation ops // Seldomly, some vars may have no pending or preceding computation ops
// Just break; // Just break;
if (status == LastLiveOpSearchStatus::kFailure) { if (status == LastLiveOpSearchStatus::kFailure) {
VLOG(1) << "Cannot find last live ops of variable " << var_name
<< " in scope " << (*iter)->scope_idx();
break; break;
} }
...@@ -362,7 +364,9 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { ...@@ -362,7 +364,9 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
VLOG(10) << "Extract " << result.size() << " ops of var " << var_name; VLOG(10) << "Extract " << result.size() << " ops of var " << var_name;
var_infos[i][var_name].reset( var_infos[i][var_name].reset(
new MemOptVarInfo(var_name, result.size())); new MemOptVarInfo(var_name, result.size()));
last_live_ops_of_vars[i].emplace(var_name, std::move(result)); auto &last_live_ops_of_var = last_live_ops_of_vars[i][var_name];
last_live_ops_of_var.set_var(*iter);
*(last_live_ops_of_var.mutable_ops()) = std::move(result);
break; break;
} }
......
...@@ -39,10 +39,28 @@ using GarbageCollectorMap = ...@@ -39,10 +39,28 @@ using GarbageCollectorMap =
const char kMemOptVarInfoMapList[] = "mem_opt_var_info_map_list"; const char kMemOptVarInfoMapList[] = "mem_opt_var_info_map_list";
const char kGarbageCollector[] = "garbage_collector"; const char kGarbageCollector[] = "garbage_collector";
const char kAllPlaces[] = "all_places"; const char kAllPlaces[] = "all_places";
const char kUseCuda[] = "use_cuda";
using LastLiveOpsOfVars = class LastLiveOpOfVarInfo {
std::unordered_map<std::string, public:
std::unordered_set<details::ComputationOpHandle *>>; details::VarHandle *var() { return var_; }
void set_var(details::VarHandle *var) { var_ = var; }
const std::unordered_set<details::ComputationOpHandle *> &ops() const {
return ops_;
}
std::unordered_set<details::ComputationOpHandle *> *mutable_ops() {
return &ops_;
}
private:
details::VarHandle *var_{nullptr};
std::unordered_set<details::ComputationOpHandle *> ops_;
};
using LastLiveOpsOfVars = std::unordered_map<std::string, LastLiveOpOfVarInfo>;
const char kLastLiveOpsOfVars[] = "last_live_ops_of_var"; const char kLastLiveOpsOfVars[] = "last_live_ops_of_var";
VarDesc *TryGetLatestVarDesc(const std::vector<details::VarHandle *> &vars); VarDesc *TryGetLatestVarDesc(const std::vector<details::VarHandle *> &vars);
......
...@@ -99,7 +99,7 @@ class Node { ...@@ -99,7 +99,7 @@ class Node {
// Test if the Node is wrapped by type T. // Test if the Node is wrapped by type T.
template <typename T> template <typename T>
bool IsWrappedBy() { bool IsWrappedBy() const {
return std::type_index(typeid(T)) == wrapper_type_; return std::type_index(typeid(T)) == wrapper_type_;
} }
......
...@@ -36,7 +36,8 @@ Graph* Pass::Apply(Graph* graph) const { ...@@ -36,7 +36,8 @@ Graph* Pass::Apply(Graph* graph) const {
ApplyImpl(graph); ApplyImpl(graph);
// TODO(panyx0718): Add more verifications. // TODO(panyx0718): Add more verifications.
PADDLE_ENFORCE(!HasCircle(*graph), PADDLE_ENFORCE(!HasCircle(*graph),
"Illegal Pass. Generated graph shouldn't has cycle."); "Illegal Pass %s. Generated graph shouldn't have cycle.",
Type());
PADDLE_ENFORCE(VarDescIsConsistency(*graph), PADDLE_ENFORCE(VarDescIsConsistency(*graph),
"The VarDescs of persistable variable are not consistency."); "The VarDescs of persistable variable are not consistency.");
applied_ = true; applied_ = true;
......
...@@ -99,7 +99,7 @@ TEST(PassTest, TestPassAttrCheck) { ...@@ -99,7 +99,7 @@ TEST(PassTest, TestPassAttrCheck) {
} catch (paddle::platform::EnforceNotMet e) { } catch (paddle::platform::EnforceNotMet e) {
exception = std::string(e.what()); exception = std::string(e.what());
} }
ASSERT_TRUE(exception.find("shouldn't has cycle") != exception.npos); ASSERT_TRUE(exception.find("shouldn't have cycle") != exception.npos);
} }
} // namespace ir } // namespace ir
......
...@@ -252,7 +252,18 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) { ...@@ -252,7 +252,18 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
VLOG(10) << "buffer_shared_inplace_pass Applied"; VLOG(10) << "buffer_shared_inplace_pass Applied";
} }
// TODO(zjl): refactor MemoryOptimizePass as well!!! if (build_strategy_.memory_optimize_) {
auto cross_op_memory_reuse_pass = ir::PassRegistry::Instance().Get(
"buffer_shared_cross_op_memory_reuse_pass");
cross_op_memory_reuse_pass->SetNotOwned(ir::kMemOptVarInfoMapList,
&mem_opt_var_infos_);
cross_op_memory_reuse_pass->SetNotOwned(ir::kLastLiveOpsOfVars,
&last_live_ops_of_vars);
cross_op_memory_reuse_pass->SetNotOwned(ir::kUseCuda, &use_cuda_);
VLOG(10) << "Start to apply buffer_shared_cross_op_memory_reuse_pass";
graph = cross_op_memory_reuse_pass->Apply(graph);
VLOG(10) << "buffer_shared_cross_op_memory_reuse_pass Applied";
}
if (GetEagerDeletionThreshold() < 0) { if (GetEagerDeletionThreshold() < 0) {
return graph; return graph;
...@@ -780,3 +791,4 @@ bool ParallelExecutor::EnableParallelGraphExecution( ...@@ -780,3 +791,4 @@ bool ParallelExecutor::EnableParallelGraphExecution(
USE_PASS(reference_count_pass); USE_PASS(reference_count_pass);
USE_PASS(eager_deletion_pass); USE_PASS(eager_deletion_pass);
USE_PASS(buffer_shared_inplace_pass); USE_PASS(buffer_shared_inplace_pass);
USE_PASS(buffer_shared_cross_op_memory_reuse_pass);
...@@ -204,16 +204,6 @@ class CompiledProgram(object): ...@@ -204,16 +204,6 @@ class CompiledProgram(object):
else: else:
self._places = None self._places = None
self._build_strategy.is_distribution = _is_pserver_mode(self._program) self._build_strategy.is_distribution = _is_pserver_mode(self._program)
# FIXME(dzhwinter): enable_inplace should be after memory_optimize
# if turn on python memory optimize, turn off the inplace_pass.
# memory_optimize and enable_inplace default are True, but we can disable them on purpose
if self._program:
if self._program._is_mem_optimized:
self._build_strategy.memory_optimize = False
if self._build_strategy.memory_optimize:
self._build_strategy._use_legacy_memory_optimize_strategy = True
return self return self
def with_inference_optimize(self, config): def with_inference_optimize(self, config):
......
...@@ -287,4 +287,4 @@ endif() ...@@ -287,4 +287,4 @@ endif()
set_tests_properties(test_recordio_reader test_parallel_executor_test_while_train test_parallel_executor_mnist set_tests_properties(test_recordio_reader test_parallel_executor_test_while_train test_parallel_executor_mnist
test_parallel_executor_seresnext test_parallel_executor_crf test_sync_batch_norm_op test_parallel_executor_seresnext test_parallel_executor_crf test_sync_batch_norm_op
test_parallel_executor_crf_auto_growth test_parallel_executor_crf_auto_growth
test_buffer_shared_inplace_pass PROPERTIES LABELS "RUN_TYPE=DIST") test_buffer_shared_memory_reuse_pass PROPERTIES LABELS "RUN_TYPE=DIST")
...@@ -42,7 +42,7 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -42,7 +42,7 @@ class TestParallelExecutorBase(unittest.TestCase):
seed=None, seed=None,
use_parallel_executor=True, use_parallel_executor=True,
use_reduce=False, use_reduce=False,
use_ir_memory_optimize=False, use_ir_memory_optimize=True,
enable_inplace=True, enable_inplace=True,
fuse_elewise_add_act_ops=False, fuse_elewise_add_act_ops=False,
fuse_all_optimizer_ops=False, fuse_all_optimizer_ops=False,
...@@ -66,8 +66,9 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -66,8 +66,9 @@ class TestParallelExecutorBase(unittest.TestCase):
main.random_seed = seed main.random_seed = seed
loss = method(use_feed=feed_dict is not None) loss = method(use_feed=feed_dict is not None)
if memory_opt or use_ir_memory_optimize: # NOTE(zjl): memory_optimize/inplace pass would not require
loss.persistable = True # that loss.persistable = True
loss.persistable = memory_opt
if optimizer: if optimizer:
optimizer().minimize(loss) optimizer().minimize(loss)
...@@ -92,10 +93,10 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -92,10 +93,10 @@ class TestParallelExecutorBase(unittest.TestCase):
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv
build_strategy.memory_optimize = False if memory_opt else use_ir_memory_optimize
build_strategy.fuse_all_optimizer_ops = fuse_all_optimizer_ops build_strategy.fuse_all_optimizer_ops = fuse_all_optimizer_ops
build_strategy.fuse_all_reduce_ops = fuse_all_reduce_ops build_strategy.fuse_all_reduce_ops = fuse_all_reduce_ops
build_strategy.memory_optimize = use_ir_memory_optimize
build_strategy.enable_inplace = enable_inplace build_strategy.enable_inplace = enable_inplace
build_strategy.enable_sequential_execution = enable_sequential_execution build_strategy.enable_sequential_execution = enable_sequential_execution
......
...@@ -81,25 +81,20 @@ class InplaceTestBase(unittest.TestCase): ...@@ -81,25 +81,20 @@ class InplaceTestBase(unittest.TestCase):
return return
prog1, scope1, exe, loss1 = self.build_program_and_scope() prog1, scope1, exe, loss1 = self.build_program_and_scope()
prog2, scope2, _, loss2 = self.build_program_and_scope() scopes = []
prog3, scope3, _, loss3 = self.build_program_and_scope() compiled_programs = []
for memory_optimize in [False, True]:
build_strategy2 = fluid.BuildStrategy() for enable_inplace in [False, True]:
build_strategy2.memory_optimize = False prog, scope, _, loss = self.build_program_and_scope()
build_strategy2.enable_inplace = True scopes.append(scope)
build_strategy = fluid.BuildStrategy()
compiled_prog2 = fluid.CompiledProgram(prog2).with_data_parallel( build_strategy.memory_optimize = memory_optimize
loss_name=loss2.name, build_strategy.enable_inplace = enable_inplace
build_strategy=build_strategy2, compiled_prog = fluid.CompiledProgram(prog).with_data_parallel(
places=self.place) loss_name=loss.name,
build_strategy=build_strategy,
build_strategy3 = fluid.BuildStrategy() places=self.place)
build_strategy3.memory_optimize = False compiled_programs.append(compiled_prog)
build_strategy3.enable_inplace = False
compiled_prog3 = fluid.CompiledProgram(prog3).with_data_parallel(
loss_name=loss2.name,
build_strategy=build_strategy3,
places=self.place)
all_vars_name = self.get_all_vars(prog1) all_vars_name = self.get_all_vars(prog1)
repeated_var_names = all_vars_name * 4 repeated_var_names = all_vars_name * 4
...@@ -112,60 +107,56 @@ class InplaceTestBase(unittest.TestCase): ...@@ -112,60 +107,56 @@ class InplaceTestBase(unittest.TestCase):
feed=feed_dict, feed=feed_dict,
fetch_list=[fetch_var]) fetch_list=[fetch_var])
with fluid.scope_guard(scope2): for scope, compiled_prog in zip(scopes, compiled_programs):
fetch_val2, = exe.run(compiled_prog2, with fluid.scope_guard(scope):
feed=feed_dict, fetch_val2, = exe.run(compiled_prog,
fetch_list=[fetch_var]) feed=feed_dict,
fetch_list=[fetch_var])
with fluid.scope_guard(scope3): self.assertTrue(np.array_equal(fetch_val1, fetch_val2))
fetch_val3, = exe.run(compiled_prog3,
feed=feed_dict,
fetch_list=[fetch_var])
self.assertTrue(np.array_equal(fetch_val1, fetch_val2))
self.assertTrue(np.array_equal(fetch_val1, fetch_val3))
def test_multi_card_fetch_var(self): def test_multi_card_fetch_var(self):
if self.is_invalid_test(): if self.is_invalid_test():
return return
prog1, scope1, exe, loss1 = self.build_program_and_scope() prog1, scope1, exe, loss1 = self.build_program_and_scope()
prog2, scope2, _, loss2 = self.build_program_and_scope() scopes = []
compiled_programs = []
build_strategy1 = fluid.BuildStrategy()
build_strategy1.memory_optimize = False
build_strategy1.enable_inplace = True
build_strategy2 = fluid.BuildStrategy()
build_strategy2.memory_optimize = False
build_strategy2.enable_inplace = False
if self.use_cuda: if self.use_cuda:
places = fluid.cuda_places() places = fluid.cuda_places()
else: else:
places = fluid.cpu_places(self.device_count) places = fluid.cpu_places(self.device_count)
compiled_prog1 = fluid.CompiledProgram(prog1).with_data_parallel( for memory_optimize in [False, True]:
loss_name=loss1.name, build_strategy=build_strategy1, places=places) for enable_inplace in [False, True]:
compiled_prog2 = fluid.CompiledProgram(prog2).with_data_parallel( prog, scope, _, loss = self.build_program_and_scope()
loss_name=loss2.name, build_strategy=build_strategy2, places=places) scopes.append(scope)
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = memory_optimize
build_strategy.enable_inplace = enable_inplace
compiled_program = fluid.CompiledProgram(
prog).with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
places=places)
compiled_programs.append(compiled_program)
repeated_var_names = self.get_all_vars(prog1) * 4 repeated_var_names = self.get_all_vars(prog1) * 4
random.shuffle(repeated_var_names) # add some random random.shuffle(repeated_var_names) # add some random
for fetch_var in repeated_var_names: for fetch_var in repeated_var_names:
for _ in range(4): for _ in range(4):
with fluid.scope_guard(scope1): fetch_vals = []
fetch_val1, = exe.run(compiled_prog1, for scope, compiled_prog in zip(scopes, compiled_programs):
feed=feed_dict, with fluid.scope_guard(scope):
fetch_list=[fetch_var]) fetch_val, = exe.run(compiled_prog,
feed=feed_dict,
with fluid.scope_guard(scope2): fetch_list=[fetch_var])
fetch_val2, = exe.run(compiled_prog2, fetch_vals.append(fetch_val)
feed=feed_dict,
fetch_list=[fetch_var]) for item in fetch_vals:
self.assertTrue(np.array_equal(fetch_vals[0], item))
self.assertTrue(np.array_equal(fetch_val1, fetch_val2))
class CPUInplaceTest(InplaceTestBase): class CPUInplaceTest(InplaceTestBase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册