未验证 提交 04bd413a 编写于 作者: C chengduo 提交者: GitHub

Code Clean: Move all pass to paddle::framework::ir (#17228)

* move pass to ir

* polish code
test=develop

* fix dependency
test=develop
上级 648320bb
cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node)
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor)
cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base)
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper)
cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper)
cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper)
cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper)
cc_library(fuse_adam_op_pass SRCS fuse_adam_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(fuse_momentum_op_pass SRCS fuse_momentum_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
......@@ -27,7 +17,7 @@ if(WITH_DISTRIBUTE)
endif()
endif()
set(all_reduce_deps all_reduce_op_handle)
if(WITH_GPU)
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor)
......@@ -37,7 +27,6 @@ if(WITH_GPU)
if(WITH_DGC)
nv_library(sparse_all_reduce_op_handle SRCS sparse_all_reduce_op_handle.cc DEPS op_handle_base scope
lod_tensor ddim memory dynload_cuda variable_visitor dgc all_reduce_op_handle)
set(all_reduce_deps sparse_all_reduce_op_handle)
endif()
if(WITH_DISTRIBUTE)
......@@ -68,34 +57,12 @@ endif()
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
if(WITH_GPU)
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper gpu_info)
else()
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper cpu_info)
endif()
cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass)
cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info)
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle)
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass)
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle ${all_reduce_deps} reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass)
if (WITH_GPU)
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
endif()
cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/alloc_continuous_space_for_grad_pass.h"
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
DEFINE_uint64(fuse_parameter_memory_size, 0, // 0 KB
"fuse_parameter_memory_size is up limited memory size "
"of one group parameters' gradient which is the input "
"of communication calling(e.g NCCLAllReduce). "
"The default value is 0, it means that "
"not set group according to memory_size.");
DEFINE_int32(
fuse_parameter_groups_size, 3,
"fuse_parameter_groups_size is the size of one group parameters' gradient. "
"The default value is a experimental result. If the "
"fuse_parameter_groups_size is 1, it means that the groups size is "
"the number of parameters' gradient. If the fuse_parameter_groups_size is "
"-1, it means that there are only one group. The default value is 3, it is "
"an experimental value.");
namespace paddle {
namespace framework {
namespace details {
// SetFuseParameterGroupsSize and SetFuseParameterMemorySize are used in unit
// test, because it is invalid that seting 'FLAGS_fuse_parameter_memory_size'
// and 'FLAGS_fuse_parameter_groups_size' in unit test.
void SetFuseParameterGroupsSize(int group_size) {
FLAGS_fuse_parameter_groups_size = group_size;
}
int GetFuseParameterGroupsSize() { return FLAGS_fuse_parameter_groups_size; }
void SetFuseParameterMemorySize(uint64_t memory_size) {
FLAGS_fuse_parameter_memory_size = memory_size;
}
uint64_t GetFuseParameterMemorySize() {
return FLAGS_fuse_parameter_memory_size;
}
static const char kUnKnow[] = "@UNKNOW@";
static framework::proto::VarType::Type kDefaultDtype =
framework::proto::VarType::Type::VarType_Type_BOOL;
void AllocContinuousSpaceForGradPass::ApplyImpl(ir::Graph *graph) const {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(kLocalScopes);
ResetAttribute<ParamsAndGrads>(kParamsAndGrads, &result);
ResetAttribute<GroupGradsAndParams>(kGroupGradsAndParams, &result);
// NOTE: The operator nodes should be in topology order.
std::vector<ir::Node *> topo_nodes = ir::TopologySortOperations(result);
auto &params_grads = result.Get<ParamsAndGrads>(kParamsAndGrads);
for (auto &node : topo_nodes) {
RecordParamsAndGrads(node, &params_grads);
}
if (params_grads.size() == 0) {
VLOG(10) << "Doesn't find gradients";
return;
}
std::unordered_map<std::string, ir::Node *> vars;
for (ir::Node *node : result.Nodes()) {
if (node->IsVar() && node->Var()) {
// Note: The graph may have the same name node. For example, parameter
// is the input of operator and it also is the output of optimizer;
vars.emplace(node->Var()->Name(), node);
}
}
auto &group_grads_params =
result.Get<GroupGradsAndParams>(kGroupGradsAndParams);
// Note: the order of params_grads may be changed by SetGroupGradsAndParams.
SetGroupGradsAndParams(vars, params_grads, &group_grads_params);
params_grads.clear();
for (auto &group_p_g : group_grads_params) {
params_grads.insert(params_grads.begin(), group_p_g.begin(),
group_p_g.end());
}
for (auto &p_g : params_grads) {
std::swap(p_g.first, p_g.second);
}
// Set Gradients as Persistable to prevent this var becoming reusable.
auto dtype = kDefaultDtype;
for (auto &p_g : params_grads) {
// Get gradient var
auto iter = vars.find(p_g.second);
PADDLE_ENFORCE(iter != vars.end(), "%s is not found.", p_g.second);
iter->second->Var()->SetPersistable(true);
PADDLE_ENFORCE(IsSupportedVarType(iter->second->Var()->GetType()));
// Get Dtype
auto ele_dtype = iter->second->Var()->GetDataType();
if (dtype == kDefaultDtype) {
dtype = ele_dtype;
PADDLE_ENFORCE_NE(ele_dtype, kDefaultDtype,
"The data type should not be bool.");
}
PADDLE_ENFORCE_EQ(ele_dtype, dtype,
"The data type of input is not consistent.");
}
// Create a FusedVarsSet to avoid duplicating names for fused_var in other
// pass.
if (!result.Has(kFusedVars)) {
result.Set(kFusedVars, new FusedVars);
}
// the kFusedGrads is used be fuse_optimizer_op_pass.
result.Set(kFusedGrads, new FusedGrads);
// the fused_var_name should be unique, so it appends
// params_grads.begin()->second.
auto fused_var_name = std::string(kFusedVarNamePrefix) + "@GRAD@" +
params_grads.begin()->second;
result.Get<FusedGrads>(kFusedGrads) = fused_var_name;
auto &fused_var_set = result.Get<FusedVars>(kFusedVars);
PADDLE_ENFORCE_EQ(fused_var_set.count(fused_var_name), 0,
"%s is duplicate in FusedVars.", fused_var_name);
fused_var_set.insert(fused_var_name);
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars, fused_var_name,
params_grads);
}
template <typename AttrType>
void AllocContinuousSpaceForGradPass::ResetAttribute(
const std::string &attr_name, ir::Graph *graph) const {
if (graph->Has(attr_name)) {
VLOG(10) << attr_name << " is reset.";
graph->Erase(attr_name);
}
graph->Set(attr_name, new AttrType);
}
void AllocContinuousSpaceForGradPass::SetGroupGradsAndParams(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const ParamsAndGrads &params_grads,
GroupGradsAndParams *group_grads_params) const {
SetGroupAccordingToLayers(var_nodes, params_grads, group_grads_params);
SetGroupAccordingToMemorySize(var_nodes, group_grads_params);
SetGroupAccordingToGroupSize(var_nodes, group_grads_params);
}
void AllocContinuousSpaceForGradPass::SetGroupAccordingToLayers(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const ParamsAndGrads &params_grads,
GroupGradsAndParams *group_grads_params) const {
std::unordered_map<std::string, std::vector<int>> layer_params;
for (size_t i = 0; i < params_grads.size(); ++i) {
auto pos = params_grads[i].first.find_first_of(".");
if (pos == std::string::npos) {
layer_params[std::string(kUnKnow)].emplace_back(i);
} else {
layer_params[params_grads[i].first.substr(0, pos)].emplace_back(i);
}
}
group_grads_params->reserve(layer_params.size());
for (size_t i = 0; i < params_grads.size(); ++i) {
auto pos = params_grads[i].first.find_first_of(".");
std::string key = kUnKnow;
if (pos != std::string::npos) {
key = params_grads[i].first.substr(0, pos);
}
auto iter = layer_params.find(key);
if (iter == layer_params.end()) continue;
group_grads_params->emplace_back();
auto &local_group_grads_params = group_grads_params->back();
for (auto &idx : iter->second) {
local_group_grads_params.emplace_back(
std::make_pair(params_grads[idx].second, params_grads[idx].first));
}
layer_params.erase(iter);
}
VLOG(10) << "SetGroupAccordingToLayers: ";
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &p_g : group_grads_params->at(i)) {
out << "(" << p_g.second << ", " << p_g.first << "), ";
}
VLOG(10) << out.str();
}
}
void AllocContinuousSpaceForGradPass::SetGroupAccordingToMemorySize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
GroupGradsAndParams *group_grads_params) const {
const uint64_t group_memory_size = GetFuseParameterMemorySize();
if (group_memory_size == 0) {
return;
}
GroupGradsAndParams local_group_grads_params;
size_t j = 0;
while (j < group_grads_params->size()) {
local_group_grads_params.emplace_back();
auto &group_p_g = local_group_grads_params.back();
size_t local_group_memory_size = 0;
while (j < group_grads_params->size()) {
std::for_each(
group_grads_params->at(j).begin(), group_grads_params->at(j).end(),
[&local_group_memory_size,
&var_nodes](const std::pair<std::string, std::string> &g_p) {
auto iter = var_nodes.find(g_p.second);
PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.",
g_p.second);
auto shape = iter->second->Var()->GetShape();
size_t size =
framework::SizeOfType(iter->second->Var()->GetDataType());
std::for_each(shape.begin(), shape.end(),
[&size](const int64_t &n) { size *= n; });
local_group_memory_size += size;
});
group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(),
group_grads_params->at(j).end());
++j;
if (local_group_memory_size >= group_memory_size) {
break;
}
}
}
std::swap(*group_grads_params, local_group_grads_params);
VLOG(10) << string::Sprintf("SetGroupAccordingToMemorySize(memory_size: %d):",
group_memory_size);
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &g_p : group_grads_params->at(i)) {
auto iter = var_nodes.find(g_p.second);
PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.", g_p.second);
auto shape = iter->second->Var()->GetShape();
size_t size = framework::SizeOfType(iter->second->Var()->GetDataType());
std::for_each(shape.begin(), shape.end(),
[&size](const int64_t &n) { size *= n; });
out << string::Sprintf("(%s(%d), %s)", g_p.second, size, g_p.first);
}
VLOG(10) << out.str();
}
}
void AllocContinuousSpaceForGradPass::SetGroupAccordingToGroupSize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
GroupGradsAndParams *group_grads_params) const {
if (GetFuseParameterGroupsSize() == 1) {
return;
}
const int group_size = GetFuseParameterGroupsSize() == -1
? static_cast<int>(group_grads_params->size())
: GetFuseParameterGroupsSize();
PADDLE_ENFORCE_GT(group_size, 1);
size_t groups = (group_grads_params->size() + group_size - 1) / group_size;
GroupGradsAndParams local_group_grads_params;
local_group_grads_params.reserve(groups);
size_t j = 0;
for (size_t i = 0; i < groups; ++i) {
local_group_grads_params.emplace_back();
auto &group_p_g = local_group_grads_params.back();
group_p_g.reserve(group_size);
while (j < group_grads_params->size()) {
group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(),
group_grads_params->at(j).end());
++j;
if (j % group_size == 0) break;
}
}
std::swap(*group_grads_params, local_group_grads_params);
VLOG(10) << string::Sprintf("SetGroupAccordingToGroupSize(group_size: %d):",
group_size);
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &p_g : group_grads_params->at(i)) {
out << "(" << p_g.second << ", " << p_g.first << "), ";
}
VLOG(10) << out.str();
}
}
bool AllocContinuousSpaceForGradPass::IsSupportedVarType(
const proto::VarType::Type &type) const {
// Current only support LOD_TENSOR.
return type == proto::VarType::LOD_TENSOR;
}
void AllocContinuousSpaceForGradPass::RecordParamsAndGrads(
ir::Node *node, ParamsAndGrads *params_grads) const {
try {
bool is_bk_op =
static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
if (!is_bk_op) return;
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
auto backward_vars =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, static_cast<size_t>(0));
for (size_t i = 0; i < backward_vars.size(); i += 2) {
VLOG(10) << "Trainable parameter: " << backward_vars[i]
<< ", gradient: " << backward_vars[i + 1];
params_grads->emplace_back(std::make_pair(backward_vars[i] /*param*/,
backward_vars[i + 1] /*grad*/));
}
} catch (boost::bad_get e) {
}
}
void AllocContinuousSpaceForGradPass::InitFusedVarsAndAllocSpaceForVars(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::unordered_map<std::string, ir::Node *> &vars,
const std::string &fused_var_name,
const ParamsAndGrads &params_grads) const {
// Init Gradients and FusedVars
VLOG(10) << "Init FusedVars and Gradients.";
for (auto it = local_scopes.rbegin(); it != local_scopes.rend(); ++it) {
auto &scope = *it;
PADDLE_ENFORCE(scope->FindVar(fused_var_name) == nullptr,
"%s has existed in scope.", fused_var_name);
scope->Var(fused_var_name)->GetMutable<LoDTensor>();
for (auto &p_g : params_grads) {
auto iter = vars.find(p_g.second);
PADDLE_ENFORCE(iter != vars.end());
PADDLE_ENFORCE_NOT_NULL(iter->second->Var());
PADDLE_ENFORCE_EQ(iter->second->Var()->GetType(),
proto::VarType::LOD_TENSOR);
scope->Var(p_g.second)->GetMutable<LoDTensor>();
}
}
// Alloc continuous space for vars.
std::vector<std::string> grads_name;
std::vector<std::string> params_name;
grads_name.reserve(params_grads.size());
params_name.reserve(params_grads.size());
for (auto &p_g : params_grads) {
params_name.emplace_back(p_g.first);
grads_name.emplace_back(p_g.second);
}
framework::ProgramDesc program_desc;
AppendAllocSpaceForVarsOp(params_name, grads_name, fused_var_name,
program_desc.MutableBlock(0));
for (size_t i = 0; i < local_scopes.size(); ++i) {
for (auto &op_desc : program_desc.Block(0).AllOps()) {
auto op = OpRegistry::CreateOp(*op_desc);
op->Run(*local_scopes[i], places[i]);
}
}
}
void AllocContinuousSpaceForGradPass::AppendAllocSpaceForVarsOp(
const std::vector<std::string> &params_name,
const std::vector<std::string> &grads_name,
const std::string &fused_var_name, BlockDesc *global_block) const {
auto op_desc = global_block->AppendOp();
op_desc->SetType("alloc_continuous_space");
op_desc->SetInput("Input", params_name);
op_desc->SetOutput("Output", grads_name);
op_desc->SetOutput("FusedOutput", {fused_var_name});
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(alloc_continuous_space_for_grad_pass,
paddle::framework::details::AllocContinuousSpaceForGradPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace details {
void SetFuseParameterGroupsSize(int group_size);
int GetFuseParameterGroupsSize();
void SetFuseParameterMemorySize(uint64_t memory_size);
uint64_t GetFuseParameterMemorySize();
class AllocContinuousSpaceForGradPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override;
template <typename AttrType>
void ResetAttribute(const std::string &attr_name, ir::Graph *graph) const;
void SetGroupGradsAndParams(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const ParamsAndGrads &params_grads,
GroupGradsAndParams *group_grads_params) const;
void SetGroupAccordingToLayers(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const ParamsAndGrads &params_grads,
GroupGradsAndParams *group_grads_params) const;
void SetGroupAccordingToMemorySize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
GroupGradsAndParams *group_grads_params) const;
void SetGroupAccordingToGroupSize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
GroupGradsAndParams *group_grads_params) const;
private:
bool IsSupportedVarType(const proto::VarType::Type &type) const;
void RecordParamsAndGrads(ir::Node *node, ParamsAndGrads *params_grads) const;
void InitFusedVarsAndAllocSpaceForVars(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::unordered_map<std::string, ir::Node *> &vars,
const std::string &fused_var_name,
const ParamsAndGrads &params_grads) const;
void AppendAllocSpaceForVarsOp(const std::vector<std::string> &params_name,
const std::vector<std::string> &grads_name,
const std::string &fused_var_name,
BlockDesc *global_block) const;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -17,15 +17,14 @@ limitations under the License. */
#include <glog/logging.h>
#include <memory>
#include <utility>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
namespace paddle {
namespace framework {
......@@ -173,10 +172,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
const std::string graph_path =
string::Sprintf("%s%s", strategy_.debug_graphviz_path_.c_str(),
"_multi_devices_graph");
multi_devices_print_pass->Set<std::string>(kGraphvizPath,
multi_devices_print_pass->Set<std::string>(ir::kGraphvizPath,
new std::string(graph_path));
multi_devices_print_pass->Set<details::GraphvizSSAGraphPrinter>(
"graph_printer", new details::GraphvizSSAGraphPrinter);
multi_devices_print_pass->Set<ir::GraphvizSSAGraphPrinter>(
"graph_printer", new ir::GraphvizSSAGraphPrinter);
}
// experimental shows that the program will be faster if append
......@@ -240,7 +239,7 @@ std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
}
bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
return framework::details::MultiDevSSAGraphBuilder().count(pass_name) > 0;
return framework::ir::MultiDevSSAGraphBuilder().count(pass_name) > 0;
}
ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
......@@ -263,13 +262,13 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
if (IsMultiDevPass(pass->Type())) {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
pass->Erase(kLossVarName);
pass->SetNotOwned<const std::string>(kLossVarName, &loss_var_name);
pass->Erase(ir::kLossVarName);
pass->SetNotOwned<const std::string>(ir::kLossVarName, &loss_var_name);
pass->Erase(kLocalScopes);
pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
&local_scopes);
pass->Erase(kNRanks);
pass->Set<size_t>(kNRanks, new size_t(nranks));
pass->Erase(ir::kNRanks);
pass->Set<size_t>(ir::kNRanks, new size_t(nranks));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
......@@ -312,8 +311,8 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
continue;
}
} else if (pass->Type() == "inplace_pass") {
pass->Erase(kUseCuda);
pass->Set<bool>(kUseCuda, new bool(use_cuda));
pass->Erase(ir::kUseCuda);
pass->Set<bool>(ir::kUseCuda, new bool(use_cuda));
}
VLOG(3) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(graph);
......
......@@ -31,7 +31,7 @@ namespace details {
EagerDeletionOpHandle::EagerDeletionOpHandle(
ir::Node *node, const Scope *scope, const platform::Place &place,
const std::unordered_set<std::string> &var_names, GarbageCollector *gc,
AtomicReferenceCountMap *ref_cnts)
ir::AtomicReferenceCountMap *ref_cnts)
: OpHandleBase(node),
scope_(scope),
var_names_(var_names.begin(), var_names.end()),
......
......@@ -20,7 +20,7 @@
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
namespace paddle {
namespace framework {
......@@ -34,7 +34,7 @@ class EagerDeletionOpHandle : public OpHandleBase {
const platform::Place &place,
const std::unordered_set<std::string> &var_names,
GarbageCollector *gc,
AtomicReferenceCountMap *ref_cnts);
ir::AtomicReferenceCountMap *ref_cnts);
~EagerDeletionOpHandle();
......@@ -55,8 +55,8 @@ class EagerDeletionOpHandle : public OpHandleBase {
const Scope *scope_;
std::vector<std::string> var_names_;
GarbageCollector *gc_; // not own
AtomicReferenceCountMap *ref_cnts_; // not own
GarbageCollector *gc_; // not own
ir::AtomicReferenceCountMap *ref_cnts_; // not own
#ifdef PADDLE_WITH_CUDA
platform::CUDADeviceContext *dev_ctx_{nullptr};
cudaEvent_t event_{nullptr};
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
class ModifyOpLockAndRecordEventPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace details {
static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) {
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
op1->Outputs() == op2->Outputs();
}
void SequentialExecutionPass::ApplyImpl(ir::Graph *graph) const {
// FIXME(zjl): Insert dependencies between some distributed ops may cause
// the multi_devices_graph_pass fails. So we skip these ops here.
// Indeed, maybe we should not insert dependencies between these ops
// casually, which may cause deadlock easily.
// We should add more skipped distributed ops when found errors in
// multi_devices_graph_pass
static std::unordered_set<std::string> skip_dist_ops{
"send", "recv", "send_barrier", "fetch_barrier"};
auto &ops = graph->Get<const std::vector<OpDesc *>>(kStaleProgramOpDescs);
std::vector<ir::Node *> op_node_list;
op_node_list.reserve(ops.size());
std::unordered_map<ir::Node *, size_t> op_deps;
std::unordered_map<ir::Node *, std::unordered_set<ir::Node *>> pending_ops;
std::unordered_set<ir::Node *> ready_ops;
for (ir::Node *node : graph->Nodes()) {
if (!node->IsOp()) continue;
std::unordered_set<ir::Node *> preceding_ops;
for (auto *in : node->inputs) {
PADDLE_ENFORCE(in->IsVar(),
"Preceding Node of Op Nodes must be Var Node");
if (in->inputs.empty()) continue;
PADDLE_ENFORCE(in->inputs.size() == 1 && in->inputs[0]->IsOp(),
"Preceding Op Node of Var Node must be unique");
preceding_ops.insert(in->inputs[0]);
pending_ops[in->inputs[0]].insert(node);
}
op_deps[node] = preceding_ops.size();
if (preceding_ops.empty()) {
ready_ops.insert(node);
}
}
for (auto *op_desc : ops) {
ir::Node *found_node = nullptr;
for (auto *node : ready_ops) {
if (IsSameOpDesc(op_desc, node->Op())) {
PADDLE_ENFORCE(found_node == nullptr,
"Found multiple op_desc in graph: %s", op_desc->Type());
found_node = node;
}
}
PADDLE_ENFORCE_NOT_NULL(found_node, "Cannot find op_desc in graph: %s",
op_desc->Type());
for (auto *pending_op : pending_ops[found_node]) {
if (--op_deps.at(pending_op) == 0) {
ready_ops.insert(pending_op);
}
}
ready_ops.erase(found_node);
if (skip_dist_ops.count(op_desc->Type()) == 0) {
op_node_list.push_back(found_node);
}
}
for (size_t i = 1; i < op_node_list.size(); ++i) {
auto *dep_var = graph->CreateControlDepVar();
op_node_list[i]->inputs.push_back(dep_var);
op_node_list[i - 1]->outputs.push_back(dep_var);
dep_var->outputs.push_back(op_node_list[i]);
dep_var->inputs.push_back(op_node_list[i - 1]);
VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name()
<< " and " << op_node_list[i]->Name();
}
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(sequential_execution_pass,
paddle::framework::details::SequentialExecutionPass)
.RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
class SequentialExecutionPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -19,7 +19,7 @@
#include <unordered_map>
#include <unordered_set>
#include "glog/logging.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/type_defs.h"
......
......@@ -18,7 +18,7 @@
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -33,7 +33,7 @@ namespace framework {
std::unique_ptr<ir::Pass> CreateInplacePass() {
auto pass = ir::PassRegistry::Instance().Get("inplace_pass");
pass->Set<bool>(details::kUseCuda, new bool(true));
pass->Set<bool>(ir::kUseCuda, new bool(true));
return pass;
}
......@@ -225,7 +225,7 @@ TEST(InferInplace, SingleOpInplaceInToOut) {
FakeSuccData(&prog);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
g->Set(ir::kMemOptSkipVars, new std::unordered_set<std::string>());
g = test_SingleOpInplaceInToOut(std::move(g));
auto op_node = GetNodeFromGraph(g.get(), "single_op");
......@@ -241,7 +241,7 @@ TEST(InferInplace, SingleOpInplaceInToOutNoInplace) {
FakeNoInplaceData(&prog);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
g->Set(ir::kMemOptSkipVars, new std::unordered_set<std::string>());
g = test_SingleOpInplaceInToOut(std::move(g));
auto op_node = GetNodeFromGraph(g.get(), "single_op");
......@@ -274,7 +274,7 @@ TEST(InferInplace, MultiOutInplaceInToOut) {
prog.MutableBlock(0)->Var("z0")->SetShape({32, 16, 1024, 1024});
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
g->Set(ir::kMemOptSkipVars, new std::unordered_set<std::string>());
auto pass = CreateInplacePass();
pass->Apply(g.get());
auto op_node = GetNodeFromGraph(g.get(), "multi_out_op");
......@@ -310,7 +310,7 @@ TEST(InferInplace, MultiGradInplaceInToOut) {
prog.MutableBlock(0)->Var("z0")->SetShape({32, 15, 1024, 1024});
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
g->Set(ir::kMemOptSkipVars, new std::unordered_set<std::string>());
auto pass = CreateInplacePass();
pass->Apply(g.get());
auto op_node = GetNodeFromGraph(g.get(), "multi_out_grad");
......
......@@ -3,6 +3,9 @@ file(WRITE ${pass_file} "// Generated by the paddle/fluid/framework/ir/CMakeList
file(APPEND ${pass_file} "\#pragma once\n")
file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
add_subdirectory(fuse_optimizer_ops_pass)
add_subdirectory(memory_optimize_pass)
add_subdirectory(multi_devices_graph_pass)
# Usage: pass_library(target inference) will append to paddle_inference_pass.h
unset(INFER_IR_PASSES CACHE) # clear the global variable
......@@ -34,7 +37,6 @@ function(pass_library TARGET DEST)
endif()
endfunction()
cc_library(node SRCS node.cc DEPS proto_desc)
cc_library(graph SRCS graph.cc DEPS node pretty_log)
cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
......@@ -43,6 +45,8 @@ cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper)
pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.h"
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
DEFINE_uint64(fuse_parameter_memory_size, 0, // 0 KB
"fuse_parameter_memory_size is up limited memory size "
"of one group parameters' gradient which is the input "
"of communication calling(e.g NCCLAllReduce). "
"The default value is 0, it means that "
"not set group according to memory_size.");
DEFINE_int32(
fuse_parameter_groups_size, 3,
"fuse_parameter_groups_size is the size of one group parameters' gradient. "
"The default value is a experimental result. If the "
"fuse_parameter_groups_size is 1, it means that the groups size is "
"the number of parameters' gradient. If the fuse_parameter_groups_size is "
"-1, it means that there are only one group. The default value is 3, it is "
"an experimental value.");
namespace paddle {
namespace framework {
namespace ir {
// SetFuseParameterGroupsSize and SetFuseParameterMemorySize are used in unit
// test, because it is invalid that seting 'FLAGS_fuse_parameter_memory_size'
// and 'FLAGS_fuse_parameter_groups_size' in unit test.
void SetFuseParameterGroupsSize(int group_size) {
FLAGS_fuse_parameter_groups_size = group_size;
}
int GetFuseParameterGroupsSize() { return FLAGS_fuse_parameter_groups_size; }
void SetFuseParameterMemorySize(uint64_t memory_size) {
FLAGS_fuse_parameter_memory_size = memory_size;
}
uint64_t GetFuseParameterMemorySize() {
return FLAGS_fuse_parameter_memory_size;
}
static const char kUnKnow[] = "@UNKNOW@";
static framework::proto::VarType::Type kDefaultDtype =
framework::proto::VarType::Type::VarType_Type_BOOL;
class AllocContinuousSpaceForGradPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(details::kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(details::kLocalScopes);
ResetAttribute<details::ParamsAndGrads>(details::kParamsAndGrads, &result);
ResetAttribute<details::GroupGradsAndParams>(details::kGroupGradsAndParams,
&result);
// NOTE: The operator nodes should be in topology order.
std::vector<ir::Node *> topo_nodes = ir::TopologySortOperations(result);
auto &params_grads =
result.Get<details::ParamsAndGrads>(details::kParamsAndGrads);
for (auto &node : topo_nodes) {
RecordParamsAndGrads(node, &params_grads);
}
if (params_grads.size() == 0) {
VLOG(10) << "Doesn't find gradients";
return;
}
std::unordered_map<std::string, ir::Node *> vars;
for (ir::Node *node : result.Nodes()) {
if (node->IsVar() && node->Var()) {
// Note: The graph may have the same name node. For example, parameter
// is the input of operator and it also is the output of optimizer;
vars.emplace(node->Var()->Name(), node);
}
}
auto &group_grads_params =
result.Get<details::GroupGradsAndParams>(details::kGroupGradsAndParams);
// Note: the order of params_grads may be changed by SetGroupGradsAndParams.
SetGroupGradsAndParams(vars, params_grads, &group_grads_params);
params_grads.clear();
for (auto &group_p_g : group_grads_params) {
params_grads.insert(params_grads.begin(), group_p_g.begin(),
group_p_g.end());
}
for (auto &p_g : params_grads) {
std::swap(p_g.first, p_g.second);
}
// Set Gradients as Persistable to prevent this var becoming reusable.
auto dtype = kDefaultDtype;
for (auto &p_g : params_grads) {
// Get gradient var
auto iter = vars.find(p_g.second);
PADDLE_ENFORCE(iter != vars.end(), "%s is not found.", p_g.second);
iter->second->Var()->SetPersistable(true);
PADDLE_ENFORCE(IsSupportedVarType(iter->second->Var()->GetType()));
// Get Dtype
auto ele_dtype = iter->second->Var()->GetDataType();
if (dtype == kDefaultDtype) {
dtype = ele_dtype;
PADDLE_ENFORCE_NE(ele_dtype, kDefaultDtype,
"The data type should not be bool.");
}
PADDLE_ENFORCE_EQ(ele_dtype, dtype,
"The data type of input is not consistent.");
}
// Create a FusedVarsSet to avoid duplicating names for fused_var in other
// pass.
if (!result.Has(details::kFusedVars)) {
result.Set(details::kFusedVars, new details::FusedVars);
}
// the kFusedGrads is used be fuse_optimizer_op_pass.
result.Set(details::kFusedGrads, new details::FusedGrads);
// the fused_var_name should be unique, so it appends
// params_grads.begin()->second.
auto fused_var_name = std::string(details::kFusedVarNamePrefix) + "@GRAD@" +
params_grads.begin()->second;
result.Get<details::FusedGrads>(details::kFusedGrads) = fused_var_name;
auto &fused_var_set = result.Get<details::FusedVars>(details::kFusedVars);
PADDLE_ENFORCE_EQ(fused_var_set.count(fused_var_name), 0,
"%s is duplicate in FusedVars.", fused_var_name);
fused_var_set.insert(fused_var_name);
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars,
fused_var_name, params_grads);
}
template <typename AttrType>
void ResetAttribute(const std::string &attr_name, ir::Graph *graph) const {
if (graph->Has(attr_name)) {
VLOG(10) << attr_name << " is reset.";
graph->Erase(attr_name);
}
graph->Set(attr_name, new AttrType);
}
void SetGroupGradsAndParams(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const details::ParamsAndGrads &params_grads,
details::GroupGradsAndParams *group_grads_params) const {
SetGroupAccordingToLayers(var_nodes, params_grads, group_grads_params);
SetGroupAccordingToMemorySize(var_nodes, group_grads_params);
SetGroupAccordingToGroupSize(var_nodes, group_grads_params);
}
void SetGroupAccordingToLayers(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const details::ParamsAndGrads &params_grads,
details::GroupGradsAndParams *group_grads_params) const {
std::unordered_map<std::string, std::vector<int>> layer_params;
for (size_t i = 0; i < params_grads.size(); ++i) {
auto pos = params_grads[i].first.find_first_of(".");
if (pos == std::string::npos) {
layer_params[std::string(kUnKnow)].emplace_back(i);
} else {
layer_params[params_grads[i].first.substr(0, pos)].emplace_back(i);
}
}
group_grads_params->reserve(layer_params.size());
for (size_t i = 0; i < params_grads.size(); ++i) {
auto pos = params_grads[i].first.find_first_of(".");
std::string key = kUnKnow;
if (pos != std::string::npos) {
key = params_grads[i].first.substr(0, pos);
}
auto iter = layer_params.find(key);
if (iter == layer_params.end()) continue;
group_grads_params->emplace_back();
auto &local_group_grads_params = group_grads_params->back();
for (auto &idx : iter->second) {
local_group_grads_params.emplace_back(
std::make_pair(params_grads[idx].second, params_grads[idx].first));
}
layer_params.erase(iter);
}
VLOG(10) << "SetGroupAccordingToLayers: ";
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &p_g : group_grads_params->at(i)) {
out << "(" << p_g.second << ", " << p_g.first << "), ";
}
VLOG(10) << out.str();
}
}
void SetGroupAccordingToMemorySize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
details::GroupGradsAndParams *group_grads_params) const {
const uint64_t group_memory_size = GetFuseParameterMemorySize();
if (group_memory_size == 0) {
return;
}
details::GroupGradsAndParams local_group_grads_params;
size_t j = 0;
while (j < group_grads_params->size()) {
local_group_grads_params.emplace_back();
auto &group_p_g = local_group_grads_params.back();
size_t local_group_memory_size = 0;
while (j < group_grads_params->size()) {
std::for_each(
group_grads_params->at(j).begin(), group_grads_params->at(j).end(),
[&local_group_memory_size,
&var_nodes](const std::pair<std::string, std::string> &g_p) {
auto iter = var_nodes.find(g_p.second);
PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.",
g_p.second);
auto shape = iter->second->Var()->GetShape();
size_t size =
framework::SizeOfType(iter->second->Var()->GetDataType());
std::for_each(shape.begin(), shape.end(),
[&size](const int64_t &n) { size *= n; });
local_group_memory_size += size;
});
group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(),
group_grads_params->at(j).end());
++j;
if (local_group_memory_size >= group_memory_size) {
break;
}
}
}
std::swap(*group_grads_params, local_group_grads_params);
VLOG(10) << string::Sprintf(
"SetGroupAccordingToMemorySize(memory_size: %d):", group_memory_size);
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &g_p : group_grads_params->at(i)) {
auto iter = var_nodes.find(g_p.second);
PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.", g_p.second);
auto shape = iter->second->Var()->GetShape();
size_t size = framework::SizeOfType(iter->second->Var()->GetDataType());
std::for_each(shape.begin(), shape.end(),
[&size](const int64_t &n) { size *= n; });
out << string::Sprintf("(%s(%d), %s)", g_p.second, size, g_p.first);
}
VLOG(10) << out.str();
}
}
void SetGroupAccordingToGroupSize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
details::GroupGradsAndParams *group_grads_params) const {
if (GetFuseParameterGroupsSize() == 1) {
return;
}
const int group_size = GetFuseParameterGroupsSize() == -1
? static_cast<int>(group_grads_params->size())
: GetFuseParameterGroupsSize();
PADDLE_ENFORCE_GT(group_size, 1);
size_t groups = (group_grads_params->size() + group_size - 1) / group_size;
details::GroupGradsAndParams local_group_grads_params;
local_group_grads_params.reserve(groups);
size_t j = 0;
for (size_t i = 0; i < groups; ++i) {
local_group_grads_params.emplace_back();
auto &group_p_g = local_group_grads_params.back();
group_p_g.reserve(group_size);
while (j < group_grads_params->size()) {
group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(),
group_grads_params->at(j).end());
++j;
if (j % group_size == 0) break;
}
}
std::swap(*group_grads_params, local_group_grads_params);
VLOG(10) << string::Sprintf("SetGroupAccordingToGroupSize(group_size: %d):",
group_size);
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &p_g : group_grads_params->at(i)) {
out << "(" << p_g.second << ", " << p_g.first << "), ";
}
VLOG(10) << out.str();
}
}
private:
bool IsSupportedVarType(const proto::VarType::Type &type) const {
// Current only support LOD_TENSOR.
return type == proto::VarType::LOD_TENSOR;
}
void RecordParamsAndGrads(ir::Node *node,
details::ParamsAndGrads *params_grads) const {
try {
bool is_bk_op =
static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
if (!is_bk_op) return;
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
auto backward_vars =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, static_cast<size_t>(0));
for (size_t i = 0; i < backward_vars.size(); i += 2) {
VLOG(10) << "Trainable parameter: " << backward_vars[i]
<< ", gradient: " << backward_vars[i + 1];
params_grads->emplace_back(std::make_pair(
backward_vars[i] /*param*/, backward_vars[i + 1] /*grad*/));
}
} catch (boost::bad_get e) {
}
}
void InitFusedVarsAndAllocSpaceForVars(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::unordered_map<std::string, ir::Node *> &vars,
const std::string &fused_var_name,
const details::ParamsAndGrads &params_grads) const {
// Init Gradients and FusedVars
VLOG(10) << "Init FusedVars and Gradients.";
for (auto it = local_scopes.rbegin(); it != local_scopes.rend(); ++it) {
auto &scope = *it;
PADDLE_ENFORCE(scope->FindVar(fused_var_name) == nullptr,
"%s has existed in scope.", fused_var_name);
scope->Var(fused_var_name)->GetMutable<LoDTensor>();
for (auto &p_g : params_grads) {
auto iter = vars.find(p_g.second);
PADDLE_ENFORCE(iter != vars.end());
PADDLE_ENFORCE_NOT_NULL(iter->second->Var());
PADDLE_ENFORCE_EQ(iter->second->Var()->GetType(),
proto::VarType::LOD_TENSOR);
scope->Var(p_g.second)->GetMutable<LoDTensor>();
}
}
// Alloc continuous space for vars.
std::vector<std::string> grads_name;
std::vector<std::string> params_name;
grads_name.reserve(params_grads.size());
params_name.reserve(params_grads.size());
for (auto &p_g : params_grads) {
params_name.emplace_back(p_g.first);
grads_name.emplace_back(p_g.second);
}
framework::ProgramDesc program_desc;
AppendAllocSpaceForVarsOp(params_name, grads_name, fused_var_name,
program_desc.MutableBlock(0));
for (size_t i = 0; i < local_scopes.size(); ++i) {
for (auto &op_desc : program_desc.Block(0).AllOps()) {
auto op = OpRegistry::CreateOp(*op_desc);
op->Run(*local_scopes[i], places[i]);
}
}
}
void AppendAllocSpaceForVarsOp(const std::vector<std::string> &params_name,
const std::vector<std::string> &grads_name,
const std::string &fused_var_name,
BlockDesc *global_block) const {
auto op_desc = global_block->AppendOp();
op_desc->SetType("alloc_continuous_space");
op_desc->SetInput("Input", params_name);
op_desc->SetOutput("Output", grads_name);
op_desc->SetOutput("FusedOutput", {fused_var_name});
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(alloc_continuous_space_for_grad_pass,
paddle::framework::ir::AllocContinuousSpaceForGradPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -11,21 +11,19 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include <algorithm>
namespace paddle {
namespace framework {
namespace details {
namespace ir {
void SetFuseParameterGroupsSize(int group_size);
int GetFuseParameterGroupsSize();
class ReferenceCountPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
void SetFuseParameterMemorySize(uint64_t memory_size);
uint64_t GetFuseParameterMemorySize();
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
cc_library(fuse_optimizer_op_pass SRCS fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(fuse_adam_op_pass SRCS fuse_adam_op_pass.cc DEPS fuse_optimizer_op_pass)
cc_library(fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc DEPS fuse_optimizer_op_pass)
cc_library(fuse_momentum_op_pass SRCS fuse_momentum_op_pass.cc DEPS fuse_optimizer_op_pass)
......@@ -16,16 +16,13 @@
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class FuseAdamOpPass : public FuseOptimizerOpPass {
private:
......@@ -203,10 +200,10 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
}
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_adam_op_pass, paddle::framework::details::FuseAdamOpPass)
REGISTER_PASS(fuse_adam_op_pass, paddle::framework::ir::FuseAdamOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
......@@ -16,14 +16,13 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class FuseMomentumOpPass : public FuseOptimizerOpPass {
private:
......@@ -84,11 +83,10 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass {
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_momentum_op_pass,
paddle::framework::details::FuseMomentumOpPass)
REGISTER_PASS(fuse_momentum_op_pass, paddle::framework::ir::FuseMomentumOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h"
#include <algorithm>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_helper.h"
......@@ -20,13 +20,13 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(kLocalScopes);
auto &places = Get<const std::vector<platform::Place>>(details::kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(details::kLocalScopes);
const std::string fuse_op_type = GetOpType();
std::vector<std::string> aux_var_names = GetAuxiliaryVarNames();
......@@ -47,24 +47,24 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
return;
}
if (result.Has(kFusedOptType)) {
if (result.Has(details::kFusedOptType)) {
VLOG(6) << "Currently only support fusing one type optimizer op. Has fused "
<< result.Get<FusedOptType>(kFusedOptType);
<< result.Get<details::FusedOptType>(details::kFusedOptType);
return;
} else {
result.Set(kFusedOptType, new FusedOptType);
result.Set(details::kFusedOptType, new details::FusedOptType);
}
result.Get<FusedOptType>(kFusedOptType) = fuse_op_type;
result.Get<details::FusedOptType>(details::kFusedOptType) = fuse_op_type;
// Step 2: Insert fused_var_name to FusedVars, and the FusedVars need be
// initialized in scopes before execution.
if (!result.Has(kFusedVars)) {
result.Set(kFusedVars, new FusedVars);
if (!result.Has(details::kFusedVars)) {
result.Set(details::kFusedVars, new details::FusedVars);
}
std::unordered_map<std::string, std::string> fused_vars_name;
fused_vars_name.reserve(aux_var_names.size());
auto &fused_var_set = result.Get<FusedVars>(kFusedVars);
const std::string prefix(kFusedVarNamePrefix);
auto &fused_var_set = result.Get<details::FusedVars>(details::kFusedVars);
const std::string prefix(details::kFusedVarNamePrefix);
// NOTE: the fused_var_name should be unique.
for (auto &var_name : aux_var_names) {
auto fused_var_name = prefix + "_" + fuse_op_type + "_" + var_name + "_" +
......@@ -77,8 +77,9 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
// Step 3: Get the fused Gradient's name
bool grad_fused = false;
if (result.Has(kParamsAndGrads)) {
auto &params_grads = result.Get<ParamsAndGrads>(kParamsAndGrads);
if (result.Has(details::kParamsAndGrads)) {
auto &params_grads =
result.Get<details::ParamsAndGrads>(details::kParamsAndGrads);
PADDLE_ENFORCE_EQ(
params_grads.size(), aux_var_set.at(kGrad).size(),
"The number of gradients and optimizer ops is not equal.");
......@@ -94,13 +95,13 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
// NOTE(zcd): the gradient of kParamsAndGrads may be different with the
// kGrad.
if (same_grad_num == aux_var_set.at(kGrad).size()) {
if (!result.Has(kFusedGrads)) {
if (!result.Has(details::kFusedGrads)) {
PADDLE_THROW(
"The alloc_continuous_space_for_grad_pass should be called before "
"this pass.");
}
auto &fused_grad = result.Get<FusedGrads>(kFusedGrads);
auto &fused_vars = result.Get<FusedVars>(kFusedVars);
auto &fused_grad = result.Get<details::FusedGrads>(details::kFusedGrads);
auto &fused_vars = result.Get<details::FusedVars>(details::kFusedVars);
auto iter = std::find(fused_vars.begin(), fused_vars.end(), fused_grad);
PADDLE_ENFORCE(iter != fused_vars.end(), "Not find the fused_grad.");
fused_vars_name[kGrad] = fused_grad;
......@@ -323,6 +324,6 @@ void FuseOptimizerOpPass::InserInputAndOutputForOptOps(
opt_node->outputs.insert(opt_node->outputs.begin(), outputs.begin(),
outputs.end());
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -25,7 +25,7 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
constexpr char kGrad[] = "Grad";
constexpr char kParam[] = "Param";
......@@ -90,6 +90,6 @@ class FuseOptimizerOpPass : public ir::Pass {
const std::string &fused_var_name) const;
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -14,18 +14,13 @@
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class FuseSgdOpPass : public FuseOptimizerOpPass {
private:
......@@ -66,10 +61,10 @@ class FuseSgdOpPass : public FuseOptimizerOpPass {
InserInputAndOutputForOptOps(sgd_ops, sgd_node);
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_sgd_op_pass, paddle::framework::details::FuseSgdOpPass)
REGISTER_PASS(fuse_sgd_op_pass, paddle::framework::ir::FuseSgdOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base)
cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle)
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
if(WITH_GPU)
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper gpu_info)
else()
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper cpu_info)
endif()
cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass)
cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info)
cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass reference_count_pass_helper)
cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper)
......@@ -27,11 +27,11 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
// op -> variables which can be deleted after op runs
using OpToVarNameSetMap =
std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>>;
using OpToVarNameSetMap = std::unordered_map<details::ComputationOpHandle *,
std::unordered_set<std::string>>;
static std::map<size_t, std::unordered_set<std::string>> VarsGroupByScopeIdx(
const OpToVarNameSetMap &map) {
......@@ -53,7 +53,8 @@ static bool IsLoDTensor(VarDesc *var) {
// Get memory size of LoDTensor
static int64_t GetMemorySize(
const std::unordered_map<std::string, std::vector<VarHandle *>> &vars,
const std::unordered_map<std::string, std::vector<details::VarHandle *>>
&vars,
const std::string &var_name) {
auto *var_desc = TryGetLatestVarDesc(vars.at(var_name));
PADDLE_ENFORCE_NOT_NULL(var_desc);
......@@ -69,13 +70,13 @@ static int64_t GetMemorySize(
// Since partial GC is based on static analysis of memory size of each variable
// So we should skip SelectedRows and LoDTensorArray here
static void SplitIntoLoDTensorAndNonLoDTensorVars(
const OpToVarNameSetMap &m, const GraphVars &vars,
const OpToVarNameSetMap &m, const details::GraphVars &vars,
OpToVarNameSetMap *lod_tensors, OpToVarNameSetMap *other_vars) {
lod_tensors->clear();
other_vars->clear();
for (auto &op_vars_pair : m) {
for (auto &var_name : op_vars_pair.second) {
for (auto var_name : op_vars_pair.second) {
auto *var_desc = TryGetLatestVarDesc(
vars[op_vars_pair.first->GetScopeIdx()].at(var_name));
if (IsLoDTensor(var_desc)) {
......@@ -89,23 +90,24 @@ static void SplitIntoLoDTensorAndNonLoDTensorVars(
struct GCVarInfo {
GCVarInfo(const std::string &name, int64_t memory_size,
ComputationOpHandle *op, size_t scope_idx)
details::ComputationOpHandle *op, size_t scope_idx)
: name_(name),
memory_size_(memory_size),
op_(op),
scope_idx_(scope_idx) {}
std::string name_; // variable name
int64_t memory_size_; // memory size
ComputationOpHandle *op_; // op after which the variable could be deleted
size_t scope_idx_; // scope index where the variable locates
std::string name_; // variable name
int64_t memory_size_; // memory size
details::ComputationOpHandle
*op_; // op after which the variable could be deleted
size_t scope_idx_; // scope index where the variable locates
int64_t AbsMemorySize() const { return std::abs(memory_size_); }
};
// Delete delete_lod_tensor_only is not used currently
static OpToVarNameSetMap ShrinkGCVars(
const OpToVarNameSetMap &m, const GraphVars &vars,
const OpToVarNameSetMap &m, const details::GraphVars &vars,
const std::vector<platform::Place> &places, double fraction_of_memory_size,
bool delete_lod_tensor_only = false) {
// Do not perform gc when fraction_of_memory_size = 0
......@@ -192,7 +194,7 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE(ref_cnts.empty(),
"kRuntimeReferenceCount should be initialized here!");
const auto &vars = graph->Get<GraphVars>(kGraphVars);
const auto &vars = graph->Get<details::GraphVars>(details::kGraphVars);
ref_cnts.resize(vars.size());
const auto &last_live_ops =
......@@ -222,27 +224,31 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
auto *eager_deletion_node =
graph->CreateEmptyNode("eager_deletion", ir::Node::Type::kOperation);
auto *eager_deletion_op = new EagerDeletionOpHandle(
auto *eager_deletion_op = new details::EagerDeletionOpHandle(
eager_deletion_node, op->GetScope(), op->GetPlace(), var_names,
gcs.at(places[op->GetScopeIdx()]).get(),
&(ref_cnts[op->GetScopeIdx()]));
auto it = std::find_if(
op->Outputs().begin(), op->Outputs().end(), [](VarHandleBase *var) {
return dynamic_cast<DummyVarHandle *>(var) != nullptr;
op->Outputs().begin(), op->Outputs().end(),
[](details::VarHandleBase *var) {
return dynamic_cast<details::DummyVarHandle *>(var) != nullptr;
});
if (it != op->Outputs().end()) {
eager_deletion_op->AddInput(*it);
} else {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
auto *dep_var = new details::DummyVarHandle(graph->CreateControlDepVar());
graph->Get<details::GraphDepVars>(details::kGraphDepVars)
.emplace(dep_var);
op->AddOutput(dep_var);
eager_deletion_op->AddInput(dep_var);
}
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
auto *dummy_leaf =
new details::DummyVarHandle(graph->CreateControlDepVar());
graph->Get<details::GraphDepVars>(details::kGraphDepVars)
.emplace(dummy_leaf);
eager_deletion_op->AddOutput(dummy_leaf);
}
......@@ -262,15 +268,14 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
while_op_eager_deletion_pass->Apply(graph);
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(eager_deletion_pass,
paddle::framework::details::EagerDeletionPass)
.RequirePassAttr(paddle::framework::details::kRuntimeReferenceCount)
.RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars)
.RequirePassAttr(paddle::framework::details::kAllPlaces)
.RequirePassAttr(paddle::framework::details::kGarbageCollector);
REGISTER_PASS(eager_deletion_pass, paddle::framework::ir::EagerDeletionPass)
.RequirePassAttr(paddle::framework::ir::kRuntimeReferenceCount)
.RequirePassAttr(paddle::framework::ir::kLastLiveOpsOfVars)
.RequirePassAttr(paddle::framework::ir::kAllPlaces)
.RequirePassAttr(paddle::framework::ir::kGarbageCollector);
USE_PASS(while_op_eager_deletion_pass);
......@@ -16,9 +16,9 @@
#include <queue>
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_info.h"
......@@ -52,7 +52,7 @@ DECLARE_string(memory_optimize_debug);
namespace paddle {
namespace framework {
namespace details {
namespace ir {
// clang-format off
const std::string kInplacedOpWhiteList[] = { // NOLINT
......@@ -199,8 +199,8 @@ bool InplacePass::CheckOpDeps(ir::Node *op,
void InplacePass::CollectSkipVars(ir::Graph *graph,
const std::vector<ir::Node *> &ops) const {
// 1. Collect op role vars
PADDLE_ENFORCE(graph->Has(details::kMemOptSkipVars),
"Graph should have attr %s", details::kMemOptSkipVars);
PADDLE_ENFORCE(graph->Has(kMemOptSkipVars), "Graph should have attr %s",
kMemOptSkipVars);
auto &mem_opt_whitelist = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
for (const auto &var : mem_opt_whitelist) {
skip_vars_.emplace(var);
......@@ -452,8 +452,7 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
continue;
}
if (details::NodeSize(*in_node->Var()) !=
details::NodeSize(*out_node->Var()) &&
if (NodeSize(*in_node->Var()) != NodeSize(*out_node->Var()) &&
kSameShapeOpWhiteSet.count(op_desc->Type()) == 0) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " is not the same size with "
......@@ -476,9 +475,9 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
}
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(inplace_pass, paddle::framework::details::InplacePass)
.RequirePassAttr(paddle::framework::details::kUseCuda);
REGISTER_PASS(inplace_pass, paddle::framework::ir::InplacePass)
.RequirePassAttr(paddle::framework::ir::kUseCuda);
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include <algorithm>
#include <deque>
#include <functional>
......@@ -32,14 +32,15 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
using paddle::framework::VarDesc;
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph) {
PADDLE_ENFORCE(graph.Has(kStaleProgramOpDescs),
PADDLE_ENFORCE(graph.Has(details::kStaleProgramOpDescs),
"Graph has no attribute of kStaleProgramOpDescs.");
// 1. get op desc order
auto& op_descs = graph.Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs);
auto& op_descs =
graph.Get<const std::vector<OpDesc*>>(details::kStaleProgramOpDescs);
// 2. topology sort order
auto nodes = graph.Nodes();
......@@ -563,6 +564,6 @@ ir::Node* ControlFlowGraph::GetNodeByName(const std::string& name,
return found_node;
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -29,7 +29,7 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
/// this attribute is used to avoid some core variables removed/reused
/// in memory optimize related passes
......@@ -184,6 +184,6 @@ void FilterVariables(const Container& nodes, Callback callback) {
FilterVariableImpl<Container, Callback>()(nodes, callback);
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -11,8 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include <algorithm>
#include <iostream>
#include <iterator>
......@@ -32,7 +31,7 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
TEST(OrderedSet, Normal) {
OrderedSet pool;
......@@ -153,7 +152,7 @@ TEST(OrderedSet, FindBestFitNode) {
ASSERT_TRUE(cache == nullptr);
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -188,7 +187,7 @@ REGISTER_OPERATOR(dummy, paddle::framework::DummyOp,
namespace paddle {
namespace framework {
namespace details {
namespace ir {
inline static ProgramDesc FillProgramDesc() {
ProgramDesc prog;
......@@ -521,6 +520,6 @@ TEST(SortOpLikeDescOrder, AddAndReplaceOpDescInplace) {
}
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_pass.h"
#include <algorithm>
#include <atomic>
#include <deque>
......@@ -42,12 +42,12 @@ DEFINE_string(memory_optimize_debug, "",
namespace paddle {
namespace framework {
namespace details {
namespace ir {
void MemoryOptimizePass::ApplyImpl(ir::Graph* graph) const {
CollectSkipVarsSet(graph);
cfg_.reset(new details::ControlFlowGraph(*graph));
cfg_.reset(new ControlFlowGraph(*graph));
cfg_->LiveVariableAnalysis();
InitSSAGraphNodes();
......@@ -205,7 +205,7 @@ void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {
void MemoryOptimizePass::CollectSkipVarsSet(ir::Graph* graph) const {
// fill skip_set_
PADDLE_ENFORCE(graph->Has(details::kMemOptSkipVars));
PADDLE_ENFORCE(graph->Has(kMemOptSkipVars));
auto& mem_opt_whitelist = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
for (const auto& var : mem_opt_whitelist) {
skip_set_.emplace(var);
......@@ -316,10 +316,9 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
}
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(memory_optimize_pass,
paddle::framework::details::MemoryOptimizePass)
REGISTER_PASS(memory_optimize_pass, paddle::framework::ir::MemoryOptimizePass)
.RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
......@@ -26,13 +26,13 @@
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class MemoryOptimizePass : public ir::Pass {
protected:
......@@ -67,6 +67,6 @@ class MemoryOptimizePass : public ir::Pass {
mutable std::map<std::string, std::vector<ir::Node*>> var_nodes_;
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -12,17 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h"
#include <queue>
#include <utility>
namespace paddle {
namespace framework {
namespace details {
namespace ir {
OpGraphView::OpGraphView(const std::vector<OpHandleBase *> &ops) { Build(ops); }
OpGraphView::OpGraphView(const std::vector<details::OpHandleBase *> &ops) {
Build(ops);
}
void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
void OpGraphView::Build(const std::vector<details::OpHandleBase *> &ops) {
preceding_ops_.clear();
pending_ops_.clear();
for (auto &op : ops) {
......@@ -40,8 +42,8 @@ void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
"There are duplicate ops in graph.");
}
std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
std::unordered_set<OpHandleBase *> ret;
std::unordered_set<details::OpHandleBase *> OpGraphView::AllOps() const {
std::unordered_set<details::OpHandleBase *> ret;
ret.reserve(preceding_ops_.size());
for (auto &pair : preceding_ops_) {
ret.insert(pair.first);
......@@ -49,21 +51,21 @@ std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
return ret;
}
bool OpGraphView::HasOp(OpHandleBase *op) const {
bool OpGraphView::HasOp(details::OpHandleBase *op) const {
return preceding_ops_.count(op) != 0;
}
void OpGraphView::EnforceHasOp(OpHandleBase *op) const {
void OpGraphView::EnforceHasOp(details::OpHandleBase *op) const {
PADDLE_ENFORCE(HasOp(op), "Cannot find op %s in OpGraphView",
op == nullptr ? "nullptr" : op->DebugString());
}
const std::unordered_set<OpHandleBase *> &OpGraphView::PendingOps(
OpHandleBase *op) const {
const std::unordered_set<details::OpHandleBase *> &OpGraphView::PendingOps(
details::OpHandleBase *op) const {
EnforceHasOp(op);
return pending_ops_.at(op);
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -22,39 +22,42 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class OpGraphView {
public:
explicit OpGraphView(const std::vector<OpHandleBase *> &ops);
explicit OpGraphView(const std::vector<details::OpHandleBase *> &ops);
std::unordered_set<OpHandleBase *> AllOps() const;
std::unordered_set<details::OpHandleBase *> AllOps() const;
const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const;
const std::unordered_set<details::OpHandleBase *> &PendingOps(
details::OpHandleBase *op) const;
bool HasOp(OpHandleBase *op) const;
bool HasOp(details::OpHandleBase *op) const;
// Use a visitor to visit all pending ops of op
// Stop when callback returns false
template <typename Callback>
bool VisitAllPendingOps(OpHandleBase *op, Callback &&callback) const;
bool VisitAllPendingOps(details::OpHandleBase *op, Callback &&callback) const;
private:
void Build(const std::vector<OpHandleBase *> &ops);
void EnforceHasOp(OpHandleBase *op) const;
void Build(const std::vector<details::OpHandleBase *> &ops);
void EnforceHasOp(details::OpHandleBase *op) const;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
std::unordered_map<details::OpHandleBase *,
std::unordered_set<details::OpHandleBase *>>
preceding_ops_;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
std::unordered_map<details::OpHandleBase *,
std::unordered_set<details::OpHandleBase *>>
pending_ops_;
};
template <typename Callback>
bool OpGraphView::VisitAllPendingOps(OpHandleBase *op,
bool OpGraphView::VisitAllPendingOps(details::OpHandleBase *op,
Callback &&callback) const {
EnforceHasOp(op);
std::unordered_set<OpHandleBase *> visited;
std::queue<OpHandleBase *> q;
std::unordered_set<details::OpHandleBase *> visited;
std::queue<details::OpHandleBase *> q;
q.push(op);
while (!q.empty()) {
op = q.front();
......@@ -72,6 +75,6 @@ bool OpGraphView::VisitAllPendingOps(OpHandleBase *op,
return true;
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -15,16 +15,16 @@
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class RecordSkipMemoryOptVarsPass : public ir::Pass {
protected:
......@@ -162,9 +162,9 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass {
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(record_skip_memory_opt_vars_pass,
paddle::framework::details::RecordSkipMemoryOptVarsPass);
paddle::framework::ir::RecordSkipMemoryOptVarsPass);
......@@ -24,14 +24,20 @@
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class ReferenceCountPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override;
};
// A functor to shrink/remove operators who depend on other operators in a set
class ShrinkDepsOpFunctor {
......@@ -39,19 +45,21 @@ class ShrinkDepsOpFunctor {
enum RelationShip { kSame = 0, kNoDeps = 1, kBefore = 2, kAfter = 3 };
public:
explicit ShrinkDepsOpFunctor(const std::vector<OpHandleBase *> &all_ops)
explicit ShrinkDepsOpFunctor(
const std::vector<details::OpHandleBase *> &all_ops)
: graph_(all_ops) {}
template <typename OpSet>
OpSet operator()(const OpSet &op_set) const {
using KeyType = typename OpSet::key_type;
static_assert(
std::is_base_of<OpHandleBase,
std::is_base_of<details::OpHandleBase,
typename std::remove_pointer<KeyType>::type>::value,
"Key type of OpSet must be OpHandleBase, or derived of OpHandleBase");
"Key type of OpSet must be details::OpHandleBase, or derived of "
"details::OpHandleBase");
if (op_set.size() <= 1) return op_set;
std::vector<OpHandleBase *> ops(op_set.begin(), op_set.end());
std::vector<details::OpHandleBase *> ops(op_set.begin(), op_set.end());
OpSet ret;
auto rels = GetRelations(ops);
auto not_before = [](RelationShip r) { return r != kBefore; };
......@@ -65,8 +73,8 @@ class ShrinkDepsOpFunctor {
private:
std::vector<std::vector<RelationShip>> GetRelations(
const std::vector<OpHandleBase *> &ops) const {
std::unordered_map<OpHandleBase *, size_t> op_to_idx;
const std::vector<details::OpHandleBase *> &ops) const {
std::unordered_map<details::OpHandleBase *, size_t> op_to_idx;
for (size_t i = 0; i < ops.size(); ++i) {
PADDLE_ENFORCE(graph_.HasOp(ops[i]), "Op does not exist in graph");
op_to_idx[ops[i]] = i;
......@@ -81,7 +89,7 @@ class ShrinkDepsOpFunctor {
size_t found_num = ops.size();
size_t total_num = ops.size() * ops.size();
auto visitor = [&](OpHandleBase *op, size_t i) {
auto visitor = [&](details::OpHandleBase *op, size_t i) {
auto it = op_to_idx.find(op);
if (it != op_to_idx.end()) {
size_t j = it->second;
......@@ -98,7 +106,9 @@ class ShrinkDepsOpFunctor {
};
for (size_t i = 0; i < ops.size(); ++i) {
auto sub_visitor = [&, i](OpHandleBase *op) { return visitor(op, i); };
auto sub_visitor = [&, i](details::OpHandleBase *op) {
return visitor(op, i);
};
if (!graph_.VisitAllPendingOps(ops[i], sub_visitor)) {
break;
}
......@@ -133,8 +143,8 @@ class ShrinkDepsOpFunctor {
*/
static bool ShrinkNoNeedBufferVarOpDependency(
const std::string &var_name,
std::unordered_set<ComputationOpHandle *> *op_handles) {
std::vector<ComputationOpHandle *> skip_ops;
std::unordered_set<details::ComputationOpHandle *> *op_handles) {
std::vector<details::ComputationOpHandle *> skip_ops;
for (auto *op_handle : *op_handles) {
auto *op_base = op_handle->GetOp();
auto &inferer = op_base->Info().NoNeedBufferVarsInferer();
......@@ -195,15 +205,15 @@ static bool ShrinkNoNeedBufferVarOpDependency(
* Find the nearest downstream computation op handle. If the op is a
* computation op, just return itself.
*/
static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
OpHandleBase *op, size_t scope_idx) {
std::queue<OpHandleBase *> q;
std::unordered_set<OpHandleBase *> visited;
static details::ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
details::OpHandleBase *op, size_t scope_idx) {
std::queue<details::OpHandleBase *> q;
std::unordered_set<details::OpHandleBase *> visited;
q.push(op);
while (!q.empty()) {
auto *op = q.front();
q.pop();
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
auto *compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) {
return compute_op;
}
......@@ -220,13 +230,13 @@ static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
enum LastLiveOpSearchStatus { kSuccess, kFailure, kShouldPrecede };
static std::unordered_set<ComputationOpHandle *>
ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
static std::unordered_set<details::ComputationOpHandle *>
ExtractComputationOpFromLastLivedVar(details::VarHandle *var, size_t scope_idx,
const std::string &var_name,
const ShrinkDepsOpFunctor &shrink_func,
LastLiveOpSearchStatus *status) {
// stage one. Get last op for variable.
std::unordered_set<OpHandleBase *> candidates;
std::unordered_set<details::OpHandleBase *> candidates;
{
if (var->PendingOps().empty() && var->GeneratedOp()) {
// No operator depends on this variable. So the last operator is the op
......@@ -251,7 +261,7 @@ ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
// some op handle may operate on many DeviceContext, however, our garbage
// collector can only wait one DeviceContext for now. So currently, we wait
// the nearest compute op.
std::unordered_set<ComputationOpHandle *> computation_op;
std::unordered_set<details::ComputationOpHandle *> computation_op;
{
for (auto *op : candidates) {
auto *compute_op =
......@@ -293,13 +303,13 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
"Last Live Ops and Reference Counts of vars should be "
"initialized at here.");
const auto &vars = graph->Get<GraphVars>(kGraphVars);
const auto &vars = graph->Get<details::GraphVars>(details::kGraphVars);
last_live_ops_of_vars.resize(vars.size());
ref_cnts.resize(vars.size());
ShrinkDepsOpFunctor shrink_func(
ir::FilterByNodeWrapper<OpHandleBase>(*graph));
ir::FilterByNodeWrapper<details::OpHandleBase>(*graph));
VLOG(1) << "Place number: " << vars.size();
for (size_t i = 0; i < vars.size(); ++i) {
......@@ -360,11 +370,10 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
}
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(reference_count_pass,
paddle::framework::details::ReferenceCountPass)
.RequirePassAttr(paddle::framework::details::kGlobalReferenceCount)
.RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars);
REGISTER_PASS(reference_count_pass, paddle::framework::ir::ReferenceCountPass)
.RequirePassAttr(paddle::framework::ir::kGlobalReferenceCount)
.RequirePassAttr(paddle::framework::ir::kLastLiveOpsOfVars);
......@@ -12,23 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/var_desc.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars) {
VarDesc *TryGetLatestVarDesc(const std::vector<details::VarHandle *> &vars) {
VarDesc *var_desc = nullptr;
std::find_if(vars.rbegin(), vars.rend(), [&](VarHandle *var_handle) -> bool {
var_desc = var_handle->Node()->Var();
return var_desc != nullptr;
});
std::find_if(vars.rbegin(), vars.rend(),
[&](details::VarHandle *var_handle) -> bool {
var_desc = var_handle->Node()->Var();
return var_desc != nullptr;
});
return var_desc;
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -22,17 +22,16 @@
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/garbage_collector.h"
namespace paddle {
namespace framework {
class VarDesc;
class VarHandle;
namespace details {
class ComputationOpHandle;
namespace ir {
using ReferenceCountMap = std::unordered_map<std::string, size_t>;
......@@ -48,11 +47,12 @@ const char kGarbageCollector[] = "garbage_collector";
const char kAllPlaces[] = "all_places";
using LastLiveOpsOfVars =
std::unordered_map<std::string, std::unordered_set<ComputationOpHandle *>>;
std::unordered_map<std::string,
std::unordered_set<details::ComputationOpHandle *>>;
const char kLastLiveOpsOfVars[] = "last_live_ops_of_var";
VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars);
VarDesc *TryGetLatestVarDesc(const std::vector<details::VarHandle *> &vars);
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -19,19 +19,19 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class WhileOpEagerDeletionPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
auto all_ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph);
// Find all while_op and while_grad_op
std::unordered_map<size_t, std::pair<std::vector<OperatorBase *>,
std::vector<OperatorBase *>>>
target_ops;
for (auto *op : all_ops) {
auto compute_op = dynamic_cast<ComputationOpHandle *>(op);
auto compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
if (compute_op == nullptr) continue;
if (compute_op->Name() == "while") {
......@@ -52,9 +52,9 @@ class WhileOpEagerDeletionPass : public ir::Pass {
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(while_op_eager_deletion_pass,
paddle::framework::details::WhileOpEagerDeletionPass);
paddle::framework::ir::WhileOpEagerDeletionPass);
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper)
cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper)
set(ALL_REDUCE_OP_HANDLES all_reduce_op_handle)
if(WITH_GPU AND WITH_DGC)
list(APPEND ALL_REDUCE_OP_HANDLES sparse_all_reduce_op_handle)
endif()
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle ${ALL_REDUCE_OP_HANDLES} reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS all_reduce_op_handle graph graph_helper pass)
......@@ -23,7 +23,6 @@
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
......@@ -31,17 +30,18 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class AllReduceDepsPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override {
std::vector<AllReduceOpHandle*> all_reduce_op_handles =
std::vector<details::AllReduceOpHandle*> all_reduce_op_handles =
GetSortedAllReduceOps(*graph);
for (size_t i = 1; i < all_reduce_op_handles.size(); ++i) {
auto* dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
auto* dep_var = new details::DummyVarHandle(graph->CreateControlDepVar());
graph->Get<details::GraphDepVars>(details::kGraphDepVars)
.emplace(dep_var);
all_reduce_op_handles[i - 1]->AddOutput(dep_var);
all_reduce_op_handles[i]->AddInput(dep_var);
}
......@@ -51,16 +51,16 @@ class AllReduceDepsPass : public ir::Pass {
}
}
std::vector<AllReduceOpHandle*> GetSortedAllReduceOps(
std::vector<details::AllReduceOpHandle*> GetSortedAllReduceOps(
const ir::Graph& graph) const {
std::vector<AllReduceOpHandle*> all_reduce_op_handles;
std::unordered_map<OpHandleBase*, size_t> pending_ops;
std::unordered_set<OpHandleBase*> ready_ops;
std::unordered_set<OpHandleBase*> next_ready_ops;
std::vector<details::AllReduceOpHandle*> all_reduce_op_handles;
std::unordered_map<details::OpHandleBase*, size_t> pending_ops;
std::unordered_set<details::OpHandleBase*> ready_ops;
std::unordered_set<details::OpHandleBase*> next_ready_ops;
auto op_handles = ir::FilterByNodeWrapper<OpHandleBase>(graph);
auto op_handles = ir::FilterByNodeWrapper<details::OpHandleBase>(graph);
size_t num_of_ops = op_handles.size();
for (OpHandleBase* op : op_handles) {
for (details::OpHandleBase* op : op_handles) {
size_t not_ready_vars = op->NotReadyInputSize();
if (not_ready_vars) {
pending_ops.insert({op, not_ready_vars});
......@@ -94,11 +94,12 @@ class AllReduceDepsPass : public ir::Pass {
}
void GetSortedAllReduceOps(
const std::unordered_set<OpHandleBase*>& ready_ops,
std::vector<AllReduceOpHandle*>* all_reduce_op_handles) const {
std::vector<AllReduceOpHandle*> current_all_reduce_op_handles;
const std::unordered_set<details::OpHandleBase*>& ready_ops,
std::vector<details::AllReduceOpHandle*>* all_reduce_op_handles) const {
std::vector<details::AllReduceOpHandle*> current_all_reduce_op_handles;
for (auto& op_handle : ready_ops) {
auto all_reduce_op_handle = dynamic_cast<AllReduceOpHandle*>(op_handle);
auto all_reduce_op_handle =
dynamic_cast<details::AllReduceOpHandle*>(op_handle);
if (all_reduce_op_handle) {
current_all_reduce_op_handles.emplace_back(all_reduce_op_handle);
}
......@@ -109,10 +110,12 @@ class AllReduceDepsPass : public ir::Pass {
// Sort the current_all_reduce_op_handles according to the name of input.
sort(current_all_reduce_op_handles.begin(),
current_all_reduce_op_handles.end(),
[](const AllReduceOpHandle* left,
const AllReduceOpHandle* right) -> bool {
auto left_in_vars = DynamicCast<VarHandle>(left->Inputs());
auto right_in_vars = DynamicCast<VarHandle>(right->Inputs());
[](const details::AllReduceOpHandle* left,
const details::AllReduceOpHandle* right) -> bool {
auto left_in_vars =
details::DynamicCast<details::VarHandle>(left->Inputs());
auto right_in_vars =
details::DynamicCast<details::VarHandle>(right->Inputs());
PADDLE_ENFORCE_GT(left_in_vars.size(), 0);
PADDLE_ENFORCE_EQ(left_in_vars.size(), right_in_vars.size());
return left_in_vars[0]->Name() > right_in_vars[0]->Name();
......@@ -123,15 +126,15 @@ class AllReduceDepsPass : public ir::Pass {
current_all_reduce_op_handles.end());
}
void DebugString(
const ir::Graph& graph,
const std::vector<AllReduceOpHandle*>& all_reduce_op_handles) const {
void DebugString(const ir::Graph& graph,
const std::vector<details::AllReduceOpHandle*>&
all_reduce_op_handles) const {
// get vars order
std::map<int, std::vector<std::string>> vars =
GetSoredGradientsFromStaleProgram(graph);
std::stringstream out;
size_t grads_of_stale_program = 0;
out << "Get Order From kStaleProgramOpDescs: ";
out << "Get Order From details::kStaleProgramOpDescs: ";
for (auto& var : vars) {
out << "Order " << var.first << " [";
for (auto& var_name : var.second) {
......@@ -147,7 +150,7 @@ class AllReduceDepsPass : public ir::Pass {
for (auto& op : all_reduce_op_handles) {
bool find_valid_input = false;
for (auto& in_var : op->Inputs()) {
if (dynamic_cast<VarHandle*>(in_var)) {
if (dynamic_cast<details::VarHandle*>(in_var)) {
out2 << in_var->Name() << ", ";
find_valid_input = true;
break;
......@@ -165,7 +168,8 @@ class AllReduceDepsPass : public ir::Pass {
std::map<int, std::vector<std::string>> GetSoredGradientsFromStaleProgram(
const ir::Graph& graph) const {
std::map<int, std::vector<std::string>> vars;
auto ops = graph.Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs);
auto ops =
graph.Get<const std::vector<OpDesc*>>(details::kStaleProgramOpDescs);
int order = 0;
for (auto* op_desc : ops) {
try {
......@@ -193,10 +197,9 @@ class AllReduceDepsPass : public ir::Pass {
return vars;
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(all_reduce_deps_pass,
paddle::framework::details::AllReduceDepsPass)
REGISTER_PASS(all_reduce_deps_pass, paddle::framework::ir::AllReduceDepsPass)
.RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
......@@ -24,21 +24,22 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class FuseAllReduceOpPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(kLocalScopes);
auto &places = Get<const std::vector<platform::Place>>(details::kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(details::kLocalScopes);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto *nccl_ctxs = &Get<platform::NCCLContextMap>(kNCCLCtxs);
auto *nccl_ctxs = &Get<platform::NCCLContextMap>(details::kNCCLCtxs);
#endif
std::unordered_set<std::string> grads;
auto &params_grads = result.Get<ParamsAndGrads>(kParamsAndGrads);
auto &params_grads =
result.Get<details::ParamsAndGrads>(details::kParamsAndGrads);
size_t num_of_all_reduce = params_grads.size();
grads.reserve(num_of_all_reduce);
for (auto p_g : params_grads) {
......@@ -50,11 +51,12 @@ class FuseAllReduceOpPass : public ir::Pass {
all_reduce_ops.reserve(grads.size());
for (auto &node : result.Nodes()) {
if (node->IsOp()) {
PADDLE_ENFORCE(node->IsWrappedBy<OpHandleBase>());
auto *all_reduce_op_handle =
dynamic_cast<AllReduceOpHandle *>(&node->Wrapper<OpHandleBase>());
PADDLE_ENFORCE(node->IsWrappedBy<details::OpHandleBase>());
auto *all_reduce_op_handle = dynamic_cast<details::AllReduceOpHandle *>(
&node->Wrapper<details::OpHandleBase>());
if (all_reduce_op_handle) {
auto inputs = DynamicCast<VarHandle>(all_reduce_op_handle->Inputs());
auto inputs = details::DynamicCast<details::VarHandle>(
all_reduce_op_handle->Inputs());
PADDLE_ENFORCE_EQ(inputs.size(), num_place);
// The inputs' name should be the same.
auto &grad_name = inputs[0]->name();
......@@ -80,7 +82,7 @@ class FuseAllReduceOpPass : public ir::Pass {
VLOG(10) << "Insert fused_all_reduce";
auto &group_grads_params =
graph->Get<GroupGradsAndParams>(kGroupGradsAndParams);
graph->Get<details::GroupGradsAndParams>(details::kGroupGradsAndParams);
for (auto &group_g_p : group_grads_params) {
size_t group_size = group_g_p.size();
......@@ -108,24 +110,25 @@ class FuseAllReduceOpPass : public ir::Pass {
const platform::NCCLContextMap *nccl_ctxs,
#endif
ir::Graph *result) const {
std::vector<VarHandleBase *> inputs;
std::vector<VarHandleBase *> outputs;
std::vector<details::VarHandleBase *> inputs;
std::vector<details::VarHandleBase *> outputs;
for (auto &op : all_reduce_ops) {
auto &op_handle = op->Wrapper<OpHandleBase>();
auto &op_handle = op->Wrapper<details::OpHandleBase>();
inputs.insert(inputs.end(), op_handle.Inputs().begin(),
op_handle.Inputs().end());
// Remove output
for_each(op_handle.Inputs().begin(), op_handle.Inputs().end(),
[&op_handle](VarHandleBase *var_handle) {
[&op_handle](details::VarHandleBase *var_handle) {
var_handle->RemoveOutput(&op_handle, op_handle.Node());
});
outputs.insert(outputs.end(), op_handle.Outputs().begin(),
op_handle.Outputs().end());
// Remove Input
for_each(
op_handle.Outputs().begin(), op_handle.Outputs().end(),
[](VarHandleBase *var_handle) { var_handle->ClearGeneratedOp(); });
for_each(op_handle.Outputs().begin(), op_handle.Outputs().end(),
[](details::VarHandleBase *var_handle) {
var_handle->ClearGeneratedOp();
});
result->RemoveNode(op_handle.Node());
}
......@@ -140,21 +143,22 @@ class FuseAllReduceOpPass : public ir::Pass {
}
private:
void CreateFusedAllReduceOp(const std::vector<VarHandleBase *> &inputs,
const std::vector<VarHandleBase *> &outputs,
const size_t num_of_all_reduce,
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
void CreateFusedAllReduceOp(
const std::vector<details::VarHandleBase *> &inputs,
const std::vector<details::VarHandleBase *> &outputs,
const size_t num_of_all_reduce,
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const platform::NCCLContextMap *nccl_ctxs,
const platform::NCCLContextMap *nccl_ctxs,
#endif
ir::Graph *result) const {
ir::Graph *result) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto *op_handle = new FusedAllReduceOpHandle(
auto *op_handle = new details::FusedAllReduceOpHandle(
result->CreateEmptyNode("fused_all_reduce", ir::Node::Type::kOperation),
local_scopes, places, num_of_all_reduce, nccl_ctxs);
#else
auto *op_handle = new FusedAllReduceOpHandle(
auto *op_handle = new details::FusedAllReduceOpHandle(
result->CreateEmptyNode("fused_all_reduce", ir::Node::Type::kOperation),
local_scopes, places, num_of_all_reduce);
#endif
......@@ -176,8 +180,9 @@ class FuseAllReduceOpPass : public ir::Pass {
#endif
}
void SetCommunicationContext(const std::vector<platform::Place> &places,
FusedAllReduceOpHandle *op_handle) const {
void SetCommunicationContext(
const std::vector<platform::Place> &places,
details::FusedAllReduceOpHandle *op_handle) const {
for (size_t i = 0; i < places.size(); ++i) {
op_handle->SetDeviceContext(
places[i], platform::DeviceContextPool::Instance().Get(places[i]));
......@@ -185,9 +190,9 @@ class FuseAllReduceOpPass : public ir::Pass {
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_all_reduce_op_pass,
paddle::framework::details::FuseAllReduceOpPass);
paddle::framework::ir::FuseAllReduceOpPass);
......@@ -12,21 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
static bool IsLockAndRecordEventFreeComputationOpHandle(
ComputationOpHandle *op, const OpGraphView &graph_view) {
details::ComputationOpHandle *op, const OpGraphView &graph_view) {
if (!platform::is_gpu_place(op->GetPlace())) return false;
for (auto &pending_op : graph_view.PendingOps(op)) {
auto *tmp = dynamic_cast<ComputationOpHandle *>(pending_op);
auto *tmp = dynamic_cast<details::ComputationOpHandle *>(pending_op);
if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) {
return false;
}
......@@ -34,25 +33,27 @@ static bool IsLockAndRecordEventFreeComputationOpHandle(
return true;
}
void ModifyOpLockAndRecordEventPass::ApplyImpl(ir::Graph *ir_graph) const {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*ir_graph);
OpGraphView graph_view(all_ops);
for (auto &op : all_ops) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op == nullptr) continue;
bool is_lock_and_record_event_free =
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view);
compute_op->SetLockAndRecordEventFree(is_lock_and_record_event_free);
if (is_lock_and_record_event_free) {
VLOG(10) << "Set is_lock_and_record_event_free be true in op "
<< compute_op->DebugString();
class ModifyOpLockAndRecordEventPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
auto all_ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph);
OpGraphView graph_view(all_ops);
for (auto &op : all_ops) {
auto *compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
if (compute_op == nullptr) continue;
bool is_lock_and_record_event_free =
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view);
compute_op->SetLockAndRecordEventFree(is_lock_and_record_event_free);
if (is_lock_and_record_event_free) {
VLOG(10) << "Set is_lock_and_record_event_free be true in op "
<< compute_op->DebugString();
}
}
}
}
} // namespace details
};
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(modify_op_lock_and_record_event_pass,
paddle::framework::details::ModifyOpLockAndRecordEventPass);
paddle::framework::ir::ModifyOpLockAndRecordEventPass);
......@@ -19,7 +19,7 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class SSAGraghBuilderWithChecker : public ir::Pass {
protected:
......@@ -28,19 +28,19 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
}
bool IsValidGraph(const ir::Graph *graph) const {
std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_set<VarHandleBase *> pending_vars;
std::unordered_set<VarHandleBase *> ready_vars;
std::unordered_set<OpHandleBase *> ready_ops;
std::unordered_map<details::OpHandleBase *, size_t> pending_ops;
std::unordered_set<details::VarHandleBase *> pending_vars;
std::unordered_set<details::VarHandleBase *> ready_vars;
std::unordered_set<details::OpHandleBase *> ready_ops;
auto insert_pending_var = [&](VarHandleBase *var) {
auto insert_pending_var = [&](details::VarHandleBase *var) {
pending_vars.insert(var);
if (var->GeneratedOp() == nullptr) {
ready_vars.emplace(var);
}
};
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &var_map : graph->Get<details::GraphVars>(details::kGraphVars)) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair);
......@@ -48,11 +48,12 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
}
}
for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) {
for (auto &var :
graph->Get<details::GraphDepVars>(details::kGraphDepVars)) {
insert_pending_var(var);
}
for (OpHandleBase *op : ir::FilterByNodeWrapper<OpHandleBase>(*graph)) {
for (auto *op : ir::FilterByNodeWrapper<details::OpHandleBase>(*graph)) {
if (op->Inputs().empty()) {
ready_ops.insert(op);
} else {
......@@ -60,7 +61,7 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
}
}
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
auto run_all_ops = [&](std::unordered_set<details::OpHandleBase *> &set) {
for (auto *op : set) {
for (auto out : op->Outputs()) {
ready_vars.emplace(out);
......@@ -91,11 +92,11 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(multi_devices_check_pass,
paddle::framework::details::SSAGraghBuilderWithChecker)
paddle::framework::ir::SSAGraghBuilderWithChecker)
.RequireGraphAttr(paddle::framework::details::kGraphVars)
.RequireGraphAttr(paddle::framework::details::kGraphDepVars);
......@@ -31,7 +31,7 @@ class NCCLContextMap;
namespace framework {
class Scope;
namespace details {
namespace ir {
constexpr char kLossVarName[] = "loss_var_name";
constexpr char kStrategy[] = "strategy";
......@@ -69,8 +69,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
ir::Node *out_var_node, size_t loss_scale,
proto::VarType::Type dtype) const;
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
size_t dst_dev_id) const;
details::VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
size_t dst_dev_id) const;
void CreateComputationalOp(ir::Graph *result, ir::Node *node,
size_t dev_id) const;
......@@ -89,7 +89,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
ir::Graph *result,
const std::vector<std::unordered_set<std::string>> &bcast_varnames) const;
void SetCommunicationContext(OpHandleBase *op_handle,
void SetCommunicationContext(details::OpHandleBase *op_handle,
const platform::Place &p) const;
void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
......@@ -103,7 +103,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
mutable std::vector<platform::Place> places_;
mutable std::vector<Scope *> local_scopes_;
mutable BuildStrategy strategy_;
mutable details::BuildStrategy strategy_;
mutable std::unordered_map<std::string, VarDesc *> all_vars_;
};
......@@ -209,6 +209,6 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
std::unordered_set<std::string> &MultiDevSSAGraphBuilder();
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
#include <memory>
#include <string>
#include <unordered_map>
......@@ -21,11 +21,21 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class SSAGraghBuilderWithPrinterPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>(kGraphvizPath)));
PADDLE_ENFORCE(fout->good());
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
}
};
template <typename Callback>
static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
for (auto &each : graph.Get<GraphVars>(kGraphVars)) {
for (auto &each : graph.Get<details::GraphVars>(details::kGraphVars)) {
for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) {
callback(*pair2);
......@@ -33,7 +43,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
}
}
for (auto &var : graph.Get<GraphDepVars>(kGraphDepVars)) {
for (auto &var : graph.Get<details::GraphDepVars>(details::kGraphDepVars)) {
callback(*var);
}
}
......@@ -41,14 +51,14 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
std::ostream &sout) const {
size_t var_id = 0;
std::unordered_map<const VarHandleBase *, size_t> vars;
std::unordered_map<const details::VarHandleBase *, size_t> vars;
sout << "digraph G {\n";
IterAllVar(graph, [&](const VarHandleBase &var) {
IterAllVar(graph, [&](const details::VarHandleBase &var) {
auto *var_ptr = &var;
auto *var_handle_ptr = dynamic_cast<const VarHandle *>(var_ptr);
auto *dummy_ptr = dynamic_cast<const DummyVarHandle *>(var_ptr);
auto *var_handle_ptr = dynamic_cast<const details::VarHandle *>(var_ptr);
auto *dummy_ptr = dynamic_cast<const details::DummyVarHandle *>(var_ptr);
size_t cur_var_id = var_id++;
vars[var_ptr] = cur_var_id;
......@@ -65,7 +75,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
});
size_t op_id = 0;
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(graph)) {
for (auto &op : ir::FilterByNodeWrapper<details::OpHandleBase>(graph)) {
std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl;
......@@ -82,10 +92,10 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
sout << "}\n";
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(multi_devices_print_pass,
paddle::framework::details::SSAGraghBuilderWithPrinter)
.RequirePassAttr(paddle::framework::details::kGraphvizPath);
paddle::framework::ir::SSAGraghBuilderWithPrinterPass)
.RequirePassAttr(paddle::framework::ir::kGraphvizPath);
......@@ -24,7 +24,7 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
constexpr char kGraphvizPath[] = "debug_graphviz_path";
......@@ -39,16 +39,6 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
void Print(const ir::Graph& graph, std::ostream& sout) const override;
};
class SSAGraghBuilderWithPrinter : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override {
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>(kGraphvizPath)));
PADDLE_ENFORCE(fout->good());
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace ir {
static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) {
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
op1->Outputs() == op2->Outputs();
}
class SequentialExecutionPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
// FIXME(zjl): Insert dependencies between some distributed ops may cause
// the multi_devices_graph_pass fails. So we skip these ops here.
// Indeed, maybe we should not insert dependencies between these ops
// casually, which may cause deadlock easily.
// We should add more skipped distributed ops when found errors in
// multi_devices_graph_pass
static std::unordered_set<std::string> skip_dist_ops{
"send", "recv", "send_barrier", "fetch_barrier"};
auto &ops =
graph->Get<const std::vector<OpDesc *>>(details::kStaleProgramOpDescs);
std::vector<ir::Node *> op_node_list;
op_node_list.reserve(ops.size());
std::unordered_map<ir::Node *, size_t> op_deps;
std::unordered_map<ir::Node *, std::unordered_set<ir::Node *>> pending_ops;
std::unordered_set<ir::Node *> ready_ops;
for (ir::Node *node : graph->Nodes()) {
if (!node->IsOp()) continue;
std::unordered_set<ir::Node *> preceding_ops;
for (auto *in : node->inputs) {
PADDLE_ENFORCE(in->IsVar(),
"Preceding Node of Op Nodes must be Var Node");
if (in->inputs.empty()) continue;
PADDLE_ENFORCE(in->inputs.size() == 1 && in->inputs[0]->IsOp(),
"Preceding Op Node of Var Node must be unique");
preceding_ops.insert(in->inputs[0]);
pending_ops[in->inputs[0]].insert(node);
}
op_deps[node] = preceding_ops.size();
if (preceding_ops.empty()) {
ready_ops.insert(node);
}
}
for (auto *op_desc : ops) {
ir::Node *found_node = nullptr;
for (auto *node : ready_ops) {
if (IsSameOpDesc(op_desc, node->Op())) {
PADDLE_ENFORCE(found_node == nullptr,
"Found multiple op_desc in graph: %s",
op_desc->Type());
found_node = node;
}
}
PADDLE_ENFORCE_NOT_NULL(found_node, "Cannot find op_desc in graph: %s",
op_desc->Type());
for (auto *pending_op : pending_ops[found_node]) {
if (--op_deps.at(pending_op) == 0) {
ready_ops.insert(pending_op);
}
}
ready_ops.erase(found_node);
if (skip_dist_ops.count(op_desc->Type()) == 0) {
op_node_list.push_back(found_node);
}
}
for (size_t i = 1; i < op_node_list.size(); ++i) {
auto *dep_var = graph->CreateControlDepVar();
op_node_list[i]->inputs.push_back(dep_var);
op_node_list[i - 1]->outputs.push_back(dep_var);
dep_var->outputs.push_back(op_node_list[i]);
dep_var->inputs.push_back(op_node_list[i - 1]);
VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name()
<< " and " << op_node_list[i]->Name();
}
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(sequential_execution_pass,
paddle::framework::ir::SequentialExecutionPass)
.RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
......@@ -12,30 +12,32 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/sync_batch_norm_pass.h"
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
void SyncBatchNormPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Use synchronous batch norm";
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
auto* op = n->Op();
if (op->Type() == "batch_norm") {
op->SetType("sync_batch_norm");
}
if (op->Type() == "batch_norm_grad") {
op->SetType("sync_batch_norm_grad");
class SyncBatchNormPass : public Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
VLOG(3) << "Use synchronous batch norm";
for (const Node *n : graph->Nodes()) {
if (n->IsOp()) {
auto *op = n->Op();
if (op->Type() == "batch_norm") {
op->SetType("sync_batch_norm");
}
if (op->Type() == "batch_norm_grad") {
op->SetType("sync_batch_norm_grad");
}
}
}
}
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class SyncBatchNormPass : public Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/sync_batch_norm_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
namespace ir {
......
......@@ -23,11 +23,11 @@ limitations under the License. */
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef WITH_GPERFTOOLS
......@@ -110,9 +110,9 @@ class ParallelExecutorPrivate {
// global_ref_cnts_ is only initialized when ParallelExecutor constructs, and
// then keeps unchanged
// Before each iteration, runtime_ref_cnts_ is reset to global_ref_cnts_
std::vector<details::ReferenceCountMap> global_ref_cnts_;
std::vector<details::AtomicReferenceCountMap> runtime_ref_cnts_;
details::GarbageCollectorMap gcs_;
std::vector<ir::ReferenceCountMap> global_ref_cnts_;
std::vector<ir::AtomicReferenceCountMap> runtime_ref_cnts_;
ir::GarbageCollectorMap gcs_;
};
ir::Graph *ParallelExecutorPrivate::PrepareGCAndRefCnts(
......@@ -150,25 +150,23 @@ ir::Graph *ParallelExecutorPrivate::PrepareGCAndRefCnts(
}
if (!gcs_.empty()) {
std::vector<details::LastLiveOpsOfVars> last_live_ops_of_vars;
std::vector<ir::LastLiveOpsOfVars> last_live_ops_of_vars;
auto ref_cnt_pass =
ir::PassRegistry::Instance().Get("reference_count_pass");
ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount,
&global_ref_cnts_);
ref_cnt_pass->SetNotOwned(details::kLastLiveOpsOfVars,
&last_live_ops_of_vars);
ref_cnt_pass->SetNotOwned(ir::kGlobalReferenceCount, &global_ref_cnts_);
ref_cnt_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars);
graph = ref_cnt_pass->Apply(graph);
VLOG(10) << "ReferenceCountPass Applied";
auto eager_deletion_pass =
ir::PassRegistry::Instance().Get("eager_deletion_pass");
eager_deletion_pass->SetNotOwned(details::kRuntimeReferenceCount,
eager_deletion_pass->SetNotOwned(ir::kRuntimeReferenceCount,
&runtime_ref_cnts_);
eager_deletion_pass->SetNotOwned(details::kGarbageCollector, &gcs_);
eager_deletion_pass->SetNotOwned(details::kLastLiveOpsOfVars,
eager_deletion_pass->SetNotOwned(ir::kGarbageCollector, &gcs_);
eager_deletion_pass->SetNotOwned(ir::kLastLiveOpsOfVars,
&last_live_ops_of_vars);
eager_deletion_pass->SetNotOwned(details::kAllPlaces, &places_);
eager_deletion_pass->SetNotOwned(ir::kAllPlaces, &places_);
graph = eager_deletion_pass->Apply(graph);
VLOG(10) << "EagerDeletionPass Applied";
}
......
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
......@@ -34,7 +34,7 @@ void BindConstValue(pybind11::module* m) {
m->def("kControlDepVarName",
[] { return framework::ir::Node::kControlDepVarName; });
m->def("kNewGradSuffix", [] { return framework::kNewGradSuffix; });
m->def("kMemOptSkipVars", [] { return framework::details::kMemOptSkipVars; });
m->def("kMemOptSkipVars", [] { return framework::ir::kMemOptSkipVars; });
auto op_proto_and_checker_maker =
m->def_submodule("op_proto_and_checker_maker");
......
......@@ -21,11 +21,11 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/alloc_continuous_space_for_grad_pass.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
......@@ -170,9 +170,9 @@ PYBIND11_MODULE(core, m) {
m.def("_set_eager_deletion_mode", &paddle::framework::SetEagerDeletionMode);
m.def("_set_fuse_parameter_group_size",
&paddle::framework::details::SetFuseParameterGroupsSize);
&paddle::framework::ir::SetFuseParameterGroupsSize);
m.def("_set_fuse_parameter_memory_size",
&paddle::framework::details::SetFuseParameterMemorySize);
&paddle::framework::ir::SetFuseParameterMemorySize);
m.add_object("_cleanup",
py::capsule([]() { ScopePool::Instance().Clear(); }));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册