提交 752974cb 编写于 作者: Z zhoubo01

rename AsyncAgent to AsyncESAgent

上级 7793b479
...@@ -45,6 +45,7 @@ if (WITH_PADDLE) ...@@ -45,6 +45,7 @@ if (WITH_PADDLE)
file(GLOB framework_src "src/paddle/*.cc") file(GLOB framework_src "src/paddle/*.cc")
set(demo "${PROJECT_SOURCE_DIR}/demo/paddle/cartpole_solver_parallel.cc") set(demo "${PROJECT_SOURCE_DIR}/demo/paddle/cartpole_solver_parallel.cc")
#set(demo "${PROJECT_SOURCE_DIR}/demo/paddle/cartpole_async_solver.cc")
########## Torch config ########## ########## Torch config ##########
elseif (WITH_TORCH) elseif (WITH_TORCH)
list(APPEND CMAKE_PREFIX_PATH "./libtorch") list(APPEND CMAKE_PREFIX_PATH "./libtorch")
......
...@@ -58,7 +58,7 @@ int arg_max(const std::vector<float>& vec) { ...@@ -58,7 +58,7 @@ int arg_max(const std::vector<float>& vec) {
} }
float evaluate(CartPole& env, std::shared_ptr<AsyncAgent> agent) { float evaluate(CartPole& env, std::shared_ptr<AsyncESAgent> agent) {
float total_reward = 0.0; float total_reward = 0.0;
env.reset(); env.reset();
const float* obs = env.getState(); const float* obs = env.getState();
...@@ -87,10 +87,10 @@ int main(int argc, char* argv[]) { ...@@ -87,10 +87,10 @@ int main(int argc, char* argv[]) {
} }
std::shared_ptr<PaddlePredictor> paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model"); std::shared_ptr<PaddlePredictor> paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model");
std::shared_ptr<AsyncAgent> agent = std::make_shared<AsyncAgent>(paddle_predictor, "../benchmark/cartpole_config.prototxt"); std::shared_ptr<AsyncESAgent> agent = std::make_shared<AsyncESAgent>(paddle_predictor, "../benchmark/cartpole_config.prototxt");
// Clone agents to sample (explore). // Clone agents to sample (explore).
std::vector< std::shared_ptr<AsyncAgent> > sampling_agents; std::vector< std::shared_ptr<AsyncESAgent> > sampling_agents;
for (int i = 0; i < ITER; ++i) { for (int i = 0; i < ITER; ++i) {
sampling_agents.push_back(agent->clone()); sampling_agents.push_back(agent->clone());
} }
...@@ -113,7 +113,7 @@ int main(int argc, char* argv[]) { ...@@ -113,7 +113,7 @@ int main(int argc, char* argv[]) {
} }
#pragma omp parallel for schedule(dynamic, 1) #pragma omp parallel for schedule(dynamic, 1)
for (int i = 0; i < ITER; ++i) { for (int i = 0; i < ITER; ++i) {
std::shared_ptr<AsyncAgent> sampling_agent = sampling_agents[i]; std::shared_ptr<AsyncESAgent> sampling_agent = sampling_agents[i];
SamplingInfo info; SamplingInfo info;
bool success = sampling_agent->add_noise(info); bool success = sampling_agent->add_noise(info);
float reward = evaluate(envs[i], sampling_agent); float reward = evaluate(envs[i], sampling_agent);
......
...@@ -26,27 +26,27 @@ namespace DeepES{ ...@@ -26,27 +26,27 @@ namespace DeepES{
* 2. add_noise: add noise into parameters. * 2. add_noise: add noise into parameters.
* 3. update: update parameters given data collected during evaluation. * 3. update: update parameters given data collected during evaluation.
*/ */
class AsyncAgent: public ESAgent { class AsyncESAgent: public ESAgent {
public: public:
AsyncAgent() {} AsyncESAgent() {}
~AsyncAgent(); ~AsyncESAgent();
/** /**
* @args: * @args:
* predictor: predictor created by users for prediction. * predictor: predictor created by users for prediction.
* config_path: the path of configuration file. * config_path: the path of configuration file.
* Note that AsyncAgent will update the configuration file after calling the update function. * Note that AsyncESAgent will update the configuration file after calling the update function.
* Please use the up-to-date configuration. * Please use the up-to-date configuration.
*/ */
AsyncAgent( AsyncESAgent(
std::shared_ptr<PaddlePredictor> predictor, std::shared_ptr<PaddlePredictor> predictor,
std::string config_path); std::string config_path);
/** /**
* @brief: Clone an agent for sampling. * @brief: Clone an agent for sampling.
*/ */
std::shared_ptr<AsyncAgent> clone(); std::shared_ptr<AsyncESAgent> clone();
/** /**
* @brief: Clone an agent for sampling. * @brief: Clone an agent for sampling.
......
...@@ -15,22 +15,22 @@ ...@@ -15,22 +15,22 @@
#include "async_es_agent.h" #include "async_es_agent.h"
namespace DeepES { namespace DeepES {
AsyncAgent::AsyncAgent( AsyncESAgent::AsyncESAgent(
std::shared_ptr<PaddlePredictor> predictor, std::shared_ptr<PaddlePredictor> predictor,
std::string config_path): ESAgent(predictor, config_path) { std::string config_path): ESAgent(predictor, config_path) {
_config_path = config_path; _config_path = config_path;
} }
AsyncAgent::~AsyncAgent() { AsyncESAgent::~AsyncESAgent() {
for(const auto kv: _param_delta) { for(const auto kv: _param_delta) {
float* delta = kv.second; float* delta = kv.second;
delete[] delta; delete[] delta;
} }
} }
bool AsyncAgent::_save() { bool AsyncESAgent::_save() {
bool success = true; bool success = true;
if (_is_sampling_agent) { if (_is_sampling_agent) {
LOG(ERROR) << "[DeepES] Original AsyncAgent cannot call add_noise function, please use cloned AsyncAgent."; LOG(ERROR) << "[DeepES] Original AsyncESAgent cannot call add_noise function, please use cloned AsyncESAgent.";
success = false; success = false;
return success; return success;
} }
...@@ -55,7 +55,7 @@ bool AsyncAgent::_save() { ...@@ -55,7 +55,7 @@ bool AsyncAgent::_save() {
async_es->set_model_iter_id(model_iter_id); async_es->set_model_iter_id(model_iter_id);
success = save_proto_conf(_config_path, *_config); success = save_proto_conf(_config_path, *_config);
if (!success) { if (!success) {
LOG(ERROR) << "[]unable to save config for AsyncAgent"; LOG(ERROR) << "[]unable to save config for AsyncESAgent";
success = false; success = false;
return success; return success;
} }
...@@ -64,7 +64,7 @@ bool AsyncAgent::_save() { ...@@ -64,7 +64,7 @@ bool AsyncAgent::_save() {
return success; return success;
} }
bool AsyncAgent::_remove_expired_model(int max_to_keep) { bool AsyncESAgent::_remove_expired_model(int max_to_keep) {
bool success = true; bool success = true;
std::string model_path = _config->async_es().model_warehouse(); std::string model_path = _config->async_es().model_warehouse();
std::vector<std::string> model_dirs = list_all_model_dirs(model_path); std::vector<std::string> model_dirs = list_all_model_dirs(model_path);
...@@ -86,7 +86,7 @@ bool AsyncAgent::_remove_expired_model(int max_to_keep) { ...@@ -86,7 +86,7 @@ bool AsyncAgent::_remove_expired_model(int max_to_keep) {
return success; return success;
} }
bool AsyncAgent::_compute_model_diff() { bool AsyncESAgent::_compute_model_diff() {
bool success = true; bool success = true;
for (const auto& kv: _previous_predictors) { for (const auto& kv: _previous_predictors) {
int model_iter_id = kv.first; int model_iter_id = kv.first;
...@@ -108,7 +108,7 @@ bool AsyncAgent::_compute_model_diff() { ...@@ -108,7 +108,7 @@ bool AsyncAgent::_compute_model_diff() {
return success; return success;
} }
bool AsyncAgent::_load() { bool AsyncESAgent::_load() {
bool success = true; bool success = true;
std::string model_path = _config->async_es().model_warehouse(); std::string model_path = _config->async_es().model_warehouse();
std::vector<std::string> model_dirs = list_all_model_dirs(model_path); std::vector<std::string> model_dirs = list_all_model_dirs(model_path);
...@@ -140,7 +140,7 @@ bool AsyncAgent::_load() { ...@@ -140,7 +140,7 @@ bool AsyncAgent::_load() {
return success; return success;
} }
std::shared_ptr<PaddlePredictor> AsyncAgent::_load_previous_model(std::string model_dir) { std::shared_ptr<PaddlePredictor> AsyncESAgent::_load_previous_model(std::string model_dir) {
// 1. Create CxxConfig // 1. Create CxxConfig
CxxConfig config; CxxConfig config;
config.set_model_file(model_dir + "/model"); config.set_model_file(model_dir + "/model");
...@@ -155,10 +155,10 @@ std::shared_ptr<PaddlePredictor> AsyncAgent::_load_previous_model(std::string mo ...@@ -155,10 +155,10 @@ std::shared_ptr<PaddlePredictor> AsyncAgent::_load_previous_model(std::string mo
return predictor; return predictor;
} }
std::shared_ptr<AsyncAgent> AsyncAgent::clone() { std::shared_ptr<AsyncESAgent> AsyncESAgent::clone() {
std::shared_ptr<PaddlePredictor> new_sampling_predictor = _predictor->Clone(); std::shared_ptr<PaddlePredictor> new_sampling_predictor = _predictor->Clone();
std::shared_ptr<AsyncAgent> new_agent = std::make_shared<AsyncAgent>(); std::shared_ptr<AsyncESAgent> new_agent = std::make_shared<AsyncESAgent>();
float* noise = new float [_param_size]; float* noise = new float [_param_size];
...@@ -175,7 +175,7 @@ std::shared_ptr<AsyncAgent> AsyncAgent::clone() { ...@@ -175,7 +175,7 @@ std::shared_ptr<AsyncAgent> AsyncAgent::clone() {
return new_agent; return new_agent;
} }
bool AsyncAgent::update( bool AsyncESAgent::update(
std::vector<SamplingInfo>& noisy_info, std::vector<SamplingInfo>& noisy_info,
std::vector<float>& noisy_rewards) { std::vector<float>& noisy_rewards) {
...@@ -237,7 +237,7 @@ bool AsyncAgent::update( ...@@ -237,7 +237,7 @@ bool AsyncAgent::update(
return true; return true;
} }
int AsyncAgent::_parse_model_iter_id(const std::string& model_path) { int AsyncESAgent::_parse_model_iter_id(const std::string& model_path) {
int model_iter_id = -1; int model_iter_id = -1;
int pow = 1; int pow = 1;
for (int i = model_path.size() - 1; i >= 0; --i) { for (int i = model_path.size() - 1; i >= 0; --i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册