From 753960103f18fab65a46172fc004c6039be44330 Mon Sep 17 00:00:00 2001 From: Bo Zhou <2466956298@qq.com> Date: Wed, 1 Apr 2020 20:22:25 +0800 Subject: [PATCH] add an agent for the support of asynchronous update (#235) * rename SamplingKey to SamplingInfo * add asynchronous lsagent * rename AsyncAgent to AsyncESAgent * fix comments * Update async_es_agent.cc * Sampling Info * commemt * commemt --- deepes/CMakeLists.txt | 1 + deepes/README.md | 8 +- deepes/benchmark/cartpole_config.prototxt | 19 +- deepes/demo/paddle/cartpole_async_solver.cc | 136 ++++++++++ .../demo/paddle/cartpole_solver_parallel.cc | 6 +- deepes/demo/torch/cartpole_solver_parallel.cc | 12 +- deepes/include/paddle/async_es_agent.h | 101 +++++++ deepes/include/paddle/es_agent.h | 23 +- deepes/include/torch/es_agent.h | 14 +- deepes/include/utils.h | 29 +- deepes/scripts/build.sh | 2 - deepes/src/paddle/async_es_agent.cc | 255 ++++++++++++++++++ deepes/src/paddle/es_agent.cc | 28 +- deepes/src/proto/deepes.proto | 12 +- deepes/src/utils.cc | 18 ++ deepes/test/src/torch_agent_test.cc | 4 +- 16 files changed, 602 insertions(+), 66 deletions(-) create mode 100644 deepes/demo/paddle/cartpole_async_solver.cc create mode 100644 deepes/include/paddle/async_es_agent.h create mode 100644 deepes/src/paddle/async_es_agent.cc diff --git a/deepes/CMakeLists.txt b/deepes/CMakeLists.txt index a9b1205..d1f482a 100644 --- a/deepes/CMakeLists.txt +++ b/deepes/CMakeLists.txt @@ -45,6 +45,7 @@ if (WITH_PADDLE) file(GLOB framework_src "src/paddle/*.cc") set(demo "${PROJECT_SOURCE_DIR}/demo/paddle/cartpole_solver_parallel.cc") + #set(demo "${PROJECT_SOURCE_DIR}/demo/paddle/cartpole_async_solver.cc") ########## Torch config ########## elseif (WITH_TORCH) list(APPEND CMAKE_PREFIX_PATH "./libtorch") diff --git a/deepes/README.md b/deepes/README.md index 48dbee2..c2e6ef8 100644 --- a/deepes/README.md +++ b/deepes/README.md @@ -11,14 +11,14 @@ auto agent = ESAgent(config); for (int i = 0; i < 10; ++i) { auto sampling_agnet = agent->clone(); // clone出一个sampling agent - SamplingKey key; - agent->add_noise(key); // 参数扰动,同时保存随机种子到key中 + SamplingInfo info; + agent->add_noise(info); // 参数扰动,同时保存随机种子到info中 int reward = evaluate(env, sampling_agent); //评估参数 - noisy_keys.push_back(key); // 记录随机噪声对应种子 + noisy_info.push_back(info); // 记录随机噪声对应种子 noisy_rewards.push_back(reward); // 记录评估结果 } //根据评估结果、随机种子更新参数,然后重复以上过程,直到收敛。 -agent->update(noisy_keys, noisy_rewards); +agent->update(noisy_info, noisy_rewards); ``` ## 一键运行demo列表 diff --git a/deepes/benchmark/cartpole_config.prototxt b/deepes/benchmark/cartpole_config.prototxt index 07e337c..03cc5fb 100644 --- a/deepes/benchmark/cartpole_config.prototxt +++ b/deepes/benchmark/cartpole_config.prototxt @@ -1,14 +1,15 @@ -seed : 1024 - +seed: 1024 gaussian_sampling { std: 0.5 } - optimizer { - type: "Adam", - base_lr: 0.05, - momentum: 0.9, - beta1: 0.9, - beta2: 0.999, - epsilon: 1e-8, + type: "Adam" + base_lr: 0.05 + momentum: 0.9 + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-08 +} +async_es { + model_iter_id: 0 } diff --git a/deepes/demo/paddle/cartpole_async_solver.cc b/deepes/demo/paddle/cartpole_async_solver.cc new file mode 100644 index 0000000..5cbe48e --- /dev/null +++ b/deepes/demo/paddle/cartpole_async_solver.cc @@ -0,0 +1,136 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "cartpole.h" +#include "async_es_agent.h" +#include "paddle_api.h" + +using namespace DeepES; +using namespace paddle::lite_api; + +const int ITER = 10; + +std::shared_ptr create_paddle_predictor(const std::string& model_dir) { + // 1. Create CxxConfig + CxxConfig config; + config.set_model_dir(model_dir); + config.set_valid_places({ + Place{TARGET(kX86), PRECISION(kFloat)}, + Place{TARGET(kHost), PRECISION(kFloat)} + }); + + // 2. Create PaddlePredictor by CxxConfig + std::shared_ptr predictor = CreatePaddlePredictor(config); + return predictor; +} + +// Use PaddlePredictor of CartPole model to predict the action. +std::vector forward(std::shared_ptr predictor, const float* obs) { + std::unique_ptr input_tensor(std::move(predictor->GetInput(0))); + input_tensor->Resize({1, 4}); + input_tensor->CopyFromCpu(obs); + + predictor->Run(); + + std::vector probs(2, 0.0); + std::unique_ptr output_tensor( + std::move(predictor->GetOutput(0))); + output_tensor->CopyToCpu(probs.data()); + return probs; +} + +int arg_max(const std::vector& vec) { + return static_cast(std::distance(vec.begin(), std::max_element(vec.begin(), vec.end()))); +} + + +float evaluate(CartPole& env, std::shared_ptr agent) { + float total_reward = 0.0; + env.reset(); + const float* obs = env.getState(); + + std::shared_ptr paddle_predictor; + paddle_predictor = agent->get_predictor(); + + while (true) { + std::vector probs = forward(paddle_predictor, obs); + int act = arg_max(probs); + env.step(act); + float reward = env.getReward(); + bool done = env.isDone(); + total_reward += reward; + if (done) break; + obs = env.getState(); + } + return total_reward; +} + + +int main(int argc, char* argv[]) { + std::vector envs; + for (int i = 0; i < ITER; ++i) { + envs.push_back(CartPole()); + } + + std::shared_ptr paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model"); + std::shared_ptr agent = std::make_shared(paddle_predictor, "../benchmark/cartpole_config.prototxt"); + + // Clone agents to sample (explore). + std::vector< std::shared_ptr > sampling_agents; + for (int i = 0; i < ITER; ++i) { + sampling_agents.push_back(agent->clone()); + } + + std::vector noisy_info; + std::vector last_noisy_info; + std::vector noisy_rewards(ITER, 0.0f); + std::vector last_noisy_rewards; + noisy_info.resize(ITER); + + omp_set_num_threads(10); + for (int epoch = 0; epoch < 100; ++epoch) { + last_noisy_info.clear(); + last_noisy_rewards.clear(); + if (epoch != 0) { + for (int i = 0; i < ITER; ++i){ + last_noisy_info.push_back(noisy_info[i]); + last_noisy_rewards.push_back(noisy_rewards[i]); + } + } +#pragma omp parallel for schedule(dynamic, 1) + for (int i = 0; i < ITER; ++i) { + std::shared_ptr sampling_agent = sampling_agents[i]; + SamplingInfo info; + bool success = sampling_agent->add_noise(info); + float reward = evaluate(envs[i], sampling_agent); + + noisy_info[i] = info; + noisy_rewards[i] = reward; + } + + for (int i = 0; i < ITER; ++i){ + last_noisy_info.push_back(noisy_info[i]); + last_noisy_rewards.push_back(noisy_rewards[i]); + } + + // NOTE: all parameters of sampling_agents will be updated + bool success = agent->update(last_noisy_info, last_noisy_rewards); + + int reward = evaluate(envs[0], agent); + LOG(INFO) << "Epoch:" << epoch << " Reward: " << reward; + } +} diff --git a/deepes/demo/paddle/cartpole_solver_parallel.cc b/deepes/demo/paddle/cartpole_solver_parallel.cc index 5fe4ae0..9fccb1a 100644 --- a/deepes/demo/paddle/cartpole_solver_parallel.cc +++ b/deepes/demo/paddle/cartpole_solver_parallel.cc @@ -95,16 +95,16 @@ int main(int argc, char* argv[]) { sampling_agents.push_back(agent->clone()); } - std::vector noisy_keys; + std::vector noisy_keys; std::vector noisy_rewards(ITER, 0.0f); noisy_keys.resize(ITER); omp_set_num_threads(10); - for (int epoch = 0; epoch < 1000; ++epoch) { + for (int epoch = 0; epoch < 100; ++epoch) { #pragma omp parallel for schedule(dynamic, 1) for (int i = 0; i < ITER; ++i) { std::shared_ptr sampling_agent = sampling_agents[i]; - SamplingKey key; + SamplingInfo key; bool success = sampling_agent->add_noise(key); float reward = evaluate(envs[i], sampling_agent); diff --git a/deepes/demo/torch/cartpole_solver_parallel.cc b/deepes/demo/torch/cartpole_solver_parallel.cc index 24f8125..98be214 100644 --- a/deepes/demo/torch/cartpole_solver_parallel.cc +++ b/deepes/demo/torch/cartpole_solver_parallel.cc @@ -59,23 +59,23 @@ int main(int argc, char* argv[]) { sampling_agents.push_back(agent->clone()); } - std::vector noisy_keys; + std::vector noisy_info; std::vector noisy_rewards(ITER, 0.0f); - noisy_keys.resize(ITER); + noisy_info.resize(ITER); for (int epoch = 0; epoch < 100; ++epoch) { #pragma omp parallel for schedule(dynamic, 1) for (int i = 0; i < ITER; ++i) { auto sampling_agent = sampling_agents[i]; - SamplingKey key; - bool success = sampling_agent->add_noise(key); + SamplingInfo info; + bool success = sampling_agent->add_noise(info); float reward = evaluate(envs[i], sampling_agent); - noisy_keys[i] = key; + noisy_info[i] = info; noisy_rewards[i] = reward; } // Will also update parameters of sampling_agents - bool success = agent->update(noisy_keys, noisy_rewards); + bool success = agent->update(noisy_info, noisy_rewards); // Use original agent to evalute (without noise). int reward = evaluate(envs[0], agent); diff --git a/deepes/include/paddle/async_es_agent.h b/deepes/include/paddle/async_es_agent.h new file mode 100644 index 0000000..11b8dff --- /dev/null +++ b/deepes/include/paddle/async_es_agent.h @@ -0,0 +1,101 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ASYNC_ES_AGENT_H +#define ASYNC_ES_AGENT_H + +#include "es_agent.h" +#include +#include + +namespace DeepES{ +/* DeepES agent with PaddleLite as backend. This agent supports asynchronous update. + * Users mainly focus on the following functions: + * 1. clone: clone an agent for multi-thread evaluation + * 2. add_noise: add noise into parameters. + * 3. update: update parameters given data collected during evaluation. + */ +class AsyncESAgent: public ESAgent { + public: + AsyncESAgent() {} + + ~AsyncESAgent(); + + /** + * @args: + * predictor: predictor created by users for prediction. + * config_path: the path of configuration file. + * Note that AsyncESAgent will update the configuration file after calling the update function. + * Please use the up-to-date configuration. + */ + AsyncESAgent( + std::shared_ptr predictor, + std::string config_path); + + /** + * @brief: Clone an agent for sampling. + */ + std::shared_ptr clone(); + + /** + * @brief: update parameters given data collected during evaluation. + * @args: + * noisy_info: sampling information returned by add_noise function. + * noisy_reward: evaluation rewards. + */ + bool update( + std::vector& noisy_info, + std::vector& noisy_rewards); + + private: + std::map> _previous_predictors; + std::map _param_delta; + std::string _config_path; + + /** + * @brief: parse model_iter_id given a string of model directory. + * @return: an integer indicating the model_iter_id + */ + int _parse_model_iter_id(const std::string&); + + /** + * @brief: compute the distance between current parameter and previous models. + */ + bool _compute_model_diff(); + + /** + * @brief: remove expired models to avoid overuse of disk space. + * @args: + * max_to_keep: the maximum number of models to keep locally. + */ + bool _remove_expired_model(int max_to_keep); + + /** + * @brief: save up-to-date parameters to the disk. + */ + bool _save(); + + /** + * @brief: load all models in the model warehouse. + */ + bool _load(); + + /** + * @brief: load a model given the model directory. + */ + std::shared_ptr _load_previous_model(std::string model_dir); +}; + +} //namespace +#endif diff --git a/deepes/include/paddle/es_agent.h b/deepes/include/paddle/es_agent.h index ceeaf05..25c9d98 100644 --- a/deepes/include/paddle/es_agent.h +++ b/deepes/include/paddle/es_agent.h @@ -22,20 +22,19 @@ #include "deepes.pb.h" #include +using namespace paddle::lite_api; namespace DeepES { -typedef paddle::lite_api::PaddlePredictor PaddlePredictor; +int64_t ShapeProduction(const shape_t& shape); /** - * @brief DeepES agent for PaddleLite. + * @brief DeepES agent with PaddleLite as backend. + * Users mainly focus on the following functions: + * 1. clone: clone an agent for multi-thread evaluation + * 2. add_noise: add noise into parameters. + * 3. update: update parameters given data collected during evaluation. * - * Users use `clone` fucntion to clone a sampling agent, which can call `add_noise` - * function to add noise to copied parameters and call `get_predictor` fucntion to - * get a paddle predictor with added noise. - * - * Then can use `update` function to update parameters based on ES algorithm. - * Note: parameters of cloned agents will also be updated. */ class ESAgent { public: @@ -63,11 +62,11 @@ class ESAgent { * Parameters of cloned agents will also be updated. */ bool update( - std::vector& noisy_keys, + std::vector& noisy_info, std::vector& noisy_rewards); // copied parameters = original parameters + noise - bool add_noise(SamplingKey& sampling_key); + bool add_noise(SamplingInfo& sampling_info); /** * @brief Get paddle predict @@ -77,7 +76,9 @@ class ESAgent { */ std::shared_ptr get_predictor(); - private: + + + protected: int64_t _calculate_param_size(); std::shared_ptr _predictor; diff --git a/deepes/include/torch/es_agent.h b/deepes/include/torch/es_agent.h index f98d220..c4fc821 100644 --- a/deepes/include/torch/es_agent.h +++ b/deepes/include/torch/es_agent.h @@ -98,7 +98,7 @@ public: * Only not cloned ESAgent can call `update` function. * Parameters of cloned agents will also be updated. */ - bool update(std::vector& noisy_keys, std::vector& noisy_rewards) { + bool update(std::vector& noisy_info, std::vector& noisy_rewards) { if (_is_sampling_agent) { LOG(ERROR) << "[DeepES] Cloned ESAgent cannot call update function, please use original ESAgent."; return false; @@ -107,8 +107,8 @@ public: compute_centered_ranks(noisy_rewards); memset(_neg_gradients, 0, _param_size * sizeof(float)); - for (int i = 0; i < noisy_keys.size(); ++i) { - int key = noisy_keys[i].key(0); + for (int i = 0; i < noisy_info.size(); ++i) { + int key = noisy_info[i].key(0); float reward = noisy_rewards[i]; bool success = _sampling_method->resampling(key, _noise, _param_size); for (int64_t j = 0; j < _param_size; ++j) { @@ -116,7 +116,7 @@ public: } } for (int64_t j = 0; j < _param_size; ++j) { - _neg_gradients[j] /= -1.0 * noisy_keys.size(); + _neg_gradients[j] /= -1.0 * noisy_info.size(); } //update @@ -133,7 +133,7 @@ public: } // copied parameters = original parameters + noise - bool add_noise(SamplingKey& sampling_key) { + bool add_noise(SamplingInfo& sampling_info) { if (!_is_sampling_agent) { LOG(ERROR) << "[DeepES] Original ESAgent cannot call add_noise function, please use cloned ESAgent."; return false; @@ -142,7 +142,7 @@ public: auto sampling_params = _sampling_model->named_parameters(); auto params = _model->named_parameters(); int key = _sampling_method->sampling(_noise, _param_size); - sampling_key.add_key(key); + sampling_info.add_key(key); int64_t counter = 0; for (auto& param: sampling_params) { torch::Tensor sampling_tensor = param.value().view({-1}); @@ -162,7 +162,7 @@ public: private: int64_t _calculate_param_size() { - int _param_size = 0; + _param_size = 0; auto params = _model->named_parameters(); for (auto& param: params) { torch::Tensor tensor = param.value().view({-1}); diff --git a/deepes/include/utils.h b/deepes/include/utils.h index 843f7c4..5835a43 100644 --- a/deepes/include/utils.h +++ b/deepes/include/utils.h @@ -39,9 +39,8 @@ template bool load_proto_conf(const std::string& config_file, T& proto_config) { bool success = true; std::ifstream fin(config_file); - CHECK(fin) << "open config file " << config_file; - if (fin.fail()) { - LOG(FATAL) << "open prototxt config failed: " << config_file; + if (!fin || fin.fail()) { + LOG(ERROR) << "open prototxt config failed: " << config_file; success = false; } else { fin.seekg(0, std::ios::end); @@ -53,8 +52,8 @@ bool load_proto_conf(const std::string& config_file, T& proto_config) { std::string proto_str(file_content_buffer, file_size); if (!google::protobuf::TextFormat::ParseFromString(proto_str, &proto_config)) { - LOG(FATAL) << "Failed to load config: " << config_file; - return -1; + LOG(ERROR) << "Failed to load config: " << config_file; + success = false; } delete[] file_content_buffer; fin.close(); @@ -62,6 +61,26 @@ bool load_proto_conf(const std::string& config_file, T& proto_config) { return success; } +template +bool save_proto_conf(const std::string& config_file, T&proto_config) { + bool success = true; + std::ofstream ofs(config_file, std::ofstream::out); + if (!ofs || ofs.fail()) { + LOG(ERROR) << "open prototxt config failed: " << config_file; + success = false; + } else { + std::string config_str; + success = google::protobuf::TextFormat::PrintToString(proto_config, &config_str); + if (!success) { + return success; + } + ofs << config_str; + } + return success; +} + +std::vector list_all_model_dirs(std::string path); + } #endif diff --git a/deepes/scripts/build.sh b/deepes/scripts/build.sh index c1a4067..2b5f6da 100644 --- a/deepes/scripts/build.sh +++ b/deepes/scripts/build.sh @@ -35,8 +35,6 @@ else exit 0 fi -export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH - #----------------protobuf-------------# cp ./src/proto/deepes.proto ./ protoc deepes.proto --cpp_out ./ diff --git a/deepes/src/paddle/async_es_agent.cc b/deepes/src/paddle/async_es_agent.cc new file mode 100644 index 0000000..f128ddc --- /dev/null +++ b/deepes/src/paddle/async_es_agent.cc @@ -0,0 +1,255 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "async_es_agent.h" +namespace DeepES { + +AsyncESAgent::AsyncESAgent( + std::shared_ptr predictor, + std::string config_path): ESAgent(predictor, config_path) { + _config_path = config_path; +} +AsyncESAgent::~AsyncESAgent() { + for(const auto kv: _param_delta) { + float* delta = kv.second; + delete[] delta; + } +} + +bool AsyncESAgent::_save() { + bool success = true; + if (_is_sampling_agent) { + LOG(ERROR) << "[DeepES] Cloned AsyncESAgent cannot call `save`.Please use original AsyncESAgent."; + success = false; + return success; + } + int model_iter_id = _config->async_es().model_iter_id() + 1; + //current time + time_t rawtime; + struct tm * timeinfo; + char buffer[80]; + + time (&rawtime); + timeinfo = localtime(&rawtime); + + std::string model_name = "model_iter_id-"+ std::to_string(model_iter_id); + std::string model_path = _config->async_es().model_warehouse() + "/" + model_name; + LOG(INFO) << "[save]model_path: " << model_path; + _predictor->SaveOptimizedModel(model_path, LiteModelType::kProtobuf); + // save config + auto async_es = _config->mutable_async_es(); + async_es->set_model_iter_id(model_iter_id); + success = save_proto_conf(_config_path, *_config); + if (!success) { + LOG(ERROR) << "[]unable to save config for AsyncESAgent"; + success = false; + return success; + } + int max_to_keep = _config->async_es().max_to_keep(); + success = _remove_expired_model(max_to_keep); + return success; +} + +bool AsyncESAgent::_remove_expired_model(int max_to_keep) { + bool success = true; + std::string model_path = _config->async_es().model_warehouse(); + std::vector model_dirs = list_all_model_dirs(model_path); + int model_iter_id = _config->async_es().model_iter_id() + 1; + for (const auto& dir: model_dirs) { + int dir_model_iter_id = _parse_model_iter_id(dir); + if (model_iter_id - dir_model_iter_id >= max_to_keep) { + std::string rm_command = std::string("rm -rf ") + dir; + int ret = system(rm_command.c_str()); + if (ret == 0) { + LOG(INFO) << "[DeepES] remove expired Model: " << dir; + } else { + LOG(ERROR) << "[DeepES] fail to remove expired Model: " << dir; + success = false; + return success; + } + } + } + return success; +} + +bool AsyncESAgent::_compute_model_diff() { + bool success = true; + for (const auto& kv: _previous_predictors) { + int model_iter_id = kv.first; + std::shared_ptr old_predictor = kv.second; + float* diff = new float[_param_size]; + memset(diff, 0, _param_size * sizeof(float)); + int offset = 0; + for (const std::string& param_name: _param_names) { + auto des_tensor = old_predictor->GetTensor(param_name); + auto src_tensor = _predictor->GetTensor(param_name); + const float* des_data = des_tensor->data(); + const float* src_data = src_tensor->data(); + int64_t tensor_size = ShapeProduction(src_tensor->shape()); + for (int i = 0; i < tensor_size; ++i) { + diff[i + offset] = des_data[i] - src_data[i]; + } + offset += tensor_size; + } + _param_delta[model_iter_id] = diff; + } + return success; +} + +bool AsyncESAgent::_load() { + bool success = true; + std::string model_path = _config->async_es().model_warehouse(); + std::vector model_dirs = list_all_model_dirs(model_path); + if (model_dirs.size() == 0) { + int model_iter_id = _config->async_es().model_iter_id(); + success = model_iter_id == 0 ? true: false; + if (!success) { + LOG(WARNING) << "[DeepES] current_model_iter_id is nonzero, but no model is \ + found at the dir: " << model_path; + } + return success; + } + for(auto &dir: model_dirs) { + int model_iter_id = _parse_model_iter_id(dir); + if (model_iter_id == -1) { + LOG(WARNING) << "[DeepES] fail to parse model_iter_id: " << dir; + success = false; + return success; + } + std::shared_ptr predictor = _load_previous_model(dir); + if (predictor == nullptr) { + success = false; + LOG(WARNING) << "[DeepES] fail to load model: " << dir; + return success; + } + _previous_predictors[model_iter_id] = predictor; + } + success = _compute_model_diff(); + return success; +} + +std::shared_ptr AsyncESAgent::_load_previous_model(std::string model_dir) { + // 1. Create CxxConfig + CxxConfig config; + config.set_model_file(model_dir + "/model"); + config.set_param_file(model_dir + "/params"); + config.set_valid_places({ + Place{TARGET(kX86), PRECISION(kFloat)}, + Place{TARGET(kHost), PRECISION(kFloat)} + }); + + // 2. Create PaddlePredictor by CxxConfig + std::shared_ptr predictor = CreatePaddlePredictor(config); + return predictor; +} + +std::shared_ptr AsyncESAgent::clone() { + std::shared_ptr new_sampling_predictor = _predictor->Clone(); + + std::shared_ptr new_agent = std::make_shared(); + + float* noise = new float [_param_size]; + + new_agent->_predictor = _predictor; + new_agent->_sampling_predictor = new_sampling_predictor; + + new_agent->_is_sampling_agent = true; + new_agent->_sampling_method = _sampling_method; + new_agent->_param_names = _param_names; + new_agent->_param_size = _param_size; + new_agent->_config = _config; + new_agent->_noise = noise; + + return new_agent; +} + +bool AsyncESAgent::update( + std::vector& noisy_info, + std::vector& noisy_rewards) { + + CHECK(!_is_sampling_agent) << "[DeepES] Cloned ESAgent cannot call update function. \ + Please use original ESAgent."; + + bool success = _load(); + CHECK(success) << "[DeepES] fail to load previous models."; + + int current_model_iter_id = _config->async_es().model_iter_id(); + // validate model_iter_id for each sample before the update + for (int i = 0; i < noisy_info.size(); ++i) { + int model_iter_id = noisy_info[i].model_iter_id(); + if (model_iter_id != current_model_iter_id + && _previous_predictors.count(model_iter_id) == 0) { + LOG(WARNING) << "[DeepES] The sample with model_dir_id: " << model_iter_id \ + << " cannot match any local model"; + success = false; + return success; + } + } + + compute_centered_ranks(noisy_rewards); + memset(_neg_gradients, 0, _param_size * sizeof(float)); + + for (int i = 0; i < noisy_info.size(); ++i) { + int key = noisy_info[i].key(0); + float reward = noisy_rewards[i]; + int model_iter_id = noisy_info[i].model_iter_id(); + bool success = _sampling_method->resampling(key, _noise, _param_size); + CHECK(success) << "[DeepES] resampling error occurs at sample: " << i; + float* delta = _param_delta[model_iter_id]; + // compute neg_gradients + if (model_iter_id == current_model_iter_id) { + for (int64_t j = 0; j < _param_size; ++j) { + _neg_gradients[j] += _noise[j] * reward; + } + } else { + for (int64_t j = 0; j < _param_size; ++j) { + _neg_gradients[j] += (_noise[j] + delta[j]) * reward; + } + } + } + for (int64_t j = 0; j < _param_size; ++j) { + _neg_gradients[j] /= -1.0 * noisy_info.size(); + } + + //update + int64_t counter = 0; + + for (std::string param_name: _param_names) { + std::unique_ptr tensor = _predictor->GetMutableTensor(param_name); + float* tensor_data = tensor->mutable_data(); + int64_t tensor_size = ShapeProduction(tensor->shape()); + _optimizer->update(tensor_data, _neg_gradients + counter, tensor_size, param_name); + counter += tensor_size; + } + success = _save(); + CHECK(success) << "[DeepES] fail to save model."; + return true; +} + +int AsyncESAgent::_parse_model_iter_id(const std::string& model_path) { + int model_iter_id = -1; + int pow = 1; + for (int i = model_path.size() - 1; i >= 0; --i) { + if (model_path[i] >= '0' && model_path[i] <= '9') { + if (model_iter_id == -1) model_iter_id = 0; + } else { + break; + } + model_iter_id += pow * (model_path[i] - '0'); + pow *= 10; + } + return model_iter_id; +} + +}//namespace diff --git a/deepes/src/paddle/es_agent.cc b/deepes/src/paddle/es_agent.cc index a27a0d7..6cd6f93 100644 --- a/deepes/src/paddle/es_agent.cc +++ b/deepes/src/paddle/es_agent.cc @@ -13,14 +13,11 @@ // limitations under the License. #include "es_agent.h" +#include namespace DeepES { -typedef paddle::lite_api::PaddlePredictor PaddlePredictor; -typedef paddle::lite_api::Tensor Tensor; -typedef paddle::lite_api::shape_t shape_t; - -inline int64_t ShapeProduction(const shape_t& shape) { +int64_t ShapeProduction(const shape_t& shape) { int64_t res = 1; for (auto i : shape) res *= i; return res; @@ -71,6 +68,7 @@ std::shared_ptr ESAgent::clone() { new_agent->_is_sampling_agent = true; new_agent->_sampling_method = _sampling_method; new_agent->_param_names = _param_names; + new_agent->_config = _config; new_agent->_param_size = _param_size; new_agent->_noise = noise; @@ -78,7 +76,7 @@ std::shared_ptr ESAgent::clone() { } bool ESAgent::update( - std::vector& noisy_keys, + std::vector& noisy_info, std::vector& noisy_rewards) { if (_is_sampling_agent) { LOG(ERROR) << "[DeepES] Cloned ESAgent cannot call update function, please use original ESAgent."; @@ -88,8 +86,8 @@ bool ESAgent::update( compute_centered_ranks(noisy_rewards); memset(_neg_gradients, 0, _param_size * sizeof(float)); - for (int i = 0; i < noisy_keys.size(); ++i) { - int key = noisy_keys[i].key(0); + for (int i = 0; i < noisy_info.size(); ++i) { + int key = noisy_info[i].key(0); float reward = noisy_rewards[i]; bool success = _sampling_method->resampling(key, _noise, _param_size); for (int64_t j = 0; j < _param_size; ++j) { @@ -97,7 +95,7 @@ bool ESAgent::update( } } for (int64_t j = 0; j < _param_size; ++j) { - _neg_gradients[j] /= -1.0 * noisy_keys.size(); + _neg_gradients[j] /= -1.0 * noisy_info.size(); } //update @@ -111,17 +109,18 @@ bool ESAgent::update( counter += tensor_size; } return true; - } -bool ESAgent::add_noise(SamplingKey& sampling_key) { +bool ESAgent::add_noise(SamplingInfo& sampling_info) { if (!_is_sampling_agent) { LOG(ERROR) << "[DeepES] Original ESAgent cannot call add_noise function, please use cloned ESAgent."; return false; } int key = _sampling_method->sampling(_noise, _param_size); - sampling_key.add_key(key); + int model_iter_id = _config->async_es().model_iter_id(); + sampling_info.add_key(key); + sampling_info.set_model_iter_id(model_iter_id); int64_t counter = 0; for (std::string param_name: _param_names) { @@ -137,7 +136,6 @@ bool ESAgent::add_noise(SamplingKey& sampling_key) { return true; } - std::shared_ptr ESAgent::get_predictor() { return _sampling_predictor; } @@ -151,6 +149,4 @@ int64_t ESAgent::_calculate_param_size() { return param_size; } - -} - +}//namespace diff --git a/deepes/src/proto/deepes.proto b/deepes/src/proto/deepes.proto index 38abee9..b839ef2 100644 --- a/deepes/src/proto/deepes.proto +++ b/deepes/src/proto/deepes.proto @@ -23,6 +23,8 @@ message DeepESConfig { optional GaussianSamplingConfig gaussian_sampling = 3; // Optimizer Configuration optional OptimizerConfig optimizer = 4; + // AsyncESAgent Configuration + optional AsyncESConfig async_es = 5; } message GaussianSamplingConfig { @@ -40,6 +42,14 @@ message OptimizerConfig{ optional float epsilon = 6 [default = 1e-8]; } -message SamplingKey{ +message SamplingInfo{ repeated int32 key = 1; + optional int32 model_iter_id = 2; +} + +message AsyncESConfig{ + optional string model_warehouse = 1 [default = "./model_warehouse"]; + repeated string model_md5 = 2; + optional int32 max_to_keep = 3 [default = 5]; + optional int32 model_iter_id = 4 [default = 0]; } diff --git a/deepes/src/utils.cc b/deepes/src/utils.cc index 726da66..cd5b055 100644 --- a/deepes/src/utils.cc +++ b/deepes/src/utils.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "utils.h" +#include namespace DeepES { @@ -34,4 +35,21 @@ bool compute_centered_ranks(std::vector &reward) { return true; } +std::vector list_all_model_dirs(std::string path) { + std::vector model_dirs; + DIR *dpdf; + struct dirent *epdf; + dpdf = opendir(path.data()); + if (dpdf != NULL){ + while (epdf = readdir(dpdf)){ + std::string dir(epdf->d_name); + if (dir.find("model_iter_id") != std::string::npos) { + model_dirs.push_back(path + "/" + dir); + } + } + } + closedir(dpdf); + return model_dirs; +} + }//namespace diff --git a/deepes/test/src/torch_agent_test.cc b/deepes/test/src/torch_agent_test.cc index cf0b322..a0aabf2 100644 --- a/deepes/test/src/torch_agent_test.cc +++ b/deepes/test/src/torch_agent_test.cc @@ -89,7 +89,7 @@ protected: sampling_agents.push_back(agent->clone()); } - std::vector noisy_keys; + std::vector noisy_keys; std::vector noisy_rewards(iter, 0.0f); noisy_keys.resize(iter); @@ -98,7 +98,7 @@ protected: #pragma omp parallel for schedule(dynamic, 1) for (int i = 0; i < iter; ++i) { auto sampling_agent = sampling_agents[i]; - SamplingKey key; + SamplingInfo key; bool success = sampling_agent->add_noise(key); float reward = evaluate(x_list, y_list, train_data_size, sampling_agent); noisy_keys[i] = key; -- GitLab