提交 2f9e5621 编写于 作者: Z zhoukunsheng

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into linspace

...@@ -193,6 +193,12 @@ if(WITH_GPU) ...@@ -193,6 +193,12 @@ if(WITH_GPU)
include(tensorrt) include(tensorrt)
include(anakin_subgraph) include(anakin_subgraph)
endif() endif()
if(WITH_GPU AND NOT WIN32)
message(STATUS "add dgc lib.")
include(external/dgc)
endif()
if(WITH_MKL OR WITH_MKLML) if(WITH_MKL OR WITH_MKLML)
include(external/anakin) include(external/anakin)
elseif() elseif()
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
INCLUDE(ExternalProject)
SET(DGC_SOURCES_DIR "${THIRD_PARTY_PATH}/dgc")
SET(DGC_INSTALL_DIR "${THIRD_PARTY_PATH}/install/dgc")
SET(DGC_INCLUDE_DIR "${DGC_INSTALL_DIR}/include" CACHE PATH "dgc include directory." FORCE)
SET(DGC_LIBRARIES "${DGC_INSTALL_DIR}/lib/libdgc.a" CACHE FILEPATH "dgc library." FORCE)
INCLUDE_DIRECTORIES(${DGC_INCLUDE_DIR})
ExternalProject_Add(
extern_dgc
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/PaddlePaddle/Fleet"
GIT_TAG "2d04dc3800cdd0601f1b65d547dabcc60b0cf9dc"
SOURCE_DIR "${DGC_SOURCES_DIR}"
CONFIGURE_COMMAND ""
BUILD_COMMAND cd collective && make -j
INSTALL_COMMAND mkdir -p ${DGC_INSTALL_DIR}/lib/ ${DGC_INCLUDE_DIR}/dgc
&& cp ${DGC_SOURCES_DIR}/collective/build/lib/libdgc.a ${DGC_LIBRARIES}
&& cp ${DGC_SOURCES_DIR}/collective/build/include/dgc.h ${DGC_INCLUDE_DIR}/dgc/
BUILD_IN_SOURCE 1
)
ADD_LIBRARY(dgc SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET dgc PROPERTY IMPORTED_LOCATION ${DGC_LIBRARIES})
ADD_DEPENDENCIES(dgc extern_dgc)
LIST(APPEND external_project_dependencies dgc)
...@@ -57,20 +57,25 @@ SET(NGRAPH_TBB_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_TBB_LIB_NAME}) ...@@ -57,20 +57,25 @@ SET(NGRAPH_TBB_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_TBB_LIB_NAME})
ExternalProject_Add( ExternalProject_Add(
${NGRAPH_PROJECT} ${NGRAPH_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${MKLDNN_PROJECT} ${MKLML_PROJECT} DEPENDS ${MKLDNN_PROJECT} ${MKLML_PROJECT}
GIT_REPOSITORY ${NGRAPH_GIT_REPO} GIT_REPOSITORY ${NGRAPH_GIT_REPO}
GIT_TAG ${NGRAPH_GIT_TAG} GIT_TAG ${NGRAPH_GIT_TAG}
PREFIX ${NGRAPH_SOURCES_DIR} PREFIX ${NGRAPH_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${NGRAPH_INSTALL_DIR} CMAKE_GENERATOR ${CMAKE_GENERATOR}
CMAKE_ARGS -DNGRAPH_UNIT_TEST_ENABLE=FALSE CMAKE_GENERATOR_PLATFORM ${CMAKE_GENERATOR_PLATFORM}
CMAKE_ARGS -DNGRAPH_TOOLS_ENABLE=FALSE CMAKE_GENERATOR_TOOLSET ${CMAKE_GENERATOR_TOOLSET}
CMAKE_ARGS -DNGRAPH_INTERPRETER_ENABLE=FALSE CMAKE_ARGS -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
CMAKE_ARGS -DNGRAPH_DEX_ONLY=TRUE CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${NGRAPH_INSTALL_DIR}
CMAKE_ARGS -DMKLDNN_INCLUDE_DIR=${MKLDNN_INC_DIR} CMAKE_ARGS -DNGRAPH_UNIT_TEST_ENABLE=FALSE
CMAKE_ARGS -DMKLDNN_LIB_DIR=${MKLDNN_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR} CMAKE_ARGS -DNGRAPH_TOOLS_ENABLE=FALSE
CMAKE_ARGS -DMKLML_LIB_DIR=${MKLML_INSTALL_DIR}/lib CMAKE_ARGS -DNGRAPH_INTERPRETER_ENABLE=FALSE
CMAKE_ARGS -DNGRAPH_DEX_ONLY=TRUE
CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
CMAKE_ARGS -DMKLDNN_INCLUDE_DIR=${MKLDNN_INC_DIR}
CMAKE_ARGS -DMKLDNN_LIB_DIR=${MKLDNN_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}
CMAKE_ARGS -DMKLML_LIB_DIR=${MKLML_INSTALL_DIR}/lib
) )
add_dependencies(ngraph ${NGRAPH_PROJECT}) add_dependencies(ngraph ${NGRAPH_PROJECT})
......
...@@ -131,6 +131,15 @@ elseif (NOT CBLAS_FOUND OR WIN32) ...@@ -131,6 +131,15 @@ elseif (NOT CBLAS_FOUND OR WIN32)
) )
endif () endif ()
if (WITH_GPU AND NOT WIN32)
set(dgc_dir "${FLUID_INSTALL_DIR}/third_party/install/dgc")
copy(dgc_lib
SRCS ${DGC_INSTALL_DIR}/lib ${DGC_INSTALL_DIR}/include
DSTS ${dgc_dir} ${dgc_dir}
DEPS dgc)
endif()
if (WITH_MKLDNN) if (WITH_MKLDNN)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/mkldnn") set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/mkldnn")
copy(mkldnn_lib copy(mkldnn_lib
......
...@@ -110,7 +110,7 @@ function(op_library TARGET) ...@@ -110,7 +110,7 @@ function(op_library TARGET)
# Define operators that don't need pybind here. # Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op") "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}") if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1) set(pybind_flag 1)
endif() endif()
......
...@@ -211,7 +211,7 @@ paddle.fluid.layers.mean (ArgSpec(args=['x', 'name'], varargs=None, keywords=Non ...@@ -211,7 +211,7 @@ paddle.fluid.layers.mean (ArgSpec(args=['x', 'name'], varargs=None, keywords=Non
paddle.fluid.layers.mul (ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None)), ('document', 'ccd37fa6b53f074adbfb732d738c4c2d')) paddle.fluid.layers.mul (ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None)), ('document', 'ccd37fa6b53f074adbfb732d738c4c2d'))
paddle.fluid.layers.sigmoid_cross_entropy_with_logits (ArgSpec(args=['x', 'label', 'ignore_index', 'name', 'normalize'], varargs=None, keywords=None, defaults=(-100, None, False)), ('document', '180c284317ea45ef89a460d8d79c0b72')) paddle.fluid.layers.sigmoid_cross_entropy_with_logits (ArgSpec(args=['x', 'label', 'ignore_index', 'name', 'normalize'], varargs=None, keywords=None, defaults=(-100, None, False)), ('document', '180c284317ea45ef89a460d8d79c0b72'))
paddle.fluid.layers.maxout (ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '891870d069a6aea746d34cc53b61690c')) paddle.fluid.layers.maxout (ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '891870d069a6aea746d34cc53b61690c'))
paddle.fluid.layers.space_to_depth (ArgSpec(args=['x', 'blocksize', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '5f207ae10589ebe38a63575ef6ff8e1e')) paddle.fluid.layers.space_to_depth (ArgSpec(args=['x', 'blocksize', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'a9221eaef53884a00654e028551b78e2'))
paddle.fluid.layers.affine_grid (ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '51def402b8910e163cbace9d0c0526ed')) paddle.fluid.layers.affine_grid (ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '51def402b8910e163cbace9d0c0526ed'))
paddle.fluid.layers.sequence_reverse (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '77a6d80aa5551ca70324fc975c44507f')) paddle.fluid.layers.sequence_reverse (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '77a6d80aa5551ca70324fc975c44507f'))
paddle.fluid.layers.affine_channel (ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name', 'act'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None, None)), ('document', 'ab84fdc6dc60f3ad9aa397e6007e3bf9')) paddle.fluid.layers.affine_channel (ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name', 'act'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None, None)), ('document', 'ab84fdc6dc60f3ad9aa397e6007e3bf9'))
...@@ -484,6 +484,11 @@ paddle.fluid.optimizer.LarsMomentumOptimizer.apply_gradients (ArgSpec(args=['sel ...@@ -484,6 +484,11 @@ paddle.fluid.optimizer.LarsMomentumOptimizer.apply_gradients (ArgSpec(args=['sel
paddle.fluid.optimizer.LarsMomentumOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.LarsMomentumOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f'))
paddle.fluid.optimizer.LarsMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.LarsMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.LarsMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) paddle.fluid.optimizer.LarsMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea'))
paddle.fluid.optimizer.DGCMomentumOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'momentum', 'rampup_begin_step', 'rampup_step', 'sparsity', 'use_nesterov', 'local_grad_clip_norm', 'num_trainers', 'regularization', 'name'], varargs=None, keywords=None, defaults=(1, [0.999], False, None, None, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.DGCMomentumOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871'))
paddle.fluid.optimizer.DGCMomentumOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f'))
paddle.fluid.optimizer.DGCMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.DGCMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea'))
paddle.fluid.backward.append_backward (ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '1a79bd7d10ae54ca763ec81bca36ba24')) paddle.fluid.backward.append_backward (ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '1a79bd7d10ae54ca763ec81bca36ba24'))
paddle.fluid.regularizer.L1DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.regularizer.L1DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.regularizer.L2DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.regularizer.L2DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
......
...@@ -23,7 +23,7 @@ endif() ...@@ -23,7 +23,7 @@ endif()
if(WITH_GPU) if(WITH_GPU)
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor) dynload_cuda variable_visitor dgc)
nv_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory nv_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor) dynload_cuda variable_visitor)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
......
...@@ -42,8 +42,7 @@ VarHandle* GetValidInput(const OpHandleBase* a) { ...@@ -42,8 +42,7 @@ VarHandle* GetValidInput(const OpHandleBase* a) {
return nullptr; return nullptr;
} }
std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl( void AllReduceDepsPass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
auto graph_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph); auto graph_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
// get vars order // get vars order
...@@ -86,7 +85,8 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl( ...@@ -86,7 +85,8 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
} }
} }
VLOG(10) << "dist_ops size:" << dist_ops.size() << std::endl; VLOG(10) << "dist_ops size:" << dist_ops.size()
<< ", outputs size:" << vars.size() << ", ops size:" << ops.size();
std::sort(dist_ops.begin(), dist_ops.end(), [&](OpHandleBase* op1, std::sort(dist_ops.begin(), dist_ops.end(), [&](OpHandleBase* op1,
OpHandleBase* op2) { OpHandleBase* op2) {
...@@ -99,6 +99,10 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl( ...@@ -99,6 +99,10 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
auto l_it = vars.find(i0->name()); auto l_it = vars.find(i0->name());
auto r_it = vars.find(i1->name()); auto r_it = vars.find(i1->name());
PADDLE_ENFORCE(l_it != vars.end() && r_it != vars.end(),
"can't find var's name %s and %s in opdesc", i0->name(),
i1->name());
if (l_it->second < r_it->second) return true; if (l_it->second < r_it->second) return true;
if (l_it->second == r_it->second) { if (l_it->second == r_it->second) {
...@@ -126,8 +130,6 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl( ...@@ -126,8 +130,6 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
VLOG(10) << "pre_op:" << pre_op->DebugString() VLOG(10) << "pre_op:" << pre_op->DebugString()
<< ", op:" << op->DebugString(); << ", op:" << op->DebugString();
} }
return graph;
} }
} // namespace details } // namespace details
......
...@@ -24,8 +24,7 @@ namespace details { ...@@ -24,8 +24,7 @@ namespace details {
// TODO(gongwb): overlap allreduce with backward computation. // TODO(gongwb): overlap allreduce with backward computation.
class AllReduceDepsPass : public ir::Pass { class AllReduceDepsPass : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace details } // namespace details
......
...@@ -16,6 +16,13 @@ ...@@ -16,6 +16,13 @@
#include "paddle/fluid/framework/details/container_cast.h" #include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h" #include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h" #include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/framework/operator.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "dgc/dgc.h"
#endif
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
// asynchronous nccl allreduce or synchronous issue: // asynchronous nccl allreduce or synchronous issue:
...@@ -33,11 +40,14 @@ namespace details { ...@@ -33,11 +40,14 @@ namespace details {
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs) const platform::NCCLContextMap *ctxs,
bool is_encoded, int nranks)
: OpHandleBase(node), : OpHandleBase(node),
local_scopes_(local_scopes), local_scopes_(local_scopes),
places_(places), places_(places),
nccl_ctxs_(ctxs) { nccl_ctxs_(ctxs),
is_encoded_(is_encoded),
nranks_(nranks) {
if (nccl_ctxs_) { if (nccl_ctxs_) {
for (auto &p : places_) { for (auto &p : places_) {
this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p)); this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p));
...@@ -51,7 +61,185 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, ...@@ -51,7 +61,185 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {} : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif #endif
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void AllReduceOpHandle::RunImplEncoded() {
platform::RecordEvent record_event(Name());
WaitInputVarGenerated();
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE_EQ(
in_var_handles.size(), places_.size(),
"The NoDummyInputSize should be equal to the number of places.");
PADDLE_ENFORCE_EQ(
in_var_handles.size(), out_var_handles.size(),
"The NoDummyInputSize and NoDummyOutputSize should be equal.");
std::vector<const LoDTensor *> ins;
std::vector<LoDTensor *> outs;
int k = -1;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &local_scope =
local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto original_name =
paddle::framework::GradOriginalVarName(in_var_handles[i]->name());
auto encode_var_name = original_name + g_dgc_encoded;
auto *in_var = local_scope->FindVar(encode_var_name);
PADDLE_ENFORCE_NOT_NULL(in_var);
auto &in = in_var->Get<LoDTensor>();
ins.emplace_back(&in);
auto *out = local_scope->FindVar(out_var_handles[i]->name())
->GetMutable<LoDTensor>();
outs.emplace_back(out);
if (k < 0) {
k = GetKValue(in_var_handles[i]->name());
}
}
PADDLE_ENFORCE(platform::is_gpu_place(ins[0]->place()));
PADDLE_ENFORCE(platform::is_gpu_place(outs[0]->place()));
PADDLE_ENFORCE(nccl_ctxs_, "nccl_ctxs should not be nullptr.");
int dtype = -1;
size_t in_numel = 0;
size_t out_numel = 0;
PADDLE_ENFORCE(nranks_ > 1);
std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &place = places_[i];
auto &in = *ins[i];
void *in_tensor_buf = const_cast<void *>(in.data<void>());
auto &out = *outs[i];
float *out_tensor_buf = out.data<float>();
dtype = (dtype == -1) ? platform::ToNCCLDataType(in.type()) : dtype;
in_numel = (in_numel == 0) ? static_cast<size_t>(in.numel()) : in_numel;
PADDLE_ENFORCE(in_numel % 2 == 0);
PADDLE_ENFORCE(in_numel / 2 == static_cast<size_t>(k));
out_numel = (out_numel == 0) ? static_cast<size_t>(out.numel()) : out_numel;
int dev_id = boost::get<platform::CUDAPlace>(place).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
auto &allocator =
platform::DeviceTemporaryAllocator::Instance().Get(place, stream);
int encode_size = 2 * k * sizeof(int);
// dgc use ncclAllGather to get all the encoded data
// so the buffer need nranks.
int buf_size = nranks_ * encode_size;
auto tmp_ious_data = allocator.Allocate(buf_size);
void *gather_buff = reinterpret_cast<void *>(tmp_ious_data->ptr());
VLOG(10) << "in_numel:" << in_numel << ", out_numel:" << out_numel
<< ", nranks:" << nranks_ << ", gather_buf size:" << buf_size
<< ", k:" << k << ", place:" << place << ", dtype:" << dtype;
all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(paddle::communication::dgc::sparseAllGReduce(
in_tensor_buf, gather_buff, k, out_tensor_buf, out_numel, comm,
stream));
});
}
this->RunAndRecordEvent([&] {
if (all_reduce_calls.size() == 1UL) {
// Do not use NCCLGroup when manage NCCL by per thread per device
all_reduce_calls[0]();
} else {
platform::NCCLGroupGuard guard;
for (auto &call : all_reduce_calls) {
call();
}
}
});
if (FLAGS_sync_nccl_allreduce) {
for (auto &p : places_) {
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
cudaError_t e_sync = cudaStreamSynchronize(stream);
if (e_sync != 0) {
LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync);
}
cudaError_t e_get = cudaGetLastError();
if (e_get != 0) {
LOG(FATAL) << "cudaGetLastError " << cudaGetErrorString(e_get)
<< " errno:" << e_get;
}
}
}
}
int AllReduceOpHandle::GetKValue(const std::string &grad_name) {
auto original_name = paddle::framework::GradOriginalVarName(grad_name);
auto var_name = original_name + g_dgc_k;
PADDLE_ENFORCE(local_scopes_.size() > 0);
auto *scope = local_scopes_[0];
auto &local_scope = scope->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto var = local_scope->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(var);
auto tensor = var->Get<LoDTensor>().data<float>();
return *tensor;
}
#endif
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
bool AllReduceOpHandle::IsEncoded() {
if (!is_encoded_) {
return false;
}
auto counter_name = g_dgc_counter_name;
auto step_name = g_dgc_rampup_begin_step;
PADDLE_ENFORCE(local_scopes_.size() > 0);
auto *scope = local_scopes_[0];
auto &local_scope = scope->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto count_var = local_scope->FindVar(counter_name);
auto step_var = local_scope->FindVar(step_name);
if (count_var == nullptr || step_var == nullptr) {
PADDLE_THROW("not find count_var:%s or step_var:%s", counter_name,
step_var);
}
float count = *count_var->Get<LoDTensor>().data<float>();
float step = *step_var->Get<LoDTensor>().data<float>();
if (static_cast<int>(count) < static_cast<int>(step)) {
VLOG(10) << "in all_reduce currentstep:" << count
<< " < rampup_begin_step:" << step
<< " so not use sparse all reduce";
return false;
}
return true;
}
#else
bool AllReduceOpHandle::IsEncoded() { return false; }
#endif
void AllReduceOpHandle::RunImpl() { void AllReduceOpHandle::RunImpl() {
if (!IsEncoded()) {
RunImplNormal();
return;
}
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
RunImplEncoded();
#else
PADDLE_THROW("Not compiled with CUDA");
#endif
}
void AllReduceOpHandle::RunImplNormal() {
platform::RecordEvent record_event(Name()); platform::RecordEvent record_event(Name());
WaitInputVarGenerated(); WaitInputVarGenerated();
...@@ -72,6 +260,8 @@ void AllReduceOpHandle::RunImpl() { ...@@ -72,6 +260,8 @@ void AllReduceOpHandle::RunImpl() {
auto &lod_tensor = auto &lod_tensor =
local_scope.FindVar(in_var_handles[i]->name())->Get<LoDTensor>(); local_scope.FindVar(in_var_handles[i]->name())->Get<LoDTensor>();
lod_tensors.emplace_back(&lod_tensor); lod_tensors.emplace_back(&lod_tensor);
VLOG(10) << "place:" << i << ", input_name:" << in_var_handles[i]->name()
<< ", out_name:" << out_var_handles[i]->name();
PADDLE_ENFORCE_EQ(in_var_handles[i]->name(), out_var_handles[i]->name(), PADDLE_ENFORCE_EQ(in_var_handles[i]->name(), out_var_handles[i]->name(),
"The name of input and output should be equal."); "The name of input and output should be equal.");
} }
...@@ -99,13 +289,17 @@ void AllReduceOpHandle::RunImpl() { ...@@ -99,13 +289,17 @@ void AllReduceOpHandle::RunImpl() {
auto &nccl_ctx = nccl_ctxs_->at(dev_id); auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream(); auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_; auto comm = nccl_ctx.comm_;
VLOG(10) << "before all reduce buffer:" << buffer << ", numel:" << numel
<< ", dev_id:" << dev_id << ", dtype:" << dtype
<< ", place:" << p;
all_reduce_calls.emplace_back([=] { all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(platform::dynload::ncclAllReduce( PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum, buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
comm, stream)); comm, stream));
}); });
} }
this->RunAndRecordEvent([&] { this->RunAndRecordEvent([&] {
if (all_reduce_calls.size() == 1UL) { if (all_reduce_calls.size() == 1UL) {
// Do not use NCCLGroup when manage NCCL by per thread per device // Do not use NCCLGroup when manage NCCL by per thread per device
......
...@@ -28,11 +28,19 @@ namespace paddle { ...@@ -28,11 +28,19 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
constexpr char g_dgc_counter_name[] = "__g_dgc_counter__";
constexpr char g_dgc_rampup_begin_step[] = "__g_rampup_begin_step__";
constexpr char g_dgc_encoded[] = "__dgc_encoded__";
constexpr char g_dgc_k[] = "__dgc_k__";
#endif
struct AllReduceOpHandle : public OpHandleBase { struct AllReduceOpHandle : public OpHandleBase {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes, AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs); const platform::NCCLContextMap *ctxs,
bool is_encoded = false, int nranks = -1);
#else #else
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes, AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places); const std::vector<platform::Place> &places);
...@@ -50,8 +58,14 @@ struct AllReduceOpHandle : public OpHandleBase { ...@@ -50,8 +58,14 @@ struct AllReduceOpHandle : public OpHandleBase {
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void RunImplEncoded();
const platform::NCCLContextMap *nccl_ctxs_; const platform::NCCLContextMap *nccl_ctxs_;
bool is_encoded_{false};
int nranks_{-1};
int GetKValue(const std::string &grad_name);
#endif #endif
void RunImplNormal();
bool IsEncoded();
}; };
} // namespace details } // namespace details
......
...@@ -46,8 +46,7 @@ static framework::proto::VarType::Type kDefaultDtype = ...@@ -46,8 +46,7 @@ static framework::proto::VarType::Type kDefaultDtype =
class AllocContinuousSpaceForGradPass : public ir::Pass { class AllocContinuousSpaceForGradPass : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph *graph) const override {
std::unique_ptr<ir::Graph> graph) const override {
ir::Graph &result = *graph; ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces); auto &places = Get<const std::vector<platform::Place>>(kPlaces);
...@@ -65,7 +64,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -65,7 +64,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
if (params_grads.size() == 0) { if (params_grads.size() == 0) {
VLOG(10) << "Doesn't find gradients"; VLOG(10) << "Doesn't find gradients";
return std::move(graph); return;
} }
std::unordered_map<std::string, ir::Node *> vars; std::unordered_map<std::string, ir::Node *> vars;
...@@ -124,8 +123,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -124,8 +123,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars, InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars,
fused_var_name, params_grads); fused_var_name, params_grads);
return std::move(graph);
} }
template <typename AttrType> template <typename AttrType>
......
...@@ -204,15 +204,16 @@ bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const { ...@@ -204,15 +204,16 @@ bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
return framework::details::MultiDevSSAGraphBuilder().count(pass_name) > 0; return framework::details::MultiDevSSAGraphBuilder().count(pass_name) > 0;
} }
std::unique_ptr<ir::Graph> BuildStrategy::Apply( ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
std::unique_ptr<ir::Graph> graph, const std::vector<platform::Place> &places,
const std::vector<platform::Place> &places, const std::string &loss_var_name,
const std::string &loss_var_name, const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const size_t &nranks, const size_t &nranks,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const { const bool use_cuda,
platform::NCCLContextMap *nccl_ctxs) const {
#else #else
const bool use_cuda) const { const bool use_cuda) const {
#endif #endif
// Create a default one if not finalized by user. // Create a default one if not finalized by user.
CreatePassesFromStrategy(false); CreatePassesFromStrategy(false);
...@@ -265,7 +266,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -265,7 +266,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
} }
} }
VLOG(3) << "Start Apply Pass " << pass->Type(); VLOG(3) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(std::move(graph)); graph = pass->Apply(graph);
VLOG(3) << "Finish Apply Pass " << pass->Type(); VLOG(3) << "Finish Apply Pass " << pass->Type();
} }
return graph; return graph;
......
...@@ -120,16 +120,15 @@ struct BuildStrategy { ...@@ -120,16 +120,15 @@ struct BuildStrategy {
// Apply the passes built by the pass_builder_. The passes will be // Apply the passes built by the pass_builder_. The passes will be
// applied to the Program and output an ir::Graph. // applied to the Program and output an ir::Graph.
std::unique_ptr<ir::Graph> Apply(std::unique_ptr<ir::Graph> graph, ir::Graph *Apply(ir::Graph *graph, const std::vector<platform::Place> &places,
const std::vector<platform::Place> &places, const std::string &loss_var_name,
const std::string &loss_var_name, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_scopes, const size_t &nranks,
const size_t &nranks,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda, const bool use_cuda,
platform::NCCLContextMap *nccl_ctxs) const; platform::NCCLContextMap *nccl_ctxs) const;
#else #else
const bool use_cuda) const; const bool use_cuda) const;
#endif #endif
// If set true, ParallelExecutor would build the main_program into multiple // If set true, ParallelExecutor would build the main_program into multiple
......
...@@ -170,12 +170,10 @@ static OpToVarNameSetMap ShrinkGCVars( ...@@ -170,12 +170,10 @@ static OpToVarNameSetMap ShrinkGCVars(
class EagerDeletionPass : public ir::Pass { class EagerDeletionPass : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph *graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = auto &ref_cnts =
Get<std::vector<AtomicReferenceCountMap>>(kRuntimeReferenceCount); Get<std::vector<AtomicReferenceCountMap>>(kRuntimeReferenceCount);
PADDLE_ENFORCE(ref_cnts.empty(), PADDLE_ENFORCE(ref_cnts.empty(),
...@@ -240,7 +238,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( ...@@ -240,7 +238,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
auto while_op_eager_deletion_pass = auto while_op_eager_deletion_pass =
ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass"); ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass");
return while_op_eager_deletion_pass->Apply(std::move(graph)); while_op_eager_deletion_pass->Apply(graph);
} }
} // namespace details } // namespace details
......
...@@ -28,8 +28,7 @@ namespace details { ...@@ -28,8 +28,7 @@ namespace details {
class FuseAllReduceOpPass : public ir::Pass { class FuseAllReduceOpPass : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph *graph) const override {
std::unique_ptr<ir::Graph> graph) const override {
ir::Graph &result = *graph; ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces); auto &places = Get<const std::vector<platform::Place>>(kPlaces);
...@@ -71,7 +70,7 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -71,7 +70,7 @@ class FuseAllReduceOpPass : public ir::Pass {
VLOG(10) << "Find all_reduce_ops: " << all_reduce_ops.size(); VLOG(10) << "Find all_reduce_ops: " << all_reduce_ops.size();
if (all_reduce_ops.size() == 0) { if (all_reduce_ops.size() == 0) {
return std::move(graph); return;
} }
PADDLE_ENFORCE_EQ(all_reduce_ops.size(), grads.size(), PADDLE_ENFORCE_EQ(all_reduce_ops.size(), grads.size(),
...@@ -99,7 +98,6 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -99,7 +98,6 @@ class FuseAllReduceOpPass : public ir::Pass {
group_all_reduce_ops, &result); group_all_reduce_ops, &result);
#endif #endif
} }
return std::move(graph);
} }
void InsertFusedAllReduce(const std::vector<platform::Place> &places, void InsertFusedAllReduce(const std::vector<platform::Place> &places,
......
...@@ -144,10 +144,9 @@ void InplacePass::InitSSAGraphNodes() const { ...@@ -144,10 +144,9 @@ void InplacePass::InitSSAGraphNodes() const {
} }
} }
std::unique_ptr<ir::Graph> InplacePass::ApplyImpl( void InplacePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
var_nodes_.clear(); var_nodes_.clear();
view_.Build(graph.get()); view_.Build(graph);
InitSSAGraphNodes(); InitSSAGraphNodes();
auto cnt = 0; auto cnt = 0;
...@@ -155,11 +154,9 @@ std::unique_ptr<ir::Graph> InplacePass::ApplyImpl( ...@@ -155,11 +154,9 @@ std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
VLOG(4) << "Handle op " << cnt++ << ": " << op->Name(); VLOG(4) << "Handle op " << cnt++ << ": " << op->Name();
if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name())) if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name()))
continue; continue;
TryInplaceOpInputOutput(op, graph.get()); TryInplaceOpInputOutput(op, graph);
} }
// graph->ResolveHazard(var_nodes_); // graph->ResolveHazard(var_nodes_);
return graph;
} }
void InplacePass::InplaceModifyDesc(const std::string& var, void InplacePass::InplaceModifyDesc(const std::string& var,
......
...@@ -69,8 +69,7 @@ class InplacePass : public ir::Pass { ...@@ -69,8 +69,7 @@ class InplacePass : public ir::Pass {
InplacePass(); InplacePass();
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
void InitSSAGraphNodes() const; void InitSSAGraphNodes() const;
......
...@@ -44,8 +44,7 @@ namespace paddle { ...@@ -44,8 +44,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl( void MemoryOptimizePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
auto nodes = graph->Nodes(); auto nodes = graph->Nodes();
CollectSkipVarsSet(nodes); CollectSkipVarsSet(nodes);
...@@ -113,7 +112,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl( ...@@ -113,7 +112,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
cfg_->RenameVarInCFGGraph(var_name, cache_name, idx); cfg_->RenameVarInCFGGraph(var_name, cache_name, idx);
RenameVarInGraphDesc(var_name, cache_name, idx); RenameVarInGraphDesc(var_name, cache_name, idx);
RenameVarInGraphNode(var_name, cache_name, idx, graph.get()); RenameVarInGraphNode(var_name, cache_name, idx, graph);
pool_.Erase(cache_name); pool_.Erase(cache_name);
} }
} }
...@@ -128,8 +127,6 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl( ...@@ -128,8 +127,6 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
} }
} }
graph->ResolveHazard(var_nodes_); graph->ResolveHazard(var_nodes_);
return graph;
} }
void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const { void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <set> #include <set>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -35,8 +36,7 @@ namespace details { ...@@ -35,8 +36,7 @@ namespace details {
class MemoryOptimizePass : public ir::Pass { class MemoryOptimizePass : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
// fill the variable map(var_nodes) by version. // fill the variable map(var_nodes) by version.
void InitSSAGraphNodes() const; void InitSSAGraphNodes() const;
......
...@@ -34,8 +34,7 @@ static bool IsLockAndRecordEventFreeComputationOpHandle( ...@@ -34,8 +34,7 @@ static bool IsLockAndRecordEventFreeComputationOpHandle(
return true; return true;
} }
std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl( void ModifyOpLockAndRecordEventPass::ApplyImpl(ir::Graph *ir_graph) const {
std::unique_ptr<ir::Graph> ir_graph) const {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*ir_graph); auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*ir_graph);
OpGraphView graph_view(all_ops); OpGraphView graph_view(all_ops);
for (auto &op : all_ops) { for (auto &op : all_ops) {
...@@ -49,7 +48,6 @@ std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl( ...@@ -49,7 +48,6 @@ std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
<< compute_op->DebugString(); << compute_op->DebugString();
} }
} }
return ir_graph;
} }
} // namespace details } // namespace details
......
...@@ -23,8 +23,7 @@ namespace details { ...@@ -23,8 +23,7 @@ namespace details {
class ModifyOpLockAndRecordEventPass : public ir::Pass { class ModifyOpLockAndRecordEventPass : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace details } // namespace details
......
...@@ -23,10 +23,8 @@ namespace details { ...@@ -23,10 +23,8 @@ namespace details {
class SSAGraghBuilderWithChecker : public ir::Pass { class SSAGraghBuilderWithChecker : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph *graph) const override {
std::unique_ptr<ir::Graph> graph) const override { PADDLE_ENFORCE(IsValidGraph(graph));
PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph;
} }
bool IsValidGraph(const ir::Graph *graph) const { bool IsValidGraph(const ir::Graph *graph) const {
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -152,8 +153,7 @@ void MultiDevSSAGraphBuilderBase::Init() const { ...@@ -152,8 +153,7 @@ void MultiDevSSAGraphBuilderBase::Init() const {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
} }
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl( void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const {
std::unique_ptr<ir::Graph> graph) const {
Init(); Init();
CheckGraph(*graph); CheckGraph(*graph);
std::vector<ir::Node *> sorted_ops = SortOperations(*graph); std::vector<ir::Node *> sorted_ops = SortOperations(*graph);
...@@ -209,7 +209,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl( ...@@ -209,7 +209,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
for (size_t i = 0; i < backward_vars.size(); i += 2) { for (size_t i = 0; i < backward_vars.size(); i += 2) {
auto &p_name = backward_vars[i]; auto &p_name = backward_vars[i];
auto &g_name = backward_vars[i + 1]; auto &g_name = backward_vars[i + 1];
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name; VLOG(10) << "Bcast " << g_name << " for parameter " << p_name
<< " op_type " << node->Op()->Type();
if (NeedCollectiveForGrad(g_name, sorted_ops)) { if (NeedCollectiveForGrad(g_name, sorted_ops)) {
InsertCollectiveOp(&result, p_name, g_name); InsertCollectiveOp(&result, p_name, g_name);
} }
...@@ -234,7 +235,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl( ...@@ -234,7 +235,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
AddOutputToLeafOps(&result); AddOutputToLeafOps(&result);
result.Erase(kGraphOps); result.Erase(kGraphOps);
return graph;
} }
void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp( void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp(
...@@ -414,8 +414,9 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result, ...@@ -414,8 +414,9 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
CreateOpHandleIOs(result, node, dev_id); CreateOpHandleIOs(result, node, dev_id);
} }
void MultiDevSSAGraphBuilderBase::CreateAllReduceOp( void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
ir::Graph *result, const std::string &og) const { const std::string &og,
bool is_encoded) const {
OpHandleBase *op_handle = nullptr; OpHandleBase *op_handle = nullptr;
auto append_allreduce_op = [&]( auto append_allreduce_op = [&](
...@@ -424,7 +425,9 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp( ...@@ -424,7 +425,9 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
scopes, places, nccl_ctxs_)); scopes, places, nccl_ctxs_, is_encoded,
static_cast<int>(strategy_.trainers_endpoints_.size()) *
places_.size()));
#else #else
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
...@@ -446,12 +449,15 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp( ...@@ -446,12 +449,15 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad); op_handle->AddInput(prev_grad);
VLOG(10) << "all_reduce_op_handle add input " << prev_grad->DebugString();
auto var = auto var =
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable), new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
vars.size(), i, og, places_[i]); vars.size(), i, og, places_[i]);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
VLOG(10) << "all_reduce_op_handle add output " << og
<< ", handle:" << var->DebugString();
} }
} }
...@@ -941,6 +947,17 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -941,6 +947,17 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
return op_dev_id; return op_dev_id;
} }
bool DistSSAGraphBuilder::IsEncoded(const std::string &p_name) const {
auto u_name = p_name + "__dgc_u__";
auto it = all_vars_.find(u_name);
if (it == all_vars_.end()) {
VLOG(10) << "can't find u_name, so it's not encoded:" << u_name;
return false;
}
return true;
}
void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
const std::string &p_name, const std::string &p_name,
const std::string &g_name) const { const std::string &g_name) const {
...@@ -956,7 +973,11 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, ...@@ -956,7 +973,11 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
CreateReduceOp(result, g_name, 0); CreateReduceOp(result, g_name, 0);
CreateBroadcastOp(result, g_name, 0); CreateBroadcastOp(result, g_name, 0);
} else { } else {
CreateAllReduceOp(result, g_name); #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
CreateAllReduceOp(result, g_name, IsEncoded(p_name));
#else
PADDLE_ENFORCE(false, "Compiled withoud cuda!");
#endif
} }
break; break;
default: default:
......
...@@ -36,8 +36,7 @@ namespace details { ...@@ -36,8 +36,7 @@ namespace details {
class MultiDevSSAGraphBuilderBase : public ir::Pass { class MultiDevSSAGraphBuilderBase : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph *graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
virtual void Init() const; virtual void Init() const;
...@@ -75,7 +74,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { ...@@ -75,7 +74,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
bool IsSparseGradient(const std::string &og) const; bool IsSparseGradient(const std::string &og) const;
void CreateAllReduceOp(ir::Graph *result, const std::string &og) const; void CreateAllReduceOp(ir::Graph *result, const std::string &og,
bool is_encoded = false) const;
void CreateBroadcastOp(ir::Graph *result, const std::string &p_name, void CreateBroadcastOp(ir::Graph *result, const std::string &p_name,
size_t src_dev_id) const; size_t src_dev_id) const;
...@@ -171,6 +171,8 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder { ...@@ -171,6 +171,8 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
mutable std::vector<std::unordered_set<std::string>> bcast_var_name_set_; mutable std::vector<std::unordered_set<std::string>> bcast_var_name_set_;
mutable bool need_broadcast_var_{false}; mutable bool need_broadcast_var_{false};
bool IsEncoded(const std::string &p_name) const;
}; };
std::unordered_set<std::string> &MultiDevSSAGraphBuilder(); std::unordered_set<std::string> &MultiDevSSAGraphBuilder();
......
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <glog/logging.h> #include <glog/logging.h>
#include <fstream> #include <fstream>
#include <iosfwd> #include <iosfwd>
#include <memory>
#include <ostream> #include <ostream>
#include <string> #include <string>
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
...@@ -40,13 +41,11 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter { ...@@ -40,13 +41,11 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
class SSAGraghBuilderWithPrinter : public ir::Pass { class SSAGraghBuilderWithPrinter : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override {
std::unique_ptr<ir::Graph> graph) const override {
std::unique_ptr<std::ostream> fout( std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>(kGraphvizPath))); new std::ofstream(Get<std::string>(kGraphvizPath)));
PADDLE_ENFORCE(fout->good()); PADDLE_ENFORCE(fout->good());
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout); Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
return graph;
} }
}; };
......
...@@ -96,7 +96,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( ...@@ -96,7 +96,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
auto seq_allreduce_pass = auto seq_allreduce_pass =
ir::PassRegistry::Instance().Get("all_reduce_deps_pass"); ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
for (size_t i = 0; i < graphs_.size(); ++i) { for (size_t i = 0; i < graphs_.size(); ++i) {
graphs_[i] = seq_allreduce_pass->Apply(std::move(graphs_[i])); graphs_[i].reset(seq_allreduce_pass->Apply(graphs_[i].release()));
} }
// set the correct size of thread pool to each device. // set the correct size of thread pool to each device.
......
...@@ -266,8 +266,7 @@ static bool ShrinkNoNeedBufferVarOpDependency( ...@@ -266,8 +266,7 @@ static bool ShrinkNoNeedBufferVarOpDependency(
} }
} }
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount); auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount);
auto &last_live_ops_of_vars = auto &last_live_ops_of_vars =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars); Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
...@@ -335,14 +334,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -335,14 +334,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
var_name); var_name);
ref_cnts[i].emplace(var_name, result.size()); ref_cnts[i].emplace(var_name, result.size());
last_live_ops_of_vars[i].emplace(var_name, std::move(result)); last_live_ops_of_vars[i].emplace(var_name, std::move(result));
break;
} }
// Seldomly, all preceding trying failed. // Seldomly, all preceding trying failed.
// Just skip this corner case // Just skip this corner case
} }
} }
return graph;
} }
} // namespace details } // namespace details
......
...@@ -23,8 +23,7 @@ namespace details { ...@@ -23,8 +23,7 @@ namespace details {
class ReferenceCountPass : public ir::Pass { class ReferenceCountPass : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace details } // namespace details
......
...@@ -29,8 +29,7 @@ static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) { ...@@ -29,8 +29,7 @@ static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) {
op1->Outputs() == op2->Outputs(); op1->Outputs() == op2->Outputs();
} }
std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl( void SequentialExecutionPass::ApplyImpl(ir::Graph *graph) const {
std::unique_ptr<ir::Graph> graph) const {
// FIXME(zjl): Insert dependencies between some distributed ops may cause // FIXME(zjl): Insert dependencies between some distributed ops may cause
// the multi_devices_graph_pass fails. So we skip these ops here. // the multi_devices_graph_pass fails. So we skip these ops here.
// Indeed, maybe we should not insert dependencies between these ops // Indeed, maybe we should not insert dependencies between these ops
...@@ -98,7 +97,6 @@ std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl( ...@@ -98,7 +97,6 @@ std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl(
VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name() VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name()
<< " and " << op_node_list[i]->Name(); << " and " << op_node_list[i]->Name();
} }
return graph;
} }
} // namespace details } // namespace details
......
...@@ -23,8 +23,7 @@ namespace details { ...@@ -23,8 +23,7 @@ namespace details {
class SequentialExecutionPass : public ir::Pass { class SequentialExecutionPass : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace details } // namespace details
......
...@@ -24,7 +24,8 @@ VarHandle::~VarHandle() { VLOG(4) << "deleting var handle " << DebugString(); } ...@@ -24,7 +24,8 @@ VarHandle::~VarHandle() { VLOG(4) << "deleting var handle " << DebugString(); }
std::string VarHandle::DebugString() const { std::string VarHandle::DebugString() const {
std::stringstream ss; std::stringstream ss;
ss << name_ << ":" << place_; ss << "name:" << name_ << ", place:" << place_ << ", version:" << version_
<< ", scope_idx:" << scope_idx_;
return ss.str(); return ss.str();
} }
......
...@@ -23,8 +23,7 @@ namespace details { ...@@ -23,8 +23,7 @@ namespace details {
class WhileOpEagerDeletionPass : public ir::Pass { class WhileOpEagerDeletionPass : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph *graph) const override {
std::unique_ptr<ir::Graph> graph) const override {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph); auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
// Find all while_op and while_grad_op // Find all while_op and while_grad_op
...@@ -50,7 +49,6 @@ class WhileOpEagerDeletionPass : public ir::Pass { ...@@ -50,7 +49,6 @@ class WhileOpEagerDeletionPass : public ir::Pass {
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
while_ops, while_grad_ops); while_ops, while_grad_ops);
} }
return graph;
} }
}; };
......
...@@ -29,10 +29,9 @@ namespace ir { ...@@ -29,10 +29,9 @@ namespace ir {
GET_IR_NODE(elementwise_mul); \ GET_IR_NODE(elementwise_mul); \
GET_IR_NODE(elementwise_mul_out); GET_IR_NODE(elementwise_mul_out);
std::unique_ptr<ir::Graph> AnakinFillconstantElementwisemulFuse::ApplyImpl( void AnakinFillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
const std::string pattern_name = "anakin_fillconstant_elementwisemul_fuse"; const std::string pattern_name = "anakin_fillconstant_elementwisemul_fuse";
FusePassBase::Init(pattern_name, graph.get()); FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern() auto* x = gpd.mutable_pattern()
...@@ -69,12 +68,11 @@ std::unique_ptr<ir::Graph> AnakinFillconstantElementwisemulFuse::ApplyImpl( ...@@ -69,12 +68,11 @@ std::unique_ptr<ir::Graph> AnakinFillconstantElementwisemulFuse::ApplyImpl(
IR_NODE_LINK_TO(scale_op, elementwise_mul_out); // Output IR_NODE_LINK_TO(scale_op, elementwise_mul_out); // Output
// Delete the unneeded nodes. // Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(), GraphSafeRemoveNodes(graph,
{fill_constant, fill_constant_out, elementwise_mul}); {fill_constant, fill_constant_out, elementwise_mul});
}; };
gpd(graph.get(), handler); gpd(graph, handler);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -26,8 +26,7 @@ class AnakinFillconstantElementwisemulFuse : public FusePassBase { ...@@ -26,8 +26,7 @@ class AnakinFillconstantElementwisemulFuse : public FusePassBase {
virtual ~AnakinFillconstantElementwisemulFuse() {} virtual ~AnakinFillconstantElementwisemulFuse() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/attention_lstm_fuse_pass.h" #include "paddle/fluid/framework/ir/attention_lstm_fuse_pass.h"
#include <string> #include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -253,8 +254,7 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input, ...@@ -253,8 +254,7 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
// Parameters // Parameters
std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl( void AttentionLSTMFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
PDPattern external_pattern, subblock_pattern; PDPattern external_pattern, subblock_pattern;
// Use the following variables to tell whether this model is RNN1. // Use the following variables to tell whether this model is RNN1.
...@@ -269,12 +269,11 @@ std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl( ...@@ -269,12 +269,11 @@ std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl(
} }
} }
if (count < specified_vars.size()) { if (count < specified_vars.size()) {
return graph; return;
} }
// Continue to fuse. // Continue to fuse.
FindWhileOp(graph.get()); FindWhileOp(graph);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -22,8 +22,7 @@ namespace ir { ...@@ -22,8 +22,7 @@ namespace ir {
class AttentionLSTMFusePass : public FusePassBase { class AttentionLSTMFusePass : public FusePassBase {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -77,10 +77,9 @@ void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight, ...@@ -77,10 +77,9 @@ void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight,
weights_array_2d.colwise() *= scale_array; weights_array_2d.colwise() *= scale_array;
} }
std::unique_ptr<ir::Graph> ConvAffineChannelFusePass::ApplyImpl( void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get());
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE(scope);
...@@ -139,7 +138,7 @@ std::unique_ptr<ir::Graph> ConvAffineChannelFusePass::ApplyImpl( ...@@ -139,7 +138,7 @@ std::unique_ptr<ir::Graph> ConvAffineChannelFusePass::ApplyImpl(
desc.SetAttr("axis", 1); desc.SetAttr("axis", 1);
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied. auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph.get(), {ac_scale, ac_bias, affine_channel}); GraphSafeRemoveNodes(graph, {ac_scale, ac_bias, affine_channel});
IR_NODE_LINK_TO(conv_out, eltwise_op); IR_NODE_LINK_TO(conv_out, eltwise_op);
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op); IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
...@@ -147,16 +146,14 @@ std::unique_ptr<ir::Graph> ConvAffineChannelFusePass::ApplyImpl( ...@@ -147,16 +146,14 @@ std::unique_ptr<ir::Graph> ConvAffineChannelFusePass::ApplyImpl(
found_conv_ac_count++; found_conv_ac_count++;
}; };
gpd(graph.get(), handler); gpd(graph, handler);
AddStatis(found_conv_ac_count); AddStatis(found_conv_ac_count);
return graph;
} }
std::unique_ptr<ir::Graph> ConvEltwiseAddAffineChannelFusePass::ApplyImpl( void ConvEltwiseAddAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get());
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE(scope);
...@@ -199,7 +196,7 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddAffineChannelFusePass::ApplyImpl( ...@@ -199,7 +196,7 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddAffineChannelFusePass::ApplyImpl(
eltwise->Op()->SetAttr("axis", 1); eltwise->Op()->SetAttr("axis", 1);
eltwise->Op()->SetOutput("Out", std::vector<std::string>({ac_out->Name()})); eltwise->Op()->SetOutput("Out", std::vector<std::string>({ac_out->Name()}));
GraphSafeRemoveNodes(graph.get(), GraphSafeRemoveNodes(graph,
{ac_scale, ac_bias, affine_channel, eltwise_out}); {ac_scale, ac_bias, affine_channel, eltwise_out});
IR_NODE_LINK_TO(eltwise, ac_out); IR_NODE_LINK_TO(eltwise, ac_out);
...@@ -207,9 +204,8 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddAffineChannelFusePass::ApplyImpl( ...@@ -207,9 +204,8 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddAffineChannelFusePass::ApplyImpl(
found_conv_ac_count++; found_conv_ac_count++;
}; };
gpd(graph.get(), handler); gpd(graph, handler);
AddStatis(found_conv_ac_count); AddStatis(found_conv_ac_count);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -31,8 +31,7 @@ class ConvAffineChannelFusePass : public FusePassBase { ...@@ -31,8 +31,7 @@ class ConvAffineChannelFusePass : public FusePassBase {
virtual ~ConvAffineChannelFusePass() {} virtual ~ConvAffineChannelFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph*) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"conv_affine_channel_fuse"}; const std::string name_scope_{"conv_affine_channel_fuse"};
}; };
...@@ -41,8 +40,7 @@ class ConvEltwiseAddAffineChannelFusePass : public FusePassBase { ...@@ -41,8 +40,7 @@ class ConvEltwiseAddAffineChannelFusePass : public FusePassBase {
virtual ~ConvEltwiseAddAffineChannelFusePass() {} virtual ~ConvEltwiseAddAffineChannelFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph*) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"conv_eltwiseadd_affine_channel_fuse"}; const std::string name_scope_{"conv_eltwiseadd_affine_channel_fuse"};
}; };
......
...@@ -101,10 +101,9 @@ void recompute_bias_and_weights(const Scope* scope, ...@@ -101,10 +101,9 @@ void recompute_bias_and_weights(const Scope* scope,
weights_array_2d.colwise() *= variance_array; weights_array_2d.colwise() *= variance_array;
} }
std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl( void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get());
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE(scope);
...@@ -187,7 +186,7 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl( ...@@ -187,7 +186,7 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
std::vector<std::string>({bn_out->Name()})); std::vector<std::string>({bn_out->Name()}));
GraphSafeRemoveNodes( GraphSafeRemoveNodes(
graph.get(), graph,
{conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, {conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm,
bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance}); bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance});
...@@ -203,10 +202,9 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl( ...@@ -203,10 +202,9 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
desc.SetAttr("axis", 1); desc.SetAttr("axis", 1);
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied. auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes( GraphSafeRemoveNodes(graph, {bn_scale, bn_bias, bn_mean, bn_variance,
graph.get(), batch_norm, bn_mean_out, bn_variance_out,
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out, bn_saved_mean, bn_saved_variance});
bn_variance_out, bn_saved_mean, bn_saved_variance});
IR_NODE_LINK_TO(conv_out, eltwise_op); IR_NODE_LINK_TO(conv_out, eltwise_op);
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op); IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
...@@ -215,16 +213,14 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl( ...@@ -215,16 +213,14 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
} }
}; };
gpd(graph.get(), handler); gpd(graph, handler);
AddStatis(found_conv_bn_count); AddStatis(found_conv_bn_count);
return graph;
} }
std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl( void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get());
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE(scope);
...@@ -274,7 +270,7 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl( ...@@ -274,7 +270,7 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl(
eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()})); eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
GraphSafeRemoveNodes( GraphSafeRemoveNodes(
graph.get(), graph,
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out, {bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
bn_variance_out, bn_saved_mean, bn_saved_variance, eltwise_out}); bn_variance_out, bn_saved_mean, bn_saved_variance, eltwise_out});
...@@ -283,10 +279,9 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl( ...@@ -283,10 +279,9 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl(
found_conv_bn_count++; found_conv_bn_count++;
}; };
gpd(graph.get(), handler); gpd(graph, handler);
AddStatis(found_conv_bn_count); AddStatis(found_conv_bn_count);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -31,8 +31,7 @@ class ConvBNFusePass : public FusePassBase { ...@@ -31,8 +31,7 @@ class ConvBNFusePass : public FusePassBase {
virtual ~ConvBNFusePass() {} virtual ~ConvBNFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"conv_bn_fuse"}; const std::string name_scope_{"conv_bn_fuse"};
}; };
...@@ -41,8 +40,7 @@ class ConvEltwiseAddBNFusePass : public FusePassBase { ...@@ -41,8 +40,7 @@ class ConvEltwiseAddBNFusePass : public FusePassBase {
virtual ~ConvEltwiseAddBNFusePass() {} virtual ~ConvEltwiseAddBNFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"conv_eltwiseadd_bn_fuse"}; const std::string name_scope_{"conv_eltwiseadd_bn_fuse"};
}; };
......
...@@ -50,10 +50,9 @@ framework::proto::OpDesc PrepareOpDesc( ...@@ -50,10 +50,9 @@ framework::proto::OpDesc PrepareOpDesc(
return *desc.Proto(); return *desc.Proto();
} }
std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl( void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
const std::string pattern_name = "conv_elementwise_add_act_fuse"; const std::string pattern_name = "conv_elementwise_add_act_fuse";
FusePassBase::Init(pattern_name, graph.get()); FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input( auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input(
...@@ -95,7 +94,6 @@ std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl( ...@@ -95,7 +94,6 @@ std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl(
elementwise_add_out}); elementwise_add_out});
}; };
gpd(graph.get(), handler); gpd(graph.get(), handler);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -51,10 +51,9 @@ framework::proto::OpDesc PrepareOpDesc( ...@@ -51,10 +51,9 @@ framework::proto::OpDesc PrepareOpDesc(
return *desc.Proto(); return *desc.Proto();
} }
std::unique_ptr<ir::Graph> ConvElementwiseAdd2ActFusePass::ApplyImpl( void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
const std::string pattern_name = "conv_elementwise_add2_act_fuse"; const std::string pattern_name = "conv_elementwise_add2_act_fuse";
FusePassBase::Init(pattern_name, graph.get()); FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input( auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input(
...@@ -92,12 +91,10 @@ std::unique_ptr<ir::Graph> ConvElementwiseAdd2ActFusePass::ApplyImpl( ...@@ -92,12 +91,10 @@ std::unique_ptr<ir::Graph> ConvElementwiseAdd2ActFusePass::ApplyImpl(
// Delete the unneeded nodes. // Delete the unneeded nodes.
GraphSafeRemoveNodes( GraphSafeRemoveNodes(
graph.get(), graph, {conv_op, conv_out, elementwise_add_op, elementwise_add_op_1,
{conv_op, conv_out, elementwise_add_op, elementwise_add_op_1, elementwise_add_out, elementwise_add_out_1, act_op});
elementwise_add_out, elementwise_add_out_1, act_op});
}; };
gpd(graph.get(), handler); gpd(graph, handler);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -25,8 +25,7 @@ class ConvElementwiseAdd2ActFusePass : public FusePassBase { ...@@ -25,8 +25,7 @@ class ConvElementwiseAdd2ActFusePass : public FusePassBase {
virtual ~ConvElementwiseAdd2ActFusePass() {} virtual ~ConvElementwiseAdd2ActFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -48,10 +48,9 @@ framework::proto::OpDesc PrepareOpDesc( ...@@ -48,10 +48,9 @@ framework::proto::OpDesc PrepareOpDesc(
return *desc.Proto(); return *desc.Proto();
} }
std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl( void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
const std::string pattern_name = "conv_elementwise_add_act_fuse"; const std::string pattern_name = "conv_elementwise_add_act_fuse";
FusePassBase::Init(pattern_name, graph.get()); FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern() auto* x = gpd.mutable_pattern()
...@@ -88,12 +87,11 @@ std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl( ...@@ -88,12 +87,11 @@ std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl(
IR_NODE_LINK_TO(new_conv_op, act_out); // Output IR_NODE_LINK_TO(new_conv_op, act_out); // Output
// Delete the unneeded nodes. // Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(), {conv_op, conv_out, elementwise_add_op, GraphSafeRemoveNodes(graph, {conv_op, conv_out, elementwise_add_op,
elementwise_add_out, act_op}); elementwise_add_out, act_op});
}; };
gpd(graph.get(), handler); gpd(graph, handler);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -25,8 +25,7 @@ class ConvElementwiseAddActFusePass : public FusePassBase { ...@@ -25,8 +25,7 @@ class ConvElementwiseAddActFusePass : public FusePassBase {
virtual ~ConvElementwiseAddActFusePass() {} virtual ~ConvElementwiseAddActFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -30,10 +30,9 @@ namespace ir { ...@@ -30,10 +30,9 @@ namespace ir {
GET_IR_NODE(elementwise_add_in_y); \ GET_IR_NODE(elementwise_add_in_y); \
GET_IR_NODE(elementwise_add_out); GET_IR_NODE(elementwise_add_out);
std::unique_ptr<ir::Graph> ConvElementwiseAddFusePass::ApplyImpl( void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
const std::string pattern_name = "conv_elementwise_add_fuse"; const std::string pattern_name = "conv_elementwise_add_fuse";
FusePassBase::Init(pattern_name, graph.get()); FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern() auto* x = gpd.mutable_pattern()
...@@ -76,11 +75,10 @@ std::unique_ptr<ir::Graph> ConvElementwiseAddFusePass::ApplyImpl( ...@@ -76,11 +75,10 @@ std::unique_ptr<ir::Graph> ConvElementwiseAddFusePass::ApplyImpl(
IR_NODE_LINK_TO(new_conv_op, elementwise_add_out); // Output IR_NODE_LINK_TO(new_conv_op, elementwise_add_out); // Output
// Delete the unneeded nodes. // Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(), {conv_op, conv_out, elementwise_add_op}); GraphSafeRemoveNodes(graph, {conv_op, conv_out, elementwise_add_op});
}; };
gpd(graph.get(), handler); gpd(graph, handler);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -25,8 +25,7 @@ class ConvElementwiseAddFusePass : public FusePassBase { ...@@ -25,8 +25,7 @@ class ConvElementwiseAddFusePass : public FusePassBase {
virtual ~ConvElementwiseAddFusePass() {} virtual ~ConvElementwiseAddFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include "paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h" #include "paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h"
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
...@@ -201,7 +203,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -201,7 +203,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
// Remove unneeded nodes. // Remove unneeded nodes.
// TODO(jczaja): Proper removing of lookup table // TODO(jczaja): Proper removing of lookup table
std::unordered_set<const Node*> marked_nodes( std::unordered_set<const Node*> marked_nodes(
//{lookup_table, mul, lstm, elementwise_add, fc_bias, W}); // {lookup_table, mul, lstm, elementwise_add, fc_bias, W});
{mul, lstm, elementwise_add, fc_bias}); {mul, lstm, elementwise_add, fc_bias});
GraphSafeRemoveNodes(graph, marked_nodes); GraphSafeRemoveNodes(graph, marked_nodes);
} else { } else {
...@@ -224,15 +226,13 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -224,15 +226,13 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
return fusion_count; return fusion_count;
} }
std::unique_ptr<ir::Graph> EmbeddingFCLSTMFusePass::ApplyImpl( void EmbeddingFCLSTMFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(), int fusion_count =
true /*with_fc_bias*/); BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/);
AddStatis(fusion_count); AddStatis(fusion_count);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -32,8 +32,7 @@ class EmbeddingFCLSTMFusePass : public FusePassBase { ...@@ -32,8 +32,7 @@ class EmbeddingFCLSTMFusePass : public FusePassBase {
virtual ~EmbeddingFCLSTMFusePass() {} virtual ~EmbeddingFCLSTMFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"embedding_fc_lstm_fuse"}; const std::string name_scope_{"embedding_fc_lstm_fuse"};
}; };
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/fc_fuse_pass.h" #include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -22,10 +23,9 @@ namespace paddle { ...@@ -22,10 +23,9 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( void FCFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init("fc_fuse", graph);
FusePassBase::Init("fc_fuse", graph.get());
std::unordered_set<Node*> nodes2delete; std::unordered_set<Node*> nodes2delete;
...@@ -61,7 +61,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( ...@@ -61,7 +61,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
desc.SetAttr("in_num_col_dims", mul->Op()->GetAttr("x_num_col_dims")); desc.SetAttr("in_num_col_dims", mul->Op()->GetAttr("x_num_col_dims"));
desc.SetType("fc"); desc.SetType("fc");
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied. auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph.get(), {mul, elementwise_add, mul_out}); GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out});
PADDLE_ENFORCE(subgraph.count(x)); PADDLE_ENFORCE(subgraph.count(x));
IR_NODE_LINK_TO(subgraph.at(x), fc_node); IR_NODE_LINK_TO(subgraph.at(x), fc_node);
...@@ -72,10 +72,9 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( ...@@ -72,10 +72,9 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
found_fc_count++; found_fc_count++;
}; };
gpd(graph.get(), handler); gpd(graph, handler);
AddStatis(found_fc_count); AddStatis(found_fc_count);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -31,8 +31,7 @@ class FCFusePass : public FusePassBase { ...@@ -31,8 +31,7 @@ class FCFusePass : public FusePassBase {
virtual ~FCFusePass() {} virtual ~FCFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -73,7 +73,7 @@ TEST(FCFusePass, basic) { ...@@ -73,7 +73,7 @@ TEST(FCFusePass, basic) {
int pre_nodes = graph->Nodes().size(); int pre_nodes = graph->Nodes().size();
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
int after_nodes = graph->Nodes().size(); int after_nodes = graph->Nodes().size();
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/fc_gru_fuse_pass.h" #include "paddle/fluid/framework/ir/fc_gru_fuse_pass.h"
#include <string> #include <string>
#include <unordered_set>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
namespace paddle { namespace paddle {
...@@ -39,7 +40,6 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -39,7 +40,6 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
// Create New OpDesc // Create New OpDesc
auto gru_creater = [&](Node* gru, Node* x, Node* weight_x, Node* weight_h, auto gru_creater = [&](Node* gru, Node* x, Node* weight_x, Node* weight_h,
Node* bias, Node* hidden, Node* fc_bias) { Node* bias, Node* hidden, Node* fc_bias) {
OpDesc op_desc; OpDesc op_desc;
op_desc.SetType("fusion_gru"); op_desc.SetType("fusion_gru");
...@@ -155,26 +155,22 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -155,26 +155,22 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
return fusion_count; return fusion_count;
} }
std::unique_ptr<ir::Graph> MulGRUFusePass::ApplyImpl( void MulGRUFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(), int fusion_count =
false /*with_fc_bias*/); BuildFusion(graph, name_scope_, param_scope(), false /*with_fc_bias*/);
AddStatis(fusion_count); AddStatis(fusion_count);
return graph;
} }
std::unique_ptr<ir::Graph> FCGRUFusePass::ApplyImpl( void FCGRUFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(), int fusion_count =
true /*with_fc_bias*/); BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/);
AddStatis(fusion_count); AddStatis(fusion_count);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -30,8 +30,7 @@ class FCGRUFusePass : public FusePassBase { ...@@ -30,8 +30,7 @@ class FCGRUFusePass : public FusePassBase {
virtual ~FCGRUFusePass() {} virtual ~FCGRUFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"fc_gru_fuse"}; const std::string name_scope_{"fc_gru_fuse"};
}; };
...@@ -42,8 +41,7 @@ class MulGRUFusePass : public FusePassBase { ...@@ -42,8 +41,7 @@ class MulGRUFusePass : public FusePassBase {
virtual ~MulGRUFusePass() {} virtual ~MulGRUFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"fc_nobias_gru_fuse"}; const std::string name_scope_{"fc_nobias_gru_fuse"};
}; };
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h" #include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include <string> #include <string>
#include <unordered_set>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
namespace paddle { namespace paddle {
...@@ -157,26 +158,22 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -157,26 +158,22 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
return fusion_count; return fusion_count;
} }
std::unique_ptr<ir::Graph> MulLstmFusePass::ApplyImpl( void MulLstmFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(), int fusion_count =
false /*with_fc_bias*/); BuildFusion(graph, name_scope_, param_scope(), false /*with_fc_bias*/);
AddStatis(fusion_count); AddStatis(fusion_count);
return graph;
} }
std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl( void FCLstmFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(), int fusion_count =
true /*with_fc_bias*/); BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/);
AddStatis(fusion_count); AddStatis(fusion_count);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -32,8 +32,7 @@ class FCLstmFusePass : public FusePassBase { ...@@ -32,8 +32,7 @@ class FCLstmFusePass : public FusePassBase {
virtual ~FCLstmFusePass() {} virtual ~FCLstmFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"fc_lstm_fuse"}; const std::string name_scope_{"fc_lstm_fuse"};
}; };
...@@ -43,8 +42,7 @@ class MulLstmFusePass : public FusePassBase { ...@@ -43,8 +42,7 @@ class MulLstmFusePass : public FusePassBase {
virtual ~MulLstmFusePass() {} virtual ~MulLstmFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"fc_nobias_lstm_fuse"}; const std::string name_scope_{"fc_nobias_lstm_fuse"};
}; };
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include "paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h" #include "paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h"
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -23,29 +25,25 @@ namespace paddle { ...@@ -23,29 +25,25 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::ApplyImpl( void FuseElewiseAddActPass::ApplyImpl(ir::Graph *graph) const {
std::unique_ptr<ir::Graph> graph) const {
std::unordered_set<std::string> act_types = {"relu", "scale"}; std::unordered_set<std::string> act_types = {"relu", "scale"};
graph = FuseActElewiseAdd(std::move(graph), act_types); graph = FuseActElewiseAdd(graph, act_types);
graph = FuseElewiseAddAct(std::move(graph), act_types); graph = FuseElewiseAddAct(graph, act_types);
// backward // backward
{ {
std::unordered_set<std::string> in_place_act_types = {"relu_grad"}; std::unordered_set<std::string> in_place_act_types = {"relu_grad"};
graph = FuseElewiseAddActInplaceGrad(std::move(graph), in_place_act_types); graph = FuseElewiseAddActInplaceGrad(graph, in_place_act_types);
} }
// Remove the removable intermediate_out. // Remove the removable intermediate_out.
RemoveIntermediateOut(graph.get()); RemoveIntermediateOut(graph);
return graph;
} }
// ele_add(x, act(y)) // ele_add(x, act(y))
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddAct( ir::Graph *FuseElewiseAddActPass::FuseElewiseAddAct(
std::unique_ptr<ir::Graph> graph, ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
const std::unordered_set<std::string> &act_types) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init("elewise_add_act", graph);
FusePassBase::Init("elewise_add_act", graph.get());
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern() auto *x = gpd.mutable_pattern()
...@@ -86,18 +84,17 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddAct( ...@@ -86,18 +84,17 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddAct(
found_elewise_add_act_count++; found_elewise_add_act_count++;
}; };
gpd(graph.get(), handler); gpd(graph, handler);
AddStatis(found_elewise_add_act_count); AddStatis(found_elewise_add_act_count);
return graph; return graph;
} }
// act(ele_add(x,y)) // act(ele_add(x,y))
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd( ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd(
std::unique_ptr<ir::Graph> graph, ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
const std::unordered_set<std::string> &act_types) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init("act_elewise_add", graph);
FusePassBase::Init("act_elewise_add", graph.get());
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern() auto *x = gpd.mutable_pattern()
...@@ -137,7 +134,7 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd( ...@@ -137,7 +134,7 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd(
found_elewise_add_act_count++; found_elewise_add_act_count++;
}; };
gpd(graph.get(), handler); gpd(graph, handler);
AddStatis(found_elewise_add_act_count); AddStatis(found_elewise_add_act_count);
return graph; return graph;
...@@ -146,11 +143,10 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd( ...@@ -146,11 +143,10 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd(
// the backward of act(ele_add(x,y)) // the backward of act(ele_add(x,y))
// act_grad: in["Out", "Out@GRAD"], out["X@GRAD"] // act_grad: in["Out", "Out@GRAD"], out["X@GRAD"]
// ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"] // ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"]
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad( ir::Graph *FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
std::unique_ptr<ir::Graph> graph, ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
const std::unordered_set<std::string> &act_types) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init("elewise_add_act_grad", graph);
FusePassBase::Init("elewise_add_act_grad", graph.get());
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto *d_act_out = gpd.mutable_pattern() auto *d_act_out = gpd.mutable_pattern()
...@@ -217,7 +213,7 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad( ...@@ -217,7 +213,7 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
found_elewise_add_act_count++; found_elewise_add_act_count++;
}; };
gpd(graph.get(), handler); gpd(graph, handler);
AddStatis(found_elewise_add_act_count); AddStatis(found_elewise_add_act_count);
return graph; return graph;
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include <string> #include <string>
#include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -32,20 +34,16 @@ class FuseElewiseAddActPass : public FusePassBase { ...@@ -32,20 +34,16 @@ class FuseElewiseAddActPass : public FusePassBase {
virtual ~FuseElewiseAddActPass() {} virtual ~FuseElewiseAddActPass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph *graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
std::unique_ptr<ir::Graph> FuseElewiseAddAct( ir::Graph *FuseElewiseAddAct(
std::unique_ptr<ir::Graph> graph, ir::Graph *graph, const std::unordered_set<std::string> &act_types) const;
const std::unordered_set<std::string> &act_types) const;
std::unique_ptr<ir::Graph> FuseActElewiseAdd( ir::Graph *FuseActElewiseAdd(
std::unique_ptr<ir::Graph> graph, ir::Graph *graph, const std::unordered_set<std::string> &act_types) const;
const std::unordered_set<std::string> &act_types) const;
std::unique_ptr<ir::Graph> FuseElewiseAddActInplaceGrad( ir::Graph *FuseElewiseAddActInplaceGrad(
std::unique_ptr<ir::Graph> graph, ir::Graph *graph, const std::unordered_set<std::string> &act_types) const;
const std::unordered_set<std::string> &act_types) const;
/** /**
* Remove the removable intermediate_out. * Remove the removable intermediate_out.
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/fuse_relu_depthwise_conv_pass.h" #include "paddle/fluid/framework/ir/fuse_relu_depthwise_conv_pass.h"
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -23,20 +24,18 @@ namespace paddle { ...@@ -23,20 +24,18 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<ir::Graph> FuseReluDepthwiseConvPass::ApplyImpl( void FuseReluDepthwiseConvPass::ApplyImpl(ir::Graph *graph) const {
std::unique_ptr<ir::Graph> graph) const { graph = FuseReluDepthwiseConv(graph, true);
graph = FuseReluDepthwiseConv(std::move(graph), true); graph = FuseReluDepthwiseConv(graph, false);
graph = FuseReluDepthwiseConv(std::move(graph), false);
return graph;
} }
std::unique_ptr<ir::Graph> FuseReluDepthwiseConvPass::FuseReluDepthwiseConv( ir::Graph *FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
std::unique_ptr<ir::Graph> graph, bool only_forward) const { ir::Graph *graph, bool only_forward) const {
PADDLE_ENFORCE(graph.get()); PADDLE_ENFORCE(graph);
if (only_forward) if (only_forward)
FusePassBase::Init("relu_depthwise_conv_only_forward", graph.get()); FusePassBase::Init("relu_depthwise_conv_only_forward", graph);
else else
FusePassBase::Init("relu_depthwise_conv", graph.get()); FusePassBase::Init("relu_depthwise_conv", graph);
/* /*
x ---act--> y ---layer-> z x ---act--> y ---layer-> z
+----------+ +----------+
...@@ -144,10 +143,9 @@ std::unique_ptr<ir::Graph> FuseReluDepthwiseConvPass::FuseReluDepthwiseConv( ...@@ -144,10 +143,9 @@ std::unique_ptr<ir::Graph> FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
} }
count++; count++;
}; };
gpd(graph.get(), handler); gpd(graph, handler);
GraphSafeRemoveNodes(graph.get(), need_removed_nodes); GraphSafeRemoveNodes(graph, need_removed_nodes);
AddStatis(count); AddStatis(count);
return graph; return graph;
} }
......
...@@ -32,10 +32,8 @@ class FuseReluDepthwiseConvPass : public FusePassBase { ...@@ -32,10 +32,8 @@ class FuseReluDepthwiseConvPass : public FusePassBase {
virtual ~FuseReluDepthwiseConvPass() {} virtual ~FuseReluDepthwiseConvPass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override; ir::Graph* FuseReluDepthwiseConv(ir::Graph* graph, bool only_forward) const;
std::unique_ptr<ir::Graph> FuseReluDepthwiseConv(
std::unique_ptr<ir::Graph> graph, bool only_forward) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -15,7 +15,9 @@ limitations under the License. */ ...@@ -15,7 +15,9 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h" #include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <map> #include <map>
#include <memory>
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -26,8 +28,7 @@ namespace paddle { ...@@ -26,8 +28,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl( void GraphToProgramPass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<Graph> graph) const {
// Remove the unneeded variables after memory optimization. // Remove the unneeded variables after memory optimization.
std::unordered_set<std::string> vars2remove; std::unordered_set<std::string> vars2remove;
if (graph->Has(kGraphToProgramVarsToRemove)) { if (graph->Has(kGraphToProgramVarsToRemove)) {
...@@ -73,7 +74,6 @@ std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl( ...@@ -73,7 +74,6 @@ std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl(
} }
program.CopyFrom(*program_pb); program.CopyFrom(*program_pb);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -26,7 +26,7 @@ const char kGraphToProgramSortKind[] = "__graph_to_program_sort_kind__"; ...@@ -26,7 +26,7 @@ const char kGraphToProgramSortKind[] = "__graph_to_program_sort_kind__";
class GraphToProgramPass : public Pass { class GraphToProgramPass : public Pass {
protected: protected:
std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const override; void ApplyImpl(ir::Graph* graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -14,7 +14,9 @@ limitations under the License. */ ...@@ -14,7 +14,9 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h" #include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <memory>
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
...@@ -84,7 +86,7 @@ TEST(GraphToProgramPass, Basic) { ...@@ -84,7 +86,7 @@ TEST(GraphToProgramPass, Basic) {
ProgramDesc compiled_prog; ProgramDesc compiled_prog;
pass->SetNotOwned<paddle::framework::ProgramDesc>("program", &compiled_prog); pass->SetNotOwned<paddle::framework::ProgramDesc>("program", &compiled_prog);
pass->Apply(std::move(g)); pass->Apply(g.get());
std::vector<OpDesc*> ops = compiled_prog.Block(0).AllOps(); std::vector<OpDesc*> ops = compiled_prog.Block(0).AllOps();
EXPECT_EQ(ops[0]->Type(), "op1"); EXPECT_EQ(ops[0]->Type(), "op1");
EXPECT_EQ(ops[1]->Type(), "op2"); EXPECT_EQ(ops[1]->Type(), "op2");
......
...@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include <algorithm> #include <algorithm>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/inference/analysis/dot.h" #include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
...@@ -38,8 +38,7 @@ std::string FormatName(const Node* node) { ...@@ -38,8 +38,7 @@ std::string FormatName(const Node* node) {
} }
} // namespace } // namespace
std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl( void GraphVizPass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
const std::string graph_viz_path = Get<std::string>(kGraphVizPath); const std::string graph_viz_path = Get<std::string>(kGraphVizPath);
VLOG(3) << "draw IR graph viz to " << graph_viz_path; VLOG(3) << "draw IR graph viz to " << graph_viz_path;
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path)); std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path));
...@@ -82,7 +81,7 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl( ...@@ -82,7 +81,7 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
{Dot::Attr("style", "filled,rounded"), Dot::Attr("shape", "box"), {Dot::Attr("style", "filled,rounded"), Dot::Attr("shape", "box"),
Dot::Attr("fillcolor", "yellow")}); Dot::Attr("fillcolor", "yellow")});
auto marked_nodes = ConsumeMarkedNodes(graph.get()); auto marked_nodes = ConsumeMarkedNodes(graph);
// Create nodes // Create nodes
for (const Node* n : graph->Nodes()) { for (const Node* n : graph->Nodes()) {
std::string node_id = FormatName(n) + "(" + std::to_string(n->id()) + ")"; std::string node_id = FormatName(n) + "(" + std::to_string(n->id()) + ")";
...@@ -115,8 +114,6 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl( ...@@ -115,8 +114,6 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
} }
sout << dot.Build(); sout << dot.Build();
return graph;
} }
GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes( GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
...@@ -135,4 +132,4 @@ GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes( ...@@ -135,4 +132,4 @@ GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
} // namespace paddle } // namespace paddle
REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass) REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass)
.RequirePassAttr(paddle::framework::ir::kGraphVizPath); .RequirePassAttr(paddle::framework::ir::kGraphVizPath);
\ No newline at end of file
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <map> #include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -34,8 +35,7 @@ class GraphVizPass : public Pass { ...@@ -34,8 +35,7 @@ class GraphVizPass : public Pass {
using marked_nodes_t = std::unordered_set<const Node*>; using marked_nodes_t = std::unordered_set<const Node*>;
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
// Tell whether there are any marked nodes in the graph. Consume the // Tell whether there are any marked nodes in the graph. Consume the
// corresponding attribute. // corresponding attribute.
......
...@@ -20,9 +20,8 @@ namespace paddle { ...@@ -20,9 +20,8 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<ir::Graph> IdentityScaleOpCleanPass::ApplyImpl( void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { FusePassBase::Init("identity_scale_op_clean", graph);
FusePassBase::Init("identity_scale_op_clean", graph.get());
// pre_op -> scale_in -> scale_op -> scale_out // pre_op -> scale_in -> scale_op -> scale_out
// -> // ->
...@@ -72,8 +71,7 @@ std::unique_ptr<ir::Graph> IdentityScaleOpCleanPass::ApplyImpl( ...@@ -72,8 +71,7 @@ std::unique_ptr<ir::Graph> IdentityScaleOpCleanPass::ApplyImpl(
IR_NODE_LINK_TO(pre_op_var, scale_out_var); IR_NODE_LINK_TO(pre_op_var, scale_out_var);
}; };
detector(graph.get(), handler); detector(graph, handler);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -22,8 +22,7 @@ namespace ir { ...@@ -22,8 +22,7 @@ namespace ir {
class IdentityScaleOpCleanPass : public FusePassBase { class IdentityScaleOpCleanPass : public FusePassBase {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
private: private:
virtual ~IdentityScaleOpCleanPass() = default; virtual ~IdentityScaleOpCleanPass() = default;
......
...@@ -26,9 +26,9 @@ class InferCleanGraphPass : public FusePassBase { ...@@ -26,9 +26,9 @@ class InferCleanGraphPass : public FusePassBase {
virtual ~InferCleanGraphPass() {} virtual ~InferCleanGraphPass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const { void ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("original_graph", graph.get()); FusePassBase::Init("original_graph", graph);
PADDLE_ENFORCE(graph.get()); PADDLE_ENFORCE(graph);
auto is_valid_node = [](Node* x) { auto is_valid_node = [](Node* x) {
return x && IsControlDepVar(*x) && x->IsVar() && !x->Var(); return x && IsControlDepVar(*x) && x->IsVar() && !x->Var();
...@@ -46,11 +46,9 @@ class InferCleanGraphPass : public FusePassBase { ...@@ -46,11 +46,9 @@ class InferCleanGraphPass : public FusePassBase {
} }
} }
GraphSafeRemoveNodes(graph.get(), invalid_nodes); GraphSafeRemoveNodes(graph, invalid_nodes);
AddStatis(valid_op); AddStatis(valid_op);
return graph;
} }
void CleanEdges(std::vector<Node*>* nodes, void CleanEdges(std::vector<Node*>* nodes,
......
...@@ -20,8 +20,7 @@ namespace paddle { ...@@ -20,8 +20,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<ir::Graph> IsTestPass::ApplyImpl( void IsTestPass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Sets is_test attrbiute to true and if it is missing, inserts it " VLOG(3) << "Sets is_test attrbiute to true and if it is missing, inserts it "
"for activations and pooling."; "for activations and pooling.";
auto op_list = {"pool2d", "sigmoid", "logsigmoid", auto op_list = {"pool2d", "sigmoid", "logsigmoid",
...@@ -47,7 +46,6 @@ std::unique_ptr<ir::Graph> IsTestPass::ApplyImpl( ...@@ -47,7 +46,6 @@ std::unique_ptr<ir::Graph> IsTestPass::ApplyImpl(
} }
} }
} }
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -22,8 +22,7 @@ namespace ir { ...@@ -22,8 +22,7 @@ namespace ir {
class IsTestPass : public Pass { class IsTestPass : public Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -97,7 +97,7 @@ TEST(IsTestPass, basic) { ...@@ -97,7 +97,7 @@ TEST(IsTestPass, basic) {
auto pass = PassRegistry::Instance().Get("is_test_pass"); auto pass = PassRegistry::Instance().Get("is_test_pass");
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (node->IsOp()) { if (node->IsOp()) {
......
...@@ -32,9 +32,8 @@ const char kSumGradOpName[] = "sum"; ...@@ -32,9 +32,8 @@ const char kSumGradOpName[] = "sum";
// other optimizers later. // other optimizers later.
const char kOptimizerType[] = "sgd"; const char kOptimizerType[] = "sgd";
std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl( void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get());
// We could collect all weights' name from SGD, where // We could collect all weights' name from SGD, where
// W1 <- SGD(W0, Grad0) // W1 <- SGD(W0, Grad0)
...@@ -92,14 +91,14 @@ std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl( ...@@ -92,14 +91,14 @@ std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl(
// find the forward op related to the backward op // find the forward op related to the backward op
ir::Node* forward_op = ir::Node* forward_op =
FindForwardOpViaBackwardOp(graph.get(), backward_op); FindForwardOpViaBackwardOp(graph, backward_op);
VLOG(3) << "Found forward_op " << forward_op->Name(); VLOG(3) << "Found forward_op " << forward_op->Name();
PADDLE_ENFORCE(forward_op); PADDLE_ENFORCE(forward_op);
Node* new_optimizer_node = CreateNewSGDNode( Node* new_optimizer_node = CreateNewSGDNode(
graph.get(), forward_op, backward_op, node, opt_node); graph, forward_op, backward_op, node, opt_node);
PADDLE_ENFORCE(new_optimizer_node); PADDLE_ENFORCE(new_optimizer_node);
} }
...@@ -140,8 +139,6 @@ std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl( ...@@ -140,8 +139,6 @@ std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl(
} }
} }
} }
return graph;
} }
ir::Node* LockFreeOptimizePass::CreateNewSGDNode( ir::Node* LockFreeOptimizePass::CreateNewSGDNode(
......
...@@ -60,8 +60,7 @@ class LockFreeOptimizePass : public Pass { ...@@ -60,8 +60,7 @@ class LockFreeOptimizePass : public Pass {
virtual ~LockFreeOptimizePass() {} virtual ~LockFreeOptimizePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
private: private:
// Create a new sgd node via current optimizer node // Create a new sgd node via current optimizer node
......
...@@ -38,10 +38,9 @@ LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b, ...@@ -38,10 +38,9 @@ LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
return vec_y; return vec_y;
} }
std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl( void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get());
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE(scope);
...@@ -99,7 +98,7 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl( ...@@ -99,7 +98,7 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
conv->Op()->SetOutput("Output", conv->Op()->SetOutput("Output",
std::vector<std::string>({eltwise_out->Name()})); std::vector<std::string>({eltwise_out->Name()}));
GraphSafeRemoveNodes(graph.get(), {eltwise, conv_out}); GraphSafeRemoveNodes(graph, {eltwise, conv_out});
IR_NODE_LINK_TO(conv, eltwise_out); IR_NODE_LINK_TO(conv, eltwise_out);
} else { } else {
...@@ -123,14 +122,13 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl( ...@@ -123,14 +122,13 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
IR_NODE_LINK_TO(eltwise_bias, conv_bias_node); IR_NODE_LINK_TO(eltwise_bias, conv_bias_node);
IR_NODE_LINK_TO(conv_bias_node, eltwise_out); IR_NODE_LINK_TO(conv_bias_node, eltwise_out);
GraphSafeRemoveNodes(graph.get(), {conv, eltwise, conv_out}); GraphSafeRemoveNodes(graph, {conv, eltwise, conv_out});
} }
found_conv_bias_count++; found_conv_bias_count++;
}; };
gpd(graph.get(), handler); gpd(graph, handler);
AddStatis(found_conv_bias_count); AddStatis(found_conv_bias_count);
return graph;
} }
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -29,8 +29,7 @@ class ConvBiasFusePass : public FusePassBase { ...@@ -29,8 +29,7 @@ class ConvBiasFusePass : public FusePassBase {
virtual bool is_conv3d() const { return false; } virtual bool is_conv3d() const { return false; }
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"conv_bias_mkldnn_fuse"}; const std::string name_scope_{"conv_bias_mkldnn_fuse"};
}; };
/* /*
......
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle { namespace paddle {
...@@ -103,7 +103,7 @@ void MainTest(bool convWithExistingBias) { ...@@ -103,7 +103,7 @@ void MainTest(bool convWithExistingBias) {
int original_nodes_num = graph->Nodes().size(); int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size(); int current_nodes_num = graph->Nodes().size();
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
#include <functional> #include <functional>
#include <list> #include <list>
#include <map> #include <map>
#include <memory>
#include <tuple> #include <tuple>
#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/graph_traits.h"
namespace paddle { namespace paddle {
...@@ -327,17 +327,15 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( ...@@ -327,17 +327,15 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
get_node_from_elementwise_add); get_node_from_elementwise_add);
} }
graph_ptr ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { void ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
FusePassBase::Init(name_scope_, graph.get()); FusePassBase::Init(name_scope_, graph);
auto fused_graph_with_stats = FuseConvAsY( auto fused_graph_with_stats = FuseConvAsY(
name_scope_, name_scope_,
FuseConvAsX( FuseConvAsX(name_scope_,
name_scope_, FuseProjectionConv(name_scope_, std::make_pair(graph, 0))));
FuseProjectionConv(name_scope_, std::make_pair(graph.get(), 0))));
std::cout << "Fused graph " << fused_graph_with_stats.second << std::endl; std::cout << "Fused graph " << fused_graph_with_stats.second << std::endl;
AddStatis(fused_graph_with_stats.second); AddStatis(fused_graph_with_stats.second);
return graph;
} }
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
...@@ -27,7 +28,7 @@ namespace paddle { ...@@ -27,7 +28,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
using graph_ptr = std::unique_ptr<ir::Graph>; using graph_ptr = ir::Graph*;
using GraphWithStats = std::pair<ir::Graph*, int>; using GraphWithStats = std::pair<ir::Graph*, int>;
void CorrectGraphEdges(Graph* graph, Node* from, Node* to); void CorrectGraphEdges(Graph* graph, Node* from, Node* to);
...@@ -124,7 +125,7 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { ...@@ -124,7 +125,7 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
virtual ~ResidualConnectionMKLDNNFusePass() {} virtual ~ResidualConnectionMKLDNNFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl(graph_ptr graph) const; void ApplyImpl(graph_ptr graph) const;
const std::string name_scope_{"residual_connection_fuse_pass"}; const std::string name_scope_{"residual_connection_fuse_pass"};
}; };
......
...@@ -148,7 +148,7 @@ void RunPassAndAssert(ProgramDesc* prog, const std::string& from, ...@@ -148,7 +148,7 @@ void RunPassAndAssert(ProgramDesc* prog, const std::string& from,
auto pass = auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size(); int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size(); int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)(from, to)); EXPECT_TRUE(is_reachable(graph)(from, to));
...@@ -258,7 +258,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) { ...@@ -258,7 +258,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) {
auto pass = auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size(); int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size(); int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "g")); EXPECT_TRUE(is_reachable(graph)("a", "g"));
......
...@@ -21,10 +21,9 @@ namespace paddle { ...@@ -21,10 +21,9 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl( void ConvReLUFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init("conv_relu_mkldnn_fuse", graph);
FusePassBase::Init("conv_relu_mkldnn_fuse", graph.get());
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* conv_input = gpd.mutable_pattern() auto* conv_input = gpd.mutable_pattern()
...@@ -56,7 +55,7 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl( ...@@ -56,7 +55,7 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
OpDesc* desc = conv->Op(); OpDesc* desc = conv->Op();
desc->SetOutput("Output", std::vector<std::string>({relu_out->Name()})); desc->SetOutput("Output", std::vector<std::string>({relu_out->Name()}));
desc->SetAttr("fuse_relu", true); desc->SetAttr("fuse_relu", true);
GraphSafeRemoveNodes(graph.get(), {relu, conv_out}); GraphSafeRemoveNodes(graph, {relu, conv_out});
PADDLE_ENFORCE(subgraph.count(conv_input)); PADDLE_ENFORCE(subgraph.count(conv_input));
IR_NODE_LINK_TO(conv, relu_out); IR_NODE_LINK_TO(conv, relu_out);
...@@ -64,10 +63,9 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl( ...@@ -64,10 +63,9 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
found_conv_relu_count++; found_conv_relu_count++;
}; };
gpd(graph.get(), handler); gpd(graph, handler);
AddStatis(found_conv_relu_count); AddStatis(found_conv_relu_count);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -31,8 +31,7 @@ class ConvReLUFusePass : public FusePassBase { ...@@ -31,8 +31,7 @@ class ConvReLUFusePass : public FusePassBase {
virtual ~ConvReLUFusePass() {} virtual ~ConvReLUFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -88,7 +88,7 @@ TEST(ConvReLUFusePass, basic) { ...@@ -88,7 +88,7 @@ TEST(ConvReLUFusePass, basic) {
int original_nodes_num = graph->Nodes().size(); int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size(); int current_nodes_num = graph->Nodes().size();
......
...@@ -216,19 +216,16 @@ void CPUQuantizePass::QuantizePool(Graph* graph) const { ...@@ -216,19 +216,16 @@ void CPUQuantizePass::QuantizePool(Graph* graph) const {
PrettyLogDetail("--- quantized %d pool2d ops", quantize_pool_count); PrettyLogDetail("--- quantized %d pool2d ops", quantize_pool_count);
} }
std::unique_ptr<ir::Graph> CPUQuantizePass::ApplyImpl( void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Quantizing the graph."; VLOG(3) << "Quantizing the graph.";
PADDLE_ENFORCE(graph.get()); PADDLE_ENFORCE(graph);
FusePassBase::Init(name_scope_, graph.get()); FusePassBase::Init(name_scope_, graph);
PADDLE_ENFORCE(param_scope()); PADDLE_ENFORCE(param_scope());
QuantizeConv(graph.get(), false /* with_residual_data */); QuantizeConv(graph, false /* with_residual_data */);
QuantizeConv(graph.get(), true /* with_residual_data */); QuantizeConv(graph, true /* with_residual_data */);
QuantizePool(graph.get()); QuantizePool(graph);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -42,8 +42,7 @@ class CPUQuantizePass : public FusePassBase { ...@@ -42,8 +42,7 @@ class CPUQuantizePass : public FusePassBase {
virtual ~CPUQuantizePass() {} virtual ~CPUQuantizePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
void QuantizeConv(Graph* graph, bool with_residual_data = false) const; void QuantizeConv(Graph* graph, bool with_residual_data = false) const;
......
...@@ -139,7 +139,7 @@ void MainTest(const ProgramDesc& prog, int conv_count, int pool_count, ...@@ -139,7 +139,7 @@ void MainTest(const ProgramDesc& prog, int conv_count, int pool_count,
int original_nodes_num = graph->Nodes().size(); int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size(); int current_nodes_num = graph->Nodes().size();
......
...@@ -20,8 +20,7 @@ namespace paddle { ...@@ -20,8 +20,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<ir::Graph> CPUQuantizePlacementPass::ApplyImpl( void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Marks operators which are to be quantized."; VLOG(3) << "Marks operators which are to be quantized.";
const auto& excluded_ids_list = const auto& excluded_ids_list =
Get<std::unordered_set<int>>("quantize_excluded_op_ids"); Get<std::unordered_set<int>>("quantize_excluded_op_ids");
...@@ -43,7 +42,6 @@ std::unique_ptr<ir::Graph> CPUQuantizePlacementPass::ApplyImpl( ...@@ -43,7 +42,6 @@ std::unique_ptr<ir::Graph> CPUQuantizePlacementPass::ApplyImpl(
} }
} }
} }
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -25,8 +25,7 @@ namespace ir { ...@@ -25,8 +25,7 @@ namespace ir {
*/ */
class CPUQuantizePlacementPass : public Pass { class CPUQuantizePlacementPass : public Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -94,7 +94,7 @@ void MainTest(std::initializer_list<std::string> quantize_enabled_op_types, ...@@ -94,7 +94,7 @@ void MainTest(std::initializer_list<std::string> quantize_enabled_op_types,
pass->Set("quantize_excluded_op_ids", pass->Set("quantize_excluded_op_ids",
new std::unordered_set<int>(quantize_excluded_op_ids)); new std::unordered_set<int>(quantize_excluded_op_ids));
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
unsigned use_quantizer_true_count = 0; unsigned use_quantizer_true_count = 0;
......
...@@ -126,16 +126,13 @@ void CPUQuantizeSquashPass::Squash( ...@@ -126,16 +126,13 @@ void CPUQuantizeSquashPass::Squash(
found_squash_count); found_squash_count);
} }
std::unique_ptr<ir::Graph> CPUQuantizeSquashPass::ApplyImpl( void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init("cpu_quantize_squash_pass", graph);
FusePassBase::Init("cpu_quantize_squash_pass", graph.get());
std::unordered_map<const Node*, int> nodes_keep_counter; std::unordered_map<const Node*, int> nodes_keep_counter;
FindNodesToKeep(graph.get(), &nodes_keep_counter); FindNodesToKeep(graph, &nodes_keep_counter);
Squash(graph.get(), &nodes_keep_counter); Squash(graph, &nodes_keep_counter);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -34,8 +34,7 @@ class CPUQuantizeSquashPass : public FusePassBase { ...@@ -34,8 +34,7 @@ class CPUQuantizeSquashPass : public FusePassBase {
virtual ~CPUQuantizeSquashPass() {} virtual ~CPUQuantizeSquashPass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
/* /*
* For each dequantize's output find the number of operators it is an input to * For each dequantize's output find the number of operators it is an input to
......
...@@ -125,7 +125,7 @@ void MainTest(const ProgramDesc& prog, int removed_nodes_num) { ...@@ -125,7 +125,7 @@ void MainTest(const ProgramDesc& prog, int removed_nodes_num) {
int original_nodes_num = graph->Nodes().size(); int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size(); int current_nodes_num = graph->Nodes().size();
......
...@@ -25,10 +25,9 @@ namespace ir { ...@@ -25,10 +25,9 @@ namespace ir {
auto* id = subgraph.at(pattern.RetrieveNode(#id)); \ auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
std::unique_ptr<ir::Graph> DepthwiseConvMKLDNNPass::ApplyImpl( void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init("depthwise_conv_mkldnn_pass", graph);
FusePassBase::Init("depthwise_conv_mkldnn_pass", graph.get());
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
...@@ -45,9 +44,8 @@ std::unique_ptr<ir::Graph> DepthwiseConvMKLDNNPass::ApplyImpl( ...@@ -45,9 +44,8 @@ std::unique_ptr<ir::Graph> DepthwiseConvMKLDNNPass::ApplyImpl(
found_depthwise_conv_mkldnn_count++; found_depthwise_conv_mkldnn_count++;
}; };
gpd(graph.get(), handler); gpd(graph, handler);
AddStatis(found_depthwise_conv_mkldnn_count); AddStatis(found_depthwise_conv_mkldnn_count);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -25,8 +25,7 @@ class DepthwiseConvMKLDNNPass : public FusePassBase { ...@@ -25,8 +25,7 @@ class DepthwiseConvMKLDNNPass : public FusePassBase {
virtual ~DepthwiseConvMKLDNNPass() {} virtual ~DepthwiseConvMKLDNNPass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -86,7 +86,7 @@ TEST(DepthwiseConvMKLDNNPass, basic) { ...@@ -86,7 +86,7 @@ TEST(DepthwiseConvMKLDNNPass, basic) {
counters before{1, 1, 1, 1}; counters before{1, 1, 1, 1};
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
// initialize counters before loop // initialize counters before loop
counters after{0, 0, 0, 0}; counters after{0, 0, 0, 0};
......
...@@ -14,13 +14,13 @@ limitations under the License. */ ...@@ -14,13 +14,13 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h" #include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h"
#include <string> #include <string>
#include <unordered_set>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl( void MKLDNNPlacementPass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Applies MKL-DNN placement strategy."; VLOG(3) << "Applies MKL-DNN placement strategy.";
const auto& op_types_list = const auto& op_types_list =
Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types"); Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types");
...@@ -37,7 +37,6 @@ std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl( ...@@ -37,7 +37,6 @@ std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl(
} }
} }
} }
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -26,8 +26,7 @@ namespace ir { ...@@ -26,8 +26,7 @@ namespace ir {
*/ */
class MKLDNNPlacementPass : public Pass { class MKLDNNPlacementPass : public Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -97,7 +97,7 @@ void MainTest(std::initializer_list<std::string> mkldnn_enabled_op_types, ...@@ -97,7 +97,7 @@ void MainTest(std::initializer_list<std::string> mkldnn_enabled_op_types,
pass->Set("mkldnn_enabled_op_types", pass->Set("mkldnn_enabled_op_types",
new std::unordered_set<std::string>(mkldnn_enabled_op_types)); new std::unordered_set<std::string>(mkldnn_enabled_op_types));
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
unsigned use_mkldnn_true_count = 0; unsigned use_mkldnn_true_count = 0;
......
...@@ -16,8 +16,9 @@ ...@@ -16,8 +16,9 @@
#include <map> #include <map>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
...@@ -68,8 +69,7 @@ VarDesc UpdateGradVarDesc( ...@@ -68,8 +69,7 @@ VarDesc UpdateGradVarDesc(
return *var_desc; return *var_desc;
} }
std::unique_ptr<Graph> BatchMergePass::ApplyImpl( void BatchMergePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<Graph> graph) const {
int num_repeats = Get<const int>(kNumRepeats); int num_repeats = Get<const int>(kNumRepeats);
std::vector<Node*> forward_backward_ops; std::vector<Node*> forward_backward_ops;
std::vector<Node*> optimize_ops; std::vector<Node*> optimize_ops;
...@@ -325,7 +325,6 @@ std::unique_ptr<Graph> BatchMergePass::ApplyImpl( ...@@ -325,7 +325,6 @@ std::unique_ptr<Graph> BatchMergePass::ApplyImpl(
} }
result.ResolveHazard(created); result.ResolveHazard(created);
return graph;
} }
} // namespace ir } // namespace ir
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册