提交 658c5432 编写于 作者: S Superjomn

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into design/lite2

......@@ -25,8 +25,9 @@ endif()
if(ANAKIN_FOUND)
message(STATUS "Current ANAKIN header is ${ANAKIN_INCLUDE_DIR}/anakin_config.h. ")
include_directories(${ANAKIN_ROOT})
include_directories(${ANAKIN_ROOT}/include)
include_directories(${ANAKIN_ROOT}/include/saber)
include_directories(${ANAKIN_ROOT}/saber)
link_directories(${ANAKIN_ROOT})
add_definitions(-DPADDLE_WITH_ANAKIN)
endif()
......@@ -77,6 +77,7 @@ else(WIN32)
ENDIF(WIN32)
MESSAGE(STATUS "warp-ctc library: ${WARPCTC_LIBRARIES}")
get_filename_component(WARPCTC_LIBRARY_PATH ${WARPCTC_LIBRARIES} DIRECTORY)
INCLUDE_DIRECTORIES(${WARPCTC_INCLUDE_DIR}) # For warpctc code to include its headers.
INCLUDE_DIRECTORIES(${THIRD_PARTY_PATH}/install) # For Paddle code to include warpctc headers.
......
......@@ -3,6 +3,8 @@ set(PADDLE_VERSION $ENV{PADDLE_VERSION})
set(tmp_version "HEAD")
set(TAG_VERSION_REGEX "[0-9]+\\.[0-9]+\\.[0-9]+(\\.(a|b|rc)\\.[0-9]+)?")
set(COMMIT_VERSION_REGEX "[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+")
set(LATEST_PADDLE_VERSION "latest")
while ("${PADDLE_VERSION}" STREQUAL "")
# Check current branch name
execute_process(
......@@ -23,8 +25,8 @@ while ("${PADDLE_VERSION}" STREQUAL "")
if (${GIT_BRANCH_NAME} MATCHES "release/${TAG_VERSION_REGEX}")
# Check the tag is a correct version
if (${GIT_TAG_NAME} MATCHES "${COMMIT_VERSION_REGEX}")
# if no tag was found, set PADDLE_VERSION to 0.0.0 to represent latest
set(PADDLE_VERSION "0.0.0")
# if no tag was found, set PADDLE_VERSION to "latest"
set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}")
elseif (${GIT_TAG_NAME} MATCHES "v${TAG_VERSION_REGEX}")
string(REPLACE "v" "" PADDLE_VERSION ${GIT_TAG_NAME})
else() # otherwise, get the previous git tag name.
......@@ -42,19 +44,19 @@ while ("${PADDLE_VERSION}" STREQUAL "")
if (${GIT_EXACT_TAG_NAME} MATCHES "v${TAG_VERSION_REGEX}")
string(REPLACE "v" "" PADDLE_VERSION ${GIT_EXACT_TAG_NAME})
else()
set(PADDLE_VERSION "0.0.0")
set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}")
endif()
else()
# otherwise, we always set PADDLE_VERSION to 0.0.0 to represent latest
set(PADDLE_VERSION "0.0.0")
# otherwise, we always set PADDLE_VERSION to "latest"
set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}")
endif()
endif()
else()
set(PADDLE_VERSION "0.0.0")
set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}")
message(WARNING "Cannot add paddle version from git tag")
endif()
else()
set(PADDLE_VERSION "0.0.0")
set(PADDLE_VERSION "${LATEST_PADDLE_VERSION}")
message(WARNING "Cannot add paddle version for wrong git branch result")
endif()
endwhile()
......
此差异已折叠。
......@@ -455,6 +455,8 @@ void MultiSlotDataFeed::Init(
all_slots_.resize(all_slot_num);
all_slots_type_.resize(all_slot_num);
use_slots_index_.resize(all_slot_num);
total_dims_without_inductive_.resize(all_slot_num);
inductive_shape_index_.resize(all_slot_num);
use_slots_.clear();
use_slots_is_dense_.clear();
for (size_t i = 0; i < all_slot_num; ++i) {
......@@ -462,14 +464,20 @@ void MultiSlotDataFeed::Init(
all_slots_[i] = slot.name();
all_slots_type_[i] = slot.type();
use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1;
total_dims_without_inductive_[i] = 1;
inductive_shape_index_[i] = -1;
if (slot.is_used()) {
use_slots_.push_back(all_slots_[i]);
use_slots_is_dense_.push_back(slot.is_dense());
std::vector<int> local_shape;
if (slot.is_dense()) {
// for batch size holder if is_dense
if (slot.shape(0) > 0) {
local_shape.push_back(0);
for (size_t i = 0; i < slot.shape_size(); ++i) {
if (slot.shape(i) > 0) {
total_dims_without_inductive_[i] *= slot.shape(i);
}
if (slot.shape(i) == -1) {
inductive_shape_index_[i] = i;
}
}
}
for (size_t i = 0; i < slot.shape_size(); ++i) {
......@@ -762,7 +770,10 @@ void MultiSlotDataFeed::PutToFeedVec(
LoD data_lod{offset};
feed_vec_[i]->set_lod(data_lod);
if (use_slots_is_dense_[i]) {
use_slots_shape_[i][0] = batch_size_;
if (inductive_shape_index_[i] != -1) {
use_slots_shape_[i][inductive_shape_index_[i]] =
total_instance / total_dims_without_inductive_[i];
}
feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i]));
}
}
......@@ -785,6 +796,8 @@ void MultiSlotInMemoryDataFeed::Init(
all_slots_.resize(all_slot_num);
all_slots_type_.resize(all_slot_num);
use_slots_index_.resize(all_slot_num);
total_dims_without_inductive_.resize(all_slot_num);
inductive_shape_index_.resize(all_slot_num);
use_slots_.clear();
use_slots_is_dense_.clear();
for (size_t i = 0; i < all_slot_num; ++i) {
......@@ -797,8 +810,13 @@ void MultiSlotInMemoryDataFeed::Init(
use_slots_is_dense_.push_back(slot.is_dense());
std::vector<int> local_shape;
if (slot.is_dense()) {
if (slot.shape(0) > 0) {
local_shape.push_back(0);
for (size_t i = 0; i < slot.shape_size(); ++i) {
if (slot.shape(i) > 0) {
total_dims_without_inductive_[i] *= slot.shape(i);
}
if (slot.shape(i) == -1) {
inductive_shape_index_[i] = i;
}
}
}
for (size_t i = 0; i < slot.shape_size(); ++i) {
......@@ -960,7 +978,10 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
LoD data_lod{offset};
feed_vec_[i]->set_lod(data_lod);
if (use_slots_is_dense_[i]) {
use_slots_shape_[i][0] = batch_size_;
if (inductive_shape_index_[i] != -1) {
use_slots_shape_[i][inductive_shape_index_[i]] =
total_instance / total_dims_without_inductive_[i];
}
feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i]));
}
}
......
......@@ -143,6 +143,8 @@ class DataFeed {
std::vector<std::string> all_slots_;
std::vector<std::string> all_slots_type_;
std::vector<std::vector<int>> use_slots_shape_;
std::vector<int> inductive_shape_index_;
std::vector<int> total_dims_without_inductive_;
std::vector<int>
use_slots_index_; // -1: not used; >=0: the index of use_slots_
......
......@@ -121,6 +121,16 @@ int64_t product(const DDim& ddim) {
return ddim.apply_visitor(ProductVisitor());
}
bool contain_unknown_dim(const DDim& ddim) {
for (int i = 0; i < ddim.size(); ++i) {
if (ddim[i] < 0) {
return true;
}
}
return false;
}
DDim slice_ddim(const DDim& dim, int begin, int end) {
PADDLE_ENFORCE(begin >= 0 && end <= dim.size(),
"[begin(%d), end(%d)) must be inside [0, %d) in ddim slice.",
......
......@@ -182,6 +182,8 @@ std::vector<int> vectorize2int(const DDim& ddim);
int64_t product(const DDim& ddim);
bool contain_unknown_dim(const DDim& ddim);
/**
* \brief Slice a ddim
*
......
cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node)
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor)
cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base)
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper)
cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper)
cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper)
cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper)
cc_library(fuse_adam_op_pass SRCS fuse_adam_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(fuse_momentum_op_pass SRCS fuse_momentum_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
......@@ -27,7 +17,7 @@ if(WITH_DISTRIBUTE)
endif()
endif()
set(all_reduce_deps all_reduce_op_handle)
if(WITH_GPU)
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor)
......@@ -37,7 +27,6 @@ if(WITH_GPU)
if(WITH_DGC)
nv_library(sparse_all_reduce_op_handle SRCS sparse_all_reduce_op_handle.cc DEPS op_handle_base scope
lod_tensor ddim memory dynload_cuda variable_visitor dgc all_reduce_op_handle)
set(all_reduce_deps sparse_all_reduce_op_handle)
endif()
if(WITH_DISTRIBUTE)
......@@ -68,34 +57,12 @@ endif()
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
if(WITH_GPU)
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper gpu_info)
else()
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper cpu_info)
endif()
cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass)
cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info)
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle)
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass)
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle ${all_reduce_deps} reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass)
if (WITH_GPU)
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
endif()
cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/alloc_continuous_space_for_grad_pass.h"
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
DEFINE_uint64(fuse_parameter_memory_size, 0, // 0 KB
"fuse_parameter_memory_size is up limited memory size "
"of one group parameters' gradient which is the input "
"of communication calling(e.g NCCLAllReduce). "
"The default value is 0, it means that "
"not set group according to memory_size.");
DEFINE_int32(
fuse_parameter_groups_size, 3,
"fuse_parameter_groups_size is the size of one group parameters' gradient. "
"The default value is a experimental result. If the "
"fuse_parameter_groups_size is 1, it means that the groups size is "
"the number of parameters' gradient. If the fuse_parameter_groups_size is "
"-1, it means that there are only one group. The default value is 3, it is "
"an experimental value.");
namespace paddle {
namespace framework {
namespace details {
// SetFuseParameterGroupsSize and SetFuseParameterMemorySize are used in unit
// test, because it is invalid that seting 'FLAGS_fuse_parameter_memory_size'
// and 'FLAGS_fuse_parameter_groups_size' in unit test.
void SetFuseParameterGroupsSize(int group_size) {
FLAGS_fuse_parameter_groups_size = group_size;
}
int GetFuseParameterGroupsSize() { return FLAGS_fuse_parameter_groups_size; }
void SetFuseParameterMemorySize(uint64_t memory_size) {
FLAGS_fuse_parameter_memory_size = memory_size;
}
uint64_t GetFuseParameterMemorySize() {
return FLAGS_fuse_parameter_memory_size;
}
static const char kUnKnow[] = "@UNKNOW@";
static framework::proto::VarType::Type kDefaultDtype =
framework::proto::VarType::Type::VarType_Type_BOOL;
void AllocContinuousSpaceForGradPass::ApplyImpl(ir::Graph *graph) const {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(kLocalScopes);
ResetAttribute<ParamsAndGrads>(kParamsAndGrads, &result);
ResetAttribute<GroupGradsAndParams>(kGroupGradsAndParams, &result);
// NOTE: The operator nodes should be in topology order.
std::vector<ir::Node *> topo_nodes = ir::TopologySortOperations(result);
auto &params_grads = result.Get<ParamsAndGrads>(kParamsAndGrads);
for (auto &node : topo_nodes) {
RecordParamsAndGrads(node, &params_grads);
}
if (params_grads.size() == 0) {
VLOG(10) << "Doesn't find gradients";
return;
}
std::unordered_map<std::string, ir::Node *> vars;
for (ir::Node *node : result.Nodes()) {
if (node->IsVar() && node->Var()) {
// Note: The graph may have the same name node. For example, parameter
// is the input of operator and it also is the output of optimizer;
vars.emplace(node->Var()->Name(), node);
}
}
auto &group_grads_params =
result.Get<GroupGradsAndParams>(kGroupGradsAndParams);
// Note: the order of params_grads may be changed by SetGroupGradsAndParams.
SetGroupGradsAndParams(vars, params_grads, &group_grads_params);
params_grads.clear();
for (auto &group_p_g : group_grads_params) {
params_grads.insert(params_grads.begin(), group_p_g.begin(),
group_p_g.end());
}
for (auto &p_g : params_grads) {
std::swap(p_g.first, p_g.second);
}
// Set Gradients as Persistable to prevent this var becoming reusable.
auto dtype = kDefaultDtype;
for (auto &p_g : params_grads) {
// Get gradient var
auto iter = vars.find(p_g.second);
PADDLE_ENFORCE(iter != vars.end(), "%s is not found.", p_g.second);
iter->second->Var()->SetPersistable(true);
PADDLE_ENFORCE(IsSupportedVarType(iter->second->Var()->GetType()));
// Get Dtype
auto ele_dtype = iter->second->Var()->GetDataType();
if (dtype == kDefaultDtype) {
dtype = ele_dtype;
PADDLE_ENFORCE_NE(ele_dtype, kDefaultDtype,
"The data type should not be bool.");
}
PADDLE_ENFORCE_EQ(ele_dtype, dtype,
"The data type of input is not consistent.");
}
// Create a FusedVarsSet to avoid duplicating names for fused_var in other
// pass.
if (!result.Has(kFusedVars)) {
result.Set(kFusedVars, new FusedVars);
}
// the kFusedGrads is used be fuse_optimizer_op_pass.
result.Set(kFusedGrads, new FusedGrads);
// the fused_var_name should be unique, so it appends
// params_grads.begin()->second.
auto fused_var_name = std::string(kFusedVarNamePrefix) + "@GRAD@" +
params_grads.begin()->second;
result.Get<FusedGrads>(kFusedGrads) = fused_var_name;
auto &fused_var_set = result.Get<FusedVars>(kFusedVars);
PADDLE_ENFORCE_EQ(fused_var_set.count(fused_var_name), 0,
"%s is duplicate in FusedVars.", fused_var_name);
fused_var_set.insert(fused_var_name);
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars, fused_var_name,
params_grads);
}
template <typename AttrType>
void AllocContinuousSpaceForGradPass::ResetAttribute(
const std::string &attr_name, ir::Graph *graph) const {
if (graph->Has(attr_name)) {
VLOG(10) << attr_name << " is reset.";
graph->Erase(attr_name);
}
graph->Set(attr_name, new AttrType);
}
void AllocContinuousSpaceForGradPass::SetGroupGradsAndParams(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const ParamsAndGrads &params_grads,
GroupGradsAndParams *group_grads_params) const {
SetGroupAccordingToLayers(var_nodes, params_grads, group_grads_params);
SetGroupAccordingToMemorySize(var_nodes, group_grads_params);
SetGroupAccordingToGroupSize(var_nodes, group_grads_params);
}
void AllocContinuousSpaceForGradPass::SetGroupAccordingToLayers(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const ParamsAndGrads &params_grads,
GroupGradsAndParams *group_grads_params) const {
std::unordered_map<std::string, std::vector<int>> layer_params;
for (size_t i = 0; i < params_grads.size(); ++i) {
auto pos = params_grads[i].first.find_first_of(".");
if (pos == std::string::npos) {
layer_params[std::string(kUnKnow)].emplace_back(i);
} else {
layer_params[params_grads[i].first.substr(0, pos)].emplace_back(i);
}
}
group_grads_params->reserve(layer_params.size());
for (size_t i = 0; i < params_grads.size(); ++i) {
auto pos = params_grads[i].first.find_first_of(".");
std::string key = kUnKnow;
if (pos != std::string::npos) {
key = params_grads[i].first.substr(0, pos);
}
auto iter = layer_params.find(key);
if (iter == layer_params.end()) continue;
group_grads_params->emplace_back();
auto &local_group_grads_params = group_grads_params->back();
for (auto &idx : iter->second) {
local_group_grads_params.emplace_back(
std::make_pair(params_grads[idx].second, params_grads[idx].first));
}
layer_params.erase(iter);
}
VLOG(10) << "SetGroupAccordingToLayers: ";
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &p_g : group_grads_params->at(i)) {
out << "(" << p_g.second << ", " << p_g.first << "), ";
}
VLOG(10) << out.str();
}
}
void AllocContinuousSpaceForGradPass::SetGroupAccordingToMemorySize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
GroupGradsAndParams *group_grads_params) const {
const uint64_t group_memory_size = GetFuseParameterMemorySize();
if (group_memory_size == 0) {
return;
}
GroupGradsAndParams local_group_grads_params;
size_t j = 0;
while (j < group_grads_params->size()) {
local_group_grads_params.emplace_back();
auto &group_p_g = local_group_grads_params.back();
size_t local_group_memory_size = 0;
while (j < group_grads_params->size()) {
std::for_each(
group_grads_params->at(j).begin(), group_grads_params->at(j).end(),
[&local_group_memory_size,
&var_nodes](const std::pair<std::string, std::string> &g_p) {
auto iter = var_nodes.find(g_p.second);
PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.",
g_p.second);
auto shape = iter->second->Var()->GetShape();
size_t size =
framework::SizeOfType(iter->second->Var()->GetDataType());
std::for_each(shape.begin(), shape.end(),
[&size](const int64_t &n) { size *= n; });
local_group_memory_size += size;
});
group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(),
group_grads_params->at(j).end());
++j;
if (local_group_memory_size >= group_memory_size) {
break;
}
}
}
std::swap(*group_grads_params, local_group_grads_params);
VLOG(10) << string::Sprintf("SetGroupAccordingToMemorySize(memory_size: %d):",
group_memory_size);
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &g_p : group_grads_params->at(i)) {
auto iter = var_nodes.find(g_p.second);
PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.", g_p.second);
auto shape = iter->second->Var()->GetShape();
size_t size = framework::SizeOfType(iter->second->Var()->GetDataType());
std::for_each(shape.begin(), shape.end(),
[&size](const int64_t &n) { size *= n; });
out << string::Sprintf("(%s(%d), %s)", g_p.second, size, g_p.first);
}
VLOG(10) << out.str();
}
}
void AllocContinuousSpaceForGradPass::SetGroupAccordingToGroupSize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
GroupGradsAndParams *group_grads_params) const {
if (GetFuseParameterGroupsSize() == 1) {
return;
}
const int group_size = GetFuseParameterGroupsSize() == -1
? static_cast<int>(group_grads_params->size())
: GetFuseParameterGroupsSize();
PADDLE_ENFORCE_GT(group_size, 1);
size_t groups = (group_grads_params->size() + group_size - 1) / group_size;
GroupGradsAndParams local_group_grads_params;
local_group_grads_params.reserve(groups);
size_t j = 0;
for (size_t i = 0; i < groups; ++i) {
local_group_grads_params.emplace_back();
auto &group_p_g = local_group_grads_params.back();
group_p_g.reserve(group_size);
while (j < group_grads_params->size()) {
group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(),
group_grads_params->at(j).end());
++j;
if (j % group_size == 0) break;
}
}
std::swap(*group_grads_params, local_group_grads_params);
VLOG(10) << string::Sprintf("SetGroupAccordingToGroupSize(group_size: %d):",
group_size);
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &p_g : group_grads_params->at(i)) {
out << "(" << p_g.second << ", " << p_g.first << "), ";
}
VLOG(10) << out.str();
}
}
bool AllocContinuousSpaceForGradPass::IsSupportedVarType(
const proto::VarType::Type &type) const {
// Current only support LOD_TENSOR.
return type == proto::VarType::LOD_TENSOR;
}
void AllocContinuousSpaceForGradPass::RecordParamsAndGrads(
ir::Node *node, ParamsAndGrads *params_grads) const {
try {
bool is_bk_op =
static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
if (!is_bk_op) return;
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
auto backward_vars =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, static_cast<size_t>(0));
for (size_t i = 0; i < backward_vars.size(); i += 2) {
VLOG(10) << "Trainable parameter: " << backward_vars[i]
<< ", gradient: " << backward_vars[i + 1];
params_grads->emplace_back(std::make_pair(backward_vars[i] /*param*/,
backward_vars[i + 1] /*grad*/));
}
} catch (boost::bad_get e) {
}
}
void AllocContinuousSpaceForGradPass::InitFusedVarsAndAllocSpaceForVars(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::unordered_map<std::string, ir::Node *> &vars,
const std::string &fused_var_name,
const ParamsAndGrads &params_grads) const {
// Init Gradients and FusedVars
VLOG(10) << "Init FusedVars and Gradients.";
for (auto it = local_scopes.rbegin(); it != local_scopes.rend(); ++it) {
auto &scope = *it;
PADDLE_ENFORCE(scope->FindVar(fused_var_name) == nullptr,
"%s has existed in scope.", fused_var_name);
scope->Var(fused_var_name)->GetMutable<LoDTensor>();
for (auto &p_g : params_grads) {
auto iter = vars.find(p_g.second);
PADDLE_ENFORCE(iter != vars.end());
PADDLE_ENFORCE_NOT_NULL(iter->second->Var());
PADDLE_ENFORCE_EQ(iter->second->Var()->GetType(),
proto::VarType::LOD_TENSOR);
scope->Var(p_g.second)->GetMutable<LoDTensor>();
}
}
// Alloc continuous space for vars.
std::vector<std::string> grads_name;
std::vector<std::string> params_name;
grads_name.reserve(params_grads.size());
params_name.reserve(params_grads.size());
for (auto &p_g : params_grads) {
params_name.emplace_back(p_g.first);
grads_name.emplace_back(p_g.second);
}
framework::ProgramDesc program_desc;
AppendAllocSpaceForVarsOp(params_name, grads_name, fused_var_name,
program_desc.MutableBlock(0));
for (size_t i = 0; i < local_scopes.size(); ++i) {
for (auto &op_desc : program_desc.Block(0).AllOps()) {
auto op = OpRegistry::CreateOp(*op_desc);
op->Run(*local_scopes[i], places[i]);
}
}
}
void AllocContinuousSpaceForGradPass::AppendAllocSpaceForVarsOp(
const std::vector<std::string> &params_name,
const std::vector<std::string> &grads_name,
const std::string &fused_var_name, BlockDesc *global_block) const {
auto op_desc = global_block->AppendOp();
op_desc->SetType("alloc_continuous_space");
op_desc->SetInput("Input", params_name);
op_desc->SetOutput("Output", grads_name);
op_desc->SetOutput("FusedOutput", {fused_var_name});
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(alloc_continuous_space_for_grad_pass,
paddle::framework::details::AllocContinuousSpaceForGradPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace details {
void SetFuseParameterGroupsSize(int group_size);
int GetFuseParameterGroupsSize();
void SetFuseParameterMemorySize(uint64_t memory_size);
uint64_t GetFuseParameterMemorySize();
class AllocContinuousSpaceForGradPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override;
template <typename AttrType>
void ResetAttribute(const std::string &attr_name, ir::Graph *graph) const;
void SetGroupGradsAndParams(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const ParamsAndGrads &params_grads,
GroupGradsAndParams *group_grads_params) const;
void SetGroupAccordingToLayers(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const ParamsAndGrads &params_grads,
GroupGradsAndParams *group_grads_params) const;
void SetGroupAccordingToMemorySize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
GroupGradsAndParams *group_grads_params) const;
void SetGroupAccordingToGroupSize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
GroupGradsAndParams *group_grads_params) const;
private:
bool IsSupportedVarType(const proto::VarType::Type &type) const;
void RecordParamsAndGrads(ir::Node *node, ParamsAndGrads *params_grads) const;
void InitFusedVarsAndAllocSpaceForVars(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::unordered_map<std::string, ir::Node *> &vars,
const std::string &fused_var_name,
const ParamsAndGrads &params_grads) const;
void AppendAllocSpaceForVarsOp(const std::vector<std::string> &params_name,
const std::vector<std::string> &grads_name,
const std::string &fused_var_name,
BlockDesc *global_block) const;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -17,15 +17,14 @@ limitations under the License. */
#include <glog/logging.h>
#include <memory>
#include <utility>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
namespace paddle {
namespace framework {
......@@ -173,10 +172,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
const std::string graph_path =
string::Sprintf("%s%s", strategy_.debug_graphviz_path_.c_str(),
"_multi_devices_graph");
multi_devices_print_pass->Set<std::string>(kGraphvizPath,
multi_devices_print_pass->Set<std::string>(ir::kGraphvizPath,
new std::string(graph_path));
multi_devices_print_pass->Set<details::GraphvizSSAGraphPrinter>(
"graph_printer", new details::GraphvizSSAGraphPrinter);
multi_devices_print_pass->Set<ir::GraphvizSSAGraphPrinter>(
"graph_printer", new ir::GraphvizSSAGraphPrinter);
}
// experimental shows that the program will be faster if append
......@@ -240,7 +239,7 @@ std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
}
bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
return framework::details::MultiDevSSAGraphBuilder().count(pass_name) > 0;
return framework::ir::MultiDevSSAGraphBuilder().count(pass_name) > 0;
}
ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
......@@ -263,13 +262,13 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
if (IsMultiDevPass(pass->Type())) {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
pass->Erase(kLossVarName);
pass->SetNotOwned<const std::string>(kLossVarName, &loss_var_name);
pass->Erase(ir::kLossVarName);
pass->SetNotOwned<const std::string>(ir::kLossVarName, &loss_var_name);
pass->Erase(kLocalScopes);
pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
&local_scopes);
pass->Erase(kNRanks);
pass->Set<size_t>(kNRanks, new size_t(nranks));
pass->Erase(ir::kNRanks);
pass->Set<size_t>(ir::kNRanks, new size_t(nranks));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
......@@ -312,8 +311,8 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
continue;
}
} else if (pass->Type() == "inplace_pass") {
pass->Erase(kUseCuda);
pass->Set<bool>(kUseCuda, new bool(use_cuda));
pass->Erase(ir::kUseCuda);
pass->Set<bool>(ir::kUseCuda, new bool(use_cuda));
}
VLOG(3) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(graph);
......
......@@ -31,7 +31,7 @@ namespace details {
EagerDeletionOpHandle::EagerDeletionOpHandle(
ir::Node *node, const Scope *scope, const platform::Place &place,
const std::unordered_set<std::string> &var_names, GarbageCollector *gc,
AtomicReferenceCountMap *ref_cnts)
ir::AtomicReferenceCountMap *ref_cnts)
: OpHandleBase(node),
scope_(scope),
var_names_(var_names.begin(), var_names.end()),
......
......@@ -20,7 +20,7 @@
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
namespace paddle {
namespace framework {
......@@ -34,7 +34,7 @@ class EagerDeletionOpHandle : public OpHandleBase {
const platform::Place &place,
const std::unordered_set<std::string> &var_names,
GarbageCollector *gc,
AtomicReferenceCountMap *ref_cnts);
ir::AtomicReferenceCountMap *ref_cnts);
~EagerDeletionOpHandle();
......@@ -56,7 +56,7 @@ class EagerDeletionOpHandle : public OpHandleBase {
const Scope *scope_;
std::vector<std::string> var_names_;
GarbageCollector *gc_; // not own
AtomicReferenceCountMap *ref_cnts_; // not own
ir::AtomicReferenceCountMap *ref_cnts_; // not own
#ifdef PADDLE_WITH_CUDA
platform::CUDADeviceContext *dev_ctx_{nullptr};
cudaEvent_t event_{nullptr};
......
......@@ -63,8 +63,7 @@ void FetchOpHandle::RunImpl() {
auto &t = var->Get<framework::LoDTensor>();
if (platform::is_gpu_place(t.place())) {
#ifdef PADDLE_WITH_CUDA
TensorCopy(t, cpu, *dev_ctxes_.at(t.place()), &tensors_[i]);
dev_ctxes_.at(t.place())->Wait();
TensorCopy(t, cpu, &tensors_[i]);
#endif
} else {
tensors_[i].ShareDataWith(t);
......
......@@ -27,7 +27,7 @@ namespace paddle {
namespace framework {
namespace details {
constexpr char kLocalExecScopeName[] = "@LOCAL_SCOPE@";
constexpr char kLocalExecScopeName[] = "@LOCAL_EXE_SCOPE@";
// Wraps ir::Node and provide helper utilities.
// It's responsible for populating necessary fields of ir::Node.
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace details {
class RecordSkipMemoryOptVarsPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override {
PADDLE_ENFORCE(!graph->Has(kMemOptSkipVars));
graph->Set(kMemOptSkipVars, new MemOptSkipVars);
auto& skip_vars = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
// NOTE(zcd): Insert OpRoleVars to SkipVarSet to prevent the vars are rename
// in memory optimize pass.
InsertOpRoleVarsToSkipVarSet(graph, &skip_vars);
}
void InsertOpRoleVarsToSkipVarSet(const ir::Graph* graph,
MemOptSkipVars* skip_vars) const {
for (auto& node : graph->Nodes()) {
PADDLE_ENFORCE_NOT_NULL(node, "The node should not be nullptr.");
if (node->IsOp() && node->Op()) {
try {
auto op_role_vars =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(op_role_vars.size() % 2, 0);
for (size_t i = 0; i < op_role_vars.size(); i += 2) {
auto& g_name = op_role_vars[i + 1];
skip_vars->insert(g_name);
}
} catch (boost::bad_get e) {
}
}
}
}
};
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(record_skip_memory_opt_vars_pass,
paddle::framework::details::RecordSkipMemoryOptVarsPass);
......@@ -68,22 +68,33 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
++drop_scope_counter_;
if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
WaitComputationalStreams();
DropLocalExeScopes();
}
if (eptr) {
std::rethrow_exception(eptr);
} else {
return fetch_data;
}
}
void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes() {
drop_scope_counter_ = 0;
for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
for (auto &scope : local_scopes_) {
auto &local_scope =
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
scope->DeleteScope(local_scope);
VLOG(3) << "Drop local execution scope: " << local_scope;
}
}
drop_scope_counter_ = 0;
}
if (eptr) {
std::rethrow_exception(eptr);
} else {
return fetch_data;
}
bool ScopeBufferedSSAGraphExecutor::NeedCreateLocalExeScope() {
return drop_scope_counter_ == 0;
}
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -47,17 +47,12 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override;
private:
inline void WaitComputationalStreams() {
// Wait All computational streams
for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
}
void DropLocalExeScopes();
bool NeedCreateLocalExeScope();
private:
size_t drop_scope_counter_{0};
ExecutionStrategy strategy_;
std::unique_ptr<SSAGraphExecutor> underlying_executor_;
std::vector<Scope*> local_scopes_;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace details {
static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) {
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
op1->Outputs() == op2->Outputs();
}
void SequentialExecutionPass::ApplyImpl(ir::Graph *graph) const {
// FIXME(zjl): Insert dependencies between some distributed ops may cause
// the multi_devices_graph_pass fails. So we skip these ops here.
// Indeed, maybe we should not insert dependencies between these ops
// casually, which may cause deadlock easily.
// We should add more skipped distributed ops when found errors in
// multi_devices_graph_pass
static std::unordered_set<std::string> skip_dist_ops{
"send", "recv", "send_barrier", "fetch_barrier"};
auto &ops = graph->Get<const std::vector<OpDesc *>>(kStaleProgramOpDescs);
std::vector<ir::Node *> op_node_list;
op_node_list.reserve(ops.size());
std::unordered_map<ir::Node *, size_t> op_deps;
std::unordered_map<ir::Node *, std::unordered_set<ir::Node *>> pending_ops;
std::unordered_set<ir::Node *> ready_ops;
for (ir::Node *node : graph->Nodes()) {
if (!node->IsOp()) continue;
std::unordered_set<ir::Node *> preceding_ops;
for (auto *in : node->inputs) {
PADDLE_ENFORCE(in->IsVar(),
"Preceding Node of Op Nodes must be Var Node");
if (in->inputs.empty()) continue;
PADDLE_ENFORCE(in->inputs.size() == 1 && in->inputs[0]->IsOp(),
"Preceding Op Node of Var Node must be unique");
preceding_ops.insert(in->inputs[0]);
pending_ops[in->inputs[0]].insert(node);
}
op_deps[node] = preceding_ops.size();
if (preceding_ops.empty()) {
ready_ops.insert(node);
}
}
for (auto *op_desc : ops) {
ir::Node *found_node = nullptr;
for (auto *node : ready_ops) {
if (IsSameOpDesc(op_desc, node->Op())) {
PADDLE_ENFORCE(found_node == nullptr,
"Found multiple op_desc in graph: %s", op_desc->Type());
found_node = node;
}
}
PADDLE_ENFORCE_NOT_NULL(found_node, "Cannot find op_desc in graph: %s",
op_desc->Type());
for (auto *pending_op : pending_ops[found_node]) {
if (--op_deps.at(pending_op) == 0) {
ready_ops.insert(pending_op);
}
}
ready_ops.erase(found_node);
if (skip_dist_ops.count(op_desc->Type()) == 0) {
op_node_list.push_back(found_node);
}
}
for (size_t i = 1; i < op_node_list.size(); ++i) {
auto *dep_var = graph->CreateControlDepVar();
op_node_list[i]->inputs.push_back(dep_var);
op_node_list[i - 1]->outputs.push_back(dep_var);
dep_var->outputs.push_back(op_node_list[i]);
dep_var->inputs.push_back(op_node_list[i - 1]);
VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name()
<< " and " << op_node_list[i]->Name();
}
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(sequential_execution_pass,
paddle::framework::details::SequentialExecutionPass)
.RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
......@@ -425,6 +425,7 @@ void DownpourWorker::TrainFiles() {
}
VLOG(3) << "push dense gradient done.";
// the following code should be more precise and clean
// TODO(guru4elephant)
int32_t tmp_push_dense_wait_times = -1;
......
......@@ -19,7 +19,7 @@
#include <unordered_map>
#include <unordered_set>
#include "glog/logging.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/type_defs.h"
......
......@@ -18,7 +18,7 @@
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -33,7 +33,7 @@ namespace framework {
std::unique_ptr<ir::Pass> CreateInplacePass() {
auto pass = ir::PassRegistry::Instance().Get("inplace_pass");
pass->Set<bool>(details::kUseCuda, new bool(true));
pass->Set<bool>(ir::kUseCuda, new bool(true));
return pass;
}
......@@ -225,7 +225,7 @@ TEST(InferInplace, SingleOpInplaceInToOut) {
FakeSuccData(&prog);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
g->Set(ir::kMemOptSkipVars, new std::unordered_set<std::string>());
g = test_SingleOpInplaceInToOut(std::move(g));
auto op_node = GetNodeFromGraph(g.get(), "single_op");
......@@ -241,7 +241,7 @@ TEST(InferInplace, SingleOpInplaceInToOutNoInplace) {
FakeNoInplaceData(&prog);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
g->Set(ir::kMemOptSkipVars, new std::unordered_set<std::string>());
g = test_SingleOpInplaceInToOut(std::move(g));
auto op_node = GetNodeFromGraph(g.get(), "single_op");
......@@ -274,7 +274,7 @@ TEST(InferInplace, MultiOutInplaceInToOut) {
prog.MutableBlock(0)->Var("z0")->SetShape({32, 16, 1024, 1024});
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
g->Set(ir::kMemOptSkipVars, new std::unordered_set<std::string>());
auto pass = CreateInplacePass();
pass->Apply(g.get());
auto op_node = GetNodeFromGraph(g.get(), "multi_out_op");
......@@ -310,7 +310,7 @@ TEST(InferInplace, MultiGradInplaceInToOut) {
prog.MutableBlock(0)->Var("z0")->SetShape({32, 15, 1024, 1024});
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
g->Set(details::kMemOptSkipVars, new std::unordered_set<std::string>());
g->Set(ir::kMemOptSkipVars, new std::unordered_set<std::string>());
auto pass = CreateInplacePass();
pass->Apply(g.get());
auto op_node = GetNodeFromGraph(g.get(), "multi_out_grad");
......
......@@ -3,6 +3,9 @@ file(WRITE ${pass_file} "// Generated by the paddle/fluid/framework/ir/CMakeList
file(APPEND ${pass_file} "\#pragma once\n")
file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
add_subdirectory(fuse_optimizer_ops_pass)
add_subdirectory(memory_optimize_pass)
add_subdirectory(multi_devices_graph_pass)
# Usage: pass_library(target inference) will append to paddle_inference_pass.h
unset(INFER_IR_PASSES CACHE) # clear the global variable
......@@ -34,7 +37,6 @@ function(pass_library TARGET DEST)
endif()
endfunction()
cc_library(node SRCS node.cc DEPS proto_desc)
cc_library(graph SRCS graph.cc DEPS node pretty_log)
cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
......@@ -43,6 +45,8 @@ cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper)
pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base)
......@@ -71,6 +75,7 @@ pass_library(runtime_context_cache_pass base)
pass_library(expected_kernel_cache_pass base)
pass_library(quant_conv2d_dequant_fuse_pass inference)
pass_library(fillconstant_elementwisemul_fuse inference)
pass_library(shuffle_channel_detect_pass inference)
if(ANAKIN_FOUND)
pass_library(simplify_anakin_priorbox_detection_out_pass inference)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.h"
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
DEFINE_uint64(fuse_parameter_memory_size, 0, // 0 KB
"fuse_parameter_memory_size is up limited memory size "
"of one group parameters' gradient which is the input "
"of communication calling(e.g NCCLAllReduce). "
"The default value is 0, it means that "
"not set group according to memory_size.");
DEFINE_int32(
fuse_parameter_groups_size, 3,
"fuse_parameter_groups_size is the size of one group parameters' gradient. "
"The default value is a experimental result. If the "
"fuse_parameter_groups_size is 1, it means that the groups size is "
"the number of parameters' gradient. If the fuse_parameter_groups_size is "
"-1, it means that there are only one group. The default value is 3, it is "
"an experimental value.");
namespace paddle {
namespace framework {
namespace ir {
// SetFuseParameterGroupsSize and SetFuseParameterMemorySize are used in unit
// test, because it is invalid that seting 'FLAGS_fuse_parameter_memory_size'
// and 'FLAGS_fuse_parameter_groups_size' in unit test.
void SetFuseParameterGroupsSize(int group_size) {
FLAGS_fuse_parameter_groups_size = group_size;
}
int GetFuseParameterGroupsSize() { return FLAGS_fuse_parameter_groups_size; }
void SetFuseParameterMemorySize(uint64_t memory_size) {
FLAGS_fuse_parameter_memory_size = memory_size;
}
uint64_t GetFuseParameterMemorySize() {
return FLAGS_fuse_parameter_memory_size;
}
static const char kUnKnow[] = "@UNKNOW@";
static framework::proto::VarType::Type kDefaultDtype =
framework::proto::VarType::Type::VarType_Type_BOOL;
class AllocContinuousSpaceForGradPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(details::kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(details::kLocalScopes);
ResetAttribute<details::ParamsAndGrads>(details::kParamsAndGrads, &result);
ResetAttribute<details::GroupGradsAndParams>(details::kGroupGradsAndParams,
&result);
// NOTE: The operator nodes should be in topology order.
std::vector<ir::Node *> topo_nodes = ir::TopologySortOperations(result);
auto &params_grads =
result.Get<details::ParamsAndGrads>(details::kParamsAndGrads);
for (auto &node : topo_nodes) {
RecordParamsAndGrads(node, &params_grads);
}
if (params_grads.size() == 0) {
VLOG(10) << "Doesn't find gradients";
return;
}
std::unordered_map<std::string, ir::Node *> vars;
for (ir::Node *node : result.Nodes()) {
if (node->IsVar() && node->Var()) {
// Note: The graph may have the same name node. For example, parameter
// is the input of operator and it also is the output of optimizer;
vars.emplace(node->Var()->Name(), node);
}
}
auto &group_grads_params =
result.Get<details::GroupGradsAndParams>(details::kGroupGradsAndParams);
// Note: the order of params_grads may be changed by SetGroupGradsAndParams.
SetGroupGradsAndParams(vars, params_grads, &group_grads_params);
params_grads.clear();
for (auto &group_p_g : group_grads_params) {
params_grads.insert(params_grads.begin(), group_p_g.begin(),
group_p_g.end());
}
for (auto &p_g : params_grads) {
std::swap(p_g.first, p_g.second);
}
// Set Gradients as Persistable to prevent this var becoming reusable.
auto dtype = kDefaultDtype;
for (auto &p_g : params_grads) {
// Get gradient var
auto iter = vars.find(p_g.second);
PADDLE_ENFORCE(iter != vars.end(), "%s is not found.", p_g.second);
iter->second->Var()->SetPersistable(true);
PADDLE_ENFORCE(IsSupportedVarType(iter->second->Var()->GetType()));
// Get Dtype
auto ele_dtype = iter->second->Var()->GetDataType();
if (dtype == kDefaultDtype) {
dtype = ele_dtype;
PADDLE_ENFORCE_NE(ele_dtype, kDefaultDtype,
"The data type should not be bool.");
}
PADDLE_ENFORCE_EQ(ele_dtype, dtype,
"The data type of input is not consistent.");
}
// Create a FusedVarsSet to avoid duplicating names for fused_var in other
// pass.
if (!result.Has(details::kFusedVars)) {
result.Set(details::kFusedVars, new details::FusedVars);
}
// the kFusedGrads is used be fuse_optimizer_op_pass.
result.Set(details::kFusedGrads, new details::FusedGrads);
// the fused_var_name should be unique, so it appends
// params_grads.begin()->second.
auto fused_var_name = std::string(details::kFusedVarNamePrefix) + "@GRAD@" +
params_grads.begin()->second;
result.Get<details::FusedGrads>(details::kFusedGrads) = fused_var_name;
auto &fused_var_set = result.Get<details::FusedVars>(details::kFusedVars);
PADDLE_ENFORCE_EQ(fused_var_set.count(fused_var_name), 0,
"%s is duplicate in FusedVars.", fused_var_name);
fused_var_set.insert(fused_var_name);
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars,
fused_var_name, params_grads);
}
template <typename AttrType>
void ResetAttribute(const std::string &attr_name, ir::Graph *graph) const {
if (graph->Has(attr_name)) {
VLOG(10) << attr_name << " is reset.";
graph->Erase(attr_name);
}
graph->Set(attr_name, new AttrType);
}
void SetGroupGradsAndParams(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const details::ParamsAndGrads &params_grads,
details::GroupGradsAndParams *group_grads_params) const {
SetGroupAccordingToLayers(var_nodes, params_grads, group_grads_params);
SetGroupAccordingToMemorySize(var_nodes, group_grads_params);
SetGroupAccordingToGroupSize(var_nodes, group_grads_params);
}
void SetGroupAccordingToLayers(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
const details::ParamsAndGrads &params_grads,
details::GroupGradsAndParams *group_grads_params) const {
std::unordered_map<std::string, std::vector<int>> layer_params;
for (size_t i = 0; i < params_grads.size(); ++i) {
auto pos = params_grads[i].first.find_first_of(".");
if (pos == std::string::npos) {
layer_params[std::string(kUnKnow)].emplace_back(i);
} else {
layer_params[params_grads[i].first.substr(0, pos)].emplace_back(i);
}
}
group_grads_params->reserve(layer_params.size());
for (size_t i = 0; i < params_grads.size(); ++i) {
auto pos = params_grads[i].first.find_first_of(".");
std::string key = kUnKnow;
if (pos != std::string::npos) {
key = params_grads[i].first.substr(0, pos);
}
auto iter = layer_params.find(key);
if (iter == layer_params.end()) continue;
group_grads_params->emplace_back();
auto &local_group_grads_params = group_grads_params->back();
for (auto &idx : iter->second) {
local_group_grads_params.emplace_back(
std::make_pair(params_grads[idx].second, params_grads[idx].first));
}
layer_params.erase(iter);
}
VLOG(10) << "SetGroupAccordingToLayers: ";
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &p_g : group_grads_params->at(i)) {
out << "(" << p_g.second << ", " << p_g.first << "), ";
}
VLOG(10) << out.str();
}
}
void SetGroupAccordingToMemorySize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
details::GroupGradsAndParams *group_grads_params) const {
const uint64_t group_memory_size = GetFuseParameterMemorySize();
if (group_memory_size == 0) {
return;
}
details::GroupGradsAndParams local_group_grads_params;
size_t j = 0;
while (j < group_grads_params->size()) {
local_group_grads_params.emplace_back();
auto &group_p_g = local_group_grads_params.back();
size_t local_group_memory_size = 0;
while (j < group_grads_params->size()) {
std::for_each(
group_grads_params->at(j).begin(), group_grads_params->at(j).end(),
[&local_group_memory_size,
&var_nodes](const std::pair<std::string, std::string> &g_p) {
auto iter = var_nodes.find(g_p.second);
PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.",
g_p.second);
auto shape = iter->second->Var()->GetShape();
size_t size =
framework::SizeOfType(iter->second->Var()->GetDataType());
std::for_each(shape.begin(), shape.end(),
[&size](const int64_t &n) { size *= n; });
local_group_memory_size += size;
});
group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(),
group_grads_params->at(j).end());
++j;
if (local_group_memory_size >= group_memory_size) {
break;
}
}
}
std::swap(*group_grads_params, local_group_grads_params);
VLOG(10) << string::Sprintf(
"SetGroupAccordingToMemorySize(memory_size: %d):", group_memory_size);
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &g_p : group_grads_params->at(i)) {
auto iter = var_nodes.find(g_p.second);
PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.", g_p.second);
auto shape = iter->second->Var()->GetShape();
size_t size = framework::SizeOfType(iter->second->Var()->GetDataType());
std::for_each(shape.begin(), shape.end(),
[&size](const int64_t &n) { size *= n; });
out << string::Sprintf("(%s(%d), %s)", g_p.second, size, g_p.first);
}
VLOG(10) << out.str();
}
}
void SetGroupAccordingToGroupSize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
details::GroupGradsAndParams *group_grads_params) const {
if (GetFuseParameterGroupsSize() == 1) {
return;
}
const int group_size = GetFuseParameterGroupsSize() == -1
? static_cast<int>(group_grads_params->size())
: GetFuseParameterGroupsSize();
PADDLE_ENFORCE_GT(group_size, 1);
size_t groups = (group_grads_params->size() + group_size - 1) / group_size;
details::GroupGradsAndParams local_group_grads_params;
local_group_grads_params.reserve(groups);
size_t j = 0;
for (size_t i = 0; i < groups; ++i) {
local_group_grads_params.emplace_back();
auto &group_p_g = local_group_grads_params.back();
group_p_g.reserve(group_size);
while (j < group_grads_params->size()) {
group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(),
group_grads_params->at(j).end());
++j;
if (j % group_size == 0) break;
}
}
std::swap(*group_grads_params, local_group_grads_params);
VLOG(10) << string::Sprintf("SetGroupAccordingToGroupSize(group_size: %d):",
group_size);
for (size_t i = 0; i < group_grads_params->size(); ++i) {
VLOG(10) << "group " << i;
std::stringstream out;
for (auto &p_g : group_grads_params->at(i)) {
out << "(" << p_g.second << ", " << p_g.first << "), ";
}
VLOG(10) << out.str();
}
}
private:
bool IsSupportedVarType(const proto::VarType::Type &type) const {
// Current only support LOD_TENSOR.
return type == proto::VarType::LOD_TENSOR;
}
void RecordParamsAndGrads(ir::Node *node,
details::ParamsAndGrads *params_grads) const {
try {
bool is_bk_op =
static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
if (!is_bk_op) return;
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
auto backward_vars =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, static_cast<size_t>(0));
for (size_t i = 0; i < backward_vars.size(); i += 2) {
VLOG(10) << "Trainable parameter: " << backward_vars[i]
<< ", gradient: " << backward_vars[i + 1];
params_grads->emplace_back(std::make_pair(
backward_vars[i] /*param*/, backward_vars[i + 1] /*grad*/));
}
} catch (boost::bad_get e) {
}
}
void InitFusedVarsAndAllocSpaceForVars(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::unordered_map<std::string, ir::Node *> &vars,
const std::string &fused_var_name,
const details::ParamsAndGrads &params_grads) const {
// Init Gradients and FusedVars
VLOG(10) << "Init FusedVars and Gradients.";
for (auto it = local_scopes.rbegin(); it != local_scopes.rend(); ++it) {
auto &scope = *it;
PADDLE_ENFORCE(scope->FindVar(fused_var_name) == nullptr,
"%s has existed in scope.", fused_var_name);
scope->Var(fused_var_name)->GetMutable<LoDTensor>();
for (auto &p_g : params_grads) {
auto iter = vars.find(p_g.second);
PADDLE_ENFORCE(iter != vars.end());
PADDLE_ENFORCE_NOT_NULL(iter->second->Var());
PADDLE_ENFORCE_EQ(iter->second->Var()->GetType(),
proto::VarType::LOD_TENSOR);
scope->Var(p_g.second)->GetMutable<LoDTensor>();
}
}
// Alloc continuous space for vars.
std::vector<std::string> grads_name;
std::vector<std::string> params_name;
grads_name.reserve(params_grads.size());
params_name.reserve(params_grads.size());
for (auto &p_g : params_grads) {
params_name.emplace_back(p_g.first);
grads_name.emplace_back(p_g.second);
}
framework::ProgramDesc program_desc;
AppendAllocSpaceForVarsOp(params_name, grads_name, fused_var_name,
program_desc.MutableBlock(0));
for (size_t i = 0; i < local_scopes.size(); ++i) {
for (auto &op_desc : program_desc.Block(0).AllOps()) {
auto op = OpRegistry::CreateOp(*op_desc);
op->Run(*local_scopes[i], places[i]);
}
}
}
void AppendAllocSpaceForVarsOp(const std::vector<std::string> &params_name,
const std::vector<std::string> &grads_name,
const std::string &fused_var_name,
BlockDesc *global_block) const {
auto op_desc = global_block->AppendOp();
op_desc->SetType("alloc_continuous_space");
op_desc->SetInput("Input", params_name);
op_desc->SetOutput("Output", grads_name);
op_desc->SetOutput("FusedOutput", {fused_var_name});
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(alloc_continuous_space_for_grad_pass,
paddle::framework::ir::AllocContinuousSpaceForGradPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -11,21 +11,19 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include <algorithm>
namespace paddle {
namespace framework {
namespace details {
namespace ir {
void SetFuseParameterGroupsSize(int group_size);
int GetFuseParameterGroupsSize();
class ReferenceCountPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
void SetFuseParameterMemorySize(uint64_t memory_size);
uint64_t GetFuseParameterMemorySize();
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -48,17 +48,37 @@ void FCFusePass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern);
auto base_op_desc = mul->Op();
// Create an FC Node.
// OpDesc desc(base_op_desc, nullptr);
OpDesc desc;
std::string fc_x_in = subgraph.at(x)->Name();
std::string fc_Y_in = w->Name();
std::string fc_bias_in = fc_bias->Name();
std::string fc_out_out = fc_out->Name();
desc.SetInput("Input", std::vector<std::string>({fc_x_in}));
desc.SetInput("W", std::vector<std::string>({fc_Y_in}));
desc.SetInput("Bias", std::vector<std::string>({fc_bias_in}));
desc.SetOutput("Out", std::vector<std::string>({fc_out_out}));
desc.SetAttr("in_num_col_dims", mul->Op()->GetAttr("x_num_col_dims"));
// For anakin subgraph int8
// When in anakin subgraph int8 mode, the pattern like "fake_quant + mul +
// fake_dequant"
// can be detected by the quant_dequant_fuse_pass. This pass will add
// "input_scale",
// "weight_scale" which are extracted from fake_quant op and fake_dequant op
// to mul op,
// and then delete the fake_quant op and fake_dequant op in the graph. If
// the mul op
// has the scale info, we should add those to the fused fc.
if (base_op_desc->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", base_op_desc->GetAttr("enable_int8"));
desc.SetAttr("input_scale", base_op_desc->GetAttr("input_scale"));
desc.SetAttr("weight_scale", base_op_desc->GetAttr("weight_scale"));
}
desc.SetType("fc");
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out});
......
cc_library(fuse_optimizer_op_pass SRCS fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(fuse_adam_op_pass SRCS fuse_adam_op_pass.cc DEPS fuse_optimizer_op_pass)
cc_library(fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc DEPS fuse_optimizer_op_pass)
cc_library(fuse_momentum_op_pass SRCS fuse_momentum_op_pass.cc DEPS fuse_optimizer_op_pass)
......@@ -16,16 +16,13 @@
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class FuseAdamOpPass : public FuseOptimizerOpPass {
private:
......@@ -203,10 +200,10 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
}
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_adam_op_pass, paddle::framework::details::FuseAdamOpPass)
REGISTER_PASS(fuse_adam_op_pass, paddle::framework::ir::FuseAdamOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
......@@ -16,14 +16,13 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class FuseMomentumOpPass : public FuseOptimizerOpPass {
private:
......@@ -84,11 +83,10 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass {
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_momentum_op_pass,
paddle::framework::details::FuseMomentumOpPass)
REGISTER_PASS(fuse_momentum_op_pass, paddle::framework::ir::FuseMomentumOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h"
#include <algorithm>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_helper.h"
......@@ -20,13 +20,13 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(kLocalScopes);
auto &places = Get<const std::vector<platform::Place>>(details::kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(details::kLocalScopes);
const std::string fuse_op_type = GetOpType();
std::vector<std::string> aux_var_names = GetAuxiliaryVarNames();
......@@ -47,24 +47,24 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
return;
}
if (result.Has(kFusedOptType)) {
if (result.Has(details::kFusedOptType)) {
VLOG(6) << "Currently only support fusing one type optimizer op. Has fused "
<< result.Get<FusedOptType>(kFusedOptType);
<< result.Get<details::FusedOptType>(details::kFusedOptType);
return;
} else {
result.Set(kFusedOptType, new FusedOptType);
result.Set(details::kFusedOptType, new details::FusedOptType);
}
result.Get<FusedOptType>(kFusedOptType) = fuse_op_type;
result.Get<details::FusedOptType>(details::kFusedOptType) = fuse_op_type;
// Step 2: Insert fused_var_name to FusedVars, and the FusedVars need be
// initialized in scopes before execution.
if (!result.Has(kFusedVars)) {
result.Set(kFusedVars, new FusedVars);
if (!result.Has(details::kFusedVars)) {
result.Set(details::kFusedVars, new details::FusedVars);
}
std::unordered_map<std::string, std::string> fused_vars_name;
fused_vars_name.reserve(aux_var_names.size());
auto &fused_var_set = result.Get<FusedVars>(kFusedVars);
const std::string prefix(kFusedVarNamePrefix);
auto &fused_var_set = result.Get<details::FusedVars>(details::kFusedVars);
const std::string prefix(details::kFusedVarNamePrefix);
// NOTE: the fused_var_name should be unique.
for (auto &var_name : aux_var_names) {
auto fused_var_name = prefix + "_" + fuse_op_type + "_" + var_name + "_" +
......@@ -77,8 +77,9 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
// Step 3: Get the fused Gradient's name
bool grad_fused = false;
if (result.Has(kParamsAndGrads)) {
auto &params_grads = result.Get<ParamsAndGrads>(kParamsAndGrads);
if (result.Has(details::kParamsAndGrads)) {
auto &params_grads =
result.Get<details::ParamsAndGrads>(details::kParamsAndGrads);
PADDLE_ENFORCE_EQ(
params_grads.size(), aux_var_set.at(kGrad).size(),
"The number of gradients and optimizer ops is not equal.");
......@@ -94,13 +95,13 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
// NOTE(zcd): the gradient of kParamsAndGrads may be different with the
// kGrad.
if (same_grad_num == aux_var_set.at(kGrad).size()) {
if (!result.Has(kFusedGrads)) {
if (!result.Has(details::kFusedGrads)) {
PADDLE_THROW(
"The alloc_continuous_space_for_grad_pass should be called before "
"this pass.");
}
auto &fused_grad = result.Get<FusedGrads>(kFusedGrads);
auto &fused_vars = result.Get<FusedVars>(kFusedVars);
auto &fused_grad = result.Get<details::FusedGrads>(details::kFusedGrads);
auto &fused_vars = result.Get<details::FusedVars>(details::kFusedVars);
auto iter = std::find(fused_vars.begin(), fused_vars.end(), fused_grad);
PADDLE_ENFORCE(iter != fused_vars.end(), "Not find the fused_grad.");
fused_vars_name[kGrad] = fused_grad;
......@@ -323,6 +324,6 @@ void FuseOptimizerOpPass::InserInputAndOutputForOptOps(
opt_node->outputs.insert(opt_node->outputs.begin(), outputs.begin(),
outputs.end());
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -25,7 +25,7 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
constexpr char kGrad[] = "Grad";
constexpr char kParam[] = "Param";
......@@ -90,6 +90,6 @@ class FuseOptimizerOpPass : public ir::Pass {
const std::string &fused_var_name) const;
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -14,18 +14,13 @@
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class FuseSgdOpPass : public FuseOptimizerOpPass {
private:
......@@ -66,10 +61,10 @@ class FuseSgdOpPass : public FuseOptimizerOpPass {
InserInputAndOutputForOptOps(sgd_ops, sgd_node);
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_sgd_op_pass, paddle::framework::details::FuseSgdOpPass)
REGISTER_PASS(fuse_sgd_op_pass, paddle::framework::ir::FuseSgdOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
......@@ -13,10 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
......@@ -61,7 +66,16 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
var->outputs.push_back(node);
}
// For output args, always create a new var.
std::unordered_set<std::string> out_arg_set;
for (auto &each_var_name : op->OutputArgumentNames()) {
if (each_var_name != kEmptyVarName) {
PADDLE_ENFORCE(out_arg_set.count(each_var_name) == 0,
"Program is wrong. %s occurs in output of %s several "
"times.",
each_var_name, op->Type());
out_arg_set.insert(each_var_name);
}
ir::Node *var = nullptr;
if (all_vars.count(each_var_name) != 0) {
var = CreateVarNode(all_vars.at(each_var_name));
......
......@@ -1640,7 +1640,8 @@ PDNode *patterns::FillConstantElementWiseMulFuse::operator()(
void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
const std::string &op_type,
const std::string &weight_name,
int times) {
int times,
const std::string &quant_type) {
const int kNumFields = 5;
const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1;
......@@ -1648,22 +1649,20 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4;
// the quant op always be one.
auto quant_op_in_scale =
pattern->NewNode(GetNodeName("quant_op_in_scale"))
->assert_is_op_input("fake_quantize_range_abs_max", "InScale")
auto quant_op_in_scale = pattern->NewNode(GetNodeName("quant_op_in_scale"))
->assert_is_op_input(quant_type, "InScale")
->AsInput();
auto quant_op = pattern->NewNode(GetNodeName("quant_op"))
->assert_is_op("fake_quantize_range_abs_max");
auto quant_op =
pattern->NewNode(GetNodeName("quant_op"))->assert_is_op(quant_type);
auto quant_op_out_scale =
pattern->NewNode(GetNodeName("quant_op_out_scale"))
->assert_is_op_output("fake_quantize_range_abs_max", "OutScale")
->assert_is_op_output(quant_type, "OutScale")
->assert_is_op_input("fake_dequantize_max_abs", "Scale")
->AsIntermediate();
auto quant_op_out =
pattern->NewNode(GetNodeName("quant_op_out"))
->assert_is_op_output("fake_quantize_range_abs_max", "Out")
auto quant_op_out = pattern->NewNode(GetNodeName("quant_op_out"))
->assert_is_op_output(quant_type, "Out")
->assert_is_op_input(op_type)
->AsIntermediate();
......@@ -1707,6 +1706,37 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
}
}
void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) {
auto reshape1_op =
pattern->NewNode(reshape1_op_repr())->assert_is_op("reshape2");
auto reshape1_out = pattern->NewNode(reshape1_out_repr())
->assert_is_op_output("reshape2", "Out")
->assert_is_op_input("transpose2")
->AsIntermediate();
auto transpose_op =
pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2");
auto transpose_out = pattern->NewNode(transpose_out_repr())
->assert_is_op_output("transpose2", "Out")
->assert_is_op_input("reshape2")
->AsIntermediate();
auto reshape2_op =
pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2");
auto reshape2_out = pattern->NewNode(reshape2_out_repr())
->assert_is_op_output("reshape2", "Out")
->AsOutput();
reshape1_op->LinksFrom({reshape1_in});
reshape1_out->LinksFrom({reshape1_op});
transpose_op->LinksFrom({reshape1_out});
transpose_out->LinksFrom({transpose_op});
reshape2_op->LinksFrom({transpose_out});
reshape2_out->LinksFrom({reshape2_op});
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -880,7 +880,8 @@ struct QuantDequantOpFuse : public PatternBase {
: PatternBase(pattern, name_scope, "quant_dequant_fuse") {}
void operator()(PDNode* quant_op_input, const std::string& op_name,
const std::string& weight_name, int times = 1);
const std::string& weight_name, int times,
const std::string& quant_type);
std::string GetNodeName(const std::string& op_type) {
return PDNodeName(name_scope_, repr_, id_, op_type);
......@@ -891,6 +892,21 @@ struct QuantDequantOpFuse : public PatternBase {
}
};
struct ShuffleChannelPattern : public PatternBase {
ShuffleChannelPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "shufflechannel_pattern") {}
void operator()(PDNode* reshape1_in);
PATTERN_DECL_NODE(reshape1_op);
PATTERN_DECL_NODE(reshape1_out);
PATTERN_DECL_NODE(transpose_op);
PATTERN_DECL_NODE(transpose_out);
PATTERN_DECL_NODE(reshape2_op);
PATTERN_DECL_NODE(reshape2_out);
};
} // namespace patterns
// Link two ir::Nodes from each other.
......
cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base)
cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle)
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
if(WITH_GPU)
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper gpu_info)
else()
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper cpu_info)
endif()
cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass)
cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info)
cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass reference_count_pass_helper)
cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper)
......@@ -27,11 +27,11 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
// op -> variables which can be deleted after op runs
using OpToVarNameSetMap =
std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>>;
using OpToVarNameSetMap = std::unordered_map<details::ComputationOpHandle *,
std::unordered_set<std::string>>;
static std::map<size_t, std::unordered_set<std::string>> VarsGroupByScopeIdx(
const OpToVarNameSetMap &map) {
......@@ -53,7 +53,8 @@ static bool IsLoDTensor(VarDesc *var) {
// Get memory size of LoDTensor
static int64_t GetMemorySize(
const std::unordered_map<std::string, std::vector<VarHandle *>> &vars,
const std::unordered_map<std::string, std::vector<details::VarHandle *>>
&vars,
const std::string &var_name) {
auto *var_desc = TryGetLatestVarDesc(vars.at(var_name));
PADDLE_ENFORCE_NOT_NULL(var_desc);
......@@ -69,13 +70,13 @@ static int64_t GetMemorySize(
// Since partial GC is based on static analysis of memory size of each variable
// So we should skip SelectedRows and LoDTensorArray here
static void SplitIntoLoDTensorAndNonLoDTensorVars(
const OpToVarNameSetMap &m, const GraphVars &vars,
const OpToVarNameSetMap &m, const details::GraphVars &vars,
OpToVarNameSetMap *lod_tensors, OpToVarNameSetMap *other_vars) {
lod_tensors->clear();
other_vars->clear();
for (auto &op_vars_pair : m) {
for (auto &var_name : op_vars_pair.second) {
for (auto var_name : op_vars_pair.second) {
auto *var_desc = TryGetLatestVarDesc(
vars[op_vars_pair.first->GetScopeIdx()].at(var_name));
if (IsLoDTensor(var_desc)) {
......@@ -89,7 +90,7 @@ static void SplitIntoLoDTensorAndNonLoDTensorVars(
struct GCVarInfo {
GCVarInfo(const std::string &name, int64_t memory_size,
ComputationOpHandle *op, size_t scope_idx)
details::ComputationOpHandle *op, size_t scope_idx)
: name_(name),
memory_size_(memory_size),
op_(op),
......@@ -97,7 +98,8 @@ struct GCVarInfo {
std::string name_; // variable name
int64_t memory_size_; // memory size
ComputationOpHandle *op_; // op after which the variable could be deleted
details::ComputationOpHandle
*op_; // op after which the variable could be deleted
size_t scope_idx_; // scope index where the variable locates
int64_t AbsMemorySize() const { return std::abs(memory_size_); }
......@@ -105,7 +107,7 @@ struct GCVarInfo {
// Delete delete_lod_tensor_only is not used currently
static OpToVarNameSetMap ShrinkGCVars(
const OpToVarNameSetMap &m, const GraphVars &vars,
const OpToVarNameSetMap &m, const details::GraphVars &vars,
const std::vector<platform::Place> &places, double fraction_of_memory_size,
bool delete_lod_tensor_only = false) {
// Do not perform gc when fraction_of_memory_size = 0
......@@ -192,7 +194,7 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE(ref_cnts.empty(),
"kRuntimeReferenceCount should be initialized here!");
const auto &vars = graph->Get<GraphVars>(kGraphVars);
const auto &vars = graph->Get<details::GraphVars>(details::kGraphVars);
ref_cnts.resize(vars.size());
const auto &last_live_ops =
......@@ -222,27 +224,31 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
auto *eager_deletion_node =
graph->CreateEmptyNode("eager_deletion", ir::Node::Type::kOperation);
auto *eager_deletion_op = new EagerDeletionOpHandle(
auto *eager_deletion_op = new details::EagerDeletionOpHandle(
eager_deletion_node, op->GetScope(), op->GetPlace(), var_names,
gcs.at(places[op->GetScopeIdx()]).get(),
&(ref_cnts[op->GetScopeIdx()]));
auto it = std::find_if(
op->Outputs().begin(), op->Outputs().end(), [](VarHandleBase *var) {
return dynamic_cast<DummyVarHandle *>(var) != nullptr;
op->Outputs().begin(), op->Outputs().end(),
[](details::VarHandleBase *var) {
return dynamic_cast<details::DummyVarHandle *>(var) != nullptr;
});
if (it != op->Outputs().end()) {
eager_deletion_op->AddInput(*it);
} else {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
auto *dep_var = new details::DummyVarHandle(graph->CreateControlDepVar());
graph->Get<details::GraphDepVars>(details::kGraphDepVars)
.emplace(dep_var);
op->AddOutput(dep_var);
eager_deletion_op->AddInput(dep_var);
}
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
auto *dummy_leaf =
new details::DummyVarHandle(graph->CreateControlDepVar());
graph->Get<details::GraphDepVars>(details::kGraphDepVars)
.emplace(dummy_leaf);
eager_deletion_op->AddOutput(dummy_leaf);
}
......@@ -262,15 +268,14 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
while_op_eager_deletion_pass->Apply(graph);
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(eager_deletion_pass,
paddle::framework::details::EagerDeletionPass)
.RequirePassAttr(paddle::framework::details::kRuntimeReferenceCount)
.RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars)
.RequirePassAttr(paddle::framework::details::kAllPlaces)
.RequirePassAttr(paddle::framework::details::kGarbageCollector);
REGISTER_PASS(eager_deletion_pass, paddle::framework::ir::EagerDeletionPass)
.RequirePassAttr(paddle::framework::ir::kRuntimeReferenceCount)
.RequirePassAttr(paddle::framework::ir::kLastLiveOpsOfVars)
.RequirePassAttr(paddle::framework::ir::kAllPlaces)
.RequirePassAttr(paddle::framework::ir::kGarbageCollector);
USE_PASS(while_op_eager_deletion_pass);
......@@ -16,9 +16,9 @@
#include <queue>
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_info.h"
......@@ -52,7 +52,7 @@ DECLARE_string(memory_optimize_debug);
namespace paddle {
namespace framework {
namespace details {
namespace ir {
// clang-format off
const std::string kInplacedOpWhiteList[] = { // NOLINT
......@@ -111,10 +111,14 @@ class InplacePass : public ir::Pass {
// Check whether all `ops` is the preceding ops of `op`
bool CheckOpDeps(ir::Node *op, const std::vector<ir::Node *> &ops) const;
// Find nodes whose name are equal to the given name
// Find nodes whose names are equal to the given name
static std::unordered_set<ir::Node *> FindNodesByName(
const std::string &name, const std::vector<ir::Node *> &nodes);
// Collect inputs and outputs of op_desc
static void CollectInputArgsOfOpDesc(
const OpDesc *op_desc, std::unordered_multiset<std::string> *in_args);
// Get all versions vars named var_name
std::vector<ir::Node *> *AllVersionVars(const std::string &var_name) const;
......@@ -195,43 +199,12 @@ bool InplacePass::CheckOpDeps(ir::Node *op,
void InplacePass::CollectSkipVars(ir::Graph *graph,
const std::vector<ir::Node *> &ops) const {
// 1. Collect op role vars
PADDLE_ENFORCE(graph->Has(details::kMemOptSkipVars),
"Graph should have attr %s", details::kMemOptSkipVars);
PADDLE_ENFORCE(graph->Has(kMemOptSkipVars), "Graph should have attr %s",
kMemOptSkipVars);
auto &mem_opt_whitelist = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
for (const auto &var : mem_opt_whitelist) {
skip_vars_.emplace(var);
}
// 2. track the nodes which used by parameter server.
// these node can not be inplaced, otherwise trainer
// pserver can not find each other's name.
// Also check the ops which has sub-block
auto update_skip_set = [&](ir::Node *node) {
for (auto &in : node->inputs) {
if (in->IsVar() && in->Var() != nullptr) {
skip_vars_.emplace(in->Name());
}
}
for (auto &out : node->outputs) {
if (out->IsVar() && out->Var() != nullptr) {
skip_vars_.emplace(out->Name());
}
}
};
for (auto *node : ops) {
if (!node->IsOp()) continue;
// avoid optimizing the variable used in sub-blocks
if (OpHasSubBlock(node->Op())) {
update_skip_set(node);
continue;
}
auto node_name = node->Name();
if (node_name == "send" || node_name == "recv" || node_name == "prefetch") {
update_skip_set(node);
}
}
}
void InplacePass::RenameInOut(ir::Node *op, ir::Node *in_var,
......@@ -301,6 +274,14 @@ std::unordered_set<ir::Node *> InplacePass::FindNodesByName(
return ret;
}
void InplacePass::CollectInputArgsOfOpDesc(
const OpDesc *op_desc, std::unordered_multiset<std::string> *in_args) {
in_args->clear();
for (auto &in_name : op_desc->InputArgumentNames()) {
in_args->insert(in_name);
}
}
void InplacePass::ApplyImpl(ir::Graph *graph) const {
// Step 1: topo sort ops, collect skip vars
auto ops = ir::TopologySortOperations(*graph);
......@@ -346,6 +327,11 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
}
auto in_to_outs = infer_inplace(*op_desc, use_cuda);
if (in_to_outs.empty()) continue;
std::unordered_multiset<std::string> all_in_args;
CollectInputArgsOfOpDesc(op_desc, &all_in_args);
for (auto &pair : in_to_outs) {
auto &in_param = pair.first;
auto &out_param = pair.second;
......@@ -387,6 +373,14 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
continue;
}
size_t in_arg_occur_times = all_in_args.count(in_arg);
if (in_arg_occur_times > 1) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " occurs " << in_arg_occur_times << " times in input of op "
<< op_type;
continue;
}
auto in_nodes = FindNodesByName(in_arg, op_node->inputs);
PADDLE_ENFORCE(!in_nodes.empty(), "Input(%s)=%s cannot be found in op %s",
in_param, in_arg, op_type);
......@@ -458,8 +452,7 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
continue;
}
if (details::NodeSize(*in_node->Var()) !=
details::NodeSize(*out_node->Var()) &&
if (NodeSize(*in_node->Var()) != NodeSize(*out_node->Var()) &&
kSameShapeOpWhiteSet.count(op_desc->Type()) == 0) {
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
<< " is not the same size with "
......@@ -482,9 +475,9 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
}
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(inplace_pass, paddle::framework::details::InplacePass)
.RequirePassAttr(paddle::framework::details::kUseCuda);
REGISTER_PASS(inplace_pass, paddle::framework::ir::InplacePass)
.RequirePassAttr(paddle::framework::ir::kUseCuda);
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include <algorithm>
#include <deque>
#include <functional>
......@@ -32,14 +32,15 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
using paddle::framework::VarDesc;
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph) {
PADDLE_ENFORCE(graph.Has(kStaleProgramOpDescs),
PADDLE_ENFORCE(graph.Has(details::kStaleProgramOpDescs),
"Graph has no attribute of kStaleProgramOpDescs.");
// 1. get op desc order
auto& op_descs = graph.Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs);
auto& op_descs =
graph.Get<const std::vector<OpDesc*>>(details::kStaleProgramOpDescs);
// 2. topology sort order
auto nodes = graph.Nodes();
......@@ -563,6 +564,6 @@ ir::Node* ControlFlowGraph::GetNodeByName(const std::string& name,
return found_node;
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -29,7 +29,7 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
/// this attribute is used to avoid some core variables removed/reused
/// in memory optimize related passes
......@@ -184,6 +184,6 @@ void FilterVariables(const Container& nodes, Callback callback) {
FilterVariableImpl<Container, Callback>()(nodes, callback);
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -11,8 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include <algorithm>
#include <iostream>
#include <iterator>
......@@ -32,7 +31,7 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
TEST(OrderedSet, Normal) {
OrderedSet pool;
......@@ -153,7 +152,7 @@ TEST(OrderedSet, FindBestFitNode) {
ASSERT_TRUE(cache == nullptr);
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -188,7 +187,7 @@ REGISTER_OPERATOR(dummy, paddle::framework::DummyOp,
namespace paddle {
namespace framework {
namespace details {
namespace ir {
inline static ProgramDesc FillProgramDesc() {
ProgramDesc prog;
......@@ -521,6 +520,6 @@ TEST(SortOpLikeDescOrder, AddAndReplaceOpDescInplace) {
}
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_pass.h"
#include <algorithm>
#include <atomic>
#include <deque>
......@@ -42,12 +42,12 @@ DEFINE_string(memory_optimize_debug, "",
namespace paddle {
namespace framework {
namespace details {
namespace ir {
void MemoryOptimizePass::ApplyImpl(ir::Graph* graph) const {
CollectSkipVarsSet(graph);
cfg_.reset(new details::ControlFlowGraph(*graph));
cfg_.reset(new ControlFlowGraph(*graph));
cfg_->LiveVariableAnalysis();
InitSSAGraphNodes();
......@@ -205,30 +205,10 @@ void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {
void MemoryOptimizePass::CollectSkipVarsSet(ir::Graph* graph) const {
// fill skip_set_
PADDLE_ENFORCE(graph->Has(details::kMemOptSkipVars));
PADDLE_ENFORCE(graph->Has(kMemOptSkipVars));
auto& mem_opt_whitelist = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
for (const auto& var : mem_opt_whitelist) skip_set_.emplace(var);
auto update_skip_set = [&](OpDesc* op_desc) {
auto inputs = op_desc->InputArgumentNames();
auto outputs = op_desc->OutputArgumentNames();
skip_set_.insert(inputs.begin(), inputs.end());
skip_set_.insert(outputs.begin(), outputs.end());
};
auto nodes = graph->Nodes();
for (auto& op : nodes) {
if (!op->IsOp() || op->Op() == nullptr) continue;
auto* op_desc = op->Op();
// NOTE(dzhwinter):
// current block can not reuse next level block vars.
if (OpHasSubBlock(op_desc)) update_skip_set(op_desc);
// NOTE(dzhwinter):
// distributed ops input/output name need to
// keep same bettwen trainer/pserver
if (op_desc->Type() == "send") update_skip_set(op_desc);
if (op_desc->Type() == "recv") update_skip_set(op_desc);
if (op_desc->Type() == "prefetch") update_skip_set(op_desc);
for (const auto& var : mem_opt_whitelist) {
skip_set_.emplace(var);
}
}
......@@ -336,10 +316,9 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
}
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(memory_optimize_pass,
paddle::framework::details::MemoryOptimizePass)
REGISTER_PASS(memory_optimize_pass, paddle::framework::ir::MemoryOptimizePass)
.RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
......@@ -26,13 +26,13 @@
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class MemoryOptimizePass : public ir::Pass {
protected:
......@@ -67,6 +67,6 @@ class MemoryOptimizePass : public ir::Pass {
mutable std::map<std::string, std::vector<ir::Node*>> var_nodes_;
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -12,17 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h"
#include <queue>
#include <utility>
namespace paddle {
namespace framework {
namespace details {
namespace ir {
OpGraphView::OpGraphView(const std::vector<OpHandleBase *> &ops) { Build(ops); }
OpGraphView::OpGraphView(const std::vector<details::OpHandleBase *> &ops) {
Build(ops);
}
void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
void OpGraphView::Build(const std::vector<details::OpHandleBase *> &ops) {
preceding_ops_.clear();
pending_ops_.clear();
for (auto &op : ops) {
......@@ -40,8 +42,8 @@ void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
"There are duplicate ops in graph.");
}
std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
std::unordered_set<OpHandleBase *> ret;
std::unordered_set<details::OpHandleBase *> OpGraphView::AllOps() const {
std::unordered_set<details::OpHandleBase *> ret;
ret.reserve(preceding_ops_.size());
for (auto &pair : preceding_ops_) {
ret.insert(pair.first);
......@@ -49,21 +51,21 @@ std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
return ret;
}
bool OpGraphView::HasOp(OpHandleBase *op) const {
bool OpGraphView::HasOp(details::OpHandleBase *op) const {
return preceding_ops_.count(op) != 0;
}
void OpGraphView::EnforceHasOp(OpHandleBase *op) const {
void OpGraphView::EnforceHasOp(details::OpHandleBase *op) const {
PADDLE_ENFORCE(HasOp(op), "Cannot find op %s in OpGraphView",
op == nullptr ? "nullptr" : op->DebugString());
}
const std::unordered_set<OpHandleBase *> &OpGraphView::PendingOps(
OpHandleBase *op) const {
const std::unordered_set<details::OpHandleBase *> &OpGraphView::PendingOps(
details::OpHandleBase *op) const {
EnforceHasOp(op);
return pending_ops_.at(op);
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -22,39 +22,42 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class OpGraphView {
public:
explicit OpGraphView(const std::vector<OpHandleBase *> &ops);
explicit OpGraphView(const std::vector<details::OpHandleBase *> &ops);
std::unordered_set<OpHandleBase *> AllOps() const;
std::unordered_set<details::OpHandleBase *> AllOps() const;
const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const;
const std::unordered_set<details::OpHandleBase *> &PendingOps(
details::OpHandleBase *op) const;
bool HasOp(OpHandleBase *op) const;
bool HasOp(details::OpHandleBase *op) const;
// Use a visitor to visit all pending ops of op
// Stop when callback returns false
template <typename Callback>
bool VisitAllPendingOps(OpHandleBase *op, Callback &&callback) const;
bool VisitAllPendingOps(details::OpHandleBase *op, Callback &&callback) const;
private:
void Build(const std::vector<OpHandleBase *> &ops);
void EnforceHasOp(OpHandleBase *op) const;
void Build(const std::vector<details::OpHandleBase *> &ops);
void EnforceHasOp(details::OpHandleBase *op) const;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
std::unordered_map<details::OpHandleBase *,
std::unordered_set<details::OpHandleBase *>>
preceding_ops_;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
std::unordered_map<details::OpHandleBase *,
std::unordered_set<details::OpHandleBase *>>
pending_ops_;
};
template <typename Callback>
bool OpGraphView::VisitAllPendingOps(OpHandleBase *op,
bool OpGraphView::VisitAllPendingOps(details::OpHandleBase *op,
Callback &&callback) const {
EnforceHasOp(op);
std::unordered_set<OpHandleBase *> visited;
std::queue<OpHandleBase *> q;
std::unordered_set<details::OpHandleBase *> visited;
std::queue<details::OpHandleBase *> q;
q.push(op);
while (!q.empty()) {
op = q.front();
......@@ -72,6 +75,6 @@ bool OpGraphView::VisitAllPendingOps(OpHandleBase *op,
return true;
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
namespace ir {
class RecordSkipMemoryOptVarsPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override {
PADDLE_ENFORCE(!graph->Has(kMemOptSkipVars));
graph->Set(kMemOptSkipVars, new MemOptSkipVars);
auto& skip_vars = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
std::vector<ir::Node*> op_nodes;
for (auto& node : graph->Nodes()) {
PADDLE_ENFORCE_NOT_NULL(node, "The node should not be nullptr.");
if (node->IsOp() && node->Op()) {
op_nodes.emplace_back(node);
}
}
// Insert kEmptyVarName to avoid optimizing empty variable
skip_vars.insert(framework::kEmptyVarName);
// NOTE(zcd): Insert OpRoleVars to SkipVarSet to prevent the vars are rename
// in memory optimize pass.
InsertOpRoleVarsToSkipVarSet(op_nodes, &skip_vars);
InsertSkipMemOptOpInOutToSkipVarSet(op_nodes, &skip_vars);
}
private:
static void InsertOpRoleVarsToSkipVarSet(const std::vector<ir::Node*>& ops,
MemOptSkipVars* skip_vars) {
for (auto& node : ops) {
try {
auto op_role_vars =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(op_role_vars.size() % 2, 0);
for (size_t i = 0; i < op_role_vars.size(); i += 2) {
auto& g_name = op_role_vars[i + 1];
skip_vars->insert(g_name);
}
} catch (boost::bad_get& e) {
}
}
}
static void UpdateSkipVarSet(
MemOptSkipVars* skip_vars,
const std::vector<std::vector<std::string>>& var_names) {
for (auto& var_name : var_names) {
skip_vars->insert(var_name.begin(), var_name.end());
}
}
static std::vector<std::string> ToGradVarName(
const std::vector<std::string>& names) {
std::vector<std::string> ret;
ret.reserve(names.size());
for (auto& name : names) {
if (name != framework::kEmptyVarName) {
ret.emplace_back(framework::GradVarName(name));
}
}
return ret;
}
static void InsertSkipMemOptOpInOutToSkipVarSet(
const std::vector<ir::Node*>& ops, MemOptSkipVars* skip_vars) {
static std::unordered_set<std::string> kSkipMemOptOps{
"send", "recv", "prefetch", "send_barrier", "fetch_barrier"};
for (auto& node : ops) {
auto* op_desc = node->Op();
// Some ops (while, conditional_block, recurrent, etc.) have sub-blocks.
// These ops often use variables from its parent or forward blocks.
// Optimizing in/out of such ops would make these variables cannot
// be found when running sub-block ops.
if (OpHasSubBlock(op_desc)) {
UpdateSkipVarSet(skip_vars, {op_desc->InputArgumentNames(),
op_desc->OutputArgumentNames()});
}
// Skip ops that are related to parameter server.
// In distributed mode, trainers and parameter server use same
// variable names to track same variables. We cannot change the
// names of these variables, otherwise trainers or parameter
// server would not find them.
if (kSkipMemOptOps.count(op_desc->Type()) > 0) {
UpdateSkipVarSet(skip_vars, {op_desc->InputArgumentNames(),
op_desc->OutputArgumentNames()});
}
// FIXME(zjl): some ops use variables that are not from their
// inputs or outputs. We do not have a nice method to solve this
// issue yet. Currently, we should skip these variables when
// memory optimization is enabled.
auto op_type = op_desc->Type();
if (op_type == "while_grad") {
// In while_grad, framework::GradVarName(Input("X")) is visited
// without being any in/out of while_grad. While_grad uses
// these variable to accumulate gradient of X across time steps.
UpdateSkipVarSet(skip_vars, {ToGradVarName(op_desc->Input("X"))});
} else if (op_type == "conditional_block_grad") {
// In conditional_block_grad, framework::GradVarName(Input("Input",
// "Cond")) is visited without being any in/out of
// conditional_block_grad. Conditional_block_grad uses these
// variables to accumulate gradient of Input/Cond across time steps.
UpdateSkipVarSet(skip_vars, {ToGradVarName(op_desc->Input("Input")),
ToGradVarName(op_desc->Input("Cond"))});
} else if (op_type == "recurrent" || op_type == "recurrent_grad") {
// Recurrent and recurrent_grad ops are implemented by a very trickly
// way. Attr("states", "ex_states") is visited without being any
// in/out of op. It is because these variables are from sub blocks,
// not main block. Adding these variables to input would make recurrent
// fail since "states" and "ex_states" cannot be found in main block.
// When memory optimization is enabled, "states", "ex_states" and their
// gradient should be skipped.
auto& ex_states =
boost::get<std::vector<std::string>>(op_desc->GetAttr("ex_states"));
auto& states =
boost::get<std::vector<std::string>>(op_desc->GetAttr("states"));
if (op_type == "recurrent") {
UpdateSkipVarSet(skip_vars, {ex_states, states});
} else {
// In recurrent_grad, framework::GradVarName(Input("parameters",
// "input")) is visited without being any in/out of recurrent_grad.
// Recurrent_grad uses these variables to accumulate gradient of
// parameters/input across time steps.
UpdateSkipVarSet(
skip_vars,
{ToGradVarName(op_desc->Input("parameters")),
ToGradVarName(op_desc->Input("input")), ex_states, states,
ToGradVarName(ex_states), ToGradVarName(states)});
}
}
}
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(record_skip_memory_opt_vars_pass,
paddle::framework::ir::RecordSkipMemoryOptVarsPass);
......@@ -24,14 +24,20 @@
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class ReferenceCountPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override;
};
// A functor to shrink/remove operators who depend on other operators in a set
class ShrinkDepsOpFunctor {
......@@ -39,19 +45,21 @@ class ShrinkDepsOpFunctor {
enum RelationShip { kSame = 0, kNoDeps = 1, kBefore = 2, kAfter = 3 };
public:
explicit ShrinkDepsOpFunctor(const std::vector<OpHandleBase *> &all_ops)
explicit ShrinkDepsOpFunctor(
const std::vector<details::OpHandleBase *> &all_ops)
: graph_(all_ops) {}
template <typename OpSet>
OpSet operator()(const OpSet &op_set) const {
using KeyType = typename OpSet::key_type;
static_assert(
std::is_base_of<OpHandleBase,
std::is_base_of<details::OpHandleBase,
typename std::remove_pointer<KeyType>::type>::value,
"Key type of OpSet must be OpHandleBase, or derived of OpHandleBase");
"Key type of OpSet must be details::OpHandleBase, or derived of "
"details::OpHandleBase");
if (op_set.size() <= 1) return op_set;
std::vector<OpHandleBase *> ops(op_set.begin(), op_set.end());
std::vector<details::OpHandleBase *> ops(op_set.begin(), op_set.end());
OpSet ret;
auto rels = GetRelations(ops);
auto not_before = [](RelationShip r) { return r != kBefore; };
......@@ -65,8 +73,8 @@ class ShrinkDepsOpFunctor {
private:
std::vector<std::vector<RelationShip>> GetRelations(
const std::vector<OpHandleBase *> &ops) const {
std::unordered_map<OpHandleBase *, size_t> op_to_idx;
const std::vector<details::OpHandleBase *> &ops) const {
std::unordered_map<details::OpHandleBase *, size_t> op_to_idx;
for (size_t i = 0; i < ops.size(); ++i) {
PADDLE_ENFORCE(graph_.HasOp(ops[i]), "Op does not exist in graph");
op_to_idx[ops[i]] = i;
......@@ -81,7 +89,7 @@ class ShrinkDepsOpFunctor {
size_t found_num = ops.size();
size_t total_num = ops.size() * ops.size();
auto visitor = [&](OpHandleBase *op, size_t i) {
auto visitor = [&](details::OpHandleBase *op, size_t i) {
auto it = op_to_idx.find(op);
if (it != op_to_idx.end()) {
size_t j = it->second;
......@@ -98,7 +106,9 @@ class ShrinkDepsOpFunctor {
};
for (size_t i = 0; i < ops.size(); ++i) {
auto sub_visitor = [&, i](OpHandleBase *op) { return visitor(op, i); };
auto sub_visitor = [&, i](details::OpHandleBase *op) {
return visitor(op, i);
};
if (!graph_.VisitAllPendingOps(ops[i], sub_visitor)) {
break;
}
......@@ -133,8 +143,8 @@ class ShrinkDepsOpFunctor {
*/
static bool ShrinkNoNeedBufferVarOpDependency(
const std::string &var_name,
std::unordered_set<ComputationOpHandle *> *op_handles) {
std::vector<ComputationOpHandle *> skip_ops;
std::unordered_set<details::ComputationOpHandle *> *op_handles) {
std::vector<details::ComputationOpHandle *> skip_ops;
for (auto *op_handle : *op_handles) {
auto *op_base = op_handle->GetOp();
auto &inferer = op_base->Info().NoNeedBufferVarsInferer();
......@@ -195,15 +205,15 @@ static bool ShrinkNoNeedBufferVarOpDependency(
* Find the nearest downstream computation op handle. If the op is a
* computation op, just return itself.
*/
static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
OpHandleBase *op, size_t scope_idx) {
std::queue<OpHandleBase *> q;
std::unordered_set<OpHandleBase *> visited;
static details::ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
details::OpHandleBase *op, size_t scope_idx) {
std::queue<details::OpHandleBase *> q;
std::unordered_set<details::OpHandleBase *> visited;
q.push(op);
while (!q.empty()) {
auto *op = q.front();
q.pop();
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
auto *compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) {
return compute_op;
}
......@@ -220,13 +230,13 @@ static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
enum LastLiveOpSearchStatus { kSuccess, kFailure, kShouldPrecede };
static std::unordered_set<ComputationOpHandle *>
ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
static std::unordered_set<details::ComputationOpHandle *>
ExtractComputationOpFromLastLivedVar(details::VarHandle *var, size_t scope_idx,
const std::string &var_name,
const ShrinkDepsOpFunctor &shrink_func,
LastLiveOpSearchStatus *status) {
// stage one. Get last op for variable.
std::unordered_set<OpHandleBase *> candidates;
std::unordered_set<details::OpHandleBase *> candidates;
{
if (var->PendingOps().empty() && var->GeneratedOp()) {
// No operator depends on this variable. So the last operator is the op
......@@ -251,7 +261,7 @@ ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
// some op handle may operate on many DeviceContext, however, our garbage
// collector can only wait one DeviceContext for now. So currently, we wait
// the nearest compute op.
std::unordered_set<ComputationOpHandle *> computation_op;
std::unordered_set<details::ComputationOpHandle *> computation_op;
{
for (auto *op : candidates) {
auto *compute_op =
......@@ -293,13 +303,13 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
"Last Live Ops and Reference Counts of vars should be "
"initialized at here.");
const auto &vars = graph->Get<GraphVars>(kGraphVars);
const auto &vars = graph->Get<details::GraphVars>(details::kGraphVars);
last_live_ops_of_vars.resize(vars.size());
ref_cnts.resize(vars.size());
ShrinkDepsOpFunctor shrink_func(
ir::FilterByNodeWrapper<OpHandleBase>(*graph));
ir::FilterByNodeWrapper<details::OpHandleBase>(*graph));
VLOG(1) << "Place number: " << vars.size();
for (size_t i = 0; i < vars.size(); ++i) {
......@@ -360,11 +370,10 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
}
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(reference_count_pass,
paddle::framework::details::ReferenceCountPass)
.RequirePassAttr(paddle::framework::details::kGlobalReferenceCount)
.RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars);
REGISTER_PASS(reference_count_pass, paddle::framework::ir::ReferenceCountPass)
.RequirePassAttr(paddle::framework::ir::kGlobalReferenceCount)
.RequirePassAttr(paddle::framework::ir::kLastLiveOpsOfVars);
......@@ -12,23 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/var_desc.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars) {
VarDesc *TryGetLatestVarDesc(const std::vector<details::VarHandle *> &vars) {
VarDesc *var_desc = nullptr;
std::find_if(vars.rbegin(), vars.rend(), [&](VarHandle *var_handle) -> bool {
std::find_if(vars.rbegin(), vars.rend(),
[&](details::VarHandle *var_handle) -> bool {
var_desc = var_handle->Node()->Var();
return var_desc != nullptr;
});
return var_desc;
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -22,17 +22,16 @@
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/garbage_collector.h"
namespace paddle {
namespace framework {
class VarDesc;
class VarHandle;
namespace details {
class ComputationOpHandle;
namespace ir {
using ReferenceCountMap = std::unordered_map<std::string, size_t>;
......@@ -48,11 +47,12 @@ const char kGarbageCollector[] = "garbage_collector";
const char kAllPlaces[] = "all_places";
using LastLiveOpsOfVars =
std::unordered_map<std::string, std::unordered_set<ComputationOpHandle *>>;
std::unordered_map<std::string,
std::unordered_set<details::ComputationOpHandle *>>;
const char kLastLiveOpsOfVars[] = "last_live_ops_of_var";
VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars);
VarDesc *TryGetLatestVarDesc(const std::vector<details::VarHandle *> &vars);
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -19,19 +19,19 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class WhileOpEagerDeletionPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
auto all_ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph);
// Find all while_op and while_grad_op
std::unordered_map<size_t, std::pair<std::vector<OperatorBase *>,
std::vector<OperatorBase *>>>
target_ops;
for (auto *op : all_ops) {
auto compute_op = dynamic_cast<ComputationOpHandle *>(op);
auto compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
if (compute_op == nullptr) continue;
if (compute_op->Name() == "while") {
......@@ -52,9 +52,9 @@ class WhileOpEagerDeletionPass : public ir::Pass {
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(while_op_eager_deletion_pass,
paddle::framework::details::WhileOpEagerDeletionPass);
paddle::framework::ir::WhileOpEagerDeletionPass);
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper)
cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper)
set(ALL_REDUCE_OP_HANDLES all_reduce_op_handle)
if(WITH_GPU AND WITH_DGC)
list(APPEND ALL_REDUCE_OP_HANDLES sparse_all_reduce_op_handle)
endif()
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle ${ALL_REDUCE_OP_HANDLES} reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS all_reduce_op_handle graph graph_helper pass)
......@@ -23,7 +23,6 @@
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
......@@ -31,17 +30,18 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class AllReduceDepsPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override {
std::vector<AllReduceOpHandle*> all_reduce_op_handles =
std::vector<details::AllReduceOpHandle*> all_reduce_op_handles =
GetSortedAllReduceOps(*graph);
for (size_t i = 1; i < all_reduce_op_handles.size(); ++i) {
auto* dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
auto* dep_var = new details::DummyVarHandle(graph->CreateControlDepVar());
graph->Get<details::GraphDepVars>(details::kGraphDepVars)
.emplace(dep_var);
all_reduce_op_handles[i - 1]->AddOutput(dep_var);
all_reduce_op_handles[i]->AddInput(dep_var);
}
......@@ -51,16 +51,16 @@ class AllReduceDepsPass : public ir::Pass {
}
}
std::vector<AllReduceOpHandle*> GetSortedAllReduceOps(
std::vector<details::AllReduceOpHandle*> GetSortedAllReduceOps(
const ir::Graph& graph) const {
std::vector<AllReduceOpHandle*> all_reduce_op_handles;
std::unordered_map<OpHandleBase*, size_t> pending_ops;
std::unordered_set<OpHandleBase*> ready_ops;
std::unordered_set<OpHandleBase*> next_ready_ops;
std::vector<details::AllReduceOpHandle*> all_reduce_op_handles;
std::unordered_map<details::OpHandleBase*, size_t> pending_ops;
std::unordered_set<details::OpHandleBase*> ready_ops;
std::unordered_set<details::OpHandleBase*> next_ready_ops;
auto op_handles = ir::FilterByNodeWrapper<OpHandleBase>(graph);
auto op_handles = ir::FilterByNodeWrapper<details::OpHandleBase>(graph);
size_t num_of_ops = op_handles.size();
for (OpHandleBase* op : op_handles) {
for (details::OpHandleBase* op : op_handles) {
size_t not_ready_vars = op->NotReadyInputSize();
if (not_ready_vars) {
pending_ops.insert({op, not_ready_vars});
......@@ -94,11 +94,12 @@ class AllReduceDepsPass : public ir::Pass {
}
void GetSortedAllReduceOps(
const std::unordered_set<OpHandleBase*>& ready_ops,
std::vector<AllReduceOpHandle*>* all_reduce_op_handles) const {
std::vector<AllReduceOpHandle*> current_all_reduce_op_handles;
const std::unordered_set<details::OpHandleBase*>& ready_ops,
std::vector<details::AllReduceOpHandle*>* all_reduce_op_handles) const {
std::vector<details::AllReduceOpHandle*> current_all_reduce_op_handles;
for (auto& op_handle : ready_ops) {
auto all_reduce_op_handle = dynamic_cast<AllReduceOpHandle*>(op_handle);
auto all_reduce_op_handle =
dynamic_cast<details::AllReduceOpHandle*>(op_handle);
if (all_reduce_op_handle) {
current_all_reduce_op_handles.emplace_back(all_reduce_op_handle);
}
......@@ -109,10 +110,12 @@ class AllReduceDepsPass : public ir::Pass {
// Sort the current_all_reduce_op_handles according to the name of input.
sort(current_all_reduce_op_handles.begin(),
current_all_reduce_op_handles.end(),
[](const AllReduceOpHandle* left,
const AllReduceOpHandle* right) -> bool {
auto left_in_vars = DynamicCast<VarHandle>(left->Inputs());
auto right_in_vars = DynamicCast<VarHandle>(right->Inputs());
[](const details::AllReduceOpHandle* left,
const details::AllReduceOpHandle* right) -> bool {
auto left_in_vars =
details::DynamicCast<details::VarHandle>(left->Inputs());
auto right_in_vars =
details::DynamicCast<details::VarHandle>(right->Inputs());
PADDLE_ENFORCE_GT(left_in_vars.size(), 0);
PADDLE_ENFORCE_EQ(left_in_vars.size(), right_in_vars.size());
return left_in_vars[0]->Name() > right_in_vars[0]->Name();
......@@ -123,15 +126,15 @@ class AllReduceDepsPass : public ir::Pass {
current_all_reduce_op_handles.end());
}
void DebugString(
const ir::Graph& graph,
const std::vector<AllReduceOpHandle*>& all_reduce_op_handles) const {
void DebugString(const ir::Graph& graph,
const std::vector<details::AllReduceOpHandle*>&
all_reduce_op_handles) const {
// get vars order
std::map<int, std::vector<std::string>> vars =
GetSoredGradientsFromStaleProgram(graph);
std::stringstream out;
size_t grads_of_stale_program = 0;
out << "Get Order From kStaleProgramOpDescs: ";
out << "Get Order From details::kStaleProgramOpDescs: ";
for (auto& var : vars) {
out << "Order " << var.first << " [";
for (auto& var_name : var.second) {
......@@ -147,7 +150,7 @@ class AllReduceDepsPass : public ir::Pass {
for (auto& op : all_reduce_op_handles) {
bool find_valid_input = false;
for (auto& in_var : op->Inputs()) {
if (dynamic_cast<VarHandle*>(in_var)) {
if (dynamic_cast<details::VarHandle*>(in_var)) {
out2 << in_var->Name() << ", ";
find_valid_input = true;
break;
......@@ -165,7 +168,8 @@ class AllReduceDepsPass : public ir::Pass {
std::map<int, std::vector<std::string>> GetSoredGradientsFromStaleProgram(
const ir::Graph& graph) const {
std::map<int, std::vector<std::string>> vars;
auto ops = graph.Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs);
auto ops =
graph.Get<const std::vector<OpDesc*>>(details::kStaleProgramOpDescs);
int order = 0;
for (auto* op_desc : ops) {
try {
......@@ -193,10 +197,9 @@ class AllReduceDepsPass : public ir::Pass {
return vars;
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(all_reduce_deps_pass,
paddle::framework::details::AllReduceDepsPass)
REGISTER_PASS(all_reduce_deps_pass, paddle::framework::ir::AllReduceDepsPass)
.RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
......@@ -24,21 +24,22 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class FuseAllReduceOpPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(kLocalScopes);
auto &places = Get<const std::vector<platform::Place>>(details::kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(details::kLocalScopes);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto *nccl_ctxs = &Get<platform::NCCLContextMap>(kNCCLCtxs);
auto *nccl_ctxs = &Get<platform::NCCLContextMap>(details::kNCCLCtxs);
#endif
std::unordered_set<std::string> grads;
auto &params_grads = result.Get<ParamsAndGrads>(kParamsAndGrads);
auto &params_grads =
result.Get<details::ParamsAndGrads>(details::kParamsAndGrads);
size_t num_of_all_reduce = params_grads.size();
grads.reserve(num_of_all_reduce);
for (auto p_g : params_grads) {
......@@ -50,11 +51,12 @@ class FuseAllReduceOpPass : public ir::Pass {
all_reduce_ops.reserve(grads.size());
for (auto &node : result.Nodes()) {
if (node->IsOp()) {
PADDLE_ENFORCE(node->IsWrappedBy<OpHandleBase>());
auto *all_reduce_op_handle =
dynamic_cast<AllReduceOpHandle *>(&node->Wrapper<OpHandleBase>());
PADDLE_ENFORCE(node->IsWrappedBy<details::OpHandleBase>());
auto *all_reduce_op_handle = dynamic_cast<details::AllReduceOpHandle *>(
&node->Wrapper<details::OpHandleBase>());
if (all_reduce_op_handle) {
auto inputs = DynamicCast<VarHandle>(all_reduce_op_handle->Inputs());
auto inputs = details::DynamicCast<details::VarHandle>(
all_reduce_op_handle->Inputs());
PADDLE_ENFORCE_EQ(inputs.size(), num_place);
// The inputs' name should be the same.
auto &grad_name = inputs[0]->name();
......@@ -80,7 +82,7 @@ class FuseAllReduceOpPass : public ir::Pass {
VLOG(10) << "Insert fused_all_reduce";
auto &group_grads_params =
graph->Get<GroupGradsAndParams>(kGroupGradsAndParams);
graph->Get<details::GroupGradsAndParams>(details::kGroupGradsAndParams);
for (auto &group_g_p : group_grads_params) {
size_t group_size = group_g_p.size();
......@@ -108,24 +110,25 @@ class FuseAllReduceOpPass : public ir::Pass {
const platform::NCCLContextMap *nccl_ctxs,
#endif
ir::Graph *result) const {
std::vector<VarHandleBase *> inputs;
std::vector<VarHandleBase *> outputs;
std::vector<details::VarHandleBase *> inputs;
std::vector<details::VarHandleBase *> outputs;
for (auto &op : all_reduce_ops) {
auto &op_handle = op->Wrapper<OpHandleBase>();
auto &op_handle = op->Wrapper<details::OpHandleBase>();
inputs.insert(inputs.end(), op_handle.Inputs().begin(),
op_handle.Inputs().end());
// Remove output
for_each(op_handle.Inputs().begin(), op_handle.Inputs().end(),
[&op_handle](VarHandleBase *var_handle) {
[&op_handle](details::VarHandleBase *var_handle) {
var_handle->RemoveOutput(&op_handle, op_handle.Node());
});
outputs.insert(outputs.end(), op_handle.Outputs().begin(),
op_handle.Outputs().end());
// Remove Input
for_each(
op_handle.Outputs().begin(), op_handle.Outputs().end(),
[](VarHandleBase *var_handle) { var_handle->ClearGeneratedOp(); });
for_each(op_handle.Outputs().begin(), op_handle.Outputs().end(),
[](details::VarHandleBase *var_handle) {
var_handle->ClearGeneratedOp();
});
result->RemoveNode(op_handle.Node());
}
......@@ -140,8 +143,9 @@ class FuseAllReduceOpPass : public ir::Pass {
}
private:
void CreateFusedAllReduceOp(const std::vector<VarHandleBase *> &inputs,
const std::vector<VarHandleBase *> &outputs,
void CreateFusedAllReduceOp(
const std::vector<details::VarHandleBase *> &inputs,
const std::vector<details::VarHandleBase *> &outputs,
const size_t num_of_all_reduce,
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
......@@ -150,11 +154,11 @@ class FuseAllReduceOpPass : public ir::Pass {
#endif
ir::Graph *result) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto *op_handle = new FusedAllReduceOpHandle(
auto *op_handle = new details::FusedAllReduceOpHandle(
result->CreateEmptyNode("fused_all_reduce", ir::Node::Type::kOperation),
local_scopes, places, num_of_all_reduce, nccl_ctxs);
#else
auto *op_handle = new FusedAllReduceOpHandle(
auto *op_handle = new details::FusedAllReduceOpHandle(
result->CreateEmptyNode("fused_all_reduce", ir::Node::Type::kOperation),
local_scopes, places, num_of_all_reduce);
#endif
......@@ -176,8 +180,9 @@ class FuseAllReduceOpPass : public ir::Pass {
#endif
}
void SetCommunicationContext(const std::vector<platform::Place> &places,
FusedAllReduceOpHandle *op_handle) const {
void SetCommunicationContext(
const std::vector<platform::Place> &places,
details::FusedAllReduceOpHandle *op_handle) const {
for (size_t i = 0; i < places.size(); ++i) {
op_handle->SetDeviceContext(
places[i], platform::DeviceContextPool::Instance().Get(places[i]));
......@@ -185,9 +190,9 @@ class FuseAllReduceOpPass : public ir::Pass {
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_all_reduce_op_pass,
paddle::framework::details::FuseAllReduceOpPass);
paddle::framework::ir::FuseAllReduceOpPass);
......@@ -12,21 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
static bool IsLockAndRecordEventFreeComputationOpHandle(
ComputationOpHandle *op, const OpGraphView &graph_view) {
details::ComputationOpHandle *op, const OpGraphView &graph_view) {
if (!platform::is_gpu_place(op->GetPlace())) return false;
for (auto &pending_op : graph_view.PendingOps(op)) {
auto *tmp = dynamic_cast<ComputationOpHandle *>(pending_op);
auto *tmp = dynamic_cast<details::ComputationOpHandle *>(pending_op);
if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) {
return false;
}
......@@ -34,11 +33,13 @@ static bool IsLockAndRecordEventFreeComputationOpHandle(
return true;
}
void ModifyOpLockAndRecordEventPass::ApplyImpl(ir::Graph *ir_graph) const {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*ir_graph);
class ModifyOpLockAndRecordEventPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
auto all_ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph);
OpGraphView graph_view(all_ops);
for (auto &op : all_ops) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
auto *compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
if (compute_op == nullptr) continue;
bool is_lock_and_record_event_free =
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view);
......@@ -48,11 +49,11 @@ void ModifyOpLockAndRecordEventPass::ApplyImpl(ir::Graph *ir_graph) const {
<< compute_op->DebugString();
}
}
}
} // namespace details
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(modify_op_lock_and_record_event_pass,
paddle::framework::details::ModifyOpLockAndRecordEventPass);
paddle::framework::ir::ModifyOpLockAndRecordEventPass);
......@@ -19,7 +19,7 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class SSAGraghBuilderWithChecker : public ir::Pass {
protected:
......@@ -28,19 +28,19 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
}
bool IsValidGraph(const ir::Graph *graph) const {
std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_set<VarHandleBase *> pending_vars;
std::unordered_set<VarHandleBase *> ready_vars;
std::unordered_set<OpHandleBase *> ready_ops;
std::unordered_map<details::OpHandleBase *, size_t> pending_ops;
std::unordered_set<details::VarHandleBase *> pending_vars;
std::unordered_set<details::VarHandleBase *> ready_vars;
std::unordered_set<details::OpHandleBase *> ready_ops;
auto insert_pending_var = [&](VarHandleBase *var) {
auto insert_pending_var = [&](details::VarHandleBase *var) {
pending_vars.insert(var);
if (var->GeneratedOp() == nullptr) {
ready_vars.emplace(var);
}
};
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &var_map : graph->Get<details::GraphVars>(details::kGraphVars)) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair);
......@@ -48,11 +48,12 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
}
}
for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) {
for (auto &var :
graph->Get<details::GraphDepVars>(details::kGraphDepVars)) {
insert_pending_var(var);
}
for (OpHandleBase *op : ir::FilterByNodeWrapper<OpHandleBase>(*graph)) {
for (auto *op : ir::FilterByNodeWrapper<details::OpHandleBase>(*graph)) {
if (op->Inputs().empty()) {
ready_ops.insert(op);
} else {
......@@ -60,7 +61,7 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
}
}
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
auto run_all_ops = [&](std::unordered_set<details::OpHandleBase *> &set) {
for (auto *op : set) {
for (auto out : op->Outputs()) {
ready_vars.emplace(out);
......@@ -91,11 +92,11 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(multi_devices_check_pass,
paddle::framework::details::SSAGraghBuilderWithChecker)
paddle::framework::ir::SSAGraghBuilderWithChecker)
.RequireGraphAttr(paddle::framework::details::kGraphVars)
.RequireGraphAttr(paddle::framework::details::kGraphDepVars);
......@@ -31,7 +31,7 @@ class NCCLContextMap;
namespace framework {
class Scope;
namespace details {
namespace ir {
constexpr char kLossVarName[] = "loss_var_name";
constexpr char kStrategy[] = "strategy";
......@@ -69,7 +69,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
ir::Node *out_var_node, size_t loss_scale,
proto::VarType::Type dtype) const;
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
details::VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
size_t dst_dev_id) const;
void CreateComputationalOp(ir::Graph *result, ir::Node *node,
......@@ -89,7 +89,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
ir::Graph *result,
const std::vector<std::unordered_set<std::string>> &bcast_varnames) const;
void SetCommunicationContext(OpHandleBase *op_handle,
void SetCommunicationContext(details::OpHandleBase *op_handle,
const platform::Place &p) const;
void CreateOpHandleIOs(ir::Graph *result, ir::Node *node,
......@@ -103,7 +103,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
mutable std::vector<platform::Place> places_;
mutable std::vector<Scope *> local_scopes_;
mutable BuildStrategy strategy_;
mutable details::BuildStrategy strategy_;
mutable std::unordered_map<std::string, VarDesc *> all_vars_;
};
......@@ -209,6 +209,6 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
std::unordered_set<std::string> &MultiDevSSAGraphBuilder();
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
#include <memory>
#include <string>
#include <unordered_map>
......@@ -21,11 +21,21 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class SSAGraghBuilderWithPrinterPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>(kGraphvizPath)));
PADDLE_ENFORCE(fout->good());
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
}
};
template <typename Callback>
static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
for (auto &each : graph.Get<GraphVars>(kGraphVars)) {
for (auto &each : graph.Get<details::GraphVars>(details::kGraphVars)) {
for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) {
callback(*pair2);
......@@ -33,7 +43,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
}
}
for (auto &var : graph.Get<GraphDepVars>(kGraphDepVars)) {
for (auto &var : graph.Get<details::GraphDepVars>(details::kGraphDepVars)) {
callback(*var);
}
}
......@@ -41,14 +51,14 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
std::ostream &sout) const {
size_t var_id = 0;
std::unordered_map<const VarHandleBase *, size_t> vars;
std::unordered_map<const details::VarHandleBase *, size_t> vars;
sout << "digraph G {\n";
IterAllVar(graph, [&](const VarHandleBase &var) {
IterAllVar(graph, [&](const details::VarHandleBase &var) {
auto *var_ptr = &var;
auto *var_handle_ptr = dynamic_cast<const VarHandle *>(var_ptr);
auto *dummy_ptr = dynamic_cast<const DummyVarHandle *>(var_ptr);
auto *var_handle_ptr = dynamic_cast<const details::VarHandle *>(var_ptr);
auto *dummy_ptr = dynamic_cast<const details::DummyVarHandle *>(var_ptr);
size_t cur_var_id = var_id++;
vars[var_ptr] = cur_var_id;
......@@ -65,7 +75,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
});
size_t op_id = 0;
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(graph)) {
for (auto &op : ir::FilterByNodeWrapper<details::OpHandleBase>(graph)) {
std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl;
......@@ -82,10 +92,10 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
sout << "}\n";
}
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(multi_devices_print_pass,
paddle::framework::details::SSAGraghBuilderWithPrinter)
.RequirePassAttr(paddle::framework::details::kGraphvizPath);
paddle::framework::ir::SSAGraghBuilderWithPrinterPass)
.RequirePassAttr(paddle::framework::ir::kGraphvizPath);
......@@ -24,7 +24,7 @@
namespace paddle {
namespace framework {
namespace details {
namespace ir {
constexpr char kGraphvizPath[] = "debug_graphviz_path";
......@@ -39,16 +39,6 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
void Print(const ir::Graph& graph, std::ostream& sout) const override;
};
class SSAGraghBuilderWithPrinter : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override {
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>(kGraphvizPath)));
PADDLE_ENFORCE(fout->good());
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
}
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace ir {
static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) {
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
op1->Outputs() == op2->Outputs();
}
class SequentialExecutionPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
// FIXME(zjl): Insert dependencies between some distributed ops may cause
// the multi_devices_graph_pass fails. So we skip these ops here.
// Indeed, maybe we should not insert dependencies between these ops
// casually, which may cause deadlock easily.
// We should add more skipped distributed ops when found errors in
// multi_devices_graph_pass
static std::unordered_set<std::string> skip_dist_ops{
"send", "recv", "send_barrier", "fetch_barrier"};
auto &ops =
graph->Get<const std::vector<OpDesc *>>(details::kStaleProgramOpDescs);
std::vector<ir::Node *> op_node_list;
op_node_list.reserve(ops.size());
std::unordered_map<ir::Node *, size_t> op_deps;
std::unordered_map<ir::Node *, std::unordered_set<ir::Node *>> pending_ops;
std::unordered_set<ir::Node *> ready_ops;
for (ir::Node *node : graph->Nodes()) {
if (!node->IsOp()) continue;
std::unordered_set<ir::Node *> preceding_ops;
for (auto *in : node->inputs) {
PADDLE_ENFORCE(in->IsVar(),
"Preceding Node of Op Nodes must be Var Node");
if (in->inputs.empty()) continue;
PADDLE_ENFORCE(in->inputs.size() == 1 && in->inputs[0]->IsOp(),
"Preceding Op Node of Var Node must be unique");
preceding_ops.insert(in->inputs[0]);
pending_ops[in->inputs[0]].insert(node);
}
op_deps[node] = preceding_ops.size();
if (preceding_ops.empty()) {
ready_ops.insert(node);
}
}
for (auto *op_desc : ops) {
ir::Node *found_node = nullptr;
for (auto *node : ready_ops) {
if (IsSameOpDesc(op_desc, node->Op())) {
PADDLE_ENFORCE(found_node == nullptr,
"Found multiple op_desc in graph: %s",
op_desc->Type());
found_node = node;
}
}
PADDLE_ENFORCE_NOT_NULL(found_node, "Cannot find op_desc in graph: %s",
op_desc->Type());
for (auto *pending_op : pending_ops[found_node]) {
if (--op_deps.at(pending_op) == 0) {
ready_ops.insert(pending_op);
}
}
ready_ops.erase(found_node);
if (skip_dist_ops.count(op_desc->Type()) == 0) {
op_node_list.push_back(found_node);
}
}
for (size_t i = 1; i < op_node_list.size(); ++i) {
auto *dep_var = graph->CreateControlDepVar();
op_node_list[i]->inputs.push_back(dep_var);
op_node_list[i - 1]->outputs.push_back(dep_var);
dep_var->outputs.push_back(op_node_list[i]);
dep_var->inputs.push_back(op_node_list[i - 1]);
VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name()
<< " and " << op_node_list[i]->Name();
}
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(sequential_execution_pass,
paddle::framework::ir::SequentialExecutionPass)
.RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
......@@ -25,7 +25,8 @@ namespace framework {
namespace ir {
void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
std::string op_type) {
const std::string& op_type,
const std::string& quant_type) {
const std::string pattern_name = "quant_dequant_fuse";
// FusePassBase::Init(pattern_name, graph);
const int kNumFields = 5;
......@@ -38,7 +39,7 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode("x")
->assert_is_op_input("fake_quantize_range_abs_max", "X")
->assert_is_op_input(quant_type, "X")
->AsInput();
std::string quantized_op_type = "";
......@@ -46,6 +47,9 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
if (op_type == "conv2d") {
quantized_op_type = "conv2d";
weight_name = "Filter";
} else if (op_type == "depthwise_conv2d") {
quantized_op_type = "depthwise_conv2d";
weight_name = "Filter";
} else if (op_type == "conv2d_fusion") {
quantized_op_type = "conv2d_fusion";
weight_name = "Filter";
......@@ -62,7 +66,7 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
}
patterns::QuantDequantOpFuse pattern(gpd.mutable_pattern(), pattern_name);
pattern(x, quantized_op_type, weight_name, times);
pattern(x, quantized_op_type, weight_name, times, quant_type);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
......@@ -103,7 +107,6 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
std::unordered_set<const Node*> delete_nodes;
for (int i = 0; i < times; i++) {
// max_range = (range * range) / weight_scale
float max_range = boost::get<float>(
nodes[i * kNumFields + kDequantOpOffset]->Op()->GetAttr("max_range"));
float weight_scale = (range * range) / max_range;
......@@ -118,7 +121,8 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
new_op_desc.SetType(quantized_op_type);
if (quantized_op_type == "conv2d" ||
quantized_op_type == "conv2d_fusion") {
quantized_op_type == "conv2d_fusion" ||
quantized_op_type == "depthwise_conv2d") {
new_op_desc.SetInput("Input", {new_input});
new_op_desc.SetOutput("Output", {new_output});
} else if (quantized_op_type == "fc") {
......@@ -156,11 +160,17 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "quant_dequant_fuse";
FusePassBase::Init(pattern_name, graph);
std::unordered_set<std::string> quantized_op_types = {"conv2d", "mul"};
std::unordered_set<std::string> quant_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
std::unordered_set<std::string> quantized_op_types = {"conv2d", "mul",
"depthwise_conv2d"};
auto* scope = param_scope();
for (auto& quant_type : quant_types) {
for (auto& op_type : quantized_op_types) {
for (int i = 1; i <= 6; i++) {
RunQuantDequant(graph, scope, i, op_type);
for (int i = 6; i >= 1; i--) {
RunQuantDequant(graph, scope, i, op_type, quant_type);
}
}
}
}
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/shuffle_channel_detect_pass.h"
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(reshape1_op); \
GET_IR_NODE(reshape1_out); \
GET_IR_NODE(transpose_op); \
GET_IR_NODE(transpose_out); \
GET_IR_NODE(reshape2_op); \
GET_IR_NODE(reshape2_out);
void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "shufflechannel_pattern";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode("x")
->assert_is_op_input("reshape2", "X")
->AsInput();
patterns::ShuffleChannelPattern pattern(gpd.mutable_pattern(), pattern_name);
pattern(x);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
PADDLE_ENFORCE(subgraph.count(x));
auto* input_node = subgraph.at(x);
auto reshape1_desc = reshape1_op->Op();
auto reshape2_desc = reshape2_op->Op();
std::string input_name = input_node->Name();
std::string output_name = reshape2_out->Name();
auto reshape1_shape =
boost::get<std::vector<int>>(reshape1_desc->GetAttr("shape"));
auto reshape2_shape =
boost::get<std::vector<int>>(reshape2_desc->GetAttr("shape"));
int i_c = reshape1_shape[2];
int o_c = reshape2_shape[1];
int group = o_c / i_c;
framework::OpDesc new_op_desc;
new_op_desc.SetType("shuffle_channel");
new_op_desc.SetInput("X", {input_name});
new_op_desc.SetOutput("Out", {output_name});
new_op_desc.SetAttr("group", group);
new_op_desc.Flush();
// Create a new node for the fused op.
auto* new_op = graph->CreateOpNode(&new_op_desc);
IR_NODE_LINK_TO(input_node, new_op);
IR_NODE_LINK_TO(new_op, reshape2_out);
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph, {reshape1_op, reshape1_out, transpose_op,
transpose_out, reshape2_op});
};
gpd(graph, handler);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(shuffle_channel_detect_pass,
paddle::framework::ir::ShuffleChannelDetectPass);
......@@ -13,19 +13,22 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace details {
namespace ir {
class ShuffleChannelDetectPass : public FusePassBase {
public:
virtual ~ShuffleChannelDetectPass() {}
class ModifyOpLockAndRecordEventPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -12,20 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/sync_batch_norm_pass.h"
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
void SyncBatchNormPass::ApplyImpl(ir::Graph* graph) const {
class SyncBatchNormPass : public Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override {
VLOG(3) << "Use synchronous batch norm";
for (const Node* n : graph->Nodes()) {
for (const Node *n : graph->Nodes()) {
if (n->IsOp()) {
auto* op = n->Op();
auto *op = n->Op();
if (op->Type() == "batch_norm") {
op->SetType("sync_batch_norm");
}
......@@ -34,8 +36,8 @@ void SyncBatchNormPass::ApplyImpl(ir::Graph* graph) const {
}
}
}
}
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
......
......@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/sync_batch_norm_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
namespace ir {
......
......@@ -1148,7 +1148,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
proto::VarType::Type tmp = t->type();
PADDLE_ENFORCE(
tmp == data_type || data_type == dafault_data_type,
"DataType of Paddle Op %s %s must be the same. Get (%d) != (%d)",
"DataType of Paddle Op %s %s must be the same. Get (%s) != (%s)",
Type(), input.first, DataTypeToString(data_type),
DataTypeToString(tmp));
data_type = tmp;
......
......@@ -386,7 +386,8 @@ class ExecutionContext {
template <typename T>
T& GetKernelConfig(int idx) const {
PADDLE_ENFORCE(kernel_configs_ && kernel_configs_->size() > idx,
PADDLE_ENFORCE(
kernel_configs_ && kernel_configs_->size() > static_cast<size_t>(idx),
"%s selected kernel doesn't have kernel config %lu <= %d",
op_.Type().c_str(), kernel_configs_->size(), idx);
return *boost::get<std::shared_ptr<T>>(kernel_configs_->at(idx));
......
......@@ -23,11 +23,11 @@ limitations under the License. */
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef WITH_GPERFTOOLS
......@@ -46,6 +46,7 @@ static std::once_flag gProfileOnce;
#ifdef WITH_GPERFTOOLS
static bool gProfileStarted = false;
#endif
class ParallelExecutorPrivate {
public:
explicit ParallelExecutorPrivate(const std::vector<platform::Place> &places)
......@@ -110,9 +111,9 @@ class ParallelExecutorPrivate {
// global_ref_cnts_ is only initialized when ParallelExecutor constructs, and
// then keeps unchanged
// Before each iteration, runtime_ref_cnts_ is reset to global_ref_cnts_
std::vector<details::ReferenceCountMap> global_ref_cnts_;
std::vector<details::AtomicReferenceCountMap> runtime_ref_cnts_;
details::GarbageCollectorMap gcs_;
std::vector<ir::ReferenceCountMap> global_ref_cnts_;
std::vector<ir::AtomicReferenceCountMap> runtime_ref_cnts_;
ir::GarbageCollectorMap gcs_;
};
ir::Graph *ParallelExecutorPrivate::PrepareGCAndRefCnts(
......@@ -150,25 +151,23 @@ ir::Graph *ParallelExecutorPrivate::PrepareGCAndRefCnts(
}
if (!gcs_.empty()) {
std::vector<details::LastLiveOpsOfVars> last_live_ops_of_vars;
std::vector<ir::LastLiveOpsOfVars> last_live_ops_of_vars;
auto ref_cnt_pass =
ir::PassRegistry::Instance().Get("reference_count_pass");
ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount,
&global_ref_cnts_);
ref_cnt_pass->SetNotOwned(details::kLastLiveOpsOfVars,
&last_live_ops_of_vars);
ref_cnt_pass->SetNotOwned(ir::kGlobalReferenceCount, &global_ref_cnts_);
ref_cnt_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars);
graph = ref_cnt_pass->Apply(graph);
VLOG(10) << "ReferenceCountPass Applied";
auto eager_deletion_pass =
ir::PassRegistry::Instance().Get("eager_deletion_pass");
eager_deletion_pass->SetNotOwned(details::kRuntimeReferenceCount,
eager_deletion_pass->SetNotOwned(ir::kRuntimeReferenceCount,
&runtime_ref_cnts_);
eager_deletion_pass->SetNotOwned(details::kGarbageCollector, &gcs_);
eager_deletion_pass->SetNotOwned(details::kLastLiveOpsOfVars,
eager_deletion_pass->SetNotOwned(ir::kGarbageCollector, &gcs_);
eager_deletion_pass->SetNotOwned(ir::kLastLiveOpsOfVars,
&last_live_ops_of_vars);
eager_deletion_pass->SetNotOwned(details::kAllPlaces, &places_);
eager_deletion_pass->SetNotOwned(ir::kAllPlaces, &places_);
graph = eager_deletion_pass->Apply(graph);
VLOG(10) << "EagerDeletionPass Applied";
}
......@@ -179,6 +178,20 @@ std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
return member_->local_scopes_;
}
void ParallelExecutor::DropLocalExeScopes() {
auto executor = dynamic_cast<details::ScopeBufferedSSAGraphExecutor *>(
member_->executor_.get());
if (executor) {
executor->DropLocalExeScopes();
}
}
bool ParallelExecutor::NeedCreateLocalExeScope() {
auto executor = dynamic_cast<details::ScopeBufferedSSAGraphExecutor *>(
member_->executor_.get());
return executor && executor->NeedCreateLocalExeScope();
}
ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
const std::vector<std::string> &bcast_vars,
const std::string &loss_var_name,
......@@ -333,7 +346,7 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
graph = build_strategy.Apply(graph, {member_->places_[0]}, loss_var_name,
{member_->local_scopes_[0]}, 1,
member_->use_cuda_);
for (int i = 1; i < member_->places_.size(); ++i) {
for (size_t i = 1; i < member_->places_.size(); ++i) {
graphs[i] = build_strategy.Apply(
graphs[i], {member_->places_[i]}, loss_var_name,
{member_->local_scopes_[i]}, 1, member_->use_cuda_);
......@@ -344,8 +357,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
member_->local_scopes_, member_->nranks_,
member_->use_cuda_);
}
#endif
auto max_memory_size = GetEagerDeletionThreshold();
VLOG(10) << "Eager Deletion Threshold "
<< static_cast<float>(max_memory_size) / (1 << 30);
......
......@@ -58,6 +58,11 @@ class ParallelExecutor {
std::vector<Scope *> &GetLocalScopes();
void DropLocalExeScopes();
// This API is used to check whether DropLocalExeScopes work.
bool NeedCreateLocalExeScope();
/**
* Feed tensors to local scopes. The size of tensors should be equal to the
* size of local scopes.
......
......@@ -25,8 +25,9 @@ inline const T* Tensor::data() const {
check_memory_size();
bool valid =
std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType;
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %d",
DataTypeToString(type_));
PADDLE_ENFORCE(
valid, "Tensor holds the wrong type, it holds %s, but desires to be %s",
DataTypeToString(type_), DataTypeToString(DataTypeTrait<T>::DataType));
return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
......@@ -39,7 +40,9 @@ inline T* Tensor::data() {
check_memory_size();
bool valid =
std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType;
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s", type_);
PADDLE_ENFORCE(
valid, "Tensor holds the wrong type, it holds %s, but desires to be %s",
DataTypeToString(type_), DataTypeToString(DataTypeTrait<T>::DataType));
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
......
cc_library(anakin_op_converter SRCS fc.cc conv2d.cc conv2d_fusion.cc elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc softmax.cc batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc detection_out.cc scale.cc dropout.cc im2sequence.cc sum.cc DEPS anakin_engine framework_proto scope op_registry)
cc_library(anakin_op_converter SRCS fc.cc conv2d.cc conv2d_fusion.cc
elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc softmax.cc
batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc
detection_out.cc scale.cc dropout.cc im2sequence.cc sum.cc affine_channel.cc
roi_align.cc shuffle_channel.cc helper.cc DEPS anakin_engine framework_proto
scope op_registry gtest)
cc_test(test_anakin_fc SRCS test_fc_op.cc DEPS anakin_op_converter mul_op SERIAL)
cc_test(test_anakin_conv2d SRCS test_conv2d_op.cc DEPS anakin_op_converter conv_op im2col vol2col depthwise_conv SERIAL)
......@@ -14,5 +19,5 @@ cc_test(test_anakin_flatten SRCS test_flatten_op.cc DEPS anakin_op_converter fla
cc_test(test_anakin_transpose SRCS test_transpose_op.cc DEPS anakin_op_converter transpose_op SERIAL)
cc_test(test_anakin_batch_norm SRCS test_batch_norm_op.cc DEPS anakin_op_converter batch_norm_op SERIAL)
cc_test(test_anakin_dropout SRCS test_dropout_op.cc DEPS anakin_op_converter dropout_op SERIAL)
#cc_test(test_anakin_im2sequence SRCS test_im2sequence_op.cc DEPS anakin_op_converter im2sequence_op im2col)
cc_test(test_anakin_sum SRCS test_sum_op.cc DEPS anakin_op_converter sum_op selected_rows_functor SERIAL)
cc_test(test_anakin_affine_channel SRCS test_affine_channel_op.cc DEPS anakin_op_converter affine_channel_op SERIAL)
......@@ -16,16 +16,13 @@
#include <algorithm>
#include <map>
using anakin::graph::GraphGlobalMem;
using anakin::AK_FLOAT;
using anakin::saber::NV;
using anakin::saber::Shape;
namespace paddle {
namespace inference {
namespace anakin {
ActivationOpConverter::ActivationOpConverter(const std::string &op_type)
template <typename TargetT, ::anakin::Precision PrecisionT>
ActivationOpConverter<TargetT, PrecisionT>::ActivationOpConverter(
const std::string &op_type)
: op_type_(op_type) {
auto it = anakin_op_types_.find(op_type_);
PADDLE_ENFORCE(it != anakin_op_types_.end(),
......@@ -33,10 +30,10 @@ ActivationOpConverter::ActivationOpConverter(const std::string &op_type)
anakin_op_type_ = it->second;
}
void ActivationOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::BlockDesc &block_desc,
const framework::Scope &scope,
bool test_mode) {
template <typename TargetT, ::anakin::Precision PrecisionT>
void ActivationOpConverter<TargetT, PrecisionT>::operator()(
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
const framework::Scope &scope, bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
......@@ -44,8 +41,17 @@ void ActivationOpConverter::operator()(const framework::proto::OpDesc &op,
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
auto input_name = op_desc.Input("X").front();
auto output_name = op_desc.Output("Out").front();
engine_->AddOp(op_name, "Activation", {input_name}, {output_name});
engine_->AddOpAttr(op_name, "type", anakin_op_type_);
this->engine_->AddOp(op_name, "Activation", {input_name}, {output_name});
this->engine_->AddOpAttr(op_name, "type", anakin_op_type_);
if (op_type_ == "swish") {
float beta = boost::get<float>(op_desc.GetAttr("beta"));
this->engine_->AddOpAttr(op_name, "clip_relu_num", beta);
}
if (op_type_ == "relu6") {
float threshold = boost::get<float>(op_desc.GetAttr("threshold"));
this->engine_->AddOpAttr(op_name, "clip_relu_num", threshold);
}
}
} // namespace anakin
......@@ -54,3 +60,5 @@ void ActivationOpConverter::operator()(const framework::proto::OpDesc &op,
REGISTER_ANAKIN_OP_CONVERTER(sigmoid, SigmoidOpConverter);
REGISTER_ANAKIN_OP_CONVERTER(tanh, TanhOpConverter);
REGISTER_ANAKIN_OP_CONVERTER(swish, SwishOpConverter);
REGISTER_ANAKIN_OP_CONVERTER(relu6, Relu6OpConverter);
......@@ -22,7 +22,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class ActivationOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class ActivationOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
explicit ActivationOpConverter(const std::string &op_type);
......@@ -36,18 +37,36 @@ class ActivationOpConverter : public AnakinOpConverter {
std::string op_type_;
std::string anakin_op_type_;
std::map<std::string, std::string> anakin_op_types_{{"tanh", "TanH"},
{"sigmoid", "Sigmoid"}};
{"sigmoid", "Sigmoid"},
{"relu6", "ClippedRelu"},
{"swish", "Swish"}};
};
class TanhOpConverter : public ActivationOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class TanhOpConverter : public ActivationOpConverter<TargetT, PrecisionT> {
public:
TanhOpConverter() : ActivationOpConverter("tanh") {}
TanhOpConverter() : ActivationOpConverter<TargetT, PrecisionT>("tanh") {}
};
class SigmoidOpConverter : public ActivationOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class SigmoidOpConverter : public ActivationOpConverter<TargetT, PrecisionT> {
public:
SigmoidOpConverter() : ActivationOpConverter("sigmoid") {}
SigmoidOpConverter()
: ActivationOpConverter<TargetT, PrecisionT>("sigmoid") {}
};
template <typename TargetT, ::anakin::Precision PrecisionT>
class Relu6OpConverter : public ActivationOpConverter<TargetT, PrecisionT> {
public:
Relu6OpConverter() : ActivationOpConverter<TargetT, PrecisionT>("relu6") {}
};
template <typename TargetT, ::anakin::Precision PrecisionT>
class SwishOpConverter : public ActivationOpConverter<TargetT, PrecisionT> {
public:
SwishOpConverter() : ActivationOpConverter<TargetT, PrecisionT>("swish") {}
};
} // namespace anakin
} // namespace inference
} // namespace paddle
此差异已折叠。
此差异已折叠。
......@@ -20,7 +20,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class BatchNormOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class BatchNormOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
BatchNormOpConverter() = default;
......
......@@ -20,7 +20,8 @@ namespace paddle {
namespace inference {
namespace anakin {
class ConcatOpConverter : public AnakinOpConverter {
template <typename TargetT, ::anakin::Precision PrecisionT>
class ConcatOpConverter : public AnakinOpConverter<TargetT, PrecisionT> {
public:
ConcatOpConverter() = default;
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册