From eb83abeac3c0146b921ce72d06fef2551ab3e8d8 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Thu, 28 Mar 2019 09:23:47 +0800 Subject: [PATCH] Add DGC(Deep Gradient Compression) interface. (#15841) --- CMakeLists.txt | 6 + cmake/external/dgc.cmake | 42 +++ cmake/inference_lib.cmake | 9 + cmake/operators.cmake | 2 +- paddle/fluid/API.spec | 5 + paddle/fluid/framework/details/CMakeLists.txt | 2 +- .../framework/details/all_reduce_deps_pass.cc | 7 +- .../framework/details/all_reduce_op_handle.cc | 200 ++++++++++++- .../framework/details/all_reduce_op_handle.h | 16 +- .../details/multi_devices_graph_pass.cc | 33 ++- .../details/multi_devices_graph_pass.h | 5 +- paddle/fluid/framework/details/var_handle.cc | 3 +- paddle/fluid/framework/op_desc.cc | 6 + paddle/fluid/framework/op_desc.h | 1 + paddle/fluid/framework/operator.cc | 5 +- paddle/fluid/inference/CMakeLists.txt | 5 + paddle/fluid/operators/CMakeLists.txt | 8 +- paddle/fluid/operators/clip_by_norm_op.cc | 61 +--- paddle/fluid/operators/clip_by_norm_op.h | 54 ++++ paddle/fluid/operators/dgc_clip_by_norm_op.cc | 67 +++++ paddle/fluid/operators/dgc_clip_by_norm_op.cu | 20 ++ paddle/fluid/operators/dgc_clip_by_norm_op.h | 46 +++ paddle/fluid/operators/dgc_op.cc | 138 +++++++++ paddle/fluid/operators/dgc_op.cu | 20 ++ paddle/fluid/operators/dgc_op.h | 132 +++++++++ paddle/fluid/platform/CMakeLists.txt | 6 +- paddle/fluid/platform/assert.h | 14 +- paddle/fluid/platform/device_context.cc | 15 +- paddle/fluid/platform/init.cc | 18 ++ paddle/fluid/platform/init.h | 2 + paddle/fluid/pybind/protobuf.cc | 1 + paddle/fluid/pybind/pybind.cc | 1 + python/paddle/fluid/framework.py | 16 ++ python/paddle/fluid/optimizer.py | 272 +++++++++++++++++- python/paddle/fluid/parallel_executor.py | 6 + .../fluid/tests/unittests/CMakeLists.txt | 8 +- .../fluid/tests/unittests/dist_mnist.py | 8 +- .../fluid/tests/unittests/dist_se_resnext.py | 20 +- .../fluid/tests/unittests/test_dgc_op.py | 138 +++++++++ .../fluid/tests/unittests/test_dist_base.py | 16 +- .../fluid/tests/unittests/test_dist_mnist.py | 14 + .../tests/unittests/test_dist_se_resnext.py | 15 + 42 files changed, 1363 insertions(+), 100 deletions(-) create mode 100644 cmake/external/dgc.cmake create mode 100644 paddle/fluid/operators/dgc_clip_by_norm_op.cc create mode 100644 paddle/fluid/operators/dgc_clip_by_norm_op.cu create mode 100644 paddle/fluid/operators/dgc_clip_by_norm_op.h create mode 100644 paddle/fluid/operators/dgc_op.cc create mode 100644 paddle/fluid/operators/dgc_op.cu create mode 100644 paddle/fluid/operators/dgc_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_dgc_op.py diff --git a/CMakeLists.txt b/CMakeLists.txt index a38e32b73d5..9ad69738eb2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -193,6 +193,12 @@ if(WITH_GPU) include(tensorrt) include(anakin_subgraph) endif() + +if(WITH_GPU AND NOT WIN32) + message(STATUS "add dgc lib.") + include(external/dgc) +endif() + if(WITH_MKL OR WITH_MKLML) include(external/anakin) elseif() diff --git a/cmake/external/dgc.cmake b/cmake/external/dgc.cmake new file mode 100644 index 00000000000..199ca88b477 --- /dev/null +++ b/cmake/external/dgc.cmake @@ -0,0 +1,42 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +INCLUDE(ExternalProject) + +SET(DGC_SOURCES_DIR "${THIRD_PARTY_PATH}/dgc") +SET(DGC_INSTALL_DIR "${THIRD_PARTY_PATH}/install/dgc") +SET(DGC_INCLUDE_DIR "${DGC_INSTALL_DIR}/include" CACHE PATH "dgc include directory." FORCE) +SET(DGC_LIBRARIES "${DGC_INSTALL_DIR}/lib/libdgc.a" CACHE FILEPATH "dgc library." FORCE) +INCLUDE_DIRECTORIES(${DGC_INCLUDE_DIR}) + +ExternalProject_Add( + extern_dgc + ${EXTERNAL_PROJECT_LOG_ARGS} + GIT_REPOSITORY "https://github.com/PaddlePaddle/Fleet" + GIT_TAG "2d04dc3800cdd0601f1b65d547dabcc60b0cf9dc" + SOURCE_DIR "${DGC_SOURCES_DIR}" + CONFIGURE_COMMAND "" + BUILD_COMMAND cd collective && make -j + INSTALL_COMMAND mkdir -p ${DGC_INSTALL_DIR}/lib/ ${DGC_INCLUDE_DIR}/dgc + && cp ${DGC_SOURCES_DIR}/collective/build/lib/libdgc.a ${DGC_LIBRARIES} + && cp ${DGC_SOURCES_DIR}/collective/build/include/dgc.h ${DGC_INCLUDE_DIR}/dgc/ + BUILD_IN_SOURCE 1 +) + +ADD_LIBRARY(dgc SHARED IMPORTED GLOBAL) +SET_PROPERTY(TARGET dgc PROPERTY IMPORTED_LOCATION ${DGC_LIBRARIES}) +ADD_DEPENDENCIES(dgc extern_dgc) + +LIST(APPEND external_project_dependencies dgc) + diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index a7dce4dfdb5..b7c32f80db0 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -131,6 +131,15 @@ elseif (NOT CBLAS_FOUND OR WIN32) ) endif () +if (WITH_GPU AND NOT WIN32) + set(dgc_dir "${FLUID_INSTALL_DIR}/third_party/install/dgc") + copy(dgc_lib + SRCS ${DGC_INSTALL_DIR}/lib ${DGC_INSTALL_DIR}/include + DSTS ${dgc_dir} ${dgc_dir} + DEPS dgc) +endif() + + if (WITH_MKLDNN) set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/mkldnn") copy(mkldnn_lib diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 34c6cbd73dd..c17e718f427 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -110,7 +110,7 @@ function(op_library TARGET) # Define operators that don't need pybind here. foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" -"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op") +"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) endif() diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 851308a0f65..e6f5cb7473c 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -483,6 +483,11 @@ paddle.fluid.optimizer.LarsMomentumOptimizer.apply_gradients (ArgSpec(args=['sel paddle.fluid.optimizer.LarsMomentumOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.LarsMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.LarsMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) +paddle.fluid.optimizer.DGCMomentumOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'momentum', 'rampup_begin_step', 'rampup_step', 'sparsity', 'use_nesterov', 'local_grad_clip_norm', 'num_trainers', 'regularization', 'name'], varargs=None, keywords=None, defaults=(1, [0.999], False, None, None, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) +paddle.fluid.optimizer.DGCMomentumOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) +paddle.fluid.optimizer.DGCMomentumOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) +paddle.fluid.optimizer.DGCMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) +paddle.fluid.optimizer.DGCMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) paddle.fluid.backward.append_backward (ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '1a79bd7d10ae54ca763ec81bca36ba24')) paddle.fluid.regularizer.L1DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.regularizer.L2DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 77e94e998c4..046ec6978a8 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -23,7 +23,7 @@ endif() if(WITH_GPU) nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory - dynload_cuda variable_visitor) + dynload_cuda variable_visitor dgc) nv_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory dynload_cuda variable_visitor) if(WITH_DISTRIBUTE) diff --git a/paddle/fluid/framework/details/all_reduce_deps_pass.cc b/paddle/fluid/framework/details/all_reduce_deps_pass.cc index c084410864b..98a74d630cd 100644 --- a/paddle/fluid/framework/details/all_reduce_deps_pass.cc +++ b/paddle/fluid/framework/details/all_reduce_deps_pass.cc @@ -86,7 +86,8 @@ std::unique_ptr AllReduceDepsPass::ApplyImpl( } } - VLOG(10) << "dist_ops size:" << dist_ops.size() << std::endl; + VLOG(10) << "dist_ops size:" << dist_ops.size() + << ", outputs size:" << vars.size() << ", ops size:" << ops.size(); std::sort(dist_ops.begin(), dist_ops.end(), [&](OpHandleBase* op1, OpHandleBase* op2) { @@ -99,6 +100,10 @@ std::unique_ptr AllReduceDepsPass::ApplyImpl( auto l_it = vars.find(i0->name()); auto r_it = vars.find(i1->name()); + PADDLE_ENFORCE(l_it != vars.end() && r_it != vars.end(), + "can't find var's name %s and %s in opdesc", i0->name(), + i1->name()); + if (l_it->second < r_it->second) return true; if (l_it->second == r_it->second) { diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index fdaff08e537..6e477cd2977 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -16,6 +16,13 @@ #include "paddle/fluid/framework/details/container_cast.h" #include "paddle/fluid/framework/details/reduce_and_gather.h" #include "paddle/fluid/framework/details/variable_visitor.h" +#include "paddle/fluid/framework/operator.h" + +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include "dgc/dgc.h" +#endif + +#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/profiler.h" // asynchronous nccl allreduce or synchronous issue: @@ -33,11 +40,14 @@ namespace details { AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places, - const platform::NCCLContextMap *ctxs) + const platform::NCCLContextMap *ctxs, + bool is_encoded, int nranks) : OpHandleBase(node), local_scopes_(local_scopes), places_(places), - nccl_ctxs_(ctxs) { + nccl_ctxs_(ctxs), + is_encoded_(is_encoded), + nranks_(nranks) { if (nccl_ctxs_) { for (auto &p : places_) { this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p)); @@ -51,7 +61,185 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {} #endif +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +void AllReduceOpHandle::RunImplEncoded() { + platform::RecordEvent record_event(Name()); + + WaitInputVarGenerated(); + + auto in_var_handles = DynamicCast(this->Inputs()); + auto out_var_handles = DynamicCast(this->Outputs()); + PADDLE_ENFORCE_EQ( + in_var_handles.size(), places_.size(), + "The NoDummyInputSize should be equal to the number of places."); + PADDLE_ENFORCE_EQ( + in_var_handles.size(), out_var_handles.size(), + "The NoDummyInputSize and NoDummyOutputSize should be equal."); + + std::vector ins; + std::vector outs; + int k = -1; + for (size_t i = 0; i < local_scopes_.size(); ++i) { + auto &local_scope = + local_scopes_[i]->FindVar(kLocalExecScopeName)->Get(); + auto original_name = + paddle::framework::GradOriginalVarName(in_var_handles[i]->name()); + auto encode_var_name = original_name + g_dgc_encoded; + auto *in_var = local_scope->FindVar(encode_var_name); + PADDLE_ENFORCE_NOT_NULL(in_var); + auto &in = in_var->Get(); + ins.emplace_back(&in); + + auto *out = local_scope->FindVar(out_var_handles[i]->name()) + ->GetMutable(); + outs.emplace_back(out); + + if (k < 0) { + k = GetKValue(in_var_handles[i]->name()); + } + } + + PADDLE_ENFORCE(platform::is_gpu_place(ins[0]->place())); + PADDLE_ENFORCE(platform::is_gpu_place(outs[0]->place())); + PADDLE_ENFORCE(nccl_ctxs_, "nccl_ctxs should not be nullptr."); + + int dtype = -1; + size_t in_numel = 0; + size_t out_numel = 0; + PADDLE_ENFORCE(nranks_ > 1); + std::vector> all_reduce_calls; + + for (size_t i = 0; i < local_scopes_.size(); ++i) { + auto &place = places_[i]; + auto &in = *ins[i]; + void *in_tensor_buf = const_cast(in.data()); + + auto &out = *outs[i]; + float *out_tensor_buf = out.data(); + + dtype = (dtype == -1) ? platform::ToNCCLDataType(in.type()) : dtype; + in_numel = (in_numel == 0) ? static_cast(in.numel()) : in_numel; + PADDLE_ENFORCE(in_numel % 2 == 0); + PADDLE_ENFORCE(in_numel / 2 == static_cast(k)); + out_numel = (out_numel == 0) ? static_cast(out.numel()) : out_numel; + + int dev_id = boost::get(place).device; + auto &nccl_ctx = nccl_ctxs_->at(dev_id); + auto stream = nccl_ctx.stream(); + auto comm = nccl_ctx.comm_; + + auto &allocator = + platform::DeviceTemporaryAllocator::Instance().Get(place, stream); + int encode_size = 2 * k * sizeof(int); + // dgc use ncclAllGather to get all the encoded data + // so the buffer need nranks. + int buf_size = nranks_ * encode_size; + auto tmp_ious_data = allocator.Allocate(buf_size); + void *gather_buff = reinterpret_cast(tmp_ious_data->ptr()); + + VLOG(10) << "in_numel:" << in_numel << ", out_numel:" << out_numel + << ", nranks:" << nranks_ << ", gather_buf size:" << buf_size + << ", k:" << k << ", place:" << place << ", dtype:" << dtype; + + all_reduce_calls.emplace_back([=] { + PADDLE_ENFORCE(paddle::communication::dgc::sparseAllGReduce( + in_tensor_buf, gather_buff, k, out_tensor_buf, out_numel, comm, + stream)); + }); + } + + this->RunAndRecordEvent([&] { + if (all_reduce_calls.size() == 1UL) { + // Do not use NCCLGroup when manage NCCL by per thread per device + all_reduce_calls[0](); + } else { + platform::NCCLGroupGuard guard; + for (auto &call : all_reduce_calls) { + call(); + } + } + }); + + if (FLAGS_sync_nccl_allreduce) { + for (auto &p : places_) { + int dev_id = boost::get(p).device; + auto &nccl_ctx = nccl_ctxs_->at(dev_id); + auto stream = nccl_ctx.stream(); + cudaError_t e_sync = cudaStreamSynchronize(stream); + if (e_sync != 0) { + LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync); + } + + cudaError_t e_get = cudaGetLastError(); + if (e_get != 0) { + LOG(FATAL) << "cudaGetLastError " << cudaGetErrorString(e_get) + << " errno:" << e_get; + } + } + } +} + +int AllReduceOpHandle::GetKValue(const std::string &grad_name) { + auto original_name = paddle::framework::GradOriginalVarName(grad_name); + auto var_name = original_name + g_dgc_k; + PADDLE_ENFORCE(local_scopes_.size() > 0); + + auto *scope = local_scopes_[0]; + auto &local_scope = scope->FindVar(kLocalExecScopeName)->Get(); + auto var = local_scope->FindVar(var_name); + PADDLE_ENFORCE_NOT_NULL(var); + auto tensor = var->Get().data(); + return *tensor; +} +#endif + +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +bool AllReduceOpHandle::IsEncoded() { + if (!is_encoded_) { + return false; + } + auto counter_name = g_dgc_counter_name; + auto step_name = g_dgc_rampup_begin_step; + PADDLE_ENFORCE(local_scopes_.size() > 0); + + auto *scope = local_scopes_[0]; + auto &local_scope = scope->FindVar(kLocalExecScopeName)->Get(); + auto count_var = local_scope->FindVar(counter_name); + auto step_var = local_scope->FindVar(step_name); + if (count_var == nullptr || step_var == nullptr) { + PADDLE_THROW("not find count_var:%s or step_var:%s", counter_name, + step_var); + } + + float count = *count_var->Get().data(); + float step = *step_var->Get().data(); + if (static_cast(count) < static_cast(step)) { + VLOG(10) << "in all_reduce currentstep:" << count + << " < rampup_begin_step:" << step + << " so not use sparse all reduce"; + return false; + } + + return true; +} +#else +bool AllReduceOpHandle::IsEncoded() { return false; } +#endif + void AllReduceOpHandle::RunImpl() { + if (!IsEncoded()) { + RunImplNormal(); + return; + } + +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + RunImplEncoded(); +#else + PADDLE_THROW("Not compiled with CUDA"); +#endif +} + +void AllReduceOpHandle::RunImplNormal() { platform::RecordEvent record_event(Name()); WaitInputVarGenerated(); @@ -72,6 +260,8 @@ void AllReduceOpHandle::RunImpl() { auto &lod_tensor = local_scope.FindVar(in_var_handles[i]->name())->Get(); lod_tensors.emplace_back(&lod_tensor); + VLOG(10) << "place:" << i << ", input_name:" << in_var_handles[i]->name() + << ", out_name:" << out_var_handles[i]->name(); PADDLE_ENFORCE_EQ(in_var_handles[i]->name(), out_var_handles[i]->name(), "The name of input and output should be equal."); } @@ -99,13 +289,17 @@ void AllReduceOpHandle::RunImpl() { auto &nccl_ctx = nccl_ctxs_->at(dev_id); auto stream = nccl_ctx.stream(); auto comm = nccl_ctx.comm_; + + VLOG(10) << "before all reduce buffer:" << buffer << ", numel:" << numel + << ", dev_id:" << dev_id << ", dtype:" << dtype + << ", place:" << p; + all_reduce_calls.emplace_back([=] { PADDLE_ENFORCE(platform::dynload::ncclAllReduce( buffer, buffer, numel, static_cast(dtype), ncclSum, comm, stream)); }); } - this->RunAndRecordEvent([&] { if (all_reduce_calls.size() == 1UL) { // Do not use NCCLGroup when manage NCCL by per thread per device diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.h b/paddle/fluid/framework/details/all_reduce_op_handle.h index b449796fcae..ca75186f6ce 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/all_reduce_op_handle.h @@ -28,11 +28,19 @@ namespace paddle { namespace framework { namespace details { +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +constexpr char g_dgc_counter_name[] = "__g_dgc_counter__"; +constexpr char g_dgc_rampup_begin_step[] = "__g_rampup_begin_step__"; +constexpr char g_dgc_encoded[] = "__dgc_encoded__"; +constexpr char g_dgc_k[] = "__dgc_k__"; +#endif + struct AllReduceOpHandle : public OpHandleBase { #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) AllReduceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places, - const platform::NCCLContextMap *ctxs); + const platform::NCCLContextMap *ctxs, + bool is_encoded = false, int nranks = -1); #else AllReduceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places); @@ -50,8 +58,14 @@ struct AllReduceOpHandle : public OpHandleBase { std::vector local_scopes_; std::vector places_; #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + void RunImplEncoded(); const platform::NCCLContextMap *nccl_ctxs_; + bool is_encoded_{false}; + int nranks_{-1}; + int GetKValue(const std::string &grad_name); #endif + void RunImplNormal(); + bool IsEncoded(); }; } // namespace details diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 253cf5b4a82..8c61684c9c9 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -32,6 +32,7 @@ #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/operators/math/math_function.h" namespace paddle { namespace framework { @@ -209,7 +210,8 @@ std::unique_ptr MultiDevSSAGraphBuilderBase::ApplyImpl( for (size_t i = 0; i < backward_vars.size(); i += 2) { auto &p_name = backward_vars[i]; auto &g_name = backward_vars[i + 1]; - VLOG(10) << "Bcast " << g_name << " for parameter " << p_name; + VLOG(10) << "Bcast " << g_name << " for parameter " << p_name + << " op_type " << node->Op()->Type(); if (NeedCollectiveForGrad(g_name, sorted_ops)) { InsertCollectiveOp(&result, p_name, g_name); } @@ -414,8 +416,9 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result, CreateOpHandleIOs(result, node, dev_id); } -void MultiDevSSAGraphBuilderBase::CreateAllReduceOp( - ir::Graph *result, const std::string &og) const { +void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result, + const std::string &og, + bool is_encoded) const { OpHandleBase *op_handle = nullptr; auto append_allreduce_op = [&]( @@ -424,7 +427,9 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp( #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) result->Get(kGraphOps).emplace_back(new AllReduceOpHandle( result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), - scopes, places, nccl_ctxs_)); + scopes, places, nccl_ctxs_, is_encoded, + static_cast(strategy_.trainers_endpoints_.size()) * + places_.size())); #else result->Get(kGraphOps).emplace_back(new AllReduceOpHandle( result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), @@ -446,12 +451,15 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp( PADDLE_ENFORCE(!vars.empty()); auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad); + VLOG(10) << "all_reduce_op_handle add input " << prev_grad->DebugString(); auto var = new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable), vars.size(), i, og, places_[i]); vars.emplace_back(var); op_handle->AddOutput(var); + VLOG(10) << "all_reduce_op_handle add output " << og + << ", handle:" << var->DebugString(); } } @@ -941,6 +949,17 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, return op_dev_id; } +bool DistSSAGraphBuilder::IsEncoded(const std::string &p_name) const { + auto u_name = p_name + "__dgc_u__"; + auto it = all_vars_.find(u_name); + if (it == all_vars_.end()) { + VLOG(10) << "can't find u_name, so it's not encoded:" << u_name; + return false; + } + + return true; +} + void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, const std::string &p_name, const std::string &g_name) const { @@ -956,7 +975,11 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, CreateReduceOp(result, g_name, 0); CreateBroadcastOp(result, g_name, 0); } else { - CreateAllReduceOp(result, g_name); +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + CreateAllReduceOp(result, g_name, IsEncoded(p_name)); +#else + PADDLE_ENFORCE(false, "Compiled withoud cuda!"); +#endif } break; default: diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index 0ee3a060629..8bfd7b9bf8f 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -75,7 +75,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { bool IsSparseGradient(const std::string &og) const; - void CreateAllReduceOp(ir::Graph *result, const std::string &og) const; + void CreateAllReduceOp(ir::Graph *result, const std::string &og, + bool is_encoded = false) const; void CreateBroadcastOp(ir::Graph *result, const std::string &p_name, size_t src_dev_id) const; @@ -171,6 +172,8 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder { mutable std::vector> bcast_var_name_set_; mutable bool need_broadcast_var_{false}; + + bool IsEncoded(const std::string &p_name) const; }; std::unordered_set &MultiDevSSAGraphBuilder(); diff --git a/paddle/fluid/framework/details/var_handle.cc b/paddle/fluid/framework/details/var_handle.cc index 30da029ca2a..95d62e66415 100644 --- a/paddle/fluid/framework/details/var_handle.cc +++ b/paddle/fluid/framework/details/var_handle.cc @@ -24,7 +24,8 @@ VarHandle::~VarHandle() { VLOG(4) << "deleting var handle " << DebugString(); } std::string VarHandle::DebugString() const { std::stringstream ss; - ss << name_ << ":" << place_; + ss << "name:" << name_ << ", place:" << place_ << ", version:" << version_ + << ", scope_idx:" << scope_idx_; return ss.str(); } diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 8f9c6cb5e92..353db435213 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -373,6 +373,11 @@ std::vector OpDesc::AttrNames() const { return retv; } +void OpDesc::RemoveAttr(const std::string &name) { + attrs_.erase(name); + need_update_ = true; +} + void OpDesc::SetAttr(const std::string &name, const Attribute &v) { // NOTICE(minqiyang): pybind11 will take the empty list in python as // the std::vector type in C++; so we have to change the attr's type @@ -644,6 +649,7 @@ void OpDesc::CheckAttrs() { // not by users. return; } + VLOG(10) << "begin to check attribute of " << Type(); checker->Check(&attrs_); } diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index d7352c5ee5a..dedaf243647 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -72,6 +72,7 @@ class OpDesc { std::vector AttrNames() const; void SetAttr(const std::string &name, const Attribute &v); + void RemoveAttr(const std::string &name); void SetBlockAttr(const std::string &name, BlockDesc *block); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index eef84d17a4b..b0ac73f9f52 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1110,8 +1110,9 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( proto::VarType::Type tmp = t->type(); PADDLE_ENFORCE( tmp == data_type || data_type == dafault_data_type, - "DataType of Paddle Op %s must be the same. Get (%d) != (%d)", - Type(), DataTypeToString(data_type), DataTypeToString(tmp)); + "DataType of Paddle Op %s %s must be the same. Get (%d) != (%d)", + Type(), input.first, DataTypeToString(data_type), + DataTypeToString(tmp)); data_type = tmp; } } diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index 5e0be5d445e..fb433ff2a2b 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -49,6 +49,11 @@ set(SHARED_INFERENCE_SRCS ${mkldnn_quantizer_src} ${CMAKE_CURRENT_SOURCE_DIR}/api/details/zero_copy_tensor.cc) +# FIXME(gongwb): hidden libdgc.a +if(WITH_GPU AND NOT WIN32) + set(fluid_modules ${fluid_modules} dgc) +endif() + if(WIN32) sep_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} zero_copy_tensor reset_tensor_array analysis_config ${mkldnn_quantizer_cfg} paddle_pass_builder) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index afac8e4d2a3..e52e83673fe 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -48,7 +48,7 @@ if (WITH_DISTRIBUTE) SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch) endif() -register_operators(EXCLUDES py_func_op warpctc_op conv_fusion_op sync_batch_norm_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) +register_operators(EXCLUDES py_func_op warpctc_op dgc_op conv_fusion_op sync_batch_norm_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) if (WITH_GPU) # warpctc_op needs cudnn 7 above @@ -72,6 +72,12 @@ endif() set(COMMON_OP_DEPS ${OP_HEADER_DEPS}) +if (WITH_GPU AND NOT WIN32) + op_library(dgc_op DEPS dgc) + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(dgc);\n") + set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dgc) +endif() + set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col) diff --git a/paddle/fluid/operators/clip_by_norm_op.cc b/paddle/fluid/operators/clip_by_norm_op.cc index eae86a373be..5720b295ecf 100644 --- a/paddle/fluid/operators/clip_by_norm_op.cc +++ b/paddle/fluid/operators/clip_by_norm_op.cc @@ -14,69 +14,10 @@ limitations under the License. */ #include "paddle/fluid/operators/clip_by_norm_op.h" -namespace paddle { -namespace operators { - -class ClipByNormOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of ClipByNormOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of ClipByNormOp should not be null."); - auto max_norm = ctx->Attrs().Get("max_norm"); - PADDLE_ENFORCE_GT(max_norm, 0, "max_norm should be greater than 0."); - auto x_dims = ctx->GetInputDim("X"); - ctx->SetOutputDim("Out", x_dims); - ctx->ShareLoD("X", /*->*/ "Out"); - } -}; - -class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", - "(Tensor) The input of clip_by_norm op." - "The number of dimensions must be between [1, 9]."); - AddOutput("Out", - "(Tensor) The output of clip_by_norm op with shape as input(X)"); - AddAttr("max_norm", "(float) The maximum norm value."); - AddComment(R"DOC( -ClipByNorm Operator. - -This operator limits the L2 norm of the input $X$ within $max\_norm$. -If the L2 norm of $X$ is less than or equal to $max\_norm$, $Out$ will be -the same as $X$. If the L2 norm of $X$ is greater than $max\_norm$, $X$ will -be linearly scaled to make the L2 norm of $Out$ equal to $max\_norm$, as -shown in the following formula: - -$$ -Out = \\frac{max\\_norm * X}{norm(X)}, -$$ - -where $norm(X)$ represents the L2 norm of $X$. - -Examples: - .. code-block:: python - - data = fluid.layer.data( - name='data', shape=[2, 4, 6], dtype='float32') - reshaped = fluid.layers.clip_by_norm( - x=data, max_norm=0.5) - -)DOC"); - } -}; - -} // namespace operators -} // namespace paddle - namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(clip_by_norm, ops::ClipByNormOp, ops::ClipByNormOpMaker); + REGISTER_OP_CPU_KERNEL( clip_by_norm, ops::ClipByNormKernel); diff --git a/paddle/fluid/operators/clip_by_norm_op.h b/paddle/fluid/operators/clip_by_norm_op.h index 49e734ce96b..d8baa4b8b23 100644 --- a/paddle/fluid/operators/clip_by_norm_op.h +++ b/paddle/fluid/operators/clip_by_norm_op.h @@ -83,5 +83,59 @@ class ClipByNormKernel : public framework::OpKernel { } }; +class ClipByNormOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ClipByNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ClipByNormOp should not be null."); + auto max_norm = ctx->Attrs().Get("max_norm"); + PADDLE_ENFORCE_GT(max_norm, 0, "max_norm should be greater than 0."); + auto x_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor) The input of clip_by_norm op." + "The number of dimensions must be between [1, 9]."); + AddOutput("Out", + "(Tensor) The output of clip_by_norm op with shape as input(X)"); + AddAttr("max_norm", "(float) The maximum norm value."); + AddComment(R"DOC( +ClipByNorm Operator. + +This operator limits the L2 norm of the input $X$ within $max\_norm$. +If the L2 norm of $X$ is less than or equal to $max\_norm$, $Out$ will be +the same as $X$. If the L2 norm of $X$ is greater than $max\_norm$, $X$ will +be linearly scaled to make the L2 norm of $Out$ equal to $max\_norm$, as +shown in the following formula: + +$$ +Out = \\frac{max\\_norm * X}{norm(X)}, +$$ + +where $norm(X)$ represents the L2 norm of $X$. + +Examples: + .. code-block:: python + + data = fluid.layer.data( + name='data', shape=[2, 4, 6], dtype='float32') + reshaped = fluid.layers.clip_by_norm( + x=data, max_norm=0.5) + +)DOC"); + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/dgc_clip_by_norm_op.cc b/paddle/fluid/operators/dgc_clip_by_norm_op.cc new file mode 100644 index 00000000000..6ebad4de3c8 --- /dev/null +++ b/paddle/fluid/operators/dgc_clip_by_norm_op.cc @@ -0,0 +1,67 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/operators/dgc_clip_by_norm_op.h" + +namespace paddle { +namespace operators { + +class DGCClipByNormOp : public ClipByNormOp { + public: + using ClipByNormOp::ClipByNormOp; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("current_step"), + "current_step should be set."); + + return ClipByNormOp::InferShape(ctx); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + if (var_name == "current_step") { + VLOG(10) << "var_name:" << var_name << " need not to transform"; + return expected_kernel_type; + } + + return framework::OperatorWithKernel::GetKernelTypeForVar( + var_name, tensor, expected_kernel_type); + } +}; + +class DGCClipByNormOpMaker : public ClipByNormOpMaker { + public: + void Make() override { + AddInput("current_step", "(Tensor) Current step."); + AddAttr("rampup_begin_step", + "(float, -1.0)" + "The period when begin k_select.") + .SetDefault(-1.0); + + return ClipByNormOpMaker::Make(); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(dgc_clip_by_norm, ops::DGCClipByNormOp, + ops::DGCClipByNormOpMaker); + +REGISTER_OP_CPU_KERNEL( + dgc_clip_by_norm, + ops::DGCClipByNormKernel); diff --git a/paddle/fluid/operators/dgc_clip_by_norm_op.cu b/paddle/fluid/operators/dgc_clip_by_norm_op.cu new file mode 100644 index 00000000000..e7f564b7ab4 --- /dev/null +++ b/paddle/fluid/operators/dgc_clip_by_norm_op.cu @@ -0,0 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/dgc_clip_by_norm_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + dgc_clip_by_norm, + ops::DGCClipByNormKernel); diff --git a/paddle/fluid/operators/dgc_clip_by_norm_op.h b/paddle/fluid/operators/dgc_clip_by_norm_op.h new file mode 100644 index 00000000000..bd22d16f7a2 --- /dev/null +++ b/paddle/fluid/operators/dgc_clip_by_norm_op.h @@ -0,0 +1,46 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/clip_by_norm_op.h" + +namespace paddle { +namespace operators { + +template +class DGCClipByNormKernel : public ClipByNormKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto rampup_begin_step = context.Attr("rampup_begin_step"); + if (static_cast(rampup_begin_step) >= 0) { + auto current_step_tensor = + context.Input("current_step"); + auto* current_step = current_step_tensor->data(); + + if (static_cast(*current_step) < + static_cast(rampup_begin_step)) { + VLOG(10) << "current_step:" << *current_step + << " < rampup_begin_step:" << rampup_begin_step + << " so does't use dgc_clip_by_norm"; + return; + } + } + + return ClipByNormKernel::Compute(context); + }; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/dgc_op.cc b/paddle/fluid/operators/dgc_op.cc new file mode 100644 index 00000000000..ccdeea2d0a9 --- /dev/null +++ b/paddle/fluid/operators/dgc_op.cc @@ -0,0 +1,138 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/dgc_op.h" +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class DGCOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("U"), "Input(U) of DGCop should not be null."); + PADDLE_ENFORCE(ctx->HasInput("V"), "Input(V) of DGCop should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Grad"), + "Input(Grad) of DGCop should not be null."); + PADDLE_ENFORCE(ctx->HasInput("current_step"), + "Input(current_step) of DGCop should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("U_out"), + "Output(U_out) of DGCop should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("V_out"), + "Output(V_out) of DGCop should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("k"), + "Output(k) of DGCop should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("EncodeGrad"), + "Output(EncodeGrad) of DGCop should not be null."); + } + + protected: + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + if (var_name == "current_step" || var_name == "rampup_step" || + var_name == "k") { + VLOG(10) << "var_name:" << var_name << " need not to transform"; + return expected_kernel_type; + } + + return framework::OperatorWithKernel::GetKernelTypeForVar( + var_name, tensor, expected_kernel_type); + } +}; + +class DGCOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("U", "(Tensor) Middle tensor of DGC"); + AddInput("V", "(Tensor) Middle tensor of DGC"); + AddInput("Grad", "(Tensor) Input gradient"); + AddInput("current_step", "(Tensor) Current step."); + + AddOutput("U_out", + "(Tensor) " + "Output encoded gradient"); + AddOutput("V_out", + "(Tensor) " + "Output encoded gradient"); + AddOutput("EncodeGrad", + "(Tensor) " + "Output encoded gradient"); + AddOutput("Grad_out", + "(Tensor) " + "Output grad gradient"); + AddOutput("k", + "(Tensor) " + "Output top-k value"); + + AddAttr("m", + "(float, 0.9) " + "The momentum of learning rate.") + .SetDefault(0.9); + + AddAttr("use_nesterov", + "(bool, true)" + "The momentum of learning rate.") + .SetDefault(true); + + AddAttr>("sparsity", + "(vecotr, float)" + "The period sparsity of k_select."); + + AddAttr("rampup_begin_step", + "(float, 0.0)" + "The period when begin k_select.") + .SetDefault(0.0); + + AddAttr("rampup_step", + "(float, 0.0)" + "The period when begin k_select."); + + AddComment(R"DOC( + Original paper is https://arxiv.org/abs/1712.01887 + + DGC reduce the communication bandwidth by sending only the important gradients (sparse update):\ + only gradients larger than a threshold are transmitted. + + To avoid losing information, DGC accumulate the rest of the gradients locally. + + Eventually, these gradients become large enough to be transmitted. + + Thus, DGC send the large gradients immediately but eventually send all of the gradients over time. + + To ensure no loss of accuracy, DGC employs momentum correc-tionandlocal gradient clipping on top of the gradient sparsification to maintain model performance. + + DGC also uses momentum factor masking and warmup training to overcome the staleness problem caused by reduced communication. + + This optimizer will do two things: + + 1. Compress the gradient by get TopK import value from tensor \ + and use it for allreduce to reduce network bandwidth. + + 2. Call momentum to optimize on the cost. + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(dgc, ops::DGCOp, ops::DGCOpMaker); diff --git a/paddle/fluid/operators/dgc_op.cu b/paddle/fluid/operators/dgc_op.cu new file mode 100644 index 00000000000..0f0bf441a70 --- /dev/null +++ b/paddle/fluid/operators/dgc_op.cu @@ -0,0 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/dgc_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + dgc, ops::DGCOpKernel); diff --git a/paddle/fluid/operators/dgc_op.h b/paddle/fluid/operators/dgc_op.h new file mode 100644 index 00000000000..8d1683bdb2d --- /dev/null +++ b/paddle/fluid/operators/dgc_op.h @@ -0,0 +1,132 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "dgc/dgc.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" + +namespace paddle { +namespace operators { + +inline float get_period_sparcity(const std::vector& sparsity, + float cur_step, float rampup_steps) { + PADDLE_ENFORCE(static_cast(cur_step) >= 0); + + size_t idx = static_cast(cur_step * sparsity.size() / rampup_steps); + if (idx >= sparsity.size()) { + return 0.999; + } + + PADDLE_ENFORCE(idx < sparsity.size()); + return sparsity[idx]; +} + +template +class DGCOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto u = ctx.Input("U"); + auto v = ctx.Input("V"); + auto g = ctx.Input("Grad"); + + // attrs + float m = ctx.Attr("m"); + bool use_nesterov = ctx.Attr("use_nesterov"); + auto sparsity = ctx.Attr>("sparsity"); + auto rampup_begin_step = ctx.Attr("rampup_begin_step"); + auto rampup_step = ctx.Attr("rampup_step"); + + // current step + auto current_step_tensor = ctx.Input("current_step"); + const float* current_step = current_step_tensor->data(); + + if (static_cast(*current_step) < static_cast(rampup_begin_step)) { + VLOG(10) << "current_step:" << *current_step + << " < rampup_begin_step:" << rampup_begin_step + << " so does't use dgc"; + return; + } + + float ratio = + 1 - get_period_sparcity(sparsity, static_cast(*current_step), + rampup_step); + PADDLE_ENFORCE(ratio > 0.0 && ratio < 1.0); + int k = static_cast(g->numel() * ratio); + + VLOG(10) << "m:" << m << ", use_nesterov:" << use_nesterov + << ", rampup_begin_step:" << rampup_begin_step + << ", rampup_step:" << rampup_step + << ", current_step:" << *current_step << ", ratio:" << ratio + << ", k:" << k; + + auto k_out = ctx.Output("k"); + T* k_out_data = k_out->data(); + *k_out_data = k; + + auto u_out = ctx.Output("U_out"); + auto v_out = ctx.Output("V_out"); + auto encode_grad_out = ctx.Output("EncodeGrad"); + + // FIXME(gongwb): use cublas. + auto u_out_e = framework::EigenVector::Flatten(*u_out); + auto u_e = framework::EigenVector::Flatten(*u); + auto g_e = framework::EigenVector::Flatten(*g); + auto& dev_ctx = ctx.template device_context(); + auto& eigen_ctx = *dev_ctx.eigen_device(); + if (use_nesterov) { + // u = m * (u + g) + u_out_e.device(eigen_ctx) = m * (u_e + g_e); + + // v = u + v + g + ElementwiseComputeEx, DeviceContext, T>( + ctx, u, v, 0, AddFunctor(), v_out); + + ElementwiseComputeEx, DeviceContext, T>( + ctx, g, v, 0, AddFunctor(), v_out); + } else { + // u = m * u + g + u_out_e.device(eigen_ctx) = m * u_e + g_e; + + // v = u + v + ElementwiseComputeEx, DeviceContext, T>( + ctx, u, v, 0, AddFunctor(), v_out); + } + + T* v_out_data = v_out->mutable_data(ctx.GetPlace()); + T* u_out_data = u_out->mutable_data(ctx.GetPlace()); + T* encode_grad_out_data = encode_grad_out->mutable_data( + framework::DDim{2 * k}, ctx.GetPlace()); + + int buf_size = paddle::communication::dgc::get_buffer_size(k); + auto& allocator = platform::DeviceTemporaryAllocator::Instance().Get( + ctx.GetPlace(), dev_ctx.stream()); + auto tmp_ious_data = allocator.Allocate(buf_size); + void* buf = reinterpret_cast(tmp_ious_data->ptr()); + + if (!paddle::communication::dgc::k_select( + static_cast(encode_grad_out_data), k, v_out_data, + static_cast(v_out->numel()), buf, dev_ctx.stream(), + u_out_data)) { + LOG(FATAL) << "v_out numel:" << v_out->numel(); + } + + auto grad_out = ctx.Output("Grad_out"); + math::SetConstant tset; + tset(dev_ctx, grad_out, static_cast(0)); + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 9220d35707b..c3db59563f3 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -46,8 +46,9 @@ cc_test(cpu_helper_test SRCS cpu_helper_test.cc DEPS cpu_helper) IF(WITH_GPU) set(GPU_CTX_DEPS dynload_cuda dynamic_loader) + set(dgc_deps dgc) ELSE() - set(GPU_CTX_DEPS) + set(dgc_deps) ENDIF() IF(WITH_MKLDNN) @@ -68,7 +69,8 @@ ENDIF() # memcpy depends on device_context, here add deps individually for # avoiding cycle dependencies cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc ${STREAM_CALLBACK_DEPS} - place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} temp_allocator) + place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} + temp_allocator ${dgc_deps}) if(WIN32) if(WITH_GPU AND NOT WITH_DSO) diff --git a/paddle/fluid/platform/assert.h b/paddle/fluid/platform/assert.h index 2e8fa7c1b8f..497c7b3c87f 100644 --- a/paddle/fluid/platform/assert.h +++ b/paddle/fluid/platform/assert.h @@ -37,13 +37,13 @@ limitations under the License. */ } \ } while (0) -#define PADDLE_ASSERT_MSG_CODE(e, m, c) \ - do { \ - if (!(e)) { \ - printf("%s:%d Assertion `%s` failed (%s %d).\n", __FILE__, __LINE__, \ - TOSTRING(e), m, c); \ - asm("trap;"); \ - } \ +#define PADDLE_ASSERT_MSG_CODE(e, m, c) \ + do { \ + if (!(e)) { \ + printf("%s:%d Assertion `%s` failed (%s %ld).\n", __FILE__, __LINE__, \ + TOSTRING(e), m, c); \ + asm("trap;"); \ + } \ } while (0) #else #include diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 48002a76202..61386bdf05a 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -21,6 +21,8 @@ limitations under the License. */ #include "paddle/fluid/platform/cuda_device_guard.h" #endif +#include "glog/logging.h" + namespace paddle { namespace platform { @@ -324,8 +326,17 @@ void CUDADeviceContext::Wait() const { auto& allocator = DeviceTemporaryAllocator::Instance().Get(*this); allocator.Release([this]() { - PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); - PADDLE_ENFORCE(cudaGetLastError()); + cudaError_t e_sync = cudaStreamSynchronize(stream_); + if (e_sync != 0) { + LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync) + << " errno:" << e_sync; + } + + cudaError_t e_get = cudaGetLastError(); + if (e_get != 0) { + LOG(FATAL) << "cudaGetLastError " << cudaGetErrorString(e_get) + << " errno:" << e_get; + } }); } diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index d53a4029e1b..407d1b12998 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -31,6 +31,10 @@ limitations under the License. */ #include "paddle/fluid/platform/place.h" #include "paddle/fluid/string/piece.h" +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include "dgc/dgc.h" +#endif + DEFINE_int32(paddle_num_threads, 1, "Number of threads for each paddle instance."); DEFINE_int32(multiple_of_cupti_buffer_size, 1, @@ -43,6 +47,10 @@ namespace framework { std::once_flag gflags_init_flag; std::once_flag p2p_init_flag; +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +std::once_flag dgc_init_flag; +#endif + void InitGflags(std::vector argv) { std::call_once(gflags_init_flag, [&]() { FLAGS_logtostderr = true; @@ -203,5 +211,15 @@ void InitGLOG(const std::string &prog_name) { #endif } +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +void InitDGC() { + std::call_once(dgc_init_flag, []() { + PADDLE_ENFORCE(paddle::communication::dgc::dynloadNcclLib()); + }); +} +#else +void InitDGC() {} +#endif + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/platform/init.h b/paddle/fluid/platform/init.h index 0e305946729..01d66f57dc9 100644 --- a/paddle/fluid/platform/init.h +++ b/paddle/fluid/platform/init.h @@ -30,5 +30,7 @@ void InitDevices(bool init_p2p); void InitDevices(bool init_p2p, const std::vector devices); +void InitDGC(); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 7b5e417504f..31b5dd5d7c0 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -222,6 +222,7 @@ void BindOpDesc(pybind11::module *m) { .def("attr_type", &pd::OpDesc::GetAttrType) .def("attr_names", &pd::OpDesc::AttrNames) .def("_set_attr", &pd::OpDesc::SetAttr) + .def("remove_attr", &pd::OpDesc::RemoveAttr) .def("attr", &pd::OpDesc::GetAttr) .def("set_block_attr", &pd::OpDesc::SetBlockAttr) .def("set_blocks_attr", &pd::OpDesc::SetBlocksAttr) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index dca40edf0bb..3b0939ef820 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -933,6 +933,7 @@ All parameter, weight, gradient are variables in Paddle. m.def("init_gflags", framework::InitGflags); m.def("init_glog", framework::InitGLOG); + m.def("init_dgc", framework::InitDGC); m.def("init_devices", [](bool init_p2p) { framework::InitDevices(init_p2p); }); diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 85e1916a3a5..4a5301b436e 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1202,6 +1202,9 @@ class Operator(object): """ self._update_desc_attr(name, val) + def _remove_attr(self, name): + self.desc.remove_attr(name) + def _update_desc_attr(self, name, val): """ Update the value of desc's attribute by attribute's name. @@ -2725,6 +2728,10 @@ class Program(object): self._trainers_endpoints = [] # the distributed lookup table names self._distributed_lookup_table = None + + # use Deep gradient comrepssion or not + self._enable_dgc = False + # @deprecated(the python memory optimize transpiler is deprecated) # whether the program is optimized by memory_optimize_transpiler self.__is_mem_optimized = False @@ -2775,6 +2782,15 @@ class Program(object): def set_op_role_var(self, var_name): self._op_role_var = [var_name] + @contextlib.contextmanager + def _backward_role_guard(self): + tmp_role = self._current_role + + OpRole = core.op_proto_and_checker_maker.OpRole + self._current_role = OpRole.Backward + yield + self._current_role = tmp_role + @signature_safe_contextmanager def _optimized_guard(self, param_and_grads): """ diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index c0deb5eacca..e21f303a3e0 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -17,7 +17,7 @@ from __future__ import print_function from collections import defaultdict from .wrapped_decorator import signature_safe_contextmanager -from paddle.fluid.framework import Program, Variable, name_scope, default_main_program +from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table from . import framework @@ -31,13 +31,17 @@ from .layer_helper import LayerHelper from .layers import ops from .regularizer import append_regularization_ops from .imperative import base as imperative_base +from paddle.fluid import core +from paddle.fluid.layers import tensor +from functools import reduce +import copy __all__ = [ 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl', 'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer', 'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer', 'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'LarsMomentum', - 'LarsMomentumOptimizer' + 'LarsMomentumOptimizer', 'DGCMomentumOptimizer' ] @@ -294,6 +298,9 @@ class Optimizer(object): outputs={"ParamOut": param_and_grad[0]}) return new_param_grads, (table_param, table_grad), sgd_op + def _append_dgc_ops(self, param_and_grad): + pass + def backward(self, loss, startup_program=None, @@ -415,6 +422,9 @@ class Optimizer(object): with program_guard(program, startup_program): params_grads = self.backward(loss, startup_program, parameter_list, no_grad_set) + # Note: since we can't use all_reduce_op now, + # dgc_op should be the last op of one grad. + self._append_dgc_ops(params_grads) optimize_ops = self.apply_gradients(params_grads) return optimize_ops, params_grads @@ -552,6 +562,264 @@ class MomentumOptimizer(Optimizer): return momentum_op +class DGCMomentumOptimizer(MomentumOptimizer): + """ + + Original paper is https://arxiv.org/abs/1712.01887 + + DGC reduce the communication bandwidth by sending only the important gradients (sparse update):\ + only gradients larger than a threshold are transmitted. + + To avoid losing information, DGC accumulate the rest of the gradients locally. + + Eventually, these gradients become large enough to be transmitted. + + Thus, DGC send the large gradients immediately but eventually send all of the gradients over time. + + To ensure no loss of accuracy, DGC employs momentum correc-tionandlocal gradient clipping on top of the gradient sparsification to maintain model performance. + + DGC also uses momentum factor masking and warmup training to overcome the staleness problem caused by reduced communication. + + This optimizer will do two things: + + 1. Compress the gradient by get TopK import value from tensor \ + and use it for allreduce to reduce network bandwidth. + + 2. Call momentum to optimize on the cost. + + Args: + learning_rate (float|Variable): the learning rate used to update parameters. \ + Can be a float value or a Variable with one float value as data element. + momentum (float): Momentum factor. + rampup_begin_step (int): The begining step from which gradient compression is implemented. + rampup_step (int): How long it use the sparsity periods. Default is 1. + for example: If the sparsity is [0.75, 0.9375, 0.984375, 0.996, 0.999], and the rampup_step is 5, \ + it will use 0.75 at 0 step, and 0.9375 at 1 step, and so on. And when reach sparsity array ends, \ + it will use 0.999 then and after. + sparsity (list[float]): Get top important element from gradient tensor, the ratio is (1 - current sparsity). + use_nesterov (bool): Enables Nesterov momentum. True means use nesterov. + local_grad_clip_norm (float): Clip norm value if needed. + num_trainers: The number of training node. + regularization: A Regularizer, such as fluid.regularizer.L2DecayRegularizer. + name: A optional name prefix. + + Examples: + .. code-block:: python + + optimizer = fluid.optimizer.DGCMomentumOptimizer( + learning_rate=fluid.layers.piecewise_decay( + boundaries=bd, values=lr), + momentum=0.9, + rampup_begin_step=1252, + regularization=fluid.regularizer.L2Decay(1e-4)) + optimizer.minimize(cost) + + """ + + def __init__(self, + learning_rate, + momentum, + rampup_begin_step, + rampup_step=1, + sparsity=[0.999], + use_nesterov=False, + local_grad_clip_norm=None, + num_trainers=None, + regularization=None, + name=None): + self._sparsity = sparsity + self._rampup_step = rampup_step + self._rampup_step_var = None + + self._rampup_begin_step = rampup_begin_step + self._rampup_begin_step_var = None + + self._global_step_var = None + self._local_grad_clip_norm = None + self._clip_norm = None + + if local_grad_clip_norm is not None: + assert isinstance(num_trainers, int) + assert isinstance(local_grad_clip_norm, float) + assert num_trainers > 0 + + self._local_grad_clip_norm = local_grad_clip_norm + self._num_trainers = num_trainers + self._clip_norm = local_grad_clip_norm / (num_trainers * + num_trainers) + + super(DGCMomentumOptimizer, self).__init__( + learning_rate, momentum, use_nesterov, regularization, name) + + core.init_dgc() + + def _add_auto_increment_var(self, counter_name, begin, step=1): + helper = LayerHelper('global_step_counter') + counter, is_new_var = helper.create_or_get_global_variable( + name=counter_name, dtype='float32', shape=[1], persistable=True) + if is_new_var: + helper.set_variable_initializer( + counter, + initializer=Constant( + value=float(begin - 1), force_cpu=True)) + helper.main_program.global_block()._prepend_op( + type='increment', + inputs={'X': [counter]}, + outputs={'Out': [counter]}, + attrs={'step': float(step)}, + stop_gradient=True) + counter.stop_gradient = True + + return counter + + def _append_dgc_ops(self, param_and_grads): + start_program = default_startup_program() + main_program = default_main_program() + main_program._enable_dgc = True + + # step counter + self._global_step_var = self._add_auto_increment_var( + counter_name='__g_dgc_counter__', begin=0) + + # rampup begin step var for all_reduce_op_handle + self._rampup_begin_step_var = tensor.create_global_var( + shape=[1], + dtype=core.VarDesc.VarType.FP32, + persistable=True, + name='__g_rampup_begin_step__', + value=self._rampup_begin_step * 1.0, + force_cpu=True) + + for param_var, grad_var in param_and_grads: + var_numel = reduce(lambda x, y: x * y, param_var.shape) + if var_numel < 16384 or \ + param_var.type == core.VarDesc.VarType.SELECTED_ROWS or \ + grad_var.type == core.VarDesc.VarType.SELECTED_ROWS or \ + param_var.dtype != core.VarDesc.VarType.FP32 : + continue + + u_var = tensor.create_global_var( + shape=param_var.shape, + dtype=param_var.dtype, + persistable=True, + name=param_var.name + "__dgc_u__", + value=0.0) + v_var = tensor.create_global_var( + shape=param_var.shape, + dtype=param_var.dtype, + persistable=True, + name=param_var.name + "__dgc_v__", + value=0.0) + + k_var = tensor.create_global_var( + shape=[1], + dtype=param_var.dtype, + persistable=True, + name=param_var.name + "__dgc_k__", + value=0.0, + force_cpu=True) + + encoded_var = tensor.create_global_var( + shape=[1], + dtype=param_var.dtype, + persistable=True, + name=param_var.name + "__dgc_encoded__", + value=0.0, + force_cpu=False) + + # del back oprolevarname + op_maker = core.op_proto_and_checker_maker + backward = core.op_proto_and_checker_maker.OpRole.Backward + for op in main_program.global_block().ops: + if not self._is_the_backward_op(op): + continue + + var_attr = op.all_attrs()[op_maker.kOpRoleVarAttrName()] + if param_var.name not in var_attr: + continue + + var_attr.remove(param_var.name) + var_attr.remove(grad_var.name) + if len(var_attr) > 1: + op._set_attr(op_maker.kOpRoleVarAttrName(), var_attr) + else: + op._remove_attr(op_maker.kOpRoleVarAttrName()) + + clip_var = grad_var + if self._local_grad_clip_norm is not None: + clip_var = self._append_clip_norm(grad_var, self._clip_norm) + self._dgc_op(param_var, clip_var, grad_var, u_var, v_var, k_var, + encoded_var) + + def _is_the_backward_op(self, op): + op_maker = core.op_proto_and_checker_maker + backward = core.op_proto_and_checker_maker.OpRole.Backward + if op_maker.kOpRoleVarAttrName() in op.attr_names and \ + int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(backward): + return True + return False + + def _clip_by_norm(self, x, max_norm, name=None): + args = {'x': x, 'max_norm': max_norm, 'name': name} + + helper = LayerHelper("dgc_clip_by_norm_op", **args) + + if name is None: + name = unique_name.generate(".".join([helper.name, 'tmp'])) + + out = helper.create_variable( + type=x.type, name=name, dtype=x.dtype, persistable=False) + + helper.append_op( + type="clip_by_norm", + inputs={"X": x, + "current_step": self._global_step_var}, + attrs={ + "max_norm": max_norm, + "rampup_begin_step": float(self._rampup_begin_step) + }, + outputs={"Out": out}) + return out + + def _append_clip_norm(self, grad_var, clip_norm): + with grad_var.block.program._backward_role_guard(): + return self._clip_by_norm( + x=grad_var, max_norm=clip_norm, name=grad_var.name + "@DGC") + + def _dgc_op(self, param_var, clip_var, grad_var, u_var, v_var, k_var, + encoded_var): + block = framework.default_main_program().global_block() + op_maker = core.op_proto_and_checker_maker + dgc_op = block.append_op( + type="dgc", + inputs={ + "U": u_var, + "V": v_var, + "Grad": clip_var, + "current_step": self._global_step_var + }, + outputs={ + "U_out": u_var, + "V_out": v_var, + "EncodeGrad": encoded_var, + "k": k_var, + "Grad_out": grad_var + }, + attrs={ + "m": self._momentum, + "sparsity": self._sparsity, + "use_nesterov": self._use_nesterov, + "rampup_begin_step": float(self._rampup_begin_step), + "rampup_step": float(self._rampup_step) + }, + stop_gradient=True) + + backward = op_maker.OpRole.Backward + dgc_op._set_attr(op_maker.kOpRoleAttrName(), backward) + dgc_op._set_attr(op_maker.kOpRoleVarAttrName(), + [param_var.name, grad_var.name]) + + class LarsMomentumOptimizer(Optimizer): """ Momentum optimizer with LARS support diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 6702fc808b1..6b88e7a99fd 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -103,6 +103,12 @@ class ParallelExecutor(object): ) if use_cuda else framework.cpu_places() self._scope = scope if scope is not None else executor.global_scope() + if main_program is not None and main_program._enable_dgc: + assert build_strategy.reduce_strategy == BuildStrategy.ReduceStrategy.AllReduce + assert num_trainers * len( + self._places) > 1, "dgc is not useful for single card training" + assert use_cuda + main_program = main_program if main_program is not None \ else framework.default_main_program() diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index cefa2b49197..d139feac6ff 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -70,6 +70,7 @@ list(REMOVE_ITEM TEST_OPS test_dist_transpiler) list(REMOVE_ITEM TEST_OPS test_parallel_executor_crf) list(REMOVE_ITEM TEST_OPS test_parallel_executor_fetch_feed) list(REMOVE_ITEM TEST_OPS test_dist_se_resnext) +list(REMOVE_ITEM TEST_OPS test_dgc_op) list(REMOVE_ITEM TEST_OPS test_dist_se_resnext_nccl) list(REMOVE_ITEM TEST_OPS test_dist_transformer) list(REMOVE_ITEM TEST_OPS test_parallel_executor_transformer) @@ -97,6 +98,7 @@ if(WITH_DISTRIBUTE) set_tests_properties(test_dist_mnist PROPERTIES TIMEOUT 200) set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200) py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext) + py_test_modules(test_dgc_op MODULES test_dgc_op) set_tests_properties(test_dist_se_resnext PROPERTIES TIMEOUT 1000) py_test_modules(test_dist_se_resnext_nccl MODULES test_dist_se_resnext_nccl) set_tests_properties(test_dist_se_resnext_nccl PROPERTIES TIMEOUT 1000) @@ -107,16 +109,20 @@ if(WITH_DISTRIBUTE) endif(NOT APPLE) # py_test_modules(test_dist_transpiler MODULES test_dist_transpiler) endif() + py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL) py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL) set_tests_properties(test_parallel_executor_fetch_feed PROPERTIES TIMEOUT 450) py_test_modules(test_parallel_executor_transformer MODULES test_parallel_executor_transformer SERIAL) + if(NOT WIN32) -py_test_modules(test_ir_memory_optimize_transformer MODULES test_ir_memory_optimize_transformer SERIAL) + py_test_modules(test_ir_memory_optimize_transformer MODULES test_ir_memory_optimize_transformer SERIAL) endif() + if(NOT APPLE) py_test_modules(test_image_classification_resnet MODULES test_image_classification_resnet SERIAL) endif() + if(CMAKE_BUILD_TYPE STREQUAL "Debug") # change the timeout from 600 to 2200, because in debug mode, this test need more time. set_tests_properties(test_parallel_executor_seresnext PROPERTIES TIMEOUT 2200) diff --git a/python/paddle/fluid/tests/unittests/dist_mnist.py b/python/paddle/fluid/tests/unittests/dist_mnist.py index 1c45a10a9dd..c598260e13c 100644 --- a/python/paddle/fluid/tests/unittests/dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/dist_mnist.py @@ -73,7 +73,7 @@ def cnn_model(data): class TestDistMnist2x2(TestDistRunnerBase): - def get_model(self, batch_size=2): + def get_model(self, batch_size=2, use_dgc=False): # Input data images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE) label = fluid.layers.data(name='label', shape=[1], dtype='int64') @@ -93,7 +93,11 @@ class TestDistMnist2x2(TestDistRunnerBase): # TODO(typhoonzero): fix distributed adam optimizer # opt = fluid.optimizer.AdamOptimizer( # learning_rate=0.001, beta1=0.9, beta2=0.999) - opt = fluid.optimizer.Momentum(learning_rate=self.lr, momentum=0.9) + if not use_dgc: + opt = fluid.optimizer.Momentum(learning_rate=self.lr, momentum=0.9) + else: + opt = fluid.optimizer.DGCMomentumOptimizer( + learning_rate=self.lr, momentum=0.9, rampup_begin_step=0) # Reader train_reader = paddle.batch( diff --git a/python/paddle/fluid/tests/unittests/dist_se_resnext.py b/python/paddle/fluid/tests/unittests/dist_se_resnext.py index c3d84dba0ae..a2fd61e2387 100644 --- a/python/paddle/fluid/tests/unittests/dist_se_resnext.py +++ b/python/paddle/fluid/tests/unittests/dist_se_resnext.py @@ -210,7 +210,7 @@ class SE_ResNeXt(): class DistSeResneXt2x2(TestDistRunnerBase): - def get_model(self, batch_size=2): + def get_model(self, batch_size=2, use_dgc=False): # Input data image = fluid.layers.data( name="data", shape=[3, 224, 224], dtype='float32') @@ -237,11 +237,19 @@ class DistSeResneXt2x2(TestDistRunnerBase): base_lr = 0.1 lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] - optimizer = fluid.optimizer.Momentum( - learning_rate=fluid.layers.piecewise_decay( - boundaries=bd, values=lr), - momentum=0.9, - regularization=fluid.regularizer.L2Decay(1e-4)) + if not use_dgc: + optimizer = fluid.optimizer.Momentum( + learning_rate=fluid.layers.piecewise_decay( + boundaries=bd, values=lr), + momentum=0.9, + regularization=fluid.regularizer.L2Decay(1e-4)) + else: + optimizer = fluid.optimizer.DGCMomentumOptimizer( + learning_rate=fluid.layers.piecewise_decay( + boundaries=bd, values=lr), + momentum=0.9, + rampup_begin_step=0, + regularization=fluid.regularizer.L2Decay(1e-4)) optimizer.minimize(avg_cost) # Reader diff --git a/python/paddle/fluid/tests/unittests/test_dgc_op.py b/python/paddle/fluid/tests/unittests/test_dgc_op.py new file mode 100644 index 00000000000..04766dd8584 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dgc_op.py @@ -0,0 +1,138 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import numpy as np +import paddle.fluid.core as core +from paddle.fluid.op import Operator +import paddle.fluid as fluid + +g_array_size = 102400 + + +class TestDGCOp(unittest.TestCase): + def setup(self, place, array_size=g_array_size): + size = array_size + np.random.seed(5) # fix seed + + self.scope = fluid.global_scope() + self.place = place + print("place:", place) + + # numpy data + # inputs: U, V, Grad, current_step + self.u_name = "U" + self.u = np.random.random(size).astype("float32") + + self.v_name = "V" + self.v = np.random.random(size).astype("float32") + + self.grad_name = "Grad" + self.grad = np.random.random(size).astype("float32") + + self.current_step_name = "current_step" + self.current_step = np.full((1), 0.0).astype("float32") + + # output: U_out, V_out, EncodeGrad, GradLocal_out + self.encode_grad_name = "EncodeGrad" + self.k_name = "k" + self.k = np.full((1), 0.0).astype("float32") + + # scope data + self.u_tensor = self.scope.var(self.u_name).get_tensor() + self.u_tensor.set(self.u, place) + + self.v_tensor = self.scope.var(self.v_name).get_tensor() + self.v_tensor.set(self.v, place) + + self.grad_tensor = self.scope.var(self.grad_name).get_tensor() + self.grad_tensor.set(self.grad, place) + + self.encode_grad_tensor = self.scope.var( + self.encode_grad_name).get_tensor() + + self.current_step_tensor = self.scope.var( + self.current_step_name).get_tensor() + self.current_step_tensor.set(self.current_step, core.CPUPlace()) + + self.k_tensor = self.scope.var(self.k_name).get_tensor() + self.k_tensor.set(self.k, core.CPUPlace()) + + def check(self, actual_t, expect_t, place, out_name, atol=1e-5): + self.assertTrue( + np.allclose( + actual_t, expect_t, atol=atol), + "Output (" + out_name + ") has diff at " + str(place) + "\nExpect " + + str(expect_t) + "\n" + "But Got" + str(actual_t)) + + def test_run_and_check(self): + self.setup(place=core.CUDAPlace(0)) + kwargs = { + # inputs + 'U': self.u_name, + 'V': self.v_name, + 'Grad': self.grad_name, + 'current_step': self.current_step_name, + + # outputs + 'U_out': self.u_name, + 'V_out': self.v_name, + 'EncodeGrad': self.encode_grad_name, + 'Grad_out': self.grad_name, + 'k': self.k_name, + + # attrs + 'm': 0.9, + 'sparsity': [0.75, 0.9375, 0.984375, 0.996, 0.999], + 'use_nesterov': True, + 'rampup_begin_step': float(0.0), + 'rampup_step': float(10.0), + } + + dgc_op = Operator('dgc', **kwargs) + + #atol = 1e-6 + dgc_op.run(self.scope, self.place) + + u_out = np.array(self.u_tensor) + v_out = np.array(self.v_tensor) + grad_out = np.array(self.grad_tensor) + encode_grad_out = np.array(self.encode_grad_tensor) + k = int(np.array(self.k_tensor)[0]) + + print("u_out:", u_out[0:20]) + print("v_out:", v_out[0:20]) + print("encode_grad_out:", encode_grad_out) + print("k_out:", k) + + self.assertEqual(k, int(g_array_size * 0.25)) + + index = encode_grad_out[0:k].view(dtype=np.int32) + value = encode_grad_out[k:2 * k] + + acl = 1e-7 + + for i in range(0, k): + self.assertAlmostEqual(u_out[index[i]], 0.0) + self.assertAlmostEqual(v_out[index[i]], 0.0) + + a_min = np.amin(value) + dangling = [x for x in v_out if x > a_min] + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 969f5cb63c9..9c0efe6d905 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -36,7 +36,8 @@ class TestDistRunnerBase(object): def get_model(self, batch_size=DEFAULT_BATCH_SIZE, lr=0.1, - single_device=False): + single_device=False, + use_dgc=False): raise NotImplementedError( "get_model should be implemented by child classes.") @@ -82,6 +83,9 @@ class TestDistRunnerBase(object): if args.nccl2_reduce_layer_local_run: test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \ self.get_model(batch_size=args.batch_size, single_device=True) + elif args.use_dgc: + test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \ + self.get_model(batch_size=args.batch_size, use_dgc=args.use_dgc) else: test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \ self.get_model(batch_size=args.batch_size) @@ -200,6 +204,7 @@ def runtime_main(test_class): parser.add_argument('--sync_mode', action='store_true') parser.add_argument('--mem_opt', action='store_true') parser.add_argument('--use_cuda', action='store_true') + parser.add_argument('--use_dgc', action='store_true') parser.add_argument('--use_reduce', action='store_true') parser.add_argument('--dc_asgd', action='store_true') parser.add_argument( @@ -235,6 +240,7 @@ class TestDistBase(unittest.TestCase): def _after_setup_config(self): if self._enforce_place == "CPU": self.__use_cuda = False + self._use_dgc = False elif self._enforce_place == "GPU": self.__use_cuda = True else: @@ -242,6 +248,10 @@ class TestDistBase(unittest.TestCase): self.__use_cuda = True else: self.__use_cuda = False + self._use_dgc = False + + if self._use_reduce: + assert not self._use_dgc def setUp(self): self._trainers = 2 @@ -264,6 +274,7 @@ class TestDistBase(unittest.TestCase): # test, reduce check this argument everywhere. self._nccl2_reduce_layer = False self._lr = 0.001 + self._use_dgc = False self._setup_config() self._after_setup_config() @@ -506,6 +517,9 @@ class TestDistBase(unittest.TestCase): env0 = {'CPU_NUM': '1'} env1 = {'CPU_NUM': '1'} + if self._use_dgc: + tr0_cmd += " --use_dgc" + tr1_cmd += " --use_dgc" if self._mp_mode: env0 = {"FLAGS_selected_gpus": "0"} env1 = {"FLAGS_selected_gpus": "1"} diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist.py b/python/paddle/fluid/tests/unittests/test_dist_mnist.py index 030860ec792..b9d2f6db394 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist.py @@ -39,6 +39,20 @@ class TestDistMnistNCCL2(TestDistBase): self.check_with_place("dist_mnist.py", delta=1e-5) +class TestDistMnistNCCL2DGC(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._use_reduce = False + self._use_reader_alloc = False + self._nccl2_mode = True + self._use_dgc = True + + def test_dist_train(self): + import paddle.fluid as fluid + if fluid.core.is_compiled_with_cuda(): + self.check_with_place("dist_mnist.py", delta=1e-5) + + class TestDistMnist2x2Lars(TestDistBase): def _setup_config(self): self._sync_mode = True diff --git a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py index 28602d3251a..4e9ca01f43e 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py +++ b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py @@ -60,5 +60,20 @@ class TestDistSeResneXt2x2Async(TestDistBase): self.check_with_place("dist_se_resnext.py", delta=100) +class TestDistSeResnetNCCL2DGC(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._use_reduce = False + self._use_reader_alloc = False + self._nccl2_mode = True + self._use_dgc = True + + @skip_ci + def test_dist_train(self): + import paddle.fluid as fluid + if fluid.core.is_compiled_with_cuda(): + self.check_with_place("dist_se_resnext.py", delta=30) + + if __name__ == "__main__": unittest.main() -- GitLab