diff --git a/deepes/benchmark/cartpole_config.prototxt b/deepes/demo/cartpole_config.prototxt similarity index 100% rename from deepes/benchmark/cartpole_config.prototxt rename to deepes/demo/cartpole_config.prototxt diff --git a/deepes/demo/paddle/cartpole_async_solver.cc b/deepes/demo/paddle/cartpole_async_solver.cc index 5cbe48eb0ddc8e1dc7458371ae8c0367e19e0198..9b3d0d31eaf72592e4f630ad70306b097c8e86db 100644 --- a/deepes/demo/paddle/cartpole_async_solver.cc +++ b/deepes/demo/paddle/cartpole_async_solver.cc @@ -24,20 +24,6 @@ using namespace paddle::lite_api; const int ITER = 10; -std::shared_ptr create_paddle_predictor(const std::string& model_dir) { - // 1. Create CxxConfig - CxxConfig config; - config.set_model_dir(model_dir); - config.set_valid_places({ - Place{TARGET(kX86), PRECISION(kFloat)}, - Place{TARGET(kHost), PRECISION(kFloat)} - }); - - // 2. Create PaddlePredictor by CxxConfig - std::shared_ptr predictor = CreatePaddlePredictor(config); - return predictor; -} - // Use PaddlePredictor of CartPole model to predict the action. std::vector forward(std::shared_ptr predictor, const float* obs) { std::unique_ptr input_tensor(std::move(predictor->GetInput(0))); @@ -86,8 +72,8 @@ int main(int argc, char* argv[]) { envs.push_back(CartPole()); } - std::shared_ptr paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model"); - std::shared_ptr agent = std::make_shared(paddle_predictor, "../benchmark/cartpole_config.prototxt"); + std::shared_ptr agent = std::make_shared("../demo/paddle/cartpole_init_model", + "../demo/cartpole_config.prototxt"); // Clone agents to sample (explore). std::vector< std::shared_ptr > sampling_agents; diff --git a/deepes/demo/paddle/cartpole_init_model.zip b/deepes/demo/paddle/cartpole_init_model.zip index 04d21fb870a13f149f9ed6d05a4618fa4cefcd4a..16a7720959786471f8f500e7aa031615d53a1928 100644 Binary files a/deepes/demo/paddle/cartpole_init_model.zip and b/deepes/demo/paddle/cartpole_init_model.zip differ diff --git a/deepes/demo/paddle/cartpole_solver_parallel.cc b/deepes/demo/paddle/cartpole_solver_parallel.cc index 9fccb1a995774a9e98d50dfbf4e42470237c0fed..952f92a85df71c852c53fab56756fd26d87d8436 100644 --- a/deepes/demo/paddle/cartpole_solver_parallel.cc +++ b/deepes/demo/paddle/cartpole_solver_parallel.cc @@ -24,20 +24,6 @@ using namespace paddle::lite_api; const int ITER = 10; -std::shared_ptr create_paddle_predictor(const std::string& model_dir) { - // 1. Create CxxConfig - CxxConfig config; - config.set_model_dir(model_dir); - config.set_valid_places({ - Place{TARGET(kX86), PRECISION(kFloat)}, - Place{TARGET(kHost), PRECISION(kFloat)} - }); - - // 2. Create PaddlePredictor by CxxConfig - std::shared_ptr predictor = CreatePaddlePredictor(config); - return predictor; -} - // Use PaddlePredictor of CartPole model to predict the action. std::vector forward(std::shared_ptr predictor, const float* obs) { std::unique_ptr input_tensor(std::move(predictor->GetInput(0))); @@ -86,8 +72,8 @@ int main(int argc, char* argv[]) { envs.push_back(CartPole()); } - std::shared_ptr paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model"); - std::shared_ptr agent = std::make_shared(paddle_predictor, "../benchmark/cartpole_config.prototxt"); + std::shared_ptr agent = std::make_shared("../demo/paddle/cartpole_init_model", + "../demo/cartpole_config.prototxt"); // Clone agents to sample (explore). std::vector< std::shared_ptr > sampling_agents; diff --git a/deepes/demo/paddle/gen_cartpole_init_model.py b/deepes/demo/paddle/gen_cartpole_init_model.py index 66b841aaf4ac428ca2232324a35fa66bd683c572..9295224953e74a9572915d3612bd4634f61de55e 100644 --- a/deepes/demo/paddle/gen_cartpole_init_model.py +++ b/deepes/demo/paddle/gen_cartpole_init_model.py @@ -36,4 +36,6 @@ if __name__ == '__main__': dirname='cartpole_init_model', feeded_var_names=['obs'], target_vars=[prob], + params_filename='param', + model_filename='model', executor=exe) diff --git a/deepes/demo/torch/cartpole_solver_parallel.cc b/deepes/demo/torch/cartpole_solver_parallel.cc index 98be214c1da5cb1921b5d83cb64ae98d46d50bdd..f7b071de307fd7f1f0253b3d0c75ef4ebd295ded 100644 --- a/deepes/demo/torch/cartpole_solver_parallel.cc +++ b/deepes/demo/torch/cartpole_solver_parallel.cc @@ -51,7 +51,8 @@ int main(int argc, char* argv[]) { } auto model = std::make_shared(4, 2); - std::shared_ptr> agent = std::make_shared>(model, "../benchmark/cartpole_config.prototxt"); + std::shared_ptr> agent = std::make_shared>(model, + "../demo/cartpole_config.prototxt"); // Clone agents to sample (explore). std::vector>> sampling_agents; diff --git a/deepes/include/paddle/async_es_agent.h b/deepes/include/paddle/async_es_agent.h index 11b8dff53bdafd23b1ce3524f42b97a688283a53..9a1f61d78304b804c91221889d6c0e210178b905 100644 --- a/deepes/include/paddle/async_es_agent.h +++ b/deepes/include/paddle/async_es_agent.h @@ -40,8 +40,8 @@ class AsyncESAgent: public ESAgent { * Please use the up-to-date configuration. */ AsyncESAgent( - std::shared_ptr predictor, - std::string config_path); + const std::string& model_dir, + const std::string& config_path); /** * @brief: Clone an agent for sampling. diff --git a/deepes/include/paddle/es_agent.h b/deepes/include/paddle/es_agent.h index 25c9d98e9b11776669692233cc1b9061cc8fe1eb..ffe27fbec568041aeab1956e368b1f07b4a0390e 100644 --- a/deepes/include/paddle/es_agent.h +++ b/deepes/include/paddle/es_agent.h @@ -38,13 +38,11 @@ int64_t ShapeProduction(const shape_t& shape); */ class ESAgent { public: - ESAgent(); + ESAgent() {} ~ESAgent(); - ESAgent( - std::shared_ptr predictor, - std::string config_path); + ESAgent(const std::string& model_dir, const std::string& config_path); /** * @breif Clone a sampling agent @@ -83,15 +81,16 @@ class ESAgent { std::shared_ptr _predictor; std::shared_ptr _sampling_predictor; - bool _is_sampling_agent; std::shared_ptr _sampling_method; std::shared_ptr _optimizer; std::shared_ptr _config; - int64_t _param_size; + std::shared_ptr _cxx_config; std::vector _param_names; // malloc memory of noise and neg_gradients in advance. float* _noise; float* _neg_gradients; + int64_t _param_size; + bool _is_sampling_agent; }; } diff --git a/deepes/include/utils.h b/deepes/include/utils.h index 5835a43defd6a4abfeae7a68a5671f3c3239dcfc..76ba45b23b4729170d3bdcb657cecf345fa9107f 100644 --- a/deepes/include/utils.h +++ b/deepes/include/utils.h @@ -20,6 +20,7 @@ #include #include "deepes.pb.h" #include +#include namespace DeepES{ @@ -29,6 +30,8 @@ namespace DeepES{ */ bool compute_centered_ranks(std::vector &reward); +std::string read_file(const std::string& filename); + /* Load a protobuf-based configuration from the file. * Args: * config_file: file path. diff --git a/deepes/src/paddle/async_es_agent.cc b/deepes/src/paddle/async_es_agent.cc index f128ddcde2556009f56f9a9aea829d3ee46ce7b3..134bcbf3cce104c6616f55c5bf0d32f3993bef8d 100644 --- a/deepes/src/paddle/async_es_agent.cc +++ b/deepes/src/paddle/async_es_agent.cc @@ -16,8 +16,8 @@ namespace DeepES { AsyncESAgent::AsyncESAgent( - std::shared_ptr predictor, - std::string config_path): ESAgent(predictor, config_path) { + const std::string& model_dir, + const std::string& config_path): ESAgent(model_dir, config_path) { _config_path = config_path; } AsyncESAgent::~AsyncESAgent() { @@ -155,15 +155,13 @@ std::shared_ptr AsyncESAgent::_load_previous_model(std::string } std::shared_ptr AsyncESAgent::clone() { - std::shared_ptr new_sampling_predictor = _predictor->Clone(); std::shared_ptr new_agent = std::make_shared(); float* noise = new float [_param_size]; new_agent->_predictor = _predictor; - new_agent->_sampling_predictor = new_sampling_predictor; - + new_agent->_sampling_predictor = CreatePaddlePredictor(*_cxx_config); new_agent->_is_sampling_agent = true; new_agent->_sampling_method = _sampling_method; new_agent->_param_names = _param_names; diff --git a/deepes/src/paddle/es_agent.cc b/deepes/src/paddle/es_agent.cc index 6cd6f93c896ee23ad15d0b947f823ea52011d806..2593472e751f5240daf449ecf333290330decadc 100644 --- a/deepes/src/paddle/es_agent.cc +++ b/deepes/src/paddle/es_agent.cc @@ -23,22 +23,31 @@ int64_t ShapeProduction(const shape_t& shape) { return res; } -ESAgent::ESAgent() {} - ESAgent::~ESAgent() { delete[] _noise; if (!_is_sampling_agent) delete[] _neg_gradients; } -ESAgent::ESAgent( - std::shared_ptr predictor, - std::string config_path) { +ESAgent::ESAgent(const std::string& model_dir, const std::string& config_path) { + // 1. Create CxxConfig + _cxx_config = std::make_shared(); + std::string model_path = model_dir + "/model"; + std::string param_path = model_dir + "/param"; + std::string model_buffer = read_file(model_path); + std::string param_buffer = read_file(param_path); + _cxx_config->set_model_buffer(model_buffer.c_str(), model_buffer.size(), + param_buffer.c_str(), param_buffer.size()); + _cxx_config->set_valid_places({ + Place{TARGET(kX86), PRECISION(kFloat)}, + Place{TARGET(kHost), PRECISION(kFloat)} + }); + + _predictor = CreatePaddlePredictor(*_cxx_config); _is_sampling_agent = false; - _predictor = predictor; // Original agent can't be used to sample, so keep it same with _predictor for evaluating. - _sampling_predictor = predictor; + _sampling_predictor = _predictor; _config = std::make_shared(); load_proto_conf(config_path, *_config); @@ -56,15 +65,17 @@ ESAgent::ESAgent( } std::shared_ptr ESAgent::clone() { - std::shared_ptr new_sampling_predictor = _predictor->Clone(); - + if (_is_sampling_agent) { + LOG(ERROR) << "[DeepES] only original ESAgent can call `clone` function."; + return nullptr; + } std::shared_ptr new_agent = std::make_shared(); float* noise = new float [_param_size]; + new_agent->_sampling_predictor = CreatePaddlePredictor(*_cxx_config); new_agent->_predictor = _predictor; - new_agent->_sampling_predictor = new_sampling_predictor; - + new_agent->_cxx_config = _cxx_config; new_agent->_is_sampling_agent = true; new_agent->_sampling_method = _sampling_method; new_agent->_param_names = _param_names; diff --git a/deepes/src/utils.cc b/deepes/src/utils.cc index cd5b055405ceefc41d7f8be007b52e9e4ddd7221..ff2b624d971aae2459aa5c563c68a21a26b0f6d1 100644 --- a/deepes/src/utils.cc +++ b/deepes/src/utils.cc @@ -52,4 +52,17 @@ std::vector list_all_model_dirs(std::string path) { return model_dirs; } +std::string read_file(const std::string& filename) { + std::ifstream ifile(filename.c_str()); + if (!ifile.is_open()) { + LOG(ERROR) << "Open file: [" << filename << "] failed."; + return ""; + } + std::ostringstream buf; + char ch; + while (buf && ifile.get(ch)) buf.put(ch); + ifile.close(); + return buf.str(); +} + }//namespace