提交 7793b479 编写于 作者: Z zhoubo01

add asynchronous lsagent

上级 3f22fd3e
seed : 1024 seed: 1024
gaussian_sampling { gaussian_sampling {
std: 0.5 std: 0.5
} }
optimizer { optimizer {
type: "Adam", type: "Adam"
base_lr: 0.05, base_lr: 0.05
momentum: 0.9, momentum: 0.9
beta1: 0.9, beta1: 0.9
beta2: 0.999, beta2: 0.999
epsilon: 1e-8, epsilon: 1e-08
}
async_es {
model_iter_id: 0
} }
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <glog/logging.h>
#include <omp.h>
#include "cartpole.h"
#include "async_es_agent.h"
#include "paddle_api.h"
using namespace DeepES;
using namespace paddle::lite_api;
const int ITER = 10;
std::shared_ptr<PaddlePredictor> create_paddle_predictor(const std::string& model_dir) {
// 1. Create CxxConfig
CxxConfig config;
config.set_model_dir(model_dir);
config.set_valid_places({
Place{TARGET(kX86), PRECISION(kFloat)},
Place{TARGET(kHost), PRECISION(kFloat)}
});
// 2. Create PaddlePredictor by CxxConfig
std::shared_ptr<PaddlePredictor> predictor = CreatePaddlePredictor<CxxConfig>(config);
return predictor;
}
// Use PaddlePredictor of CartPole model to predict the action.
std::vector<float> forward(std::shared_ptr<PaddlePredictor> predictor, const float* obs) {
std::unique_ptr<Tensor> input_tensor(std::move(predictor->GetInput(0)));
input_tensor->Resize({1, 4});
input_tensor->CopyFromCpu(obs);
predictor->Run();
std::vector<float> probs(2, 0.0);
std::unique_ptr<const Tensor> output_tensor(
std::move(predictor->GetOutput(0)));
output_tensor->CopyToCpu(probs.data());
return probs;
}
int arg_max(const std::vector<float>& vec) {
return static_cast<int>(std::distance(vec.begin(), std::max_element(vec.begin(), vec.end())));
}
float evaluate(CartPole& env, std::shared_ptr<AsyncAgent> agent) {
float total_reward = 0.0;
env.reset();
const float* obs = env.getState();
std::shared_ptr<PaddlePredictor> paddle_predictor;
paddle_predictor = agent->get_predictor();
while (true) {
std::vector<float> probs = forward(paddle_predictor, obs);
int act = arg_max(probs);
env.step(act);
float reward = env.getReward();
bool done = env.isDone();
total_reward += reward;
if (done) break;
obs = env.getState();
}
return total_reward;
}
int main(int argc, char* argv[]) {
std::vector<CartPole> envs;
for (int i = 0; i < ITER; ++i) {
envs.push_back(CartPole());
}
std::shared_ptr<PaddlePredictor> paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model");
std::shared_ptr<AsyncAgent> agent = std::make_shared<AsyncAgent>(paddle_predictor, "../benchmark/cartpole_config.prototxt");
// Clone agents to sample (explore).
std::vector< std::shared_ptr<AsyncAgent> > sampling_agents;
for (int i = 0; i < ITER; ++i) {
sampling_agents.push_back(agent->clone());
}
std::vector<SamplingInfo> noisy_info;
std::vector<SamplingInfo> last_noisy_info;
std::vector<float> noisy_rewards(ITER, 0.0f);
std::vector<float> last_noisy_rewards;
noisy_info.resize(ITER);
omp_set_num_threads(10);
for (int epoch = 0; epoch < 100; ++epoch) {
last_noisy_info.clear();
last_noisy_rewards.clear();
if (epoch != 0) {
for (int i = 0; i < ITER; ++i){
last_noisy_info.push_back(noisy_info[i]);
last_noisy_rewards.push_back(noisy_rewards[i]);
}
}
#pragma omp parallel for schedule(dynamic, 1)
for (int i = 0; i < ITER; ++i) {
std::shared_ptr<AsyncAgent> sampling_agent = sampling_agents[i];
SamplingInfo info;
bool success = sampling_agent->add_noise(info);
float reward = evaluate(envs[i], sampling_agent);
noisy_info[i] = info;
noisy_rewards[i] = reward;
}
for (int i = 0; i < ITER; ++i){
last_noisy_info.push_back(noisy_info[i]);
last_noisy_rewards.push_back(noisy_rewards[i]);
}
// NOTE: all parameters of sampling_agents will be updated
bool success = agent->update(last_noisy_info, last_noisy_rewards);
int reward = evaluate(envs[0], agent);
LOG(INFO) << "Epoch:" << epoch << " Reward: " << reward;
}
}
...@@ -95,25 +95,25 @@ int main(int argc, char* argv[]) { ...@@ -95,25 +95,25 @@ int main(int argc, char* argv[]) {
sampling_agents.push_back(agent->clone()); sampling_agents.push_back(agent->clone());
} }
std::vector<SamplingInfo> noisy_info; std::vector<SamplingInfo> noisy_keys;
std::vector<float> noisy_rewards(ITER, 0.0f); std::vector<float> noisy_rewards(ITER, 0.0f);
noisy_info.resize(ITER); noisy_keys.resize(ITER);
omp_set_num_threads(10); omp_set_num_threads(10);
for (int epoch = 0; epoch < 300; ++epoch) { for (int epoch = 0; epoch < 100; ++epoch) {
#pragma omp parallel for schedule(dynamic, 1) #pragma omp parallel for schedule(dynamic, 1)
for (int i = 0; i < ITER; ++i) { for (int i = 0; i < ITER; ++i) {
std::shared_ptr<ESAgent> sampling_agent = sampling_agents[i]; std::shared_ptr<ESAgent> sampling_agent = sampling_agents[i];
SamplingInfo info; SamplingInfo key;
bool success = sampling_agent->add_noise(info); bool success = sampling_agent->add_noise(key);
float reward = evaluate(envs[i], sampling_agent); float reward = evaluate(envs[i], sampling_agent);
noisy_info[i] = info; noisy_keys[i] = key;
noisy_rewards[i] = reward; noisy_rewards[i] = reward;
} }
// NOTE: all parameters of sampling_agents will be updated // NOTE: all parameters of sampling_agents will be updated
bool success = agent->update(noisy_info, noisy_rewards); bool success = agent->update(noisy_keys, noisy_rewards);
int reward = evaluate(envs[0], agent); int reward = evaluate(envs[0], agent);
LOG(INFO) << "Epoch:" << epoch << " Reward: " << reward; LOG(INFO) << "Epoch:" << epoch << " Reward: " << reward;
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef _ASYNC_ES_AGENT_H
#define _ASYNC_ES_AGENT_H
#include "es_agent.h"
#include <map>
#include <stdlib.h>
namespace DeepES{
/* DeepES agent with PaddleLite as backend. This agent supports asynchronous update.
* Users mainly focus on the following functions:
* 1. clone: clone an agent for multi-thread evaluation
* 2. add_noise: add noise into parameters.
* 3. update: update parameters given data collected during evaluation.
*/
class AsyncAgent: public ESAgent {
public:
AsyncAgent() {}
~AsyncAgent();
/**
* @args:
* predictor: predictor created by users for prediction.
* config_path: the path of configuration file.
* Note that AsyncAgent will update the configuration file after calling the update function.
* Please use the up-to-date configuration.
*/
AsyncAgent(
std::shared_ptr<PaddlePredictor> predictor,
std::string config_path);
/**
* @brief: Clone an agent for sampling.
*/
std::shared_ptr<AsyncAgent> clone();
/**
* @brief: Clone an agent for sampling.
*/
bool update(
std::vector<SamplingInfo>& noisy_info,
std::vector<float>& noisy_rewards);
private:
std::map<int, std::shared_ptr<PaddlePredictor>> _previous_predictors;
std::map<int, float*> _param_delta;
std::string _config_path;
/**
* @brief: parse model_iter_id given a string of model directory.
* @return: an integer indicating the model_iter_id
*/
int _parse_model_iter_id(const std::string&);
/**
* @brief: compute the distance between current parameter and previous models.
*/
bool _compute_model_diff();
/**
* @brief: remove expired models to avoid overuse of disk space.
* @args:
* max_to_keep: the maximum number of models to keep locally.
*/
bool _remove_expired_model(int max_to_keep);
/**
* @brief: save up-to-date parameters to the disk.
*/
bool _save();
/**
* @brief: load all models in the model warehouse.
*/
bool _load();
/**
* @brief: load a model given the model directory.
*/
std::shared_ptr<PaddlePredictor> _load_previous_model(std::string model_dir);
};
} //namespace
#endif
...@@ -21,21 +21,22 @@ ...@@ -21,21 +21,22 @@
#include "gaussian_sampling.h" #include "gaussian_sampling.h"
#include "deepes.pb.h" #include "deepes.pb.h"
#include <vector> #include <vector>
using namespace paddle::lite_api;
namespace DeepES { namespace DeepES {
int64_t ShapeProduction(const shape_t& shape);
typedef paddle::lite_api::PaddlePredictor PaddlePredictor; typedef paddle::lite_api::PaddlePredictor PaddlePredictor;
/** /**
* @brief DeepES agent for PaddleLite. * @brief DeepES agent with PaddleLite as backend.
* * Users mainly focus on the following functions:
* Users use `clone` fucntion to clone a sampling agent, which can call `add_noise` * 1. clone: clone an agent for multi-thread evaluation
* function to add noise to copied parameters and call `get_predictor` fucntion to * 2. add_noise: add noise into parameters.
* get a paddle predictor with added noise. * 3. update: update parameters given data collected during evaluation.
* *
* Then can use `update` function to update parameters based on ES algorithm.
* Note: parameters of cloned agents will also be updated.
*/ */
class ESAgent { class ESAgent {
public: public:
...@@ -77,7 +78,9 @@ class ESAgent { ...@@ -77,7 +78,9 @@ class ESAgent {
*/ */
std::shared_ptr<PaddlePredictor> get_predictor(); std::shared_ptr<PaddlePredictor> get_predictor();
private:
protected:
int64_t _calculate_param_size(); int64_t _calculate_param_size();
std::shared_ptr<PaddlePredictor> _predictor; std::shared_ptr<PaddlePredictor> _predictor;
......
...@@ -125,7 +125,7 @@ public: ...@@ -125,7 +125,7 @@ public:
for (auto& param: params) { for (auto& param: params) {
torch::Tensor tensor = param.value().view({-1}); torch::Tensor tensor = param.value().view({-1});
auto tensor_a = tensor.accessor<float,1>(); auto tensor_a = tensor.accessor<float,1>();
_optimizer->update(tensor_a, _neg_gradients+counter, tensor.size(0), param.info()); _optimizer->update(tensor_a, _neg_gradients+counter, tensor.size(0), param.key());
counter += tensor.size(0); counter += tensor.size(0);
} }
...@@ -146,7 +146,7 @@ public: ...@@ -146,7 +146,7 @@ public:
int64_t counter = 0; int64_t counter = 0;
for (auto& param: sampling_params) { for (auto& param: sampling_params) {
torch::Tensor sampling_tensor = param.value().view({-1}); torch::Tensor sampling_tensor = param.value().view({-1});
std::string param_name = param.info(); std::string param_name = param.key();
torch::Tensor tensor = params.find(param_name)->view({-1}); torch::Tensor tensor = params.find(param_name)->view({-1});
auto sampling_tensor_a = sampling_tensor.accessor<float,1>(); auto sampling_tensor_a = sampling_tensor.accessor<float,1>();
auto tensor_a = tensor.accessor<float,1>(); auto tensor_a = tensor.accessor<float,1>();
...@@ -162,6 +162,7 @@ public: ...@@ -162,6 +162,7 @@ public:
private: private:
int64_t _calculate_param_size() { int64_t _calculate_param_size() {
_param_size = 0;
auto params = _model->named_parameters(); auto params = _model->named_parameters();
for (auto& param: params) { for (auto& param: params) {
torch::Tensor tensor = param.value().view({-1}); torch::Tensor tensor = param.value().view({-1});
......
...@@ -39,8 +39,7 @@ template<typename T> ...@@ -39,8 +39,7 @@ template<typename T>
bool load_proto_conf(const std::string& config_file, T& proto_config) { bool load_proto_conf(const std::string& config_file, T& proto_config) {
bool success = true; bool success = true;
std::ifstream fin(config_file); std::ifstream fin(config_file);
CHECK(fin) << "open config file " << config_file; if (!fin || fin.fail()) {
if (fin.fail()) {
LOG(FATAL) << "open prototxt config failed: " << config_file; LOG(FATAL) << "open prototxt config failed: " << config_file;
success = false; success = false;
} else { } else {
...@@ -54,7 +53,7 @@ bool load_proto_conf(const std::string& config_file, T& proto_config) { ...@@ -54,7 +53,7 @@ bool load_proto_conf(const std::string& config_file, T& proto_config) {
std::string proto_str(file_content_buffer, file_size); std::string proto_str(file_content_buffer, file_size);
if (!google::protobuf::TextFormat::ParseFromString(proto_str, &proto_config)) { if (!google::protobuf::TextFormat::ParseFromString(proto_str, &proto_config)) {
LOG(FATAL) << "Failed to load config: " << config_file; LOG(FATAL) << "Failed to load config: " << config_file;
return -1; success = false;
} }
delete[] file_content_buffer; delete[] file_content_buffer;
fin.close(); fin.close();
...@@ -62,6 +61,25 @@ bool load_proto_conf(const std::string& config_file, T& proto_config) { ...@@ -62,6 +61,25 @@ bool load_proto_conf(const std::string& config_file, T& proto_config) {
return success; return success;
} }
template<typename T>
bool save_proto_conf(const std::string& config_file, T&proto_config) {
bool success = true;
std::ofstream ofs(config_file, std::ofstream::out);
if (!ofs || ofs.fail()) {
LOG(FATAL) << "open prototxt config failed: " << config_file;
success = false;
} else {
std::string config_str;
success = google::protobuf::TextFormat::PrintToString(proto_config, &config_str);
if (!success) {
return success;
}
ofs << config_str;
}
}
std::vector<std::string> list_all_model_dirs(std::string path);
} }
#endif #endif
...@@ -47,7 +47,7 @@ rm -rf build ...@@ -47,7 +47,7 @@ rm -rf build
mkdir build mkdir build
cd build cd build
cmake ../ ${FLAGS} cmake ../ ${FLAGS}
make -j10 make -j10
#-----------------run----------------# #-----------------run----------------#
./parallel_main ./parallel_main
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "async_es_agent.h"
namespace DeepES {
AsyncAgent::AsyncAgent(
std::shared_ptr<PaddlePredictor> predictor,
std::string config_path): ESAgent(predictor, config_path) {
_config_path = config_path;
}
AsyncAgent::~AsyncAgent() {
for(const auto kv: _param_delta) {
float* delta = kv.second;
delete[] delta;
}
}
bool AsyncAgent::_save() {
bool success = true;
if (_is_sampling_agent) {
LOG(ERROR) << "[DeepES] Original AsyncAgent cannot call add_noise function, please use cloned AsyncAgent.";
success = false;
return success;
}
int model_iter_id = _config->async_es().model_iter_id() + 1;
//current time
time_t rawtime;
struct tm * timeinfo;
char buffer[80];
time (&rawtime);
timeinfo = localtime(&rawtime);
strftime(buffer,sizeof(buffer),"%d-%m-%Y-%H:%M:%S",timeinfo);
std::string current_time(buffer);
std::string model_name = current_time + "-model_iter_id-"+ std::to_string(model_iter_id);
model_name = "model_iter_id-"+ std::to_string(model_iter_id);
std::string model_path = _config->async_es().model_warehouse() + "/" + model_name;
LOG(INFO) << "[save]model_path: " << model_path;
_predictor->SaveOptimizedModel(model_path, LiteModelType::kProtobuf);
// save config
auto async_es = _config->mutable_async_es();
async_es->set_model_iter_id(model_iter_id);
success = save_proto_conf(_config_path, *_config);
if (!success) {
LOG(ERROR) << "[]unable to save config for AsyncAgent";
success = false;
return success;
}
int max_to_keep = _config->async_es().max_to_keep();
success = _remove_expired_model(max_to_keep);
return success;
}
bool AsyncAgent::_remove_expired_model(int max_to_keep) {
bool success = true;
std::string model_path = _config->async_es().model_warehouse();
std::vector<std::string> model_dirs = list_all_model_dirs(model_path);
int model_iter_id = _config->async_es().model_iter_id() + 1;
for (const auto& dir: model_dirs) {
int dir_model_iter_id = _parse_model_iter_id(dir);
if (model_iter_id - dir_model_iter_id >= max_to_keep) {
std::string rm_command = std::string("rm -rf ") + dir;
int ret = system(rm_command.c_str());
if (ret == 0) {
LOG(INFO) << "[DeepES] remove expired Model: " << dir;
} else {
LOG(ERROR) << "[DeepES] fail to remove expired Model: " << dir;
success = false;
return success;
}
}
}
return success;
}
bool AsyncAgent::_compute_model_diff() {
bool success = true;
for (const auto& kv: _previous_predictors) {
int model_iter_id = kv.first;
std::shared_ptr<PaddlePredictor> old_predictor = kv.second;
float* diff = new float[_param_size];
memset(diff, 0, _param_size * sizeof(float));
for (std::string param_name: _param_names) {
auto des_tensor = old_predictor->GetTensor(param_name);
auto src_tensor = _predictor->GetTensor(param_name);
const float* des_data = des_tensor->data<float>();
const float* src_data = src_tensor->data<float>();
int64_t tensor_size = ShapeProduction(src_tensor->shape());
for (int i = 0; i < tensor_size; ++i) {
diff[i] = des_data[i] - src_data[i];
}
}
_param_delta[model_iter_id] = diff;
}
return success;
}
bool AsyncAgent::_load() {
bool success = true;
std::string model_path = _config->async_es().model_warehouse();
std::vector<std::string> model_dirs = list_all_model_dirs(model_path);
if (model_dirs.size() == 0) {
int model_iter_id = _config->async_es().model_iter_id();
success = model_iter_id == 0 ? true: false;
if (!success) {
LOG(WARNING) << "[DeepES] current_model_iter_id is nonzero, but no model is \
found at the dir: " << model_path;
}
return success;
}
for(auto &dir: model_dirs) {
int model_iter_id = _parse_model_iter_id(dir);
if (model_iter_id == -1) {
LOG(WARNING) << "[DeepES] fail to parse model_iter_id: " << dir;
success = false;
return success;
}
std::shared_ptr<PaddlePredictor> predictor = _load_previous_model(dir);
if (predictor == nullptr) {
success = false;
LOG(WARNING) << "[DeepES] fail to load model: " << dir;
return success;
}
_previous_predictors[model_iter_id] = predictor;
}
success = _compute_model_diff();
return success;
}
std::shared_ptr<PaddlePredictor> AsyncAgent::_load_previous_model(std::string model_dir) {
// 1. Create CxxConfig
CxxConfig config;
config.set_model_file(model_dir + "/model");
config.set_param_file(model_dir + "/params");
config.set_valid_places({
Place{TARGET(kX86), PRECISION(kFloat)},
Place{TARGET(kHost), PRECISION(kFloat)}
});
// 2. Create PaddlePredictor by CxxConfig
std::shared_ptr<PaddlePredictor> predictor = CreatePaddlePredictor<CxxConfig>(config);
return predictor;
}
std::shared_ptr<AsyncAgent> AsyncAgent::clone() {
std::shared_ptr<PaddlePredictor> new_sampling_predictor = _predictor->Clone();
std::shared_ptr<AsyncAgent> new_agent = std::make_shared<AsyncAgent>();
float* noise = new float [_param_size];
new_agent->_predictor = _predictor;
new_agent->_sampling_predictor = new_sampling_predictor;
new_agent->_is_sampling_agent = true;
new_agent->_sampling_method = _sampling_method;
new_agent->_param_names = _param_names;
new_agent->_param_size = _param_size;
new_agent->_config = _config;
new_agent->_noise = noise;
return new_agent;
}
bool AsyncAgent::update(
std::vector<SamplingInfo>& noisy_info,
std::vector<float>& noisy_rewards) {
CHECK(!_is_sampling_agent) << "[DeepES] Cloned ESAgent cannot call update function. \
Please use original ESAgent.";
bool success = _load();
CHECK(success) << "[DeepES] fail to load previous models.";
int current_model_iter_id = _config->async_es().model_iter_id();
// validate model_iter_id for each sample before the update
for (int i = 0; i < noisy_info.size(); ++i) {
int model_iter_id = noisy_info[i].model_iter_id();
if (model_iter_id != current_model_iter_id
&& _previous_predictors.count(model_iter_id) == 0) {
LOG(WARNING) << "[DeepES] The sample with model_dir_id: " << model_iter_id \
<< " cannot match any local model";
success = false;
return success;
}
}
compute_centered_ranks(noisy_rewards);
memset(_neg_gradients, 0, _param_size * sizeof(float));
for (int i = 0; i < noisy_info.size(); ++i) {
int key = noisy_info[i].key(0);
float reward = noisy_rewards[i];
int model_iter_id = noisy_info[i].model_iter_id();
bool success = _sampling_method->resampling(key, _noise, _param_size);
float* delta = _param_delta[model_iter_id];
// compute neg_gradients
if (model_iter_id == current_model_iter_id) {
for (int64_t j = 0; j < _param_size; ++j) {
_neg_gradients[j] += _noise[j] * reward;
}
} else {
for (int64_t j = 0; j < _param_size; ++j) {
_neg_gradients[j] += (_noise[j] + delta[j]) * reward;
}
}
}
for (int64_t j = 0; j < _param_size; ++j) {
_neg_gradients[j] /= -1.0 * noisy_info.size();
}
//update
int64_t counter = 0;
for (std::string param_name: _param_names) {
std::unique_ptr<Tensor> tensor = _predictor->GetMutableTensor(param_name);
float* tensor_data = tensor->mutable_data<float>();
int64_t tensor_size = ShapeProduction(tensor->shape());
_optimizer->update(tensor_data, _neg_gradients + counter, tensor_size, param_name);
counter += tensor_size;
}
success = _save();
CHECK(success) << "[DeepES] fail to save model.";
return true;
}
int AsyncAgent::_parse_model_iter_id(const std::string& model_path) {
int model_iter_id = -1;
int pow = 1;
for (int i = model_path.size() - 1; i >= 0; --i) {
if (model_path[i] >= '0' && model_path[i] <= '9') {
if (model_iter_id == -1) model_iter_id = 0;
} else {
break;
}
model_iter_id += pow * (model_path[i] - '0');
pow *= 10;
}
return model_iter_id;
}
}//namespace
...@@ -13,14 +13,14 @@ ...@@ -13,14 +13,14 @@
// limitations under the License. // limitations under the License.
#include "es_agent.h" #include "es_agent.h"
#include <ctime>
namespace DeepES { namespace DeepES {
typedef paddle::lite_api::PaddlePredictor PaddlePredictor;
typedef paddle::lite_api::Tensor Tensor; typedef paddle::lite_api::Tensor Tensor;
typedef paddle::lite_api::shape_t shape_t; typedef paddle::lite_api::shape_t shape_t;
inline int64_t ShapeProduction(const shape_t& shape) { int64_t ShapeProduction(const shape_t& shape) {
int64_t res = 1; int64_t res = 1;
for (auto i : shape) res *= i; for (auto i : shape) res *= i;
return res; return res;
...@@ -71,6 +71,7 @@ std::shared_ptr<ESAgent> ESAgent::clone() { ...@@ -71,6 +71,7 @@ std::shared_ptr<ESAgent> ESAgent::clone() {
new_agent->_is_sampling_agent = true; new_agent->_is_sampling_agent = true;
new_agent->_sampling_method = _sampling_method; new_agent->_sampling_method = _sampling_method;
new_agent->_param_names = _param_names; new_agent->_param_names = _param_names;
new_agent->_config = _config;
new_agent->_param_size = _param_size; new_agent->_param_size = _param_size;
new_agent->_noise = noise; new_agent->_noise = noise;
...@@ -111,7 +112,6 @@ bool ESAgent::update( ...@@ -111,7 +112,6 @@ bool ESAgent::update(
counter += tensor_size; counter += tensor_size;
} }
return true; return true;
} }
bool ESAgent::add_noise(SamplingInfo& sampling_info) { bool ESAgent::add_noise(SamplingInfo& sampling_info) {
...@@ -121,7 +121,9 @@ bool ESAgent::add_noise(SamplingInfo& sampling_info) { ...@@ -121,7 +121,9 @@ bool ESAgent::add_noise(SamplingInfo& sampling_info) {
} }
int key = _sampling_method->sampling(_noise, _param_size); int key = _sampling_method->sampling(_noise, _param_size);
int model_iter_id = _config->async_es().model_iter_id();
sampling_info.add_key(key); sampling_info.add_key(key);
sampling_info.set_model_iter_id(model_iter_id);
int64_t counter = 0; int64_t counter = 0;
for (std::string param_name: _param_names) { for (std::string param_name: _param_names) {
...@@ -137,7 +139,6 @@ bool ESAgent::add_noise(SamplingInfo& sampling_info) { ...@@ -137,7 +139,6 @@ bool ESAgent::add_noise(SamplingInfo& sampling_info) {
return true; return true;
} }
std::shared_ptr<PaddlePredictor> ESAgent::get_predictor() { std::shared_ptr<PaddlePredictor> ESAgent::get_predictor() {
return _sampling_predictor; return _sampling_predictor;
} }
...@@ -151,6 +152,4 @@ int64_t ESAgent::_calculate_param_size() { ...@@ -151,6 +152,4 @@ int64_t ESAgent::_calculate_param_size() {
return param_size; return param_size;
} }
}//namespace
}
...@@ -51,4 +51,5 @@ message AsyncESConfig{ ...@@ -51,4 +51,5 @@ message AsyncESConfig{
optional string model_warehouse = 1 [default = "./model_warehouse"]; optional string model_warehouse = 1 [default = "./model_warehouse"];
repeated string model_md5 = 2; repeated string model_md5 = 2;
optional int32 max_to_keep = 3 [default = 5]; optional int32 max_to_keep = 3 [default = 5];
optional int32 model_iter_id = 4 [default = 0];
} }
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "utils.h" #include "utils.h"
#include <dirent.h>
namespace DeepES { namespace DeepES {
...@@ -33,4 +34,21 @@ void compute_centered_ranks(std::vector<float> &reward) { ...@@ -33,4 +34,21 @@ void compute_centered_ranks(std::vector<float> &reward) {
} }
} }
std::vector<std::string> list_all_model_dirs(std::string path) {
std::vector<std::string> model_dirs;
DIR *dpdf;
struct dirent *epdf;
dpdf = opendir(path.data());
if (dpdf != NULL){
while (epdf = readdir(dpdf)){
std::string dir(epdf->d_name);
if (dir.find("model_iter_id") != std::string::npos) {
model_dirs.push_back(path + "/" + dir);
}
}
}
closedir(dpdf);
return model_dirs;
}
}//namespace }//namespace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册