未验证 提交 4ead9a5a 编写于 作者: T Thunderbrook 提交者: GitHub

[PsCore] solve Brpc dep (#32632)

* Revert "Revert "[PsCore] optimize performance of large kv (#32535)" (#32599)"

This reverts commit 809ac036.

* brpc dep
上级 9ee709fc
...@@ -353,6 +353,11 @@ if (WITH_MIPS) ...@@ -353,6 +353,11 @@ if (WITH_MIPS)
add_definitions(-DPADDLE_WITH_MIPS) add_definitions(-DPADDLE_WITH_MIPS)
endif() endif()
if (WITH_HETERPS)
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -faligned-new")
endif()
endif()
set(PADDLE_PYTHON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/python/build") set(PADDLE_PYTHON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/python/build")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
......
...@@ -11,8 +11,8 @@ if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) ...@@ -11,8 +11,8 @@ if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
"${DISTRIBUTE_COMPILE_FLAGS} -faligned-new") "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new")
endif() endif()
add_subdirectory(table)
add_subdirectory(service) add_subdirectory(service)
add_subdirectory(table)
add_subdirectory(test) add_subdirectory(test)
add_subdirectory(index_dataset) add_subdirectory(index_dataset)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/distributed/service/brpc_ps_server.h" #include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include <thread> // NOLINT #include <thread> // NOLINT
#include "butil/object_pool.h"
#include "paddle/fluid/distributed/table/depends/sparse_utils.h" #include "paddle/fluid/distributed/table/depends/sparse_utils.h"
#include "paddle/fluid/distributed/table/table.h" #include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/archive.h" #include "paddle/fluid/framework/archive.h"
...@@ -196,12 +197,13 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request, ...@@ -196,12 +197,13 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request,
return 0; return 0;
} }
std::vector<float> res_data; auto res_data = butil::get_object<std::vector<float>>();
res_data.resize(num * table->value_accesor()->select_size() / sizeof(float)); res_data->resize(num * table->value_accesor()->select_size() / sizeof(float));
table->pull_dense(res_data.data(), num); table->pull_dense(res_data->data(), num);
cntl->response_attachment().append((char *)res_data.data(), cntl->response_attachment().append((char *)(res_data->data()),
res_data.size() * sizeof(float)); res_data->size() * sizeof(float));
butil::return_object(res_data);
return 0; return 0;
} }
...@@ -367,12 +369,13 @@ int32_t BrpcPsService::pull_sparse(Table *table, ...@@ -367,12 +369,13 @@ int32_t BrpcPsService::pull_sparse(Table *table,
value.DeserializeFromBytes(const_cast<void *>(data)); value.DeserializeFromBytes(const_cast<void *>(data));
std::vector<float> res_data; auto res_data = butil::get_object<std::vector<float>>();
res_data.resize(num * dim); res_data->resize(num * dim);
table->pull_sparse(res_data.data(), value); table->pull_sparse(res_data->data(), value);
cntl->response_attachment().append((char *)res_data.data(), cntl->response_attachment().append((char *)(res_data->data()),
res_data.size() * sizeof(float)); res_data->size() * sizeof(float));
butil::return_object(res_data);
return 0; return 0;
} }
......
...@@ -13,7 +13,11 @@ set_source_files_properties(sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTR ...@@ -13,7 +13,11 @@ set_source_files_properties(sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTR
set_source_files_properties(barrier_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(barrier_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(common_graph_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(common_graph_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(common_table SRCS common_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc common_graph_table.cc DEPS ${TABLE_DEPS} graph_edge graph_node device_context string_helper simple_threadpool xxhash generator) get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
cc_library(common_table SRCS common_sparse_table.cc common_dense_table.cc
sparse_geo_table.cc barrier_table.cc common_graph_table.cc DEPS ${TABLE_DEPS}
${RPC_DEPS} graph_edge graph_node device_context string_helper simple_threadpool xxhash generator)
set_source_files_properties(tensor_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(tensor_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(tensor_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(tensor_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......
...@@ -125,34 +125,37 @@ void ProcessALine(const std::vector<std::string>& columns, const Meta& meta, ...@@ -125,34 +125,37 @@ void ProcessALine(const std::vector<std::string>& columns, const Meta& meta,
int64_t SaveToText(std::ostream* os, std::shared_ptr<ValueBlock> block, int64_t SaveToText(std::ostream* os, std::shared_ptr<ValueBlock> block,
const int mode) { const int mode) {
int64_t not_save_num = 0; int64_t save_num = 0;
for (auto& value : block->values_) { for (auto& table : block->values_) {
if (mode == SaveMode::delta && !value.second.need_save_) { for (auto& value : table) {
not_save_num++; if (mode == SaveMode::delta && !value.second->need_save_) {
continue; continue;
} }
save_num += 1;
auto* vs = value.second.data_;
std::stringstream ss; auto* vs = value.second->data_.data();
auto id = value.first; std::stringstream ss;
ss << id << "\t" << value.second.count_ << "\t" << value.second.unseen_days_ auto id = value.first;
<< "\t" << value.second.is_entry_ << "\t"; ss << id << "\t" << value.second->count_ << "\t"
<< value.second->unseen_days_ << "\t" << value.second->is_entry_
for (int i = 0; i < block->value_length_; i++) { << "\t";
ss << vs[i];
ss << ","; for (int i = 0; i < block->value_length_; i++) {
} ss << vs[i];
ss << ",";
}
ss << "\n"; ss << "\n";
os->write(ss.str().c_str(), sizeof(char) * ss.str().size()); os->write(ss.str().c_str(), sizeof(char) * ss.str().size());
if (mode == SaveMode::base || mode == SaveMode::delta) { if (mode == SaveMode::base || mode == SaveMode::delta) {
value.second.need_save_ = false; value.second->need_save_ = false;
}
} }
} }
return block->values_.size() - not_save_num; return save_num;
} }
int64_t LoadFromText(const std::string& valuepath, const std::string& metapath, int64_t LoadFromText(const std::string& valuepath, const std::string& metapath,
...@@ -183,7 +186,7 @@ int64_t LoadFromText(const std::string& valuepath, const std::string& metapath, ...@@ -183,7 +186,7 @@ int64_t LoadFromText(const std::string& valuepath, const std::string& metapath,
block->Init(id, false); block->Init(id, false);
auto value_instant = block->GetValue(id); VALUE* value_instant = block->GetValue(id);
if (values.size() == 5) { if (values.size() == 5) {
value_instant->count_ = std::stoi(values[1]); value_instant->count_ = std::stoi(values[1]);
value_instant->unseen_days_ = std::stoi(values[2]); value_instant->unseen_days_ = std::stoi(values[2]);
...@@ -373,8 +376,10 @@ std::pair<int64_t, int64_t> CommonSparseTable::print_table_stat() { ...@@ -373,8 +376,10 @@ std::pair<int64_t, int64_t> CommonSparseTable::print_table_stat() {
int64_t feasign_size = 0; int64_t feasign_size = 0;
int64_t mf_size = 0; int64_t mf_size = 0;
for (auto& value : shard_values_) { for (auto& shard : shard_values_) {
feasign_size += value->values_.size(); for (auto& table : shard->values_) {
feasign_size += table.size();
}
} }
return {feasign_size, mf_size}; return {feasign_size, mf_size};
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <vector> #include <vector>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "butil/object_pool.h"
#include "paddle/fluid/distributed/common/utils.h" #include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/distributed/table/depends/initializers.h" #include "paddle/fluid/distributed/table/depends/initializers.h"
#include "paddle/fluid/distributed/thirdparty/round_robin.h" #include "paddle/fluid/distributed/thirdparty/round_robin.h"
...@@ -48,6 +49,10 @@ namespace distributed { ...@@ -48,6 +49,10 @@ namespace distributed {
enum Mode { training, infer }; enum Mode { training, infer };
static const int SPARSE_SHARD_BUCKET_NUM_BITS = 6;
static const size_t SPARSE_SHARD_BUCKET_NUM = (size_t)1
<< SPARSE_SHARD_BUCKET_NUM_BITS;
struct VALUE { struct VALUE {
explicit VALUE(size_t length) explicit VALUE(size_t length)
: length_(length), : length_(length),
...@@ -55,46 +60,16 @@ struct VALUE { ...@@ -55,46 +60,16 @@ struct VALUE {
unseen_days_(0), unseen_days_(0),
need_save_(false), need_save_(false),
is_entry_(false) { is_entry_(false) {
data_ = new float[length]; data_.resize(length);
memset(data_, 0, sizeof(float) * length); memset(data_.data(), 0, sizeof(float) * length);
}
VALUE(const VALUE &value) {
length_ = value.length_;
count_ = value.count_;
unseen_days_ = value.unseen_days_;
need_save_ = value.need_save_;
is_entry_ = value.is_entry_;
data_ = new float[length_];
memcpy(data_, value.data_, sizeof(float) * length_);
}
VALUE &operator=(const VALUE &value) {
if (this != &value) {
delete[] data_;
length_ = value.length_;
count_ = value.count_;
unseen_days_ = value.unseen_days_;
need_save_ = value.need_save_;
is_entry_ = value.is_entry_;
data_ = new float[length_];
memcpy(data_, value.data_, sizeof(float) * length_);
}
return *this;
}
~VALUE() {
delete[] data_;
data_ = nullptr;
} }
size_t length_; size_t length_;
std::vector<float> data_;
int count_; int count_;
int unseen_days_; // use to check knock-out int unseen_days_; // use to check knock-out
bool need_save_; // whether need to save bool need_save_; // whether need to save
bool is_entry_; // whether knock-in bool is_entry_; // whether knock-in
float *data_;
}; };
inline bool count_entry(VALUE *value, int threshold) { inline bool count_entry(VALUE *value, int threshold) {
...@@ -176,12 +151,12 @@ class ValueBlock { ...@@ -176,12 +151,12 @@ class ValueBlock {
const std::vector<int> &value_dims) { const std::vector<int> &value_dims) {
auto pts = std::vector<float *>(); auto pts = std::vector<float *>();
pts.reserve(value_names.size()); pts.reserve(value_names.size());
auto &values = values_.at(id); auto values = GetValue(id);
for (int i = 0; i < static_cast<int>(value_names.size()); i++) { for (int i = 0; i < static_cast<int>(value_names.size()); i++) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
value_dims[i], value_dims_[i], value_dims[i], value_dims_[i],
platform::errors::InvalidArgument("value dims is not match")); platform::errors::InvalidArgument("value dims is not match"));
pts.push_back(values.data_ + pts.push_back(values->data_.data() +
value_offsets_.at(value_idx_.at(value_names[i]))); value_offsets_.at(value_idx_.at(value_names[i])));
} }
return pts; return pts;
...@@ -190,33 +165,45 @@ class ValueBlock { ...@@ -190,33 +165,45 @@ class ValueBlock {
// pull // pull
float *Init(const uint64_t &id, const bool with_update = true, float *Init(const uint64_t &id, const bool with_update = true,
const int counter = 1) { const int counter = 1) {
if (!Has(id)) { size_t hash = _hasher(id);
values_.emplace(std::make_pair(id, VALUE(value_length_))); size_t bucket = compute_bucket(hash);
}
auto &value = values_.at(id); auto &table = values_[bucket];
auto res = table.find(id);
if (with_update) { VALUE *value = nullptr;
AttrUpdate(&value, counter); if (res == table.end()) {
value = butil::get_object<VALUE>(value_length_);
table[id] = value;
} else {
value = res->second;
} }
return value.data_; if (with_update) {
AttrUpdate(value, counter);
}
return value->data_.data();
} }
VALUE *InitGet(const uint64_t &id, const bool with_update = true, VALUE *InitGet(const uint64_t &id, const bool with_update = true,
const int counter = 1) { const int counter = 1) {
if (!Has(id)) { size_t hash = _hasher(id);
values_.emplace(std::make_pair(id, VALUE(value_length_))); size_t bucket = compute_bucket(hash);
}
auto &value = values_.at(id); auto &table = values_[bucket];
auto res = table.find(id);
if (with_update) { VALUE *value = nullptr;
AttrUpdate(&value, counter); if (res == table.end()) {
value = butil::get_object<VALUE>(value_length_);
// value = _alloc.acquire(value_length_);
table[id] = value;
} else {
value = (VALUE *)(void *)(res->second);
} }
return value;
return &value;
} }
void AttrUpdate(VALUE *value, const int counter) { void AttrUpdate(VALUE *value, const int counter) {
...@@ -229,7 +216,7 @@ class ValueBlock { ...@@ -229,7 +216,7 @@ class ValueBlock {
if (value->is_entry_) { if (value->is_entry_) {
// initialize // initialize
for (size_t x = 0; x < value_names_.size(); ++x) { for (size_t x = 0; x < value_names_.size(); ++x) {
initializers_[x]->GetValue(value->data_ + value_offsets_[x], initializers_[x]->GetValue(value->data_.data() + value_offsets_[x],
value_dims_[x]); value_dims_[x]);
} }
value->need_save_ = true; value->need_save_ = true;
...@@ -243,42 +230,73 @@ class ValueBlock { ...@@ -243,42 +230,73 @@ class ValueBlock {
// dont jude if (has(id)) // dont jude if (has(id))
float *Get(const uint64_t &id) { float *Get(const uint64_t &id) {
auto &value = values_.at(id); size_t hash = _hasher(id);
return value.data_; size_t bucket = compute_bucket(hash);
auto &table = values_[bucket];
// auto &value = table.at(id);
// return value->data_.data();
auto res = table.find(id);
VALUE *value = res->second;
return value->data_.data();
} }
// for load, to reset count, unseen_days // for load, to reset count, unseen_days
VALUE *GetValue(const uint64_t &id) { return &values_.at(id); } VALUE *GetValue(const uint64_t &id) {
size_t hash = _hasher(id);
size_t bucket = compute_bucket(hash);
auto &table = values_[bucket];
auto res = table.find(id);
return res->second;
}
bool GetEntry(const uint64_t &id) { bool GetEntry(const uint64_t &id) {
auto &value = values_.at(id); auto value = GetValue(id);
return value.is_entry_; return value->is_entry_;
} }
void SetEntry(const uint64_t &id, const bool state) { void SetEntry(const uint64_t &id, const bool state) {
auto &value = values_.at(id); auto value = GetValue(id);
value.is_entry_ = state; value->is_entry_ = state;
} }
void Shrink(const int threshold) { void Shrink(const int threshold) {
for (auto iter = values_.begin(); iter != values_.end();) { for (auto &table : values_) {
auto &value = iter->second; for (auto iter = table.begin(); iter != table.end();) {
value.unseen_days_++; // VALUE* value = (VALUE*)(void*)(iter->second);
if (value.unseen_days_ >= threshold) { VALUE *value = iter->second;
iter = values_.erase(iter); value->unseen_days_++;
} else { if (value->unseen_days_ >= threshold) {
++iter; butil::return_object(iter->second);
//_alloc.release(iter->second);
//_alloc.release(value);
iter = table.erase(iter);
} else {
++iter;
}
} }
} }
return; return;
} }
float GetThreshold() { return threshold_; } float GetThreshold() { return threshold_; }
size_t compute_bucket(size_t hash) {
if (SPARSE_SHARD_BUCKET_NUM == 1) {
return 0;
} else {
return hash >> (sizeof(size_t) * 8 - SPARSE_SHARD_BUCKET_NUM_BITS);
}
}
private: private:
bool Has(const uint64_t id) { bool Has(const uint64_t id) {
auto got = values_.find(id); size_t hash = _hasher(id);
if (got == values_.end()) { size_t bucket = compute_bucket(hash);
auto &table = values_[bucket];
auto got = table.find(id);
if (got == table.end()) {
return false; return false;
} else { } else {
return true; return true;
...@@ -286,8 +304,9 @@ class ValueBlock { ...@@ -286,8 +304,9 @@ class ValueBlock {
} }
public: public:
robin_hood::unordered_map<uint64_t, VALUE> values_; robin_hood::unordered_map<uint64_t, VALUE *> values_[SPARSE_SHARD_BUCKET_NUM];
size_t value_length_ = 0; size_t value_length_ = 0;
std::hash<uint64_t> _hasher;
private: private:
const std::vector<std::string> &value_names_; const std::vector<std::string> &value_names_;
...@@ -302,4 +321,3 @@ class ValueBlock { ...@@ -302,4 +321,3 @@ class ValueBlock {
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
set_source_files_properties(table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(table_test SRCS table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS}) cc_test(table_test SRCS table_test.cc DEPS common_table table tensor_accessor
ps_framework_proto ${COMMON_DEPS} ${RPC_DEPS})
set_source_files_properties(dense_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(dense_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(dense_table_test SRCS dense_table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS}) cc_test(dense_table_test SRCS dense_table_test.cc DEPS common_table table
tensor_accessor ps_framework_proto ${COMMON_DEPS} ${RPC_DEPS})
set_source_files_properties(barrier_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(barrier_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(barrier_table_test SRCS barrier_table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS}) cc_test(barrier_table_test SRCS barrier_table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS})
......
...@@ -301,8 +301,14 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS ...@@ -301,8 +301,14 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
fast_threaded_ssa_graph_executor variable_helper) fast_threaded_ssa_graph_executor variable_helper)
cc_library(executor_cache SRCS executor_cache.cc DEPS executor) cc_library(executor_cache SRCS executor_cache.cc DEPS executor)
cc_test(dist_multi_trainer_test SRCS dist_multi_trainer_test.cc DEPS if(WITH_PSCORE)
conditional_block_op executor) get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
cc_test(dist_multi_trainer_test SRCS dist_multi_trainer_test.cc DEPS
conditional_block_op executor ${RPC_DEPS})
else()
cc_test(dist_multi_trainer_test SRCS dist_multi_trainer_test.cc DEPS
conditional_block_op executor)
endif()
cc_library(prune SRCS prune.cc DEPS framework_proto boost) cc_library(prune SRCS prune.cc DEPS framework_proto boost)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
......
IF(WITH_GPU) IF(WITH_GPU)
nv_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h DEPS cub device_context) SET(HETERPS_DEPS device_context)
if (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0)
SET(HETERPS_DEPS ${HETERPS_DEPS} cub)
endif()
if(WITH_PSCORE)
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
SET(HETERPS_DEPS ${HETERPS_DEPS} ${RPC_DEPS})
endif()
nv_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h DEPS ${HETERPS_DEPS})
nv_test(test_heter_comm SRCS test_heter_comm.cu feature_value.h DEPS heter_comm) nv_test(test_heter_comm SRCS test_heter_comm.cu feature_value.h DEPS heter_comm)
nv_library(heter_ps SRCS heter_ps.cu DEPS heter_comm) nv_library(heter_ps SRCS heter_ps.cu DEPS heter_comm)
ENDIF() ENDIF()
......
...@@ -26,7 +26,6 @@ limitations under the License. */ ...@@ -26,7 +26,6 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/fleet/heter_context.h"
#include "paddle/fluid/framework/fleet/heter_wrapper.h" #include "paddle/fluid/framework/fleet/heter_wrapper.h"
#include "paddle/fluid/framework/heter_service.h" #include "paddle/fluid/framework/heter_service.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
......
...@@ -77,10 +77,13 @@ class CommonAccessor: ...@@ -77,10 +77,13 @@ class CommonAccessor:
("Moment2", None), ("Beta1Pow", 1), ("Moment2", None), ("Beta1Pow", 1),
("Beta2Pow", 1), ("LearningRate", 1)] ("Beta2Pow", 1), ("LearningRate", 1)]
opt_input_map["sum"] = [("Param", None)] opt_input_map["sum"] = [("Param", None)]
opt_input_map["naive_adagrad"] = [("Param", None), ("G2Sum", 1),
("LearningRate", 1)]
opt_attr_map = {} opt_attr_map = {}
opt_attr_map["sgd"] = [] opt_attr_map["sgd"] = []
opt_attr_map["sum"] = [] opt_attr_map["sum"] = []
opt_attr_map["naive_adagrad"] = []
opt_attr_map["adam"] = [("beta1", "f"), ("beta2", "f"), opt_attr_map["adam"] = [("beta1", "f"), ("beta2", "f"),
("epsilon", "f")] ("epsilon", "f")]
...@@ -169,6 +172,10 @@ class CommonAccessor: ...@@ -169,6 +172,10 @@ class CommonAccessor:
param_varnames = self.opt_input_map["sum"] param_varnames = self.opt_input_map["sum"]
attr_varnames = self.opt_attr_map["sum"] attr_varnames = self.opt_attr_map["sum"]
self.accessor_class = "sum" self.accessor_class = "sum"
elif compiled_strategy.use_ps_gpu and is_sparse:
param_varnames = self.opt_input_map["naive_adagrad"]
attr_varnames = self.opt_attr_map["naive_adagrad"]
self.accessor_class = "sgd"
else: else:
param_varnames = self.opt_input_map[oop.type] param_varnames = self.opt_input_map[oop.type]
attr_varnames = self.opt_attr_map[oop.type] attr_varnames = self.opt_attr_map[oop.type]
...@@ -176,20 +183,28 @@ class CommonAccessor: ...@@ -176,20 +183,28 @@ class CommonAccessor:
for (formal_name, shape) in param_varnames: for (formal_name, shape) in param_varnames:
params.append(formal_name) params.append(formal_name)
param = main_program.global_block().vars[oop.input(formal_name)[0]] if formal_name == "G2Sum":
if formal_name == "LearningRate" and param.name != "learning_rate_0": dims.append(1)
warnings.warn("will support decay soon") initializer = "fill_constant&0"
param = main_program.global_block().vars["learning_rate_0"] initializers.append(initializer)
else:
if shape is None: param = main_program.global_block().vars[oop.input(formal_name)[
if is_sparse: 0]]
shape = total_dims if formal_name == "LearningRate" and param.name != "learning_rate_0":
else: warnings.warn("will support decay soon")
shape = self.get_shard(total_dims, pserver_num, pserver_id) param = main_program.global_block().vars["learning_rate_0"]
dims.append(shape)
if shape is None:
if is_sparse:
shape = total_dims
else:
shape = self.get_shard(total_dims, pserver_num,
pserver_id)
dims.append(shape)
initializer = self.get_initializer_attr(param.name, startup_program) initializer = self.get_initializer_attr(param.name,
initializers.append(initializer) startup_program)
initializers.append(initializer)
for (attr_varname, type_) in attr_varnames: for (attr_varname, type_) in attr_varnames:
value = oop.attr(attr_varname) value = oop.attr(attr_varname)
...@@ -435,6 +450,8 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -435,6 +450,8 @@ class TheOnePSRuntime(RuntimeBase):
if not strategy: if not strategy:
raise ValueError("k_steps must be invalid value, please check") raise ValueError("k_steps must be invalid value, please check")
if dist_strategy.a_sync_configs["use_ps_gpu"]:
strategy.use_ps_gpu = True
return strategy return strategy
def build_compiled_startegy(self): def build_compiled_startegy(self):
...@@ -443,6 +460,8 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -443,6 +460,8 @@ class TheOnePSRuntime(RuntimeBase):
compiled_config = CompileTimeStrategy( compiled_config = CompileTimeStrategy(
self.origin_main_program, self.origin_main_program, self.origin_main_program, self.origin_main_program,
self.async_strategy, self.role_maker) self.async_strategy, self.role_maker)
if self.async_strategy.use_ps_gpu:
compiled_config.use_ps_gpu = True
return compiled_config return compiled_config
def _init_worker(self): def _init_worker(self):
......
...@@ -149,6 +149,7 @@ class DistributedStrategy(object): ...@@ -149,6 +149,7 @@ class DistributedStrategy(object):
if num_threads > 1: if num_threads > 1:
self._build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce self._build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
self.debug_opt = None self.debug_opt = None
self.use_ps_gpu = False
def set_debug_opt(self, opt_info): def set_debug_opt(self, opt_info):
self.debug_opt = opt_info self.debug_opt = opt_info
......
...@@ -138,6 +138,7 @@ class CompileTimeStrategy(object): ...@@ -138,6 +138,7 @@ class CompileTimeStrategy(object):
self.strategy = strategy self.strategy = strategy
self.role_maker = role_maker self.role_maker = role_maker
self.use_ps_gpu = False
try: try:
self.is_heter_ps_mode = role_maker._is_heter_parameter_server_mode self.is_heter_ps_mode = role_maker._is_heter_parameter_server_mode
except: except:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册