未验证 提交 75396010 编写于 作者: B Bo Zhou 提交者: GitHub

add an agent for the support of asynchronous update (#235)

* rename SamplingKey to SamplingInfo

* add asynchronous lsagent

* rename AsyncAgent to AsyncESAgent

* fix comments

* Update async_es_agent.cc

* Sampling Info

* commemt

* commemt
上级 a1ac2da3
......@@ -45,6 +45,7 @@ if (WITH_PADDLE)
file(GLOB framework_src "src/paddle/*.cc")
set(demo "${PROJECT_SOURCE_DIR}/demo/paddle/cartpole_solver_parallel.cc")
#set(demo "${PROJECT_SOURCE_DIR}/demo/paddle/cartpole_async_solver.cc")
########## Torch config ##########
elseif (WITH_TORCH)
list(APPEND CMAKE_PREFIX_PATH "./libtorch")
......
......@@ -11,14 +11,14 @@ auto agent = ESAgent(config);
for (int i = 0; i < 10; ++i) {
auto sampling_agnet = agent->clone(); // clone出一个sampling agent
SamplingKey key;
agent->add_noise(key); // 参数扰动,同时保存随机种子到key
SamplingInfo info;
agent->add_noise(info); // 参数扰动,同时保存随机种子到info
int reward = evaluate(env, sampling_agent); //评估参数
noisy_keys.push_back(key); // 记录随机噪声对应种子
noisy_info.push_back(info); // 记录随机噪声对应种子
noisy_rewards.push_back(reward); // 记录评估结果
}
//根据评估结果、随机种子更新参数,然后重复以上过程,直到收敛。
agent->update(noisy_keys, noisy_rewards);
agent->update(noisy_info, noisy_rewards);
```
## 一键运行demo列表
......
seed : 1024
seed: 1024
gaussian_sampling {
std: 0.5
}
optimizer {
type: "Adam",
base_lr: 0.05,
momentum: 0.9,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
type: "Adam"
base_lr: 0.05
momentum: 0.9
beta1: 0.9
beta2: 0.999
epsilon: 1e-08
}
async_es {
model_iter_id: 0
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <glog/logging.h>
#include <omp.h>
#include "cartpole.h"
#include "async_es_agent.h"
#include "paddle_api.h"
using namespace DeepES;
using namespace paddle::lite_api;
const int ITER = 10;
std::shared_ptr<PaddlePredictor> create_paddle_predictor(const std::string& model_dir) {
// 1. Create CxxConfig
CxxConfig config;
config.set_model_dir(model_dir);
config.set_valid_places({
Place{TARGET(kX86), PRECISION(kFloat)},
Place{TARGET(kHost), PRECISION(kFloat)}
});
// 2. Create PaddlePredictor by CxxConfig
std::shared_ptr<PaddlePredictor> predictor = CreatePaddlePredictor<CxxConfig>(config);
return predictor;
}
// Use PaddlePredictor of CartPole model to predict the action.
std::vector<float> forward(std::shared_ptr<PaddlePredictor> predictor, const float* obs) {
std::unique_ptr<Tensor> input_tensor(std::move(predictor->GetInput(0)));
input_tensor->Resize({1, 4});
input_tensor->CopyFromCpu(obs);
predictor->Run();
std::vector<float> probs(2, 0.0);
std::unique_ptr<const Tensor> output_tensor(
std::move(predictor->GetOutput(0)));
output_tensor->CopyToCpu(probs.data());
return probs;
}
int arg_max(const std::vector<float>& vec) {
return static_cast<int>(std::distance(vec.begin(), std::max_element(vec.begin(), vec.end())));
}
float evaluate(CartPole& env, std::shared_ptr<AsyncESAgent> agent) {
float total_reward = 0.0;
env.reset();
const float* obs = env.getState();
std::shared_ptr<PaddlePredictor> paddle_predictor;
paddle_predictor = agent->get_predictor();
while (true) {
std::vector<float> probs = forward(paddle_predictor, obs);
int act = arg_max(probs);
env.step(act);
float reward = env.getReward();
bool done = env.isDone();
total_reward += reward;
if (done) break;
obs = env.getState();
}
return total_reward;
}
int main(int argc, char* argv[]) {
std::vector<CartPole> envs;
for (int i = 0; i < ITER; ++i) {
envs.push_back(CartPole());
}
std::shared_ptr<PaddlePredictor> paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model");
std::shared_ptr<AsyncESAgent> agent = std::make_shared<AsyncESAgent>(paddle_predictor, "../benchmark/cartpole_config.prototxt");
// Clone agents to sample (explore).
std::vector< std::shared_ptr<AsyncESAgent> > sampling_agents;
for (int i = 0; i < ITER; ++i) {
sampling_agents.push_back(agent->clone());
}
std::vector<SamplingInfo> noisy_info;
std::vector<SamplingInfo> last_noisy_info;
std::vector<float> noisy_rewards(ITER, 0.0f);
std::vector<float> last_noisy_rewards;
noisy_info.resize(ITER);
omp_set_num_threads(10);
for (int epoch = 0; epoch < 100; ++epoch) {
last_noisy_info.clear();
last_noisy_rewards.clear();
if (epoch != 0) {
for (int i = 0; i < ITER; ++i){
last_noisy_info.push_back(noisy_info[i]);
last_noisy_rewards.push_back(noisy_rewards[i]);
}
}
#pragma omp parallel for schedule(dynamic, 1)
for (int i = 0; i < ITER; ++i) {
std::shared_ptr<AsyncESAgent> sampling_agent = sampling_agents[i];
SamplingInfo info;
bool success = sampling_agent->add_noise(info);
float reward = evaluate(envs[i], sampling_agent);
noisy_info[i] = info;
noisy_rewards[i] = reward;
}
for (int i = 0; i < ITER; ++i){
last_noisy_info.push_back(noisy_info[i]);
last_noisy_rewards.push_back(noisy_rewards[i]);
}
// NOTE: all parameters of sampling_agents will be updated
bool success = agent->update(last_noisy_info, last_noisy_rewards);
int reward = evaluate(envs[0], agent);
LOG(INFO) << "Epoch:" << epoch << " Reward: " << reward;
}
}
......@@ -95,16 +95,16 @@ int main(int argc, char* argv[]) {
sampling_agents.push_back(agent->clone());
}
std::vector<SamplingKey> noisy_keys;
std::vector<SamplingInfo> noisy_keys;
std::vector<float> noisy_rewards(ITER, 0.0f);
noisy_keys.resize(ITER);
omp_set_num_threads(10);
for (int epoch = 0; epoch < 1000; ++epoch) {
for (int epoch = 0; epoch < 100; ++epoch) {
#pragma omp parallel for schedule(dynamic, 1)
for (int i = 0; i < ITER; ++i) {
std::shared_ptr<ESAgent> sampling_agent = sampling_agents[i];
SamplingKey key;
SamplingInfo key;
bool success = sampling_agent->add_noise(key);
float reward = evaluate(envs[i], sampling_agent);
......
......@@ -59,23 +59,23 @@ int main(int argc, char* argv[]) {
sampling_agents.push_back(agent->clone());
}
std::vector<SamplingKey> noisy_keys;
std::vector<SamplingInfo> noisy_info;
std::vector<float> noisy_rewards(ITER, 0.0f);
noisy_keys.resize(ITER);
noisy_info.resize(ITER);
for (int epoch = 0; epoch < 100; ++epoch) {
#pragma omp parallel for schedule(dynamic, 1)
for (int i = 0; i < ITER; ++i) {
auto sampling_agent = sampling_agents[i];
SamplingKey key;
bool success = sampling_agent->add_noise(key);
SamplingInfo info;
bool success = sampling_agent->add_noise(info);
float reward = evaluate(envs[i], sampling_agent);
noisy_keys[i] = key;
noisy_info[i] = info;
noisy_rewards[i] = reward;
}
// Will also update parameters of sampling_agents
bool success = agent->update(noisy_keys, noisy_rewards);
bool success = agent->update(noisy_info, noisy_rewards);
// Use original agent to evalute (without noise).
int reward = evaluate(envs[0], agent);
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef ASYNC_ES_AGENT_H
#define ASYNC_ES_AGENT_H
#include "es_agent.h"
#include <map>
#include <stdlib.h>
namespace DeepES{
/* DeepES agent with PaddleLite as backend. This agent supports asynchronous update.
* Users mainly focus on the following functions:
* 1. clone: clone an agent for multi-thread evaluation
* 2. add_noise: add noise into parameters.
* 3. update: update parameters given data collected during evaluation.
*/
class AsyncESAgent: public ESAgent {
public:
AsyncESAgent() {}
~AsyncESAgent();
/**
* @args:
* predictor: predictor created by users for prediction.
* config_path: the path of configuration file.
* Note that AsyncESAgent will update the configuration file after calling the update function.
* Please use the up-to-date configuration.
*/
AsyncESAgent(
std::shared_ptr<PaddlePredictor> predictor,
std::string config_path);
/**
* @brief: Clone an agent for sampling.
*/
std::shared_ptr<AsyncESAgent> clone();
/**
* @brief: update parameters given data collected during evaluation.
* @args:
* noisy_info: sampling information returned by add_noise function.
* noisy_reward: evaluation rewards.
*/
bool update(
std::vector<SamplingInfo>& noisy_info,
std::vector<float>& noisy_rewards);
private:
std::map<int, std::shared_ptr<PaddlePredictor>> _previous_predictors;
std::map<int, float*> _param_delta;
std::string _config_path;
/**
* @brief: parse model_iter_id given a string of model directory.
* @return: an integer indicating the model_iter_id
*/
int _parse_model_iter_id(const std::string&);
/**
* @brief: compute the distance between current parameter and previous models.
*/
bool _compute_model_diff();
/**
* @brief: remove expired models to avoid overuse of disk space.
* @args:
* max_to_keep: the maximum number of models to keep locally.
*/
bool _remove_expired_model(int max_to_keep);
/**
* @brief: save up-to-date parameters to the disk.
*/
bool _save();
/**
* @brief: load all models in the model warehouse.
*/
bool _load();
/**
* @brief: load a model given the model directory.
*/
std::shared_ptr<PaddlePredictor> _load_previous_model(std::string model_dir);
};
} //namespace
#endif
......@@ -22,20 +22,19 @@
#include "deepes.pb.h"
#include <vector>
using namespace paddle::lite_api;
namespace DeepES {
typedef paddle::lite_api::PaddlePredictor PaddlePredictor;
int64_t ShapeProduction(const shape_t& shape);
/**
* @brief DeepES agent for PaddleLite.
* @brief DeepES agent with PaddleLite as backend.
* Users mainly focus on the following functions:
* 1. clone: clone an agent for multi-thread evaluation
* 2. add_noise: add noise into parameters.
* 3. update: update parameters given data collected during evaluation.
*
* Users use `clone` fucntion to clone a sampling agent, which can call `add_noise`
* function to add noise to copied parameters and call `get_predictor` fucntion to
* get a paddle predictor with added noise.
*
* Then can use `update` function to update parameters based on ES algorithm.
* Note: parameters of cloned agents will also be updated.
*/
class ESAgent {
public:
......@@ -63,11 +62,11 @@ class ESAgent {
* Parameters of cloned agents will also be updated.
*/
bool update(
std::vector<SamplingKey>& noisy_keys,
std::vector<SamplingInfo>& noisy_info,
std::vector<float>& noisy_rewards);
// copied parameters = original parameters + noise
bool add_noise(SamplingKey& sampling_key);
bool add_noise(SamplingInfo& sampling_info);
/**
* @brief Get paddle predict
......@@ -77,7 +76,9 @@ class ESAgent {
*/
std::shared_ptr<PaddlePredictor> get_predictor();
private:
protected:
int64_t _calculate_param_size();
std::shared_ptr<PaddlePredictor> _predictor;
......
......@@ -98,7 +98,7 @@ public:
* Only not cloned ESAgent can call `update` function.
* Parameters of cloned agents will also be updated.
*/
bool update(std::vector<SamplingKey>& noisy_keys, std::vector<float>& noisy_rewards) {
bool update(std::vector<SamplingInfo>& noisy_info, std::vector<float>& noisy_rewards) {
if (_is_sampling_agent) {
LOG(ERROR) << "[DeepES] Cloned ESAgent cannot call update function, please use original ESAgent.";
return false;
......@@ -107,8 +107,8 @@ public:
compute_centered_ranks(noisy_rewards);
memset(_neg_gradients, 0, _param_size * sizeof(float));
for (int i = 0; i < noisy_keys.size(); ++i) {
int key = noisy_keys[i].key(0);
for (int i = 0; i < noisy_info.size(); ++i) {
int key = noisy_info[i].key(0);
float reward = noisy_rewards[i];
bool success = _sampling_method->resampling(key, _noise, _param_size);
for (int64_t j = 0; j < _param_size; ++j) {
......@@ -116,7 +116,7 @@ public:
}
}
for (int64_t j = 0; j < _param_size; ++j) {
_neg_gradients[j] /= -1.0 * noisy_keys.size();
_neg_gradients[j] /= -1.0 * noisy_info.size();
}
//update
......@@ -133,7 +133,7 @@ public:
}
// copied parameters = original parameters + noise
bool add_noise(SamplingKey& sampling_key) {
bool add_noise(SamplingInfo& sampling_info) {
if (!_is_sampling_agent) {
LOG(ERROR) << "[DeepES] Original ESAgent cannot call add_noise function, please use cloned ESAgent.";
return false;
......@@ -142,7 +142,7 @@ public:
auto sampling_params = _sampling_model->named_parameters();
auto params = _model->named_parameters();
int key = _sampling_method->sampling(_noise, _param_size);
sampling_key.add_key(key);
sampling_info.add_key(key);
int64_t counter = 0;
for (auto& param: sampling_params) {
torch::Tensor sampling_tensor = param.value().view({-1});
......@@ -162,7 +162,7 @@ public:
private:
int64_t _calculate_param_size() {
int _param_size = 0;
_param_size = 0;
auto params = _model->named_parameters();
for (auto& param: params) {
torch::Tensor tensor = param.value().view({-1});
......
......@@ -39,9 +39,8 @@ template<typename T>
bool load_proto_conf(const std::string& config_file, T& proto_config) {
bool success = true;
std::ifstream fin(config_file);
CHECK(fin) << "open config file " << config_file;
if (fin.fail()) {
LOG(FATAL) << "open prototxt config failed: " << config_file;
if (!fin || fin.fail()) {
LOG(ERROR) << "open prototxt config failed: " << config_file;
success = false;
} else {
fin.seekg(0, std::ios::end);
......@@ -53,8 +52,8 @@ bool load_proto_conf(const std::string& config_file, T& proto_config) {
std::string proto_str(file_content_buffer, file_size);
if (!google::protobuf::TextFormat::ParseFromString(proto_str, &proto_config)) {
LOG(FATAL) << "Failed to load config: " << config_file;
return -1;
LOG(ERROR) << "Failed to load config: " << config_file;
success = false;
}
delete[] file_content_buffer;
fin.close();
......@@ -62,6 +61,26 @@ bool load_proto_conf(const std::string& config_file, T& proto_config) {
return success;
}
template<typename T>
bool save_proto_conf(const std::string& config_file, T&proto_config) {
bool success = true;
std::ofstream ofs(config_file, std::ofstream::out);
if (!ofs || ofs.fail()) {
LOG(ERROR) << "open prototxt config failed: " << config_file;
success = false;
} else {
std::string config_str;
success = google::protobuf::TextFormat::PrintToString(proto_config, &config_str);
if (!success) {
return success;
}
ofs << config_str;
}
return success;
}
std::vector<std::string> list_all_model_dirs(std::string path);
}
#endif
......@@ -35,8 +35,6 @@ else
exit 0
fi
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
#----------------protobuf-------------#
cp ./src/proto/deepes.proto ./
protoc deepes.proto --cpp_out ./
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "async_es_agent.h"
namespace DeepES {
AsyncESAgent::AsyncESAgent(
std::shared_ptr<PaddlePredictor> predictor,
std::string config_path): ESAgent(predictor, config_path) {
_config_path = config_path;
}
AsyncESAgent::~AsyncESAgent() {
for(const auto kv: _param_delta) {
float* delta = kv.second;
delete[] delta;
}
}
bool AsyncESAgent::_save() {
bool success = true;
if (_is_sampling_agent) {
LOG(ERROR) << "[DeepES] Cloned AsyncESAgent cannot call `save`.Please use original AsyncESAgent.";
success = false;
return success;
}
int model_iter_id = _config->async_es().model_iter_id() + 1;
//current time
time_t rawtime;
struct tm * timeinfo;
char buffer[80];
time (&rawtime);
timeinfo = localtime(&rawtime);
std::string model_name = "model_iter_id-"+ std::to_string(model_iter_id);
std::string model_path = _config->async_es().model_warehouse() + "/" + model_name;
LOG(INFO) << "[save]model_path: " << model_path;
_predictor->SaveOptimizedModel(model_path, LiteModelType::kProtobuf);
// save config
auto async_es = _config->mutable_async_es();
async_es->set_model_iter_id(model_iter_id);
success = save_proto_conf(_config_path, *_config);
if (!success) {
LOG(ERROR) << "[]unable to save config for AsyncESAgent";
success = false;
return success;
}
int max_to_keep = _config->async_es().max_to_keep();
success = _remove_expired_model(max_to_keep);
return success;
}
bool AsyncESAgent::_remove_expired_model(int max_to_keep) {
bool success = true;
std::string model_path = _config->async_es().model_warehouse();
std::vector<std::string> model_dirs = list_all_model_dirs(model_path);
int model_iter_id = _config->async_es().model_iter_id() + 1;
for (const auto& dir: model_dirs) {
int dir_model_iter_id = _parse_model_iter_id(dir);
if (model_iter_id - dir_model_iter_id >= max_to_keep) {
std::string rm_command = std::string("rm -rf ") + dir;
int ret = system(rm_command.c_str());
if (ret == 0) {
LOG(INFO) << "[DeepES] remove expired Model: " << dir;
} else {
LOG(ERROR) << "[DeepES] fail to remove expired Model: " << dir;
success = false;
return success;
}
}
}
return success;
}
bool AsyncESAgent::_compute_model_diff() {
bool success = true;
for (const auto& kv: _previous_predictors) {
int model_iter_id = kv.first;
std::shared_ptr<PaddlePredictor> old_predictor = kv.second;
float* diff = new float[_param_size];
memset(diff, 0, _param_size * sizeof(float));
int offset = 0;
for (const std::string& param_name: _param_names) {
auto des_tensor = old_predictor->GetTensor(param_name);
auto src_tensor = _predictor->GetTensor(param_name);
const float* des_data = des_tensor->data<float>();
const float* src_data = src_tensor->data<float>();
int64_t tensor_size = ShapeProduction(src_tensor->shape());
for (int i = 0; i < tensor_size; ++i) {
diff[i + offset] = des_data[i] - src_data[i];
}
offset += tensor_size;
}
_param_delta[model_iter_id] = diff;
}
return success;
}
bool AsyncESAgent::_load() {
bool success = true;
std::string model_path = _config->async_es().model_warehouse();
std::vector<std::string> model_dirs = list_all_model_dirs(model_path);
if (model_dirs.size() == 0) {
int model_iter_id = _config->async_es().model_iter_id();
success = model_iter_id == 0 ? true: false;
if (!success) {
LOG(WARNING) << "[DeepES] current_model_iter_id is nonzero, but no model is \
found at the dir: " << model_path;
}
return success;
}
for(auto &dir: model_dirs) {
int model_iter_id = _parse_model_iter_id(dir);
if (model_iter_id == -1) {
LOG(WARNING) << "[DeepES] fail to parse model_iter_id: " << dir;
success = false;
return success;
}
std::shared_ptr<PaddlePredictor> predictor = _load_previous_model(dir);
if (predictor == nullptr) {
success = false;
LOG(WARNING) << "[DeepES] fail to load model: " << dir;
return success;
}
_previous_predictors[model_iter_id] = predictor;
}
success = _compute_model_diff();
return success;
}
std::shared_ptr<PaddlePredictor> AsyncESAgent::_load_previous_model(std::string model_dir) {
// 1. Create CxxConfig
CxxConfig config;
config.set_model_file(model_dir + "/model");
config.set_param_file(model_dir + "/params");
config.set_valid_places({
Place{TARGET(kX86), PRECISION(kFloat)},
Place{TARGET(kHost), PRECISION(kFloat)}
});
// 2. Create PaddlePredictor by CxxConfig
std::shared_ptr<PaddlePredictor> predictor = CreatePaddlePredictor<CxxConfig>(config);
return predictor;
}
std::shared_ptr<AsyncESAgent> AsyncESAgent::clone() {
std::shared_ptr<PaddlePredictor> new_sampling_predictor = _predictor->Clone();
std::shared_ptr<AsyncESAgent> new_agent = std::make_shared<AsyncESAgent>();
float* noise = new float [_param_size];
new_agent->_predictor = _predictor;
new_agent->_sampling_predictor = new_sampling_predictor;
new_agent->_is_sampling_agent = true;
new_agent->_sampling_method = _sampling_method;
new_agent->_param_names = _param_names;
new_agent->_param_size = _param_size;
new_agent->_config = _config;
new_agent->_noise = noise;
return new_agent;
}
bool AsyncESAgent::update(
std::vector<SamplingInfo>& noisy_info,
std::vector<float>& noisy_rewards) {
CHECK(!_is_sampling_agent) << "[DeepES] Cloned ESAgent cannot call update function. \
Please use original ESAgent.";
bool success = _load();
CHECK(success) << "[DeepES] fail to load previous models.";
int current_model_iter_id = _config->async_es().model_iter_id();
// validate model_iter_id for each sample before the update
for (int i = 0; i < noisy_info.size(); ++i) {
int model_iter_id = noisy_info[i].model_iter_id();
if (model_iter_id != current_model_iter_id
&& _previous_predictors.count(model_iter_id) == 0) {
LOG(WARNING) << "[DeepES] The sample with model_dir_id: " << model_iter_id \
<< " cannot match any local model";
success = false;
return success;
}
}
compute_centered_ranks(noisy_rewards);
memset(_neg_gradients, 0, _param_size * sizeof(float));
for (int i = 0; i < noisy_info.size(); ++i) {
int key = noisy_info[i].key(0);
float reward = noisy_rewards[i];
int model_iter_id = noisy_info[i].model_iter_id();
bool success = _sampling_method->resampling(key, _noise, _param_size);
CHECK(success) << "[DeepES] resampling error occurs at sample: " << i;
float* delta = _param_delta[model_iter_id];
// compute neg_gradients
if (model_iter_id == current_model_iter_id) {
for (int64_t j = 0; j < _param_size; ++j) {
_neg_gradients[j] += _noise[j] * reward;
}
} else {
for (int64_t j = 0; j < _param_size; ++j) {
_neg_gradients[j] += (_noise[j] + delta[j]) * reward;
}
}
}
for (int64_t j = 0; j < _param_size; ++j) {
_neg_gradients[j] /= -1.0 * noisy_info.size();
}
//update
int64_t counter = 0;
for (std::string param_name: _param_names) {
std::unique_ptr<Tensor> tensor = _predictor->GetMutableTensor(param_name);
float* tensor_data = tensor->mutable_data<float>();
int64_t tensor_size = ShapeProduction(tensor->shape());
_optimizer->update(tensor_data, _neg_gradients + counter, tensor_size, param_name);
counter += tensor_size;
}
success = _save();
CHECK(success) << "[DeepES] fail to save model.";
return true;
}
int AsyncESAgent::_parse_model_iter_id(const std::string& model_path) {
int model_iter_id = -1;
int pow = 1;
for (int i = model_path.size() - 1; i >= 0; --i) {
if (model_path[i] >= '0' && model_path[i] <= '9') {
if (model_iter_id == -1) model_iter_id = 0;
} else {
break;
}
model_iter_id += pow * (model_path[i] - '0');
pow *= 10;
}
return model_iter_id;
}
}//namespace
......@@ -13,14 +13,11 @@
// limitations under the License.
#include "es_agent.h"
#include <ctime>
namespace DeepES {
typedef paddle::lite_api::PaddlePredictor PaddlePredictor;
typedef paddle::lite_api::Tensor Tensor;
typedef paddle::lite_api::shape_t shape_t;
inline int64_t ShapeProduction(const shape_t& shape) {
int64_t ShapeProduction(const shape_t& shape) {
int64_t res = 1;
for (auto i : shape) res *= i;
return res;
......@@ -71,6 +68,7 @@ std::shared_ptr<ESAgent> ESAgent::clone() {
new_agent->_is_sampling_agent = true;
new_agent->_sampling_method = _sampling_method;
new_agent->_param_names = _param_names;
new_agent->_config = _config;
new_agent->_param_size = _param_size;
new_agent->_noise = noise;
......@@ -78,7 +76,7 @@ std::shared_ptr<ESAgent> ESAgent::clone() {
}
bool ESAgent::update(
std::vector<SamplingKey>& noisy_keys,
std::vector<SamplingInfo>& noisy_info,
std::vector<float>& noisy_rewards) {
if (_is_sampling_agent) {
LOG(ERROR) << "[DeepES] Cloned ESAgent cannot call update function, please use original ESAgent.";
......@@ -88,8 +86,8 @@ bool ESAgent::update(
compute_centered_ranks(noisy_rewards);
memset(_neg_gradients, 0, _param_size * sizeof(float));
for (int i = 0; i < noisy_keys.size(); ++i) {
int key = noisy_keys[i].key(0);
for (int i = 0; i < noisy_info.size(); ++i) {
int key = noisy_info[i].key(0);
float reward = noisy_rewards[i];
bool success = _sampling_method->resampling(key, _noise, _param_size);
for (int64_t j = 0; j < _param_size; ++j) {
......@@ -97,7 +95,7 @@ bool ESAgent::update(
}
}
for (int64_t j = 0; j < _param_size; ++j) {
_neg_gradients[j] /= -1.0 * noisy_keys.size();
_neg_gradients[j] /= -1.0 * noisy_info.size();
}
//update
......@@ -111,17 +109,18 @@ bool ESAgent::update(
counter += tensor_size;
}
return true;
}
bool ESAgent::add_noise(SamplingKey& sampling_key) {
bool ESAgent::add_noise(SamplingInfo& sampling_info) {
if (!_is_sampling_agent) {
LOG(ERROR) << "[DeepES] Original ESAgent cannot call add_noise function, please use cloned ESAgent.";
return false;
}
int key = _sampling_method->sampling(_noise, _param_size);
sampling_key.add_key(key);
int model_iter_id = _config->async_es().model_iter_id();
sampling_info.add_key(key);
sampling_info.set_model_iter_id(model_iter_id);
int64_t counter = 0;
for (std::string param_name: _param_names) {
......@@ -137,7 +136,6 @@ bool ESAgent::add_noise(SamplingKey& sampling_key) {
return true;
}
std::shared_ptr<PaddlePredictor> ESAgent::get_predictor() {
return _sampling_predictor;
}
......@@ -151,6 +149,4 @@ int64_t ESAgent::_calculate_param_size() {
return param_size;
}
}
}//namespace
......@@ -23,6 +23,8 @@ message DeepESConfig {
optional GaussianSamplingConfig gaussian_sampling = 3;
// Optimizer Configuration
optional OptimizerConfig optimizer = 4;
// AsyncESAgent Configuration
optional AsyncESConfig async_es = 5;
}
message GaussianSamplingConfig {
......@@ -40,6 +42,14 @@ message OptimizerConfig{
optional float epsilon = 6 [default = 1e-8];
}
message SamplingKey{
message SamplingInfo{
repeated int32 key = 1;
optional int32 model_iter_id = 2;
}
message AsyncESConfig{
optional string model_warehouse = 1 [default = "./model_warehouse"];
repeated string model_md5 = 2;
optional int32 max_to_keep = 3 [default = 5];
optional int32 model_iter_id = 4 [default = 0];
}
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "utils.h"
#include <dirent.h>
namespace DeepES {
......@@ -34,4 +35,21 @@ bool compute_centered_ranks(std::vector<float> &reward) {
return true;
}
std::vector<std::string> list_all_model_dirs(std::string path) {
std::vector<std::string> model_dirs;
DIR *dpdf;
struct dirent *epdf;
dpdf = opendir(path.data());
if (dpdf != NULL){
while (epdf = readdir(dpdf)){
std::string dir(epdf->d_name);
if (dir.find("model_iter_id") != std::string::npos) {
model_dirs.push_back(path + "/" + dir);
}
}
}
closedir(dpdf);
return model_dirs;
}
}//namespace
......@@ -89,7 +89,7 @@ protected:
sampling_agents.push_back(agent->clone());
}
std::vector<SamplingKey> noisy_keys;
std::vector<SamplingInfo> noisy_keys;
std::vector<float> noisy_rewards(iter, 0.0f);
noisy_keys.resize(iter);
......@@ -98,7 +98,7 @@ protected:
#pragma omp parallel for schedule(dynamic, 1)
for (int i = 0; i < iter; ++i) {
auto sampling_agent = sampling_agents[i];
SamplingKey key;
SamplingInfo key;
bool success = sampling_agent->add_noise(key);
float reward = evaluate(x_list, y_list, train_data_size, sampling_agent);
noisy_keys[i] = key;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册