diff --git a/deepes/CMakeLists.txt b/deepes/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..f4c66fbaae4244edceb55f89dcd6d7d32bdec09f --- /dev/null +++ b/deepes/CMakeLists.txt @@ -0,0 +1,26 @@ +cmake_minimum_required (VERSION 2.6) +project (DeepES) + +find_package(OpenMP) +if (OPENMP_FOUND) + set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") + set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") +endif() + +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) +find_package(Torch REQUIRED ON) + +file(GLOB demo_src "demo/*.cpp") +file(GLOB core_src "src/*.cpp") +file(GLOB pb_src "src/*.cc") + +include_directories("include") +include_directories("demo") +include_directories("benchmark") +link_directories("/usr/lib/x86_64-linux-gnu/") + +add_executable(parallel_main "./demo/cartpole_solver_parallel.cpp" ${core_src} ${pb_src} ${benchmark_src}) +target_link_libraries(parallel_main gflags protobuf pthread glog "${TORCH_LIBRARIES}") diff --git a/deepes/DeepES.gif b/deepes/DeepES.gif new file mode 100644 index 0000000000000000000000000000000000000000..7240118f3fce55b587690450e0c9cafc2f0694db Binary files /dev/null and b/deepes/DeepES.gif differ diff --git a/deepes/README.md b/deepes/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ae14819304b173344dc62a8281025b91ef679d41 --- /dev/null +++ b/deepes/README.md @@ -0,0 +1,36 @@ +# DeepES工具 +DeepES是一个支持**快速验证**ES效果、**兼容多个框架**的C++库。 +

+PARL +

+ +## 使用示范 +```c++ +//实例化一个预测,根据配置文件加载模型,采样方式(Gaussian\CMA sampling..)、更新方式(SGD\Adam)等 +auto predictor = Predicotr(config); + +for (int i = 0; i < 100; ++i) { + auto noisy_predictor = predictor->clone(); // copy 一份参数 + int key = noisy_predictor->add_noise(); // 参数扰动,同时保存随机种子 + int reward = evaluate(env, noisiy_predictor); //评估参数 + noisy_keys.push_back(key); // 记录随机噪声对应种子 + noisy_rewards.push_back(reward); // 记录评估结果 +} +//根据评估结果、随机种子更新参数,然后重复以上过程,直到收敛。 +predictor->update(noisy_keys, noisy_rewards); +``` + +## 一键运行demo列表 +- **Torch**: sh [./scripts/build.sh](./scripts/build.sh) +- **Paddle**: +- **裸写网络**: + +## 相关依赖: +- Protobuf >= 2.4.2 +- glog +- gflag + +## 额外依赖: + +### 使用torch +下载[libtorch](https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.4.0%2Bcpu.zip)或者编译torch源码,得到libtorch文件夹,放在当前目录中。 diff --git a/deepes/benchmark/cartpole.h b/deepes/benchmark/cartpole.h new file mode 100644 index 0000000000000000000000000000000000000000..6935f8ddb3a058444945c3dab08be088a0152454 --- /dev/null +++ b/deepes/benchmark/cartpole.h @@ -0,0 +1,91 @@ +// Third party code +// This code is copied or modified from openai/gym's cartpole.py + +#include + +const double kPi = 3.1415926535898; + +class CartPole { +public: + double gravity = 9.8; + double masscart = 1.0; + double masspole = 0.1; + double total_mass = (masspole + masscart); + double length = 0.5; // actually half the pole's length; + double polemass_length = (masspole * length); + double force_mag = 10.0; + double tau = 0.02; // seconds between state updates; + + // Angle at which to fail the episode + double theta_threshold_radians = 12 * 2 * kPi / 360; + double x_threshold = 2.4; + int steps_beyond_done = -1; + + torch::Tensor state; + double reward; + bool done; + int step_ = 0; + + torch::Tensor getState() { + return state; + } + + double getReward() { + return reward; + } + + double isDone() { + return done; + } + + void reset() { + state = torch::empty({ 4 }).uniform_(-0.05, 0.05); + steps_beyond_done = -1; + step_ = 0; + } + + CartPole() { + reset(); + } + + void step(int action) { + auto x = state[0].item(); + auto x_dot = state[1].item(); + auto theta = state[2].item(); + auto theta_dot = state[3].item(); + + auto force = (action == 1) ? force_mag : -force_mag; + auto costheta = std::cos(theta); + auto sintheta = std::sin(theta); + auto temp = (force + polemass_length * theta_dot * theta_dot * sintheta) / + total_mass; + auto thetaacc = (gravity * sintheta - costheta * temp) / + (length * (4.0 / 3.0 - masspole * costheta * costheta / total_mass)); + auto xacc = temp - polemass_length * thetaacc * costheta / total_mass; + + x = x + tau * x_dot; + x_dot = x_dot + tau * xacc; + theta = theta + tau * theta_dot; + theta_dot = theta_dot + tau * thetaacc; + state = torch::tensor({ x, x_dot, theta, theta_dot }); + + done = x < -x_threshold || x > x_threshold || + theta < -theta_threshold_radians || theta > theta_threshold_radians || + step_ > 200; + + if (!done) { + reward = 1.0; + } + else if (steps_beyond_done == -1) { + // Pole just fell! + steps_beyond_done = 0; + reward = 0; + } + else { + if (steps_beyond_done == 0) { + AT_ASSERT(false); // Can't do this + } + } + step_++; + } +}; diff --git a/deepes/deepes_config.prototxt b/deepes/deepes_config.prototxt new file mode 100644 index 0000000000000000000000000000000000000000..db2608f0638b16b5793a6d40221ae69c17953ed8 --- /dev/null +++ b/deepes/deepes_config.prototxt @@ -0,0 +1,10 @@ +seed : 1024 + +gaussian_sampling { + std: 0.3 +} + +optimizer { + type: "SGD", + base_lr: 1e-2 +} diff --git a/deepes/demo/cartpole_solver_parallel.cpp b/deepes/demo/cartpole_solver_parallel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f6f39a6835bdf96f4691d92f71e9f64569e7a89a --- /dev/null +++ b/deepes/demo/cartpole_solver_parallel.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "cartpole.h" +#include "gaussian_sampling.h" +#include "model.h" +#include "torch_predictor.h" + +using namespace DeepES; +const int ITER = 100; + +float evaluate(CartPole& env, std::shared_ptr> predictor) { + float total_reward = 0.0; + env.reset(); + auto obs = env.getState(); + while (true) { + torch::Tensor action = predictor->predict(obs); + int act = std::get<1>(action.max(-1)).item(); + env.step(act); + float reward = env.getReward(); + auto done = env.isDone(); + total_reward += reward; + if (done) break; + obs = env.getState(); + } + return total_reward; +} + +int main(int argc, char* argv[]) { + //google::InitGoogleLogging(argv[0]); + std::vector envs; + for (int i = 0; i < ITER; ++i) { + envs.push_back(CartPole()); + } + + auto model = std::make_shared(4, 2); + std::shared_ptr> predictor = std::make_shared>(model, "../deepes_config.prototxt"); + std::vector>> noisy_predictors; + for (int i = 0; i < ITER; ++i) { + noisy_predictors.push_back(predictor->clone()); + } + + std::vector noisy_keys; + std::vector noisy_rewards(ITER, 0.0f); + noisy_keys.resize(ITER); + + for (int epoch = 0; epoch < 10000; ++epoch) { +#pragma omp parallel for schedule(dynamic, 1) + for (int i = 0; i < ITER; ++i) { + auto noisy_predictor = noisy_predictors[i]; + SamplingKey key = noisy_predictor->add_noise(); + float reward = evaluate(envs[i], noisy_predictor); + noisy_keys[i] = key; + noisy_rewards[i] = reward; + } + + predictor->update(noisy_keys, noisy_rewards); + + int reward = evaluate(envs[0], predictor); + LOG(INFO) << "Epoch:" << epoch << " Reward: " << reward; + } +} diff --git a/deepes/demo/model.h b/deepes/demo/model.h new file mode 100644 index 0000000000000000000000000000000000000000..27373ceffd66bffd9d8a047a2e4fc5fe3a14005a --- /dev/null +++ b/deepes/demo/model.h @@ -0,0 +1,61 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef _MODEL_H +#define _MODEL_H + +#include + +struct Model : public torch::nn::Module{ + + Model() = delete; + + Model(const int obs_dim, const int act_dim) { + + _obs_dim = obs_dim; + _act_dim = act_dim; + int hid1_size = act_dim * 10; + fc1 = register_module("fc1", torch::nn::Linear(obs_dim, hid1_size)); + fc2 = register_module("fc2", torch::nn::Linear(hid1_size, act_dim)); + } + + torch::Tensor forward(torch::Tensor x) { + x = x.reshape({-1, _obs_dim}); + x = torch::tanh(fc1->forward(x)); + x = torch::softmax(fc2->forward(x), 1); + return x; + } + + std::shared_ptr clone() { + std::shared_ptr model = std::make_shared(_obs_dim, _act_dim); + std::vector parameters1 = parameters(); + std::vector parameters2 = model->parameters(); + for (int i = 0; i < parameters1.size(); ++i) { + torch::Tensor src = parameters1[i].view({-1}); + torch::Tensor des = parameters2[i].view({-1}); + auto src_a = src.accessor(); + auto des_a = des.accessor(); + for (int j = 0; j < src.size(0); ++j) { + des_a[j] = src_a[j]; + } + } + return model; + } + + int _act_dim; + int _obs_dim; + torch::nn::Linear fc1{nullptr}, fc2{nullptr}; +}; + +#endif diff --git a/deepes/include/gaussian_sampling.h b/deepes/include/gaussian_sampling.h new file mode 100644 index 0000000000000000000000000000000000000000..59c753e279d8575c3dda85ca855a099f0eabe398 --- /dev/null +++ b/deepes/include/gaussian_sampling.h @@ -0,0 +1,62 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +#ifndef _GAUSSIAN_SAMPLING_H +#define _GAUSSIAN_SAMPLING_H +#include "sampling_method.h" + +namespace DeepES{ + +class GaussianSampling: public SamplingMethod { + +public: + GaussianSampling() {} + + ~GaussianSampling() {} + /*Initialize the sampling algorithm given the config with the protobuf format. + *DeepES library uses only one configuration file for all sampling algorithms. A defalut + configuration file can be found at: . Usally you won't have to modify the configuration items of other algorithms + if you are not using them. + */ + void load_config(const DeepESConfig& config); + + /*@brief add Gaussian noise to the parameter. + * + *@Args: + * param: a pointer pointed to the memory of the parameter. + * size: the number of floats of the parameter. + * noisy_param: The pointer pointed to updated parameter. + * + *@return: + * success: load configuration successfully or not. + */ + int sampling(float* noise, int size); + + /*@brief reconstruct the Gaussion noise given the key. + * This function is often used for updating the neuron network parameters in the offline environment. + * + *@Args: + * key: a unique key associated with the sampled noise. + * noise: a pointer pointed to the memory that stores the noise + * size: the number of float to be sampled. + */ + bool resampling(int key, float* noise, int size); + +private: + float _std; +}; + +} + +#endif diff --git a/deepes/include/optimizer.h b/deepes/include/optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..6aca7588d7e58a3aacbb2d25a5d0c145724a2a25 --- /dev/null +++ b/deepes/include/optimizer.h @@ -0,0 +1,63 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPTIMIZER_H +#define OPTIMIZER_H +namespace DeepES{ + +/* Base class for optimizers. Subclsses are required to implement the following functions: + * 1. compute_steps + */ + +class Optimizer { +public: + Optimizer() : _base_lr(1e-3), _update_times(0) {} + Optimizer(float base_lr) : _base_lr(base_lr), _update_times(0) {} + template + bool update(T weights, float* gradient, int size, std::string param_name="") { + bool success = true; + ++_update_times; + compute_step(gradient, size, param_name); + for (int i = 0; i < size; ++i) { + weights[i] -= _base_lr * gradient[i]; + } + return success; + } // template function + +protected: + virtual void compute_step(float* graident, int size, std::string param_name="") = 0; + float _base_lr; + float _update_times; +}; + +class SGDOptimizer: public Optimizer { +public: + SGDOptimizer(float base_lr, float momentum=0.0):Optimizer(base_lr), _momentum(momentum) {} + +protected: + void compute_step(float* gradient, int size, std::string param_name="") { + } + +private: + float _momentum; + +}; //class + +//class AdamOptimizer: public Optimizer { +//public: +// AdamOptimizer(float base) +//}; + +}//namespace +#endif diff --git a/deepes/include/sampling_method.h b/deepes/include/sampling_method.h new file mode 100644 index 0000000000000000000000000000000000000000..a23273273decccb988449783176dbc501824bc39 --- /dev/null +++ b/deepes/include/sampling_method.h @@ -0,0 +1,86 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef _SAMPLING_METHOD_H +#define _SAMPLING_METHOD_H + +#include +#include +#include "deepes.pb.h" + +namespace DeepES{ + +/*Base class for sampling algorithms. All algorithms are required to override the following functions: + * + * 1. load_config + * 2. sampling + * 3. resampling + * + * View an demostrative algorithm in gaussian_sampling.h + * */ + +class SamplingMethod{ + +public: + + SamplingMethod(): _seed(0) {} + + virtual ~SamplingMethod() {} + + /*Initialize the sampling algorithm given the config with the protobuf format. + *DeepES library uses only one configuration file for all sampling algorithms. A defalut + configuration file can be found at: . Usally you won't have to modify the configuration items of other algorithms + if you are not using them. + */ + virtual void load_config(const DeepESConfig& config)=0; + + /*@brief add Gaussian noise to the parameter. + * + *@Args: + * param: a pointer pointed to the memory of the parameter. + * size: the number of floats of the parameter. + * noisy_param: The pointer pointed to updated parameter. + * + *@return: + * success: load configuration successfully or not. + */ + virtual int sampling(float* noise, int size)=0; + + /*@brief reconstruct the Gaussion noise given the key. + * This function is often used for updating the neuron network parameters in the offline environment. + * + *@Args: + * key: a unique key associated with the sampled noise. + * noise: a pointer pointed to the memory that stores the noise + * size: the number of float to be sampled. + */ + virtual bool resampling(int key, float* noise, int size)=0; + + bool set_seed(int seed) { + _seed = seed; + srand(_seed); + return true; + } + + int get_seed() { + return _seed; + } + +protected: + int _seed; + +}; + +} +#endif diff --git a/deepes/include/torch_predictor.h b/deepes/include/torch_predictor.h new file mode 100644 index 0000000000000000000000000000000000000000..f17f65f0d831f40cbb31a0f21c0c2d7d8c9acd0d --- /dev/null +++ b/deepes/include/torch_predictor.h @@ -0,0 +1,165 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TORCHPREDICTOR_H +#define TORCHPREDICTOR_H +#include +#include +#include "optimizer.h" +#include "utils.h" +#include "gaussian_sampling.h" +#include "deepes.pb.h" + +namespace DeepES{ + +/* DeepES predictor for Torch. + * Our implemtation is flexible to support any model that subclass torch::nn::Module. + * That is, we can instantiate a preditor by: predictor = Predcitor(model); + * After that, users can clone a predictor for multi-thread processing, add parametric noise for exploration, + * and update the parameteres, according to the evaluation resutls of noisy parameters. + * + */ +template +class Predictor{ +public: + Predictor(): _param_size(0){} + + Predictor(std::shared_ptr model, std::string config_path): _model(model) { + _config = std::make_shared(); + load_proto_conf(config_path, *_config); + _sampling_method = std::make_shared(); + _sampling_method->load_config(*_config); + _optimizer = std::make_shared(_config->optimizer().base_lr()); + _param_size = 0; + _sampled_model = model; + param_size(); + } + + std::shared_ptr clone() { + std::shared_ptr new_model = _model->clone(); + std::shared_ptr new_predictor = std::make_shared(); + new_predictor->set_model(new_model, _model); + new_predictor->set_sampling_method(_sampling_method); + new_predictor->set_param_size(_param_size); + return new_predictor; + } + + void set_config(std::shared_ptr config) { + _config = config; + } + + void set_sampling_method(std::shared_ptr sampling_method) { + _sampling_method = sampling_method; + } + + void set_model(std::shared_ptr sampled_model, std::shared_ptr model) { + _sampled_model = sampled_model; + _model = model; + } + + std::shared_ptr get_sampling_method() { + return _sampling_method; + } + + std::shared_ptr get_optimizer() { + return _optimizer; + } + + void set_optimizer(std::shared_ptr optimizer) { + _optimizer = optimizer; + } + + void set_param_size(int param_size) { + _param_size = param_size; + } + + torch::Tensor predict(const torch::Tensor& x) { + return _sampled_model->forward(x); + } + + bool update(std::vector& noisy_keys, std::vector& noisy_rewards) { + compute_centered_ranks(noisy_rewards); + float* noise = new float [_param_size]; + float* neg_gradients = new float [_param_size]; + memset(neg_gradients, 0, _param_size * sizeof(float)); + for (int i = 0; i < noisy_keys.size(); ++i) { + int key = noisy_keys[i].key(0); + float reward = noisy_rewards[i]; + bool success = _sampling_method->resampling(key, noise, _param_size); + for (int j = 0; j < _param_size; ++j) { + neg_gradients[j] += noise[j] * reward; + } + } + for (int j = 0; j < _param_size; ++j) { + neg_gradients[j] /= -1.0 * noisy_keys.size(); + } + + //update + auto params = _model->named_parameters(); + int counter = 0; + for (auto& param: params) { + torch::Tensor tensor = param.value().view({-1}); + auto tensor_a = tensor.accessor(); + _optimizer->update(tensor_a, neg_gradients+counter, tensor.size(0)); + counter += tensor.size(0); + } + delete[] noise; + delete[] neg_gradients; + } + + SamplingKey add_noise() { + SamplingKey sampling_key; + auto sampled_params = _sampled_model->named_parameters(); + auto params = _model->named_parameters(); + float* noise = new float [_param_size]; + int key = _sampling_method->sampling(noise, _param_size); + sampling_key.add_key(key); + int counter = 0; + for (auto& param: sampled_params) { + torch::Tensor sampled_tensor = param.value().view({-1}); + std::string param_name = param.key(); + torch::Tensor tensor = params.find(param_name)->view({-1}); + auto sampled_tensor_a = sampled_tensor.accessor(); + auto tensor_a = tensor.accessor(); + for (int j = 0; j < tensor.size(0); ++j) { + sampled_tensor_a[j] = tensor_a[j] + noise[counter + j]; + } + counter += tensor.size(0); + } + delete[] noise; + return sampling_key; + } + + int param_size() { + if (_param_size == 0) { + auto params = _model->named_parameters(); + for (auto& param: params) { + torch::Tensor tensor = param.value().view({-1}); + _param_size += tensor.size(0); + } + } + return _param_size; + } + +private: + std::shared_ptr _sampled_model; + std::shared_ptr _model; + std::shared_ptr _sampling_method; + std::shared_ptr _optimizer; + std::shared_ptr _config; + int _param_size; +}; + +} +#endif diff --git a/deepes/include/utils.h b/deepes/include/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..6733e7ccb765194f4df65bfcdbd470f4a09fbb0c --- /dev/null +++ b/deepes/include/utils.h @@ -0,0 +1,67 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef UTILS_H +#define UTILS_H +#include +#include +#include +#include +#include "deepes.pb.h" +#include + +namespace DeepES{ + +/*Return ranks that is normliazed to [-0.5, 0.5] with the rewards as input. + Args: + reward: an array of rewards +*/ +void compute_centered_ranks(std::vector &reward) ; + +/* Load a protobuf-based configuration from the file. + * Args: + * config_file: file path. + * proto_config: protobuff message for configuration. + * return + */ +template +bool load_proto_conf(const std::string& config_file, T& proto_config) { + bool success = true; + std::ifstream fin(config_file); + CHECK(fin) << "open config file " << config_file; + if (fin.fail()) { + LOG(FATAL) << "open prototxt config failed: " << config_file; + success = false; + } else { + fin.seekg(0, std::ios::end); + size_t file_size = fin.tellg(); + fin.seekg(0, std::ios::beg); + + char* file_content_buffer = new char[file_size]; + fin.read(file_content_buffer, file_size); + + std::string proto_str(file_content_buffer, file_size); + if (!google::protobuf::TextFormat::ParseFromString(proto_str, &proto_config)) { + LOG(FATAL) << "Failed to load config: " << config_file; + return -1; + } + delete[] file_content_buffer; + fin.close(); + } + return success; +} + +} + +#endif diff --git a/deepes/scripts/build.sh b/deepes/scripts/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..70017d2487db13f0c11634eb16a5256e21722e21 --- /dev/null +++ b/deepes/scripts/build.sh @@ -0,0 +1,23 @@ +#!/bin/bash +export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH + +#----------------protobuf-------------# +cp ./src/proto/deepes.proto ./ +protoc deepes.proto --cpp_out ./ +mv deepes.pb.h ./include +mv deepes.pb.cc ./src + +#---------------libtorch-------------# +if [ ! -d "./libtorch" ];then + echo "Cannot find the torch library: ./libtorch" + echo "Please put the torch libraray to current folder according the instruction in README" + exit 1 +fi + +#----------------build---------------# +rm -rf build +mkdir build +cd build +cmake -DCMAKE_PREFIX_PATH=./libtorch ../ +make -j10 +./parallel_main diff --git a/deepes/src/gaussian_sampling.cpp b/deepes/src/gaussian_sampling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4ad6cf3021ea31cb371ffa59d42378971fae3016 --- /dev/null +++ b/deepes/src/gaussian_sampling.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include "gaussian_sampling.h" +#include "utils.h" + +namespace DeepES{ + +void GaussianSampling::load_config(const DeepESConfig& config) { + _std = config.gaussian_sampling().std(); + set_seed(config.seed()); +} + +int GaussianSampling::sampling(float* noise, int size) { + int key = rand(); + std::default_random_engine generator(key); + std::normal_distribution norm; + for (int i = 0; i < size; ++i) { + *(noise + i) = norm(generator) * _std; + } + return key; +} + +bool GaussianSampling::resampling(int key, float* noise, int size) { + bool success = true; + if (noise == nullptr) { + success = false; + } + else { + std::default_random_engine generator(key); + std::normal_distribution norm; + for (int i = 0; i < size; ++i) { + *(noise + i) = norm(generator) * _std; + } + } + return success; +} + +} diff --git a/deepes/src/proto/deepes.proto b/deepes/src/proto/deepes.proto new file mode 100644 index 0000000000000000000000000000000000000000..26a97a52f8001c268f00fa477159675b66668a50 --- /dev/null +++ b/deepes/src/proto/deepes.proto @@ -0,0 +1,40 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package DeepES; + +message DeepESConfig { + //sampling configuration + optional int32 seed = 1 [default = 18]; + optional int32 buffer_size = 2 [default = 100000]; + optional GaussianSamplingConfig gaussian_sampling = 3; + // Optimizer Configuration + optional OptimizerConfig optimizer = 4; +} + +message GaussianSamplingConfig { + optional float std = 1 [default = 1.0]; +} + +message OptimizerConfig{ + optional string type = 1 [default = "SGD"]; + optional float base_lr = 2; // The base learning rate + optional float momentum = 3; // The momentum value. +} + +message SamplingKey{ + repeated int32 key = 1; +} diff --git a/deepes/src/utils.cpp b/deepes/src/utils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5a2e8c37fd2967bd378230858e941d1b8a684b89 --- /dev/null +++ b/deepes/src/utils.cpp @@ -0,0 +1,36 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "utils.h" + +namespace DeepES { + +void compute_centered_ranks(std::vector &reward) { + std::vector> reward_index; + float gap = 1.0 / (reward.size() - 1); + float normlized_rank = -0.5; + int id = 0; + for (auto& rew: reward) { + reward_index.push_back(std::make_pair(rew, id)); + ++id; + } + std::sort(reward_index.begin(), reward_index.end()); + for (int i = 0; i < reward.size(); ++i) { + id = reward_index[i].second; + reward[id] = normlized_rank; + normlized_rank += gap; + } +} + +}//namespace