From 4ead9a5a3c936d045ffa400536ec348e81bcaea2 Mon Sep 17 00:00:00 2001 From: Thunderbrook <52529258+Thunderbrook@users.noreply.github.com> Date: Wed, 28 Apr 2021 15:02:33 +0800 Subject: [PATCH] [PsCore] solve Brpc dep (#32632) * Revert "Revert "[PsCore] optimize performance of large kv (#32535)" (#32599)" This reverts commit 809ac03656712744d6dea7a6268aeeea46b6f12e. * brpc dep --- CMakeLists.txt | 5 + paddle/fluid/distributed/CMakeLists.txt | 2 +- .../distributed/service/brpc_ps_server.cc | 23 +-- paddle/fluid/distributed/table/CMakeLists.txt | 6 +- .../distributed/table/common_sparse_table.cc | 55 +++--- .../table/depends/large_scale_kv.h | 158 ++++++++++-------- paddle/fluid/distributed/test/CMakeLists.txt | 6 +- paddle/fluid/framework/CMakeLists.txt | 10 +- .../framework/fleet/heter_ps/CMakeLists.txt | 10 +- paddle/fluid/framework/trainer.h | 1 - .../distributed/fleet/runtime/the_one_ps.py | 45 +++-- .../distributed_strategy.py | 1 + .../fleet/parameter_server/ir/public.py | 1 + 13 files changed, 197 insertions(+), 126 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2f16c390d8..f30671bd3a 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -353,6 +353,11 @@ if (WITH_MIPS) add_definitions(-DPADDLE_WITH_MIPS) endif() +if (WITH_HETERPS) + if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -faligned-new") + endif() +endif() set(PADDLE_PYTHON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/python/build") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG") diff --git a/paddle/fluid/distributed/CMakeLists.txt b/paddle/fluid/distributed/CMakeLists.txt index a2062d82c8..905347d031 100644 --- a/paddle/fluid/distributed/CMakeLists.txt +++ b/paddle/fluid/distributed/CMakeLists.txt @@ -11,8 +11,8 @@ if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new") endif() -add_subdirectory(table) add_subdirectory(service) +add_subdirectory(table) add_subdirectory(test) add_subdirectory(index_dataset) diff --git a/paddle/fluid/distributed/service/brpc_ps_server.cc b/paddle/fluid/distributed/service/brpc_ps_server.cc index a9370561a5..a1440260bf 100644 --- a/paddle/fluid/distributed/service/brpc_ps_server.cc +++ b/paddle/fluid/distributed/service/brpc_ps_server.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/distributed/service/brpc_ps_server.h" #include // NOLINT +#include "butil/object_pool.h" #include "paddle/fluid/distributed/table/depends/sparse_utils.h" #include "paddle/fluid/distributed/table/table.h" #include "paddle/fluid/framework/archive.h" @@ -196,12 +197,13 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request, return 0; } - std::vector res_data; - res_data.resize(num * table->value_accesor()->select_size() / sizeof(float)); - table->pull_dense(res_data.data(), num); + auto res_data = butil::get_object>(); + res_data->resize(num * table->value_accesor()->select_size() / sizeof(float)); + table->pull_dense(res_data->data(), num); - cntl->response_attachment().append((char *)res_data.data(), - res_data.size() * sizeof(float)); + cntl->response_attachment().append((char *)(res_data->data()), + res_data->size() * sizeof(float)); + butil::return_object(res_data); return 0; } @@ -367,12 +369,13 @@ int32_t BrpcPsService::pull_sparse(Table *table, value.DeserializeFromBytes(const_cast(data)); - std::vector res_data; - res_data.resize(num * dim); - table->pull_sparse(res_data.data(), value); + auto res_data = butil::get_object>(); + res_data->resize(num * dim); + table->pull_sparse(res_data->data(), value); - cntl->response_attachment().append((char *)res_data.data(), - res_data.size() * sizeof(float)); + cntl->response_attachment().append((char *)(res_data->data()), + res_data->size() * sizeof(float)); + butil::return_object(res_data); return 0; } diff --git a/paddle/fluid/distributed/table/CMakeLists.txt b/paddle/fluid/distributed/table/CMakeLists.txt index dde1f5ae8e..dab3909580 100644 --- a/paddle/fluid/distributed/table/CMakeLists.txt +++ b/paddle/fluid/distributed/table/CMakeLists.txt @@ -13,7 +13,11 @@ set_source_files_properties(sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTR set_source_files_properties(barrier_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(common_graph_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_library(common_table SRCS common_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc common_graph_table.cc DEPS ${TABLE_DEPS} graph_edge graph_node device_context string_helper simple_threadpool xxhash generator) +get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) + +cc_library(common_table SRCS common_sparse_table.cc common_dense_table.cc +sparse_geo_table.cc barrier_table.cc common_graph_table.cc DEPS ${TABLE_DEPS} +${RPC_DEPS} graph_edge graph_node device_context string_helper simple_threadpool xxhash generator) set_source_files_properties(tensor_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(tensor_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) diff --git a/paddle/fluid/distributed/table/common_sparse_table.cc b/paddle/fluid/distributed/table/common_sparse_table.cc index 1c315d34ab..718fce9950 100644 --- a/paddle/fluid/distributed/table/common_sparse_table.cc +++ b/paddle/fluid/distributed/table/common_sparse_table.cc @@ -125,34 +125,37 @@ void ProcessALine(const std::vector& columns, const Meta& meta, int64_t SaveToText(std::ostream* os, std::shared_ptr block, const int mode) { - int64_t not_save_num = 0; - for (auto& value : block->values_) { - if (mode == SaveMode::delta && !value.second.need_save_) { - not_save_num++; - continue; - } - - auto* vs = value.second.data_; - std::stringstream ss; - auto id = value.first; - ss << id << "\t" << value.second.count_ << "\t" << value.second.unseen_days_ - << "\t" << value.second.is_entry_ << "\t"; - - for (int i = 0; i < block->value_length_; i++) { - ss << vs[i]; - ss << ","; - } + int64_t save_num = 0; + for (auto& table : block->values_) { + for (auto& value : table) { + if (mode == SaveMode::delta && !value.second->need_save_) { + continue; + } + save_num += 1; + + auto* vs = value.second->data_.data(); + std::stringstream ss; + auto id = value.first; + ss << id << "\t" << value.second->count_ << "\t" + << value.second->unseen_days_ << "\t" << value.second->is_entry_ + << "\t"; + + for (int i = 0; i < block->value_length_; i++) { + ss << vs[i]; + ss << ","; + } - ss << "\n"; + ss << "\n"; - os->write(ss.str().c_str(), sizeof(char) * ss.str().size()); + os->write(ss.str().c_str(), sizeof(char) * ss.str().size()); - if (mode == SaveMode::base || mode == SaveMode::delta) { - value.second.need_save_ = false; + if (mode == SaveMode::base || mode == SaveMode::delta) { + value.second->need_save_ = false; + } } } - return block->values_.size() - not_save_num; + return save_num; } int64_t LoadFromText(const std::string& valuepath, const std::string& metapath, @@ -183,7 +186,7 @@ int64_t LoadFromText(const std::string& valuepath, const std::string& metapath, block->Init(id, false); - auto value_instant = block->GetValue(id); + VALUE* value_instant = block->GetValue(id); if (values.size() == 5) { value_instant->count_ = std::stoi(values[1]); value_instant->unseen_days_ = std::stoi(values[2]); @@ -373,8 +376,10 @@ std::pair CommonSparseTable::print_table_stat() { int64_t feasign_size = 0; int64_t mf_size = 0; - for (auto& value : shard_values_) { - feasign_size += value->values_.size(); + for (auto& shard : shard_values_) { + for (auto& table : shard->values_) { + feasign_size += table.size(); + } } return {feasign_size, mf_size}; diff --git a/paddle/fluid/distributed/table/depends/large_scale_kv.h b/paddle/fluid/distributed/table/depends/large_scale_kv.h index bb4174bd2c..5c10fca98c 100644 --- a/paddle/fluid/distributed/table/depends/large_scale_kv.h +++ b/paddle/fluid/distributed/table/depends/large_scale_kv.h @@ -26,6 +26,7 @@ #include #include "gflags/gflags.h" +#include "butil/object_pool.h" #include "paddle/fluid/distributed/common/utils.h" #include "paddle/fluid/distributed/table/depends/initializers.h" #include "paddle/fluid/distributed/thirdparty/round_robin.h" @@ -48,6 +49,10 @@ namespace distributed { enum Mode { training, infer }; +static const int SPARSE_SHARD_BUCKET_NUM_BITS = 6; +static const size_t SPARSE_SHARD_BUCKET_NUM = (size_t)1 + << SPARSE_SHARD_BUCKET_NUM_BITS; + struct VALUE { explicit VALUE(size_t length) : length_(length), @@ -55,46 +60,16 @@ struct VALUE { unseen_days_(0), need_save_(false), is_entry_(false) { - data_ = new float[length]; - memset(data_, 0, sizeof(float) * length); - } - - VALUE(const VALUE &value) { - length_ = value.length_; - count_ = value.count_; - unseen_days_ = value.unseen_days_; - need_save_ = value.need_save_; - is_entry_ = value.is_entry_; - data_ = new float[length_]; - memcpy(data_, value.data_, sizeof(float) * length_); - } - - VALUE &operator=(const VALUE &value) { - if (this != &value) { - delete[] data_; - length_ = value.length_; - count_ = value.count_; - unseen_days_ = value.unseen_days_; - need_save_ = value.need_save_; - is_entry_ = value.is_entry_; - - data_ = new float[length_]; - memcpy(data_, value.data_, sizeof(float) * length_); - } - return *this; - } - - ~VALUE() { - delete[] data_; - data_ = nullptr; + data_.resize(length); + memset(data_.data(), 0, sizeof(float) * length); } size_t length_; + std::vector data_; int count_; int unseen_days_; // use to check knock-out bool need_save_; // whether need to save bool is_entry_; // whether knock-in - float *data_; }; inline bool count_entry(VALUE *value, int threshold) { @@ -176,12 +151,12 @@ class ValueBlock { const std::vector &value_dims) { auto pts = std::vector(); pts.reserve(value_names.size()); - auto &values = values_.at(id); + auto values = GetValue(id); for (int i = 0; i < static_cast(value_names.size()); i++) { PADDLE_ENFORCE_EQ( value_dims[i], value_dims_[i], platform::errors::InvalidArgument("value dims is not match")); - pts.push_back(values.data_ + + pts.push_back(values->data_.data() + value_offsets_.at(value_idx_.at(value_names[i]))); } return pts; @@ -190,33 +165,45 @@ class ValueBlock { // pull float *Init(const uint64_t &id, const bool with_update = true, const int counter = 1) { - if (!Has(id)) { - values_.emplace(std::make_pair(id, VALUE(value_length_))); - } + size_t hash = _hasher(id); + size_t bucket = compute_bucket(hash); - auto &value = values_.at(id); + auto &table = values_[bucket]; + auto res = table.find(id); - if (with_update) { - AttrUpdate(&value, counter); + VALUE *value = nullptr; + if (res == table.end()) { + value = butil::get_object(value_length_); + + table[id] = value; + + } else { + value = res->second; } - return value.data_; + if (with_update) { + AttrUpdate(value, counter); + } + return value->data_.data(); } - VALUE *InitGet(const uint64_t &id, const bool with_update = true, const int counter = 1) { - if (!Has(id)) { - values_.emplace(std::make_pair(id, VALUE(value_length_))); - } + size_t hash = _hasher(id); + size_t bucket = compute_bucket(hash); - auto &value = values_.at(id); + auto &table = values_[bucket]; + auto res = table.find(id); - if (with_update) { - AttrUpdate(&value, counter); + VALUE *value = nullptr; + if (res == table.end()) { + value = butil::get_object(value_length_); + // value = _alloc.acquire(value_length_); + table[id] = value; + } else { + value = (VALUE *)(void *)(res->second); } - - return &value; + return value; } void AttrUpdate(VALUE *value, const int counter) { @@ -229,7 +216,7 @@ class ValueBlock { if (value->is_entry_) { // initialize for (size_t x = 0; x < value_names_.size(); ++x) { - initializers_[x]->GetValue(value->data_ + value_offsets_[x], + initializers_[x]->GetValue(value->data_.data() + value_offsets_[x], value_dims_[x]); } value->need_save_ = true; @@ -243,42 +230,73 @@ class ValueBlock { // dont jude if (has(id)) float *Get(const uint64_t &id) { - auto &value = values_.at(id); - return value.data_; + size_t hash = _hasher(id); + size_t bucket = compute_bucket(hash); + auto &table = values_[bucket]; + + // auto &value = table.at(id); + // return value->data_.data(); + auto res = table.find(id); + VALUE *value = res->second; + return value->data_.data(); } // for load, to reset count, unseen_days - VALUE *GetValue(const uint64_t &id) { return &values_.at(id); } + VALUE *GetValue(const uint64_t &id) { + size_t hash = _hasher(id); + size_t bucket = compute_bucket(hash); + + auto &table = values_[bucket]; + auto res = table.find(id); + return res->second; + } bool GetEntry(const uint64_t &id) { - auto &value = values_.at(id); - return value.is_entry_; + auto value = GetValue(id); + return value->is_entry_; } void SetEntry(const uint64_t &id, const bool state) { - auto &value = values_.at(id); - value.is_entry_ = state; + auto value = GetValue(id); + value->is_entry_ = state; } void Shrink(const int threshold) { - for (auto iter = values_.begin(); iter != values_.end();) { - auto &value = iter->second; - value.unseen_days_++; - if (value.unseen_days_ >= threshold) { - iter = values_.erase(iter); - } else { - ++iter; + for (auto &table : values_) { + for (auto iter = table.begin(); iter != table.end();) { + // VALUE* value = (VALUE*)(void*)(iter->second); + VALUE *value = iter->second; + value->unseen_days_++; + if (value->unseen_days_ >= threshold) { + butil::return_object(iter->second); + //_alloc.release(iter->second); + //_alloc.release(value); + iter = table.erase(iter); + } else { + ++iter; + } } } return; } float GetThreshold() { return threshold_; } + size_t compute_bucket(size_t hash) { + if (SPARSE_SHARD_BUCKET_NUM == 1) { + return 0; + } else { + return hash >> (sizeof(size_t) * 8 - SPARSE_SHARD_BUCKET_NUM_BITS); + } + } private: bool Has(const uint64_t id) { - auto got = values_.find(id); - if (got == values_.end()) { + size_t hash = _hasher(id); + size_t bucket = compute_bucket(hash); + auto &table = values_[bucket]; + + auto got = table.find(id); + if (got == table.end()) { return false; } else { return true; @@ -286,8 +304,9 @@ class ValueBlock { } public: - robin_hood::unordered_map values_; + robin_hood::unordered_map values_[SPARSE_SHARD_BUCKET_NUM]; size_t value_length_ = 0; + std::hash _hasher; private: const std::vector &value_names_; @@ -302,4 +321,3 @@ class ValueBlock { } // namespace distributed } // namespace paddle - diff --git a/paddle/fluid/distributed/test/CMakeLists.txt b/paddle/fluid/distributed/test/CMakeLists.txt index b756c740ac..af87e1b6cc 100644 --- a/paddle/fluid/distributed/test/CMakeLists.txt +++ b/paddle/fluid/distributed/test/CMakeLists.txt @@ -1,8 +1,10 @@ set_source_files_properties(table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_test(table_test SRCS table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS}) +cc_test(table_test SRCS table_test.cc DEPS common_table table tensor_accessor +ps_framework_proto ${COMMON_DEPS} ${RPC_DEPS}) set_source_files_properties(dense_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_test(dense_table_test SRCS dense_table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS}) +cc_test(dense_table_test SRCS dense_table_test.cc DEPS common_table table +tensor_accessor ps_framework_proto ${COMMON_DEPS} ${RPC_DEPS}) set_source_files_properties(barrier_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(barrier_table_test SRCS barrier_table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS}) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 24bed27728..1494e74c07 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -301,8 +301,14 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS fast_threaded_ssa_graph_executor variable_helper) cc_library(executor_cache SRCS executor_cache.cc DEPS executor) -cc_test(dist_multi_trainer_test SRCS dist_multi_trainer_test.cc DEPS - conditional_block_op executor) +if(WITH_PSCORE) + get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) + cc_test(dist_multi_trainer_test SRCS dist_multi_trainer_test.cc DEPS + conditional_block_op executor ${RPC_DEPS}) +else() + cc_test(dist_multi_trainer_test SRCS dist_multi_trainer_test.cc DEPS + conditional_block_op executor) +endif() cc_library(prune SRCS prune.cc DEPS framework_proto boost) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry diff --git a/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt b/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt index 6df2cd52bb..67c44368b7 100644 --- a/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt +++ b/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt @@ -1,5 +1,13 @@ IF(WITH_GPU) - nv_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h DEPS cub device_context) + SET(HETERPS_DEPS device_context) + if (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0) + SET(HETERPS_DEPS ${HETERPS_DEPS} cub) + endif() + if(WITH_PSCORE) + get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) + SET(HETERPS_DEPS ${HETERPS_DEPS} ${RPC_DEPS}) + endif() + nv_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h DEPS ${HETERPS_DEPS}) nv_test(test_heter_comm SRCS test_heter_comm.cu feature_value.h DEPS heter_comm) nv_library(heter_ps SRCS heter_ps.cu DEPS heter_comm) ENDIF() diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index 01aa07e618..10f6c1ddbd 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -26,7 +26,6 @@ limitations under the License. */ #include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/device_worker.h" -#include "paddle/fluid/framework/fleet/heter_context.h" #include "paddle/fluid/framework/fleet/heter_wrapper.h" #include "paddle/fluid/framework/heter_service.h" #include "paddle/fluid/framework/lod_tensor.h" diff --git a/python/paddle/distributed/fleet/runtime/the_one_ps.py b/python/paddle/distributed/fleet/runtime/the_one_ps.py index df07a7a6e7..24b83662c9 100644 --- a/python/paddle/distributed/fleet/runtime/the_one_ps.py +++ b/python/paddle/distributed/fleet/runtime/the_one_ps.py @@ -77,10 +77,13 @@ class CommonAccessor: ("Moment2", None), ("Beta1Pow", 1), ("Beta2Pow", 1), ("LearningRate", 1)] opt_input_map["sum"] = [("Param", None)] + opt_input_map["naive_adagrad"] = [("Param", None), ("G2Sum", 1), + ("LearningRate", 1)] opt_attr_map = {} opt_attr_map["sgd"] = [] opt_attr_map["sum"] = [] + opt_attr_map["naive_adagrad"] = [] opt_attr_map["adam"] = [("beta1", "f"), ("beta2", "f"), ("epsilon", "f")] @@ -169,6 +172,10 @@ class CommonAccessor: param_varnames = self.opt_input_map["sum"] attr_varnames = self.opt_attr_map["sum"] self.accessor_class = "sum" + elif compiled_strategy.use_ps_gpu and is_sparse: + param_varnames = self.opt_input_map["naive_adagrad"] + attr_varnames = self.opt_attr_map["naive_adagrad"] + self.accessor_class = "sgd" else: param_varnames = self.opt_input_map[oop.type] attr_varnames = self.opt_attr_map[oop.type] @@ -176,20 +183,28 @@ class CommonAccessor: for (formal_name, shape) in param_varnames: params.append(formal_name) - param = main_program.global_block().vars[oop.input(formal_name)[0]] - if formal_name == "LearningRate" and param.name != "learning_rate_0": - warnings.warn("will support decay soon") - param = main_program.global_block().vars["learning_rate_0"] - - if shape is None: - if is_sparse: - shape = total_dims - else: - shape = self.get_shard(total_dims, pserver_num, pserver_id) - dims.append(shape) + if formal_name == "G2Sum": + dims.append(1) + initializer = "fill_constant&0" + initializers.append(initializer) + else: + param = main_program.global_block().vars[oop.input(formal_name)[ + 0]] + if formal_name == "LearningRate" and param.name != "learning_rate_0": + warnings.warn("will support decay soon") + param = main_program.global_block().vars["learning_rate_0"] + + if shape is None: + if is_sparse: + shape = total_dims + else: + shape = self.get_shard(total_dims, pserver_num, + pserver_id) + dims.append(shape) - initializer = self.get_initializer_attr(param.name, startup_program) - initializers.append(initializer) + initializer = self.get_initializer_attr(param.name, + startup_program) + initializers.append(initializer) for (attr_varname, type_) in attr_varnames: value = oop.attr(attr_varname) @@ -435,6 +450,8 @@ class TheOnePSRuntime(RuntimeBase): if not strategy: raise ValueError("k_steps must be invalid value, please check") + if dist_strategy.a_sync_configs["use_ps_gpu"]: + strategy.use_ps_gpu = True return strategy def build_compiled_startegy(self): @@ -443,6 +460,8 @@ class TheOnePSRuntime(RuntimeBase): compiled_config = CompileTimeStrategy( self.origin_main_program, self.origin_main_program, self.async_strategy, self.role_maker) + if self.async_strategy.use_ps_gpu: + compiled_config.use_ps_gpu = True return compiled_config def _init_worker(self): diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py index 35029a3dfc..2a9d26daae 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py @@ -149,6 +149,7 @@ class DistributedStrategy(object): if num_threads > 1: self._build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce self.debug_opt = None + self.use_ps_gpu = False def set_debug_opt(self, opt_info): self.debug_opt = opt_info diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py index baf8add04c..b2735727f6 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py @@ -138,6 +138,7 @@ class CompileTimeStrategy(object): self.strategy = strategy self.role_maker = role_maker + self.use_ps_gpu = False try: self.is_heter_ps_mode = role_maker._is_heter_parameter_server_mode except: -- GitLab