提交 9552cf55 编写于 作者: L linan17

Merge branch 'master' of ssh://icode.baidu.com:8235/baidu/feed-mlarch/paddle-trainer

Change-Id: If03dc2ea6e1a8b9bcedb9d5c9fa9dbcc44d41396
...@@ -36,6 +36,7 @@ CONFIGS('baidu/third-party/pybind11@v2.2.4@git_branch') ...@@ -36,6 +36,7 @@ CONFIGS('baidu/third-party/pybind11@v2.2.4@git_branch')
CONFIGS('baidu/third-party/python@gcc482output@git_branch') CONFIGS('baidu/third-party/python@gcc482output@git_branch')
CONFIGS('baidu/third-party/yaml-cpp@yaml-cpp_0-6-2-0_GEN_PD_BL@git_tag') CONFIGS('baidu/third-party/yaml-cpp@yaml-cpp_0-6-2-0_GEN_PD_BL@git_tag')
CONFIGS('baidu/third-party/openmpi@openmpi_1-4-5-0-feed_mlarch@git_branch') CONFIGS('baidu/third-party/openmpi@openmpi_1-4-5-0-feed_mlarch@git_branch')
CONFIGS('baidu/feed-mlarch/hopscotch-map@stable')
CONFIGS('baidu/paddlepaddle/pslib@stable') CONFIGS('baidu/paddlepaddle/pslib@stable')
CONFIGS('third-64/gtest@gtest_1-7-0-100_PD_BL') CONFIGS('third-64/gtest@gtest_1-7-0-100_PD_BL')
HEADERS('paddle/fluid/memory/*.h', '$INC/paddle/fluid/memory/') HEADERS('paddle/fluid/memory/*.h', '$INC/paddle/fluid/memory/')
......
...@@ -68,6 +68,16 @@ public: ...@@ -68,6 +68,16 @@ public:
std::string data;//样本数据, maybe压缩格式 std::string data;//样本数据, maybe压缩格式
}; };
template <class AR>
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar, DataItem& x) {
return ar >> x.id >> x.data;
}
template <class AR>
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar, const DataItem& x) {
return ar << x.id << x.data;
}
typedef std::shared_ptr<Pipeline<DataItem, SampleInstance>> SampleInstancePipe; typedef std::shared_ptr<Pipeline<DataItem, SampleInstance>> SampleInstancePipe;
inline SampleInstancePipe make_sample_instance_channel() { inline SampleInstancePipe make_sample_instance_channel() {
return std::make_shared<Pipeline<DataItem, SampleInstance>>(); return std::make_shared<Pipeline<DataItem, SampleInstance>>();
......
...@@ -243,7 +243,7 @@ paddle::framework::Channel<DataItem> MultiThreadExecutor::run( ...@@ -243,7 +243,7 @@ paddle::framework::Channel<DataItem> MultiThreadExecutor::run(
for (auto& monitor : _monitors) { for (auto& monitor : _monitors) {
if (monitor->need_compute_result(epoch_id)) { if (monitor->need_compute_result(epoch_id)) {
monitor->compute_result(); monitor->compute_result();
ENVLOG_WORKER_MASTER_NOTICE("[Monitor]%s, monitor:%s, , result:%s", ENVLOG_WORKER_MASTER_NOTICE("[Monitor]%s, monitor:%s, result:%s",
_train_exe_name.c_str(), monitor->get_name().c_str(), monitor->format_result().c_str()); _train_exe_name.c_str(), monitor->get_name().c_str(), monitor->format_result().c_str());
_trainer_context->monitor_ssm << _train_exe_name << ":" << _trainer_context->monitor_ssm << _train_exe_name << ":" <<
monitor->get_name() << ":" << monitor->format_result() << ","; monitor->get_name() << ":" << monitor->format_result() << ",";
......
...@@ -21,6 +21,7 @@ int main(int argc, char* argv[]) { ...@@ -21,6 +21,7 @@ int main(int argc, char* argv[]) {
//load trainer config //load trainer config
auto trainer_context_ptr = std::make_shared<TrainerContext>(); auto trainer_context_ptr = std::make_shared<TrainerContext>();
trainer_context_ptr->cache_dict.reset(new SignCacheDict);
trainer_context_ptr->trainer_config = YAML::LoadFile(FLAGS_feed_trainer_conf_path); trainer_context_ptr->trainer_config = YAML::LoadFile(FLAGS_feed_trainer_conf_path);
//environment //environment
......
...@@ -16,6 +16,8 @@ namespace feed { ...@@ -16,6 +16,8 @@ namespace feed {
int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) { int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
int ret = Process::initialize(context_ptr); int ret = Process::initialize(context_ptr);
auto& config = _context_ptr->trainer_config; auto& config = _context_ptr->trainer_config;
_is_dump_cache_model = config["dump_cache_model"].as<bool>(false);
_cache_load_converter = config["load_cache_converter"].as<std::string>("");
_startup_dump_inference_base = config["startup_dump_inference_base"].as<bool>(false); _startup_dump_inference_base = config["startup_dump_inference_base"].as<bool>(false);
if (config["executor"]) { if (config["executor"]) {
_executors.resize(config["executor"].size()); _executors.resize(config["executor"].size());
...@@ -27,6 +29,53 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) { ...@@ -27,6 +29,53 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
return 0; return 0;
} }
// 更新各节点存储的CacheModel
int LearnerProcess::update_cache_model(uint64_t epoch_id, ModelSaveWay way) {
auto fs = _context_ptr->file_system;
auto* ps_client = _context_ptr->pslib->ps_client();
auto* environment = _context_ptr->environment.get();
auto* epoch_accessor = _context_ptr->epoch_accessor.get();
if (!epoch_accessor->need_save_model(epoch_id, way)) {
return 0;
}
auto* ps_param = _context_ptr->pslib->get_param();
if (_is_dump_cache_model && way == ModelSaveWay::ModelSaveInferenceBase) {
auto model_dir = epoch_accessor->model_save_path(epoch_id, way);
auto& table_param = ps_param->server_param().downpour_server_param().downpour_table_param();
for (auto& param : table_param) {
if (param.type() != paddle::PS_SPARSE_TABLE) {
continue;
}
auto cache_model_path = fs->path_join(
model_dir, string::format_string("%03d_cache/", param.table_id()));
if (!fs->exists(cache_model_path)) {
continue;
}
auto& cache_dict = *(_context_ptr->cache_dict.get());
cache_dict.clear();
cache_dict.reserve(_cache_sign_max_num);
auto cache_file_list = fs->list(fs->path_join(cache_model_path, "part*"));
for (auto& cache_path : cache_file_list) {
auto cache_file = fs->open_read(cache_path, _cache_load_converter);
char *buffer = nullptr;
size_t buffer_size = 0;
ssize_t line_len = 0;
while ((line_len = getline(&buffer, &buffer_size, cache_file.get())) != -1) {
if (line_len <= 1) {
continue;
}
char* data_ptr = NULL;
cache_dict.append(strtoul(buffer, &data_ptr, 10));
}
if (buffer != nullptr) {
free(buffer);
}
}
break;
}
}
return 0;
}
int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump) { int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump) {
auto fs = _context_ptr->file_system; auto fs = _context_ptr->file_system;
auto* ps_client = _context_ptr->pslib->ps_client(); auto* ps_client = _context_ptr->pslib->ps_client();
...@@ -65,7 +114,38 @@ int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is ...@@ -65,7 +114,38 @@ int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is
} }
timer.Pause(); timer.Pause();
VLOG(2) << "Save Model Cost(s):" << timer.ElapsedSec(); VLOG(2) << "Save Model Cost(s):" << timer.ElapsedSec();
// save cache model, 只有inference需要cache_model
auto* ps_param = _context_ptr->pslib->get_param();
if (_is_dump_cache_model && (way == ModelSaveWay::ModelSaveInferenceBase ||
way == ModelSaveWay::ModelSaveInferenceDelta)) {
auto& table_param = ps_param->server_param().downpour_server_param().downpour_table_param();
for (auto& param : table_param) {
if (param.type() != paddle::PS_SPARSE_TABLE) {
continue;
}
double cache_threshold = 0.0;
auto status = ps_client->get_cache_threshold(param.table_id(), cache_threshold);
CHECK(status.get() == 0) << "CacheThreshold Get failed!";
status = ps_client->cache_shuffle(param.table_id(), model_dir, std::to_string((int)way),
std::to_string(cache_threshold));
CHECK(status.get() == 0) << "Cache Shuffler Failed";
status = ps_client->save_cache(param.table_id(), model_dir, std::to_string((int)way));
auto feature_size = status.get();
CHECK(feature_size >= 0) << "Cache Save Failed";
auto cache_model_path = fs->path_join(
model_dir, string::format_string("%03d_cache/sparse_cache.meta", param.table_id()));
auto cache_meta_file = fs->open_write(cache_model_path, "");
auto meta = string::format_string("file_prefix:part\npart_num:%d\nkey_num:%d\n",
param.sparse_table_cache_file_num(), feature_size);
CHECK(fwrite(meta.c_str(), meta.size(), 1, cache_meta_file.get()) == 1) << "Cache Meta Failed";
if (feature_size > _cache_sign_max_num) {
_cache_sign_max_num = feature_size;
}
}
}
_context_ptr->epoch_accessor->update_model_donefile(epoch_id, way); _context_ptr->epoch_accessor->update_model_donefile(epoch_id, way);
return all_ret; return all_ret;
} }
...@@ -176,6 +256,9 @@ int LearnerProcess::run() { ...@@ -176,6 +256,9 @@ int LearnerProcess::run() {
{ {
wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase); wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
environment->barrier(EnvironmentRole::WORKER); environment->barrier(EnvironmentRole::WORKER);
update_cache_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
environment->barrier(EnvironmentRole::WORKER);
if (epoch_accessor->is_last_epoch(epoch_id)) { if (epoch_accessor->is_last_epoch(epoch_id)) {
wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpointBase); wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpointBase);
} else { } else {
......
...@@ -22,9 +22,13 @@ protected: ...@@ -22,9 +22,13 @@ protected:
virtual int load_model(uint64_t epoch_id); virtual int load_model(uint64_t epoch_id);
// 同步保存所有模型, is_force_dump:不判断dump条件,强制dump出模型 // 同步保存所有模型, is_force_dump:不判断dump条件,强制dump出模型
virtual int wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump = false); virtual int wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump = false);
virtual int update_cache_model(uint64_t epoch_id, ModelSaveWay way);
private: private:
bool _startup_dump_inference_base; //启动立即dump base bool _is_dump_cache_model; // 是否进行cache dump
uint32_t _cache_sign_max_num = 0; // cache sign最大个数
std::string _cache_load_converter; // cache加载的前置转换脚本
bool _startup_dump_inference_base; // 启动立即dump base
std::vector<std::shared_ptr<MultiThreadExecutor>> _executors; std::vector<std::shared_ptr<MultiThreadExecutor>> _executors;
}; };
......
#include "paddle/fluid/framework/archive.h" #include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h" #include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.h" #include "paddle/fluid/train/custom_trainer/feed/shuffler/shuffler.h"
#include <bthread/butex.h>
namespace paddle { namespace paddle {
namespace custom_trainer { namespace custom_trainer {
...@@ -40,73 +41,195 @@ public: ...@@ -40,73 +41,195 @@ public:
Shuffler::initialize(config, context_ptr); Shuffler::initialize(config, context_ptr);
_max_concurrent_num = config["max_concurrent_num"].as<int>(4); // 最大并发发送数 _max_concurrent_num = config["max_concurrent_num"].as<int>(4); // 最大并发发送数
_max_package_size = config["max_package_size"].as<int>(1024); // 最大包个数,一次发送package个数据 _max_package_size = config["max_package_size"].as<int>(1024); // 最大包个数,一次发送package个数据
_shuffle_data_msg_type = config["shuffle_data_msg_type"].as<int>(3); // c2c msg type
_finish_msg_type = config["finish_msg_type"].as<int>(4); // c2c msg type
reset_channel();
auto binded = std::bind(&GlobalShuffler::get_client2client_msg, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3);
_trainer_context->pslib->ps_client()->registe_client2client_msg_handler(_shuffle_data_msg_type,
binded);
_trainer_context->pslib->ps_client()->registe_client2client_msg_handler(_finish_msg_type,
binded);
return 0; return 0;
} }
// 所有worker必须都调用shuffle,并且shuffler同时只能有一个shuffle任务
virtual int shuffle(::paddle::framework::Channel<DataItem>& data_channel) { virtual int shuffle(::paddle::framework::Channel<DataItem>& data_channel) {
uint32_t send_count = 0; uint32_t send_count = 0;
uint32_t package_size = _max_package_size; uint32_t package_size = _max_package_size;
uint32_t concurrent_num = _max_concurrent_num; uint32_t concurrent_num = _max_concurrent_num;
uint32_t current_wait_idx = 0; ::paddle::framework::Channel<DataItem> input_channel = ::paddle::framework::MakeChannel<DataItem>(data_channel);
data_channel.swap(input_channel);
set_channel(data_channel);
auto* environment = _trainer_context->environment.get(); auto* environment = _trainer_context->environment.get();
auto worker_num = environment->node_num(EnvironmentRole::WORKER); auto worker_num = environment->node_num(EnvironmentRole::WORKER);
std::vector<std::vector<std::future<int>>> waits(concurrent_num); std::vector<std::vector<std::future<int>>> waits(concurrent_num);
std::vector<DataItem> send_buffer(concurrent_num * package_size); std::vector<DataItem> send_buffer(package_size);
std::vector<paddle::framework::BinaryArchive> request_data_buffer(worker_num); std::vector<std::vector<DataItem>> send_buffer_worker(worker_num);
while (true) {
auto read_size = data_channel->Read(concurrent_num * package_size, &send_buffer[0]); int status = 0;// >0: finish; =0: running; <0: fail
while (status == 0) {
// update status
// 如果在训练期,则限速shuffle
// 如果在wait状态,全速shuffle
if (_trainer_context->is_status(TrainerStatus::Training)) {
concurrent_num = 1;
package_size = _max_concurrent_num / 2;
} else {
package_size = _max_package_size;
concurrent_num = _max_concurrent_num;
}
for (uint32_t current_wait_idx = 0; status == 0 && current_wait_idx < concurrent_num; ++current_wait_idx) {
auto read_size = input_channel->Read(package_size, send_buffer.data());
if (read_size == 0) { if (read_size == 0) {
status = 1;
break; break;
} }
for (size_t idx = 0; idx < read_size; idx += package_size) { for (int i = 0; i < worker_num; ++i) {
// data shard && seriliaze send_buffer_worker.clear();
for (size_t i = 0; i < worker_num; ++i) {
request_data_buffer[i].Clear();
} }
for (size_t i = idx; i < package_size && i < read_size; ++i) { for (int i = 0; i < read_size; ++i) {
auto worker_idx = _shuffle_key_func(send_buffer[i].id) % worker_num; auto worker_idx = _shuffle_key_func(send_buffer[i].id) % worker_num;
// TODO Serialize To Arcive send_buffer_worker[worker_idx].push_back(std::move(send_buffer[i]));
//request_data_buffer[worker_idx] << send_buffer[i];
} }
std::string data_vec[worker_num];
for (size_t i = 0; i < worker_num; ++i) {
auto& buffer = request_data_buffer[i];
data_vec[i].assign(buffer.Buffer(), buffer.Length());
}
// wait async done
for (auto& wait_s : waits[current_wait_idx]) { for (auto& wait_s : waits[current_wait_idx]) {
if (!wait_s.valid()) { if (wait_s.get() != 0) {
LOG(WARNING) << "fail to send shuffle data";
status = -1;
break; break;
} }
CHECK(wait_s.get() == 0); }
if (status != 0) {
break;
}
waits[current_wait_idx].clear();
for (int i = 0; i < worker_num; ++i) {
if (!send_buffer_worker[i].empty()) {
waits[current_wait_idx].push_back(send_shuffle_data(i, send_buffer_worker[i]));
}
}
}
}
for (auto& waits_s : waits) {
for (auto& wait_s : waits_s) {
if (wait_s.get() != 0) {
LOG(WARNING) << "fail to send shuffle data";
status = -1;
}
}
}
VLOG(5) << "start send finish, worker_num: " << worker_num;
waits[0].clear();
for (int i = 0; i < worker_num; ++i) {
waits[0].push_back(send_finish(i));
}
VLOG(5) << "wait all finish";
for (int i = 0; i < worker_num; ++i) {
if (waits[0][i].get() != 0) {
LOG(WARNING) << "fail to send finish " << i;
status = -1;
}
}
VLOG(5) << "finish shuffler, status: " << status;
return status < 0 ? status : 0;
} }
// send shuffle data private:
for (size_t i = 0; i < worker_num; ++i) { /*
waits[current_wait_idx][i] = _trainer_context->pslib->ps_client()->send_client2client_msg(3, i * 2, data_vec[i]); 1. 部分c2c send_shuffle_data先到, 此时channel未设置, 等待wait_channel
2. shuffle中调用set_channel, 先reset_wait_num, 再解锁channel
3. 当接收到所有worker的finish请求后,先reset_channel, 再同时返回
*/
bool wait_channel() {
VLOG(5) << "wait_channel";
std::lock_guard<bthread::Mutex> lock(_channel_mutex);
return _out_channel != nullptr;
}
void reset_channel() {
VLOG(5) << "reset_channel";
_channel_mutex.lock();
if (_out_channel != nullptr) {
_out_channel->Close();
}
_out_channel = nullptr;
}
void reset_wait_num() {
_wait_num_mutex.lock();
_wait_num = _trainer_context->environment->node_num(EnvironmentRole::WORKER);
VLOG(5) << "reset_wait_num: " << _wait_num;
}
void set_channel(paddle::framework::Channel<DataItem>& channel) {
VLOG(5) << "set_channel";
// 在节点开始写入channel之前,重置wait_num
CHECK(_out_channel == nullptr);
_out_channel = channel;
reset_wait_num();
_channel_mutex.unlock();
} }
// update status int32_t finish_write_channel() {
// 如果在训练期,则限速shuffle int wait_num = --_wait_num;
// 如果在wait状态,全速shuffle VLOG(5) << "finish_write_channel, wait_num: " << wait_num;
if (_trainer_context->is_status(TrainerStatus::Training)) { // 同步所有worker,在所有写入完成后,c2c_msg返回前,重置channel
concurrent_num = 1; if (wait_num == 0) {
package_size = _max_concurrent_num / 2; reset_channel();
_wait_num_mutex.unlock();
} else { } else {
package_size = _max_package_size; std::lock_guard<bthread::Mutex> lock(_wait_num_mutex);
concurrent_num = _max_concurrent_num;
} }
++current_wait_idx; return 0;
current_wait_idx = current_wait_idx >= concurrent_num ? 0 : current_wait_idx;
} }
int32_t write_to_channel(std::vector<DataItem>&& items) {
size_t items_size = items.size();
VLOG(5) << "write_to_channel, items_size: " << items_size;
return _out_channel->Write(std::move(items)) == items_size ? 0 : -1;
} }
return 0;
int32_t get_client2client_msg(int msg_type, int from_client, const std::string& msg) {
// wait channel
if (!wait_channel()) {
LOG(FATAL) << "out_channel is null";
return -1;
}
VLOG(5) << "get c2c msg, type: " << msg_type << ", from_client: " << from_client << ", msg_size: " << msg.size();
if (msg_type == _shuffle_data_msg_type) {
paddle::framework::BinaryArchive ar;
ar.SetReadBuffer(const_cast<char*>(msg.data()), msg.size(), [](char*){});
std::vector<DataItem> items;
ar >> items;
return write_to_channel(std::move(items));
} else if (msg_type == _finish_msg_type) {
return finish_write_channel();
}
LOG(FATAL) << "no such msg type: " << msg_type;
return -1;
}
std::future<int32_t> send_shuffle_data(int to_client_id, std::vector<DataItem>& items) {
VLOG(5) << "send_shuffle_data, to_client_id: " << to_client_id << ", items_size: " << items.size();
paddle::framework::BinaryArchive ar;
ar << items;
return _trainer_context->pslib->ps_client()->send_client2client_msg(_shuffle_data_msg_type, to_client_id,
std::string(ar.Buffer(), ar.Length()));
}
std::future<int32_t> send_finish(int to_client_id) {
VLOG(5) << "send_finish, to_client_id: " << to_client_id;
static const std::string empty_str;
return _trainer_context->pslib->ps_client()->send_client2client_msg(_finish_msg_type, to_client_id, empty_str);
} }
private:
uint32_t _max_package_size = 0; uint32_t _max_package_size = 0;
uint32_t _max_concurrent_num = 0; uint32_t _max_concurrent_num = 0;
int _shuffle_data_msg_type = 3;
int _finish_msg_type = 4;
bthread::Mutex _channel_mutex;
paddle::framework::Channel<DataItem> _out_channel = nullptr;
bthread::Mutex _wait_num_mutex;
std::atomic<int> _wait_num;
}; };
REGIST_CLASS(Shuffler, GlobalShuffler); REGIST_CLASS(Shuffler, GlobalShuffler);
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <sstream> #include <sstream>
#include <tsl/bhopscotch_map.h>
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/train/custom_trainer/feed/common/yaml_helper.h" #include "paddle/fluid/train/custom_trainer/feed/common/yaml_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h" #include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h"
...@@ -35,15 +36,61 @@ enum class TrainerStatus { ...@@ -35,15 +36,61 @@ enum class TrainerStatus {
Saving = 1 // 模型存储状态 Saving = 1 // 模型存储状态
}; };
const uint32_t SignCacheMaxValueNum = 13;
struct SignCacheData {
SignCacheData() {
memset(cache_value, 0, sizeof(float) * SignCacheMaxValueNum);
}
uint32_t idx;
float cache_value[SignCacheMaxValueNum];
};
class SignCacheDict { class SignCacheDict {
public: public:
int32_t sign2index(uint64_t sign) { inline int32_t sign2index(uint64_t sign) {
auto itr = _sign2data_map.find(sign);
if (itr == _sign2data_map.end()) {
return -1; return -1;
} }
return itr->second.idx;
}
uint64_t index2sign(int32_t index) { inline uint64_t index2sign(int32_t index) {
if (index >= _sign_list.size()) {
return 0; return 0;
} }
return _sign_list[index];
}
inline void reserve(uint32_t size) {
_sign_list.reserve(size);
_sign2data_map.reserve(size);
}
inline void clear() {
_sign_list.clear();
_sign2data_map.clear();
}
inline void append(uint64_t sign) {
if (_sign2data_map.find(sign) != _sign2data_map.end()) {
return;
}
SignCacheData data;
data.idx = _sign_list.size();
_sign_list.push_back(sign);
_sign2data_map.emplace(sign, std::move(data));
}
inline SignCacheData* data(uint64_t sign) {
tsl::bhopscotch_pg_map<uint64_t, SignCacheData>::iterator itr = _sign2data_map.find(sign);
if (itr == _sign2data_map.end()) {
return nullptr;
}
return const_cast<SignCacheData*>(&(itr->second));
}
private:
std::vector<uint64_t> _sign_list;
tsl::bhopscotch_pg_map<uint64_t, SignCacheData> _sign2data_map;
}; };
class TrainerContext { class TrainerContext {
......
#include <gtest/gtest.h>
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
TEST(Archive, DataItem) {
paddle::custom_trainer::feed::DataItem item;
paddle::custom_trainer::feed::DataItem item2;
item.id = "123";
item.data = "name";
paddle::framework::BinaryArchive ar;
ar << item;
ar >> item2;
ASSERT_EQ(item.id, item2.id);
ASSERT_EQ(item.data, item2.data);
item.id += "~";
item.data += "~";
ASSERT_NE(item.id, item2.id);
ASSERT_NE(item.data, item2.data);
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册