未验证 提交 eb83abea 编写于 作者: G gongweibao 提交者: GitHub

Add DGC(Deep Gradient Compression) interface. (#15841)

上级 b1d26051
......@@ -193,6 +193,12 @@ if(WITH_GPU)
include(tensorrt)
include(anakin_subgraph)
endif()
if(WITH_GPU AND NOT WIN32)
message(STATUS "add dgc lib.")
include(external/dgc)
endif()
if(WITH_MKL OR WITH_MKLML)
include(external/anakin)
elseif()
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
INCLUDE(ExternalProject)
SET(DGC_SOURCES_DIR "${THIRD_PARTY_PATH}/dgc")
SET(DGC_INSTALL_DIR "${THIRD_PARTY_PATH}/install/dgc")
SET(DGC_INCLUDE_DIR "${DGC_INSTALL_DIR}/include" CACHE PATH "dgc include directory." FORCE)
SET(DGC_LIBRARIES "${DGC_INSTALL_DIR}/lib/libdgc.a" CACHE FILEPATH "dgc library." FORCE)
INCLUDE_DIRECTORIES(${DGC_INCLUDE_DIR})
ExternalProject_Add(
extern_dgc
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/PaddlePaddle/Fleet"
GIT_TAG "2d04dc3800cdd0601f1b65d547dabcc60b0cf9dc"
SOURCE_DIR "${DGC_SOURCES_DIR}"
CONFIGURE_COMMAND ""
BUILD_COMMAND cd collective && make -j
INSTALL_COMMAND mkdir -p ${DGC_INSTALL_DIR}/lib/ ${DGC_INCLUDE_DIR}/dgc
&& cp ${DGC_SOURCES_DIR}/collective/build/lib/libdgc.a ${DGC_LIBRARIES}
&& cp ${DGC_SOURCES_DIR}/collective/build/include/dgc.h ${DGC_INCLUDE_DIR}/dgc/
BUILD_IN_SOURCE 1
)
ADD_LIBRARY(dgc SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET dgc PROPERTY IMPORTED_LOCATION ${DGC_LIBRARIES})
ADD_DEPENDENCIES(dgc extern_dgc)
LIST(APPEND external_project_dependencies dgc)
......@@ -131,6 +131,15 @@ elseif (NOT CBLAS_FOUND OR WIN32)
)
endif ()
if (WITH_GPU AND NOT WIN32)
set(dgc_dir "${FLUID_INSTALL_DIR}/third_party/install/dgc")
copy(dgc_lib
SRCS ${DGC_INSTALL_DIR}/lib ${DGC_INSTALL_DIR}/include
DSTS ${dgc_dir} ${dgc_dir}
DEPS dgc)
endif()
if (WITH_MKLDNN)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/mkldnn")
copy(mkldnn_lib
......
......@@ -110,7 +110,7 @@ function(op_library TARGET)
# Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op")
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
......
......@@ -483,6 +483,11 @@ paddle.fluid.optimizer.LarsMomentumOptimizer.apply_gradients (ArgSpec(args=['sel
paddle.fluid.optimizer.LarsMomentumOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f'))
paddle.fluid.optimizer.LarsMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.LarsMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea'))
paddle.fluid.optimizer.DGCMomentumOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'momentum', 'rampup_begin_step', 'rampup_step', 'sparsity', 'use_nesterov', 'local_grad_clip_norm', 'num_trainers', 'regularization', 'name'], varargs=None, keywords=None, defaults=(1, [0.999], False, None, None, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.DGCMomentumOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871'))
paddle.fluid.optimizer.DGCMomentumOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f'))
paddle.fluid.optimizer.DGCMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.DGCMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea'))
paddle.fluid.backward.append_backward (ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '1a79bd7d10ae54ca763ec81bca36ba24'))
paddle.fluid.regularizer.L1DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.regularizer.L2DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
......
......@@ -23,7 +23,7 @@ endif()
if(WITH_GPU)
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor)
dynload_cuda variable_visitor dgc)
nv_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor)
if(WITH_DISTRIBUTE)
......
......@@ -86,7 +86,8 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
}
}
VLOG(10) << "dist_ops size:" << dist_ops.size() << std::endl;
VLOG(10) << "dist_ops size:" << dist_ops.size()
<< ", outputs size:" << vars.size() << ", ops size:" << ops.size();
std::sort(dist_ops.begin(), dist_ops.end(), [&](OpHandleBase* op1,
OpHandleBase* op2) {
......@@ -99,6 +100,10 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
auto l_it = vars.find(i0->name());
auto r_it = vars.find(i1->name());
PADDLE_ENFORCE(l_it != vars.end() && r_it != vars.end(),
"can't find var's name %s and %s in opdesc", i0->name(),
i1->name());
if (l_it->second < r_it->second) return true;
if (l_it->second == r_it->second) {
......
......@@ -16,6 +16,13 @@
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/framework/operator.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "dgc/dgc.h"
#endif
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/profiler.h"
// asynchronous nccl allreduce or synchronous issue:
......@@ -33,11 +40,14 @@ namespace details {
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs)
const platform::NCCLContextMap *ctxs,
bool is_encoded, int nranks)
: OpHandleBase(node),
local_scopes_(local_scopes),
places_(places),
nccl_ctxs_(ctxs) {
nccl_ctxs_(ctxs),
is_encoded_(is_encoded),
nranks_(nranks) {
if (nccl_ctxs_) {
for (auto &p : places_) {
this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p));
......@@ -51,7 +61,185 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {}
#endif
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void AllReduceOpHandle::RunImplEncoded() {
platform::RecordEvent record_event(Name());
WaitInputVarGenerated();
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE_EQ(
in_var_handles.size(), places_.size(),
"The NoDummyInputSize should be equal to the number of places.");
PADDLE_ENFORCE_EQ(
in_var_handles.size(), out_var_handles.size(),
"The NoDummyInputSize and NoDummyOutputSize should be equal.");
std::vector<const LoDTensor *> ins;
std::vector<LoDTensor *> outs;
int k = -1;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &local_scope =
local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto original_name =
paddle::framework::GradOriginalVarName(in_var_handles[i]->name());
auto encode_var_name = original_name + g_dgc_encoded;
auto *in_var = local_scope->FindVar(encode_var_name);
PADDLE_ENFORCE_NOT_NULL(in_var);
auto &in = in_var->Get<LoDTensor>();
ins.emplace_back(&in);
auto *out = local_scope->FindVar(out_var_handles[i]->name())
->GetMutable<LoDTensor>();
outs.emplace_back(out);
if (k < 0) {
k = GetKValue(in_var_handles[i]->name());
}
}
PADDLE_ENFORCE(platform::is_gpu_place(ins[0]->place()));
PADDLE_ENFORCE(platform::is_gpu_place(outs[0]->place()));
PADDLE_ENFORCE(nccl_ctxs_, "nccl_ctxs should not be nullptr.");
int dtype = -1;
size_t in_numel = 0;
size_t out_numel = 0;
PADDLE_ENFORCE(nranks_ > 1);
std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &place = places_[i];
auto &in = *ins[i];
void *in_tensor_buf = const_cast<void *>(in.data<void>());
auto &out = *outs[i];
float *out_tensor_buf = out.data<float>();
dtype = (dtype == -1) ? platform::ToNCCLDataType(in.type()) : dtype;
in_numel = (in_numel == 0) ? static_cast<size_t>(in.numel()) : in_numel;
PADDLE_ENFORCE(in_numel % 2 == 0);
PADDLE_ENFORCE(in_numel / 2 == static_cast<size_t>(k));
out_numel = (out_numel == 0) ? static_cast<size_t>(out.numel()) : out_numel;
int dev_id = boost::get<platform::CUDAPlace>(place).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
auto &allocator =
platform::DeviceTemporaryAllocator::Instance().Get(place, stream);
int encode_size = 2 * k * sizeof(int);
// dgc use ncclAllGather to get all the encoded data
// so the buffer need nranks.
int buf_size = nranks_ * encode_size;
auto tmp_ious_data = allocator.Allocate(buf_size);
void *gather_buff = reinterpret_cast<void *>(tmp_ious_data->ptr());
VLOG(10) << "in_numel:" << in_numel << ", out_numel:" << out_numel
<< ", nranks:" << nranks_ << ", gather_buf size:" << buf_size
<< ", k:" << k << ", place:" << place << ", dtype:" << dtype;
all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(paddle::communication::dgc::sparseAllGReduce(
in_tensor_buf, gather_buff, k, out_tensor_buf, out_numel, comm,
stream));
});
}
this->RunAndRecordEvent([&] {
if (all_reduce_calls.size() == 1UL) {
// Do not use NCCLGroup when manage NCCL by per thread per device
all_reduce_calls[0]();
} else {
platform::NCCLGroupGuard guard;
for (auto &call : all_reduce_calls) {
call();
}
}
});
if (FLAGS_sync_nccl_allreduce) {
for (auto &p : places_) {
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
cudaError_t e_sync = cudaStreamSynchronize(stream);
if (e_sync != 0) {
LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync);
}
cudaError_t e_get = cudaGetLastError();
if (e_get != 0) {
LOG(FATAL) << "cudaGetLastError " << cudaGetErrorString(e_get)
<< " errno:" << e_get;
}
}
}
}
int AllReduceOpHandle::GetKValue(const std::string &grad_name) {
auto original_name = paddle::framework::GradOriginalVarName(grad_name);
auto var_name = original_name + g_dgc_k;
PADDLE_ENFORCE(local_scopes_.size() > 0);
auto *scope = local_scopes_[0];
auto &local_scope = scope->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto var = local_scope->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(var);
auto tensor = var->Get<LoDTensor>().data<float>();
return *tensor;
}
#endif
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
bool AllReduceOpHandle::IsEncoded() {
if (!is_encoded_) {
return false;
}
auto counter_name = g_dgc_counter_name;
auto step_name = g_dgc_rampup_begin_step;
PADDLE_ENFORCE(local_scopes_.size() > 0);
auto *scope = local_scopes_[0];
auto &local_scope = scope->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto count_var = local_scope->FindVar(counter_name);
auto step_var = local_scope->FindVar(step_name);
if (count_var == nullptr || step_var == nullptr) {
PADDLE_THROW("not find count_var:%s or step_var:%s", counter_name,
step_var);
}
float count = *count_var->Get<LoDTensor>().data<float>();
float step = *step_var->Get<LoDTensor>().data<float>();
if (static_cast<int>(count) < static_cast<int>(step)) {
VLOG(10) << "in all_reduce currentstep:" << count
<< " < rampup_begin_step:" << step
<< " so not use sparse all reduce";
return false;
}
return true;
}
#else
bool AllReduceOpHandle::IsEncoded() { return false; }
#endif
void AllReduceOpHandle::RunImpl() {
if (!IsEncoded()) {
RunImplNormal();
return;
}
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
RunImplEncoded();
#else
PADDLE_THROW("Not compiled with CUDA");
#endif
}
void AllReduceOpHandle::RunImplNormal() {
platform::RecordEvent record_event(Name());
WaitInputVarGenerated();
......@@ -72,6 +260,8 @@ void AllReduceOpHandle::RunImpl() {
auto &lod_tensor =
local_scope.FindVar(in_var_handles[i]->name())->Get<LoDTensor>();
lod_tensors.emplace_back(&lod_tensor);
VLOG(10) << "place:" << i << ", input_name:" << in_var_handles[i]->name()
<< ", out_name:" << out_var_handles[i]->name();
PADDLE_ENFORCE_EQ(in_var_handles[i]->name(), out_var_handles[i]->name(),
"The name of input and output should be equal.");
}
......@@ -99,13 +289,17 @@ void AllReduceOpHandle::RunImpl() {
auto &nccl_ctx = nccl_ctxs_->at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
VLOG(10) << "before all reduce buffer:" << buffer << ", numel:" << numel
<< ", dev_id:" << dev_id << ", dtype:" << dtype
<< ", place:" << p;
all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
comm, stream));
});
}
this->RunAndRecordEvent([&] {
if (all_reduce_calls.size() == 1UL) {
// Do not use NCCLGroup when manage NCCL by per thread per device
......
......@@ -28,11 +28,19 @@ namespace paddle {
namespace framework {
namespace details {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
constexpr char g_dgc_counter_name[] = "__g_dgc_counter__";
constexpr char g_dgc_rampup_begin_step[] = "__g_rampup_begin_step__";
constexpr char g_dgc_encoded[] = "__dgc_encoded__";
constexpr char g_dgc_k[] = "__dgc_k__";
#endif
struct AllReduceOpHandle : public OpHandleBase {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs);
const platform::NCCLContextMap *ctxs,
bool is_encoded = false, int nranks = -1);
#else
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
......@@ -50,8 +58,14 @@ struct AllReduceOpHandle : public OpHandleBase {
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void RunImplEncoded();
const platform::NCCLContextMap *nccl_ctxs_;
bool is_encoded_{false};
int nranks_{-1};
int GetKValue(const std::string &grad_name);
#endif
void RunImplNormal();
bool IsEncoded();
};
} // namespace details
......
......@@ -32,6 +32,7 @@
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace framework {
......@@ -209,7 +210,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
for (size_t i = 0; i < backward_vars.size(); i += 2) {
auto &p_name = backward_vars[i];
auto &g_name = backward_vars[i + 1];
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name
<< " op_type " << node->Op()->Type();
if (NeedCollectiveForGrad(g_name, sorted_ops)) {
InsertCollectiveOp(&result, p_name, g_name);
}
......@@ -414,8 +416,9 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
CreateOpHandleIOs(result, node, dev_id);
}
void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
ir::Graph *result, const std::string &og) const {
void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
const std::string &og,
bool is_encoded) const {
OpHandleBase *op_handle = nullptr;
auto append_allreduce_op = [&](
......@@ -424,7 +427,9 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
scopes, places, nccl_ctxs_));
scopes, places, nccl_ctxs_, is_encoded,
static_cast<int>(strategy_.trainers_endpoints_.size()) *
places_.size()));
#else
result->Get<GraphOps>(kGraphOps).emplace_back(new AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
......@@ -446,12 +451,15 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(
PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad);
VLOG(10) << "all_reduce_op_handle add input " << prev_grad->DebugString();
auto var =
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
vars.size(), i, og, places_[i]);
vars.emplace_back(var);
op_handle->AddOutput(var);
VLOG(10) << "all_reduce_op_handle add output " << og
<< ", handle:" << var->DebugString();
}
}
......@@ -941,6 +949,17 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
return op_dev_id;
}
bool DistSSAGraphBuilder::IsEncoded(const std::string &p_name) const {
auto u_name = p_name + "__dgc_u__";
auto it = all_vars_.find(u_name);
if (it == all_vars_.end()) {
VLOG(10) << "can't find u_name, so it's not encoded:" << u_name;
return false;
}
return true;
}
void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
const std::string &p_name,
const std::string &g_name) const {
......@@ -956,7 +975,11 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
CreateReduceOp(result, g_name, 0);
CreateBroadcastOp(result, g_name, 0);
} else {
CreateAllReduceOp(result, g_name);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
CreateAllReduceOp(result, g_name, IsEncoded(p_name));
#else
PADDLE_ENFORCE(false, "Compiled withoud cuda!");
#endif
}
break;
default:
......
......@@ -75,7 +75,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
bool IsSparseGradient(const std::string &og) const;
void CreateAllReduceOp(ir::Graph *result, const std::string &og) const;
void CreateAllReduceOp(ir::Graph *result, const std::string &og,
bool is_encoded = false) const;
void CreateBroadcastOp(ir::Graph *result, const std::string &p_name,
size_t src_dev_id) const;
......@@ -171,6 +172,8 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
mutable std::vector<std::unordered_set<std::string>> bcast_var_name_set_;
mutable bool need_broadcast_var_{false};
bool IsEncoded(const std::string &p_name) const;
};
std::unordered_set<std::string> &MultiDevSSAGraphBuilder();
......
......@@ -24,7 +24,8 @@ VarHandle::~VarHandle() { VLOG(4) << "deleting var handle " << DebugString(); }
std::string VarHandle::DebugString() const {
std::stringstream ss;
ss << name_ << ":" << place_;
ss << "name:" << name_ << ", place:" << place_ << ", version:" << version_
<< ", scope_idx:" << scope_idx_;
return ss.str();
}
......
......@@ -373,6 +373,11 @@ std::vector<std::string> OpDesc::AttrNames() const {
return retv;
}
void OpDesc::RemoveAttr(const std::string &name) {
attrs_.erase(name);
need_update_ = true;
}
void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
// NOTICE(minqiyang): pybind11 will take the empty list in python as
// the std::vector<int> type in C++; so we have to change the attr's type
......@@ -644,6 +649,7 @@ void OpDesc::CheckAttrs() {
// not by users.
return;
}
VLOG(10) << "begin to check attribute of " << Type();
checker->Check(&attrs_);
}
......
......@@ -72,6 +72,7 @@ class OpDesc {
std::vector<std::string> AttrNames() const;
void SetAttr(const std::string &name, const Attribute &v);
void RemoveAttr(const std::string &name);
void SetBlockAttr(const std::string &name, BlockDesc *block);
......
......@@ -1110,8 +1110,9 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
proto::VarType::Type tmp = t->type();
PADDLE_ENFORCE(
tmp == data_type || data_type == dafault_data_type,
"DataType of Paddle Op %s must be the same. Get (%d) != (%d)",
Type(), DataTypeToString(data_type), DataTypeToString(tmp));
"DataType of Paddle Op %s %s must be the same. Get (%d) != (%d)",
Type(), input.first, DataTypeToString(data_type),
DataTypeToString(tmp));
data_type = tmp;
}
}
......
......@@ -49,6 +49,11 @@ set(SHARED_INFERENCE_SRCS
${mkldnn_quantizer_src}
${CMAKE_CURRENT_SOURCE_DIR}/api/details/zero_copy_tensor.cc)
# FIXME(gongwb): hidden libdgc.a
if(WITH_GPU AND NOT WIN32)
set(fluid_modules ${fluid_modules} dgc)
endif()
if(WIN32)
sep_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} zero_copy_tensor reset_tensor_array
analysis_config ${mkldnn_quantizer_cfg} paddle_pass_builder)
......
......@@ -48,7 +48,7 @@ if (WITH_DISTRIBUTE)
SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch)
endif()
register_operators(EXCLUDES py_func_op warpctc_op conv_fusion_op sync_batch_norm_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
register_operators(EXCLUDES py_func_op warpctc_op dgc_op conv_fusion_op sync_batch_norm_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
if (WITH_GPU)
# warpctc_op needs cudnn 7 above
......@@ -72,6 +72,12 @@ endif()
set(COMMON_OP_DEPS ${OP_HEADER_DEPS})
if (WITH_GPU AND NOT WIN32)
op_library(dgc_op DEPS dgc)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(dgc);\n")
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dgc)
endif()
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col)
......
......@@ -14,69 +14,10 @@ limitations under the License. */
#include "paddle/fluid/operators/clip_by_norm_op.h"
namespace paddle {
namespace operators {
class ClipByNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ClipByNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ClipByNormOp should not be null.");
auto max_norm = ctx->Attrs().Get<float>("max_norm");
PADDLE_ENFORCE_GT(max_norm, 0, "max_norm should be greater than 0.");
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor) The input of clip_by_norm op."
"The number of dimensions must be between [1, 9].");
AddOutput("Out",
"(Tensor) The output of clip_by_norm op with shape as input(X)");
AddAttr<float>("max_norm", "(float) The maximum norm value.");
AddComment(R"DOC(
ClipByNorm Operator.
This operator limits the L2 norm of the input $X$ within $max\_norm$.
If the L2 norm of $X$ is less than or equal to $max\_norm$, $Out$ will be
the same as $X$. If the L2 norm of $X$ is greater than $max\_norm$, $X$ will
be linearly scaled to make the L2 norm of $Out$ equal to $max\_norm$, as
shown in the following formula:
$$
Out = \\frac{max\\_norm * X}{norm(X)},
$$
where $norm(X)$ represents the L2 norm of $X$.
Examples:
.. code-block:: python
data = fluid.layer.data(
name='data', shape=[2, 4, 6], dtype='float32')
reshaped = fluid.layers.clip_by_norm(
x=data, max_norm=0.5)
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(clip_by_norm, ops::ClipByNormOp,
ops::ClipByNormOpMaker);
REGISTER_OP_CPU_KERNEL(
clip_by_norm,
ops::ClipByNormKernel<paddle::platform::CPUDeviceContext, float>);
......@@ -83,5 +83,59 @@ class ClipByNormKernel : public framework::OpKernel<T> {
}
};
class ClipByNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ClipByNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ClipByNormOp should not be null.");
auto max_norm = ctx->Attrs().Get<float>("max_norm");
PADDLE_ENFORCE_GT(max_norm, 0, "max_norm should be greater than 0.");
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor) The input of clip_by_norm op."
"The number of dimensions must be between [1, 9].");
AddOutput("Out",
"(Tensor) The output of clip_by_norm op with shape as input(X)");
AddAttr<float>("max_norm", "(float) The maximum norm value.");
AddComment(R"DOC(
ClipByNorm Operator.
This operator limits the L2 norm of the input $X$ within $max\_norm$.
If the L2 norm of $X$ is less than or equal to $max\_norm$, $Out$ will be
the same as $X$. If the L2 norm of $X$ is greater than $max\_norm$, $X$ will
be linearly scaled to make the L2 norm of $Out$ equal to $max\_norm$, as
shown in the following formula:
$$
Out = \\frac{max\\_norm * X}{norm(X)},
$$
where $norm(X)$ represents the L2 norm of $X$.
Examples:
.. code-block:: python
data = fluid.layer.data(
name='data', shape=[2, 4, 6], dtype='float32')
reshaped = fluid.layers.clip_by_norm(
x=data, max_norm=0.5)
)DOC");
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/operators/dgc_clip_by_norm_op.h"
namespace paddle {
namespace operators {
class DGCClipByNormOp : public ClipByNormOp {
public:
using ClipByNormOp::ClipByNormOp;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("current_step"),
"current_step should be set.");
return ClipByNormOp::InferShape(ctx);
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "current_step") {
VLOG(10) << "var_name:" << var_name << " need not to transform";
return expected_kernel_type;
}
return framework::OperatorWithKernel::GetKernelTypeForVar(
var_name, tensor, expected_kernel_type);
}
};
class DGCClipByNormOpMaker : public ClipByNormOpMaker {
public:
void Make() override {
AddInput("current_step", "(Tensor) Current step.");
AddAttr<float>("rampup_begin_step",
"(float, -1.0)"
"The period when begin k_select.")
.SetDefault(-1.0);
return ClipByNormOpMaker::Make();
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(dgc_clip_by_norm, ops::DGCClipByNormOp,
ops::DGCClipByNormOpMaker);
REGISTER_OP_CPU_KERNEL(
dgc_clip_by_norm,
ops::DGCClipByNormKernel<paddle::platform::CPUDeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/dgc_clip_by_norm_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
dgc_clip_by_norm,
ops::DGCClipByNormKernel<paddle::platform::CUDADeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/clip_by_norm_op.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class DGCClipByNormKernel : public ClipByNormKernel<DeviceContext, T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto rampup_begin_step = context.Attr<float>("rampup_begin_step");
if (static_cast<int>(rampup_begin_step) >= 0) {
auto current_step_tensor =
context.Input<framework::Tensor>("current_step");
auto* current_step = current_step_tensor->data<T>();
if (static_cast<int>(*current_step) <
static_cast<int>(rampup_begin_step)) {
VLOG(10) << "current_step:" << *current_step
<< " < rampup_begin_step:" << rampup_begin_step
<< " so does't use dgc_clip_by_norm";
return;
}
}
return ClipByNormKernel<DeviceContext, T>::Compute(context);
};
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/dgc_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class DGCOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("U"), "Input(U) of DGCop should not be null.");
PADDLE_ENFORCE(ctx->HasInput("V"), "Input(V) of DGCop should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of DGCop should not be null.");
PADDLE_ENFORCE(ctx->HasInput("current_step"),
"Input(current_step) of DGCop should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("U_out"),
"Output(U_out) of DGCop should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("V_out"),
"Output(V_out) of DGCop should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("k"),
"Output(k) of DGCop should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("EncodeGrad"),
"Output(EncodeGrad) of DGCop should not be null.");
}
protected:
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "current_step" || var_name == "rampup_step" ||
var_name == "k") {
VLOG(10) << "var_name:" << var_name << " need not to transform";
return expected_kernel_type;
}
return framework::OperatorWithKernel::GetKernelTypeForVar(
var_name, tensor, expected_kernel_type);
}
};
class DGCOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("U", "(Tensor) Middle tensor of DGC");
AddInput("V", "(Tensor) Middle tensor of DGC");
AddInput("Grad", "(Tensor) Input gradient");
AddInput("current_step", "(Tensor) Current step.");
AddOutput("U_out",
"(Tensor) "
"Output encoded gradient");
AddOutput("V_out",
"(Tensor) "
"Output encoded gradient");
AddOutput("EncodeGrad",
"(Tensor) "
"Output encoded gradient");
AddOutput("Grad_out",
"(Tensor) "
"Output grad gradient");
AddOutput("k",
"(Tensor) "
"Output top-k value");
AddAttr<float>("m",
"(float, 0.9) "
"The momentum of learning rate.")
.SetDefault(0.9);
AddAttr<bool>("use_nesterov",
"(bool, true)"
"The momentum of learning rate.")
.SetDefault(true);
AddAttr<std::vector<float>>("sparsity",
"(vecotr, float)"
"The period sparsity of k_select.");
AddAttr<float>("rampup_begin_step",
"(float, 0.0)"
"The period when begin k_select.")
.SetDefault(0.0);
AddAttr<float>("rampup_step",
"(float, 0.0)"
"The period when begin k_select.");
AddComment(R"DOC(
Original paper is https://arxiv.org/abs/1712.01887
DGC reduce the communication bandwidth by sending only the important gradients (sparse update):\
only gradients larger than a threshold are transmitted.
To avoid losing information, DGC accumulate the rest of the gradients locally.
Eventually, these gradients become large enough to be transmitted.
Thus, DGC send the large gradients immediately but eventually send all of the gradients over time.
To ensure no loss of accuracy, DGC employs momentum correc-tionandlocal gradient clipping on top of the gradient sparsification to maintain model performance.
DGC also uses momentum factor masking and warmup training to overcome the staleness problem caused by reduced communication.
This optimizer will do two things:
1. Compress the gradient by get TopK import value from tensor \
and use it for allreduce to reduce network bandwidth.
2. Call momentum to optimize on the cost.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(dgc, ops::DGCOp, ops::DGCOpMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/dgc_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
dgc, ops::DGCOpKernel<paddle::platform::CUDADeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "dgc/dgc.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
namespace paddle {
namespace operators {
inline float get_period_sparcity(const std::vector<float>& sparsity,
float cur_step, float rampup_steps) {
PADDLE_ENFORCE(static_cast<int>(cur_step) >= 0);
size_t idx = static_cast<int>(cur_step * sparsity.size() / rampup_steps);
if (idx >= sparsity.size()) {
return 0.999;
}
PADDLE_ENFORCE(idx < sparsity.size());
return sparsity[idx];
}
template <typename DeviceContext, typename T>
class DGCOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto u = ctx.Input<framework::Tensor>("U");
auto v = ctx.Input<framework::Tensor>("V");
auto g = ctx.Input<framework::Tensor>("Grad");
// attrs
float m = ctx.Attr<float>("m");
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
auto sparsity = ctx.Attr<std::vector<float>>("sparsity");
auto rampup_begin_step = ctx.Attr<float>("rampup_begin_step");
auto rampup_step = ctx.Attr<float>("rampup_step");
// current step
auto current_step_tensor = ctx.Input<framework::Tensor>("current_step");
const float* current_step = current_step_tensor->data<float>();
if (static_cast<int>(*current_step) < static_cast<int>(rampup_begin_step)) {
VLOG(10) << "current_step:" << *current_step
<< " < rampup_begin_step:" << rampup_begin_step
<< " so does't use dgc";
return;
}
float ratio =
1 - get_period_sparcity(sparsity, static_cast<float>(*current_step),
rampup_step);
PADDLE_ENFORCE(ratio > 0.0 && ratio < 1.0);
int k = static_cast<int>(g->numel() * ratio);
VLOG(10) << "m:" << m << ", use_nesterov:" << use_nesterov
<< ", rampup_begin_step:" << rampup_begin_step
<< ", rampup_step:" << rampup_step
<< ", current_step:" << *current_step << ", ratio:" << ratio
<< ", k:" << k;
auto k_out = ctx.Output<framework::Tensor>("k");
T* k_out_data = k_out->data<T>();
*k_out_data = k;
auto u_out = ctx.Output<framework::Tensor>("U_out");
auto v_out = ctx.Output<framework::Tensor>("V_out");
auto encode_grad_out = ctx.Output<framework::Tensor>("EncodeGrad");
// FIXME(gongwb): use cublas.
auto u_out_e = framework::EigenVector<T>::Flatten(*u_out);
auto u_e = framework::EigenVector<T>::Flatten(*u);
auto g_e = framework::EigenVector<T>::Flatten(*g);
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto& eigen_ctx = *dev_ctx.eigen_device();
if (use_nesterov) {
// u = m * (u + g)
u_out_e.device(eigen_ctx) = m * (u_e + g_e);
// v = u + v + g
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
ctx, u, v, 0, AddFunctor<T>(), v_out);
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
ctx, g, v, 0, AddFunctor<T>(), v_out);
} else {
// u = m * u + g
u_out_e.device(eigen_ctx) = m * u_e + g_e;
// v = u + v
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
ctx, u, v, 0, AddFunctor<T>(), v_out);
}
T* v_out_data = v_out->mutable_data<T>(ctx.GetPlace());
T* u_out_data = u_out->mutable_data<T>(ctx.GetPlace());
T* encode_grad_out_data = encode_grad_out->mutable_data<T>(
framework::DDim{2 * k}, ctx.GetPlace());
int buf_size = paddle::communication::dgc::get_buffer_size(k);
auto& allocator = platform::DeviceTemporaryAllocator::Instance().Get(
ctx.GetPlace(), dev_ctx.stream());
auto tmp_ious_data = allocator.Allocate(buf_size);
void* buf = reinterpret_cast<void*>(tmp_ious_data->ptr());
if (!paddle::communication::dgc::k_select(
static_cast<void*>(encode_grad_out_data), k, v_out_data,
static_cast<int>(v_out->numel()), buf, dev_ctx.stream(),
u_out_data)) {
LOG(FATAL) << "v_out numel:" << v_out->numel();
}
auto grad_out = ctx.Output<framework::Tensor>("Grad_out");
math::SetConstant<DeviceContext, T> tset;
tset(dev_ctx, grad_out, static_cast<T>(0));
}
};
} // namespace operators
} // namespace paddle
......@@ -46,8 +46,9 @@ cc_test(cpu_helper_test SRCS cpu_helper_test.cc DEPS cpu_helper)
IF(WITH_GPU)
set(GPU_CTX_DEPS dynload_cuda dynamic_loader)
set(dgc_deps dgc)
ELSE()
set(GPU_CTX_DEPS)
set(dgc_deps)
ENDIF()
IF(WITH_MKLDNN)
......@@ -68,7 +69,8 @@ ENDIF()
# memcpy depends on device_context, here add deps individually for
# avoiding cycle dependencies
cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc ${STREAM_CALLBACK_DEPS}
place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} temp_allocator)
place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}
temp_allocator ${dgc_deps})
if(WIN32)
if(WITH_GPU AND NOT WITH_DSO)
......
......@@ -37,13 +37,13 @@ limitations under the License. */
} \
} while (0)
#define PADDLE_ASSERT_MSG_CODE(e, m, c) \
do { \
if (!(e)) { \
printf("%s:%d Assertion `%s` failed (%s %d).\n", __FILE__, __LINE__, \
TOSTRING(e), m, c); \
asm("trap;"); \
} \
#define PADDLE_ASSERT_MSG_CODE(e, m, c) \
do { \
if (!(e)) { \
printf("%s:%d Assertion `%s` failed (%s %ld).\n", __FILE__, __LINE__, \
TOSTRING(e), m, c); \
asm("trap;"); \
} \
} while (0)
#else
#include <assert.h>
......
......@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
#include "glog/logging.h"
namespace paddle {
namespace platform {
......@@ -324,8 +326,17 @@ void CUDADeviceContext::Wait() const {
auto& allocator =
DeviceTemporaryAllocator::Instance().Get<CUDADeviceContext>(*this);
allocator.Release([this]() {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaGetLastError());
cudaError_t e_sync = cudaStreamSynchronize(stream_);
if (e_sync != 0) {
LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync)
<< " errno:" << e_sync;
}
cudaError_t e_get = cudaGetLastError();
if (e_get != 0) {
LOG(FATAL) << "cudaGetLastError " << cudaGetErrorString(e_get)
<< " errno:" << e_get;
}
});
}
......
......@@ -31,6 +31,10 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/piece.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "dgc/dgc.h"
#endif
DEFINE_int32(paddle_num_threads, 1,
"Number of threads for each paddle instance.");
DEFINE_int32(multiple_of_cupti_buffer_size, 1,
......@@ -43,6 +47,10 @@ namespace framework {
std::once_flag gflags_init_flag;
std::once_flag p2p_init_flag;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
std::once_flag dgc_init_flag;
#endif
void InitGflags(std::vector<std::string> argv) {
std::call_once(gflags_init_flag, [&]() {
FLAGS_logtostderr = true;
......@@ -203,5 +211,15 @@ void InitGLOG(const std::string &prog_name) {
#endif
}
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void InitDGC() {
std::call_once(dgc_init_flag, []() {
PADDLE_ENFORCE(paddle::communication::dgc::dynloadNcclLib());
});
}
#else
void InitDGC() {}
#endif
} // namespace framework
} // namespace paddle
......@@ -30,5 +30,7 @@ void InitDevices(bool init_p2p);
void InitDevices(bool init_p2p, const std::vector<int> devices);
void InitDGC();
} // namespace framework
} // namespace paddle
......@@ -222,6 +222,7 @@ void BindOpDesc(pybind11::module *m) {
.def("attr_type", &pd::OpDesc::GetAttrType)
.def("attr_names", &pd::OpDesc::AttrNames)
.def("_set_attr", &pd::OpDesc::SetAttr)
.def("remove_attr", &pd::OpDesc::RemoveAttr)
.def("attr", &pd::OpDesc::GetAttr)
.def("set_block_attr", &pd::OpDesc::SetBlockAttr)
.def("set_blocks_attr", &pd::OpDesc::SetBlocksAttr)
......
......@@ -933,6 +933,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("init_gflags", framework::InitGflags);
m.def("init_glog", framework::InitGLOG);
m.def("init_dgc", framework::InitDGC);
m.def("init_devices",
[](bool init_p2p) { framework::InitDevices(init_p2p); });
......
......@@ -1202,6 +1202,9 @@ class Operator(object):
"""
self._update_desc_attr(name, val)
def _remove_attr(self, name):
self.desc.remove_attr(name)
def _update_desc_attr(self, name, val):
"""
Update the value of desc's attribute by attribute's name.
......@@ -2725,6 +2728,10 @@ class Program(object):
self._trainers_endpoints = []
# the distributed lookup table names
self._distributed_lookup_table = None
# use Deep gradient comrepssion or not
self._enable_dgc = False
# @deprecated(the python memory optimize transpiler is deprecated)
# whether the program is optimized by memory_optimize_transpiler
self.__is_mem_optimized = False
......@@ -2775,6 +2782,15 @@ class Program(object):
def set_op_role_var(self, var_name):
self._op_role_var = [var_name]
@contextlib.contextmanager
def _backward_role_guard(self):
tmp_role = self._current_role
OpRole = core.op_proto_and_checker_maker.OpRole
self._current_role = OpRole.Backward
yield
self._current_role = tmp_role
@signature_safe_contextmanager
def _optimized_guard(self, param_and_grads):
"""
......
......@@ -17,7 +17,7 @@ from __future__ import print_function
from collections import defaultdict
from .wrapped_decorator import signature_safe_contextmanager
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
from . import framework
......@@ -31,13 +31,17 @@ from .layer_helper import LayerHelper
from .layers import ops
from .regularizer import append_regularization_ops
from .imperative import base as imperative_base
from paddle.fluid import core
from paddle.fluid.layers import tensor
from functools import reduce
import copy
__all__ = [
'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl',
'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer',
'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer',
'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'LarsMomentum',
'LarsMomentumOptimizer'
'LarsMomentumOptimizer', 'DGCMomentumOptimizer'
]
......@@ -294,6 +298,9 @@ class Optimizer(object):
outputs={"ParamOut": param_and_grad[0]})
return new_param_grads, (table_param, table_grad), sgd_op
def _append_dgc_ops(self, param_and_grad):
pass
def backward(self,
loss,
startup_program=None,
......@@ -415,6 +422,9 @@ class Optimizer(object):
with program_guard(program, startup_program):
params_grads = self.backward(loss, startup_program,
parameter_list, no_grad_set)
# Note: since we can't use all_reduce_op now,
# dgc_op should be the last op of one grad.
self._append_dgc_ops(params_grads)
optimize_ops = self.apply_gradients(params_grads)
return optimize_ops, params_grads
......@@ -552,6 +562,264 @@ class MomentumOptimizer(Optimizer):
return momentum_op
class DGCMomentumOptimizer(MomentumOptimizer):
"""
Original paper is https://arxiv.org/abs/1712.01887
DGC reduce the communication bandwidth by sending only the important gradients (sparse update):\
only gradients larger than a threshold are transmitted.
To avoid losing information, DGC accumulate the rest of the gradients locally.
Eventually, these gradients become large enough to be transmitted.
Thus, DGC send the large gradients immediately but eventually send all of the gradients over time.
To ensure no loss of accuracy, DGC employs momentum correc-tionandlocal gradient clipping on top of the gradient sparsification to maintain model performance.
DGC also uses momentum factor masking and warmup training to overcome the staleness problem caused by reduced communication.
This optimizer will do two things:
1. Compress the gradient by get TopK import value from tensor \
and use it for allreduce to reduce network bandwidth.
2. Call momentum to optimize on the cost.
Args:
learning_rate (float|Variable): the learning rate used to update parameters. \
Can be a float value or a Variable with one float value as data element.
momentum (float): Momentum factor.
rampup_begin_step (int): The begining step from which gradient compression is implemented.
rampup_step (int): How long it use the sparsity periods. Default is 1.
for example: If the sparsity is [0.75, 0.9375, 0.984375, 0.996, 0.999], and the rampup_step is 5, \
it will use 0.75 at 0 step, and 0.9375 at 1 step, and so on. And when reach sparsity array ends, \
it will use 0.999 then and after.
sparsity (list[float]): Get top important element from gradient tensor, the ratio is (1 - current sparsity).
use_nesterov (bool): Enables Nesterov momentum. True means use nesterov.
local_grad_clip_norm (float): Clip norm value if needed.
num_trainers: The number of training node.
regularization: A Regularizer, such as fluid.regularizer.L2DecayRegularizer.
name: A optional name prefix.
Examples:
.. code-block:: python
optimizer = fluid.optimizer.DGCMomentumOptimizer(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=0.9,
rampup_begin_step=1252,
regularization=fluid.regularizer.L2Decay(1e-4))
optimizer.minimize(cost)
"""
def __init__(self,
learning_rate,
momentum,
rampup_begin_step,
rampup_step=1,
sparsity=[0.999],
use_nesterov=False,
local_grad_clip_norm=None,
num_trainers=None,
regularization=None,
name=None):
self._sparsity = sparsity
self._rampup_step = rampup_step
self._rampup_step_var = None
self._rampup_begin_step = rampup_begin_step
self._rampup_begin_step_var = None
self._global_step_var = None
self._local_grad_clip_norm = None
self._clip_norm = None
if local_grad_clip_norm is not None:
assert isinstance(num_trainers, int)
assert isinstance(local_grad_clip_norm, float)
assert num_trainers > 0
self._local_grad_clip_norm = local_grad_clip_norm
self._num_trainers = num_trainers
self._clip_norm = local_grad_clip_norm / (num_trainers *
num_trainers)
super(DGCMomentumOptimizer, self).__init__(
learning_rate, momentum, use_nesterov, regularization, name)
core.init_dgc()
def _add_auto_increment_var(self, counter_name, begin, step=1):
helper = LayerHelper('global_step_counter')
counter, is_new_var = helper.create_or_get_global_variable(
name=counter_name, dtype='float32', shape=[1], persistable=True)
if is_new_var:
helper.set_variable_initializer(
counter,
initializer=Constant(
value=float(begin - 1), force_cpu=True))
helper.main_program.global_block()._prepend_op(
type='increment',
inputs={'X': [counter]},
outputs={'Out': [counter]},
attrs={'step': float(step)},
stop_gradient=True)
counter.stop_gradient = True
return counter
def _append_dgc_ops(self, param_and_grads):
start_program = default_startup_program()
main_program = default_main_program()
main_program._enable_dgc = True
# step counter
self._global_step_var = self._add_auto_increment_var(
counter_name='__g_dgc_counter__', begin=0)
# rampup begin step var for all_reduce_op_handle
self._rampup_begin_step_var = tensor.create_global_var(
shape=[1],
dtype=core.VarDesc.VarType.FP32,
persistable=True,
name='__g_rampup_begin_step__',
value=self._rampup_begin_step * 1.0,
force_cpu=True)
for param_var, grad_var in param_and_grads:
var_numel = reduce(lambda x, y: x * y, param_var.shape)
if var_numel < 16384 or \
param_var.type == core.VarDesc.VarType.SELECTED_ROWS or \
grad_var.type == core.VarDesc.VarType.SELECTED_ROWS or \
param_var.dtype != core.VarDesc.VarType.FP32 :
continue
u_var = tensor.create_global_var(
shape=param_var.shape,
dtype=param_var.dtype,
persistable=True,
name=param_var.name + "__dgc_u__",
value=0.0)
v_var = tensor.create_global_var(
shape=param_var.shape,
dtype=param_var.dtype,
persistable=True,
name=param_var.name + "__dgc_v__",
value=0.0)
k_var = tensor.create_global_var(
shape=[1],
dtype=param_var.dtype,
persistable=True,
name=param_var.name + "__dgc_k__",
value=0.0,
force_cpu=True)
encoded_var = tensor.create_global_var(
shape=[1],
dtype=param_var.dtype,
persistable=True,
name=param_var.name + "__dgc_encoded__",
value=0.0,
force_cpu=False)
# del back oprolevarname
op_maker = core.op_proto_and_checker_maker
backward = core.op_proto_and_checker_maker.OpRole.Backward
for op in main_program.global_block().ops:
if not self._is_the_backward_op(op):
continue
var_attr = op.all_attrs()[op_maker.kOpRoleVarAttrName()]
if param_var.name not in var_attr:
continue
var_attr.remove(param_var.name)
var_attr.remove(grad_var.name)
if len(var_attr) > 1:
op._set_attr(op_maker.kOpRoleVarAttrName(), var_attr)
else:
op._remove_attr(op_maker.kOpRoleVarAttrName())
clip_var = grad_var
if self._local_grad_clip_norm is not None:
clip_var = self._append_clip_norm(grad_var, self._clip_norm)
self._dgc_op(param_var, clip_var, grad_var, u_var, v_var, k_var,
encoded_var)
def _is_the_backward_op(self, op):
op_maker = core.op_proto_and_checker_maker
backward = core.op_proto_and_checker_maker.OpRole.Backward
if op_maker.kOpRoleVarAttrName() in op.attr_names and \
int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(backward):
return True
return False
def _clip_by_norm(self, x, max_norm, name=None):
args = {'x': x, 'max_norm': max_norm, 'name': name}
helper = LayerHelper("dgc_clip_by_norm_op", **args)
if name is None:
name = unique_name.generate(".".join([helper.name, 'tmp']))
out = helper.create_variable(
type=x.type, name=name, dtype=x.dtype, persistable=False)
helper.append_op(
type="clip_by_norm",
inputs={"X": x,
"current_step": self._global_step_var},
attrs={
"max_norm": max_norm,
"rampup_begin_step": float(self._rampup_begin_step)
},
outputs={"Out": out})
return out
def _append_clip_norm(self, grad_var, clip_norm):
with grad_var.block.program._backward_role_guard():
return self._clip_by_norm(
x=grad_var, max_norm=clip_norm, name=grad_var.name + "@DGC")
def _dgc_op(self, param_var, clip_var, grad_var, u_var, v_var, k_var,
encoded_var):
block = framework.default_main_program().global_block()
op_maker = core.op_proto_and_checker_maker
dgc_op = block.append_op(
type="dgc",
inputs={
"U": u_var,
"V": v_var,
"Grad": clip_var,
"current_step": self._global_step_var
},
outputs={
"U_out": u_var,
"V_out": v_var,
"EncodeGrad": encoded_var,
"k": k_var,
"Grad_out": grad_var
},
attrs={
"m": self._momentum,
"sparsity": self._sparsity,
"use_nesterov": self._use_nesterov,
"rampup_begin_step": float(self._rampup_begin_step),
"rampup_step": float(self._rampup_step)
},
stop_gradient=True)
backward = op_maker.OpRole.Backward
dgc_op._set_attr(op_maker.kOpRoleAttrName(), backward)
dgc_op._set_attr(op_maker.kOpRoleVarAttrName(),
[param_var.name, grad_var.name])
class LarsMomentumOptimizer(Optimizer):
"""
Momentum optimizer with LARS support
......
......@@ -103,6 +103,12 @@ class ParallelExecutor(object):
) if use_cuda else framework.cpu_places()
self._scope = scope if scope is not None else executor.global_scope()
if main_program is not None and main_program._enable_dgc:
assert build_strategy.reduce_strategy == BuildStrategy.ReduceStrategy.AllReduce
assert num_trainers * len(
self._places) > 1, "dgc is not useful for single card training"
assert use_cuda
main_program = main_program if main_program is not None \
else framework.default_main_program()
......
......@@ -70,6 +70,7 @@ list(REMOVE_ITEM TEST_OPS test_dist_transpiler)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_crf)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_fetch_feed)
list(REMOVE_ITEM TEST_OPS test_dist_se_resnext)
list(REMOVE_ITEM TEST_OPS test_dgc_op)
list(REMOVE_ITEM TEST_OPS test_dist_se_resnext_nccl)
list(REMOVE_ITEM TEST_OPS test_dist_transformer)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_transformer)
......@@ -97,6 +98,7 @@ if(WITH_DISTRIBUTE)
set_tests_properties(test_dist_mnist PROPERTIES TIMEOUT 200)
set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200)
py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext)
py_test_modules(test_dgc_op MODULES test_dgc_op)
set_tests_properties(test_dist_se_resnext PROPERTIES TIMEOUT 1000)
py_test_modules(test_dist_se_resnext_nccl MODULES test_dist_se_resnext_nccl)
set_tests_properties(test_dist_se_resnext_nccl PROPERTIES TIMEOUT 1000)
......@@ -107,16 +109,20 @@ if(WITH_DISTRIBUTE)
endif(NOT APPLE)
# py_test_modules(test_dist_transpiler MODULES test_dist_transpiler)
endif()
py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL)
py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL)
set_tests_properties(test_parallel_executor_fetch_feed PROPERTIES TIMEOUT 450)
py_test_modules(test_parallel_executor_transformer MODULES test_parallel_executor_transformer SERIAL)
if(NOT WIN32)
py_test_modules(test_ir_memory_optimize_transformer MODULES test_ir_memory_optimize_transformer SERIAL)
py_test_modules(test_ir_memory_optimize_transformer MODULES test_ir_memory_optimize_transformer SERIAL)
endif()
if(NOT APPLE)
py_test_modules(test_image_classification_resnet MODULES test_image_classification_resnet SERIAL)
endif()
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
# change the timeout from 600 to 2200, because in debug mode, this test need more time.
set_tests_properties(test_parallel_executor_seresnext PROPERTIES TIMEOUT 2200)
......
......@@ -73,7 +73,7 @@ def cnn_model(data):
class TestDistMnist2x2(TestDistRunnerBase):
def get_model(self, batch_size=2):
def get_model(self, batch_size=2, use_dgc=False):
# Input data
images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE)
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
......@@ -93,7 +93,11 @@ class TestDistMnist2x2(TestDistRunnerBase):
# TODO(typhoonzero): fix distributed adam optimizer
# opt = fluid.optimizer.AdamOptimizer(
# learning_rate=0.001, beta1=0.9, beta2=0.999)
opt = fluid.optimizer.Momentum(learning_rate=self.lr, momentum=0.9)
if not use_dgc:
opt = fluid.optimizer.Momentum(learning_rate=self.lr, momentum=0.9)
else:
opt = fluid.optimizer.DGCMomentumOptimizer(
learning_rate=self.lr, momentum=0.9, rampup_begin_step=0)
# Reader
train_reader = paddle.batch(
......
......@@ -210,7 +210,7 @@ class SE_ResNeXt():
class DistSeResneXt2x2(TestDistRunnerBase):
def get_model(self, batch_size=2):
def get_model(self, batch_size=2, use_dgc=False):
# Input data
image = fluid.layers.data(
name="data", shape=[3, 224, 224], dtype='float32')
......@@ -237,11 +237,19 @@ class DistSeResneXt2x2(TestDistRunnerBase):
base_lr = 0.1
lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
if not use_dgc:
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
else:
optimizer = fluid.optimizer.DGCMomentumOptimizer(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=0.9,
rampup_begin_step=0,
regularization=fluid.regularizer.L2Decay(1e-4))
optimizer.minimize(avg_cost)
# Reader
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
g_array_size = 102400
class TestDGCOp(unittest.TestCase):
def setup(self, place, array_size=g_array_size):
size = array_size
np.random.seed(5) # fix seed
self.scope = fluid.global_scope()
self.place = place
print("place:", place)
# numpy data
# inputs: U, V, Grad, current_step
self.u_name = "U"
self.u = np.random.random(size).astype("float32")
self.v_name = "V"
self.v = np.random.random(size).astype("float32")
self.grad_name = "Grad"
self.grad = np.random.random(size).astype("float32")
self.current_step_name = "current_step"
self.current_step = np.full((1), 0.0).astype("float32")
# output: U_out, V_out, EncodeGrad, GradLocal_out
self.encode_grad_name = "EncodeGrad"
self.k_name = "k"
self.k = np.full((1), 0.0).astype("float32")
# scope data
self.u_tensor = self.scope.var(self.u_name).get_tensor()
self.u_tensor.set(self.u, place)
self.v_tensor = self.scope.var(self.v_name).get_tensor()
self.v_tensor.set(self.v, place)
self.grad_tensor = self.scope.var(self.grad_name).get_tensor()
self.grad_tensor.set(self.grad, place)
self.encode_grad_tensor = self.scope.var(
self.encode_grad_name).get_tensor()
self.current_step_tensor = self.scope.var(
self.current_step_name).get_tensor()
self.current_step_tensor.set(self.current_step, core.CPUPlace())
self.k_tensor = self.scope.var(self.k_name).get_tensor()
self.k_tensor.set(self.k, core.CPUPlace())
def check(self, actual_t, expect_t, place, out_name, atol=1e-5):
self.assertTrue(
np.allclose(
actual_t, expect_t, atol=atol),
"Output (" + out_name + ") has diff at " + str(place) + "\nExpect "
+ str(expect_t) + "\n" + "But Got" + str(actual_t))
def test_run_and_check(self):
self.setup(place=core.CUDAPlace(0))
kwargs = {
# inputs
'U': self.u_name,
'V': self.v_name,
'Grad': self.grad_name,
'current_step': self.current_step_name,
# outputs
'U_out': self.u_name,
'V_out': self.v_name,
'EncodeGrad': self.encode_grad_name,
'Grad_out': self.grad_name,
'k': self.k_name,
# attrs
'm': 0.9,
'sparsity': [0.75, 0.9375, 0.984375, 0.996, 0.999],
'use_nesterov': True,
'rampup_begin_step': float(0.0),
'rampup_step': float(10.0),
}
dgc_op = Operator('dgc', **kwargs)
#atol = 1e-6
dgc_op.run(self.scope, self.place)
u_out = np.array(self.u_tensor)
v_out = np.array(self.v_tensor)
grad_out = np.array(self.grad_tensor)
encode_grad_out = np.array(self.encode_grad_tensor)
k = int(np.array(self.k_tensor)[0])
print("u_out:", u_out[0:20])
print("v_out:", v_out[0:20])
print("encode_grad_out:", encode_grad_out)
print("k_out:", k)
self.assertEqual(k, int(g_array_size * 0.25))
index = encode_grad_out[0:k].view(dtype=np.int32)
value = encode_grad_out[k:2 * k]
acl = 1e-7
for i in range(0, k):
self.assertAlmostEqual(u_out[index[i]], 0.0)
self.assertAlmostEqual(v_out[index[i]], 0.0)
a_min = np.amin(value)
dangling = [x for x in v_out if x > a_min]
if __name__ == "__main__":
unittest.main()
......@@ -36,7 +36,8 @@ class TestDistRunnerBase(object):
def get_model(self,
batch_size=DEFAULT_BATCH_SIZE,
lr=0.1,
single_device=False):
single_device=False,
use_dgc=False):
raise NotImplementedError(
"get_model should be implemented by child classes.")
......@@ -82,6 +83,9 @@ class TestDistRunnerBase(object):
if args.nccl2_reduce_layer_local_run:
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
self.get_model(batch_size=args.batch_size, single_device=True)
elif args.use_dgc:
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
self.get_model(batch_size=args.batch_size, use_dgc=args.use_dgc)
else:
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
self.get_model(batch_size=args.batch_size)
......@@ -200,6 +204,7 @@ def runtime_main(test_class):
parser.add_argument('--sync_mode', action='store_true')
parser.add_argument('--mem_opt', action='store_true')
parser.add_argument('--use_cuda', action='store_true')
parser.add_argument('--use_dgc', action='store_true')
parser.add_argument('--use_reduce', action='store_true')
parser.add_argument('--dc_asgd', action='store_true')
parser.add_argument(
......@@ -235,6 +240,7 @@ class TestDistBase(unittest.TestCase):
def _after_setup_config(self):
if self._enforce_place == "CPU":
self.__use_cuda = False
self._use_dgc = False
elif self._enforce_place == "GPU":
self.__use_cuda = True
else:
......@@ -242,6 +248,10 @@ class TestDistBase(unittest.TestCase):
self.__use_cuda = True
else:
self.__use_cuda = False
self._use_dgc = False
if self._use_reduce:
assert not self._use_dgc
def setUp(self):
self._trainers = 2
......@@ -264,6 +274,7 @@ class TestDistBase(unittest.TestCase):
# test, reduce check this argument everywhere.
self._nccl2_reduce_layer = False
self._lr = 0.001
self._use_dgc = False
self._setup_config()
self._after_setup_config()
......@@ -506,6 +517,9 @@ class TestDistBase(unittest.TestCase):
env0 = {'CPU_NUM': '1'}
env1 = {'CPU_NUM': '1'}
if self._use_dgc:
tr0_cmd += " --use_dgc"
tr1_cmd += " --use_dgc"
if self._mp_mode:
env0 = {"FLAGS_selected_gpus": "0"}
env1 = {"FLAGS_selected_gpus": "1"}
......
......@@ -39,6 +39,20 @@ class TestDistMnistNCCL2(TestDistBase):
self.check_with_place("dist_mnist.py", delta=1e-5)
class TestDistMnistNCCL2DGC(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._use_reduce = False
self._use_reader_alloc = False
self._nccl2_mode = True
self._use_dgc = True
def test_dist_train(self):
import paddle.fluid as fluid
if fluid.core.is_compiled_with_cuda():
self.check_with_place("dist_mnist.py", delta=1e-5)
class TestDistMnist2x2Lars(TestDistBase):
def _setup_config(self):
self._sync_mode = True
......
......@@ -60,5 +60,20 @@ class TestDistSeResneXt2x2Async(TestDistBase):
self.check_with_place("dist_se_resnext.py", delta=100)
class TestDistSeResnetNCCL2DGC(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._use_reduce = False
self._use_reader_alloc = False
self._nccl2_mode = True
self._use_dgc = True
@skip_ci
def test_dist_train(self):
import paddle.fluid as fluid
if fluid.core.is_compiled_with_cuda():
self.check_with_place("dist_se_resnext.py", delta=30)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册