提交 752974cb 编写于 作者: Z zhoubo01

rename AsyncAgent to AsyncESAgent

上级 7793b479
......@@ -45,6 +45,7 @@ if (WITH_PADDLE)
file(GLOB framework_src "src/paddle/*.cc")
set(demo "${PROJECT_SOURCE_DIR}/demo/paddle/cartpole_solver_parallel.cc")
#set(demo "${PROJECT_SOURCE_DIR}/demo/paddle/cartpole_async_solver.cc")
########## Torch config ##########
elseif (WITH_TORCH)
list(APPEND CMAKE_PREFIX_PATH "./libtorch")
......
......@@ -58,7 +58,7 @@ int arg_max(const std::vector<float>& vec) {
}
float evaluate(CartPole& env, std::shared_ptr<AsyncAgent> agent) {
float evaluate(CartPole& env, std::shared_ptr<AsyncESAgent> agent) {
float total_reward = 0.0;
env.reset();
const float* obs = env.getState();
......@@ -87,10 +87,10 @@ int main(int argc, char* argv[]) {
}
std::shared_ptr<PaddlePredictor> paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model");
std::shared_ptr<AsyncAgent> agent = std::make_shared<AsyncAgent>(paddle_predictor, "../benchmark/cartpole_config.prototxt");
std::shared_ptr<AsyncESAgent> agent = std::make_shared<AsyncESAgent>(paddle_predictor, "../benchmark/cartpole_config.prototxt");
// Clone agents to sample (explore).
std::vector< std::shared_ptr<AsyncAgent> > sampling_agents;
std::vector< std::shared_ptr<AsyncESAgent> > sampling_agents;
for (int i = 0; i < ITER; ++i) {
sampling_agents.push_back(agent->clone());
}
......@@ -113,7 +113,7 @@ int main(int argc, char* argv[]) {
}
#pragma omp parallel for schedule(dynamic, 1)
for (int i = 0; i < ITER; ++i) {
std::shared_ptr<AsyncAgent> sampling_agent = sampling_agents[i];
std::shared_ptr<AsyncESAgent> sampling_agent = sampling_agents[i];
SamplingInfo info;
bool success = sampling_agent->add_noise(info);
float reward = evaluate(envs[i], sampling_agent);
......
......@@ -26,27 +26,27 @@ namespace DeepES{
* 2. add_noise: add noise into parameters.
* 3. update: update parameters given data collected during evaluation.
*/
class AsyncAgent: public ESAgent {
class AsyncESAgent: public ESAgent {
public:
AsyncAgent() {}
AsyncESAgent() {}
~AsyncAgent();
~AsyncESAgent();
/**
* @args:
* predictor: predictor created by users for prediction.
* config_path: the path of configuration file.
* Note that AsyncAgent will update the configuration file after calling the update function.
* Note that AsyncESAgent will update the configuration file after calling the update function.
* Please use the up-to-date configuration.
*/
AsyncAgent(
AsyncESAgent(
std::shared_ptr<PaddlePredictor> predictor,
std::string config_path);
/**
* @brief: Clone an agent for sampling.
*/
std::shared_ptr<AsyncAgent> clone();
std::shared_ptr<AsyncESAgent> clone();
/**
* @brief: Clone an agent for sampling.
......
......@@ -47,7 +47,7 @@ rm -rf build
mkdir build
cd build
cmake ../ ${FLAGS}
make -j10
make -j10
#-----------------run----------------#
./parallel_main
......@@ -15,22 +15,22 @@
#include "async_es_agent.h"
namespace DeepES {
AsyncAgent::AsyncAgent(
AsyncESAgent::AsyncESAgent(
std::shared_ptr<PaddlePredictor> predictor,
std::string config_path): ESAgent(predictor, config_path) {
_config_path = config_path;
}
AsyncAgent::~AsyncAgent() {
AsyncESAgent::~AsyncESAgent() {
for(const auto kv: _param_delta) {
float* delta = kv.second;
delete[] delta;
}
}
bool AsyncAgent::_save() {
bool AsyncESAgent::_save() {
bool success = true;
if (_is_sampling_agent) {
LOG(ERROR) << "[DeepES] Original AsyncAgent cannot call add_noise function, please use cloned AsyncAgent.";
LOG(ERROR) << "[DeepES] Original AsyncESAgent cannot call add_noise function, please use cloned AsyncESAgent.";
success = false;
return success;
}
......@@ -55,7 +55,7 @@ bool AsyncAgent::_save() {
async_es->set_model_iter_id(model_iter_id);
success = save_proto_conf(_config_path, *_config);
if (!success) {
LOG(ERROR) << "[]unable to save config for AsyncAgent";
LOG(ERROR) << "[]unable to save config for AsyncESAgent";
success = false;
return success;
}
......@@ -64,7 +64,7 @@ bool AsyncAgent::_save() {
return success;
}
bool AsyncAgent::_remove_expired_model(int max_to_keep) {
bool AsyncESAgent::_remove_expired_model(int max_to_keep) {
bool success = true;
std::string model_path = _config->async_es().model_warehouse();
std::vector<std::string> model_dirs = list_all_model_dirs(model_path);
......@@ -86,7 +86,7 @@ bool AsyncAgent::_remove_expired_model(int max_to_keep) {
return success;
}
bool AsyncAgent::_compute_model_diff() {
bool AsyncESAgent::_compute_model_diff() {
bool success = true;
for (const auto& kv: _previous_predictors) {
int model_iter_id = kv.first;
......@@ -108,7 +108,7 @@ bool AsyncAgent::_compute_model_diff() {
return success;
}
bool AsyncAgent::_load() {
bool AsyncESAgent::_load() {
bool success = true;
std::string model_path = _config->async_es().model_warehouse();
std::vector<std::string> model_dirs = list_all_model_dirs(model_path);
......@@ -140,7 +140,7 @@ bool AsyncAgent::_load() {
return success;
}
std::shared_ptr<PaddlePredictor> AsyncAgent::_load_previous_model(std::string model_dir) {
std::shared_ptr<PaddlePredictor> AsyncESAgent::_load_previous_model(std::string model_dir) {
// 1. Create CxxConfig
CxxConfig config;
config.set_model_file(model_dir + "/model");
......@@ -155,10 +155,10 @@ std::shared_ptr<PaddlePredictor> AsyncAgent::_load_previous_model(std::string mo
return predictor;
}
std::shared_ptr<AsyncAgent> AsyncAgent::clone() {
std::shared_ptr<AsyncESAgent> AsyncESAgent::clone() {
std::shared_ptr<PaddlePredictor> new_sampling_predictor = _predictor->Clone();
std::shared_ptr<AsyncAgent> new_agent = std::make_shared<AsyncAgent>();
std::shared_ptr<AsyncESAgent> new_agent = std::make_shared<AsyncESAgent>();
float* noise = new float [_param_size];
......@@ -175,7 +175,7 @@ std::shared_ptr<AsyncAgent> AsyncAgent::clone() {
return new_agent;
}
bool AsyncAgent::update(
bool AsyncESAgent::update(
std::vector<SamplingInfo>& noisy_info,
std::vector<float>& noisy_rewards) {
......@@ -237,7 +237,7 @@ bool AsyncAgent::update(
return true;
}
int AsyncAgent::_parse_model_iter_id(const std::string& model_path) {
int AsyncESAgent::_parse_model_iter_id(const std::string& model_path) {
int model_iter_id = -1;
int pow = 1;
for (int i = model_path.size() - 1; i >= 0; --i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册