提交 2ad117a9 编写于 作者: Z zenghsh3

refine comments

上级 f2e6ef0e
......@@ -25,15 +25,18 @@
namespace DeepES {
/* DeepES agent for PaddleLite.
* Users can use `add_noise` function to add noise to parameters and use `get_sample_predictor`
* function to get a predictor with added noise to explore.
* Then can use `update` function to update parameters based on ES algorithm.
* Users also can `clone` multi agents to sample in multi-thread way.
*/
typedef paddle::lite_api::PaddlePredictor PaddlePredictor;
/**
* @brief DeepES agent for PaddleLite.
*
* Users use `clone` fucntion to clone a sampling agent, which can call `add_noise`
* function to add noise to copied parameters and call `get_predictor` fucntion to
* get a paddle predictor with added noise.
*
* Then can use `update` function to update parameters based on ES algorithm.
* Note: parameters of cloned agents will also be updated.
*/
class ESAgent {
public:
ESAgent();
......@@ -44,24 +47,34 @@ class ESAgent {
std::shared_ptr<PaddlePredictor> predictor,
std::string config_path);
// Return a cloned ESAgent, whose _predictor is same with this->_predictor
// but _sample_predictor is pointed to a newly created object.
// This function is used to clone a new ESAgent to sample in multi-thread way.
// NOTE: when calling `update` function of current object, both of their
// parameters will be updated. Because their _predictor is point to same object.
/**
* @breif Clone a sampling agent
*
* Only cloned ESAgent can call `add_noise` function.
* Each cloned ESAgent will have a copy of original parameters.
* (support sampling in multi-thread way)
*/
std::shared_ptr<ESAgent> clone();
// Update parameters of _predictor
/**
* @brief Update parameters of predictor based on ES algorithm.
*
* Only not cloned ESAgent can call `update` function.
* Parameters of cloned agents will also be updated.
*/
bool update(
std::vector<SamplingKey>& noisy_keys,
std::vector<float>& noisy_rewards);
// parameters of _sample_predictor = parameters of _predictor + noise
// copied parameters = original parameters + noise
bool add_noise(SamplingKey& sampling_key);
// Return paddle predict _sample_predictor
// if _is_sampling_agent is true, will return predictor with added noise;
// if _is_sampling_agent is false, will return predictor without added noise.
/**
* @brief Get paddle predict
*
* if _is_sampling_agent is true, will return predictor with added noise;
* if _is_sampling_agent is false, will return predictor without added noise.
*/
std::shared_ptr<PaddlePredictor> get_predictor();
private:
......
......@@ -24,12 +24,13 @@
namespace DeepES{
/* DeepES agent for Torch.
/**
* @brief DeepES agent for Torch.
*
* Our implemtation is flexible to support any model that subclass torch::nn::Module.
* That is, we can instantiate an agent by: es_agent = ESAgent<Model>(model);
* After that, users can clone an agent for multi-thread processing, add parametric noise for exploration,
* and update the parameteres, according to the evaluation resutls of noisy parameters.
*
*/
template <class T>
class ESAgent{
......@@ -57,6 +58,13 @@ public:
_neg_gradients = new float [_param_size];
}
/**
* @breif Clone a sampling agent
*
* Only cloned ESAgent can call `add_noise` function.
* Each cloned ESAgent will have a copy of original parameters.
* (support sampling in multi-thread way)
*/
std::shared_ptr<ESAgent> clone() {
std::shared_ptr<ESAgent> new_agent = std::make_shared<ESAgent>();
......@@ -74,10 +82,22 @@ public:
return new_agent;
}
/**
* @brief Use the model to predict.
*
* if _is_sampling_agent is true, will use the sampling model with added noise;
* if _is_sampling_agent is false, will use the original model without added noise.
*/
torch::Tensor predict(const torch::Tensor& x) {
return _sampled_model->forward(x);
}
/**
* @brief Update parameters of model based on ES algorithm.
*
* Only not cloned ESAgent can call `update` function.
* Parameters of cloned agents will also be updated.
*/
bool update(std::vector<SamplingKey>& noisy_keys, std::vector<float>& noisy_rewards) {
if (_is_sampling_agent) {
LOG(ERROR) << "[DeepES] Cloned ESAgent cannot call update function, please use original ESAgent.";
......@@ -112,6 +132,7 @@ public:
return true;
}
// copied parameters = original parameters + noise
bool add_noise(SamplingKey& sampling_key) {
if (!_is_sampling_agent) {
LOG(ERROR) << "[DeepES] Original ESAgent cannot call add_noise function, please use cloned ESAgent.";
......
......@@ -13,14 +13,7 @@
// limitations under the License.
#include <vector>
#include <iostream>
#include "es_agent.h"
#include "paddle_api.h"
#include "optimizer.h"
#include "utils.h"
#include "gaussian_sampling.h"
#include "deepes.pb.h"
namespace DeepES {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册