Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PARL
提交
1e2746fa
P
PARL
项目概览
PaddlePaddle
/
PARL
通知
68
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
18
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PARL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
18
Issue
18
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1e2746fa
编写于
3月 23, 2020
作者:
Z
zenghsh3
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add _is_sampling_agent flag in original agent and cloned agnet
上级
86d0fedb
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
108 addition
and
198 deletion
+108
-198
deepes/demo/paddle/cartpole_init_model.zip
deepes/demo/paddle/cartpole_init_model.zip
+0
-0
deepes/demo/paddle/cartpole_init_model/__model__
deepes/demo/paddle/cartpole_init_model/__model__
+0
-0
deepes/demo/paddle/cartpole_init_model/fc_0.b_0
deepes/demo/paddle/cartpole_init_model/fc_0.b_0
+0
-0
deepes/demo/paddle/cartpole_init_model/fc_0.w_0
deepes/demo/paddle/cartpole_init_model/fc_0.w_0
+0
-0
deepes/demo/paddle/cartpole_init_model/fc_1.b_0
deepes/demo/paddle/cartpole_init_model/fc_1.b_0
+0
-0
deepes/demo/paddle/cartpole_init_model/fc_1.w_0
deepes/demo/paddle/cartpole_init_model/fc_1.w_0
+0
-0
deepes/demo/paddle/cartpole_solver_parallel.cc
deepes/demo/paddle/cartpole_solver_parallel.cc
+9
-10
deepes/demo/torch/cartpole_solver_parallel.cc
deepes/demo/torch/cartpole_solver_parallel.cc
+14
-10
deepes/include/paddle/es_agent.h
deepes/include/paddle/es_agent.h
+8
-25
deepes/include/torch/es_agent.h
deepes/include/torch/es_agent.h
+44
-74
deepes/scripts/build.sh
deepes/scripts/build.sh
+3
-0
deepes/src/paddle/es_agent.cc
deepes/src/paddle/es_agent.cc
+30
-79
未找到文件。
deepes/demo/paddle/cartpole_init_model.zip
0 → 100644
浏览文件 @
1e2746fa
文件已添加
deepes/demo/paddle/cartpole_init_model/__model__
已删除
100644 → 0
浏览文件 @
86d0fedb
文件已删除
deepes/demo/paddle/cartpole_init_model/fc_0.b_0
已删除
100644 → 0
浏览文件 @
86d0fedb
文件已删除
deepes/demo/paddle/cartpole_init_model/fc_0.w_0
已删除
100644 → 0
浏览文件 @
86d0fedb
文件已删除
deepes/demo/paddle/cartpole_init_model/fc_1.b_0
已删除
100644 → 0
浏览文件 @
86d0fedb
文件已删除
deepes/demo/paddle/cartpole_init_model/fc_1.w_0
已删除
100644 → 0
浏览文件 @
86d0fedb
文件已删除
deepes/demo/paddle/cartpole_solver_parallel.cc
浏览文件 @
1e2746fa
...
...
@@ -59,16 +59,13 @@ int arg_max(const std::vector<float>& vec) {
}
float
evaluate
(
CartPole
&
env
,
std
::
shared_ptr
<
ESAgent
>
agent
,
bool
is_eval
=
false
)
{
float
evaluate
(
CartPole
&
env
,
std
::
shared_ptr
<
ESAgent
>
agent
)
{
float
total_reward
=
0.0
;
env
.
reset
();
const
float
*
obs
=
env
.
getState
();
std
::
shared_ptr
<
PaddlePredictor
>
paddle_predictor
;
if
(
is_eval
)
paddle_predictor
=
agent
->
get_evaluate_predictor
();
// For evaluate
else
paddle_predictor
=
agent
->
get_sample_predictor
();
// For sampling (ES exploring)
paddle_predictor
=
agent
->
get_predictor
();
while
(
true
)
{
std
::
vector
<
float
>
probs
=
forward
(
paddle_predictor
,
obs
);
...
...
@@ -93,8 +90,9 @@ int main(int argc, char* argv[]) {
std
::
shared_ptr
<
PaddlePredictor
>
paddle_predictor
=
create_paddle_predictor
(
"../demo/paddle/cartpole_init_model"
);
std
::
shared_ptr
<
ESAgent
>
agent
=
std
::
make_shared
<
ESAgent
>
(
paddle_predictor
,
"../benchmark/cartpole_config.prototxt"
);
std
::
vector
<
std
::
shared_ptr
<
ESAgent
>
>
sampling_agents
{
agent
};
for
(
int
i
=
0
;
i
<
(
ITER
-
1
);
++
i
)
{
// Clone agents to sample (explore).
std
::
vector
<
std
::
shared_ptr
<
ESAgent
>
>
sampling_agents
;
for
(
int
i
=
0
;
i
<
ITER
;
++
i
)
{
sampling_agents
.
push_back
(
agent
->
clone
());
}
...
...
@@ -107,7 +105,8 @@ int main(int argc, char* argv[]) {
#pragma omp parallel for schedule(dynamic, 1)
for
(
int
i
=
0
;
i
<
ITER
;
++
i
)
{
std
::
shared_ptr
<
ESAgent
>
sampling_agent
=
sampling_agents
[
i
];
SamplingKey
key
=
sampling_agent
->
add_noise
();
SamplingKey
key
;
bool
success
=
sampling_agent
->
add_noise
(
key
);
float
reward
=
evaluate
(
envs
[
i
],
sampling_agent
);
noisy_keys
[
i
]
=
key
;
...
...
@@ -115,9 +114,9 @@ int main(int argc, char* argv[]) {
}
// NOTE: all parameters of sampling_agents will be updated
agent
->
update
(
noisy_keys
,
noisy_rewards
);
bool
success
=
agent
->
update
(
noisy_keys
,
noisy_rewards
);
int
reward
=
evaluate
(
envs
[
0
],
agent
,
true
);
int
reward
=
evaluate
(
envs
[
0
],
agent
);
LOG
(
INFO
)
<<
"Epoch:"
<<
epoch
<<
" Reward: "
<<
reward
;
}
}
deepes/demo/torch/cartpole_solver_parallel.cc
浏览文件 @
1e2746fa
...
...
@@ -25,13 +25,13 @@
using
namespace
DeepES
;
const
int
ITER
=
10
;
float
evaluate
(
CartPole
&
env
,
std
::
shared_ptr
<
ESAgent
<
Model
>>
agent
,
bool
is_eval
=
false
)
{
float
evaluate
(
CartPole
&
env
,
std
::
shared_ptr
<
ESAgent
<
Model
>>
agent
)
{
float
total_reward
=
0.0
;
env
.
reset
();
const
float
*
obs
=
env
.
getState
();
while
(
true
)
{
torch
::
Tensor
obs_tensor
=
torch
::
tensor
({
obs
[
0
],
obs
[
1
],
obs
[
2
],
obs
[
3
]});
torch
::
Tensor
action
=
agent
->
predict
(
obs_tensor
,
is_eval
);
torch
::
Tensor
action
=
agent
->
predict
(
obs_tensor
);
int
act
=
std
::
get
<
1
>
(
action
.
max
(
-
1
)).
item
<
long
>
();
env
.
step
(
act
);
float
reward
=
env
.
getReward
();
...
...
@@ -52,9 +52,10 @@ int main(int argc, char* argv[]) {
auto
model
=
std
::
make_shared
<
Model
>
(
4
,
2
);
std
::
shared_ptr
<
ESAgent
<
Model
>>
agent
=
std
::
make_shared
<
ESAgent
<
Model
>>
(
model
,
"../benchmark/cartpole_config.prototxt"
);
std
::
vector
<
std
::
shared_ptr
<
ESAgent
<
Model
>>>
sampling_agents
=
{
agent
};
for
(
int
i
=
0
;
i
<
ITER
-
1
;
++
i
)
{
// Clone agents to sample (explore).
std
::
vector
<
std
::
shared_ptr
<
ESAgent
<
Model
>>>
sampling_agents
;
for
(
int
i
=
0
;
i
<
ITER
;
++
i
)
{
sampling_agents
.
push_back
(
agent
->
clone
());
}
...
...
@@ -66,15 +67,18 @@ int main(int argc, char* argv[]) {
#pragma omp parallel for schedule(dynamic, 1)
for
(
int
i
=
0
;
i
<
ITER
;
++
i
)
{
auto
sampling_agent
=
sampling_agents
[
i
];
SamplingKey
key
=
sampling_agent
->
add_noise
();
SamplingKey
key
;
bool
success
=
sampling_agent
->
add_noise
(
key
);
float
reward
=
evaluate
(
envs
[
i
],
sampling_agent
);
noisy_keys
[
i
]
=
key
;
noisy_rewards
[
i
]
=
reward
;
}
agent
->
update
(
noisy_keys
,
noisy_rewards
);
int
reward
=
evaluate
(
envs
[
0
],
agent
,
true
);
// Will also update parameters of sampling_agents
bool
success
=
agent
->
update
(
noisy_keys
,
noisy_rewards
);
// Use original agent to evalute (without noise).
int
reward
=
evaluate
(
envs
[
0
],
agent
);
LOG
(
INFO
)
<<
"Epoch:"
<<
epoch
<<
" Reward: "
<<
reward
;
}
}
deepes/include/paddle/es_agent.h
浏览文件 @
1e2746fa
...
...
@@ -46,8 +46,8 @@ class ESAgent {
// Return a cloned ESAgent, whose _predictor is same with this->_predictor
// but _sample_predictor is pointed to a newly created object.
// This function
mainly used to clone a new ESAgent to do sampling
in multi-thread way.
// NOTE: when calling `update` function of current object
or cloned one
, both of their
// This function
is used to clone a new ESAgent to sample
in multi-thread way.
// NOTE: when calling `update` function of current object, both of their
// parameters will be updated. Because their _predictor is point to same object.
std
::
shared_ptr
<
ESAgent
>
clone
();
...
...
@@ -57,34 +57,17 @@ class ESAgent {
std
::
vector
<
float
>&
noisy_rewards
);
// parameters of _sample_predictor = parameters of _predictor + noise
SamplingKey
add_noise
(
);
bool
add_noise
(
SamplingKey
&
sampling_key
);
std
::
shared_ptr
<
SamplingMethod
>
get_sampling_method
();
std
::
shared_ptr
<
Optimizer
>
get_optimizer
();
std
::
shared_ptr
<
DeepESConfig
>
get_config
();
int64_t
get_param_size
();
std
::
vector
<
std
::
string
>
get_param_names
();
// Return paddle predict _sample_predictor (with addded noise)
std
::
shared_ptr
<
PaddlePredictor
>
get_sample_predictor
();
// Return paddle predict _predictor (without addded noise)
std
::
shared_ptr
<
PaddlePredictor
>
get_evaluate_predictor
();
void
set_config
(
std
::
shared_ptr
<
DeepESConfig
>
config
);
void
set_sampling_method
(
std
::
shared_ptr
<
SamplingMethod
>
sampling_method
);
void
set_optimizer
(
std
::
shared_ptr
<
Optimizer
>
optimizer
);
void
set_param_size
(
int64_t
param_size
);
void
set_param_names
(
std
::
vector
<
std
::
string
>
param_names
);
void
set_noise
(
float
*
noise
);
void
set_neg_gradients
(
float
*
neg_gradients
);
void
set_predictor
(
std
::
shared_ptr
<
PaddlePredictor
>
predictor
,
std
::
shared_ptr
<
PaddlePredictor
>
sample_predictor
);
// Return paddle predict _sample_predictor
// if _is_sampling_agent is true, will return predictor with added noise;
// if _is_sampling_agent is false, will return predictor without added noise.
std
::
shared_ptr
<
PaddlePredictor
>
get_predictor
();
private:
std
::
shared_ptr
<
PaddlePredictor
>
_predictor
;
std
::
shared_ptr
<
PaddlePredictor
>
_sample_predictor
;
bool
_is_sampling_agent
;
std
::
shared_ptr
<
SamplingMethod
>
_sampling_method
;
std
::
shared_ptr
<
Optimizer
>
_optimizer
;
std
::
shared_ptr
<
DeepESConfig
>
_config
;
...
...
deepes/include/torch/es_agent.h
浏览文件 @
1e2746fa
...
...
@@ -26,101 +26,64 @@ namespace DeepES{
/* DeepES agent for Torch.
* Our implemtation is flexible to support any model that subclass torch::nn::Module.
* That is, we can instantiate a agent by: es_agent = ESAgent<Model>(model);
* After that, users can clone a agent for multi-thread processing, add parametric noise for exploration,
* That is, we can instantiate a
n
agent by: es_agent = ESAgent<Model>(model);
* After that, users can clone a
n
agent for multi-thread processing, add parametric noise for exploration,
* and update the parameteres, according to the evaluation resutls of noisy parameters.
*
*/
template
<
class
T
>
class
ESAgent
{
public:
ESAgent
()
:
_param_size
(
0
)
{}
ESAgent
()
{}
~
ESAgent
()
{
delete
[]
_noise
;
delete
[]
_neg_gradients
;
if
(
!
_is_sampling_agent
)
delete
[]
_neg_gradients
;
}
ESAgent
(
std
::
shared_ptr
<
T
>
model
,
std
::
string
config_path
)
:
_model
(
model
)
{
_is_sampling_agent
=
false
;
_config
=
std
::
make_shared
<
DeepESConfig
>
();
load_proto_conf
(
config_path
,
*
_config
);
_sampling_method
=
std
::
make_shared
<
GaussianSampling
>
();
_sampling_method
->
load_config
(
*
_config
);
_optimizer
=
std
::
make_shared
<
SGDOptimizer
>
(
_config
->
optimizer
().
base_lr
());
_param_size
=
0
;
_sampled_model
=
model
->
clone
()
;
param_size
();
// Origin agent can't be used to sample, so keep it same with _model for evaluating.
_sampled_model
=
model
;
_param_size
=
_calculate_
param_size
();
_noise
=
new
float
[
_param_size
];
_neg_gradients
=
new
float
[
_param_size
];
}
std
::
shared_ptr
<
ESAgent
>
clone
()
{
std
::
shared_ptr
<
T
>
new_model
=
_model
->
clone
();
std
::
shared_ptr
<
ESAgent
>
new_agent
=
std
::
make_shared
<
ESAgent
>
();
new_agent
->
set_model
(
_model
,
new_model
);
new_agent
->
set_sampling_method
(
_sampling_method
);
new_agent
->
set_optimizer
(
_optimizer
);
new_agent
->
set_config
(
_config
);
new_agent
->
set_param_size
(
_param_size
);
float
*
new_noise
=
new
float
[
_param_size
];
float
*
new_neg_gradients
=
new
float
[
_param_size
];
new_agent
->
set_noise
(
new_noise
);
new_agent
->
set_neg_gradients
(
new_neg_gradients
);
return
new_agent
;
}
void
set_config
(
std
::
shared_ptr
<
DeepESConfig
>
config
)
{
_config
=
config
;
}
void
set_sampling_method
(
std
::
shared_ptr
<
SamplingMethod
>
sampling_method
)
{
_sampling_method
=
sampling_method
;
}
new_agent
->
_model
=
_model
;
std
::
shared_ptr
<
T
>
new_model
=
_model
->
clone
();
new_agent
->
_sampled_model
=
new_model
;
void
set_model
(
std
::
shared_ptr
<
T
>
model
,
std
::
shared_ptr
<
T
>
sampled_model
)
{
_model
=
model
;
_sampled_model
=
sampled_model
;
}
new_agent
->
_is_sampling_agent
=
true
;
new_agent
->
_sampling_method
=
_sampling_method
;
new_agent
->
_param_size
=
_param_size
;
std
::
shared_ptr
<
SamplingMethod
>
get_sampling_method
()
{
return
_sampling_method
;
}
std
::
shared_ptr
<
Optimizer
>
get_optimizer
()
{
return
_optimizer
;
}
void
set_optimizer
(
std
::
shared_ptr
<
Optimizer
>
optimizer
)
{
_optimizer
=
optimizer
;
}
void
set_param_size
(
int64_t
param_size
)
{
_param_size
=
param_size
;
}
float
*
new_noise
=
new
float
[
_param_size
];
new_agent
->
_noise
=
new_noise
;
void
set_noise
(
float
*
noise
)
{
_noise
=
noise
;
return
new_agent
;
}
void
set_neg_gradients
(
float
*
neg_gradients
)
{
_neg_gradients
=
neg_gradients
;
torch
::
Tensor
predict
(
const
torch
::
Tensor
&
x
)
{
return
_sampled_model
->
forward
(
x
)
;
}
torch
::
Tensor
predict
(
const
torch
::
Tensor
&
x
,
bool
is_eval
=
false
)
{
if
(
is_eval
)
{
// predict with _model (without addding noise)
return
_model
->
forward
(
x
);
}
else
{
// predict with _sampled_model (with adding noise)
return
_sampled_model
->
forward
(
x
);
bool
update
(
std
::
vector
<
SamplingKey
>&
noisy_keys
,
std
::
vector
<
float
>&
noisy_rewards
)
{
if
(
_is_sampling_agent
)
{
LOG
(
ERROR
)
<<
"[DeepES] Cloned ESAgent cannot call update function, please use original ESAgent."
;
return
false
;
}
}
bool
update
(
std
::
vector
<
SamplingKey
>&
noisy_keys
,
std
::
vector
<
float
>&
noisy_rewards
)
{
compute_centered_ranks
(
noisy_rewards
);
memset
(
_neg_gradients
,
0
,
_param_size
*
sizeof
(
float
));
...
...
@@ -145,10 +108,16 @@ public:
_optimizer
->
update
(
tensor_a
,
_neg_gradients
+
counter
,
tensor
.
size
(
0
));
counter
+=
tensor
.
size
(
0
);
}
return
true
;
}
SamplingKey
add_noise
()
{
SamplingKey
sampling_key
;
bool
add_noise
(
SamplingKey
&
sampling_key
)
{
if
(
!
_is_sampling_agent
)
{
LOG
(
ERROR
)
<<
"[DeepES] Original ESAgent cannot call add_noise function, please use cloned ESAgent."
;
return
false
;
}
auto
sampled_params
=
_sampled_model
->
named_parameters
();
auto
params
=
_model
->
named_parameters
();
int
key
=
_sampling_method
->
sampling
(
_noise
,
_param_size
);
...
...
@@ -165,23 +134,15 @@ public:
}
counter
+=
tensor
.
size
(
0
);
}
return
sampling_key
;
return
true
;
}
int64_t
param_size
()
{
if
(
_param_size
==
0
)
{
auto
params
=
_model
->
named_parameters
();
for
(
auto
&
param
:
params
)
{
torch
::
Tensor
tensor
=
param
.
value
().
view
({
-
1
});
_param_size
+=
tensor
.
size
(
0
);
}
}
return
_param_size
;
}
private:
std
::
shared_ptr
<
T
>
_sampled_model
;
std
::
shared_ptr
<
T
>
_model
;
bool
_is_sampling_agent
;
std
::
shared_ptr
<
SamplingMethod
>
_sampling_method
;
std
::
shared_ptr
<
Optimizer
>
_optimizer
;
std
::
shared_ptr
<
DeepESConfig
>
_config
;
...
...
@@ -189,6 +150,15 @@ private:
// malloc memory of noise and neg_gradients in advance.
float
*
_noise
;
float
*
_neg_gradients
;
int64_t
_calculate_param_size
()
{
auto
params
=
_model
->
named_parameters
();
for
(
auto
&
param
:
params
)
{
torch
::
Tensor
tensor
=
param
.
value
().
view
({
-
1
});
_param_size
+=
tensor
.
size
(
0
);
}
return
_param_size
;
}
};
}
...
...
deepes/scripts/build.sh
浏览文件 @
1e2746fa
...
...
@@ -12,6 +12,9 @@ if [ $1 = "paddle" ]; then
echo
"Please put the PaddleLite libraray to current folder according the instruction in README"
exit
1
fi
# Initialization model
unzip ./demo/paddle/cartpole_init_model.zip
-d
./demo/paddle/
FLAGS
=
" -DWITH_PADDLE=ON"
elif
[
$1
=
"torch"
]
;
then
...
...
deepes/src/paddle/es_agent.cc
浏览文件 @
1e2746fa
...
...
@@ -28,7 +28,7 @@ typedef paddle::lite_api::PaddlePredictor PaddlePredictor;
typedef
paddle
::
lite_api
::
Tensor
Tensor
;
typedef
paddle
::
lite_api
::
shape_t
shape_t
;
int64_t
ShapeProduction
(
const
shape_t
&
shape
)
{
in
line
in
t64_t
ShapeProduction
(
const
shape_t
&
shape
)
{
int64_t
res
=
1
;
for
(
auto
i
:
shape
)
res
*=
i
;
return
res
;
...
...
@@ -38,15 +38,18 @@ ESAgent::ESAgent() {}
ESAgent
::~
ESAgent
()
{
delete
[]
_noise
;
delete
[]
_neg_gradients
;
if
(
!
_is_sampling_agent
)
delete
[]
_neg_gradients
;
}
ESAgent
::
ESAgent
(
std
::
shared_ptr
<
PaddlePredictor
>
predictor
,
std
::
string
config_path
)
{
_is_sampling_agent
=
false
;
_predictor
=
predictor
;
_sample_predictor
=
predictor
->
Clone
();
// Original agent can't be used to sample, so keep it same with _predictor for evaluating.
_sample_predictor
=
predictor
;
_config
=
std
::
make_shared
<
DeepESConfig
>
();
load_proto_conf
(
config_path
,
*
_config
);
...
...
@@ -69,22 +72,27 @@ std::shared_ptr<ESAgent> ESAgent::clone() {
std
::
shared_ptr
<
ESAgent
>
new_agent
=
std
::
make_shared
<
ESAgent
>
();
float
*
new_noise
=
new
float
[
_param_size
];
float
*
new_neg_gradients
=
new
float
[
_param_size
];
new_agent
->
set_predictor
(
_predictor
,
new_sample_predictor
)
;
new_agent
->
set_sampling_method
(
_sampling_method
);
new_agent
->
set_optimizer
(
_optimizer
)
;
new_agent
->
set_config
(
_config
)
;
new_agent
->
set_param_size
(
_param_size
)
;
new_agent
->
set_param_names
(
_param_names
)
;
new_agent
->
set_noise
(
new_noise
)
;
new_agent
->
set_neg_gradients
(
new_neg_gradients
);
new_agent
->
_predictor
=
_predictor
;
new_agent
->
_sample_predictor
=
new_sample_predictor
;
new_agent
->
_is_sampling_agent
=
true
;
new_agent
->
_sampling_method
=
_sampling_method
;
new_agent
->
_param_names
=
_param_names
;
new_agent
->
_param_size
=
_param_size
;
new_agent
->
_noise
=
new_noise
;
return
new_agent
;
}
bool
ESAgent
::
update
(
std
::
vector
<
SamplingKey
>&
noisy_keys
,
std
::
vector
<
float
>&
noisy_rewards
)
{
if
(
_is_sampling_agent
)
{
LOG
(
ERROR
)
<<
"[DeepES] Cloned ESAgent cannot call update function, please use original ESAgent."
;
return
false
;
}
compute_centered_ranks
(
noisy_rewards
);
memset
(
_neg_gradients
,
0
,
_param_size
*
sizeof
(
float
));
...
...
@@ -110,11 +118,16 @@ bool ESAgent::update(
_optimizer
->
update
(
tensor_data
,
_neg_gradients
+
counter
,
tensor_size
);
counter
+=
tensor_size
;
}
return
true
;
}
SamplingKey
ESAgent
::
add_noise
()
{
SamplingKey
sampling_key
;
bool
ESAgent
::
add_noise
(
SamplingKey
&
sampling_key
)
{
if
(
!
_is_sampling_agent
)
{
LOG
(
ERROR
)
<<
"[DeepES] Original ESAgent cannot call add_noise function, please use cloned ESAgent."
;
return
false
;
}
int
key
=
_sampling_method
->
sampling
(
_noise
,
_param_size
);
sampling_key
.
add_key
(
key
);
int64_t
counter
=
0
;
...
...
@@ -129,76 +142,14 @@ SamplingKey ESAgent::add_noise() {
counter
+=
tensor_size
;
}
return
sampling_key
;
}
std
::
shared_ptr
<
SamplingMethod
>
ESAgent
::
get_sampling_method
()
{
return
_sampling_method
;
}
std
::
shared_ptr
<
Optimizer
>
ESAgent
::
get_optimizer
()
{
return
_optimizer
;
}
std
::
shared_ptr
<
DeepESConfig
>
ESAgent
::
get_config
()
{
return
_config
;
}
int64_t
ESAgent
::
get_param_size
()
{
return
_param_size
;
}
std
::
vector
<
std
::
string
>
ESAgent
::
get_param_names
()
{
return
_param_names
;
return
true
;
}
std
::
shared_ptr
<
PaddlePredictor
>
ESAgent
::
get_
sample_
predictor
()
{
std
::
shared_ptr
<
PaddlePredictor
>
ESAgent
::
get_predictor
()
{
return
_sample_predictor
;
}
std
::
shared_ptr
<
PaddlePredictor
>
ESAgent
::
get_evaluate_predictor
()
{
return
_predictor
;
}
void
ESAgent
::
set_predictor
(
std
::
shared_ptr
<
PaddlePredictor
>
predictor
,
std
::
shared_ptr
<
PaddlePredictor
>
sample_predictor
)
{
_predictor
=
predictor
;
_sample_predictor
=
sample_predictor
;
}
void
ESAgent
::
set_sampling_method
(
std
::
shared_ptr
<
SamplingMethod
>
sampling_method
)
{
_sampling_method
=
sampling_method
;
}
void
ESAgent
::
set_optimizer
(
std
::
shared_ptr
<
Optimizer
>
optimizer
)
{
_optimizer
=
optimizer
;
}
void
ESAgent
::
set_config
(
std
::
shared_ptr
<
DeepESConfig
>
config
)
{
_config
=
config
;
}
void
ESAgent
::
set_param_size
(
int64_t
param_size
)
{
_param_size
=
param_size
;
}
void
ESAgent
::
set_param_names
(
std
::
vector
<
std
::
string
>
param_names
)
{
_param_names
=
param_names
;
}
void
ESAgent
::
set_noise
(
float
*
noise
)
{
_noise
=
noise
;
}
void
ESAgent
::
set_neg_gradients
(
float
*
neg_gradients
)
{
_neg_gradients
=
neg_gradients
;
}
int64_t
ESAgent
::
_calculate_param_size
()
{
int64_t
param_size
=
0
;
for
(
std
::
string
param_name
:
_param_names
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录