提交 2f9e5621 编写于 作者: Z zhoukunsheng

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into linspace

...@@ -193,6 +193,12 @@ if(WITH_GPU) ...@@ -193,6 +193,12 @@ if(WITH_GPU)
include(tensorrt) include(tensorrt)
include(anakin_subgraph) include(anakin_subgraph)
endif() endif()
if(WITH_GPU AND NOT WIN32)
message(STATUS "add dgc lib.")
include(external/dgc)
endif()
if(WITH_MKL OR WITH_MKLML) if(WITH_MKL OR WITH_MKLML)
include(external/anakin) include(external/anakin)
elseif() elseif()
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
INCLUDE(ExternalProject)
SET(DGC_SOURCES_DIR "${THIRD_PARTY_PATH}/dgc")
SET(DGC_INSTALL_DIR "${THIRD_PARTY_PATH}/install/dgc")
SET(DGC_INCLUDE_DIR "${DGC_INSTALL_DIR}/include" CACHE PATH "dgc include directory." FORCE)
SET(DGC_LIBRARIES "${DGC_INSTALL_DIR}/lib/libdgc.a" CACHE FILEPATH "dgc library." FORCE)
INCLUDE_DIRECTORIES(${DGC_INCLUDE_DIR})
ExternalProject_Add(
extern_dgc
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/PaddlePaddle/Fleet"
GIT_TAG "2d04dc3800cdd0601f1b65d547dabcc60b0cf9dc"
SOURCE_DIR "${DGC_SOURCES_DIR}"
CONFIGURE_COMMAND ""
BUILD_COMMAND cd collective && make -j
INSTALL_COMMAND mkdir -p ${DGC_INSTALL_DIR}/lib/ ${DGC_INCLUDE_DIR}/dgc
&& cp ${DGC_SOURCES_DIR}/collective/build/lib/libdgc.a ${DGC_LIBRARIES}
&& cp ${DGC_SOURCES_DIR}/collective/build/include/dgc.h ${DGC_INCLUDE_DIR}/dgc/
BUILD_IN_SOURCE 1
)
ADD_LIBRARY(dgc SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET dgc PROPERTY IMPORTED_LOCATION ${DGC_LIBRARIES})
ADD_DEPENDENCIES(dgc extern_dgc)
LIST(APPEND external_project_dependencies dgc)
...@@ -62,6 +62,11 @@ ExternalProject_Add( ...@@ -62,6 +62,11 @@ ExternalProject_Add(
GIT_TAG ${NGRAPH_GIT_TAG} GIT_TAG ${NGRAPH_GIT_TAG}
PREFIX ${NGRAPH_SOURCES_DIR} PREFIX ${NGRAPH_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_GENERATOR ${CMAKE_GENERATOR}
CMAKE_GENERATOR_PLATFORM ${CMAKE_GENERATOR_PLATFORM}
CMAKE_GENERATOR_TOOLSET ${CMAKE_GENERATOR_TOOLSET}
CMAKE_ARGS -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${NGRAPH_INSTALL_DIR} CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${NGRAPH_INSTALL_DIR}
CMAKE_ARGS -DNGRAPH_UNIT_TEST_ENABLE=FALSE CMAKE_ARGS -DNGRAPH_UNIT_TEST_ENABLE=FALSE
CMAKE_ARGS -DNGRAPH_TOOLS_ENABLE=FALSE CMAKE_ARGS -DNGRAPH_TOOLS_ENABLE=FALSE
......
...@@ -131,6 +131,15 @@ elseif (NOT CBLAS_FOUND OR WIN32) ...@@ -131,6 +131,15 @@ elseif (NOT CBLAS_FOUND OR WIN32)
) )
endif () endif ()
if (WITH_GPU AND NOT WIN32)
set(dgc_dir "${FLUID_INSTALL_DIR}/third_party/install/dgc")
copy(dgc_lib
SRCS ${DGC_INSTALL_DIR}/lib ${DGC_INSTALL_DIR}/include
DSTS ${dgc_dir} ${dgc_dir}
DEPS dgc)
endif()
if (WITH_MKLDNN) if (WITH_MKLDNN)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/mkldnn") set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/mkldnn")
copy(mkldnn_lib copy(mkldnn_lib
......
...@@ -110,7 +110,7 @@ function(op_library TARGET) ...@@ -110,7 +110,7 @@ function(op_library TARGET)
# Define operators that don't need pybind here. # Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op") "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}") if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1) set(pybind_flag 1)
endif() endif()
......
...@@ -211,7 +211,7 @@ paddle.fluid.layers.mean (ArgSpec(args=['x', 'name'], varargs=None, keywords=Non ...@@ -211,7 +211,7 @@ paddle.fluid.layers.mean (ArgSpec(args=['x', 'name'], varargs=None, keywords=Non
paddle.fluid.layers.mul (ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None)), ('document', 'ccd37fa6b53f074adbfb732d738c4c2d')) paddle.fluid.layers.mul (ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None)), ('document', 'ccd37fa6b53f074adbfb732d738c4c2d'))
paddle.fluid.layers.sigmoid_cross_entropy_with_logits (ArgSpec(args=['x', 'label', 'ignore_index', 'name', 'normalize'], varargs=None, keywords=None, defaults=(-100, None, False)), ('document', '180c284317ea45ef89a460d8d79c0b72')) paddle.fluid.layers.sigmoid_cross_entropy_with_logits (ArgSpec(args=['x', 'label', 'ignore_index', 'name', 'normalize'], varargs=None, keywords=None, defaults=(-100, None, False)), ('document', '180c284317ea45ef89a460d8d79c0b72'))
paddle.fluid.layers.maxout (ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '891870d069a6aea746d34cc53b61690c')) paddle.fluid.layers.maxout (ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '891870d069a6aea746d34cc53b61690c'))
paddle.fluid.layers.space_to_depth (ArgSpec(args=['x', 'blocksize', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '5f207ae10589ebe38a63575ef6ff8e1e')) paddle.fluid.layers.space_to_depth (ArgSpec(args=['x', 'blocksize', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'a9221eaef53884a00654e028551b78e2'))
paddle.fluid.layers.affine_grid (ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '51def402b8910e163cbace9d0c0526ed')) paddle.fluid.layers.affine_grid (ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '51def402b8910e163cbace9d0c0526ed'))
paddle.fluid.layers.sequence_reverse (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '77a6d80aa5551ca70324fc975c44507f')) paddle.fluid.layers.sequence_reverse (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '77a6d80aa5551ca70324fc975c44507f'))
paddle.fluid.layers.affine_channel (ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name', 'act'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None, None)), ('document', 'ab84fdc6dc60f3ad9aa397e6007e3bf9')) paddle.fluid.layers.affine_channel (ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name', 'act'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None, None)), ('document', 'ab84fdc6dc60f3ad9aa397e6007e3bf9'))
...@@ -484,6 +484,11 @@ paddle.fluid.optimizer.LarsMomentumOptimizer.apply_gradients (ArgSpec(args=['sel ...@@ -484,6 +484,11 @@ paddle.fluid.optimizer.LarsMomentumOptimizer.apply_gradients (ArgSpec(args=['sel
paddle.fluid.optimizer.LarsMomentumOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.LarsMomentumOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f'))
paddle.fluid.optimizer.LarsMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.LarsMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.LarsMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) paddle.fluid.optimizer.LarsMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea'))
paddle.fluid.optimizer.DGCMomentumOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'momentum', 'rampup_begin_step', 'rampup_step', 'sparsity', 'use_nesterov', 'local_grad_clip_norm', 'num_trainers', 'regularization', 'name'], varargs=None, keywords=None, defaults=(1, [0.999], False, None, None, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.DGCMomentumOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871'))
paddle.fluid.optimizer.DGCMomentumOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f'))
paddle.fluid.optimizer.DGCMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.DGCMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea'))
paddle.fluid.backward.append_backward (ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '1a79bd7d10ae54ca763ec81bca36ba24')) paddle.fluid.backward.append_backward (ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '1a79bd7d10ae54ca763ec81bca36ba24'))
paddle.fluid.regularizer.L1DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.regularizer.L1DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.regularizer.L2DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.regularizer.L2DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
......
...@@ -23,7 +23,7 @@ endif() ...@@ -23,7 +23,7 @@ endif()
if(WITH_GPU) if(WITH_GPU)
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor) dynload_cuda variable_visitor dgc)
nv_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory nv_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor) dynload_cuda variable_visitor)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
......
...@@ -42,8 +42,7 @@ VarHandle* GetValidInput(const OpHandleBase* a) { ...@@ -42,8 +42,7 @@ VarHandle* GetValidInput(const OpHandleBase* a) {
return nullptr; return nullptr;
} }
std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl( void AllReduceDepsPass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
auto graph_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph); auto graph_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
// get vars order // get vars order
...@@ -86,7 +85,8 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl( ...@@ -86,7 +85,8 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
} }
} }
VLOG(10) << "dist_ops size:" << dist_ops.size() << std::endl; VLOG(10) << "dist_ops size:" << dist_ops.size()
<< ", outputs size:" << vars.size() << ", ops size:" << ops.size();
std::sort(dist_ops.begin(), dist_ops.end(), [&](OpHandleBase* op1, std::sort(dist_ops.begin(), dist_ops.end(), [&](OpHandleBase* op1,
OpHandleBase* op2) { OpHandleBase* op2) {
...@@ -99,6 +99,10 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl( ...@@ -99,6 +99,10 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
auto l_it = vars.find(i0->name()); auto l_it = vars.find(i0->name());
auto r_it = vars.find(i1->name()); auto r_it = vars.find(i1->name());
PADDLE_ENFORCE(l_it != vars.end() && r_it != vars.end(),
"can't find var's name %s and %s in opdesc", i0->name(),
i1->name());
if (l_it->second < r_it->second) return true; if (l_it->second < r_it->second) return true;
if (l_it->second == r_it->second) { if (l_it->second == r_it->second) {
...@@ -126,8 +130,6 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl( ...@@ -126,8 +130,6 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
VLOG(10) << "pre_op:" << pre_op->DebugString() VLOG(10) << "pre_op:" << pre_op->DebugString()
<< ", op:" << op->DebugString(); << ", op:" << op->DebugString();
} }
return graph;
} }
} // namespace details } // namespace details
......
...@@ -24,8 +24,7 @@ namespace details { ...@@ -24,8 +24,7 @@ namespace details {
// TODO(gongwb): overlap allreduce with backward computation. // TODO(gongwb): overlap allreduce with backward computation.
class AllReduceDepsPass : public ir::Pass { class AllReduceDepsPass : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace details } // namespace details
......
...@@ -16,6 +16,13 @@ ...@@ -16,6 +16,13 @@
#include "paddle/fluid/framework/details/container_cast.h" #include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h" #include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h" #include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/framework/operator.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "dgc/dgc.h"
#endif
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
// asynchronous nccl allreduce or synchronous issue: // asynchronous nccl allreduce or synchronous issue:
...@@ -33,11 +40,14 @@ namespace details { ...@@ -33,11 +40,14 @@ namespace details {
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs) const platform::NCCLContextMap *ctxs,
bool is_encoded, int nranks)
: OpHandleBase(node), : OpHandleBase(node),
local_scopes_(local_scopes), local_scopes_(local_scopes),
places_(places), places_(places),
nccl_ctxs_(ctxs) { nccl_ctxs_(ctxs),
is_encoded_(is_encoded),
nranks_(nranks) {
if (nccl_ctxs_) { if (nccl_ctxs_) {
for (auto &p : places_) { for (auto &p : places_) {
this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p)); this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p));
...@@ -51,7 +61,185 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, ...@@ -51,7 +61,185 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {} : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif #endif
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void AllReduceOpHandle::RunImplEncoded() {
platform::RecordEvent record_event(Name());
WaitInputVarGenerated();
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE_EQ(
in_var_handles.size(), places_.size(),
"The NoDummyInputSize should be equal to the number of places.");
PADDLE_ENFORCE_EQ(
in_var_handles.size(), out_var_handles.size(),
"The NoDummyInputSize and NoDummyOutputSize should be equal.");
std::vector<const LoDTensor *> ins;
std::vector<LoDTensor *> outs;
int k = -1;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &local_scope =
local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto original_name =
paddle::framework::GradOriginalVarName(in_var_handles[i]->name());
auto encode_var_name = original_name + g_dgc_encoded;
auto *in_var = local_scope->FindVar(encode_var_name);
PADDLE_ENFORCE_NOT_NULL(in_var);
auto &in = in_var->Get<LoDTensor>();
ins.emplace_back(&in);
auto *out = local_scope->FindVar(out_var_handles[i]->name())
->GetMutable<LoDTensor>();
outs.emplace_back(out);
if (k < 0) {
k = GetKValue(in_var_handles[i]->name());
}
}
PADDLE_ENFORCE(platform::is_gpu_place(ins[0]->place()));
PADDLE_ENFORCE(platform::is_gpu_place(outs[0]->place()));
PADDLE_ENFORCE(nccl_ctxs_, "nccl_ctxs should not be nullptr.");
int dtype = -1;
size_t in_numel = 0;
size_t out_numel = 0;
PADDLE_ENFORCE(nranks_ > 1);
std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &place = places_[i];
auto &in = *ins[i];
void *in_tensor_buf = const_cast<void *>(in.data<void>());
auto &out = *outs[i];
float *out_tensor_buf = out.data<float>();
dtype = (dtype == -1) ? platform::ToNCCLDataType(in.type()) : dtype;
in_numel = (in_numel == 0) ? static_cast<size_t>(in.numel()) : in_numel;
PADDLE_ENFORCE(in_numel % 2 == 0);
PADDLE_ENFORCE(in_numel / 2 == static_cast<size_t>(k));
out_numel = (out_numel == 0) ? static_cast<size_t>(out.numel()) : out_numel;
int dev_id = boost::get<platform::CUDAPlace>(place).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
auto &allocator =
platform::DeviceTemporaryAllocator::Instance().Get(place, stream);
int encode_size = 2 * k * sizeof(int);
// dgc use ncclAllGather to get all the encoded data
// so the buffer need nranks.
int buf_size = nranks_ * encode_size;
auto tmp_ious_data = allocator.Allocate(buf_size);
void *gather_buff = reinterpret_cast<void *>(tmp_ious_data->ptr());
VLOG(10) << "in_numel:" << in_numel << ", out_numel:" << out_numel
<< ", nranks:" << nranks_ << ", gather_buf size:" << buf_size
<< ", k:" << k << ", place:" << place << ", dtype:" << dtype;
all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(paddle::communication::dgc::sparseAllGReduce(
in_tensor_buf, gather_buff, k, out_tensor_buf, out_numel, comm,
stream));
});
}
this->RunAndRecordEvent([&] {
if (all_reduce_calls.size() == 1UL) {
// Do not use NCCLGroup when manage NCCL by per thread per device
all_reduce_calls[0]();
} else {
platform::NCCLGroupGuard guard;
for (auto &call : all_reduce_calls) {
call();
}
}
});
if (FLAGS_sync_nccl_allreduce) {
for (auto &p : places_) {
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
cudaError_t e_sync = cudaStreamSynchronize(stream);
if (e_sync != 0) {
LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync);
}
cudaError_t e_get = cudaGetLastError();
if (e_get != 0) {
LOG(FATAL) << "cudaGetLastError " << cudaGetErrorString(e_get)
<< " errno:" << e_get;
}
}
}
}
int AllReduceOpHandle::GetKValue(const std::string &grad_name) {
auto original_name = paddle::framework::GradOriginalVarName(grad_name);
auto var_name = original_name + g_dgc_k;
PADDLE_ENFORCE(local_scopes_.size() > 0);
auto *scope = local_scopes_[0];
auto &local_scope = scope->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto var = local_scope->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(var);
auto tensor = var->Get<LoDTensor>().data<float>();
return *tensor;
}
#endif
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
bool AllReduceOpHandle::IsEncoded() {
if (!is_encoded_) {
return false;
}
auto counter_name = g_dgc_counter_name;
auto step_name = g_dgc_rampup_begin_step;
PADDLE_ENFORCE(local_scopes_.size() > 0);
auto *scope = local_scopes_[0];
auto &local_scope = scope->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto count_var = local_scope->FindVar(counter_name);
auto step_var = local_scope->FindVar(step_name);
if (count_var == nullptr || step_var == nullptr) {
PADDLE_THROW("not find count_var:%s or step_var:%s", counter_name,
step_var);
}
float count = *count_var->Get<LoDTensor>().data<float>();
float step = *step_var->Get<LoDTensor>().data<float>();
if (static_cast<int>(count) < static_cast<int>(step)) {
VLOG(10) << "in all_reduce currentstep:" << count
<< " < rampup_begin_step:" << step
<< " so not use sparse all reduce";
return false;
}
return true;
}
#else
bool AllReduceOpHandle::IsEncoded() { return false; }
#endif
void AllReduceOpHandle::RunImpl() { void AllReduceOpHandle::RunImpl() {
if (!IsEncoded()) {
RunImplNormal();
return;
}
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
RunImplEncoded();
#else
PADDLE_THROW("Not compiled with CUDA");
#endif
}
void AllReduceOpHandle::RunImplNormal() {
platform::RecordEvent record_event(Name()); platform::RecordEvent record_event(Name());
WaitInputVarGenerated(); WaitInputVarGenerated();
...@@ -72,6 +260,8 @@ void AllReduceOpHandle::RunImpl() { ...@@ -72,6 +260,8 @@ void AllReduceOpHandle::RunImpl() {
auto &lod_tensor = auto &lod_tensor =
local_scope.FindVar(in_var_handles[i]->name())->Get<LoDTensor>(); local_scope.FindVar(in_var_handles[i]->name())->Get<LoDTensor>();
lod_tensors.emplace_back(&lod_tensor); lod_tensors.emplace_back(&lod_tensor);
VLOG(10) << "place:" << i << ", input_name:" << in_var_handles[i]->name()
<< ", out_name:" << out_var_handles[i]->name();
PADDLE_ENFORCE_EQ(in_var_handles[i]->name(), out_var_handles[i]->name(), PADDLE_ENFORCE_EQ(in_var_handles[i]->name(), out_var_handles[i]->name(),
"The name of input and output should be equal."); "The name of input and output should be equal.");
} }
...@@ -99,13 +289,17 @@ void AllReduceOpHandle::RunImpl() { ...@@ -99,13 +289,17 @@ void AllReduceOpHandle::RunImpl() {
auto &nccl_ctx = nccl_ctxs_->at(dev_id); auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream(); auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_; auto comm = nccl_ctx.comm_;
VLOG(10) << "before all reduce buffer:" << buffer << ", numel:" << numel
<< ", dev_id:" << dev_id << ", dtype:" << dtype
<< ", place:" << p;
all_reduce_calls.emplace_back([=] { all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(platform::dynload::ncclAllReduce( PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum, buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
comm, stream)); comm, stream));
}); });
} }
this->RunAndRecordEvent([&] { this->RunAndRecordEvent([&] {
if (all_reduce_calls.size() == 1UL) { if (all_reduce_calls.size() == 1UL) {
// Do not use NCCLGroup when manage NCCL by per thread per device // Do not use NCCLGroup when manage NCCL by per thread per device
......
...@@ -28,11 +28,19 @@ namespace paddle { ...@@ -28,11 +28,19 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
constexpr char g_dgc_counter_name[] = "__g_dgc_counter__";
constexpr char g_dgc_rampup_begin_step[] = "__g_rampup_begin_step__";
constexpr char g_dgc_encoded[] = "__dgc_encoded__";
constexpr char g_dgc_k[] = "__dgc_k__";
#endif
struct AllReduceOpHandle : public OpHandleBase { struct AllReduceOpHandle : public OpHandleBase {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes, AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs); const platform::NCCLContextMap *ctxs,
bool is_encoded = false, int nranks = -1);
#else #else
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes, AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places); const std::vector<platform::Place> &places);
...@@ -50,8 +58,14 @@ struct AllReduceOpHandle : public OpHandleBase { ...@@ -50,8 +58,14 @@ struct AllReduceOpHandle : public OpHandleBase {
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void RunImplEncoded();
const platform::NCCLContextMap *nccl_ctxs_; const platform::NCCLContextMap *nccl_ctxs_;
bool is_encoded_{false};
int nranks_{-1};
int GetKValue(const std::string &grad_name);
#endif #endif
void RunImplNormal();
bool IsEncoded();
}; };
} // namespace details } // namespace details
......
...@@ -46,8 +46,7 @@ static framework::proto::VarType::Type kDefaultDtype = ...@@ -46,8 +46,7 @@ static framework::proto::VarType::Type kDefaultDtype =
class AllocContinuousSpaceForGradPass : public ir::Pass { class AllocContinuousSpaceForGradPass : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph *graph) const override {
std::unique_ptr<ir::Graph> graph) const override {
ir::Graph &result = *graph; ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces); auto &places = Get<const std::vector<platform::Place>>(kPlaces);
...@@ -65,7 +64,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -65,7 +64,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
if (params_grads.size() == 0) { if (params_grads.size() == 0) {
VLOG(10) << "Doesn't find gradients"; VLOG(10) << "Doesn't find gradients";
return std::move(graph); return;
} }
std::unordered_map<std::string, ir::Node *> vars; std::unordered_map<std::string, ir::Node *> vars;
...@@ -124,8 +123,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass { ...@@ -124,8 +123,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars, InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars,
fused_var_name, params_grads); fused_var_name, params_grads);
return std::move(graph);
} }
template <typename AttrType> template <typename AttrType>
......
...@@ -204,13 +204,14 @@ bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const { ...@@ -204,13 +204,14 @@ bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
return framework::details::MultiDevSSAGraphBuilder().count(pass_name) > 0; return framework::details::MultiDevSSAGraphBuilder().count(pass_name) > 0;
} }
std::unique_ptr<ir::Graph> BuildStrategy::Apply( ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
std::unique_ptr<ir::Graph> graph,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::vector<Scope *> &local_scopes, const std::string &loss_var_name,
const std::vector<Scope *> &local_scopes,
const size_t &nranks, const size_t &nranks,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const { const bool use_cuda,
platform::NCCLContextMap *nccl_ctxs) const {
#else #else
const bool use_cuda) const { const bool use_cuda) const {
#endif #endif
...@@ -265,7 +266,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -265,7 +266,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
} }
} }
VLOG(3) << "Start Apply Pass " << pass->Type(); VLOG(3) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(std::move(graph)); graph = pass->Apply(graph);
VLOG(3) << "Finish Apply Pass " << pass->Type(); VLOG(3) << "Finish Apply Pass " << pass->Type();
} }
return graph; return graph;
......
...@@ -120,8 +120,7 @@ struct BuildStrategy { ...@@ -120,8 +120,7 @@ struct BuildStrategy {
// Apply the passes built by the pass_builder_. The passes will be // Apply the passes built by the pass_builder_. The passes will be
// applied to the Program and output an ir::Graph. // applied to the Program and output an ir::Graph.
std::unique_ptr<ir::Graph> Apply(std::unique_ptr<ir::Graph> graph, ir::Graph *Apply(ir::Graph *graph, const std::vector<platform::Place> &places,
const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const size_t &nranks, const size_t &nranks,
......
...@@ -170,12 +170,10 @@ static OpToVarNameSetMap ShrinkGCVars( ...@@ -170,12 +170,10 @@ static OpToVarNameSetMap ShrinkGCVars(
class EagerDeletionPass : public ir::Pass { class EagerDeletionPass : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph *graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = auto &ref_cnts =
Get<std::vector<AtomicReferenceCountMap>>(kRuntimeReferenceCount); Get<std::vector<AtomicReferenceCountMap>>(kRuntimeReferenceCount);
PADDLE_ENFORCE(ref_cnts.empty(), PADDLE_ENFORCE(ref_cnts.empty(),
...@@ -240,7 +238,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( ...@@ -240,7 +238,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
auto while_op_eager_deletion_pass = auto while_op_eager_deletion_pass =
ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass"); ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass");
return while_op_eager_deletion_pass->Apply(std::move(graph)); while_op_eager_deletion_pass->Apply(graph);
} }
} // namespace details } // namespace details
......
...@@ -28,8 +28,7 @@ namespace details { ...@@ -28,8 +28,7 @@ namespace details {
class FuseAllReduceOpPass : public ir::Pass { class FuseAllReduceOpPass : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph *graph) const override {
std::unique_ptr<ir::Graph> graph) const override {
ir::Graph &result = *graph; ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces); auto &places = Get<const std::vector<platform::Place>>(kPlaces);
...@@ -71,7 +70,7 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -71,7 +70,7 @@ class FuseAllReduceOpPass : public ir::Pass {
VLOG(10) << "Find all_reduce_ops: " << all_reduce_ops.size(); VLOG(10) << "Find all_reduce_ops: " << all_reduce_ops.size();
if (all_reduce_ops.size() == 0) { if (all_reduce_ops.size() == 0) {
return std::move(graph); return;
} }
PADDLE_ENFORCE_EQ(all_reduce_ops.size(), grads.size(), PADDLE_ENFORCE_EQ(all_reduce_ops.size(), grads.size(),
...@@ -99,7 +98,6 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -99,7 +98,6 @@ class FuseAllReduceOpPass : public ir::Pass {
group_all_reduce_ops, &result); group_all_reduce_ops, &result);
#endif #endif
} }
return std::move(graph);
} }
void InsertFusedAllReduce(const std::vector<platform::Place> &places, void InsertFusedAllReduce(const std::vector<platform::Place> &places,
......
...@@ -144,10 +144,9 @@ void InplacePass::InitSSAGraphNodes() const { ...@@ -144,10 +144,9 @@ void InplacePass::InitSSAGraphNodes() const {
} }
} }
std::unique_ptr<ir::Graph> InplacePass::ApplyImpl( void InplacePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
var_nodes_.clear(); var_nodes_.clear();
view_.Build(graph.get()); view_.Build(graph);
InitSSAGraphNodes(); InitSSAGraphNodes();
auto cnt = 0; auto cnt = 0;
...@@ -155,11 +154,9 @@ std::unique_ptr<ir::Graph> InplacePass::ApplyImpl( ...@@ -155,11 +154,9 @@ std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
VLOG(4) << "Handle op " << cnt++ << ": " << op->Name(); VLOG(4) << "Handle op " << cnt++ << ": " << op->Name();
if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name())) if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name()))
continue; continue;
TryInplaceOpInputOutput(op, graph.get()); TryInplaceOpInputOutput(op, graph);
} }
// graph->ResolveHazard(var_nodes_); // graph->ResolveHazard(var_nodes_);
return graph;
} }
void InplacePass::InplaceModifyDesc(const std::string& var, void InplacePass::InplaceModifyDesc(const std::string& var,
......
...@@ -69,8 +69,7 @@ class InplacePass : public ir::Pass { ...@@ -69,8 +69,7 @@ class InplacePass : public ir::Pass {
InplacePass(); InplacePass();
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
void InitSSAGraphNodes() const; void InitSSAGraphNodes() const;
......
...@@ -44,8 +44,7 @@ namespace paddle { ...@@ -44,8 +44,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl( void MemoryOptimizePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
auto nodes = graph->Nodes(); auto nodes = graph->Nodes();
CollectSkipVarsSet(nodes); CollectSkipVarsSet(nodes);
...@@ -113,7 +112,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl( ...@@ -113,7 +112,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
cfg_->RenameVarInCFGGraph(var_name, cache_name, idx); cfg_->RenameVarInCFGGraph(var_name, cache_name, idx);
RenameVarInGraphDesc(var_name, cache_name, idx); RenameVarInGraphDesc(var_name, cache_name, idx);
RenameVarInGraphNode(var_name, cache_name, idx, graph.get()); RenameVarInGraphNode(var_name, cache_name, idx, graph);
pool_.Erase(cache_name); pool_.Erase(cache_name);
} }
} }
...@@ -128,8 +127,6 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl( ...@@ -128,8 +127,6 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
} }
} }
graph->ResolveHazard(var_nodes_); graph->ResolveHazard(var_nodes_);
return graph;
} }
void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const { void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <set> #include <set>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -35,8 +36,7 @@ namespace details { ...@@ -35,8 +36,7 @@ namespace details {
class MemoryOptimizePass : public ir::Pass { class MemoryOptimizePass : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
// fill the variable map(var_nodes) by version. // fill the variable map(var_nodes) by version.
void InitSSAGraphNodes() const; void InitSSAGraphNodes() const;
......
...@@ -34,8 +34,7 @@ static bool IsLockAndRecordEventFreeComputationOpHandle( ...@@ -34,8 +34,7 @@ static bool IsLockAndRecordEventFreeComputationOpHandle(
return true; return true;
} }
std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl( void ModifyOpLockAndRecordEventPass::ApplyImpl(ir::Graph *ir_graph) const {
std::unique_ptr<ir::Graph> ir_graph) const {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*ir_graph); auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*ir_graph);
OpGraphView graph_view(all_ops); OpGraphView graph_view(all_ops);
for (auto &op : all_ops) { for (auto &op : all_ops) {
...@@ -49,7 +48,6 @@ std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl( ...@@ -49,7 +48,6 @@ std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
<< compute_op->DebugString(); << compute_op->DebugString();
} }
} }
return ir_graph;
} }
} // namespace details } // namespace details
......
...@@ -23,8 +23,7 @@ namespace details { ...@@ -23,8 +23,7 @@ namespace details {
class ModifyOpLockAndRecordEventPass : public ir::Pass { class ModifyOpLockAndRecordEventPass : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace details } // namespace details
......
...@@ -23,10 +23,8 @@ namespace details { ...@@ -23,10 +23,8 @@ namespace details {
class SSAGraghBuilderWithChecker : public ir::Pass { class SSAGraghBuilderWithChecker : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph *graph) const override {
std::unique_ptr<ir::Graph> graph) const override { PADDLE_ENFORCE(IsValidGraph(graph));
PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph;
} }
bool IsValidGraph(const ir::Graph *graph) const { bool IsValidGraph(const ir::Graph *graph) const {
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -152,8 +153,7 @@ void MultiDevSSAGraphBuilderBase::Init() const { ...@@ -152,8 +153,7 @@ void MultiDevSSAGraphBuilderBase::Init() const {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
} }
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl( void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const {
std::unique_ptr<ir::Graph> graph) const {
Init(); Init();
CheckGraph(*graph); CheckGraph(*graph);
std::vector<ir::Node *> sorted_ops = SortOperations(*graph); std::vector<ir::Node *> sorted_ops = SortOperations(*graph);
...@@ -209,7 +209,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl( ...@@ -209,7 +209,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
for (size_t i = 0; i < backward_vars.size(); i += 2) { for (size_t i = 0; i < backward_vars.size(); i += 2) {
auto &p_name = backward_vars[i]; auto &p_name = backward_vars[i];
auto &g_name = backward_vars[i + 1]; auto &g_name = backward_vars[i + 1];
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name; VLOG(10) << "Bcast " << g_name << " for parameter " << p_name
<< " op_type " << node->Op()->Type();
if (NeedCollectiveForGrad(g_name, sorted_ops)) { if (NeedCollectiveForGrad(g_name, sorted_ops)) {
InsertCollectiveOp(&result, p_name, g_name); InsertCollectiveOp(&result, p_name, g_name);
} }
...@@ -234,7 +235,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl( ...@@ -234,7 +235,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
AddOutputToLeafOps(&result); AddOutputToLeafOps(&result);
result.Erase(kGraphOps); result.Erase(kGraphOps);
return graph;
} }
void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp( void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp(
...@@ -414,8 +414,9 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result, ...@@ -414,8 +414,9 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
CreateOpHandleIOs(result, node, dev_id); CreateOpHandleIOs(result, node, dev_id);
} }
void MultiDevSSAGraphBuilderBase::CreateAllReduceOp( void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
ir::Graph *result, const std::string &og) const { const std::string &og,
bool is_encoded) const {
OpHandleBase *op_handle = nullptr; OpHandleBase *op_handle = nullptr;
auto append_allreduce_op = [&]( auto append_allreduce_op = [&](
...@@ -424,7 +425,9 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp( ...@@ -424,7 +425,9 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
scopes, places, nccl_ctxs_)); scopes, places, nccl_ctxs_, is_encoded,
static_cast<int>(strategy_.trainers_endpoints_.size()) *
places_.size()));
#else #else
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
...@@ -446,12 +449,15 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp( ...@@ -446,12 +449,15 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad); op_handle->AddInput(prev_grad);
VLOG(10) << "all_reduce_op_handle add input " << prev_grad->DebugString();
auto var = auto var =
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable), new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
vars.size(), i, og, places_[i]); vars.size(), i, og, places_[i]);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
VLOG(10) << "all_reduce_op_handle add output " << og
<< ", handle:" << var->DebugString();
} }
} }
...@@ -941,6 +947,17 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -941,6 +947,17 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
return op_dev_id; return op_dev_id;
} }
bool DistSSAGraphBuilder::IsEncoded(const std::string &p_name) const {
auto u_name = p_name + "__dgc_u__";
auto it = all_vars_.find(u_name);
if (it == all_vars_.end()) {
VLOG(10) << "can't find u_name, so it's not encoded:" << u_name;
return false;
}
return true;
}
void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
const std::string &p_name, const std::string &p_name,
const std::string &g_name) const { const std::string &g_name) const {
...@@ -956,7 +973,11 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, ...@@ -956,7 +973,11 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
CreateReduceOp(result, g_name, 0); CreateReduceOp(result, g_name, 0);
CreateBroadcastOp(result, g_name, 0); CreateBroadcastOp(result, g_name, 0);
} else { } else {
CreateAllReduceOp(result, g_name); #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
CreateAllReduceOp(result, g_name, IsEncoded(p_name));
#else
PADDLE_ENFORCE(false, "Compiled withoud cuda!");
#endif
} }
break; break;
default: default:
......
...@@ -36,8 +36,7 @@ namespace details { ...@@ -36,8 +36,7 @@ namespace details {
class MultiDevSSAGraphBuilderBase : public ir::Pass { class MultiDevSSAGraphBuilderBase : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph *graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
virtual void Init() const; virtual void Init() const;
...@@ -75,7 +74,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { ...@@ -75,7 +74,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
bool IsSparseGradient(const std::string &og) const; bool IsSparseGradient(const std::string &og) const;
void CreateAllReduceOp(ir::Graph *result, const std::string &og) const; void CreateAllReduceOp(ir::Graph *result, const std::string &og,
bool is_encoded = false) const;
void CreateBroadcastOp(ir::Graph *result, const std::string &p_name, void CreateBroadcastOp(ir::Graph *result, const std::string &p_name,
size_t src_dev_id) const; size_t src_dev_id) const;
...@@ -171,6 +171,8 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder { ...@@ -171,6 +171,8 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
mutable std::vector<std::unordered_set<std::string>> bcast_var_name_set_; mutable std::vector<std::unordered_set<std::string>> bcast_var_name_set_;
mutable bool need_broadcast_var_{false}; mutable bool need_broadcast_var_{false};
bool IsEncoded(const std::string &p_name) const;
}; };
std::unordered_set<std::string> &MultiDevSSAGraphBuilder(); std::unordered_set<std::string> &MultiDevSSAGraphBuilder();
......
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <glog/logging.h> #include <glog/logging.h>
#include <fstream> #include <fstream>
#include <iosfwd> #include <iosfwd>
#include <memory>
#include <ostream> #include <ostream>
#include <string> #include <string>
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
...@@ -40,13 +41,11 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter { ...@@ -40,13 +41,11 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
class SSAGraghBuilderWithPrinter : public ir::Pass { class SSAGraghBuilderWithPrinter : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override {
std::unique_ptr<ir::Graph> graph) const override {
std::unique_ptr<std::ostream> fout( std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>(kGraphvizPath))); new std::ofstream(Get<std::string>(kGraphvizPath)));
PADDLE_ENFORCE(fout->good()); PADDLE_ENFORCE(fout->good());
Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout); Get<GraphvizSSAGraphPrinter>("graph_printer").Print(*graph, *fout);
return graph;
} }
}; };
......
...@@ -96,7 +96,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( ...@@ -96,7 +96,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
auto seq_allreduce_pass = auto seq_allreduce_pass =
ir::PassRegistry::Instance().Get("all_reduce_deps_pass"); ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
for (size_t i = 0; i < graphs_.size(); ++i) { for (size_t i = 0; i < graphs_.size(); ++i) {
graphs_[i] = seq_allreduce_pass->Apply(std::move(graphs_[i])); graphs_[i].reset(seq_allreduce_pass->Apply(graphs_[i].release()));
} }
// set the correct size of thread pool to each device. // set the correct size of thread pool to each device.
......
...@@ -266,8 +266,7 @@ static bool ShrinkNoNeedBufferVarOpDependency( ...@@ -266,8 +266,7 @@ static bool ShrinkNoNeedBufferVarOpDependency(
} }
} }
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount); auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount);
auto &last_live_ops_of_vars = auto &last_live_ops_of_vars =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars); Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
...@@ -335,14 +334,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -335,14 +334,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
var_name); var_name);
ref_cnts[i].emplace(var_name, result.size()); ref_cnts[i].emplace(var_name, result.size());
last_live_ops_of_vars[i].emplace(var_name, std::move(result)); last_live_ops_of_vars[i].emplace(var_name, std::move(result));
break;
} }
// Seldomly, all preceding trying failed. // Seldomly, all preceding trying failed.
// Just skip this corner case // Just skip this corner case
} }
} }
return graph;
} }
} // namespace details } // namespace details
......
...@@ -23,8 +23,7 @@ namespace details { ...@@ -23,8 +23,7 @@ namespace details {
class ReferenceCountPass : public ir::Pass { class ReferenceCountPass : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace details } // namespace details
......
...@@ -29,8 +29,7 @@ static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) { ...@@ -29,8 +29,7 @@ static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) {
op1->Outputs() == op2->Outputs(); op1->Outputs() == op2->Outputs();
} }
std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl( void SequentialExecutionPass::ApplyImpl(ir::Graph *graph) const {
std::unique_ptr<ir::Graph> graph) const {
// FIXME(zjl): Insert dependencies between some distributed ops may cause // FIXME(zjl): Insert dependencies between some distributed ops may cause
// the multi_devices_graph_pass fails. So we skip these ops here. // the multi_devices_graph_pass fails. So we skip these ops here.
// Indeed, maybe we should not insert dependencies between these ops // Indeed, maybe we should not insert dependencies between these ops
...@@ -98,7 +97,6 @@ std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl( ...@@ -98,7 +97,6 @@ std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl(
VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name() VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name()
<< " and " << op_node_list[i]->Name(); << " and " << op_node_list[i]->Name();
} }
return graph;
} }
} // namespace details } // namespace details
......
...@@ -23,8 +23,7 @@ namespace details { ...@@ -23,8 +23,7 @@ namespace details {
class SequentialExecutionPass : public ir::Pass { class SequentialExecutionPass : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace details } // namespace details
......
...@@ -24,7 +24,8 @@ VarHandle::~VarHandle() { VLOG(4) << "deleting var handle " << DebugString(); } ...@@ -24,7 +24,8 @@ VarHandle::~VarHandle() { VLOG(4) << "deleting var handle " << DebugString(); }
std::string VarHandle::DebugString() const { std::string VarHandle::DebugString() const {
std::stringstream ss; std::stringstream ss;
ss << name_ << ":" << place_; ss << "name:" << name_ << ", place:" << place_ << ", version:" << version_
<< ", scope_idx:" << scope_idx_;
return ss.str(); return ss.str();
} }
......
...@@ -23,8 +23,7 @@ namespace details { ...@@ -23,8 +23,7 @@ namespace details {
class WhileOpEagerDeletionPass : public ir::Pass { class WhileOpEagerDeletionPass : public ir::Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph *graph) const override {
std::unique_ptr<ir::Graph> graph) const override {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph); auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
// Find all while_op and while_grad_op // Find all while_op and while_grad_op
...@@ -50,7 +49,6 @@ class WhileOpEagerDeletionPass : public ir::Pass { ...@@ -50,7 +49,6 @@ class WhileOpEagerDeletionPass : public ir::Pass {
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
while_ops, while_grad_ops); while_ops, while_grad_ops);
} }
return graph;
} }
}; };
......
...@@ -29,10 +29,9 @@ namespace ir { ...@@ -29,10 +29,9 @@ namespace ir {
GET_IR_NODE(elementwise_mul); \ GET_IR_NODE(elementwise_mul); \
GET_IR_NODE(elementwise_mul_out); GET_IR_NODE(elementwise_mul_out);
std::unique_ptr<ir::Graph> AnakinFillconstantElementwisemulFuse::ApplyImpl( void AnakinFillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
const std::string pattern_name = "anakin_fillconstant_elementwisemul_fuse"; const std::string pattern_name = "anakin_fillconstant_elementwisemul_fuse";
FusePassBase::Init(pattern_name, graph.get()); FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern() auto* x = gpd.mutable_pattern()
...@@ -69,12 +68,11 @@ std::unique_ptr<ir::Graph> AnakinFillconstantElementwisemulFuse::ApplyImpl( ...@@ -69,12 +68,11 @@ std::unique_ptr<ir::Graph> AnakinFillconstantElementwisemulFuse::ApplyImpl(
IR_NODE_LINK_TO(scale_op, elementwise_mul_out); // Output IR_NODE_LINK_TO(scale_op, elementwise_mul_out); // Output
// Delete the unneeded nodes. // Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(), GraphSafeRemoveNodes(graph,
{fill_constant, fill_constant_out, elementwise_mul}); {fill_constant, fill_constant_out, elementwise_mul});
}; };
gpd(graph.get(), handler); gpd(graph, handler);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -26,8 +26,7 @@ class AnakinFillconstantElementwisemulFuse : public FusePassBase { ...@@ -26,8 +26,7 @@ class AnakinFillconstantElementwisemulFuse : public FusePassBase {
virtual ~AnakinFillconstantElementwisemulFuse() {} virtual ~AnakinFillconstantElementwisemulFuse() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/attention_lstm_fuse_pass.h" #include "paddle/fluid/framework/ir/attention_lstm_fuse_pass.h"
#include <string> #include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -253,8 +254,7 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input, ...@@ -253,8 +254,7 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
// Parameters // Parameters
std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl( void AttentionLSTMFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
PDPattern external_pattern, subblock_pattern; PDPattern external_pattern, subblock_pattern;
// Use the following variables to tell whether this model is RNN1. // Use the following variables to tell whether this model is RNN1.
...@@ -269,12 +269,11 @@ std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl( ...@@ -269,12 +269,11 @@ std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl(
} }
} }
if (count < specified_vars.size()) { if (count < specified_vars.size()) {
return graph; return;
} }
// Continue to fuse. // Continue to fuse.
FindWhileOp(graph.get()); FindWhileOp(graph);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -22,8 +22,7 @@ namespace ir { ...@@ -22,8 +22,7 @@ namespace ir {
class AttentionLSTMFusePass : public FusePassBase { class AttentionLSTMFusePass : public FusePassBase {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -77,10 +77,9 @@ void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight, ...@@ -77,10 +77,9 @@ void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight,
weights_array_2d.colwise() *= scale_array; weights_array_2d.colwise() *= scale_array;
} }
std::unique_ptr<ir::Graph> ConvAffineChannelFusePass::ApplyImpl( void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get());
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE(scope);
...@@ -139,7 +138,7 @@ std::unique_ptr<ir::Graph> ConvAffineChannelFusePass::ApplyImpl( ...@@ -139,7 +138,7 @@ std::unique_ptr<ir::Graph> ConvAffineChannelFusePass::ApplyImpl(
desc.SetAttr("axis", 1); desc.SetAttr("axis", 1);
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied. auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph.get(), {ac_scale, ac_bias, affine_channel}); GraphSafeRemoveNodes(graph, {ac_scale, ac_bias, affine_channel});
IR_NODE_LINK_TO(conv_out, eltwise_op); IR_NODE_LINK_TO(conv_out, eltwise_op);
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op); IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
...@@ -147,16 +146,14 @@ std::unique_ptr<ir::Graph> ConvAffineChannelFusePass::ApplyImpl( ...@@ -147,16 +146,14 @@ std::unique_ptr<ir::Graph> ConvAffineChannelFusePass::ApplyImpl(
found_conv_ac_count++; found_conv_ac_count++;
}; };
gpd(graph.get(), handler); gpd(graph, handler);
AddStatis(found_conv_ac_count); AddStatis(found_conv_ac_count);
return graph;
} }
std::unique_ptr<ir::Graph> ConvEltwiseAddAffineChannelFusePass::ApplyImpl( void ConvEltwiseAddAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get());
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE(scope);
...@@ -199,7 +196,7 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddAffineChannelFusePass::ApplyImpl( ...@@ -199,7 +196,7 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddAffineChannelFusePass::ApplyImpl(
eltwise->Op()->SetAttr("axis", 1); eltwise->Op()->SetAttr("axis", 1);
eltwise->Op()->SetOutput("Out", std::vector<std::string>({ac_out->Name()})); eltwise->Op()->SetOutput("Out", std::vector<std::string>({ac_out->Name()}));
GraphSafeRemoveNodes(graph.get(), GraphSafeRemoveNodes(graph,
{ac_scale, ac_bias, affine_channel, eltwise_out}); {ac_scale, ac_bias, affine_channel, eltwise_out});
IR_NODE_LINK_TO(eltwise, ac_out); IR_NODE_LINK_TO(eltwise, ac_out);
...@@ -207,9 +204,8 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddAffineChannelFusePass::ApplyImpl( ...@@ -207,9 +204,8 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddAffineChannelFusePass::ApplyImpl(
found_conv_ac_count++; found_conv_ac_count++;
}; };
gpd(graph.get(), handler); gpd(graph, handler);
AddStatis(found_conv_ac_count); AddStatis(found_conv_ac_count);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -31,8 +31,7 @@ class ConvAffineChannelFusePass : public FusePassBase { ...@@ -31,8 +31,7 @@ class ConvAffineChannelFusePass : public FusePassBase {
virtual ~ConvAffineChannelFusePass() {} virtual ~ConvAffineChannelFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph*) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"conv_affine_channel_fuse"}; const std::string name_scope_{"conv_affine_channel_fuse"};
}; };
...@@ -41,8 +40,7 @@ class ConvEltwiseAddAffineChannelFusePass : public FusePassBase { ...@@ -41,8 +40,7 @@ class ConvEltwiseAddAffineChannelFusePass : public FusePassBase {
virtual ~ConvEltwiseAddAffineChannelFusePass() {} virtual ~ConvEltwiseAddAffineChannelFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph*) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"conv_eltwiseadd_affine_channel_fuse"}; const std::string name_scope_{"conv_eltwiseadd_affine_channel_fuse"};
}; };
......
...@@ -101,10 +101,9 @@ void recompute_bias_and_weights(const Scope* scope, ...@@ -101,10 +101,9 @@ void recompute_bias_and_weights(const Scope* scope,
weights_array_2d.colwise() *= variance_array; weights_array_2d.colwise() *= variance_array;
} }
std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl( void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get());
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE(scope);
...@@ -187,7 +186,7 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl( ...@@ -187,7 +186,7 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
std::vector<std::string>({bn_out->Name()})); std::vector<std::string>({bn_out->Name()}));
GraphSafeRemoveNodes( GraphSafeRemoveNodes(
graph.get(), graph,
{conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, {conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm,
bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance}); bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance});
...@@ -203,10 +202,9 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl( ...@@ -203,10 +202,9 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
desc.SetAttr("axis", 1); desc.SetAttr("axis", 1);
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied. auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes( GraphSafeRemoveNodes(graph, {bn_scale, bn_bias, bn_mean, bn_variance,
graph.get(), batch_norm, bn_mean_out, bn_variance_out,
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out, bn_saved_mean, bn_saved_variance});
bn_variance_out, bn_saved_mean, bn_saved_variance});
IR_NODE_LINK_TO(conv_out, eltwise_op); IR_NODE_LINK_TO(conv_out, eltwise_op);
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op); IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
...@@ -215,16 +213,14 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl( ...@@ -215,16 +213,14 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
} }
}; };
gpd(graph.get(), handler); gpd(graph, handler);
AddStatis(found_conv_bn_count); AddStatis(found_conv_bn_count);
return graph;
} }
std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl( void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get());
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE(scope);
...@@ -274,7 +270,7 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl( ...@@ -274,7 +270,7 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl(
eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()})); eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
GraphSafeRemoveNodes( GraphSafeRemoveNodes(
graph.get(), graph,
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out, {bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
bn_variance_out, bn_saved_mean, bn_saved_variance, eltwise_out}); bn_variance_out, bn_saved_mean, bn_saved_variance, eltwise_out});
...@@ -283,10 +279,9 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl( ...@@ -283,10 +279,9 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl(
found_conv_bn_count++; found_conv_bn_count++;
}; };
gpd(graph.get(), handler); gpd(graph, handler);
AddStatis(found_conv_bn_count); AddStatis(found_conv_bn_count);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -31,8 +31,7 @@ class ConvBNFusePass : public FusePassBase { ...@@ -31,8 +31,7 @@ class ConvBNFusePass : public FusePassBase {
virtual ~ConvBNFusePass() {} virtual ~ConvBNFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"conv_bn_fuse"}; const std::string name_scope_{"conv_bn_fuse"};
}; };
...@@ -41,8 +40,7 @@ class ConvEltwiseAddBNFusePass : public FusePassBase { ...@@ -41,8 +40,7 @@ class ConvEltwiseAddBNFusePass : public FusePassBase {
virtual ~ConvEltwiseAddBNFusePass() {} virtual ~ConvEltwiseAddBNFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"conv_eltwiseadd_bn_fuse"}; const std::string name_scope_{"conv_eltwiseadd_bn_fuse"};
}; };
......
...@@ -50,10 +50,9 @@ framework::proto::OpDesc PrepareOpDesc( ...@@ -50,10 +50,9 @@ framework::proto::OpDesc PrepareOpDesc(
return *desc.Proto(); return *desc.Proto();
} }
std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl( void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
const std::string pattern_name = "conv_elementwise_add_act_fuse"; const std::string pattern_name = "conv_elementwise_add_act_fuse";
FusePassBase::Init(pattern_name, graph.get()); FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input( auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input(
...@@ -95,7 +94,6 @@ std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl( ...@@ -95,7 +94,6 @@ std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl(
elementwise_add_out}); elementwise_add_out});
}; };
gpd(graph.get(), handler); gpd(graph.get(), handler);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -51,10 +51,9 @@ framework::proto::OpDesc PrepareOpDesc( ...@@ -51,10 +51,9 @@ framework::proto::OpDesc PrepareOpDesc(
return *desc.Proto(); return *desc.Proto();
} }
std::unique_ptr<ir::Graph> ConvElementwiseAdd2ActFusePass::ApplyImpl( void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
const std::string pattern_name = "conv_elementwise_add2_act_fuse"; const std::string pattern_name = "conv_elementwise_add2_act_fuse";
FusePassBase::Init(pattern_name, graph.get()); FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input( auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input(
...@@ -92,12 +91,10 @@ std::unique_ptr<ir::Graph> ConvElementwiseAdd2ActFusePass::ApplyImpl( ...@@ -92,12 +91,10 @@ std::unique_ptr<ir::Graph> ConvElementwiseAdd2ActFusePass::ApplyImpl(
// Delete the unneeded nodes. // Delete the unneeded nodes.
GraphSafeRemoveNodes( GraphSafeRemoveNodes(
graph.get(), graph, {conv_op, conv_out, elementwise_add_op, elementwise_add_op_1,
{conv_op, conv_out, elementwise_add_op, elementwise_add_op_1,
elementwise_add_out, elementwise_add_out_1, act_op}); elementwise_add_out, elementwise_add_out_1, act_op});
}; };
gpd(graph.get(), handler); gpd(graph, handler);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -25,8 +25,7 @@ class ConvElementwiseAdd2ActFusePass : public FusePassBase { ...@@ -25,8 +25,7 @@ class ConvElementwiseAdd2ActFusePass : public FusePassBase {
virtual ~ConvElementwiseAdd2ActFusePass() {} virtual ~ConvElementwiseAdd2ActFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -48,10 +48,9 @@ framework::proto::OpDesc PrepareOpDesc( ...@@ -48,10 +48,9 @@ framework::proto::OpDesc PrepareOpDesc(
return *desc.Proto(); return *desc.Proto();
} }
std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl( void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
const std::string pattern_name = "conv_elementwise_add_act_fuse"; const std::string pattern_name = "conv_elementwise_add_act_fuse";
FusePassBase::Init(pattern_name, graph.get()); FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern() auto* x = gpd.mutable_pattern()
...@@ -88,12 +87,11 @@ std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl( ...@@ -88,12 +87,11 @@ std::unique_ptr<ir::Graph> ConvElementwiseAddActFusePass::ApplyImpl(
IR_NODE_LINK_TO(new_conv_op, act_out); // Output IR_NODE_LINK_TO(new_conv_op, act_out); // Output
// Delete the unneeded nodes. // Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(), {conv_op, conv_out, elementwise_add_op, GraphSafeRemoveNodes(graph, {conv_op, conv_out, elementwise_add_op,
elementwise_add_out, act_op}); elementwise_add_out, act_op});
}; };
gpd(graph.get(), handler); gpd(graph, handler);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -25,8 +25,7 @@ class ConvElementwiseAddActFusePass : public FusePassBase { ...@@ -25,8 +25,7 @@ class ConvElementwiseAddActFusePass : public FusePassBase {
virtual ~ConvElementwiseAddActFusePass() {} virtual ~ConvElementwiseAddActFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -30,10 +30,9 @@ namespace ir { ...@@ -30,10 +30,9 @@ namespace ir {
GET_IR_NODE(elementwise_add_in_y); \ GET_IR_NODE(elementwise_add_in_y); \
GET_IR_NODE(elementwise_add_out); GET_IR_NODE(elementwise_add_out);
std::unique_ptr<ir::Graph> ConvElementwiseAddFusePass::ApplyImpl( void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
const std::string pattern_name = "conv_elementwise_add_fuse"; const std::string pattern_name = "conv_elementwise_add_fuse";
FusePassBase::Init(pattern_name, graph.get()); FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern() auto* x = gpd.mutable_pattern()
...@@ -76,11 +75,10 @@ std::unique_ptr<ir::Graph> ConvElementwiseAddFusePass::ApplyImpl( ...@@ -76,11 +75,10 @@ std::unique_ptr<ir::Graph> ConvElementwiseAddFusePass::ApplyImpl(
IR_NODE_LINK_TO(new_conv_op, elementwise_add_out); // Output IR_NODE_LINK_TO(new_conv_op, elementwise_add_out); // Output
// Delete the unneeded nodes. // Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(), {conv_op, conv_out, elementwise_add_op}); GraphSafeRemoveNodes(graph, {conv_op, conv_out, elementwise_add_op});
}; };
gpd(graph.get(), handler); gpd(graph, handler);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -25,8 +25,7 @@ class ConvElementwiseAddFusePass : public FusePassBase { ...@@ -25,8 +25,7 @@ class ConvElementwiseAddFusePass : public FusePassBase {
virtual ~ConvElementwiseAddFusePass() {} virtual ~ConvElementwiseAddFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include "paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h" #include "paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h"
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
...@@ -201,7 +203,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -201,7 +203,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
// Remove unneeded nodes. // Remove unneeded nodes.
// TODO(jczaja): Proper removing of lookup table // TODO(jczaja): Proper removing of lookup table
std::unordered_set<const Node*> marked_nodes( std::unordered_set<const Node*> marked_nodes(
//{lookup_table, mul, lstm, elementwise_add, fc_bias, W}); // {lookup_table, mul, lstm, elementwise_add, fc_bias, W});
{mul, lstm, elementwise_add, fc_bias}); {mul, lstm, elementwise_add, fc_bias});
GraphSafeRemoveNodes(graph, marked_nodes); GraphSafeRemoveNodes(graph, marked_nodes);
} else { } else {
...@@ -224,15 +226,13 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -224,15 +226,13 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
return fusion_count; return fusion_count;
} }
std::unique_ptr<ir::Graph> EmbeddingFCLSTMFusePass::ApplyImpl( void EmbeddingFCLSTMFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(), int fusion_count =
true /*with_fc_bias*/); BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/);
AddStatis(fusion_count); AddStatis(fusion_count);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -32,8 +32,7 @@ class EmbeddingFCLSTMFusePass : public FusePassBase { ...@@ -32,8 +32,7 @@ class EmbeddingFCLSTMFusePass : public FusePassBase {
virtual ~EmbeddingFCLSTMFusePass() {} virtual ~EmbeddingFCLSTMFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"embedding_fc_lstm_fuse"}; const std::string name_scope_{"embedding_fc_lstm_fuse"};
}; };
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/fc_fuse_pass.h" #include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -22,10 +23,9 @@ namespace paddle { ...@@ -22,10 +23,9 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( void FCFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init("fc_fuse", graph);
FusePassBase::Init("fc_fuse", graph.get());
std::unordered_set<Node*> nodes2delete; std::unordered_set<Node*> nodes2delete;
...@@ -61,7 +61,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( ...@@ -61,7 +61,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
desc.SetAttr("in_num_col_dims", mul->Op()->GetAttr("x_num_col_dims")); desc.SetAttr("in_num_col_dims", mul->Op()->GetAttr("x_num_col_dims"));
desc.SetType("fc"); desc.SetType("fc");
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied. auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph.get(), {mul, elementwise_add, mul_out}); GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out});
PADDLE_ENFORCE(subgraph.count(x)); PADDLE_ENFORCE(subgraph.count(x));
IR_NODE_LINK_TO(subgraph.at(x), fc_node); IR_NODE_LINK_TO(subgraph.at(x), fc_node);
...@@ -72,10 +72,9 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( ...@@ -72,10 +72,9 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
found_fc_count++; found_fc_count++;
}; };
gpd(graph.get(), handler); gpd(graph, handler);
AddStatis(found_fc_count); AddStatis(found_fc_count);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -31,8 +31,7 @@ class FCFusePass : public FusePassBase { ...@@ -31,8 +31,7 @@ class FCFusePass : public FusePassBase {
virtual ~FCFusePass() {} virtual ~FCFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -73,7 +73,7 @@ TEST(FCFusePass, basic) { ...@@ -73,7 +73,7 @@ TEST(FCFusePass, basic) {
int pre_nodes = graph->Nodes().size(); int pre_nodes = graph->Nodes().size();
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
int after_nodes = graph->Nodes().size(); int after_nodes = graph->Nodes().size();
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/fc_gru_fuse_pass.h" #include "paddle/fluid/framework/ir/fc_gru_fuse_pass.h"
#include <string> #include <string>
#include <unordered_set>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
namespace paddle { namespace paddle {
...@@ -39,7 +40,6 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -39,7 +40,6 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
// Create New OpDesc // Create New OpDesc
auto gru_creater = [&](Node* gru, Node* x, Node* weight_x, Node* weight_h, auto gru_creater = [&](Node* gru, Node* x, Node* weight_x, Node* weight_h,
Node* bias, Node* hidden, Node* fc_bias) { Node* bias, Node* hidden, Node* fc_bias) {
OpDesc op_desc; OpDesc op_desc;
op_desc.SetType("fusion_gru"); op_desc.SetType("fusion_gru");
...@@ -155,26 +155,22 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -155,26 +155,22 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
return fusion_count; return fusion_count;
} }
std::unique_ptr<ir::Graph> MulGRUFusePass::ApplyImpl( void MulGRUFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(), int fusion_count =
false /*with_fc_bias*/); BuildFusion(graph, name_scope_, param_scope(), false /*with_fc_bias*/);
AddStatis(fusion_count); AddStatis(fusion_count);
return graph;
} }
std::unique_ptr<ir::Graph> FCGRUFusePass::ApplyImpl( void FCGRUFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(), int fusion_count =
true /*with_fc_bias*/); BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/);
AddStatis(fusion_count); AddStatis(fusion_count);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -30,8 +30,7 @@ class FCGRUFusePass : public FusePassBase { ...@@ -30,8 +30,7 @@ class FCGRUFusePass : public FusePassBase {
virtual ~FCGRUFusePass() {} virtual ~FCGRUFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"fc_gru_fuse"}; const std::string name_scope_{"fc_gru_fuse"};
}; };
...@@ -42,8 +41,7 @@ class MulGRUFusePass : public FusePassBase { ...@@ -42,8 +41,7 @@ class MulGRUFusePass : public FusePassBase {
virtual ~MulGRUFusePass() {} virtual ~MulGRUFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"fc_nobias_gru_fuse"}; const std::string name_scope_{"fc_nobias_gru_fuse"};
}; };
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h" #include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include <string> #include <string>
#include <unordered_set>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
namespace paddle { namespace paddle {
...@@ -157,26 +158,22 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -157,26 +158,22 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
return fusion_count; return fusion_count;
} }
std::unique_ptr<ir::Graph> MulLstmFusePass::ApplyImpl( void MulLstmFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(), int fusion_count =
false /*with_fc_bias*/); BuildFusion(graph, name_scope_, param_scope(), false /*with_fc_bias*/);
AddStatis(fusion_count); AddStatis(fusion_count);
return graph;
} }
std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl( void FCLstmFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(), int fusion_count =
true /*with_fc_bias*/); BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/);
AddStatis(fusion_count); AddStatis(fusion_count);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -32,8 +32,7 @@ class FCLstmFusePass : public FusePassBase { ...@@ -32,8 +32,7 @@ class FCLstmFusePass : public FusePassBase {
virtual ~FCLstmFusePass() {} virtual ~FCLstmFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"fc_lstm_fuse"}; const std::string name_scope_{"fc_lstm_fuse"};
}; };
...@@ -43,8 +42,7 @@ class MulLstmFusePass : public FusePassBase { ...@@ -43,8 +42,7 @@ class MulLstmFusePass : public FusePassBase {
virtual ~MulLstmFusePass() {} virtual ~MulLstmFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"fc_nobias_lstm_fuse"}; const std::string name_scope_{"fc_nobias_lstm_fuse"};
}; };
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include "paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h" #include "paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h"
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -23,29 +25,25 @@ namespace paddle { ...@@ -23,29 +25,25 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::ApplyImpl( void FuseElewiseAddActPass::ApplyImpl(ir::Graph *graph) const {
std::unique_ptr<ir::Graph> graph) const {
std::unordered_set<std::string> act_types = {"relu", "scale"}; std::unordered_set<std::string> act_types = {"relu", "scale"};
graph = FuseActElewiseAdd(std::move(graph), act_types); graph = FuseActElewiseAdd(graph, act_types);
graph = FuseElewiseAddAct(std::move(graph), act_types); graph = FuseElewiseAddAct(graph, act_types);
// backward // backward
{ {
std::unordered_set<std::string> in_place_act_types = {"relu_grad"}; std::unordered_set<std::string> in_place_act_types = {"relu_grad"};
graph = FuseElewiseAddActInplaceGrad(std::move(graph), in_place_act_types); graph = FuseElewiseAddActInplaceGrad(graph, in_place_act_types);
} }
// Remove the removable intermediate_out. // Remove the removable intermediate_out.
RemoveIntermediateOut(graph.get()); RemoveIntermediateOut(graph);
return graph;
} }
// ele_add(x, act(y)) // ele_add(x, act(y))
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddAct( ir::Graph *FuseElewiseAddActPass::FuseElewiseAddAct(
std::unique_ptr<ir::Graph> graph, ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
const std::unordered_set<std::string> &act_types) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init("elewise_add_act", graph);
FusePassBase::Init("elewise_add_act", graph.get());
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern() auto *x = gpd.mutable_pattern()
...@@ -86,18 +84,17 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddAct( ...@@ -86,18 +84,17 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddAct(
found_elewise_add_act_count++; found_elewise_add_act_count++;
}; };
gpd(graph.get(), handler); gpd(graph, handler);
AddStatis(found_elewise_add_act_count); AddStatis(found_elewise_add_act_count);
return graph; return graph;
} }
// act(ele_add(x,y)) // act(ele_add(x,y))
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd( ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd(
std::unique_ptr<ir::Graph> graph, ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
const std::unordered_set<std::string> &act_types) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init("act_elewise_add", graph);
FusePassBase::Init("act_elewise_add", graph.get());
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern() auto *x = gpd.mutable_pattern()
...@@ -137,7 +134,7 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd( ...@@ -137,7 +134,7 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd(
found_elewise_add_act_count++; found_elewise_add_act_count++;
}; };
gpd(graph.get(), handler); gpd(graph, handler);
AddStatis(found_elewise_add_act_count); AddStatis(found_elewise_add_act_count);
return graph; return graph;
...@@ -146,11 +143,10 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd( ...@@ -146,11 +143,10 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd(
// the backward of act(ele_add(x,y)) // the backward of act(ele_add(x,y))
// act_grad: in["Out", "Out@GRAD"], out["X@GRAD"] // act_grad: in["Out", "Out@GRAD"], out["X@GRAD"]
// ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"] // ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"]
std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad( ir::Graph *FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
std::unique_ptr<ir::Graph> graph, ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
const std::unordered_set<std::string> &act_types) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init("elewise_add_act_grad", graph);
FusePassBase::Init("elewise_add_act_grad", graph.get());
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto *d_act_out = gpd.mutable_pattern() auto *d_act_out = gpd.mutable_pattern()
...@@ -217,7 +213,7 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad( ...@@ -217,7 +213,7 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
found_elewise_add_act_count++; found_elewise_add_act_count++;
}; };
gpd(graph.get(), handler); gpd(graph, handler);
AddStatis(found_elewise_add_act_count); AddStatis(found_elewise_add_act_count);
return graph; return graph;
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include <string> #include <string>
#include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -32,20 +34,16 @@ class FuseElewiseAddActPass : public FusePassBase { ...@@ -32,20 +34,16 @@ class FuseElewiseAddActPass : public FusePassBase {
virtual ~FuseElewiseAddActPass() {} virtual ~FuseElewiseAddActPass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph *graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
std::unique_ptr<ir::Graph> FuseElewiseAddAct( ir::Graph *FuseElewiseAddAct(
std::unique_ptr<ir::Graph> graph, ir::Graph *graph, const std::unordered_set<std::string> &act_types) const;
const std::unordered_set<std::string> &act_types) const;
std::unique_ptr<ir::Graph> FuseActElewiseAdd( ir::Graph *FuseActElewiseAdd(
std::unique_ptr<ir::Graph> graph, ir::Graph *graph, const std::unordered_set<std::string> &act_types) const;
const std::unordered_set<std::string> &act_types) const;
std::unique_ptr<ir::Graph> FuseElewiseAddActInplaceGrad( ir::Graph *FuseElewiseAddActInplaceGrad(
std::unique_ptr<ir::Graph> graph, ir::Graph *graph, const std::unordered_set<std::string> &act_types) const;
const std::unordered_set<std::string> &act_types) const;
/** /**
* Remove the removable intermediate_out. * Remove the removable intermediate_out.
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/fuse_relu_depthwise_conv_pass.h" #include "paddle/fluid/framework/ir/fuse_relu_depthwise_conv_pass.h"
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -23,20 +24,18 @@ namespace paddle { ...@@ -23,20 +24,18 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<ir::Graph> FuseReluDepthwiseConvPass::ApplyImpl( void FuseReluDepthwiseConvPass::ApplyImpl(ir::Graph *graph) const {
std::unique_ptr<ir::Graph> graph) const { graph = FuseReluDepthwiseConv(graph, true);
graph = FuseReluDepthwiseConv(std::move(graph), true); graph = FuseReluDepthwiseConv(graph, false);
graph = FuseReluDepthwiseConv(std::move(graph), false);
return graph;
} }
std::unique_ptr<ir::Graph> FuseReluDepthwiseConvPass::FuseReluDepthwiseConv( ir::Graph *FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
std::unique_ptr<ir::Graph> graph, bool only_forward) const { ir::Graph *graph, bool only_forward) const {
PADDLE_ENFORCE(graph.get()); PADDLE_ENFORCE(graph);
if (only_forward) if (only_forward)
FusePassBase::Init("relu_depthwise_conv_only_forward", graph.get()); FusePassBase::Init("relu_depthwise_conv_only_forward", graph);
else else
FusePassBase::Init("relu_depthwise_conv", graph.get()); FusePassBase::Init("relu_depthwise_conv", graph);
/* /*
x ---act--> y ---layer-> z x ---act--> y ---layer-> z
+----------+ +----------+
...@@ -144,10 +143,9 @@ std::unique_ptr<ir::Graph> FuseReluDepthwiseConvPass::FuseReluDepthwiseConv( ...@@ -144,10 +143,9 @@ std::unique_ptr<ir::Graph> FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
} }
count++; count++;
}; };
gpd(graph.get(), handler); gpd(graph, handler);
GraphSafeRemoveNodes(graph.get(), need_removed_nodes); GraphSafeRemoveNodes(graph, need_removed_nodes);
AddStatis(count); AddStatis(count);
return graph; return graph;
} }
......
...@@ -32,10 +32,8 @@ class FuseReluDepthwiseConvPass : public FusePassBase { ...@@ -32,10 +32,8 @@ class FuseReluDepthwiseConvPass : public FusePassBase {
virtual ~FuseReluDepthwiseConvPass() {} virtual ~FuseReluDepthwiseConvPass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override; ir::Graph* FuseReluDepthwiseConv(ir::Graph* graph, bool only_forward) const;
std::unique_ptr<ir::Graph> FuseReluDepthwiseConv(
std::unique_ptr<ir::Graph> graph, bool only_forward) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -15,7 +15,9 @@ limitations under the License. */ ...@@ -15,7 +15,9 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h" #include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <map> #include <map>
#include <memory>
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -26,8 +28,7 @@ namespace paddle { ...@@ -26,8 +28,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl( void GraphToProgramPass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<Graph> graph) const {
// Remove the unneeded variables after memory optimization. // Remove the unneeded variables after memory optimization.
std::unordered_set<std::string> vars2remove; std::unordered_set<std::string> vars2remove;
if (graph->Has(kGraphToProgramVarsToRemove)) { if (graph->Has(kGraphToProgramVarsToRemove)) {
...@@ -73,7 +74,6 @@ std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl( ...@@ -73,7 +74,6 @@ std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl(
} }
program.CopyFrom(*program_pb); program.CopyFrom(*program_pb);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -26,7 +26,7 @@ const char kGraphToProgramSortKind[] = "__graph_to_program_sort_kind__"; ...@@ -26,7 +26,7 @@ const char kGraphToProgramSortKind[] = "__graph_to_program_sort_kind__";
class GraphToProgramPass : public Pass { class GraphToProgramPass : public Pass {
protected: protected:
std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const override; void ApplyImpl(ir::Graph* graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -14,7 +14,9 @@ limitations under the License. */ ...@@ -14,7 +14,9 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_to_program_pass.h" #include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include <memory>
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
...@@ -84,7 +86,7 @@ TEST(GraphToProgramPass, Basic) { ...@@ -84,7 +86,7 @@ TEST(GraphToProgramPass, Basic) {
ProgramDesc compiled_prog; ProgramDesc compiled_prog;
pass->SetNotOwned<paddle::framework::ProgramDesc>("program", &compiled_prog); pass->SetNotOwned<paddle::framework::ProgramDesc>("program", &compiled_prog);
pass->Apply(std::move(g)); pass->Apply(g.get());
std::vector<OpDesc*> ops = compiled_prog.Block(0).AllOps(); std::vector<OpDesc*> ops = compiled_prog.Block(0).AllOps();
EXPECT_EQ(ops[0]->Type(), "op1"); EXPECT_EQ(ops[0]->Type(), "op1");
EXPECT_EQ(ops[1]->Type(), "op2"); EXPECT_EQ(ops[1]->Type(), "op2");
......
...@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include <algorithm> #include <algorithm>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/inference/analysis/dot.h" #include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
...@@ -38,8 +38,7 @@ std::string FormatName(const Node* node) { ...@@ -38,8 +38,7 @@ std::string FormatName(const Node* node) {
} }
} // namespace } // namespace
std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl( void GraphVizPass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
const std::string graph_viz_path = Get<std::string>(kGraphVizPath); const std::string graph_viz_path = Get<std::string>(kGraphVizPath);
VLOG(3) << "draw IR graph viz to " << graph_viz_path; VLOG(3) << "draw IR graph viz to " << graph_viz_path;
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path)); std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path));
...@@ -82,7 +81,7 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl( ...@@ -82,7 +81,7 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
{Dot::Attr("style", "filled,rounded"), Dot::Attr("shape", "box"), {Dot::Attr("style", "filled,rounded"), Dot::Attr("shape", "box"),
Dot::Attr("fillcolor", "yellow")}); Dot::Attr("fillcolor", "yellow")});
auto marked_nodes = ConsumeMarkedNodes(graph.get()); auto marked_nodes = ConsumeMarkedNodes(graph);
// Create nodes // Create nodes
for (const Node* n : graph->Nodes()) { for (const Node* n : graph->Nodes()) {
std::string node_id = FormatName(n) + "(" + std::to_string(n->id()) + ")"; std::string node_id = FormatName(n) + "(" + std::to_string(n->id()) + ")";
...@@ -115,8 +114,6 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl( ...@@ -115,8 +114,6 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
} }
sout << dot.Build(); sout << dot.Build();
return graph;
} }
GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes( GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <map> #include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -34,8 +35,7 @@ class GraphVizPass : public Pass { ...@@ -34,8 +35,7 @@ class GraphVizPass : public Pass {
using marked_nodes_t = std::unordered_set<const Node*>; using marked_nodes_t = std::unordered_set<const Node*>;
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
// Tell whether there are any marked nodes in the graph. Consume the // Tell whether there are any marked nodes in the graph. Consume the
// corresponding attribute. // corresponding attribute.
......
...@@ -20,9 +20,8 @@ namespace paddle { ...@@ -20,9 +20,8 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<ir::Graph> IdentityScaleOpCleanPass::ApplyImpl( void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { FusePassBase::Init("identity_scale_op_clean", graph);
FusePassBase::Init("identity_scale_op_clean", graph.get());
// pre_op -> scale_in -> scale_op -> scale_out // pre_op -> scale_in -> scale_op -> scale_out
// -> // ->
...@@ -72,8 +71,7 @@ std::unique_ptr<ir::Graph> IdentityScaleOpCleanPass::ApplyImpl( ...@@ -72,8 +71,7 @@ std::unique_ptr<ir::Graph> IdentityScaleOpCleanPass::ApplyImpl(
IR_NODE_LINK_TO(pre_op_var, scale_out_var); IR_NODE_LINK_TO(pre_op_var, scale_out_var);
}; };
detector(graph.get(), handler); detector(graph, handler);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -22,8 +22,7 @@ namespace ir { ...@@ -22,8 +22,7 @@ namespace ir {
class IdentityScaleOpCleanPass : public FusePassBase { class IdentityScaleOpCleanPass : public FusePassBase {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
private: private:
virtual ~IdentityScaleOpCleanPass() = default; virtual ~IdentityScaleOpCleanPass() = default;
......
...@@ -26,9 +26,9 @@ class InferCleanGraphPass : public FusePassBase { ...@@ -26,9 +26,9 @@ class InferCleanGraphPass : public FusePassBase {
virtual ~InferCleanGraphPass() {} virtual ~InferCleanGraphPass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const { void ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("original_graph", graph.get()); FusePassBase::Init("original_graph", graph);
PADDLE_ENFORCE(graph.get()); PADDLE_ENFORCE(graph);
auto is_valid_node = [](Node* x) { auto is_valid_node = [](Node* x) {
return x && IsControlDepVar(*x) && x->IsVar() && !x->Var(); return x && IsControlDepVar(*x) && x->IsVar() && !x->Var();
...@@ -46,11 +46,9 @@ class InferCleanGraphPass : public FusePassBase { ...@@ -46,11 +46,9 @@ class InferCleanGraphPass : public FusePassBase {
} }
} }
GraphSafeRemoveNodes(graph.get(), invalid_nodes); GraphSafeRemoveNodes(graph, invalid_nodes);
AddStatis(valid_op); AddStatis(valid_op);
return graph;
} }
void CleanEdges(std::vector<Node*>* nodes, void CleanEdges(std::vector<Node*>* nodes,
......
...@@ -20,8 +20,7 @@ namespace paddle { ...@@ -20,8 +20,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<ir::Graph> IsTestPass::ApplyImpl( void IsTestPass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Sets is_test attrbiute to true and if it is missing, inserts it " VLOG(3) << "Sets is_test attrbiute to true and if it is missing, inserts it "
"for activations and pooling."; "for activations and pooling.";
auto op_list = {"pool2d", "sigmoid", "logsigmoid", auto op_list = {"pool2d", "sigmoid", "logsigmoid",
...@@ -47,7 +46,6 @@ std::unique_ptr<ir::Graph> IsTestPass::ApplyImpl( ...@@ -47,7 +46,6 @@ std::unique_ptr<ir::Graph> IsTestPass::ApplyImpl(
} }
} }
} }
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -22,8 +22,7 @@ namespace ir { ...@@ -22,8 +22,7 @@ namespace ir {
class IsTestPass : public Pass { class IsTestPass : public Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -97,7 +97,7 @@ TEST(IsTestPass, basic) { ...@@ -97,7 +97,7 @@ TEST(IsTestPass, basic) {
auto pass = PassRegistry::Instance().Get("is_test_pass"); auto pass = PassRegistry::Instance().Get("is_test_pass");
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (node->IsOp()) { if (node->IsOp()) {
......
...@@ -32,9 +32,8 @@ const char kSumGradOpName[] = "sum"; ...@@ -32,9 +32,8 @@ const char kSumGradOpName[] = "sum";
// other optimizers later. // other optimizers later.
const char kOptimizerType[] = "sgd"; const char kOptimizerType[] = "sgd";
std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl( void LockFreeOptimizePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get());
// We could collect all weights' name from SGD, where // We could collect all weights' name from SGD, where
// W1 <- SGD(W0, Grad0) // W1 <- SGD(W0, Grad0)
...@@ -92,14 +91,14 @@ std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl( ...@@ -92,14 +91,14 @@ std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl(
// find the forward op related to the backward op // find the forward op related to the backward op
ir::Node* forward_op = ir::Node* forward_op =
FindForwardOpViaBackwardOp(graph.get(), backward_op); FindForwardOpViaBackwardOp(graph, backward_op);
VLOG(3) << "Found forward_op " << forward_op->Name(); VLOG(3) << "Found forward_op " << forward_op->Name();
PADDLE_ENFORCE(forward_op); PADDLE_ENFORCE(forward_op);
Node* new_optimizer_node = CreateNewSGDNode( Node* new_optimizer_node = CreateNewSGDNode(
graph.get(), forward_op, backward_op, node, opt_node); graph, forward_op, backward_op, node, opt_node);
PADDLE_ENFORCE(new_optimizer_node); PADDLE_ENFORCE(new_optimizer_node);
} }
...@@ -140,8 +139,6 @@ std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl( ...@@ -140,8 +139,6 @@ std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl(
} }
} }
} }
return graph;
} }
ir::Node* LockFreeOptimizePass::CreateNewSGDNode( ir::Node* LockFreeOptimizePass::CreateNewSGDNode(
......
...@@ -60,8 +60,7 @@ class LockFreeOptimizePass : public Pass { ...@@ -60,8 +60,7 @@ class LockFreeOptimizePass : public Pass {
virtual ~LockFreeOptimizePass() {} virtual ~LockFreeOptimizePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
private: private:
// Create a new sgd node via current optimizer node // Create a new sgd node via current optimizer node
......
...@@ -38,10 +38,9 @@ LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b, ...@@ -38,10 +38,9 @@ LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
return vec_y; return vec_y;
} }
std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl( void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get());
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE(scope);
...@@ -99,7 +98,7 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl( ...@@ -99,7 +98,7 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
conv->Op()->SetOutput("Output", conv->Op()->SetOutput("Output",
std::vector<std::string>({eltwise_out->Name()})); std::vector<std::string>({eltwise_out->Name()}));
GraphSafeRemoveNodes(graph.get(), {eltwise, conv_out}); GraphSafeRemoveNodes(graph, {eltwise, conv_out});
IR_NODE_LINK_TO(conv, eltwise_out); IR_NODE_LINK_TO(conv, eltwise_out);
} else { } else {
...@@ -123,14 +122,13 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl( ...@@ -123,14 +122,13 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
IR_NODE_LINK_TO(eltwise_bias, conv_bias_node); IR_NODE_LINK_TO(eltwise_bias, conv_bias_node);
IR_NODE_LINK_TO(conv_bias_node, eltwise_out); IR_NODE_LINK_TO(conv_bias_node, eltwise_out);
GraphSafeRemoveNodes(graph.get(), {conv, eltwise, conv_out}); GraphSafeRemoveNodes(graph, {conv, eltwise, conv_out});
} }
found_conv_bias_count++; found_conv_bias_count++;
}; };
gpd(graph.get(), handler); gpd(graph, handler);
AddStatis(found_conv_bias_count); AddStatis(found_conv_bias_count);
return graph;
} }
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -29,8 +29,7 @@ class ConvBiasFusePass : public FusePassBase { ...@@ -29,8 +29,7 @@ class ConvBiasFusePass : public FusePassBase {
virtual bool is_conv3d() const { return false; } virtual bool is_conv3d() const { return false; }
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"conv_bias_mkldnn_fuse"}; const std::string name_scope_{"conv_bias_mkldnn_fuse"};
}; };
/* /*
......
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle { namespace paddle {
...@@ -103,7 +103,7 @@ void MainTest(bool convWithExistingBias) { ...@@ -103,7 +103,7 @@ void MainTest(bool convWithExistingBias) {
int original_nodes_num = graph->Nodes().size(); int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size(); int current_nodes_num = graph->Nodes().size();
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
#include <functional> #include <functional>
#include <list> #include <list>
#include <map> #include <map>
#include <memory>
#include <tuple> #include <tuple>
#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/graph_traits.h"
namespace paddle { namespace paddle {
...@@ -327,17 +327,15 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( ...@@ -327,17 +327,15 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
get_node_from_elementwise_add); get_node_from_elementwise_add);
} }
graph_ptr ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { void ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
FusePassBase::Init(name_scope_, graph.get()); FusePassBase::Init(name_scope_, graph);
auto fused_graph_with_stats = FuseConvAsY( auto fused_graph_with_stats = FuseConvAsY(
name_scope_, name_scope_,
FuseConvAsX( FuseConvAsX(name_scope_,
name_scope_, FuseProjectionConv(name_scope_, std::make_pair(graph, 0))));
FuseProjectionConv(name_scope_, std::make_pair(graph.get(), 0))));
std::cout << "Fused graph " << fused_graph_with_stats.second << std::endl; std::cout << "Fused graph " << fused_graph_with_stats.second << std::endl;
AddStatis(fused_graph_with_stats.second); AddStatis(fused_graph_with_stats.second);
return graph;
} }
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
...@@ -27,7 +28,7 @@ namespace paddle { ...@@ -27,7 +28,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
using graph_ptr = std::unique_ptr<ir::Graph>; using graph_ptr = ir::Graph*;
using GraphWithStats = std::pair<ir::Graph*, int>; using GraphWithStats = std::pair<ir::Graph*, int>;
void CorrectGraphEdges(Graph* graph, Node* from, Node* to); void CorrectGraphEdges(Graph* graph, Node* from, Node* to);
...@@ -124,7 +125,7 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { ...@@ -124,7 +125,7 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
virtual ~ResidualConnectionMKLDNNFusePass() {} virtual ~ResidualConnectionMKLDNNFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl(graph_ptr graph) const; void ApplyImpl(graph_ptr graph) const;
const std::string name_scope_{"residual_connection_fuse_pass"}; const std::string name_scope_{"residual_connection_fuse_pass"};
}; };
......
...@@ -148,7 +148,7 @@ void RunPassAndAssert(ProgramDesc* prog, const std::string& from, ...@@ -148,7 +148,7 @@ void RunPassAndAssert(ProgramDesc* prog, const std::string& from,
auto pass = auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size(); int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size(); int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)(from, to)); EXPECT_TRUE(is_reachable(graph)(from, to));
...@@ -258,7 +258,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) { ...@@ -258,7 +258,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) {
auto pass = auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass"); PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size(); int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size(); int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "g")); EXPECT_TRUE(is_reachable(graph)("a", "g"));
......
...@@ -21,10 +21,9 @@ namespace paddle { ...@@ -21,10 +21,9 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl( void ConvReLUFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init("conv_relu_mkldnn_fuse", graph);
FusePassBase::Init("conv_relu_mkldnn_fuse", graph.get());
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* conv_input = gpd.mutable_pattern() auto* conv_input = gpd.mutable_pattern()
...@@ -56,7 +55,7 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl( ...@@ -56,7 +55,7 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
OpDesc* desc = conv->Op(); OpDesc* desc = conv->Op();
desc->SetOutput("Output", std::vector<std::string>({relu_out->Name()})); desc->SetOutput("Output", std::vector<std::string>({relu_out->Name()}));
desc->SetAttr("fuse_relu", true); desc->SetAttr("fuse_relu", true);
GraphSafeRemoveNodes(graph.get(), {relu, conv_out}); GraphSafeRemoveNodes(graph, {relu, conv_out});
PADDLE_ENFORCE(subgraph.count(conv_input)); PADDLE_ENFORCE(subgraph.count(conv_input));
IR_NODE_LINK_TO(conv, relu_out); IR_NODE_LINK_TO(conv, relu_out);
...@@ -64,10 +63,9 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl( ...@@ -64,10 +63,9 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
found_conv_relu_count++; found_conv_relu_count++;
}; };
gpd(graph.get(), handler); gpd(graph, handler);
AddStatis(found_conv_relu_count); AddStatis(found_conv_relu_count);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -31,8 +31,7 @@ class ConvReLUFusePass : public FusePassBase { ...@@ -31,8 +31,7 @@ class ConvReLUFusePass : public FusePassBase {
virtual ~ConvReLUFusePass() {} virtual ~ConvReLUFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -88,7 +88,7 @@ TEST(ConvReLUFusePass, basic) { ...@@ -88,7 +88,7 @@ TEST(ConvReLUFusePass, basic) {
int original_nodes_num = graph->Nodes().size(); int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size(); int current_nodes_num = graph->Nodes().size();
......
...@@ -216,19 +216,16 @@ void CPUQuantizePass::QuantizePool(Graph* graph) const { ...@@ -216,19 +216,16 @@ void CPUQuantizePass::QuantizePool(Graph* graph) const {
PrettyLogDetail("--- quantized %d pool2d ops", quantize_pool_count); PrettyLogDetail("--- quantized %d pool2d ops", quantize_pool_count);
} }
std::unique_ptr<ir::Graph> CPUQuantizePass::ApplyImpl( void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Quantizing the graph."; VLOG(3) << "Quantizing the graph.";
PADDLE_ENFORCE(graph.get()); PADDLE_ENFORCE(graph);
FusePassBase::Init(name_scope_, graph.get()); FusePassBase::Init(name_scope_, graph);
PADDLE_ENFORCE(param_scope()); PADDLE_ENFORCE(param_scope());
QuantizeConv(graph.get(), false /* with_residual_data */); QuantizeConv(graph, false /* with_residual_data */);
QuantizeConv(graph.get(), true /* with_residual_data */); QuantizeConv(graph, true /* with_residual_data */);
QuantizePool(graph.get()); QuantizePool(graph);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -42,8 +42,7 @@ class CPUQuantizePass : public FusePassBase { ...@@ -42,8 +42,7 @@ class CPUQuantizePass : public FusePassBase {
virtual ~CPUQuantizePass() {} virtual ~CPUQuantizePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
void QuantizeConv(Graph* graph, bool with_residual_data = false) const; void QuantizeConv(Graph* graph, bool with_residual_data = false) const;
......
...@@ -139,7 +139,7 @@ void MainTest(const ProgramDesc& prog, int conv_count, int pool_count, ...@@ -139,7 +139,7 @@ void MainTest(const ProgramDesc& prog, int conv_count, int pool_count,
int original_nodes_num = graph->Nodes().size(); int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size(); int current_nodes_num = graph->Nodes().size();
......
...@@ -20,8 +20,7 @@ namespace paddle { ...@@ -20,8 +20,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<ir::Graph> CPUQuantizePlacementPass::ApplyImpl( void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Marks operators which are to be quantized."; VLOG(3) << "Marks operators which are to be quantized.";
const auto& excluded_ids_list = const auto& excluded_ids_list =
Get<std::unordered_set<int>>("quantize_excluded_op_ids"); Get<std::unordered_set<int>>("quantize_excluded_op_ids");
...@@ -43,7 +42,6 @@ std::unique_ptr<ir::Graph> CPUQuantizePlacementPass::ApplyImpl( ...@@ -43,7 +42,6 @@ std::unique_ptr<ir::Graph> CPUQuantizePlacementPass::ApplyImpl(
} }
} }
} }
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -25,8 +25,7 @@ namespace ir { ...@@ -25,8 +25,7 @@ namespace ir {
*/ */
class CPUQuantizePlacementPass : public Pass { class CPUQuantizePlacementPass : public Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -94,7 +94,7 @@ void MainTest(std::initializer_list<std::string> quantize_enabled_op_types, ...@@ -94,7 +94,7 @@ void MainTest(std::initializer_list<std::string> quantize_enabled_op_types,
pass->Set("quantize_excluded_op_ids", pass->Set("quantize_excluded_op_ids",
new std::unordered_set<int>(quantize_excluded_op_ids)); new std::unordered_set<int>(quantize_excluded_op_ids));
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
unsigned use_quantizer_true_count = 0; unsigned use_quantizer_true_count = 0;
......
...@@ -126,16 +126,13 @@ void CPUQuantizeSquashPass::Squash( ...@@ -126,16 +126,13 @@ void CPUQuantizeSquashPass::Squash(
found_squash_count); found_squash_count);
} }
std::unique_ptr<ir::Graph> CPUQuantizeSquashPass::ApplyImpl( void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init("cpu_quantize_squash_pass", graph);
FusePassBase::Init("cpu_quantize_squash_pass", graph.get());
std::unordered_map<const Node*, int> nodes_keep_counter; std::unordered_map<const Node*, int> nodes_keep_counter;
FindNodesToKeep(graph.get(), &nodes_keep_counter); FindNodesToKeep(graph, &nodes_keep_counter);
Squash(graph.get(), &nodes_keep_counter); Squash(graph, &nodes_keep_counter);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -34,8 +34,7 @@ class CPUQuantizeSquashPass : public FusePassBase { ...@@ -34,8 +34,7 @@ class CPUQuantizeSquashPass : public FusePassBase {
virtual ~CPUQuantizeSquashPass() {} virtual ~CPUQuantizeSquashPass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
/* /*
* For each dequantize's output find the number of operators it is an input to * For each dequantize's output find the number of operators it is an input to
......
...@@ -125,7 +125,7 @@ void MainTest(const ProgramDesc& prog, int removed_nodes_num) { ...@@ -125,7 +125,7 @@ void MainTest(const ProgramDesc& prog, int removed_nodes_num) {
int original_nodes_num = graph->Nodes().size(); int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size(); int current_nodes_num = graph->Nodes().size();
......
...@@ -25,10 +25,9 @@ namespace ir { ...@@ -25,10 +25,9 @@ namespace ir {
auto* id = subgraph.at(pattern.RetrieveNode(#id)); \ auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
std::unique_ptr<ir::Graph> DepthwiseConvMKLDNNPass::ApplyImpl( void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(graph.get()); FusePassBase::Init("depthwise_conv_mkldnn_pass", graph);
FusePassBase::Init("depthwise_conv_mkldnn_pass", graph.get());
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
...@@ -45,9 +44,8 @@ std::unique_ptr<ir::Graph> DepthwiseConvMKLDNNPass::ApplyImpl( ...@@ -45,9 +44,8 @@ std::unique_ptr<ir::Graph> DepthwiseConvMKLDNNPass::ApplyImpl(
found_depthwise_conv_mkldnn_count++; found_depthwise_conv_mkldnn_count++;
}; };
gpd(graph.get(), handler); gpd(graph, handler);
AddStatis(found_depthwise_conv_mkldnn_count); AddStatis(found_depthwise_conv_mkldnn_count);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -25,8 +25,7 @@ class DepthwiseConvMKLDNNPass : public FusePassBase { ...@@ -25,8 +25,7 @@ class DepthwiseConvMKLDNNPass : public FusePassBase {
virtual ~DepthwiseConvMKLDNNPass() {} virtual ~DepthwiseConvMKLDNNPass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -86,7 +86,7 @@ TEST(DepthwiseConvMKLDNNPass, basic) { ...@@ -86,7 +86,7 @@ TEST(DepthwiseConvMKLDNNPass, basic) {
counters before{1, 1, 1, 1}; counters before{1, 1, 1, 1};
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
// initialize counters before loop // initialize counters before loop
counters after{0, 0, 0, 0}; counters after{0, 0, 0, 0};
......
...@@ -14,13 +14,13 @@ limitations under the License. */ ...@@ -14,13 +14,13 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h" #include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h"
#include <string> #include <string>
#include <unordered_set>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl( void MKLDNNPlacementPass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Applies MKL-DNN placement strategy."; VLOG(3) << "Applies MKL-DNN placement strategy.";
const auto& op_types_list = const auto& op_types_list =
Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types"); Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types");
...@@ -37,7 +37,6 @@ std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl( ...@@ -37,7 +37,6 @@ std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl(
} }
} }
} }
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -26,8 +26,7 @@ namespace ir { ...@@ -26,8 +26,7 @@ namespace ir {
*/ */
class MKLDNNPlacementPass : public Pass { class MKLDNNPlacementPass : public Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -97,7 +97,7 @@ void MainTest(std::initializer_list<std::string> mkldnn_enabled_op_types, ...@@ -97,7 +97,7 @@ void MainTest(std::initializer_list<std::string> mkldnn_enabled_op_types,
pass->Set("mkldnn_enabled_op_types", pass->Set("mkldnn_enabled_op_types",
new std::unordered_set<std::string>(mkldnn_enabled_op_types)); new std::unordered_set<std::string>(mkldnn_enabled_op_types));
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
unsigned use_mkldnn_true_count = 0; unsigned use_mkldnn_true_count = 0;
......
...@@ -16,8 +16,9 @@ ...@@ -16,8 +16,9 @@
#include <map> #include <map>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
...@@ -68,8 +69,7 @@ VarDesc UpdateGradVarDesc( ...@@ -68,8 +69,7 @@ VarDesc UpdateGradVarDesc(
return *var_desc; return *var_desc;
} }
std::unique_ptr<Graph> BatchMergePass::ApplyImpl( void BatchMergePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<Graph> graph) const {
int num_repeats = Get<const int>(kNumRepeats); int num_repeats = Get<const int>(kNumRepeats);
std::vector<Node*> forward_backward_ops; std::vector<Node*> forward_backward_ops;
std::vector<Node*> optimize_ops; std::vector<Node*> optimize_ops;
...@@ -325,7 +325,6 @@ std::unique_ptr<Graph> BatchMergePass::ApplyImpl( ...@@ -325,7 +325,6 @@ std::unique_ptr<Graph> BatchMergePass::ApplyImpl(
} }
result.ResolveHazard(created); result.ResolveHazard(created);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -36,7 +36,7 @@ class BatchMergePass : public Pass { ...@@ -36,7 +36,7 @@ class BatchMergePass : public Pass {
virtual ~BatchMergePass() {} virtual ~BatchMergePass() {}
protected: protected:
std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const override; void ApplyImpl(Graph* graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -18,8 +18,8 @@ limitations under the License. */ ...@@ -18,8 +18,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<Graph> Pass::Apply(std::unique_ptr<Graph> graph) const { Graph* Pass::Apply(Graph* graph) const {
PADDLE_ENFORCE(graph.get(), "graph passed to Pass::Apply() cannot be empty."); PADDLE_ENFORCE(graph, "graph passed to Pass::Apply() cannot be empty.");
for (const std::string& attr : required_pass_attrs_) { for (const std::string& attr : required_pass_attrs_) {
PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(), PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(),
"Required pass atrribute %s not set.", attr); "Required pass atrribute %s not set.", attr);
...@@ -28,16 +28,16 @@ std::unique_ptr<Graph> Pass::Apply(std::unique_ptr<Graph> graph) const { ...@@ -28,16 +28,16 @@ std::unique_ptr<Graph> Pass::Apply(std::unique_ptr<Graph> graph) const {
PADDLE_ENFORCE(graph->Has(attr), "Required graph atrribute %s not set.", PADDLE_ENFORCE(graph->Has(attr), "Required graph atrribute %s not set.",
attr); attr);
} }
auto* native_graph = graph.get(); auto* native_graph = graph;
auto applied_graph = ApplyImpl(std::move(graph)); ApplyImpl(graph);
// TODO(panyx0718): Add more verifications. // TODO(panyx0718): Add more verifications.
PADDLE_ENFORCE(!HasCircle(*applied_graph), PADDLE_ENFORCE(!HasCircle(*graph),
"Illegal Pass. Generated graph shouldn't has cycle."); "Illegal Pass. Generated graph shouldn't has cycle.");
PADDLE_ENFORCE(applied_graph.get() == native_graph, PADDLE_ENFORCE(graph == native_graph,
"Pass::Apply() cannot delete the passed graph and shouldn't " "Pass::Apply() cannot delete the passed graph and shouldn't "
"return a new graph.(For the need of pybind11)"); "return a new graph.(For the need of pybind11)");
applied_ = true; applied_ = true;
return applied_graph; return graph;
} }
PassRegistry& PassRegistry::Instance() { PassRegistry& PassRegistry::Instance() {
......
...@@ -16,8 +16,10 @@ limitations under the License. */ ...@@ -16,8 +16,10 @@ limitations under the License. */
#include <functional> #include <functional>
#include <map> #include <map>
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
...@@ -44,7 +46,7 @@ class Pass { ...@@ -44,7 +46,7 @@ class Pass {
std::string Type() const { return type_; } std::string Type() const { return type_; }
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const; Graph *Apply(Graph *graph) const;
// Get a reference to the attributed previously set. // Get a reference to the attributed previously set.
template <typename AttrType> template <typename AttrType>
...@@ -98,9 +100,8 @@ class Pass { ...@@ -98,9 +100,8 @@ class Pass {
} }
protected: protected:
virtual std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const { virtual void ApplyImpl(Graph *graph) const {
LOG(FATAL) << "Calling virtual Pass not implemented."; LOG(FATAL) << "Calling virtual Pass not implemented.";
return graph;
} }
private: private:
......
...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include <memory>
#include <string> #include <string>
#include <utility>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -39,7 +41,7 @@ void BuildCircleGraph(Graph* g) { ...@@ -39,7 +41,7 @@ void BuildCircleGraph(Graph* g) {
class TestPass : public Pass { class TestPass : public Pass {
protected: protected:
std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const { void ApplyImpl(ir::Graph* graph) const {
graph->Set<int>("copy_test_pass_attr", new int); graph->Set<int>("copy_test_pass_attr", new int);
graph->Set<int>("copy_test_graph_attr", new int); graph->Set<int>("copy_test_graph_attr", new int);
...@@ -48,7 +50,6 @@ class TestPass : public Pass { ...@@ -48,7 +50,6 @@ class TestPass : public Pass {
int test_graph_attr = graph->Get<int>("test_graph_attr"); int test_graph_attr = graph->Get<int>("test_graph_attr");
graph->Get<int>("copy_test_graph_attr") = test_graph_attr + 1; graph->Get<int>("copy_test_graph_attr") = test_graph_attr + 1;
return graph;
} }
}; };
...@@ -58,7 +59,7 @@ TEST(PassTest, TestPassAttrCheck) { ...@@ -58,7 +59,7 @@ TEST(PassTest, TestPassAttrCheck) {
std::unique_ptr<Graph> graph(new Graph(prog)); std::unique_ptr<Graph> graph(new Graph(prog));
std::string exception; std::string exception;
try { try {
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
} catch (paddle::platform::EnforceNotMet e) { } catch (paddle::platform::EnforceNotMet e) {
exception = std::string(e.what()); exception = std::string(e.what());
} }
...@@ -69,7 +70,7 @@ TEST(PassTest, TestPassAttrCheck) { ...@@ -69,7 +70,7 @@ TEST(PassTest, TestPassAttrCheck) {
pass->SetNotOwned<int>("test_pass_attr", &val); pass->SetNotOwned<int>("test_pass_attr", &val);
try { try {
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
} catch (paddle::platform::EnforceNotMet e) { } catch (paddle::platform::EnforceNotMet e) {
exception = std::string(e.what()); exception = std::string(e.what());
} }
...@@ -78,14 +79,14 @@ TEST(PassTest, TestPassAttrCheck) { ...@@ -78,14 +79,14 @@ TEST(PassTest, TestPassAttrCheck) {
graph.reset(new Graph(prog)); graph.reset(new Graph(prog));
graph->Set<int>("test_graph_attr", new int); graph->Set<int>("test_graph_attr", new int);
graph->Get<int>("test_graph_attr") = 1; graph->Get<int>("test_graph_attr") = 1;
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
ASSERT_EQ(graph->Get<int>("copy_test_pass_attr"), 2); ASSERT_EQ(graph->Get<int>("copy_test_pass_attr"), 2);
ASSERT_EQ(graph->Get<int>("copy_test_graph_attr"), 2); ASSERT_EQ(graph->Get<int>("copy_test_graph_attr"), 2);
// Allow apply more than once. // Allow apply more than once.
graph.reset(new Graph(prog)); graph.reset(new Graph(prog));
graph->Set<int>("test_graph_attr", new int); graph->Set<int>("test_graph_attr", new int);
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
pass = PassRegistry::Instance().Get("test_pass"); pass = PassRegistry::Instance().Get("test_pass");
pass->SetNotOwned<int>("test_pass_attr", &val); pass->SetNotOwned<int>("test_pass_attr", &val);
...@@ -94,7 +95,7 @@ TEST(PassTest, TestPassAttrCheck) { ...@@ -94,7 +95,7 @@ TEST(PassTest, TestPassAttrCheck) {
graph->Set<int>("test_graph_attr", new int); graph->Set<int>("test_graph_attr", new int);
graph->Get<int>("test_graph_attr") = 2; graph->Get<int>("test_graph_attr") = 2;
try { try {
auto tmp = pass->Apply(std::move(graph)); pass->Apply(graph.release());
} catch (paddle::platform::EnforceNotMet e) { } catch (paddle::platform::EnforceNotMet e) {
exception = std::string(e.what()); exception = std::string(e.what());
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h" #include "paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h"
#include <algorithm> // for max #include <algorithm> // for max
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -365,17 +366,14 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -365,17 +366,14 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
return fusion_count; return fusion_count;
} }
std::unique_ptr<ir::Graph> RepeatedFCReluFusePass::ApplyImpl( void RepeatedFCReluFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = 0; int fusion_count = 0;
for (int i = MAX_NUM_FC; i > 1; --i) { for (int i = MAX_NUM_FC; i > 1; --i) {
fusion_count += fusion_count +=
BuildFusion(graph.get(), name_scope_ + "/" + std::to_string(i), i); BuildFusion(graph, name_scope_ + "/" + std::to_string(i), i);
} }
AddStatis(fusion_count); AddStatis(fusion_count);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -31,8 +31,7 @@ class RepeatedFCReluFusePass : public FusePassBase { ...@@ -31,8 +31,7 @@ class RepeatedFCReluFusePass : public FusePassBase {
virtual ~RepeatedFCReluFusePass() {} virtual ~RepeatedFCReluFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"repeated_fc_relu_fuse"}; const std::string name_scope_{"repeated_fc_relu_fuse"};
}; };
......
...@@ -20,15 +20,13 @@ namespace paddle { ...@@ -20,15 +20,13 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<ir::Graph> RuntimeContextCachePass::ApplyImpl( void RuntimeContextCachePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Applies Runtime Context Cache strategy."; VLOG(3) << "Applies Runtime Context Cache strategy.";
for (const Node* n : graph->Nodes()) { for (const Node* n : graph->Nodes()) {
if (n->IsOp()) { if (n->IsOp()) {
n->Op()->SetAttr(kEnableCacheRuntimeContext, true); n->Op()->SetAttr(kEnableCacheRuntimeContext, true);
} }
} }
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -23,8 +23,7 @@ namespace ir { ...@@ -23,8 +23,7 @@ namespace ir {
class RuntimeContextCachePass : public Pass { class RuntimeContextCachePass : public Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h"
#include <set> #include <set>
#include <string> #include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
namespace paddle { namespace paddle {
...@@ -178,9 +178,8 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) { ...@@ -178,9 +178,8 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) {
return fc_out; return fc_out;
} }
std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl( void SeqConcatFcFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { FusePassBase::Init("seq_concat_fc_fuse", graph);
FusePassBase::Init("seq_concat_fc_fuse", graph.get());
GraphPatternDetector detector; GraphPatternDetector detector;
auto* pattern = detector.mutable_pattern(); auto* pattern = detector.mutable_pattern();
auto* concat_out = BuildSeqExpandConcatPattern(pattern); auto* concat_out = BuildSeqExpandConcatPattern(pattern);
...@@ -194,7 +193,7 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl( ...@@ -194,7 +193,7 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
int fuse_count{0}; int fuse_count{0};
detector(graph.get(), [&](const GraphPatternDetector::subgraph_t& subgraph, detector(graph, [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) { Graph* graph) {
VLOG(4) << "get one concat pattern"; VLOG(4) << "get one concat pattern";
// fc // fc
...@@ -246,8 +245,6 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl( ...@@ -246,8 +245,6 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
}); });
AddStatis(fuse_count); AddStatis(fuse_count);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -27,8 +27,7 @@ class SeqConcatFcFusePass : public FusePassBase { ...@@ -27,8 +27,7 @@ class SeqConcatFcFusePass : public FusePassBase {
virtual ~SeqConcatFcFusePass() {} virtual ~SeqConcatFcFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h" #include "paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h"
#include <string> #include <string>
#include <unordered_set>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
namespace paddle { namespace paddle {
...@@ -83,14 +84,11 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) { ...@@ -83,14 +84,11 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) {
return fusion_count; return fusion_count;
} }
std::unique_ptr<ir::Graph> SeqConvEltAddReluFusePass::ApplyImpl( void SeqConvEltAddReluFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope()); int fusion_count = BuildFusion(graph, name_scope_, param_scope());
AddStatis(fusion_count); AddStatis(fusion_count);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -28,8 +28,7 @@ class SeqConvEltAddReluFusePass : public FusePassBase { ...@@ -28,8 +28,7 @@ class SeqConvEltAddReluFusePass : public FusePassBase {
virtual ~SeqConvEltAddReluFusePass() {} virtual ~SeqConvEltAddReluFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"seqconv_eltadd_relu_fuse"}; const std::string name_scope_{"seqconv_eltadd_relu_fuse"};
}; };
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h" #include "paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h"
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -194,17 +195,14 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -194,17 +195,14 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
return fusion_count; return fusion_count;
} }
std::unique_ptr<ir::Graph> SeqPoolConcatFusePass::ApplyImpl( void SeqPoolConcatFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = 0; int fusion_count = 0;
for (int i = MAX_CONCAT_INPUTS; i > 0; --i) { for (int i = MAX_CONCAT_INPUTS; i > 0; --i) {
fusion_count += fusion_count +=
BuildFusion(graph.get(), name_scope_ + "/" + std::to_string(i), i); BuildFusion(graph, name_scope_ + "/" + std::to_string(i), i);
} }
AddStatis(fusion_count); AddStatis(fusion_count);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -42,8 +42,7 @@ class SeqPoolConcatFusePass : public FusePassBase { ...@@ -42,8 +42,7 @@ class SeqPoolConcatFusePass : public FusePassBase {
virtual ~SeqPoolConcatFusePass() {} virtual ~SeqPoolConcatFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"seqpool_concat_fuse"}; const std::string name_scope_{"seqpool_concat_fuse"};
}; };
......
...@@ -59,7 +59,7 @@ std::unique_ptr<ir::Graph> GetNumNodesOfBeforeAfter( ...@@ -59,7 +59,7 @@ std::unique_ptr<ir::Graph> GetNumNodesOfBeforeAfter(
const std::string& pass_type = "seqpool_concat_fuse_pass") { const std::string& pass_type = "seqpool_concat_fuse_pass") {
auto pass = PassRegistry::Instance().Get(pass_type); auto pass = PassRegistry::Instance().Get(pass_type);
*before = graph->Nodes().size(); *before = graph->Nodes().size();
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
*after = graph->Nodes().size(); *after = graph->Nodes().size();
return graph; return graph;
} }
......
...@@ -24,11 +24,11 @@ namespace framework { ...@@ -24,11 +24,11 @@ namespace framework {
namespace ir { namespace ir {
template <int times> template <int times>
std::unique_ptr<ir::Graph> SimplifyAnakinDetectionPatternPass<times>::ApplyImpl( void SimplifyAnakinDetectionPatternPass<times>::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { ir::Graph *graph) const {
const std::string pattern_name = const std::string pattern_name =
"simplify_anakin_detection_pattern_pass" + std::to_string(times); "simplify_anakin_detection_pattern_pass" + std::to_string(times);
FusePassBase::Init(pattern_name, graph.get()); FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
std::vector<PDNode *> input_nodes; std::vector<PDNode *> input_nodes;
...@@ -207,11 +207,10 @@ std::unique_ptr<ir::Graph> SimplifyAnakinDetectionPatternPass<times>::ApplyImpl( ...@@ -207,11 +207,10 @@ std::unique_ptr<ir::Graph> SimplifyAnakinDetectionPatternPass<times>::ApplyImpl(
multiclass_nms_out->inputs.push_back(detection_out_op); multiclass_nms_out->inputs.push_back(detection_out_op);
// Delete the unneeded nodes. // Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(), delete_nodes); GraphSafeRemoveNodes(graph, delete_nodes);
}; };
gpd(graph.get(), handler); gpd(graph, handler);
return graph;
} }
template class SimplifyAnakinDetectionPatternPass<1>; template class SimplifyAnakinDetectionPatternPass<1>;
......
...@@ -32,8 +32,7 @@ class SimplifyAnakinDetectionPatternPass : public FusePassBase { ...@@ -32,8 +32,7 @@ class SimplifyAnakinDetectionPatternPass : public FusePassBase {
virtual ~SimplifyAnakinDetectionPatternPass() {} virtual ~SimplifyAnakinDetectionPatternPass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h" #include "paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h"
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -362,13 +363,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { ...@@ -362,13 +363,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
return fusion_count; return fusion_count;
} }
std::unique_ptr<ir::Graph> SquaredMatSubFusePass::ApplyImpl( void SquaredMatSubFusePass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const { FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(name_scope_, graph.get()); int fusion_count = BuildFusion(graph, name_scope_);
int fusion_count = BuildFusion(graph.get(), name_scope_);
AddStatis(fusion_count); AddStatis(fusion_count);
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -31,8 +31,7 @@ class SquaredMatSubFusePass : public FusePassBase { ...@@ -31,8 +31,7 @@ class SquaredMatSubFusePass : public FusePassBase {
virtual ~SquaredMatSubFusePass() {} virtual ~SquaredMatSubFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
const std::string name_scope_{"squared_mat_sub_fuse"}; const std::string name_scope_{"squared_mat_sub_fuse"};
}; };
......
...@@ -21,8 +21,7 @@ namespace paddle { ...@@ -21,8 +21,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<ir::Graph> SyncBatchNormPass::ApplyImpl( void SyncBatchNormPass::ApplyImpl(ir::Graph* graph) const {
std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Use synchronous batch norm"; VLOG(3) << "Use synchronous batch norm";
for (const Node* n : graph->Nodes()) { for (const Node* n : graph->Nodes()) {
if (n->IsOp()) { if (n->IsOp()) {
...@@ -35,7 +34,6 @@ std::unique_ptr<ir::Graph> SyncBatchNormPass::ApplyImpl( ...@@ -35,7 +34,6 @@ std::unique_ptr<ir::Graph> SyncBatchNormPass::ApplyImpl(
} }
} }
} }
return graph;
} }
} // namespace ir } // namespace ir
......
...@@ -23,8 +23,7 @@ namespace ir { ...@@ -23,8 +23,7 @@ namespace ir {
class SyncBatchNormPass : public Pass { class SyncBatchNormPass : public Pass {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -60,7 +60,7 @@ TEST(IsTestPass, basic) { ...@@ -60,7 +60,7 @@ TEST(IsTestPass, basic) {
auto pass = PassRegistry::Instance().Get("sync_batch_norm_pass"); auto pass = PassRegistry::Instance().Get("sync_batch_norm_pass");
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (node->IsOp()) { if (node->IsOp()) {
......
...@@ -26,11 +26,10 @@ namespace framework { ...@@ -26,11 +26,10 @@ namespace framework {
namespace ir { namespace ir {
template <int times> template <int times>
std::unique_ptr<ir::Graph> TransposeFlattenConcatFusePass<times>::ApplyImpl( void TransposeFlattenConcatFusePass<times>::ApplyImpl(ir::Graph *graph) const {
std::unique_ptr<ir::Graph> graph) const {
const std::string pattern_name = const std::string pattern_name =
"transpose_flatten" + std::to_string(times) + "_concat_fuse"; "transpose_flatten" + std::to_string(times) + "_concat_fuse";
FusePassBase::Init(pattern_name, graph.get()); FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
std::vector<PDNode *> input_nodes; std::vector<PDNode *> input_nodes;
...@@ -117,11 +116,10 @@ std::unique_ptr<ir::Graph> TransposeFlattenConcatFusePass<times>::ApplyImpl( ...@@ -117,11 +116,10 @@ std::unique_ptr<ir::Graph> TransposeFlattenConcatFusePass<times>::ApplyImpl(
concat_out->inputs.push_back(new_conv_op); concat_out->inputs.push_back(new_conv_op);
// Delete the unneeded nodes. // Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(), delete_nodes); GraphSafeRemoveNodes(graph, delete_nodes);
}; };
gpd(graph.get(), handler); gpd(graph, handler);
return graph;
} }
template class TransposeFlattenConcatFusePass<1>; template class TransposeFlattenConcatFusePass<1>;
......
...@@ -30,8 +30,7 @@ class TransposeFlattenConcatFusePass : public FusePassBase { ...@@ -30,8 +30,7 @@ class TransposeFlattenConcatFusePass : public FusePassBase {
virtual ~TransposeFlattenConcatFusePass() {} virtual ~TransposeFlattenConcatFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( void ApplyImpl(ir::Graph* graph) const override;
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -373,6 +373,11 @@ std::vector<std::string> OpDesc::AttrNames() const { ...@@ -373,6 +373,11 @@ std::vector<std::string> OpDesc::AttrNames() const {
return retv; return retv;
} }
void OpDesc::RemoveAttr(const std::string &name) {
attrs_.erase(name);
need_update_ = true;
}
void OpDesc::SetAttr(const std::string &name, const Attribute &v) { void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
// NOTICE(minqiyang): pybind11 will take the empty list in python as // NOTICE(minqiyang): pybind11 will take the empty list in python as
// the std::vector<int> type in C++; so we have to change the attr's type // the std::vector<int> type in C++; so we have to change the attr's type
...@@ -644,6 +649,7 @@ void OpDesc::CheckAttrs() { ...@@ -644,6 +649,7 @@ void OpDesc::CheckAttrs() {
// not by users. // not by users.
return; return;
} }
VLOG(10) << "begin to check attribute of " << Type();
checker->Check(&attrs_); checker->Check(&attrs_);
} }
......
...@@ -72,6 +72,7 @@ class OpDesc { ...@@ -72,6 +72,7 @@ class OpDesc {
std::vector<std::string> AttrNames() const; std::vector<std::string> AttrNames() const;
void SetAttr(const std::string &name, const Attribute &v); void SetAttr(const std::string &name, const Attribute &v);
void RemoveAttr(const std::string &name);
void SetBlockAttr(const std::string &name, BlockDesc *block); void SetBlockAttr(const std::string &name, BlockDesc *block);
......
...@@ -1110,8 +1110,9 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( ...@@ -1110,8 +1110,9 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
proto::VarType::Type tmp = t->type(); proto::VarType::Type tmp = t->type();
PADDLE_ENFORCE( PADDLE_ENFORCE(
tmp == data_type || data_type == dafault_data_type, tmp == data_type || data_type == dafault_data_type,
"DataType of Paddle Op %s must be the same. Get (%d) != (%d)", "DataType of Paddle Op %s %s must be the same. Get (%d) != (%d)",
Type(), DataTypeToString(data_type), DataTypeToString(tmp)); Type(), input.first, DataTypeToString(data_type),
DataTypeToString(tmp));
data_type = tmp; data_type = tmp;
} }
} }
......
...@@ -365,9 +365,6 @@ class ExecutionContext { ...@@ -365,9 +365,6 @@ class ExecutionContext {
auto shared_allocation = std::shared_ptr<memory::allocation::Allocation>( auto shared_allocation = std::shared_ptr<memory::allocation::Allocation>(
allocation_ptr, deleter); allocation_ptr, deleter);
PADDLE_ENFORCE(
dynamic_cast<platform::TemporaryAllocation*>(allocation_ptr) != nullptr,
"The AllocationPtr must be TemporaryAllocation.");
PADDLE_ENFORCE_GE(allocation_ptr->size(), PADDLE_ENFORCE_GE(allocation_ptr->size(),
framework::product(dim) * sizeof(T)); framework::product(dim) * sizeof(T));
......
...@@ -77,8 +77,7 @@ class ParallelExecutorPrivate { ...@@ -77,8 +77,7 @@ class ParallelExecutorPrivate {
} }
} }
std::unique_ptr<ir::Graph> PrepareGCAndRefCnts( ir::Graph *PrepareGCAndRefCnts(ir::Graph *graph, size_t max_memory_size);
std::unique_ptr<ir::Graph> graph, size_t max_memory_size);
inline bool HasGarbageCollectors() const { return !gcs_.empty(); } inline bool HasGarbageCollectors() const { return !gcs_.empty(); }
...@@ -118,8 +117,8 @@ class ParallelExecutorPrivate { ...@@ -118,8 +117,8 @@ class ParallelExecutorPrivate {
details::GarbageCollectorMap gcs_; details::GarbageCollectorMap gcs_;
}; };
std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts( ir::Graph *ParallelExecutorPrivate::PrepareGCAndRefCnts(
std::unique_ptr<ir::Graph> graph, size_t max_memory_size) { ir::Graph *graph, size_t max_memory_size) {
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &place = places_[i]; auto &place = places_[i];
if (gcs_.count(place) > 0) { if (gcs_.count(place) > 0) {
...@@ -161,7 +160,7 @@ std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts( ...@@ -161,7 +160,7 @@ std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts(
&global_ref_cnts_); &global_ref_cnts_);
ref_cnt_pass->SetNotOwned(details::kLastLiveOpsOfVars, ref_cnt_pass->SetNotOwned(details::kLastLiveOpsOfVars,
&last_live_ops_of_vars); &last_live_ops_of_vars);
graph = ref_cnt_pass->Apply(std::move(graph)); graph = ref_cnt_pass->Apply(graph);
VLOG(10) << "ReferenceCountPass Applied"; VLOG(10) << "ReferenceCountPass Applied";
auto eager_deletion_pass = auto eager_deletion_pass =
...@@ -172,10 +171,9 @@ std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts( ...@@ -172,10 +171,9 @@ std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts(
eager_deletion_pass->SetNotOwned(details::kLastLiveOpsOfVars, eager_deletion_pass->SetNotOwned(details::kLastLiveOpsOfVars,
&last_live_ops_of_vars); &last_live_ops_of_vars);
eager_deletion_pass->SetNotOwned(details::kAllPlaces, &places_); eager_deletion_pass->SetNotOwned(details::kAllPlaces, &places_);
graph = eager_deletion_pass->Apply(std::move(graph)); graph = eager_deletion_pass->Apply(graph);
VLOG(10) << "EagerDeletionPass Applied"; VLOG(10) << "EagerDeletionPass Applied";
} }
return graph; return graph;
} }
...@@ -220,13 +218,11 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -220,13 +218,11 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
} }
} }
std::unique_ptr<ir::Graph> temp_owned_graph(graph);
// FIXME(Yancey1989): parallel graph mode get better performance // FIXME(Yancey1989): parallel graph mode get better performance
// in GPU allreduce distributed training. Need an elegant way to // in GPU allreduce distributed training. Need an elegant way to
// choice the execution strategy. // choice the execution strategy.
build_strategy.enable_parallel_graph_ = EnableParallelGraphExecution( build_strategy.enable_parallel_graph_ =
*temp_owned_graph, exec_strategy, build_strategy); EnableParallelGraphExecution(*graph, exec_strategy, build_strategy);
if (build_strategy.enable_parallel_graph_) if (build_strategy.enable_parallel_graph_)
VLOG(0) << "The Executor would execute the graph by ParallelGraph " VLOG(0) << "The Executor would execute the graph by ParallelGraph "
"Execution which can get better performance," "Execution which can get better performance,"
...@@ -304,27 +300,21 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -304,27 +300,21 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
graph = build_strategy.Apply(graph, member_->places_, loss_var_name,
temp_owned_graph = build_strategy.Apply( member_->local_scopes_, member_->nranks_,
std::move(temp_owned_graph), member_->places_, loss_var_name, member_->use_cuda_, member_->nccl_ctxs_.get());
member_->local_scopes_, member_->nranks_, member_->use_cuda_,
member_->nccl_ctxs_.get());
#else #else
temp_owned_graph = build_strategy.Apply( graph = build_strategy.Apply(graph, member_->places_, loss_var_name,
std::move(temp_owned_graph), member_->places_, loss_var_name, member_->local_scopes_, member_->nranks_,
member_->local_scopes_, member_->nranks_, member_->use_cuda_); member_->use_cuda_);
#endif #endif
auto max_memory_size = GetEagerDeletionThreshold(); auto max_memory_size = GetEagerDeletionThreshold();
VLOG(10) << "Eager Deletion Threshold " VLOG(10) << "Eager Deletion Threshold "
<< static_cast<float>(max_memory_size) / (1 << 30); << static_cast<float>(max_memory_size) / (1 << 30);
if (max_memory_size >= 0) { if (max_memory_size >= 0) {
graph = member_ graph = member_->PrepareGCAndRefCnts(graph,
->PrepareGCAndRefCnts(std::move(temp_owned_graph), static_cast<size_t>(max_memory_size));
static_cast<size_t>(max_memory_size))
.release();
} else {
graph = temp_owned_graph.release();
} }
// Step 3. Create vars in each scope. Passes may also create new vars. // Step 3. Create vars in each scope. Passes may also create new vars.
......
...@@ -37,18 +37,29 @@ endif(WIN32) ...@@ -37,18 +37,29 @@ endif(WIN32)
add_subdirectory(api) add_subdirectory(api)
if(WITH_MKLDNN)
set(mkldnn_quantizer_src ${CMAKE_CURRENT_SOURCE_DIR}/api/mkldnn_quantizer.cc)
set(mkldnn_quantizer_cfg mkldnn_quantizer_config)
endif()
set(STATIC_INFERENCE_APIS paddle_fluid_api paddle_inference_api analysis_predictor) set(STATIC_INFERENCE_APIS paddle_fluid_api paddle_inference_api analysis_predictor)
set(SHARED_INFERENCE_SRCS set(SHARED_INFERENCE_SRCS
io.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api_impl.cc io.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api_impl.cc
${CMAKE_CURRENT_SOURCE_DIR}/api/analysis_predictor.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/analysis_predictor.cc
${mkldnn_quantizer_src}
${CMAKE_CURRENT_SOURCE_DIR}/api/details/zero_copy_tensor.cc) ${CMAKE_CURRENT_SOURCE_DIR}/api/details/zero_copy_tensor.cc)
# FIXME(gongwb): hidden libdgc.a
if(WITH_GPU AND NOT WIN32)
set(fluid_modules ${fluid_modules} dgc)
endif()
if(WIN32) if(WIN32)
sep_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} zero_copy_tensor reset_tensor_array sep_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} zero_copy_tensor reset_tensor_array
analysis_config paddle_pass_builder) analysis_config ${mkldnn_quantizer_cfg} paddle_pass_builder)
else(WIN32) else(WIN32)
cc_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} cc_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS}
zero_copy_tensor reset_tensor_array analysis_config paddle_pass_builder) zero_copy_tensor reset_tensor_array analysis_config ${mkldnn_quantizer_cfg} paddle_pass_builder)
endif(WIN32) endif(WIN32)
if(NOT APPLE) if(NOT APPLE)
...@@ -61,11 +72,11 @@ endif() ...@@ -61,11 +72,11 @@ endif()
if(WIN32) if(WIN32)
sep_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS} sep_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS}
DEPS ${fluid_modules} paddle_fluid_api reset_tensor_array DEPS ${fluid_modules} paddle_fluid_api reset_tensor_array
analysis_config paddle_pass_builder) analysis_config ${mkldnn_quantizer_cfg} paddle_pass_builder)
else(WIN32) else(WIN32)
cc_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS} cc_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS}
DEPS ${fluid_modules} paddle_fluid_api reset_tensor_array DEPS ${fluid_modules} paddle_fluid_api reset_tensor_array
analysis_config paddle_pass_builder) analysis_config ${mkldnn_quantizer_cfg} paddle_pass_builder)
endif() endif()
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(paddle_fluid_shared ${os_dependency_modules}) target_link_libraries(paddle_fluid_shared ${os_dependency_modules})
......
...@@ -140,7 +140,7 @@ std::unique_ptr<Graph> IRPassManager::Apply(std::unique_ptr<Graph> graph) { ...@@ -140,7 +140,7 @@ std::unique_ptr<Graph> IRPassManager::Apply(std::unique_ptr<Graph> graph) {
if (pass->Type() != "graph_viz_pass") { if (pass->Type() != "graph_viz_pass") {
PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass->Type()); PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass->Type());
} }
graph = pass->Apply(std::move(graph)); graph.reset(pass->Apply(graph.release()));
} }
return graph; return graph;
} }
...@@ -156,7 +156,7 @@ framework::proto::ProgramDesc IRPassManager::AcquireProgram( ...@@ -156,7 +156,7 @@ framework::proto::ProgramDesc IRPassManager::AcquireProgram(
desc.CopyFrom(*program->Proto()); desc.CopyFrom(*program->Proto());
pass->SetNotOwned("program", &desc); pass->SetNotOwned("program", &desc);
auto *the_graph = graph->release(); auto *the_graph = graph->release();
*graph = pass->Apply(std::unique_ptr<Graph>(the_graph)); graph->reset(pass->Apply(the_graph));
return *desc.Proto(); return *desc.Proto();
} }
......
...@@ -35,8 +35,8 @@ namespace analysis { ...@@ -35,8 +35,8 @@ namespace analysis {
using framework::ir::Node; using framework::ir::Node;
std::unique_ptr<framework::ir::Graph> analysis::AnakinSubgraphPass::ApplyImpl( void analysis::AnakinSubgraphPass::ApplyImpl(
std::unique_ptr<framework::ir::Graph> graph) const { framework::ir::Graph *graph) const {
framework::ir::FusePassBase::Init("anakin_subgraph_pass", graph.get()); framework::ir::FusePassBase::Init("anakin_subgraph_pass", graph.get());
auto teller = [](const framework::ir::Node *node) { auto teller = [](const framework::ir::Node *node) {
...@@ -72,8 +72,6 @@ std::unique_ptr<framework::ir::Graph> analysis::AnakinSubgraphPass::ApplyImpl( ...@@ -72,8 +72,6 @@ std::unique_ptr<framework::ir::Graph> analysis::AnakinSubgraphPass::ApplyImpl(
framework::ir::GraphSafeRemoveNodes(graph.get(), nodes2remove); framework::ir::GraphSafeRemoveNodes(graph.get(), nodes2remove);
graph->Set(framework::ir::kRepetitiveParamAttr, graph->Set(framework::ir::kRepetitiveParamAttr,
new std::vector<std::string>(repetitive_params)); new std::vector<std::string>(repetitive_params));
return graph;
} }
std::string GenerateAnakinEngineKey(const std::set<std::string> &engine_inputs, std::string GenerateAnakinEngineKey(const std::set<std::string> &engine_inputs,
......
...@@ -29,8 +29,7 @@ namespace analysis { ...@@ -29,8 +29,7 @@ namespace analysis {
class AnakinSubgraphPass : public framework::ir::FusePassBase { class AnakinSubgraphPass : public framework::ir::FusePassBase {
public: public:
std::unique_ptr<framework::ir::Graph> ApplyImpl( void ApplyImpl(framework::ir::Graph *graph) const override;
std::unique_ptr<framework::ir::Graph> graph) const override;
private: private:
void CreateAnakinOp(framework::ir::Node *x, framework::ir::Graph *graph, void CreateAnakinOp(framework::ir::Node *x, framework::ir::Graph *graph,
......
...@@ -31,16 +31,16 @@ namespace analysis { ...@@ -31,16 +31,16 @@ namespace analysis {
using framework::ir::Node; using framework::ir::Node;
std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl( void analysis::TensorRtSubgraphPass::ApplyImpl(
std::unique_ptr<framework::ir::Graph> graph) const { framework::ir::Graph *graph) const {
framework::ir::FusePassBase::Init("tensorrt_subgraph_pass", graph.get()); framework::ir::FusePassBase::Init("tensorrt_subgraph_pass", graph);
auto teller = [](const framework::ir::Node *node) { auto teller = [](const framework::ir::Node *node) {
if (!node->IsOp() || !node->Op()) return false; if (!node->IsOp() || !node->Op()) return false;
return tensorrt::OpTeller::Global().Tell(node->Op()->Type(), *node->Op()); return tensorrt::OpTeller::Global().Tell(node->Op()->Type(), *node->Op());
}; };
SubGraphFuser fuser(graph.get(), teller, SubGraphFuser fuser(graph, teller,
Get<int>("min_subgraph_size") /*min subgraph size*/); Get<int>("min_subgraph_size") /*min subgraph size*/);
fuser(); fuser();
...@@ -52,12 +52,11 @@ std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl( ...@@ -52,12 +52,11 @@ std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl(
for (auto *node : graph->Nodes()) { for (auto *node : graph->Nodes()) {
if (node->IsOp() && !Agent(node).subgraph()->empty()) { if (node->IsOp() && !Agent(node).subgraph()->empty()) {
CreateTensorRTOp(node, graph.get(), graph_param_names, CreateTensorRTOp(node, graph, graph_param_names, &repetitive_params);
&repetitive_params);
std::unordered_set<const Node *> nodes2remove( std::unordered_set<const Node *> nodes2remove(
Agent(node).subgraph()->begin(), Agent(node).subgraph()->end()); Agent(node).subgraph()->begin(), Agent(node).subgraph()->end());
framework::ir::GraphSafeRemoveNodes(graph.get(), nodes2remove); framework::ir::GraphSafeRemoveNodes(graph, nodes2remove);
} }
} }
...@@ -67,11 +66,9 @@ std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl( ...@@ -67,11 +66,9 @@ std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl(
nodes2remove.insert(node); nodes2remove.insert(node);
} }
} }
framework::ir::GraphSafeRemoveNodes(graph.get(), nodes2remove); framework::ir::GraphSafeRemoveNodes(graph, nodes2remove);
graph->Set(framework::ir::kRepetitiveParamAttr, graph->Set(framework::ir::kRepetitiveParamAttr,
new std::vector<std::string>(repetitive_params)); new std::vector<std::string>(repetitive_params));
return graph;
} }
std::string GenerateEngineKey(const std::set<std::string> &engine_inputs, std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
......
...@@ -28,8 +28,7 @@ namespace analysis { ...@@ -28,8 +28,7 @@ namespace analysis {
class TensorRtSubgraphPass : public framework::ir::FusePassBase { class TensorRtSubgraphPass : public framework::ir::FusePassBase {
public: public:
std::unique_ptr<framework::ir::Graph> ApplyImpl( void ApplyImpl(framework::ir::Graph *graph) const override;
std::unique_ptr<framework::ir::Graph> graph) const override;
private: private:
void CreateTensorRTOp(framework::ir::Node *x, framework::ir::Graph *graph, void CreateTensorRTOp(framework::ir::Node *x, framework::ir::Graph *graph,
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.h"
#include <memory>
#include "paddle/fluid/framework/ir/graph_to_program_pass.h" #include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
...@@ -37,8 +38,7 @@ void IrGraphToProgramPass::RunImpl(Argument *argument) { ...@@ -37,8 +38,7 @@ void IrGraphToProgramPass::RunImpl(Argument *argument) {
framework::ProgramDesc desc; framework::ProgramDesc desc;
desc.CopyFrom(*argument->main_program().Proto()); desc.CopyFrom(*argument->main_program().Proto());
pass->SetNotOwned("program", &desc); pass->SetNotOwned("program", &desc);
auto thegraph = pass->Apply(std::move(graph)); pass->Apply(graph.release()); // the argument still own the graph.
thegraph.release(); // the argument still own the graph.
argument->SetIrAnalyzedProgram( argument->SetIrAnalyzedProgram(
new framework::proto::ProgramDesc(*desc.Proto())); new framework::proto::ProgramDesc(*desc.Proto()));
......
...@@ -33,13 +33,19 @@ endif() ...@@ -33,13 +33,19 @@ endif()
add_subdirectory(details) add_subdirectory(details)
cc_library(analysis_config SRCS analysis_config.cc DEPS lod_tensor paddle_pass_builder) if(WITH_MKLDNN)
set(mkldnn_quantizer_src mkldnn_quantizer.cc)
set(mkldnn_quantizer_cfg mkldnn_quantizer_config)
cc_library(${mkldnn_quantizer_cfg} SRCS mkldnn_quantizer_config.cc DEPS lod_tensor paddle_pass_builder)
endif()
cc_library(analysis_config SRCS analysis_config.cc DEPS ${mkldnn_quantizer_cfg} lod_tensor paddle_pass_builder)
cc_library(paddle_pass_builder SRCS paddle_pass_builder.cc) cc_library(paddle_pass_builder SRCS paddle_pass_builder.cc)
cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api zero_copy_tensor cc_library(analysis_predictor SRCS analysis_predictor.cc ${mkldnn_quantizer_src} DEPS paddle_inference_api zero_copy_tensor
reset_tensor_array analysis_config paddle_pass_builder ir_pass_manager ${inference_deps}) reset_tensor_array analysis_config paddle_pass_builder ir_pass_manager ${inference_deps})
cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS
lod_tensor scope paddle_pass_builder reset_tensor_array analysis_config lod_tensor scope paddle_pass_builder reset_tensor_array analysis_config
analysis_config paddle_pass_builder zero_copy_tensor paddle_pass_builder zero_copy_tensor
reset_tensor_array) reset_tensor_array)
cc_test(test_paddle_inference_api cc_test(test_paddle_inference_api
......
...@@ -108,6 +108,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { ...@@ -108,6 +108,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
// MKLDNN related. // MKLDNN related.
CP_MEMBER(use_mkldnn_); CP_MEMBER(use_mkldnn_);
CP_MEMBER(mkldnn_enabled_op_types_); CP_MEMBER(mkldnn_enabled_op_types_);
// Quantization related.
CP_MEMBER(use_mkldnn_quantizer_);
CP_MEMBER(mkldnn_quantizer_config_);
CP_MEMBER(use_anakin_); CP_MEMBER(use_anakin_);
CP_MEMBER(anakin_max_batchsize_); CP_MEMBER(anakin_max_batchsize_);
...@@ -148,6 +151,26 @@ void AnalysisConfig::EnableMKLDNN() { ...@@ -148,6 +151,26 @@ void AnalysisConfig::EnableMKLDNN() {
Update(); Update();
} }
void AnalysisConfig::EnableMkldnnQuantizer() {
#ifdef PADDLE_WITH_MKLDNN
if (!mkldnn_quantizer_config_)
mkldnn_quantizer_config_.reset(new MkldnnQuantizerConfig());
use_mkldnn_quantizer_ = true;
#else
LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnQuantizer";
use_mkldnn_quantizer_ = false;
#endif
Update();
}
std::shared_ptr<MkldnnQuantizerConfig> AnalysisConfig::mkldnn_quantizer_config()
const {
PADDLE_ENFORCE_NOT_NULL(mkldnn_quantizer_config_,
"MkldnnQuantizer was not enabled yet.");
return mkldnn_quantizer_config_;
}
void AnalysisConfig::EnableTensorRtEngine( void AnalysisConfig::EnableTensorRtEngine(
int workspace_size, int max_batch_size, int min_subgraph_size, int workspace_size, int max_batch_size, int min_subgraph_size,
AnalysisConfig::Precision precision_mode, bool use_static) { AnalysisConfig::Precision precision_mode, bool use_static) {
...@@ -224,15 +247,27 @@ void AnalysisConfig::Update() { ...@@ -224,15 +247,27 @@ void AnalysisConfig::Update() {
#endif #endif
} }
if (enable_memory_optim_) { // Quantization passes must come after all other optimization passes
auto analysis_passes = pass_builder()->AnalysisPasses(); if (use_mkldnn_quantizer_) {
auto memory_opti_pass_name = "memory_optimize_pass"; if (!enable_ir_optim_) {
bool already_exists = LOG(ERROR) << "EnableMkldnnQuantizer() only works when IR optimization "
std::find(analysis_passes.begin(), analysis_passes.end(), "is enabled.";
memory_opti_pass_name) != analysis_passes.end(); }
if (!already_exists) { #ifdef PADDLE_WITH_MKLDNN
pass_builder()->AppendAnalysisPass(memory_opti_pass_name); pass_builder()->EnableMkldnnQuantizer();
#else
LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnQuantizer";
use_mkldnn_quantizer_ = false;
#endif
} }
#ifdef PADDLE_WITH_MKLDNN
// Do not optimize before quantization
if (enable_memory_optim_ && !use_mkldnn_quantizer_) {
#else
if (enable_memory_optim_) {
#endif
pass_builder()->AppendAnalysisPass("memory_optimize_pass");
} }
if (use_anakin_) { if (use_anakin_) {
...@@ -277,6 +312,7 @@ std::string AnalysisConfig::SerializeInfoCache() { ...@@ -277,6 +312,7 @@ std::string AnalysisConfig::SerializeInfoCache() {
for (auto &item : mkldnn_enabled_op_types_) ss << item; for (auto &item : mkldnn_enabled_op_types_) ss << item;
ss << ";"; ss << ";";
ss << use_mkldnn_quantizer_;
ss << model_from_memory_; ss << model_from_memory_;
ss << enable_ir_optim_; ss << enable_ir_optim_;
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <fstream> #include <fstream>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
...@@ -35,8 +36,13 @@ ...@@ -35,8 +36,13 @@
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/inference/api/mkldnn_quantizer.h"
#endif
#if PADDLE_WITH_TENSORRT #if PADDLE_WITH_TENSORRT
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h" #include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
...@@ -341,10 +347,7 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs, ...@@ -341,10 +347,7 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
return true; return true;
} }
// NOTE All the members in AnalysisConfig should be copied to Argument. void AnalysisPredictor::PrepareArgument() {
void AnalysisPredictor::OptimizeInferenceProgram() {
status_program_optimized_ = true;
argument_.SetUseGPU(config_.use_gpu()); argument_.SetUseGPU(config_.use_gpu());
argument_.SetGPUDeviceId(config_.gpu_device_id()); argument_.SetGPUDeviceId(config_.gpu_device_id());
argument_.SetEnableMemoryOptim(config_.enable_memory_optim()); argument_.SetEnableMemoryOptim(config_.enable_memory_optim());
...@@ -390,6 +393,16 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -390,6 +393,16 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
argument_.SetMKLDNNEnabledOpTypes(config_.mkldnn_enabled_op_types_); argument_.SetMKLDNNEnabledOpTypes(config_.mkldnn_enabled_op_types_);
} }
#ifdef PADDLE_WITH_MKLDNN
if (config_.mkldnn_quantizer_enabled()) {
LOG(INFO) << "Quantization is enabled";
argument_.SetQuantizeEnabledOpTypes(
config_.mkldnn_quantizer_config()->enabled_op_types());
argument_.SetQuantizeExcludedOpIds(
config_.mkldnn_quantizer_config()->excluded_op_ids());
}
#endif
auto passes = config_.pass_builder()->AllPasses(); auto passes = config_.pass_builder()->AllPasses();
if (!config_.ir_optim()) { if (!config_.ir_optim()) {
passes.clear(); passes.clear();
...@@ -398,6 +411,13 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -398,6 +411,13 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
argument_.SetIrAnalysisPasses(passes); argument_.SetIrAnalysisPasses(passes);
argument_.SetAnalysisPasses(config_.pass_builder()->AnalysisPasses()); argument_.SetAnalysisPasses(config_.pass_builder()->AnalysisPasses());
argument_.SetScopeNotOwned(scope_.get()); argument_.SetScopeNotOwned(scope_.get());
}
// NOTE All the members in AnalysisConfig should be copied to Argument.
void AnalysisPredictor::OptimizeInferenceProgram() {
status_program_optimized_ = true;
PrepareArgument();
Analyzer().Run(&argument_); Analyzer().Run(&argument_);
PADDLE_ENFORCE(argument_.scope_valid()); PADDLE_ENFORCE(argument_.scope_valid());
...@@ -439,12 +459,31 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor< ...@@ -439,12 +459,31 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
} }
std::unique_ptr<PaddlePredictor> predictor(new AnalysisPredictor(config)); std::unique_ptr<PaddlePredictor> predictor(new AnalysisPredictor(config));
if (!dynamic_cast<AnalysisPredictor *>(predictor.get())->Init(nullptr)) { auto predictor_p = dynamic_cast<AnalysisPredictor *>(predictor.get());
if (!predictor_p->Init(nullptr)) {
return nullptr;
}
if (config.mkldnn_quantizer_enabled() && !predictor_p->MkldnnQuantize()) {
return nullptr; return nullptr;
} }
return predictor; return predictor;
} }
bool AnalysisPredictor::MkldnnQuantize() {
#if PADDLE_WITH_MKLDNN
if (!mkldnn_quantizer_)
mkldnn_quantizer_ = new AnalysisPredictor::MkldnnQuantizer(
*this, config_.mkldnn_quantizer_config());
return mkldnn_quantizer_->Quantize();
#else
LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnQuantizer";
return false;
#endif
}
void AnalysisPredictor::PrepareFeedFetch() { void AnalysisPredictor::PrepareFeedFetch() {
PADDLE_ENFORCE_NOT_NULL(sub_scope_); PADDLE_ENFORCE_NOT_NULL(sub_scope_);
CreateFeedFetchVar(sub_scope_); CreateFeedFetchVar(sub_scope_);
...@@ -703,6 +742,13 @@ AnalysisPredictor::~AnalysisPredictor() { ...@@ -703,6 +742,13 @@ AnalysisPredictor::~AnalysisPredictor() {
scope_->DeleteScope(sub_scope_); scope_->DeleteScope(sub_scope_);
} }
#if PADDLE_WITH_MKLDNN
if (mkldnn_quantizer_) {
delete mkldnn_quantizer_;
mkldnn_quantizer_ = nullptr;
}
#endif
// TODO(Superjomn) deduce the directory path. // TODO(Superjomn) deduce the directory path.
std::string out_path = inference::analysis::GetMemoryCachePath( std::string out_path = inference::analysis::GetMemoryCachePath(
config_.model_dir(), config_.prog_file()); config_.model_dir(), config_.prog_file());
......
...@@ -70,6 +70,7 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -70,6 +70,7 @@ class AnalysisPredictor : public PaddlePredictor {
void CreateFeedFetchVar(framework::Scope *scope); void CreateFeedFetchVar(framework::Scope *scope);
void PrepareFeedFetch(); void PrepareFeedFetch();
void PrepareArgument();
void OptimizeInferenceProgram(); void OptimizeInferenceProgram();
Argument &analysis_argument() { return argument_; } Argument &analysis_argument() { return argument_; }
...@@ -83,6 +84,8 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -83,6 +84,8 @@ class AnalysisPredictor : public PaddlePredictor {
std::string GetSerializedProgram() const override; std::string GetSerializedProgram() const override;
bool MkldnnQuantize();
protected: protected:
// For memory optimization. // For memory optimization.
bool need_collect_var_shapes_for_memory_optim(); bool need_collect_var_shapes_for_memory_optim();
...@@ -143,6 +146,16 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -143,6 +146,16 @@ class AnalysisPredictor : public PaddlePredictor {
std::vector<framework::OpDesc *> fetches_; std::vector<framework::OpDesc *> fetches_;
std::map<size_t, std::string> idx2fetches_; std::map<size_t, std::string> idx2fetches_;
#if PADDLE_WITH_MKLDNN
// Helper class to perform quantization
class MkldnnQuantizer;
MkldnnQuantizer *mkldnn_quantizer_{nullptr};
#if PADDLE_WITH_TESTING
friend class MkldnnQuantizerTest;
#endif
#endif
// Memory buffer for feed inputs. The temporary LoDTensor will cause serious // Memory buffer for feed inputs. The temporary LoDTensor will cause serious
// concurrency problems, wrong results and memory leak, so cache them. // concurrency problems, wrong results and memory leak, so cache them.
std::vector<framework::LoDTensor> feed_tensors_; std::vector<framework::LoDTensor> feed_tensors_;
......
...@@ -17,9 +17,13 @@ ...@@ -17,9 +17,13 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <thread> // NOLINT #include <thread> // NOLINT
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h" #include "paddle/fluid/inference/tests/api/tester_helper.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/inference/api/mkldnn_quantizer.h"
#endif
DEFINE_string(dirname, "", "dirname to tests."); DEFINE_string(dirname, "", "dirname to tests.");
...@@ -243,4 +247,241 @@ TEST(AnalysisPredictor, memory_optim) { ...@@ -243,4 +247,241 @@ TEST(AnalysisPredictor, memory_optim) {
inference::CompareResult(output, output1); inference::CompareResult(output, output1);
} }
#ifdef PADDLE_WITH_MKLDNN
class MkldnnQuantizerTest : public testing::Test {
public:
MkldnnQuantizerTest() {
AnalysisConfig config(FLAGS_dirname);
predictor.reset(new AnalysisPredictor(config));
auto* predictor_p = static_cast<AnalysisPredictor*>(predictor.get());
auto qconfig = std::make_shared<MkldnnQuantizerConfig>();
mkldnn_quantizer.reset(
new AnalysisPredictor::MkldnnQuantizer(*predictor_p, qconfig));
}
std::pair<std::vector<int>, float> Histogram(
const framework::LoDTensor& var_tensor, float min_val, float max_val,
int num_bins) const {
return mkldnn_quantizer->Histogram(var_tensor, min_val, max_val, num_bins);
}
std::pair<bool, framework::LoDTensor> GetMaxScalingFactor(
const framework::LoDTensor& var_tensor, bool is_unsigned) const {
return mkldnn_quantizer->GetMaxScalingFactor(var_tensor, is_unsigned);
}
std::pair<bool, framework::LoDTensor> GetMaxChScalingFactor(
const framework::LoDTensor& var_tensor, bool is_unsigned) const {
return mkldnn_quantizer->GetMaxChScalingFactor(var_tensor, is_unsigned);
}
std::pair<bool, framework::LoDTensor> GetKLScalingFactor(
const framework::LoDTensor& var_tensor, bool is_unsigned) const {
return mkldnn_quantizer->GetKLScalingFactor(var_tensor, is_unsigned);
}
protected:
std::unique_ptr<PaddlePredictor> predictor;
std::unique_ptr<AnalysisPredictor::MkldnnQuantizer> mkldnn_quantizer;
float abs_error = 1e-6;
static const std::array<float, 10> non_negative_values;
static const std::array<float, 10> positive_and_negative_values;
};
const std::array<float, 10> MkldnnQuantizerTest::non_negative_values = {
0.0158671, 0.026459, 0.0280772, 0.00962479, 0.0131628,
0.016704, 0.00118407, 0.00765726, 0.0123213, 0.00944741};
const std::array<float, 10> MkldnnQuantizerTest::positive_and_negative_values =
{-0.0482659, -0.0102493, -0.00794221, -0.00387115, -0.00674586,
-0.0495346, 0.0629528, -0.00531285, -0.0230353, 0.0269089};
TEST_F(MkldnnQuantizerTest, histogram_inverted_min_max) {
const auto& values = non_negative_values;
auto min_val = *std::min_element(values.begin(), values.end());
auto max_val = *std::max_element(values.begin(), values.end());
framework::LoDTensor var_tensor;
var_tensor.Resize(framework::make_dim(values.size()));
std::copy(begin(values), end(values),
var_tensor.mutable_data<float>(platform::CPUPlace()));
ASSERT_THROW(Histogram(var_tensor, max_val, min_val, 3),
platform::EnforceNotMet);
}
TEST_F(MkldnnQuantizerTest, histogram_non_negative_to_3) {
// all non-negative values
const auto& values = non_negative_values;
auto min_val = *std::min_element(values.begin(), values.end());
auto max_val = *std::max_element(values.begin(), values.end());
framework::LoDTensor var_tensor;
var_tensor.Resize(framework::make_dim(values.size()));
std::copy(begin(values), end(values),
var_tensor.mutable_data<float>(platform::CPUPlace()));
std::vector<int> histogram;
float bin_width;
std::tie(histogram, bin_width) = Histogram(var_tensor, min_val, max_val, 3);
ASSERT_NEAR(bin_width, std::abs(max_val - min_val) / 3.f, abs_error)
<< "Improperly calculated bin_width.";
ASSERT_EQ(histogram[0], 4);
ASSERT_EQ(histogram[1], 4);
ASSERT_EQ(histogram[2], 2);
}
TEST_F(MkldnnQuantizerTest, histogram_positive_and_negative_to_3) {
const auto& values = positive_and_negative_values;
auto min_val = *std::min_element(values.begin(), values.end());
auto max_val = *std::max_element(values.begin(), values.end());
framework::LoDTensor var_tensor;
var_tensor.Resize(framework::make_dim(values.size()));
std::copy(begin(values), end(values),
var_tensor.mutable_data<float>(platform::CPUPlace()));
std::vector<int> histogram;
float bin_width;
std::tie(histogram, bin_width) = Histogram(var_tensor, min_val, max_val, 3);
ASSERT_NEAR(bin_width, std::abs(max_val - min_val) / 3.0f, abs_error)
<< "Improperly calculated bin_width.";
ASSERT_EQ(histogram[0], 3);
ASSERT_EQ(histogram[1], 5);
ASSERT_EQ(histogram[2], 2);
}
TEST_F(MkldnnQuantizerTest, histogram_zero_bins) {
const auto& values = non_negative_values;
auto min_val = *std::min_element(values.begin(), values.end());
auto max_val = *std::max_element(values.begin(), values.end());
framework::LoDTensor var_tensor;
var_tensor.Resize(framework::make_dim(values.size()));
std::copy(begin(values), end(values),
var_tensor.mutable_data<float>(platform::CPUPlace()));
ASSERT_THROW(Histogram(var_tensor, min_val, max_val, 0),
platform::EnforceNotMet);
}
TEST_F(MkldnnQuantizerTest, histogram_empty) {
// empty tensor
ASSERT_THROW(Histogram({}, -1, 1, 1), platform::EnforceNotMet);
// zero tensor
framework::LoDTensor var_tensor;
var_tensor.Resize({0});
ASSERT_TRUE(var_tensor.mutable_data<double>(platform::CPUPlace()));
ASSERT_THROW(Histogram(var_tensor, -1, 1, 1), platform::EnforceNotMet);
}
TEST_F(MkldnnQuantizerTest, kl_scaling_factor_signed) {
const auto& values = positive_and_negative_values;
framework::LoDTensor var_tensor;
var_tensor.Resize(framework::make_dim(values.size()));
std::copy(begin(values), end(values),
var_tensor.mutable_data<float>(platform::CPUPlace()));
bool is_unsigned;
framework::LoDTensor lod_tensor;
std::tie(is_unsigned, lod_tensor) = GetKLScalingFactor(var_tensor, false);
ASSERT_EQ(is_unsigned, false);
ASSERT_EQ(lod_tensor.numel(), 1);
ASSERT_NEAR(lod_tensor.data<double>()[0], 1.0 / 0.0899106152344, abs_error);
}
TEST_F(MkldnnQuantizerTest, max_scaling_factor_signed) {
const auto& values = positive_and_negative_values;
auto max_val = *std::max_element(values.begin(), values.end());
framework::LoDTensor var_tensor;
var_tensor.Resize(framework::make_dim(values.size()));
std::copy(begin(values), end(values),
var_tensor.mutable_data<float>(platform::CPUPlace()));
bool is_unsigned;
framework::LoDTensor lod_tensor;
std::tie(is_unsigned, lod_tensor) = GetMaxScalingFactor(var_tensor, false);
ASSERT_EQ(is_unsigned, false);
ASSERT_EQ(lod_tensor.numel(), 1);
ASSERT_NEAR(lod_tensor.data<double>()[0], 1.0 / max_val, abs_error);
}
TEST_F(MkldnnQuantizerTest, max_scaling_factor_unsigned) {
const auto& values = non_negative_values;
auto max_val = *std::max_element(values.begin(), values.end());
framework::LoDTensor var_tensor;
var_tensor.Resize(framework::make_dim(values.size()));
std::copy(begin(values), end(values),
var_tensor.mutable_data<float>(platform::CPUPlace()));
bool is_unsigned;
framework::LoDTensor lod_tensor;
std::tie(is_unsigned, lod_tensor) = GetMaxScalingFactor(var_tensor, true);
ASSERT_EQ(is_unsigned, true);
ASSERT_EQ(lod_tensor.numel(), 1);
ASSERT_NEAR(lod_tensor.data<double>()[0], 1.0 / max_val, abs_error);
}
TEST_F(MkldnnQuantizerTest, max_scaling_factor_chwise_unsigned) {
const auto& values = non_negative_values;
auto max_val = *std::max_element(values.begin(), values.end());
int channels = 3;
framework::LoDTensor var_tensor;
var_tensor.Resize(framework::make_dim(channels, 1, 1, values.size()));
for (int i = 0; i < channels; i++)
std::copy(begin(values), end(values),
var_tensor.mutable_data<float>(platform::CPUPlace()) +
i * values.size());
bool is_unsigned;
framework::LoDTensor lod_tensor;
std::tie(is_unsigned, lod_tensor) = GetMaxChScalingFactor(var_tensor, true);
ASSERT_EQ(is_unsigned, true);
ASSERT_EQ(lod_tensor.numel(), channels);
for (int i = 0; i < channels; i++) {
ASSERT_NEAR(lod_tensor.data<double>()[i], 1.0 / max_val, abs_error);
}
}
TEST_F(MkldnnQuantizerTest, kl_scaling_factor_unsigned) {
const auto& values = non_negative_values;
framework::LoDTensor var_tensor;
var_tensor.Resize(framework::make_dim(values.size()));
std::copy(begin(values), end(values),
var_tensor.mutable_data<float>(platform::CPUPlace()));
bool is_unsigned;
framework::LoDTensor lod_tensor;
std::tie(is_unsigned, lod_tensor) = GetKLScalingFactor(var_tensor, true);
ASSERT_EQ(is_unsigned, true);
ASSERT_EQ(lod_tensor.numel(), 1);
ASSERT_NEAR(lod_tensor.data<double>()[0], 1.0 / 0.0252845321362, abs_error);
}
#endif
} // namespace paddle } // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/api/mkldnn_quantizer.h"
#include <algorithm>
#include <map>
#include <numeric>
#include <unordered_map>
#include <utility>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
using platform::CPUPlace;
using framework::LoDTensor;
using framework::ir::Graph;
using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<float, Eigen::Dynamic, 1>>;
using string::PrettyLogH1;
bool AnalysisPredictor::MkldnnQuantizer::CalculateScales() {
PrettyLogH1("--- Calculating scales for quantization");
using VariableNameMap = std::map<std::string, std::vector<std::string>>;
std::map<std::string, std::map<std::string, LoDTensor>> gathered_data;
for (const auto* op : predictor_.inference_program_->Block(0).AllOps()) {
if (op->HasAttr("use_quantizer") &&
boost::get<bool>(op->GetAttr("use_quantizer"))) {
const VariableNameMap& connections_in = op->Inputs();
const VariableNameMap& connections_out = op->Outputs();
auto glambda = [&](const VariableNameMap& connections, bool is_output) {
for (auto const& conn : connections) {
if (conn.second.size() == 0) continue;
auto& var_name = conn.second[0];
// skip if scale already computed
if (scales_.find(var_name) != scales_.end()) return;
auto* var = predictor_.sub_scope_->FindVar(var_name);
PADDLE_ENFORCE(var, "%s is not in the scope", var_name);
PADDLE_ENFORCE(var->IsType<LoDTensor>(),
"Only support lod tensor now.");
LoDTensor* var_tensor = var->GetMutable<LoDTensor>();
// force unsigned type if already know it
bool is_unsigned = false;
if (is_output && op->Type() == "conv2d") {
// output of conv2d with relu must be unsigned
is_unsigned = op->HasAttr("fuse_relu") &&
boost::get<bool>(op->GetAttr("fuse_relu"));
} else if (is_output && op->Type() == "pool2d") {
// output of pool2d with unsigned input must be unsigned
auto input_var_name = op->Input("X")[0];
if (scales_.find(input_var_name) != scales_.end()) {
is_unsigned = scales_[input_var_name].first;
}
}
CalculateSingleScale(op->Type(), conn.first, var_name, *var_tensor,
is_unsigned);
}
};
// handle outputs first so unsigned outputs could be inferred
glambda(connections_out, true /* is_output */);
glambda(connections_in, false /* is_output */);
}
}
return true;
}
void AnalysisPredictor::MkldnnQuantizer::CalculateSingleScale(
const std::string& op_type_name, const std::string& conn_name,
const std::string& var_name, const LoDTensor& var_tensor,
bool is_unsigned) {
auto rule = qconfig_->scale_algo(op_type_name, conn_name);
if (rule == ScaleAlgo::NONE) return;
PADDLE_ENFORCE(
var_tensor.numel() > 0,
"MkldnnQuantizer: LoDTensor of variable %s for quantization of op "
"%s of connection %s should not be empty.",
var_name, op_type_name, conn_name);
switch (rule) {
case ScaleAlgo::MAX:
scales_[var_name] = GetMaxScalingFactor(var_tensor, is_unsigned);
break;
case ScaleAlgo::MAX_CH:
scales_[var_name] = GetMaxChScalingFactor(var_tensor, is_unsigned);
break;
case ScaleAlgo::KL:
scales_[var_name] = GetKLScalingFactor(var_tensor, is_unsigned);
break;
default:
throw std::runtime_error(
"MkldnnQuantizer: Unexpected ScaleAlgo specified.");
}
}
std::vector<int> AnalysisPredictor::MkldnnQuantizer::ExpandQuantizedBins(
std::vector<int> quantized_bins, std::vector<int> reference_bins) const {
std::vector<int> expanded_quantized_bins(reference_bins.size(), 0);
int num_merged_bins = reference_bins.size() / quantized_bins.size();
int j_start = 0;
int j_end = num_merged_bins;
for (size_t idx = 0; idx < quantized_bins.size(); idx++) {
int zero_count =
std::count(&reference_bins[j_start], &reference_bins[j_end], 0);
num_merged_bins = j_end - j_start;
int avg_bin_ele;
if (zero_count == num_merged_bins) {
avg_bin_ele = 0;
} else {
avg_bin_ele = quantized_bins[idx] / (num_merged_bins - zero_count + 0.0);
}
for (int idx1 = j_start; idx1 < j_end; idx1++) {
expanded_quantized_bins[idx1] =
(reference_bins[idx1] == 0) ? 0 : avg_bin_ele;
}
j_start += num_merged_bins;
j_end += num_merged_bins;
if ((idx + 1) == quantized_bins.size() - 1) {
j_end = reference_bins.size();
}
}
return expanded_quantized_bins;
}
std::pair<bool, LoDTensor>
AnalysisPredictor::MkldnnQuantizer::GetKLScalingFactor(
const LoDTensor& var_tensor, bool is_unsigned) const {
ConstEigenVectorArrayMap eigen_tensor{var_tensor.data<float>(),
var_tensor.numel(), 1};
int precision_hist_num_bins = 2048;
float max_val = eigen_tensor.maxCoeff();
float min_val = eigen_tensor.minCoeff();
bool is_positive = min_val >= 0.0f;
if (is_unsigned)
PADDLE_ENFORCE(
is_positive,
"Tensor is claimed to be unsigned, but its min value (%f) is < 0.0",
min_val);
int num_quantized_bins = 255;
std::vector<int> hist;
float bin_width;
int starting_iter;
int ending_iter = precision_hist_num_bins - 1;
if (is_positive) {
std::tie(hist, bin_width) =
Histogram(var_tensor, min_val, max_val, precision_hist_num_bins);
starting_iter = static_cast<int>(ending_iter * 0.7);
} else {
float th = std::max(std::abs(max_val), std::abs(min_val));
std::tie(hist, bin_width) =
Histogram(var_tensor, -th, th, precision_hist_num_bins);
starting_iter = 0;
if (std::abs(max_val) > std::abs(min_val)) {
while (starting_iter < ending_iter) {
if (hist[starting_iter] == 0) {
++starting_iter;
continue;
} else {
break;
}
}
starting_iter += static_cast<int>((ending_iter - starting_iter) * 0.6);
} else {
while (ending_iter > 0) {
if (hist[ending_iter] == 0) {
--ending_iter;
continue;
} else {
break;
}
}
starting_iter = static_cast<int>(0.6 * ending_iter);
}
}
auto P_sum = eigen_tensor.size();
int min_kl_divergence = 0;
int min_kl_index = 0;
bool kl_inited = false;
for (int i = starting_iter; i <= ending_iter; i++) {
std::vector<int> reference_distr_P(&hist[0], &hist[i]);
auto outliers_count =
std::accumulate(&hist[i], &hist[precision_hist_num_bins], 0);
if (reference_distr_P[i - 1] == 0) {
continue;
}
reference_distr_P[i - 1] += outliers_count;
auto reference_distr_bins = reference_distr_P;
std::vector<int> candidate_distr_Q(&hist[0], &hist[i]);
int num_merged_bins = i / num_quantized_bins;
std::vector<int> candidate_distr_Q_quantized(num_quantized_bins, 0);
int j_start = 0;
int j_end = num_merged_bins;
for (int idx = 0; idx < num_quantized_bins; idx++) {
candidate_distr_Q_quantized[idx] = std::accumulate(
&candidate_distr_Q[j_start], &candidate_distr_Q[j_end], 0);
j_start += num_merged_bins;
j_end += num_merged_bins;
if ((idx + 1) == num_quantized_bins - 1) {
j_end = i;
}
}
candidate_distr_Q =
ExpandQuantizedBins(candidate_distr_Q_quantized, reference_distr_bins);
int Q_sum =
std::accumulate(candidate_distr_Q.begin(), candidate_distr_Q.end(), 0);
auto kl_divergence =
SafeEntropy(reference_distr_P, P_sum, candidate_distr_Q, Q_sum);
if (!kl_inited) {
min_kl_divergence = kl_divergence;
min_kl_index = i;
kl_inited = true;
} else if (kl_divergence < min_kl_divergence) {
min_kl_divergence = kl_divergence;
min_kl_index = i;
} else {
}
}
if (min_kl_index == 0) {
while (starting_iter > 0) {
if (hist[starting_iter] == 0) {
starting_iter -= 1;
continue;
} else {
break;
}
}
min_kl_index = starting_iter;
}
LoDTensor scale_tensor;
scale_tensor.Resize({1});
auto* scale_ptr = scale_tensor.mutable_data<double>(CPUPlace());
scale_ptr[0] = 1.0 / ((min_kl_index + 0.5) * bin_width);
return std::make_pair(is_unsigned, scale_tensor);
}
std::pair<bool, LoDTensor>
AnalysisPredictor::MkldnnQuantizer::GetMaxScalingFactor(
const LoDTensor& var_tensor, bool is_unsigned) const {
ConstEigenVectorArrayMap eigen_tensor{var_tensor.data<float>(),
var_tensor.numel(), 1};
float max_abs = eigen_tensor.abs().maxCoeff();
float min_val = eigen_tensor.minCoeff();
if (is_unsigned)
PADDLE_ENFORCE(
min_val >= 0.0f,
"Tensor is claimed to be unsigned, but its min value (%f) is < 0.0",
min_val);
LoDTensor scale_tensor;
scale_tensor.Resize({1});
auto* scale_ptr = scale_tensor.mutable_data<double>(CPUPlace());
scale_ptr[0] = 1.0 / max_abs;
return std::make_pair(is_unsigned, scale_tensor);
}
std::pair<bool, LoDTensor>
AnalysisPredictor::MkldnnQuantizer::GetMaxChScalingFactor(
const LoDTensor& var_tensor, bool is_unsigned) const {
PADDLE_ENFORCE(var_tensor.dims().size() > 0, "Tensor dimension is empty.");
ConstEigenVectorArrayMap eigen_tensor{var_tensor.data<float>(),
var_tensor.numel(), 1};
float min_val = eigen_tensor.minCoeff();
if (is_unsigned)
PADDLE_ENFORCE(
min_val >= 0.0f,
"Tensor is claimed to be unsigned, but its min value (%f) is < 0.0",
min_val);
int channels = var_tensor.dims()[0];
LoDTensor scale_tensor;
scale_tensor.Resize({channels});
auto* scale_ptr = scale_tensor.mutable_data<double>(CPUPlace());
for (int i = 0; i < channels; ++i) {
const auto tensor = var_tensor.Slice(i, i + 1);
ConstEigenVectorArrayMap eigen_tensor{tensor.data<float>(), tensor.numel(),
1};
float max_abs = eigen_tensor.abs().maxCoeff();
scale_ptr[i] = 1.0 / max_abs;
}
return std::make_pair(is_unsigned, scale_tensor);
}
std::pair<std::vector<int>, float>
AnalysisPredictor::MkldnnQuantizer::Histogram(
const framework::LoDTensor& var_tensor, float min_val, float max_val,
size_t num_bins) const {
PADDLE_ENFORCE_GT(num_bins, 0,
"MkldnnQuantizer: To calculate Histogram, num_bins (" +
std::to_string(num_bins) + ") must be positive.");
PADDLE_ENFORCE_GT(
var_tensor.numel(), 0,
"MkldnnQuantizer: To calculate Histogram, the tensor must not be empty.");
PADDLE_ENFORCE(max_val >= min_val,
"MkldnnQuantizer: To calculate Histogram, max_val (" +
std::to_string(max_val) +
") must be greater or equal"
"to min_val (" +
std::to_string(min_val) + ").");
ConstEigenVectorArrayMap eigen_tensor{var_tensor.data<float>(),
var_tensor.numel(), 1};
auto bin_width = std::abs(max_val - min_val) / num_bins;
std::vector<int> hist(num_bins);
for (int i = 0; i < eigen_tensor.size(); i++) {
int bin = std::min(
num_bins - 1,
static_cast<size_t>(floor((eigen_tensor[i] - min_val) / bin_width)));
++hist[bin];
}
return std::make_pair(std::move(hist), std::move(bin_width));
}
void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
auto& arg = predictor_.argument_;
if (!arg.scope_valid()) arg.SetScope(new framework::Scope);
arg.SetMainProgramNotOwned(predictor_.inference_program_.get());
auto graph = std::unique_ptr<Graph>(new Graph(arg.main_program()));
arg.SetMainGraph(graph.release());
arg.main_graph().Set(framework::ir::kParamScopeAttr,
new framework::Scope*(arg.scope_ptr()));
auto* builder = predictor_.config_.pass_builder();
builder->SetPasses({
"infer_clean_graph_pass", "cpu_quantize_pass", "cpu_quantize_squash_pass",
});
if (predictor_.config_.ir_debug_) builder->TurnOnDebug();
auto passes = builder->AllPasses();
predictor_.argument_.SetIrAnalysisPasses(passes);
predictor_.argument_.SetAnalysisPasses(
{"ir_analysis_pass", "memory_optimize_pass", "ir_graph_to_program_pass"});
predictor_.argument_.SetQuantVarScales(scales_);
}
bool AnalysisPredictor::MkldnnQuantizer::Quantize() {
if (!RunWarmup()) return false;
if (!CalculateScales()) return false;
predictor_.PrepareScope(predictor_.scope_);
predictor_.CreateExecutor();
if (!RunQuantizePasses()) return false;
predictor_.PrepareExecutor();
predictor_.PrepareFeedFetch();
return true;
}
bool AnalysisPredictor::MkldnnQuantizer::RunQuantizePasses() const {
predictor_.executor_->CreateVariables(*predictor_.inference_program_, 0, true,
predictor_.sub_scope_);
PrepareArgument();
auto& arg = predictor_.argument_;
Analyzer().Run(&arg);
PADDLE_ENFORCE(arg.scope_valid());
VLOG(5) << "to prepare executor";
ARGUMENT_CHECK_FIELD((&arg), ir_analyzed_program);
predictor_.inference_program_.reset(
new framework::ProgramDesc(arg.ir_analyzed_program()));
LOG(INFO) << "== optimize 2 end ==";
predictor_.executor_->CreateVariables(*predictor_.inference_program_, 0,
false, predictor_.sub_scope_);
return true;
}
bool AnalysisPredictor::MkldnnQuantizer::RunWarmup() const {
VLOG(3) << "Predictor: run a quantization warmup iteration";
auto warmup_data = qconfig_->warmup_data();
PADDLE_ENFORCE_NOT_NULL(warmup_data,
"Warmup data cannot be NULL in the config.");
PrettyLogH1("--- Running warmup iteration for quantization");
// Run the inference program
std::vector<PaddleTensor> output_slots;
predictor_.Run(*warmup_data, &output_slots, qconfig_->warmup_batch_size());
return true;
}
float AnalysisPredictor::MkldnnQuantizer::SafeEntropy(
std::vector<int> reference_distr_P, int P_sum,
std::vector<int> candidate_distr_Q, int Q_sum) const {
PADDLE_ENFORCE_EQ(reference_distr_P.size(), candidate_distr_Q.size());
float tmp_sum1 = 0;
float tmp_sum2 = 0;
for (size_t idx = 0; idx < reference_distr_P.size(); idx++) {
int p_idx = reference_distr_P[idx];
int q_idx = candidate_distr_Q[idx];
if (p_idx == 0) {
tmp_sum1 += 0;
tmp_sum2 += 0;
} else {
PADDLE_ENFORCE(q_idx != 0, "MkldnnQuantizer: Fatal error!, idx = " +
std::to_string(idx) +
" qindex = 0! p_idx = " +
std::to_string(p_idx));
}
tmp_sum1 += p_idx * (log(Q_sum * p_idx));
tmp_sum2 += p_idx * (log(P_sum * q_idx));
}
return (tmp_sum1 - tmp_sum2) / P_sum;
}
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/details/reset_tensor_array.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/string/printf.h"
#ifdef PADDLE_WITH_TESTING
#include <gtest/gtest.h>
#include <gtest/gtest_prod.h>
#endif
namespace paddle {
/*
* Map variable name to tensor of scaling factors scaling it to MAX=1.0.
* bool denotes whether quantization of the variable should be done to unsigned
* type.
*/
using VarQuantScale =
std::unordered_map<std::string, std::pair<bool, framework::LoDTensor>>;
class AnalysisPredictor::MkldnnQuantizer {
public:
explicit MkldnnQuantizer(
AnalysisPredictor& predictor, // NOLINT
const std::shared_ptr<MkldnnQuantizerConfig>& qconfig)
: predictor_(predictor), qconfig_(qconfig) {}
// Execute full quantization procedure.
bool Quantize();
#if PADDLE_WITH_TESTING
friend class MkldnnQuantizerTest;
#endif
private:
// Run single warmup iteration
bool RunWarmup() const;
// Gather data from variables and calculate scales for them.
bool CalculateScales();
// Calculate a scale for tensor based on ScaleAlgo rules.
void CalculateSingleScale(const std::string& op_name,
const std::string& conn_name,
const std::string& var_name,
const framework::LoDTensor& var_tensor,
bool is_unsigned);
void PrepareArgument() const;
bool RunQuantizePasses() const;
std::vector<int> ExpandQuantizedBins(std::vector<int> quantized_bins,
std::vector<int> reference_bins) const;
// Using the KL-divergence method get the most precise scaling factor.
std::pair<bool, framework::LoDTensor> GetKLScalingFactor(
const framework::LoDTensor& var_tensor, bool is_unsigned) const;
std::pair<bool, framework::LoDTensor> GetMaxChScalingFactor(
const framework::LoDTensor& var_tensor, bool is_unsigned) const;
std::pair<bool, framework::LoDTensor> GetMaxScalingFactor(
const framework::LoDTensor& var_tensor, bool is_unsigned) const;
// Returns histogram and bin width
std::pair<std::vector<int>, float> Histogram(
const framework::LoDTensor& var_tensor, float min_val, float max_val,
size_t num_bins = 2048) const;
// Calculate the entropy.
float SafeEntropy(std::vector<int> reference_distr_P, int P_sum,
std::vector<int> candidate_distr_Q, int Q_sum) const;
private:
AnalysisPredictor& predictor_;
const std::shared_ptr<MkldnnQuantizerConfig> qconfig_;
// A map: variable name -> scale
VarQuantScale scales_;
};
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/api/paddle_mkldnn_quantizer_config.h"
namespace paddle {
MkldnnQuantizerConfig::MkldnnQuantizerConfig() {
// The default configuration of scale computing algorightms
rules_["conv2d"]["Input"] = ScaleAlgo::KL;
rules_["conv2d"]["Filter"] = ScaleAlgo::MAX_CH;
rules_["conv2d"]["Bias"] = ScaleAlgo::NONE; // do not compute scale
rules_["conv2d"]["ResidualData"] = ScaleAlgo::KL;
rules_["conv2d"]["Output"] = ScaleAlgo::KL; // do not compute scale
rules_["pool2d"]["X"] = ScaleAlgo::KL;
rules_["pool2d"]["Out"] = ScaleAlgo::KL; // do not compute scale
}
ScaleAlgo MkldnnQuantizerConfig::scale_algo(
const std::string& op_type_name, const std::string& conn_name) const {
if (rules_.find(op_type_name) != rules_.end()) {
auto op_rule = rules_.at(op_type_name);
if (op_rule.find(conn_name) != op_rule.end()) return op_rule.at(conn_name);
}
return default_scale_algo_;
}
} // namespace paddle
...@@ -27,10 +27,14 @@ ...@@ -27,10 +27,14 @@
// the abstract path of this header file will be changed. // the abstract path of this header file will be changed.
#include "paddle_api.h" // NOLINT #include "paddle_api.h" // NOLINT
#include "paddle_pass_builder.h" // NOLINT #include "paddle_pass_builder.h" // NOLINT
#ifdef PADDLE_WITH_MKLDNN
#include "paddle_mkldnn_quantizer_config.h" // NOLINT
#endif
namespace paddle { namespace paddle {
class AnalysisPredictor; class AnalysisPredictor;
struct MkldnnQuantizerConfig;
// NOTE WIP, not stable yet. // NOTE WIP, not stable yet.
struct AnalysisConfig { struct AnalysisConfig {
...@@ -186,6 +190,16 @@ struct AnalysisConfig { ...@@ -186,6 +190,16 @@ struct AnalysisConfig {
mkldnn_enabled_op_types_ = op_list; mkldnn_enabled_op_types_ = op_list;
} }
/** Turn on quantization.
*/
void EnableMkldnnQuantizer();
/** A boolean state telling whether the quantization is enabled.
*/
bool mkldnn_quantizer_enabled() const { return use_mkldnn_quantizer_; }
std::shared_ptr<MkldnnQuantizerConfig> mkldnn_quantizer_config() const;
/** Specify the memory buffer of program and parameter /** Specify the memory buffer of program and parameter
* @param prog_buffer the memory buffer of program. * @param prog_buffer the memory buffer of program.
* @param prog_buffer_size the size of the data. * @param prog_buffer_size the size of the data.
...@@ -271,10 +285,14 @@ struct AnalysisConfig { ...@@ -271,10 +285,14 @@ struct AnalysisConfig {
std::string serialized_info_cache_; std::string serialized_info_cache_;
mutable std::unique_ptr<PassStrategy> pass_builder_; mutable std::unique_ptr<PassStrategy> pass_builder_;
bool use_anakin_{false}; bool use_anakin_{false};
int anakin_max_batchsize_; int anakin_max_batchsize_;
std::map<std::string, std::vector<int>> anakin_max_input_shape_; std::map<std::string, std::vector<int>> anakin_max_input_shape_;
std::map<std::string, std::string> engine_opt_info_; std::map<std::string, std::string> engine_opt_info_;
bool use_mkldnn_quantizer_{false};
std::shared_ptr<MkldnnQuantizerConfig> mkldnn_quantizer_config_;
}; };
} // namespace paddle } // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cassert>
#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle_api.h" // NOLINT
namespace paddle {
// Algorithms for finding scale of quantized Tensors.
enum class ScaleAlgo {
NONE, // Do not compute scale
MAX, // Find scale based on the maximum absolute value
MAX_CH, // Find scale based on the maximum absolute value per channel
KL, // Find scale based on KL Divergence
};
struct MkldnnQuantizerConfig {
MkldnnQuantizerConfig();
/** Specify a quantization algorithm for a connection (input/output) of the
* operator type.
* @param op_type_name the operator's name.
* @param conn_name name of the connection (input/output) of the operator.
* @param algo the algorithm for computing scale.
*/
void SetScaleAlgo(std::string op_type_name, std::string conn_name,
ScaleAlgo algo) {
rules_[op_type_name][conn_name] = algo;
}
/** Get the quantization algorithm for a connection (input/output) of the
* operator type.
* @param op_type_name the operator's name.
* @param conn_name name of the connection (input/output) of the operator.
* @return the algorithm for computing scale.
*/
ScaleAlgo scale_algo(const std::string& op_type_name,
const std::string& conn_name) const;
/** Set the batch of data to be used for warm-up iteration.
* @param data batch of data.
*/
void SetWarmupData(std::shared_ptr<std::vector<PaddleTensor>> data) {
warmup_data_ = data;
}
/** Get the batch of data used for warm-up iteration.
* @return batch of data.
*/
std::shared_ptr<std::vector<PaddleTensor>> warmup_data() const {
return warmup_data_;
}
void SetWarmupBatchSize(int batch_size) { warmup_bs_ = batch_size; }
int warmup_batch_size() const { return warmup_bs_; }
void SetEnabledOpTypes(std::unordered_set<std::string> op_list) {
enabled_op_types_ = op_list;
}
const std::unordered_set<std::string>& enabled_op_types() const {
return enabled_op_types_;
}
void SetExcludedOpIds(std::unordered_set<int> op_ids_list) {
excluded_op_ids_ = op_ids_list;
}
const std::unordered_set<int>& excluded_op_ids() const {
return excluded_op_ids_;
}
void SetDefaultScaleAlgo(ScaleAlgo algo) { default_scale_algo_ = algo; }
ScaleAlgo default_scale_algo() const { return default_scale_algo_; }
protected:
std::map<std::string, std::map<std::string, ScaleAlgo>> rules_;
std::unordered_set<std::string> enabled_op_types_;
std::unordered_set<int> excluded_op_ids_;
std::shared_ptr<std::vector<PaddleTensor>> warmup_data_;
int warmup_bs_{1};
ScaleAlgo default_scale_algo_{ScaleAlgo::MAX};
};
} // namespace paddle
...@@ -107,8 +107,8 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -107,8 +107,8 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
use_gpu_ = true; use_gpu_ = true;
} }
void GpuPassStrategy::EnableQuantizer() { void GpuPassStrategy::EnableMkldnnQuantizer() {
LOG(ERROR) << "GPU not support quantization yet"; LOG(ERROR) << "GPU not support MKL-DNN quantization";
} }
void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) { void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) {
......
...@@ -30,6 +30,10 @@ class PaddlePassBuilder { ...@@ -30,6 +30,10 @@ class PaddlePassBuilder {
explicit PaddlePassBuilder(const std::vector<std::string> &passes) explicit PaddlePassBuilder(const std::vector<std::string> &passes)
: passes_(passes) {} : passes_(passes) {}
void SetPasses(std::initializer_list<std::string> passes) {
passes_ = passes;
}
/** Append a pass to the end of the passes. */ /** Append a pass to the end of the passes. */
void AppendPass(const std::string &pass_type); void AppendPass(const std::string &pass_type);
...@@ -85,9 +89,9 @@ class PassStrategy : public PaddlePassBuilder { ...@@ -85,9 +89,9 @@ class PassStrategy : public PaddlePassBuilder {
*/ */
virtual void EnableMKLDNN() {} virtual void EnableMKLDNN() {}
/** Enable quantize optimization /** Enable MKLDNN quantize optimization
*/ */
virtual void EnableQuantizer() {} virtual void EnableMkldnnQuantizer() {}
bool use_gpu() const { return use_gpu_; } bool use_gpu() const { return use_gpu_; }
...@@ -117,6 +121,8 @@ class CpuPassStrategy : public PassStrategy { ...@@ -117,6 +121,8 @@ class CpuPassStrategy : public PassStrategy {
for (auto &pass : std::vector<std::string>( for (auto &pass : std::vector<std::string>(
{"depthwise_conv_mkldnn_pass", // {"depthwise_conv_mkldnn_pass", //
"conv_bn_fuse_pass", // Execute BN passes again to
"conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order
"conv_bias_mkldnn_fuse_pass", // "conv_bias_mkldnn_fuse_pass", //
"conv3d_bias_mkldnn_fuse_pass", // "conv3d_bias_mkldnn_fuse_pass", //
"conv_relu_mkldnn_fuse_pass", // "conv_relu_mkldnn_fuse_pass", //
...@@ -130,15 +136,19 @@ class CpuPassStrategy : public PassStrategy { ...@@ -130,15 +136,19 @@ class CpuPassStrategy : public PassStrategy {
#endif #endif
} }
void EnableQuantizer() override { void EnableMkldnnQuantizer() override {
if (!use_quantizer_) { #ifdef PADDLE_WITH_MKLDNN
if (!use_mkldnn_quantizer_) {
passes_.push_back("cpu_quantize_placement_pass"); passes_.push_back("cpu_quantize_placement_pass");
} }
use_quantizer_ = true; use_mkldnn_quantizer_ = true;
#else
use_mkldnn_quantizer_ = false;
#endif
} }
protected: protected:
bool use_quantizer_{false}; bool use_mkldnn_quantizer_{false};
}; };
/** The GPU passes strategy, it is used in AnalysisPredictor with GPU mode. /** The GPU passes strategy, it is used in AnalysisPredictor with GPU mode.
...@@ -153,7 +163,7 @@ class GpuPassStrategy : public PassStrategy { ...@@ -153,7 +163,7 @@ class GpuPassStrategy : public PassStrategy {
} }
void EnableMKLDNN() override; void EnableMKLDNN() override;
void EnableQuantizer() override; void EnableMkldnnQuantizer() override;
virtual ~GpuPassStrategy() = default; virtual ~GpuPassStrategy() = default;
}; };
......
...@@ -4,6 +4,7 @@ cc_library(best_fit_allocator SRCS best_fit_allocator.cc DEPS allocator) ...@@ -4,6 +4,7 @@ cc_library(best_fit_allocator SRCS best_fit_allocator.cc DEPS allocator)
cc_library(locked_allocator SRCS locked_allocator.cc DEPS allocator) cc_library(locked_allocator SRCS locked_allocator.cc DEPS allocator)
cc_library(buffered_allocator SRCS buffered_allocator.cc DEPS allocator) cc_library(buffered_allocator SRCS buffered_allocator.cc DEPS allocator)
cc_library(legacy_allocator SRCS legacy_allocator.cc DEPS allocator buddy_allocator profiler) cc_library(legacy_allocator SRCS legacy_allocator.cc DEPS allocator buddy_allocator profiler)
cc_library(zero_size_allocator SRCS zero_size_allocator.cc DEPS allocator)
cc_test(buffered_allocator_test SRCS buffered_allocator_test.cc DEPS best_fit_allocator locked_allocator buffered_allocator cpu_allocator) cc_test(buffered_allocator_test SRCS buffered_allocator_test.cc DEPS best_fit_allocator locked_allocator buffered_allocator cpu_allocator)
if (WITH_GPU) if (WITH_GPU)
...@@ -37,30 +38,20 @@ else () ...@@ -37,30 +38,20 @@ else ()
set(AllocatorFacadeDeps) set(AllocatorFacadeDeps)
endif() endif()
list(APPEND AllocatorFacadeDeps cpu_allocator locked_allocator best_fit_allocator aligned_allocator auto_increment_allocator conditional_allocator retry_allocator buffered_allocator legacy_allocator zero_size_allocator)
cc_library(aligned_allocator SRCS aligned_allocator.cc DEPS allocator) cc_library(aligned_allocator SRCS aligned_allocator.cc DEPS allocator)
cc_library(auto_increment_allocator SRCS auto_increment_allocator.cc DEPS allocator) cc_library(auto_increment_allocator SRCS auto_increment_allocator.cc DEPS allocator)
cc_library(zero_size_allocator SRCS zero_size_allocator.cc DEPS allocator)
cc_library(conditional_allocator SRCS conditional_allocator.cc DEPS allocator) cc_library(conditional_allocator SRCS conditional_allocator.cc DEPS allocator)
cc_library(allocator_strategy SRCS allocator_strategy.cc DEPS gflags) cc_library(allocator_strategy SRCS allocator_strategy.cc DEPS gflags ${AllocatorFacadeDeps})
cc_library(allocator_facade SRCS allocator_facade.cc DEPS cc_library(allocator_facade SRCS allocator_facade.cc DEPS allocator_strategy)
${AllocatorFacadeDeps}
cpu_allocator
locked_allocator
best_fit_allocator
aligned_allocator
auto_increment_allocator
zero_size_allocator
conditional_allocator
retry_allocator
buffered_allocator
allocator_strategy
legacy_allocator
)
nv_test(allocation_and_eigen_test SRCS allocation_and_eigen_test.cu DEPS allocator_facade) nv_test(allocation_and_eigen_test SRCS allocation_and_eigen_test.cu DEPS allocator_facade)
cc_test(retry_allocator_test SRCS retry_allocator_test.cc DEPS retry_allocator best_fit_allocator locked_allocator cpu_allocator) cc_test(retry_allocator_test SRCS retry_allocator_test.cc DEPS retry_allocator best_fit_allocator locked_allocator cpu_allocator)
cc_test(naive_best_fit_allocator_facade_test SRCS naive_best_fit_allocator_facade_test.cc DEPS allocator_facade)
cc_test(allocator_facade_abs_flags_test SRCS allocator_facade_abs_flags_test.cc DEPS allocator_facade) cc_test(allocator_facade_abs_flags_test SRCS allocator_facade_abs_flags_test.cc DEPS allocator_facade)
cc_test(allocator_facade_frac_flags_test SRCS allocator_facade_frac_flags_test.cc DEPS allocator_facade) cc_test(allocator_facade_frac_flags_test SRCS allocator_facade_frac_flags_test.cc DEPS allocator_facade)
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <utility>
#include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/allocator.h"
namespace paddle { namespace paddle {
...@@ -93,6 +94,8 @@ class AlignedAllocator : public ThinAlignedAllocator { ...@@ -93,6 +94,8 @@ class AlignedAllocator : public ThinAlignedAllocator {
underlying_allocator_->Allocate(size + kAlignment, attr); underlying_allocator_->Allocate(size + kAlignment, attr);
return new AlignedAllocation<kAlignment>(std::move(raw_allocation), size); return new AlignedAllocation<kAlignment>(std::move(raw_allocation), size);
} }
void FreeImpl(Allocation* allocation) override { delete allocation; }
}; };
} // namespace allocation } // namespace allocation
......
...@@ -27,16 +27,24 @@ bool Allocator::IsAllocThreadSafe() const { return false; } ...@@ -27,16 +27,24 @@ bool Allocator::IsAllocThreadSafe() const { return false; }
AllocationPtr Allocator::Allocate(size_t size, Allocator::Attr attr) { AllocationPtr Allocator::Allocate(size_t size, Allocator::Attr attr) {
auto ptr = AllocateImpl(size, attr); auto ptr = AllocateImpl(size, attr);
ptr->set_allocator(this); ptr->RegisterDecoratedAllocator(this);
return AllocationPtr(ptr); return AllocationPtr(ptr);
} }
void Allocator::Free(Allocation* allocation) { delete allocation; } void Allocator::FreeImpl(Allocation* allocation) {
Allocator* allocator = allocation->TopDecoratedAllocator();
allocator->Free(allocation);
}
void Allocator::Free(Allocation* allocation) {
allocation->PopDecoratedAllocator();
FreeImpl(allocation);
}
const char* BadAlloc::what() const noexcept { return msg_.c_str(); } const char* BadAlloc::what() const noexcept { return msg_.c_str(); }
void AllocationDeleter::operator()(Allocation* allocation) const { void AllocationDeleter::operator()(Allocation* allocation) const {
auto* allocator = allocation->allocator(); Allocator* allocator = allocation->TopDecoratedAllocator();
allocator->Free(allocation); allocator->Free(allocation);
} }
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
...@@ -44,13 +46,56 @@ class Allocator; ...@@ -44,13 +46,56 @@ class Allocator;
// NOTE: this is the base class of Allocation. Each allocator can use its own // NOTE: this is the base class of Allocation. Each allocator can use its own
// allocation object. // allocation object.
// NOTE: the `Allocation::ptr()` could be nullptr, if the allocation size is 0 // NOTE: the `Allocation::ptr()` could be nullptr, if the allocation size is 0
/**
* Allocation is returned by Allocator::Allocate() method.
*
* An allocator may be decorated by another allocator. For example, we can
* decorate
* a RetryAllocator to any allocator to perform allocation retry when first
* allocation request fails.
*
* Explanations of Allocator design is as follows:
*
* Suppose we have an allocator which is decorated by several allocators:
*
* A(1) <- A(2) <- A(3) <- ... <- A(n)
*
* , and the public allocator is A(1).
*
* The allocation process would be:
*
* A(n).Allocate() -> ... -> A(2).Allocate() -> A(1).Allocate()
*
* , and the free process would be:
*
* A(1).Free() -> A(2).Free() -> ... -> A(n).Free()
*
* Therefore, we should record the allocator chain when allocating, so
* that we can free the allocation in the reverse order of allocator chain.
* The field `decorated_allocators_` is used to record this chain.
*
* Another example is that we want to add additional fields in Allocation,
* e.g., something what is done in AlignedAllocator, etc.
* In this case, we should declare a derived class of Allocation, which
* contains an underlying Allocation allocated by the underlying allocator.
* Therefore, `decorated_allocators_` of the new Allocation object would
* be a new chain, differing from the underlying Allocation object.
*/
class Allocation { class Allocation {
public: public:
Allocation(void* ptr, size_t size, platform::Place place) Allocation(void* ptr, size_t size, platform::Place place)
: allocator_(nullptr), ptr_(ptr), size_(size), place_(place) {} : ptr_(ptr), size_(size), place_(place) {
// NOTE(zjl): Since decorated_allocators_ is usually a small vector
// We reserve a small buffer to it to prevent frequent heap allocation
// Not quite sure whether we need something like gtl vector.
decorated_allocators_.reserve(8);
}
Allocation(const Allocation& o) = delete; Allocation(const Allocation& o) = delete;
Allocation& operator=(const Allocation& o) = delete; Allocation& operator=(const Allocation& o) = delete;
Allocation(Allocation&& o) = delete;
Allocation& operator=(Allocation&& o) = delete;
// Returns the holding pointer. // Returns the holding pointer.
// NOTE: For performance consideration, it is better not to make this method // NOTE: For performance consideration, it is better not to make this method
...@@ -72,17 +117,31 @@ class Allocation { ...@@ -72,17 +117,31 @@ class Allocation {
const platform::Place& place() const { return place_; } const platform::Place& place() const { return place_; }
Allocator* allocator() { return allocator_; } virtual ~Allocation();
void set_allocator(Allocator* allocator) { allocator_ = allocator; } private:
const std::vector<Allocator*>& DecoratedAllocators() const {
return decorated_allocators_;
}
virtual ~Allocation(); inline void RegisterDecoratedAllocator(Allocator* allocator) {
decorated_allocators_.push_back(allocator);
}
inline void PopDecoratedAllocator() { decorated_allocators_.pop_back(); }
inline Allocator* TopDecoratedAllocator() {
return decorated_allocators_.back();
}
private: private:
Allocator* allocator_;
void* ptr_; void* ptr_;
size_t size_; size_t size_;
platform::Place place_; platform::Place place_;
std::vector<Allocator*> decorated_allocators_;
friend class Allocator;
friend class AllocationDeleter;
}; };
using AllocationPtr = std::unique_ptr<Allocation, AllocationDeleter>; using AllocationPtr = std::unique_ptr<Allocation, AllocationDeleter>;
...@@ -132,9 +191,12 @@ class Allocator { ...@@ -132,9 +191,12 @@ class Allocator {
// True if the `Allocate` is thread safe. // True if the `Allocate` is thread safe.
virtual bool IsAllocThreadSafe() const; virtual bool IsAllocThreadSafe() const;
// This function should not be called outside
void Free(Allocation* allocation);
protected: protected:
virtual void Free(Allocation* allocation);
virtual Allocation* AllocateImpl(size_t size, Allocator::Attr attr) = 0; virtual Allocation* AllocateImpl(size_t size, Allocator::Attr attr) = 0;
virtual void FreeImpl(Allocation* allocation);
private: private:
friend class AllocationDeleter; friend class AllocationDeleter;
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <map> #include <map>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/memory/allocation/aligned_allocator.h" #include "paddle/fluid/memory/allocation/aligned_allocator.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h" #include "paddle/fluid/memory/allocation/allocator_facade.h"
...@@ -30,6 +31,7 @@ ...@@ -30,6 +31,7 @@
#include "paddle/fluid/memory/allocation/retry_allocator.h" #include "paddle/fluid/memory/allocation/retry_allocator.h"
#include "paddle/fluid/memory/allocation/zero_size_allocator.h" #include "paddle/fluid/memory/allocation/zero_size_allocator.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/memory/allocation/cuda_allocator.h" #include "paddle/fluid/memory/allocation/cuda_allocator.h"
...@@ -47,6 +49,17 @@ namespace paddle { ...@@ -47,6 +49,17 @@ namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
static inline std::shared_ptr<Allocator> WrapRetryAllocator(
std::shared_ptr<Allocator> allocator, int64_t retry_time) {
if (retry_time > 0) {
auto* retry_allocator =
new RetryAllocator(std::move(allocator), retry_time);
allocator.reset(retry_allocator);
}
return allocator;
}
// TODO(yy): Dirty code here. This class should be configurable in runtime. // TODO(yy): Dirty code here. This class should be configurable in runtime.
class CPUManagedAllocator : public Allocator { class CPUManagedAllocator : public Allocator {
public: public:
...@@ -110,14 +123,10 @@ class ChunkedAllocator : public Allocator { ...@@ -110,14 +123,10 @@ class ChunkedAllocator : public Allocator {
std::shared_ptr<Allocator> CreateAllocatorWithChunk() { std::shared_ptr<Allocator> CreateAllocatorWithChunk() {
chunks_.emplace_back(raw_allocator_->Allocate(max_chunk_size_)); chunks_.emplace_back(raw_allocator_->Allocate(max_chunk_size_));
auto* allocation = chunks_.back().get(); auto* allocation = chunks_.back().get();
std::unique_ptr<Allocator> allocator(new LockedAllocator( std::shared_ptr<Allocator> allocator(new LockedAllocator(
std::unique_ptr<Allocator>(new BestFitAllocator(allocation)))); std::shared_ptr<Allocator>(new BestFitAllocator(allocation))));
if (retry_time_ > 0) { allocator = WrapRetryAllocator(allocator, retry_time_);
auto* retry_allocator =
new RetryAllocator(std::move(allocator), retry_time_);
allocator.reset(retry_allocator);
}
return std::make_shared<AlignedAllocator<64u>>(std::move(allocator)); return std::make_shared<AlignedAllocator<64u>>(std::move(allocator));
} }
...@@ -188,13 +197,23 @@ class AllocatorFacadePrivate { ...@@ -188,13 +197,23 @@ class AllocatorFacadePrivate {
~AllocatorFacadePrivate() = default; ~AllocatorFacadePrivate() = default;
AllocatorFacadePrivate() { AllocatorFacadePrivate() {
if (GetAllocatorStrategy() == AllocatorStrategy::kLegacy) { auto strategy = GetAllocatorStrategy();
switch (strategy) {
case AllocatorStrategy::kLegacy: {
InitLegacyAllocator(); InitLegacyAllocator();
} else { break;
}
case AllocatorStrategy::kNaiveBestFit: {
InitCPUAllocator(); InitCPUAllocator();
InitCUDAAllocator(); InitCUDAAllocator();
InitCUDAPinnedAllocator(); InitCUDAPinnedAllocator();
WrapZeroSizeAllocator(); WrapZeroSizeAllocator();
break;
}
default: {
PADDLE_THROW("Unsupported allocator strategy: %d",
static_cast<int>(strategy));
}
} }
} }
...@@ -252,8 +271,7 @@ AllocatorFacade& AllocatorFacade::Instance() { ...@@ -252,8 +271,7 @@ AllocatorFacade& AllocatorFacade::Instance() {
std::shared_ptr<Allocation> AllocatorFacade::AllocShared( std::shared_ptr<Allocation> AllocatorFacade::AllocShared(
const platform::Place& place, size_t size, Allocator::Attr attr) { const platform::Place& place, size_t size, Allocator::Attr attr) {
return std::shared_ptr<Allocation>(Alloc(place, size, attr).release(), return std::shared_ptr<Allocation>(Alloc(place, size, attr));
AllocationDeleter());
} }
AllocationPtr AllocatorFacade::Alloc(const platform::Place& place, size_t size, AllocationPtr AllocatorFacade::Alloc(const platform::Place& place, size_t size,
......
...@@ -14,20 +14,27 @@ ...@@ -14,20 +14,27 @@
#include "paddle/fluid/memory/allocation/allocator_strategy.h" #include "paddle/fluid/memory/allocation/allocator_strategy.h"
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/platform/enforce.h"
DEFINE_string( DEFINE_string(
allocator_strategy, "legacy", allocator_strategy, "legacy",
"The allocation strategy. Legacy means the original allocator of Fluid." "The allocation strategy. Legacy means the original allocator of Fluid."
"New means the experimental allocators of Fluid. in [legacy, new]"); "naive_best_fit means the experimental best fit allocator. "
"allocator. Enum in [legacy, naive_best_fit].");
namespace paddle { namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
static AllocatorStrategy GetStrategyFromFlag() { static AllocatorStrategy GetStrategyFromFlag() {
return FLAGS_allocator_strategy == "legacy" if (FLAGS_allocator_strategy == "legacy") {
? AllocatorStrategy::kLegacy return AllocatorStrategy::kLegacy;
: AllocatorStrategy::kNaiveBestFit; } else if (FLAGS_allocator_strategy == "naive_best_fit") {
return AllocatorStrategy::kNaiveBestFit;
} else {
PADDLE_THROW("Unsupported allocator strategy: %s",
FLAGS_allocator_strategy);
}
} }
AllocatorStrategy GetAllocatorStrategy() { AllocatorStrategy GetAllocatorStrategy() {
......
...@@ -109,7 +109,7 @@ size_t BestFitAllocator::NumFreeChunks() const { ...@@ -109,7 +109,7 @@ size_t BestFitAllocator::NumFreeChunks() const {
} }
return num; return num;
} }
void BestFitAllocator::Free(Allocation* allocation) { void BestFitAllocator::FreeImpl(Allocation* allocation) {
auto* bf_allocation = dynamic_cast<BestFitAllocation*>(allocation); auto* bf_allocation = dynamic_cast<BestFitAllocation*>(allocation);
PADDLE_ENFORCE_NOT_NULL(bf_allocation, PADDLE_ENFORCE_NOT_NULL(bf_allocation,
"The input allocation is not BestFitAllocation."); "The input allocation is not BestFitAllocation.");
......
...@@ -119,7 +119,7 @@ class BestFitAllocator : public Allocator { ...@@ -119,7 +119,7 @@ class BestFitAllocator : public Allocator {
void InsertFreeNode(const ListIt& it); void InsertFreeNode(const ListIt& it);
protected: protected:
void Free(Allocation* allocation) override; void FreeImpl(Allocation* allocation) override;
Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override; Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override;
private: private:
......
...@@ -22,11 +22,11 @@ namespace paddle { ...@@ -22,11 +22,11 @@ namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
BufferedAllocator::BufferedAllocator(std::unique_ptr<Allocator> &&allocator) BufferedAllocator::BufferedAllocator(std::shared_ptr<Allocator> allocator)
: underlying_allocator_(std::move(allocator)) { : underlying_allocator_(std::move(allocator)) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
underlying_allocator_, underlying_allocator_,
"Underlying allocator of BufferedAllocator must be unmanaged"); "Underlying allocator of BufferedAllocator must not be null");
if (underlying_allocator_->IsAllocThreadSafe()) { if (underlying_allocator_->IsAllocThreadSafe()) {
mtx_.reset(new std::mutex()); mtx_.reset(new std::mutex());
} }
...@@ -41,19 +41,19 @@ void BufferedAllocator::FreeCache(size_t size) { ...@@ -41,19 +41,19 @@ void BufferedAllocator::FreeCache(size_t size) {
while (!allocations_.empty()) { // free the largest while (!allocations_.empty()) { // free the largest
auto it = --allocations_.end(); auto it = --allocations_.end();
cur += it->second->size(); cur += it->second->size();
delete it->second.release(); underlying_allocator_->Free(it->second.release());
allocations_.erase(it); allocations_.erase(it);
if (cur >= size) return; if (cur >= size) return;
} }
} }
bool BufferedAllocator::IsAllocThreadSafe() const { bool BufferedAllocator::IsAllocThreadSafe() const { return mtx_ != nullptr; }
return this->underlying_allocator_->IsAllocThreadSafe();
} void BufferedAllocator::FreeImpl(Allocation *allocation) {
void BufferedAllocator::Free(Allocation *allocation) {
platform::LockGuardPtr<std::mutex> guard(mtx_); platform::LockGuardPtr<std::mutex> guard(mtx_);
allocations_.emplace(allocation->size(), AllocationPtr(allocation)); allocations_.emplace(allocation->size(), AllocationPtr(allocation));
} }
Allocation *BufferedAllocator::AllocateImpl(size_t size, Allocator::Attr attr) { Allocation *BufferedAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
{ {
platform::LockGuardPtr<std::mutex> guard(mtx_); platform::LockGuardPtr<std::mutex> guard(mtx_);
...@@ -61,17 +61,15 @@ Allocation *BufferedAllocator::AllocateImpl(size_t size, Allocator::Attr attr) { ...@@ -61,17 +61,15 @@ Allocation *BufferedAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
if (it != allocations_.end() && it->first < size * 2) { if (it != allocations_.end() && it->first < size * 2) {
AllocationPtr result(std::move(it->second)); AllocationPtr result(std::move(it->second));
allocations_.erase(it); allocations_.erase(it);
return new AllocationWithUnderlying(std::move(result)); return result.release();
} }
} }
try { try {
return new AllocationWithUnderlying( return underlying_allocator_->Allocate(size, attr).release();
underlying_allocator_->Allocate(size, attr));
} catch (BadAlloc &) { } catch (BadAlloc &) {
FreeCache(size); FreeCache(size);
return new AllocationWithUnderlying( return underlying_allocator_->Allocate(size, attr).release();
underlying_allocator_->Allocate(size, attr));
} }
} }
......
...@@ -31,7 +31,7 @@ namespace allocation { ...@@ -31,7 +31,7 @@ namespace allocation {
// underlying_allocator_ // underlying_allocator_
class BufferedAllocator : public Allocator { class BufferedAllocator : public Allocator {
public: public:
explicit BufferedAllocator(std::unique_ptr<Allocator> &&allocator); explicit BufferedAllocator(std::shared_ptr<Allocator> allocator);
~BufferedAllocator(); ~BufferedAllocator();
...@@ -44,11 +44,11 @@ class BufferedAllocator : public Allocator { ...@@ -44,11 +44,11 @@ class BufferedAllocator : public Allocator {
void FreeCache(size_t size); void FreeCache(size_t size);
protected: protected:
void Free(Allocation *allocation) override; void FreeImpl(Allocation *allocation) override;
Allocation *AllocateImpl(size_t size, Allocator::Attr attr) override; Allocation *AllocateImpl(size_t size, Allocator::Attr attr) override;
private: private:
std::unique_ptr<Allocator> underlying_allocator_; std::shared_ptr<Allocator> underlying_allocator_;
std::multimap<size_t, AllocationPtr> allocations_; std::multimap<size_t, AllocationPtr> allocations_;
std::unique_ptr<std::mutex> mtx_; std::unique_ptr<std::mutex> mtx_;
}; };
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/memory/allocation/buffered_allocator.h" #include "paddle/fluid/memory/allocation/buffered_allocator.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <utility>
#include "paddle/fluid/memory/allocation/best_fit_allocator.h" #include "paddle/fluid/memory/allocation/best_fit_allocator.h"
#include "paddle/fluid/memory/allocation/cpu_allocator.h" #include "paddle/fluid/memory/allocation/cpu_allocator.h"
#include "paddle/fluid/memory/allocation/locked_allocator.h" #include "paddle/fluid/memory/allocation/locked_allocator.h"
...@@ -64,7 +65,7 @@ class StubAllocator : public Allocator { ...@@ -64,7 +65,7 @@ class StubAllocator : public Allocator {
size_t GetFreeCount() const { return destruct_count_; } size_t GetFreeCount() const { return destruct_count_; }
protected: protected:
void Free(Allocation *allocation) override { void FreeImpl(Allocation *allocation) override {
auto *alloc = dynamic_cast<StubAllocation *>(allocation); auto *alloc = dynamic_cast<StubAllocation *>(allocation);
PADDLE_ENFORCE_NOT_NULL(alloc); PADDLE_ENFORCE_NOT_NULL(alloc);
if (alloc->ptr()) delete[] static_cast<uint8_t *>(alloc->ptr()); if (alloc->ptr()) delete[] static_cast<uint8_t *>(alloc->ptr());
......
...@@ -20,25 +20,27 @@ namespace paddle { ...@@ -20,25 +20,27 @@ namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
CPUAllocation::CPUAllocation(void *ptr, size_t size)
: Allocation(ptr, size, platform::CPUPlace()) {}
bool CPUAllocator::IsAllocThreadSafe() const { return true; } bool CPUAllocator::IsAllocThreadSafe() const { return true; }
void CPUAllocator::Free(Allocation *allocation) { void CPUAllocator::FreeImpl(Allocation *allocation) {
PADDLE_ENFORCE_NOT_NULL(dynamic_cast<CPUAllocation *>(allocation)); void *p = allocation->ptr();
free(allocation->ptr()); #ifdef _WIN32
_aligned_free(p);
#else
free(p);
#endif
delete allocation; delete allocation;
} }
Allocation *CPUAllocator::AllocateImpl(size_t size, Allocator::Attr attr) { Allocation *CPUAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
void *ptr; void *p;
auto status = posix_memalign(&ptr, kAlignment, size); #ifdef _WIN32
if (UNLIKELY(status) != 0) { p = _aligned_malloc(size, kAlignment);
throw BadAlloc(string::Sprintf("Cannot allocate cpu memory %d. Errno is %d", #else
size, status)); PADDLE_ENFORCE_EQ(posix_memalign(&p, kAlignment, size), 0, "Alloc %ld error!",
} size);
return new CPUAllocation(ptr, size); #endif
return new Allocation(p, size, platform::CPUPlace());
} }
} // namespace allocation } // namespace allocation
} // namespace memory } // namespace memory
......
...@@ -31,19 +31,13 @@ namespace allocation { ...@@ -31,19 +31,13 @@ namespace allocation {
// //
// NOTE(yy): It is no need to use `BestFitAllocator` in CPU. We can import // NOTE(yy): It is no need to use `BestFitAllocator` in CPU. We can import
// an open-sourced allocator into Paddle. // an open-sourced allocator into Paddle.
class CPUAllocator;
class CPUAllocation : public Allocation {
public:
CPUAllocation(void* ptr, size_t size);
};
class CPUAllocator : public Allocator { class CPUAllocator : public Allocator {
public: public:
constexpr static size_t kAlignment = 64u; constexpr static size_t kAlignment = 4096UL;
bool IsAllocThreadSafe() const override; bool IsAllocThreadSafe() const override;
protected: protected:
void Free(Allocation* allocation) override; void FreeImpl(Allocation* allocation) override;
Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override; Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override;
}; };
} // namespace allocation } // namespace allocation
......
...@@ -23,15 +23,14 @@ namespace paddle { ...@@ -23,15 +23,14 @@ namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
bool CUDAAllocator::IsAllocThreadSafe() const { return true; } bool CUDAAllocator::IsAllocThreadSafe() const { return true; }
void CUDAAllocator::Free(Allocation* allocation) { void CUDAAllocator::FreeImpl(Allocation* allocation) {
platform::CUDADeviceGuard guard(place_.device); platform::CUDADeviceGuard guard(place_.device);
auto* cuda_allocation = dynamic_cast<CUDAAllocation*>(allocation); PADDLE_ENFORCE_EQ(boost::get<platform::CUDAPlace>(allocation->place()),
PADDLE_ENFORCE_NOT_NULL(cuda_allocation);
PADDLE_ENFORCE_EQ(boost::get<platform::CUDAPlace>(cuda_allocation->place()),
place_); place_);
PADDLE_ENFORCE(cudaFree(allocation->ptr())); PADDLE_ENFORCE(cudaFree(allocation->ptr()));
delete allocation; delete allocation;
} }
Allocation* CUDAAllocator::AllocateImpl(size_t size, Allocator::Attr attr) { Allocation* CUDAAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
platform::CUDADeviceGuard guard(place_.device); platform::CUDADeviceGuard guard(place_.device);
void* ptr; void* ptr;
...@@ -41,8 +40,9 @@ Allocation* CUDAAllocator::AllocateImpl(size_t size, Allocator::Attr attr) { ...@@ -41,8 +40,9 @@ Allocation* CUDAAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
"Cannot allocate %d on GPU %d, cuda status %d, %s", size, place_.device, "Cannot allocate %d on GPU %d, cuda status %d, %s", size, place_.device,
status, cudaGetErrorString(status))); status, cudaGetErrorString(status)));
} }
return new CUDAAllocation(ptr, size, platform::Place(place_)); return new Allocation(ptr, size, platform::Place(place_));
} }
} // namespace allocation } // namespace allocation
} // namespace memory } // namespace memory
} // namespace paddle } // namespace paddle
...@@ -20,13 +20,6 @@ namespace paddle { ...@@ -20,13 +20,6 @@ namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
// CUDA System allocator and allocation.
// Just a flag type.
class CUDAAllocation : public Allocation {
public:
using Allocation::Allocation;
};
class CUDAAllocator : public Allocator { class CUDAAllocator : public Allocator {
public: public:
explicit CUDAAllocator(const platform::CUDAPlace& place) : place_(place) {} explicit CUDAAllocator(const platform::CUDAPlace& place) : place_(place) {}
...@@ -35,7 +28,7 @@ class CUDAAllocator : public Allocator { ...@@ -35,7 +28,7 @@ class CUDAAllocator : public Allocator {
bool IsAllocThreadSafe() const override; bool IsAllocThreadSafe() const override;
protected: protected:
void Free(Allocation* allocation) override; void FreeImpl(Allocation* allocation) override;
Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override; Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override;
private: private:
......
...@@ -134,26 +134,22 @@ size_t Used<platform::CPUPlace>(const platform::CPUPlace &place) { ...@@ -134,26 +134,22 @@ size_t Used<platform::CPUPlace>(const platform::CPUPlace &place) {
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
BuddyAllocator *GetGPUBuddyAllocator(int gpu_id) { class GPUBuddyAllocatorList {
static std::once_flag init_flag; public:
static detail::BuddyAllocator **a_arr = nullptr; GPUBuddyAllocatorList()
static std::vector<int> devices; : allocators_(platform::GetCUDADeviceCount()),
flags_(platform::GetCUDADeviceCount()) {
std::call_once(init_flag, [gpu_id]() { allocation::GPUMemMonitor.Initialize(allocators_.size());
devices = platform::GetSelectedDevices(); }
int gpu_num = devices.size();
allocation::GPUMemMonitor.Initialize(devices.size());
a_arr = new BuddyAllocator *[gpu_num]; BuddyAllocator *Get(size_t dev_id) {
for (size_t i = 0; i < devices.size(); ++i) { PADDLE_ENFORCE(dev_id < flags_.size(), "Invalid device id %s", dev_id);
int dev_id = devices[i]; std::call_once(flags_[dev_id], [this, dev_id] {
a_arr[i] = nullptr;
platform::SetDeviceId(dev_id); platform::SetDeviceId(dev_id);
a_arr[i] = new BuddyAllocator(std::unique_ptr<detail::SystemAllocator>( allocators_[dev_id] = new BuddyAllocator(
std::unique_ptr<detail::SystemAllocator>(
new detail::GPUAllocator(dev_id)), new detail::GPUAllocator(dev_id)),
platform::GpuMinChunkSize(), platform::GpuMinChunkSize(), platform::GpuMaxChunkSize());
platform::GpuMaxChunkSize());
VLOG(10) << "\n\nNOTE:\n" VLOG(10) << "\n\nNOTE:\n"
<< "You can set GFlags environment variable " << "You can set GFlags environment variable "
...@@ -167,13 +163,19 @@ BuddyAllocator *GetGPUBuddyAllocator(int gpu_id) { ...@@ -167,13 +163,19 @@ BuddyAllocator *GetGPUBuddyAllocator(int gpu_id) {
<< FLAGS_initial_gpu_memory_in_mb << FLAGS_initial_gpu_memory_in_mb
<< ". Current 'FLAGS_reallocate_gpu_memory_in_mb' value is " << ". Current 'FLAGS_reallocate_gpu_memory_in_mb' value is "
<< FLAGS_reallocate_gpu_memory_in_mb << "\n\n"; << FLAGS_reallocate_gpu_memory_in_mb << "\n\n";
}
}); });
return allocators_[dev_id];
}
private:
std::vector<BuddyAllocator *> allocators_;
std::vector<std::once_flag> flags_;
};
BuddyAllocator *GetGPUBuddyAllocator(int gpu_id) {
static GPUBuddyAllocatorList allocators;
platform::SetDeviceId(gpu_id); platform::SetDeviceId(gpu_id);
auto pos = std::distance(devices.begin(), return allocators.Get(gpu_id);
std::find(devices.begin(), devices.end(), gpu_id));
return a_arr[pos];
} }
#endif #endif
...@@ -192,7 +194,7 @@ void *Alloc<platform::CUDAPlace>(const platform::CUDAPlace &place, ...@@ -192,7 +194,7 @@ void *Alloc<platform::CUDAPlace>(const platform::CUDAPlace &place,
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto *buddy_allocator = GetGPUBuddyAllocator(place.device); auto *buddy_allocator = GetGPUBuddyAllocator(place.device);
auto *ptr = buddy_allocator->Alloc(size); auto *ptr = buddy_allocator->Alloc(size);
if (ptr == nullptr) { if (ptr == nullptr && size > 0) {
int cur_dev = platform::GetCurrentDeviceId(); int cur_dev = platform::GetCurrentDeviceId();
platform::SetDeviceId(place.device); platform::SetDeviceId(place.device);
size_t avail, total; size_t avail, total;
...@@ -347,7 +349,7 @@ Allocation *LegacyAllocator::AllocateImpl(size_t size, Allocator::Attr attr) { ...@@ -347,7 +349,7 @@ Allocation *LegacyAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
return tmp_alloc; return tmp_alloc;
} }
void LegacyAllocator::Free(Allocation *allocation) { void LegacyAllocator::FreeImpl(Allocation *allocation) {
boost::apply_visitor( boost::apply_visitor(
legacy::FreeVisitor(allocation->ptr(), allocation->size()), legacy::FreeVisitor(allocation->ptr(), allocation->size()),
allocation->place()); allocation->place());
......
...@@ -73,7 +73,7 @@ class LegacyAllocator : public Allocator { ...@@ -73,7 +73,7 @@ class LegacyAllocator : public Allocator {
protected: protected:
Allocation *AllocateImpl(size_t size, Allocator::Attr attr) override; Allocation *AllocateImpl(size_t size, Allocator::Attr attr) override;
void Free(Allocation *allocation) override; void FreeImpl(Allocation *allocation) override;
private: private:
platform::Place place_; platform::Place place_;
......
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
#include "paddle/fluid/memory/allocation/locked_allocator.h" #include "paddle/fluid/memory/allocation/locked_allocator.h"
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <utility>
#include "paddle/fluid/memory/allocation/allocation_with_underlying.h" #include "paddle/fluid/memory/allocation/allocation_with_underlying.h"
#include "paddle/fluid/platform/lock_guard_ptr.h" #include "paddle/fluid/platform/lock_guard_ptr.h"
namespace paddle { namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
...@@ -23,26 +25,24 @@ namespace allocation { ...@@ -23,26 +25,24 @@ namespace allocation {
bool LockedAllocator::IsAllocThreadSafe() const { return true; } bool LockedAllocator::IsAllocThreadSafe() const { return true; }
LockedAllocator::LockedAllocator( LockedAllocator::LockedAllocator(
std::unique_ptr<Allocator> &&underlying_allocator) std::shared_ptr<Allocator> underlying_allocator)
: underlying_allocator_(std::move(underlying_allocator)) { : underlying_allocator_(std::move(underlying_allocator)) {
PADDLE_ENFORCE_NOT_NULL(underlying_allocator_); PADDLE_ENFORCE_NOT_NULL(underlying_allocator_);
if (!underlying_allocator_->IsAllocThreadSafe()) { if (!underlying_allocator_->IsAllocThreadSafe()) {
mtx_.reset(new std::mutex()); mtx_.reset(new std::mutex());
} }
} }
void LockedAllocator::Free(Allocation *allocation) {
{ void LockedAllocator::FreeImpl(Allocation *allocation) {
platform::LockGuardPtr<std::mutex> guard(mtx_); platform::LockGuardPtr<std::mutex> guard(mtx_);
reinterpret_cast<AllocationWithUnderlying *>(allocation) underlying_allocator_->Free(allocation);
->allocation_.reset(); // Destroy inner allocation
}
delete allocation;
} }
Allocation *LockedAllocator::AllocateImpl(size_t size, Allocator::Attr attr) { Allocation *LockedAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
platform::LockGuardPtr<std::mutex> guard(mtx_); platform::LockGuardPtr<std::mutex> guard(mtx_);
return new AllocationWithUnderlying( return underlying_allocator_->Allocate(size, attr).release();
underlying_allocator_->Allocate(size, attr));
} }
} // namespace allocation } // namespace allocation
} // namespace memory } // namespace memory
} // namespace paddle } // namespace paddle
...@@ -24,15 +24,15 @@ namespace allocation { ...@@ -24,15 +24,15 @@ namespace allocation {
// A allocator to make underlying allocator thread safe. // A allocator to make underlying allocator thread safe.
class LockedAllocator : public Allocator { class LockedAllocator : public Allocator {
public: public:
explicit LockedAllocator(std::unique_ptr<Allocator> &&underlying_allocator); explicit LockedAllocator(std::shared_ptr<Allocator> underlying_allocator);
bool IsAllocThreadSafe() const override; bool IsAllocThreadSafe() const override;
protected: protected:
void Free(Allocation *allocation) override; void FreeImpl(Allocation *allocation) override;
Allocation *AllocateImpl(size_t size, Allocator::Attr attr) override; Allocation *AllocateImpl(size_t size, Allocator::Attr attr) override;
private: private:
std::unique_ptr<Allocator> underlying_allocator_; std::shared_ptr<Allocator> underlying_allocator_;
std::unique_ptr<std::mutex> mtx_; std::unique_ptr<std::mutex> mtx_;
}; };
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#ifdef PADDLE_WITH_CUDA
DECLARE_double(fraction_of_gpu_memory_to_use);
DECLARE_double(fraction_of_cuda_pinned_memory_to_use);
DECLARE_int64(gpu_allocator_retry_time);
#endif
DECLARE_string(allocator_strategy);
namespace paddle {
namespace memory {
namespace allocation {
TEST(allocator, allocator) {
#ifdef PADDLE_WITH_CUDA
FLAGS_fraction_of_gpu_memory_to_use = 0.01;
FLAGS_gpu_allocator_retry_time = 500;
FLAGS_fraction_of_cuda_pinned_memory_to_use = 0.5;
#endif
FLAGS_allocator_strategy = "naive_best_fit";
auto &instance = AllocatorFacade::Instance();
platform::Place place;
size_t size = 1024;
{
place = platform::CPUPlace();
size = 1024;
auto cpu_allocation = instance.Alloc(place, size);
ASSERT_NE(cpu_allocation, nullptr);
ASSERT_NE(cpu_allocation->ptr(), nullptr);
ASSERT_EQ(cpu_allocation->place(), place);
ASSERT_EQ(cpu_allocation->size(), size);
}
#ifdef PADDLE_WITH_CUDA
{
place = platform::CUDAPlace(0);
size = 1024;
auto gpu_allocation = instance.Alloc(place, size);
ASSERT_NE(gpu_allocation, nullptr);
ASSERT_NE(gpu_allocation->ptr(), nullptr);
ASSERT_EQ(gpu_allocation->place(), place);
ASSERT_GE(gpu_allocation->size(), size);
}
{
// Allocate 2GB gpu memory
place = platform::CUDAPlace(0);
size = 2 * static_cast<size_t>(1 << 30);
auto gpu_allocation = instance.Alloc(place, size);
ASSERT_NE(gpu_allocation, nullptr);
ASSERT_NE(gpu_allocation->ptr(), nullptr);
ASSERT_EQ(gpu_allocation->place(), place);
ASSERT_GE(gpu_allocation->size(), size);
}
{
place = platform::CUDAPinnedPlace();
size = (1 << 20);
auto cuda_pinned_allocation =
instance.Alloc(platform::CUDAPinnedPlace(), 1 << 20);
ASSERT_NE(cuda_pinned_allocation, nullptr);
ASSERT_NE(cuda_pinned_allocation->ptr(), nullptr);
ASSERT_EQ(cuda_pinned_allocation->place(), place);
ASSERT_GE(cuda_pinned_allocation->size(), size);
}
#endif
}
} // namespace allocation
} // namespace memory
} // namespace paddle
...@@ -20,20 +20,15 @@ namespace paddle { ...@@ -20,20 +20,15 @@ namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
bool CPUPinnedAllocator::IsAllocThreadSafe() const { return true; } bool CPUPinnedAllocator::IsAllocThreadSafe() const { return true; }
void CPUPinnedAllocator::Free(Allocation *allocation) { void CPUPinnedAllocator::FreeImpl(Allocation *allocation) {
PADDLE_ENFORCE_NOT_NULL(dynamic_cast<CPUPinnedAllocation *>(allocation));
PADDLE_ENFORCE(cudaFreeHost(allocation->ptr())); PADDLE_ENFORCE(cudaFreeHost(allocation->ptr()));
delete allocation; delete allocation;
} }
Allocation *CPUPinnedAllocator::AllocateImpl(size_t size, Allocation *CPUPinnedAllocator::AllocateImpl(size_t size,
Allocator::Attr attr) { Allocator::Attr attr) {
// PADDLE_ENFORCE_EQ(
// attr, kCrossDevice,
// "CPUPinnedAllocator should be used for Cross-Device Communication");
void *ptr; void *ptr;
PADDLE_ENFORCE(cudaHostAlloc(&ptr, size, cudaHostAllocPortable)); PADDLE_ENFORCE(cudaHostAlloc(&ptr, size, cudaHostAllocPortable));
return new CPUPinnedAllocation(ptr, size); return new Allocation(ptr, size, platform::CUDAPinnedPlace());
} }
} // namespace allocation } // namespace allocation
} // namespace memory } // namespace memory
......
...@@ -20,18 +20,12 @@ namespace memory { ...@@ -20,18 +20,12 @@ namespace memory {
namespace allocation { namespace allocation {
// Allocator uses `cudaHostAlloc` // Allocator uses `cudaHostAlloc`
class CPUPinnedAllocation : public Allocation {
public:
CPUPinnedAllocation(void *ptr, size_t size)
: Allocation(ptr, size, platform::CUDAPinnedPlace()) {}
};
class CPUPinnedAllocator : public Allocator { class CPUPinnedAllocator : public Allocator {
public: public:
bool IsAllocThreadSafe() const override; bool IsAllocThreadSafe() const override;
protected: protected:
void Free(Allocation *allocation) override; void FreeImpl(Allocation *allocation) override;
Allocation *AllocateImpl(size_t size, Allocator::Attr attr) override; Allocation *AllocateImpl(size_t size, Allocator::Attr attr) override;
}; };
......
...@@ -18,25 +18,15 @@ namespace paddle { ...@@ -18,25 +18,15 @@ namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
bool RetryAllocator::IsAllocThreadSafe() const { void RetryAllocator::FreeImpl(Allocation* allocation) {
return underlying_allocator_->IsAllocThreadSafe();
}
void RetryAllocator::Free(Allocation* allocation) {
// Delete underlying allocation first. // Delete underlying allocation first.
reinterpret_cast<AllocationWithUnderlying*>(allocation)->allocation_.reset(); underlying_allocator_->Free(allocation);
{
// notify all waited allocators, they can try to allocate memory after free.
std::lock_guard<std::mutex> lock(mutex_);
cv_.notify_all(); cv_.notify_all();
}
delete allocation;
} }
Allocation* RetryAllocator::AllocateImpl(size_t size, Allocator::Attr attr) { Allocation* RetryAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
auto alloc_func = [&, this]() { auto alloc_func = [&, this]() {
return new AllocationWithUnderlying( return underlying_allocator_->Allocate(size, attr).release();
underlying_allocator_->Allocate(size, attr));
}; };
// In fact, we can unify the code of allocation success and failure // In fact, we can unify the code of allocation success and failure
// But it would add lock even when allocation success at the first time // But it would add lock even when allocation success at the first time
......
...@@ -18,38 +18,32 @@ ...@@ -18,38 +18,32 @@
#include <condition_variable> // NOLINT #include <condition_variable> // NOLINT
#include <memory> #include <memory>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <utility>
#include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/allocator.h"
namespace paddle { namespace paddle {
namespace memory { namespace memory {
namespace allocation { namespace allocation {
class RetryAllocator;
class RetryAllocator : public Allocator { class RetryAllocator : public Allocator {
public: public:
RetryAllocator(std::unique_ptr<Allocator>&& allocator, size_t retry_ms) RetryAllocator(std::shared_ptr<Allocator> allocator, size_t retry_ms)
: underlying_allocator_(std::move(allocator)), retry_time_(retry_ms) { : underlying_allocator_(std::move(allocator)), retry_time_(retry_ms) {
EnforceCheck();
}
bool IsAllocThreadSafe() const override;
private:
void EnforceCheck() {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
underlying_allocator_.get(), underlying_allocator_,
"UnderlyingAllocator of RetryAllocator must be UnmanagedAllocator"); "UnderlyingAllocator of RetryAllocator must not be null");
PADDLE_ENFORCE(underlying_allocator_->IsAllocThreadSafe(), PADDLE_ENFORCE(underlying_allocator_->IsAllocThreadSafe(),
"UnderlyingAllocator of RetryAllocator must be thread-safe"); "UnderlyingAllocator of RetryAllocator must be thread-safe");
} }
bool IsAllocThreadSafe() const override { return true; }
protected: protected:
void Free(Allocation* allocation) override; void FreeImpl(Allocation* allocation) override;
Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override; Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override;
private: private:
std::unique_ptr<Allocator> underlying_allocator_; std::shared_ptr<Allocator> underlying_allocator_;
std::chrono::milliseconds retry_time_; std::chrono::milliseconds retry_time_;
std::mutex mutex_; std::mutex mutex_;
std::condition_variable cv_; std::condition_variable cv_;
...@@ -57,8 +51,6 @@ class RetryAllocator : public Allocator { ...@@ -57,8 +51,6 @@ class RetryAllocator : public Allocator {
// For debug, We can add an atomic integer to record how many memory sizes are // For debug, We can add an atomic integer to record how many memory sizes are
// waited to allocate // waited to allocate
// std::atomic<size_t> waited_allocate_size_{0}; // std::atomic<size_t> waited_allocate_size_{0};
friend class RetryAllocation;
}; };
} // namespace allocation } // namespace allocation
......
...@@ -24,11 +24,20 @@ bool ZeroSizeAllocator::IsAllocThreadSafe() const { ...@@ -24,11 +24,20 @@ bool ZeroSizeAllocator::IsAllocThreadSafe() const {
Allocation *ZeroSizeAllocator::AllocateImpl(size_t size, Allocator::Attr attr) { Allocation *ZeroSizeAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
if (size == 0) { if (size == 0) {
return new ZeroSizeAllocation(place_); return new Allocation(nullptr, 0, place_);
} else { } else {
return underlying_allocator_->Allocate(size, attr).release(); return underlying_allocator_->Allocate(size, attr).release();
} }
} }
void ZeroSizeAllocator::FreeImpl(Allocation *allocation) {
if (allocation->size() == 0) {
delete allocation;
} else {
underlying_allocator_->Free(allocation);
}
}
} // namespace allocation } // namespace allocation
} // namespace memory } // namespace memory
} // namespace paddle } // namespace paddle
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <memory>
#include <utility> #include <utility>
#include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/allocator.h"
...@@ -23,12 +24,6 @@ namespace allocation { ...@@ -23,12 +24,6 @@ namespace allocation {
// The allocator handles the request's size is zero. Allocator will always // The allocator handles the request's size is zero. Allocator will always
// return an allocation even the request size is zero. However, the // return an allocation even the request size is zero. However, the
// allocation.ptr() is nullptr // allocation.ptr() is nullptr
class ZeroSizeAllocation : public Allocation {
public:
explicit ZeroSizeAllocation(const platform::Place& p)
: Allocation(nullptr, 0, p) {}
};
class ZeroSizeAllocator : public Allocator { class ZeroSizeAllocator : public Allocator {
public: public:
ZeroSizeAllocator(std::shared_ptr<Allocator> underlying_allocator, ZeroSizeAllocator(std::shared_ptr<Allocator> underlying_allocator,
...@@ -39,6 +34,7 @@ class ZeroSizeAllocator : public Allocator { ...@@ -39,6 +34,7 @@ class ZeroSizeAllocator : public Allocator {
protected: protected:
Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override; Allocation* AllocateImpl(size_t size, Allocator::Attr attr) override;
void FreeImpl(Allocation* allocation) override;
private: private:
std::shared_ptr<Allocator> underlying_allocator_; std::shared_ptr<Allocator> underlying_allocator_;
......
...@@ -48,7 +48,7 @@ if (WITH_DISTRIBUTE) ...@@ -48,7 +48,7 @@ if (WITH_DISTRIBUTE)
SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch) SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch)
endif() endif()
register_operators(EXCLUDES py_func_op warpctc_op conv_fusion_op sync_batch_norm_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) register_operators(EXCLUDES py_func_op warpctc_op dgc_op conv_fusion_op sync_batch_norm_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
if (WITH_GPU) if (WITH_GPU)
# warpctc_op needs cudnn 7 above # warpctc_op needs cudnn 7 above
...@@ -72,6 +72,12 @@ endif() ...@@ -72,6 +72,12 @@ endif()
set(COMMON_OP_DEPS ${OP_HEADER_DEPS}) set(COMMON_OP_DEPS ${OP_HEADER_DEPS})
if (WITH_GPU AND NOT WIN32)
op_library(dgc_op DEPS dgc)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(dgc);\n")
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dgc)
endif()
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col)
......
...@@ -14,69 +14,10 @@ limitations under the License. */ ...@@ -14,69 +14,10 @@ limitations under the License. */
#include "paddle/fluid/operators/clip_by_norm_op.h" #include "paddle/fluid/operators/clip_by_norm_op.h"
namespace paddle {
namespace operators {
class ClipByNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ClipByNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ClipByNormOp should not be null.");
auto max_norm = ctx->Attrs().Get<float>("max_norm");
PADDLE_ENFORCE_GT(max_norm, 0, "max_norm should be greater than 0.");
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor) The input of clip_by_norm op."
"The number of dimensions must be between [1, 9].");
AddOutput("Out",
"(Tensor) The output of clip_by_norm op with shape as input(X)");
AddAttr<float>("max_norm", "(float) The maximum norm value.");
AddComment(R"DOC(
ClipByNorm Operator.
This operator limits the L2 norm of the input $X$ within $max\_norm$.
If the L2 norm of $X$ is less than or equal to $max\_norm$, $Out$ will be
the same as $X$. If the L2 norm of $X$ is greater than $max\_norm$, $X$ will
be linearly scaled to make the L2 norm of $Out$ equal to $max\_norm$, as
shown in the following formula:
$$
Out = \\frac{max\\_norm * X}{norm(X)},
$$
where $norm(X)$ represents the L2 norm of $X$.
Examples:
.. code-block:: python
data = fluid.layer.data(
name='data', shape=[2, 4, 6], dtype='float32')
reshaped = fluid.layers.clip_by_norm(
x=data, max_norm=0.5)
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(clip_by_norm, ops::ClipByNormOp, REGISTER_OP_WITHOUT_GRADIENT(clip_by_norm, ops::ClipByNormOp,
ops::ClipByNormOpMaker); ops::ClipByNormOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
clip_by_norm, clip_by_norm,
ops::ClipByNormKernel<paddle::platform::CPUDeviceContext, float>); ops::ClipByNormKernel<paddle::platform::CPUDeviceContext, float>);
...@@ -83,5 +83,59 @@ class ClipByNormKernel : public framework::OpKernel<T> { ...@@ -83,5 +83,59 @@ class ClipByNormKernel : public framework::OpKernel<T> {
} }
}; };
class ClipByNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ClipByNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ClipByNormOp should not be null.");
auto max_norm = ctx->Attrs().Get<float>("max_norm");
PADDLE_ENFORCE_GT(max_norm, 0, "max_norm should be greater than 0.");
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor) The input of clip_by_norm op."
"The number of dimensions must be between [1, 9].");
AddOutput("Out",
"(Tensor) The output of clip_by_norm op with shape as input(X)");
AddAttr<float>("max_norm", "(float) The maximum norm value.");
AddComment(R"DOC(
ClipByNorm Operator.
This operator limits the L2 norm of the input $X$ within $max\_norm$.
If the L2 norm of $X$ is less than or equal to $max\_norm$, $Out$ will be
the same as $X$. If the L2 norm of $X$ is greater than $max\_norm$, $X$ will
be linearly scaled to make the L2 norm of $Out$ equal to $max\_norm$, as
shown in the following formula:
$$
Out = \\frac{max\\_norm * X}{norm(X)},
$$
where $norm(X)$ represents the L2 norm of $X$.
Examples:
.. code-block:: python
data = fluid.layer.data(
name='data', shape=[2, 4, 6], dtype='float32')
reshaped = fluid.layers.clip_by_norm(
x=data, max_norm=0.5)
)DOC");
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/operators/dgc_clip_by_norm_op.h"
namespace paddle {
namespace operators {
class DGCClipByNormOp : public ClipByNormOp {
public:
using ClipByNormOp::ClipByNormOp;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("current_step"),
"current_step should be set.");
return ClipByNormOp::InferShape(ctx);
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "current_step") {
VLOG(10) << "var_name:" << var_name << " need not to transform";
return expected_kernel_type;
}
return framework::OperatorWithKernel::GetKernelTypeForVar(
var_name, tensor, expected_kernel_type);
}
};
class DGCClipByNormOpMaker : public ClipByNormOpMaker {
public:
void Make() override {
AddInput("current_step", "(Tensor) Current step.");
AddAttr<float>("rampup_begin_step",
"(float, -1.0)"
"The period when begin k_select.")
.SetDefault(-1.0);
return ClipByNormOpMaker::Make();
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(dgc_clip_by_norm, ops::DGCClipByNormOp,
ops::DGCClipByNormOpMaker);
REGISTER_OP_CPU_KERNEL(
dgc_clip_by_norm,
ops::DGCClipByNormKernel<paddle::platform::CPUDeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/dgc_clip_by_norm_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
dgc_clip_by_norm,
ops::DGCClipByNormKernel<paddle::platform::CUDADeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/clip_by_norm_op.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class DGCClipByNormKernel : public ClipByNormKernel<DeviceContext, T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto rampup_begin_step = context.Attr<float>("rampup_begin_step");
if (static_cast<int>(rampup_begin_step) >= 0) {
auto current_step_tensor =
context.Input<framework::Tensor>("current_step");
auto* current_step = current_step_tensor->data<T>();
if (static_cast<int>(*current_step) <
static_cast<int>(rampup_begin_step)) {
VLOG(10) << "current_step:" << *current_step
<< " < rampup_begin_step:" << rampup_begin_step
<< " so does't use dgc_clip_by_norm";
return;
}
}
return ClipByNormKernel<DeviceContext, T>::Compute(context);
};
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/dgc_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class DGCOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("U"), "Input(U) of DGCop should not be null.");
PADDLE_ENFORCE(ctx->HasInput("V"), "Input(V) of DGCop should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of DGCop should not be null.");
PADDLE_ENFORCE(ctx->HasInput("current_step"),
"Input(current_step) of DGCop should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("U_out"),
"Output(U_out) of DGCop should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("V_out"),
"Output(V_out) of DGCop should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("k"),
"Output(k) of DGCop should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("EncodeGrad"),
"Output(EncodeGrad) of DGCop should not be null.");
}
protected:
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "current_step" || var_name == "rampup_step" ||
var_name == "k") {
VLOG(10) << "var_name:" << var_name << " need not to transform";
return expected_kernel_type;
}
return framework::OperatorWithKernel::GetKernelTypeForVar(
var_name, tensor, expected_kernel_type);
}
};
class DGCOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("U", "(Tensor) Middle tensor of DGC");
AddInput("V", "(Tensor) Middle tensor of DGC");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("current_step", "(Tensor) Current step.");
AddOutput("U_out",
"(Tensor) "
"Output encoded gradient");
AddOutput("V_out",
"(Tensor) "
"Output encoded gradient");
AddOutput("EncodeGrad",
"(Tensor) "
"Output encoded gradient");
AddOutput("Grad_out",
"(Tensor) "
"Output grad gradient");
AddOutput("k",
"(Tensor) "
"Output top-k value");
AddAttr<float>("m",
"(float, 0.9) "
"The momentum of learning rate.")
.SetDefault(0.9);
AddAttr<bool>("use_nesterov",
"(bool, true)"
"The momentum of learning rate.")
.SetDefault(true);
AddAttr<std::vector<float>>("sparsity",
"(vecotr, float)"
"The period sparsity of k_select.");
AddAttr<float>("rampup_begin_step",
"(float, 0.0)"
"The period when begin k_select.")
.SetDefault(0.0);
AddAttr<float>("rampup_step",
"(float, 0.0)"
"The period when begin k_select.");
AddComment(R"DOC(
Original paper is https://arxiv.org/abs/1712.01887
DGC reduce the communication bandwidth by sending only the important gradients (sparse update):\
only gradients larger than a threshold are transmitted.
To avoid losing information, DGC accumulate the rest of the gradients locally.
Eventually, these gradients become large enough to be transmitted.
Thus, DGC send the large gradients immediately but eventually send all of the gradients over time.
To ensure no loss of accuracy, DGC employs momentum correc-tionandlocal gradient clipping on top of the gradient sparsification to maintain model performance.
DGC also uses momentum factor masking and warmup training to overcome the staleness problem caused by reduced communication.
This optimizer will do two things:
1. Compress the gradient by get TopK import value from tensor \
and use it for allreduce to reduce network bandwidth.
2. Call momentum to optimize on the cost.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(dgc, ops::DGCOp, ops::DGCOpMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/dgc_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
dgc, ops::DGCOpKernel<paddle::platform::CUDADeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "dgc/dgc.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
namespace paddle {
namespace operators {
inline float get_period_sparcity(const std::vector<float>& sparsity,
float cur_step, float rampup_steps) {
PADDLE_ENFORCE(static_cast<int>(cur_step) >= 0);
size_t idx = static_cast<int>(cur_step * sparsity.size() / rampup_steps);
if (idx >= sparsity.size()) {
return 0.999;
}
PADDLE_ENFORCE(idx < sparsity.size());
return sparsity[idx];
}
template <typename DeviceContext, typename T>
class DGCOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto u = ctx.Input<framework::Tensor>("U");
auto v = ctx.Input<framework::Tensor>("V");
auto g = ctx.Input<framework::Tensor>("Grad");
// attrs
float m = ctx.Attr<float>("m");
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
auto sparsity = ctx.Attr<std::vector<float>>("sparsity");
auto rampup_begin_step = ctx.Attr<float>("rampup_begin_step");
auto rampup_step = ctx.Attr<float>("rampup_step");
// current step
auto current_step_tensor = ctx.Input<framework::Tensor>("current_step");
const float* current_step = current_step_tensor->data<float>();
if (static_cast<int>(*current_step) < static_cast<int>(rampup_begin_step)) {
VLOG(10) << "current_step:" << *current_step
<< " < rampup_begin_step:" << rampup_begin_step
<< " so does't use dgc";
return;
}
float ratio =
1 - get_period_sparcity(sparsity, static_cast<float>(*current_step),
rampup_step);
PADDLE_ENFORCE(ratio > 0.0 && ratio < 1.0);
int k = static_cast<int>(g->numel() * ratio);
VLOG(10) << "m:" << m << ", use_nesterov:" << use_nesterov
<< ", rampup_begin_step:" << rampup_begin_step
<< ", rampup_step:" << rampup_step
<< ", current_step:" << *current_step << ", ratio:" << ratio
<< ", k:" << k;
auto k_out = ctx.Output<framework::Tensor>("k");
T* k_out_data = k_out->data<T>();
*k_out_data = k;
auto u_out = ctx.Output<framework::Tensor>("U_out");
auto v_out = ctx.Output<framework::Tensor>("V_out");
auto encode_grad_out = ctx.Output<framework::Tensor>("EncodeGrad");
// FIXME(gongwb): use cublas.
auto u_out_e = framework::EigenVector<T>::Flatten(*u_out);
auto u_e = framework::EigenVector<T>::Flatten(*u);
auto g_e = framework::EigenVector<T>::Flatten(*g);
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto& eigen_ctx = *dev_ctx.eigen_device();
if (use_nesterov) {
// u = m * (u + g)
u_out_e.device(eigen_ctx) = m * (u_e + g_e);
// v = u + v + g
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
ctx, u, v, 0, AddFunctor<T>(), v_out);
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
ctx, g, v, 0, AddFunctor<T>(), v_out);
} else {
// u = m * u + g
u_out_e.device(eigen_ctx) = m * u_e + g_e;
// v = u + v
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
ctx, u, v, 0, AddFunctor<T>(), v_out);
}
T* v_out_data = v_out->mutable_data<T>(ctx.GetPlace());
T* u_out_data = u_out->mutable_data<T>(ctx.GetPlace());
T* encode_grad_out_data = encode_grad_out->mutable_data<T>(
framework::DDim{2 * k}, ctx.GetPlace());
int buf_size = paddle::communication::dgc::get_buffer_size(k);
auto& allocator = platform::DeviceTemporaryAllocator::Instance().Get(
ctx.GetPlace(), dev_ctx.stream());
auto tmp_ious_data = allocator.Allocate(buf_size);
void* buf = reinterpret_cast<void*>(tmp_ious_data->ptr());
if (!paddle::communication::dgc::k_select(
static_cast<void*>(encode_grad_out_data), k, v_out_data,
static_cast<int>(v_out->numel()), buf, dev_ctx.stream(),
u_out_data)) {
LOG(FATAL) << "v_out numel:" << v_out->numel();
}
auto grad_out = ctx.Output<framework::Tensor>("Grad_out");
math::SetConstant<DeviceContext, T> tset;
tset(dev_ctx, grad_out, static_cast<T>(0));
}
};
} // namespace operators
} // namespace paddle
...@@ -10,6 +10,9 @@ ...@@ -10,6 +10,9 @@
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/spectral_norm_op.h" #include "paddle/fluid/operators/spectral_norm_op.h"
#include <memory>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
...@@ -156,6 +159,28 @@ class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -156,6 +159,28 @@ class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
class SpectralNormGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("spectral_norm_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetInput("Weight", Input("Weight"));
op->SetInput("U", Input("U"));
op->SetInput("V", Input("V"));
op->SetOutput(framework::GradVarName("Weight"), InputGrad("Weight"));
op->SetAttrMap(Attrs());
return op;
}
};
class SpectralNormOpGrad : public framework::OperatorWithKernel { class SpectralNormOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -185,7 +210,7 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel { ...@@ -185,7 +210,7 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(spectral_norm, ops::SpectralNormOp, ops::SpectralNormOpMaker, REGISTER_OPERATOR(spectral_norm, ops::SpectralNormOp, ops::SpectralNormOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::SpectralNormGradOpDescMaker);
REGISTER_OPERATOR(spectral_norm_grad, ops::SpectralNormOpGrad); REGISTER_OPERATOR(spectral_norm_grad, ops::SpectralNormOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
spectral_norm, spectral_norm,
......
...@@ -46,8 +46,9 @@ cc_test(cpu_helper_test SRCS cpu_helper_test.cc DEPS cpu_helper) ...@@ -46,8 +46,9 @@ cc_test(cpu_helper_test SRCS cpu_helper_test.cc DEPS cpu_helper)
IF(WITH_GPU) IF(WITH_GPU)
set(GPU_CTX_DEPS dynload_cuda dynamic_loader) set(GPU_CTX_DEPS dynload_cuda dynamic_loader)
set(dgc_deps dgc)
ELSE() ELSE()
set(GPU_CTX_DEPS) set(dgc_deps)
ENDIF() ENDIF()
IF(WITH_MKLDNN) IF(WITH_MKLDNN)
...@@ -68,7 +69,8 @@ ENDIF() ...@@ -68,7 +69,8 @@ ENDIF()
# memcpy depends on device_context, here add deps individually for # memcpy depends on device_context, here add deps individually for
# avoiding cycle dependencies # avoiding cycle dependencies
cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc ${STREAM_CALLBACK_DEPS} cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc ${STREAM_CALLBACK_DEPS}
place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} temp_allocator) place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}
temp_allocator ${dgc_deps})
if(WIN32) if(WIN32)
if(WITH_GPU AND NOT WITH_DSO) if(WITH_GPU AND NOT WITH_DSO)
......
...@@ -40,7 +40,7 @@ limitations under the License. */ ...@@ -40,7 +40,7 @@ limitations under the License. */
#define PADDLE_ASSERT_MSG_CODE(e, m, c) \ #define PADDLE_ASSERT_MSG_CODE(e, m, c) \
do { \ do { \
if (!(e)) { \ if (!(e)) { \
printf("%s:%d Assertion `%s` failed (%s %d).\n", __FILE__, __LINE__, \ printf("%s:%d Assertion `%s` failed (%s %ld).\n", __FILE__, __LINE__, \
TOSTRING(e), m, c); \ TOSTRING(e), m, c); \
asm("trap;"); \ asm("trap;"); \
} \ } \
......
...@@ -21,6 +21,8 @@ limitations under the License. */ ...@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/cuda_device_guard.h"
#endif #endif
#include "glog/logging.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -324,8 +326,17 @@ void CUDADeviceContext::Wait() const { ...@@ -324,8 +326,17 @@ void CUDADeviceContext::Wait() const {
auto& allocator = auto& allocator =
DeviceTemporaryAllocator::Instance().Get<CUDADeviceContext>(*this); DeviceTemporaryAllocator::Instance().Get<CUDADeviceContext>(*this);
allocator.Release([this]() { allocator.Release([this]() {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); cudaError_t e_sync = cudaStreamSynchronize(stream_);
PADDLE_ENFORCE(cudaGetLastError()); if (e_sync != 0) {
LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync)
<< " errno:" << e_sync;
}
cudaError_t e_get = cudaGetLastError();
if (e_get != 0) {
LOG(FATAL) << "cudaGetLastError " << cudaGetErrorString(e_get)
<< " errno:" << e_get;
}
}); });
} }
......
...@@ -31,6 +31,10 @@ limitations under the License. */ ...@@ -31,6 +31,10 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/piece.h" #include "paddle/fluid/string/piece.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "dgc/dgc.h"
#endif
DEFINE_int32(paddle_num_threads, 1, DEFINE_int32(paddle_num_threads, 1,
"Number of threads for each paddle instance."); "Number of threads for each paddle instance.");
DEFINE_int32(multiple_of_cupti_buffer_size, 1, DEFINE_int32(multiple_of_cupti_buffer_size, 1,
...@@ -43,6 +47,10 @@ namespace framework { ...@@ -43,6 +47,10 @@ namespace framework {
std::once_flag gflags_init_flag; std::once_flag gflags_init_flag;
std::once_flag p2p_init_flag; std::once_flag p2p_init_flag;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
std::once_flag dgc_init_flag;
#endif
void InitGflags(std::vector<std::string> argv) { void InitGflags(std::vector<std::string> argv) {
std::call_once(gflags_init_flag, [&]() { std::call_once(gflags_init_flag, [&]() {
FLAGS_logtostderr = true; FLAGS_logtostderr = true;
...@@ -203,5 +211,15 @@ void InitGLOG(const std::string &prog_name) { ...@@ -203,5 +211,15 @@ void InitGLOG(const std::string &prog_name) {
#endif #endif
} }
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void InitDGC() {
std::call_once(dgc_init_flag, []() {
PADDLE_ENFORCE(paddle::communication::dgc::dynloadNcclLib());
});
}
#else
void InitDGC() {}
#endif
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -30,5 +30,7 @@ void InitDevices(bool init_p2p); ...@@ -30,5 +30,7 @@ void InitDevices(bool init_p2p);
void InitDevices(bool init_p2p, const std::vector<int> devices); void InitDevices(bool init_p2p, const std::vector<int> devices);
void InitDGC();
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/platform/temporary_allocator.h" #include "paddle/fluid/platform/temporary_allocator.h"
#include <memory>
#include "paddle/fluid/memory/allocation/allocator_facade.h" #include "paddle/fluid/memory/allocation/allocator_facade.h"
DEFINE_int64(limit_of_tmp_allocation, -1, DEFINE_int64(limit_of_tmp_allocation, -1,
...@@ -29,38 +30,31 @@ namespace paddle { ...@@ -29,38 +30,31 @@ namespace paddle {
namespace platform { namespace platform {
namespace alloc = memory::allocation; namespace alloc = memory::allocation;
TemporaryAllocation::TemporaryAllocation(
alloc::AllocationPtr &&underlying_allocation)
: Allocation(underlying_allocation->ptr(), underlying_allocation->size(),
underlying_allocation->place()),
underlying_allocation_(std::move(underlying_allocation)) {}
TemporaryAllocator::TemporaryAllocator(platform::Place place) : place_(place) { TemporaryAllocator::TemporaryAllocator(platform::Place place) : place_(place) {
temp_mem_map_.reset(new std::multimap<size_t, TemporaryAllocation *>()); temp_mem_map_.reset(new std::multimap<size_t, alloc::Allocation *>());
} }
bool TemporaryAllocator::IsAllocThreadSafe() const { return true; } bool TemporaryAllocator::IsAllocThreadSafe() const { return true; }
void TemporaryAllocator::Release(const std::function<void()> &callback) { void TemporaryAllocator::Release(const std::function<void()> &callback) {
std::unique_ptr<std::multimap<size_t, TemporaryAllocation *>> t_allocations; std::unique_ptr<std::multimap<size_t, alloc::Allocation *>> t_allocations;
{ {
std::unique_lock<std::mutex> lock(mtx_); std::unique_lock<std::mutex> lock(mtx_);
callback(); callback();
t_allocations.swap(temp_mem_map_); t_allocations.swap(temp_mem_map_);
temp_mem_map_.reset(new std::multimap<size_t, TemporaryAllocation *>()); temp_mem_map_.reset(new std::multimap<size_t, alloc::Allocation *>());
wait_delete_mem_ = 0; wait_delete_mem_ = 0;
} }
alloc::AllocationDeleter deleter;
for (auto tmp : *t_allocations) { for (auto tmp : *t_allocations) {
VLOG(10) << "Delete temporary allocation " << tmp.second->ptr() VLOG(10) << "Delete temporary allocation " << tmp.second->ptr()
<< " size: " << tmp.second->size(); << " size: " << tmp.second->size();
delete tmp.second; deleter(tmp.second);
} }
} }
void TemporaryAllocator::Free(alloc::Allocation *allocation) { void TemporaryAllocator::FreeImpl(alloc::Allocation *temp_allocation) {
auto *temp_allocation = dynamic_cast<TemporaryAllocation *>(allocation);
PADDLE_ENFORCE_NOT_NULL(temp_allocation);
if (platform::is_gpu_place(temp_allocation->place())) { if (platform::is_gpu_place(temp_allocation->place())) {
PADDLE_ENFORCE(platform::is_same_place(temp_allocation->place(), place_), PADDLE_ENFORCE(platform::is_same_place(temp_allocation->place(), place_),
"The place should be the same."); "The place should be the same.");
...@@ -84,7 +78,7 @@ void TemporaryAllocator::Free(alloc::Allocation *allocation) { ...@@ -84,7 +78,7 @@ void TemporaryAllocator::Free(alloc::Allocation *allocation) {
} }
VLOG(10) << "Delete temporary allocation " << temp_allocation->ptr() VLOG(10) << "Delete temporary allocation " << temp_allocation->ptr()
<< " size: " << temp_allocation->size(); << " size: " << temp_allocation->size();
delete temp_allocation; alloc::AllocationDeleter()(temp_allocation);
} }
size_t TemporaryAllocator::TemporaryAllocationQueueSize() { size_t TemporaryAllocator::TemporaryAllocationQueueSize() {
...@@ -119,11 +113,9 @@ alloc::Allocation *TemporaryAllocator::AllocateImpl( ...@@ -119,11 +113,9 @@ alloc::Allocation *TemporaryAllocator::AllocateImpl(
} }
// If not find the the available allocation, get allocation from // If not find the the available allocation, get allocation from
// AllocatorFacadeInstance. // AllocatorFacadeInstance.
auto raw_allocation = auto temp_mem = alloc::AllocatorFacade::Instance().Alloc(place_, size, attr);
alloc::AllocatorFacade::Instance().Alloc(place_, size, attr);
auto temp_mem = new TemporaryAllocation(std::move(raw_allocation));
VLOG(10) << "Alloc temporary allocation: " << temp_mem->ptr() << ": " << size; VLOG(10) << "Alloc temporary allocation: " << temp_mem->ptr() << ": " << size;
return temp_mem; return temp_mem.release();
} }
} // namespace platform } // namespace platform
......
...@@ -16,20 +16,13 @@ ...@@ -16,20 +16,13 @@
#include <condition_variable> // NOLINT #include <condition_variable> // NOLINT
#include <deque> #include <deque>
#include <map> #include <map>
#include <memory>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/lock_guard_ptr.h" #include "paddle/fluid/platform/lock_guard_ptr.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
class TemporaryAllocation : public memory::allocation::Allocation {
public:
explicit TemporaryAllocation(
memory::allocation::AllocationPtr &&underlying_allocation);
memory::allocation::AllocationPtr underlying_allocation_;
};
/*! \brief the TemporaryAllocator is used to alloc the temporary allocation /*! \brief the TemporaryAllocator is used to alloc the temporary allocation
* which used by CUDA's async operation. * which used by CUDA's async operation.
* *
...@@ -56,7 +49,7 @@ class TemporaryAllocator : public memory::allocation::Allocator { ...@@ -56,7 +49,7 @@ class TemporaryAllocator : public memory::allocation::Allocator {
void SetCallback(const std::function<void()> &callback); void SetCallback(const std::function<void()> &callback);
protected: protected:
void Free(memory::allocation::Allocation *allocation) override; void FreeImpl(memory::allocation::Allocation *allocation) override;
memory::allocation::Allocation *AllocateImpl( memory::allocation::Allocation *AllocateImpl(
size_t size, memory::allocation::Allocator::Attr attr) override; size_t size, memory::allocation::Allocator::Attr attr) override;
...@@ -65,8 +58,8 @@ class TemporaryAllocator : public memory::allocation::Allocator { ...@@ -65,8 +58,8 @@ class TemporaryAllocator : public memory::allocation::Allocator {
platform::Place place_; platform::Place place_;
// When the allocation is not held by any variable, it should be placed // When the allocation is not held by any variable, it should be placed
// to temp_mem_map immediately. // to temp_mem_map immediately.
std::unique_ptr<std::multimap<size_t, TemporaryAllocation *>> temp_mem_map_{ std::unique_ptr<std::multimap<size_t, memory::allocation::Allocation *>>
nullptr}; temp_mem_map_{nullptr};
std::mutex mtx_; std::mutex mtx_;
size_t wait_delete_mem_{0}; size_t wait_delete_mem_{0};
std::function<void()> callback_; std::function<void()> callback_;
......
...@@ -222,6 +222,7 @@ void BindOpDesc(pybind11::module *m) { ...@@ -222,6 +222,7 @@ void BindOpDesc(pybind11::module *m) {
.def("attr_type", &pd::OpDesc::GetAttrType) .def("attr_type", &pd::OpDesc::GetAttrType)
.def("attr_names", &pd::OpDesc::AttrNames) .def("attr_names", &pd::OpDesc::AttrNames)
.def("_set_attr", &pd::OpDesc::SetAttr) .def("_set_attr", &pd::OpDesc::SetAttr)
.def("remove_attr", &pd::OpDesc::RemoveAttr)
.def("attr", &pd::OpDesc::GetAttr) .def("attr", &pd::OpDesc::GetAttr)
.def("set_block_attr", &pd::OpDesc::SetBlockAttr) .def("set_block_attr", &pd::OpDesc::SetBlockAttr)
.def("set_blocks_attr", &pd::OpDesc::SetBlocksAttr) .def("set_blocks_attr", &pd::OpDesc::SetBlocksAttr)
......
...@@ -324,6 +324,7 @@ PYBIND11_MODULE(core, m) { ...@@ -324,6 +324,7 @@ PYBIND11_MODULE(core, m) {
[](Tensor &self, paddle::platform::CUDAPinnedPlace &place) { [](Tensor &self, paddle::platform::CUDAPinnedPlace &place) {
self.mutable_data<float>(place); self.mutable_data<float>(place);
}) })
.def("_clear", &Tensor::clear)
.def("set", PyCPUTensorSetFromArray<float>) .def("set", PyCPUTensorSetFromArray<float>)
.def("set", PyCPUTensorSetFromArray<int>) .def("set", PyCPUTensorSetFromArray<int>)
.def("set", PyCPUTensorSetFromArray<double>) .def("set", PyCPUTensorSetFromArray<double>)
...@@ -932,6 +933,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -932,6 +933,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("init_gflags", framework::InitGflags); m.def("init_gflags", framework::InitGflags);
m.def("init_glog", framework::InitGLOG); m.def("init_glog", framework::InitGLOG);
m.def("init_dgc", framework::InitDGC);
m.def("init_devices", m.def("init_devices",
[](bool init_p2p) { framework::InitDevices(init_p2p); }); [](bool init_p2p) { framework::InitDevices(init_p2p); });
...@@ -1044,9 +1046,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1044,9 +1046,7 @@ All parameter, weight, gradient are variables in Paddle.
int val) { self.Set<const int>(name, new int(val)); }) int val) { self.Set<const int>(name, new int(val)); })
.def("type", &ir::Pass::Type) .def("type", &ir::Pass::Type)
.def("apply", [](ir::Pass &self, std::shared_ptr<ir::Graph> graph) { .def("apply", [](ir::Pass &self, std::shared_ptr<ir::Graph> graph) {
std::unique_ptr<ir::Graph> origin_graph(graph.get()); self.Apply(graph.get());
auto optim_graph = self.Apply(std::move(origin_graph));
optim_graph.release();
}); });
py::class_<ir::PassBuilder, std::shared_ptr<ir::PassBuilder>> pb( py::class_<ir::PassBuilder, std::shared_ptr<ir::PassBuilder>> pb(
......
...@@ -105,14 +105,12 @@ void Printf(const char* fmt, const Args&... args) { ...@@ -105,14 +105,12 @@ void Printf(const char* fmt, const Args&... args) {
Fprintf(std::cout, fmt, args...); Fprintf(std::cout, fmt, args...);
} }
template <typename T> inline std::string HumanReadableSize(double f_size) {
std::string HumanReadableSize(T size) {
size_t i = 0; size_t i = 0;
double f_size = static_cast<double>(size);
double orig = f_size; double orig = f_size;
const std::vector<std::string> units( const std::vector<std::string> units(
{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"}); {"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"});
while (f_size > 1024) { while (f_size >= 1024) {
f_size /= 1024; f_size /= 1024;
i++; i++;
} }
......
...@@ -152,7 +152,7 @@ class QuantizationStrategy(Strategy): ...@@ -152,7 +152,7 @@ class QuantizationStrategy(Strategy):
] ]
if self.save_in_nodes == None: if self.save_in_nodes == None:
in_vars = list(context.eval_graph.out_nodes.values()) in_vars = list(context.eval_graph.in_nodes.values())
else: else:
in_vars = self.save_in_nodes in_vars = self.save_in_nodes
......
...@@ -1202,6 +1202,9 @@ class Operator(object): ...@@ -1202,6 +1202,9 @@ class Operator(object):
""" """
self._update_desc_attr(name, val) self._update_desc_attr(name, val)
def _remove_attr(self, name):
self.desc.remove_attr(name)
def _update_desc_attr(self, name, val): def _update_desc_attr(self, name, val):
""" """
Update the value of desc's attribute by attribute's name. Update the value of desc's attribute by attribute's name.
...@@ -2725,6 +2728,10 @@ class Program(object): ...@@ -2725,6 +2728,10 @@ class Program(object):
self._trainers_endpoints = [] self._trainers_endpoints = []
# the distributed lookup table names # the distributed lookup table names
self._distributed_lookup_table = None self._distributed_lookup_table = None
# use Deep gradient comrepssion or not
self._enable_dgc = False
# @deprecated(the python memory optimize transpiler is deprecated) # @deprecated(the python memory optimize transpiler is deprecated)
# whether the program is optimized by memory_optimize_transpiler # whether the program is optimized by memory_optimize_transpiler
self.__is_mem_optimized = False self.__is_mem_optimized = False
...@@ -2775,6 +2782,15 @@ class Program(object): ...@@ -2775,6 +2782,15 @@ class Program(object):
def set_op_role_var(self, var_name): def set_op_role_var(self, var_name):
self._op_role_var = [var_name] self._op_role_var = [var_name]
@contextlib.contextmanager
def _backward_role_guard(self):
tmp_role = self._current_role
OpRole = core.op_proto_and_checker_maker.OpRole
self._current_role = OpRole.Backward
yield
self._current_role = tmp_role
@signature_safe_contextmanager @signature_safe_contextmanager
def _optimized_guard(self, param_and_grads): def _optimized_guard(self, param_and_grads):
""" """
......
...@@ -9674,9 +9674,15 @@ def space_to_depth(x, blocksize, name=None): ...@@ -9674,9 +9674,15 @@ def space_to_depth(x, blocksize, name=None):
.. code-block:: python .. code-block:: python
data = fluid.layers.data( data = fluid.layers.data(
name='data', shape=[1, 4, 2, 2], dtype='float32') name='data', shape=[1, 4, 2, 2], dtype='float32', append_batch_size=False)
space_to_depthed = fluid.layers.space_to_depth( space_to_depthed = fluid.layers.space_to_depth(
x=data, blocksize=2) x=data, blocksize=2)
exe = fluid.Executor(fluid.CUDAPlace(0))
data_np = np.arange(0,16).reshape((1,4,2,2)).astype('float32')
out_main = exe.run(fluid.default_main_program(),
feed={'data': data_np},
fetch_list=[space_to_depthed])
""" """
helper = LayerHelper("space_to_depth", **locals()) helper = LayerHelper("space_to_depth", **locals())
......
...@@ -17,7 +17,7 @@ from __future__ import print_function ...@@ -17,7 +17,7 @@ from __future__ import print_function
from collections import defaultdict from collections import defaultdict
from .wrapped_decorator import signature_safe_contextmanager from .wrapped_decorator import signature_safe_contextmanager
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
from . import framework from . import framework
...@@ -31,13 +31,17 @@ from .layer_helper import LayerHelper ...@@ -31,13 +31,17 @@ from .layer_helper import LayerHelper
from .layers import ops from .layers import ops
from .regularizer import append_regularization_ops from .regularizer import append_regularization_ops
from .imperative import base as imperative_base from .imperative import base as imperative_base
from paddle.fluid import core
from paddle.fluid.layers import tensor
from functools import reduce
import copy
__all__ = [ __all__ = [
'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl', 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl',
'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer', 'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer',
'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer', 'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer',
'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'LarsMomentum', 'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'LarsMomentum',
'LarsMomentumOptimizer' 'LarsMomentumOptimizer', 'DGCMomentumOptimizer'
] ]
...@@ -294,6 +298,9 @@ class Optimizer(object): ...@@ -294,6 +298,9 @@ class Optimizer(object):
outputs={"ParamOut": param_and_grad[0]}) outputs={"ParamOut": param_and_grad[0]})
return new_param_grads, (table_param, table_grad), sgd_op return new_param_grads, (table_param, table_grad), sgd_op
def _append_dgc_ops(self, param_and_grad):
pass
def backward(self, def backward(self,
loss, loss,
startup_program=None, startup_program=None,
...@@ -415,6 +422,9 @@ class Optimizer(object): ...@@ -415,6 +422,9 @@ class Optimizer(object):
with program_guard(program, startup_program): with program_guard(program, startup_program):
params_grads = self.backward(loss, startup_program, params_grads = self.backward(loss, startup_program,
parameter_list, no_grad_set) parameter_list, no_grad_set)
# Note: since we can't use all_reduce_op now,
# dgc_op should be the last op of one grad.
self._append_dgc_ops(params_grads)
optimize_ops = self.apply_gradients(params_grads) optimize_ops = self.apply_gradients(params_grads)
return optimize_ops, params_grads return optimize_ops, params_grads
...@@ -552,6 +562,264 @@ class MomentumOptimizer(Optimizer): ...@@ -552,6 +562,264 @@ class MomentumOptimizer(Optimizer):
return momentum_op return momentum_op
class DGCMomentumOptimizer(MomentumOptimizer):
"""
Original paper is https://arxiv.org/abs/1712.01887
DGC reduce the communication bandwidth by sending only the important gradients (sparse update):\
only gradients larger than a threshold are transmitted.
To avoid losing information, DGC accumulate the rest of the gradients locally.
Eventually, these gradients become large enough to be transmitted.
Thus, DGC send the large gradients immediately but eventually send all of the gradients over time.
To ensure no loss of accuracy, DGC employs momentum correc-tionandlocal gradient clipping on top of the gradient sparsification to maintain model performance.
DGC also uses momentum factor masking and warmup training to overcome the staleness problem caused by reduced communication.
This optimizer will do two things:
1. Compress the gradient by get TopK import value from tensor \
and use it for allreduce to reduce network bandwidth.
2. Call momentum to optimize on the cost.
Args:
learning_rate (float|Variable): the learning rate used to update parameters. \
Can be a float value or a Variable with one float value as data element.
momentum (float): Momentum factor.
rampup_begin_step (int): The begining step from which gradient compression is implemented.
rampup_step (int): How long it use the sparsity periods. Default is 1.
for example: If the sparsity is [0.75, 0.9375, 0.984375, 0.996, 0.999], and the rampup_step is 5, \
it will use 0.75 at 0 step, and 0.9375 at 1 step, and so on. And when reach sparsity array ends, \
it will use 0.999 then and after.
sparsity (list[float]): Get top important element from gradient tensor, the ratio is (1 - current sparsity).
use_nesterov (bool): Enables Nesterov momentum. True means use nesterov.
local_grad_clip_norm (float): Clip norm value if needed.
num_trainers: The number of training node.
regularization: A Regularizer, such as fluid.regularizer.L2DecayRegularizer.
name: A optional name prefix.
Examples:
.. code-block:: python
optimizer = fluid.optimizer.DGCMomentumOptimizer(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=0.9,
rampup_begin_step=1252,
regularization=fluid.regularizer.L2Decay(1e-4))
optimizer.minimize(cost)
"""
def __init__(self,
learning_rate,
momentum,
rampup_begin_step,
rampup_step=1,
sparsity=[0.999],
use_nesterov=False,
local_grad_clip_norm=None,
num_trainers=None,
regularization=None,
name=None):
self._sparsity = sparsity
self._rampup_step = rampup_step
self._rampup_step_var = None
self._rampup_begin_step = rampup_begin_step
self._rampup_begin_step_var = None
self._global_step_var = None
self._local_grad_clip_norm = None
self._clip_norm = None
if local_grad_clip_norm is not None:
assert isinstance(num_trainers, int)
assert isinstance(local_grad_clip_norm, float)
assert num_trainers > 0
self._local_grad_clip_norm = local_grad_clip_norm
self._num_trainers = num_trainers
self._clip_norm = local_grad_clip_norm / (num_trainers *
num_trainers)
super(DGCMomentumOptimizer, self).__init__(
learning_rate, momentum, use_nesterov, regularization, name)
core.init_dgc()
def _add_auto_increment_var(self, counter_name, begin, step=1):
helper = LayerHelper('global_step_counter')
counter, is_new_var = helper.create_or_get_global_variable(
name=counter_name, dtype='float32', shape=[1], persistable=True)
if is_new_var:
helper.set_variable_initializer(
counter,
initializer=Constant(
value=float(begin - 1), force_cpu=True))
helper.main_program.global_block()._prepend_op(
type='increment',
inputs={'X': [counter]},
outputs={'Out': [counter]},
attrs={'step': float(step)},
stop_gradient=True)
counter.stop_gradient = True
return counter
def _append_dgc_ops(self, param_and_grads):
start_program = default_startup_program()
main_program = default_main_program()
main_program._enable_dgc = True
# step counter
self._global_step_var = self._add_auto_increment_var(
counter_name='__g_dgc_counter__', begin=0)
# rampup begin step var for all_reduce_op_handle
self._rampup_begin_step_var = tensor.create_global_var(
shape=[1],
dtype=core.VarDesc.VarType.FP32,
persistable=True,
name='__g_rampup_begin_step__',
value=self._rampup_begin_step * 1.0,
force_cpu=True)
for param_var, grad_var in param_and_grads:
var_numel = reduce(lambda x, y: x * y, param_var.shape)
if var_numel < 16384 or \
param_var.type == core.VarDesc.VarType.SELECTED_ROWS or \
grad_var.type == core.VarDesc.VarType.SELECTED_ROWS or \
param_var.dtype != core.VarDesc.VarType.FP32 :
continue
u_var = tensor.create_global_var(
shape=param_var.shape,
dtype=param_var.dtype,
persistable=True,
name=param_var.name + "__dgc_u__",
value=0.0)
v_var = tensor.create_global_var(
shape=param_var.shape,
dtype=param_var.dtype,
persistable=True,
name=param_var.name + "__dgc_v__",
value=0.0)
k_var = tensor.create_global_var(
shape=[1],
dtype=param_var.dtype,
persistable=True,
name=param_var.name + "__dgc_k__",
value=0.0,
force_cpu=True)
encoded_var = tensor.create_global_var(
shape=[1],
dtype=param_var.dtype,
persistable=True,
name=param_var.name + "__dgc_encoded__",
value=0.0,
force_cpu=False)
# del back oprolevarname
op_maker = core.op_proto_and_checker_maker
backward = core.op_proto_and_checker_maker.OpRole.Backward
for op in main_program.global_block().ops:
if not self._is_the_backward_op(op):
continue
var_attr = op.all_attrs()[op_maker.kOpRoleVarAttrName()]
if param_var.name not in var_attr:
continue
var_attr.remove(param_var.name)
var_attr.remove(grad_var.name)
if len(var_attr) > 1:
op._set_attr(op_maker.kOpRoleVarAttrName(), var_attr)
else:
op._remove_attr(op_maker.kOpRoleVarAttrName())
clip_var = grad_var
if self._local_grad_clip_norm is not None:
clip_var = self._append_clip_norm(grad_var, self._clip_norm)
self._dgc_op(param_var, clip_var, grad_var, u_var, v_var, k_var,
encoded_var)
def _is_the_backward_op(self, op):
op_maker = core.op_proto_and_checker_maker
backward = core.op_proto_and_checker_maker.OpRole.Backward
if op_maker.kOpRoleVarAttrName() in op.attr_names and \
int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(backward):
return True
return False
def _clip_by_norm(self, x, max_norm, name=None):
args = {'x': x, 'max_norm': max_norm, 'name': name}
helper = LayerHelper("dgc_clip_by_norm_op", **args)
if name is None:
name = unique_name.generate(".".join([helper.name, 'tmp']))
out = helper.create_variable(
type=x.type, name=name, dtype=x.dtype, persistable=False)
helper.append_op(
type="clip_by_norm",
inputs={"X": x,
"current_step": self._global_step_var},
attrs={
"max_norm": max_norm,
"rampup_begin_step": float(self._rampup_begin_step)
},
outputs={"Out": out})
return out
def _append_clip_norm(self, grad_var, clip_norm):
with grad_var.block.program._backward_role_guard():
return self._clip_by_norm(
x=grad_var, max_norm=clip_norm, name=grad_var.name + "@DGC")
def _dgc_op(self, param_var, clip_var, grad_var, u_var, v_var, k_var,
encoded_var):
block = framework.default_main_program().global_block()
op_maker = core.op_proto_and_checker_maker
dgc_op = block.append_op(
type="dgc",
inputs={
"U": u_var,
"V": v_var,
"Grad": clip_var,
"current_step": self._global_step_var
},
outputs={
"U_out": u_var,
"V_out": v_var,
"EncodeGrad": encoded_var,
"k": k_var,
"Grad_out": grad_var
},
attrs={
"m": self._momentum,
"sparsity": self._sparsity,
"use_nesterov": self._use_nesterov,
"rampup_begin_step": float(self._rampup_begin_step),
"rampup_step": float(self._rampup_step)
},
stop_gradient=True)
backward = op_maker.OpRole.Backward
dgc_op._set_attr(op_maker.kOpRoleAttrName(), backward)
dgc_op._set_attr(op_maker.kOpRoleVarAttrName(),
[param_var.name, grad_var.name])
class LarsMomentumOptimizer(Optimizer): class LarsMomentumOptimizer(Optimizer):
""" """
Momentum optimizer with LARS support Momentum optimizer with LARS support
......
...@@ -103,6 +103,12 @@ class ParallelExecutor(object): ...@@ -103,6 +103,12 @@ class ParallelExecutor(object):
) if use_cuda else framework.cpu_places() ) if use_cuda else framework.cpu_places()
self._scope = scope if scope is not None else executor.global_scope() self._scope = scope if scope is not None else executor.global_scope()
if main_program is not None and main_program._enable_dgc:
assert build_strategy.reduce_strategy == BuildStrategy.ReduceStrategy.AllReduce
assert num_trainers * len(
self._places) > 1, "dgc is not useful for single card training"
assert use_cuda
main_program = main_program if main_program is not None \ main_program = main_program if main_program is not None \
else framework.default_main_program() else framework.default_main_program()
......
...@@ -70,6 +70,7 @@ list(REMOVE_ITEM TEST_OPS test_dist_transpiler) ...@@ -70,6 +70,7 @@ list(REMOVE_ITEM TEST_OPS test_dist_transpiler)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_crf) list(REMOVE_ITEM TEST_OPS test_parallel_executor_crf)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_fetch_feed) list(REMOVE_ITEM TEST_OPS test_parallel_executor_fetch_feed)
list(REMOVE_ITEM TEST_OPS test_dist_se_resnext) list(REMOVE_ITEM TEST_OPS test_dist_se_resnext)
list(REMOVE_ITEM TEST_OPS test_dgc_op)
list(REMOVE_ITEM TEST_OPS test_dist_se_resnext_nccl) list(REMOVE_ITEM TEST_OPS test_dist_se_resnext_nccl)
list(REMOVE_ITEM TEST_OPS test_dist_transformer) list(REMOVE_ITEM TEST_OPS test_dist_transformer)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_transformer) list(REMOVE_ITEM TEST_OPS test_parallel_executor_transformer)
...@@ -97,6 +98,7 @@ if(WITH_DISTRIBUTE) ...@@ -97,6 +98,7 @@ if(WITH_DISTRIBUTE)
set_tests_properties(test_dist_mnist PROPERTIES TIMEOUT 200) set_tests_properties(test_dist_mnist PROPERTIES TIMEOUT 200)
set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200) set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200)
py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext) py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext)
py_test_modules(test_dgc_op MODULES test_dgc_op)
set_tests_properties(test_dist_se_resnext PROPERTIES TIMEOUT 1000) set_tests_properties(test_dist_se_resnext PROPERTIES TIMEOUT 1000)
py_test_modules(test_dist_se_resnext_nccl MODULES test_dist_se_resnext_nccl) py_test_modules(test_dist_se_resnext_nccl MODULES test_dist_se_resnext_nccl)
set_tests_properties(test_dist_se_resnext_nccl PROPERTIES TIMEOUT 1000) set_tests_properties(test_dist_se_resnext_nccl PROPERTIES TIMEOUT 1000)
...@@ -107,16 +109,20 @@ if(WITH_DISTRIBUTE) ...@@ -107,16 +109,20 @@ if(WITH_DISTRIBUTE)
endif(NOT APPLE) endif(NOT APPLE)
# py_test_modules(test_dist_transpiler MODULES test_dist_transpiler) # py_test_modules(test_dist_transpiler MODULES test_dist_transpiler)
endif() endif()
py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL) py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL)
py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL) py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL)
set_tests_properties(test_parallel_executor_fetch_feed PROPERTIES TIMEOUT 450) set_tests_properties(test_parallel_executor_fetch_feed PROPERTIES TIMEOUT 450)
py_test_modules(test_parallel_executor_transformer MODULES test_parallel_executor_transformer SERIAL) py_test_modules(test_parallel_executor_transformer MODULES test_parallel_executor_transformer SERIAL)
if(NOT WIN32) if(NOT WIN32)
py_test_modules(test_ir_memory_optimize_transformer MODULES test_ir_memory_optimize_transformer SERIAL) py_test_modules(test_ir_memory_optimize_transformer MODULES test_ir_memory_optimize_transformer SERIAL)
endif() endif()
if(NOT APPLE) if(NOT APPLE)
py_test_modules(test_image_classification_resnet MODULES test_image_classification_resnet SERIAL) py_test_modules(test_image_classification_resnet MODULES test_image_classification_resnet SERIAL)
endif() endif()
if(CMAKE_BUILD_TYPE STREQUAL "Debug") if(CMAKE_BUILD_TYPE STREQUAL "Debug")
# change the timeout from 600 to 2200, because in debug mode, this test need more time. # change the timeout from 600 to 2200, because in debug mode, this test need more time.
set_tests_properties(test_parallel_executor_seresnext PROPERTIES TIMEOUT 2200) set_tests_properties(test_parallel_executor_seresnext PROPERTIES TIMEOUT 2200)
......
...@@ -73,7 +73,7 @@ def cnn_model(data): ...@@ -73,7 +73,7 @@ def cnn_model(data):
class TestDistMnist2x2(TestDistRunnerBase): class TestDistMnist2x2(TestDistRunnerBase):
def get_model(self, batch_size=2): def get_model(self, batch_size=2, use_dgc=False):
# Input data # Input data
images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE) images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE)
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
...@@ -93,7 +93,11 @@ class TestDistMnist2x2(TestDistRunnerBase): ...@@ -93,7 +93,11 @@ class TestDistMnist2x2(TestDistRunnerBase):
# TODO(typhoonzero): fix distributed adam optimizer # TODO(typhoonzero): fix distributed adam optimizer
# opt = fluid.optimizer.AdamOptimizer( # opt = fluid.optimizer.AdamOptimizer(
# learning_rate=0.001, beta1=0.9, beta2=0.999) # learning_rate=0.001, beta1=0.9, beta2=0.999)
if not use_dgc:
opt = fluid.optimizer.Momentum(learning_rate=self.lr, momentum=0.9) opt = fluid.optimizer.Momentum(learning_rate=self.lr, momentum=0.9)
else:
opt = fluid.optimizer.DGCMomentumOptimizer(
learning_rate=self.lr, momentum=0.9, rampup_begin_step=0)
# Reader # Reader
train_reader = paddle.batch( train_reader = paddle.batch(
......
...@@ -210,7 +210,7 @@ class SE_ResNeXt(): ...@@ -210,7 +210,7 @@ class SE_ResNeXt():
class DistSeResneXt2x2(TestDistRunnerBase): class DistSeResneXt2x2(TestDistRunnerBase):
def get_model(self, batch_size=2): def get_model(self, batch_size=2, use_dgc=False):
# Input data # Input data
image = fluid.layers.data( image = fluid.layers.data(
name="data", shape=[3, 224, 224], dtype='float32') name="data", shape=[3, 224, 224], dtype='float32')
...@@ -237,11 +237,19 @@ class DistSeResneXt2x2(TestDistRunnerBase): ...@@ -237,11 +237,19 @@ class DistSeResneXt2x2(TestDistRunnerBase):
base_lr = 0.1 base_lr = 0.1
lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
if not use_dgc:
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay( learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr), boundaries=bd, values=lr),
momentum=0.9, momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4)) regularization=fluid.regularizer.L2Decay(1e-4))
else:
optimizer = fluid.optimizer.DGCMomentumOptimizer(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=0.9,
rampup_begin_step=0,
regularization=fluid.regularizer.L2Decay(1e-4))
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
# Reader # Reader
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
g_array_size = 102400
class TestDGCOp(unittest.TestCase):
def setup(self, place, array_size=g_array_size):
size = array_size
np.random.seed(5) # fix seed
self.scope = fluid.global_scope()
self.place = place
print("place:", place)
# numpy data
# inputs: U, V, Grad, current_step
self.u_name = "U"
self.u = np.random.random(size).astype("float32")
self.v_name = "V"
self.v = np.random.random(size).astype("float32")
self.grad_name = "Grad"
self.grad = np.random.random(size).astype("float32")
self.current_step_name = "current_step"
self.current_step = np.full((1), 0.0).astype("float32")
# output: U_out, V_out, EncodeGrad, GradLocal_out
self.encode_grad_name = "EncodeGrad"
self.k_name = "k"
self.k = np.full((1), 0.0).astype("float32")
# scope data
self.u_tensor = self.scope.var(self.u_name).get_tensor()
self.u_tensor.set(self.u, place)
self.v_tensor = self.scope.var(self.v_name).get_tensor()
self.v_tensor.set(self.v, place)
self.grad_tensor = self.scope.var(self.grad_name).get_tensor()
self.grad_tensor.set(self.grad, place)
self.encode_grad_tensor = self.scope.var(
self.encode_grad_name).get_tensor()
self.current_step_tensor = self.scope.var(
self.current_step_name).get_tensor()
self.current_step_tensor.set(self.current_step, core.CPUPlace())
self.k_tensor = self.scope.var(self.k_name).get_tensor()
self.k_tensor.set(self.k, core.CPUPlace())
def check(self, actual_t, expect_t, place, out_name, atol=1e-5):
self.assertTrue(
np.allclose(
actual_t, expect_t, atol=atol),
"Output (" + out_name + ") has diff at " + str(place) + "\nExpect "
+ str(expect_t) + "\n" + "But Got" + str(actual_t))
def test_run_and_check(self):
self.setup(place=core.CUDAPlace(0))
kwargs = {
# inputs
'U': self.u_name,
'V': self.v_name,
'Grad': self.grad_name,
'current_step': self.current_step_name,
# outputs
'U_out': self.u_name,
'V_out': self.v_name,
'EncodeGrad': self.encode_grad_name,
'Grad_out': self.grad_name,
'k': self.k_name,
# attrs
'm': 0.9,
'sparsity': [0.75, 0.9375, 0.984375, 0.996, 0.999],
'use_nesterov': True,
'rampup_begin_step': float(0.0),
'rampup_step': float(10.0),
}
dgc_op = Operator('dgc', **kwargs)
#atol = 1e-6
dgc_op.run(self.scope, self.place)
u_out = np.array(self.u_tensor)
v_out = np.array(self.v_tensor)
grad_out = np.array(self.grad_tensor)
encode_grad_out = np.array(self.encode_grad_tensor)
k = int(np.array(self.k_tensor)[0])
print("u_out:", u_out[0:20])
print("v_out:", v_out[0:20])
print("encode_grad_out:", encode_grad_out)
print("k_out:", k)
self.assertEqual(k, int(g_array_size * 0.25))
index = encode_grad_out[0:k].view(dtype=np.int32)
value = encode_grad_out[k:2 * k]
acl = 1e-7
for i in range(0, k):
self.assertAlmostEqual(u_out[index[i]], 0.0)
self.assertAlmostEqual(v_out[index[i]], 0.0)
a_min = np.amin(value)
dangling = [x for x in v_out if x > a_min]
if __name__ == "__main__":
unittest.main()
...@@ -36,7 +36,8 @@ class TestDistRunnerBase(object): ...@@ -36,7 +36,8 @@ class TestDistRunnerBase(object):
def get_model(self, def get_model(self,
batch_size=DEFAULT_BATCH_SIZE, batch_size=DEFAULT_BATCH_SIZE,
lr=0.1, lr=0.1,
single_device=False): single_device=False,
use_dgc=False):
raise NotImplementedError( raise NotImplementedError(
"get_model should be implemented by child classes.") "get_model should be implemented by child classes.")
...@@ -82,6 +83,9 @@ class TestDistRunnerBase(object): ...@@ -82,6 +83,9 @@ class TestDistRunnerBase(object):
if args.nccl2_reduce_layer_local_run: if args.nccl2_reduce_layer_local_run:
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \ test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
self.get_model(batch_size=args.batch_size, single_device=True) self.get_model(batch_size=args.batch_size, single_device=True)
elif args.use_dgc:
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
self.get_model(batch_size=args.batch_size, use_dgc=args.use_dgc)
else: else:
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \ test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
self.get_model(batch_size=args.batch_size) self.get_model(batch_size=args.batch_size)
...@@ -200,6 +204,7 @@ def runtime_main(test_class): ...@@ -200,6 +204,7 @@ def runtime_main(test_class):
parser.add_argument('--sync_mode', action='store_true') parser.add_argument('--sync_mode', action='store_true')
parser.add_argument('--mem_opt', action='store_true') parser.add_argument('--mem_opt', action='store_true')
parser.add_argument('--use_cuda', action='store_true') parser.add_argument('--use_cuda', action='store_true')
parser.add_argument('--use_dgc', action='store_true')
parser.add_argument('--use_reduce', action='store_true') parser.add_argument('--use_reduce', action='store_true')
parser.add_argument('--dc_asgd', action='store_true') parser.add_argument('--dc_asgd', action='store_true')
parser.add_argument( parser.add_argument(
...@@ -235,6 +240,7 @@ class TestDistBase(unittest.TestCase): ...@@ -235,6 +240,7 @@ class TestDistBase(unittest.TestCase):
def _after_setup_config(self): def _after_setup_config(self):
if self._enforce_place == "CPU": if self._enforce_place == "CPU":
self.__use_cuda = False self.__use_cuda = False
self._use_dgc = False
elif self._enforce_place == "GPU": elif self._enforce_place == "GPU":
self.__use_cuda = True self.__use_cuda = True
else: else:
...@@ -242,6 +248,10 @@ class TestDistBase(unittest.TestCase): ...@@ -242,6 +248,10 @@ class TestDistBase(unittest.TestCase):
self.__use_cuda = True self.__use_cuda = True
else: else:
self.__use_cuda = False self.__use_cuda = False
self._use_dgc = False
if self._use_reduce:
assert not self._use_dgc
def setUp(self): def setUp(self):
self._trainers = 2 self._trainers = 2
...@@ -264,6 +274,7 @@ class TestDistBase(unittest.TestCase): ...@@ -264,6 +274,7 @@ class TestDistBase(unittest.TestCase):
# test, reduce check this argument everywhere. # test, reduce check this argument everywhere.
self._nccl2_reduce_layer = False self._nccl2_reduce_layer = False
self._lr = 0.001 self._lr = 0.001
self._use_dgc = False
self._setup_config() self._setup_config()
self._after_setup_config() self._after_setup_config()
...@@ -506,6 +517,9 @@ class TestDistBase(unittest.TestCase): ...@@ -506,6 +517,9 @@ class TestDistBase(unittest.TestCase):
env0 = {'CPU_NUM': '1'} env0 = {'CPU_NUM': '1'}
env1 = {'CPU_NUM': '1'} env1 = {'CPU_NUM': '1'}
if self._use_dgc:
tr0_cmd += " --use_dgc"
tr1_cmd += " --use_dgc"
if self._mp_mode: if self._mp_mode:
env0 = {"FLAGS_selected_gpus": "0"} env0 = {"FLAGS_selected_gpus": "0"}
env1 = {"FLAGS_selected_gpus": "1"} env1 = {"FLAGS_selected_gpus": "1"}
......
...@@ -39,6 +39,20 @@ class TestDistMnistNCCL2(TestDistBase): ...@@ -39,6 +39,20 @@ class TestDistMnistNCCL2(TestDistBase):
self.check_with_place("dist_mnist.py", delta=1e-5) self.check_with_place("dist_mnist.py", delta=1e-5)
class TestDistMnistNCCL2DGC(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._use_reduce = False
self._use_reader_alloc = False
self._nccl2_mode = True
self._use_dgc = True
def test_dist_train(self):
import paddle.fluid as fluid
if fluid.core.is_compiled_with_cuda():
self.check_with_place("dist_mnist.py", delta=1e-5)
class TestDistMnist2x2Lars(TestDistBase): class TestDistMnist2x2Lars(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = True self._sync_mode = True
......
...@@ -60,5 +60,20 @@ class TestDistSeResneXt2x2Async(TestDistBase): ...@@ -60,5 +60,20 @@ class TestDistSeResneXt2x2Async(TestDistBase):
self.check_with_place("dist_se_resnext.py", delta=100) self.check_with_place("dist_se_resnext.py", delta=100)
class TestDistSeResnetNCCL2DGC(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._use_reduce = False
self._use_reader_alloc = False
self._nccl2_mode = True
self._use_dgc = True
@skip_ci
def test_dist_train(self):
import paddle.fluid as fluid
if fluid.core.is_compiled_with_cuda():
self.check_with_place("dist_se_resnext.py", delta=30)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -39,19 +39,21 @@ class TestFSPOp(OpTest): ...@@ -39,19 +39,21 @@ class TestFSPOp(OpTest):
self.op_type = "fsp" self.op_type = "fsp"
self.initTestCase() self.initTestCase()
feature_map_0 = np.random.uniform(0, 10, self.a_shape).astype('float32') feature_map_0 = np.random.uniform(0, 10, self.a_shape).astype('float64')
feature_map_1 = np.random.uniform(0, 10, self.b_shape).astype('float32') feature_map_1 = np.random.uniform(0, 10, self.b_shape).astype('float64')
self.inputs = {'X': feature_map_0, 'Y': feature_map_1} self.inputs = {'X': feature_map_0, 'Y': feature_map_1}
self.outputs = {'Out': fsp_matrix(feature_map_0, feature_map_1)} self.outputs = {'Out': fsp_matrix(feature_map_0, feature_map_1)}
def initTestCase(self): def initTestCase(self):
self.a_shape = (2, 16, 32, 31) self.a_shape = (2, 3, 5, 6)
self.b_shape = (2, 28, 32, 31) self.b_shape = (2, 4, 5, 6)
@unittest.skip("Disable temporarily.")
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
@unittest.skip("Disable temporarily.")
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.05) self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.05)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册