未验证 提交 f1201482 编写于 作者: Z zhaocaibei123 提交者: GitHub

pscore perfermance optimization (#38582)

上级 050fd168
...@@ -19,7 +19,7 @@ IF((NOT DEFINED LIBMCT_VER) OR (NOT DEFINED LIBMCT_URL)) ...@@ -19,7 +19,7 @@ IF((NOT DEFINED LIBMCT_VER) OR (NOT DEFINED LIBMCT_URL))
MESSAGE(STATUS "use pre defined download url") MESSAGE(STATUS "use pre defined download url")
SET(LIBMCT_VER "0.1.0" CACHE STRING "" FORCE) SET(LIBMCT_VER "0.1.0" CACHE STRING "" FORCE)
SET(LIBMCT_NAME "libmct" CACHE STRING "" FORCE) SET(LIBMCT_NAME "libmct" CACHE STRING "" FORCE)
SET(LIBMCT_URL "https://pslib.bj.bcebos.com/libmct.tar.gz" CACHE STRING "" FORCE) SET(LIBMCT_URL "https://pslib.bj.bcebos.com/libmct/libmct.tar.gz" CACHE STRING "" FORCE)
ENDIF() ENDIF()
MESSAGE(STATUS "LIBMCT_NAME: ${LIBMCT_NAME}, LIBMCT_URL: ${LIBMCT_URL}") MESSAGE(STATUS "LIBMCT_NAME: ${LIBMCT_NAME}, LIBMCT_URL: ${LIBMCT_URL}")
SET(LIBMCT_PREFIX_DIR "${THIRD_PARTY_PATH}/libmct") SET(LIBMCT_PREFIX_DIR "${THIRD_PARTY_PATH}/libmct")
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
namespace paddle {
namespace distributed {
// Fast allocation and deallocation of objects by allocating them in chunks.
template <class T>
class ChunkAllocator {
public:
explicit ChunkAllocator(size_t chunk_size = 64) {
CHECK(sizeof(Node) == std::max(sizeof(void*), sizeof(T)));
_chunk_size = chunk_size;
_chunks = NULL;
_free_nodes = NULL;
_counter = 0;
}
ChunkAllocator(const ChunkAllocator&) = delete;
~ChunkAllocator() {
while (_chunks != NULL) {
Chunk* x = _chunks;
_chunks = _chunks->next;
free(x);
}
}
template <class... ARGS>
T* acquire(ARGS&&... args) {
if (_free_nodes == NULL) {
create_new_chunk();
}
T* x = (T*)(void*)_free_nodes; // NOLINT
_free_nodes = _free_nodes->next;
new (x) T(std::forward<ARGS>(args)...);
_counter++;
return x;
}
void release(T* x) {
x->~T();
Node* node = (Node*)(void*)x; // NOLINT
node->next = _free_nodes;
_free_nodes = node;
_counter--;
}
size_t size() const { return _counter; }
private:
struct alignas(T) Node {
union {
Node* next;
char data[sizeof(T)];
};
};
struct Chunk {
Chunk* next;
Node nodes[];
};
size_t _chunk_size; // how many elements in one chunk
Chunk* _chunks; // a list
Node* _free_nodes; // a list
size_t _counter; // how many elements are acquired
void create_new_chunk() {
Chunk* chunk;
posix_memalign(reinterpret_cast<void**>(&chunk),
std::max<size_t>(sizeof(void*), alignof(Chunk)),
sizeof(Chunk) + sizeof(Node) * _chunk_size);
chunk->next = _chunks;
_chunks = chunk;
for (size_t i = 0; i < _chunk_size; i++) {
Node* node = &chunk->nodes[i];
node->next = _free_nodes;
_free_nodes = node;
}
}
};
} // namespace distributed
} // namespace paddle
...@@ -460,25 +460,7 @@ void FleetWrapper::PushSparseFromTensorAsync( ...@@ -460,25 +460,7 @@ void FleetWrapper::PushSparseFromTensorAsync(
clks->lod().size() ? clks->lod()[0].size() - 1 : clks->dims()[0]; clks->lod().size() ? clks->lod()[0].size() - 1 : clks->dims()[0];
CHECK(clk_size == batch_size || clk_size == 1); CHECK(clk_size == batch_size || clk_size == 1);
std::vector<float> g; CHECK(outputs->size() == inputs->size());
for (framework::LoDTensor* g_tensor : *outputs) {
float* g_ori = g_tensor->data<float>();
// no cvm
if (batch_size_consist) { // TODO(zhaocaibei123): add config
// scale_sparse_gradient_with_batch_size_
Eigen::Map<
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
g_mat(g_ori, g_tensor->numel() / fea_dim, fea_dim);
g_mat.rightCols(fea_dim) *= batch_size;
}
size_t origin = g.size();
size_t add = g_tensor->numel();
g.resize(origin + add);
memcpy(g.data() + origin, g_tensor->data<float>(), add * sizeof(float));
}
std::vector<uint64_t> push_keys; std::vector<uint64_t> push_keys;
push_keys.reserve(MAX_FEASIGN_NUM / 100); push_keys.reserve(MAX_FEASIGN_NUM / 100);
std::vector<std::vector<float>> push_values; std::vector<std::vector<float>> push_values;
...@@ -495,9 +477,21 @@ void FleetWrapper::PushSparseFromTensorAsync( ...@@ -495,9 +477,21 @@ void FleetWrapper::PushSparseFromTensorAsync(
const int64_t* clk_tensor = clks->data<int64_t>(); const int64_t* clk_tensor = clks->data<int64_t>();
for (size_t index = 0; index < inputs->size(); ++index) { for (size_t index = 0; index < inputs->size(); ++index) {
framework::LoDTensor* g_tensor = outputs->at(index);
float* g = g_tensor->data<float>();
// no cvm
if (batch_size_consist) { // TODO(zhaocaibei123): add config
// scale_sparse_gradient_with_batch_size_
Eigen::Map<
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
g_mat(g, g_tensor->numel() / fea_dim, fea_dim);
g_mat.rightCols(fea_dim) *= batch_size;
}
const framework::LoDTensor* tensor = inputs->at(index); const framework::LoDTensor* tensor = inputs->at(index);
const int64_t* ids = tensor->data<int64_t>(); const int64_t* ids = tensor->data<int64_t>();
size_t len = tensor->numel(); size_t len = tensor->numel();
output_len = 0;
if (tensor->lod().size() > 0) { if (tensor->lod().size() > 0) {
for (size_t i = 0; i < tensor->lod()[0].size() - 1; ++i) { for (size_t i = 0; i < tensor->lod()[0].size() - 1; ++i) {
...@@ -519,7 +513,7 @@ void FleetWrapper::PushSparseFromTensorAsync( ...@@ -519,7 +513,7 @@ void FleetWrapper::PushSparseFromTensorAsync(
float* data = push_values.back().data() + 3; float* data = push_values.back().data() + 3;
memcpy(data, g.data() + output_len, sizeof(float) * fea_dim); memcpy(data, g + output_len, sizeof(float) * fea_dim);
++input_idx; ++input_idx;
} }
...@@ -542,14 +536,13 @@ void FleetWrapper::PushSparseFromTensorAsync( ...@@ -542,14 +536,13 @@ void FleetWrapper::PushSparseFromTensorAsync(
float* data = push_values.back().data() + 3; float* data = push_values.back().data() + 3;
memcpy(data, g.data() + output_len, sizeof(float) * fea_dim); memcpy(data, g + output_len, sizeof(float) * fea_dim);
++input_idx; ++input_idx;
} }
} }
CHECK(output_len == g_tensor->numel());
} }
VLOG(1) << "output_len: " << output_len << " g.size(): " << g.size();
CHECK(output_len == g.size());
std::vector<float*> push_g_vec(input_idx, nullptr); std::vector<float*> push_g_vec(input_idx, nullptr);
......
...@@ -210,6 +210,23 @@ int32_t BrpcPsClient::initialize() { ...@@ -210,6 +210,23 @@ int32_t BrpcPsClient::initialize() {
} }
} }
auto &profiler = CostProfiler::instance();
profiler.register_profiler("pserver_client_pull_dense");
profiler.register_profiler("pserver_client_pull_sparse");
profiler.register_profiler("pserver_client_pull_sparse_local");
profiler.register_profiler("pserver_client_push_sparse");
profiler.register_profiler("pserver_client_push_sparse_parse");
profiler.register_profiler("client_push_sparse_put");
profiler.register_profiler("pserver_client_push_sparse");
profiler.register_profiler("pserver_client_push_sparse_merge");
profiler.register_profiler("pserver_client_push_sparse_rpc");
profiler.register_profiler("pserver_client_push_dense");
profiler.register_profiler("pserver_client_push_dense_parse");
profiler.register_profiler("push_dense_put");
profiler.register_profiler("pserver_client_push_dense_merge");
profiler.register_profiler("pserver_client_push_dense_rpc");
profiler.register_profiler("pserver_client_push_dense_send");
_running = true; _running = true;
_flushing = false; _flushing = false;
// 启动异步push线程 // 启动异步push线程
...@@ -588,6 +605,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_param( ...@@ -588,6 +605,7 @@ std::future<int32_t> BrpcPsClient::push_sparse_param(
std::future<int32_t> BrpcPsClient::pull_dense(Region *regions, std::future<int32_t> BrpcPsClient::pull_dense(Region *regions,
size_t region_num, size_t region_num,
size_t table_id) { size_t table_id) {
auto timer = std::make_shared<CostTimer>("pserver_client_pull_dense");
auto *accessor = table_accessor(table_id); auto *accessor = table_accessor(table_id);
size_t request_call_num = _server_channels.size(); size_t request_call_num = _server_channels.size();
uint32_t num_per_shard = uint32_t num_per_shard =
...@@ -643,6 +661,7 @@ std::future<int32_t> BrpcPsClient::pull_dense(Region *regions, ...@@ -643,6 +661,7 @@ std::future<int32_t> BrpcPsClient::pull_dense(Region *regions,
} }
closure->set_promise_value(ret); closure->set_promise_value(ret);
}); });
closure->add_timer(timer);
auto promise = std::make_shared<std::promise<int32_t>>(); auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise); closure->add_promise(promise);
std::future<int> fut = promise->get_future(); std::future<int> fut = promise->get_future();
...@@ -865,6 +884,9 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values, ...@@ -865,6 +884,9 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
size_t table_id, size_t table_id,
const uint64_t *keys, size_t num, const uint64_t *keys, size_t num,
bool is_training) { bool is_training) {
auto timer = std::make_shared<CostTimer>("pserver_client_pull_sparse");
auto local_timer =
std::make_shared<CostTimer>("pserver_client_pull_sparse_local");
size_t request_call_num = _server_channels.size(); size_t request_call_num = _server_channels.size();
auto shard_sorted_kvs = std::make_shared< auto shard_sorted_kvs = std::make_shared<
...@@ -925,7 +947,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values, ...@@ -925,7 +947,7 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
} }
closure->set_promise_value(ret); closure->set_promise_value(ret);
}); });
closure->add_timer(timer);
auto promise = std::make_shared<std::promise<int32_t>>(); auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise); closure->add_promise(promise);
std::future<int> fut = promise->get_future(); std::future<int> fut = promise->get_future();
...@@ -1110,8 +1132,8 @@ std::future<int32_t> BrpcPsClient::push_sparse(size_t table_id, ...@@ -1110,8 +1132,8 @@ std::future<int32_t> BrpcPsClient::push_sparse(size_t table_id,
const uint64_t *keys, const uint64_t *keys,
const float **update_values, const float **update_values,
size_t num) { size_t num) {
auto push_timer = auto push_timer = std::make_shared<CostTimer>("pserver_client_push_sparse");
std::make_shared<CostTimer>("pserver_client_push_sparse_parse"); CostTimer parse_timer("pserver_client_push_sparse_parse");
int push_sparse_async_num = _push_sparse_task_queue_map[table_id]->Size(); int push_sparse_async_num = _push_sparse_task_queue_map[table_id]->Size();
while (push_sparse_async_num > FLAGS_pserver_max_async_call_num) { while (push_sparse_async_num > FLAGS_pserver_max_async_call_num) {
// LOG(INFO) << "push_sparse Waiting for async_call_num comsume, task_num:" // LOG(INFO) << "push_sparse Waiting for async_call_num comsume, task_num:"
...@@ -1121,6 +1143,7 @@ std::future<int32_t> BrpcPsClient::push_sparse(size_t table_id, ...@@ -1121,6 +1143,7 @@ std::future<int32_t> BrpcPsClient::push_sparse(size_t table_id,
// push_sparse_async_num = _push_sparse_task_queue_map[table_id]->size(); // push_sparse_async_num = _push_sparse_task_queue_map[table_id]->size();
push_sparse_async_num = _push_sparse_task_queue_map[table_id]->Size(); push_sparse_async_num = _push_sparse_task_queue_map[table_id]->Size();
} }
auto put_timer = std::make_shared<CostTimer>("client_push_sparse_put");
thread_local std::vector<std::vector<std::pair<uint64_t, const float *>>> thread_local std::vector<std::vector<std::pair<uint64_t, const float *>>>
shard_sorted_kv_list; shard_sorted_kv_list;
auto *accessor = table_accessor(table_id); auto *accessor = table_accessor(table_id);
...@@ -1250,14 +1273,14 @@ void BrpcPsClient::push_sparse_task_consume() { ...@@ -1250,14 +1273,14 @@ void BrpcPsClient::push_sparse_task_consume() {
for_each(task_list.begin() + 1, task_list.end(), for_each(task_list.begin() + 1, task_list.end(),
[&request_kv_num, request_call_num, [&request_kv_num, request_call_num,
closure](std::shared_ptr<SparseAsyncTask> &task) { closure](std::shared_ptr<SparseAsyncTask> &task) {
// closure->add_timer(task->timer()); closure->add_timer(task->timer());
closure->add_promise(task->promise()); closure->add_promise(task->promise());
}); });
// CostTimer merge_timer("pserver_client_push_sparse_merge"); CostTimer merge_timer("pserver_client_push_sparse_merge");
// auto rpc_timer = auto rpc_timer =
// std::make_shared<CostTimer>("pserver_client_push_sparse_rpc"); std::make_shared<CostTimer>("pserver_client_push_sparse_rpc");
// closure->add_timer(rpc_timer); closure->add_timer(rpc_timer);
std::vector<std::future<int>> merge_status(request_call_num); std::vector<std::future<int>> merge_status(request_call_num);
for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
...@@ -1295,6 +1318,7 @@ void BrpcPsClient::push_sparse_task_consume() { ...@@ -1295,6 +1318,7 @@ void BrpcPsClient::push_sparse_task_consume() {
std::vector<std::future<int>>().swap(merge_status); std::vector<std::future<int>>().swap(merge_status);
} }
} }
timeline.Pause();
auto wait_ms = auto wait_ms =
FLAGS_pserver_async_push_sparse_interval_ms - (timeline.ElapsedMS()); FLAGS_pserver_async_push_sparse_interval_ms - (timeline.ElapsedMS());
if (wait_ms > 0) { if (wait_ms > 0) {
...@@ -1464,10 +1488,12 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions, ...@@ -1464,10 +1488,12 @@ std::future<int32_t> BrpcPsClient::push_dense(const Region *regions,
usleep(5000); // 5ms usleep(5000); // 5ms
push_dense_async_num = _push_dense_task_queue_map[table_id]->Size(); push_dense_async_num = _push_dense_task_queue_map[table_id]->Size();
} }
auto push_dense_timer = std::make_shared<CostTimer>("push_dense_put");
// auto dense_data = _dense_matrix_obj_pool.get(); // auto dense_data = _dense_matrix_obj_pool.get();
auto dense_data = std::make_shared<std::vector<float>>(); auto dense_data = std::make_shared<std::vector<float>>();
auto async_task = new DenseAsyncTask(dense_data, table_id, push_timer); auto async_task = new DenseAsyncTask(dense_data, table_id, push_timer);
size_t request_call_num = _server_channels.size(); size_t request_call_num = _server_channels.size();
uint32_t num_per_shard = uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num); dense_dim_per_shard(accessor->fea_dim(), request_call_num);
...@@ -1567,6 +1593,7 @@ void BrpcPsClient::push_dense_task_consume() { ...@@ -1567,6 +1593,7 @@ void BrpcPsClient::push_dense_task_consume() {
<< total_send_data[total_send_data_size - 2] << total_send_data[total_send_data_size - 2]
<< total_send_data[0] << " total_send_data[-1]" << total_send_data[0] << " total_send_data[-1]"
<< total_send_data[total_send_data_size - 1]; << total_send_data[total_send_data_size - 1];
if (scale_gradient && merge_count > 1) { if (scale_gradient && merge_count > 1) {
Eigen::Map<Eigen::MatrixXf> mat(total_send_data, 1, Eigen::Map<Eigen::MatrixXf> mat(total_send_data, 1,
total_send_data_size); total_send_data_size);
...@@ -1585,6 +1612,7 @@ void BrpcPsClient::push_dense_task_consume() { ...@@ -1585,6 +1612,7 @@ void BrpcPsClient::push_dense_task_consume() {
push_dense_raw_gradient(task_ptr, total_send_data, total_send_data_size, push_dense_raw_gradient(task_ptr, total_send_data, total_send_data_size,
closure); closure);
} }
timeline.Pause();
auto wait_ms = auto wait_ms =
FLAGS_pserver_async_push_dense_interval_ms - (timeline.ElapsedMS()); FLAGS_pserver_async_push_dense_interval_ms - (timeline.ElapsedMS());
if (wait_ms > 0) { if (wait_ms > 0) {
...@@ -1603,6 +1631,8 @@ void BrpcPsClient::push_dense_raw_gradient( ...@@ -1603,6 +1631,8 @@ void BrpcPsClient::push_dense_raw_gradient(
closure->add_timer(timer); closure->add_timer(timer);
uint32_t num_per_shard = uint32_t num_per_shard =
dense_dim_per_shard(accessor->fea_dim(), request_call_num); dense_dim_per_shard(accessor->fea_dim(), request_call_num);
auto send_timer =
std::make_shared<CostTimer>("pserver_client_push_dense_send");
for (size_t i = 0; i < request_call_num; ++i) { for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE); closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE);
closure->request(i)->set_table_id(task->table_id()); closure->request(i)->set_table_id(task->table_id());
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/distributed/service/brpc_ps_server.h" #include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include <thread> // NOLINT #include <thread> // NOLINT
#include "butil/object_pool.h" #include "butil/object_pool.h"
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/distributed/table/depends/sparse_utils.h" #include "paddle/fluid/distributed/table/depends/sparse_utils.h"
#include "paddle/fluid/distributed/table/table.h" #include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/archive.h" #include "paddle/fluid/framework/archive.h"
...@@ -117,6 +118,11 @@ int32_t BrpcPsService::initialize() { ...@@ -117,6 +118,11 @@ int32_t BrpcPsService::initialize() {
_service_handler_map[PS_START_PROFILER] = &BrpcPsService::start_profiler; _service_handler_map[PS_START_PROFILER] = &BrpcPsService::start_profiler;
_service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::stop_profiler; _service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::stop_profiler;
_service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::push_global_step; _service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::push_global_step;
auto &profiler = CostProfiler::instance();
profiler.register_profiler("pserver_server_pull_dense");
profiler.register_profiler("pserver_server_push_dense");
profiler.register_profiler("pserver_server_pull_sparse");
profiler.register_profiler("pserver_server_push_sparse");
// shard初始化,server启动后才可从env获取到server_list的shard信息 // shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info(); initialize_shard_info();
...@@ -190,6 +196,7 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request, ...@@ -190,6 +196,7 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request,
"PsRequestMessage.datas is requeired at least 1 for num of dense"); "PsRequestMessage.datas is requeired at least 1 for num of dense");
return 0; return 0;
} }
CostTimer timer("pserver_server_pull_dense");
uint32_t num = *(const uint32_t *)request.params(0).c_str(); uint32_t num = *(const uint32_t *)request.params(0).c_str();
if (num < 0) { if (num < 0) {
set_response_code(response, -1, set_response_code(response, -1,
...@@ -246,6 +253,7 @@ int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request, ...@@ -246,6 +253,7 @@ int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request,
return 0; return 0;
} }
CostTimer timer("pserver_server_push_dense");
/* /*
Push Content: Push Content:
|--num--|---valuesData---| |--num--|---valuesData---|
...@@ -356,6 +364,7 @@ int32_t BrpcPsService::pull_sparse(Table *table, ...@@ -356,6 +364,7 @@ int32_t BrpcPsService::pull_sparse(Table *table,
return 0; return 0;
} }
CostTimer timer("pserver_server_pull_sparse");
uint32_t num = *(uint32_t *)(request.params(0).c_str()); uint32_t num = *(uint32_t *)(request.params(0).c_str());
auto dim = table->value_accesor()->select_dim(); auto dim = table->value_accesor()->select_dim();
...@@ -396,6 +405,7 @@ int32_t BrpcPsService::push_sparse(Table *table, ...@@ -396,6 +405,7 @@ int32_t BrpcPsService::push_sparse(Table *table,
"least 1 for num of sparse_key"); "least 1 for num of sparse_key");
return 0; return 0;
} }
CostTimer timer("pserver_server_push_sparse");
uint32_t num = *(uint32_t *)(request.params(0).c_str()); uint32_t num = *(uint32_t *)(request.params(0).c_str());
/* /*
Push Content: Push Content:
......
...@@ -16,6 +16,11 @@ set_source_files_properties(common_graph_table.cc PROPERTIES COMPILE_FLAGS ${DIS ...@@ -16,6 +16,11 @@ set_source_files_properties(common_graph_table.cc PROPERTIES COMPILE_FLAGS ${DIS
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/")
include_directories(${PADDLE_LIB_THIRD_PARTY_PATH}libmct/src/extern_libmct/libmct/include)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
set(EXTERN_DEP "") set(EXTERN_DEP "")
if(WITH_HETERPS) if(WITH_HETERPS)
set(TABLE_SRC common_sparse_table.cc ssd_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc common_graph_table.cc) set(TABLE_SRC common_sparse_table.cc ssd_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc common_graph_table.cc)
...@@ -43,3 +48,5 @@ cc_library(ctr_accessor SRCS ctr_accessor.cc DEPS ${TABLE_DEPS} ps_framework_pro ...@@ -43,3 +48,5 @@ cc_library(ctr_accessor SRCS ctr_accessor.cc DEPS ${TABLE_DEPS} ps_framework_pro
cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_proto ${TABLE_DEPS} fs afs_wrapper ctr_accessor common_table) cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_proto ${TABLE_DEPS} fs afs_wrapper ctr_accessor common_table)
cc_library(table SRCS table.cc DEPS memory_sparse_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost) cc_library(table SRCS table.cc DEPS memory_sparse_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost)
target_link_libraries(table -fopenmp)
...@@ -57,7 +57,7 @@ class CommonDenseTable : public DenseTable { ...@@ -57,7 +57,7 @@ class CommonDenseTable : public DenseTable {
int32_t _push_dense(const float* values, size_t num); int32_t _push_dense(const float* values, size_t num);
private: private:
const int task_pool_size_ = 1; const int task_pool_size_ = 10;
bool sync = true; bool sync = true;
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool; std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
int param_dim_ = 0; int param_dim_ = 0;
......
...@@ -99,6 +99,7 @@ class DSGD : public DenseOptimizer { ...@@ -99,6 +99,7 @@ class DSGD : public DenseOptimizer {
}; };
// adam optimizer for dense tensor // adam optimizer for dense tensor
// TODO(zhaocaibei123): add CHECK(common_dense_table.task_pool_size_) == 1
class DAdam : public DenseOptimizer { class DAdam : public DenseOptimizer {
public: public:
explicit DAdam(const CommonAccessorParameter& accessor, explicit DAdam(const CommonAccessorParameter& accessor,
...@@ -131,6 +132,8 @@ class DAdam : public DenseOptimizer { ...@@ -131,6 +132,8 @@ class DAdam : public DenseOptimizer {
epsilon = 1.0e-8; epsilon = 1.0e-8;
} }
// make sure common_dense_table.task_pool_size_ == 1;
// otherwise, task_pool_size_ times beta1_pow/beta2_pow multiplication
void update(const float* update_values, size_t num, int begin, void update(const float* update_values, size_t num, int begin,
int end) override { int end) override {
auto update_numel = end - begin; auto update_numel = end - begin;
...@@ -221,45 +224,35 @@ class DAdamD2Sum : public DenseOptimizer { ...@@ -221,45 +224,35 @@ class DAdamD2Sum : public DenseOptimizer {
void update(const float* update_values, size_t num, int begin, void update(const float* update_values, size_t num, int begin,
int end) override { int end) override {
auto update_numel = end - begin; auto update_numel = end - begin;
std::vector<float> grad, grad2, scale; Eigen::Map<Eigen::MatrixXf> mat_ada_g2sum(ada_g2sum + begin, 1,
grad.resize(update_numel); update_numel);
grad2.resize(update_numel);
scale.resize(update_numel); Eigen::Map<Eigen::MatrixXf> mat_ada_d2sum(ada_d2sum + begin, 1,
update_numel);
auto blas = GetBlas<float>(); Eigen::Map<Eigen::MatrixXf> mat_mom_velocity(mom_velocity + begin, 1,
// copy grad update_numel);
blas.VCOPY(update_numel, update_values + begin, grad.data()); Eigen::Map<Eigen::MatrixXf> mat_w(param + begin, 1, update_numel);
blas.VCOPY(update_numel, update_values + begin, grad2.data());
Eigen::Map<const Eigen::MatrixXf> mat_grad(update_values + begin, 1,
// d2sum update_numel);
blas.SCAL(update_numel, ada_decay_rate[0], ada_d2sum + begin);
ADD<float>(update_numel, ada_d2sum + begin, 1, ada_d2sum + begin); mat_ada_d2sum = (mat_ada_d2sum * ada_decay_rate[0]).array() + 1;
mat_ada_g2sum =
// g2sum (mat_ada_g2sum * ada_decay_rate[0]) + mat_grad.cwiseProduct(mat_grad);
blas.SCAL(update_numel, ada_decay_rate[0], ada_g2sum + begin);
blas.VSQUARE(update_numel, grad2.data(), grad2.data()); thread_local std::vector<float> scale_vec;
blas.VADD(update_numel, ada_g2sum + begin, grad2.data(), ada_g2sum + begin); scale_vec.resize(update_numel);
Eigen::Map<Eigen::MatrixXf> scale(scale_vec.data(), 1, update_numel);
// mom memcpy(scale_vec.data(), mat_ada_d2sum.data(),
blas.SCAL(update_numel, mom_decay_rate[0], mom_velocity + begin); sizeof(float) * update_numel);
blas.SCAL(update_numel, 1 - mom_decay_rate[0], grad.data());
blas.VADD(update_numel, mom_velocity + begin, grad.data(), scale = scale.array() * ada_epsilon[0];
mom_velocity + begin); scale = (mat_ada_d2sum + scale).cwiseQuotient(mat_ada_g2sum + scale);
scale = scale.cwiseSqrt();
// scale mat_mom_velocity =
float* scale_ = scale.data(); (mat_mom_velocity - mat_grad) * mom_decay_rate[0] + mat_grad;
blas.VDIV(update_numel, ada_g2sum + begin, ada_d2sum + begin, scale_);
ADD<float>(update_numel, scale_, ada_epsilon[0], scale_); mat_w -= learning_rate[0] * mat_mom_velocity.cwiseProduct(scale);
DIV<float>(update_numel, 1 + ada_epsilon[0], scale_, scale_);
SQRT<float>(update_numel, scale_, scale_);
blas.SCAL(update_numel, learning_rate[0], scale_);
// TODO(zhaocaibei123): check if there exists elementwise_multiply in blas
// TODO(zhaocaibei123): blas.VMUL
ELE_MUL<float>(update_numel, scale_, mom_velocity + begin, scale_);
blas.VSUB(update_numel, param + begin, scale_, param + begin);
} }
float* learning_rate; float* learning_rate;
......
...@@ -14,35 +14,11 @@ ...@@ -14,35 +14,11 @@
#pragma once #pragma once
#include <ThreadPool.h>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <thread> // NOLINT
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "butil/object_pool.h" #include <mct/hash-map.hpp>
#include "paddle/fluid/distributed/common/utils.h" #include "paddle/fluid/distributed/common/chunk_allocator.h"
#include "paddle/fluid/distributed/table/depends/initializers.h"
#include "paddle/fluid/distributed/thirdparty/round_robin.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -55,112 +31,169 @@ class FixedFeatureValue { ...@@ -55,112 +31,169 @@ class FixedFeatureValue {
public: public:
FixedFeatureValue() {} FixedFeatureValue() {}
~FixedFeatureValue() {} ~FixedFeatureValue() {}
float *data() { return data_.data(); } float* data() { return _data.data(); }
size_t size() { return data_.size(); } size_t size() { return _data.size(); }
void resize(size_t size) { data_.resize(size); } void resize(size_t size) { _data.resize(size); }
void shrink_to_fit() { data_.shrink_to_fit(); } void shrink_to_fit() { _data.shrink_to_fit(); }
private: private:
std::vector<float> data_; std::vector<float> _data;
}; };
class SparseTableShard { template <class KEY, class VALUE>
struct alignas(64) SparseTableShard {
public: public:
typedef typename robin_hood::unordered_map<uint64_t, FixedFeatureValue *> typedef typename mct::closed_hash_map<KEY, mct::Pointer, std::hash<KEY>>
map_type; map_type;
SparseTableShard() {} struct iterator {
~SparseTableShard() {} typename map_type::iterator it;
size_t bucket;
map_type* buckets;
friend bool operator==(const iterator& a, const iterator& b) {
return a.it == b.it;
}
friend bool operator!=(const iterator& a, const iterator& b) {
return a.it != b.it;
}
const KEY& key() const { return it->first; }
VALUE& value() const { return *(VALUE*)(void*)it->second; } // NOLINT
iterator& operator++() {
++it;
FixedFeatureValue *Init(const uint64_t &id) { while (it == buckets[bucket].end() &&
size_t hash = hasher_(id); bucket + 1 < CTR_SPARSE_SHARD_BUCKET_NUM) {
size_t bucket = compute_bucket(hash); it = buckets[++bucket].begin();
auto &table = values_[bucket]; }
FixedFeatureValue *value = nullptr; return *this;
value = butil::get_object<FixedFeatureValue>(); }
table[id] = value; iterator operator++(int) {
return value; iterator ret = *this;
++*this;
return ret;
}
};
struct local_iterator {
typename map_type::iterator it;
friend bool operator==(const local_iterator& a, const local_iterator& b) {
return a.it == b.it;
}
friend bool operator!=(const local_iterator& a, const local_iterator& b) {
return a.it != b.it;
}
const KEY& key() const { return it->first; }
VALUE& value() const { return *(VALUE*)(void*)it->second; } // NOLINT
local_iterator& operator++() {
++it;
return *this;
}
local_iterator operator++(int) { return {it++}; }
};
~SparseTableShard() { clear(); }
bool empty() { return _alloc.size() == 0; }
size_t size() { return _alloc.size(); }
void set_max_load_factor(float x) {
for (size_t bucket = 0; bucket < CTR_SPARSE_SHARD_BUCKET_NUM; bucket++) {
_buckets[bucket].max_load_factor(x);
}
} }
size_t bucket_count() { return CTR_SPARSE_SHARD_BUCKET_NUM; }
// dont judge if (has(id)) size_t bucket_size(size_t bucket) { return _buckets[bucket].size(); }
float *Get(const uint64_t &id) { void clear() {
size_t hash = hasher_(id); for (size_t bucket = 0; bucket < CTR_SPARSE_SHARD_BUCKET_NUM; bucket++) {
size_t bucket = compute_bucket(hash); map_type& data = _buckets[bucket];
auto &table = values_[bucket]; for (auto it = data.begin(); it != data.end(); ++it) {
_alloc.release((VALUE*)(void*)it->second); // NOLINT
// auto &value = table.at(id); }
// return value->data_.data(); data.clear();
auto res = table.find(id); }
FixedFeatureValue *value = res->second;
return value->data();
} }
iterator begin() {
// for load, to reset count, unseen_days auto it = _buckets[0].begin();
FixedFeatureValue *GetValue(const uint64_t &id) { size_t bucket = 0;
size_t hash = hasher_(id); while (it == _buckets[bucket].end() &&
size_t bucket = compute_bucket(hash); bucket + 1 < CTR_SPARSE_SHARD_BUCKET_NUM) {
it = _buckets[++bucket].begin();
auto &table = values_[bucket]; }
auto res = table.find(id); return {it, bucket, _buckets};
return res->second;
} }
iterator end() {
void erase(uint64_t feasign) { return {_buckets[CTR_SPARSE_SHARD_BUCKET_NUM - 1].end(),
size_t hash = hasher_(feasign); CTR_SPARSE_SHARD_BUCKET_NUM - 1, _buckets};
}
local_iterator begin(size_t bucket) { return {_buckets[bucket].begin()}; }
local_iterator end(size_t bucket) { return {_buckets[bucket].end()}; }
iterator find(const KEY& key) {
size_t hash = _hasher(key);
size_t bucket = compute_bucket(hash); size_t bucket = compute_bucket(hash);
auto &table = values_[bucket]; auto it = _buckets[bucket].find_with_hash(key, hash);
if (it == _buckets[bucket].end()) {
auto iter = table.find(feasign); return end();
if (iter != table.end()) {
butil::return_object(iter->second);
iter = table.erase(iter);
} }
return {it, bucket, _buckets};
}
VALUE& operator[](const KEY& key) { return emplace(key).first.value(); }
std::pair<iterator, bool> insert(const KEY& key, const VALUE& val) {
return emplace(key, val);
} }
std::pair<iterator, bool> insert(const KEY& key, VALUE&& val) {
return emplace(key, std::move(val));
}
template <class... ARGS>
std::pair<iterator, bool> emplace(const KEY& key, ARGS&&... args) {
size_t hash = _hasher(key);
size_t bucket = compute_bucket(hash);
auto res = _buckets[bucket].insert_with_hash({key, NULL}, hash);
void clear() {} if (res.second) {
res.first->second = _alloc.acquire(std::forward<ARGS>(args)...);
}
size_t compute_bucket(size_t hash) { return {{res.first, bucket, _buckets}, res.second};
if (CTR_SPARSE_SHARD_BUCKET_NUM == 1) { }
return 0; iterator erase(iterator it) {
} else { _alloc.release((VALUE*)(void*)it.it->second); // NOLINT
return hash >> (sizeof(size_t) * 8 - CTR_SPARSE_SHARD_BUCKET_NUM_BITS); size_t bucket = it.bucket;
auto it2 = _buckets[bucket].erase(it.it);
while (it2 == _buckets[bucket].end() &&
bucket + 1 < CTR_SPARSE_SHARD_BUCKET_NUM) {
it2 = _buckets[++bucket].begin();
} }
return {it2, bucket, _buckets};
} }
void quick_erase(iterator it) {
map_type::iterator end() { _alloc.release((VALUE*)(void*)it.it->second); // NOLINT
return values_[CTR_SPARSE_SHARD_BUCKET_NUM - 1].end(); _buckets[it.bucket].quick_erase(it.it);
} }
local_iterator erase(size_t bucket, local_iterator it) {
map_type::iterator Find(uint64_t id) { _alloc.release((VALUE*)(void*)it.it->second); // NOLINT
size_t hash = hasher_(id); return {_buckets[bucket].erase(it.it)};
size_t bucket = compute_bucket(hash); }
auto &table = values_[bucket]; void quick_erase(size_t bucket, local_iterator it) {
_alloc.release((VALUE*)(void*)it.it->second); // NOLINT
auto got = table.find(id); _buckets[bucket].quick_erase(it.it);
if (got == table.end()) { }
return end(); size_t erase(const KEY& key) {
} else { auto it = find(key);
return got; if (it == end()) {
return 0;
} }
quick_erase(it);
return 1;
} }
size_t compute_bucket(size_t hash) {
private: if (CTR_SPARSE_SHARD_BUCKET_NUM == 1) {
bool Has(const uint64_t id) { return 0;
size_t hash = hasher_(id);
size_t bucket = compute_bucket(hash);
auto &table = values_[bucket];
auto got = table.find(id);
if (got == table.end()) {
return false;
} else { } else {
return true; return hash >> (sizeof(size_t) * 8 - CTR_SPARSE_SHARD_BUCKET_NUM_BITS);
} }
} }
public: private:
map_type values_[CTR_SPARSE_SHARD_BUCKET_NUM]; map_type _buckets[CTR_SPARSE_SHARD_BUCKET_NUM];
std::hash<uint64_t> hasher_; ChunkAllocator<VALUE> _alloc;
std::hash<KEY> _hasher;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -36,6 +36,7 @@ namespace distributed { ...@@ -36,6 +36,7 @@ namespace distributed {
class MemorySparseTable : public SparseTable { class MemorySparseTable : public SparseTable {
public: public:
typedef SparseTableShard<uint64_t, FixedFeatureValue> shard_type;
MemorySparseTable() {} MemorySparseTable() {}
virtual ~MemorySparseTable() {} virtual ~MemorySparseTable() {}
...@@ -59,6 +60,9 @@ class MemorySparseTable : public SparseTable { ...@@ -59,6 +60,9 @@ class MemorySparseTable : public SparseTable {
int32_t save_local_fs(const std::string& path, const std::string& param, int32_t save_local_fs(const std::string& path, const std::string& param,
const std::string& prefix); const std::string& prefix);
int64_t local_size();
int64_t local_mf_size();
virtual std::pair<int64_t, int64_t> print_table_stat(); virtual std::pair<int64_t, int64_t> print_table_stat();
virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value);
...@@ -80,12 +84,12 @@ class MemorySparseTable : public SparseTable { ...@@ -80,12 +84,12 @@ class MemorySparseTable : public SparseTable {
size_t num); size_t num);
protected: protected:
const int task_pool_size_ = 24; const int _task_pool_size = 24;
size_t avg_local_shard_num_; size_t _avg_local_shard_num;
size_t real_local_shard_num_; size_t _real_local_shard_num;
size_t sparse_table_shard_num_; size_t _sparse_table_shard_num;
std::vector<std::shared_ptr<::ThreadPool>> shards_task_pool_; std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
std::vector<std::shared_ptr<SparseTableShard>> shard_values_; std::unique_ptr<shard_type[]> _local_shards;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -27,9 +27,6 @@ class Table; ...@@ -27,9 +27,6 @@ class Table;
TEST(CommonDenseTable, Adam) { TEST(CommonDenseTable, Adam) {
int fea_dim = 10; int fea_dim = 10;
int trainers = 2; int trainers = 2;
float beta1 = 0.9;
float beta2 = 0.999;
float epsilon = 1.0e-8;
TableParameter table_config; TableParameter table_config;
table_config.set_table_class("CommonDenseTable"); table_config.set_table_class("CommonDenseTable");
...@@ -39,27 +36,33 @@ TEST(CommonDenseTable, Adam) { ...@@ -39,27 +36,33 @@ TEST(CommonDenseTable, Adam) {
accessor_config->set_accessor_class("CommMergeAccessor"); accessor_config->set_accessor_class("CommMergeAccessor");
CommonAccessorParameter *common_config = table_config.mutable_common(); CommonAccessorParameter *common_config = table_config.mutable_common();
// set adam optimize config // set adam optimize config
common_config->set_name("adam"); common_config->set_name("adam_d2sum");
common_config->set_table_name("adam_test_table"); common_config->set_table_name("adam_test_table");
common_config->set_trainer_num(trainers); common_config->set_trainer_num(trainers);
common_config->add_params("Param"); common_config->add_params("Param");
common_config->add_dims(fea_dim); common_config->add_dims(fea_dim);
common_config->add_initializers("gaussian_random&0&0.0&1.0"); common_config->add_initializers("gaussian_random&0&0.0&1.0");
common_config->add_params("LearningRate"); common_config->add_params("D2Sum");
common_config->add_dims(1); common_config->add_dims(fea_dim);
common_config->add_initializers("fill_constant&1.0"); common_config->add_initializers("fill_constant&0.0");
common_config->add_params("Moment1"); common_config->add_params("G2Sum");
common_config->add_dims(fea_dim); common_config->add_dims(fea_dim);
common_config->add_initializers("fill_constant&0.0"); common_config->add_initializers("fill_constant&0.0");
common_config->add_params("Moment2"); common_config->add_params("Moment");
common_config->add_dims(fea_dim); common_config->add_dims(fea_dim);
common_config->add_initializers("fill_constant&0.0"); common_config->add_initializers("fill_constant&0.0");
common_config->add_params("Beta1Pow"); common_config->add_params("MomentDecayRate");
common_config->add_dims(1); common_config->add_dims(1);
common_config->add_initializers("fill_constant&1.0"); common_config->add_initializers("fill_constant&0.99");
common_config->add_params("Beta2Pow"); common_config->add_params("AdaDecayRate");
common_config->add_dims(1); common_config->add_dims(1);
common_config->add_initializers("fill_constant&1.0"); common_config->add_initializers("fill_constant&0.9999");
common_config->add_params("AdaEpsilon");
common_config->add_dims(1);
common_config->add_initializers("fill_constant&1.0e-8");
common_config->add_params("LearningRate");
common_config->add_dims(1);
common_config->add_initializers("fill_constant&5e-6");
auto ret = table->initialize(table_config, fs_config); auto ret = table->initialize(table_config, fs_config);
ASSERT_EQ(ret, 0); ASSERT_EQ(ret, 0);
...@@ -89,29 +92,30 @@ TEST(CommonDenseTable, Adam) { ...@@ -89,29 +92,30 @@ TEST(CommonDenseTable, Adam) {
pull_values.resize(fea_dim); pull_values.resize(fea_dim);
table->pull_dense(pull_values.data(), fea_dim); table->pull_dense(pull_values.data(), fea_dim);
std::vector<float> beta1_pow, beta2_pow, lr, mom1, mom2, param; float mom_rate = 0.99;
beta1_pow.push_back(beta1); float decay_rate = 0.9999;
beta2_pow.push_back(beta2); float epsilon = 1.0e-8;
lr.push_back(1.0); float lr = 5e-6;
std::vector<float> d2sum, g2sum, mom, param;
for (int i = 0; i < fea_dim; i++) { for (int i = 0; i < fea_dim; i++) {
mom1.push_back(0.0); mom.push_back(0.0);
mom2.push_back(0.0); d2sum.push_back(0.0);
g2sum.push_back(0.0);
param.push_back(init_values[i]); param.push_back(init_values[i]);
} }
for (int i = 0; i < trainers; i++) { for (int i = 0; i < trainers; i++) {
auto lr_ = lr[0] * sqrt(1 - beta2_pow[0]) / (1 - beta1_pow[0]);
for (int j = 0; j < fea_dim; j++) { for (int j = 0; j < fea_dim; j++) {
mom1[j] = beta1 * mom1[j] + (1 - beta1) * trainer_gradient_values[i][j]; d2sum[j] = d2sum[j] * decay_rate + 1;
mom2[j] = beta2 * mom2[j] + g2sum[j] = g2sum[j] * decay_rate +
(1 - beta2) * trainer_gradient_values[i][j] * trainer_gradient_values[i][j] * trainer_gradient_values[i][j];
trainer_gradient_values[i][j]; float scale = d2sum[j] * epsilon;
param[j] = scale = (scale + d2sum[j]) / (scale + g2sum[j]);
param[j] - scale = sqrt(scale);
lr_ * (mom1[j] / (sqrt(mom2[j]) + epsilon * sqrt(1 - beta2_pow[0]))); mom[j] = (mom[j] - trainer_gradient_values[i][j]) * mom_rate +
trainer_gradient_values[i][j];
param[j] = param[j] - lr * scale * mom[j];
} }
beta1_pow[0] *= beta1;
beta2_pow[0] *= beta2;
} }
for (int j = 0; j < fea_dim; j++) { for (int j = 0; j < fea_dim; j++) {
ASSERT_TRUE(abs(param[j] - pull_values[j]) < 1e-5); ASSERT_TRUE(abs(param[j] - pull_values[j]) < 1e-5);
......
...@@ -12,38 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,38 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <ThreadPool.h> #include "paddle/fluid/distributed/table/depends/feature_value.h"
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include <vector> #include <vector>
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/table/depends/feature_value.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
TEST(BENCHMARK, LargeScaleKV) { TEST(BENCHMARK, LargeScaleKV) {
std::shared_ptr<SparseTableShard> shard = typedef SparseTableShard<uint64_t, FixedFeatureValue> shard_type;
std::make_shared<SparseTableShard>(); shard_type shard;
uint64_t key = 1; uint64_t key = 1;
auto itr = shard->Find(key); auto itr = shard.find(key);
ASSERT_TRUE(itr == shard->end()); ASSERT_TRUE(itr == shard.end());
std::vector<float> vec = {0.0, 0.1, 0.2, 0.3}; std::vector<float> vec = {0.0, 0.1, 0.2, 0.3};
auto* feature_value = shard->Init(key); auto& feature_value = shard[key];
feature_value->resize(vec.size()); feature_value.resize(vec.size());
memcpy(feature_value->data(), vec.data(), vec.size() * sizeof(float)); memcpy(feature_value.data(), vec.data(), vec.size() * sizeof(float));
itr = shard->Find(key); itr = shard.find(key);
ASSERT_TRUE(itr != shard->end()); ASSERT_TRUE(itr != shard.end());
feature_value = itr->second; feature_value = itr.value();
float* value_data = feature_value->data(); float* value_data = feature_value.data();
ASSERT_FLOAT_EQ(value_data[0], 0.0); ASSERT_FLOAT_EQ(value_data[0], 0.0);
ASSERT_FLOAT_EQ(value_data[1], 0.1); ASSERT_FLOAT_EQ(value_data[1], 0.1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册