Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PARL
提交
2ad117a9
P
PARL
项目概览
PaddlePaddle
/
PARL
通知
68
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
18
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PARL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
18
Issue
18
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2ad117a9
编写于
3月 23, 2020
作者:
Z
zenghsh3
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine comments
上级
f2e6ef0e
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
53 addition
and
26 deletion
+53
-26
deepes/include/paddle/es_agent.h
deepes/include/paddle/es_agent.h
+30
-17
deepes/include/torch/es_agent.h
deepes/include/torch/es_agent.h
+23
-2
deepes/src/paddle/es_agent.cc
deepes/src/paddle/es_agent.cc
+0
-7
未找到文件。
deepes/include/paddle/es_agent.h
浏览文件 @
2ad117a9
...
...
@@ -25,15 +25,18 @@
namespace
DeepES
{
/* DeepES agent for PaddleLite.
* Users can use `add_noise` function to add noise to parameters and use `get_sample_predictor`
* function to get a predictor with added noise to explore.
* Then can use `update` function to update parameters based on ES algorithm.
* Users also can `clone` multi agents to sample in multi-thread way.
*/
typedef
paddle
::
lite_api
::
PaddlePredictor
PaddlePredictor
;
/**
* @brief DeepES agent for PaddleLite.
*
* Users use `clone` fucntion to clone a sampling agent, which can call `add_noise`
* function to add noise to copied parameters and call `get_predictor` fucntion to
* get a paddle predictor with added noise.
*
* Then can use `update` function to update parameters based on ES algorithm.
* Note: parameters of cloned agents will also be updated.
*/
class
ESAgent
{
public:
ESAgent
();
...
...
@@ -44,24 +47,34 @@ class ESAgent {
std
::
shared_ptr
<
PaddlePredictor
>
predictor
,
std
::
string
config_path
);
// Return a cloned ESAgent, whose _predictor is same with this->_predictor
// but _sample_predictor is pointed to a newly created object.
// This function is used to clone a new ESAgent to sample in multi-thread way.
// NOTE: when calling `update` function of current object, both of their
// parameters will be updated. Because their _predictor is point to same object.
/**
* @breif Clone a sampling agent
*
* Only cloned ESAgent can call `add_noise` function.
* Each cloned ESAgent will have a copy of original parameters.
* (support sampling in multi-thread way)
*/
std
::
shared_ptr
<
ESAgent
>
clone
();
// Update parameters of _predictor
/**
* @brief Update parameters of predictor based on ES algorithm.
*
* Only not cloned ESAgent can call `update` function.
* Parameters of cloned agents will also be updated.
*/
bool
update
(
std
::
vector
<
SamplingKey
>&
noisy_keys
,
std
::
vector
<
float
>&
noisy_rewards
);
//
parameters of _sample_predictor = parameters of _predictor
+ noise
//
copied parameters = original parameters
+ noise
bool
add_noise
(
SamplingKey
&
sampling_key
);
// Return paddle predict _sample_predictor
// if _is_sampling_agent is true, will return predictor with added noise;
// if _is_sampling_agent is false, will return predictor without added noise.
/**
* @brief Get paddle predict
*
* if _is_sampling_agent is true, will return predictor with added noise;
* if _is_sampling_agent is false, will return predictor without added noise.
*/
std
::
shared_ptr
<
PaddlePredictor
>
get_predictor
();
private:
...
...
deepes/include/torch/es_agent.h
浏览文件 @
2ad117a9
...
...
@@ -24,12 +24,13 @@
namespace
DeepES
{
/* DeepES agent for Torch.
/**
* @brief DeepES agent for Torch.
*
* Our implemtation is flexible to support any model that subclass torch::nn::Module.
* That is, we can instantiate an agent by: es_agent = ESAgent<Model>(model);
* After that, users can clone an agent for multi-thread processing, add parametric noise for exploration,
* and update the parameteres, according to the evaluation resutls of noisy parameters.
*
*/
template
<
class
T
>
class
ESAgent
{
...
...
@@ -57,6 +58,13 @@ public:
_neg_gradients
=
new
float
[
_param_size
];
}
/**
* @breif Clone a sampling agent
*
* Only cloned ESAgent can call `add_noise` function.
* Each cloned ESAgent will have a copy of original parameters.
* (support sampling in multi-thread way)
*/
std
::
shared_ptr
<
ESAgent
>
clone
()
{
std
::
shared_ptr
<
ESAgent
>
new_agent
=
std
::
make_shared
<
ESAgent
>
();
...
...
@@ -74,10 +82,22 @@ public:
return
new_agent
;
}
/**
* @brief Use the model to predict.
*
* if _is_sampling_agent is true, will use the sampling model with added noise;
* if _is_sampling_agent is false, will use the original model without added noise.
*/
torch
::
Tensor
predict
(
const
torch
::
Tensor
&
x
)
{
return
_sampled_model
->
forward
(
x
);
}
/**
* @brief Update parameters of model based on ES algorithm.
*
* Only not cloned ESAgent can call `update` function.
* Parameters of cloned agents will also be updated.
*/
bool
update
(
std
::
vector
<
SamplingKey
>&
noisy_keys
,
std
::
vector
<
float
>&
noisy_rewards
)
{
if
(
_is_sampling_agent
)
{
LOG
(
ERROR
)
<<
"[DeepES] Cloned ESAgent cannot call update function, please use original ESAgent."
;
...
...
@@ -112,6 +132,7 @@ public:
return
true
;
}
// copied parameters = original parameters + noise
bool
add_noise
(
SamplingKey
&
sampling_key
)
{
if
(
!
_is_sampling_agent
)
{
LOG
(
ERROR
)
<<
"[DeepES] Original ESAgent cannot call add_noise function, please use cloned ESAgent."
;
...
...
deepes/src/paddle/es_agent.cc
浏览文件 @
2ad117a9
...
...
@@ -13,14 +13,7 @@
// limitations under the License.
#include <vector>
#include <iostream>
#include "es_agent.h"
#include "paddle_api.h"
#include "optimizer.h"
#include "utils.h"
#include "gaussian_sampling.h"
#include "deepes.pb.h"
namespace
DeepES
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录