提交 658c5432 编写于 作者: S Superjomn

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into design/lite2

......@@ -25,8 +25,9 @@ endif()
if(ANAKIN_FOUND)
message(STATUS "Current ANAKIN header is ${ANAKIN_INCLUDE_DIR}/anakin_config.h. ")
include_directories(${ANAKIN_ROOT})
include_directories(${ANAKIN_ROOT}/include)
include_directories(${ANAKIN_ROOT}/include/saber)
include_directories(${ANAKIN_ROOT}/saber)
link_directories(${ANAKIN_ROOT})
add_definitions(-DPADDLE_WITH_ANAKIN)
endif()
......@@ -77,6 +77,7 @@ else(WIN32)
ENDIF(WIN32)
MESSAGE(STATUS "warp-ctc library: ${WARPCTC_LIBRARIES}")
get_filename_component(WARPCTC_LIBRARY_PATH ${WARPCTC_LIBRARIES} DIRECTORY)
INCLUDE_DIRECTORIES(${WARPCTC_INCLUDE_DIR}) # For warpctc code to include its headers.
INCLUDE_DIRECTORIES(${THIRD_PARTY_PATH}/install) # For Paddle code to include warpctc headers.
......
......@@ -3,6 +3,8 @@ set(PADDLE_VERSION $ENV{PADDLE_VERSION})
set(tmp_version "HEAD")
set(TAG_VERSION_REGEX "[0-9]+\\.[0-9]+\\.[0-9]+(\\.(a|b|rc)\\.[0-9]+)?")
set(COMMIT_VERSION_REGEX "[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+")
set(LATEST_PADDLE_VERSION "latest")
while ("${PADDLE_VERSION}" STREQUAL "")
# Check current branch name
execute_process(
......@@ -23,8 +25,8 @@ while ("${PADDLE_VERSION}" STREQUAL "")
if (${GIT_BRANCH_NAME} MATCHES "release/${TAG_VERSION_REGEX}")
# Check the tag is a correct version
if (${GIT_TAG_NAME} MATCHES "${COMMIT_VERSION_REGEX}")
# if no tag was found, set PADDLE_VERSION to 0.0.0 to represent latest
set(PADDLE_VERSION "0.0.0")
# if no tag was found, set PADDLE_VERSION to "latest"
set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}")
elseif (${GIT_TAG_NAME} MATCHES "v${TAG_VERSION_REGEX}")
string(REPLACE "v" "" PADDLE_VERSION ${GIT_TAG_NAME})
else() # otherwise, get the previous git tag name.
......@@ -42,19 +44,19 @@ while ("${PADDLE_VERSION}" STREQUAL "")
if (${GIT_EXACT_TAG_NAME} MATCHES "v${TAG_VERSION_REGEX}")
string(REPLACE "v" "" PADDLE_VERSION ${GIT_EXACT_TAG_NAME})
else()
set(PADDLE_VERSION "0.0.0")
set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}")
endif()
else()
# otherwise, we always set PADDLE_VERSION to 0.0.0 to represent latest
set(PADDLE_VERSION "0.0.0")
# otherwise, we always set PADDLE_VERSION to "latest"
set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}")
endif()
endif()
else()
set(PADDLE_VERSION "0.0.0")
set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}")
message(WARNING "Cannot add paddle version from git tag")
endif()
else()
set(PADDLE_VERSION "0.0.0")
set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}")
message(WARNING "Cannot add paddle version for wrong git branch result")
endif()
endwhile()
......
此差异已折叠。
......@@ -455,6 +455,8 @@ void MultiSlotDataFeed::Init(
all_slots_.resize(all_slot_num);
all_slots_type_.resize(all_slot_num);
use_slots_index_.resize(all_slot_num);
total_dims_without_inductive_.resize(all_slot_num);
inductive_shape_index_.resize(all_slot_num);
use_slots_.clear();
use_slots_is_dense_.clear();
for (size_t i = 0; i < all_slot_num; ++i) {
......@@ -462,14 +464,20 @@ void MultiSlotDataFeed::Init(
all_slots_[i] = slot.name();
all_slots_type_[i] = slot.type();
use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1;
total_dims_without_inductive_[i] = 1;
inductive_shape_index_[i] = -1;
if (slot.is_used()) {
use_slots_.push_back(all_slots_[i]);
use_slots_is_dense_.push_back(slot.is_dense());
std::vector<int> local_shape;
if (slot.is_dense()) {
// for batch size holder if is_dense
if (slot.shape(0) > 0) {
local_shape.push_back(0);
for (size_t i = 0; i < slot.shape_size(); ++i) {
if (slot.shape(i) > 0) {
total_dims_without_inductive_[i] *= slot.shape(i);
}
if (slot.shape(i) == -1) {
inductive_shape_index_[i] = i;
}
}
}
for (size_t i = 0; i < slot.shape_size(); ++i) {
......@@ -762,7 +770,10 @@ void MultiSlotDataFeed::PutToFeedVec(
LoD data_lod{offset};
feed_vec_[i]->set_lod(data_lod);
if (use_slots_is_dense_[i]) {
use_slots_shape_[i][0] = batch_size_;
if (inductive_shape_index_[i] != -1) {
use_slots_shape_[i][inductive_shape_index_[i]] =
total_instance / total_dims_without_inductive_[i];
}
feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i]));
}
}
......@@ -785,6 +796,8 @@ void MultiSlotInMemoryDataFeed::Init(
all_slots_.resize(all_slot_num);
all_slots_type_.resize(all_slot_num);
use_slots_index_.resize(all_slot_num);
total_dims_without_inductive_.resize(all_slot_num);
inductive_shape_index_.resize(all_slot_num);
use_slots_.clear();
use_slots_is_dense_.clear();
for (size_t i = 0; i < all_slot_num; ++i) {
......@@ -797,8 +810,13 @@ void MultiSlotInMemoryDataFeed::Init(
use_slots_is_dense_.push_back(slot.is_dense());
std::vector<int> local_shape;
if (slot.is_dense()) {
if (slot.shape(0) > 0) {
local_shape.push_back(0);
for (size_t i = 0; i < slot.shape_size(); ++i) {
if (slot.shape(i) > 0) {
total_dims_without_inductive_[i] *= slot.shape(i);
}
if (slot.shape(i) == -1) {
inductive_shape_index_[i] = i;
}
}
}
for (size_t i = 0; i < slot.shape_size(); ++i) {
......@@ -960,7 +978,10 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
LoD data_lod{offset};
feed_vec_[i]->set_lod(data_lod);
if (use_slots_is_dense_[i]) {
use_slots_shape_[i][0] = batch_size_;
if (inductive_shape_index_[i] != -1) {
use_slots_shape_[i][inductive_shape_index_[i]] =
total_instance / total_dims_without_inductive_[i];
}
feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i]));
}
}
......
......@@ -143,6 +143,8 @@ class DataFeed {
std::vector<std::string> all_slots_;
std::vector<std::string> all_slots_type_;
std::vector<std::vector<int>> use_slots_shape_;
std::vector<int> inductive_shape_index_;
std::vector<int> total_dims_without_inductive_;
std::vector<int>
use_slots_index_; // -1: not used; >=0: the index of use_slots_
......
......@@ -121,6 +121,16 @@ int64_t product(const DDim& ddim) {
return ddim.apply_visitor(ProductVisitor());
}
bool contain_unknown_dim(const DDim& ddim) {
for (int i = 0; i < ddim.size(); ++i) {
if (ddim[i] < 0) {
return true;
}
}
return false;
}
DDim slice_ddim(const DDim& dim, int begin, int end) {
PADDLE_ENFORCE(begin >= 0 && end <= dim.size(),
"[begin(%d), end(%d)) must be inside [0, %d) in ddim slice.",
......
......@@ -182,6 +182,8 @@ std::vector<int> vectorize2int(const DDim& ddim);
int64_t product(const DDim& ddim);
bool contain_unknown_dim(const DDim& ddim);
/**
* \brief Slice a ddim
*
......
cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node)
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor)
cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base)
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper)
cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper)
cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper)
cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper)
cc_library(fuse_adam_op_pass SRCS fuse_adam_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(fuse_momentum_op_pass SRCS fuse_momentum_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
......@@ -27,7 +17,7 @@ if(WITH_DISTRIBUTE)
endif()
endif()
set(all_reduce_deps all_reduce_op_handle)
if(WITH_GPU)
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor)
......@@ -37,7 +27,6 @@ if(WITH_GPU)
if(WITH_DGC)
nv_library(sparse_all_reduce_op_handle SRCS sparse_all_reduce_op_handle.cc DEPS op_handle_base scope
lod_tensor ddim memory dynload_cuda variable_visitor dgc all_reduce_op_handle)
set(all_reduce_deps sparse_all_reduce_op_handle)
endif()
if(WITH_DISTRIBUTE)
......@@ -68,34 +57,12 @@ endif()
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
if(WITH_GPU)
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper gpu_info)
else()
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper cpu_info)
endif()
cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass)
cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info)
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle)
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass)
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle ${all_reduce_deps} reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass)
if (WITH_GPU)
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
endif()
cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/alloc_continuous_space_for_grad_pass.h"
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
DEFINE_uint64(fuse_parameter_memory_size, 0, // 0 KB
"fuse_parameter_memory_size is up limited memory size "
"of one group parameters' gradient which is the input "
"of communication calling(e.g NCCLAllReduce). "
"The default value is 0, it means that "
"not set group according to memory_size.");
DEFINE_int32(
fuse_parameter_groups_size, 3,
"fuse_parameter_groups_size is the size of one group parameters' gradient. "
"The default value is a experimental result. If the "
"fuse_parameter_groups_size is 1, it means that the groups size is "
"the number of parameters' gradient. If the fuse_parameter_groups_size is "
"-1, it means that there are only one group. The default value is 3, it is "
"an experimental value.");
namespace paddle {
namespace framework {
namespace details {
// SetFuseParameterGroupsSize and SetFuseParameterMemorySize are used in unit
// test, because it is invalid that seting 'FLAGS_fuse_parameter_memory_size'
// and 'FLAGS_fuse_parameter_groups_size' in unit test.
void SetFuseParameterGroupsSize(int group_size) {
FLAGS_fuse_parameter_groups_size = group_size;
}
int GetFuseParameterGroupsSize() { return FLAGS_fuse_parameter_groups_size; }
void SetFuseParameterMemorySize(uint64_t memory_size) {
FLAGS_fuse_parameter_memory_size = memory_size;
}
uint64_t GetFuseParameterMemorySize() {
return FLAGS_fuse_parameter_memory_size;
}
static const char kUnKnow[] = "@UNKNOW@";
static framework::proto::VarType::Type kDefaultDtype =
framework::proto::VarType::Type::VarType_Type_BOOL;
void AllocContinuousSpaceForGradPass::ApplyImpl(ir::Graph *graph) const {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(kLocalScopes);
ResetAttribute<ParamsAndGrads>(kParamsAndGrads, &result);
ResetAttribute<GroupGradsAndParams>(kGroupGradsAndParams, &result);
// NOTE: The operator nodes should be in topology order.
std::vector<ir::Node *> topo_nodes = ir::TopologySortOperations(result);
auto &params_grads = result.Get<ParamsAndGrads>(kParamsAndGrads);
for (auto &node : topo_nodes) {
RecordParamsAndGrads(node, &params_grads);
}
if (params_grads.size() == 0) {
VLOG(10) << "Doesn't find gradients";
return;
}
std::unordered_map<std::string, ir::Node *> vars;
for (ir::Node *node : result.Nodes()) {
if (node->IsVar() && node->Var()) {
// Note: The graph may have the same name node. For example, parameter
// is the input of operator and it also is the output of optimizer;
vars.emplace(node->Var()->Name(), node);
}
}
auto &group_grads_params =
result.Get<GroupGradsAndParams>(kGroupGradsAndParams);
// Note: the order of params_grads may be changed by SetGroupGradsAndParams.
SetGroupGradsAndParams(vars, params_grads, &group_grads_params);
params_grads.clear();
for (auto &group_p_g : group_grads_params) {
params_grads.insert(params_grads.begin(), group_p_g.begin(),
group_p_g.end());
}
for (auto &p_g : params_grads) {
std::swap(p_g.first, p_g.second);
}
// Set Gradients as Persistable to prevent this var becoming reusable.
auto dtype = kDefaultDtype;
for (auto &p_g : params_grads) {
// Get gradient var
auto iter = vars.find(p_g.second);
PADDLE_ENFORCE(iter != vars.end(), "%s is not found.", p_g.second);
iter->second->Var()->SetPersistable(true);
PADDLE_ENFORCE(IsSupportedVarType(iter->second->Var()->GetType()));
// Get Dtype
auto ele_dtype = iter->second->Var()->GetDataType();
if (dtype == kDefaultDtype) {
dtype = ele_dtype;
PADDLE_ENFORCE_NE(ele_dtype, kDefaultDtype,
"The data type should not be bool.");
}
PADDLE_ENFORCE_EQ(ele_dtype, dtype,
"The data type of input is not consistent.");
}
// Create a FusedVarsSet to avoid duplicating names for fused_var in other
// pass.
if (!result.Has(kFusedVars)) {
result.Set(kFusedVars, new FusedVars);
}
// the kFusedGrads is used be fuse_optimizer_op_pass.
result.Set(kFusedGrads, new FusedGrads);
// the fused_var_name should be unique, so it appends
// params_grads.begin()->second.
auto fused_var_name = std::string(kFusedVarNamePrefix) + "@GRAD@" +
params_grads.begin()->second;
result.Get<FusedGrads>(kFusedGrads) = fused_var_name;
auto &fused_var_set = result.Get<FusedVars>(kFusedVars);
PADDLE_ENFORCE_EQ(fused_var_set.count(fused_var_name), 0,
"%s is duplicate in FusedVars.", fused_var_name);
fused_var_set.insert(fused_var_name);
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars, fused_var_name,
params_grads);
}
template <typename AttrType>
void AllocContinuousSpaceForGradPass::ResetAttribute(
const std::string &attr_name, ir::Graph *graph) const {
if (graph->Has(attr_name)) {
VLOG(10) << attr_name << " is reset.";
graph->Erase(attr_name);
}
graph->Set(attr_name, new AttrType);
}
void AllocContinuousSpaceForGradPass::SetGroupGradsAndParams(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const ParamsAndGrads &params_grads,
GroupGradsAndParams *group_grads_params) const {
SetGroupAccordingToLayers(var_nodes, params_grads, group_grads_params);
SetGroupAccordingToMemorySize(var_nodes, group_grads_params);
SetGroupAccordingToGroupSize(var_nodes, group_grads_params);
}
void AllocContinuousSpaceForGradPass::SetGroupAccordingToLayers(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const ParamsAndGrads &params_grads,
GroupGradsAndParams *group_grads_params) const {
std::unordered_map<std::string, std::vector<int>> layer_params;
for (size_t i = 0; i < params_grads.size(); ++i) {
auto pos = params_grads[i].first.find_first_of(".");
if (pos == std::string::npos) {
layer_params[std::string(kUnKnow)].emplace_back(i);
} else {
layer_params[params_grads[i].first.substr(0, pos)].emplace_back(i);
}
}
group_grads_params->reserve(layer_params.size());
for (size_t i = 0; i < params_grads.size(); ++i) {
auto pos = params_grads[i].first.find_first_of(".");
std::string key = kUnKnow;
if (pos != std::string::npos) {
key = params_grads[i].first.substr(0, pos);
}
auto iter = layer_params.find(key);
if (iter == layer_params.end()) continue;
group_grads_params->emplace_back();
auto &local_group_grads_params = group_grads_params->back();
for (auto &idx : iter->second) {
local_group_grads_params.emplace_back(
std::make_pair(params_grads[idx].second, params_grads[idx].first));
}
layer_params.erase(iter);
}
VLOG(10) << "SetGroupAccordingToLayers: ";
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &p_g : group_grads_params->at(i)) {
out << "(" << p_g.second << ", " << p_g.first << "), ";
}
VLOG(10) << out.str();
}
}
void AllocContinuousSpaceForGradPass::SetGroupAccordingToMemorySize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
GroupGradsAndParams *group_grads_params) const {
const uint64_t group_memory_size = GetFuseParameterMemorySize();
if (group_memory_size == 0) {
return;
}
GroupGradsAndParams local_group_grads_params;
size_t j = 0;
while (j < group_grads_params->size()) {
local_group_grads_params.emplace_back();
auto &group_p_g = local_group_grads_params.back();
size_t local_group_memory_size = 0;
while (j < group_grads_params->size()) {
std::for_each(
group_grads_params->at(j).begin(), group_grads_params->at(j).end(),
[&local_group_memory_size,
&var_nodes](const std::pair<std::string, std::string> &g_p) {
auto iter = var_nodes.find(g_p.second);
PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.",
g_p.second);
auto shape = iter->second->Var()->GetShape();
size_t size =
framework::SizeOfType(iter->second->Var()->GetDataType());
std::for_each(shape.begin(), shape.end(),
[&size](const int64_t &n) { size *= n; });
local_group_memory_size += size;
});
group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(),
group_grads_params->at(j).end());
++j;
if (local_group_memory_size >= group_memory_size) {
break;
}
}
}
std::swap(*group_grads_params, local_group_grads_params);
VLOG(10) << string::Sprintf("SetGroupAccordingToMemorySize(memory_size: %d):",
group_memory_size);
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &g_p : group_grads_params->at(i)) {
auto iter = var_nodes.find(g_p.second);
PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.", g_p.second);
auto shape = iter->second->Var()->GetShape();
size_t size = framework::SizeOfType(iter->second->Var()->GetDataType());
std::for_each(shape.begin(), shape.end(),
[&size](const int64_t &n) { size *= n; });
out << string::Sprintf("(%s(%d), %s)", g_p.second, size, g_p.first);
}
VLOG(10) << out.str();
}
}
void AllocContinuousSpaceForGradPass::SetGroupAccordingToGroupSize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
GroupGradsAndParams *group_grads_params) const {
if (GetFuseParameterGroupsSize() == 1) {
return;
}
const int group_size = GetFuseParameterGroupsSize() == -1
? static_cast<int>(group_grads_params->size())
: GetFuseParameterGroupsSize();
PADDLE_ENFORCE_GT(group_size, 1);
size_t groups = (group_grads_params->size() + group_size - 1) / group_size;
GroupGradsAndParams local_group_grads_params;
local_group_grads_params.reserve(groups);
size_t j = 0;
for (size_t i = 0; i < groups; ++i) {
local_group_grads_params.emplace_back();
auto &group_p_g = local_group_grads_params.back();
group_p_g.reserve(group_size);
while (j < group_grads_params->size()) {
group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(),
group_grads_params->at(j).end());
++j;
if (j % group_size == 0) break;
}
}
std::swap(*group_grads_params, local_group_grads_params);
VLOG(10) << string::Sprintf("SetGroupAccordingToGroupSize(group_size: %d):",
group_size);
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &p_g : group_grads_params->at(i)) {
out << "(" << p_g.second << ", " << p_g.first << "), ";
}
VLOG(10) << out.str();
}
}
bool AllocContinuousSpaceForGradPass::IsSupportedVarType(
const proto::VarType::Type &type) const {
// Current only support LOD_TENSOR.
return type == proto::VarType::LOD_TENSOR;
}
void AllocContinuousSpaceForGradPass::RecordParamsAndGrads(
ir::Node *node, ParamsAndGrads *params_grads) const {
try {
bool is_bk_op =
static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
if (!is_bk_op) return;
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
auto backward_vars =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, static_cast<size_t>(0));
for (size_t i = 0; i < backward_vars.size(); i += 2) {
VLOG(10) << "Trainable parameter: " << backward_vars[i]
<< ", gradient: " << backward_vars[i + 1];
params_grads->emplace_back(std::make_pair(backward_vars[i] /*param*/,
backward_vars[i + 1] /*grad*/));
}
} catch (boost::bad_get e) {
}
}
void AllocContinuousSpaceForGradPass::InitFusedVarsAndAllocSpaceForVars(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::unordered_map<std::string, ir::Node *> &vars,
const std::string &fused_var_name,
const ParamsAndGrads &params_grads) const {
// Init Gradients and FusedVars
VLOG(10) << "Init FusedVars and Gradients.";
for (auto it = local_scopes.rbegin(); it != local_scopes.rend(); ++it) {
auto &scope = *it;
PADDLE_ENFORCE(scope->FindVar(fused_var_name) == nullptr,
"%s has existed in scope.", fused_var_name);
scope->Var(fused_var_name)->GetMutable<LoDTensor>();
for (auto &p_g : params_grads) {
auto iter = vars.find(p_g.second);
PADDLE_ENFORCE(iter != vars.end());
PADDLE_ENFORCE_NOT_NULL(iter->second->Var());
PADDLE_ENFORCE_EQ(iter->second->Var()->GetType(),
proto::VarType::LOD_TENSOR);
scope->Var(p_g.second)->GetMutable<LoDTensor>();
}
}
// Alloc continuous space for vars.
std::vector<std::string> grads_name;
std::vector<std::string> params_name;
grads_name.reserve(params_grads.size());
params_name.reserve(params_grads.size());
for (auto &p_g : params_grads) {
params_name.emplace_back(p_g.first);
grads_name.emplace_back(p_g.second);
}
framework::ProgramDesc program_desc;
AppendAllocSpaceForVarsOp(params_name, grads_name, fused_var_name,
program_desc.MutableBlock(0));
for (size_t i = 0; i < local_scopes.size(); ++i) {
for (auto &op_desc : program_desc.Block(0).AllOps()) {
auto op = OpRegistry::CreateOp(*op_desc);
op->Run(*local_scopes[i], places[i]);
}
}
}
void AllocContinuousSpaceForGradPass::AppendAllocSpaceForVarsOp(
const std::vector<std::string> &params_name,
const std::vector<std::string> &grads_name,
const std::string &fused_var_name, BlockDesc *global_block) const {
auto op_desc = global_block->AppendOp();
op_desc->SetType("alloc_continuous_space");
op_desc->SetInput("Input", params_name);
op_desc->SetOutput("Output", grads_name);
op_desc->SetOutput("FusedOutput", {fused_var_name});
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(alloc_continuous_space_for_grad_pass,
paddle::framework::details::AllocContinuousSpaceForGradPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace details {
void SetFuseParameterGroupsSize(int group_size);
int GetFuseParameterGroupsSize();
void SetFuseParameterMemorySize(uint64_t memory_size);
uint64_t GetFuseParameterMemorySize();
class AllocContinuousSpaceForGradPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override;
template <typename AttrType>
void ResetAttribute(const std::string &attr_name, ir::Graph *graph) const;
void SetGroupGradsAndParams(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const ParamsAndGrads &params_grads,
GroupGradsAndParams *group_grads_params) const;
void SetGroupAccordingToLayers(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const ParamsAndGrads &params_grads,
GroupGradsAndParams *group_grads_params) const;
void SetGroupAccordingToMemorySize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
GroupGradsAndParams *group_grads_params) const;
void SetGroupAccordingToGroupSize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
GroupGradsAndParams *group_grads_params) const;
private:
bool IsSupportedVarType(const proto::VarType::Type &type) const;
void RecordParamsAndGrads(ir::Node *node, ParamsAndGrads *params_grads) const;
void InitFusedVarsAndAllocSpaceForVars(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::unordered_map<std::string, ir::Node *> &vars,
const std::string &fused_var_name,
const ParamsAndGrads &params_grads) const;
void AppendAllocSpaceForVarsOp(const std::vector<std::string> &params_name,
const std::vector<std::string> &grads_name,
const std::string &fused_var_name,
BlockDesc *global_block) const;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -17,15 +17,14 @@ limitations under the License. */
#include <glog/logging.h>
#include <memory>
#include <utility>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
namespace paddle {
namespace framework {
......@@ -173,10 +172,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
const std::string graph_path =
string::Sprintf("%s%s", strategy_.debug_graphviz_path_.c_str(),
"_multi_devices_graph");
multi_devices_print_pass->Set<std::string>(kGraphvizPath,
multi_devices_print_pass->Set<std::string>(ir::kGraphvizPath,
new std::string(graph_path));
multi_devices_print_pass->Set<details::GraphvizSSAGraphPrinter>(
"graph_printer", new details::GraphvizSSAGraphPrinter);
multi_devices_print_pass->Set<ir::GraphvizSSAGraphPrinter>(
"graph_printer", new ir::GraphvizSSAGraphPrinter);
}
// experimental shows that the program will be faster if append
......@@ -240,7 +239,7 @@ std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
}
bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
return framework::details::MultiDevSSAGraphBuilder().count(pass_name) > 0;
return framework::ir::MultiDevSSAGraphBuilder().count(pass_name) > 0;
}
ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
......@@ -263,13 +262,13 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
if (IsMultiDevPass(pass->Type())) {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
pass->Erase(kLossVarName);
pass->SetNotOwned<const std::string>(kLossVarName, &loss_var_name);
pass->Erase(ir::kLossVarName);
pass->SetNotOwned<const std::string>(ir::kLossVarName, &loss_var_name);
pass->Erase(kLocalScopes);
pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
&local_scopes);
pass->Erase(kNRanks);
pass->Set<size_t>(kNRanks, new size_t(nranks));
pass->Erase(ir::kNRanks);
pass->Set<size_t>(ir::kNRanks, new size_t(nranks));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
......@@ -312,8 +311,8 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
continue;
}
} else if (pass->Type() == "inplace_pass") {
pass->Erase(kUseCuda);
pass->Set<bool>(kUseCuda, new bool(use_cuda));
pass->Erase(ir::kUseCuda);
pass->Set<bool>(ir::kUseCuda, new bool(use_cuda));
}
VLOG(3) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(graph);
......
......@@ -31,7 +31,7 @@ namespace details {
EagerDeletionOpHandle::EagerDeletionOpHandle(
ir::Node *node, const Scope *scope, const platform::Place &place,
const std::unordered_set<std::string> &var_names, GarbageCollector *gc,
AtomicReferenceCountMap *ref_cnts)
ir::AtomicReferenceCountMap *ref_cnts)
: OpHandleBase(node),
scope_(scope),
var_names_(var_names.begin(), var_names.end()),
......
......@@ -20,7 +20,7 @@
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
namespace paddle {
namespace framework {
......@@ -34,7 +34,7 @@ class EagerDeletionOpHandle : public OpHandleBase {
const platform::Place &place,
const std::unordered_set<std::string> &var_names,
GarbageCollector *gc,
AtomicReferenceCountMap *ref_cnts);
ir::AtomicReferenceCountMap *ref_cnts);
~EagerDeletionOpHandle();
......@@ -55,8 +55,8 @@ class EagerDeletionOpHandle : public OpHandleBase {
const Scope *scope_;
std::vector<std::string> var_names_;
GarbageCollector *gc_; // not own
AtomicReferenceCountMap *ref_cnts_; // not own
GarbageCollector *gc_; // not own
ir::AtomicReferenceCountMap *ref_cnts_; // not own
#ifdef PADDLE_WITH_CUDA
platform::CUDADeviceContext *dev_ctx_{nullptr};
cudaEvent_t event_{nullptr};
......
......@@ -63,8 +63,7 @@ void FetchOpHandle::RunImpl() {
auto &t = var->Get<framework::LoDTensor>();
if (platform::is_gpu_place(t.place())) {
#ifdef PADDLE_WITH_CUDA
TensorCopy(t, cpu, *dev_ctxes_.at(t.place()), &tensors_[i]);
dev_ctxes_.at(t.place())->Wait();
TensorCopy(t, cpu, &tensors_[i]);
#endif
} else {
tensors_[i].ShareDataWith(t);
......
......@@ -27,7 +27,7 @@ namespace paddle {
namespace framework {
namespace details {
constexpr char kLocalExecScopeName[] = "@LOCAL_SCOPE@";
constexpr char kLocalExecScopeName[] = "@LOCAL_EXE_SCOPE@";
// Wraps ir::Node and provide helper utilities.
// It's responsible for populating necessary fields of ir::Node.
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace details {
class RecordSkipMemoryOptVarsPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override {
PADDLE_ENFORCE(!graph->Has(kMemOptSkipVars));
graph->Set(kMemOptSkipVars, new MemOptSkipVars);
auto& skip_vars = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
// NOTE(zcd): Insert OpRoleVars to SkipVarSet to prevent the vars are rename
// in memory optimize pass.
InsertOpRoleVarsToSkipVarSet(graph, &skip_vars);
}
void InsertOpRoleVarsToSkipVarSet(const ir::Graph* graph,
MemOptSkipVars* skip_vars) const {
for (auto& node : graph->Nodes()) {
PADDLE_ENFORCE_NOT_NULL(node, "The node should not be nullptr.");
if (node->IsOp() && node->Op()) {
try {
auto op_role_vars =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(op_role_vars.size() % 2, 0);
for (size_t i = 0; i < op_role_vars.size(); i += 2) {
auto& g_name = op_role_vars[i + 1];
skip_vars->insert(g_name);
}
} catch (boost::bad_get e) {
}
}
}
}
};
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(record_skip_memory_opt_vars_pass,
paddle::framework::details::RecordSkipMemoryOptVarsPass);
......@@ -68,15 +68,7 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
++drop_scope_counter_;
if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
WaitComputationalStreams();
for (auto &scope : local_scopes_) {
auto &local_scope =
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
scope->DeleteScope(local_scope);
}
drop_scope_counter_ = 0;
DropLocalExeScopes();
}
if (eptr) {
std::rethrow_exception(eptr);
......@@ -84,6 +76,25 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
return fetch_data;
}
}
void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes() {
drop_scope_counter_ = 0;
for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
for (auto &scope : local_scopes_) {
auto &local_scope =
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
scope->DeleteScope(local_scope);
VLOG(3) << "Drop local execution scope: " << local_scope;
}
}
bool ScopeBufferedSSAGraphExecutor::NeedCreateLocalExeScope() {
return drop_scope_counter_ == 0;
}
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -47,17 +47,12 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override;
private:
inline void WaitComputationalStreams() {
// Wait All computational streams
for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
}
void DropLocalExeScopes();
bool NeedCreateLocalExeScope();
private:
size_t drop_scope_counter_{0};
ExecutionStrategy strategy_;
std::unique_ptr<SSAGraphExecutor> underlying_executor_;
std::vector<Scope*> local_scopes_;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace details {
static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) {
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
op1->Outputs() == op2->Outputs();
}
void SequentialExecutionPass::ApplyImpl(ir::Graph *graph) const {
// FIXME(zjl): Insert dependencies between some distributed ops may cause
// the multi_devices_graph_pass fails. So we skip these ops here.
// Indeed, maybe we should not insert dependencies between these ops
// casually, which may cause deadlock easily.
// We should add more skipped distributed ops when found errors in
// multi_devices_graph_pass
static std::unordered_set<std::string> skip_dist_ops{
"send", "recv", "send_barrier", "fetch_barrier"};
auto &ops = graph->Get<const std::vector<OpDesc *>>(kStaleProgramOpDescs);
std::vector<ir::Node *> op_node_list;
op_node_list.reserve(ops.size());
std::unordered_map<ir::Node *, size_t> op_deps;
std::unordered_map<ir::Node *, std::unordered_set<ir::Node *>> pending_ops;
std::unordered_set<ir::Node *> ready_ops;
for (ir::Node *node : graph->Nodes()) {
if (!node->IsOp()) continue;
std::unordered_set<ir::Node *> preceding_ops;
for (auto *in : node->inputs) {
PADDLE_ENFORCE(in->IsVar(),
"Preceding Node of Op Nodes must be Var Node");
if (in->inputs.empty()) continue;
PADDLE_ENFORCE(in->inputs.size() == 1 && in->inputs[0]->IsOp(),
"Preceding Op Node of Var Node must be unique");
preceding_ops.insert(in->inputs[0]);
pending_ops[in->inputs[0]].insert(node);
}
op_deps[node] = preceding_ops.size();
if (preceding_ops.empty()) {
ready_ops.insert(node);
}
}
for (auto *op_desc : ops) {
ir::Node *found_node = nullptr;
for (auto *node : ready_ops) {
if (IsSameOpDesc(op_desc, node->Op())) {
PADDLE_ENFORCE(found_node == nullptr,
"Found multiple op_desc in graph: %s", op_desc->Type());
found_node = node;
}
}
PADDLE_ENFORCE_NOT_NULL(found_node, "Cannot find op_desc in graph: %s",
op_desc->Type());
for (auto *pending_op : pending_ops[found_node]) {
if (--op_deps.at(pending_op) == 0) {
ready_ops.insert(pending_op);
}
}
ready_ops.erase(found_node);
if (skip_dist_ops.count(op_desc->Type()) == 0) {
op_node_list.push_back(found_node);
}
}
for (size_t i = 1; i < op_node_list.size(); ++i) {
auto *dep_var = graph->CreateControlDepVar();
op_node_list[i]->inputs.push_back(dep_var);
op_node_list[i - 1]->outputs.push_back(dep_var);
dep_var->outputs.push_back(op_node_list[i]);
dep_var->inputs.push_back(op_node_list[i - 1]);
VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name()
<< " and " << op_node_list[i]->Name();
}
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(sequential_execution_pass,
paddle::framework::details::SequentialExecutionPass)
.RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
......@@ -425,6 +425,7 @@ void DownpourWorker::TrainFiles() {
}
VLOG(3) << "push dense gradient done.";
// the following code should be more precise and clean
// TODO(guru4elephant)
int32_t tmp_push_dense_wait_times = -1;
......
......@@ -19,7 +19,7 @@
#include <unordered_map>
#include <unordered_set>
#include "glog/logging.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/type_defs.h"
......
......@@ -18,7 +18,7 @@
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -33,7 +33,7 @@ namespace framework {
std::unique_ptr<ir::Pass> CreateInplacePass() {
auto pass = ir::PassRegistry::Instance().Get("inplace_pass");
pass->Set<bool>(details::kUseCuda, new bool(true));
pass->Set<bool>(ir::kUseCuda, new bool(true));
return pass;
}
......@@ -225,7 +225,7 @@ TEST(InferInplace, SingleOpInplaceInToOut) {
FakeSuccData(&prog);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
g->Set(ir::kMemOptSkipVars, new std::unordered_set<std::string>());
g = test_SingleOpInplaceInToOut(std::move(g));
auto op_node = GetNodeFromGraph(g.get(), "single_op");
......@@ -241,7 +241,7 @@ TEST(InferInplace, SingleOpInplaceInToOutNoInplace) {
FakeNoInplaceData(&prog);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
g->Set(ir::kMemOptSkipVars, new std::unordered_set<std::string>());
g = test_SingleOpInplaceInToOut(std::move(g));
auto op_node = GetNodeFromGraph(g.get(), "single_op");
......@@ -274,7 +274,7 @@ TEST(InferInplace, MultiOutInplaceInToOut) {
prog.MutableBlock(0)->Var("z0")->SetShape({32, 16, 1024, 1024});
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
g->Set(ir::kMemOptSkipVars, new std::unordered_set<std::string>());
auto pass = CreateInplacePass();
pass->Apply(g.get());
auto op_node = GetNodeFromGraph(g.get(), "multi_out_op");
......@@ -310,7 +310,7 @@ TEST(InferInplace, MultiGradInplaceInToOut) {
prog.MutableBlock(0)->Var("z0")->SetShape({32, 15, 1024, 1024});
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
g->Set(ir::kMemOptSkipVars, new std::unordered_set<std::string>());
auto pass = CreateInplacePass();
pass->Apply(g.get());
auto op_node = GetNodeFromGraph(g.get(), "multi_out_grad");
......
......@@ -3,6 +3,9 @@ file(WRITE ${pass_file} "// Generated by the paddle/fluid/framework/ir/CMakeList
file(APPEND ${pass_file} "\#pragma once\n")
file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
add_subdirectory(fuse_optimizer_ops_pass)
add_subdirectory(memory_optimize_pass)
add_subdirectory(multi_devices_graph_pass)
# Usage: pass_library(target inference) will append to paddle_inference_pass.h
unset(INFER_IR_PASSES CACHE) # clear the global variable
......@@ -34,7 +37,6 @@ function(pass_library TARGET DEST)
endif()
endfunction()
cc_library(node SRCS node.cc DEPS proto_desc)
cc_library(graph SRCS graph.cc DEPS node pretty_log)
cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
......@@ -43,6 +45,8 @@ cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper)
pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base)
......@@ -71,6 +75,7 @@ pass_library(runtime_context_cache_pass base)
pass_library(expected_kernel_cache_pass base)
pass_library(quant_conv2d_dequant_fuse_pass inference)
pass_library(fillconstant_elementwisemul_fuse inference)
pass_library(shuffle_channel_detect_pass inference)
if(ANAKIN_FOUND)
pass_library(simplify_anakin_priorbox_detection_out_pass inference)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.h"
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
DEFINE_uint64(fuse_parameter_memory_size, 0, // 0 KB
"fuse_parameter_memory_size is up limited memory size "
"of one group parameters' gradient which is the input "
"of communication calling(e.g NCCLAllReduce). "
"The default value is 0, it means that "
"not set group according to memory_size.");
DEFINE_int32(
fuse_parameter_groups_size, 3,
"fuse_parameter_groups_size is the size of one group parameters' gradient. "
"The default value is a experimental result. If the "
"fuse_parameter_groups_size is 1, it means that the groups size is "
"the number of parameters' gradient. If the fuse_parameter_groups_size is "
"-1, it means that there are only one group. The default value is 3, it is "
"an experimental value.");
namespace paddle {
namespace framework {
namespace ir {
// SetFuseParameterGroupsSize and SetFuseParameterMemorySize are used in unit
// test, because it is invalid that seting 'FLAGS_fuse_parameter_memory_size'
// and 'FLAGS_fuse_parameter_groups_size' in unit test.
void SetFuseParameterGroupsSize(int group_size) {
FLAGS_fuse_parameter_groups_size = group_size;
}
int GetFuseParameterGroupsSize() { return FLAGS_fuse_parameter_groups_size; }
void SetFuseParameterMemorySize(uint64_t memory_size) {
FLAGS_fuse_parameter_memory_size = memory_size;
}
uint64_t GetFuseParameterMemorySize() {
return FLAGS_fuse_parameter_memory_size;
}
static const char kUnKnow[] = "@UNKNOW@";
static framework::proto::VarType::Type kDefaultDtype =
framework::proto::VarType::Type::VarType_Type_BOOL;
class AllocContinuousSpaceForGradPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(details::kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(details::kLocalScopes);
ResetAttribute<details::ParamsAndGrads>(details::kParamsAndGrads, &result);
ResetAttribute<details::GroupGradsAndParams>(details::kGroupGradsAndParams,
&result);
// NOTE: The operator nodes should be in topology order.
std::vector<ir::Node *> topo_nodes = ir::TopologySortOperations(result);
auto &params_grads =
result.Get<details::ParamsAndGrads>(details::kParamsAndGrads);
for (auto &node : topo_nodes) {
RecordParamsAndGrads(node, &params_grads);
}
if (params_grads.size() == 0) {
VLOG(10) << "Doesn't find gradients";
return;
}
std::unordered_map<std::string, ir::Node *> vars;
for (ir::Node *node : result.Nodes()) {
if (node->IsVar() && node->Var()) {
// Note: The graph may have the same name node. For example, parameter
// is the input of operator and it also is the output of optimizer;
vars.emplace(node->Var()->Name(), node);
}
}
auto &group_grads_params =
result.Get<details::GroupGradsAndParams>(details::kGroupGradsAndParams);
// Note: the order of params_grads may be changed by SetGroupGradsAndParams.
SetGroupGradsAndParams(vars, params_grads, &group_grads_params);
params_grads.clear();
for (auto &group_p_g : group_grads_params) {
params_grads.insert(params_grads.begin(), group_p_g.begin(),
group_p_g.end());
}
for (auto &p_g : params_grads) {
std::swap(p_g.first, p_g.second);
}
// Set Gradients as Persistable to prevent this var becoming reusable.
auto dtype = kDefaultDtype;
for (auto &p_g : params_grads) {
// Get gradient var
auto iter = vars.find(p_g.second);
PADDLE_ENFORCE(iter != vars.end(), "%s is not found.", p_g.second);
iter->second->Var()->SetPersistable(true);
PADDLE_ENFORCE(IsSupportedVarType(iter->second->Var()->GetType()));
// Get Dtype
auto ele_dtype = iter->second->Var()->GetDataType();
if (dtype == kDefaultDtype) {
dtype = ele_dtype;
PADDLE_ENFORCE_NE(ele_dtype, kDefaultDtype,
"The data type should not be bool.");
}
PADDLE_ENFORCE_EQ(ele_dtype, dtype,
"The data type of input is not consistent.");
}
// Create a FusedVarsSet to avoid duplicating names for fused_var in other
// pass.
if (!result.Has(details::kFusedVars)) {
result.Set(details::kFusedVars, new details::FusedVars);
}
// the kFusedGrads is used be fuse_optimizer_op_pass.
result.Set(details::kFusedGrads, new details::FusedGrads);
// the fused_var_name should be unique, so it appends
// params_grads.begin()->second.
auto fused_var_name = std::string(details::kFusedVarNamePrefix) + "@GRAD@" +
params_grads.begin()->second;
result.Get<details::FusedGrads>(details::kFusedGrads) = fused_var_name;
auto &fused_var_set = result.Get<details::FusedVars>(details::kFusedVars);
PADDLE_ENFORCE_EQ(fused_var_set.count(fused_var_name), 0,
"%s is duplicate in FusedVars.", fused_var_name);
fused_var_set.insert(fused_var_name);
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars,
fused_var_name, params_grads);
}
template <typename AttrType>
void ResetAttribute(const std::string &attr_name, ir::Graph *graph) const {
if (graph->Has(attr_name)) {
VLOG(10) << attr_name << " is reset.";
graph->Erase(attr_name);
}
graph->Set(attr_name, new AttrType);
}
void SetGroupGradsAndParams(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const details::ParamsAndGrads &params_grads,
details::GroupGradsAndParams *group_grads_params) const {
SetGroupAccordingToLayers(var_nodes, params_grads, group_grads_params);
SetGroupAccordingToMemorySize(var_nodes, group_grads_params);
SetGroupAccordingToGroupSize(var_nodes, group_grads_params);
}
void SetGroupAccordingToLayers(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const details::ParamsAndGrads &params_grads,
details::GroupGradsAndParams *group_grads_params) const {
std::unordered_map<std::string, std::vector<int>> layer_params;
for (size_t i = 0; i < params_grads.size(); ++i) {
auto pos = params_grads[i].first.find_first_of(".");
if (pos == std::string::npos) {
layer_params[std::string(kUnKnow)].emplace_back(i);
} else {
layer_params[params_grads[i].first.substr(0, pos)].emplace_back(i);
}
}
group_grads_params->reserve(layer_params.size());
for (size_t i = 0; i < params_grads.size(); ++i) {
auto pos = params_grads[i].first.find_first_of(".");
std::string key = kUnKnow;
if (pos != std::string::npos) {
key = params_grads[i].first.substr(0, pos);
}
auto iter = layer_params.find(key);
if (iter == layer_params.end()) continue;
group_grads_params->emplace_back();
auto &local_group_grads_params = group_grads_params->back();
for (auto &idx : iter->second) {
local_group_grads_params.emplace_back(
std::make_pair(params_grads[idx].second, params_grads[idx].first));
}
layer_params.erase(iter);
}
VLOG(10) << "SetGroupAccordingToLayers: ";
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &p_g : group_grads_params->at(i)) {
out << "(" << p_g.second << ", " << p_g.first << "), ";
}
VLOG(10) << out.str();
}
}
void SetGroupAccordingToMemorySize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
details::GroupGradsAndParams *group_grads_params) const {
const uint64_t group_memory_size = GetFuseParameterMemorySize();
if (group_memory_size == 0) {
return;
}
details::GroupGradsAndParams local_group_grads_params;
size_t j = 0;
while (j < group_grads_params->size()) {
local_group_grads_params.emplace_back();
auto &group_p_g = local_group_grads_params.back();
size_t local_group_memory_size = 0;
while (j < group_grads_params->size()) {
std::for_each(
group_grads_params->at(j).begin(), group_grads_params->at(j).end(),
[&local_group_memory_size,
&var_nodes](const std::pair<std::string, std::string> &g_p) {
auto iter = var_nodes.find(g_p.second);
PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.",
g_p.second);
auto shape = iter->second->Var()->GetShape();
size_t size =
framework::SizeOfType(iter->second->Var()->GetDataType());
std::for_each(shape.begin(), shape.end(),
[&size](const int64_t &n) { size *= n; });
local_group_memory_size += size;
});
group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(),
group_grads_params->at(j).end());
++j;
if (local_group_memory_size >= group_memory_size) {
break;
}
}
}
std::swap(*group_grads_params, local_group_grads_params);
VLOG(10) << string::Sprintf(
"SetGroupAccordingToMemorySize(memory_size: %d):", group_memory_size);
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &g_p : group_grads_params->at(i)) {
auto iter = var_nodes.find(g_p.second);
PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.", g_p.second);
auto shape = iter->second->Var()->GetShape();
size_t size = framework::SizeOfType(iter->second->Var()->GetDataType());
std::for_each(shape.begin(), shape.end(),
[&size](const int64_t &n) { size *= n; });
out << string::Sprintf("(%s(%d), %s)", g_p.second, size, g_p.first);
}
VLOG(10) << out.str();
}
}
void SetGroupAccordingToGroupSize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
details::GroupGradsAndParams *group_grads_params) const {
if (GetFuseParameterGroupsSize() == 1) {
return;
}
const int group_size = GetFuseParameterGroupsSize() == -1
? static_cast<int>(group_grads_params->size())
: GetFuseParameterGroupsSize();
PADDLE_ENFORCE_GT(group_size, 1);
size_t groups = (group_grads_params->size() + group_size - 1) / group_size;
details::GroupGradsAndParams local_group_grads_params;
local_group_grads_params.reserve(groups);
size_t j = 0;
for (size_t i = 0; i < groups; ++i) {
local_group_grads_params.emplace_back();
auto &group_p_g = local_group_grads_params.back();
group_p_g.reserve(group_size);
while (j < group_grads_params->size()) {
group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(),
group_grads_params->at(j).end());
++j;
if (j % group_size == 0) break;
}
}
std::swap(*group_grads_params, local_group_grads_params);
VLOG(10) << string::Sprintf("SetGroupAccordingToGroupSize(group_size: %d):",
group_size);
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &p_g : group_grads_params->at(i)) {
out << "(" << p_g.second << ", " << p_g.first << "), ";
}
VLOG(10) << out.str();
}
}
private:
bool IsSupportedVarType(const proto::VarType::Type &type) const {
// Current only support LOD_TENSOR.
return type == proto::VarType::LOD_TENSOR;
}
void RecordParamsAndGrads(ir::Node *node,
details::ParamsAndGrads *params_grads) const {
try {
bool is_bk_op =
static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
if (!is_bk_op) return;
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
auto backward_vars =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, static_cast<size_t>(0));
for (size_t i = 0; i < backward_vars.size(); i += 2) {
VLOG(10) << "Trainable parameter: " << backward_vars[i]
<< ", gradient: " << backward_vars[i + 1];
params_grads->emplace_back(std::make_pair(
backward_vars[i] /*param*/, backward_vars[i + 1] /*grad*/));
}
} catch (boost::bad_get e) {
}
}
void InitFusedVarsAndAllocSpaceForVars(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::unordered_map<std::string, ir::Node *> &vars,
const std::string &fused_var_name,
const details::ParamsAndGrads &params_grads) const {
// Init Gradients and FusedVars
VLOG(10) << "Init FusedVars and Gradients.";
for (auto it = local_scopes.rbegin(); it != local_scopes.rend(); ++it) {
auto &scope = *it;
PADDLE_ENFORCE(scope->FindVar(fused_var_name) == nullptr,
"%s has existed in scope.", fused_var_name);
scope->Var(fused_var_name)->GetMutable<LoDTensor>();
for (auto &p_g : params_grads) {
auto iter = vars.find(p_g.second);
PADDLE_ENFORCE(iter != vars.end());
PADDLE_ENFORCE_NOT_NULL(iter->second->Var());
PADDLE_ENFORCE_EQ(iter->second->Var()->GetType(),
proto::VarType::LOD_TENSOR);
scope->Var(p_g.second)->GetMutable<LoDTensor>();
}
}
// Alloc continuous space for vars.
std::vector<std::string> grads_name;
std::vector<std::string> params_name;
grads_name.reserve(params_grads.size());
params_name.reserve(params_grads.size());
for (auto &p_g : params_grads) {
params_name.emplace_back(p_g.first);
grads_name.emplace_back(p_g.second);
}
framework::ProgramDesc program_desc;
AppendAllocSpaceForVarsOp(params_name, grads_name, fused_var_name,
program_desc.MutableBlock(0));
for (size_t i = 0; i < local_scopes.size(); ++i) {
for (auto &op_desc : program_desc.Block(0).AllOps()) {
auto op = OpRegistry::CreateOp(*op_desc);
op->Run(*local_scopes[i], places[i]);
}
}
}
void AppendAllocSpaceForVarsOp(const std::vector<std::string> &params_name,
const std::vector<std::string> &grads_name,
const std::string &fused_var_name,
BlockDesc *global_block) const {
auto op_desc = global_block->AppendOp();
op_desc->SetType("alloc_continuous_space");
op_desc->SetInput("Input", params_name);
op_desc->SetOutput("Output", grads_name);
op_desc->SetOutput("FusedOutput", {fused_var_name});
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(alloc_continuous_space_for_grad_pass,
paddle::framework::ir::AllocContinuousSpaceForGradPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -11,21 +11,19 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include <algorithm>
namespace paddle {
namespace framework {
namespace details {
namespace ir {
void SetFuseParameterGroupsSize(int group_size);
int GetFuseParameterGroupsSize();
class ReferenceCountPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
void SetFuseParameterMemorySize(uint64_t memory_size);
uint64_t GetFuseParameterMemorySize();
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -48,17 +48,37 @@ void FCFusePass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern);
auto base_op_desc = mul->Op();
// Create an FC Node.
// OpDesc desc(base_op_desc, nullptr);
OpDesc desc;
std::string fc_x_in = subgraph.at(x)->Name();
std::string fc_Y_in = w->Name();
std::string fc_bias_in = fc_bias->Name();
std::string fc_out_out = fc_out->Name();
desc.SetInput("Input", std::vector<std::string>({fc_x_in}));
desc.SetInput("W", std::vector<std::string>({fc_Y_in}));
desc.SetInput("Bias", std::vector<std::string>({fc_bias_in}));
desc.SetOutput("Out", std::vector<std::string>({fc_out_out}));
desc.SetAttr("in_num_col_dims", mul->Op()->GetAttr("x_num_col_dims"));
// For anakin subgraph int8
// When in anakin subgraph int8 mode, the pattern like "fake_quant + mul +
// fake_dequant"
// can be detected by the quant_dequant_fuse_pass. This pass will add
// "input_scale",
// "weight_scale" which are extracted from fake_quant op and fake_dequant op
// to mul op,
// and then delete the fake_quant op and fake_dequant op in the graph. If
// the mul op
// has the scale info, we should add those to the fused fc.
if (base_op_desc->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", base_op_desc->GetAttr("enable_int8"));
desc.SetAttr("input_scale", base_op_desc->GetAttr("input_scale"));
desc.SetAttr("weight_scale", base_op_desc->GetAttr("weight_scale"));
}
desc.SetType("fc");
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out});
......
cc_library(fuse_optimizer_op_pass SRCS fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(fuse_adam_op_pass SRCS fuse_adam_op_pass.cc DEPS fuse_optimizer_op_pass)
cc_library(fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc DEPS fuse_optimizer_op_pass)
cc_library(fuse_momentum_op_pass SRCS fuse_momentum_op_pass.cc DEPS fuse_optimizer_op_pass)
......@@ -16,16 +16,13 @@
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class FuseAdamOpPass : public FuseOptimizerOpPass {
private:
......@@ -203,10 +200,10 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
}
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_adam_op_pass, paddle::framework::details::FuseAdamOpPass)
REGISTER_PASS(fuse_adam_op_pass, paddle::framework::ir::FuseAdamOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
......@@ -16,14 +16,13 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class FuseMomentumOpPass : public FuseOptimizerOpPass {
private:
......@@ -84,11 +83,10 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass {
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_momentum_op_pass,
paddle::framework::details::FuseMomentumOpPass)
REGISTER_PASS(fuse_momentum_op_pass, paddle::framework::ir::FuseMomentumOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h"
#include <algorithm>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_helper.h"
......@@ -20,13 +20,13 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(kLocalScopes);
auto &places = Get<const std::vector<platform::Place>>(details::kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(details::kLocalScopes);
const std::string fuse_op_type = GetOpType();
std::vector<std::string> aux_var_names = GetAuxiliaryVarNames();
......@@ -47,24 +47,24 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
return;
}
if (result.Has(kFusedOptType)) {
if (result.Has(details::kFusedOptType)) {
VLOG(6) << "Currently only support fusing one type optimizer op. Has fused "
<< result.Get<FusedOptType>(kFusedOptType);
<< result.Get<details::FusedOptType>(details::kFusedOptType);
return;
} else {
result.Set(kFusedOptType, new FusedOptType);
result.Set(details::kFusedOptType, new details::FusedOptType);
}
result.Get<FusedOptType>(kFusedOptType) = fuse_op_type;
result.Get<details::FusedOptType>(details::kFusedOptType) = fuse_op_type;
// Step 2: Insert fused_var_name to FusedVars, and the FusedVars need be
// initialized in scopes before execution.
if (!result.Has(kFusedVars)) {
result.Set(kFusedVars, new FusedVars);
if (!result.Has(details::kFusedVars)) {
result.Set(details::kFusedVars, new details::FusedVars);
}
std::unordered_map<std::string, std::string> fused_vars_name;
fused_vars_name.reserve(aux_var_names.size());
auto &fused_var_set = result.Get<FusedVars>(kFusedVars);
const std::string prefix(kFusedVarNamePrefix);
auto &fused_var_set = result.Get<details::FusedVars>(details::kFusedVars);
const std::string prefix(details::kFusedVarNamePrefix);
// NOTE: the fused_var_name should be unique.
for (auto &var_name : aux_var_names) {
auto fused_var_name = prefix + "_" + fuse_op_type + "_" + var_name + "_" +
......@@ -77,8 +77,9 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
// Step 3: Get the fused Gradient's name
bool grad_fused = false;
if (result.Has(kParamsAndGrads)) {
auto &params_grads = result.Get<ParamsAndGrads>(kParamsAndGrads);
if (result.Has(details::kParamsAndGrads)) {
auto &params_grads =
result.Get<details::ParamsAndGrads>(details::kParamsAndGrads);
PADDLE_ENFORCE_EQ(
params_grads.size(), aux_var_set.at(kGrad).size(),
"The number of gradients and optimizer ops is not equal.");
......@@ -94,13 +95,13 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
// NOTE(zcd): the gradient of kParamsAndGrads may be different with the
// kGrad.
if (same_grad_num == aux_var_set.at(kGrad).size()) {
if (!result.Has(kFusedGrads)) {
if (!result.Has(details::kFusedGrads)) {
PADDLE_THROW(
"The alloc_continuous_space_for_grad_pass should be called before "
"this pass.");
}
auto &fused_grad = result.Get<FusedGrads>(kFusedGrads);
auto &fused_vars = result.Get<FusedVars>(kFusedVars);
auto &fused_grad = result.Get<details::FusedGrads>(details::kFusedGrads);
auto &fused_vars = result.Get<details::FusedVars>(details::kFusedVars);
auto iter = std::find(fused_vars.begin(), fused_vars.end(), fused_grad);
PADDLE_ENFORCE(iter != fused_vars.end(), "Not find the fused_grad.");
fused_vars_name[kGrad] = fused_grad;
......@@ -323,6 +324,6 @@ void FuseOptimizerOpPass::InserInputAndOutputForOptOps(
opt_node->outputs.insert(opt_node->outputs.begin(), outputs.begin(),
outputs.end());
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -25,7 +25,7 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
constexpr char kGrad[] = "Grad";
constexpr char kParam[] = "Param";
......@@ -90,6 +90,6 @@ class FuseOptimizerOpPass : public ir::Pass {
const std::string &fused_var_name) const;
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -14,18 +14,13 @@
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class FuseSgdOpPass : public FuseOptimizerOpPass {
private:
......@@ -66,10 +61,10 @@ class FuseSgdOpPass : public FuseOptimizerOpPass {
InserInputAndOutputForOptOps(sgd_ops, sgd_node);
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_sgd_op_pass, paddle::framework::details::FuseSgdOpPass)
REGISTER_PASS(fuse_sgd_op_pass, paddle::framework::ir::FuseSgdOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
......@@ -13,10 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
......@@ -61,7 +66,16 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
var->outputs.push_back(node);
}
// For output args, always create a new var.
std::unordered_set<std::string> out_arg_set;
for (auto &each_var_name : op->OutputArgumentNames()) {
if (each_var_name != kEmptyVarName) {
PADDLE_ENFORCE(out_arg_set.count(each_var_name) == 0,
"Program is wrong. %s occurs in output of %s several "
"times.",
each_var_name, op->Type());
out_arg_set.insert(each_var_name);
}
ir::Node *var = nullptr;
if (all_vars.count(each_var_name) != 0) {
var = CreateVarNode(all_vars.at(each_var_name));
......
......@@ -1640,7 +1640,8 @@ PDNode *patterns::FillConstantElementWiseMulFuse::operator()(
void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
const std::string &op_type,
const std::string &weight_name,
int times) {
int times,
const std::string &quant_type) {
const int kNumFields = 5;
const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1;
......@@ -1648,24 +1649,22 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4;
// the quant op always be one.
auto quant_op_in_scale =
pattern->NewNode(GetNodeName("quant_op_in_scale"))
->assert_is_op_input("fake_quantize_range_abs_max", "InScale")
->AsInput();
auto quant_op = pattern->NewNode(GetNodeName("quant_op"))
->assert_is_op("fake_quantize_range_abs_max");
auto quant_op_in_scale = pattern->NewNode(GetNodeName("quant_op_in_scale"))
->assert_is_op_input(quant_type, "InScale")
->AsInput();
auto quant_op =
pattern->NewNode(GetNodeName("quant_op"))->assert_is_op(quant_type);
auto quant_op_out_scale =
pattern->NewNode(GetNodeName("quant_op_out_scale"))
->assert_is_op_output("fake_quantize_range_abs_max", "OutScale")
->assert_is_op_output(quant_type, "OutScale")
->assert_is_op_input("fake_dequantize_max_abs", "Scale")
->AsIntermediate();
auto quant_op_out =
pattern->NewNode(GetNodeName("quant_op_out"))
->assert_is_op_output("fake_quantize_range_abs_max", "Out")
->assert_is_op_input(op_type)
->AsIntermediate();
auto quant_op_out = pattern->NewNode(GetNodeName("quant_op_out"))
->assert_is_op_output(quant_type, "Out")
->assert_is_op_input(op_type)
->AsIntermediate();
// there are 'times' quantized and dequant op
std::vector<PDNode *> nodes;
......@@ -1707,6 +1706,37 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
}
}
void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) {
auto reshape1_op =
pattern->NewNode(reshape1_op_repr())->assert_is_op("reshape2");
auto reshape1_out = pattern->NewNode(reshape1_out_repr())
->assert_is_op_output("reshape2", "Out")
->assert_is_op_input("transpose2")
->AsIntermediate();
auto transpose_op =
pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2");
auto transpose_out = pattern->NewNode(transpose_out_repr())
->assert_is_op_output("transpose2", "Out")
->assert_is_op_input("reshape2")
->AsIntermediate();
auto reshape2_op =
pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2");
auto reshape2_out = pattern->NewNode(reshape2_out_repr())
->assert_is_op_output("reshape2", "Out")
->AsOutput();
reshape1_op->LinksFrom({reshape1_in});
reshape1_out->LinksFrom({reshape1_op});
transpose_op->LinksFrom({reshape1_out});
transpose_out->LinksFrom({transpose_op});
reshape2_op->LinksFrom({transpose_out});
reshape2_out->LinksFrom({reshape2_op});
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -880,7 +880,8 @@ struct QuantDequantOpFuse : public PatternBase {
: PatternBase(pattern, name_scope, "quant_dequant_fuse") {}
void operator()(PDNode* quant_op_input, const std::string& op_name,
const std::string& weight_name, int times = 1);
const std::string& weight_name, int times,
const std::string& quant_type);
std::string GetNodeName(const std::string& op_type) {
return PDNodeName(name_scope_, repr_, id_, op_type);
......@@ -891,6 +892,21 @@ struct QuantDequantOpFuse : public PatternBase {
}
};
struct ShuffleChannelPattern : public PatternBase {
ShuffleChannelPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "shufflechannel_pattern") {}
void operator()(PDNode* reshape1_in);
PATTERN_DECL_NODE(reshape1_op);
PATTERN_DECL_NODE(reshape1_out);
PATTERN_DECL_NODE(transpose_op);
PATTERN_DECL_NODE(transpose_out);
PATTERN_DECL_NODE(reshape2_op);
PATTERN_DECL_NODE(reshape2_out);
};
} // namespace patterns
// Link two ir::Nodes from each other.
......
cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base)
cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle)
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
if(WITH_GPU)
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper gpu_info)
else()
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper cpu_info)
endif()
cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass)
cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info)
cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass reference_count_pass_helper)
cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper)
......@@ -27,11 +27,11 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
// op -> variables which can be deleted after op runs
using OpToVarNameSetMap =
std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>>;
using OpToVarNameSetMap = std::unordered_map<details::ComputationOpHandle *,
std::unordered_set<std::string>>;
static std::map<size_t, std::unordered_set<std::string>> VarsGroupByScopeIdx(
const OpToVarNameSetMap &map) {
......@@ -53,7 +53,8 @@ static bool IsLoDTensor(VarDesc *var) {
// Get memory size of LoDTensor
static int64_t GetMemorySize(
const std::unordered_map<std::string, std::vector<VarHandle *>> &vars,
const std::unordered_map<std::string, std::vector<details::VarHandle *>>
&vars,
const std::string &var_name) {
auto *var_desc = TryGetLatestVarDesc(vars.at(var_name));
PADDLE_ENFORCE_NOT_NULL(var_desc);
......@@ -69,13 +70,13 @@ static int64_t GetMemorySize(
// Since partial GC is based on static analysis of memory size of each variable
// So we should skip SelectedRows and LoDTensorArray here
static void SplitIntoLoDTensorAndNonLoDTensorVars(
const OpToVarNameSetMap &m, const GraphVars &vars,
const OpToVarNameSetMap &m, const details::GraphVars &vars,
OpToVarNameSetMap *lod_tensors, OpToVarNameSetMap *other_vars) {
lod_tensors->clear();
other_vars->clear();
for (auto &op_vars_pair : m) {
for (auto &var_name : op_vars_pair.second) {
for (auto var_name : op_vars_pair.second) {
auto *var_desc = TryGetLatestVarDesc(
vars[op_vars_pair.first->GetScopeIdx()].at(var_name));
if (IsLoDTensor(var_desc)) {
......@@ -89,23 +90,24 @@ static void SplitIntoLoDTensorAndNonLoDTensorVars(
struct GCVarInfo {
GCVarInfo(const std::string &name, int64_t memory_size,
ComputationOpHandle *op, size_t scope_idx)
details::ComputationOpHandle *op, size_t scope_idx)
: name_(name),
memory_size_(memory_size),
op_(op),
scope_idx_(scope_idx) {}
std::string name_; // variable name
int64_t memory_size_; // memory size
ComputationOpHandle *op_; // op after which the variable could be deleted
size_t scope_idx_; // scope index where the variable locates
std::string name_; // variable name
int64_t memory_size_; // memory size
details::ComputationOpHandle
*op_; // op after which the variable could be deleted
size_t scope_idx_; // scope index where the variable locates
int64_t AbsMemorySize() const { return std::abs(memory_size_); }
};
// Delete delete_lod_tensor_only is not used currently
static OpToVarNameSetMap ShrinkGCVars(
const OpToVarNameSetMap &m, const GraphVars &vars,
const OpToVarNameSetMap &m, const details::GraphVars &vars,
const std::vector<platform::Place> &places, double fraction_of_memory_size,
bool delete_lod_tensor_only = false) {
// Do not perform gc when fraction_of_memory_size = 0
......@@ -192,7 +194,7 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE(ref_cnts.empty(),
"kRuntimeReferenceCount should be initialized here!");
const auto &vars = graph->Get<GraphVars>(kGraphVars);
const auto &vars = graph->Get<details::GraphVars>(details::kGraphVars);
ref_cnts.resize(vars.size());
const auto &last_live_ops =
......@@ -222,27 +224,31 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
auto *eager_deletion_node =
graph->CreateEmptyNode("eager_deletion", ir::Node::Type::kOperation);
auto *eager_deletion_op = new EagerDeletionOpHandle(
auto *eager_deletion_op = new details::EagerDeletionOpHandle(
eager_deletion_node, op->GetScope(), op->GetPlace(), var_names,
gcs.at(places[op->GetScopeIdx()]).get(),
&(ref_cnts[op->GetScopeIdx()]));
auto it = std::find_if(
op->Outputs().begin(), op->Outputs().end(), [](VarHandleBase *var) {
return dynamic_cast<DummyVarHandle *>(var) != nullptr;
op->Outputs().begin(), op->Outputs().end(),
[](details::VarHandleBase *var) {
return dynamic_cast<details::DummyVarHandle *>(var) != nullptr;
});
if (it != op->Outputs().end()) {
eager_deletion_op->AddInput(*it);
} else {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
auto *dep_var = new details::DummyVarHandle(graph->CreateControlDepVar());
graph->Get<details::GraphDepVars>(details::kGraphDepVars)
.emplace(dep_var);
op->AddOutput(dep_var);
eager_deletion_op->AddInput(dep_var);
}
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
auto *dummy_leaf =
new details::DummyVarHandle(graph->CreateControlDepVar());
graph->Get<details::GraphDepVars>(details::kGraphDepVars)
.emplace(dummy_leaf);
eager_deletion_op->AddOutput(dummy_leaf);
}
......@@ -262,15 +268,14 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
while_op_eager_deletion_pass->Apply(graph);
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(eager_deletion_pass,
paddle::framework::details::EagerDeletionPass)
.RequirePassAttr(paddle::framework::details::kRuntimeReferenceCount)
.RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars)
.RequirePassAttr(paddle::framework::details::kAllPlaces)
.RequirePassAttr(paddle::framework::details::kGarbageCollector);
REGISTER_PASS(eager_deletion_pass, paddle::framework::ir::EagerDeletionPass)
.RequirePassAttr(paddle::framework::ir::kRuntimeReferenceCount)
.RequirePassAttr(paddle::framework::ir::kLastLiveOpsOfVars)
.RequirePassAttr(paddle::framework::ir::kAllPlaces)
.RequirePassAttr(paddle::framework::ir::kGarbageCollector);
USE_PASS(while_op_eager_deletion_pass);
......@@ -16,9 +16,9 @@
#include <queue>
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_info.h"
......@@ -52,7 +52,7 @@ DECLARE_string(memory_optimize_debug);
namespace paddle {
namespace framework {
namespace details {
namespace ir {
// clang-format off
const std::string kInplacedOpWhiteList[] = { // NOLINT
......@@ -111,10 +111,14 @@ class InplacePass : public ir::Pass {
// Check whether all `ops` is the preceding ops of `op`
bool CheckOpDeps(ir::Node *op, const std::vector<ir::Node *> &ops) const;
// Find nodes whose name are equal to the given name
// Find nodes whose names are equal to the given name
static std::unordered_set<ir::Node *> FindNodesByName(
const std::string &name, const std::vector<ir::Node *> &nodes);
// Collect inputs and outputs of op_desc
static void CollectInputArgsOfOpDesc(
const OpDesc *op_desc, std::unordered_multiset<std::string> *in_args);
// Get all versions vars named var_name
std::vector<ir::Node *> *AllVersionVars(const std::string &var_name) const;
......@@ -195,43 +199,12 @@ bool InplacePass::CheckOpDeps(ir::Node *op,
void InplacePass::CollectSkipVars(ir::Graph *graph,
const std::vector<ir::Node *> &ops) const {
// 1. Collect op role vars
PADDLE_ENFORCE(graph->Has(details::kMemOptSkipVars),
"Graph should have attr %s", details::kMemOptSkipVars);
PADDLE_ENFORCE(graph->Has(kMemOptSkipVars), "Graph should have attr %s",
kMemOptSkipVars);
auto &mem_opt_whitelist = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
for (const auto &var : mem_opt_whitelist) {
skip_vars_.emplace(var);
}
// 2. track the nodes which used by parameter server.
// these node can not be inplaced, otherwise trainer
// pserver can not find each other's name.
// Also check the ops which has sub-block
auto update_skip_set = [&](ir::Node *node) {
for (auto &in : node->inputs) {
if (in->IsVar() && in->Var() != nullptr) {
skip_vars_.emplace(in->Name());
}
}
for (auto &out : node->outputs) {
if (out->IsVar() && out->Var() != nullptr) {
skip_vars_.emplace(out->Name());
}
}
};
for (auto *node : ops) {
if (!node->IsOp()) continue;
// avoid optimizing the variable used in sub-blocks
if (OpHasSubBlock(node->Op())) {
update_skip_set(node);
continue;
}
auto node_name = node->Name();
if (node_name == "send" || node_name == "recv" || node_name == "prefetch") {
update_skip_set(node);
}
}
}
void InplacePass::RenameInOut(ir::Node *op, ir::Node *in_var,
......@@ -301,6 +274,14 @@ std::unordered_set<ir::Node *> InplacePass::FindNodesByName(
return ret;
}
void InplacePass::CollectInputArgsOfOpDesc(
const OpDesc *op_desc, std::unordered_multiset<std::string> *in_args) {
in_args->clear();
for (auto &in_name : op_desc->InputArgumentNames()) {
in_args->insert(in_name);
}
}
void InplacePass::ApplyImpl(ir::Graph *graph) const {
// Step 1: topo sort ops, collect skip vars
auto ops = ir::TopologySortOperations(*graph);
......@@ -346,6 +327,11 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
}
auto in_to_outs = infer_inplace(*op_desc, use_cuda);
if (in_to_outs.empty()) continue;
std::unordered_multiset<std::string> all_in_args;
CollectInputArgsOfOpDesc(op_desc, &all_in_args);
for (auto &pair : in_to_outs) {
auto &in_param = pair.first;
auto &out_param = pair.second;
......@@ -387,6 +373,14 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
continue;
}
size_t in_arg_occur_times = all_in_args.count(in_arg);
if (in_arg_occur_times > 1) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " occurs " << in_arg_occur_times << " times in input of op "
<< op_type;
continue;
}
auto in_nodes = FindNodesByName(in_arg, op_node->inputs);
PADDLE_ENFORCE(!in_nodes.empty(), "Input(%s)=%s cannot be found in op %s",
in_param, in_arg, op_type);
......@@ -458,8 +452,7 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
continue;
}
if (details::NodeSize(*in_node->Var()) !=
details::NodeSize(*out_node->Var()) &&
if (NodeSize(*in_node->Var()) != NodeSize(*out_node->Var()) &&
kSameShapeOpWhiteSet.count(op_desc->Type()) == 0) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " is not the same size with "
......@@ -482,9 +475,9 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
}
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(inplace_pass, paddle::framework::details::InplacePass)
.RequirePassAttr(paddle::framework::details::kUseCuda);
REGISTER_PASS(inplace_pass, paddle::framework::ir::InplacePass)
.RequirePassAttr(paddle::framework::ir::kUseCuda);
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include <algorithm>
#include <deque>
#include <functional>
......@@ -32,14 +32,15 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
using paddle::framework::VarDesc;
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph) {
PADDLE_ENFORCE(graph.Has(kStaleProgramOpDescs),
PADDLE_ENFORCE(graph.Has(details::kStaleProgramOpDescs),
"Graph has no attribute of kStaleProgramOpDescs.");
// 1. get op desc order
auto& op_descs = graph.Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs);
auto& op_descs =
graph.Get<const std::vector<OpDesc*>>(details::kStaleProgramOpDescs);
// 2. topology sort order
auto nodes = graph.Nodes();
......@@ -563,6 +564,6 @@ ir::Node* ControlFlowGraph::GetNodeByName(const std::string& name,
return found_node;
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -29,7 +29,7 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
/// this attribute is used to avoid some core variables removed/reused
/// in memory optimize related passes
......@@ -184,6 +184,6 @@ void FilterVariables(const Container& nodes, Callback callback) {
FilterVariableImpl<Container, Callback>()(nodes, callback);
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -11,8 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include <algorithm>
#include <iostream>
#include <iterator>
......@@ -32,7 +31,7 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
TEST(OrderedSet, Normal) {
OrderedSet pool;
......@@ -153,7 +152,7 @@ TEST(OrderedSet, FindBestFitNode) {
ASSERT_TRUE(cache == nullptr);
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -188,7 +187,7 @@ REGISTER_OPERATOR(dummy, paddle::framework::DummyOp,
namespace paddle {
namespace framework {
namespace details {
namespace ir {
inline static ProgramDesc FillProgramDesc() {
ProgramDesc prog;
......@@ -521,6 +520,6 @@ TEST(SortOpLikeDescOrder, AddAndReplaceOpDescInplace) {
}
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_pass.h"
#include <algorithm>
#include <atomic>
#include <deque>
......@@ -42,12 +42,12 @@ DEFINE_string(memory_optimize_debug, "",
namespace paddle {
namespace framework {
namespace details {
namespace ir {
void MemoryOptimizePass::ApplyImpl(ir::Graph* graph) const {
CollectSkipVarsSet(graph);
cfg_.reset(new details::ControlFlowGraph(*graph));
cfg_.reset(new ControlFlowGraph(*graph));
cfg_->LiveVariableAnalysis();
InitSSAGraphNodes();
......@@ -205,30 +205,10 @@ void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {
void MemoryOptimizePass::CollectSkipVarsSet(ir::Graph* graph) const {
// fill skip_set_
PADDLE_ENFORCE(graph->Has(details::kMemOptSkipVars));
PADDLE_ENFORCE(graph->Has(kMemOptSkipVars));
auto& mem_opt_whitelist = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
for (const auto& var : mem_opt_whitelist) skip_set_.emplace(var);
auto update_skip_set = [&](OpDesc* op_desc) {
auto inputs = op_desc->InputArgumentNames();
auto outputs = op_desc->OutputArgumentNames();
skip_set_.insert(inputs.begin(), inputs.end());
skip_set_.insert(outputs.begin(), outputs.end());
};
auto nodes = graph->Nodes();
for (auto& op : nodes) {
if (!op->IsOp() || op->Op() == nullptr) continue;
auto* op_desc = op->Op();
// NOTE(dzhwinter):
// current block can not reuse next level block vars.
if (OpHasSubBlock(op_desc)) update_skip_set(op_desc);
// NOTE(dzhwinter):
// distributed ops input/output name need to
// keep same bettwen trainer/pserver
if (op_desc->Type() == "send") update_skip_set(op_desc);
if (op_desc->Type() == "recv") update_skip_set(op_desc);
if (op_desc->Type() == "prefetch") update_skip_set(op_desc);
for (const auto& var : mem_opt_whitelist) {
skip_set_.emplace(var);
}
}
......@@ -336,10 +316,9 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
}
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(memory_optimize_pass,
paddle::framework::details::MemoryOptimizePass)
REGISTER_PASS(memory_optimize_pass, paddle::framework::ir::MemoryOptimizePass)
.RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
......@@ -26,13 +26,13 @@
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class MemoryOptimizePass : public ir::Pass {
protected:
......@@ -67,6 +67,6 @@ class MemoryOptimizePass : public ir::Pass {
mutable std::map<std::string, std::vector<ir::Node*>> var_nodes_;
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -12,17 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h"
#include <queue>
#include <utility>
namespace paddle {
namespace framework {
namespace details {
namespace ir {
OpGraphView::OpGraphView(const std::vector<OpHandleBase *> &ops) { Build(ops); }
OpGraphView::OpGraphView(const std::vector<details::OpHandleBase *> &ops) {
Build(ops);
}
void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
void OpGraphView::Build(const std::vector<details::OpHandleBase *> &ops) {
preceding_ops_.clear();
pending_ops_.clear();
for (auto &op : ops) {
......@@ -40,8 +42,8 @@ void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
"There are duplicate ops in graph.");
}
std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
std::unordered_set<OpHandleBase *> ret;
std::unordered_set<details::OpHandleBase *> OpGraphView::AllOps() const {
std::unordered_set<details::OpHandleBase *> ret;
ret.reserve(preceding_ops_.size());
for (auto &pair : preceding_ops_) {
ret.insert(pair.first);
......@@ -49,21 +51,21 @@ std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
return ret;
}
bool OpGraphView::HasOp(OpHandleBase *op) const {
bool OpGraphView::HasOp(details::OpHandleBase *op) const {
return preceding_ops_.count(op) != 0;
}
void OpGraphView::EnforceHasOp(OpHandleBase *op) const {
void OpGraphView::EnforceHasOp(details::OpHandleBase *op) const {
PADDLE_ENFORCE(HasOp(op), "Cannot find op %s in OpGraphView",
op == nullptr ? "nullptr" : op->DebugString());
}
const std::unordered_set<OpHandleBase *> &OpGraphView::PendingOps(
OpHandleBase *op) const {
const std::unordered_set<details::OpHandleBase *> &OpGraphView::PendingOps(
details::OpHandleBase *op) const {
EnforceHasOp(op);
return pending_ops_.at(op);
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -22,39 +22,42 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class OpGraphView {
public:
explicit OpGraphView(const std::vector<OpHandleBase *> &ops);
explicit OpGraphView(const std::vector<details::OpHandleBase *> &ops);
std::unordered_set<OpHandleBase *> AllOps() const;
std::unordered_set<details::OpHandleBase *> AllOps() const;
const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const;
const std::unordered_set<details::OpHandleBase *> &PendingOps(
details::OpHandleBase *op) const;
bool HasOp(OpHandleBase *op) const;
bool HasOp(details::OpHandleBase *op) const;
// Use a visitor to visit all pending ops of op
// Stop when callback returns false
template <typename Callback>
bool VisitAllPendingOps(OpHandleBase *op, Callback &&callback) const;
bool VisitAllPendingOps(details::OpHandleBase *op, Callback &&callback) const;
private:
void Build(const std::vector<OpHandleBase *> &ops);
void EnforceHasOp(OpHandleBase *op) const;
void Build(const std::vector<details::OpHandleBase *> &ops);
void EnforceHasOp(details::OpHandleBase *op) const;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
std::unordered_map<details::OpHandleBase *,
std::unordered_set<details::OpHandleBase *>>
preceding_ops_;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
std::unordered_map<details::OpHandleBase *,
std::unordered_set<details::OpHandleBase *>>
pending_ops_;
};
template <typename Callback>
bool OpGraphView::VisitAllPendingOps(OpHandleBase *op,
bool OpGraphView::VisitAllPendingOps(details::OpHandleBase *op,
Callback &&callback) const {
EnforceHasOp(op);
std::unordered_set<OpHandleBase *> visited;
std::queue<OpHandleBase *> q;
std::unordered_set<details::OpHandleBase *> visited;
std::queue<details::OpHandleBase *> q;
q.push(op);
while (!q.empty()) {
op = q.front();
......@@ -72,6 +75,6 @@ bool OpGraphView::VisitAllPendingOps(OpHandleBase *op,
return true;
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
namespace ir {
class RecordSkipMemoryOptVarsPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override {
PADDLE_ENFORCE(!graph->Has(kMemOptSkipVars));
graph->Set(kMemOptSkipVars, new MemOptSkipVars);
auto& skip_vars = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
std::vector<ir::Node*> op_nodes;
for (auto& node : graph->Nodes()) {
PADDLE_ENFORCE_NOT_NULL(node, "The node should not be nullptr.");
if (node->IsOp() && node->Op()) {
op_nodes.emplace_back(node);
}
}
// Insert kEmptyVarName to avoid optimizing empty variable
skip_vars.insert(framework::kEmptyVarName);
// NOTE(zcd): Insert OpRoleVars to SkipVarSet to prevent the vars are rename
// in memory optimize pass.
InsertOpRoleVarsToSkipVarSet(op_nodes, &skip_vars);
InsertSkipMemOptOpInOutToSkipVarSet(op_nodes, &skip_vars);
}
private:
static void InsertOpRoleVarsToSkipVarSet(const std::vector<ir::Node*>& ops,
MemOptSkipVars* skip_vars) {
for (auto& node : ops) {
try {
auto op_role_vars =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(op_role_vars.size() % 2, 0);
for (size_t i = 0; i < op_role_vars.size(); i += 2) {
auto& g_name = op_role_vars[i + 1];
skip_vars->insert(g_name);
}
} catch (boost::bad_get& e) {
}
}
}
static void UpdateSkipVarSet(
MemOptSkipVars* skip_vars,
const std::vector<std::vector<std::string>>& var_names) {
for (auto& var_name : var_names) {
skip_vars->insert(var_name.begin(), var_name.end());
}
}
static std::vector<std::string> ToGradVarName(
const std::vector<std::string>& names) {
std::vector<std::string> ret;
ret.reserve(names.size());
for (auto& name : names) {
if (name != framework::kEmptyVarName) {
ret.emplace_back(framework::GradVarName(name));
}
}
return ret;
}
static void InsertSkipMemOptOpInOutToSkipVarSet(
const std::vector<ir::Node*>& ops, MemOptSkipVars* skip_vars) {
static std::unordered_set<std::string> kSkipMemOptOps{
"send", "recv", "prefetch", "send_barrier", "fetch_barrier"};
for (auto& node : ops) {
auto* op_desc = node->Op();
// Some ops (while, conditional_block, recurrent, etc.) have sub-blocks.
// These ops often use variables from its parent or forward blocks.
// Optimizing in/out of such ops would make these variables cannot
// be found when running sub-block ops.
if (OpHasSubBlock(op_desc)) {
UpdateSkipVarSet(skip_vars, {op_desc->InputArgumentNames(),
op_desc->OutputArgumentNames()});
}
// Skip ops that are related to parameter server.
// In distributed mode, trainers and parameter server use same
// variable names to track same variables. We cannot change the
// names of these variables, otherwise trainers or parameter
// server would not find them.
if (kSkipMemOptOps.count(op_desc->Type()) > 0) {
UpdateSkipVarSet(skip_vars, {op_desc->InputArgumentNames(),
op_desc->OutputArgumentNames()});
}
// FIXME(zjl): some ops use variables that are not from their
// inputs or outputs. We do not have a nice method to solve this
// issue yet. Currently, we should skip these variables when
// memory optimization is enabled.
auto op_type = op_desc->Type();
if (op_type == "while_grad") {
// In while_grad, framework::GradVarName(Input("X")) is visited
// without being any in/out of while_grad. While_grad uses
// these variable to accumulate gradient of X across time steps.
UpdateSkipVarSet(skip_vars, {ToGradVarName(op_desc->Input("X"))});
} else if (op_type == "conditional_block_grad") {
// In conditional_block_grad, framework::GradVarName(Input("Input",
// "Cond")) is visited without being any in/out of
// conditional_block_grad. Conditional_block_grad uses these
// variables to accumulate gradient of Input/Cond across time steps.
UpdateSkipVarSet(skip_vars, {ToGradVarName(op_desc->Input("Input")),
ToGradVarName(op_desc->Input("Cond"))});
} else if (op_type == "recurrent" || op_type == "recurrent_grad") {
// Recurrent and recurrent_grad ops are implemented by a very trickly
// way. Attr("states", "ex_states") is visited without being any
// in/out of op. It is because these variables are from sub blocks,
// not main block. Adding these variables to input would make recurrent
// fail since "states" and "ex_states" cannot be found in main block.
// When memory optimization is enabled, "states", "ex_states" and their
// gradient should be skipped.
auto& ex_states =
boost::get<std::vector<std::string>>(op_desc->GetAttr("ex_states"));
auto& states =
boost::get<std::vector<std::string>>(op_desc->GetAttr("states"));
if (op_type == "recurrent") {
UpdateSkipVarSet(skip_vars, {ex_states, states});
} else {
// In recurrent_grad, framework::GradVarName(Input("parameters",
// "input")) is visited without being any in/out of recurrent_grad.
// Recurrent_grad uses these variables to accumulate gradient of
// parameters/input across time steps.
UpdateSkipVarSet(
skip_vars,
{ToGradVarName(op_desc->Input("parameters")),
ToGradVarName(op_desc->Input("input")), ex_states, states,
ToGradVarName(ex_states), ToGradVarName(states)});
}
}
}
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(record_skip_memory_opt_vars_pass,
paddle::framework::ir::RecordSkipMemoryOptVarsPass);
......@@ -24,14 +24,20 @@
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class ReferenceCountPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override;
};
// A functor to shrink/remove operators who depend on other operators in a set
class ShrinkDepsOpFunctor {
......@@ -39,19 +45,21 @@ class ShrinkDepsOpFunctor {
enum RelationShip { kSame = 0, kNoDeps = 1, kBefore = 2, kAfter = 3 };
public:
explicit ShrinkDepsOpFunctor(const std::vector<OpHandleBase *> &all_ops)
explicit ShrinkDepsOpFunctor(
const std::vector<details::OpHandleBase *> &all_ops)
: graph_(all_ops) {}
template <typename OpSet>
OpSet operator()(const OpSet &op_set) const {
using KeyType = typename OpSet::key_type;
static_assert(
std::is_base_of<OpHandleBase,
std::is_base_of<details::OpHandleBase,
typename std::remove_pointer<KeyType>::type>::value,
"Key type of OpSet must be OpHandleBase, or derived of OpHandleBase");
"Key type of OpSet must be details::OpHandleBase, or derived of "
"details::OpHandleBase");
if (op_set.size() <= 1) return op_set;
std::vector<OpHandleBase *> ops(op_set.begin(), op_set.end());
std::vector<details::OpHandleBase *> ops(op_set.begin(), op_set.end());
OpSet ret;
auto rels = GetRelations(ops);
auto not_before = [](RelationShip r) { return r != kBefore; };
......@@ -65,8 +73,8 @@ class ShrinkDepsOpFunctor {
private:
std::vector<std::vector<RelationShip>> GetRelations(
const std::vector<OpHandleBase *> &ops) const {
std::unordered_map<OpHandleBase *, size_t> op_to_idx;
const std::vector<details::OpHandleBase *> &ops) const {
std::unordered_map<details::OpHandleBase *, size_t> op_to_idx;
for (size_t i = 0; i < ops.size(); ++i) {
PADDLE_ENFORCE(graph_.HasOp(ops[i]), "Op does not exist in graph");
op_to_idx[ops[i]] = i;
......@@ -81,7 +89,7 @@ class ShrinkDepsOpFunctor {
size_t found_num = ops.size();
size_t total_num = ops.size() * ops.size();
auto visitor = [&](OpHandleBase *op, size_t i) {
auto visitor = [&](details::OpHandleBase *op, size_t i) {
auto it = op_to_idx.find(op);
if (it != op_to_idx.end()) {
size_t j = it->second;
......@@ -98,7 +106,9 @@ class ShrinkDepsOpFunctor {
};
for (size_t i = 0; i < ops.size(); ++i) {
auto sub_visitor = [&, i](OpHandleBase *op) { return visitor(op, i); };
auto sub_visitor = [&, i](details::OpHandleBase *op) {
return visitor(op, i);
};
if (!graph_.VisitAllPendingOps(ops[i], sub_visitor)) {
break;
}
......@@ -133,8 +143,8 @@ class ShrinkDepsOpFunctor {
*/
static bool ShrinkNoNeedBufferVarOpDependency(
const std::string &var_name,
std::unordered_set<ComputationOpHandle *> *op_handles) {
std::vector<ComputationOpHandle *> skip_ops;
std::unordered_set<details::ComputationOpHandle *> *op_handles) {
std::vector<details::ComputationOpHandle *> skip_ops;
for (auto *op_handle : *op_handles) {
auto *op_base = op_handle->GetOp();
auto &inferer = op_base->Info().NoNeedBufferVarsInferer();
......@@ -195,15 +205,15 @@ static bool ShrinkNoNeedBufferVarOpDependency(
* Find the nearest downstream computation op handle. If the op is a
* computation op, just return itself.
*/
static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
OpHandleBase *op, size_t scope_idx) {
std::queue<OpHandleBase *> q;
std::unordered_set<OpHandleBase *> visited;
static details::ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
details::OpHandleBase *op, size_t scope_idx) {
std::queue<details::OpHandleBase *> q;
std::unordered_set<details::OpHandleBase *> visited;
q.push(op);
while (!q.empty()) {
auto *op = q.front();
q.pop();
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
auto *compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) {
return compute_op;
}
......@@ -220,13 +230,13 @@ static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
enum LastLiveOpSearchStatus { kSuccess, kFailure, kShouldPrecede };
static std::unordered_set<ComputationOpHandle *>
ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
static std::unordered_set<details::ComputationOpHandle *>
ExtractComputationOpFromLastLivedVar(details::VarHandle *var, size_t scope_idx,
const std::string &var_name,
const ShrinkDepsOpFunctor &shrink_func,
LastLiveOpSearchStatus *status) {
// stage one. Get last op for variable.
std::unordered_set<OpHandleBase *> candidates;
std::unordered_set<details::OpHandleBase *> candidates;
{
if (var->PendingOps().empty() && var->GeneratedOp()) {
// No operator depends on this variable. So the last operator is the op
......@@ -251,7 +261,7 @@ ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
// some op handle may operate on many DeviceContext, however, our garbage
// collector can only wait one DeviceContext for now. So currently, we wait
// the nearest compute op.
std::unordered_set<ComputationOpHandle *> computation_op;
std::unordered_set<details::ComputationOpHandle *> computation_op;
{
for (auto *op : candidates) {
auto *compute_op =
......@@ -293,13 +303,13 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
"Last Live Ops and Reference Counts of vars should be "
"initialized at here.");
const auto &vars = graph->Get<GraphVars>(kGraphVars);
const auto &vars = graph->Get<details::GraphVars>(details::kGraphVars);
last_live_ops_of_vars.resize(vars.size());
ref_cnts.resize(vars.size());
ShrinkDepsOpFunctor shrink_func(
ir::FilterByNodeWrapper<OpHandleBase>(*graph));
ir::FilterByNodeWrapper<details::OpHandleBase>(*graph));
VLOG(1) << "Place number: " << vars.size();
for (size_t i = 0; i < vars.size(); ++i) {
......@@ -360,11 +370,10 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
}
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(reference_count_pass,
paddle::framework::details::ReferenceCountPass)
.RequirePassAttr(paddle::framework::details::kGlobalReferenceCount)
.RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars);
REGISTER_PASS(reference_count_pass, paddle::framework::ir::ReferenceCountPass)
.RequirePassAttr(paddle::framework::ir::kGlobalReferenceCount)
.RequirePassAttr(paddle::framework::ir::kLastLiveOpsOfVars);
......@@ -12,23 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/var_desc.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars) {
VarDesc *TryGetLatestVarDesc(const std::vector<details::VarHandle *> &vars) {
VarDesc *var_desc = nullptr;
std::find_if(vars.rbegin(), vars.rend(), [&](VarHandle *var_handle) -> bool {
var_desc = var_handle->Node()->Var();
return var_desc != nullptr;
});
std::find_if(vars.rbegin(), vars.rend(),
[&](details::VarHandle *var_handle) -> bool {
var_desc = var_handle->Node()->Var();
return var_desc != nullptr;
});
return var_desc;
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -22,17 +22,16 @@
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/garbage_collector.h"
namespace paddle {
namespace framework {
class VarDesc;
class VarHandle;
namespace details {
class ComputationOpHandle;
namespace ir {
using ReferenceCountMap = std::unordered_map<std::string, size_t>;
......@@ -48,11 +47,12 @@ const char kGarbageCollector[] = "garbage_collector";
const char kAllPlaces[] = "all_places";
using LastLiveOpsOfVars =
std::unordered_map<std::string, std::unordered_set<ComputationOpHandle *>>;
std::unordered_map<std::string,
std::unordered_set<details::ComputationOpHandle *>>;
const char kLastLiveOpsOfVars[] = "last_live_ops_of_var";
VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars);
VarDesc *TryGetLatestVarDesc(const std::vector<details::VarHandle *> &vars);
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -19,19 +19,19 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class WhileOpEagerDeletionPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
auto all_ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph);
// Find all while_op and while_grad_op
std::unordered_map<size_t, std::pair<std::vector<OperatorBase *>,
std::vector<OperatorBase *>>>
target_ops;
for (auto *op : all_ops) {
auto compute_op = dynamic_cast<ComputationOpHandle *>(op);
auto compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
if (compute_op == nullptr) continue;
if (compute_op->Name() == "while") {
......@@ -52,9 +52,9 @@ class WhileOpEagerDeletionPass : public ir::Pass {
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(while_op_eager_deletion_pass,
paddle::framework::details::WhileOpEagerDeletionPass);
paddle::framework::ir::WhileOpEagerDeletionPass);
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper)
cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper)
set(ALL_REDUCE_OP_HANDLES all_reduce_op_handle)
if(WITH_GPU AND WITH_DGC)
list(APPEND ALL_REDUCE_OP_HANDLES sparse_all_reduce_op_handle)
endif()
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle ${ALL_REDUCE_OP_HANDLES} reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS all_reduce_op_handle graph graph_helper pass)
......@@ -23,7 +23,6 @@
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
......@@ -31,17 +30,18 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class AllReduceDepsPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override {
std::vector<AllReduceOpHandle*> all_reduce_op_handles =
std::vector<details::AllReduceOpHandle*> all_reduce_op_handles =
GetSortedAllReduceOps(*graph);
for (size_t i = 1; i < all_reduce_op_handles.size(); ++i) {
auto* dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
auto* dep_var = new details::DummyVarHandle(graph->CreateControlDepVar());
graph->Get<details::GraphDepVars>(details::kGraphDepVars)
.emplace(dep_var);
all_reduce_op_handles[i - 1]->AddOutput(dep_var);
all_reduce_op_handles[i]->AddInput(dep_var);
}
......@@ -51,16 +51,16 @@ class AllReduceDepsPass : public ir::Pass {
}
}
std::vector<AllReduceOpHandle*> GetSortedAllReduceOps(
std::vector<details::AllReduceOpHandle*> GetSortedAllReduceOps(
const ir::Graph& graph) const {
std::vector<AllReduceOpHandle*> all_reduce_op_handles;
std::unordered_map<OpHandleBase*, size_t> pending_ops;
std::unordered_set<OpHandleBase*> ready_ops;
std::unordered_set<OpHandleBase*> next_ready_ops;
std::vector<details::AllReduceOpHandle*> all_reduce_op_handles;
std::unordered_map<details::OpHandleBase*, size_t> pending_ops;
std::unordered_set<details::OpHandleBase*> ready_ops;
std::unordered_set<details::OpHandleBase*> next_ready_ops;
auto op_handles = ir::FilterByNodeWrapper<OpHandleBase>(graph);
auto op_handles = ir::FilterByNodeWrapper<details::OpHandleBase>(graph);
size_t num_of_ops = op_handles.size();
for (OpHandleBase* op : op_handles) {
for (details::OpHandleBase* op : op_handles) {
size_t not_ready_vars = op->NotReadyInputSize();
if (not_ready_vars) {
pending_ops.insert({op, not_ready_vars});
......@@ -94,11 +94,12 @@ class AllReduceDepsPass : public ir::Pass {
}
void GetSortedAllReduceOps(
const std::unordered_set<OpHandleBase*>& ready_ops,
std::vector<AllReduceOpHandle*>* all_reduce_op_handles) const {
std::vector<AllReduceOpHandle*> current_all_reduce_op_handles;
const std::unordered_set<details::OpHandleBase*>& ready_ops,
std::vector<details::AllReduceOpHandle*>* all_reduce_op_handles) const {
std::vector<details::AllReduceOpHandle*> current_all_reduce_op_handles;
for (auto& op_handle : ready_ops) {
auto all_reduce_op_handle = dynamic_cast<AllReduceOpHandle*>(op_handle);
auto all_reduce_op_handle =
dynamic_cast<details::AllReduceOpHandle*>(op_handle);
if (all_reduce_op_handle) {
current_all_reduce_op_handles.emplace_back(all_reduce_op_handle);
}
......@@ -109,10 +110,12 @@ class AllReduceDepsPass : public ir::Pass {
// Sort the current_all_reduce_op_handles according to the name of input.
sort(current_all_reduce_op_handles.begin(),
current_all_reduce_op_handles.end(),
[](const AllReduceOpHandle* left,
const AllReduceOpHandle* right) -> bool {
auto left_in_vars = DynamicCast<VarHandle>(left->Inputs());
auto right_in_vars = DynamicCast<VarHandle>(right->Inputs());
[](const details::AllReduceOpHandle* left,
const details::AllReduceOpHandle* right) -> bool {
auto left_in_vars =
details::DynamicCast<details::VarHandle>(left->Inputs());
auto right_in_vars =
details::DynamicCast<details::VarHandle>(right->Inputs());
PADDLE_ENFORCE_GT(left_in_vars.size(), 0);
PADDLE_ENFORCE_EQ(left_in_vars.size(), right_in_vars.size());
return left_in_vars[0]->Name() > right_in_vars[0]->Name();
......@@ -123,15 +126,15 @@ class AllReduceDepsPass : public ir::Pass {
current_all_reduce_op_handles.end());
}
void DebugString(
const ir::Graph& graph,
const std::vector<AllReduceOpHandle*>& all_reduce_op_handles) const {
void DebugString(const ir::Graph& graph,
const std::vector<details::AllReduceOpHandle*>&
all_reduce_op_handles) const {
// get vars order
std::map<int, std::vector<std::string>> vars =
GetSoredGradientsFromStaleProgram(graph);
std::stringstream out;
size_t grads_of_stale_program = 0;
out << "Get Order From kStaleProgramOpDescs: ";
out << "Get Order From details::kStaleProgramOpDescs: ";
for (auto& var : vars) {
out << "Order " << var.first << " [";
for (auto& var_name : var.second) {
......@@ -147,7 +150,7 @@ class AllReduceDepsPass : public ir::Pass {
for (auto& op : all_reduce_op_handles) {
bool find_valid_input = false;
for (auto& in_var : op->Inputs()) {
if (dynamic_cast<VarHandle*>(in_var)) {
if (dynamic_cast<details::VarHandle*>(in_var)) {
out2 << in_var->Name() << ", ";
find_valid_input = true;
break;
......@@ -165,7 +168,8 @@ class AllReduceDepsPass : public ir::Pass {
std::map<int, std::vector<std::string>> GetSoredGradientsFromStaleProgram(
const ir::Graph& graph) const {
std::map<int, std::vector<std::string>> vars;
auto ops = graph.Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs);
auto ops =
graph.Get<const std::vector<OpDesc*>>(details::kStaleProgramOpDescs);
int order = 0;
for (auto* op_desc : ops) {
try {
......@@ -193,10 +197,9 @@ class AllReduceDepsPass : public ir::Pass {
return vars;
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(all_reduce_deps_pass,
paddle::framework::details::AllReduceDepsPass)
REGISTER_PASS(all_reduce_deps_pass, paddle::framework::ir::AllReduceDepsPass)
.RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
......@@ -24,21 +24,22 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class FuseAllReduceOpPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(kLocalScopes);
auto &places = Get<const std::vector<platform::Place>>(details::kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(details::kLocalScopes);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto *nccl_ctxs = &Get<platform::NCCLContextMap>(kNCCLCtxs);
auto *nccl_ctxs = &Get<platform::NCCLContextMap>(details::kNCCLCtxs);
#endif
std::unordered_set<std::string> grads;
auto &params_grads = result.Get<ParamsAndGrads>(kParamsAndGrads);
auto &params_grads =
result.Get<details::ParamsAndGrads>(details::kParamsAndGrads);
size_t num_of_all_reduce = params_grads.size();
grads.reserve(num_of_all_reduce);
for (auto p_g : params_grads) {
......@@ -50,11 +51,12 @@ class FuseAllReduceOpPass : public ir::Pass {
all_reduce_ops.reserve(grads.size());
for (auto &node : result.Nodes()) {
if (node->IsOp()) {
PADDLE_ENFORCE(node->IsWrappedBy<OpHandleBase>());
auto *all_reduce_op_handle =
dynamic_cast<AllReduceOpHandle *>(&node->Wrapper<OpHandleBase>());
PADDLE_ENFORCE(node->IsWrappedBy<details::OpHandleBase>());
auto *all_reduce_op_handle = dynamic_cast<details::AllReduceOpHandle *>(
&node->Wrapper<details::OpHandleBase>());
if (all_reduce_op_handle) {
auto inputs = DynamicCast<VarHandle>(all_reduce_op_handle->Inputs());
auto inputs = details::DynamicCast<details::VarHandle>(
all_reduce_op_handle->Inputs());
PADDLE_ENFORCE_EQ(inputs.size(), num_place);
// The inputs' name should be the same.
auto &grad_name = inputs[0]->name();
......@@ -80,7 +82,7 @@ class FuseAllReduceOpPass : public ir::Pass {
VLOG(10) << "Insert fused_all_reduce";
auto &group_grads_params =
graph->Get<GroupGradsAndParams>(kGroupGradsAndParams);
graph->Get<details::GroupGradsAndParams>(details::kGroupGradsAndParams);
for (auto &group_g_p : group_grads_params) {
size_t group_size = group_g_p.size();
......@@ -108,24 +110,25 @@ class FuseAllReduceOpPass : public ir::Pass {
const platform::NCCLContextMap *nccl_ctxs,
#endif
ir::Graph *result) const {
std::vector<VarHandleBase *> inputs;
std::vector<VarHandleBase *> outputs;
std::vector<details::VarHandleBase *> inputs;
std::vector<details::VarHandleBase *> outputs;
for (auto &op : all_reduce_ops) {
auto &op_handle = op->Wrapper<OpHandleBase>();
auto &op_handle = op->Wrapper<details::OpHandleBase>();
inputs.insert(inputs.end(), op_handle.Inputs().begin(),
op_handle.Inputs().end());
// Remove output
for_each(op_handle.Inputs().begin(), op_handle.Inputs().end(),
[&op_handle](VarHandleBase *var_handle) {
[&op_handle](details::VarHandleBase *var_handle) {
var_handle->RemoveOutput(&op_handle, op_handle.Node());
});
outputs.insert(outputs.end(), op_handle.Outputs().begin(),
op_handle.Outputs().end());
// Remove Input
for_each(
op_handle.Outputs().begin(), op_handle.Outputs().end(),
[](VarHandleBase *var_handle) { var_handle->ClearGeneratedOp(); });
for_each(op_handle.Outputs().begin(), op_handle.Outputs().end(),
[](details::VarHandleBase *var_handle) {
var_handle->ClearGeneratedOp();
});
result->RemoveNode(op_handle.Node());
}
......@@ -140,21 +143,22 @@ class FuseAllReduceOpPass : public ir::Pass {
}
private:
void CreateFusedAllReduceOp(const std::vector<VarHandleBase *> &inputs,
const std::vector<VarHandleBase *> &outputs,
const size_t num_of_all_reduce,
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
void CreateFusedAllReduceOp(
const std::vector<details::VarHandleBase *> &inputs,
const std::vector<details::VarHandleBase *> &outputs,
const size_t num_of_all_reduce,
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const platform::NCCLContextMap *nccl_ctxs,
const platform::NCCLContextMap *nccl_ctxs,
#endif
ir::Graph *result) const {
ir::Graph *result) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto *op_handle = new FusedAllReduceOpHandle(
auto *op_handle = new details::FusedAllReduceOpHandle(
result->CreateEmptyNode("fused_all_reduce", ir::Node::Type::kOperation),
local_scopes, places, num_of_all_reduce, nccl_ctxs);
#else
auto *op_handle = new FusedAllReduceOpHandle(
auto *op_handle = new details::FusedAllReduceOpHandle(
result->CreateEmptyNode("fused_all_reduce", ir::Node::Type::kOperation),
local_scopes, places, num_of_all_reduce);
#endif
......@@ -176,8 +180,9 @@ class FuseAllReduceOpPass : public ir::Pass {
#endif
}
void SetCommunicationContext(const std::vector<platform::Place> &places,
FusedAllReduceOpHandle *op_handle) const {
void SetCommunicationContext(
const std::vector<platform::Place> &places,
details::FusedAllReduceOpHandle *op_handle) const {
for (size_t i = 0; i < places.size(); ++i) {
op_handle->SetDeviceContext(
places[i], platform::DeviceContextPool::Instance().Get(places[i]));
......@@ -185,9 +190,9 @@ class FuseAllReduceOpPass : public ir::Pass {
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_all_reduce_op_pass,
paddle::framework::details::FuseAllReduceOpPass);
paddle::framework::ir::FuseAllReduceOpPass);
......@@ -12,21 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
static bool IsLockAndRecordEventFreeComputationOpHandle(
ComputationOpHandle *op, const OpGraphView &graph_view) {
details::ComputationOpHandle *op, const OpGraphView &graph_view) {
if (!platform::is_gpu_place(op->GetPlace())) return false;
for (auto &pending_op : graph_view.PendingOps(op)) {
auto *tmp = dynamic_cast<ComputationOpHandle *>(pending_op);
auto *tmp = dynamic_cast<details::ComputationOpHandle *>(pending_op);
if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) {
return false;
}
......@@ -34,25 +33,27 @@ static bool IsLockAndRecordEventFreeComputationOpHandle(
return true;
}
void ModifyOpLockAndRecordEventPass::ApplyImpl(ir::Graph *ir_graph) const {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*ir_graph);
OpGraphView graph_view(all_ops);
for (auto &op : all_ops) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op == nullptr) continue;
bool is_lock_and_record_event_free =
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view);
compute_op->SetLockAndRecordEventFree(is_lock_and_record_event_free);
if (is_lock_and_record_event_free) {
VLOG(10) << "Set is_lock_and_record_event_free be true in op "
<< compute_op->DebugString();
class ModifyOpLockAndRecordEventPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
auto all_ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph);
OpGraphView graph_view(all_ops);
for (auto &op : all_ops) {
auto *compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
if (compute_op == nullptr) continue;
bool is_lock_and_record_event_free =
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view);
compute_op->SetLockAndRecordEventFree(is_lock_and_record_event_free);
if (is_lock_and_record_event_free) {
VLOG(10) << "Set is_lock_and_record_event_free be true in op "
<< compute_op->DebugString();
}
}
}
}
} // namespace details
};
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(modify_op_lock_and_record_event_pass,
paddle::framework::details::ModifyOpLockAndRecordEventPass);
paddle::framework::ir::ModifyOpLockAndRecordEventPass);
......@@ -19,7 +19,7 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class SSAGraghBuilderWithChecker : public ir::Pass {
protected:
......@@ -28,19 +28,19 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
}
bool IsValidGraph(const ir::Graph *graph) const {
std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_set<VarHandleBase *> pending_vars;
std::unordered_set<VarHandleBase *> ready_vars;
std::unordered_set<OpHandleBase *> ready_ops;
std::unordered_map<details::OpHandleBase *, size_t> pending_ops;
std::unordered_set<details::VarHandleBase *> pending_vars;
std::unordered_set<details::VarHandleBase *> ready_vars;
std::unordered_set<details::OpHandleBase *> ready_ops;
auto insert_pending_var = [&](VarHandleBase *var) {
auto insert_pending_var = [&](details::VarHandleBase *var) {
pending_vars.insert(var);
if (var->GeneratedOp() == nullptr) {
ready_vars.emplace(var);
}
};
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &var_map : graph->Get<details::GraphVars>(details::kGraphVars)) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair);
......@@ -48,11 +48,12 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
}
}
for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) {
for (auto &var :
graph->Get<details::GraphDepVars>(details::kGraphDepVars)) {
insert_pending_var(var);
}
for (OpHandleBase *op : ir::FilterByNodeWrapper<OpHandleBase>(*graph)) {
for (auto *op : ir::FilterByNodeWrapper<details::OpHandleBase>(*graph)) {
if (op->Inputs().empty()) {
ready_ops.insert(op);
} else {
......@@ -60,7 +61,7 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
}
}
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
auto run_all_ops = [&](std::unordered_set<details::OpHandleBase *> &set) {
for (auto *op : set) {
for (auto out : op->Outputs()) {
ready_vars.emplace(out);
......@@ -91,11 +92,11 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(multi_devices_check_pass,
paddle::framework::details::SSAGraghBuilderWithChecker)
paddle::framework::ir::SSAGraghBuilderWithChecker)
.RequireGraphAttr(paddle::framework::details::kGraphVars)
.RequireGraphAttr(paddle::framework::details::kGraphDepVars);
......@@ -31,7 +31,7 @@ class NCCLContextMap;
namespace framework {
class Scope;
namespace details {
namespace ir {
constexpr char kLossVarName[] = "loss_var_name";
constexpr char kStrategy[] = "strategy";
......@@ -69,8 +69,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
ir::Node *out_var_node, size_t loss_scale,
proto::VarType::Type dtype) const;
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
size_t dst_dev_id) const;
details::VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
size_t dst_dev_id) const;
void CreateComputationalOp(ir::Graph *result, ir::Node *node,
size_t dev_id) const;
......@@ -89,7 +89,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
ir::Graph *result,
const std::vector<std::unordered_set<std::string>> &bcast_varnames) const;
void SetCommunicationContext(OpHandleBase *op_handle,
void SetCommunicationContext(details::OpHandleBase *op_handle,
const platform::Place &p) const;
void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
......@@ -103,7 +103,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
mutable std::vector<platform::Place> places_;
mutable std::vector<Scope *> local_scopes_;
mutable BuildStrategy strategy_;
mutable details::BuildStrategy strategy_;
mutable std::unordered_map<std::string, VarDesc *> all_vars_;
};
......@@ -209,6 +209,6 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
std::unordered_set<std::string> &MultiDevSSAGraphBuilder();
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
#include <memory>
#include <string>
#include <unordered_map>
......@@ -21,11 +21,21 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class SSAGraghBuilderWithPrinterPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>(kGraphvizPath)));
PADDLE_ENFORCE(fout->good());
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
}
};
template <typename Callback>
static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
for (auto &each : graph.Get<GraphVars>(kGraphVars)) {
for (auto &each : graph.Get<details::GraphVars>(details::kGraphVars)) {
for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) {
callback(*pair2);
......@@ -33,7 +43,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
}
}
for (auto &var : graph.Get<GraphDepVars>(kGraphDepVars)) {
for (auto &var : graph.Get<details::GraphDepVars>(details::kGraphDepVars)) {
callback(*var);
}
}
......@@ -41,14 +51,14 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
std::ostream &sout) const {
size_t var_id = 0;
std::unordered_map<const VarHandleBase *, size_t> vars;
std::unordered_map<const details::VarHandleBase *, size_t> vars;
sout << "digraph G {\n";
IterAllVar(graph, [&](const VarHandleBase &var) {
IterAllVar(graph, [&](const details::VarHandleBase &var) {
auto *var_ptr = &var;
auto *var_handle_ptr = dynamic_cast<const VarHandle *>(var_ptr);
auto *dummy_ptr = dynamic_cast<const DummyVarHandle *>(var_ptr);
auto *var_handle_ptr = dynamic_cast<const details::VarHandle *>(var_ptr);
auto *dummy_ptr = dynamic_cast<const details::DummyVarHandle *>(var_ptr);
size_t cur_var_id = var_id++;
vars[var_ptr] = cur_var_id;
......@@ -65,7 +75,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
});
size_t op_id = 0;
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(graph)) {
for (auto &op : ir::FilterByNodeWrapper<details::OpHandleBase>(graph)) {
std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl;
......@@ -82,10 +92,10 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
sout << "}\n";
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(multi_devices_print_pass,
paddle::framework::details::SSAGraghBuilderWithPrinter)
.RequirePassAttr(paddle::framework::details::kGraphvizPath);
paddle::framework::ir::SSAGraghBuilderWithPrinterPass)
.RequirePassAttr(paddle::framework::ir::kGraphvizPath);
......@@ -24,7 +24,7 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
constexpr char kGraphvizPath[] = "debug_graphviz_path";
......@@ -39,16 +39,6 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
void Print(const ir::Graph& graph, std::ostream& sout) const override;
};
class SSAGraghBuilderWithPrinter : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override {
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>(kGraphvizPath)));
PADDLE_ENFORCE(fout->good());
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace ir {
static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) {
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
op1->Outputs() == op2->Outputs();
}
class SequentialExecutionPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
// FIXME(zjl): Insert dependencies between some distributed ops may cause
// the multi_devices_graph_pass fails. So we skip these ops here.
// Indeed, maybe we should not insert dependencies between these ops
// casually, which may cause deadlock easily.
// We should add more skipped distributed ops when found errors in
// multi_devices_graph_pass
static std::unordered_set<std::string> skip_dist_ops{
"send", "recv", "send_barrier", "fetch_barrier"};
auto &ops =
graph->Get<const std::vector<OpDesc *>>(details::kStaleProgramOpDescs);
std::vector<ir::Node *> op_node_list;
op_node_list.reserve(ops.size());
std::unordered_map<ir::Node *, size_t> op_deps;
std::unordered_map<ir::Node *, std::unordered_set<ir::Node *>> pending_ops;
std::unordered_set<ir::Node *> ready_ops;
for (ir::Node *node : graph->Nodes()) {
if (!node->IsOp()) continue;
std::unordered_set<ir::Node *> preceding_ops;
for (auto *in : node->inputs) {
PADDLE_ENFORCE(in->IsVar(),
"Preceding Node of Op Nodes must be Var Node");
if (in->inputs.empty()) continue;
PADDLE_ENFORCE(in->inputs.size() == 1 && in->inputs[0]->IsOp(),
"Preceding Op Node of Var Node must be unique");
preceding_ops.insert(in->inputs[0]);
pending_ops[in->inputs[0]].insert(node);
}
op_deps[node] = preceding_ops.size();
if (preceding_ops.empty()) {
ready_ops.insert(node);
}
}
for (auto *op_desc : ops) {
ir::Node *found_node = nullptr;
for (auto *node : ready_ops) {
if (IsSameOpDesc(op_desc, node->Op())) {
PADDLE_ENFORCE(found_node == nullptr,
"Found multiple op_desc in graph: %s",
op_desc->Type());
found_node = node;
}
}
PADDLE_ENFORCE_NOT_NULL(found_node, "Cannot find op_desc in graph: %s",
op_desc->Type());
for (auto *pending_op : pending_ops[found_node]) {
if (--op_deps.at(pending_op) == 0) {
ready_ops.insert(pending_op);
}
}
ready_ops.erase(found_node);
if (skip_dist_ops.count(op_desc->Type()) == 0) {
op_node_list.push_back(found_node);
}
}
for (size_t i = 1; i < op_node_list.size(); ++i) {
auto *dep_var = graph->CreateControlDepVar();
op_node_list[i]->inputs.push_back(dep_var);
op_node_list[i - 1]->outputs.push_back(dep_var);
dep_var->outputs.push_back(op_node_list[i]);
dep_var->inputs.push_back(op_node_list[i - 1]);
VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name()
<< " and " << op_node_list[i]->Name();
}
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(sequential_execution_pass,
paddle::framework::ir::SequentialExecutionPass)
.RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
......@@ -25,7 +25,8 @@ namespace framework {
namespace ir {
void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
std::string op_type) {
const std::string& op_type,
const std::string& quant_type) {
const std::string pattern_name = "quant_dequant_fuse";
// FusePassBase::Init(pattern_name, graph);
const int kNumFields = 5;
......@@ -38,7 +39,7 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode("x")
->assert_is_op_input("fake_quantize_range_abs_max", "X")
->assert_is_op_input(quant_type, "X")
->AsInput();
std::string quantized_op_type = "";
......@@ -46,6 +47,9 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
if (op_type == "conv2d") {
quantized_op_type = "conv2d";
weight_name = "Filter";
} else if (op_type == "depthwise_conv2d") {
quantized_op_type = "depthwise_conv2d";
weight_name = "Filter";
} else if (op_type == "conv2d_fusion") {
quantized_op_type = "conv2d_fusion";
weight_name = "Filter";
......@@ -62,7 +66,7 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
}
patterns::QuantDequantOpFuse pattern(gpd.mutable_pattern(), pattern_name);
pattern(x, quantized_op_type, weight_name, times);
pattern(x, quantized_op_type, weight_name, times, quant_type);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
......@@ -103,7 +107,6 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
std::unordered_set<const Node*> delete_nodes;
for (int i = 0; i < times; i++) {
// max_range = (range * range) / weight_scale
float max_range = boost::get<float>(
nodes[i * kNumFields + kDequantOpOffset]->Op()->GetAttr("max_range"));
float weight_scale = (range * range) / max_range;
......@@ -118,7 +121,8 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
new_op_desc.SetType(quantized_op_type);
if (quantized_op_type == "conv2d" ||
quantized_op_type == "conv2d_fusion") {
quantized_op_type == "conv2d_fusion" ||
quantized_op_type == "depthwise_conv2d") {
new_op_desc.SetInput("Input", {new_input});
new_op_desc.SetOutput("Output", {new_output});
} else if (quantized_op_type == "fc") {
......@@ -156,11 +160,17 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "quant_dequant_fuse";
FusePassBase::Init(pattern_name, graph);
std::unordered_set<std::string> quantized_op_types = {"conv2d", "mul"};
std::unordered_set<std::string> quant_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
std::unordered_set<std::string> quantized_op_types = {"conv2d", "mul",
"depthwise_conv2d"};
auto* scope = param_scope();
for (auto& op_type : quantized_op_types) {
for (int i = 1; i <= 6; i++) {
RunQuantDequant(graph, scope, i, op_type);
for (auto& quant_type : quant_types) {
for (auto& op_type : quantized_op_types) {
for (int i = 6; i >= 1; i--) {
RunQuantDequant(graph, scope, i, op_type, quant_type);
}
}
}
}
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/shuffle_channel_detect_pass.h"
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(reshape1_op); \
GET_IR_NODE(reshape1_out); \
GET_IR_NODE(transpose_op); \
GET_IR_NODE(transpose_out); \
GET_IR_NODE(reshape2_op); \
GET_IR_NODE(reshape2_out);
void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "shufflechannel_pattern";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode("x")
->assert_is_op_input("reshape2", "X")
->AsInput();
patterns::ShuffleChannelPattern pattern(gpd.mutable_pattern(), pattern_name);
pattern(x);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
PADDLE_ENFORCE(subgraph.count(x));
auto* input_node = subgraph.at(x);
auto reshape1_desc = reshape1_op->Op();
auto reshape2_desc = reshape2_op->Op();
std::string input_name = input_node->Name();
std::string output_name = reshape2_out->Name();
auto reshape1_shape =
boost::get<std::vector<int>>(reshape1_desc->GetAttr("shape"));
auto reshape2_shape =
boost::get<std::vector<int>>(reshape2_desc->GetAttr("shape"));
int i_c = reshape1_shape[2];
int o_c = reshape2_shape[1];
int group = o_c / i_c;
framework::OpDesc new_op_desc;
new_op_desc.SetType("shuffle_channel");
new_op_desc.SetInput("X", {input_name});
new_op_desc.SetOutput("Out", {output_name});
new_op_desc.SetAttr("group", group);
new_op_desc.Flush();
// Create a new node for the fused op.
auto* new_op = graph->CreateOpNode(&new_op_desc);
IR_NODE_LINK_TO(input_node, new_op);
IR_NODE_LINK_TO(new_op, reshape2_out);
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph, {reshape1_op, reshape1_out, transpose_op,
transpose_out, reshape2_op});
};
gpd(graph, handler);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(shuffle_channel_detect_pass,
paddle::framework::ir::ShuffleChannelDetectPass);
......@@ -13,19 +13,22 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class ShuffleChannelDetectPass : public FusePassBase {
public:
virtual ~ShuffleChannelDetectPass() {}
class ModifyOpLockAndRecordEventPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -12,30 +12,32 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/sync_batch_norm_pass.h"
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
void SyncBatchNormPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Use synchronous batch norm";
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
auto* op = n->Op();
if (op->Type() == "batch_norm") {
op->SetType("sync_batch_norm");
}
if (op->Type() == "batch_norm_grad") {
op->SetType("sync_batch_norm_grad");
class SyncBatchNormPass : public Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
VLOG(3) << "Use synchronous batch norm";
for (const Node *n : graph->Nodes()) {
if (n->IsOp()) {
auto *op = n->Op();
if (op->Type() == "batch_norm") {
op->SetType("sync_batch_norm");
}
if (op->Type() == "batch_norm_grad") {
op->SetType("sync_batch_norm_grad");
}
}
}
}
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
......
......@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/sync_batch_norm_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
namespace ir {
......
......@@ -1148,7 +1148,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
proto::VarType::Type tmp = t->type();
PADDLE_ENFORCE(
tmp == data_type || data_type == dafault_data_type,
"DataType of Paddle Op %s %s must be the same. Get (%d) != (%d)",
"DataType of Paddle Op %s %s must be the same. Get (%s) != (%s)",
Type(), input.first, DataTypeToString(data_type),
DataTypeToString(tmp));
data_type = tmp;
......
......@@ -386,9 +386,10 @@ class ExecutionContext {
template <typename T>
T& GetKernelConfig(int idx) const {
PADDLE_ENFORCE(kernel_configs_ && kernel_configs_->size() > idx,
"%s selected kernel doesn't have kernel config %lu <= %d",
op_.Type().c_str(), kernel_configs_->size(), idx);
PADDLE_ENFORCE(
kernel_configs_ && kernel_configs_->size() > static_cast<size_t>(idx),
"%s selected kernel doesn't have kernel config %lu <= %d",
op_.Type().c_str(), kernel_configs_->size(), idx);
return *boost::get<std::shared_ptr<T>>(kernel_configs_->at(idx));
}
......
......@@ -23,11 +23,11 @@ limitations under the License. */
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef WITH_GPERFTOOLS
......@@ -46,6 +46,7 @@ static std::once_flag gProfileOnce;
#ifdef WITH_GPERFTOOLS
static bool gProfileStarted = false;
#endif
class ParallelExecutorPrivate {
public:
explicit ParallelExecutorPrivate(const std::vector<platform::Place> &places)
......@@ -57,7 +58,7 @@ class ParallelExecutorPrivate {
gProfileStarted = true;
#else
LOG(WARNING) << "Paddle is not compiled with gperftools. "
"FLAGS_pe_profile_fname will be ignored";
"FLAGS_pe_profile_fname will be ignored";
#endif
});
}
......@@ -110,9 +111,9 @@ class ParallelExecutorPrivate {
// global_ref_cnts_ is only initialized when ParallelExecutor constructs, and
// then keeps unchanged
// Before each iteration, runtime_ref_cnts_ is reset to global_ref_cnts_
std::vector<details::ReferenceCountMap> global_ref_cnts_;
std::vector<details::AtomicReferenceCountMap> runtime_ref_cnts_;
details::GarbageCollectorMap gcs_;
std::vector<ir::ReferenceCountMap> global_ref_cnts_;
std::vector<ir::AtomicReferenceCountMap> runtime_ref_cnts_;
ir::GarbageCollectorMap gcs_;
};
ir::Graph *ParallelExecutorPrivate::PrepareGCAndRefCnts(
......@@ -150,25 +151,23 @@ ir::Graph *ParallelExecutorPrivate::PrepareGCAndRefCnts(
}
if (!gcs_.empty()) {
std::vector<details::LastLiveOpsOfVars> last_live_ops_of_vars;
std::vector<ir::LastLiveOpsOfVars> last_live_ops_of_vars;
auto ref_cnt_pass =
ir::PassRegistry::Instance().Get("reference_count_pass");
ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount,
&global_ref_cnts_);
ref_cnt_pass->SetNotOwned(details::kLastLiveOpsOfVars,
&last_live_ops_of_vars);
ref_cnt_pass->SetNotOwned(ir::kGlobalReferenceCount, &global_ref_cnts_);
ref_cnt_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars);
graph = ref_cnt_pass->Apply(graph);
VLOG(10) << "ReferenceCountPass Applied";
auto eager_deletion_pass =
ir::PassRegistry::Instance().Get("eager_deletion_pass");
eager_deletion_pass->SetNotOwned(details::kRuntimeReferenceCount,
eager_deletion_pass->SetNotOwned(ir::kRuntimeReferenceCount,
&runtime_ref_cnts_);
eager_deletion_pass->SetNotOwned(details::kGarbageCollector, &gcs_);
eager_deletion_pass->SetNotOwned(details::kLastLiveOpsOfVars,
eager_deletion_pass->SetNotOwned(ir::kGarbageCollector, &gcs_);
eager_deletion_pass->SetNotOwned(ir::kLastLiveOpsOfVars,
&last_live_ops_of_vars);
eager_deletion_pass->SetNotOwned(details::kAllPlaces, &places_);
eager_deletion_pass->SetNotOwned(ir::kAllPlaces, &places_);
graph = eager_deletion_pass->Apply(graph);
VLOG(10) << "EagerDeletionPass Applied";
}
......@@ -179,6 +178,20 @@ std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
return member_->local_scopes_;
}
void ParallelExecutor::DropLocalExeScopes() {
auto executor = dynamic_cast<details::ScopeBufferedSSAGraphExecutor *>(
member_->executor_.get());
if (executor) {
executor->DropLocalExeScopes();
}
}
bool ParallelExecutor::NeedCreateLocalExeScope() {
auto executor = dynamic_cast<details::ScopeBufferedSSAGraphExecutor *>(
member_->executor_.get());
return executor && executor->NeedCreateLocalExeScope();
}
ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
const std::vector<std::string> &bcast_vars,
const std::string &loss_var_name,
......@@ -333,7 +346,7 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
graph = build_strategy.Apply(graph, {member_->places_[0]}, loss_var_name,
{member_->local_scopes_[0]}, 1,
member_->use_cuda_);
for (int i = 1; i < member_->places_.size(); ++i) {
for (size_t i = 1; i < member_->places_.size(); ++i) {
graphs[i] = build_strategy.Apply(
graphs[i], {member_->places_[i]}, loss_var_name,
{member_->local_scopes_[i]}, 1, member_->use_cuda_);
......@@ -344,8 +357,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
member_->local_scopes_, member_->nranks_,
member_->use_cuda_);
}
#endif
auto max_memory_size = GetEagerDeletionThreshold();
VLOG(10) << "Eager Deletion Threshold "
<< static_cast<float>(max_memory_size) / (1 << 30);
......
......@@ -58,6 +58,11 @@ class ParallelExecutor {
std::vector<Scope *> &GetLocalScopes();
void DropLocalExeScopes();
// This API is used to check whether DropLocalExeScopes work.
bool NeedCreateLocalExeScope();
/**
* Feed tensors to local scopes. The size of tensors should be equal to the
* size of local scopes.
......
......@@ -25,8 +25,9 @@ inline const T* Tensor::data() const {
check_memory_size();
bool valid =
std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType;
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %d",
DataTypeToString(type_));
PADDLE_ENFORCE(
valid, "Tensor holds the wrong type, it holds %s, but desires to be %s",
DataTypeToString(type_), DataTypeToString(DataTypeTrait<T>::DataType));
return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
......@@ -39,7 +40,9 @@ inline T* Tensor::data() {
check_memory_size();
bool valid =
std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType;
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s", type_);
PADDLE_ENFORCE(
valid, "Tensor holds the wrong type, it holds %s, but desires to be %s",
DataTypeToString(type_), DataTypeToString(DataTypeTrait<T>::DataType));
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
......
cc_library(anakin_op_converter SRCS fc.cc conv2d.cc conv2d_fusion.cc elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc softmax.cc batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc detection_out.cc scale.cc dropout.cc im2sequence.cc sum.cc DEPS anakin_engine framework_proto scope op_registry)
cc_library(anakin_op_converter SRCS fc.cc conv2d.cc conv2d_fusion.cc
elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc softmax.cc
batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc
detection_out.cc scale.cc dropout.cc im2sequence.cc sum.cc affine_channel.cc
roi_align.cc shuffle_channel.cc helper.cc DEPS anakin_engine framework_proto
scope op_registry gtest)
cc_test(test_anakin_fc SRCS test_fc_op.cc DEPS anakin_op_converter mul_op SERIAL)
cc_test(test_anakin_conv2d SRCS test_conv2d_op.cc DEPS anakin_op_converter conv_op im2col vol2col depthwise_conv SERIAL)
......@@ -14,5 +19,5 @@ cc_test(test_anakin_flatten SRCS test_flatten_op.cc DEPS anakin_op_converter fla
cc_test(test_anakin_transpose SRCS test_transpose_op.cc DEPS anakin_op_converter transpose_op SERIAL)
cc_test(test_anakin_batch_norm SRCS test_batch_norm_op.cc DEPS anakin_op_converter batch_norm_op SERIAL)
cc_test(test_anakin_dropout SRCS test_dropout_op.cc DEPS anakin_op_converter dropout_op SERIAL)
#cc_test(test_anakin_im2sequence SRCS test_im2sequence_op.cc DEPS anakin_op_converter im2sequence_op im2col)
cc_test(test_anakin_sum SRCS test_sum_op.cc DEPS anakin_op_converter sum_op selected_rows_functor SERIAL)
cc_test(test_anakin_affine_channel SRCS test_affine_channel_op.cc DEPS anakin_op_converter affine_channel_op SERIAL)
此差异已折叠。
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace anakin {
template <typename TargetT, ::anakin::Precision PrecisionT>
class AffineChannelOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
AffineChannelOpConverter() = default;
virtual void operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) override;
virtual ~AffineChannelOpConverter() {}
private:
};
} // namespace anakin
} // namespace inference
} // namespace paddle
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册