From 7793b47927202b5c59fc1f3a3c77a2d7c16b7715 Mon Sep 17 00:00:00 2001 From: zhoubo01 Date: Mon, 30 Mar 2020 15:42:38 +0800 Subject: [PATCH] add asynchronous lsagent --- deepes/benchmark/cartpole_config.prototxt | 19 +- deepes/demo/paddle/cartpole_async_solver.cc | 136 ++++++++++ .../demo/paddle/cartpole_solver_parallel.cc | 14 +- deepes/include/paddle/async_es_agent.h | 98 +++++++ deepes/include/paddle/es_agent.h | 19 +- deepes/include/torch/es_agent.h | 5 +- deepes/include/utils.h | 24 +- deepes/scripts/build.sh | 2 +- deepes/src/paddle/async_es_agent.cc | 255 ++++++++++++++++++ deepes/src/paddle/es_agent.cc | 13 +- deepes/src/proto/deepes.proto | 1 + deepes/src/utils.cc | 18 ++ 12 files changed, 567 insertions(+), 37 deletions(-) create mode 100644 deepes/demo/paddle/cartpole_async_solver.cc create mode 100644 deepes/include/paddle/async_es_agent.h create mode 100644 deepes/src/paddle/async_es_agent.cc diff --git a/deepes/benchmark/cartpole_config.prototxt b/deepes/benchmark/cartpole_config.prototxt index 07e337c..03cc5fb 100644 --- a/deepes/benchmark/cartpole_config.prototxt +++ b/deepes/benchmark/cartpole_config.prototxt @@ -1,14 +1,15 @@ -seed : 1024 - +seed: 1024 gaussian_sampling { std: 0.5 } - optimizer { - type: "Adam", - base_lr: 0.05, - momentum: 0.9, - beta1: 0.9, - beta2: 0.999, - epsilon: 1e-8, + type: "Adam" + base_lr: 0.05 + momentum: 0.9 + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-08 +} +async_es { + model_iter_id: 0 } diff --git a/deepes/demo/paddle/cartpole_async_solver.cc b/deepes/demo/paddle/cartpole_async_solver.cc new file mode 100644 index 0000000..e2e7c06 --- /dev/null +++ b/deepes/demo/paddle/cartpole_async_solver.cc @@ -0,0 +1,136 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "cartpole.h" +#include "async_es_agent.h" +#include "paddle_api.h" + +using namespace DeepES; +using namespace paddle::lite_api; + +const int ITER = 10; + +std::shared_ptr create_paddle_predictor(const std::string& model_dir) { + // 1. Create CxxConfig + CxxConfig config; + config.set_model_dir(model_dir); + config.set_valid_places({ + Place{TARGET(kX86), PRECISION(kFloat)}, + Place{TARGET(kHost), PRECISION(kFloat)} + }); + + // 2. Create PaddlePredictor by CxxConfig + std::shared_ptr predictor = CreatePaddlePredictor(config); + return predictor; +} + +// Use PaddlePredictor of CartPole model to predict the action. +std::vector forward(std::shared_ptr predictor, const float* obs) { + std::unique_ptr input_tensor(std::move(predictor->GetInput(0))); + input_tensor->Resize({1, 4}); + input_tensor->CopyFromCpu(obs); + + predictor->Run(); + + std::vector probs(2, 0.0); + std::unique_ptr output_tensor( + std::move(predictor->GetOutput(0))); + output_tensor->CopyToCpu(probs.data()); + return probs; +} + +int arg_max(const std::vector& vec) { + return static_cast(std::distance(vec.begin(), std::max_element(vec.begin(), vec.end()))); +} + + +float evaluate(CartPole& env, std::shared_ptr agent) { + float total_reward = 0.0; + env.reset(); + const float* obs = env.getState(); + + std::shared_ptr paddle_predictor; + paddle_predictor = agent->get_predictor(); + + while (true) { + std::vector probs = forward(paddle_predictor, obs); + int act = arg_max(probs); + env.step(act); + float reward = env.getReward(); + bool done = env.isDone(); + total_reward += reward; + if (done) break; + obs = env.getState(); + } + return total_reward; +} + + +int main(int argc, char* argv[]) { + std::vector envs; + for (int i = 0; i < ITER; ++i) { + envs.push_back(CartPole()); + } + + std::shared_ptr paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model"); + std::shared_ptr agent = std::make_shared(paddle_predictor, "../benchmark/cartpole_config.prototxt"); + + // Clone agents to sample (explore). + std::vector< std::shared_ptr > sampling_agents; + for (int i = 0; i < ITER; ++i) { + sampling_agents.push_back(agent->clone()); + } + + std::vector noisy_info; + std::vector last_noisy_info; + std::vector noisy_rewards(ITER, 0.0f); + std::vector last_noisy_rewards; + noisy_info.resize(ITER); + + omp_set_num_threads(10); + for (int epoch = 0; epoch < 100; ++epoch) { + last_noisy_info.clear(); + last_noisy_rewards.clear(); + if (epoch != 0) { + for (int i = 0; i < ITER; ++i){ + last_noisy_info.push_back(noisy_info[i]); + last_noisy_rewards.push_back(noisy_rewards[i]); + } + } +#pragma omp parallel for schedule(dynamic, 1) + for (int i = 0; i < ITER; ++i) { + std::shared_ptr sampling_agent = sampling_agents[i]; + SamplingInfo info; + bool success = sampling_agent->add_noise(info); + float reward = evaluate(envs[i], sampling_agent); + + noisy_info[i] = info; + noisy_rewards[i] = reward; + } + + for (int i = 0; i < ITER; ++i){ + last_noisy_info.push_back(noisy_info[i]); + last_noisy_rewards.push_back(noisy_rewards[i]); + } + + // NOTE: all parameters of sampling_agents will be updated + bool success = agent->update(last_noisy_info, last_noisy_rewards); + + int reward = evaluate(envs[0], agent); + LOG(INFO) << "Epoch:" << epoch << " Reward: " << reward; + } +} diff --git a/deepes/demo/paddle/cartpole_solver_parallel.cc b/deepes/demo/paddle/cartpole_solver_parallel.cc index 704b35e..9fccb1a 100644 --- a/deepes/demo/paddle/cartpole_solver_parallel.cc +++ b/deepes/demo/paddle/cartpole_solver_parallel.cc @@ -95,25 +95,25 @@ int main(int argc, char* argv[]) { sampling_agents.push_back(agent->clone()); } - std::vector noisy_info; + std::vector noisy_keys; std::vector noisy_rewards(ITER, 0.0f); - noisy_info.resize(ITER); + noisy_keys.resize(ITER); omp_set_num_threads(10); - for (int epoch = 0; epoch < 300; ++epoch) { + for (int epoch = 0; epoch < 100; ++epoch) { #pragma omp parallel for schedule(dynamic, 1) for (int i = 0; i < ITER; ++i) { std::shared_ptr sampling_agent = sampling_agents[i]; - SamplingInfo info; - bool success = sampling_agent->add_noise(info); + SamplingInfo key; + bool success = sampling_agent->add_noise(key); float reward = evaluate(envs[i], sampling_agent); - noisy_info[i] = info; + noisy_keys[i] = key; noisy_rewards[i] = reward; } // NOTE: all parameters of sampling_agents will be updated - bool success = agent->update(noisy_info, noisy_rewards); + bool success = agent->update(noisy_keys, noisy_rewards); int reward = evaluate(envs[0], agent); LOG(INFO) << "Epoch:" << epoch << " Reward: " << reward; diff --git a/deepes/include/paddle/async_es_agent.h b/deepes/include/paddle/async_es_agent.h new file mode 100644 index 0000000..8989831 --- /dev/null +++ b/deepes/include/paddle/async_es_agent.h @@ -0,0 +1,98 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef _ASYNC_ES_AGENT_H +#define _ASYNC_ES_AGENT_H + +#include "es_agent.h" +#include +#include + +namespace DeepES{ +/* DeepES agent with PaddleLite as backend. This agent supports asynchronous update. + * Users mainly focus on the following functions: + * 1. clone: clone an agent for multi-thread evaluation + * 2. add_noise: add noise into parameters. + * 3. update: update parameters given data collected during evaluation. + */ +class AsyncAgent: public ESAgent { + public: + AsyncAgent() {} + + ~AsyncAgent(); + + /** + * @args: + * predictor: predictor created by users for prediction. + * config_path: the path of configuration file. + * Note that AsyncAgent will update the configuration file after calling the update function. + * Please use the up-to-date configuration. + */ + AsyncAgent( + std::shared_ptr predictor, + std::string config_path); + + /** + * @brief: Clone an agent for sampling. + */ + std::shared_ptr clone(); + + /** + * @brief: Clone an agent for sampling. + */ + bool update( + std::vector& noisy_info, + std::vector& noisy_rewards); + + private: + std::map> _previous_predictors; + std::map _param_delta; + std::string _config_path; + + /** + * @brief: parse model_iter_id given a string of model directory. + * @return: an integer indicating the model_iter_id + */ + int _parse_model_iter_id(const std::string&); + + /** + * @brief: compute the distance between current parameter and previous models. + */ + bool _compute_model_diff(); + + /** + * @brief: remove expired models to avoid overuse of disk space. + * @args: + * max_to_keep: the maximum number of models to keep locally. + */ + bool _remove_expired_model(int max_to_keep); + + /** + * @brief: save up-to-date parameters to the disk. + */ + bool _save(); + + /** + * @brief: load all models in the model warehouse. + */ + bool _load(); + + /** + * @brief: load a model given the model directory. + */ + std::shared_ptr _load_previous_model(std::string model_dir); +}; + +} //namespace +#endif diff --git a/deepes/include/paddle/es_agent.h b/deepes/include/paddle/es_agent.h index 219c58f..cd95fb2 100644 --- a/deepes/include/paddle/es_agent.h +++ b/deepes/include/paddle/es_agent.h @@ -21,21 +21,22 @@ #include "gaussian_sampling.h" #include "deepes.pb.h" #include +using namespace paddle::lite_api; namespace DeepES { +int64_t ShapeProduction(const shape_t& shape); + typedef paddle::lite_api::PaddlePredictor PaddlePredictor; /** - * @brief DeepES agent for PaddleLite. - * - * Users use `clone` fucntion to clone a sampling agent, which can call `add_noise` - * function to add noise to copied parameters and call `get_predictor` fucntion to - * get a paddle predictor with added noise. + * @brief DeepES agent with PaddleLite as backend. + * Users mainly focus on the following functions: + * 1. clone: clone an agent for multi-thread evaluation + * 2. add_noise: add noise into parameters. + * 3. update: update parameters given data collected during evaluation. * - * Then can use `update` function to update parameters based on ES algorithm. - * Note: parameters of cloned agents will also be updated. */ class ESAgent { public: @@ -77,7 +78,9 @@ class ESAgent { */ std::shared_ptr get_predictor(); - private: + + + protected: int64_t _calculate_param_size(); std::shared_ptr _predictor; diff --git a/deepes/include/torch/es_agent.h b/deepes/include/torch/es_agent.h index 50434b5..c4fc821 100644 --- a/deepes/include/torch/es_agent.h +++ b/deepes/include/torch/es_agent.h @@ -125,7 +125,7 @@ public: for (auto& param: params) { torch::Tensor tensor = param.value().view({-1}); auto tensor_a = tensor.accessor(); - _optimizer->update(tensor_a, _neg_gradients+counter, tensor.size(0), param.info()); + _optimizer->update(tensor_a, _neg_gradients+counter, tensor.size(0), param.key()); counter += tensor.size(0); } @@ -146,7 +146,7 @@ public: int64_t counter = 0; for (auto& param: sampling_params) { torch::Tensor sampling_tensor = param.value().view({-1}); - std::string param_name = param.info(); + std::string param_name = param.key(); torch::Tensor tensor = params.find(param_name)->view({-1}); auto sampling_tensor_a = sampling_tensor.accessor(); auto tensor_a = tensor.accessor(); @@ -162,6 +162,7 @@ public: private: int64_t _calculate_param_size() { + _param_size = 0; auto params = _model->named_parameters(); for (auto& param: params) { torch::Tensor tensor = param.value().view({-1}); diff --git a/deepes/include/utils.h b/deepes/include/utils.h index 6733e7c..002fa33 100644 --- a/deepes/include/utils.h +++ b/deepes/include/utils.h @@ -39,8 +39,7 @@ template bool load_proto_conf(const std::string& config_file, T& proto_config) { bool success = true; std::ifstream fin(config_file); - CHECK(fin) << "open config file " << config_file; - if (fin.fail()) { + if (!fin || fin.fail()) { LOG(FATAL) << "open prototxt config failed: " << config_file; success = false; } else { @@ -54,7 +53,7 @@ bool load_proto_conf(const std::string& config_file, T& proto_config) { std::string proto_str(file_content_buffer, file_size); if (!google::protobuf::TextFormat::ParseFromString(proto_str, &proto_config)) { LOG(FATAL) << "Failed to load config: " << config_file; - return -1; + success = false; } delete[] file_content_buffer; fin.close(); @@ -62,6 +61,25 @@ bool load_proto_conf(const std::string& config_file, T& proto_config) { return success; } +template +bool save_proto_conf(const std::string& config_file, T&proto_config) { + bool success = true; + std::ofstream ofs(config_file, std::ofstream::out); + if (!ofs || ofs.fail()) { + LOG(FATAL) << "open prototxt config failed: " << config_file; + success = false; + } else { + std::string config_str; + success = google::protobuf::TextFormat::PrintToString(proto_config, &config_str); + if (!success) { + return success; + } + ofs << config_str; + } +} + +std::vector list_all_model_dirs(std::string path); + } #endif diff --git a/deepes/scripts/build.sh b/deepes/scripts/build.sh index 4dc8d51..f5f5c7e 100644 --- a/deepes/scripts/build.sh +++ b/deepes/scripts/build.sh @@ -47,7 +47,7 @@ rm -rf build mkdir build cd build cmake ../ ${FLAGS} -make -j10 +make -j10 #-----------------run----------------# ./parallel_main diff --git a/deepes/src/paddle/async_es_agent.cc b/deepes/src/paddle/async_es_agent.cc new file mode 100644 index 0000000..48f8f67 --- /dev/null +++ b/deepes/src/paddle/async_es_agent.cc @@ -0,0 +1,255 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "async_es_agent.h" +namespace DeepES { + +AsyncAgent::AsyncAgent( + std::shared_ptr predictor, + std::string config_path): ESAgent(predictor, config_path) { + _config_path = config_path; +} +AsyncAgent::~AsyncAgent() { + for(const auto kv: _param_delta) { + float* delta = kv.second; + delete[] delta; + } +} + +bool AsyncAgent::_save() { + bool success = true; + if (_is_sampling_agent) { + LOG(ERROR) << "[DeepES] Original AsyncAgent cannot call add_noise function, please use cloned AsyncAgent."; + success = false; + return success; + } + int model_iter_id = _config->async_es().model_iter_id() + 1; + //current time + time_t rawtime; + struct tm * timeinfo; + char buffer[80]; + + time (&rawtime); + timeinfo = localtime(&rawtime); + + strftime(buffer,sizeof(buffer),"%d-%m-%Y-%H:%M:%S",timeinfo); + std::string current_time(buffer); + std::string model_name = current_time + "-model_iter_id-"+ std::to_string(model_iter_id); + model_name = "model_iter_id-"+ std::to_string(model_iter_id); + std::string model_path = _config->async_es().model_warehouse() + "/" + model_name; + LOG(INFO) << "[save]model_path: " << model_path; + _predictor->SaveOptimizedModel(model_path, LiteModelType::kProtobuf); + // save config + auto async_es = _config->mutable_async_es(); + async_es->set_model_iter_id(model_iter_id); + success = save_proto_conf(_config_path, *_config); + if (!success) { + LOG(ERROR) << "[]unable to save config for AsyncAgent"; + success = false; + return success; + } + int max_to_keep = _config->async_es().max_to_keep(); + success = _remove_expired_model(max_to_keep); + return success; +} + +bool AsyncAgent::_remove_expired_model(int max_to_keep) { + bool success = true; + std::string model_path = _config->async_es().model_warehouse(); + std::vector model_dirs = list_all_model_dirs(model_path); + int model_iter_id = _config->async_es().model_iter_id() + 1; + for (const auto& dir: model_dirs) { + int dir_model_iter_id = _parse_model_iter_id(dir); + if (model_iter_id - dir_model_iter_id >= max_to_keep) { + std::string rm_command = std::string("rm -rf ") + dir; + int ret = system(rm_command.c_str()); + if (ret == 0) { + LOG(INFO) << "[DeepES] remove expired Model: " << dir; + } else { + LOG(ERROR) << "[DeepES] fail to remove expired Model: " << dir; + success = false; + return success; + } + } + } + return success; +} + +bool AsyncAgent::_compute_model_diff() { + bool success = true; + for (const auto& kv: _previous_predictors) { + int model_iter_id = kv.first; + std::shared_ptr old_predictor = kv.second; + float* diff = new float[_param_size]; + memset(diff, 0, _param_size * sizeof(float)); + for (std::string param_name: _param_names) { + auto des_tensor = old_predictor->GetTensor(param_name); + auto src_tensor = _predictor->GetTensor(param_name); + const float* des_data = des_tensor->data(); + const float* src_data = src_tensor->data(); + int64_t tensor_size = ShapeProduction(src_tensor->shape()); + for (int i = 0; i < tensor_size; ++i) { + diff[i] = des_data[i] - src_data[i]; + } + } + _param_delta[model_iter_id] = diff; + } + return success; +} + +bool AsyncAgent::_load() { + bool success = true; + std::string model_path = _config->async_es().model_warehouse(); + std::vector model_dirs = list_all_model_dirs(model_path); + if (model_dirs.size() == 0) { + int model_iter_id = _config->async_es().model_iter_id(); + success = model_iter_id == 0 ? true: false; + if (!success) { + LOG(WARNING) << "[DeepES] current_model_iter_id is nonzero, but no model is \ + found at the dir: " << model_path; + } + return success; + } + for(auto &dir: model_dirs) { + int model_iter_id = _parse_model_iter_id(dir); + if (model_iter_id == -1) { + LOG(WARNING) << "[DeepES] fail to parse model_iter_id: " << dir; + success = false; + return success; + } + std::shared_ptr predictor = _load_previous_model(dir); + if (predictor == nullptr) { + success = false; + LOG(WARNING) << "[DeepES] fail to load model: " << dir; + return success; + } + _previous_predictors[model_iter_id] = predictor; + } + success = _compute_model_diff(); + return success; +} + +std::shared_ptr AsyncAgent::_load_previous_model(std::string model_dir) { + // 1. Create CxxConfig + CxxConfig config; + config.set_model_file(model_dir + "/model"); + config.set_param_file(model_dir + "/params"); + config.set_valid_places({ + Place{TARGET(kX86), PRECISION(kFloat)}, + Place{TARGET(kHost), PRECISION(kFloat)} + }); + + // 2. Create PaddlePredictor by CxxConfig + std::shared_ptr predictor = CreatePaddlePredictor(config); + return predictor; +} + +std::shared_ptr AsyncAgent::clone() { + std::shared_ptr new_sampling_predictor = _predictor->Clone(); + + std::shared_ptr new_agent = std::make_shared(); + + float* noise = new float [_param_size]; + + new_agent->_predictor = _predictor; + new_agent->_sampling_predictor = new_sampling_predictor; + + new_agent->_is_sampling_agent = true; + new_agent->_sampling_method = _sampling_method; + new_agent->_param_names = _param_names; + new_agent->_param_size = _param_size; + new_agent->_config = _config; + new_agent->_noise = noise; + + return new_agent; +} + +bool AsyncAgent::update( + std::vector& noisy_info, + std::vector& noisy_rewards) { + + CHECK(!_is_sampling_agent) << "[DeepES] Cloned ESAgent cannot call update function. \ + Please use original ESAgent."; + + bool success = _load(); + CHECK(success) << "[DeepES] fail to load previous models."; + + int current_model_iter_id = _config->async_es().model_iter_id(); + // validate model_iter_id for each sample before the update + for (int i = 0; i < noisy_info.size(); ++i) { + int model_iter_id = noisy_info[i].model_iter_id(); + if (model_iter_id != current_model_iter_id + && _previous_predictors.count(model_iter_id) == 0) { + LOG(WARNING) << "[DeepES] The sample with model_dir_id: " << model_iter_id \ + << " cannot match any local model"; + success = false; + return success; + } + } + + compute_centered_ranks(noisy_rewards); + memset(_neg_gradients, 0, _param_size * sizeof(float)); + + for (int i = 0; i < noisy_info.size(); ++i) { + int key = noisy_info[i].key(0); + float reward = noisy_rewards[i]; + int model_iter_id = noisy_info[i].model_iter_id(); + bool success = _sampling_method->resampling(key, _noise, _param_size); + float* delta = _param_delta[model_iter_id]; + // compute neg_gradients + if (model_iter_id == current_model_iter_id) { + for (int64_t j = 0; j < _param_size; ++j) { + _neg_gradients[j] += _noise[j] * reward; + } + } else { + for (int64_t j = 0; j < _param_size; ++j) { + _neg_gradients[j] += (_noise[j] + delta[j]) * reward; + } + } + } + for (int64_t j = 0; j < _param_size; ++j) { + _neg_gradients[j] /= -1.0 * noisy_info.size(); + } + + //update + int64_t counter = 0; + + for (std::string param_name: _param_names) { + std::unique_ptr tensor = _predictor->GetMutableTensor(param_name); + float* tensor_data = tensor->mutable_data(); + int64_t tensor_size = ShapeProduction(tensor->shape()); + _optimizer->update(tensor_data, _neg_gradients + counter, tensor_size, param_name); + counter += tensor_size; + } + success = _save(); + CHECK(success) << "[DeepES] fail to save model."; + return true; +} + +int AsyncAgent::_parse_model_iter_id(const std::string& model_path) { + int model_iter_id = -1; + int pow = 1; + for (int i = model_path.size() - 1; i >= 0; --i) { + if (model_path[i] >= '0' && model_path[i] <= '9') { + if (model_iter_id == -1) model_iter_id = 0; + } else { + break; + } + model_iter_id += pow * (model_path[i] - '0'); + pow *= 10; + } + return model_iter_id; +} + +}//namespace diff --git a/deepes/src/paddle/es_agent.cc b/deepes/src/paddle/es_agent.cc index ed2b1b9..4299df4 100644 --- a/deepes/src/paddle/es_agent.cc +++ b/deepes/src/paddle/es_agent.cc @@ -13,14 +13,14 @@ // limitations under the License. #include "es_agent.h" +#include namespace DeepES { -typedef paddle::lite_api::PaddlePredictor PaddlePredictor; typedef paddle::lite_api::Tensor Tensor; typedef paddle::lite_api::shape_t shape_t; -inline int64_t ShapeProduction(const shape_t& shape) { +int64_t ShapeProduction(const shape_t& shape) { int64_t res = 1; for (auto i : shape) res *= i; return res; @@ -71,6 +71,7 @@ std::shared_ptr ESAgent::clone() { new_agent->_is_sampling_agent = true; new_agent->_sampling_method = _sampling_method; new_agent->_param_names = _param_names; + new_agent->_config = _config; new_agent->_param_size = _param_size; new_agent->_noise = noise; @@ -111,7 +112,6 @@ bool ESAgent::update( counter += tensor_size; } return true; - } bool ESAgent::add_noise(SamplingInfo& sampling_info) { @@ -121,7 +121,9 @@ bool ESAgent::add_noise(SamplingInfo& sampling_info) { } int key = _sampling_method->sampling(_noise, _param_size); + int model_iter_id = _config->async_es().model_iter_id(); sampling_info.add_key(key); + sampling_info.set_model_iter_id(model_iter_id); int64_t counter = 0; for (std::string param_name: _param_names) { @@ -137,7 +139,6 @@ bool ESAgent::add_noise(SamplingInfo& sampling_info) { return true; } - std::shared_ptr ESAgent::get_predictor() { return _sampling_predictor; } @@ -151,6 +152,4 @@ int64_t ESAgent::_calculate_param_size() { return param_size; } - -} - +}//namespace diff --git a/deepes/src/proto/deepes.proto b/deepes/src/proto/deepes.proto index c6c1c9c..b839ef2 100644 --- a/deepes/src/proto/deepes.proto +++ b/deepes/src/proto/deepes.proto @@ -51,4 +51,5 @@ message AsyncESConfig{ optional string model_warehouse = 1 [default = "./model_warehouse"]; repeated string model_md5 = 2; optional int32 max_to_keep = 3 [default = 5]; + optional int32 model_iter_id = 4 [default = 0]; } diff --git a/deepes/src/utils.cc b/deepes/src/utils.cc index 5a2e8c3..153f5a1 100644 --- a/deepes/src/utils.cc +++ b/deepes/src/utils.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "utils.h" +#include namespace DeepES { @@ -33,4 +34,21 @@ void compute_centered_ranks(std::vector &reward) { } } +std::vector list_all_model_dirs(std::string path) { + std::vector model_dirs; + DIR *dpdf; + struct dirent *epdf; + dpdf = opendir(path.data()); + if (dpdf != NULL){ + while (epdf = readdir(dpdf)){ + std::string dir(epdf->d_name); + if (dir.find("model_iter_id") != std::string::npos) { + model_dirs.push_back(path + "/" + dir); + } + } + } + closedir(dpdf); + return model_dirs; +} + }//namespace -- GitLab