From 65554cd654206c554d1dc9a4b30d6a776e3c769c Mon Sep 17 00:00:00 2001 From: Bo Zhou <2466956298@qq.com> Date: Thu, 2 Apr 2020 16:43:18 +0800 Subject: [PATCH] remove depedence on predictor.clone() (#240) * remove depedence on predictor.clone() * remove commented lines * remove FATAL logging * remove ESAgent's construnction function based on cxx_config * mv prototxt --- .../cartpole_config.prototxt | 0 deepes/demo/paddle/cartpole_async_solver.cc | 18 ++-------- deepes/demo/paddle/cartpole_init_model.zip | Bin 3487 -> 2380 bytes .../demo/paddle/cartpole_solver_parallel.cc | 18 ++-------- deepes/demo/paddle/gen_cartpole_init_model.py | 2 ++ deepes/demo/torch/cartpole_solver_parallel.cc | 3 +- deepes/include/paddle/async_es_agent.h | 4 +-- deepes/include/paddle/es_agent.h | 11 +++--- deepes/include/utils.h | 3 ++ deepes/src/paddle/async_es_agent.cc | 8 ++--- deepes/src/paddle/es_agent.cc | 33 ++++++++++++------ deepes/src/utils.cc | 13 +++++++ 12 files changed, 56 insertions(+), 57 deletions(-) rename deepes/{benchmark => demo}/cartpole_config.prototxt (100%) diff --git a/deepes/benchmark/cartpole_config.prototxt b/deepes/demo/cartpole_config.prototxt similarity index 100% rename from deepes/benchmark/cartpole_config.prototxt rename to deepes/demo/cartpole_config.prototxt diff --git a/deepes/demo/paddle/cartpole_async_solver.cc b/deepes/demo/paddle/cartpole_async_solver.cc index 5cbe48e..9b3d0d3 100644 --- a/deepes/demo/paddle/cartpole_async_solver.cc +++ b/deepes/demo/paddle/cartpole_async_solver.cc @@ -24,20 +24,6 @@ using namespace paddle::lite_api; const int ITER = 10; -std::shared_ptr create_paddle_predictor(const std::string& model_dir) { - // 1. Create CxxConfig - CxxConfig config; - config.set_model_dir(model_dir); - config.set_valid_places({ - Place{TARGET(kX86), PRECISION(kFloat)}, - Place{TARGET(kHost), PRECISION(kFloat)} - }); - - // 2. Create PaddlePredictor by CxxConfig - std::shared_ptr predictor = CreatePaddlePredictor(config); - return predictor; -} - // Use PaddlePredictor of CartPole model to predict the action. std::vector forward(std::shared_ptr predictor, const float* obs) { std::unique_ptr input_tensor(std::move(predictor->GetInput(0))); @@ -86,8 +72,8 @@ int main(int argc, char* argv[]) { envs.push_back(CartPole()); } - std::shared_ptr paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model"); - std::shared_ptr agent = std::make_shared(paddle_predictor, "../benchmark/cartpole_config.prototxt"); + std::shared_ptr agent = std::make_shared("../demo/paddle/cartpole_init_model", + "../demo/cartpole_config.prototxt"); // Clone agents to sample (explore). std::vector< std::shared_ptr > sampling_agents; diff --git a/deepes/demo/paddle/cartpole_init_model.zip b/deepes/demo/paddle/cartpole_init_model.zip index 04d21fb870a13f149f9ed6d05a4618fa4cefcd4a..16a7720959786471f8f500e7aa031615d53a1928 100644 GIT binary patch literal 2380 zcma);c{J4h9>;%-v0f&OeQPipO4%t)F_Pz6hRN8Op)8YiY$+*3mLlR&CWYv+WMb?g zMW~s{I+5&5iOHlw581AIy7%1YIgfkp_j^8{^ZEYqeVy<3ocH;B-nNz?E`H$qI`Lba z?T_MTf&#L@DbLXGkl+As93hAhjtdMv?HzCkYtILO!08Dd+&d8wB>-@6&4U2oN3z|a zL3i%+{-E8uej-7H69DpdLTVRUh-av0;D3x2zduc^MIj#dwrhUX>}Jyu9kCF{0_Cvpq6 zk?V!wbrudi@lCJppM)5_hZrfOuZw<0!=1Gb|DI3Xs#D6L8uxa^WR3wV9|YeH-|c{H z`>TaV9>bip@p9eYnBh3+)WAc`y}i;~8RsC{*$|?jf-|$gy)jYhJObP5QLz|@+fNL# zM(bvjx}XXO^dpJf9=MU-M8_5X$@#AoHBUs$We@fDMd2(Vw^p)~x7A(@gR|u7^SCQT z?PFaTm40+=A+qHLfvH(-bqPGVlB2qSN_OR^7M=$>XCgjrnJgp6O#1sR5eyA2)pgs9 zXO|%r5<1gd0zF@JdVCUUq(Sf%W2oC`+_#wnIrV9t(?Nn*NK0Htkgwmx^wb>1F=?lE zw>IxzkNSwMGUyij$NK_PF=YPmyXO%~&U@!sCNAL@&7VSjeb@&CE48z>y}|3v3yVC* zw-JYKrnDA|?5Vak7?4GkyBnR~FfI|oHsK|yN8sm}} zaz~0`jSlhILXvFQN$>MgR%=mWye@MmS;FpqR^>vGVXUhI@*BXPMF#bBUXla@K#{`k zBKslS|0OcEb8j%y3~Dp7kr&tiBNtRBR2?WR#wIh1$MyENr^*|U@xf22Bm-twIh(RT z91Xzor6}nu3(00Q0uoz*c$*ZF*(9=0xg%{GXuC~pBIebyQUiIhp=F!1Yu&$3v@G7X z5)R*eXbqvjOwH0sAptgL-9*xx{j{GxH*0Xq zT5Uo_cQ}ARqg=DCGui~+&s<{;Psy>NQ#rzWi!{|W>+M|PDc#emsj^!@I zdNND1jaOCF>5EQ!nPPQlxqHZuj4DeR5)DxY8lTwxQjNL`5MlG zNnO09d2~+p4t`ssR6)c6n*eT&zS(rzUqJ3cWSE7z^yS7Ulj6LGfG-$M+FtIZ3NPE{ zP+WZv$LlPq=>h3z5}?P5o{?5&#c&)tdC129qy7Oyx|M<28|mJj`#J$k|%2N0LA-{A~Xor zFO9^;84FxWtt3sQuUF`)kKZ$&_39@vx+N z2f=98(@igTUGdGwHk`xgo{Kj8CnR)j`AQg{{p&RNz!`S3l@+SE;F|vRP~4CKg}OrR z4@izW=&6XGWD!D7hyyI)*d~H~VXuEE-dIKMW5LqCTeF)V&M-6& zpn#dMv5h_EpBJL((0T6loi0Ob7bQI+9)_-nowIcY7ybiT_pq0;Zw>r>Hue;**GPg~ zCxiaYe;VJi>?<~5QI%=>jMXQDG|V^Y&+@9Ok&plln?+1zL|n@~?v5YQAb(aaA4@y> z+I-m1&CE&$15ylwJxY_tm?F63LedxRe#i;EZt&g_y&~`9q+hm(=Ps7$YD@Ucynwk> zo4MsjjlnfjW45n%wjPC$bZ)L+{3M$<*kxe&*6v14UKwgYC>M?xD&u-}S_nygncbET zJ#paGjS@ip9Dkq3E7hrzPF1|TWP<^CgQq0c^&w(=*6LFA1m8gUI18&d?_|cNDA9t~ihFeEmB){S%^#1?7K6~-p~l$* zl_eMjxwqoYRlbSn;& zD!cObpnefOvP$$hE7%?FIH?N|{i_`+sV#+uG&uB#q2o@&d@wX5@Fe(FUnZ`44r}v; zUdipKNIUU553}s6)x)4hS8ogK_l>X|3L3l06isYlmBFJt`=@@`(?-(&!&jkDQ>{L?x7fA~L7 iiT}pe?3k0kd-%_pu^S%D{r$(Foh``$0P?)wPyYhGEe7KN literal 3487 zcma)ZYt0}b9>$ubl9V-DWM8gjMp;6Vh zf;Y@E5&pKW&)j=1o7_@T`M|gLLf}YvLqgFcj1D03R`=rm2YpVTg-Uz0T8TsDq(C@V-2Vc~?Yj=Zj(%`nskafs1 zruy&--?##4LDyF^=un(##j7I`3rhS8UFiq)e;HDn5-*ryGcnhWGF>#`Xw6*WxL5i3 zj-#XIQAb5eLFH?ELV7l8^TWcZJwY4S9a~`f%)H=d+@aing4ZoPBK#$kThrn3)r7kH zhN!s_y}8iOjlvFHVH^hq8$ENZ5yqROgtR7DZ8TkvCG5qr43f1N$KUSnMoDKXi1Ex? zm>=I{oK-GVuVI=>*2Sj|@C83kP&rY9oPC#QL(*7ejanJd zBb()%!Qz=iyj~_5KPq5L&dB_9DX6Ls)Y%6zuOR3phZy`chhnQ1c-n|Q0JiSI{n@f~la-o8steUC~ z?&my}j%*?EGj(1S0ePz76Fq^z@2w2W^=DEARb@$R;S6jy{Q#0)UxdYCgwNWdy^x^3 zV-EV|@u4eXj4Xyz^PGoOtJj!H$Wg*SU2n3mgZVQ9!y+m;8*o3@TJ*hizym~O*s5YP zzPP418~4IKT@a-VqKurJeW;)^XeRfDQ*L;3lXo18ZV4nMP&*He zj|Rk1vG3xkgUBy53Di4=GUWk9n+3VN_U93+o#%DErW4TaLy!N-`frJ%aE2wN*Uz%}A5TJD*41P@s?dLNL$uC} z$`gssLwEj1J8St;--=+5RIItPc9E$>on1RM_gL#0i+L4d>?7J%g(M`WR-+tWa`n|` zeVw}b!bVE``!IC!{^NLkIy)cCz zXpMG6UNFv2tRIspNr-lT^hsJ)Ywa_rh(gZ3iszomc2s@J_YvQR^`qzr3v+dPzSDmk zYoC%!bmx=#$Qz&QX_dKrkrze`J4>qq-@NQmm#s5OAsqi5kHOrd>%!?&Msxo9Ag#zG z!#l@pzi*lobHpVc1wCWcB%Slu`F z&_KVyAZt;=4d^aKf@*1!uRo65)KsV~$aT}dB`165jOyx5{q(w0q>29Rxq*hE( zK1Yr2nK9o{Pa2L~j}hd7?&3bog!qeNFg0#xr+u17B_7+)o9T8q7t@v(GS-uE)*?x* z2Z`3Q!iv0Hh}0uV!G{{oqQ}BK<={QSpM4*yg@*^3T;QeBbqllSS-X7Xeio}M>AS%p z=OPx^c@kxoB_+M$)1@}u!;f8=8NkAJm)qqD#o)(V^}O!KH61B~`L3q_gHzb9=^i^Z z9YI*I6WWA|8n9$Vg-FUNSM!GD4uf-O9w4_xO>+LGE%mdt_tB;SgxY;X90Po-C?z zbIcM30NeeCF6#eed~1oj8vRzpNXJgg&c&te0J66{Qwtu*ZXmVlK}P4>z;(n6Ff}6_ zsuElR`)^l6z%5aH;OP?3Zleh_Zn6QVhufiSgH+X7<{k)fx($?54uyaV>vc+632LX= z3!z&Nt3Xsr351pU2&r{EfVfK+AkB&eFvKbgBC$_{vVpx&KM!p~l42?dh?`F>U2 z%TfPV833xnPVb@o&0qgJ!freLJ)v?B!guES-ywEiA>ShqjQ^aTt&jNYLiu-u-N*8K uf-KVy2s^(aKO%hF;@yYxdqVQw1lg~b;$8&0?Gp*mZ#_CR0HDph9sLPmFt0EG diff --git a/deepes/demo/paddle/cartpole_solver_parallel.cc b/deepes/demo/paddle/cartpole_solver_parallel.cc index 9fccb1a..952f92a 100644 --- a/deepes/demo/paddle/cartpole_solver_parallel.cc +++ b/deepes/demo/paddle/cartpole_solver_parallel.cc @@ -24,20 +24,6 @@ using namespace paddle::lite_api; const int ITER = 10; -std::shared_ptr create_paddle_predictor(const std::string& model_dir) { - // 1. Create CxxConfig - CxxConfig config; - config.set_model_dir(model_dir); - config.set_valid_places({ - Place{TARGET(kX86), PRECISION(kFloat)}, - Place{TARGET(kHost), PRECISION(kFloat)} - }); - - // 2. Create PaddlePredictor by CxxConfig - std::shared_ptr predictor = CreatePaddlePredictor(config); - return predictor; -} - // Use PaddlePredictor of CartPole model to predict the action. std::vector forward(std::shared_ptr predictor, const float* obs) { std::unique_ptr input_tensor(std::move(predictor->GetInput(0))); @@ -86,8 +72,8 @@ int main(int argc, char* argv[]) { envs.push_back(CartPole()); } - std::shared_ptr paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model"); - std::shared_ptr agent = std::make_shared(paddle_predictor, "../benchmark/cartpole_config.prototxt"); + std::shared_ptr agent = std::make_shared("../demo/paddle/cartpole_init_model", + "../demo/cartpole_config.prototxt"); // Clone agents to sample (explore). std::vector< std::shared_ptr > sampling_agents; diff --git a/deepes/demo/paddle/gen_cartpole_init_model.py b/deepes/demo/paddle/gen_cartpole_init_model.py index 66b841a..9295224 100644 --- a/deepes/demo/paddle/gen_cartpole_init_model.py +++ b/deepes/demo/paddle/gen_cartpole_init_model.py @@ -36,4 +36,6 @@ if __name__ == '__main__': dirname='cartpole_init_model', feeded_var_names=['obs'], target_vars=[prob], + params_filename='param', + model_filename='model', executor=exe) diff --git a/deepes/demo/torch/cartpole_solver_parallel.cc b/deepes/demo/torch/cartpole_solver_parallel.cc index 98be214..f7b071d 100644 --- a/deepes/demo/torch/cartpole_solver_parallel.cc +++ b/deepes/demo/torch/cartpole_solver_parallel.cc @@ -51,7 +51,8 @@ int main(int argc, char* argv[]) { } auto model = std::make_shared(4, 2); - std::shared_ptr> agent = std::make_shared>(model, "../benchmark/cartpole_config.prototxt"); + std::shared_ptr> agent = std::make_shared>(model, + "../demo/cartpole_config.prototxt"); // Clone agents to sample (explore). std::vector>> sampling_agents; diff --git a/deepes/include/paddle/async_es_agent.h b/deepes/include/paddle/async_es_agent.h index 11b8dff..9a1f61d 100644 --- a/deepes/include/paddle/async_es_agent.h +++ b/deepes/include/paddle/async_es_agent.h @@ -40,8 +40,8 @@ class AsyncESAgent: public ESAgent { * Please use the up-to-date configuration. */ AsyncESAgent( - std::shared_ptr predictor, - std::string config_path); + const std::string& model_dir, + const std::string& config_path); /** * @brief: Clone an agent for sampling. diff --git a/deepes/include/paddle/es_agent.h b/deepes/include/paddle/es_agent.h index 25c9d98..ffe27fb 100644 --- a/deepes/include/paddle/es_agent.h +++ b/deepes/include/paddle/es_agent.h @@ -38,13 +38,11 @@ int64_t ShapeProduction(const shape_t& shape); */ class ESAgent { public: - ESAgent(); + ESAgent() {} ~ESAgent(); - ESAgent( - std::shared_ptr predictor, - std::string config_path); + ESAgent(const std::string& model_dir, const std::string& config_path); /** * @breif Clone a sampling agent @@ -83,15 +81,16 @@ class ESAgent { std::shared_ptr _predictor; std::shared_ptr _sampling_predictor; - bool _is_sampling_agent; std::shared_ptr _sampling_method; std::shared_ptr _optimizer; std::shared_ptr _config; - int64_t _param_size; + std::shared_ptr _cxx_config; std::vector _param_names; // malloc memory of noise and neg_gradients in advance. float* _noise; float* _neg_gradients; + int64_t _param_size; + bool _is_sampling_agent; }; } diff --git a/deepes/include/utils.h b/deepes/include/utils.h index 5835a43..76ba45b 100644 --- a/deepes/include/utils.h +++ b/deepes/include/utils.h @@ -20,6 +20,7 @@ #include #include "deepes.pb.h" #include +#include namespace DeepES{ @@ -29,6 +30,8 @@ namespace DeepES{ */ bool compute_centered_ranks(std::vector &reward); +std::string read_file(const std::string& filename); + /* Load a protobuf-based configuration from the file. * Args: * config_file: file path. diff --git a/deepes/src/paddle/async_es_agent.cc b/deepes/src/paddle/async_es_agent.cc index f128ddc..134bcbf 100644 --- a/deepes/src/paddle/async_es_agent.cc +++ b/deepes/src/paddle/async_es_agent.cc @@ -16,8 +16,8 @@ namespace DeepES { AsyncESAgent::AsyncESAgent( - std::shared_ptr predictor, - std::string config_path): ESAgent(predictor, config_path) { + const std::string& model_dir, + const std::string& config_path): ESAgent(model_dir, config_path) { _config_path = config_path; } AsyncESAgent::~AsyncESAgent() { @@ -155,15 +155,13 @@ std::shared_ptr AsyncESAgent::_load_previous_model(std::string } std::shared_ptr AsyncESAgent::clone() { - std::shared_ptr new_sampling_predictor = _predictor->Clone(); std::shared_ptr new_agent = std::make_shared(); float* noise = new float [_param_size]; new_agent->_predictor = _predictor; - new_agent->_sampling_predictor = new_sampling_predictor; - + new_agent->_sampling_predictor = CreatePaddlePredictor(*_cxx_config); new_agent->_is_sampling_agent = true; new_agent->_sampling_method = _sampling_method; new_agent->_param_names = _param_names; diff --git a/deepes/src/paddle/es_agent.cc b/deepes/src/paddle/es_agent.cc index 6cd6f93..2593472 100644 --- a/deepes/src/paddle/es_agent.cc +++ b/deepes/src/paddle/es_agent.cc @@ -23,22 +23,31 @@ int64_t ShapeProduction(const shape_t& shape) { return res; } -ESAgent::ESAgent() {} - ESAgent::~ESAgent() { delete[] _noise; if (!_is_sampling_agent) delete[] _neg_gradients; } -ESAgent::ESAgent( - std::shared_ptr predictor, - std::string config_path) { +ESAgent::ESAgent(const std::string& model_dir, const std::string& config_path) { + // 1. Create CxxConfig + _cxx_config = std::make_shared(); + std::string model_path = model_dir + "/model"; + std::string param_path = model_dir + "/param"; + std::string model_buffer = read_file(model_path); + std::string param_buffer = read_file(param_path); + _cxx_config->set_model_buffer(model_buffer.c_str(), model_buffer.size(), + param_buffer.c_str(), param_buffer.size()); + _cxx_config->set_valid_places({ + Place{TARGET(kX86), PRECISION(kFloat)}, + Place{TARGET(kHost), PRECISION(kFloat)} + }); + + _predictor = CreatePaddlePredictor(*_cxx_config); _is_sampling_agent = false; - _predictor = predictor; // Original agent can't be used to sample, so keep it same with _predictor for evaluating. - _sampling_predictor = predictor; + _sampling_predictor = _predictor; _config = std::make_shared(); load_proto_conf(config_path, *_config); @@ -56,15 +65,17 @@ ESAgent::ESAgent( } std::shared_ptr ESAgent::clone() { - std::shared_ptr new_sampling_predictor = _predictor->Clone(); - + if (_is_sampling_agent) { + LOG(ERROR) << "[DeepES] only original ESAgent can call `clone` function."; + return nullptr; + } std::shared_ptr new_agent = std::make_shared(); float* noise = new float [_param_size]; + new_agent->_sampling_predictor = CreatePaddlePredictor(*_cxx_config); new_agent->_predictor = _predictor; - new_agent->_sampling_predictor = new_sampling_predictor; - + new_agent->_cxx_config = _cxx_config; new_agent->_is_sampling_agent = true; new_agent->_sampling_method = _sampling_method; new_agent->_param_names = _param_names; diff --git a/deepes/src/utils.cc b/deepes/src/utils.cc index cd5b055..ff2b624 100644 --- a/deepes/src/utils.cc +++ b/deepes/src/utils.cc @@ -52,4 +52,17 @@ std::vector list_all_model_dirs(std::string path) { return model_dirs; } +std::string read_file(const std::string& filename) { + std::ifstream ifile(filename.c_str()); + if (!ifile.is_open()) { + LOG(ERROR) << "Open file: [" << filename << "] failed."; + return ""; + } + std::ostringstream buf; + char ch; + while (buf && ifile.get(ch)) buf.put(ch); + ifile.close(); + return buf.str(); +} + }//namespace -- GitLab