未验证 提交 04bd413a 编写于 作者: C chengduo 提交者: GitHub

Code Clean: Move all pass to paddle::framework::ir (#17228)

* move pass to ir

* polish code
test=develop

* fix dependency
test=develop
上级 648320bb
cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node) cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node)
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor) cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor)
cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base)
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper) cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper)
cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper)
cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper)
cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper)
cc_library(fuse_adam_op_pass SRCS fuse_adam_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(fuse_momentum_op_pass SRCS fuse_momentum_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
...@@ -27,7 +17,7 @@ if(WITH_DISTRIBUTE) ...@@ -27,7 +17,7 @@ if(WITH_DISTRIBUTE)
endif() endif()
endif() endif()
set(all_reduce_deps all_reduce_op_handle)
if(WITH_GPU) if(WITH_GPU)
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor) dynload_cuda variable_visitor)
...@@ -37,7 +27,6 @@ if(WITH_GPU) ...@@ -37,7 +27,6 @@ if(WITH_GPU)
if(WITH_DGC) if(WITH_DGC)
nv_library(sparse_all_reduce_op_handle SRCS sparse_all_reduce_op_handle.cc DEPS op_handle_base scope nv_library(sparse_all_reduce_op_handle SRCS sparse_all_reduce_op_handle.cc DEPS op_handle_base scope
lod_tensor ddim memory dynload_cuda variable_visitor dgc all_reduce_op_handle) lod_tensor ddim memory dynload_cuda variable_visitor dgc all_reduce_op_handle)
set(all_reduce_deps sparse_all_reduce_op_handle)
endif() endif()
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
...@@ -68,34 +57,12 @@ endif() ...@@ -68,34 +57,12 @@ endif()
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
if(WITH_GPU)
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper gpu_info)
else()
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper cpu_info)
endif()
cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass)
cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info)
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle)
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper) cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass)
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle ${all_reduce_deps} reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass) set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass)
if (WITH_GPU) if (WITH_GPU)
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass) list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
endif() endif()
cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/alloc_continuous_space_for_grad_pass.h"
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
DEFINE_uint64(fuse_parameter_memory_size, 0, // 0 KB
"fuse_parameter_memory_size is up limited memory size "
"of one group parameters' gradient which is the input "
"of communication calling(e.g NCCLAllReduce). "
"The default value is 0, it means that "
"not set group according to memory_size.");
DEFINE_int32(
fuse_parameter_groups_size, 3,
"fuse_parameter_groups_size is the size of one group parameters' gradient. "
"The default value is a experimental result. If the "
"fuse_parameter_groups_size is 1, it means that the groups size is "
"the number of parameters' gradient. If the fuse_parameter_groups_size is "
"-1, it means that there are only one group. The default value is 3, it is "
"an experimental value.");
namespace paddle {
namespace framework {
namespace details {
// SetFuseParameterGroupsSize and SetFuseParameterMemorySize are used in unit
// test, because it is invalid that seting 'FLAGS_fuse_parameter_memory_size'
// and 'FLAGS_fuse_parameter_groups_size' in unit test.
void SetFuseParameterGroupsSize(int group_size) {
FLAGS_fuse_parameter_groups_size = group_size;
}
int GetFuseParameterGroupsSize() { return FLAGS_fuse_parameter_groups_size; }
void SetFuseParameterMemorySize(uint64_t memory_size) {
FLAGS_fuse_parameter_memory_size = memory_size;
}
uint64_t GetFuseParameterMemorySize() {
return FLAGS_fuse_parameter_memory_size;
}
static const char kUnKnow[] = "@UNKNOW@";
static framework::proto::VarType::Type kDefaultDtype =
framework::proto::VarType::Type::VarType_Type_BOOL;
void AllocContinuousSpaceForGradPass::ApplyImpl(ir::Graph *graph) const {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(kLocalScopes);
ResetAttribute<ParamsAndGrads>(kParamsAndGrads, &result);
ResetAttribute<GroupGradsAndParams>(kGroupGradsAndParams, &result);
// NOTE: The operator nodes should be in topology order.
std::vector<ir::Node *> topo_nodes = ir::TopologySortOperations(result);
auto &params_grads = result.Get<ParamsAndGrads>(kParamsAndGrads);
for (auto &node : topo_nodes) {
RecordParamsAndGrads(node, &params_grads);
}
if (params_grads.size() == 0) {
VLOG(10) << "Doesn't find gradients";
return;
}
std::unordered_map<std::string, ir::Node *> vars;
for (ir::Node *node : result.Nodes()) {
if (node->IsVar() && node->Var()) {
// Note: The graph may have the same name node. For example, parameter
// is the input of operator and it also is the output of optimizer;
vars.emplace(node->Var()->Name(), node);
}
}
auto &group_grads_params =
result.Get<GroupGradsAndParams>(kGroupGradsAndParams);
// Note: the order of params_grads may be changed by SetGroupGradsAndParams.
SetGroupGradsAndParams(vars, params_grads, &group_grads_params);
params_grads.clear();
for (auto &group_p_g : group_grads_params) {
params_grads.insert(params_grads.begin(), group_p_g.begin(),
group_p_g.end());
}
for (auto &p_g : params_grads) {
std::swap(p_g.first, p_g.second);
}
// Set Gradients as Persistable to prevent this var becoming reusable.
auto dtype = kDefaultDtype;
for (auto &p_g : params_grads) {
// Get gradient var
auto iter = vars.find(p_g.second);
PADDLE_ENFORCE(iter != vars.end(), "%s is not found.", p_g.second);
iter->second->Var()->SetPersistable(true);
PADDLE_ENFORCE(IsSupportedVarType(iter->second->Var()->GetType()));
// Get Dtype
auto ele_dtype = iter->second->Var()->GetDataType();
if (dtype == kDefaultDtype) {
dtype = ele_dtype;
PADDLE_ENFORCE_NE(ele_dtype, kDefaultDtype,
"The data type should not be bool.");
}
PADDLE_ENFORCE_EQ(ele_dtype, dtype,
"The data type of input is not consistent.");
}
// Create a FusedVarsSet to avoid duplicating names for fused_var in other
// pass.
if (!result.Has(kFusedVars)) {
result.Set(kFusedVars, new FusedVars);
}
// the kFusedGrads is used be fuse_optimizer_op_pass.
result.Set(kFusedGrads, new FusedGrads);
// the fused_var_name should be unique, so it appends
// params_grads.begin()->second.
auto fused_var_name = std::string(kFusedVarNamePrefix) + "@GRAD@" +
params_grads.begin()->second;
result.Get<FusedGrads>(kFusedGrads) = fused_var_name;
auto &fused_var_set = result.Get<FusedVars>(kFusedVars);
PADDLE_ENFORCE_EQ(fused_var_set.count(fused_var_name), 0,
"%s is duplicate in FusedVars.", fused_var_name);
fused_var_set.insert(fused_var_name);
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars, fused_var_name,
params_grads);
}
template <typename AttrType>
void AllocContinuousSpaceForGradPass::ResetAttribute(
const std::string &attr_name, ir::Graph *graph) const {
if (graph->Has(attr_name)) {
VLOG(10) << attr_name << " is reset.";
graph->Erase(attr_name);
}
graph->Set(attr_name, new AttrType);
}
void AllocContinuousSpaceForGradPass::SetGroupGradsAndParams(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const ParamsAndGrads &params_grads,
GroupGradsAndParams *group_grads_params) const {
SetGroupAccordingToLayers(var_nodes, params_grads, group_grads_params);
SetGroupAccordingToMemorySize(var_nodes, group_grads_params);
SetGroupAccordingToGroupSize(var_nodes, group_grads_params);
}
void AllocContinuousSpaceForGradPass::SetGroupAccordingToLayers(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const ParamsAndGrads &params_grads,
GroupGradsAndParams *group_grads_params) const {
std::unordered_map<std::string, std::vector<int>> layer_params;
for (size_t i = 0; i < params_grads.size(); ++i) {
auto pos = params_grads[i].first.find_first_of(".");
if (pos == std::string::npos) {
layer_params[std::string(kUnKnow)].emplace_back(i);
} else {
layer_params[params_grads[i].first.substr(0, pos)].emplace_back(i);
}
}
group_grads_params->reserve(layer_params.size());
for (size_t i = 0; i < params_grads.size(); ++i) {
auto pos = params_grads[i].first.find_first_of(".");
std::string key = kUnKnow;
if (pos != std::string::npos) {
key = params_grads[i].first.substr(0, pos);
}
auto iter = layer_params.find(key);
if (iter == layer_params.end()) continue;
group_grads_params->emplace_back();
auto &local_group_grads_params = group_grads_params->back();
for (auto &idx : iter->second) {
local_group_grads_params.emplace_back(
std::make_pair(params_grads[idx].second, params_grads[idx].first));
}
layer_params.erase(iter);
}
VLOG(10) << "SetGroupAccordingToLayers: ";
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &p_g : group_grads_params->at(i)) {
out << "(" << p_g.second << ", " << p_g.first << "), ";
}
VLOG(10) << out.str();
}
}
void AllocContinuousSpaceForGradPass::SetGroupAccordingToMemorySize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
GroupGradsAndParams *group_grads_params) const {
const uint64_t group_memory_size = GetFuseParameterMemorySize();
if (group_memory_size == 0) {
return;
}
GroupGradsAndParams local_group_grads_params;
size_t j = 0;
while (j < group_grads_params->size()) {
local_group_grads_params.emplace_back();
auto &group_p_g = local_group_grads_params.back();
size_t local_group_memory_size = 0;
while (j < group_grads_params->size()) {
std::for_each(
group_grads_params->at(j).begin(), group_grads_params->at(j).end(),
[&local_group_memory_size,
&var_nodes](const std::pair<std::string, std::string> &g_p) {
auto iter = var_nodes.find(g_p.second);
PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.",
g_p.second);
auto shape = iter->second->Var()->GetShape();
size_t size =
framework::SizeOfType(iter->second->Var()->GetDataType());
std::for_each(shape.begin(), shape.end(),
[&size](const int64_t &n) { size *= n; });
local_group_memory_size += size;
});
group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(),
group_grads_params->at(j).end());
++j;
if (local_group_memory_size >= group_memory_size) {
break;
}
}
}
std::swap(*group_grads_params, local_group_grads_params);
VLOG(10) << string::Sprintf("SetGroupAccordingToMemorySize(memory_size: %d):",
group_memory_size);
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &g_p : group_grads_params->at(i)) {
auto iter = var_nodes.find(g_p.second);
PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.", g_p.second);
auto shape = iter->second->Var()->GetShape();
size_t size = framework::SizeOfType(iter->second->Var()->GetDataType());
std::for_each(shape.begin(), shape.end(),
[&size](const int64_t &n) { size *= n; });
out << string::Sprintf("(%s(%d), %s)", g_p.second, size, g_p.first);
}
VLOG(10) << out.str();
}
}
void AllocContinuousSpaceForGradPass::SetGroupAccordingToGroupSize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
GroupGradsAndParams *group_grads_params) const {
if (GetFuseParameterGroupsSize() == 1) {
return;
}
const int group_size = GetFuseParameterGroupsSize() == -1
? static_cast<int>(group_grads_params->size())
: GetFuseParameterGroupsSize();
PADDLE_ENFORCE_GT(group_size, 1);
size_t groups = (group_grads_params->size() + group_size - 1) / group_size;
GroupGradsAndParams local_group_grads_params;
local_group_grads_params.reserve(groups);
size_t j = 0;
for (size_t i = 0; i < groups; ++i) {
local_group_grads_params.emplace_back();
auto &group_p_g = local_group_grads_params.back();
group_p_g.reserve(group_size);
while (j < group_grads_params->size()) {
group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(),
group_grads_params->at(j).end());
++j;
if (j % group_size == 0) break;
}
}
std::swap(*group_grads_params, local_group_grads_params);
VLOG(10) << string::Sprintf("SetGroupAccordingToGroupSize(group_size: %d):",
group_size);
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &p_g : group_grads_params->at(i)) {
out << "(" << p_g.second << ", " << p_g.first << "), ";
}
VLOG(10) << out.str();
}
}
bool AllocContinuousSpaceForGradPass::IsSupportedVarType(
const proto::VarType::Type &type) const {
// Current only support LOD_TENSOR.
return type == proto::VarType::LOD_TENSOR;
}
void AllocContinuousSpaceForGradPass::RecordParamsAndGrads(
ir::Node *node, ParamsAndGrads *params_grads) const {
try {
bool is_bk_op =
static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
if (!is_bk_op) return;
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
auto backward_vars =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, static_cast<size_t>(0));
for (size_t i = 0; i < backward_vars.size(); i += 2) {
VLOG(10) << "Trainable parameter: " << backward_vars[i]
<< ", gradient: " << backward_vars[i + 1];
params_grads->emplace_back(std::make_pair(backward_vars[i] /*param*/,
backward_vars[i + 1] /*grad*/));
}
} catch (boost::bad_get e) {
}
}
void AllocContinuousSpaceForGradPass::InitFusedVarsAndAllocSpaceForVars(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::unordered_map<std::string, ir::Node *> &vars,
const std::string &fused_var_name,
const ParamsAndGrads &params_grads) const {
// Init Gradients and FusedVars
VLOG(10) << "Init FusedVars and Gradients.";
for (auto it = local_scopes.rbegin(); it != local_scopes.rend(); ++it) {
auto &scope = *it;
PADDLE_ENFORCE(scope->FindVar(fused_var_name) == nullptr,
"%s has existed in scope.", fused_var_name);
scope->Var(fused_var_name)->GetMutable<LoDTensor>();
for (auto &p_g : params_grads) {
auto iter = vars.find(p_g.second);
PADDLE_ENFORCE(iter != vars.end());
PADDLE_ENFORCE_NOT_NULL(iter->second->Var());
PADDLE_ENFORCE_EQ(iter->second->Var()->GetType(),
proto::VarType::LOD_TENSOR);
scope->Var(p_g.second)->GetMutable<LoDTensor>();
}
}
// Alloc continuous space for vars.
std::vector<std::string> grads_name;
std::vector<std::string> params_name;
grads_name.reserve(params_grads.size());
params_name.reserve(params_grads.size());
for (auto &p_g : params_grads) {
params_name.emplace_back(p_g.first);
grads_name.emplace_back(p_g.second);
}
framework::ProgramDesc program_desc;
AppendAllocSpaceForVarsOp(params_name, grads_name, fused_var_name,
program_desc.MutableBlock(0));
for (size_t i = 0; i < local_scopes.size(); ++i) {
for (auto &op_desc : program_desc.Block(0).AllOps()) {
auto op = OpRegistry::CreateOp(*op_desc);
op->Run(*local_scopes[i], places[i]);
}
}
}
void AllocContinuousSpaceForGradPass::AppendAllocSpaceForVarsOp(
const std::vector<std::string> &params_name,
const std::vector<std::string> &grads_name,
const std::string &fused_var_name, BlockDesc *global_block) const {
auto op_desc = global_block->AppendOp();
op_desc->SetType("alloc_continuous_space");
op_desc->SetInput("Input", params_name);
op_desc->SetOutput("Output", grads_name);
op_desc->SetOutput("FusedOutput", {fused_var_name});
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(alloc_continuous_space_for_grad_pass,
paddle::framework::details::AllocContinuousSpaceForGradPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace details {
void SetFuseParameterGroupsSize(int group_size);
int GetFuseParameterGroupsSize();
void SetFuseParameterMemorySize(uint64_t memory_size);
uint64_t GetFuseParameterMemorySize();
class AllocContinuousSpaceForGradPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override;
template <typename AttrType>
void ResetAttribute(const std::string &attr_name, ir::Graph *graph) const;
void SetGroupGradsAndParams(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const ParamsAndGrads &params_grads,
GroupGradsAndParams *group_grads_params) const;
void SetGroupAccordingToLayers(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const ParamsAndGrads &params_grads,
GroupGradsAndParams *group_grads_params) const;
void SetGroupAccordingToMemorySize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
GroupGradsAndParams *group_grads_params) const;
void SetGroupAccordingToGroupSize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
GroupGradsAndParams *group_grads_params) const;
private:
bool IsSupportedVarType(const proto::VarType::Type &type) const;
void RecordParamsAndGrads(ir::Node *node, ParamsAndGrads *params_grads) const;
void InitFusedVarsAndAllocSpaceForVars(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::unordered_map<std::string, ir::Node *> &vars,
const std::string &fused_var_name,
const ParamsAndGrads &params_grads) const;
void AppendAllocSpaceForVarsOp(const std::vector<std::string> &params_name,
const std::vector<std::string> &grads_name,
const std::string &fused_var_name,
BlockDesc *global_block) const;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -17,15 +17,14 @@ limitations under the License. */ ...@@ -17,15 +17,14 @@ limitations under the License. */
#include <glog/logging.h> #include <glog/logging.h>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_to_program_pass.h" #include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -173,10 +172,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -173,10 +172,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
const std::string graph_path = const std::string graph_path =
string::Sprintf("%s%s", strategy_.debug_graphviz_path_.c_str(), string::Sprintf("%s%s", strategy_.debug_graphviz_path_.c_str(),
"_multi_devices_graph"); "_multi_devices_graph");
multi_devices_print_pass->Set<std::string>(kGraphvizPath, multi_devices_print_pass->Set<std::string>(ir::kGraphvizPath,
new std::string(graph_path)); new std::string(graph_path));
multi_devices_print_pass->Set<details::GraphvizSSAGraphPrinter>( multi_devices_print_pass->Set<ir::GraphvizSSAGraphPrinter>(
"graph_printer", new details::GraphvizSSAGraphPrinter); "graph_printer", new ir::GraphvizSSAGraphPrinter);
} }
// experimental shows that the program will be faster if append // experimental shows that the program will be faster if append
...@@ -240,7 +239,7 @@ std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy( ...@@ -240,7 +239,7 @@ std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
} }
bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const { bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
return framework::details::MultiDevSSAGraphBuilder().count(pass_name) > 0; return framework::ir::MultiDevSSAGraphBuilder().count(pass_name) > 0;
} }
ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
...@@ -263,13 +262,13 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -263,13 +262,13 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
if (IsMultiDevPass(pass->Type())) { if (IsMultiDevPass(pass->Type())) {
pass->Erase(kPlaces); pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places); pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
pass->Erase(kLossVarName); pass->Erase(ir::kLossVarName);
pass->SetNotOwned<const std::string>(kLossVarName, &loss_var_name); pass->SetNotOwned<const std::string>(ir::kLossVarName, &loss_var_name);
pass->Erase(kLocalScopes); pass->Erase(kLocalScopes);
pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes, pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
&local_scopes); &local_scopes);
pass->Erase(kNRanks); pass->Erase(ir::kNRanks);
pass->Set<size_t>(kNRanks, new size_t(nranks)); pass->Set<size_t>(ir::kNRanks, new size_t(nranks));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
...@@ -312,8 +311,8 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -312,8 +311,8 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
continue; continue;
} }
} else if (pass->Type() == "inplace_pass") { } else if (pass->Type() == "inplace_pass") {
pass->Erase(kUseCuda); pass->Erase(ir::kUseCuda);
pass->Set<bool>(kUseCuda, new bool(use_cuda)); pass->Set<bool>(ir::kUseCuda, new bool(use_cuda));
} }
VLOG(3) << "Start Apply Pass " << pass->Type(); VLOG(3) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(graph); graph = pass->Apply(graph);
......
...@@ -31,7 +31,7 @@ namespace details { ...@@ -31,7 +31,7 @@ namespace details {
EagerDeletionOpHandle::EagerDeletionOpHandle( EagerDeletionOpHandle::EagerDeletionOpHandle(
ir::Node *node, const Scope *scope, const platform::Place &place, ir::Node *node, const Scope *scope, const platform::Place &place,
const std::unordered_set<std::string> &var_names, GarbageCollector *gc, const std::unordered_set<std::string> &var_names, GarbageCollector *gc,
AtomicReferenceCountMap *ref_cnts) ir::AtomicReferenceCountMap *ref_cnts)
: OpHandleBase(node), : OpHandleBase(node),
scope_(scope), scope_(scope),
var_names_(var_names.begin(), var_names.end()), var_names_(var_names.begin(), var_names.end()),
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -34,7 +34,7 @@ class EagerDeletionOpHandle : public OpHandleBase { ...@@ -34,7 +34,7 @@ class EagerDeletionOpHandle : public OpHandleBase {
const platform::Place &place, const platform::Place &place,
const std::unordered_set<std::string> &var_names, const std::unordered_set<std::string> &var_names,
GarbageCollector *gc, GarbageCollector *gc,
AtomicReferenceCountMap *ref_cnts); ir::AtomicReferenceCountMap *ref_cnts);
~EagerDeletionOpHandle(); ~EagerDeletionOpHandle();
...@@ -56,7 +56,7 @@ class EagerDeletionOpHandle : public OpHandleBase { ...@@ -56,7 +56,7 @@ class EagerDeletionOpHandle : public OpHandleBase {
const Scope *scope_; const Scope *scope_;
std::vector<std::string> var_names_; std::vector<std::string> var_names_;
GarbageCollector *gc_; // not own GarbageCollector *gc_; // not own
AtomicReferenceCountMap *ref_cnts_; // not own ir::AtomicReferenceCountMap *ref_cnts_; // not own
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
platform::CUDADeviceContext *dev_ctx_{nullptr}; platform::CUDADeviceContext *dev_ctx_{nullptr};
cudaEvent_t event_{nullptr}; cudaEvent_t event_{nullptr};
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
class ModifyOpLockAndRecordEventPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace details {
static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) {
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
op1->Outputs() == op2->Outputs();
}
void SequentialExecutionPass::ApplyImpl(ir::Graph *graph) const {
// FIXME(zjl): Insert dependencies between some distributed ops may cause
// the multi_devices_graph_pass fails. So we skip these ops here.
// Indeed, maybe we should not insert dependencies between these ops
// casually, which may cause deadlock easily.
// We should add more skipped distributed ops when found errors in
// multi_devices_graph_pass
static std::unordered_set<std::string> skip_dist_ops{
"send", "recv", "send_barrier", "fetch_barrier"};
auto &ops = graph->Get<const std::vector<OpDesc *>>(kStaleProgramOpDescs);
std::vector<ir::Node *> op_node_list;
op_node_list.reserve(ops.size());
std::unordered_map<ir::Node *, size_t> op_deps;
std::unordered_map<ir::Node *, std::unordered_set<ir::Node *>> pending_ops;
std::unordered_set<ir::Node *> ready_ops;
for (ir::Node *node : graph->Nodes()) {
if (!node->IsOp()) continue;
std::unordered_set<ir::Node *> preceding_ops;
for (auto *in : node->inputs) {
PADDLE_ENFORCE(in->IsVar(),
"Preceding Node of Op Nodes must be Var Node");
if (in->inputs.empty()) continue;
PADDLE_ENFORCE(in->inputs.size() == 1 && in->inputs[0]->IsOp(),
"Preceding Op Node of Var Node must be unique");
preceding_ops.insert(in->inputs[0]);
pending_ops[in->inputs[0]].insert(node);
}
op_deps[node] = preceding_ops.size();
if (preceding_ops.empty()) {
ready_ops.insert(node);
}
}
for (auto *op_desc : ops) {
ir::Node *found_node = nullptr;
for (auto *node : ready_ops) {
if (IsSameOpDesc(op_desc, node->Op())) {
PADDLE_ENFORCE(found_node == nullptr,
"Found multiple op_desc in graph: %s", op_desc->Type());
found_node = node;
}
}
PADDLE_ENFORCE_NOT_NULL(found_node, "Cannot find op_desc in graph: %s",
op_desc->Type());
for (auto *pending_op : pending_ops[found_node]) {
if (--op_deps.at(pending_op) == 0) {
ready_ops.insert(pending_op);
}
}
ready_ops.erase(found_node);
if (skip_dist_ops.count(op_desc->Type()) == 0) {
op_node_list.push_back(found_node);
}
}
for (size_t i = 1; i < op_node_list.size(); ++i) {
auto *dep_var = graph->CreateControlDepVar();
op_node_list[i]->inputs.push_back(dep_var);
op_node_list[i - 1]->outputs.push_back(dep_var);
dep_var->outputs.push_back(op_node_list[i]);
dep_var->inputs.push_back(op_node_list[i - 1]);
VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name()
<< " and " << op_node_list[i]->Name();
}
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(sequential_execution_pass,
paddle::framework::details::SequentialExecutionPass)
.RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
class SequentialExecutionPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/framework/type_defs.h"
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass_builder.h" #include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -33,7 +33,7 @@ namespace framework { ...@@ -33,7 +33,7 @@ namespace framework {
std::unique_ptr<ir::Pass> CreateInplacePass() { std::unique_ptr<ir::Pass> CreateInplacePass() {
auto pass = ir::PassRegistry::Instance().Get("inplace_pass"); auto pass = ir::PassRegistry::Instance().Get("inplace_pass");
pass->Set<bool>(details::kUseCuda, new bool(true)); pass->Set<bool>(ir::kUseCuda, new bool(true));
return pass; return pass;
} }
...@@ -225,7 +225,7 @@ TEST(InferInplace, SingleOpInplaceInToOut) { ...@@ -225,7 +225,7 @@ TEST(InferInplace, SingleOpInplaceInToOut) {
FakeSuccData(&prog); FakeSuccData(&prog);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog)); std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>()); g->Set(ir::kMemOptSkipVars, new std::unordered_set<std::string>());
g = test_SingleOpInplaceInToOut(std::move(g)); g = test_SingleOpInplaceInToOut(std::move(g));
auto op_node = GetNodeFromGraph(g.get(), "single_op"); auto op_node = GetNodeFromGraph(g.get(), "single_op");
...@@ -241,7 +241,7 @@ TEST(InferInplace, SingleOpInplaceInToOutNoInplace) { ...@@ -241,7 +241,7 @@ TEST(InferInplace, SingleOpInplaceInToOutNoInplace) {
FakeNoInplaceData(&prog); FakeNoInplaceData(&prog);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog)); std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>()); g->Set(ir::kMemOptSkipVars, new std::unordered_set<std::string>());
g = test_SingleOpInplaceInToOut(std::move(g)); g = test_SingleOpInplaceInToOut(std::move(g));
auto op_node = GetNodeFromGraph(g.get(), "single_op"); auto op_node = GetNodeFromGraph(g.get(), "single_op");
...@@ -274,7 +274,7 @@ TEST(InferInplace, MultiOutInplaceInToOut) { ...@@ -274,7 +274,7 @@ TEST(InferInplace, MultiOutInplaceInToOut) {
prog.MutableBlock(0)->Var("z0")->SetShape({32, 16, 1024, 1024}); prog.MutableBlock(0)->Var("z0")->SetShape({32, 16, 1024, 1024});
std::unique_ptr<ir::Graph> g(new ir::Graph(prog)); std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>()); g->Set(ir::kMemOptSkipVars, new std::unordered_set<std::string>());
auto pass = CreateInplacePass(); auto pass = CreateInplacePass();
pass->Apply(g.get()); pass->Apply(g.get());
auto op_node = GetNodeFromGraph(g.get(), "multi_out_op"); auto op_node = GetNodeFromGraph(g.get(), "multi_out_op");
...@@ -310,7 +310,7 @@ TEST(InferInplace, MultiGradInplaceInToOut) { ...@@ -310,7 +310,7 @@ TEST(InferInplace, MultiGradInplaceInToOut) {
prog.MutableBlock(0)->Var("z0")->SetShape({32, 15, 1024, 1024}); prog.MutableBlock(0)->Var("z0")->SetShape({32, 15, 1024, 1024});
std::unique_ptr<ir::Graph> g(new ir::Graph(prog)); std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>()); g->Set(ir::kMemOptSkipVars, new std::unordered_set<std::string>());
auto pass = CreateInplacePass(); auto pass = CreateInplacePass();
pass->Apply(g.get()); pass->Apply(g.get());
auto op_node = GetNodeFromGraph(g.get(), "multi_out_grad"); auto op_node = GetNodeFromGraph(g.get(), "multi_out_grad");
......
...@@ -3,6 +3,9 @@ file(WRITE ${pass_file} "// Generated by the paddle/fluid/framework/ir/CMakeList ...@@ -3,6 +3,9 @@ file(WRITE ${pass_file} "// Generated by the paddle/fluid/framework/ir/CMakeList
file(APPEND ${pass_file} "\#pragma once\n") file(APPEND ${pass_file} "\#pragma once\n")
file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n") file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
add_subdirectory(fuse_optimizer_ops_pass)
add_subdirectory(memory_optimize_pass)
add_subdirectory(multi_devices_graph_pass)
# Usage: pass_library(target inference) will append to paddle_inference_pass.h # Usage: pass_library(target inference) will append to paddle_inference_pass.h
unset(INFER_IR_PASSES CACHE) # clear the global variable unset(INFER_IR_PASSES CACHE) # clear the global variable
...@@ -34,7 +37,6 @@ function(pass_library TARGET DEST) ...@@ -34,7 +37,6 @@ function(pass_library TARGET DEST)
endif() endif()
endfunction() endfunction()
cc_library(node SRCS node.cc DEPS proto_desc) cc_library(node SRCS node.cc DEPS proto_desc)
cc_library(graph SRCS graph.cc DEPS node pretty_log) cc_library(graph SRCS graph.cc DEPS node pretty_log)
cc_library(graph_helper SRCS graph_helper.cc DEPS graph) cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
...@@ -43,6 +45,8 @@ cc_library(graph_traits SRCS graph_traits.cc DEPS graph) ...@@ -43,6 +45,8 @@ cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits) cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass) cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper)
pass_library(graph_to_program_pass base) pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base) pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base) pass_library(lock_free_optimize_pass base)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.h"
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
DEFINE_uint64(fuse_parameter_memory_size, 0, // 0 KB
"fuse_parameter_memory_size is up limited memory size "
"of one group parameters' gradient which is the input "
"of communication calling(e.g NCCLAllReduce). "
"The default value is 0, it means that "
"not set group according to memory_size.");
DEFINE_int32(
fuse_parameter_groups_size, 3,
"fuse_parameter_groups_size is the size of one group parameters' gradient. "
"The default value is a experimental result. If the "
"fuse_parameter_groups_size is 1, it means that the groups size is "
"the number of parameters' gradient. If the fuse_parameter_groups_size is "
"-1, it means that there are only one group. The default value is 3, it is "
"an experimental value.");
namespace paddle {
namespace framework {
namespace ir {
// SetFuseParameterGroupsSize and SetFuseParameterMemorySize are used in unit
// test, because it is invalid that seting 'FLAGS_fuse_parameter_memory_size'
// and 'FLAGS_fuse_parameter_groups_size' in unit test.
void SetFuseParameterGroupsSize(int group_size) {
FLAGS_fuse_parameter_groups_size = group_size;
}
int GetFuseParameterGroupsSize() { return FLAGS_fuse_parameter_groups_size; }
void SetFuseParameterMemorySize(uint64_t memory_size) {
FLAGS_fuse_parameter_memory_size = memory_size;
}
uint64_t GetFuseParameterMemorySize() {
return FLAGS_fuse_parameter_memory_size;
}
static const char kUnKnow[] = "@UNKNOW@";
static framework::proto::VarType::Type kDefaultDtype =
framework::proto::VarType::Type::VarType_Type_BOOL;
class AllocContinuousSpaceForGradPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(details::kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(details::kLocalScopes);
ResetAttribute<details::ParamsAndGrads>(details::kParamsAndGrads, &result);
ResetAttribute<details::GroupGradsAndParams>(details::kGroupGradsAndParams,
&result);
// NOTE: The operator nodes should be in topology order.
std::vector<ir::Node *> topo_nodes = ir::TopologySortOperations(result);
auto &params_grads =
result.Get<details::ParamsAndGrads>(details::kParamsAndGrads);
for (auto &node : topo_nodes) {
RecordParamsAndGrads(node, &params_grads);
}
if (params_grads.size() == 0) {
VLOG(10) << "Doesn't find gradients";
return;
}
std::unordered_map<std::string, ir::Node *> vars;
for (ir::Node *node : result.Nodes()) {
if (node->IsVar() && node->Var()) {
// Note: The graph may have the same name node. For example, parameter
// is the input of operator and it also is the output of optimizer;
vars.emplace(node->Var()->Name(), node);
}
}
auto &group_grads_params =
result.Get<details::GroupGradsAndParams>(details::kGroupGradsAndParams);
// Note: the order of params_grads may be changed by SetGroupGradsAndParams.
SetGroupGradsAndParams(vars, params_grads, &group_grads_params);
params_grads.clear();
for (auto &group_p_g : group_grads_params) {
params_grads.insert(params_grads.begin(), group_p_g.begin(),
group_p_g.end());
}
for (auto &p_g : params_grads) {
std::swap(p_g.first, p_g.second);
}
// Set Gradients as Persistable to prevent this var becoming reusable.
auto dtype = kDefaultDtype;
for (auto &p_g : params_grads) {
// Get gradient var
auto iter = vars.find(p_g.second);
PADDLE_ENFORCE(iter != vars.end(), "%s is not found.", p_g.second);
iter->second->Var()->SetPersistable(true);
PADDLE_ENFORCE(IsSupportedVarType(iter->second->Var()->GetType()));
// Get Dtype
auto ele_dtype = iter->second->Var()->GetDataType();
if (dtype == kDefaultDtype) {
dtype = ele_dtype;
PADDLE_ENFORCE_NE(ele_dtype, kDefaultDtype,
"The data type should not be bool.");
}
PADDLE_ENFORCE_EQ(ele_dtype, dtype,
"The data type of input is not consistent.");
}
// Create a FusedVarsSet to avoid duplicating names for fused_var in other
// pass.
if (!result.Has(details::kFusedVars)) {
result.Set(details::kFusedVars, new details::FusedVars);
}
// the kFusedGrads is used be fuse_optimizer_op_pass.
result.Set(details::kFusedGrads, new details::FusedGrads);
// the fused_var_name should be unique, so it appends
// params_grads.begin()->second.
auto fused_var_name = std::string(details::kFusedVarNamePrefix) + "@GRAD@" +
params_grads.begin()->second;
result.Get<details::FusedGrads>(details::kFusedGrads) = fused_var_name;
auto &fused_var_set = result.Get<details::FusedVars>(details::kFusedVars);
PADDLE_ENFORCE_EQ(fused_var_set.count(fused_var_name), 0,
"%s is duplicate in FusedVars.", fused_var_name);
fused_var_set.insert(fused_var_name);
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars,
fused_var_name, params_grads);
}
template <typename AttrType>
void ResetAttribute(const std::string &attr_name, ir::Graph *graph) const {
if (graph->Has(attr_name)) {
VLOG(10) << attr_name << " is reset.";
graph->Erase(attr_name);
}
graph->Set(attr_name, new AttrType);
}
void SetGroupGradsAndParams(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const details::ParamsAndGrads &params_grads,
details::GroupGradsAndParams *group_grads_params) const {
SetGroupAccordingToLayers(var_nodes, params_grads, group_grads_params);
SetGroupAccordingToMemorySize(var_nodes, group_grads_params);
SetGroupAccordingToGroupSize(var_nodes, group_grads_params);
}
void SetGroupAccordingToLayers(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const details::ParamsAndGrads &params_grads,
details::GroupGradsAndParams *group_grads_params) const {
std::unordered_map<std::string, std::vector<int>> layer_params;
for (size_t i = 0; i < params_grads.size(); ++i) {
auto pos = params_grads[i].first.find_first_of(".");
if (pos == std::string::npos) {
layer_params[std::string(kUnKnow)].emplace_back(i);
} else {
layer_params[params_grads[i].first.substr(0, pos)].emplace_back(i);
}
}
group_grads_params->reserve(layer_params.size());
for (size_t i = 0; i < params_grads.size(); ++i) {
auto pos = params_grads[i].first.find_first_of(".");
std::string key = kUnKnow;
if (pos != std::string::npos) {
key = params_grads[i].first.substr(0, pos);
}
auto iter = layer_params.find(key);
if (iter == layer_params.end()) continue;
group_grads_params->emplace_back();
auto &local_group_grads_params = group_grads_params->back();
for (auto &idx : iter->second) {
local_group_grads_params.emplace_back(
std::make_pair(params_grads[idx].second, params_grads[idx].first));
}
layer_params.erase(iter);
}
VLOG(10) << "SetGroupAccordingToLayers: ";
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &p_g : group_grads_params->at(i)) {
out << "(" << p_g.second << ", " << p_g.first << "), ";
}
VLOG(10) << out.str();
}
}
void SetGroupAccordingToMemorySize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
details::GroupGradsAndParams *group_grads_params) const {
const uint64_t group_memory_size = GetFuseParameterMemorySize();
if (group_memory_size == 0) {
return;
}
details::GroupGradsAndParams local_group_grads_params;
size_t j = 0;
while (j < group_grads_params->size()) {
local_group_grads_params.emplace_back();
auto &group_p_g = local_group_grads_params.back();
size_t local_group_memory_size = 0;
while (j < group_grads_params->size()) {
std::for_each(
group_grads_params->at(j).begin(), group_grads_params->at(j).end(),
[&local_group_memory_size,
&var_nodes](const std::pair<std::string, std::string> &g_p) {
auto iter = var_nodes.find(g_p.second);
PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.",
g_p.second);
auto shape = iter->second->Var()->GetShape();
size_t size =
framework::SizeOfType(iter->second->Var()->GetDataType());
std::for_each(shape.begin(), shape.end(),
[&size](const int64_t &n) { size *= n; });
local_group_memory_size += size;
});
group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(),
group_grads_params->at(j).end());
++j;
if (local_group_memory_size >= group_memory_size) {
break;
}
}
}
std::swap(*group_grads_params, local_group_grads_params);
VLOG(10) << string::Sprintf(
"SetGroupAccordingToMemorySize(memory_size: %d):", group_memory_size);
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &g_p : group_grads_params->at(i)) {
auto iter = var_nodes.find(g_p.second);
PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.", g_p.second);
auto shape = iter->second->Var()->GetShape();
size_t size = framework::SizeOfType(iter->second->Var()->GetDataType());
std::for_each(shape.begin(), shape.end(),
[&size](const int64_t &n) { size *= n; });
out << string::Sprintf("(%s(%d), %s)", g_p.second, size, g_p.first);
}
VLOG(10) << out.str();
}
}
void SetGroupAccordingToGroupSize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
details::GroupGradsAndParams *group_grads_params) const {
if (GetFuseParameterGroupsSize() == 1) {
return;
}
const int group_size = GetFuseParameterGroupsSize() == -1
? static_cast<int>(group_grads_params->size())
: GetFuseParameterGroupsSize();
PADDLE_ENFORCE_GT(group_size, 1);
size_t groups = (group_grads_params->size() + group_size - 1) / group_size;
details::GroupGradsAndParams local_group_grads_params;
local_group_grads_params.reserve(groups);
size_t j = 0;
for (size_t i = 0; i < groups; ++i) {
local_group_grads_params.emplace_back();
auto &group_p_g = local_group_grads_params.back();
group_p_g.reserve(group_size);
while (j < group_grads_params->size()) {
group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(),
group_grads_params->at(j).end());
++j;
if (j % group_size == 0) break;
}
}
std::swap(*group_grads_params, local_group_grads_params);
VLOG(10) << string::Sprintf("SetGroupAccordingToGroupSize(group_size: %d):",
group_size);
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &p_g : group_grads_params->at(i)) {
out << "(" << p_g.second << ", " << p_g.first << "), ";
}
VLOG(10) << out.str();
}
}
private:
bool IsSupportedVarType(const proto::VarType::Type &type) const {
// Current only support LOD_TENSOR.
return type == proto::VarType::LOD_TENSOR;
}
void RecordParamsAndGrads(ir::Node *node,
details::ParamsAndGrads *params_grads) const {
try {
bool is_bk_op =
static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
if (!is_bk_op) return;
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
auto backward_vars =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, static_cast<size_t>(0));
for (size_t i = 0; i < backward_vars.size(); i += 2) {
VLOG(10) << "Trainable parameter: " << backward_vars[i]
<< ", gradient: " << backward_vars[i + 1];
params_grads->emplace_back(std::make_pair(
backward_vars[i] /*param*/, backward_vars[i + 1] /*grad*/));
}
} catch (boost::bad_get e) {
}
}
void InitFusedVarsAndAllocSpaceForVars(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::unordered_map<std::string, ir::Node *> &vars,
const std::string &fused_var_name,
const details::ParamsAndGrads &params_grads) const {
// Init Gradients and FusedVars
VLOG(10) << "Init FusedVars and Gradients.";
for (auto it = local_scopes.rbegin(); it != local_scopes.rend(); ++it) {
auto &scope = *it;
PADDLE_ENFORCE(scope->FindVar(fused_var_name) == nullptr,
"%s has existed in scope.", fused_var_name);
scope->Var(fused_var_name)->GetMutable<LoDTensor>();
for (auto &p_g : params_grads) {
auto iter = vars.find(p_g.second);
PADDLE_ENFORCE(iter != vars.end());
PADDLE_ENFORCE_NOT_NULL(iter->second->Var());
PADDLE_ENFORCE_EQ(iter->second->Var()->GetType(),
proto::VarType::LOD_TENSOR);
scope->Var(p_g.second)->GetMutable<LoDTensor>();
}
}
// Alloc continuous space for vars.
std::vector<std::string> grads_name;
std::vector<std::string> params_name;
grads_name.reserve(params_grads.size());
params_name.reserve(params_grads.size());
for (auto &p_g : params_grads) {
params_name.emplace_back(p_g.first);
grads_name.emplace_back(p_g.second);
}
framework::ProgramDesc program_desc;
AppendAllocSpaceForVarsOp(params_name, grads_name, fused_var_name,
program_desc.MutableBlock(0));
for (size_t i = 0; i < local_scopes.size(); ++i) {
for (auto &op_desc : program_desc.Block(0).AllOps()) {
auto op = OpRegistry::CreateOp(*op_desc);
op->Run(*local_scopes[i], places[i]);
}
}
}
void AppendAllocSpaceForVarsOp(const std::vector<std::string> &params_name,
const std::vector<std::string> &grads_name,
const std::string &fused_var_name,
BlockDesc *global_block) const {
auto op_desc = global_block->AppendOp();
op_desc->SetType("alloc_continuous_space");
op_desc->SetInput("Input", params_name);
op_desc->SetOutput("Output", grads_name);
op_desc->SetOutput("FusedOutput", {fused_var_name});
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(alloc_continuous_space_for_grad_pass,
paddle::framework::ir::AllocContinuousSpaceForGradPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -11,21 +11,19 @@ ...@@ -11,21 +11,19 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <algorithm>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
void SetFuseParameterGroupsSize(int group_size);
int GetFuseParameterGroupsSize();
class ReferenceCountPass : public ir::Pass { void SetFuseParameterMemorySize(uint64_t memory_size);
protected: uint64_t GetFuseParameterMemorySize();
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
cc_library(fuse_optimizer_op_pass SRCS fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(fuse_adam_op_pass SRCS fuse_adam_op_pass.cc DEPS fuse_optimizer_op_pass)
cc_library(fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc DEPS fuse_optimizer_op_pass)
cc_library(fuse_momentum_op_pass SRCS fuse_momentum_op_pass.cc DEPS fuse_optimizer_op_pass)
...@@ -16,16 +16,13 @@ ...@@ -16,16 +16,13 @@
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h" #include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
class FuseAdamOpPass : public FuseOptimizerOpPass { class FuseAdamOpPass : public FuseOptimizerOpPass {
private: private:
...@@ -203,10 +200,10 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { ...@@ -203,10 +200,10 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
} }
} }
}; };
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(fuse_adam_op_pass, paddle::framework::details::FuseAdamOpPass) REGISTER_PASS(fuse_adam_op_pass, paddle::framework::ir::FuseAdamOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces) .RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes); .RequirePassAttr(paddle::framework::details::kLocalScopes);
...@@ -16,14 +16,13 @@ ...@@ -16,14 +16,13 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h" #include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
class FuseMomentumOpPass : public FuseOptimizerOpPass { class FuseMomentumOpPass : public FuseOptimizerOpPass {
private: private:
...@@ -84,11 +83,10 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass { ...@@ -84,11 +83,10 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass {
} }
}; };
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(fuse_momentum_op_pass, REGISTER_PASS(fuse_momentum_op_pass, paddle::framework::ir::FuseMomentumOpPass)
paddle::framework::details::FuseMomentumOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces) .RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes); .RequirePassAttr(paddle::framework::details::kLocalScopes);
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h" #include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h"
#include <algorithm> #include <algorithm>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
...@@ -20,13 +20,13 @@ ...@@ -20,13 +20,13 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
ir::Graph &result = *graph; ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces); auto &places = Get<const std::vector<platform::Place>>(details::kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(kLocalScopes); auto &local_scopes = Get<const std::vector<Scope *>>(details::kLocalScopes);
const std::string fuse_op_type = GetOpType(); const std::string fuse_op_type = GetOpType();
std::vector<std::string> aux_var_names = GetAuxiliaryVarNames(); std::vector<std::string> aux_var_names = GetAuxiliaryVarNames();
...@@ -47,24 +47,24 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { ...@@ -47,24 +47,24 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
return; return;
} }
if (result.Has(kFusedOptType)) { if (result.Has(details::kFusedOptType)) {
VLOG(6) << "Currently only support fusing one type optimizer op. Has fused " VLOG(6) << "Currently only support fusing one type optimizer op. Has fused "
<< result.Get<FusedOptType>(kFusedOptType); << result.Get<details::FusedOptType>(details::kFusedOptType);
return; return;
} else { } else {
result.Set(kFusedOptType, new FusedOptType); result.Set(details::kFusedOptType, new details::FusedOptType);
} }
result.Get<FusedOptType>(kFusedOptType) = fuse_op_type; result.Get<details::FusedOptType>(details::kFusedOptType) = fuse_op_type;
// Step 2: Insert fused_var_name to FusedVars, and the FusedVars need be // Step 2: Insert fused_var_name to FusedVars, and the FusedVars need be
// initialized in scopes before execution. // initialized in scopes before execution.
if (!result.Has(kFusedVars)) { if (!result.Has(details::kFusedVars)) {
result.Set(kFusedVars, new FusedVars); result.Set(details::kFusedVars, new details::FusedVars);
} }
std::unordered_map<std::string, std::string> fused_vars_name; std::unordered_map<std::string, std::string> fused_vars_name;
fused_vars_name.reserve(aux_var_names.size()); fused_vars_name.reserve(aux_var_names.size());
auto &fused_var_set = result.Get<FusedVars>(kFusedVars); auto &fused_var_set = result.Get<details::FusedVars>(details::kFusedVars);
const std::string prefix(kFusedVarNamePrefix); const std::string prefix(details::kFusedVarNamePrefix);
// NOTE: the fused_var_name should be unique. // NOTE: the fused_var_name should be unique.
for (auto &var_name : aux_var_names) { for (auto &var_name : aux_var_names) {
auto fused_var_name = prefix + "_" + fuse_op_type + "_" + var_name + "_" + auto fused_var_name = prefix + "_" + fuse_op_type + "_" + var_name + "_" +
...@@ -77,8 +77,9 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { ...@@ -77,8 +77,9 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
// Step 3: Get the fused Gradient's name // Step 3: Get the fused Gradient's name
bool grad_fused = false; bool grad_fused = false;
if (result.Has(kParamsAndGrads)) { if (result.Has(details::kParamsAndGrads)) {
auto &params_grads = result.Get<ParamsAndGrads>(kParamsAndGrads); auto &params_grads =
result.Get<details::ParamsAndGrads>(details::kParamsAndGrads);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
params_grads.size(), aux_var_set.at(kGrad).size(), params_grads.size(), aux_var_set.at(kGrad).size(),
"The number of gradients and optimizer ops is not equal."); "The number of gradients and optimizer ops is not equal.");
...@@ -94,13 +95,13 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { ...@@ -94,13 +95,13 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
// NOTE(zcd): the gradient of kParamsAndGrads may be different with the // NOTE(zcd): the gradient of kParamsAndGrads may be different with the
// kGrad. // kGrad.
if (same_grad_num == aux_var_set.at(kGrad).size()) { if (same_grad_num == aux_var_set.at(kGrad).size()) {
if (!result.Has(kFusedGrads)) { if (!result.Has(details::kFusedGrads)) {
PADDLE_THROW( PADDLE_THROW(
"The alloc_continuous_space_for_grad_pass should be called before " "The alloc_continuous_space_for_grad_pass should be called before "
"this pass."); "this pass.");
} }
auto &fused_grad = result.Get<FusedGrads>(kFusedGrads); auto &fused_grad = result.Get<details::FusedGrads>(details::kFusedGrads);
auto &fused_vars = result.Get<FusedVars>(kFusedVars); auto &fused_vars = result.Get<details::FusedVars>(details::kFusedVars);
auto iter = std::find(fused_vars.begin(), fused_vars.end(), fused_grad); auto iter = std::find(fused_vars.begin(), fused_vars.end(), fused_grad);
PADDLE_ENFORCE(iter != fused_vars.end(), "Not find the fused_grad."); PADDLE_ENFORCE(iter != fused_vars.end(), "Not find the fused_grad.");
fused_vars_name[kGrad] = fused_grad; fused_vars_name[kGrad] = fused_grad;
...@@ -323,6 +324,6 @@ void FuseOptimizerOpPass::InserInputAndOutputForOptOps( ...@@ -323,6 +324,6 @@ void FuseOptimizerOpPass::InserInputAndOutputForOptOps(
opt_node->outputs.insert(opt_node->outputs.begin(), outputs.begin(), opt_node->outputs.insert(opt_node->outputs.begin(), outputs.begin(),
outputs.end()); outputs.end());
} }
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
constexpr char kGrad[] = "Grad"; constexpr char kGrad[] = "Grad";
constexpr char kParam[] = "Param"; constexpr char kParam[] = "Param";
...@@ -90,6 +90,6 @@ class FuseOptimizerOpPass : public ir::Pass { ...@@ -90,6 +90,6 @@ class FuseOptimizerOpPass : public ir::Pass {
const std::string &fused_var_name) const; const std::string &fused_var_name) const;
}; };
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,18 +14,13 @@ ...@@ -14,18 +14,13 @@
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/build_strategy.h" #include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
class FuseSgdOpPass : public FuseOptimizerOpPass { class FuseSgdOpPass : public FuseOptimizerOpPass {
private: private:
...@@ -66,10 +61,10 @@ class FuseSgdOpPass : public FuseOptimizerOpPass { ...@@ -66,10 +61,10 @@ class FuseSgdOpPass : public FuseOptimizerOpPass {
InserInputAndOutputForOptOps(sgd_ops, sgd_node); InserInputAndOutputForOptOps(sgd_ops, sgd_node);
} }
}; };
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(fuse_sgd_op_pass, paddle::framework::details::FuseSgdOpPass) REGISTER_PASS(fuse_sgd_op_pass, paddle::framework::ir::FuseSgdOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces) .RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes); .RequirePassAttr(paddle::framework::details::kLocalScopes);
cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base)
cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle)
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
if(WITH_GPU)
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper gpu_info)
else()
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper cpu_info)
endif()
cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass)
cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info)
cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass reference_count_pass_helper)
cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper)
...@@ -27,11 +27,11 @@ ...@@ -27,11 +27,11 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
// op -> variables which can be deleted after op runs // op -> variables which can be deleted after op runs
using OpToVarNameSetMap = using OpToVarNameSetMap = std::unordered_map<details::ComputationOpHandle *,
std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>>; std::unordered_set<std::string>>;
static std::map<size_t, std::unordered_set<std::string>> VarsGroupByScopeIdx( static std::map<size_t, std::unordered_set<std::string>> VarsGroupByScopeIdx(
const OpToVarNameSetMap &map) { const OpToVarNameSetMap &map) {
...@@ -53,7 +53,8 @@ static bool IsLoDTensor(VarDesc *var) { ...@@ -53,7 +53,8 @@ static bool IsLoDTensor(VarDesc *var) {
// Get memory size of LoDTensor // Get memory size of LoDTensor
static int64_t GetMemorySize( static int64_t GetMemorySize(
const std::unordered_map<std::string, std::vector<VarHandle *>> &vars, const std::unordered_map<std::string, std::vector<details::VarHandle *>>
&vars,
const std::string &var_name) { const std::string &var_name) {
auto *var_desc = TryGetLatestVarDesc(vars.at(var_name)); auto *var_desc = TryGetLatestVarDesc(vars.at(var_name));
PADDLE_ENFORCE_NOT_NULL(var_desc); PADDLE_ENFORCE_NOT_NULL(var_desc);
...@@ -69,13 +70,13 @@ static int64_t GetMemorySize( ...@@ -69,13 +70,13 @@ static int64_t GetMemorySize(
// Since partial GC is based on static analysis of memory size of each variable // Since partial GC is based on static analysis of memory size of each variable
// So we should skip SelectedRows and LoDTensorArray here // So we should skip SelectedRows and LoDTensorArray here
static void SplitIntoLoDTensorAndNonLoDTensorVars( static void SplitIntoLoDTensorAndNonLoDTensorVars(
const OpToVarNameSetMap &m, const GraphVars &vars, const OpToVarNameSetMap &m, const details::GraphVars &vars,
OpToVarNameSetMap *lod_tensors, OpToVarNameSetMap *other_vars) { OpToVarNameSetMap *lod_tensors, OpToVarNameSetMap *other_vars) {
lod_tensors->clear(); lod_tensors->clear();
other_vars->clear(); other_vars->clear();
for (auto &op_vars_pair : m) { for (auto &op_vars_pair : m) {
for (auto &var_name : op_vars_pair.second) { for (auto var_name : op_vars_pair.second) {
auto *var_desc = TryGetLatestVarDesc( auto *var_desc = TryGetLatestVarDesc(
vars[op_vars_pair.first->GetScopeIdx()].at(var_name)); vars[op_vars_pair.first->GetScopeIdx()].at(var_name));
if (IsLoDTensor(var_desc)) { if (IsLoDTensor(var_desc)) {
...@@ -89,7 +90,7 @@ static void SplitIntoLoDTensorAndNonLoDTensorVars( ...@@ -89,7 +90,7 @@ static void SplitIntoLoDTensorAndNonLoDTensorVars(
struct GCVarInfo { struct GCVarInfo {
GCVarInfo(const std::string &name, int64_t memory_size, GCVarInfo(const std::string &name, int64_t memory_size,
ComputationOpHandle *op, size_t scope_idx) details::ComputationOpHandle *op, size_t scope_idx)
: name_(name), : name_(name),
memory_size_(memory_size), memory_size_(memory_size),
op_(op), op_(op),
...@@ -97,7 +98,8 @@ struct GCVarInfo { ...@@ -97,7 +98,8 @@ struct GCVarInfo {
std::string name_; // variable name std::string name_; // variable name
int64_t memory_size_; // memory size int64_t memory_size_; // memory size
ComputationOpHandle *op_; // op after which the variable could be deleted details::ComputationOpHandle
*op_; // op after which the variable could be deleted
size_t scope_idx_; // scope index where the variable locates size_t scope_idx_; // scope index where the variable locates
int64_t AbsMemorySize() const { return std::abs(memory_size_); } int64_t AbsMemorySize() const { return std::abs(memory_size_); }
...@@ -105,7 +107,7 @@ struct GCVarInfo { ...@@ -105,7 +107,7 @@ struct GCVarInfo {
// Delete delete_lod_tensor_only is not used currently // Delete delete_lod_tensor_only is not used currently
static OpToVarNameSetMap ShrinkGCVars( static OpToVarNameSetMap ShrinkGCVars(
const OpToVarNameSetMap &m, const GraphVars &vars, const OpToVarNameSetMap &m, const details::GraphVars &vars,
const std::vector<platform::Place> &places, double fraction_of_memory_size, const std::vector<platform::Place> &places, double fraction_of_memory_size,
bool delete_lod_tensor_only = false) { bool delete_lod_tensor_only = false) {
// Do not perform gc when fraction_of_memory_size = 0 // Do not perform gc when fraction_of_memory_size = 0
...@@ -192,7 +194,7 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { ...@@ -192,7 +194,7 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE(ref_cnts.empty(), PADDLE_ENFORCE(ref_cnts.empty(),
"kRuntimeReferenceCount should be initialized here!"); "kRuntimeReferenceCount should be initialized here!");
const auto &vars = graph->Get<GraphVars>(kGraphVars); const auto &vars = graph->Get<details::GraphVars>(details::kGraphVars);
ref_cnts.resize(vars.size()); ref_cnts.resize(vars.size());
const auto &last_live_ops = const auto &last_live_ops =
...@@ -222,27 +224,31 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { ...@@ -222,27 +224,31 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
auto *eager_deletion_node = auto *eager_deletion_node =
graph->CreateEmptyNode("eager_deletion", ir::Node::Type::kOperation); graph->CreateEmptyNode("eager_deletion", ir::Node::Type::kOperation);
auto *eager_deletion_op = new EagerDeletionOpHandle( auto *eager_deletion_op = new details::EagerDeletionOpHandle(
eager_deletion_node, op->GetScope(), op->GetPlace(), var_names, eager_deletion_node, op->GetScope(), op->GetPlace(), var_names,
gcs.at(places[op->GetScopeIdx()]).get(), gcs.at(places[op->GetScopeIdx()]).get(),
&(ref_cnts[op->GetScopeIdx()])); &(ref_cnts[op->GetScopeIdx()]));
auto it = std::find_if( auto it = std::find_if(
op->Outputs().begin(), op->Outputs().end(), [](VarHandleBase *var) { op->Outputs().begin(), op->Outputs().end(),
return dynamic_cast<DummyVarHandle *>(var) != nullptr; [](details::VarHandleBase *var) {
return dynamic_cast<details::DummyVarHandle *>(var) != nullptr;
}); });
if (it != op->Outputs().end()) { if (it != op->Outputs().end()) {
eager_deletion_op->AddInput(*it); eager_deletion_op->AddInput(*it);
} else { } else {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); auto *dep_var = new details::DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var); graph->Get<details::GraphDepVars>(details::kGraphDepVars)
.emplace(dep_var);
op->AddOutput(dep_var); op->AddOutput(dep_var);
eager_deletion_op->AddInput(dep_var); eager_deletion_op->AddInput(dep_var);
} }
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar()); auto *dummy_leaf =
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf); new details::DummyVarHandle(graph->CreateControlDepVar());
graph->Get<details::GraphDepVars>(details::kGraphDepVars)
.emplace(dummy_leaf);
eager_deletion_op->AddOutput(dummy_leaf); eager_deletion_op->AddOutput(dummy_leaf);
} }
...@@ -262,15 +268,14 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { ...@@ -262,15 +268,14 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
while_op_eager_deletion_pass->Apply(graph); while_op_eager_deletion_pass->Apply(graph);
} }
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(eager_deletion_pass, REGISTER_PASS(eager_deletion_pass, paddle::framework::ir::EagerDeletionPass)
paddle::framework::details::EagerDeletionPass) .RequirePassAttr(paddle::framework::ir::kRuntimeReferenceCount)
.RequirePassAttr(paddle::framework::details::kRuntimeReferenceCount) .RequirePassAttr(paddle::framework::ir::kLastLiveOpsOfVars)
.RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars) .RequirePassAttr(paddle::framework::ir::kAllPlaces)
.RequirePassAttr(paddle::framework::details::kAllPlaces) .RequirePassAttr(paddle::framework::ir::kGarbageCollector);
.RequirePassAttr(paddle::framework::details::kGarbageCollector);
USE_PASS(while_op_eager_deletion_pass); USE_PASS(while_op_eager_deletion_pass);
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
#include <queue> #include <queue>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
...@@ -52,7 +52,7 @@ DECLARE_string(memory_optimize_debug); ...@@ -52,7 +52,7 @@ DECLARE_string(memory_optimize_debug);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
// clang-format off // clang-format off
const std::string kInplacedOpWhiteList[] = { // NOLINT const std::string kInplacedOpWhiteList[] = { // NOLINT
...@@ -199,8 +199,8 @@ bool InplacePass::CheckOpDeps(ir::Node *op, ...@@ -199,8 +199,8 @@ bool InplacePass::CheckOpDeps(ir::Node *op,
void InplacePass::CollectSkipVars(ir::Graph *graph, void InplacePass::CollectSkipVars(ir::Graph *graph,
const std::vector<ir::Node *> &ops) const { const std::vector<ir::Node *> &ops) const {
// 1. Collect op role vars // 1. Collect op role vars
PADDLE_ENFORCE(graph->Has(details::kMemOptSkipVars), PADDLE_ENFORCE(graph->Has(kMemOptSkipVars), "Graph should have attr %s",
"Graph should have attr %s", details::kMemOptSkipVars); kMemOptSkipVars);
auto &mem_opt_whitelist = graph->Get<MemOptSkipVars>(kMemOptSkipVars); auto &mem_opt_whitelist = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
for (const auto &var : mem_opt_whitelist) { for (const auto &var : mem_opt_whitelist) {
skip_vars_.emplace(var); skip_vars_.emplace(var);
...@@ -452,8 +452,7 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const { ...@@ -452,8 +452,7 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
continue; continue;
} }
if (details::NodeSize(*in_node->Var()) != if (NodeSize(*in_node->Var()) != NodeSize(*out_node->Var()) &&
details::NodeSize(*out_node->Var()) &&
kSameShapeOpWhiteSet.count(op_desc->Type()) == 0) { kSameShapeOpWhiteSet.count(op_desc->Type()) == 0) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " is not the same size with " << " is not the same size with "
...@@ -476,9 +475,9 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const { ...@@ -476,9 +475,9 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
} }
} }
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(inplace_pass, paddle::framework::details::InplacePass) REGISTER_PASS(inplace_pass, paddle::framework::ir::InplacePass)
.RequirePassAttr(paddle::framework::details::kUseCuda); .RequirePassAttr(paddle::framework::ir::kUseCuda);
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include <algorithm> #include <algorithm>
#include <deque> #include <deque>
#include <functional> #include <functional>
...@@ -32,14 +32,15 @@ ...@@ -32,14 +32,15 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
using paddle::framework::VarDesc; using paddle::framework::VarDesc;
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph) { std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph) {
PADDLE_ENFORCE(graph.Has(kStaleProgramOpDescs), PADDLE_ENFORCE(graph.Has(details::kStaleProgramOpDescs),
"Graph has no attribute of kStaleProgramOpDescs."); "Graph has no attribute of kStaleProgramOpDescs.");
// 1. get op desc order // 1. get op desc order
auto& op_descs = graph.Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs); auto& op_descs =
graph.Get<const std::vector<OpDesc*>>(details::kStaleProgramOpDescs);
// 2. topology sort order // 2. topology sort order
auto nodes = graph.Nodes(); auto nodes = graph.Nodes();
...@@ -563,6 +564,6 @@ ir::Node* ControlFlowGraph::GetNodeByName(const std::string& name, ...@@ -563,6 +564,6 @@ ir::Node* ControlFlowGraph::GetNodeByName(const std::string& name,
return found_node; return found_node;
} }
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
/// this attribute is used to avoid some core variables removed/reused /// this attribute is used to avoid some core variables removed/reused
/// in memory optimize related passes /// in memory optimize related passes
...@@ -184,6 +184,6 @@ void FilterVariables(const Container& nodes, Callback callback) { ...@@ -184,6 +184,6 @@ void FilterVariables(const Container& nodes, Callback callback) {
FilterVariableImpl<Container, Callback>()(nodes, callback); FilterVariableImpl<Container, Callback>()(nodes, callback);
} }
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -11,8 +11,7 @@ ...@@ -11,8 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <iterator> #include <iterator>
...@@ -32,7 +31,7 @@ ...@@ -32,7 +31,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
TEST(OrderedSet, Normal) { TEST(OrderedSet, Normal) {
OrderedSet pool; OrderedSet pool;
...@@ -153,7 +152,7 @@ TEST(OrderedSet, FindBestFitNode) { ...@@ -153,7 +152,7 @@ TEST(OrderedSet, FindBestFitNode) {
ASSERT_TRUE(cache == nullptr); ASSERT_TRUE(cache == nullptr);
} }
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -188,7 +187,7 @@ REGISTER_OPERATOR(dummy, paddle::framework::DummyOp, ...@@ -188,7 +187,7 @@ REGISTER_OPERATOR(dummy, paddle::framework::DummyOp,
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
inline static ProgramDesc FillProgramDesc() { inline static ProgramDesc FillProgramDesc() {
ProgramDesc prog; ProgramDesc prog;
...@@ -521,6 +520,6 @@ TEST(SortOpLikeDescOrder, AddAndReplaceOpDescInplace) { ...@@ -521,6 +520,6 @@ TEST(SortOpLikeDescOrder, AddAndReplaceOpDescInplace) {
} }
} }
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/memory_optimize_pass.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_pass.h"
#include <algorithm> #include <algorithm>
#include <atomic> #include <atomic>
#include <deque> #include <deque>
...@@ -42,12 +42,12 @@ DEFINE_string(memory_optimize_debug, "", ...@@ -42,12 +42,12 @@ DEFINE_string(memory_optimize_debug, "",
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
void MemoryOptimizePass::ApplyImpl(ir::Graph* graph) const { void MemoryOptimizePass::ApplyImpl(ir::Graph* graph) const {
CollectSkipVarsSet(graph); CollectSkipVarsSet(graph);
cfg_.reset(new details::ControlFlowGraph(*graph)); cfg_.reset(new ControlFlowGraph(*graph));
cfg_->LiveVariableAnalysis(); cfg_->LiveVariableAnalysis();
InitSSAGraphNodes(); InitSSAGraphNodes();
...@@ -205,7 +205,7 @@ void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const { ...@@ -205,7 +205,7 @@ void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {
void MemoryOptimizePass::CollectSkipVarsSet(ir::Graph* graph) const { void MemoryOptimizePass::CollectSkipVarsSet(ir::Graph* graph) const {
// fill skip_set_ // fill skip_set_
PADDLE_ENFORCE(graph->Has(details::kMemOptSkipVars)); PADDLE_ENFORCE(graph->Has(kMemOptSkipVars));
auto& mem_opt_whitelist = graph->Get<MemOptSkipVars>(kMemOptSkipVars); auto& mem_opt_whitelist = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
for (const auto& var : mem_opt_whitelist) { for (const auto& var : mem_opt_whitelist) {
skip_set_.emplace(var); skip_set_.emplace(var);
...@@ -316,10 +316,9 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var, ...@@ -316,10 +316,9 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
} }
} }
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(memory_optimize_pass, REGISTER_PASS(memory_optimize_pass, paddle::framework::ir::MemoryOptimizePass)
paddle::framework::details::MemoryOptimizePass)
.RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs); .RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
...@@ -26,13 +26,13 @@ ...@@ -26,13 +26,13 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
class MemoryOptimizePass : public ir::Pass { class MemoryOptimizePass : public ir::Pass {
protected: protected:
...@@ -67,6 +67,6 @@ class MemoryOptimizePass : public ir::Pass { ...@@ -67,6 +67,6 @@ class MemoryOptimizePass : public ir::Pass {
mutable std::map<std::string, std::vector<ir::Node*>> var_nodes_; mutable std::map<std::string, std::vector<ir::Node*>> var_nodes_;
}; };
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -12,17 +12,19 @@ ...@@ -12,17 +12,19 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/op_graph_view.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h"
#include <queue> #include <queue>
#include <utility> #include <utility>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
OpGraphView::OpGraphView(const std::vector<OpHandleBase *> &ops) { Build(ops); } OpGraphView::OpGraphView(const std::vector<details::OpHandleBase *> &ops) {
Build(ops);
}
void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) { void OpGraphView::Build(const std::vector<details::OpHandleBase *> &ops) {
preceding_ops_.clear(); preceding_ops_.clear();
pending_ops_.clear(); pending_ops_.clear();
for (auto &op : ops) { for (auto &op : ops) {
...@@ -40,8 +42,8 @@ void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) { ...@@ -40,8 +42,8 @@ void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
"There are duplicate ops in graph."); "There are duplicate ops in graph.");
} }
std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const { std::unordered_set<details::OpHandleBase *> OpGraphView::AllOps() const {
std::unordered_set<OpHandleBase *> ret; std::unordered_set<details::OpHandleBase *> ret;
ret.reserve(preceding_ops_.size()); ret.reserve(preceding_ops_.size());
for (auto &pair : preceding_ops_) { for (auto &pair : preceding_ops_) {
ret.insert(pair.first); ret.insert(pair.first);
...@@ -49,21 +51,21 @@ std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const { ...@@ -49,21 +51,21 @@ std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
return ret; return ret;
} }
bool OpGraphView::HasOp(OpHandleBase *op) const { bool OpGraphView::HasOp(details::OpHandleBase *op) const {
return preceding_ops_.count(op) != 0; return preceding_ops_.count(op) != 0;
} }
void OpGraphView::EnforceHasOp(OpHandleBase *op) const { void OpGraphView::EnforceHasOp(details::OpHandleBase *op) const {
PADDLE_ENFORCE(HasOp(op), "Cannot find op %s in OpGraphView", PADDLE_ENFORCE(HasOp(op), "Cannot find op %s in OpGraphView",
op == nullptr ? "nullptr" : op->DebugString()); op == nullptr ? "nullptr" : op->DebugString());
} }
const std::unordered_set<OpHandleBase *> &OpGraphView::PendingOps( const std::unordered_set<details::OpHandleBase *> &OpGraphView::PendingOps(
OpHandleBase *op) const { details::OpHandleBase *op) const {
EnforceHasOp(op); EnforceHasOp(op);
return pending_ops_.at(op); return pending_ops_.at(op);
} }
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -22,39 +22,42 @@ ...@@ -22,39 +22,42 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
class OpGraphView { class OpGraphView {
public: public:
explicit OpGraphView(const std::vector<OpHandleBase *> &ops); explicit OpGraphView(const std::vector<details::OpHandleBase *> &ops);
std::unordered_set<OpHandleBase *> AllOps() const; std::unordered_set<details::OpHandleBase *> AllOps() const;
const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const; const std::unordered_set<details::OpHandleBase *> &PendingOps(
details::OpHandleBase *op) const;
bool HasOp(OpHandleBase *op) const; bool HasOp(details::OpHandleBase *op) const;
// Use a visitor to visit all pending ops of op // Use a visitor to visit all pending ops of op
// Stop when callback returns false // Stop when callback returns false
template <typename Callback> template <typename Callback>
bool VisitAllPendingOps(OpHandleBase *op, Callback &&callback) const; bool VisitAllPendingOps(details::OpHandleBase *op, Callback &&callback) const;
private: private:
void Build(const std::vector<OpHandleBase *> &ops); void Build(const std::vector<details::OpHandleBase *> &ops);
void EnforceHasOp(OpHandleBase *op) const; void EnforceHasOp(details::OpHandleBase *op) const;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>> std::unordered_map<details::OpHandleBase *,
std::unordered_set<details::OpHandleBase *>>
preceding_ops_; preceding_ops_;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>> std::unordered_map<details::OpHandleBase *,
std::unordered_set<details::OpHandleBase *>>
pending_ops_; pending_ops_;
}; };
template <typename Callback> template <typename Callback>
bool OpGraphView::VisitAllPendingOps(OpHandleBase *op, bool OpGraphView::VisitAllPendingOps(details::OpHandleBase *op,
Callback &&callback) const { Callback &&callback) const {
EnforceHasOp(op); EnforceHasOp(op);
std::unordered_set<OpHandleBase *> visited; std::unordered_set<details::OpHandleBase *> visited;
std::queue<OpHandleBase *> q; std::queue<details::OpHandleBase *> q;
q.push(op); q.push(op);
while (!q.empty()) { while (!q.empty()) {
op = q.front(); op = q.front();
...@@ -72,6 +75,6 @@ bool OpGraphView::VisitAllPendingOps(OpHandleBase *op, ...@@ -72,6 +75,6 @@ bool OpGraphView::VisitAllPendingOps(OpHandleBase *op,
return true; return true;
} }
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -15,16 +15,16 @@ ...@@ -15,16 +15,16 @@
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
class RecordSkipMemoryOptVarsPass : public ir::Pass { class RecordSkipMemoryOptVarsPass : public ir::Pass {
protected: protected:
...@@ -162,9 +162,9 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass { ...@@ -162,9 +162,9 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass {
} }
}; };
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(record_skip_memory_opt_vars_pass, REGISTER_PASS(record_skip_memory_opt_vars_pass,
paddle::framework::details::RecordSkipMemoryOptVarsPass); paddle::framework::ir::RecordSkipMemoryOptVarsPass);
...@@ -24,14 +24,20 @@ ...@@ -24,14 +24,20 @@
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" #include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
class ReferenceCountPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override;
};
// A functor to shrink/remove operators who depend on other operators in a set // A functor to shrink/remove operators who depend on other operators in a set
class ShrinkDepsOpFunctor { class ShrinkDepsOpFunctor {
...@@ -39,19 +45,21 @@ class ShrinkDepsOpFunctor { ...@@ -39,19 +45,21 @@ class ShrinkDepsOpFunctor {
enum RelationShip { kSame = 0, kNoDeps = 1, kBefore = 2, kAfter = 3 }; enum RelationShip { kSame = 0, kNoDeps = 1, kBefore = 2, kAfter = 3 };
public: public:
explicit ShrinkDepsOpFunctor(const std::vector<OpHandleBase *> &all_ops) explicit ShrinkDepsOpFunctor(
const std::vector<details::OpHandleBase *> &all_ops)
: graph_(all_ops) {} : graph_(all_ops) {}
template <typename OpSet> template <typename OpSet>
OpSet operator()(const OpSet &op_set) const { OpSet operator()(const OpSet &op_set) const {
using KeyType = typename OpSet::key_type; using KeyType = typename OpSet::key_type;
static_assert( static_assert(
std::is_base_of<OpHandleBase, std::is_base_of<details::OpHandleBase,
typename std::remove_pointer<KeyType>::type>::value, typename std::remove_pointer<KeyType>::type>::value,
"Key type of OpSet must be OpHandleBase, or derived of OpHandleBase"); "Key type of OpSet must be details::OpHandleBase, or derived of "
"details::OpHandleBase");
if (op_set.size() <= 1) return op_set; if (op_set.size() <= 1) return op_set;
std::vector<OpHandleBase *> ops(op_set.begin(), op_set.end()); std::vector<details::OpHandleBase *> ops(op_set.begin(), op_set.end());
OpSet ret; OpSet ret;
auto rels = GetRelations(ops); auto rels = GetRelations(ops);
auto not_before = [](RelationShip r) { return r != kBefore; }; auto not_before = [](RelationShip r) { return r != kBefore; };
...@@ -65,8 +73,8 @@ class ShrinkDepsOpFunctor { ...@@ -65,8 +73,8 @@ class ShrinkDepsOpFunctor {
private: private:
std::vector<std::vector<RelationShip>> GetRelations( std::vector<std::vector<RelationShip>> GetRelations(
const std::vector<OpHandleBase *> &ops) const { const std::vector<details::OpHandleBase *> &ops) const {
std::unordered_map<OpHandleBase *, size_t> op_to_idx; std::unordered_map<details::OpHandleBase *, size_t> op_to_idx;
for (size_t i = 0; i < ops.size(); ++i) { for (size_t i = 0; i < ops.size(); ++i) {
PADDLE_ENFORCE(graph_.HasOp(ops[i]), "Op does not exist in graph"); PADDLE_ENFORCE(graph_.HasOp(ops[i]), "Op does not exist in graph");
op_to_idx[ops[i]] = i; op_to_idx[ops[i]] = i;
...@@ -81,7 +89,7 @@ class ShrinkDepsOpFunctor { ...@@ -81,7 +89,7 @@ class ShrinkDepsOpFunctor {
size_t found_num = ops.size(); size_t found_num = ops.size();
size_t total_num = ops.size() * ops.size(); size_t total_num = ops.size() * ops.size();
auto visitor = [&](OpHandleBase *op, size_t i) { auto visitor = [&](details::OpHandleBase *op, size_t i) {
auto it = op_to_idx.find(op); auto it = op_to_idx.find(op);
if (it != op_to_idx.end()) { if (it != op_to_idx.end()) {
size_t j = it->second; size_t j = it->second;
...@@ -98,7 +106,9 @@ class ShrinkDepsOpFunctor { ...@@ -98,7 +106,9 @@ class ShrinkDepsOpFunctor {
}; };
for (size_t i = 0; i < ops.size(); ++i) { for (size_t i = 0; i < ops.size(); ++i) {
auto sub_visitor = [&, i](OpHandleBase *op) { return visitor(op, i); }; auto sub_visitor = [&, i](details::OpHandleBase *op) {
return visitor(op, i);
};
if (!graph_.VisitAllPendingOps(ops[i], sub_visitor)) { if (!graph_.VisitAllPendingOps(ops[i], sub_visitor)) {
break; break;
} }
...@@ -133,8 +143,8 @@ class ShrinkDepsOpFunctor { ...@@ -133,8 +143,8 @@ class ShrinkDepsOpFunctor {
*/ */
static bool ShrinkNoNeedBufferVarOpDependency( static bool ShrinkNoNeedBufferVarOpDependency(
const std::string &var_name, const std::string &var_name,
std::unordered_set<ComputationOpHandle *> *op_handles) { std::unordered_set<details::ComputationOpHandle *> *op_handles) {
std::vector<ComputationOpHandle *> skip_ops; std::vector<details::ComputationOpHandle *> skip_ops;
for (auto *op_handle : *op_handles) { for (auto *op_handle : *op_handles) {
auto *op_base = op_handle->GetOp(); auto *op_base = op_handle->GetOp();
auto &inferer = op_base->Info().NoNeedBufferVarsInferer(); auto &inferer = op_base->Info().NoNeedBufferVarsInferer();
...@@ -195,15 +205,15 @@ static bool ShrinkNoNeedBufferVarOpDependency( ...@@ -195,15 +205,15 @@ static bool ShrinkNoNeedBufferVarOpDependency(
* Find the nearest downstream computation op handle. If the op is a * Find the nearest downstream computation op handle. If the op is a
* computation op, just return itself. * computation op, just return itself.
*/ */
static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself( static details::ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
OpHandleBase *op, size_t scope_idx) { details::OpHandleBase *op, size_t scope_idx) {
std::queue<OpHandleBase *> q; std::queue<details::OpHandleBase *> q;
std::unordered_set<OpHandleBase *> visited; std::unordered_set<details::OpHandleBase *> visited;
q.push(op); q.push(op);
while (!q.empty()) { while (!q.empty()) {
auto *op = q.front(); auto *op = q.front();
q.pop(); q.pop();
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op); auto *compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) { if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) {
return compute_op; return compute_op;
} }
...@@ -220,13 +230,13 @@ static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself( ...@@ -220,13 +230,13 @@ static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
enum LastLiveOpSearchStatus { kSuccess, kFailure, kShouldPrecede }; enum LastLiveOpSearchStatus { kSuccess, kFailure, kShouldPrecede };
static std::unordered_set<ComputationOpHandle *> static std::unordered_set<details::ComputationOpHandle *>
ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx, ExtractComputationOpFromLastLivedVar(details::VarHandle *var, size_t scope_idx,
const std::string &var_name, const std::string &var_name,
const ShrinkDepsOpFunctor &shrink_func, const ShrinkDepsOpFunctor &shrink_func,
LastLiveOpSearchStatus *status) { LastLiveOpSearchStatus *status) {
// stage one. Get last op for variable. // stage one. Get last op for variable.
std::unordered_set<OpHandleBase *> candidates; std::unordered_set<details::OpHandleBase *> candidates;
{ {
if (var->PendingOps().empty() && var->GeneratedOp()) { if (var->PendingOps().empty() && var->GeneratedOp()) {
// No operator depends on this variable. So the last operator is the op // No operator depends on this variable. So the last operator is the op
...@@ -251,7 +261,7 @@ ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx, ...@@ -251,7 +261,7 @@ ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
// some op handle may operate on many DeviceContext, however, our garbage // some op handle may operate on many DeviceContext, however, our garbage
// collector can only wait one DeviceContext for now. So currently, we wait // collector can only wait one DeviceContext for now. So currently, we wait
// the nearest compute op. // the nearest compute op.
std::unordered_set<ComputationOpHandle *> computation_op; std::unordered_set<details::ComputationOpHandle *> computation_op;
{ {
for (auto *op : candidates) { for (auto *op : candidates) {
auto *compute_op = auto *compute_op =
...@@ -293,13 +303,13 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { ...@@ -293,13 +303,13 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
"Last Live Ops and Reference Counts of vars should be " "Last Live Ops and Reference Counts of vars should be "
"initialized at here."); "initialized at here.");
const auto &vars = graph->Get<GraphVars>(kGraphVars); const auto &vars = graph->Get<details::GraphVars>(details::kGraphVars);
last_live_ops_of_vars.resize(vars.size()); last_live_ops_of_vars.resize(vars.size());
ref_cnts.resize(vars.size()); ref_cnts.resize(vars.size());
ShrinkDepsOpFunctor shrink_func( ShrinkDepsOpFunctor shrink_func(
ir::FilterByNodeWrapper<OpHandleBase>(*graph)); ir::FilterByNodeWrapper<details::OpHandleBase>(*graph));
VLOG(1) << "Place number: " << vars.size(); VLOG(1) << "Place number: " << vars.size();
for (size_t i = 0; i < vars.size(); ++i) { for (size_t i = 0; i < vars.size(); ++i) {
...@@ -360,11 +370,10 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { ...@@ -360,11 +370,10 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
} }
} }
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(reference_count_pass, REGISTER_PASS(reference_count_pass, paddle::framework::ir::ReferenceCountPass)
paddle::framework::details::ReferenceCountPass) .RequirePassAttr(paddle::framework::ir::kGlobalReferenceCount)
.RequirePassAttr(paddle::framework::details::kGlobalReferenceCount) .RequirePassAttr(paddle::framework::ir::kLastLiveOpsOfVars);
.RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars);
...@@ -12,23 +12,24 @@ ...@@ -12,23 +12,24 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/reference_count_pass_helper.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars) { VarDesc *TryGetLatestVarDesc(const std::vector<details::VarHandle *> &vars) {
VarDesc *var_desc = nullptr; VarDesc *var_desc = nullptr;
std::find_if(vars.rbegin(), vars.rend(), [&](VarHandle *var_handle) -> bool { std::find_if(vars.rbegin(), vars.rend(),
[&](details::VarHandle *var_handle) -> bool {
var_desc = var_handle->Node()->Var(); var_desc = var_handle->Node()->Var();
return var_desc != nullptr; return var_desc != nullptr;
}); });
return var_desc; return var_desc;
} }
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -22,17 +22,16 @@ ...@@ -22,17 +22,16 @@
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class VarDesc; class VarDesc;
class VarHandle;
namespace details { namespace ir {
class ComputationOpHandle;
using ReferenceCountMap = std::unordered_map<std::string, size_t>; using ReferenceCountMap = std::unordered_map<std::string, size_t>;
...@@ -48,11 +47,12 @@ const char kGarbageCollector[] = "garbage_collector"; ...@@ -48,11 +47,12 @@ const char kGarbageCollector[] = "garbage_collector";
const char kAllPlaces[] = "all_places"; const char kAllPlaces[] = "all_places";
using LastLiveOpsOfVars = using LastLiveOpsOfVars =
std::unordered_map<std::string, std::unordered_set<ComputationOpHandle *>>; std::unordered_map<std::string,
std::unordered_set<details::ComputationOpHandle *>>;
const char kLastLiveOpsOfVars[] = "last_live_ops_of_var"; const char kLastLiveOpsOfVars[] = "last_live_ops_of_var";
VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars); VarDesc *TryGetLatestVarDesc(const std::vector<details::VarHandle *> &vars);
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -19,19 +19,19 @@ ...@@ -19,19 +19,19 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
class WhileOpEagerDeletionPass : public ir::Pass { class WhileOpEagerDeletionPass : public ir::Pass {
protected: protected:
void ApplyImpl(ir::Graph *graph) const override { void ApplyImpl(ir::Graph *graph) const override {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph); auto all_ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph);
// Find all while_op and while_grad_op // Find all while_op and while_grad_op
std::unordered_map<size_t, std::pair<std::vector<OperatorBase *>, std::unordered_map<size_t, std::pair<std::vector<OperatorBase *>,
std::vector<OperatorBase *>>> std::vector<OperatorBase *>>>
target_ops; target_ops;
for (auto *op : all_ops) { for (auto *op : all_ops) {
auto compute_op = dynamic_cast<ComputationOpHandle *>(op); auto compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
if (compute_op == nullptr) continue; if (compute_op == nullptr) continue;
if (compute_op->Name() == "while") { if (compute_op->Name() == "while") {
...@@ -52,9 +52,9 @@ class WhileOpEagerDeletionPass : public ir::Pass { ...@@ -52,9 +52,9 @@ class WhileOpEagerDeletionPass : public ir::Pass {
} }
}; };
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(while_op_eager_deletion_pass, REGISTER_PASS(while_op_eager_deletion_pass,
paddle::framework::details::WhileOpEagerDeletionPass); paddle::framework::ir::WhileOpEagerDeletionPass);
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper)
cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper)
set(ALL_REDUCE_OP_HANDLES all_reduce_op_handle)
if(WITH_GPU AND WITH_DGC)
list(APPEND ALL_REDUCE_OP_HANDLES sparse_all_reduce_op_handle)
endif()
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle ${ALL_REDUCE_OP_HANDLES} reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS all_reduce_op_handle graph graph_helper pass)
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
#include "paddle/fluid/framework/details/all_reduce_op_handle.h" #include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h" #include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
...@@ -31,17 +30,18 @@ ...@@ -31,17 +30,18 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
class AllReduceDepsPass : public ir::Pass { class AllReduceDepsPass : public ir::Pass {
protected: protected:
void ApplyImpl(ir::Graph* graph) const override { void ApplyImpl(ir::Graph* graph) const override {
std::vector<AllReduceOpHandle*> all_reduce_op_handles = std::vector<details::AllReduceOpHandle*> all_reduce_op_handles =
GetSortedAllReduceOps(*graph); GetSortedAllReduceOps(*graph);
for (size_t i = 1; i < all_reduce_op_handles.size(); ++i) { for (size_t i = 1; i < all_reduce_op_handles.size(); ++i) {
auto* dep_var = new DummyVarHandle(graph->CreateControlDepVar()); auto* dep_var = new details::DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var); graph->Get<details::GraphDepVars>(details::kGraphDepVars)
.emplace(dep_var);
all_reduce_op_handles[i - 1]->AddOutput(dep_var); all_reduce_op_handles[i - 1]->AddOutput(dep_var);
all_reduce_op_handles[i]->AddInput(dep_var); all_reduce_op_handles[i]->AddInput(dep_var);
} }
...@@ -51,16 +51,16 @@ class AllReduceDepsPass : public ir::Pass { ...@@ -51,16 +51,16 @@ class AllReduceDepsPass : public ir::Pass {
} }
} }
std::vector<AllReduceOpHandle*> GetSortedAllReduceOps( std::vector<details::AllReduceOpHandle*> GetSortedAllReduceOps(
const ir::Graph& graph) const { const ir::Graph& graph) const {
std::vector<AllReduceOpHandle*> all_reduce_op_handles; std::vector<details::AllReduceOpHandle*> all_reduce_op_handles;
std::unordered_map<OpHandleBase*, size_t> pending_ops; std::unordered_map<details::OpHandleBase*, size_t> pending_ops;
std::unordered_set<OpHandleBase*> ready_ops; std::unordered_set<details::OpHandleBase*> ready_ops;
std::unordered_set<OpHandleBase*> next_ready_ops; std::unordered_set<details::OpHandleBase*> next_ready_ops;
auto op_handles = ir::FilterByNodeWrapper<OpHandleBase>(graph); auto op_handles = ir::FilterByNodeWrapper<details::OpHandleBase>(graph);
size_t num_of_ops = op_handles.size(); size_t num_of_ops = op_handles.size();
for (OpHandleBase* op : op_handles) { for (details::OpHandleBase* op : op_handles) {
size_t not_ready_vars = op->NotReadyInputSize(); size_t not_ready_vars = op->NotReadyInputSize();
if (not_ready_vars) { if (not_ready_vars) {
pending_ops.insert({op, not_ready_vars}); pending_ops.insert({op, not_ready_vars});
...@@ -94,11 +94,12 @@ class AllReduceDepsPass : public ir::Pass { ...@@ -94,11 +94,12 @@ class AllReduceDepsPass : public ir::Pass {
} }
void GetSortedAllReduceOps( void GetSortedAllReduceOps(
const std::unordered_set<OpHandleBase*>& ready_ops, const std::unordered_set<details::OpHandleBase*>& ready_ops,
std::vector<AllReduceOpHandle*>* all_reduce_op_handles) const { std::vector<details::AllReduceOpHandle*>* all_reduce_op_handles) const {
std::vector<AllReduceOpHandle*> current_all_reduce_op_handles; std::vector<details::AllReduceOpHandle*> current_all_reduce_op_handles;
for (auto& op_handle : ready_ops) { for (auto& op_handle : ready_ops) {
auto all_reduce_op_handle = dynamic_cast<AllReduceOpHandle*>(op_handle); auto all_reduce_op_handle =
dynamic_cast<details::AllReduceOpHandle*>(op_handle);
if (all_reduce_op_handle) { if (all_reduce_op_handle) {
current_all_reduce_op_handles.emplace_back(all_reduce_op_handle); current_all_reduce_op_handles.emplace_back(all_reduce_op_handle);
} }
...@@ -109,10 +110,12 @@ class AllReduceDepsPass : public ir::Pass { ...@@ -109,10 +110,12 @@ class AllReduceDepsPass : public ir::Pass {
// Sort the current_all_reduce_op_handles according to the name of input. // Sort the current_all_reduce_op_handles according to the name of input.
sort(current_all_reduce_op_handles.begin(), sort(current_all_reduce_op_handles.begin(),
current_all_reduce_op_handles.end(), current_all_reduce_op_handles.end(),
[](const AllReduceOpHandle* left, [](const details::AllReduceOpHandle* left,
const AllReduceOpHandle* right) -> bool { const details::AllReduceOpHandle* right) -> bool {
auto left_in_vars = DynamicCast<VarHandle>(left->Inputs()); auto left_in_vars =
auto right_in_vars = DynamicCast<VarHandle>(right->Inputs()); details::DynamicCast<details::VarHandle>(left->Inputs());
auto right_in_vars =
details::DynamicCast<details::VarHandle>(right->Inputs());
PADDLE_ENFORCE_GT(left_in_vars.size(), 0); PADDLE_ENFORCE_GT(left_in_vars.size(), 0);
PADDLE_ENFORCE_EQ(left_in_vars.size(), right_in_vars.size()); PADDLE_ENFORCE_EQ(left_in_vars.size(), right_in_vars.size());
return left_in_vars[0]->Name() > right_in_vars[0]->Name(); return left_in_vars[0]->Name() > right_in_vars[0]->Name();
...@@ -123,15 +126,15 @@ class AllReduceDepsPass : public ir::Pass { ...@@ -123,15 +126,15 @@ class AllReduceDepsPass : public ir::Pass {
current_all_reduce_op_handles.end()); current_all_reduce_op_handles.end());
} }
void DebugString( void DebugString(const ir::Graph& graph,
const ir::Graph& graph, const std::vector<details::AllReduceOpHandle*>&
const std::vector<AllReduceOpHandle*>& all_reduce_op_handles) const { all_reduce_op_handles) const {
// get vars order // get vars order
std::map<int, std::vector<std::string>> vars = std::map<int, std::vector<std::string>> vars =
GetSoredGradientsFromStaleProgram(graph); GetSoredGradientsFromStaleProgram(graph);
std::stringstream out; std::stringstream out;
size_t grads_of_stale_program = 0; size_t grads_of_stale_program = 0;
out << "Get Order From kStaleProgramOpDescs: "; out << "Get Order From details::kStaleProgramOpDescs: ";
for (auto& var : vars) { for (auto& var : vars) {
out << "Order " << var.first << " ["; out << "Order " << var.first << " [";
for (auto& var_name : var.second) { for (auto& var_name : var.second) {
...@@ -147,7 +150,7 @@ class AllReduceDepsPass : public ir::Pass { ...@@ -147,7 +150,7 @@ class AllReduceDepsPass : public ir::Pass {
for (auto& op : all_reduce_op_handles) { for (auto& op : all_reduce_op_handles) {
bool find_valid_input = false; bool find_valid_input = false;
for (auto& in_var : op->Inputs()) { for (auto& in_var : op->Inputs()) {
if (dynamic_cast<VarHandle*>(in_var)) { if (dynamic_cast<details::VarHandle*>(in_var)) {
out2 << in_var->Name() << ", "; out2 << in_var->Name() << ", ";
find_valid_input = true; find_valid_input = true;
break; break;
...@@ -165,7 +168,8 @@ class AllReduceDepsPass : public ir::Pass { ...@@ -165,7 +168,8 @@ class AllReduceDepsPass : public ir::Pass {
std::map<int, std::vector<std::string>> GetSoredGradientsFromStaleProgram( std::map<int, std::vector<std::string>> GetSoredGradientsFromStaleProgram(
const ir::Graph& graph) const { const ir::Graph& graph) const {
std::map<int, std::vector<std::string>> vars; std::map<int, std::vector<std::string>> vars;
auto ops = graph.Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs); auto ops =
graph.Get<const std::vector<OpDesc*>>(details::kStaleProgramOpDescs);
int order = 0; int order = 0;
for (auto* op_desc : ops) { for (auto* op_desc : ops) {
try { try {
...@@ -193,10 +197,9 @@ class AllReduceDepsPass : public ir::Pass { ...@@ -193,10 +197,9 @@ class AllReduceDepsPass : public ir::Pass {
return vars; return vars;
} }
}; };
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(all_reduce_deps_pass, REGISTER_PASS(all_reduce_deps_pass, paddle::framework::ir::AllReduceDepsPass)
paddle::framework::details::AllReduceDepsPass)
.RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs); .RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
...@@ -24,21 +24,22 @@ ...@@ -24,21 +24,22 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
class FuseAllReduceOpPass : public ir::Pass { class FuseAllReduceOpPass : public ir::Pass {
protected: protected:
void ApplyImpl(ir::Graph *graph) const override { void ApplyImpl(ir::Graph *graph) const override {
ir::Graph &result = *graph; ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces); auto &places = Get<const std::vector<platform::Place>>(details::kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(kLocalScopes); auto &local_scopes = Get<const std::vector<Scope *>>(details::kLocalScopes);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto *nccl_ctxs = &Get<platform::NCCLContextMap>(kNCCLCtxs); auto *nccl_ctxs = &Get<platform::NCCLContextMap>(details::kNCCLCtxs);
#endif #endif
std::unordered_set<std::string> grads; std::unordered_set<std::string> grads;
auto &params_grads = result.Get<ParamsAndGrads>(kParamsAndGrads); auto &params_grads =
result.Get<details::ParamsAndGrads>(details::kParamsAndGrads);
size_t num_of_all_reduce = params_grads.size(); size_t num_of_all_reduce = params_grads.size();
grads.reserve(num_of_all_reduce); grads.reserve(num_of_all_reduce);
for (auto p_g : params_grads) { for (auto p_g : params_grads) {
...@@ -50,11 +51,12 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -50,11 +51,12 @@ class FuseAllReduceOpPass : public ir::Pass {
all_reduce_ops.reserve(grads.size()); all_reduce_ops.reserve(grads.size());
for (auto &node : result.Nodes()) { for (auto &node : result.Nodes()) {
if (node->IsOp()) { if (node->IsOp()) {
PADDLE_ENFORCE(node->IsWrappedBy<OpHandleBase>()); PADDLE_ENFORCE(node->IsWrappedBy<details::OpHandleBase>());
auto *all_reduce_op_handle = auto *all_reduce_op_handle = dynamic_cast<details::AllReduceOpHandle *>(
dynamic_cast<AllReduceOpHandle *>(&node->Wrapper<OpHandleBase>()); &node->Wrapper<details::OpHandleBase>());
if (all_reduce_op_handle) { if (all_reduce_op_handle) {
auto inputs = DynamicCast<VarHandle>(all_reduce_op_handle->Inputs()); auto inputs = details::DynamicCast<details::VarHandle>(
all_reduce_op_handle->Inputs());
PADDLE_ENFORCE_EQ(inputs.size(), num_place); PADDLE_ENFORCE_EQ(inputs.size(), num_place);
// The inputs' name should be the same. // The inputs' name should be the same.
auto &grad_name = inputs[0]->name(); auto &grad_name = inputs[0]->name();
...@@ -80,7 +82,7 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -80,7 +82,7 @@ class FuseAllReduceOpPass : public ir::Pass {
VLOG(10) << "Insert fused_all_reduce"; VLOG(10) << "Insert fused_all_reduce";
auto &group_grads_params = auto &group_grads_params =
graph->Get<GroupGradsAndParams>(kGroupGradsAndParams); graph->Get<details::GroupGradsAndParams>(details::kGroupGradsAndParams);
for (auto &group_g_p : group_grads_params) { for (auto &group_g_p : group_grads_params) {
size_t group_size = group_g_p.size(); size_t group_size = group_g_p.size();
...@@ -108,24 +110,25 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -108,24 +110,25 @@ class FuseAllReduceOpPass : public ir::Pass {
const platform::NCCLContextMap *nccl_ctxs, const platform::NCCLContextMap *nccl_ctxs,
#endif #endif
ir::Graph *result) const { ir::Graph *result) const {
std::vector<VarHandleBase *> inputs; std::vector<details::VarHandleBase *> inputs;
std::vector<VarHandleBase *> outputs; std::vector<details::VarHandleBase *> outputs;
for (auto &op : all_reduce_ops) { for (auto &op : all_reduce_ops) {
auto &op_handle = op->Wrapper<OpHandleBase>(); auto &op_handle = op->Wrapper<details::OpHandleBase>();
inputs.insert(inputs.end(), op_handle.Inputs().begin(), inputs.insert(inputs.end(), op_handle.Inputs().begin(),
op_handle.Inputs().end()); op_handle.Inputs().end());
// Remove output // Remove output
for_each(op_handle.Inputs().begin(), op_handle.Inputs().end(), for_each(op_handle.Inputs().begin(), op_handle.Inputs().end(),
[&op_handle](VarHandleBase *var_handle) { [&op_handle](details::VarHandleBase *var_handle) {
var_handle->RemoveOutput(&op_handle, op_handle.Node()); var_handle->RemoveOutput(&op_handle, op_handle.Node());
}); });
outputs.insert(outputs.end(), op_handle.Outputs().begin(), outputs.insert(outputs.end(), op_handle.Outputs().begin(),
op_handle.Outputs().end()); op_handle.Outputs().end());
// Remove Input // Remove Input
for_each( for_each(op_handle.Outputs().begin(), op_handle.Outputs().end(),
op_handle.Outputs().begin(), op_handle.Outputs().end(), [](details::VarHandleBase *var_handle) {
[](VarHandleBase *var_handle) { var_handle->ClearGeneratedOp(); }); var_handle->ClearGeneratedOp();
});
result->RemoveNode(op_handle.Node()); result->RemoveNode(op_handle.Node());
} }
...@@ -140,8 +143,9 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -140,8 +143,9 @@ class FuseAllReduceOpPass : public ir::Pass {
} }
private: private:
void CreateFusedAllReduceOp(const std::vector<VarHandleBase *> &inputs, void CreateFusedAllReduceOp(
const std::vector<VarHandleBase *> &outputs, const std::vector<details::VarHandleBase *> &inputs,
const std::vector<details::VarHandleBase *> &outputs,
const size_t num_of_all_reduce, const size_t num_of_all_reduce,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
...@@ -150,11 +154,11 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -150,11 +154,11 @@ class FuseAllReduceOpPass : public ir::Pass {
#endif #endif
ir::Graph *result) const { ir::Graph *result) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto *op_handle = new FusedAllReduceOpHandle( auto *op_handle = new details::FusedAllReduceOpHandle(
result->CreateEmptyNode("fused_all_reduce", ir::Node::Type::kOperation), result->CreateEmptyNode("fused_all_reduce", ir::Node::Type::kOperation),
local_scopes, places, num_of_all_reduce, nccl_ctxs); local_scopes, places, num_of_all_reduce, nccl_ctxs);
#else #else
auto *op_handle = new FusedAllReduceOpHandle( auto *op_handle = new details::FusedAllReduceOpHandle(
result->CreateEmptyNode("fused_all_reduce", ir::Node::Type::kOperation), result->CreateEmptyNode("fused_all_reduce", ir::Node::Type::kOperation),
local_scopes, places, num_of_all_reduce); local_scopes, places, num_of_all_reduce);
#endif #endif
...@@ -176,8 +180,9 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -176,8 +180,9 @@ class FuseAllReduceOpPass : public ir::Pass {
#endif #endif
} }
void SetCommunicationContext(const std::vector<platform::Place> &places, void SetCommunicationContext(
FusedAllReduceOpHandle *op_handle) const { const std::vector<platform::Place> &places,
details::FusedAllReduceOpHandle *op_handle) const {
for (size_t i = 0; i < places.size(); ++i) { for (size_t i = 0; i < places.size(); ++i) {
op_handle->SetDeviceContext( op_handle->SetDeviceContext(
places[i], platform::DeviceContextPool::Instance().Get(places[i])); places[i], platform::DeviceContextPool::Instance().Get(places[i]));
...@@ -185,9 +190,9 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -185,9 +190,9 @@ class FuseAllReduceOpPass : public ir::Pass {
} }
}; };
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(fuse_all_reduce_op_pass, REGISTER_PASS(fuse_all_reduce_op_pass,
paddle::framework::details::FuseAllReduceOpPass); paddle::framework::ir::FuseAllReduceOpPass);
...@@ -12,21 +12,20 @@ ...@@ -12,21 +12,20 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h"
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
static bool IsLockAndRecordEventFreeComputationOpHandle( static bool IsLockAndRecordEventFreeComputationOpHandle(
ComputationOpHandle *op, const OpGraphView &graph_view) { details::ComputationOpHandle *op, const OpGraphView &graph_view) {
if (!platform::is_gpu_place(op->GetPlace())) return false; if (!platform::is_gpu_place(op->GetPlace())) return false;
for (auto &pending_op : graph_view.PendingOps(op)) { for (auto &pending_op : graph_view.PendingOps(op)) {
auto *tmp = dynamic_cast<ComputationOpHandle *>(pending_op); auto *tmp = dynamic_cast<details::ComputationOpHandle *>(pending_op);
if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) { if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) {
return false; return false;
} }
...@@ -34,11 +33,13 @@ static bool IsLockAndRecordEventFreeComputationOpHandle( ...@@ -34,11 +33,13 @@ static bool IsLockAndRecordEventFreeComputationOpHandle(
return true; return true;
} }
void ModifyOpLockAndRecordEventPass::ApplyImpl(ir::Graph *ir_graph) const { class ModifyOpLockAndRecordEventPass : public ir::Pass {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*ir_graph); protected:
void ApplyImpl(ir::Graph *graph) const override {
auto all_ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph);
OpGraphView graph_view(all_ops); OpGraphView graph_view(all_ops);
for (auto &op : all_ops) { for (auto &op : all_ops) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op); auto *compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
if (compute_op == nullptr) continue; if (compute_op == nullptr) continue;
bool is_lock_and_record_event_free = bool is_lock_and_record_event_free =
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view); IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view);
...@@ -48,11 +49,11 @@ void ModifyOpLockAndRecordEventPass::ApplyImpl(ir::Graph *ir_graph) const { ...@@ -48,11 +49,11 @@ void ModifyOpLockAndRecordEventPass::ApplyImpl(ir::Graph *ir_graph) const {
<< compute_op->DebugString(); << compute_op->DebugString();
} }
} }
} }
};
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(modify_op_lock_and_record_event_pass, REGISTER_PASS(modify_op_lock_and_record_event_pass,
paddle::framework::details::ModifyOpLockAndRecordEventPass); paddle::framework::ir::ModifyOpLockAndRecordEventPass);
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
class SSAGraghBuilderWithChecker : public ir::Pass { class SSAGraghBuilderWithChecker : public ir::Pass {
protected: protected:
...@@ -28,19 +28,19 @@ class SSAGraghBuilderWithChecker : public ir::Pass { ...@@ -28,19 +28,19 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
} }
bool IsValidGraph(const ir::Graph *graph) const { bool IsValidGraph(const ir::Graph *graph) const {
std::unordered_map<OpHandleBase *, size_t> pending_ops; std::unordered_map<details::OpHandleBase *, size_t> pending_ops;
std::unordered_set<VarHandleBase *> pending_vars; std::unordered_set<details::VarHandleBase *> pending_vars;
std::unordered_set<VarHandleBase *> ready_vars; std::unordered_set<details::VarHandleBase *> ready_vars;
std::unordered_set<OpHandleBase *> ready_ops; std::unordered_set<details::OpHandleBase *> ready_ops;
auto insert_pending_var = [&](VarHandleBase *var) { auto insert_pending_var = [&](details::VarHandleBase *var) {
pending_vars.insert(var); pending_vars.insert(var);
if (var->GeneratedOp() == nullptr) { if (var->GeneratedOp() == nullptr) {
ready_vars.emplace(var); ready_vars.emplace(var);
} }
}; };
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) { for (auto &var_map : graph->Get<details::GraphVars>(details::kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair); insert_pending_var(version_pair);
...@@ -48,11 +48,12 @@ class SSAGraghBuilderWithChecker : public ir::Pass { ...@@ -48,11 +48,12 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
} }
} }
for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) { for (auto &var :
graph->Get<details::GraphDepVars>(details::kGraphDepVars)) {
insert_pending_var(var); insert_pending_var(var);
} }
for (OpHandleBase *op : ir::FilterByNodeWrapper<OpHandleBase>(*graph)) { for (auto *op : ir::FilterByNodeWrapper<details::OpHandleBase>(*graph)) {
if (op->Inputs().empty()) { if (op->Inputs().empty()) {
ready_ops.insert(op); ready_ops.insert(op);
} else { } else {
...@@ -60,7 +61,7 @@ class SSAGraghBuilderWithChecker : public ir::Pass { ...@@ -60,7 +61,7 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
} }
} }
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) { auto run_all_ops = [&](std::unordered_set<details::OpHandleBase *> &set) {
for (auto *op : set) { for (auto *op : set) {
for (auto out : op->Outputs()) { for (auto out : op->Outputs()) {
ready_vars.emplace(out); ready_vars.emplace(out);
...@@ -91,11 +92,11 @@ class SSAGraghBuilderWithChecker : public ir::Pass { ...@@ -91,11 +92,11 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
} }
}; };
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(multi_devices_check_pass, REGISTER_PASS(multi_devices_check_pass,
paddle::framework::details::SSAGraghBuilderWithChecker) paddle::framework::ir::SSAGraghBuilderWithChecker)
.RequireGraphAttr(paddle::framework::details::kGraphVars) .RequireGraphAttr(paddle::framework::details::kGraphVars)
.RequireGraphAttr(paddle::framework::details::kGraphDepVars); .RequireGraphAttr(paddle::framework::details::kGraphDepVars);
...@@ -31,7 +31,7 @@ class NCCLContextMap; ...@@ -31,7 +31,7 @@ class NCCLContextMap;
namespace framework { namespace framework {
class Scope; class Scope;
namespace details { namespace ir {
constexpr char kLossVarName[] = "loss_var_name"; constexpr char kLossVarName[] = "loss_var_name";
constexpr char kStrategy[] = "strategy"; constexpr char kStrategy[] = "strategy";
...@@ -69,7 +69,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { ...@@ -69,7 +69,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
ir::Node *out_var_node, size_t loss_scale, ir::Node *out_var_node, size_t loss_scale,
proto::VarType::Type dtype) const; proto::VarType::Type dtype) const;
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og, details::VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
size_t dst_dev_id) const; size_t dst_dev_id) const;
void CreateComputationalOp(ir::Graph *result, ir::Node *node, void CreateComputationalOp(ir::Graph *result, ir::Node *node,
...@@ -89,7 +89,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { ...@@ -89,7 +89,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
ir::Graph *result, ir::Graph *result,
const std::vector<std::unordered_set<std::string>> &bcast_varnames) const; const std::vector<std::unordered_set<std::string>> &bcast_varnames) const;
void SetCommunicationContext(OpHandleBase *op_handle, void SetCommunicationContext(details::OpHandleBase *op_handle,
const platform::Place &p) const; const platform::Place &p) const;
void CreateOpHandleIOs(ir::Graph *result, ir::Node *node, void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
...@@ -103,7 +103,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { ...@@ -103,7 +103,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
mutable std::vector<platform::Place> places_; mutable std::vector<platform::Place> places_;
mutable std::vector<Scope *> local_scopes_; mutable std::vector<Scope *> local_scopes_;
mutable BuildStrategy strategy_; mutable details::BuildStrategy strategy_;
mutable std::unordered_map<std::string, VarDesc *> all_vars_; mutable std::unordered_map<std::string, VarDesc *> all_vars_;
}; };
...@@ -209,6 +209,6 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder { ...@@ -209,6 +209,6 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
std::unordered_set<std::string> &MultiDevSSAGraphBuilder(); std::unordered_set<std::string> &MultiDevSSAGraphBuilder();
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -21,11 +21,21 @@ ...@@ -21,11 +21,21 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
class SSAGraghBuilderWithPrinterPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>(kGraphvizPath)));
PADDLE_ENFORCE(fout->good());
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
}
};
template <typename Callback> template <typename Callback>
static inline void IterAllVar(const ir::Graph &graph, Callback callback) { static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
for (auto &each : graph.Get<GraphVars>(kGraphVars)) { for (auto &each : graph.Get<details::GraphVars>(details::kGraphVars)) {
for (auto &pair1 : each) { for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) { for (auto &pair2 : pair1.second) {
callback(*pair2); callback(*pair2);
...@@ -33,7 +43,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) { ...@@ -33,7 +43,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
} }
} }
for (auto &var : graph.Get<GraphDepVars>(kGraphDepVars)) { for (auto &var : graph.Get<details::GraphDepVars>(details::kGraphDepVars)) {
callback(*var); callback(*var);
} }
} }
...@@ -41,14 +51,14 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) { ...@@ -41,14 +51,14 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
std::ostream &sout) const { std::ostream &sout) const {
size_t var_id = 0; size_t var_id = 0;
std::unordered_map<const VarHandleBase *, size_t> vars; std::unordered_map<const details::VarHandleBase *, size_t> vars;
sout << "digraph G {\n"; sout << "digraph G {\n";
IterAllVar(graph, [&](const VarHandleBase &var) { IterAllVar(graph, [&](const details::VarHandleBase &var) {
auto *var_ptr = &var; auto *var_ptr = &var;
auto *var_handle_ptr = dynamic_cast<const VarHandle *>(var_ptr); auto *var_handle_ptr = dynamic_cast<const details::VarHandle *>(var_ptr);
auto *dummy_ptr = dynamic_cast<const DummyVarHandle *>(var_ptr); auto *dummy_ptr = dynamic_cast<const details::DummyVarHandle *>(var_ptr);
size_t cur_var_id = var_id++; size_t cur_var_id = var_id++;
vars[var_ptr] = cur_var_id; vars[var_ptr] = cur_var_id;
...@@ -65,7 +75,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, ...@@ -65,7 +75,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
}); });
size_t op_id = 0; size_t op_id = 0;
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(graph)) { for (auto &op : ir::FilterByNodeWrapper<details::OpHandleBase>(graph)) {
std::string op_name = "op_" + std::to_string(op_id++); std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl; << std::endl;
...@@ -82,10 +92,10 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, ...@@ -82,10 +92,10 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
sout << "}\n"; sout << "}\n";
} }
} // namespace details } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(multi_devices_print_pass, REGISTER_PASS(multi_devices_print_pass,
paddle::framework::details::SSAGraghBuilderWithPrinter) paddle::framework::ir::SSAGraghBuilderWithPrinterPass)
.RequirePassAttr(paddle::framework::details::kGraphvizPath); .RequirePassAttr(paddle::framework::ir::kGraphvizPath);
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace ir {
constexpr char kGraphvizPath[] = "debug_graphviz_path"; constexpr char kGraphvizPath[] = "debug_graphviz_path";
...@@ -39,16 +39,6 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter { ...@@ -39,16 +39,6 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
void Print(const ir::Graph& graph, std::ostream& sout) const override; void Print(const ir::Graph& graph, std::ostream& sout) const override;
}; };
class SSAGraghBuilderWithPrinter : public ir::Pass { } // namespace ir
protected:
void ApplyImpl(ir::Graph* graph) const override {
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>(kGraphvizPath)));
PADDLE_ENFORCE(fout->good());
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
}
};
} // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace ir {
static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) {
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
op1->Outputs() == op2->Outputs();
}
class SequentialExecutionPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
// FIXME(zjl): Insert dependencies between some distributed ops may cause
// the multi_devices_graph_pass fails. So we skip these ops here.
// Indeed, maybe we should not insert dependencies between these ops
// casually, which may cause deadlock easily.
// We should add more skipped distributed ops when found errors in
// multi_devices_graph_pass
static std::unordered_set<std::string> skip_dist_ops{
"send", "recv", "send_barrier", "fetch_barrier"};
auto &ops =
graph->Get<const std::vector<OpDesc *>>(details::kStaleProgramOpDescs);
std::vector<ir::Node *> op_node_list;
op_node_list.reserve(ops.size());
std::unordered_map<ir::Node *, size_t> op_deps;
std::unordered_map<ir::Node *, std::unordered_set<ir::Node *>> pending_ops;
std::unordered_set<ir::Node *> ready_ops;
for (ir::Node *node : graph->Nodes()) {
if (!node->IsOp()) continue;
std::unordered_set<ir::Node *> preceding_ops;
for (auto *in : node->inputs) {
PADDLE_ENFORCE(in->IsVar(),
"Preceding Node of Op Nodes must be Var Node");
if (in->inputs.empty()) continue;
PADDLE_ENFORCE(in->inputs.size() == 1 && in->inputs[0]->IsOp(),
"Preceding Op Node of Var Node must be unique");
preceding_ops.insert(in->inputs[0]);
pending_ops[in->inputs[0]].insert(node);
}
op_deps[node] = preceding_ops.size();
if (preceding_ops.empty()) {
ready_ops.insert(node);
}
}
for (auto *op_desc : ops) {
ir::Node *found_node = nullptr;
for (auto *node : ready_ops) {
if (IsSameOpDesc(op_desc, node->Op())) {
PADDLE_ENFORCE(found_node == nullptr,
"Found multiple op_desc in graph: %s",
op_desc->Type());
found_node = node;
}
}
PADDLE_ENFORCE_NOT_NULL(found_node, "Cannot find op_desc in graph: %s",
op_desc->Type());
for (auto *pending_op : pending_ops[found_node]) {
if (--op_deps.at(pending_op) == 0) {
ready_ops.insert(pending_op);
}
}
ready_ops.erase(found_node);
if (skip_dist_ops.count(op_desc->Type()) == 0) {
op_node_list.push_back(found_node);
}
}
for (size_t i = 1; i < op_node_list.size(); ++i) {
auto *dep_var = graph->CreateControlDepVar();
op_node_list[i]->inputs.push_back(dep_var);
op_node_list[i - 1]->outputs.push_back(dep_var);
dep_var->outputs.push_back(op_node_list[i]);
dep_var->inputs.push_back(op_node_list[i - 1]);
VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name()
<< " and " << op_node_list[i]->Name();
}
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(sequential_execution_pass,
paddle::framework::ir::SequentialExecutionPass)
.RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
...@@ -12,20 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,20 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/sync_batch_norm_pass.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void SyncBatchNormPass::ApplyImpl(ir::Graph* graph) const { class SyncBatchNormPass : public Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
VLOG(3) << "Use synchronous batch norm"; VLOG(3) << "Use synchronous batch norm";
for (const Node* n : graph->Nodes()) { for (const Node *n : graph->Nodes()) {
if (n->IsOp()) { if (n->IsOp()) {
auto* op = n->Op(); auto *op = n->Op();
if (op->Type() == "batch_norm") { if (op->Type() == "batch_norm") {
op->SetType("sync_batch_norm"); op->SetType("sync_batch_norm");
} }
...@@ -34,8 +36,8 @@ void SyncBatchNormPass::ApplyImpl(ir::Graph* graph) const { ...@@ -34,8 +36,8 @@ void SyncBatchNormPass::ApplyImpl(ir::Graph* graph) const {
} }
} }
} }
} }
};
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class SyncBatchNormPass : public Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/sync_batch_norm_pass.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
......
...@@ -23,11 +23,11 @@ limitations under the License. */ ...@@ -23,11 +23,11 @@ limitations under the License. */
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h" #include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#ifdef WITH_GPERFTOOLS #ifdef WITH_GPERFTOOLS
...@@ -110,9 +110,9 @@ class ParallelExecutorPrivate { ...@@ -110,9 +110,9 @@ class ParallelExecutorPrivate {
// global_ref_cnts_ is only initialized when ParallelExecutor constructs, and // global_ref_cnts_ is only initialized when ParallelExecutor constructs, and
// then keeps unchanged // then keeps unchanged
// Before each iteration, runtime_ref_cnts_ is reset to global_ref_cnts_ // Before each iteration, runtime_ref_cnts_ is reset to global_ref_cnts_
std::vector<details::ReferenceCountMap> global_ref_cnts_; std::vector<ir::ReferenceCountMap> global_ref_cnts_;
std::vector<details::AtomicReferenceCountMap> runtime_ref_cnts_; std::vector<ir::AtomicReferenceCountMap> runtime_ref_cnts_;
details::GarbageCollectorMap gcs_; ir::GarbageCollectorMap gcs_;
}; };
ir::Graph *ParallelExecutorPrivate::PrepareGCAndRefCnts( ir::Graph *ParallelExecutorPrivate::PrepareGCAndRefCnts(
...@@ -150,25 +150,23 @@ ir::Graph *ParallelExecutorPrivate::PrepareGCAndRefCnts( ...@@ -150,25 +150,23 @@ ir::Graph *ParallelExecutorPrivate::PrepareGCAndRefCnts(
} }
if (!gcs_.empty()) { if (!gcs_.empty()) {
std::vector<details::LastLiveOpsOfVars> last_live_ops_of_vars; std::vector<ir::LastLiveOpsOfVars> last_live_ops_of_vars;
auto ref_cnt_pass = auto ref_cnt_pass =
ir::PassRegistry::Instance().Get("reference_count_pass"); ir::PassRegistry::Instance().Get("reference_count_pass");
ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount, ref_cnt_pass->SetNotOwned(ir::kGlobalReferenceCount, &global_ref_cnts_);
&global_ref_cnts_); ref_cnt_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars);
ref_cnt_pass->SetNotOwned(details::kLastLiveOpsOfVars,
&last_live_ops_of_vars);
graph = ref_cnt_pass->Apply(graph); graph = ref_cnt_pass->Apply(graph);
VLOG(10) << "ReferenceCountPass Applied"; VLOG(10) << "ReferenceCountPass Applied";
auto eager_deletion_pass = auto eager_deletion_pass =
ir::PassRegistry::Instance().Get("eager_deletion_pass"); ir::PassRegistry::Instance().Get("eager_deletion_pass");
eager_deletion_pass->SetNotOwned(details::kRuntimeReferenceCount, eager_deletion_pass->SetNotOwned(ir::kRuntimeReferenceCount,
&runtime_ref_cnts_); &runtime_ref_cnts_);
eager_deletion_pass->SetNotOwned(details::kGarbageCollector, &gcs_); eager_deletion_pass->SetNotOwned(ir::kGarbageCollector, &gcs_);
eager_deletion_pass->SetNotOwned(details::kLastLiveOpsOfVars, eager_deletion_pass->SetNotOwned(ir::kLastLiveOpsOfVars,
&last_live_ops_of_vars); &last_live_ops_of_vars);
eager_deletion_pass->SetNotOwned(details::kAllPlaces, &places_); eager_deletion_pass->SetNotOwned(ir::kAllPlaces, &places_);
graph = eager_deletion_pass->Apply(graph); graph = eager_deletion_pass->Apply(graph);
VLOG(10) << "EagerDeletionPass Applied"; VLOG(10) << "EagerDeletionPass Applied";
} }
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/pybind/const_value.h" #include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/framework/details/memory_optimize_pass.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
...@@ -34,7 +34,7 @@ void BindConstValue(pybind11::module* m) { ...@@ -34,7 +34,7 @@ void BindConstValue(pybind11::module* m) {
m->def("kControlDepVarName", m->def("kControlDepVarName",
[] { return framework::ir::Node::kControlDepVarName; }); [] { return framework::ir::Node::kControlDepVarName; });
m->def("kNewGradSuffix", [] { return framework::kNewGradSuffix; }); m->def("kNewGradSuffix", [] { return framework::kNewGradSuffix; });
m->def("kMemOptSkipVars", [] { return framework::details::kMemOptSkipVars; }); m->def("kMemOptSkipVars", [] { return framework::ir::kMemOptSkipVars; });
auto op_proto_and_checker_maker = auto op_proto_and_checker_maker =
m->def_submodule("op_proto_and_checker_maker"); m->def_submodule("op_proto_and_checker_maker");
......
...@@ -21,11 +21,11 @@ limitations under the License. */ ...@@ -21,11 +21,11 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/alloc_continuous_space_for_grad_pass.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.h"
#include "paddle/fluid/framework/ir/pass_builder.h" #include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -170,9 +170,9 @@ PYBIND11_MODULE(core, m) { ...@@ -170,9 +170,9 @@ PYBIND11_MODULE(core, m) {
m.def("_set_eager_deletion_mode", &paddle::framework::SetEagerDeletionMode); m.def("_set_eager_deletion_mode", &paddle::framework::SetEagerDeletionMode);
m.def("_set_fuse_parameter_group_size", m.def("_set_fuse_parameter_group_size",
&paddle::framework::details::SetFuseParameterGroupsSize); &paddle::framework::ir::SetFuseParameterGroupsSize);
m.def("_set_fuse_parameter_memory_size", m.def("_set_fuse_parameter_memory_size",
&paddle::framework::details::SetFuseParameterMemorySize); &paddle::framework::ir::SetFuseParameterMemorySize);
m.add_object("_cleanup", m.add_object("_cleanup",
py::capsule([]() { ScopePool::Instance().Clear(); })); py::capsule([]() { ScopePool::Instance().Clear(); }));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册