Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PARL
提交
752974cb
P
PARL
项目概览
PaddlePaddle
/
PARL
通知
67
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
18
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PARL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
18
Issue
18
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
752974cb
编写于
3月 30, 2020
作者:
Z
zhoubo01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rename AsyncAgent to AsyncESAgent
上级
7793b479
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
25 addition
and
24 deletion
+25
-24
deepes/CMakeLists.txt
deepes/CMakeLists.txt
+1
-0
deepes/demo/paddle/cartpole_async_solver.cc
deepes/demo/paddle/cartpole_async_solver.cc
+4
-4
deepes/include/paddle/async_es_agent.h
deepes/include/paddle/async_es_agent.h
+6
-6
deepes/scripts/build.sh
deepes/scripts/build.sh
+1
-1
deepes/src/paddle/async_es_agent.cc
deepes/src/paddle/async_es_agent.cc
+13
-13
未找到文件。
deepes/CMakeLists.txt
浏览文件 @
752974cb
...
@@ -45,6 +45,7 @@ if (WITH_PADDLE)
...
@@ -45,6 +45,7 @@ if (WITH_PADDLE)
file
(
GLOB framework_src
"src/paddle/*.cc"
)
file
(
GLOB framework_src
"src/paddle/*.cc"
)
set
(
demo
"
${
PROJECT_SOURCE_DIR
}
/demo/paddle/cartpole_solver_parallel.cc"
)
set
(
demo
"
${
PROJECT_SOURCE_DIR
}
/demo/paddle/cartpole_solver_parallel.cc"
)
#set(demo "${PROJECT_SOURCE_DIR}/demo/paddle/cartpole_async_solver.cc")
########## Torch config ##########
########## Torch config ##########
elseif
(
WITH_TORCH
)
elseif
(
WITH_TORCH
)
list
(
APPEND CMAKE_PREFIX_PATH
"./libtorch"
)
list
(
APPEND CMAKE_PREFIX_PATH
"./libtorch"
)
...
...
deepes/demo/paddle/cartpole_async_solver.cc
浏览文件 @
752974cb
...
@@ -58,7 +58,7 @@ int arg_max(const std::vector<float>& vec) {
...
@@ -58,7 +58,7 @@ int arg_max(const std::vector<float>& vec) {
}
}
float
evaluate
(
CartPole
&
env
,
std
::
shared_ptr
<
AsyncAgent
>
agent
)
{
float
evaluate
(
CartPole
&
env
,
std
::
shared_ptr
<
Async
ES
Agent
>
agent
)
{
float
total_reward
=
0.0
;
float
total_reward
=
0.0
;
env
.
reset
();
env
.
reset
();
const
float
*
obs
=
env
.
getState
();
const
float
*
obs
=
env
.
getState
();
...
@@ -87,10 +87,10 @@ int main(int argc, char* argv[]) {
...
@@ -87,10 +87,10 @@ int main(int argc, char* argv[]) {
}
}
std
::
shared_ptr
<
PaddlePredictor
>
paddle_predictor
=
create_paddle_predictor
(
"../demo/paddle/cartpole_init_model"
);
std
::
shared_ptr
<
PaddlePredictor
>
paddle_predictor
=
create_paddle_predictor
(
"../demo/paddle/cartpole_init_model"
);
std
::
shared_ptr
<
Async
Agent
>
agent
=
std
::
make_shared
<
Async
Agent
>
(
paddle_predictor
,
"../benchmark/cartpole_config.prototxt"
);
std
::
shared_ptr
<
Async
ESAgent
>
agent
=
std
::
make_shared
<
AsyncES
Agent
>
(
paddle_predictor
,
"../benchmark/cartpole_config.prototxt"
);
// Clone agents to sample (explore).
// Clone agents to sample (explore).
std
::
vector
<
std
::
shared_ptr
<
AsyncAgent
>
>
sampling_agents
;
std
::
vector
<
std
::
shared_ptr
<
Async
ES
Agent
>
>
sampling_agents
;
for
(
int
i
=
0
;
i
<
ITER
;
++
i
)
{
for
(
int
i
=
0
;
i
<
ITER
;
++
i
)
{
sampling_agents
.
push_back
(
agent
->
clone
());
sampling_agents
.
push_back
(
agent
->
clone
());
}
}
...
@@ -113,7 +113,7 @@ int main(int argc, char* argv[]) {
...
@@ -113,7 +113,7 @@ int main(int argc, char* argv[]) {
}
}
#pragma omp parallel for schedule(dynamic, 1)
#pragma omp parallel for schedule(dynamic, 1)
for
(
int
i
=
0
;
i
<
ITER
;
++
i
)
{
for
(
int
i
=
0
;
i
<
ITER
;
++
i
)
{
std
::
shared_ptr
<
AsyncAgent
>
sampling_agent
=
sampling_agents
[
i
];
std
::
shared_ptr
<
Async
ES
Agent
>
sampling_agent
=
sampling_agents
[
i
];
SamplingInfo
info
;
SamplingInfo
info
;
bool
success
=
sampling_agent
->
add_noise
(
info
);
bool
success
=
sampling_agent
->
add_noise
(
info
);
float
reward
=
evaluate
(
envs
[
i
],
sampling_agent
);
float
reward
=
evaluate
(
envs
[
i
],
sampling_agent
);
...
...
deepes/include/paddle/async_es_agent.h
浏览文件 @
752974cb
...
@@ -26,27 +26,27 @@ namespace DeepES{
...
@@ -26,27 +26,27 @@ namespace DeepES{
* 2. add_noise: add noise into parameters.
* 2. add_noise: add noise into parameters.
* 3. update: update parameters given data collected during evaluation.
* 3. update: update parameters given data collected during evaluation.
*/
*/
class
AsyncAgent
:
public
ESAgent
{
class
Async
ES
Agent
:
public
ESAgent
{
public:
public:
AsyncAgent
()
{}
Async
ES
Agent
()
{}
~
AsyncAgent
();
~
Async
ES
Agent
();
/**
/**
* @args:
* @args:
* predictor: predictor created by users for prediction.
* predictor: predictor created by users for prediction.
* config_path: the path of configuration file.
* config_path: the path of configuration file.
* Note that AsyncAgent will update the configuration file after calling the update function.
* Note that Async
ES
Agent will update the configuration file after calling the update function.
* Please use the up-to-date configuration.
* Please use the up-to-date configuration.
*/
*/
AsyncAgent
(
Async
ES
Agent
(
std
::
shared_ptr
<
PaddlePredictor
>
predictor
,
std
::
shared_ptr
<
PaddlePredictor
>
predictor
,
std
::
string
config_path
);
std
::
string
config_path
);
/**
/**
* @brief: Clone an agent for sampling.
* @brief: Clone an agent for sampling.
*/
*/
std
::
shared_ptr
<
AsyncAgent
>
clone
();
std
::
shared_ptr
<
Async
ES
Agent
>
clone
();
/**
/**
* @brief: Clone an agent for sampling.
* @brief: Clone an agent for sampling.
...
...
deepes/scripts/build.sh
浏览文件 @
752974cb
...
@@ -47,7 +47,7 @@ rm -rf build
...
@@ -47,7 +47,7 @@ rm -rf build
mkdir
build
mkdir
build
cd
build
cd
build
cmake ../
${
FLAGS
}
cmake ../
${
FLAGS
}
make
-j10
make
-j10
#-----------------run----------------#
#-----------------run----------------#
./parallel_main
./parallel_main
deepes/src/paddle/async_es_agent.cc
浏览文件 @
752974cb
...
@@ -15,22 +15,22 @@
...
@@ -15,22 +15,22 @@
#include "async_es_agent.h"
#include "async_es_agent.h"
namespace
DeepES
{
namespace
DeepES
{
Async
Agent
::
Async
Agent
(
Async
ESAgent
::
AsyncES
Agent
(
std
::
shared_ptr
<
PaddlePredictor
>
predictor
,
std
::
shared_ptr
<
PaddlePredictor
>
predictor
,
std
::
string
config_path
)
:
ESAgent
(
predictor
,
config_path
)
{
std
::
string
config_path
)
:
ESAgent
(
predictor
,
config_path
)
{
_config_path
=
config_path
;
_config_path
=
config_path
;
}
}
Async
Agent
::~
Async
Agent
()
{
Async
ESAgent
::~
AsyncES
Agent
()
{
for
(
const
auto
kv
:
_param_delta
)
{
for
(
const
auto
kv
:
_param_delta
)
{
float
*
delta
=
kv
.
second
;
float
*
delta
=
kv
.
second
;
delete
[]
delta
;
delete
[]
delta
;
}
}
}
}
bool
AsyncAgent
::
_save
()
{
bool
Async
ES
Agent
::
_save
()
{
bool
success
=
true
;
bool
success
=
true
;
if
(
_is_sampling_agent
)
{
if
(
_is_sampling_agent
)
{
LOG
(
ERROR
)
<<
"[DeepES] Original Async
Agent cannot call add_noise function, please use cloned Async
Agent."
;
LOG
(
ERROR
)
<<
"[DeepES] Original Async
ESAgent cannot call add_noise function, please use cloned AsyncES
Agent."
;
success
=
false
;
success
=
false
;
return
success
;
return
success
;
}
}
...
@@ -55,7 +55,7 @@ bool AsyncAgent::_save() {
...
@@ -55,7 +55,7 @@ bool AsyncAgent::_save() {
async_es
->
set_model_iter_id
(
model_iter_id
);
async_es
->
set_model_iter_id
(
model_iter_id
);
success
=
save_proto_conf
(
_config_path
,
*
_config
);
success
=
save_proto_conf
(
_config_path
,
*
_config
);
if
(
!
success
)
{
if
(
!
success
)
{
LOG
(
ERROR
)
<<
"[]unable to save config for AsyncAgent"
;
LOG
(
ERROR
)
<<
"[]unable to save config for Async
ES
Agent"
;
success
=
false
;
success
=
false
;
return
success
;
return
success
;
}
}
...
@@ -64,7 +64,7 @@ bool AsyncAgent::_save() {
...
@@ -64,7 +64,7 @@ bool AsyncAgent::_save() {
return
success
;
return
success
;
}
}
bool
AsyncAgent
::
_remove_expired_model
(
int
max_to_keep
)
{
bool
Async
ES
Agent
::
_remove_expired_model
(
int
max_to_keep
)
{
bool
success
=
true
;
bool
success
=
true
;
std
::
string
model_path
=
_config
->
async_es
().
model_warehouse
();
std
::
string
model_path
=
_config
->
async_es
().
model_warehouse
();
std
::
vector
<
std
::
string
>
model_dirs
=
list_all_model_dirs
(
model_path
);
std
::
vector
<
std
::
string
>
model_dirs
=
list_all_model_dirs
(
model_path
);
...
@@ -86,7 +86,7 @@ bool AsyncAgent::_remove_expired_model(int max_to_keep) {
...
@@ -86,7 +86,7 @@ bool AsyncAgent::_remove_expired_model(int max_to_keep) {
return
success
;
return
success
;
}
}
bool
AsyncAgent
::
_compute_model_diff
()
{
bool
Async
ES
Agent
::
_compute_model_diff
()
{
bool
success
=
true
;
bool
success
=
true
;
for
(
const
auto
&
kv
:
_previous_predictors
)
{
for
(
const
auto
&
kv
:
_previous_predictors
)
{
int
model_iter_id
=
kv
.
first
;
int
model_iter_id
=
kv
.
first
;
...
@@ -108,7 +108,7 @@ bool AsyncAgent::_compute_model_diff() {
...
@@ -108,7 +108,7 @@ bool AsyncAgent::_compute_model_diff() {
return
success
;
return
success
;
}
}
bool
AsyncAgent
::
_load
()
{
bool
Async
ES
Agent
::
_load
()
{
bool
success
=
true
;
bool
success
=
true
;
std
::
string
model_path
=
_config
->
async_es
().
model_warehouse
();
std
::
string
model_path
=
_config
->
async_es
().
model_warehouse
();
std
::
vector
<
std
::
string
>
model_dirs
=
list_all_model_dirs
(
model_path
);
std
::
vector
<
std
::
string
>
model_dirs
=
list_all_model_dirs
(
model_path
);
...
@@ -140,7 +140,7 @@ bool AsyncAgent::_load() {
...
@@ -140,7 +140,7 @@ bool AsyncAgent::_load() {
return
success
;
return
success
;
}
}
std
::
shared_ptr
<
PaddlePredictor
>
AsyncAgent
::
_load_previous_model
(
std
::
string
model_dir
)
{
std
::
shared_ptr
<
PaddlePredictor
>
Async
ES
Agent
::
_load_previous_model
(
std
::
string
model_dir
)
{
// 1. Create CxxConfig
// 1. Create CxxConfig
CxxConfig
config
;
CxxConfig
config
;
config
.
set_model_file
(
model_dir
+
"/model"
);
config
.
set_model_file
(
model_dir
+
"/model"
);
...
@@ -155,10 +155,10 @@ std::shared_ptr<PaddlePredictor> AsyncAgent::_load_previous_model(std::string mo
...
@@ -155,10 +155,10 @@ std::shared_ptr<PaddlePredictor> AsyncAgent::_load_previous_model(std::string mo
return
predictor
;
return
predictor
;
}
}
std
::
shared_ptr
<
Async
Agent
>
Async
Agent
::
clone
()
{
std
::
shared_ptr
<
Async
ESAgent
>
AsyncES
Agent
::
clone
()
{
std
::
shared_ptr
<
PaddlePredictor
>
new_sampling_predictor
=
_predictor
->
Clone
();
std
::
shared_ptr
<
PaddlePredictor
>
new_sampling_predictor
=
_predictor
->
Clone
();
std
::
shared_ptr
<
Async
Agent
>
new_agent
=
std
::
make_shared
<
Async
Agent
>
();
std
::
shared_ptr
<
Async
ESAgent
>
new_agent
=
std
::
make_shared
<
AsyncES
Agent
>
();
float
*
noise
=
new
float
[
_param_size
];
float
*
noise
=
new
float
[
_param_size
];
...
@@ -175,7 +175,7 @@ std::shared_ptr<AsyncAgent> AsyncAgent::clone() {
...
@@ -175,7 +175,7 @@ std::shared_ptr<AsyncAgent> AsyncAgent::clone() {
return
new_agent
;
return
new_agent
;
}
}
bool
AsyncAgent
::
update
(
bool
Async
ES
Agent
::
update
(
std
::
vector
<
SamplingInfo
>&
noisy_info
,
std
::
vector
<
SamplingInfo
>&
noisy_info
,
std
::
vector
<
float
>&
noisy_rewards
)
{
std
::
vector
<
float
>&
noisy_rewards
)
{
...
@@ -237,7 +237,7 @@ bool AsyncAgent::update(
...
@@ -237,7 +237,7 @@ bool AsyncAgent::update(
return
true
;
return
true
;
}
}
int
AsyncAgent
::
_parse_model_iter_id
(
const
std
::
string
&
model_path
)
{
int
Async
ES
Agent
::
_parse_model_iter_id
(
const
std
::
string
&
model_path
)
{
int
model_iter_id
=
-
1
;
int
model_iter_id
=
-
1
;
int
pow
=
1
;
int
pow
=
1
;
for
(
int
i
=
model_path
.
size
()
-
1
;
i
>=
0
;
--
i
)
{
for
(
int
i
=
model_path
.
size
()
-
1
;
i
>=
0
;
--
i
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录