diff --git a/doc/fluid/design/dist_train/async_update.md b/doc/fluid/design/dist_train/async_update.md index 6a0835b761b69030ba30697e6e8863928efbf57f..248d2ec18dafdecac9184527638754b6ba4d85b8 100644 --- a/doc/fluid/design/dist_train/async_update.md +++ b/doc/fluid/design/dist_train/async_update.md @@ -4,34 +4,37 @@ For the typical synchronous distributed training, some significant steps are as follows: -1. A Trainer will compute the gradients and SEND them to the Parameter Server(PServer) nodes. -1. After the PServer node received gradients came from all the Trainers, It will aggregate the +1. A trainer process will compute the gradients and **send** them to the parameter server (PS) nodes. +1. After the PS node received gradients came from all the Trainers, It will aggregate the gradient variables for the same parameter into one gradient variable and then apply the aggregated gradient to the respective parameter, finally using an optimize algorithms(SGD, Monument...) to update the parameters. -1. The Trainer would wait for the PServers finished the optimize stage, and GET the parameters from PServer, +1. The Trainer would wait for the PS finished the optimize stage, and GET the parameters from PS, so all the Trainers would get the same parameters. -In the synchronously distributed training, there should be a `Barrier` to synchronise the -parameters after the optimizing stage. The performance of a distributed training job would -depend on the slowest node if there were hundreds or thousands of training nodes in a -Job, the performance of synchronously distributed training might be very poor because of -the slow node. So this design doc would introduce an approach to implement -*asynchronously* distributed training in PaddlePaddle Fluid. +In Synchronous Distributed Training, there is a **barrier** on each PS to wait until all trainers processes +have completed running current mini-batch. After that, all trainers can continue to run the next +mini-batch. So, we can find that the overall performance of Synchronous Distributed Training depends +on the slowest node. + +In Asynchronous Distributed Training, we don't need to wait for a global mini-bach, the optimizer on +the PS will run immediately when the gradient is uploaded to the PS from one trainer. This mode would +train such models that achieve scaling, better throughput. In this design doc, we will introduce how to +implement the Asynchronous Distributed Training base on PaddlePaddle Fluid. ## Design -As the figure above, we describe a global view of asynchronously update process and use +As the figure above, we describe a global view of the asynchronous update process and use the parameter `w1` as an example to introduce the steps: 1. For each gradient variables, they may distribute on different GPU card and aggregate them while they are all calculated. -1. Split the gradient variable into multiple blocks according to the number of PServer +1. Split the gradient variable into multiple blocks according to the number of PS instances and then send them. -1. PServer would run an `Optimize Block` using a specified optimize algorithm to update +1. PS would run an `Optimize Block` using a specified optimize algorithm to update the specified parameter. -1. The trainer will fetch latest parameter from PServer before running forward Op which depends +1. The trainer will fetch the latest parameter from PS before running forward Op which depends on the specified parameter. 1. Broadcast the received variable into multiple GPU cards and continue to run the next mini-batch. @@ -40,8 +43,8 @@ mini-batch. - For the multiple devices distributed training, we need to aggregate the gradient variables which placed on different devices firstly and then schedule a `SendVars` Operator to -send the gradient variables to the multiple PServer instances. -- Schedule `FetchVars` operator to fetch the latest parameter from PServer before running +send the gradient variables to the multiple PS instances. +- Schedule `FetchVars` operator to fetch the latest parameter from PS before running the forward ops. - There could be a large number of gradient variables to be sent, so we need to use another thread pool(IO Threadpool) whose a number of the schedulable threads is larger than the diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index ab71e0e63ce18e4f221a046eeb2c39499c1c3816..ed1e70c6460b513c1d2e1add18ac037f71d36944 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -5,11 +5,11 @@ proto_library(framework_proto SRCS framework.proto) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim) - +cc_library(data_type SRCS data_type.cc DEPS framework_proto ddim device_context) if(WITH_GPU) - nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS ddim place memory device_context framework_proto) + nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS place memory data_type) else() - cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS ddim place memory device_context framework_proto) + cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS place memory data_type) endif() cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) diff --git a/paddle/fluid/framework/data_type.cc b/paddle/fluid/framework/data_type.cc new file mode 100644 index 0000000000000000000000000000000000000000..b9c90cb0c32f337ba82ce1eaa5b43199540491ef --- /dev/null +++ b/paddle/fluid/framework/data_type.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/data_type.h" +#include +#include +#include + +namespace paddle { +namespace framework { + +struct DataTypeMap { + std::unordered_map cpp_to_proto_; + std::unordered_map proto_to_cpp_; + std::unordered_map proto_to_str_; + std::unordered_map cpp_to_size_; +}; + +static DataTypeMap* InitDataTypeMap(); +static DataTypeMap& gDataTypeMap() { + static DataTypeMap* g_data_type_map_ = InitDataTypeMap(); + return *g_data_type_map_; +} + +template +static inline void RegisterType(DataTypeMap* map, + proto::VarType::Type proto_type, + const std::string& name) { + map->proto_to_cpp_.emplace(static_cast(proto_type), typeid(T)); + map->cpp_to_proto_.emplace(typeid(T), proto_type); + map->proto_to_str_.emplace(static_cast(proto_type), name); + map->cpp_to_size_.emplace(typeid(T), sizeof(T)); +} + +static DataTypeMap* InitDataTypeMap() { + auto retv = new DataTypeMap(); + +#define RegType(cc_type, proto_type) \ + RegisterType(retv, proto_type, #cc_type) + + // NOTE: Add your customize type here. + RegType(platform::float16, proto::VarType::FP16); + RegType(float, proto::VarType::FP32); + RegType(double, proto::VarType::FP64); + RegType(int, proto::VarType::INT32); + RegType(int64_t, proto::VarType::INT64); + RegType(bool, proto::VarType::BOOL); + RegType(size_t, proto::VarType::SIZE_T); + RegType(int16_t, proto::VarType::INT16); + +#undef RegType + return retv; +} + +proto::VarType::Type ToDataType(std::type_index type) { + auto it = gDataTypeMap().cpp_to_proto_.find(type); + if (it != gDataTypeMap().cpp_to_proto_.end()) { + return it->second; + } + PADDLE_THROW("Not support %s as tensor type", type.name()); +} + +std::type_index ToTypeIndex(proto::VarType::Type type) { + auto it = gDataTypeMap().proto_to_cpp_.find(static_cast(type)); + if (it != gDataTypeMap().proto_to_cpp_.end()) { + return it->second; + } + PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type", + static_cast(type)); +} + +std::string DataTypeToString(const proto::VarType::Type type) { + auto it = gDataTypeMap().proto_to_str_.find(static_cast(type)); + if (it != gDataTypeMap().proto_to_str_.end()) { + return it->second; + } + PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type", + static_cast(type)); +} + +size_t SizeOfType(std::type_index type) { + auto it = gDataTypeMap().cpp_to_size_.find(type); + if (it != gDataTypeMap().cpp_to_size_.end()) { + return it->second; + } + PADDLE_THROW("Not support %s as tensor type", type.name()); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index 2a528eb3aa562568c92059250f2c9bc5a75ec103..4b9f572ec5f1cda71c8b8dd8fae54b42e9f16f7a 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -17,51 +17,14 @@ limitations under the License. */ #include #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/platform/enforce.h" + #include "paddle/fluid/platform/float16.h" namespace paddle { namespace framework { -inline proto::VarType::Type ToDataType(std::type_index type) { - if (typeid(platform::float16).hash_code() == type.hash_code()) { - return proto::VarType::FP16; - } else if (typeid(const float).hash_code() == type.hash_code()) { - // CPPLint complains Using C-style cast. Use static_cast() instead - // One fix to this is to replace float with const float because - // typeid(T) == typeid(const T) - // http://en.cppreference.com/w/cpp/language/typeid - return proto::VarType::FP32; - } else if (typeid(const double).hash_code() == type.hash_code()) { - return proto::VarType::FP64; - } else if (typeid(const int).hash_code() == type.hash_code()) { - return proto::VarType::INT32; - } else if (typeid(const int64_t).hash_code() == type.hash_code()) { - return proto::VarType::INT64; - } else if (typeid(const bool).hash_code() == type.hash_code()) { - return proto::VarType::BOOL; - } else { - PADDLE_THROW("Not supported"); - } -} - -inline std::type_index ToTypeIndex(proto::VarType::Type type) { - switch (type) { - case proto::VarType::FP16: - return typeid(platform::float16); - case proto::VarType::FP32: - return typeid(float); - case proto::VarType::FP64: - return typeid(double); - case proto::VarType::INT32: - return typeid(int); - case proto::VarType::INT64: - return typeid(int64_t); - case proto::VarType::BOOL: - return typeid(bool); - default: - PADDLE_THROW("Not support type %d", type); - } -} +extern proto::VarType::Type ToDataType(std::type_index type); +extern std::type_index ToTypeIndex(proto::VarType::Type type); template inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { @@ -89,32 +52,12 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { } } -inline std::string DataTypeToString(const proto::VarType::Type type) { - switch (type) { - case proto::VarType::FP16: - return "float16"; - case proto::VarType::FP32: - return "float32"; - case proto::VarType::FP64: - return "float64"; - case proto::VarType::INT16: - return "int16"; - case proto::VarType::INT32: - return "int32"; - case proto::VarType::INT64: - return "int64"; - case proto::VarType::BOOL: - return "bool"; - default: - PADDLE_THROW("Not support type %d", type); - } -} - +extern std::string DataTypeToString(const proto::VarType::Type type); +extern size_t SizeOfType(std::type_index type); inline std::ostream& operator<<(std::ostream& out, const proto::VarType::Type& type) { out << DataTypeToString(type); return out; } - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index 96f53dc1bc8747e1b8ea84166614f98ff363ae5e..d2558f111f49139b33f921f7260b41830279edc8 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -101,6 +101,8 @@ message VarType { FP16 = 4; FP32 = 5; FP64 = 6; + // Tensor is used in C++. + SIZE_T = 19; // Other types that may need additional descriptions LOD_TENSOR = 7; diff --git a/paddle/fluid/framework/op_kernel_type_test.cc b/paddle/fluid/framework/op_kernel_type_test.cc index d37ce149ce3df63692b41289bb03448d54e392f5..db95861c510b52a5b52229541434e6437d3fb9f4 100644 --- a/paddle/fluid/framework/op_kernel_type_test.cc +++ b/paddle/fluid/framework/op_kernel_type_test.cc @@ -27,7 +27,7 @@ TEST(OpKernelType, ToString) { LibraryType::kCUDNN); ASSERT_EQ(paddle::framework::KernelTypeToString(op_kernel_type), - "data_type[float32]:data_layout[NCHW]:place[CPUPlace]:library_type[" + "data_type[float]:data_layout[NCHW]:place[CPUPlace]:library_type[" "CUDNN]"); } diff --git a/paddle/fluid/framework/tensor_impl.h b/paddle/fluid/framework/tensor_impl.h index f49d1a47a325b2aac6185073203df124be18b54d..0a1db7758bd9ec0dac133efcbf495de1d690021d 100644 --- a/paddle/fluid/framework/tensor_impl.h +++ b/paddle/fluid/framework/tensor_impl.h @@ -13,54 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" namespace paddle { namespace framework { - -template -struct SizeOfTypeFunctor; - -template -struct SizeOfTypeFunctor { - size_t operator()(std::type_index type) const { - if (typeid(T).hash_code() == type.hash_code()) { - return sizeof(T); - } else { - return 0UL; - } - } -}; - -template <> -struct SizeOfTypeFunctor<> { - size_t operator()(std::type_index type) const { return 0UL; } -}; - -template -struct SizeOfTypeFunctor { - size_t operator()(std::type_index type) const { - SizeOfTypeFunctor head; - size_t head_size = head(type); - if (head_size != 0) { - return head_size; - } - SizeOfTypeFunctor tail; - return tail(type); - } -}; - -static inline size_t SizeOfType(std::type_index type) { - SizeOfTypeFunctor - functor; - size_t size = functor(type); - PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name()); - return size; -} - +extern size_t SizeOfType(std::type_index type); inline void Tensor::check_memory_size() const { PADDLE_ENFORCE_NOT_NULL( holder_, "Tensor holds no memory. Call Tensor::mutable_data first."); diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index e6ee28ea8d920ef80fead258a9bd0d5f6762c879..d09f8479b765ad26cc202bfdb2692828213c7956 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -306,7 +306,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() { } RequestPrefetch* prefetch = new RequestPrefetch(&service_, cq_prefetch_.get(), sync_mode_, scope_, - dev_ctx_, executor_, program_, prefetch_ctx_); + dev_ctx_, executor_, program_, prefetch_ctx_.get()); VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); } diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index 18f1bc53d0f561f412a5bbbe018bc3d427ac9ef9..238aaa29634a7eff65429c27aa3538a185723eb2 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -64,8 +64,9 @@ class AsyncGRPCServer final { void SetExecutor(framework::Executor *executor) { executor_ = executor; } - void SetPrefetchPreparedCtx(framework::ExecutorPrepareContext *prepared) { - prefetch_ctx_ = prepared; + void SetPrefetchPreparedCtx( + std::unique_ptr prepared) { + prefetch_ctx_.reset(prepared.release()); } int GetSelectedPort() const { return selected_port_; } @@ -116,7 +117,7 @@ class AsyncGRPCServer final { std::unique_ptr t_get_; std::unique_ptr t_prefetch_; - framework::ExecutorPrepareContext *prefetch_ctx_; + std::unique_ptr prefetch_ctx_; framework::ProgramDesc *program_; framework::Executor *executor_; int selected_port_; diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc index 25b95d608d10d6e456d5f563ce9fbe35d812cb0f..b8db0ad987cdfaec1fc9236c3f26e88891376dce 100644 --- a/paddle/fluid/operators/detail/grpc_server_test.cc +++ b/paddle/fluid/operators/detail/grpc_server_test.cc @@ -100,7 +100,7 @@ void StartServer(const std::string& endpoint) { InitTensorsOnServer(&scope, &place, 10); rpc_service_->SetProgram(&program); - rpc_service_->SetPrefetchPreparedCtx(prepared.get()); + rpc_service_->SetPrefetchPreparedCtx(std::move(prepared)); rpc_service_->SetDevCtx(&ctx); rpc_service_->SetScope(&scope); rpc_service_->SetExecutor(&exe); diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index a29e0cd52cfccf242a6490822234045e6eb66c0f..abc88d3eb1514e159f4a880f44ecc0c0960a73d9 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -322,8 +322,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, // prepare for prefetch VLOG(3) << "prefetch block id is " << prefetch_block->ID(); auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID()); - rpc_service_->SetPrefetchPreparedCtx(prefetch_prepared.get()); - prefetch_prepared.release(); + rpc_service_->SetPrefetchPreparedCtx(std::move(prefetch_prepared)); // start the server listening after all member initialized. server_thread_.reset(new std::thread(RunServer, rpc_service_)); diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 7af6ed1463ab737e871da487f2a687301652ef2d..32b1b65bd97ef1e512a5880843509611b606f52d 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -480,6 +480,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, program.current_block_idx = current_block_idx program.sync_with_cpp() + # FIXME(zcd): prevent loss.grad optimized by mem_opt. + loss.block.var(_append_grad_suffix_(loss.name)).persistable = True if parameter_list is not None: parameters = parameter_list diff --git a/python/paddle/fluid/tests/unittests/test_split_var.py b/python/paddle/fluid/tests/unittests/test_split_var.py index 79d387f0066672058d1640f4e5fd28ed8913fe4c..0c5e8901b903375c7d4de32943e657b205d8fae9 100644 --- a/python/paddle/fluid/tests/unittests/test_split_var.py +++ b/python/paddle/fluid/tests/unittests/test_split_var.py @@ -21,15 +21,7 @@ import random class TestSplitVar(unittest.TestCase): - def test_check_output(self): - # split below shapes to 10 servers - shapes = [[3, 5], [1024], [28, 784], [8, 1020], [800, 10]] - expected_sizes = [ - [15], [1024], - [2352, 2352, 2352, 2352, 2352, 2352, 2352, 2352, 2352, 784], - [2040, 2040, 2040, 2040], - [1150, 1150, 1150, 1150, 1150, 1150, 1100] - ] + def check_split_output(self, shapes, expected_sizes, min_size): var_list = [] program = fluid.Program() for shape in shapes: @@ -39,7 +31,7 @@ class TestSplitVar(unittest.TestCase): # dtype=core.VarDesc.VarType.LOD_TENSOR, shape=shape) var_list.append(var) - blocks = split_dense_variable(var_list, 10) + blocks = split_dense_variable(var_list, 10, min_size) all_sizes = [] for s in expected_sizes: for s2 in s: @@ -48,6 +40,25 @@ class TestSplitVar(unittest.TestCase): varname, block_id, size = block_str.split(":") self.assertEqual(int(size), all_sizes[i]) + def test_1k(self): + shapes = [[3, 5], [1024], [28, 784], [8, 1020], [800, 10]] + expected_sizes = [ + [15], [1024], + [2352, 2352, 2352, 2352, 2352, 2352, 2352, 2352, 2352, 784], + [2040, 2040, 2040, 2040], + [1150, 1150, 1150, 1150, 1150, 1150, 1100] + ] + + self.check_split_output(shapes, expected_sizes, 1024) + + def test_check_output_8k(self): + shapes = [[3, 5], [1024], [28, 784], [8, 1020], [800, 10], + [6, 33, 33, 33]] + expected_sizes = [[15], [1024], [10976, 10976], [8160], [8000], + [35937, 35937, 35937, 35937, 35937, 35937]] + + self.check_split_output(shapes, expected_sizes, 8192) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index a323f8d03613e7c4149812daab4ccb57fb940a7e..42ff0a9eb1112ed5709749e3867794c80be8f1d1 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -93,30 +93,33 @@ def same_or_split_var(p_name, var_name): return p_name == var_name or p_name.startswith(var_name + ".block") -def split_dense_variable(var_list, - pserver_count, - min_block_size=1024, - max_block_size=1048576): +def split_dense_variable(var_list, service_count, min_block_size=8192): """ - We may need to split dense tensor to one or more blocks and put - them equally onto parameter server. One block is a sub-tensor - aligned by dim[0] of the tensor. - - We need to have a minimal block size so that the calculations in - the parameter server side can gain better performance. By default - minimum block size is 1024. The max block size is used to prevent - very large blocks that may cause send error. - :return: A list of VarBlocks. Each VarBlock specifies a shard of - the var. + We may need to split dense tensor to one or more blocks and put + them equally onto parameter server. One block is a sub-tensor + aligned by dim[0] of the tensor. + + We need to have a minimal block size so that the calculations in + the parameter server side can gain better performance. By default + minimum block size 8K elements (maybe 16bit or 32bit or 64bit). + + Args: + var_list (list): List of variables. + service_count (int): Numel of pserver services. A pserver may have two + or more listening ports. + min_block_size (int): Minimum splitted block size. + Returns: + blocks (list[(varname, block_id, current_block_size)]): A list + of VarBlocks. Each VarBlock specifies a shard of the var. """ blocks = [] for var in var_list: - split_count = pserver_count + split_count = service_count var_numel = reduce(lambda x, y: x * y, var.shape) max_pserver_count = int(math.floor(var_numel / float(min_block_size))) if max_pserver_count == 0: max_pserver_count = 1 - if max_pserver_count < pserver_count: + if max_pserver_count < service_count: split_count = max_pserver_count block_size = int(math.ceil(var_numel / float(split_count))) @@ -270,6 +273,7 @@ class DistributeTranspiler: grad_var_mapping = self._append_split_op(program, grad_blocks) param_var_mapping = self._create_vars_from_blocklist(program, param_blocks) + # step3: Add gradients as send op inputs and parameters as send # op outputs. send_inputs = [] @@ -277,9 +281,11 @@ class DistributeTranspiler: for b in grad_blocks: # append by order varname, block_id, _ = b.split(":") send_inputs.append(grad_var_mapping[varname][int(block_id)]) + for b in param_blocks: varname, block_id, _ = b.split(":") send_outputs.append(param_var_mapping[varname][int(block_id)]) + # let send_op know which endpoint to send which var to, eplist has the same # order as send_inputs. eplist = split_method(send_inputs, pserver_endpoints) @@ -751,9 +757,18 @@ class DistributeTranspiler: Create vars for each split. NOTE: only grads need to be named for different trainers, use add_trainer_suffix to rename the grad vars. - :return: A dict mapping from original var name to each var split. + Args: + program (ProgramDesc): ProgramDesc which gradients blong. + block_list (list[(varname, block_id, block_size)]): List of gradient blocks. + add_trainer_suffix (Bool): Add trainer suffix to new variable's name if set True. + Returns: + var_mapping (dict(varname->[new_varname_variable])):A dict mapping + from original var name to each var split. """ + + # varname->[(block_id, current_block_size)] block_map = dict() + var_mapping = dict() for block_str in block_list: varname, offset, size = block_str.split(":") @@ -824,7 +839,16 @@ class DistributeTranspiler: persistable=persistable) def _append_split_op(self, program, gradblocks): - # Split variables that need to be split and append respective ops + """ + Split variables that need to be split and append respective ops + Args: + program (ProgramDesc): ProgramDesc that gradients blong. + gradblocks (list[(varname, block_id, block_size)]): List of gradient blocks. + Returns: + var_mapping (dict(varname->[new_splitted_variable])):A dict mapping + from original var name to each var split. + """ + add_suffix = False if self.trainer_num > 1: add_suffix = True @@ -1148,6 +1172,12 @@ class DistributeTranspiler: return lr_ops def _get_optimize_pass(self): + """ + Get optimizer operators, paramters and gradients from origin_program + Returns: + opt_ops (list): optimize operators. + params_grads (dict): paramter->gradient. + """ block = self.origin_program.global_block() opt_ops = [] params_grads = []