提交 d41b623a 编写于 作者: W wanghaoshuang

Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into quan_ck

test=develop
......@@ -193,6 +193,12 @@ if(WITH_GPU)
include(tensorrt)
include(anakin_subgraph)
endif()
if(WITH_GPU AND NOT WIN32)
message(STATUS "add dgc lib.")
include(external/dgc)
endif()
if(WITH_MKL OR WITH_MKLML)
include(external/anakin)
elseif()
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
INCLUDE(ExternalProject)
SET(DGC_SOURCES_DIR "${THIRD_PARTY_PATH}/dgc")
SET(DGC_INSTALL_DIR "${THIRD_PARTY_PATH}/install/dgc")
SET(DGC_INCLUDE_DIR "${DGC_INSTALL_DIR}/include" CACHE PATH "dgc include directory." FORCE)
SET(DGC_LIBRARIES "${DGC_INSTALL_DIR}/lib/libdgc.a" CACHE FILEPATH "dgc library." FORCE)
INCLUDE_DIRECTORIES(${DGC_INCLUDE_DIR})
ExternalProject_Add(
extern_dgc
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/PaddlePaddle/Fleet"
GIT_TAG "2d04dc3800cdd0601f1b65d547dabcc60b0cf9dc"
SOURCE_DIR "${DGC_SOURCES_DIR}"
CONFIGURE_COMMAND ""
BUILD_COMMAND cd collective && make -j
INSTALL_COMMAND mkdir -p ${DGC_INSTALL_DIR}/lib/ ${DGC_INCLUDE_DIR}/dgc
&& cp ${DGC_SOURCES_DIR}/collective/build/lib/libdgc.a ${DGC_LIBRARIES}
&& cp ${DGC_SOURCES_DIR}/collective/build/include/dgc.h ${DGC_INCLUDE_DIR}/dgc/
BUILD_IN_SOURCE 1
)
ADD_LIBRARY(dgc SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET dgc PROPERTY IMPORTED_LOCATION ${DGC_LIBRARIES})
ADD_DEPENDENCIES(dgc extern_dgc)
LIST(APPEND external_project_dependencies dgc)
......@@ -57,20 +57,25 @@ SET(NGRAPH_TBB_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_TBB_LIB_NAME})
ExternalProject_Add(
${NGRAPH_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${MKLDNN_PROJECT} ${MKLML_PROJECT}
GIT_REPOSITORY ${NGRAPH_GIT_REPO}
GIT_TAG ${NGRAPH_GIT_TAG}
PREFIX ${NGRAPH_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${NGRAPH_INSTALL_DIR}
CMAKE_ARGS -DNGRAPH_UNIT_TEST_ENABLE=FALSE
CMAKE_ARGS -DNGRAPH_TOOLS_ENABLE=FALSE
CMAKE_ARGS -DNGRAPH_INTERPRETER_ENABLE=FALSE
CMAKE_ARGS -DNGRAPH_DEX_ONLY=TRUE
CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
CMAKE_ARGS -DMKLDNN_INCLUDE_DIR=${MKLDNN_INC_DIR}
CMAKE_ARGS -DMKLDNN_LIB_DIR=${MKLDNN_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}
CMAKE_ARGS -DMKLML_LIB_DIR=${MKLML_INSTALL_DIR}/lib
DEPENDS ${MKLDNN_PROJECT} ${MKLML_PROJECT}
GIT_REPOSITORY ${NGRAPH_GIT_REPO}
GIT_TAG ${NGRAPH_GIT_TAG}
PREFIX ${NGRAPH_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_GENERATOR ${CMAKE_GENERATOR}
CMAKE_GENERATOR_PLATFORM ${CMAKE_GENERATOR_PLATFORM}
CMAKE_GENERATOR_TOOLSET ${CMAKE_GENERATOR_TOOLSET}
CMAKE_ARGS -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${NGRAPH_INSTALL_DIR}
CMAKE_ARGS -DNGRAPH_UNIT_TEST_ENABLE=FALSE
CMAKE_ARGS -DNGRAPH_TOOLS_ENABLE=FALSE
CMAKE_ARGS -DNGRAPH_INTERPRETER_ENABLE=FALSE
CMAKE_ARGS -DNGRAPH_DEX_ONLY=TRUE
CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
CMAKE_ARGS -DMKLDNN_INCLUDE_DIR=${MKLDNN_INC_DIR}
CMAKE_ARGS -DMKLDNN_LIB_DIR=${MKLDNN_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}
CMAKE_ARGS -DMKLML_LIB_DIR=${MKLML_INSTALL_DIR}/lib
)
add_dependencies(ngraph ${NGRAPH_PROJECT})
......
......@@ -131,6 +131,15 @@ elseif (NOT CBLAS_FOUND OR WIN32)
)
endif ()
if (WITH_GPU AND NOT WIN32)
set(dgc_dir "${FLUID_INSTALL_DIR}/third_party/install/dgc")
copy(dgc_lib
SRCS ${DGC_INSTALL_DIR}/lib ${DGC_INSTALL_DIR}/include
DSTS ${dgc_dir} ${dgc_dir}
DEPS dgc)
endif()
if (WITH_MKLDNN)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/mkldnn")
copy(mkldnn_lib
......
......@@ -110,7 +110,7 @@ function(op_library TARGET)
# Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op")
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
......
......@@ -483,6 +483,11 @@ paddle.fluid.optimizer.LarsMomentumOptimizer.apply_gradients (ArgSpec(args=['sel
paddle.fluid.optimizer.LarsMomentumOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f'))
paddle.fluid.optimizer.LarsMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.LarsMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea'))
paddle.fluid.optimizer.DGCMomentumOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'momentum', 'rampup_begin_step', 'rampup_step', 'sparsity', 'use_nesterov', 'local_grad_clip_norm', 'num_trainers', 'regularization', 'name'], varargs=None, keywords=None, defaults=(1, [0.999], False, None, None, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.DGCMomentumOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871'))
paddle.fluid.optimizer.DGCMomentumOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f'))
paddle.fluid.optimizer.DGCMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.DGCMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea'))
paddle.fluid.backward.append_backward (ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '1a79bd7d10ae54ca763ec81bca36ba24'))
paddle.fluid.regularizer.L1DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.regularizer.L2DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
......
......@@ -134,6 +134,11 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
out_layout =
out_layout == DataLayout::kAnyLayout ? DataLayout::kNCHW : out_layout;
auto& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext*>(
pool.Get(expected_kernel_type.place_));
auto& cpu_engine = dev_ctx->GetEngine();
std::vector<int> in_tz = paddle::framework::vectorize2int(in.dims());
std::vector<int> out_tz = in_tz;
......@@ -142,25 +147,29 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
"Input tensor type is not supported: %s", in.type());
memory::data_type out_type = in_type;
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
auto out_format =
platform::MKLDNNFormatForSize(in_tz.size(), ToMKLDNNFormat(out_layout));
// output tensor has the same dims as input. Reorder don't change dims
out->Resize(in.dims());
// tempory mem pd fr out , to make reorder
auto out_mem_pd = paddle::platform::create_prim_desc_from_dims(
paddle::framework::vectorize2int(out->dims()),
mkldnn::memory::format::blocked, out_type);
if (in.get_mkldnn_prim_desc() != out_mem_pd) {
if (in_format != out_format) {
void* in_data = GetDataFromTensor(in, in_type);
auto out_data = out->mutable_data(expected_kernel_type.place_, in.type());
auto in_memory = memory(in.get_mkldnn_prim_desc(), in_data);
auto out_memory = memory(out_mem_pd, out_data);
auto in_memory =
memory({{{in_tz}, in_type, in_format}, cpu_engine}, in_data);
auto out_memory =
memory({{{out_tz}, out_type, out_format}, cpu_engine}, out_data);
platform::Reorder(in_memory, out_memory);
} else {
out->ShareDataWith(in);
}
out->set_layout(out_layout);
// reset format since the out tensor will be feed to non-MKLDNN OPkernel
out->set_format(memory::format::format_undef);
#endif
}
......
......@@ -51,31 +51,13 @@ void TransformData(const OpKernelType &expected_kernel_type,
#ifdef PADDLE_WITH_MKLDNN
// Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel
// Just set layout/format. No real transform occur
auto out_format = platform::MKLDNNFormatForSize(in.dims().size(),
ToMKLDNNFormat(lin));
out.ShareDataWith(input_tensor);
// TODO(jczaja): Remove that once all mkldnn ops
// are modified to work with mkldnn_blocked
auto mkldnn_fmt = [&](int rank) {
switch (rank) {
case 5:
return mkldnn::memory::format::ncdhw;
case 4:
return mkldnn::memory::format::nchw;
case 3:
return mkldnn::memory::format::ncw;
case 2:
return mkldnn::memory::format::nc;
case 1:
return mkldnn::memory::format::x;
default:
return mkldnn::memory::format::blocked;
}
};
auto out_mem_pd = paddle::platform::create_prim_desc_from_dims(
paddle::framework::vectorize2int(out.dims()),
mkldnn_fmt(out.dims().size()));
out.set_mkldnn_prim_desc(out_mem_pd);
out.set_layout(DataLayout::kMKLDNN);
out.set_format(out_format);
#endif
} else {
// Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel
......
......@@ -10,7 +10,10 @@ cc_library(fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framewor
cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper)
cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper)
cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper)
cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper)
cc_library(fuse_adam_op_pass SRCS fuse_adam_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
......@@ -23,7 +26,7 @@ endif()
if(WITH_GPU)
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor)
dynload_cuda variable_visitor dgc)
nv_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor)
if(WITH_DISTRIBUTE)
......@@ -104,5 +107,7 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass)
fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass
alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass
fuse_adam_op_pass fuse_sgd_op_pass)
......@@ -42,8 +42,7 @@ VarHandle* GetValidInput(const OpHandleBase* a) {
return nullptr;
}
std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void AllReduceDepsPass::ApplyImpl(ir::Graph* graph) const {
auto graph_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
// get vars order
......@@ -86,7 +85,8 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
}
}
VLOG(10) << "dist_ops size:" << dist_ops.size() << std::endl;
VLOG(10) << "dist_ops size:" << dist_ops.size()
<< ", outputs size:" << vars.size() << ", ops size:" << ops.size();
std::sort(dist_ops.begin(), dist_ops.end(), [&](OpHandleBase* op1,
OpHandleBase* op2) {
......@@ -99,6 +99,10 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
auto l_it = vars.find(i0->name());
auto r_it = vars.find(i1->name());
PADDLE_ENFORCE(l_it != vars.end() && r_it != vars.end(),
"can't find var's name %s and %s in opdesc", i0->name(),
i1->name());
if (l_it->second < r_it->second) return true;
if (l_it->second == r_it->second) {
......@@ -126,8 +130,6 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
VLOG(10) << "pre_op:" << pre_op->DebugString()
<< ", op:" << op->DebugString();
}
return graph;
}
} // namespace details
......
......@@ -24,8 +24,7 @@ namespace details {
// TODO(gongwb): overlap allreduce with backward computation.
class AllReduceDepsPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details
......
......@@ -16,6 +16,13 @@
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/framework/operator.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "dgc/dgc.h"
#endif
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/profiler.h"
// asynchronous nccl allreduce or synchronous issue:
......@@ -33,11 +40,14 @@ namespace details {
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs)
const platform::NCCLContextMap *ctxs,
bool is_encoded, int nranks)
: OpHandleBase(node),
local_scopes_(local_scopes),
places_(places),
nccl_ctxs_(ctxs) {
nccl_ctxs_(ctxs),
is_encoded_(is_encoded),
nranks_(nranks) {
if (nccl_ctxs_) {
for (auto &p : places_) {
this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p));
......@@ -51,7 +61,185 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void AllReduceOpHandle::RunImplEncoded() {
platform::RecordEvent record_event(Name());
WaitInputVarGenerated();
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE_EQ(
in_var_handles.size(), places_.size(),
"The NoDummyInputSize should be equal to the number of places.");
PADDLE_ENFORCE_EQ(
in_var_handles.size(), out_var_handles.size(),
"The NoDummyInputSize and NoDummyOutputSize should be equal.");
std::vector<const LoDTensor *> ins;
std::vector<LoDTensor *> outs;
int k = -1;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &local_scope =
local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto original_name =
paddle::framework::GradOriginalVarName(in_var_handles[i]->name());
auto encode_var_name = original_name + g_dgc_encoded;
auto *in_var = local_scope->FindVar(encode_var_name);
PADDLE_ENFORCE_NOT_NULL(in_var);
auto &in = in_var->Get<LoDTensor>();
ins.emplace_back(&in);
auto *out = local_scope->FindVar(out_var_handles[i]->name())
->GetMutable<LoDTensor>();
outs.emplace_back(out);
if (k < 0) {
k = GetKValue(in_var_handles[i]->name());
}
}
PADDLE_ENFORCE(platform::is_gpu_place(ins[0]->place()));
PADDLE_ENFORCE(platform::is_gpu_place(outs[0]->place()));
PADDLE_ENFORCE(nccl_ctxs_, "nccl_ctxs should not be nullptr.");
int dtype = -1;
size_t in_numel = 0;
size_t out_numel = 0;
PADDLE_ENFORCE(nranks_ > 1);
std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &place = places_[i];
auto &in = *ins[i];
void *in_tensor_buf = const_cast<void *>(in.data<void>());
auto &out = *outs[i];
float *out_tensor_buf = out.data<float>();
dtype = (dtype == -1) ? platform::ToNCCLDataType(in.type()) : dtype;
in_numel = (in_numel == 0) ? static_cast<size_t>(in.numel()) : in_numel;
PADDLE_ENFORCE(in_numel % 2 == 0);
PADDLE_ENFORCE(in_numel / 2 == static_cast<size_t>(k));
out_numel = (out_numel == 0) ? static_cast<size_t>(out.numel()) : out_numel;
int dev_id = boost::get<platform::CUDAPlace>(place).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
auto &allocator =
platform::DeviceTemporaryAllocator::Instance().Get(place, stream);
int encode_size = 2 * k * sizeof(int);
// dgc use ncclAllGather to get all the encoded data
// so the buffer need nranks.
int buf_size = nranks_ * encode_size;
auto tmp_ious_data = allocator.Allocate(buf_size);
void *gather_buff = reinterpret_cast<void *>(tmp_ious_data->ptr());
VLOG(10) << "in_numel:" << in_numel << ", out_numel:" << out_numel
<< ", nranks:" << nranks_ << ", gather_buf size:" << buf_size
<< ", k:" << k << ", place:" << place << ", dtype:" << dtype;
all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(paddle::communication::dgc::sparseAllGReduce(
in_tensor_buf, gather_buff, k, out_tensor_buf, out_numel, comm,
stream));
});
}
this->RunAndRecordEvent([&] {
if (all_reduce_calls.size() == 1UL) {
// Do not use NCCLGroup when manage NCCL by per thread per device
all_reduce_calls[0]();
} else {
platform::NCCLGroupGuard guard;
for (auto &call : all_reduce_calls) {
call();
}
}
});
if (FLAGS_sync_nccl_allreduce) {
for (auto &p : places_) {
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
cudaError_t e_sync = cudaStreamSynchronize(stream);
if (e_sync != 0) {
LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync);
}
cudaError_t e_get = cudaGetLastError();
if (e_get != 0) {
LOG(FATAL) << "cudaGetLastError " << cudaGetErrorString(e_get)
<< " errno:" << e_get;
}
}
}
}
int AllReduceOpHandle::GetKValue(const std::string &grad_name) {
auto original_name = paddle::framework::GradOriginalVarName(grad_name);
auto var_name = original_name + g_dgc_k;
PADDLE_ENFORCE(local_scopes_.size() > 0);
auto *scope = local_scopes_[0];
auto &local_scope = scope->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto var = local_scope->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(var);
auto tensor = var->Get<LoDTensor>().data<float>();
return *tensor;
}
#endif
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
bool AllReduceOpHandle::IsEncoded() {
if (!is_encoded_) {
return false;
}
auto counter_name = g_dgc_counter_name;
auto step_name = g_dgc_rampup_begin_step;
PADDLE_ENFORCE(local_scopes_.size() > 0);
auto *scope = local_scopes_[0];
auto &local_scope = scope->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto count_var = local_scope->FindVar(counter_name);
auto step_var = local_scope->FindVar(step_name);
if (count_var == nullptr || step_var == nullptr) {
PADDLE_THROW("not find count_var:%s or step_var:%s", counter_name,
step_var);
}
float count = *count_var->Get<LoDTensor>().data<float>();
float step = *step_var->Get<LoDTensor>().data<float>();
if (static_cast<int>(count) < static_cast<int>(step)) {
VLOG(10) << "in all_reduce currentstep:" << count
<< " < rampup_begin_step:" << step
<< " so not use sparse all reduce";
return false;
}
return true;
}
#else
bool AllReduceOpHandle::IsEncoded() { return false; }
#endif
void AllReduceOpHandle::RunImpl() {
if (!IsEncoded()) {
RunImplNormal();
return;
}
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
RunImplEncoded();
#else
PADDLE_THROW("Not compiled with CUDA");
#endif
}
void AllReduceOpHandle::RunImplNormal() {
platform::RecordEvent record_event(Name());
WaitInputVarGenerated();
......@@ -72,6 +260,8 @@ void AllReduceOpHandle::RunImpl() {
auto &lod_tensor =
local_scope.FindVar(in_var_handles[i]->name())->Get<LoDTensor>();
lod_tensors.emplace_back(&lod_tensor);
VLOG(10) << "place:" << i << ", input_name:" << in_var_handles[i]->name()
<< ", out_name:" << out_var_handles[i]->name();
PADDLE_ENFORCE_EQ(in_var_handles[i]->name(), out_var_handles[i]->name(),
"The name of input and output should be equal.");
}
......@@ -99,13 +289,17 @@ void AllReduceOpHandle::RunImpl() {
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
VLOG(10) << "before all reduce buffer:" << buffer << ", numel:" << numel
<< ", dev_id:" << dev_id << ", dtype:" << dtype
<< ", place:" << p;
all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
comm, stream));
});
}
this->RunAndRecordEvent([&] {
if (all_reduce_calls.size() == 1UL) {
// Do not use NCCLGroup when manage NCCL by per thread per device
......
......@@ -28,11 +28,19 @@ namespace paddle {
namespace framework {
namespace details {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
constexpr char g_dgc_counter_name[] = "__g_dgc_counter__";
constexpr char g_dgc_rampup_begin_step[] = "__g_rampup_begin_step__";
constexpr char g_dgc_encoded[] = "__dgc_encoded__";
constexpr char g_dgc_k[] = "__dgc_k__";
#endif
struct AllReduceOpHandle : public OpHandleBase {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs);
const platform::NCCLContextMap *ctxs,
bool is_encoded = false, int nranks = -1);
#else
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
......@@ -50,8 +58,14 @@ struct AllReduceOpHandle : public OpHandleBase {
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void RunImplEncoded();
const platform::NCCLContextMap *nccl_ctxs_;
bool is_encoded_{false};
int nranks_{-1};
int GetKValue(const std::string &grad_name);
#endif
void RunImplNormal();
bool IsEncoded();
};
} // namespace details
......
......@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
DEFINE_uint32(fuse_parameter_memory_size, 0, // 0 KB
"fuse_parameter_memory_size is up limited memory size "
"of one group parameters' gradient which is the input "
......@@ -46,8 +47,7 @@ static framework::proto::VarType::Type kDefaultDtype =
class AllocContinuousSpaceForGradPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
void ApplyImpl(ir::Graph *graph) const override {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
......@@ -65,7 +65,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
if (params_grads.size() == 0) {
VLOG(10) << "Doesn't find gradients";
return std::move(graph);
return;
}
std::unordered_map<std::string, ir::Node *> vars;
......@@ -106,26 +106,33 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
auto ele_dtype = iter->second->Var()->GetDataType();
if (dtype == kDefaultDtype) {
dtype = ele_dtype;
PADDLE_ENFORCE_NE(ele_dtype, kDefaultDtype);
PADDLE_ENFORCE_NE(ele_dtype, kDefaultDtype,
"The data type should not be bool.");
}
PADDLE_ENFORCE_EQ(ele_dtype, dtype);
PADDLE_ENFORCE_EQ(ele_dtype, dtype,
"The data type of input is not consistent.");
}
// Create the fused variable name.
// Create a FusedVarsSet to avoid duplicating names for fused_var in other
// pass.
if (!result.Has(kFusedVars)) {
result.Set(kFusedVars, new FusedVars);
}
const std::string prefix(kFusedVarNamePrefix);
// The fused_var_name should be unique.
auto fused_var_name = prefix + "GRAD@" + params_grads[0].second;
// the kFusedGrads is used be fuse_optimizer_op_pass.
result.Set(kFusedGrads, new FusedGrads);
// the fused_var_name should be unique, so it appends
// params_grads.begin()->second.
auto fused_var_name = std::string(kFusedVarNamePrefix) + "@GRAD@" +
params_grads.begin()->second;
result.Get<FusedGrads>(kFusedGrads) = fused_var_name;
auto &fused_var_set = result.Get<FusedVars>(kFusedVars);
PADDLE_ENFORCE_EQ(fused_var_set.count(fused_var_name), 0);
PADDLE_ENFORCE_EQ(fused_var_set.count(fused_var_name), 0,
"%s is duplicate in FusedVars.", fused_var_name);
fused_var_set.insert(fused_var_name);
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars,
fused_var_name, params_grads);
return std::move(graph);
}
template <typename AttrType>
......@@ -298,17 +305,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
return type == proto::VarType::LOD_TENSOR;
}
void AppendAllocSpaceForVarsOp(const std::vector<std::string> &params_name,
const std::vector<std::string> &grads_name,
const std::string &fused_var_name,
BlockDesc *global_block) const {
auto op_desc = global_block->AppendOp();
op_desc->SetType("alloc_continuous_space");
op_desc->SetInput("Input", params_name);
op_desc->SetOutput("Output", grads_name);
op_desc->SetOutput("FusedOutput", {fused_var_name});
}
void RecordParamsAndGrads(ir::Node *node,
ParamsAndGrads *params_grads) const {
try {
......@@ -361,6 +357,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
}
}
// Alloc continuous space for vars.
std::vector<std::string> grads_name;
std::vector<std::string> params_name;
grads_name.reserve(params_grads.size());
......@@ -373,7 +370,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
AppendAllocSpaceForVarsOp(params_name, grads_name, fused_var_name,
program_desc.MutableBlock(0));
// Run Only Once Programs
for (size_t i = 0; i < local_scopes.size(); ++i) {
for (auto &op_desc : program_desc.Block(0).AllOps()) {
auto op = OpRegistry::CreateOp(*op_desc);
......@@ -381,6 +377,17 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
}
}
}
void AppendAllocSpaceForVarsOp(const std::vector<std::string> &params_name,
const std::vector<std::string> &grads_name,
const std::string &fused_var_name,
BlockDesc *global_block) const {
auto op_desc = global_block->AppendOp();
op_desc->SetType("alloc_continuous_space");
op_desc->SetInput("Input", params_name);
op_desc->SetOutput("Output", grads_name);
op_desc->SetOutput("FusedOutput", {fused_var_name});
}
};
} // namespace details
......
......@@ -27,20 +27,17 @@ void BroadcastOpHandle::RunImpl() {
if (places_.size() == 1) return;
// The input and output may have dummy vars.
VarHandle *in_var_handle;
{
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
PADDLE_ENFORCE_EQ(in_var_handles.size(), 1UL,
"The number of input should be one.");
in_var_handle = in_var_handles[0];
}
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
auto out_var_handles = DynamicCast<VarHandle>(outputs_);
PADDLE_ENFORCE_EQ(in_var_handles.size(), 1UL,
"The number of input should be one.");
PADDLE_ENFORCE_EQ(
out_var_handles.size(), places_.size(),
"The number of output should equal to the number of places.");
VarHandle *in_var_handle = in_var_handles[0];
WaitInputVarGenerated();
std::vector<const Scope *> var_scopes;
......
......@@ -17,7 +17,6 @@ limitations under the License. */
#include <glog/logging.h>
#include <memory>
#include <utility>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
......@@ -82,23 +81,43 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("inplace_pass");
}
if (strategy.fuse_elewise_add_act_ops_) {
if (strategy_.fuse_elewise_add_act_ops_) {
VLOG(10) << "Add fuse_elewise_add_act_pass";
AppendPass("fuse_elewise_add_act_pass");
}
// for single card training, fuse_all_reduce_ops is unnecessary.
// alloc_continuous_space_for_grad_pass should be before of MultiDevPass.
if (strategy.fuse_all_reduce_ops_) {
if (strategy_.fuse_all_reduce_ops_) {
VLOG(10) << "Add alloc_continuous_space_for_grad_pass";
AppendPass("alloc_continuous_space_for_grad_pass");
}
if (strategy_.fuse_all_optimizer_ops_) {
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce ||
strategy_.is_distribution_) {
VLOG(3)
<< "Currently, fuse_all_optimizer_ops only works under AllReduce "
"mode.";
strategy_.fuse_all_optimizer_ops_ = false;
} else {
VLOG(10) << "Add alloc_continuous_space_for_grad_pass";
AppendPass("alloc_continuous_space_for_grad_pass");
// NOTE: fuse_all_xx_ops will count the number of xx operator first,
// if the number is zero, fuse_all_reduce_ops will do nothing.
// Currently, only one type of optimization algorithm can be fused.
VLOG(10) << "Add fuse_adam_op_pass";
AppendPass("fuse_adam_op_pass");
VLOG(10) << "Add fuse_sgd_op_pass";
AppendPass("fuse_sgd_op_pass");
}
}
// Add a graph viz pass to record a graph.
if (!strategy.debug_graphviz_path_.empty()) {
auto viz_pass = AppendPass("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", strategy.debug_graphviz_path_.c_str(), "_fused_graph");
"%s%s", strategy_.debug_graphviz_path_.c_str(), "_fused_graph");
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
}
......@@ -118,14 +137,14 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// the de-fact IR, any reuse on Graph is meaningless.
// A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface.
if (strategy.memory_optimize_) {
if (strategy_.memory_optimize_) {
VLOG(10) << "Add memory_optimize_pass";
AppendPass("memory_optimize_pass");
}
AppendMultiDevPass(strategy);
AppendMultiDevPass(strategy_);
if (strategy.fuse_all_reduce_ops_) {
if (strategy_.fuse_all_reduce_ops_) {
// NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
// first, if the number is zero, fuse_all_reduce_ops will do nothing.
VLOG(10) << "Add fuse_all_reduce_op_pass";
......@@ -151,7 +170,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("all_reduce_deps_pass");
}
if (SeqOnlyAllReduceOps(strategy)) {
if (SeqOnlyAllReduceOps(strategy_)) {
VLOG(10) << "Add all_reduce_deps_pass";
AppendPass("all_reduce_deps_pass");
}
......@@ -165,7 +184,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Convert graph to run on multi-devices.
void AppendMultiDevPass(const BuildStrategy &strategy) {
ir::Pass *multi_devices_pass = nullptr;
if (strategy_.is_distribution_) {
if (strategy.is_distribution_) {
VLOG(10) << "Add dist_multi_devices_pass";
multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else {
......@@ -204,15 +223,16 @@ bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
return framework::details::MultiDevSSAGraphBuilder().count(pass_name) > 0;
}
std::unique_ptr<ir::Graph> BuildStrategy::Apply(
std::unique_ptr<ir::Graph> graph,
const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::vector<Scope *> &local_scopes,
const size_t &nranks,
ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::vector<Scope *> &local_scopes,
const size_t &nranks,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const {
const bool use_cuda,
platform::NCCLContextMap *nccl_ctxs) const {
#else
const bool use_cuda) const {
const bool use_cuda) const {
#endif
// Create a default one if not finalized by user.
CreatePassesFromStrategy(false);
......@@ -234,17 +254,22 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
#endif
} else if (pass->Type() == "fuse_all_reduce_op_pass") {
} else if (pass->Type() == "alloc_continuous_space_for_grad_pass" ||
pass->Type() == "fuse_adam_op_pass" ||
pass->Type() == "fuse_sgd_op_pass" ||
pass->Type() == "fuse_all_reduce_op_pass") {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
pass->Erase(kLocalScopes);
pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
&local_scopes);
if (pass->Type() == "fuse_all_reduce_op_pass") {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
#endif
}
} else if (pass->Type() == "alloc_continuous_space_for_grad_pass") {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
......@@ -265,7 +290,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
}
}
VLOG(3) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(std::move(graph));
graph = pass->Apply(graph);
VLOG(3) << "Finish Apply Pass " << pass->Type();
}
return graph;
......@@ -293,4 +318,6 @@ USE_PASS(inplace_pass);
USE_PASS(lock_free_optimize_pass);
USE_PASS(alloc_continuous_space_for_grad_pass);
USE_PASS(graph_to_program_pass);
USE_PASS(fuse_adam_op_pass);
USE_PASS(fuse_sgd_op_pass);
USE_PASS(fuse_all_reduce_op_pass);
......@@ -18,7 +18,6 @@
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
......@@ -76,6 +75,8 @@ struct BuildStrategy {
bool fuse_elewise_add_act_ops_{false};
bool fuse_all_optimizer_ops_{false};
bool fuse_all_reduce_ops_{false};
bool fuse_relu_depthwise_conv_{false};
......@@ -120,16 +121,15 @@ struct BuildStrategy {
// Apply the passes built by the pass_builder_. The passes will be
// applied to the Program and output an ir::Graph.
std::unique_ptr<ir::Graph> Apply(std::unique_ptr<ir::Graph> graph,
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::vector<Scope *> &local_scopes,
const size_t &nranks,
ir::Graph *Apply(ir::Graph *graph, const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::vector<Scope *> &local_scopes,
const size_t &nranks,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda,
platform::NCCLContextMap *nccl_ctxs) const;
const bool use_cuda,
platform::NCCLContextMap *nccl_ctxs) const;
#else
const bool use_cuda) const;
const bool use_cuda) const;
#endif
// If set true, ParallelExecutor would build the main_program into multiple
......
......@@ -170,12 +170,10 @@ static OpToVarNameSetMap ShrinkGCVars(
class EagerDeletionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph *graph) const override;
};
std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
auto &ref_cnts =
Get<std::vector<AtomicReferenceCountMap>>(kRuntimeReferenceCount);
PADDLE_ENFORCE(ref_cnts.empty(),
......@@ -240,7 +238,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
auto while_op_eager_deletion_pass =
ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass");
return while_op_eager_deletion_pass->Apply(std::move(graph));
while_op_eager_deletion_pass->Apply(graph);
}
} // namespace details
......
......@@ -31,9 +31,10 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
local_scopes_(local_scopes),
places_(places),
graph_(graph),
fetch_ctxs_(places),
pool_(strategy.num_threads_),
prepare_pool_(1), // add one more thread for generate op_deps
fetch_ctxs_(places) {
// add one more thread for generate op_deps
prepare_pool_(1) {
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
int dep = static_cast<int>(op->NotReadyInputSize());
op_deps_.emplace(op, dep);
......
......@@ -14,7 +14,9 @@
#pragma once
#include <ThreadPool.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/exception_holder.h"
......@@ -37,6 +39,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
const ir::Graph &Graph() const override;
private:
// Note(zcd): the ThreadPool should be placed last so that ThreadPool should
// be destroyed first.
ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
......@@ -45,21 +49,22 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::unordered_map<OpHandleBase *, int> op_deps_;
std::vector<OpHandleBase *> bootstrap_ops_;
::ThreadPool pool_;
::ThreadPool prepare_pool_;
platform::DeviceContextPool fetch_ctxs_;
std::atomic<int> remaining_;
std::future<
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>>
atomic_op_deps_;
ExceptionHolder exception_;
::ThreadPool pool_;
::ThreadPool prepare_pool_;
void RunOpAsync(std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
OpHandleBase *op,
const std::shared_ptr<BlockingQueue<size_t>> &complete_q);
void PrepareAtomicOpDeps();
std::future<
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>>
atomic_op_deps_;
ExceptionHolder exception_;
};
} // namespace details
} // namespace framework
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fuse_adam_op_pass.h"
#include <algorithm>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
namespace details {
const std::string FuseAdamOpPass::GetOpType() const { return "adam"; }
const std::vector<std::string> FuseAdamOpPass::GetAuxiliaryVarNames() const {
return {"Param", "Moment1", "Moment2", "Beta1Pow", "Beta2Pow"};
}
void FuseAdamOpPass::FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>>
&aux_var_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const {
FuseAdamOps(aux_var_set, fused_vars_name, adam_ops, graph);
FuseScaleOps(aux_var_set.at("Beta1Pow"), fused_vars_name.at("Beta1Pow"),
adam_ops, graph);
FuseScaleOps(aux_var_set.at("Beta2Pow"), fused_vars_name.at("Beta2Pow"),
adam_ops, graph);
}
void FuseAdamOpPass::FuseAdamOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const {
PADDLE_ENFORCE_GT(adam_ops.size(), static_cast<size_t>(0));
// Check attributions
// NOTE: If new attribution is added, the following code maybe need change.
int op_role = boost::get<int>(
adam_ops[0]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
float beta1 = boost::get<float>(adam_ops[0]->Op()->GetAttr("beta1"));
float beta2 = boost::get<float>(adam_ops[0]->Op()->GetAttr("beta2"));
float epsilon = boost::get<float>(adam_ops[0]->Op()->GetAttr("epsilon"));
bool lazy_mode = boost::get<bool>(adam_ops[0]->Op()->GetAttr("lazy_mode"));
int64_t min_row_size_to_use_multithread = boost::get<int64_t>(
adam_ops[0]->Op()->GetAttr("min_row_size_to_use_multithread"));
for (auto &adam_op : adam_ops) {
PADDLE_ENFORCE_EQ(beta1,
boost::get<float>(adam_op->Op()->GetAttr("beta1")));
PADDLE_ENFORCE_EQ(beta2,
boost::get<float>(adam_op->Op()->GetAttr("beta2")));
PADDLE_ENFORCE_EQ(epsilon,
boost::get<float>(adam_op->Op()->GetAttr("epsilon")));
PADDLE_ENFORCE_EQ(lazy_mode,
boost::get<bool>(adam_op->Op()->GetAttr("lazy_mode")));
PADDLE_ENFORCE_EQ(min_row_size_to_use_multithread,
boost::get<int64_t>(adam_op->Op()->GetAttr(
"min_row_size_to_use_multithread")));
PADDLE_ENFORCE_EQ(op_role, boost::get<int>(adam_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())));
}
// NOTE: fused_var is only exist in scope, so the graph doesn't have fused_var
// node.
VLOG(10) << "Insert adam to graph ";
OpDesc adam_desc(adam_ops[0]->Op()->Block());
adam_desc.SetType("adam");
adam_desc.SetInput("Param", {fused_vars_name.at("Param")});
adam_desc.SetInput("Grad", {fused_vars_name.at("Grad")});
adam_desc.SetInput("Moment1", {fused_vars_name.at("Moment1")});
adam_desc.SetInput("Moment2", {fused_vars_name.at("Moment2")});
// TODO(zcd): The LearningRate, Beta1Pow, Beta2Pow should be equal.
adam_desc.SetInput("LearningRate", adam_ops[0]->Op()->Input("LearningRate"));
adam_desc.SetInput("Beta1Pow", adam_ops[0]->Op()->Input("Beta1Pow"));
adam_desc.SetInput("Beta2Pow", adam_ops[0]->Op()->Input("Beta2Pow"));
adam_desc.SetOutput("ParamOut", {fused_vars_name.at("Param")});
adam_desc.SetOutput("Moment1Out", {fused_vars_name.at("Moment1")});
adam_desc.SetOutput("Moment2Out", {fused_vars_name.at("Moment2")});
adam_desc.SetAttr("beta1", beta1);
adam_desc.SetAttr("beta2", beta2);
adam_desc.SetAttr("epsilon", epsilon);
adam_desc.SetAttr("lazy_mode", lazy_mode);
adam_desc.SetAttr("min_row_size_to_use_multithread",
min_row_size_to_use_multithread);
adam_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role);
auto adam_node = graph->CreateOpNode(&adam_desc);
InserInputAndOutputForOptOps(adam_ops, adam_node);
}
void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name,
const std::string &fused_var_name,
const std::vector<ir::Node *> &adam_ops,
ir::Graph *graph) const {
PADDLE_ENFORCE_EQ(beta_name.size(), adam_ops.size());
const std::string scale_op_name = "scale";
// Get the scale_ops of dealing the adam's beta var.
std::vector<ir::Node *> scale_ops;
scale_ops.reserve(beta_name.size());
for (size_t i = 0; i < adam_ops.size(); ++i) {
auto &beta_1_pow_name = beta_name[i];
auto beta_pow_iter = std::find_if(
adam_ops[i]->inputs.begin(), adam_ops[i]->inputs.end(),
[&beta_name, &beta_1_pow_name](ir::Node *var_node) -> bool {
return var_node->Var() && var_node->Var()->Name() == beta_1_pow_name;
});
PADDLE_ENFORCE(beta_pow_iter != adam_ops[i]->inputs.end());
auto beta_pow_node = *beta_pow_iter;
auto scale_op_iter = std::find_if(
beta_pow_node->outputs.begin(), beta_pow_node->outputs.end(),
[&scale_op_name](ir::Node *op_node) -> bool {
return op_node->Op() && op_node->Op()->Type() == scale_op_name;
});
PADDLE_ENFORCE(scale_op_iter != beta_pow_node->outputs.end());
scale_ops.emplace_back(*scale_op_iter);
}
PADDLE_ENFORCE_EQ(scale_ops.size(), beta_name.size());
// Check attributions
// NOTE: If new attribution is added, the following code maybe need change.
int op_role = boost::get<int>(
scale_ops[0]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
float scale = boost::get<float>(scale_ops[0]->Op()->GetAttr("scale"));
float bias = boost::get<float>(scale_ops[0]->Op()->GetAttr("bias"));
bool bias_after_scale =
boost::get<bool>(scale_ops[0]->Op()->GetAttr("bias_after_scale"));
for (auto &scale_op : scale_ops) {
PADDLE_ENFORCE_EQ(scale,
boost::get<float>(scale_op->Op()->GetAttr("scale")));
PADDLE_ENFORCE_EQ(bias, boost::get<float>(scale_op->Op()->GetAttr("bias")));
PADDLE_ENFORCE_EQ(
bias_after_scale,
boost::get<bool>(scale_op->Op()->GetAttr("bias_after_scale")));
PADDLE_ENFORCE_EQ(op_role, boost::get<int>(scale_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())));
}
// NOTE: fused_var is only exist in scope, so the graph doesn't have fused_var
// node.
VLOG(10) << "Insert fused scale to graph.";
OpDesc scale_desc(scale_ops[0]->Op()->Block());
scale_desc.SetType("scale");
scale_desc.SetInput("X", {fused_var_name});
scale_desc.SetOutput("Out", {fused_var_name});
scale_desc.SetAttr("scale", scale);
scale_desc.SetAttr("bias", bias);
scale_desc.SetAttr("bias_after_scale", bias_after_scale);
scale_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role);
auto scale_node = graph->CreateOpNode(&scale_desc);
for (auto scale_op : scale_ops) {
// set inputs
scale_node->inputs.insert(scale_node->inputs.begin(),
scale_op->inputs.begin(), scale_op->inputs.end());
for (auto &input : scale_op->inputs) {
std::replace(input->outputs.begin(), input->outputs.end(), scale_op,
scale_node);
}
// set outputs
scale_node->outputs.insert(scale_node->outputs.begin(),
scale_op->outputs.begin(),
scale_op->outputs.end());
for (auto &output : scale_op->outputs) {
std::replace(output->inputs.begin(), output->inputs.end(), scale_op,
scale_node);
}
}
// Delete scale_ops
for (auto &scale_op : scale_ops) {
graph->RemoveNode(scale_op);
}
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_adam_op_pass, paddle::framework::details::FuseAdamOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace details {
class FuseAdamOpPass : public FuseOptimizerOpPass {
private:
virtual const std::string GetOpType() const;
virtual const std::vector<std::string> GetAuxiliaryVarNames() const;
// Fuse Adam Ops and Scale Ops which are used to update "Beta1Pow", "Beta2Pow"
virtual void FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const;
void FuseAdamOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const;
void FuseScaleOps(const std::vector<std::string> &aux_var_set,
const std::string &fused_var_name,
const std::vector<ir::Node *> &adam_ops,
ir::Graph *graph) const;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -28,8 +28,7 @@ namespace details {
class FuseAllReduceOpPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
void ApplyImpl(ir::Graph *graph) const override {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
......@@ -71,7 +70,7 @@ class FuseAllReduceOpPass : public ir::Pass {
VLOG(10) << "Find all_reduce_ops: " << all_reduce_ops.size();
if (all_reduce_ops.size() == 0) {
return std::move(graph);
return;
}
PADDLE_ENFORCE_EQ(all_reduce_ops.size(), grads.size(),
......@@ -99,7 +98,6 @@ class FuseAllReduceOpPass : public ir::Pass {
group_all_reduce_ops, &result);
#endif
}
return std::move(graph);
}
void InsertFusedAllReduce(const std::vector<platform::Place> &places,
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include <algorithm>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
namespace details {
void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(kLocalScopes);
const std::string fuse_op_type = GetOpType();
const std::vector<std::string> aux_var_names = GetAuxiliaryVarNames();
// Step 1: Get the specified op and auxiliary variables.
std::vector<ir::Node *> topo_nodes = ir::TopologySortOperations(result);
std::unordered_map<std::string, std::vector<std::string>> aux_var_set;
std::vector<ir::Node *> opt_ops;
for (auto &node : topo_nodes) {
GetSpecifiedOpsAndVars(fuse_op_type, aux_var_names, node, &opt_ops,
&aux_var_set);
}
VLOG(10) << "Find " << fuse_op_type << " operators: " << opt_ops.size();
if (opt_ops.size() == 0) {
return;
}
if (result.Has(kFusedOptType)) {
VLOG(10)
<< "Currently only support fusing one type optimizer op. Has fused "
<< result.Get<FusedOptType>(kFusedOptType);
return;
} else {
result.Set(kFusedOptType, new FusedOptType);
}
result.Get<FusedOptType>(kFusedOptType) = fuse_op_type;
// Step 2: Insert fused_var_name to FusedVars, and the FusedVars need be
// initialized in scopes before execution.
if (!result.Has(kFusedVars)) {
result.Set(kFusedVars, new FusedVars);
}
std::unordered_map<std::string, std::string> fused_vars_name;
fused_vars_name.reserve(aux_var_names.size() + 1);
auto &fused_var_set = result.Get<FusedVars>(kFusedVars);
const std::string prefix(kFusedVarNamePrefix);
// NOTE: the fused_var_name should be unique.
for (auto &var_name : aux_var_names) {
auto fused_var_name = prefix + "_" + fuse_op_type + "_" + var_name + "_" +
aux_var_set[var_name][0];
VLOG(10) << fused_var_name;
fused_vars_name.emplace(var_name, fused_var_name);
PADDLE_ENFORCE_EQ(fused_var_set.count(fused_var_name), 0);
fused_var_set.insert(fused_var_name);
}
// Step 3: Get the fused Gradient's name
auto &params_grads = result.Get<ParamsAndGrads>(kParamsAndGrads);
if (!result.Has(kFusedGrads)) {
PADDLE_THROW(
"The alloc_continuous_space_for_grad_pass should be called before this "
"pass.");
}
auto &fused_grad = result.Get<FusedGrads>(kFusedGrads);
auto &fused_vars = result.Get<FusedVars>(kFusedVars);
auto iter = std::find(fused_vars.begin(), fused_vars.end(), fused_grad);
PADDLE_ENFORCE(iter != fused_vars.end(), "Not find the fused_grad.");
fused_vars_name.emplace("Grad", fused_grad);
// Step 4: Sort the parameters and auxiliary variables according
// to parameters' name to make variables' name correspond correctly.
PADDLE_ENFORCE(result.Has(kParamsAndGrads), "Does't find kParamsAndGrads.");
PADDLE_ENFORCE_EQ(params_grads.size(), aux_var_set.begin()->second.size(),
"The size of params_grads and aux_var_set are not equal.");
SortParametersAndAuxVars(params_grads, &aux_var_set, &opt_ops);
// Step 5: Alloc continuous space for Parameters and AuxiliaryVar(e.g.
// Moment1, Moment2, Beta1Pow, Beta2Pow) of all the optimizer ops separately.
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, aux_var_names,
aux_var_set, fused_vars_name);
// Step 6: Fuse optimizer Ops and Scale Ops
FuseOptimizerOps(aux_var_set, fused_vars_name, opt_ops, &result);
// Step 7: Remove optimizer Ops
for (auto &opt_op : opt_ops) {
graph->RemoveNode(opt_op);
}
}
void FuseOptimizerOpPass::InitFusedVarsAndAllocSpaceForVars(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::vector<std::string> &aux_var_names,
const std::unordered_map<std::string, std::vector<std::string>>
&aux_var_set,
const std::unordered_map<std::string, std::string> &fused_vars_name) const {
VLOG(10) << "Init FusedVars.";
// Alloc parameters and auxiliary vars in the respective scope.
size_t idx = local_scopes.size();
for (auto iter = local_scopes.rbegin(); iter != local_scopes.rend();
++iter, --idx) {
auto &scope = *iter;
for (auto &var_name : aux_var_names) {
auto fused_var_name = fused_vars_name.at(var_name);
VLOG(10) << "Init " << fused_var_name;
PADDLE_ENFORCE(scope->FindVar(fused_var_name) == nullptr,
"%s has exist in scope[%d]", fused_var_name, idx);
scope->Var(fused_var_name)->GetMutable<LoDTensor>();
}
}
ProgramDesc program_desc;
auto *global_block = program_desc.MutableBlock(0);
for (auto &var_name : aux_var_names) {
AppendAllocContinuousSpace(aux_var_set.at(var_name),
fused_vars_name.at(var_name), true,
global_block);
}
for (size_t i = 0; i < local_scopes.size(); ++i) {
for (auto &op_desc : global_block->AllOps()) {
auto op = OpRegistry::CreateOp(*op_desc);
op->Run(*local_scopes[i], places[i]);
}
}
}
void FuseOptimizerOpPass::SortParametersAndAuxVars(
const std::vector<std::pair<std::string, std::string>> &params_grads,
std::unordered_map<std::string, std::vector<std::string>> *aux_vars_set,
std::vector<ir::Node *> *ops) const {
PADDLE_ENFORCE_NE(aux_vars_set->count("Param"), static_cast<size_t>(0));
auto &param_vec = aux_vars_set->at("Param");
std::vector<size_t> param_sort_idx;
param_sort_idx.reserve(param_vec.size());
for (auto &p_g : params_grads) {
auto iter = std::find(param_vec.begin(), param_vec.end(), p_g.first);
PADDLE_ENFORCE(iter != param_vec.end());
auto idx = std::distance(param_vec.begin(), iter);
param_sort_idx.emplace_back(idx);
}
for (auto &aux_vars : *aux_vars_set) {
std::vector<std::string> sorted_vars;
sorted_vars.reserve(aux_vars.second.size());
for (size_t i = 0; i < aux_vars.second.size(); ++i) {
sorted_vars.emplace_back(aux_vars.second.at(param_sort_idx[i]));
}
std::swap(aux_vars.second, sorted_vars);
std::stringstream out;
for (auto &var_name : aux_vars.second) {
out << var_name << " ";
}
VLOG(10) << aux_vars.first << ": " << out.str();
}
std::vector<ir::Node *> sorted_ops;
sorted_ops.reserve(ops->size());
for (size_t i = 0; i < ops->size(); ++i) {
sorted_ops.emplace_back(ops->at(param_sort_idx[i]));
}
std::swap(*ops, sorted_ops);
}
void FuseOptimizerOpPass::GetSpecifiedOpsAndVars(
const std::string &op_type, const std::vector<std::string> &aux_vars_name,
ir::Node *node, std::vector<ir::Node *> *ops,
std::unordered_map<std::string, std::vector<std::string>> *aux_args_name)
const {
if (node->Op()->Type() != op_type) return;
for (auto &var_n : aux_vars_name) {
auto arg_names = node->Op()->Input(var_n);
PADDLE_ENFORCE_EQ(arg_names.size(), static_cast<size_t>(1));
(*aux_args_name)[var_n].emplace_back(arg_names[0]);
VLOG(10) << var_n << ", " << arg_names[0];
}
ops->emplace_back(node);
}
void FuseOptimizerOpPass::AppendAllocContinuousSpace(
const std::vector<std::string> &args, const std::string &out_arg,
bool copy_data, BlockDesc *global_block) const {
auto op_desc = global_block->AppendOp();
op_desc->SetType("alloc_continuous_space");
op_desc->SetInput("Input", args);
op_desc->SetOutput("Output", args);
op_desc->SetOutput("FusedOutput", {out_arg});
op_desc->SetAttr("copy_data", copy_data);
op_desc->SetAttr("check_name", true);
}
void FuseOptimizerOpPass::InserInputAndOutputForOptOps(
const std::vector<ir::Node *> &opt_ops, ir::Node *opt_node) const {
std::unordered_set<ir::Node *> inputs;
std::unordered_set<ir::Node *> outputs;
for (auto opt_op : opt_ops) {
// set inputs
inputs.insert(opt_op->inputs.begin(), opt_op->inputs.end());
for (auto &input : opt_op->inputs) {
replace(input->outputs.begin(), input->outputs.end(), opt_op, opt_node);
}
// set outputs
outputs.insert(opt_op->outputs.begin(), opt_op->outputs.end());
for (auto &output : opt_op->outputs) {
replace(output->inputs.begin(), output->inputs.end(), opt_op, opt_node);
}
}
opt_node->inputs.insert(opt_node->inputs.begin(), inputs.begin(),
inputs.end());
opt_node->outputs.insert(opt_node->outputs.begin(), outputs.begin(),
outputs.end());
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace details {
class FuseOptimizerOpPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override;
protected:
virtual void SortParametersAndAuxVars(
const std::vector<std::pair<std::string, std::string>> &params_grads,
std::unordered_map<std::string, std::vector<std::string>> *aux_var_set,
std::vector<ir::Node *> *ops) const;
void InserInputAndOutputForOptOps(const std::vector<ir::Node *> &opt_ops,
ir::Node *opt_node) const;
private:
virtual const std::string GetOpType() const = 0;
virtual const std::vector<std::string> GetAuxiliaryVarNames() const = 0;
virtual void FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const = 0;
void GetSpecifiedOpsAndVars(
const std::string &op_type, const std::vector<std::string> &aux_vars_name,
ir::Node *node, std::vector<ir::Node *> *ops,
std::unordered_map<std::string, std::vector<std::string>> *aux_args_name)
const;
void AppendAllocContinuousSpace(const std::vector<std::string> &args,
const std::string &out_arg, bool copy_data,
BlockDesc *global_block) const;
void InitFusedVarsAndAllocSpaceForVars(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::vector<std::string> &aux_var_names,
const std::unordered_map<std::string, std::vector<std::string>>
&aux_var_set,
const std::unordered_map<std::string, std::string> &fused_vars_name)
const;
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fuse_sgd_op_pass.h"
#include <algorithm>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
namespace details {
const std::string FuseSgdOpPass::GetOpType() const { return "sgd"; }
const std::vector<std::string> FuseSgdOpPass::GetAuxiliaryVarNames() const {
return {"Param"};
}
void FuseSgdOpPass::FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>>
&aux_var_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const {
FuseSgdOps(aux_var_set, fused_vars_name, sgd_ops, graph);
}
void FuseSgdOpPass::FuseSgdOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const {
PADDLE_ENFORCE_GT(sgd_ops.size(), static_cast<size_t>(0));
// NOTE: fused_var is only exist in scope, so the graph doesn't have fused_var
// node.
int op_role = boost::get<int>(
sgd_ops[0]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
VLOG(10) << "Insert sgd to graph ";
// Add fused scale
OpDesc Sgd_desc(sgd_ops[0]->Op()->Block());
Sgd_desc.SetType("sgd");
Sgd_desc.SetInput("Param", {fused_vars_name.at("Param")});
Sgd_desc.SetInput("Grad", {fused_vars_name.at("Grad")});
Sgd_desc.SetOutput("ParamOut", {fused_vars_name.at("Param")});
// TODO(zcd): The LearningRate, Beta1Pow, Beta2Pow should be equal.
Sgd_desc.SetInput("LearningRate", sgd_ops[0]->Op()->Input("LearningRate"));
// NOTE: multi_devices_pass requires that every op should have a role.
Sgd_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role);
auto sgd_node = graph->CreateOpNode(&Sgd_desc);
InserInputAndOutputForOptOps(sgd_ops, sgd_node);
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_sgd_op_pass, paddle::framework::details::FuseSgdOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace details {
class FuseSgdOpPass : public FuseOptimizerOpPass {
private:
virtual const std::string GetOpType() const;
virtual const std::vector<std::string> GetAuxiliaryVarNames() const;
// Fuse Sgd Ops
virtual void FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const;
void FuseSgdOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -24,6 +24,19 @@ namespace paddle {
namespace framework {
namespace details {
// Note(zcd): Addresses should be aligned, otherwise, the results may have
// diff.
static size_t Alignment(size_t size, const platform::Place &place) {
// Allow to allocate the minimum chunk size is 4 KB.
size_t alignment = 1 << 12;
if (platform::is_gpu_place(place)) {
// Allow to allocate the minimum chunk size is 256 B.
alignment = 1 << 8;
}
size_t remaining = size % alignment;
return remaining == 0 ? size : size + (alignment - remaining);
}
typedef std::vector<std::vector<std::pair<std::string, const LoDTensor *>>>
GradientAndLoDTensor;
......@@ -111,10 +124,11 @@ void FusedAllReduceOpHandle::RunImpl() {
return grad1.second->data<void>() < grad2.second->data<void>();
});
size_t size_of_dtype = framework::SizeOfType(dtype);
for (size_t k = 1; k < g_tensor.size(); ++k) {
const void *cur_address = g_tensor.at(k - 1).second->data<void>();
int64_t len = g_tensor.at(k - 1).second->numel();
auto offset = len * framework::SizeOfType(dtype);
auto offset = Alignment(len * size_of_dtype, places_[0]);
void *infer_next_address = reinterpret_cast<void *>(
reinterpret_cast<uintptr_t>(cur_address) + offset);
const void *next_address = g_tensor.at(k).second->data<void>();
......@@ -228,18 +242,21 @@ void FusedAllReduceOpHandle::GetDTypeAndNumel(
const std::vector<std::pair<std::string, const LoDTensor *>> &grad_tensor,
proto::VarType::Type *dtype, int64_t *numel) const {
*numel = 0;
size_t size_of_dtype = 0;
for (size_t i = 0; i < grad_tensor.size(); ++i) {
// Get element number
int64_t len = grad_tensor.at(i).second->numel();
PADDLE_ENFORCE_GT(len, 0);
*numel += len;
// Get dtype
auto ele_type = grad_tensor.at(i).second->type();
if (i == 0) {
*dtype = ele_type;
size_of_dtype = framework::SizeOfType(ele_type);
}
PADDLE_ENFORCE_EQ(ele_type, *dtype);
// Get element number
int64_t len = grad_tensor.at(i).second->numel();
PADDLE_ENFORCE_GT(len, 0);
// Alignment(len)
*numel += Alignment(len * size_of_dtype, places_[0]) / size_of_dtype;
}
}
......
......@@ -144,10 +144,9 @@ void InplacePass::InitSSAGraphNodes() const {
}
}
std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void InplacePass::ApplyImpl(ir::Graph* graph) const {
var_nodes_.clear();
view_.Build(graph.get());
view_.Build(graph);
InitSSAGraphNodes();
auto cnt = 0;
......@@ -155,11 +154,9 @@ std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
VLOG(4) << "Handle op " << cnt++ << ": " << op->Name();
if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name()))
continue;
TryInplaceOpInputOutput(op, graph.get());
TryInplaceOpInputOutput(op, graph);
}
// graph->ResolveHazard(var_nodes_);
return graph;
}
void InplacePass::InplaceModifyDesc(const std::string& var,
......
......@@ -69,8 +69,7 @@ class InplacePass : public ir::Pass {
InplacePass();
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
void InitSSAGraphNodes() const;
......
......@@ -44,8 +44,7 @@ namespace paddle {
namespace framework {
namespace details {
std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void MemoryOptimizePass::ApplyImpl(ir::Graph* graph) const {
auto nodes = graph->Nodes();
CollectSkipVarsSet(nodes);
......@@ -113,7 +112,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
cfg_->RenameVarInCFGGraph(var_name, cache_name, idx);
RenameVarInGraphDesc(var_name, cache_name, idx);
RenameVarInGraphNode(var_name, cache_name, idx, graph.get());
RenameVarInGraphNode(var_name, cache_name, idx, graph);
pool_.Erase(cache_name);
}
}
......@@ -128,8 +127,6 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
}
}
graph->ResolveHazard(var_nodes_);
return graph;
}
void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {
......
......@@ -21,6 +21,7 @@
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
......@@ -35,8 +36,7 @@ namespace details {
class MemoryOptimizePass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
// fill the variable map(var_nodes) by version.
void InitSSAGraphNodes() const;
......
......@@ -34,8 +34,7 @@ static bool IsLockAndRecordEventFreeComputationOpHandle(
return true;
}
std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
std::unique_ptr<ir::Graph> ir_graph) const {
void ModifyOpLockAndRecordEventPass::ApplyImpl(ir::Graph *ir_graph) const {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*ir_graph);
OpGraphView graph_view(all_ops);
for (auto &op : all_ops) {
......@@ -49,7 +48,6 @@ std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
<< compute_op->DebugString();
}
}
return ir_graph;
}
} // namespace details
......
......@@ -23,8 +23,7 @@ namespace details {
class ModifyOpLockAndRecordEventPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details
......
......@@ -23,10 +23,8 @@ namespace details {
class SSAGraghBuilderWithChecker : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph;
void ApplyImpl(ir::Graph *graph) const override {
PADDLE_ENFORCE(IsValidGraph(graph));
}
bool IsValidGraph(const ir::Graph *graph) const {
......
......@@ -32,6 +32,7 @@
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace framework {
......@@ -152,8 +153,7 @@ void MultiDevSSAGraphBuilderBase::Init() const {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
}
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const {
Init();
CheckGraph(*graph);
std::vector<ir::Node *> sorted_ops = SortOperations(*graph);
......@@ -209,7 +209,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
for (size_t i = 0; i < backward_vars.size(); i += 2) {
auto &p_name = backward_vars[i];
auto &g_name = backward_vars[i + 1];
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name
<< " op_type " << node->Op()->Type();
if (NeedCollectiveForGrad(g_name, sorted_ops)) {
InsertCollectiveOp(&result, p_name, g_name);
}
......@@ -234,7 +235,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
AddOutputToLeafOps(&result);
result.Erase(kGraphOps);
return graph;
}
void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp(
......@@ -414,8 +414,9 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
CreateOpHandleIOs(result, node, dev_id);
}
void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
ir::Graph *result, const std::string &og) const {
void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
const std::string &og,
bool is_encoded) const {
OpHandleBase *op_handle = nullptr;
auto append_allreduce_op = [&](
......@@ -424,7 +425,9 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
scopes, places, nccl_ctxs_));
scopes, places, nccl_ctxs_, is_encoded,
static_cast<int>(strategy_.trainers_endpoints_.size()) *
places_.size()));
#else
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
......@@ -446,12 +449,15 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad);
VLOG(10) << "all_reduce_op_handle add input " << prev_grad->DebugString();
auto var =
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
vars.size(), i, og, places_[i]);
vars.emplace_back(var);
op_handle->AddOutput(var);
VLOG(10) << "all_reduce_op_handle add output " << og
<< ", handle:" << var->DebugString();
}
}
......@@ -941,6 +947,17 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
return op_dev_id;
}
bool DistSSAGraphBuilder::IsEncoded(const std::string &p_name) const {
auto u_name = p_name + "__dgc_u__";
auto it = all_vars_.find(u_name);
if (it == all_vars_.end()) {
VLOG(10) << "can't find u_name, so it's not encoded:" << u_name;
return false;
}
return true;
}
void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
const std::string &p_name,
const std::string &g_name) const {
......@@ -956,7 +973,11 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
CreateReduceOp(result, g_name, 0);
CreateBroadcastOp(result, g_name, 0);
} else {
CreateAllReduceOp(result, g_name);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
CreateAllReduceOp(result, g_name, IsEncoded(p_name));
#else
PADDLE_ENFORCE(false, "Compiled withoud cuda!");
#endif
}
break;
default:
......
......@@ -20,7 +20,6 @@
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
......@@ -34,10 +33,13 @@ namespace framework {
class Scope;
namespace details {
constexpr char kLossVarName[] = "loss_var_name";
constexpr char kStrategy[] = "strategy";
constexpr char kNRanks[] = "nranks";
class MultiDevSSAGraphBuilderBase : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph *graph) const override;
virtual void Init() const;
......@@ -75,7 +77,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
bool IsSparseGradient(const std::string &og) const;
void CreateAllReduceOp(ir::Graph *result, const std::string &og) const;
void CreateAllReduceOp(ir::Graph *result, const std::string &og,
bool is_encoded = false) const;
void CreateBroadcastOp(ir::Graph *result, const std::string &p_name,
size_t src_dev_id) const;
......@@ -171,6 +174,8 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
mutable std::vector<std::unordered_set<std::string>> bcast_var_name_set_;
mutable bool need_broadcast_var_{false};
bool IsEncoded(const std::string &p_name) const;
};
std::unordered_set<std::string> &MultiDevSSAGraphBuilder();
......
......@@ -13,7 +13,9 @@
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
......
......@@ -17,6 +17,7 @@
#include <glog/logging.h>
#include <fstream>
#include <iosfwd>
#include <memory>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/details/multi_devices_helper.h"
......@@ -40,13 +41,11 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
class SSAGraghBuilderWithPrinter : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
void ApplyImpl(ir::Graph* graph) const override {
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>(kGraphvizPath)));
PADDLE_ENFORCE(fout->good());
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
return graph;
}
};
......
......@@ -20,7 +20,6 @@
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
......@@ -41,22 +40,25 @@ namespace details {
// `std::vector<VarHandle*>` is the version of varaibles.
typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle *>>>
GraphVars;
const char kGraphVars[] = "vars";
// aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<VarHandleBase *> GraphDepVars;
const char kGraphDepVars[] = "dep_vars";
constexpr char kGraphVars[] = "vars";
constexpr char kNCCLCtxs[] = "nccl_ctxs";
constexpr char kLossVarName[] = "loss_var_name";
constexpr char kPlaces[] = "places";
constexpr char kLocalScopes[] = "local_scopes";
constexpr char kStrategy[] = "strategy";
constexpr char kNRanks[] = "nranks";
constexpr char kNCCLCtxs[] = "nccl_ctxs";
// aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<VarHandleBase *> GraphDepVars;
constexpr char kGraphDepVars[] = "dep_vars";
typedef std::unordered_set<std::string> FusedVars;
constexpr char kFusedVars[] = "fused_vars";
constexpr char kFusedVarNamePrefix[] = "@FUSEDVAR@";
typedef std::string FusedOptType;
constexpr char kFusedOptType[] = "fused_opt_type";
typedef std::string FusedGrads;
constexpr char kFusedGrads[] = "fused_gradients";
typedef std::vector<std::pair<std::string, std::string>> ParamsAndGrads;
constexpr char kParamsAndGrads[] = "params_grads";
......@@ -65,8 +67,6 @@ typedef std::vector<std::vector<std::pair<std::string, std::string>>>
GroupGradsAndParams;
constexpr char kGroupGradsAndParams[] = "group_grads_params";
constexpr char kFusedVarNamePrefix[] = "@FUSEDVAR@";
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -96,7 +96,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
auto seq_allreduce_pass =
ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
for (size_t i = 0; i < graphs_.size(); ++i) {
graphs_[i] = seq_allreduce_pass->Apply(std::move(graphs_[i]));
graphs_[i].reset(seq_allreduce_pass->Apply(graphs_[i].release()));
}
// set the correct size of thread pool to each device.
......
......@@ -266,8 +266,7 @@ static bool ShrinkNoNeedBufferVarOpDependency(
}
}
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount);
auto &last_live_ops_of_vars =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
......@@ -335,14 +334,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
var_name);
ref_cnts[i].emplace(var_name, result.size());
last_live_ops_of_vars[i].emplace(var_name, std::move(result));
break;
}
// Seldomly, all preceding trying failed.
// Just skip this corner case
}
}
return graph;
}
} // namespace details
......
......@@ -23,8 +23,7 @@ namespace details {
class ReferenceCountPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details
......
......@@ -29,8 +29,7 @@ static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) {
op1->Outputs() == op2->Outputs();
}
std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void SequentialExecutionPass::ApplyImpl(ir::Graph *graph) const {
// FIXME(zjl): Insert dependencies between some distributed ops may cause
// the multi_devices_graph_pass fails. So we skip these ops here.
// Indeed, maybe we should not insert dependencies between these ops
......@@ -98,7 +97,6 @@ std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl(
VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name()
<< " and " << op_node_list[i]->Name();
}
return graph;
}
} // namespace details
......
......@@ -23,8 +23,7 @@ namespace details {
class SequentialExecutionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details
......
......@@ -24,13 +24,13 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, ir::Graph *graph)
: graph_(graph),
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
: nullptr),
prepare_pool_(1),
local_scopes_(local_scopes),
places_(places),
fetch_ctxs_(places),
strategy_(strategy) {
strategy_(strategy),
prepare_pool_(1),
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
: nullptr) {
PrepareOpDeps();
CopyOpDeps();
}
......
......@@ -63,13 +63,20 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
details::OpHandleBase *op);
private:
// Note(zcd): the ThreadPool should be placed last so that ThreadPool should
// be destroyed first.
ir::Graph *graph_;
std::unique_ptr<::ThreadPool> pool_;
::ThreadPool prepare_pool_;
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
platform::DeviceContextPool fetch_ctxs_;
ExceptionHolder exception_holder_;
std::unique_ptr<OpDependentData> op_deps_;
std::future<std::unique_ptr<OpDependentData>> op_deps_futures_;
ExecutionStrategy strategy_;
// use std::list because clear(), push_back, and for_each are O(1)
std::list<std::future<void>> run_op_futures_;
::ThreadPool prepare_pool_;
std::unique_ptr<::ThreadPool> pool_;
void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops,
OpHandleBase *op_instance) const;
......@@ -88,14 +95,6 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
void PrepareOpDeps();
void CopyOpDeps();
private:
std::future<std::unique_ptr<OpDependentData>> op_deps_futures_;
ExecutionStrategy strategy_;
std::unique_ptr<OpDependentData> op_deps_;
// use std::list because clear(), push_back, and for_each are O(1)
std::list<std::future<void>> run_op_futures_;
};
} // namespace details
......
......@@ -24,7 +24,8 @@ VarHandle::~VarHandle() { VLOG(4) << "deleting var handle " << DebugString(); }
std::string VarHandle::DebugString() const {
std::stringstream ss;
ss << name_ << ":" << place_;
ss << "name:" << name_ << ", place:" << place_ << ", version:" << version_
<< ", scope_idx:" << scope_idx_;
return ss.str();
}
......
......@@ -23,8 +23,7 @@ namespace details {
class WhileOpEagerDeletionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
void ApplyImpl(ir::Graph *graph) const override {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
// Find all while_op and while_grad_op
......@@ -50,7 +49,6 @@ class WhileOpEagerDeletionPass : public ir::Pass {
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
while_ops, while_grad_ops);
}
return graph;
}
};
......
......@@ -29,10 +29,9 @@ namespace ir {
GET_IR_NODE(elementwise_mul); \
GET_IR_NODE(elementwise_mul_out);
std::unique_ptr<ir::Graph> AnakinFillconstantElementwisemulFuse::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void AnakinFillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "anakin_fillconstant_elementwisemul_fuse";
FusePassBase::Init(pattern_name, graph.get());
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
......@@ -69,12 +68,11 @@ std::unique_ptr<ir::Graph> AnakinFillconstantElementwisemulFuse::ApplyImpl(
IR_NODE_LINK_TO(scale_op, elementwise_mul_out); // Output
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(),
GraphSafeRemoveNodes(graph,
{fill_constant, fill_constant_out, elementwise_mul});
};
gpd(graph.get(), handler);
return graph;
gpd(graph, handler);
}
} // namespace ir
......
......@@ -26,8 +26,7 @@ class AnakinFillconstantElementwisemulFuse : public FusePassBase {
virtual ~AnakinFillconstantElementwisemulFuse() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/attention_lstm_fuse_pass.h"
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/lod_tensor.h"
......@@ -253,8 +254,7 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
// Parameters
std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void AttentionLSTMFusePass::ApplyImpl(ir::Graph* graph) const {
PDPattern external_pattern, subblock_pattern;
// Use the following variables to tell whether this model is RNN1.
......@@ -269,12 +269,11 @@ std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl(
}
}
if (count < specified_vars.size()) {
return graph;
return;
}
// Continue to fuse.
FindWhileOp(graph.get());
return graph;
FindWhileOp(graph);
}
} // namespace ir
......
......@@ -22,8 +22,7 @@ namespace ir {
class AttentionLSTMFusePass : public FusePassBase {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -77,10 +77,9 @@ void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight,
weights_array_2d.colwise() *= scale_array;
}
std::unique_ptr<ir::Graph> ConvAffineChannelFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
......@@ -139,7 +138,7 @@ std::unique_ptr<ir::Graph> ConvAffineChannelFusePass::ApplyImpl(
desc.SetAttr("axis", 1);
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph.get(), {ac_scale, ac_bias, affine_channel});
GraphSafeRemoveNodes(graph, {ac_scale, ac_bias, affine_channel});
IR_NODE_LINK_TO(conv_out, eltwise_op);
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
......@@ -147,16 +146,14 @@ std::unique_ptr<ir::Graph> ConvAffineChannelFusePass::ApplyImpl(
found_conv_ac_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_conv_ac_count);
return graph;
}
std::unique_ptr<ir::Graph> ConvEltwiseAddAffineChannelFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
void ConvEltwiseAddAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
......@@ -199,7 +196,7 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddAffineChannelFusePass::ApplyImpl(
eltwise->Op()->SetAttr("axis", 1);
eltwise->Op()->SetOutput("Out", std::vector<std::string>({ac_out->Name()}));
GraphSafeRemoveNodes(graph.get(),
GraphSafeRemoveNodes(graph,
{ac_scale, ac_bias, affine_channel, eltwise_out});
IR_NODE_LINK_TO(eltwise, ac_out);
......@@ -207,9 +204,8 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddAffineChannelFusePass::ApplyImpl(
found_conv_ac_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_conv_ac_count);
return graph;
}
} // namespace ir
......
......@@ -31,8 +31,7 @@ class ConvAffineChannelFusePass : public FusePassBase {
virtual ~ConvAffineChannelFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph*) const override;
const std::string name_scope_{"conv_affine_channel_fuse"};
};
......@@ -41,8 +40,7 @@ class ConvEltwiseAddAffineChannelFusePass : public FusePassBase {
virtual ~ConvEltwiseAddAffineChannelFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph*) const override;
const std::string name_scope_{"conv_eltwiseadd_affine_channel_fuse"};
};
......
......@@ -101,10 +101,9 @@ void recompute_bias_and_weights(const Scope* scope,
weights_array_2d.colwise() *= variance_array;
}
std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
......@@ -187,7 +186,7 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
std::vector<std::string>({bn_out->Name()}));
GraphSafeRemoveNodes(
graph.get(),
graph,
{conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm,
bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance});
......@@ -203,10 +202,9 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
desc.SetAttr("axis", 1);
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(
graph.get(),
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
bn_variance_out, bn_saved_mean, bn_saved_variance});
GraphSafeRemoveNodes(graph, {bn_scale, bn_bias, bn_mean, bn_variance,
batch_norm, bn_mean_out, bn_variance_out,
bn_saved_mean, bn_saved_variance});
IR_NODE_LINK_TO(conv_out, eltwise_op);
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
......@@ -215,16 +213,14 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
}
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_conv_bn_count);
return graph;
}
std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
......@@ -274,7 +270,7 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl(
eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
GraphSafeRemoveNodes(
graph.get(),
graph,
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
bn_variance_out, bn_saved_mean, bn_saved_variance, eltwise_out});
......@@ -283,10 +279,9 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl(
found_conv_bn_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_conv_bn_count);
return graph;
}
} // namespace ir
......
......@@ -31,8 +31,7 @@ class ConvBNFusePass : public FusePassBase {
virtual ~ConvBNFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"conv_bn_fuse"};
};
......@@ -41,8 +40,7 @@ class ConvEltwiseAddBNFusePass : public FusePassBase {
virtual ~ConvEltwiseAddBNFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"conv_eltwiseadd_bn_fuse"};
};
......
......@@ -50,10 +50,9 @@ framework::proto::OpDesc PrepareOpDesc(
return *desc.Proto();
}
std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "conv_elementwise_add_act_fuse";
FusePassBase::Init(pattern_name, graph.get());
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input(
......@@ -95,7 +94,6 @@ std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl(
elementwise_add_out});
};
gpd(graph.get(), handler);
return graph;
}
} // namespace ir
......
......@@ -51,10 +51,9 @@ framework::proto::OpDesc PrepareOpDesc(
return *desc.Proto();
}
std::unique_ptr<ir::Graph> ConvElementwiseAdd2ActFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "conv_elementwise_add2_act_fuse";
FusePassBase::Init(pattern_name, graph.get());
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input(
......@@ -92,12 +91,10 @@ std::unique_ptr<ir::Graph> ConvElementwiseAdd2ActFusePass::ApplyImpl(
// Delete the unneeded nodes.
GraphSafeRemoveNodes(
graph.get(),
{conv_op, conv_out, elementwise_add_op, elementwise_add_op_1,
elementwise_add_out, elementwise_add_out_1, act_op});
graph, {conv_op, conv_out, elementwise_add_op, elementwise_add_op_1,
elementwise_add_out, elementwise_add_out_1, act_op});
};
gpd(graph.get(), handler);
return graph;
gpd(graph, handler);
}
} // namespace ir
......
......@@ -25,8 +25,7 @@ class ConvElementwiseAdd2ActFusePass : public FusePassBase {
virtual ~ConvElementwiseAdd2ActFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -48,10 +48,9 @@ framework::proto::OpDesc PrepareOpDesc(
return *desc.Proto();
}
std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "conv_elementwise_add_act_fuse";
FusePassBase::Init(pattern_name, graph.get());
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
......@@ -88,12 +87,11 @@ std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl(
IR_NODE_LINK_TO(new_conv_op, act_out); // Output
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(), {conv_op, conv_out, elementwise_add_op,
elementwise_add_out, act_op});
GraphSafeRemoveNodes(graph, {conv_op, conv_out, elementwise_add_op,
elementwise_add_out, act_op});
};
gpd(graph.get(), handler);
return graph;
gpd(graph, handler);
}
} // namespace ir
......
......@@ -25,8 +25,7 @@ class ConvElementwiseAddActFusePass : public FusePassBase {
virtual ~ConvElementwiseAddActFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -30,10 +30,9 @@ namespace ir {
GET_IR_NODE(elementwise_add_in_y); \
GET_IR_NODE(elementwise_add_out);
std::unique_ptr<ir::Graph> ConvElementwiseAddFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "conv_elementwise_add_fuse";
FusePassBase::Init(pattern_name, graph.get());
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
......@@ -76,11 +75,10 @@ std::unique_ptr<ir::Graph> ConvElementwiseAddFusePass::ApplyImpl(
IR_NODE_LINK_TO(new_conv_op, elementwise_add_out); // Output
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(), {conv_op, conv_out, elementwise_add_op});
GraphSafeRemoveNodes(graph, {conv_op, conv_out, elementwise_add_op});
};
gpd(graph.get(), handler);
return graph;
gpd(graph, handler);
}
} // namespace ir
......
......@@ -25,8 +25,7 @@ class ConvElementwiseAddFusePass : public FusePassBase {
virtual ~ConvElementwiseAddFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -15,6 +15,8 @@
#include "paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h"
#include <algorithm>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/math/blas.h"
......@@ -201,7 +203,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
// Remove unneeded nodes.
// TODO(jczaja): Proper removing of lookup table
std::unordered_set<const Node*> marked_nodes(
//{lookup_table, mul, lstm, elementwise_add, fc_bias, W});
// {lookup_table, mul, lstm, elementwise_add, fc_bias, W});
{mul, lstm, elementwise_add, fc_bias});
GraphSafeRemoveNodes(graph, marked_nodes);
} else {
......@@ -224,15 +226,13 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
return fusion_count;
}
std::unique_ptr<ir::Graph> EmbeddingFCLSTMFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
void EmbeddingFCLSTMFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
true /*with_fc_bias*/);
int fusion_count =
BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/);
AddStatis(fusion_count);
return graph;
}
} // namespace ir
......
......@@ -32,8 +32,7 @@ class EmbeddingFCLSTMFusePass : public FusePassBase {
virtual ~EmbeddingFCLSTMFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"embedding_fc_lstm_fuse"};
};
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -22,10 +23,9 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("fc_fuse", graph.get());
void FCFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("fc_fuse", graph);
std::unordered_set<Node*> nodes2delete;
......@@ -61,7 +61,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
desc.SetAttr("in_num_col_dims", mul->Op()->GetAttr("x_num_col_dims"));
desc.SetType("fc");
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph.get(), {mul, elementwise_add, mul_out});
GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out});
PADDLE_ENFORCE(subgraph.count(x));
IR_NODE_LINK_TO(subgraph.at(x), fc_node);
......@@ -72,10 +72,9 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
found_fc_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_fc_count);
return graph;
}
} // namespace ir
......
......@@ -31,8 +31,7 @@ class FCFusePass : public FusePassBase {
virtual ~FCFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -73,7 +73,7 @@ TEST(FCFusePass, basic) {
int pre_nodes = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
int after_nodes = graph->Nodes().size();
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/fc_gru_fuse_pass.h"
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
......@@ -39,7 +40,6 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
// Create New OpDesc
auto gru_creater = [&](Node* gru, Node* x, Node* weight_x, Node* weight_h,
Node* bias, Node* hidden, Node* fc_bias) {
OpDesc op_desc;
op_desc.SetType("fusion_gru");
......@@ -155,26 +155,22 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
return fusion_count;
}
std::unique_ptr<ir::Graph> MulGRUFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
void MulGRUFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
false /*with_fc_bias*/);
int fusion_count =
BuildFusion(graph, name_scope_, param_scope(), false /*with_fc_bias*/);
AddStatis(fusion_count);
return graph;
}
std::unique_ptr<ir::Graph> FCGRUFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
void FCGRUFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
true /*with_fc_bias*/);
int fusion_count =
BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/);
AddStatis(fusion_count);
return graph;
}
} // namespace ir
......
......@@ -30,8 +30,7 @@ class FCGRUFusePass : public FusePassBase {
virtual ~FCGRUFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"fc_gru_fuse"};
};
......@@ -42,8 +41,7 @@ class MulGRUFusePass : public FusePassBase {
virtual ~MulGRUFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"fc_nobias_gru_fuse"};
};
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
......@@ -157,26 +158,22 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
return fusion_count;
}
std::unique_ptr<ir::Graph> MulLstmFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
void MulLstmFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
false /*with_fc_bias*/);
int fusion_count =
BuildFusion(graph, name_scope_, param_scope(), false /*with_fc_bias*/);
AddStatis(fusion_count);
return graph;
}
std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
void FCLstmFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
true /*with_fc_bias*/);
int fusion_count =
BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/);
AddStatis(fusion_count);
return graph;
}
} // namespace ir
......
......@@ -32,8 +32,7 @@ class FCLstmFusePass : public FusePassBase {
virtual ~FCLstmFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"fc_lstm_fuse"};
};
......@@ -43,8 +42,7 @@ class MulLstmFusePass : public FusePassBase {
virtual ~MulLstmFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"fc_nobias_lstm_fuse"};
};
......
......@@ -15,6 +15,8 @@
#include "paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h"
#include <algorithm>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -23,29 +25,25 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void FuseElewiseAddActPass::ApplyImpl(ir::Graph *graph) const {
std::unordered_set<std::string> act_types = {"relu", "scale"};
graph = FuseActElewiseAdd(std::move(graph), act_types);
graph = FuseElewiseAddAct(std::move(graph), act_types);
graph = FuseActElewiseAdd(graph, act_types);
graph = FuseElewiseAddAct(graph, act_types);
// backward
{
std::unordered_set<std::string> in_place_act_types = {"relu_grad"};
graph = FuseElewiseAddActInplaceGrad(std::move(graph), in_place_act_types);
graph = FuseElewiseAddActInplaceGrad(graph, in_place_act_types);
}
// Remove the removable intermediate_out.
RemoveIntermediateOut(graph.get());
return graph;
RemoveIntermediateOut(graph);
}
// ele_add(x, act(y))
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddAct(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("elewise_add_act", graph.get());
ir::Graph *FuseElewiseAddActPass::FuseElewiseAddAct(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("elewise_add_act", graph);
GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern()
......@@ -86,18 +84,17 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddAct(
found_elewise_add_act_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_elewise_add_act_count);
return graph;
}
// act(ele_add(x,y))
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("act_elewise_add", graph.get());
ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("act_elewise_add", graph);
GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern()
......@@ -137,7 +134,7 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd(
found_elewise_add_act_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_elewise_add_act_count);
return graph;
......@@ -146,11 +143,10 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd(
// the backward of act(ele_add(x,y))
// act_grad: in["Out", "Out@GRAD"], out["X@GRAD"]
// ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"]
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("elewise_add_act_grad", graph.get());
ir::Graph *FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("elewise_add_act_grad", graph);
GraphPatternDetector gpd;
auto *d_act_out = gpd.mutable_pattern()
......@@ -217,7 +213,7 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
found_elewise_add_act_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_elewise_add_act_count);
return graph;
......
......@@ -14,6 +14,8 @@
#pragma once
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
......@@ -32,20 +34,16 @@ class FuseElewiseAddActPass : public FusePassBase {
virtual ~FuseElewiseAddActPass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph *graph) const override;
std::unique_ptr<ir::Graph> FuseElewiseAddAct(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const;
ir::Graph *FuseElewiseAddAct(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const;
std::unique_ptr<ir::Graph> FuseActElewiseAdd(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const;
ir::Graph *FuseActElewiseAdd(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const;
std::unique_ptr<ir::Graph> FuseElewiseAddActInplaceGrad(
std::unique_ptr<ir::Graph> graph,
const std::unordered_set<std::string> &act_types) const;
ir::Graph *FuseElewiseAddActInplaceGrad(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const;
/**
* Remove the removable intermediate_out.
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/fuse_relu_depthwise_conv_pass.h"
#include <algorithm>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -23,20 +24,18 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> FuseReluDepthwiseConvPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
graph = FuseReluDepthwiseConv(std::move(graph), true);
graph = FuseReluDepthwiseConv(std::move(graph), false);
return graph;
void FuseReluDepthwiseConvPass::ApplyImpl(ir::Graph *graph) const {
graph = FuseReluDepthwiseConv(graph, true);
graph = FuseReluDepthwiseConv(graph, false);
}
std::unique_ptr<ir::Graph> FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
std::unique_ptr<ir::Graph> graph, bool only_forward) const {
PADDLE_ENFORCE(graph.get());
ir::Graph *FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
ir::Graph *graph, bool only_forward) const {
PADDLE_ENFORCE(graph);
if (only_forward)
FusePassBase::Init("relu_depthwise_conv_only_forward", graph.get());
FusePassBase::Init("relu_depthwise_conv_only_forward", graph);
else
FusePassBase::Init("relu_depthwise_conv", graph.get());
FusePassBase::Init("relu_depthwise_conv", graph);
/*
x ---act--> y ---layer-> z
+----------+
......@@ -144,10 +143,9 @@ std::unique_ptr<ir::Graph> FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
}
count++;
};
gpd(graph.get(), handler);
GraphSafeRemoveNodes(graph.get(), need_removed_nodes);
gpd(graph, handler);
GraphSafeRemoveNodes(graph, need_removed_nodes);
AddStatis(count);
return graph;
}
......
......@@ -32,10 +32,8 @@ class FuseReluDepthwiseConvPass : public FusePassBase {
virtual ~FuseReluDepthwiseConvPass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
std::unique_ptr<ir::Graph> FuseReluDepthwiseConv(
std::unique_ptr<ir::Graph> graph, bool only_forward) const;
void ApplyImpl(ir::Graph* graph) const override;
ir::Graph* FuseReluDepthwiseConv(ir::Graph* graph, bool only_forward) const;
};
} // namespace ir
......
......@@ -15,7 +15,9 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
......@@ -26,8 +28,7 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl(
std::unique_ptr<Graph> graph) const {
void GraphToProgramPass::ApplyImpl(ir::Graph* graph) const {
// Remove the unneeded variables after memory optimization.
std::unordered_set<std::string> vars2remove;
if (graph->Has(kGraphToProgramVarsToRemove)) {
......@@ -73,7 +74,6 @@ std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl(
}
program.CopyFrom(*program_pb);
return graph;
}
} // namespace ir
......
......@@ -26,7 +26,7 @@ const char kGraphToProgramSortKind[] = "__graph_to_program_sort_kind__";
class GraphToProgramPass : public Pass {
protected:
std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -14,7 +14,9 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/program_desc.h"
......@@ -84,7 +86,7 @@ TEST(GraphToProgramPass, Basic) {
ProgramDesc compiled_prog;
pass->SetNotOwned<paddle::framework::ProgramDesc>("program", &compiled_prog);
pass->Apply(std::move(g));
pass->Apply(g.get());
std::vector<OpDesc*> ops = compiled_prog.Block(0).AllOps();
EXPECT_EQ(ops[0]->Type(), "op1");
EXPECT_EQ(ops[1]->Type(), "op2");
......
......@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/string/printf.h"
......@@ -38,8 +38,7 @@ std::string FormatName(const Node* node) {
}
} // namespace
std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void GraphVizPass::ApplyImpl(ir::Graph* graph) const {
const std::string graph_viz_path = Get<std::string>(kGraphVizPath);
VLOG(3) << "draw IR graph viz to " << graph_viz_path;
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path));
......@@ -82,7 +81,7 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
{Dot::Attr("style", "filled,rounded"), Dot::Attr("shape", "box"),
Dot::Attr("fillcolor", "yellow")});
auto marked_nodes = ConsumeMarkedNodes(graph.get());
auto marked_nodes = ConsumeMarkedNodes(graph);
// Create nodes
for (const Node* n : graph->Nodes()) {
std::string node_id = FormatName(n) + "(" + std::to_string(n->id()) + ")";
......@@ -115,8 +114,6 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
}
sout << dot.Build();
return graph;
}
GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
......@@ -135,4 +132,4 @@ GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
} // namespace paddle
REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass)
.RequirePassAttr(paddle::framework::ir::kGraphVizPath);
\ No newline at end of file
.RequirePassAttr(paddle::framework::ir::kGraphVizPath);
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
......@@ -34,8 +35,7 @@ class GraphVizPass : public Pass {
using marked_nodes_t = std::unordered_set<const Node*>;
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
// Tell whether there are any marked nodes in the graph. Consume the
// corresponding attribute.
......
......@@ -20,9 +20,8 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> IdentityScaleOpCleanPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init("identity_scale_op_clean", graph.get());
void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("identity_scale_op_clean", graph);
// pre_op -> scale_in -> scale_op -> scale_out
// ->
......@@ -72,8 +71,7 @@ std::unique_ptr<ir::Graph> IdentityScaleOpCleanPass::ApplyImpl(
IR_NODE_LINK_TO(pre_op_var, scale_out_var);
};
detector(graph.get(), handler);
return graph;
detector(graph, handler);
}
} // namespace ir
......
......@@ -22,8 +22,7 @@ namespace ir {
class IdentityScaleOpCleanPass : public FusePassBase {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
private:
virtual ~IdentityScaleOpCleanPass() = default;
......
......@@ -26,9 +26,9 @@ class InferCleanGraphPass : public FusePassBase {
virtual ~InferCleanGraphPass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init("original_graph", graph.get());
PADDLE_ENFORCE(graph.get());
void ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("original_graph", graph);
PADDLE_ENFORCE(graph);
auto is_valid_node = [](Node* x) {
return x && IsControlDepVar(*x) && x->IsVar() && !x->Var();
......@@ -46,11 +46,9 @@ class InferCleanGraphPass : public FusePassBase {
}
}
GraphSafeRemoveNodes(graph.get(), invalid_nodes);
GraphSafeRemoveNodes(graph, invalid_nodes);
AddStatis(valid_op);
return graph;
}
void CleanEdges(std::vector<Node*>* nodes,
......
......@@ -20,8 +20,7 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> IsTestPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void IsTestPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Sets is_test attrbiute to true and if it is missing, inserts it "
"for activations and pooling.";
auto op_list = {"pool2d", "sigmoid", "logsigmoid",
......@@ -47,7 +46,6 @@ std::unique_ptr<ir::Graph> IsTestPass::ApplyImpl(
}
}
}
return graph;
}
} // namespace ir
......
......@@ -22,8 +22,7 @@ namespace ir {
class IsTestPass : public Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -97,7 +97,7 @@ TEST(IsTestPass, basic) {
auto pass = PassRegistry::Instance().Get("is_test_pass");
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
......
......@@ -32,9 +32,8 @@ const char kSumGradOpName[] = "sum";
// other optimizers later.
const char kOptimizerType[] = "sgd";
std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
// We could collect all weights' name from SGD, where
// W1 <- SGD(W0, Grad0)
......@@ -92,14 +91,14 @@ std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl(
// find the forward op related to the backward op
ir::Node* forward_op =
FindForwardOpViaBackwardOp(graph.get(), backward_op);
FindForwardOpViaBackwardOp(graph, backward_op);
VLOG(3) << "Found forward_op " << forward_op->Name();
PADDLE_ENFORCE(forward_op);
Node* new_optimizer_node = CreateNewSGDNode(
graph.get(), forward_op, backward_op, node, opt_node);
graph, forward_op, backward_op, node, opt_node);
PADDLE_ENFORCE(new_optimizer_node);
}
......@@ -140,8 +139,6 @@ std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl(
}
}
}
return graph;
}
ir::Node* LockFreeOptimizePass::CreateNewSGDNode(
......
......@@ -60,8 +60,7 @@ class LockFreeOptimizePass : public Pass {
virtual ~LockFreeOptimizePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
private:
// Create a new sgd node via current optimizer node
......
......@@ -38,10 +38,9 @@ LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
return vec_y;
}
std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
......@@ -99,7 +98,7 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
conv->Op()->SetOutput("Output",
std::vector<std::string>({eltwise_out->Name()}));
GraphSafeRemoveNodes(graph.get(), {eltwise, conv_out});
GraphSafeRemoveNodes(graph, {eltwise, conv_out});
IR_NODE_LINK_TO(conv, eltwise_out);
} else {
......@@ -123,14 +122,13 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
IR_NODE_LINK_TO(eltwise_bias, conv_bias_node);
IR_NODE_LINK_TO(conv_bias_node, eltwise_out);
GraphSafeRemoveNodes(graph.get(), {conv, eltwise, conv_out});
GraphSafeRemoveNodes(graph, {conv, eltwise, conv_out});
}
found_conv_bias_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_conv_bias_count);
return graph;
}
} // namespace ir
} // namespace framework
......
......@@ -29,8 +29,7 @@ class ConvBiasFusePass : public FusePassBase {
virtual bool is_conv3d() const { return false; }
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"conv_bias_mkldnn_fuse"};
};
/*
......
......@@ -13,10 +13,10 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/platform/place.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
......@@ -103,7 +103,7 @@ void MainTest(bool convWithExistingBias) {
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
......
......@@ -16,8 +16,8 @@
#include <functional>
#include <list>
#include <map>
#include <memory>
#include <tuple>
#include "paddle/fluid/framework/ir/graph_traits.h"
namespace paddle {
......@@ -327,17 +327,15 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
get_node_from_elementwise_add);
}
graph_ptr ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
FusePassBase::Init(name_scope_, graph.get());
void ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
FusePassBase::Init(name_scope_, graph);
auto fused_graph_with_stats = FuseConvAsY(
name_scope_,
FuseConvAsX(
name_scope_,
FuseProjectionConv(name_scope_, std::make_pair(graph.get(), 0))));
FuseConvAsX(name_scope_,
FuseProjectionConv(name_scope_, std::make_pair(graph, 0))));
std::cout << "Fused graph " << fused_graph_with_stats.second << std::endl;
AddStatis(fused_graph_with_stats.second);
return graph;
}
} // namespace ir
} // namespace framework
......
......@@ -14,6 +14,7 @@
#pragma once
#include <memory>
#include <string>
#include <tuple>
#include <utility>
......@@ -27,7 +28,7 @@ namespace paddle {
namespace framework {
namespace ir {
using graph_ptr = std::unique_ptr<ir::Graph>;
using graph_ptr = ir::Graph*;
using GraphWithStats = std::pair<ir::Graph*, int>;
void CorrectGraphEdges(Graph* graph, Node* from, Node* to);
......@@ -124,7 +125,7 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
virtual ~ResidualConnectionMKLDNNFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(graph_ptr graph) const;
void ApplyImpl(graph_ptr graph) const;
const std::string name_scope_{"residual_connection_fuse_pass"};
};
......
......@@ -148,7 +148,7 @@ void RunPassAndAssert(ProgramDesc* prog, const std::string& from,
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)(from, to));
......@@ -258,7 +258,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) {
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "g"));
......
......@@ -21,10 +21,9 @@ namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("conv_relu_mkldnn_fuse", graph.get());
void ConvReLUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("conv_relu_mkldnn_fuse", graph);
GraphPatternDetector gpd;
auto* conv_input = gpd.mutable_pattern()
......@@ -56,7 +55,7 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
OpDesc* desc = conv->Op();
desc->SetOutput("Output", std::vector<std::string>({relu_out->Name()}));
desc->SetAttr("fuse_relu", true);
GraphSafeRemoveNodes(graph.get(), {relu, conv_out});
GraphSafeRemoveNodes(graph, {relu, conv_out});
PADDLE_ENFORCE(subgraph.count(conv_input));
IR_NODE_LINK_TO(conv, relu_out);
......@@ -64,10 +63,9 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
found_conv_relu_count++;
};
gpd(graph.get(), handler);
gpd(graph, handler);
AddStatis(found_conv_relu_count);
return graph;
}
} // namespace ir
......
......@@ -31,8 +31,7 @@ class ConvReLUFusePass : public FusePassBase {
virtual ~ConvReLUFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
......
......@@ -88,7 +88,7 @@ TEST(ConvReLUFusePass, basic) {
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
......
......@@ -216,19 +216,16 @@ void CPUQuantizePass::QuantizePool(Graph* graph) const {
PrettyLogDetail("--- quantized %d pool2d ops", quantize_pool_count);
}
std::unique_ptr<ir::Graph> CPUQuantizePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Quantizing the graph.";
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
PADDLE_ENFORCE(graph);
FusePassBase::Init(name_scope_, graph);
PADDLE_ENFORCE(param_scope());
QuantizeConv(graph.get(), false /* with_residual_data */);
QuantizeConv(graph.get(), true /* with_residual_data */);
QuantizePool(graph.get());
return graph;
QuantizeConv(graph, false /* with_residual_data */);
QuantizeConv(graph, true /* with_residual_data */);
QuantizePool(graph);
}
} // namespace ir
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册