未验证 提交 9b3b53ba 编写于 作者: Z zhaocaibei123 提交者: GitHub

geo memory sparse table (#39250)

* geo depends

* add memory geo table

* fix
上级 bafea65c
......@@ -213,6 +213,7 @@ int32_t BrpcPsClient::initialize() {
auto &profiler = CostProfiler::instance();
profiler.register_profiler("pserver_client_pull_dense");
profiler.register_profiler("pserver_client_pull_sparse");
profiler.register_profiler("pserver_client_pull_sparse_param");
profiler.register_profiler("pserver_client_pull_sparse_local");
profiler.register_profiler("pserver_client_push_sparse");
profiler.register_profiler("pserver_client_push_sparse_parse");
......@@ -543,6 +544,7 @@ std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
return fut;
}
// for GEO
std::future<int32_t> BrpcPsClient::push_sparse_param(
size_t table_id, const uint64_t *keys, const float **update_values,
size_t num, void *done) {
......@@ -558,18 +560,8 @@ std::future<int32_t> BrpcPsClient::push_sparse_param(
ids.resize(request_call_num);
value_ptrs.resize(request_call_num);
const auto &server_param = _config.server_param().downpour_server_param();
uint64_t shard_num = FLAGS_pserver_sparse_table_shard_num;
for (int i = 0; i < server_param.downpour_table_param_size(); ++i) {
const auto &table_param = server_param.downpour_table_param(i);
if (table_param.table_id() == table_id) {
shard_num = table_param.shard_num();
break;
}
}
for (size_t i = 0; i < num; ++i) {
size_t pserver_idx = get_sparse_shard(shard_num, request_call_num, keys[i]);
size_t pserver_idx = keys[i] % request_call_num;
ids[pserver_idx].push_back(keys[i]);
value_ptrs[pserver_idx].push_back(update_values[i]);
}
......@@ -1003,6 +995,120 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
return fut;
}
// for GEO
std::future<int32_t> BrpcPsClient::pull_sparse_param(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num,
bool is_training) {
auto timer = std::make_shared<CostTimer>("pserver_client_pull_sparse_param");
size_t request_call_num = _server_channels.size();
auto shard_sorted_kvs = std::make_shared<
std::vector<std::vector<std::pair<uint64_t, float *>>>>();
shard_sorted_kvs->resize(request_call_num);
for (size_t i = 0; i < num; ++i) {
size_t shard_id = keys[i] % request_call_num;
shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]});
}
auto *accessor = table_accessor(table_id);
size_t value_size = accessor->select_size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [shard_sorted_kvs, value_size](void *done) {
int ret = 0;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < shard_sorted_kvs->size(); ++i) {
if (closure->check_response(i, PS_PULL_SPARSE_TABLE) != 0) {
ret = -1;
break;
}
auto &request_kvs = shard_sorted_kvs->at(i);
auto &res_io_buffer = closure->cntl(i)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
uint64_t last_key = UINT64_MAX;
float *last_value_data = NULL;
// can remove sort&unique
for (size_t kv_idx = 0; kv_idx < request_kvs.size(); ++kv_idx) {
auto *kv_pair = &(request_kvs[kv_idx]);
if (kv_pair->first == last_key) {
memcpy(reinterpret_cast<void *>(kv_pair->second),
reinterpret_cast<void *>(last_value_data), value_size);
} else {
last_key = kv_pair->first;
last_value_data = kv_pair->second;
if (value_size !=
io_buffer_itr.copy_and_forward(
reinterpret_cast<void *>(last_value_data), value_size)) {
LOG(WARNING) << "res data is lack or not in format";
ret = -1;
break;
}
}
}
}
closure->set_promise_value(ret);
});
closure->add_timer(timer);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
auto &sorted_kvs = shard_sorted_kvs->at(i);
std::sort(sorted_kvs.begin(), sorted_kvs.end(),
[](const std::pair<uint64_t, float *> &k1,
const std::pair<uint64_t, float *> &k2) {
return k1.first < k2.first;
});
uint64_t last_key = UINT64_MAX;
uint32_t kv_request_count = 0;
size_t sorted_kv_size = sorted_kvs.size();
auto &request_buffer = closure->cntl(i)->request_attachment();
request_buffer.append(reinterpret_cast<void *>(&is_training), sizeof(bool));
std::vector<uint32_t> keys_counter;
keys_counter.reserve(sorted_kv_size);
for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) {
++kv_request_count;
uint32_t keys = 1;
last_key = sorted_kvs[kv_idx].first;
request_buffer.append(reinterpret_cast<void *>(&last_key),
sizeof(uint64_t));
while (kv_idx < sorted_kv_size - 1 &&
last_key == sorted_kvs[kv_idx + 1].first) {
++kv_idx;
++keys;
}
keys_counter.push_back(keys);
}
request_buffer.append(reinterpret_cast<void *>(keys_counter.data()),
sizeof(uint32_t) * keys_counter.size());
if (kv_request_count == 0) {
closure->Run();
} else {
closure->request(i)->set_cmd_id(PS_PULL_SPARSE_TABLE);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
closure->request(i)->add_params((char *)&kv_request_count, // NOLINT
sizeof(uint32_t));
PsService_Stub rpc_stub(get_cmd_channel(i));
closure->cntl(i)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
}
return fut;
}
std::future<int32_t> BrpcPsClient::send_client2client_msg(
int msg_type, int to_client_id, const std::string &msg) {
auto promise = std::make_shared<std::promise<int32_t>>();
......@@ -1067,12 +1173,14 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
std::string var_name = "";
int64_t var_num = 0;
int64_t var_shape = 0;
std::string table_class;
const auto &worker_param = _config.worker_param().downpour_worker_param();
for (size_t i = 0; i < worker_param.downpour_table_param_size(); ++i) {
if (worker_param.downpour_table_param(i).table_id() == table_id) {
var_name = worker_param.downpour_table_param(i).common().table_name();
var_num = worker_param.downpour_table_param(i).common().table_num();
var_shape = worker_param.downpour_table_param(i).common().table_dim();
table_class = worker_param.downpour_table_param(i).table_class();
break;
}
}
......@@ -1094,9 +1202,19 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
save_vec.push_back(save_huge_vec.data() + i * var_shape);
}
auto status = pull_sparse(reinterpret_cast<float **>(save_vec.data()),
table_id, save_key.data(), save_key.size(), true);
status.wait();
VLOG(2) << "recv_and_save_table: table_class: " << table_class;
// TODO(zhaocaibei123): new GeoBrpcPSClient, move this to its
// recv_and_save_table
if (table_class == "MemorySparseGeoTable") {
auto status =
pull_sparse_param(reinterpret_cast<float **>(save_vec.data()), table_id,
save_key.data(), save_key.size(), true);
status.wait();
} else {
auto status = pull_sparse(reinterpret_cast<float **>(save_vec.data()),
table_id, save_key.data(), save_key.size(), true);
status.wait();
}
// create lod tensor
std::shared_ptr<framework::Scope> scope;
......
......@@ -194,6 +194,10 @@ class BrpcPsClient : public PSClient {
size_t table_id,
const uint64_t *keys, size_t num,
bool is_training);
virtual std::future<int32_t> pull_sparse_param(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num, bool is_training);
virtual std::future<int32_t> print_table_stat(uint32_t table_id);
......
......@@ -354,7 +354,7 @@ void Communicator::RpcRecvSparse(const std::string &varname, int table_id,
bool training = true;
auto status = _worker_ptr->pull_sparse(
auto status = _worker_ptr->pull_sparse_param(
(float **)push_g_vec.data(), table_id, // NOLINT
sparse_push_keys.data(), sparse_push_keys.size(), training);
status.wait();
......@@ -1029,7 +1029,7 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
auto &sparse_ids_set = iter.second;
auto sparse_ids_vec = std::make_shared<std::vector<int64_t>>();
sparse_ids_vec->assign(sparse_ids_set.begin(), sparse_ids_set.end());
sparse_id_queues_.at(key)->Push(sparse_ids_vec);
sparse_id_queues_.at(key)->Put(sparse_ids_vec);
VLOG(3) << "push " << sparse_ids_vec->size() << " ids to " << key
<< "'s queue";
}
......@@ -1051,7 +1051,10 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
if (!ctx.is_sparse) continue;
if (!ctx.is_sparse) {
parallel_task_nums_ += 1;
continue;
}
auto &varnames = ctx.origin_varnames;
PADDLE_ENFORCE_EQ(
varnames.size(), 1,
......@@ -1060,12 +1063,11 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
for (auto &splited_var : ctx.splited_varnames) {
parallel_task_nums_ += 1;
sparse_id_queues_.insert(
std::pair<std::string, std::shared_ptr<BlockingQueue<
std::shared_ptr<std::vector<int64_t>>>>>(
std::pair<std::string, paddle::framework::Channel<
std::shared_ptr<std::vector<int64_t>>>>(
splited_var,
std::make_shared<
BlockingQueue<std::shared_ptr<std::vector<int64_t>>>>(
send_queue_size_)));
paddle::framework::MakeChannel<
std::shared_ptr<std::vector<int64_t>>>(send_queue_size_)));
}
}
......@@ -1242,8 +1244,8 @@ std::vector<int64_t> GeoCommunicator::MergeSparseIds(
VLOG(3) << "Merge Number of " << send_varname << " = " << merge_num;
if (sparse_id_queues_.at(send_varname)->Size() > 0) {
wait_times = 0;
std::shared_ptr<std::vector<int64_t>> pop_ids =
sparse_id_queues_.at(send_varname)->Pop();
std::shared_ptr<std::vector<int64_t>> pop_ids = nullptr;
sparse_id_queues_.at(send_varname)->Get(pop_ids);
for (size_t j = 0; j < pop_ids->size(); j++) {
sparse_ids.insert(pop_ids->at(j));
}
......@@ -1268,6 +1270,9 @@ void GeoCommunicator::SendSparse(const std::string &varname,
std::vector<int64_t> &sparse_ids, int table_id,
int ep_idx) {
platform::RecordEvent record_event("GeoCommunicator->SendSparse");
if (sparse_ids.size() == 0) {
return;
}
std::string param_name = SplitedGradToParam(varname);
VLOG(1) << "In GeoCommunicator::SendSparse(" << varname << " " << param_name
<< ", ids.size = " << sparse_ids.size() << ", table_id: " << table_id
......@@ -1313,6 +1318,10 @@ void GeoCommunicator::SendSparse(const std::string &varname,
t_value + j * dims1,
t_old->data<float>() + sparse_ids[j] * dims1);
push_g_vec.push_back(t_value + j * dims1);
VLOG(5) << "DEBUG GeoCommunicator::SendSparse send sparse key "
<< sparse_ids[j] << " value[0] " << push_g_vec[j][0]
<< " value[-1] " << push_g_vec[j][dims1 - 1];
}
++_async_call_num;
......@@ -1367,6 +1376,9 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int table_id,
cpu_ctx);
for (auto j = 0; j < static_cast<int>(keys.size()); ++j) {
VLOG(5) << "DEBUG GeoCommunicator::RecvSparse recv sparse key" << keys[j]
<< "value[0] " << values[j * dims1] << " value[-1] "
<< values[j * dims1 + dims1 - 1];
float *latest_data = t_latest->data<float>() + keys[j] * dims1;
float *old_data = t_old->data<float>() + keys[j] * dims1;
// pserver - old => delta
......
......@@ -30,6 +30,7 @@ limitations under the License. */
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h"
......@@ -626,9 +627,8 @@ class GeoCommunicator : public AsyncCommunicator {
// parameter on pserver
std::shared_ptr<Scope> pserver_scope_;
std::unordered_map<
std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<std::vector<int64_t>>>>>
std::unordered_map<std::string, paddle::framework::Channel<
std::shared_ptr<std::vector<int64_t>>>>
sparse_id_queues_;
};
......
......@@ -128,6 +128,17 @@ class PSClient {
const uint64_t *keys, size_t num,
bool is_training) = 0;
virtual std::future<int32_t> pull_sparse_param(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num, bool is_training) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual ::std::future<int32_t> pull_sparse_ptr(char **select_values,
size_t table_id,
const uint64_t *keys,
......
......@@ -47,6 +47,9 @@ cc_library(sparse_sgd_rule SRCS sparse_sgd_rule.cc DEPS ${TABLE_DEPS} ps_framewo
cc_library(ctr_accessor SRCS ctr_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule)
cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_proto ${TABLE_DEPS} fs afs_wrapper ctr_accessor common_table)
cc_library(table SRCS table.cc DEPS memory_sparse_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost)
set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(memory_sparse_geo_table SRCS memory_sparse_geo_table.cc DEPS ps_framework_proto ${TABLE_DEPS} common_table)
cc_library(table SRCS table.cc DEPS memory_sparse_table memory_sparse_geo_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost)
target_link_libraries(table -fopenmp)
......@@ -15,13 +15,9 @@
#pragma once
#include <ThreadPool.h>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
namespace paddle {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h"
namespace paddle {
namespace distributed {
int32_t MemorySparseGeoTable::push_sparse_param(const uint64_t* keys,
const float* values,
size_t num) {
VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse_param begin "
"push_sparse_param "
<< num;
auto shard_num = _task_pool_size;
std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(shard_num);
for (int x = 0; x < num; ++x) {
auto y = keys[x] % shard_num;
offset_bucket[y].push_back(x);
if (x < 10) {
VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse_param key: "
<< keys[x] << " shard: " << y;
}
}
std::vector<std::future<int>> tasks(shard_num);
for (int shard_id = 0; shard_id < shard_num; ++shard_id) {
tasks[shard_id] = _shards_task_pool[shard_id]->enqueue(
[this, shard_id, &keys, &offset_bucket, &values]() -> int {
auto& local_shard = _local_shards[shard_id];
auto& offsets = offset_bucket[shard_id];
for (int i = 0; i < offsets.size(); ++i) {
auto offset = offsets[i];
auto id = keys[offset];
auto& feature_value = local_shard[id];
feature_value.resize(_dim);
std::copy_n(values + _dim * offset, _dim, feature_value.data());
if (i < 10) {
VLOG(5) << "MemorySparseGeoTable::push_sparse_param "
"push_sparse_param key "
<< id << " value[0]: " << (values + _dim * offset)[0]
<< " data: " << feature_value.data()[0]
<< " value[-1]: " << (values + _dim * offset)[_dim - 1]
<< " data: " << feature_value.data()[_dim - 1];
}
}
return 0;
});
}
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait();
}
return 0;
}
int32_t MemorySparseGeoTable::pull_geo_param(const uint32_t trainer_id,
std::vector<float>* values,
std::vector<uint64_t>* ids) {
_geo_recorder->GetAndClear(trainer_id, ids);
VLOG(5)
<< "DEBUG MemorySparseGeoTable::pull_geo_param pull_geo_param trainer_id "
<< trainer_id << " id_num: " << ids->size();
std::vector<uint32_t> frequencies;
frequencies.resize(ids->size(), 1);
auto pull_value = PullSparseValue(ids->size(), _dim);
pull_value.is_training_ = true;
pull_value.feasigns_ = ids->data();
pull_value.frequencies_ = frequencies.data();
values->resize(ids->size() * _dim);
pull_sparse(values->data(), pull_value);
return 0;
}
int32_t MemorySparseGeoTable::push_sparse(const uint64_t* keys,
const float* values, size_t num) {
VLOG(5) << "DEBUG MemorySparseGeoTable::push_sparse keys[0]" << keys[0]
<< " key_num: " << num;
std::vector<uint64_t> ids;
ids.resize(num);
std::copy_n(keys, num, ids.begin());
_geo_recorder->Update(ids);
_push_sparse(keys, values, num);
return 0;
}
int32_t MemorySparseGeoTable::initialize() {
if (!_geo_recorder) {
auto trainers = _config.common().trainer_num();
_geo_recorder = std::make_shared<GeoRecorder>(trainers);
}
_dim = _config.common().dims()[0];
_shards_task_pool.resize(_task_pool_size);
for (int i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1));
}
_local_shards.reset(new shard_type[_task_pool_size]);
return 0;
}
int32_t MemorySparseGeoTable::pull_sparse(float* pull_values,
const PullSparseValue& pull_value) {
auto shard_num = _task_pool_size;
std::vector<std::future<int>> tasks(shard_num);
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(shard_num);
size_t num = pull_value.numel_;
for (size_t i = 0; i < num; ++i) {
int shard_id = pull_value.feasigns_[i] % shard_num;
task_keys[shard_id].push_back({pull_value.feasigns_[i], i});
}
for (int shard_id = 0; shard_id < shard_num; ++shard_id) {
tasks[shard_id] = _shards_task_pool[shard_id]->enqueue(
[this, shard_id, &task_keys, pull_values]() -> int {
auto& local_shard = _local_shards[shard_id];
auto& keys = task_keys[shard_id];
for (size_t i = 0; i < keys.size(); i++) {
uint64_t key = keys[i].first;
auto offset = keys[i].second;
float* select_data = pull_values + _dim * offset;
auto itr = local_shard.find(key);
if (itr == local_shard.end()) {
// ++missed_keys;
auto& feature_value = local_shard[key];
feature_value.resize(_dim);
memset(feature_value.data(), 0, sizeof(float) * _dim);
VLOG(0) << "MemorySparseGeoTable pull_sparse key not found!!! "
<< key;
itr = local_shard.find(key);
}
memcpy(select_data, itr.value().data(), _dim * sizeof(float));
VLOG(5) << "DEBUG MemorySparseGeoTable::pull_sparse key: " << key
<< " select_data[0] " << select_data[0]
<< " value[0]: " << itr.value().data()[0];
}
return 0;
});
}
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait();
}
return 0;
}
int32_t MemorySparseGeoTable::_push_sparse(const uint64_t* keys,
const float* values, size_t num) {
auto shard_num = _task_pool_size;
std::vector<std::future<int>> tasks(shard_num);
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(shard_num);
for (size_t i = 0; i < num; ++i) {
int shard_id = keys[i] % shard_num;
task_keys[shard_id].push_back({keys[i], i});
}
for (size_t shard_id = 0; shard_id < shard_num; ++shard_id) {
tasks[shard_id] = _shards_task_pool[shard_id]->enqueue(
[this, shard_id, values, &task_keys]() -> int {
auto& keys = task_keys[shard_id];
auto& local_shard = _local_shards[shard_id];
auto blas = GetBlas<float>();
for (int i = 0; i < keys.size(); ++i) {
uint64_t key = keys[i].first;
uint64_t push_data_idx = keys[i].second;
const float* update_data = values + push_data_idx * _dim;
auto itr = local_shard.find(key);
if (itr == local_shard.end()) {
VLOG(0) << "sparse geo table push not found key!!! " << key;
auto& feature_value = local_shard[key];
feature_value.resize(_dim);
memset(feature_value.data(), 0, sizeof(float) * _dim);
itr = local_shard.find(key);
}
auto& feature_value = itr.value();
float* value_data = feature_value.data();
VLOG(5) << "DEBUG MemorySparseGeoTable::_push_sparse before key: "
<< key << " update_data[0] " << update_data[0]
<< " value[0]: " << value_data[0];
blas.VADD(_dim, update_data, value_data, value_data);
VLOG(5) << "DEBUG MemorySparseGeoTable::_push_sparse after key: "
<< key << " value[0]: " << value_data[0];
}
return 0;
});
}
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait();
}
return 0;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <assert.h>
// #include <pthread.h>
#include <stdint.h>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/common_table.h"
#include "paddle/fluid/distributed/ps/table/depends/feature_value.h"
#include "paddle/fluid/distributed/ps/table/depends/geo_recorder.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
class GeoRecorder;
class MemorySparseGeoTable : public SparseTable {
public:
typedef SparseTableShard<uint64_t, FixedFeatureValue> shard_type;
MemorySparseGeoTable() { _geo_recorder = nullptr; }
virtual ~MemorySparseGeoTable() {}
virtual int32_t initialize();
virtual int32_t initialize_shard() { return 0; }
virtual int32_t load(const std::string& path, const std::string& param) {
return 0;
}
virtual int32_t save(const std::string& path, const std::string& param) {
return 0;
}
virtual int32_t flush() { return 0; }
virtual int32_t shrink(const std::string& param) { return 0; }
virtual void clear() { return; }
virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value);
int32_t push_sparse_param(const uint64_t* keys, const float* values,
size_t num);
// TODO(zhaocaibei123): change to pull_sparse, and rename pull_sparse
int32_t pull_geo_param(const uint32_t trainer_id, std::vector<float>* values,
std::vector<uint64_t>* keys);
int32_t push_sparse(const uint64_t* keys, const float* values,
size_t num) override;
int32_t _push_sparse(const uint64_t* keys, const float* values, size_t num);
// int32_t _pull_sparse(float* pull_values, const PullSparseValue&
// pull_value);
private:
std::shared_ptr<GeoRecorder> _geo_recorder;
const int _task_pool_size = 10;
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
std::unique_ptr<shard_type[]> _local_shards;
int _dim;
};
} // namespace distributed
} // namespace paddle
......@@ -20,6 +20,7 @@
#include "paddle/fluid/distributed/ps/table/common_dense_table.h"
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/distributed/ps/table/common_sparse_table.h"
#include "paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h"
#include "paddle/fluid/distributed/ps/table/sparse_geo_table.h"
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/distributed/ps/table/ssd_sparse_table.h"
......@@ -43,6 +44,7 @@ REGISTER_PSCORE_CLASS(Table, TensorTable);
REGISTER_PSCORE_CLASS(Table, DenseTensorTable);
REGISTER_PSCORE_CLASS(Table, GlobalStepTable);
REGISTER_PSCORE_CLASS(Table, MemorySparseTable);
REGISTER_PSCORE_CLASS(Table, MemorySparseGeoTable);
REGISTER_PSCORE_CLASS(ValueAccessor, CommMergeAccessor);
REGISTER_PSCORE_CLASS(ValueAccessor, CtrCommonAccessor);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, StdAdaGradSGDRule);
......
......@@ -35,3 +35,6 @@ cc_test(ctr_accessor_test SRCS ctr_accessor_test.cc DEPS ${COMMON_DEPS} boost ta
set_source_files_properties(memory_sparse_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(memory_sparse_table_test SRCS memory_sparse_table_test.cc DEPS ${COMMON_DEPS} boost table)
set_source_files_properties(memory_geo_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(memory_sparse_geo_table_test SRCS memory_geo_table_test.cc DEPS ${COMMON_DEPS} boost table)
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <ThreadPool.h>
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/table/depends/sparse_utils.h"
#include "paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h"
#include "paddle/fluid/distributed/ps/table/table.h"
namespace paddle {
namespace distributed {
// MemorySparseGeoTable
TEST(MemorySparseGeoTable, SSUM) {
int emb_dim = 10;
int trainers = 2;
TableParameter table_config;
table_config.set_table_class("MemorySparseGeoTable");
FsClientParameter fs_config;
Table *table = new MemorySparseGeoTable();
TableAccessorParameter *accessor_config = table_config.mutable_accessor();
accessor_config->set_accessor_class("CommMergeAccessor");
accessor_config->set_fea_dim(10);
CommonAccessorParameter *common_config = table_config.mutable_common();
common_config->set_name("sum");
common_config->set_table_name("ssum_test_table");
common_config->set_trainer_num(trainers);
common_config->add_params("Param");
common_config->add_dims(emb_dim);
common_config->add_initializers("fill_constant&1.0");
auto ret = table->initialize(table_config, fs_config);
ASSERT_EQ(ret, 0);
// test push_sparse_param, and create params
std::vector<uint64_t> init_keys = {0, 1, 2, 3, 4};
std::vector<uint32_t> init_fres = {1, 1, 1, 1, 1};
std::vector<float> init_values;
for (size_t i = 0; i < init_keys.size() * emb_dim; i++) {
init_values.push_back(0.0);
}
table->push_sparse_param(init_keys.data(), init_values.data(),
init_keys.size());
std::vector<float> pull_values(init_values.size());
auto value = PullSparseValue(init_keys, init_fres, emb_dim);
table->pull_sparse(pull_values.data(), value);
for (size_t i = 0; i < init_keys.size() * emb_dim; i++) {
ASSERT_TRUE(abs(pull_values[i] - init_values[i]) < 1e-5);
}
std::vector<std::vector<uint64_t>> trainer_keys;
std::vector<std::vector<float>> trainer_values;
trainer_keys.resize(trainers);
trainer_values.resize(trainers);
float start = 0.0;
for (int i = 0; i < trainers; i++) {
trainer_keys[i] = init_keys;
for (size_t j = 0; j < trainer_keys[i].size(); j++) {
auto id = trainer_keys[i][j];
for (int k = 0; k < emb_dim; k++) {
trainer_values[i].push_back(start);
pull_values[id * emb_dim + k] += start;
start += 0.1;
}
}
}
std::shared_ptr<::ThreadPool> pool_ =
std::make_shared<::ThreadPool>(trainers);
std::vector<std::future<void>> task_status;
for (int i = 0; i < trainers; i++) {
auto &push_keys = trainer_keys[i];
auto &push_values = trainer_values[i];
auto task = [table, &push_keys, &push_values] {
table->push_sparse(push_keys.data(), push_values.data(),
push_keys.size());
};
task_status.push_back(pool_->enqueue(std::move(task)));
}
for (auto &status : task_status) {
status.wait();
}
std::vector<std::vector<uint64_t>> geo_pull_ids;
std::vector<std::vector<float>> geo_pull_values;
geo_pull_ids.resize(trainers);
geo_pull_values.resize(trainers);
for (int i = 0; i < trainers; i++) {
table->pull_geo_param(i, &geo_pull_values[i], &geo_pull_ids[i]);
ASSERT_EQ(geo_pull_values[i].size(), geo_pull_ids[i].size() * emb_dim);
for (size_t j = 0; j < geo_pull_ids[i].size(); ++j) {
auto id = geo_pull_ids[i][j];
for (int k = 0; k < emb_dim; k++) {
ASSERT_TRUE(abs(geo_pull_values[i][j * emb_dim + k] -
pull_values[id * emb_dim + k]) < 1e-5);
}
}
}
}
} // namespace distributed
} // namespace paddle
......@@ -943,7 +943,7 @@ class TheOnePSRuntime(RuntimeBase):
ctx.origin_varnames()[0]]
if self.compiled_strategy.is_geo_mode():
table.table_class = "SparseGeoTable"
table.table_class = "MemorySparseGeoTable"
else:
all_table_proto = self.context[
"user_defined_strategy"].sparse_table_configs
......@@ -1306,6 +1306,7 @@ class TheOnePSRuntime(RuntimeBase):
is_dense=True,
split_dense_table=self.role_maker._is_heter_parameter_server_mode,
use_origin_program=True)
# TODO(zhaocaibei123): for GEO: should call GeoCommunicator::RecvDense
self._communicator.pull_dense(denses)
generate_vars = self.context[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册