未验证 提交 2d21e5d6 编写于 作者: B Bo Zhou 提交者: GitHub

Evo kit (#244)

* rename namespace

* update torch code

* rename deepes to evokit

* update readme

* Update CMakeLists.txt

* update build.sh
上级 6d23261a
...@@ -134,17 +134,17 @@ EOF ...@@ -134,17 +134,17 @@ EOF
rm -rf ${REPO_ROOT}/build rm -rf ${REPO_ROOT}/build
} }
function run_deepes_test { function run_evo_kit_test {
cd ${REPO_ROOT}/deepes cd ${REPO_ROOT}/evo_kit
cat <<EOF cat <<EOF
======================================== ========================================
Running DeepES test... Running evo_kit test...
======================================== ========================================
EOF EOF
sh test/run_test.sh sh test/run_test.sh
rm -rf ${REPO_ROOT}/deepes/build rm -rf ${REPO_ROOT}/evo_kit/build
rm -rf ${REPO_ROOT}/deepes/libtorch rm -rf ${REPO_ROOT}/evo_kit/libtorch
} }
function main() { function main() {
...@@ -189,7 +189,7 @@ function main() { ...@@ -189,7 +189,7 @@ function main() {
/root/miniconda3/envs/empty_env/bin/pip install . /root/miniconda3/envs/empty_env/bin/pip install .
run_import_test run_import_test
run_docs_test run_docs_test
run_deepes_test run_evo_kit_test
;; ;;
*) *)
print_usage print_usage
......
cmake_minimum_required (VERSION 2.6) cmake_minimum_required (VERSION 2.6)
project (DeepES) project (EvoKit)
set(TARGET parallel_main) set(TARGET parallel_main)
########## options ########## ########## options ##########
option(WITH_PADDLE "Compile DeepES with PaddleLite framework." OFF) option(WITH_PADDLE "Compile EvoKit with PaddleLite framework." OFF)
option(WITH_TORCH "Compile DeepES with Torch framework." OFF) option(WITH_TORCH "Compile EvoKit with Torch framework." OFF)
message("WITH_PADDLE: "${WITH_PADDLE}) message("WITH_PADDLE: "${WITH_PADDLE})
message("WITH_TORCH: "${WITH_TORCH}) message("WITH_TORCH: "${WITH_TORCH})
if (NOT (WITH_PADDLE OR WITH_TORCH)) if (NOT (WITH_PADDLE OR WITH_TORCH))
message("ERROR: You should choose at least one framework to compile DeepES.") message("ERROR: You should choose at least one framework to compile EvoKit.")
return() return()
elseif(WITH_PADDLE AND WITH_TORCH) elseif(WITH_PADDLE AND WITH_TORCH)
message("ERROR: You cannot choose more than one framework to compile DeepES.") message("ERROR: You cannot choose more than one framework to compile EvoKit.")
return() return()
endif() endif()
...@@ -29,8 +29,9 @@ if (OPENMP_FOUND) ...@@ -29,8 +29,9 @@ if (OPENMP_FOUND)
endif() endif()
file(GLOB src "core/src/*.cc") file(GLOB src "core/src/*.cc" "core/proto/evo_kit/*.cc")
include_directories("core/include") include_directories("core/include")
include_directories("core/proto")
include_directories("benchmark") include_directories("benchmark")
########## PaddleLite config ########## ########## PaddleLite config ##########
...@@ -57,7 +58,7 @@ elseif (WITH_TORCH) ...@@ -57,7 +58,7 @@ elseif (WITH_TORCH)
file(GLOB framework_src "torch/src/*.cc") file(GLOB framework_src "torch/src/*.cc")
set(demo "${PROJECT_SOURCE_DIR}/demo/torch/cartpole_solver_parallel.cc") set(demo "${PROJECT_SOURCE_DIR}/demo/torch/cartpole_solver_parallel.cc")
else () else ()
message("ERROR: You should choose at least one framework to compile DeepES.") message("ERROR: You should choose at least one framework to compile EvoKit.")
endif() endif()
add_executable(${TARGET} ${demo} ${src} ${framework_src}) add_executable(${TARGET} ${demo} ${src} ${framework_src})
......
# DeepES工具 # EvoKit
DeepES是一个支持**快速验证**ES效果、**兼容多个框架**的C++库 EvoKit 是一个集合了多种进化算法、兼容多种类预测框架的进化算法库,主打快速上线验证
<p align="center"> <p align="center">
<img src="DeepES.gif" alt="PARL" width="500"/> <img src="DeepES.gif" alt="PARL" width="500"/>
</p> </p>
......
...@@ -12,14 +12,14 @@ ...@@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef ADAM_OPTIMIZER_H #ifndef EVO_KIT_ADAM_OPTIMIZER_H
#define ADAM_OPTIMIZER_H #define EVO_KIT_ADAM_OPTIMIZER_H
#include <map>
#include <cmath> #include <cmath>
#include "optimizer.h" #include <unordered_map>
#include "evo_kit/optimizer.h"
namespace deep_es { namespace evo_kit {
/*@brief AdamOptimizer. /*@brief AdamOptimizer.
* Implements Adam algorithm. * Implements Adam algorithm.
...@@ -44,8 +44,8 @@ private: ...@@ -44,8 +44,8 @@ private:
float _beta1; float _beta1;
float _beta2; float _beta2;
float _epsilon; float _epsilon;
std::map<std::string, float*> _momentum; std::unordered_map<std::string, float*> _momentum;
std::map<std::string, float*> _velocity; std::unordered_map<std::string, float*> _velocity;
}; };
}//namespace }//namespace
......
...@@ -12,18 +12,18 @@ ...@@ -12,18 +12,18 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// //
#ifndef CACHED_GAUSSIAN_SAMPLING_H #ifndef EVO_KIT_CACHED_GAUSSIAN_SAMPLING_H
#define CACHED_GAUSSIAN_SAMPLING_H #define EVO_KIT_CACHED_GAUSSIAN_SAMPLING_H
#include <glog/logging.h>
#include <random> #include <random>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <time.h> #include <time.h>
#include "sampling_method.h" #include "sampling_method.h"
#include "utils.h" #include "utils.h"
#include <glog/logging.h>
namespace deep_es { namespace evo_kit {
class CachedGaussianSampling: public SamplingMethod { class CachedGaussianSampling: public SamplingMethod {
...@@ -33,12 +33,12 @@ public: ...@@ -33,12 +33,12 @@ public:
~CachedGaussianSampling(); ~CachedGaussianSampling();
/*Initialize the sampling algorithm given the config with the protobuf format. /*Initialize the sampling algorithm given the config with the protobuf format.
*DeepES library uses only one configuration file for all sampling algorithms. *EvoKit library uses only one configuration file for all sampling algorithms.
A defalut configuration file can be found at: . // TODO: where? A defalut configuration file can be found at: . // TODO: where?
Usally you won't have to modify the configuration items of other algorithms Usally you won't have to modify the configuration items of other algorithms
if you are not using them. if you are not using them.
*/ */
bool load_config(const DeepESConfig& config); bool load_config(const EvoKitConfig& config);
/*@brief generate Gaussian noise and the related key. /*@brief generate Gaussian noise and the related key.
* *
......
...@@ -12,17 +12,17 @@ ...@@ -12,17 +12,17 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// //
#ifndef GAUSSIAN_SAMPLING_H #ifndef EVO_KIT_GAUSSIAN_SAMPLING_H
#define GAUSSIAN_SAMPLING_H #define EVO_KIT_GAUSSIAN_SAMPLING_H
#include <random> #include <random>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <time.h> #include <time.h>
#include "sampling_method.h" #include "evo_kit/sampling_method.h"
#include "utils.h" #include "evo_kit/utils.h"
namespace deep_es { namespace evo_kit {
class GaussianSampling: public SamplingMethod { class GaussianSampling: public SamplingMethod {
...@@ -32,12 +32,12 @@ public: ...@@ -32,12 +32,12 @@ public:
~GaussianSampling() {} ~GaussianSampling() {}
/*Initialize the sampling algorithm given the config with the protobuf format. /*Initialize the sampling algorithm given the config with the protobuf format.
*DeepES library uses only one configuration file for all sampling algorithms. *EvoKit library uses only one configuration file for all sampling algorithms.
A defalut configuration file can be found at: . // TODO: where? A defalut configuration file can be found at: . // TODO: where?
Usally you won't have to modify the configuration items of other algorithms Usally you won't have to modify the configuration items of other algorithms
if you are not using them. if you are not using them.
*/ */
bool load_config(const DeepESConfig& config); bool load_config(const EvoKitConfig& config);
/*@brief generate Gaussian noise and the related key. /*@brief generate Gaussian noise and the related key.
* *
......
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef OPTIMIZER_H #ifndef EVO_KIT_OPTIMIZER_H
#define OPTIMIZER_H #define EVO_KIT_OPTIMIZER_H
#include <map>
#include <glog/logging.h> #include <glog/logging.h>
#include <unordered_map>
namespace deep_es { namespace evo_kit {
/*@brief Optimizer. Base class for optimizers. /*@brief Optimizer. Base class for optimizers.
* *
...@@ -71,7 +71,7 @@ protected: ...@@ -71,7 +71,7 @@ protected:
virtual void compute_step(float* graident, int size, std::string param_name = "") = 0; virtual void compute_step(float* graident, int size, std::string param_name = "") = 0;
float _base_lr; float _base_lr;
float _update_times; float _update_times;
std::map<std::string, int> _params_size; std::unordered_map<std::string, int> _params_size;
}; };
......
...@@ -12,18 +12,18 @@ ...@@ -12,18 +12,18 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef OPTIMIZER_FACTORY_H #ifndef EVO_KIT_OPTIMIZER_FACTORY_H
#define OPTIMIZER_FACTORY_H #define EVO_KIT_OPTIMIZER_FACTORY_H
#include <algorithm> #include <algorithm>
#include <memory>
#include "optimizer.h"
#include "sgd_optimizer.h"
#include "adam_optimizer.h"
#include "deepes.pb.h"
#include <glog/logging.h> #include <glog/logging.h>
#include <memory>
#include "evo_kit/adam_optimizer.h"
#include "evo_kit/evo_kit.pb.h"
#include "evo_kit/optimizer.h"
#include "evo_kit/sgd_optimizer.h"
namespace deep_es { namespace evo_kit {
/* @brief: create an optimizer according to the configuration" /* @brief: create an optimizer according to the configuration"
* @args: * @args:
* config: configuration for the optimizer * config: configuration for the optimizer
...@@ -31,6 +31,6 @@ namespace deep_es { ...@@ -31,6 +31,6 @@ namespace deep_es {
*/ */
std::shared_ptr<Optimizer> create_optimizer(const OptimizerConfig& optimizer_config); std::shared_ptr<Optimizer> create_optimizer(const OptimizerConfig& optimizer_config);
}//namespace } // namespace
#endif #endif
...@@ -12,25 +12,25 @@ ...@@ -12,25 +12,25 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef SAMPLING_FACTORY_H #ifndef EVO_KIT_SAMPLING_FACTORY_H
#define SAMPLING_FACTORY_H #define EVO_KIT_SAMPLING_FACTORY_H
#include <algorithm> #include <algorithm>
#include <memory>
#include "sampling_method.h"
#include "gaussian_sampling.h"
#include "cached_gaussian_sampling.h"
#include "deepes.pb.h"
#include <glog/logging.h> #include <glog/logging.h>
#include <memory>
#include "evo_kit/cached_gaussian_sampling.h"
#include "evo_kit/evo_kit.pb.h"
#include "evo_kit/gaussian_sampling.h"
#include "evo_kit/sampling_method.h"
namespace deep_es { namespace evo_kit {
/* @brief: create an sampling_method according to the configuration" /* @brief: create an sampling_method according to the configuration"
* @args: * @args:
* config: configuration for the DeepES * config: configuration for the EvoKit
* *
*/ */
std::shared_ptr<SamplingMethod> create_sampling_method(const DeepESConfig& Config); std::shared_ptr<SamplingMethod> create_sampling_method(const EvoKitConfig& Config);
}//namespace } // namespace
#endif #endif
...@@ -12,14 +12,14 @@ ...@@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef SAMPLING_METHOD_H #ifndef EVO_KIT_SAMPLING_METHOD_H
#define SAMPLING_METHOD_H #define EVO_KIT_SAMPLING_METHOD_H
#include <string> #include <string>
#include <random> #include <random>
#include "deepes.pb.h" #include "evo_kit/evo_kit.pb.h"
namespace deep_es { namespace evo_kit {
/*Base class for sampling algorithms. All algorithms are required to override the following functions: /*Base class for sampling algorithms. All algorithms are required to override the following functions:
* *
...@@ -39,12 +39,12 @@ public: ...@@ -39,12 +39,12 @@ public:
virtual ~SamplingMethod() {} virtual ~SamplingMethod() {}
/*Initialize the sampling algorithm given the config with the protobuf format. /*Initialize the sampling algorithm given the config with the protobuf format.
*DeepES library uses only one configuration file for all sampling algorithms. *EvoKit library uses only one configuration file for all sampling algorithms.
A defalut configuration file can be found at: . // TODO: where? A defalut configuration file can be found at: . // TODO: where?
Usally you won't have to modify the configuration items of other algorithms Usally you won't have to modify the configuration items of other algorithms
if you are not using them. if you are not using them.
*/ */
virtual bool load_config(const DeepESConfig& config) = 0; virtual bool load_config(const EvoKitConfig& config) = 0;
/*@brief generate Gaussian noise and the related key. /*@brief generate Gaussian noise and the related key.
* *
......
...@@ -12,14 +12,14 @@ ...@@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef SGD_OPTIMIZER_H #ifndef EVO_KIT_SGD_OPTIMIZER_H
#define SGD_OPTIMIZER_H #define EVO_KIT_SGD_OPTIMIZER_H
#include <map>
#include <cmath> #include <cmath>
#include "optimizer.h" #include <unordered_map>
#include "evo_kit/optimizer.h"
namespace deep_es { namespace evo_kit {
/*@brief SGDOptimizer. /*@brief SGDOptimizer.
* Implements stochastic gradient descent (optionally with momentum). * Implements stochastic gradient descent (optionally with momentum).
...@@ -38,9 +38,9 @@ protected: ...@@ -38,9 +38,9 @@ protected:
private: private:
float _momentum; float _momentum;
std::map<std::string, float*> _velocity; std::unordered_map<std::string, float*> _velocity;
}; };
} } // namespace
#endif #endif
...@@ -12,18 +12,17 @@ ...@@ -12,18 +12,17 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef UTILS_H #ifndef EVO_KIT_UTILS_H
#define UTILS_H #define EVO_KIT_UTILS_H
#include <string>
#include <fstream>
#include <algorithm> #include <algorithm>
#include <fstream>
#include <glog/logging.h> #include <glog/logging.h>
#include "deepes.pb.h"
#include <google/protobuf/text_format.h> #include <google/protobuf/text_format.h>
#include <fstream> #include <string>
#include "evo_kit/evo_kit.pb.h"
namespace deep_es { namespace evo_kit {
/*Return ranks that is normliazed to [-0.5, 0.5] with the rewards as input. /*Return ranks that is normliazed to [-0.5, 0.5] with the rewards as input.
Args: Args:
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
syntax = "proto2"; syntax = "proto2";
package deep_es; package evo_kit;
message DeepESConfig { message EvoKitConfig {
//sampling configuration //sampling configuration
optional int32 seed = 1 [default = 18]; optional int32 seed = 1 [default = 18];
optional int32 buffer_size = 2 [default = 100000]; optional int32 buffer_size = 2 [default = 100000];
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "adam_optimizer.h" #include "evo_kit/adam_optimizer.h"
namespace deep_es { namespace evo_kit {
AdamOptimizer::~AdamOptimizer() { AdamOptimizer::~AdamOptimizer() {
for (auto iter = _momentum.begin(); iter != _momentum.end(); iter++) { for (auto iter = _momentum.begin(); iter != _momentum.end(); iter++) {
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "cached_gaussian_sampling.h" #include "evo_kit/cached_gaussian_sampling.h"
namespace deep_es { namespace evo_kit {
CachedGaussianSampling::CachedGaussianSampling() {} CachedGaussianSampling::CachedGaussianSampling() {}
...@@ -22,16 +22,16 @@ CachedGaussianSampling::~CachedGaussianSampling() { ...@@ -22,16 +22,16 @@ CachedGaussianSampling::~CachedGaussianSampling() {
delete[] _noise_cache; delete[] _noise_cache;
} }
bool CachedGaussianSampling::load_config(const DeepESConfig& config) { bool CachedGaussianSampling::load_config(const EvoKitConfig& config) {
bool success = true; bool success = true;
_std = config.gaussian_sampling().std(); _std = config.gaussian_sampling().std();
success = set_seed(config.seed()); success = set_seed(config.seed());
CHECK(success) << "[DeepES] Fail to set seed while load config."; CHECK(success) << "[EvoKit] Fail to set seed while load config.";
_cache_size = config.gaussian_sampling().cache_size(); _cache_size = config.gaussian_sampling().cache_size();
_noise_cache = new float [_cache_size]; _noise_cache = new float [_cache_size];
memset(_noise_cache, 0, _cache_size * sizeof(float)); memset(_noise_cache, 0, _cache_size * sizeof(float));
success = _create_noise_cache(); success = _create_noise_cache();
CHECK(success) << "[DeepES] Fail to create noise_cache while load config."; CHECK(success) << "[EvoKit] Fail to create noise_cache while load config.";
return success; return success;
} }
...@@ -39,19 +39,19 @@ bool CachedGaussianSampling::sampling(int* key, float* noise, int64_t size) { ...@@ -39,19 +39,19 @@ bool CachedGaussianSampling::sampling(int* key, float* noise, int64_t size) {
bool success = true; bool success = true;
if (_noise_cache == nullptr) { if (_noise_cache == nullptr) {
LOG(ERROR) << "[DeepES] Please use load_config() first."; LOG(ERROR) << "[EvoKit] Please use load_config() first.";
success = false; success = false;
return success; return success;
} }
if (noise == nullptr) { if (noise == nullptr) {
LOG(ERROR) << "[DeepES] Input noise array cannot be nullptr."; LOG(ERROR) << "[EvoKit] Input noise array cannot be nullptr.";
success = false; success = false;
return success; return success;
} }
if ((size >= _cache_size) || (size < 0)) { if ((size >= _cache_size) || (size < 0)) {
LOG(ERROR) << "[DeepES] Input size " << size << " is out of bounds [0, " << _cache_size << LOG(ERROR) << "[EvoKit] Input size " << size << " is out of bounds [0, " << _cache_size <<
"), cache_size: " << _cache_size; "), cache_size: " << _cache_size;
success = false; success = false;
return success; return success;
...@@ -74,26 +74,27 @@ bool CachedGaussianSampling::resampling(int key, float* noise, int64_t size) { ...@@ -74,26 +74,27 @@ bool CachedGaussianSampling::resampling(int key, float* noise, int64_t size) {
bool success = true; bool success = true;
if (_noise_cache == nullptr) { if (_noise_cache == nullptr) {
LOG(ERROR) << "[DeepES] Please use load_config() first."; LOG(ERROR) << "[EvoKit] Please use load_config() first.";
success = false; success = false;
return success; return success;
} }
if (noise == nullptr) { if (noise == nullptr) {
LOG(ERROR) << "[DeepES] Input noise array cannot be nullptr."; LOG(ERROR) << "[EvoKit] Input noise array cannot be nullptr.";
success = false; success = false;
return success; return success;
} }
if ((size >= _cache_size) || (size < 0)) { if ((size >= _cache_size) || (size < 0)) {
LOG(ERROR) << "[DeepES] Input size " << size << " is out of bounds [0, " << _cache_size << LOG(ERROR) << "[EvoKit] Input size " << size << " is out of bounds [0, " << _cache_size <<
"), cache_size: " << _cache_size; "), cache_size: " << _cache_size;
success = false; success = false;
return success; return success;
} }
if ((key > _cache_size - size) || (key < 0)) { if ((key > _cache_size - size) || (key < 0)) {
LOG(ERROR) << "[DeepES] Resampling key " << key << " is out of bounds [0, " << _cache_size - size << LOG(ERROR) << "[EvoKit] Resampling key " << key << " is out of bounds [0, "
<< _cache_size - size <<
"], cache_size: " << _cache_size << ", size: " << size; "], cache_size: " << _cache_size << ", size: " << size;
success = false; success = false;
return success; return success;
......
...@@ -12,11 +12,11 @@ ...@@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "gaussian_sampling.h" #include "evo_kit/gaussian_sampling.h"
namespace deep_es { namespace evo_kit {
bool GaussianSampling::load_config(const DeepESConfig& config) { bool GaussianSampling::load_config(const EvoKitConfig& config) {
bool success = true; bool success = true;
_std = config.gaussian_sampling().std(); _std = config.gaussian_sampling().std();
success = set_seed(config.seed()); success = set_seed(config.seed());
...@@ -27,7 +27,7 @@ bool GaussianSampling::sampling(int* key, float* noise, int64_t size) { ...@@ -27,7 +27,7 @@ bool GaussianSampling::sampling(int* key, float* noise, int64_t size) {
bool success = true; bool success = true;
if (noise == nullptr) { if (noise == nullptr) {
LOG(ERROR) << "[DeepES] Input noise array cannot be nullptr."; LOG(ERROR) << "[EvoKit] Input noise array cannot be nullptr.";
success = false; success = false;
return success; return success;
} }
...@@ -48,7 +48,7 @@ bool GaussianSampling::resampling(int key, float* noise, int64_t size) { ...@@ -48,7 +48,7 @@ bool GaussianSampling::resampling(int key, float* noise, int64_t size) {
bool success = true; bool success = true;
if (noise == nullptr) { if (noise == nullptr) {
LOG(ERROR) << "[DeepES] Input noise array cannot be nullptr."; LOG(ERROR) << "[EvoKit] Input noise array cannot be nullptr.";
success = false; success = false;
} else { } else {
std::default_random_engine generator(key); std::default_random_engine generator(key);
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "optimizer_factory.h" #include "evo_kit/optimizer_factory.h"
namespace deep_es { namespace evo_kit {
std::shared_ptr<Optimizer> create_optimizer(const OptimizerConfig& optimizer_config) { std::shared_ptr<Optimizer> create_optimizer(const OptimizerConfig& optimizer_config) {
std::shared_ptr<Optimizer> optimizer; std::shared_ptr<Optimizer> optimizer;
......
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "sampling_factory.h" #include "evo_kit/sampling_factory.h"
namespace deep_es { namespace evo_kit {
std::shared_ptr<SamplingMethod> create_sampling_method(const DeepESConfig& config) { std::shared_ptr<SamplingMethod> create_sampling_method(const EvoKitConfig& config) {
std::shared_ptr<SamplingMethod> sampling_method; std::shared_ptr<SamplingMethod> sampling_method;
bool cached = config.gaussian_sampling().cached(); bool cached = config.gaussian_sampling().cached();
...@@ -32,7 +32,7 @@ std::shared_ptr<SamplingMethod> create_sampling_method(const DeepESConfig& confi ...@@ -32,7 +32,7 @@ std::shared_ptr<SamplingMethod> create_sampling_method(const DeepESConfig& confi
if (success) { if (success) {
return sampling_method; return sampling_method;
} else { } else {
LOG(ERROR) << "[DeepES] Fail to create sampling_method"; LOG(ERROR) << "[EvoKit] Fail to create sampling_method";
return nullptr; return nullptr;
} }
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "sgd_optimizer.h" #include "evo_kit/sgd_optimizer.h"
namespace deep_es { namespace evo_kit {
SGDOptimizer::~SGDOptimizer() { SGDOptimizer::~SGDOptimizer() {
for (auto iter = _velocity.begin(); iter != _velocity.end(); iter++) { for (auto iter = _velocity.begin(); iter != _velocity.end(); iter++) {
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "utils.h" #include "evo_kit/utils.h"
#include <dirent.h> #include <dirent.h>
namespace deep_es { namespace evo_kit {
bool compute_centered_ranks(std::vector<float>& reward) { bool compute_centered_ranks(std::vector<float>& reward) {
std::vector<std::pair<float, int>> reward_index; std::vector<std::pair<float, int>> reward_index;
......
seed: 1024 seed: 1024
gaussian_sampling { gaussian_sampling {
std: 0.5 std: 0.5
cached: true cached: true
cache_size : 100000 cache_size: 100000
} }
optimizer { optimizer {
type: "Adam" type: "Adam"
base_lr: 0.05 base_lr: 0.05
...@@ -14,7 +12,6 @@ optimizer { ...@@ -14,7 +12,6 @@ optimizer {
beta2: 0.999 beta2: 0.999
epsilon: 1e-08 epsilon: 1e-08
} }
async_es { async_es {
model_iter_id: 0 model_iter_id: 0
} }
...@@ -15,11 +15,11 @@ ...@@ -15,11 +15,11 @@
#include <algorithm> #include <algorithm>
#include <glog/logging.h> #include <glog/logging.h>
#include <omp.h> #include <omp.h>
#include "evo_kit/async_es_agent.h"
#include "cartpole.h" #include "cartpole.h"
#include "async_es_agent.h"
#include "paddle_api.h" #include "paddle_api.h"
using namespace deep_es; using namespace evo_kit;
using namespace paddle::lite_api; using namespace paddle::lite_api;
const int ITER = 10; const int ITER = 10;
......
...@@ -16,10 +16,10 @@ ...@@ -16,10 +16,10 @@
#include <glog/logging.h> #include <glog/logging.h>
#include <omp.h> #include <omp.h>
#include "cartpole.h" #include "cartpole.h"
#include "es_agent.h" #include "evo_kit/es_agent.h"
#include "paddle_api.h" #include "paddle_api.h"
using namespace deep_es; using namespace evo_kit;
using namespace paddle::lite_api; using namespace paddle::lite_api;
const int ITER = 10; const int ITER = 10;
......
...@@ -17,12 +17,12 @@ ...@@ -17,12 +17,12 @@
#include <algorithm> #include <algorithm>
#include <glog/logging.h> #include <glog/logging.h>
#include <omp.h> #include <omp.h>
#include "evo_kit/gaussian_sampling.h"
#include "evo_kit/es_agent.h"
#include "cartpole.h" #include "cartpole.h"
#include "gaussian_sampling.h"
#include "model.h" #include "model.h"
#include "es_agent.h"
using namespace DeepES; using namespace evo_kit;
const int ITER = 10; const int ITER = 10;
float evaluate(CartPole& env, std::shared_ptr<ESAgent<Model>> agent) { float evaluate(CartPole& env, std::shared_ptr<ESAgent<Model>> agent) {
...@@ -52,7 +52,7 @@ int main(int argc, char* argv[]) { ...@@ -52,7 +52,7 @@ int main(int argc, char* argv[]) {
auto model = std::make_shared<Model>(4, 2); auto model = std::make_shared<Model>(4, 2);
std::shared_ptr<ESAgent<Model>> agent = std::make_shared<ESAgent<Model>>(model, std::shared_ptr<ESAgent<Model>> agent = std::make_shared<ESAgent<Model>>(model,
"../demo/cartpole_config.prototxt"); "./demo/cartpole_config.prototxt");
// Clone agents to sample (explore). // Clone agents to sample (explore).
std::vector<std::shared_ptr<ESAgent<Model>>> sampling_agents; std::vector<std::shared_ptr<ESAgent<Model>>> sampling_agents;
......
...@@ -12,15 +12,15 @@ ...@@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef ASYNC_ES_AGENT_H #ifndef EVO_KIT_ASYNC_ES_AGENT_H
#define ASYNC_ES_AGENT_H #define EVO_KIT_ASYNC_ES_AGENT_H
#include "es_agent.h"
#include <map>
#include <stdlib.h> #include <stdlib.h>
#include <unordered_map>
#include "evo_kit/es_agent.h"
namespace deep_es { namespace evo_kit {
/* DeepES agent with PaddleLite as backend. This agent supports asynchronous update. /* EvoKit agent with PaddleLite as backend. This agent supports asynchronous update.
* Users mainly focus on the following functions: * Users mainly focus on the following functions:
* 1. clone: clone an agent for multi-thread evaluation * 1. clone: clone an agent for multi-thread evaluation
* 2. add_noise: add noise into parameters. * 2. add_noise: add noise into parameters.
...@@ -59,8 +59,8 @@ public: ...@@ -59,8 +59,8 @@ public:
std::vector<float>& noisy_rewards); std::vector<float>& noisy_rewards);
private: private:
std::map<int, std::shared_ptr<PaddlePredictor>> _previous_predictors; std::unordered_map<int, std::shared_ptr<PaddlePredictor>> _previous_predictors;
std::map<int, float*> _param_delta; std::unordered_map<int, float*> _param_delta;
std::string _config_path; std::string _config_path;
/** /**
...@@ -97,5 +97,5 @@ private: ...@@ -97,5 +97,5 @@ private:
std::shared_ptr<PaddlePredictor> _load_previous_model(std::string model_dir); std::shared_ptr<PaddlePredictor> _load_previous_model(std::string model_dir);
}; };
} //namespace } // namespace
#endif #endif
...@@ -12,17 +12,17 @@ ...@@ -12,17 +12,17 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef DEEPES_PADDLE_ES_AGENT_H_ #ifndef EVO_KIT_DEEPES_PADDLE_ES_AGENT_H_
#define DEEPES_PADDLE_ES_AGENT_H_ #define EVO_KIT_DEEPES_PADDLE_ES_AGENT_H_
#include "paddle_api.h"
#include "optimizer_factory.h"
#include "sampling_factory.h"
#include "utils.h"
#include "deepes.pb.h"
#include <vector> #include <vector>
#include "evo_kit/evo_kit.pb.h"
#include "evo_kit/optimizer_factory.h"
#include "evo_kit/sampling_factory.h"
#include "evo_kit/utils.h"
#include "paddle_api.h"
namespace deep_es { namespace evo_kit {
typedef paddle::lite_api::PaddlePredictor PaddlePredictor; typedef paddle::lite_api::PaddlePredictor PaddlePredictor;
typedef paddle::lite_api::CxxConfig CxxConfig; typedef paddle::lite_api::CxxConfig CxxConfig;
...@@ -31,7 +31,7 @@ typedef paddle::lite_api::Tensor Tensor; ...@@ -31,7 +31,7 @@ typedef paddle::lite_api::Tensor Tensor;
int64_t ShapeProduction(const paddle::lite_api::shape_t& shape); int64_t ShapeProduction(const paddle::lite_api::shape_t& shape);
/** /**
* @brief DeepES agent with PaddleLite as backend. * @brief EvoKit agent with PaddleLite as backend.
* Users mainly focus on the following functions: * Users mainly focus on the following functions:
* 1. clone: clone an agent for multi-thread evaluation * 1. clone: clone an agent for multi-thread evaluation
* 2. add_noise: add noise into parameters. * 2. add_noise: add noise into parameters.
...@@ -88,7 +88,7 @@ protected: ...@@ -88,7 +88,7 @@ protected:
std::shared_ptr<PaddlePredictor> _sampling_predictor; std::shared_ptr<PaddlePredictor> _sampling_predictor;
std::shared_ptr<SamplingMethod> _sampling_method; std::shared_ptr<SamplingMethod> _sampling_method;
std::shared_ptr<Optimizer> _optimizer; std::shared_ptr<Optimizer> _optimizer;
std::shared_ptr<DeepESConfig> _config; std::shared_ptr<EvoKitConfig> _config;
std::shared_ptr<CxxConfig> _cxx_config; std::shared_ptr<CxxConfig> _cxx_config;
std::vector<std::string> _param_names; std::vector<std::string> _param_names;
// malloc memory of noise and neg_gradients in advance. // malloc memory of noise and neg_gradients in advance.
...@@ -98,6 +98,6 @@ protected: ...@@ -98,6 +98,6 @@ protected:
bool _is_sampling_agent; bool _is_sampling_agent;
}; };
} } // namespace
#endif /* DEEPES_PADDLE_ES_AGENT_H_ */ #endif
...@@ -12,8 +12,9 @@ ...@@ -12,8 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "async_es_agent.h" #include "evo_kit/async_es_agent.h"
namespace deep_es {
namespace evo_kit {
AsyncESAgent::AsyncESAgent( AsyncESAgent::AsyncESAgent(
const std::string& model_dir, const std::string& model_dir,
...@@ -32,7 +33,8 @@ bool AsyncESAgent::_save() { ...@@ -32,7 +33,8 @@ bool AsyncESAgent::_save() {
bool success = true; bool success = true;
if (_is_sampling_agent) { if (_is_sampling_agent) {
LOG(ERROR) << "[DeepES] Cloned AsyncESAgent cannot call `save`.Please use original AsyncESAgent."; LOG(ERROR) <<
"[EvoKit] Cloned AsyncESAgent cannot call `save`.Please use original AsyncESAgent.";
success = false; success = false;
return success; return success;
} }
...@@ -80,9 +82,9 @@ bool AsyncESAgent::_remove_expired_model(int max_to_keep) { ...@@ -80,9 +82,9 @@ bool AsyncESAgent::_remove_expired_model(int max_to_keep) {
int ret = system(rm_command.c_str()); int ret = system(rm_command.c_str());
if (ret == 0) { if (ret == 0) {
LOG(INFO) << "[DeepES] remove expired Model: " << dir; LOG(INFO) << "[EvoKit] remove expired Model: " << dir;
} else { } else {
LOG(ERROR) << "[DeepES] fail to remove expired Model: " << dir; LOG(ERROR) << "[EvoKit] fail to remove expired Model: " << dir;
success = false; success = false;
return success; return success;
} }
...@@ -132,7 +134,7 @@ bool AsyncESAgent::_load() { ...@@ -132,7 +134,7 @@ bool AsyncESAgent::_load() {
success = model_iter_id == 0 ? true : false; success = model_iter_id == 0 ? true : false;
if (!success) { if (!success) {
LOG(WARNING) << "[DeepES] current_model_iter_id is nonzero, but no model is \ LOG(WARNING) << "[EvoKit] current_model_iter_id is nonzero, but no model is \
found at the dir: " << model_path; found at the dir: " << model_path;
} }
...@@ -143,7 +145,7 @@ bool AsyncESAgent::_load() { ...@@ -143,7 +145,7 @@ bool AsyncESAgent::_load() {
int model_iter_id = _parse_model_iter_id(dir); int model_iter_id = _parse_model_iter_id(dir);
if (model_iter_id == -1) { if (model_iter_id == -1) {
LOG(WARNING) << "[DeepES] fail to parse model_iter_id: " << dir; LOG(WARNING) << "[EvoKit] fail to parse model_iter_id: " << dir;
success = false; success = false;
return success; return success;
} }
...@@ -152,7 +154,7 @@ bool AsyncESAgent::_load() { ...@@ -152,7 +154,7 @@ bool AsyncESAgent::_load() {
if (predictor == nullptr) { if (predictor == nullptr) {
success = false; success = false;
LOG(WARNING) << "[DeepES] fail to load model: " << dir; LOG(WARNING) << "[EvoKit] fail to load model: " << dir;
return success; return success;
} }
...@@ -201,11 +203,11 @@ bool AsyncESAgent::update( ...@@ -201,11 +203,11 @@ bool AsyncESAgent::update(
std::vector<SamplingInfo>& noisy_info, std::vector<SamplingInfo>& noisy_info,
std::vector<float>& noisy_rewards) { std::vector<float>& noisy_rewards) {
CHECK(!_is_sampling_agent) << "[DeepES] Cloned ESAgent cannot call update function. \ CHECK(!_is_sampling_agent) << "[EvoKit] Cloned ESAgent cannot call update function. \
Please use original ESAgent."; Please use original ESAgent.";
bool success = _load(); bool success = _load();
CHECK(success) << "[DeepES] fail to load previous models."; CHECK(success) << "[EvoKit] fail to load previous models.";
int current_model_iter_id = _config->async_es().model_iter_id(); int current_model_iter_id = _config->async_es().model_iter_id();
...@@ -215,7 +217,7 @@ bool AsyncESAgent::update( ...@@ -215,7 +217,7 @@ bool AsyncESAgent::update(
if (model_iter_id != current_model_iter_id if (model_iter_id != current_model_iter_id
&& _previous_predictors.count(model_iter_id) == 0) { && _previous_predictors.count(model_iter_id) == 0) {
LOG(WARNING) << "[DeepES] The sample with model_dir_id: " << model_iter_id \ LOG(WARNING) << "[EvoKit] The sample with model_dir_id: " << model_iter_id \
<< " cannot match any local model"; << " cannot match any local model";
success = false; success = false;
return success; return success;
...@@ -230,7 +232,7 @@ bool AsyncESAgent::update( ...@@ -230,7 +232,7 @@ bool AsyncESAgent::update(
float reward = noisy_rewards[i]; float reward = noisy_rewards[i];
int model_iter_id = noisy_info[i].model_iter_id(); int model_iter_id = noisy_info[i].model_iter_id();
bool success = _sampling_method->resampling(key, _noise, _param_size); bool success = _sampling_method->resampling(key, _noise, _param_size);
CHECK(success) << "[DeepES] resampling error occurs at sample: " << i; CHECK(success) << "[EvoKit] resampling error occurs at sample: " << i;
float* delta = _param_delta[model_iter_id]; float* delta = _param_delta[model_iter_id];
// compute neg_gradients // compute neg_gradients
...@@ -261,7 +263,7 @@ bool AsyncESAgent::update( ...@@ -261,7 +263,7 @@ bool AsyncESAgent::update(
} }
success = _save(); success = _save();
CHECK(success) << "[DeepES] fail to save model."; CHECK(success) << "[EvoKit] fail to save model.";
return true; return true;
} }
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "es_agent.h" #include "evo_kit/es_agent.h"
#include <ctime> #include <ctime>
namespace deep_es { namespace evo_kit {
int64_t ShapeProduction(const paddle::lite_api::shape_t& shape) { int64_t ShapeProduction(const paddle::lite_api::shape_t& shape) {
int64_t res = 1; int64_t res = 1;
...@@ -56,7 +56,7 @@ ESAgent::ESAgent(const std::string& model_dir, const std::string& config_path) { ...@@ -56,7 +56,7 @@ ESAgent::ESAgent(const std::string& model_dir, const std::string& config_path) {
// Original agent can't be used to sample, so keep it same with _predictor for evaluating. // Original agent can't be used to sample, so keep it same with _predictor for evaluating.
_sampling_predictor = _predictor; _sampling_predictor = _predictor;
_config = std::make_shared<DeepESConfig>(); _config = std::make_shared<EvoKitConfig>();
load_proto_conf(config_path, *_config); load_proto_conf(config_path, *_config);
_sampling_method = create_sampling_method(*_config); _sampling_method = create_sampling_method(*_config);
...@@ -72,7 +72,7 @@ ESAgent::ESAgent(const std::string& model_dir, const std::string& config_path) { ...@@ -72,7 +72,7 @@ ESAgent::ESAgent(const std::string& model_dir, const std::string& config_path) {
std::shared_ptr<ESAgent> ESAgent::clone() { std::shared_ptr<ESAgent> ESAgent::clone() {
if (_is_sampling_agent) { if (_is_sampling_agent) {
LOG(ERROR) << "[DeepES] only original ESAgent can call `clone` function."; LOG(ERROR) << "[EvoKit] only original ESAgent can call `clone` function.";
return nullptr; return nullptr;
} }
...@@ -97,7 +97,7 @@ bool ESAgent::update( ...@@ -97,7 +97,7 @@ bool ESAgent::update(
std::vector<SamplingInfo>& noisy_info, std::vector<SamplingInfo>& noisy_info,
std::vector<float>& noisy_rewards) { std::vector<float>& noisy_rewards) {
if (_is_sampling_agent) { if (_is_sampling_agent) {
LOG(ERROR) << "[DeepES] Cloned ESAgent cannot call update function, please use original ESAgent."; LOG(ERROR) << "[EvoKit] Cloned ESAgent cannot call update function, please use original ESAgent.";
return false; return false;
} }
...@@ -109,7 +109,7 @@ bool ESAgent::update( ...@@ -109,7 +109,7 @@ bool ESAgent::update(
int key = noisy_info[i].key(0); int key = noisy_info[i].key(0);
float reward = noisy_rewards[i]; float reward = noisy_rewards[i];
bool success = _sampling_method->resampling(key, _noise, _param_size); bool success = _sampling_method->resampling(key, _noise, _param_size);
CHECK(success) << "[DeepES] resampling error occurs at sample: " << i; CHECK(success) << "[EvoKit] resampling error occurs at sample: " << i;
for (int64_t j = 0; j < _param_size; ++j) { for (int64_t j = 0; j < _param_size; ++j) {
_neg_gradients[j] += _noise[j] * reward; _neg_gradients[j] += _noise[j] * reward;
...@@ -139,14 +139,14 @@ bool ESAgent::add_noise(SamplingInfo& sampling_info) { ...@@ -139,14 +139,14 @@ bool ESAgent::add_noise(SamplingInfo& sampling_info) {
if (!_is_sampling_agent) { if (!_is_sampling_agent) {
LOG(ERROR) << LOG(ERROR) <<
"[DeepES] Original ESAgent cannot call add_noise function, please use cloned ESAgent."; "[EvoKit] Original ESAgent cannot call add_noise function, please use cloned ESAgent.";
success = false; success = false;
return success; return success;
} }
int key = 0; int key = 0;
success = _sampling_method->sampling(&key, _noise, _param_size); success = _sampling_method->sampling(&key, _noise, _param_size);
CHECK(success) << "[DeepES] sampling error occurs while add_noise."; CHECK(success) << "[EvoKit] sampling error occurs while add_noise.";
int model_iter_id = _config->async_es().model_iter_id(); int model_iter_id = _config->async_es().model_iter_id();
sampling_info.add_key(key); sampling_info.add_key(key);
sampling_info.set_model_iter_id(model_iter_id); sampling_info.set_model_iter_id(model_iter_id);
......
#!/bin/bash #!/bin/bash
if [ $# != 1 ]; then if [ $# != 1 ]; then
echo "You must choose one framework (paddle/torch) to compile DeepES." echo "You must choose one framework (paddle/torch) to compile EvoKit."
exit 0 exit 0
fi fi
...@@ -36,11 +36,9 @@ else ...@@ -36,11 +36,9 @@ else
fi fi
#----------------protobuf-------------# #----------------protobuf-------------#
cp ./core/proto/deepes.proto ./ cd core/proto/
protoc deepes.proto --cpp_out ./ protoc evo_kit/evo_kit.proto --cpp_out .
mv deepes.pb.h core/include cd -
mv deepes.pb.cc core/src
rm deepes.proto
#----------------build---------------# #----------------build---------------#
echo ${FLAGS} echo ${FLAGS}
......
...@@ -21,11 +21,12 @@ find_package(Torch REQUIRED ON) ...@@ -21,11 +21,12 @@ find_package(Torch REQUIRED ON)
# include and source # include and source
file(GLOB test_src "../test/src/*.cc") file(GLOB test_src "../test/src/*.cc")
file(GLOB core_src "../core/src/*.cc") file(GLOB core_src "../core/src/*.cc" "../core/proto/evo_kit/*.cc")
file(GLOB agent_src "../torch/src/*.cc") file(GLOB agent_src "../torch/src/*.cc")
include_directories("../torch/include") include_directories("../torch/include")
include_directories("../core/include") include_directories("../core/include")
include_directories("../core/proto")
include_directories("../benchmark") include_directories("../benchmark")
include_directories("../test/include") include_directories("../test/include")
......
...@@ -12,10 +12,9 @@ echo "Cannot find the torch library: ../libtorch" ...@@ -12,10 +12,9 @@ echo "Cannot find the torch library: ../libtorch"
fi fi
#----------------protobuf-------------# #----------------protobuf-------------#
cp ./core/proto/deepes.proto ./ cd core/proto/
protoc deepes.proto --cpp_out ./ protoc evo_kit/evo_kit.proto --cpp_out .
mv deepes.pb.h core/include cd -
mv deepes.pb.cc core/src
#----------------build---------------# #----------------build---------------#
......
...@@ -14,14 +14,13 @@ ...@@ -14,14 +14,13 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include <vector> #include <vector>
#include "optimizer_factory.h" #include "evo_kit/optimizer_factory.h"
#include <memory> #include <memory>
namespace deep_es { namespace evo_kit {
TEST(SGDOptimizersTest, Method_update) { TEST(SGDOptimizersTest, Method_update) {
std::shared_ptr<DeepESConfig> config = std::make_shared<DeepESConfig>(); std::shared_ptr<EvoKitConfig> config = std::make_shared<EvoKitConfig>();
auto optimizer_config = config->mutable_optimizer(); auto optimizer_config = config->mutable_optimizer();
optimizer_config->set_base_lr(1.0); optimizer_config->set_base_lr(1.0);
optimizer_config->set_type("sgd"); optimizer_config->set_type("sgd");
...@@ -39,7 +38,7 @@ TEST(SGDOptimizersTest, Method_update) { ...@@ -39,7 +38,7 @@ TEST(SGDOptimizersTest, Method_update) {
} }
TEST(AdamOptimizersTest, Method_update) { TEST(AdamOptimizersTest, Method_update) {
std::shared_ptr<DeepESConfig> config = std::make_shared<DeepESConfig>(); std::shared_ptr<EvoKitConfig> config = std::make_shared<EvoKitConfig>();
auto optimizer_config = config->mutable_optimizer(); auto optimizer_config = config->mutable_optimizer();
optimizer_config->set_base_lr(1.0); optimizer_config->set_base_lr(1.0);
optimizer_config->set_type("adam"); optimizer_config->set_type("adam");
......
...@@ -14,18 +14,17 @@ ...@@ -14,18 +14,17 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include <vector> #include <vector>
#include "sampling_method.h" #include "evo_kit/sampling_method.h"
#include "gaussian_sampling.h" #include "evo_kit/gaussian_sampling.h"
#include "cached_gaussian_sampling.h" #include "evo_kit/cached_gaussian_sampling.h"
#include <memory> #include <memory>
namespace deep_es { namespace evo_kit {
class SamplingTest : public ::testing::Test { class SamplingTest : public ::testing::Test {
protected: protected:
void init_sampling_method(bool cached) { void init_sampling_method(bool cached) {
config = std::make_shared<DeepESConfig>(); config = std::make_shared<EvoKitConfig>();
config->set_seed(1024); config->set_seed(1024);
auto sampling_config = config->mutable_gaussian_sampling(); auto sampling_config = config->mutable_gaussian_sampling();
sampling_config->set_std(1.0); sampling_config->set_std(1.0);
...@@ -39,7 +38,7 @@ class SamplingTest : public ::testing::Test { ...@@ -39,7 +38,7 @@ class SamplingTest : public ::testing::Test {
} }
std::shared_ptr<SamplingMethod> sampler; std::shared_ptr<SamplingMethod> sampler;
std::shared_ptr<DeepESConfig> config; std::shared_ptr<EvoKitConfig> config;
float array[3] = {1.0, 2.0, 3.0}; float array[3] = {1.0, 2.0, 3.0};
int cache_size = 100; // default cache_size 100 int cache_size = 100; // default cache_size 100
int key = 0; int key = 0;
......
...@@ -17,16 +17,16 @@ ...@@ -17,16 +17,16 @@
#include <glog/logging.h> #include <glog/logging.h>
#include <omp.h> #include <omp.h>
#include "gaussian_sampling.h" #include "evo_kit/gaussian_sampling.h"
#include "evo_kit/es_agent.h"
#include "torch_demo_model.h" #include "torch_demo_model.h"
#include "es_agent.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <random> #include <random>
#include <math.h> #include <math.h>
namespace deep_es { namespace evo_kit {
// The fixture for testing class Foo. // The fixture for testing class Foo.
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include <vector> #include <vector>
#include "utils.h" #include "evo_kit/utils.h"
namespace deep_es { namespace evo_kit {
// Tests that the Utils::compute_centered_rank() method. // Tests that the Utils::compute_centered_rank() method.
TEST(UtilsTest, Method_compute_centered_ranks) { TEST(UtilsTest, Method_compute_centered_ranks) {
......
...@@ -17,12 +17,12 @@ ...@@ -17,12 +17,12 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include "optimizer_factory.h" #include "evo_kit/optimizer_factory.h"
#include "sampling_factory.h" #include "evo_kit/sampling_factory.h"
#include "utils.h" #include "evo_kit/utils.h"
#include "deepes.pb.h" #include "evo_kit/evo_kit.pb.h"
namespace deep_es{ namespace evo_kit{
/** /**
* @brief DeepES agent for Torch. * @brief DeepES agent for Torch.
...@@ -45,7 +45,7 @@ public: ...@@ -45,7 +45,7 @@ public:
ESAgent(std::shared_ptr<T> model, std::string config_path): _model(model) { ESAgent(std::shared_ptr<T> model, std::string config_path): _model(model) {
_is_sampling_agent = false; _is_sampling_agent = false;
_config = std::make_shared<DeepESConfig>(); _config = std::make_shared<EvoKitConfig>();
load_proto_conf(config_path, *_config); load_proto_conf(config_path, *_config);
_sampling_method = create_sampling_method(*_config); _sampling_method = create_sampling_method(*_config);
_optimizer = create_optimizer(_config->optimizer()); _optimizer = create_optimizer(_config->optimizer());
...@@ -145,7 +145,7 @@ public: ...@@ -145,7 +145,7 @@ public:
auto params = _model->named_parameters(); auto params = _model->named_parameters();
int key = 0; int key = 0;
success = _sampling_method->sampling(&key, _noise, _param_size); success = _sampling_method->sampling(&key, _noise, _param_size);
CHECK(success) << "[DeepES] sampling error occurs while add_noise."; CHECK(success) << "[EvoKit] sampling error occurs while add_noise.";
sampling_info.add_key(key); sampling_info.add_key(key);
int64_t counter = 0; int64_t counter = 0;
for (auto& param: sampling_params) { for (auto& param: sampling_params) {
...@@ -184,7 +184,7 @@ private: ...@@ -184,7 +184,7 @@ private:
bool _is_sampling_agent; bool _is_sampling_agent;
std::shared_ptr<SamplingMethod> _sampling_method; std::shared_ptr<SamplingMethod> _sampling_method;
std::shared_ptr<Optimizer> _optimizer; std::shared_ptr<Optimizer> _optimizer;
std::shared_ptr<DeepESConfig> _config; std::shared_ptr<EvoKitConfig> _config;
int64_t _param_size; int64_t _param_size;
// malloc memory of noise and neg_gradients in advance. // malloc memory of noise and neg_gradients in advance.
float* _noise; float* _noise;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册