diff --git a/deepes/CMakeLists.txt b/deepes/CMakeLists.txt index f4c66fbaae4244edceb55f89dcd6d7d32bdec09f..a9b120579e474aa374610678c563297cfa14a2e4 100644 --- a/deepes/CMakeLists.txt +++ b/deepes/CMakeLists.txt @@ -1,5 +1,25 @@ cmake_minimum_required (VERSION 2.6) project (DeepES) +set(TARGET parallel_main) + +########## options ########## +option(WITH_PADDLE "Compile DeepES with PaddleLite framework." OFF) +option(WITH_TORCH "Compile DeepES with Torch framework." OFF) + +message("WITH_PADDLE: "${WITH_PADDLE}) +message("WITH_TORCH: "${WITH_TORCH}) + +if (NOT (WITH_PADDLE OR WITH_TORCH)) + message("ERROR: You should choose at least one framework to compile DeepES.") + return() +elseif(WITH_PADDLE AND WITH_TORCH) + message("ERROR: You cannot choose more than one framework to compile DeepES.") + return() +endif() + +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) find_package(OpenMP) if (OPENMP_FOUND) @@ -8,19 +28,47 @@ if (OPENMP_FOUND) set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") endif() -set(CMAKE_CXX_STANDARD 11) -set(CMAKE_CXX_STANDARD_REQUIRED ON) -set(CMAKE_CXX_EXTENSIONS OFF) -find_package(Torch REQUIRED ON) - -file(GLOB demo_src "demo/*.cpp") -file(GLOB core_src "src/*.cpp") -file(GLOB pb_src "src/*.cc") +file(GLOB src "src/*.cc") include_directories("include") -include_directories("demo") include_directories("benchmark") -link_directories("/usr/lib/x86_64-linux-gnu/") -add_executable(parallel_main "./demo/cartpole_solver_parallel.cpp" ${core_src} ${pb_src} ${benchmark_src}) -target_link_libraries(parallel_main gflags protobuf pthread glog "${TORCH_LIBRARIES}") +########## PaddleLite config ########## +if (WITH_PADDLE) + add_definitions(-g -O3 -pthread) + + include_directories("include/paddle") + include_directories("${PROJECT_SOURCE_DIR}/inference_lite_lib/cxx/include" + "${PROJECT_SOURCE_DIR}/inference_lite_lib/third_party/mklml/include") + link_directories("${PROJECT_SOURCE_DIR}/inference_lite_lib/cxx/lib" + "${PROJECT_SOURCE_DIR}/inference_lite_lib/third_party/mklml/lib") + + file(GLOB framework_src "src/paddle/*.cc") + set(demo "${PROJECT_SOURCE_DIR}/demo/paddle/cartpole_solver_parallel.cc") +########## Torch config ########## +elseif (WITH_TORCH) + list(APPEND CMAKE_PREFIX_PATH "./libtorch") + find_package(Torch REQUIRED ON) + + include_directories("include/torch") + include_directories("demo/torch") + + file(GLOB framework_src "src/torch/*.cc") + set(demo "${PROJECT_SOURCE_DIR}/demo/torch/cartpole_solver_parallel.cc") +else () + message("ERROR: You should choose at least one framework to compile DeepES.") +endif() + +add_executable(${TARGET} ${demo} ${src} ${framework_src}) + +target_link_libraries(${TARGET} gflags protobuf pthread glog) + +########## PaddleLite libraries ########## +if (WITH_PADDLE) + target_link_libraries(${TARGET} -lpaddle_full_api_shared) + target_link_libraries(${TARGET} -lmklml_intel) + target_link_libraries(${TARGET} -ldl) +########## Torch libraries ########## +elseif (WITH_TORCH) + target_link_libraries(${TARGET} "${TORCH_LIBRARIES}") +endif() diff --git a/deepes/README.md b/deepes/README.md index ae14819304b173344dc62a8281025b91ef679d41..81b9c5d469c109d011ae6752c70abc2b2064aace 100644 --- a/deepes/README.md +++ b/deepes/README.md @@ -7,22 +7,21 @@ DeepES是一个支持**快速验证**ES效果、**兼容多个框架**的C++库 ## 使用示范 ```c++ //实例化一个预测,根据配置文件加载模型,采样方式(Gaussian\CMA sampling..)、更新方式(SGD\Adam)等 -auto predictor = Predicotr(config); +auto agent = ESAgent(config); for (int i = 0; i < 100; ++i) { - auto noisy_predictor = predictor->clone(); // copy 一份参数 - int key = noisy_predictor->add_noise(); // 参数扰动,同时保存随机种子 - int reward = evaluate(env, noisiy_predictor); //评估参数 + int key = agent->add_noise(); // 参数扰动,同时保存随机种子 + int reward = evaluate(env, agent); //评估参数 noisy_keys.push_back(key); // 记录随机噪声对应种子 noisy_rewards.push_back(reward); // 记录评估结果 } //根据评估结果、随机种子更新参数,然后重复以上过程,直到收敛。 -predictor->update(noisy_keys, noisy_rewards); +agent->update(noisy_keys, noisy_rewards); ``` ## 一键运行demo列表 -- **Torch**: sh [./scripts/build.sh](./scripts/build.sh) -- **Paddle**: +- **PaddleLite**: sh ./scripts/build.sh paddle +- **Torch**: sh ./scripts/build.sh torch - **裸写网络**: ## 相关依赖: @@ -32,5 +31,8 @@ predictor->update(noisy_keys, noisy_rewards); ## 额外依赖: +### 使用PaddleLite +下载PaddleLite的X86预编译库,或者编译PaddleLite源码,得到inference_lite_lib文件夹,放在当前目录中。(可参考:[PaddleLite使用X86预测部署](https://paddle-lite.readthedocs.io/zh/latest/demo_guides/x86.html)) + ### 使用torch 下载[libtorch](https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.4.0%2Bcpu.zip)或者编译torch源码,得到libtorch文件夹,放在当前目录中。 diff --git a/deepes/benchmark/cartpole.h b/deepes/benchmark/cartpole.h index 6935f8ddb3a058444945c3dab08be088a0152454..48d2e08515ed3e105e6a700caf7dfa889fbce0b3 100644 --- a/deepes/benchmark/cartpole.h +++ b/deepes/benchmark/cartpole.h @@ -1,7 +1,9 @@ // Third party code // This code is copied or modified from openai/gym's cartpole.py - -#include +#include +#include +#include +#include const double kPi = 3.1415926535898; @@ -21,13 +23,13 @@ public: double x_threshold = 2.4; int steps_beyond_done = -1; - torch::Tensor state; + std::vector state = {0, 0, 0, 0}; double reward; bool done; int step_ = 0; - torch::Tensor getState() { - return state; + const float* getState() { + return state.data(); } double getReward() { @@ -39,7 +41,13 @@ public: } void reset() { - state = torch::empty({ 4 }).uniform_(-0.05, 0.05); + std::random_device rd; + std::default_random_engine generator(rd()); + std::uniform_real_distribution distribution(-0.05, 0.05); + for (int i = 0; i < 4; ++i) { + state[i] = distribution(generator); + } + steps_beyond_done = -1; step_ = 0; } @@ -49,10 +57,10 @@ public: } void step(int action) { - auto x = state[0].item(); - auto x_dot = state[1].item(); - auto theta = state[2].item(); - auto theta_dot = state[3].item(); + float x = state[0]; + float x_dot = state[1]; + float theta = state[2]; + float theta_dot = state[3]; auto force = (action == 1) ? force_mag : -force_mag; auto costheta = std::cos(theta); @@ -67,7 +75,8 @@ public: x_dot = x_dot + tau * xacc; theta = theta + tau * theta_dot; theta_dot = theta_dot + tau * thetaacc; - state = torch::tensor({ x, x_dot, theta, theta_dot }); + + state = {x, x_dot, theta, theta_dot}; done = x < -x_threshold || x > x_threshold || theta < -theta_threshold_radians || theta > theta_threshold_radians || @@ -83,7 +92,7 @@ public: } else { if (steps_beyond_done == 0) { - AT_ASSERT(false); // Can't do this + assert(false); // Can't do this } } step_++; diff --git a/deepes/deepes_config.prototxt b/deepes/benchmark/cartpole_config.prototxt similarity index 100% rename from deepes/deepes_config.prototxt rename to deepes/benchmark/cartpole_config.prototxt diff --git a/deepes/demo/paddle/cartpole_init_model/__model__ b/deepes/demo/paddle/cartpole_init_model/__model__ new file mode 100644 index 0000000000000000000000000000000000000000..d08e3f68ba790f0393417035d496c16ed3f8915e Binary files /dev/null and b/deepes/demo/paddle/cartpole_init_model/__model__ differ diff --git a/deepes/demo/paddle/cartpole_init_model/fc_0.b_0 b/deepes/demo/paddle/cartpole_init_model/fc_0.b_0 new file mode 100644 index 0000000000000000000000000000000000000000..9dbc513dffea31e50e053694f0b2bb0a1d855982 Binary files /dev/null and b/deepes/demo/paddle/cartpole_init_model/fc_0.b_0 differ diff --git a/deepes/demo/paddle/cartpole_init_model/fc_0.w_0 b/deepes/demo/paddle/cartpole_init_model/fc_0.w_0 new file mode 100644 index 0000000000000000000000000000000000000000..00221f8eca184598b9476f53bbd7bbc718c4d834 Binary files /dev/null and b/deepes/demo/paddle/cartpole_init_model/fc_0.w_0 differ diff --git a/deepes/demo/paddle/cartpole_init_model/fc_1.b_0 b/deepes/demo/paddle/cartpole_init_model/fc_1.b_0 new file mode 100644 index 0000000000000000000000000000000000000000..998b905201aed8d9b191a4f077e4bf976a0eb682 Binary files /dev/null and b/deepes/demo/paddle/cartpole_init_model/fc_1.b_0 differ diff --git a/deepes/demo/paddle/cartpole_init_model/fc_1.w_0 b/deepes/demo/paddle/cartpole_init_model/fc_1.w_0 new file mode 100644 index 0000000000000000000000000000000000000000..5c30b360b4ab5511cc9ef847522b07c93210b338 Binary files /dev/null and b/deepes/demo/paddle/cartpole_init_model/fc_1.w_0 differ diff --git a/deepes/demo/paddle/cartpole_solver_parallel.cc b/deepes/demo/paddle/cartpole_solver_parallel.cc new file mode 100644 index 0000000000000000000000000000000000000000..a2049c054f1076a83a5ae724c9f1e9ed3be3c499 --- /dev/null +++ b/deepes/demo/paddle/cartpole_solver_parallel.cc @@ -0,0 +1,123 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "cartpole.h" +#include "gaussian_sampling.h" +#include "es_agent.h" +#include "paddle_api.h" + +using namespace DeepES; +using namespace paddle::lite_api; + +const int ITER = 10; + +std::shared_ptr create_paddle_predictor(const std::string& model_dir) { + // 1. Create CxxConfig + CxxConfig config; + config.set_model_dir(model_dir); + config.set_valid_places({ + Place{TARGET(kX86), PRECISION(kFloat)}, + Place{TARGET(kHost), PRECISION(kFloat)} + }); + + // 2. Create PaddlePredictor by CxxConfig + std::shared_ptr predictor = CreatePaddlePredictor(config); + return predictor; +} + +// Use PaddlePredictor of CartPole model to predict the action. +std::vector forward(std::shared_ptr predictor, const float* obs) { + std::unique_ptr input_tensor(std::move(predictor->GetInput(0))); + input_tensor->Resize({1, 4}); + input_tensor->CopyFromCpu(obs); + + predictor->Run(); + + std::vector probs(2, 0.0); + std::unique_ptr output_tensor( + std::move(predictor->GetOutput(0))); + output_tensor->CopyToCpu(probs.data()); + return probs; +} + +int arg_max(const std::vector& vec) { + return static_cast(std::distance(vec.begin(), std::max_element(vec.begin(), vec.end()))); +} + + +float evaluate(CartPole& env, std::shared_ptr agent, bool is_eval=false) { + float total_reward = 0.0; + env.reset(); + const float* obs = env.getState(); + + std::shared_ptr paddle_predictor; + if (is_eval) + paddle_predictor = agent->get_evaluate_predictor(); // For evaluate + else + paddle_predictor = agent->get_sample_predictor(); // For sampling (ES exploring) + + while (true) { + std::vector probs = forward(paddle_predictor, obs); + int act = arg_max(probs); + env.step(act); + float reward = env.getReward(); + bool done = env.isDone(); + total_reward += reward; + if (done) break; + obs = env.getState(); + } + return total_reward; +} + + +int main(int argc, char* argv[]) { + std::vector envs; + for (int i = 0; i < ITER; ++i) { + envs.push_back(CartPole()); + } + + std::shared_ptr paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model"); + std::shared_ptr agent = std::make_shared(paddle_predictor, "../benchmark/cartpole_config.prototxt"); + + std::vector< std::shared_ptr > sampling_agents{ agent }; + for (int i = 0; i < (ITER - 1); ++i) { + sampling_agents.push_back(agent->clone()); + } + + std::vector noisy_keys; + std::vector noisy_rewards(ITER, 0.0f); + noisy_keys.resize(ITER); + + omp_set_num_threads(10); + for (int epoch = 0; epoch < 10000; ++epoch) { +#pragma omp parallel for schedule(dynamic, 1) + for (int i = 0; i < ITER; ++i) { + std::shared_ptr sampling_agent = sampling_agents[i]; + SamplingKey key = sampling_agent->add_noise(); + float reward = evaluate(envs[i], sampling_agent); + + noisy_keys[i] = key; + noisy_rewards[i] = reward; + } + + // NOTE: all parameters of sampling_agents will be updated + agent->update(noisy_keys, noisy_rewards); + + int reward = evaluate(envs[0], agent, true); + LOG(INFO) << "Epoch:" << epoch << " Reward: " << reward; + } +} diff --git a/deepes/demo/paddle/gen_cartpole_init_model.py b/deepes/demo/paddle/gen_cartpole_init_model.py new file mode 100644 index 0000000000000000000000000000000000000000..16d020021e639c360588d417738d3e2c923d9269 --- /dev/null +++ b/deepes/demo/paddle/gen_cartpole_init_model.py @@ -0,0 +1,36 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle import fluid + +def net(obs, act_dim): + hid1_size = act_dim * 10 + hid1 = fluid.layers.fc(obs, size=hid1_size) + prob = fluid.layers.fc(hid1, size=act_dim, act='softmax') + return prob + +if __name__ == '__main__': + obs_dim = 4 + act_dim = 2 + + obs = fluid.layers.data(name="obs", shape=[obs_dim], dtype='float32') + + prob = net(obs, act_dim) + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(fluid.default_startup_program()) + fluid.io.save_inference_model(dirname='cartpole_init_model', + feeded_var_names=['obs'], + target_vars=[prob], + executor=exe) diff --git a/deepes/demo/cartpole_solver_parallel.cpp b/deepes/demo/torch/cartpole_solver_parallel.cc similarity index 67% rename from deepes/demo/cartpole_solver_parallel.cpp rename to deepes/demo/torch/cartpole_solver_parallel.cc index f6f39a6835bdf96f4691d92f71e9f64569e7a89a..4313fd6da2633feddcd86ce1b738bfe9ed8b2a5b 100644 --- a/deepes/demo/cartpole_solver_parallel.cpp +++ b/deepes/demo/torch/cartpole_solver_parallel.cc @@ -20,17 +20,18 @@ #include "cartpole.h" #include "gaussian_sampling.h" #include "model.h" -#include "torch_predictor.h" +#include "es_agent.h" using namespace DeepES; -const int ITER = 100; +const int ITER = 10; -float evaluate(CartPole& env, std::shared_ptr> predictor) { +float evaluate(CartPole& env, std::shared_ptr> agent, bool is_eval=false) { float total_reward = 0.0; env.reset(); - auto obs = env.getState(); + const float* obs = env.getState(); while (true) { - torch::Tensor action = predictor->predict(obs); + torch::Tensor obs_tensor = torch::tensor({obs[0], obs[1], obs[2], obs[3]}); + torch::Tensor action = agent->predict(obs_tensor, is_eval); int act = std::get<1>(action.max(-1)).item(); env.step(act); float reward = env.getReward(); @@ -50,10 +51,11 @@ int main(int argc, char* argv[]) { } auto model = std::make_shared(4, 2); - std::shared_ptr> predictor = std::make_shared>(model, "../deepes_config.prototxt"); - std::vector>> noisy_predictors; - for (int i = 0; i < ITER; ++i) { - noisy_predictors.push_back(predictor->clone()); + std::shared_ptr> agent = std::make_shared>(model, "../benchmark/cartpole_config.prototxt"); + + std::vector>> sampling_agents = {agent}; + for (int i = 0; i < ITER - 1; ++i) { + sampling_agents.push_back(agent->clone()); } std::vector noisy_keys; @@ -63,16 +65,16 @@ int main(int argc, char* argv[]) { for (int epoch = 0; epoch < 10000; ++epoch) { #pragma omp parallel for schedule(dynamic, 1) for (int i = 0; i < ITER; ++i) { - auto noisy_predictor = noisy_predictors[i]; - SamplingKey key = noisy_predictor->add_noise(); - float reward = evaluate(envs[i], noisy_predictor); + auto sampling_agent = sampling_agents[i]; + SamplingKey key = sampling_agent->add_noise(); + float reward = evaluate(envs[i], sampling_agent); noisy_keys[i] = key; noisy_rewards[i] = reward; } - predictor->update(noisy_keys, noisy_rewards); + agent->update(noisy_keys, noisy_rewards); - int reward = evaluate(envs[0], predictor); + int reward = evaluate(envs[0], agent, true); LOG(INFO) << "Epoch:" << epoch << " Reward: " << reward; } } diff --git a/deepes/demo/model.h b/deepes/demo/torch/model.h similarity index 100% rename from deepes/demo/model.h rename to deepes/demo/torch/model.h diff --git a/deepes/include/gaussian_sampling.h b/deepes/include/gaussian_sampling.h index 59c753e279d8575c3dda85ca855a099f0eabe398..82c58e50a1078faec011dba94ef66079479ab289 100644 --- a/deepes/include/gaussian_sampling.h +++ b/deepes/include/gaussian_sampling.h @@ -41,7 +41,7 @@ public: *@return: * success: load configuration successfully or not. */ - int sampling(float* noise, int size); + int sampling(float* noise, int64_t size); /*@brief reconstruct the Gaussion noise given the key. * This function is often used for updating the neuron network parameters in the offline environment. @@ -51,7 +51,7 @@ public: * noise: a pointer pointed to the memory that stores the noise * size: the number of float to be sampled. */ - bool resampling(int key, float* noise, int size); + bool resampling(int key, float* noise, int64_t size); private: float _std; diff --git a/deepes/include/paddle/es_agent.h b/deepes/include/paddle/es_agent.h new file mode 100644 index 0000000000000000000000000000000000000000..382637864f8052cb8683a408297b43379050ff4a --- /dev/null +++ b/deepes/include/paddle/es_agent.h @@ -0,0 +1,102 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DEEPES_PADDLE_ES_AGENT_H_ +#define DEEPES_PADDLE_ES_AGENT_H_ + +#include "paddle_api.h" +#include "optimizer.h" +#include "utils.h" +#include "gaussian_sampling.h" +#include "deepes.pb.h" +#include + + +namespace DeepES { + +/* DeepES agent for PaddleLite. + * Users can use `add_noise` function to add noise to parameters and use `get_sample_predictor` + * function to get a predictor with added noise to explore. + * Then can use `update` function to update parameters based on ES algorithm. + * Users also can `clone` multi agents to sample in multi-thread way. + */ + +typedef paddle::lite_api::PaddlePredictor PaddlePredictor; + +class ESAgent { + public: + ESAgent(); + + ~ESAgent(); + + ESAgent( + std::shared_ptr predictor, + std::string config_path); + + // Return a cloned ESAgent, whose _predictor is same with this->_predictor + // but _sample_predictor is pointed to a newly created object. + // This function mainly used to clone a new ESAgent to do sampling in multi-thread way. + // NOTE: when calling `update` function of current object or cloned one, both of their + // parameters will be updated. Because their _predictor is point to same object. + std::shared_ptr clone(); + + // Update parameters of _predictor + bool update( + std::vector& noisy_keys, + std::vector& noisy_rewards); + + // parameters of _sample_predictor = parameters of _predictor + noise + SamplingKey add_noise(); + + std::shared_ptr get_sampling_method(); + std::shared_ptr get_optimizer(); + std::shared_ptr get_config(); + int64_t get_param_size(); + std::vector get_param_names(); + + // Return paddle predict _sample_predictor (with addded noise) + std::shared_ptr get_sample_predictor(); + + // Return paddle predict _predictor (without addded noise) + std::shared_ptr get_evaluate_predictor(); + + void set_config(std::shared_ptr config); + void set_sampling_method(std::shared_ptr sampling_method); + void set_optimizer(std::shared_ptr optimizer); + void set_param_size(int64_t param_size); + void set_param_names(std::vector param_names); + void set_noise(float* noise); + void set_neg_gradients(float* neg_gradients); + void set_predictor( + std::shared_ptr predictor, + std::shared_ptr sample_predictor); + + private: + std::shared_ptr _predictor; + std::shared_ptr _sample_predictor; + std::shared_ptr _sampling_method; + std::shared_ptr _optimizer; + std::shared_ptr _config; + int64_t _param_size; + std::vector _param_names; + // malloc memory of noise and neg_gradients in advance. + float* _noise; + float* _neg_gradients; + + int64_t _calculate_param_size(); +}; + +} + +#endif /* DEEPES_PADDLE_ES_AGENT_H_ */ diff --git a/deepes/include/sampling_method.h b/deepes/include/sampling_method.h index a23273273decccb988449783176dbc501824bc39..835c8d77294de1befe6ebaf27601eebce4bcfa9f 100644 --- a/deepes/include/sampling_method.h +++ b/deepes/include/sampling_method.h @@ -55,7 +55,7 @@ public: *@return: * success: load configuration successfully or not. */ - virtual int sampling(float* noise, int size)=0; + virtual int sampling(float* noise, int64_t size)=0; /*@brief reconstruct the Gaussion noise given the key. * This function is often used for updating the neuron network parameters in the offline environment. @@ -65,7 +65,7 @@ public: * noise: a pointer pointed to the memory that stores the noise * size: the number of float to be sampled. */ - virtual bool resampling(int key, float* noise, int size)=0; + virtual bool resampling(int key, float* noise, int64_t size)=0; bool set_seed(int seed) { _seed = seed; diff --git a/deepes/include/torch_predictor.h b/deepes/include/torch/es_agent.h similarity index 58% rename from deepes/include/torch_predictor.h rename to deepes/include/torch/es_agent.h index f17f65f0d831f40cbb31a0f21c0c2d7d8c9acd0d..6c4427336f261898ab7c7ac2cce05198df7a6b91 100644 --- a/deepes/include/torch_predictor.h +++ b/deepes/include/torch/es_agent.h @@ -12,8 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef TORCHPREDICTOR_H -#define TORCHPREDICTOR_H +#ifndef TORCH_ESAGENT_H +#define TORCH_ESAGENT_H + #include #include #include "optimizer.h" @@ -23,36 +24,51 @@ namespace DeepES{ -/* DeepES predictor for Torch. +/* DeepES agent for Torch. * Our implemtation is flexible to support any model that subclass torch::nn::Module. - * That is, we can instantiate a preditor by: predictor = Predcitor(model); - * After that, users can clone a predictor for multi-thread processing, add parametric noise for exploration, + * That is, we can instantiate a agent by: es_agent = ESAgent(model); + * After that, users can clone a agent for multi-thread processing, add parametric noise for exploration, * and update the parameteres, according to the evaluation resutls of noisy parameters. * */ template -class Predictor{ +class ESAgent{ public: - Predictor(): _param_size(0){} + ESAgent(): _param_size(0){} + + ~ESAgent() { + delete[] _noise; + delete[] _neg_gradients; + } - Predictor(std::shared_ptr model, std::string config_path): _model(model) { + ESAgent(std::shared_ptr model, std::string config_path): _model(model) { _config = std::make_shared(); load_proto_conf(config_path, *_config); _sampling_method = std::make_shared(); _sampling_method->load_config(*_config); _optimizer = std::make_shared(_config->optimizer().base_lr()); _param_size = 0; - _sampled_model = model; + _sampled_model = model->clone(); param_size(); + + _noise = new float [_param_size]; + _neg_gradients = new float [_param_size]; } - std::shared_ptr clone() { + std::shared_ptr clone() { std::shared_ptr new_model = _model->clone(); - std::shared_ptr new_predictor = std::make_shared(); - new_predictor->set_model(new_model, _model); - new_predictor->set_sampling_method(_sampling_method); - new_predictor->set_param_size(_param_size); - return new_predictor; + std::shared_ptr new_agent = std::make_shared(); + new_agent->set_model(_model, new_model); + new_agent->set_sampling_method(_sampling_method); + new_agent->set_optimizer(_optimizer); + new_agent->set_config(_config); + new_agent->set_param_size(_param_size); + + float* new_noise = new float [_param_size]; + float* new_neg_gradients = new float [_param_size]; + new_agent->set_noise(new_noise); + new_agent->set_neg_gradients(new_neg_gradients); + return new_agent; } void set_config(std::shared_ptr config) { @@ -63,9 +79,9 @@ public: _sampling_method = sampling_method; } - void set_model(std::shared_ptr sampled_model, std::shared_ptr model) { - _sampled_model = sampled_model; + void set_model(std::shared_ptr model, std::shared_ptr sampled_model) { _model = model; + _sampled_model = sampled_model; } std::shared_ptr get_sampling_method() { @@ -80,68 +96,79 @@ public: _optimizer = optimizer; } - void set_param_size(int param_size) { + void set_param_size(int64_t param_size) { _param_size = param_size; } - torch::Tensor predict(const torch::Tensor& x) { - return _sampled_model->forward(x); + void set_noise(float* noise) { + _noise = noise; + } + + void set_neg_gradients(float* neg_gradients) { + _neg_gradients = neg_gradients; + } + + + torch::Tensor predict(const torch::Tensor& x, bool is_eval=false) { + if (is_eval) { + // predict with _model (without addding noise) + return _model->forward(x); + } + else { + // predict with _sampled_model (with adding noise) + return _sampled_model->forward(x); + } } bool update(std::vector& noisy_keys, std::vector& noisy_rewards) { compute_centered_ranks(noisy_rewards); - float* noise = new float [_param_size]; - float* neg_gradients = new float [_param_size]; - memset(neg_gradients, 0, _param_size * sizeof(float)); + + memset(_neg_gradients, 0, _param_size * sizeof(float)); for (int i = 0; i < noisy_keys.size(); ++i) { int key = noisy_keys[i].key(0); float reward = noisy_rewards[i]; - bool success = _sampling_method->resampling(key, noise, _param_size); - for (int j = 0; j < _param_size; ++j) { - neg_gradients[j] += noise[j] * reward; + bool success = _sampling_method->resampling(key, _noise, _param_size); + for (int64_t j = 0; j < _param_size; ++j) { + _neg_gradients[j] += _noise[j] * reward; } } - for (int j = 0; j < _param_size; ++j) { - neg_gradients[j] /= -1.0 * noisy_keys.size(); + for (int64_t j = 0; j < _param_size; ++j) { + _neg_gradients[j] /= -1.0 * noisy_keys.size(); } //update auto params = _model->named_parameters(); - int counter = 0; + int64_t counter = 0; for (auto& param: params) { torch::Tensor tensor = param.value().view({-1}); auto tensor_a = tensor.accessor(); - _optimizer->update(tensor_a, neg_gradients+counter, tensor.size(0)); + _optimizer->update(tensor_a, _neg_gradients+counter, tensor.size(0)); counter += tensor.size(0); } - delete[] noise; - delete[] neg_gradients; } SamplingKey add_noise() { SamplingKey sampling_key; auto sampled_params = _sampled_model->named_parameters(); auto params = _model->named_parameters(); - float* noise = new float [_param_size]; - int key = _sampling_method->sampling(noise, _param_size); + int key = _sampling_method->sampling(_noise, _param_size); sampling_key.add_key(key); - int counter = 0; + int64_t counter = 0; for (auto& param: sampled_params) { torch::Tensor sampled_tensor = param.value().view({-1}); std::string param_name = param.key(); torch::Tensor tensor = params.find(param_name)->view({-1}); auto sampled_tensor_a = sampled_tensor.accessor(); auto tensor_a = tensor.accessor(); - for (int j = 0; j < tensor.size(0); ++j) { - sampled_tensor_a[j] = tensor_a[j] + noise[counter + j]; + for (int64_t j = 0; j < tensor.size(0); ++j) { + sampled_tensor_a[j] = tensor_a[j] + _noise[counter + j]; } counter += tensor.size(0); } - delete[] noise; return sampling_key; } - int param_size() { + int64_t param_size() { if (_param_size == 0) { auto params = _model->named_parameters(); for (auto& param: params) { @@ -158,8 +185,12 @@ private: std::shared_ptr _sampling_method; std::shared_ptr _optimizer; std::shared_ptr _config; - int _param_size; + int64_t _param_size; + // malloc memory of noise and neg_gradients in advance. + float* _noise; + float* _neg_gradients; }; } -#endif + +#endif /* TORCH_ESAGENT_H */ diff --git a/deepes/scripts/build.sh b/deepes/scripts/build.sh index 70017d2487db13f0c11634eb16a5256e21722e21..da9dadb691c28d94d6691d35a98f6777fdd4fa71 100644 --- a/deepes/scripts/build.sh +++ b/deepes/scripts/build.sh @@ -1,23 +1,48 @@ #!/bin/bash -export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH + +if [ $# != 1 ]; then + echo "You must choose one framework (paddle/torch) to compile DeepES." + exit 0 +fi + +if [ $1 = "paddle" ]; then + #---------------paddlelite-------------# + if [ ! -d "./inference_lite_lib" ];then + echo "Cannot find the PaddleLite library: ./inference_lite_lib" + echo "Please put the PaddleLite libraray to current folder according the instruction in README" + exit 1 + fi + + FLAGS=" -DWITH_PADDLE=ON" +elif [ $1 = "torch" ]; then + #---------------libtorch-------------# + if [ ! -d "./libtorch" ];then + echo "Cannot find the torch library: ./libtorch" + echo "Please put the torch libraray to current folder according the instruction in README" + exit 1 + fi + FLAGS=" -DWITH_TORCH=ON" +else + echo "Invalid arguments. [paddle/torch]" + exit 0 +fi + +#export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH #----------------protobuf-------------# cp ./src/proto/deepes.proto ./ protoc deepes.proto --cpp_out ./ mv deepes.pb.h ./include mv deepes.pb.cc ./src - -#---------------libtorch-------------# -if [ ! -d "./libtorch" ];then - echo "Cannot find the torch library: ./libtorch" - echo "Please put the torch libraray to current folder according the instruction in README" - exit 1 -fi +rm deepes.proto #----------------build---------------# +echo ${FLAGS} rm -rf build mkdir build cd build -cmake -DCMAKE_PREFIX_PATH=./libtorch ../ +cmake ../ ${FLAGS} make -j10 + +#-----------------run----------------# ./parallel_main diff --git a/deepes/src/gaussian_sampling.cpp b/deepes/src/gaussian_sampling.cc similarity index 86% rename from deepes/src/gaussian_sampling.cpp rename to deepes/src/gaussian_sampling.cc index 4ad6cf3021ea31cb371ffa59d42378971fae3016..f44dd5abecaa4c45a9c829952a38c2c4c26cf4aa 100644 --- a/deepes/src/gaussian_sampling.cpp +++ b/deepes/src/gaussian_sampling.cc @@ -26,17 +26,17 @@ void GaussianSampling::load_config(const DeepESConfig& config) { set_seed(config.seed()); } -int GaussianSampling::sampling(float* noise, int size) { +int GaussianSampling::sampling(float* noise, int64_t size) { int key = rand(); std::default_random_engine generator(key); std::normal_distribution norm; - for (int i = 0; i < size; ++i) { + for (int64_t i = 0; i < size; ++i) { *(noise + i) = norm(generator) * _std; } return key; } -bool GaussianSampling::resampling(int key, float* noise, int size) { +bool GaussianSampling::resampling(int key, float* noise, int64_t size) { bool success = true; if (noise == nullptr) { success = false; @@ -44,7 +44,7 @@ bool GaussianSampling::resampling(int key, float* noise, int size) { else { std::default_random_engine generator(key); std::normal_distribution norm; - for (int i = 0; i < size; ++i) { + for (int64_t i = 0; i < size; ++i) { *(noise + i) = norm(generator) * _std; } } diff --git a/deepes/src/paddle/es_agent.cc b/deepes/src/paddle/es_agent.cc new file mode 100644 index 0000000000000000000000000000000000000000..13160270be26189ea969f9b7158312e525969435 --- /dev/null +++ b/deepes/src/paddle/es_agent.cc @@ -0,0 +1,213 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "es_agent.h" +#include "paddle_api.h" +#include "optimizer.h" +#include "utils.h" +#include "gaussian_sampling.h" +#include "deepes.pb.h" + + +namespace DeepES { + +typedef paddle::lite_api::PaddlePredictor PaddlePredictor; +typedef paddle::lite_api::Tensor Tensor; +typedef paddle::lite_api::shape_t shape_t; + +int64_t ShapeProduction(const shape_t& shape) { + int64_t res = 1; + for (auto i : shape) res *= i; + return res; +} + +ESAgent::ESAgent() {} + +ESAgent::~ESAgent() { + delete[] _noise; + delete[] _neg_gradients; +} + +ESAgent::ESAgent( + std::shared_ptr predictor, + std::string config_path) { + + _predictor = predictor; + _sample_predictor = predictor->Clone(); + + _config = std::make_shared(); + load_proto_conf(config_path, *_config); + + _sampling_method = std::make_shared(); + _sampling_method->load_config(*_config); + + _optimizer = std::make_shared(_config->optimizer().base_lr()); + + _param_names = _predictor->GetParamNames(); + _param_size = _calculate_param_size(); + + _noise = new float [_param_size]; + _neg_gradients = new float [_param_size]; +} + +std::shared_ptr ESAgent::clone() { + std::shared_ptr new_sample_predictor = _predictor->Clone(); + + std::shared_ptr new_agent = std::make_shared(); + + float* new_noise = new float [_param_size]; + float* new_neg_gradients = new float [_param_size]; + + new_agent->set_predictor(_predictor, new_sample_predictor); + new_agent->set_sampling_method(_sampling_method); + new_agent->set_optimizer(_optimizer); + new_agent->set_config(_config); + new_agent->set_param_size(_param_size); + new_agent->set_param_names(_param_names); + new_agent->set_noise(new_noise); + new_agent->set_neg_gradients(new_neg_gradients); + return new_agent; +} + +bool ESAgent::update( + std::vector& noisy_keys, + std::vector& noisy_rewards) { + compute_centered_ranks(noisy_rewards); + + memset(_neg_gradients, 0, _param_size * sizeof(float)); + for (int i = 0; i < noisy_keys.size(); ++i) { + int key = noisy_keys[i].key(0); + float reward = noisy_rewards[i]; + bool success = _sampling_method->resampling(key, _noise, _param_size); + for (int64_t j = 0; j < _param_size; ++j) { + _neg_gradients[j] += _noise[j] * reward; + } + } + for (int64_t j = 0; j < _param_size; ++j) { + _neg_gradients[j] /= -1.0 * noisy_keys.size(); + } + + //update + int64_t counter = 0; + + for (std::string param_name: _param_names) { + std::unique_ptr tensor = _predictor->GetMutableTensor(param_name); + float* tensor_data = tensor->mutable_data(); + int64_t tensor_size = ShapeProduction(tensor->shape()); + _optimizer->update(tensor_data, _neg_gradients + counter, tensor_size); + counter += tensor_size; + } + +} + +SamplingKey ESAgent::add_noise() { + SamplingKey sampling_key; + int key = _sampling_method->sampling(_noise, _param_size); + sampling_key.add_key(key); + int64_t counter = 0; + + for (std::string param_name: _param_names) { + std::unique_ptr sample_tensor = _sample_predictor->GetMutableTensor(param_name); + std::unique_ptr tensor = _predictor->GetTensor(param_name); + int64_t tensor_size = ShapeProduction(tensor->shape()); + for (int64_t j = 0; j < tensor_size; ++j) { + sample_tensor->mutable_data()[j] = tensor->data()[j] + _noise[counter + j]; + } + counter += tensor_size; + } + + return sampling_key; +} + +std::shared_ptr ESAgent::get_sampling_method() { + return _sampling_method; +} + +std::shared_ptr ESAgent::get_optimizer() { + return _optimizer; +} + +std::shared_ptr ESAgent::get_config() { + return _config; +} + + +int64_t ESAgent::get_param_size() { + return _param_size; +} + +std::vector ESAgent::get_param_names() { + return _param_names; +} + + +std::shared_ptr ESAgent::get_sample_predictor() { + return _sample_predictor; +} + +std::shared_ptr ESAgent::get_evaluate_predictor() { + return _predictor; +} + + +void ESAgent::set_predictor( + std::shared_ptr predictor, + std::shared_ptr sample_predictor) { + _predictor = predictor; + _sample_predictor = sample_predictor; +} + +void ESAgent::set_sampling_method(std::shared_ptr sampling_method) { + _sampling_method = sampling_method; +} + +void ESAgent::set_optimizer(std::shared_ptr optimizer) { + _optimizer = optimizer; +} + +void ESAgent::set_config(std::shared_ptr config) { + _config = config; +} + +void ESAgent::set_param_size(int64_t param_size) { + _param_size = param_size; +} + +void ESAgent::set_param_names(std::vector param_names) { + _param_names = param_names; +} + +void ESAgent::set_noise(float* noise) { + _noise = noise; +} + +void ESAgent::set_neg_gradients(float* neg_gradients) { + _neg_gradients = neg_gradients; +} + + +int64_t ESAgent::_calculate_param_size() { + int64_t param_size = 0; + for (std::string param_name: _param_names) { + std::unique_ptr tensor = _predictor->GetTensor(param_name); + param_size += ShapeProduction(tensor->shape()); + } + return param_size; +} + + +} + diff --git a/deepes/src/utils.cpp b/deepes/src/utils.cc similarity index 100% rename from deepes/src/utils.cpp rename to deepes/src/utils.cc