提交 d4514949 编写于 作者: D dongdaxiang

remove local random engine in fleet with rand_r()

test=develop
上级 e82969ee
......@@ -349,7 +349,7 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
for (int64_t i = interval.first; i < interval.second; ++i) {
// if get ins id, can also use hash
// std::string ins_id = memory_data_[i].ins_id;
int64_t random_num = fleet_ptr->LocalRandomEngine()();
int64_t random_num = rand_r(&rand_seed);
int64_t node_id = random_num % trainer_num_;
send_vec[node_id].push_back(&((*memory_data_)[i]));
if (i % fleet_send_batch_size_ == 0 && i != 0) {
......
......@@ -232,6 +232,7 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
int thread_id_;
int thread_num_;
int trainer_num_;
uint32_t rand_seed;
std::vector<T>* memory_data_;
std::mutex* mutex_for_update_memory_data_;
// when read ins, we put ins from one channel to the other,
......
......@@ -250,7 +250,7 @@ int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
VLOG(3) << "ReceiveFromClient msg_type=" << msg_type
<< ", client_id=" << client_id << ", msg length=" << msg.length();
auto fleet_ptr = FleetWrapper::GetInstance();
int64_t index = fleet_ptr->LocalRandomEngine()() % thread_num_;
int64_t index = rand_r(&rand_seed) % thread_num_;
VLOG(3) << "ramdom index=" << index;
readers_[index]->PutInsToChannel(msg);
return 0;
......
......@@ -136,6 +136,7 @@ class DatasetImpl : public Dataset {
std::mutex mutex_for_pick_file_;
std::string fs_name_;
std::string fs_ugi_;
unsigned int rand_seed;
};
// use std::vector<MultiSlotType> as data type
......
......@@ -210,52 +210,20 @@ void FleetWrapper::PushDenseParamSync(
const ProgramDesc& program, const uint64_t table_id,
const std::vector<std::string>& var_names) {
#ifdef PADDLE_WITH_PSLIB
paddle::framework::Scope scope;
auto& block = program.Block(0);
for (auto& var : block.AllVars()) {
if (var->Persistable()) {
auto* ptr = scope.Var(var->Name());
InitializeVariable(ptr, var->GetType());
} else {
auto* ptr = scope.Var(var->Name());
InitializeVariable(ptr, var->GetType());
}
}
auto place = platform::CPUPlace();
std::vector<paddle::ps::Region> regions;
for (auto& t : var_names) {
Variable* var = scope.FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
std::vector<int64_t> dim;
for (auto& var : block.AllVars()) {
if (var->Name() == t) {
dim = var->GetShape();
break;
}
}
int cnt = 1;
for (auto& i : dim) {
cnt *= i;
}
DDim d(std::vector<int64_t>{cnt}.data(), 1);
float* g = tensor->mutable_data<float>(d, place);
CHECK(g != nullptr) << "var[" << t << "] value not initialized";
float init_range = 0.2;
int rown = tensor->dims()[0];
init_range /= sqrt(rown);
std::normal_distribution<float> ndistr(0.0, 1.0);
for (auto i = 0u; i < tensor->numel(); ++i) {
g[i] = ndistr(LocalRandomEngine()) * init_range;
}
float* g = tensor->mutable_data<float>(place);
paddle::ps::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
}
auto push_status = pslib_ptr_->_worker_ptr->push_dense_param(
regions.data(), regions.size(), table_id);
push_status.wait();
auto status = push_status.get();
CHECK(status == 0) << "push dense param failed, status[" << status << "]";
}
#endif
}
......@@ -372,22 +340,6 @@ std::future<int32_t> FleetWrapper::SendClientToClientMsg(
return std::future<int32_t>();
}
std::default_random_engine& FleetWrapper::LocalRandomEngine() {
struct engine_wrapper_t {
std::default_random_engine engine;
engine_wrapper_t() {
struct timespec tp;
clock_gettime(CLOCK_REALTIME, &tp);
double cur_time = tp.tv_sec + tp.tv_nsec * 1e-9;
static std::atomic<uint64_t> x(0);
std::seed_seq sseq = {x++, x++, x++, (uint64_t)(cur_time * 1000)};
engine.seed(sseq);
}
};
thread_local engine_wrapper_t r;
return r.engine;
}
template <typename T>
void FleetWrapper::Serialize(const std::vector<T*>& t, std::string* str) {
#ifdef PADDLE_WITH_PSLIB
......
......@@ -127,7 +127,6 @@ class FleetWrapper {
std::future<int32_t> SendClientToClientMsg(int msg_type, int to_client_id,
const std::string& msg);
std::default_random_engine& LocalRandomEngine();
template <typename T>
void Serialize(const std::vector<T*>& t, std::string* str);
template <typename T>
......
......@@ -79,7 +79,7 @@ inline int str_to_float(const char* str, float* v) {
// A line buffer is maintained. It
// doesn't need to know the maximum possible length of a line.
char* LineFileReader::getdelim(FILE* f, char delim) {
#ifndef __WIN32
#ifndef _WIN32
int32_t ret = ::getdelim(&_buffer, &_buf_size, delim, f);
if (ret >= 0) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册