Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PARL
提交
a8caedbe
P
PARL
项目概览
PaddlePaddle
/
PARL
通知
67
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
18
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PARL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
18
Issue
18
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a8caedbe
编写于
4月 01, 2020
作者:
Z
zhoubo01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove depedence on predictor.clone()
上级
bc4c9c43
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
64 addition
and
58 deletion
+64
-58
deepes/demo/paddle/cartpole_async_solver.cc
deepes/demo/paddle/cartpole_async_solver.cc
+1
-16
deepes/demo/paddle/cartpole_solver_parallel.cc
deepes/demo/paddle/cartpole_solver_parallel.cc
+2
-16
deepes/demo/paddle/gen_cartpole_init_model.py
deepes/demo/paddle/gen_cartpole_init_model.py
+2
-0
deepes/include/paddle/async_es_agent.h
deepes/include/paddle/async_es_agent.h
+5
-3
deepes/include/paddle/es_agent.h
deepes/include/paddle/es_agent.h
+7
-6
deepes/include/utils.h
deepes/include/utils.h
+3
-0
deepes/src/paddle/async_es_agent.cc
deepes/src/paddle/async_es_agent.cc
+6
-5
deepes/src/paddle/es_agent.cc
deepes/src/paddle/es_agent.cc
+26
-12
deepes/src/utils.cc
deepes/src/utils.cc
+12
-0
未找到文件。
deepes/demo/paddle/cartpole_async_solver.cc
浏览文件 @
a8caedbe
...
@@ -24,20 +24,6 @@ using namespace paddle::lite_api;
...
@@ -24,20 +24,6 @@ using namespace paddle::lite_api;
const
int
ITER
=
10
;
const
int
ITER
=
10
;
std
::
shared_ptr
<
PaddlePredictor
>
create_paddle_predictor
(
const
std
::
string
&
model_dir
)
{
// 1. Create CxxConfig
CxxConfig
config
;
config
.
set_model_dir
(
model_dir
);
config
.
set_valid_places
({
Place
{
TARGET
(
kX86
),
PRECISION
(
kFloat
)},
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}
});
// 2. Create PaddlePredictor by CxxConfig
std
::
shared_ptr
<
PaddlePredictor
>
predictor
=
CreatePaddlePredictor
<
CxxConfig
>
(
config
);
return
predictor
;
}
// Use PaddlePredictor of CartPole model to predict the action.
// Use PaddlePredictor of CartPole model to predict the action.
std
::
vector
<
float
>
forward
(
std
::
shared_ptr
<
PaddlePredictor
>
predictor
,
const
float
*
obs
)
{
std
::
vector
<
float
>
forward
(
std
::
shared_ptr
<
PaddlePredictor
>
predictor
,
const
float
*
obs
)
{
std
::
unique_ptr
<
Tensor
>
input_tensor
(
std
::
move
(
predictor
->
GetInput
(
0
)));
std
::
unique_ptr
<
Tensor
>
input_tensor
(
std
::
move
(
predictor
->
GetInput
(
0
)));
...
@@ -86,8 +72,7 @@ int main(int argc, char* argv[]) {
...
@@ -86,8 +72,7 @@ int main(int argc, char* argv[]) {
envs
.
push_back
(
CartPole
());
envs
.
push_back
(
CartPole
());
}
}
std
::
shared_ptr
<
PaddlePredictor
>
paddle_predictor
=
create_paddle_predictor
(
"../demo/paddle/cartpole_init_model"
);
std
::
shared_ptr
<
AsyncESAgent
>
agent
=
std
::
make_shared
<
AsyncESAgent
>
(
"../demo/paddle/cartpole_init_model"
,
"../benchmark/cartpole_config.prototxt"
);
std
::
shared_ptr
<
AsyncESAgent
>
agent
=
std
::
make_shared
<
AsyncESAgent
>
(
paddle_predictor
,
"../benchmark/cartpole_config.prototxt"
);
// Clone agents to sample (explore).
// Clone agents to sample (explore).
std
::
vector
<
std
::
shared_ptr
<
AsyncESAgent
>
>
sampling_agents
;
std
::
vector
<
std
::
shared_ptr
<
AsyncESAgent
>
>
sampling_agents
;
...
...
deepes/demo/paddle/cartpole_solver_parallel.cc
浏览文件 @
a8caedbe
...
@@ -24,20 +24,6 @@ using namespace paddle::lite_api;
...
@@ -24,20 +24,6 @@ using namespace paddle::lite_api;
const
int
ITER
=
10
;
const
int
ITER
=
10
;
std
::
shared_ptr
<
PaddlePredictor
>
create_paddle_predictor
(
const
std
::
string
&
model_dir
)
{
// 1. Create CxxConfig
CxxConfig
config
;
config
.
set_model_dir
(
model_dir
);
config
.
set_valid_places
({
Place
{
TARGET
(
kX86
),
PRECISION
(
kFloat
)},
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}
});
// 2. Create PaddlePredictor by CxxConfig
std
::
shared_ptr
<
PaddlePredictor
>
predictor
=
CreatePaddlePredictor
<
CxxConfig
>
(
config
);
return
predictor
;
}
// Use PaddlePredictor of CartPole model to predict the action.
// Use PaddlePredictor of CartPole model to predict the action.
std
::
vector
<
float
>
forward
(
std
::
shared_ptr
<
PaddlePredictor
>
predictor
,
const
float
*
obs
)
{
std
::
vector
<
float
>
forward
(
std
::
shared_ptr
<
PaddlePredictor
>
predictor
,
const
float
*
obs
)
{
std
::
unique_ptr
<
Tensor
>
input_tensor
(
std
::
move
(
predictor
->
GetInput
(
0
)));
std
::
unique_ptr
<
Tensor
>
input_tensor
(
std
::
move
(
predictor
->
GetInput
(
0
)));
...
@@ -86,8 +72,8 @@ int main(int argc, char* argv[]) {
...
@@ -86,8 +72,8 @@ int main(int argc, char* argv[]) {
envs
.
push_back
(
CartPole
());
envs
.
push_back
(
CartPole
());
}
}
std
::
shared_ptr
<
PaddlePredictor
>
paddle_predictor
=
create_paddle_predictor
(
"../demo/paddle/cartpole_init_model"
);
//
std::shared_ptr<PaddlePredictor> paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model");
std
::
shared_ptr
<
ESAgent
>
agent
=
std
::
make_shared
<
ESAgent
>
(
paddle_predictor
,
"../benchmark/cartpole_config.prototxt"
);
std
::
shared_ptr
<
ESAgent
>
agent
=
std
::
make_shared
<
ESAgent
>
(
"../demo/paddle/cartpole_init_model"
,
"../benchmark/cartpole_config.prototxt"
);
// Clone agents to sample (explore).
// Clone agents to sample (explore).
std
::
vector
<
std
::
shared_ptr
<
ESAgent
>
>
sampling_agents
;
std
::
vector
<
std
::
shared_ptr
<
ESAgent
>
>
sampling_agents
;
...
...
deepes/demo/paddle/gen_cartpole_init_model.py
浏览文件 @
a8caedbe
...
@@ -36,4 +36,6 @@ if __name__ == '__main__':
...
@@ -36,4 +36,6 @@ if __name__ == '__main__':
dirname
=
'cartpole_init_model'
,
dirname
=
'cartpole_init_model'
,
feeded_var_names
=
[
'obs'
],
feeded_var_names
=
[
'obs'
],
target_vars
=
[
prob
],
target_vars
=
[
prob
],
params_filename
=
'param'
,
model_filename
=
'model'
,
executor
=
exe
)
executor
=
exe
)
deepes/include/paddle/async_es_agent.h
浏览文件 @
a8caedbe
...
@@ -28,7 +28,9 @@ namespace DeepES{
...
@@ -28,7 +28,9 @@ namespace DeepES{
*/
*/
class
AsyncESAgent
:
public
ESAgent
{
class
AsyncESAgent
:
public
ESAgent
{
public:
public:
AsyncESAgent
()
{}
AsyncESAgent
()
=
delete
;
AsyncESAgent
(
const
CxxConfig
&
cxx_config
);
~
AsyncESAgent
();
~
AsyncESAgent
();
...
@@ -40,8 +42,8 @@ class AsyncESAgent: public ESAgent {
...
@@ -40,8 +42,8 @@ class AsyncESAgent: public ESAgent {
* Please use the up-to-date configuration.
* Please use the up-to-date configuration.
*/
*/
AsyncESAgent
(
AsyncESAgent
(
std
::
shared_ptr
<
PaddlePredictor
>
predicto
r
,
const
std
::
string
&
model_di
r
,
std
::
string
config_path
);
const
std
::
string
&
config_path
);
/**
/**
* @brief: Clone an agent for sampling.
* @brief: Clone an agent for sampling.
...
...
deepes/include/paddle/es_agent.h
浏览文件 @
a8caedbe
...
@@ -38,14 +38,14 @@ int64_t ShapeProduction(const shape_t& shape);
...
@@ -38,14 +38,14 @@ int64_t ShapeProduction(const shape_t& shape);
*/
*/
class
ESAgent
{
class
ESAgent
{
public:
public:
ESAgent
();
ESAgent
()
=
delete
;
~
ESAgent
();
~
ESAgent
();
ESAgent
(
std
::
shared_ptr
<
PaddlePredictor
>
predictor
,
ESAgent
(
const
std
::
string
&
model_dir
,
const
std
::
string
&
config_path
);
std
::
string
config_path
);
ESAgent
(
const
CxxConfig
&
cxx_config
);
/**
/**
* @breif Clone a sampling agent
* @breif Clone a sampling agent
*
*
...
@@ -83,15 +83,16 @@ class ESAgent {
...
@@ -83,15 +83,16 @@ class ESAgent {
std
::
shared_ptr
<
PaddlePredictor
>
_predictor
;
std
::
shared_ptr
<
PaddlePredictor
>
_predictor
;
std
::
shared_ptr
<
PaddlePredictor
>
_sampling_predictor
;
std
::
shared_ptr
<
PaddlePredictor
>
_sampling_predictor
;
bool
_is_sampling_agent
;
std
::
shared_ptr
<
SamplingMethod
>
_sampling_method
;
std
::
shared_ptr
<
SamplingMethod
>
_sampling_method
;
std
::
shared_ptr
<
Optimizer
>
_optimizer
;
std
::
shared_ptr
<
Optimizer
>
_optimizer
;
std
::
shared_ptr
<
DeepESConfig
>
_config
;
std
::
shared_ptr
<
DeepESConfig
>
_config
;
int64_t
_param_size
;
std
::
shared_ptr
<
CxxConfig
>
_cxx_config
;
std
::
vector
<
std
::
string
>
_param_names
;
std
::
vector
<
std
::
string
>
_param_names
;
// malloc memory of noise and neg_gradients in advance.
// malloc memory of noise and neg_gradients in advance.
float
*
_noise
;
float
*
_noise
;
float
*
_neg_gradients
;
float
*
_neg_gradients
;
int64_t
_param_size
;
bool
_is_sampling_agent
;
};
};
}
}
...
...
deepes/include/utils.h
浏览文件 @
a8caedbe
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include <glog/logging.h>
#include <glog/logging.h>
#include "deepes.pb.h"
#include "deepes.pb.h"
#include <google/protobuf/text_format.h>
#include <google/protobuf/text_format.h>
#include <fstream>
namespace
DeepES
{
namespace
DeepES
{
...
@@ -29,6 +30,8 @@ namespace DeepES{
...
@@ -29,6 +30,8 @@ namespace DeepES{
*/
*/
bool
compute_centered_ranks
(
std
::
vector
<
float
>
&
reward
);
bool
compute_centered_ranks
(
std
::
vector
<
float
>
&
reward
);
std
::
string
read_file
(
const
std
::
string
&
filename
);
/* Load a protobuf-based configuration from the file.
/* Load a protobuf-based configuration from the file.
* Args:
* Args:
* config_file: file path.
* config_file: file path.
...
...
deepes/src/paddle/async_es_agent.cc
浏览文件 @
a8caedbe
...
@@ -16,8 +16,8 @@
...
@@ -16,8 +16,8 @@
namespace
DeepES
{
namespace
DeepES
{
AsyncESAgent
::
AsyncESAgent
(
AsyncESAgent
::
AsyncESAgent
(
std
::
shared_ptr
<
PaddlePredictor
>
predicto
r
,
const
std
::
string
&
model_di
r
,
std
::
string
config_path
)
:
ESAgent
(
predicto
r
,
config_path
)
{
const
std
::
string
&
config_path
)
:
ESAgent
(
model_di
r
,
config_path
)
{
_config_path
=
config_path
;
_config_path
=
config_path
;
}
}
AsyncESAgent
::~
AsyncESAgent
()
{
AsyncESAgent
::~
AsyncESAgent
()
{
...
@@ -154,15 +154,16 @@ std::shared_ptr<PaddlePredictor> AsyncESAgent::_load_previous_model(std::string
...
@@ -154,15 +154,16 @@ std::shared_ptr<PaddlePredictor> AsyncESAgent::_load_previous_model(std::string
return
predictor
;
return
predictor
;
}
}
AsyncESAgent
::
AsyncESAgent
(
const
CxxConfig
&
cxx_config
)
:
ESAgent
(
cxx_config
){
}
std
::
shared_ptr
<
AsyncESAgent
>
AsyncESAgent
::
clone
()
{
std
::
shared_ptr
<
AsyncESAgent
>
AsyncESAgent
::
clone
()
{
std
::
shared_ptr
<
PaddlePredictor
>
new_sampling_predictor
=
_predictor
->
Clone
();
std
::
shared_ptr
<
AsyncESAgent
>
new_agent
=
std
::
make_shared
<
AsyncESAgent
>
();
std
::
shared_ptr
<
AsyncESAgent
>
new_agent
=
std
::
make_shared
<
AsyncESAgent
>
(
*
_cxx_config
);
float
*
noise
=
new
float
[
_param_size
];
float
*
noise
=
new
float
[
_param_size
];
new_agent
->
_predictor
=
_predictor
;
new_agent
->
_predictor
=
_predictor
;
new_agent
->
_sampling_predictor
=
new_sampling_predictor
;
new_agent
->
_is_sampling_agent
=
true
;
new_agent
->
_is_sampling_agent
=
true
;
new_agent
->
_sampling_method
=
_sampling_method
;
new_agent
->
_sampling_method
=
_sampling_method
;
...
...
deepes/src/paddle/es_agent.cc
浏览文件 @
a8caedbe
...
@@ -23,22 +23,31 @@ int64_t ShapeProduction(const shape_t& shape) {
...
@@ -23,22 +23,31 @@ int64_t ShapeProduction(const shape_t& shape) {
return
res
;
return
res
;
}
}
ESAgent
::
ESAgent
()
{}
ESAgent
::~
ESAgent
()
{
ESAgent
::~
ESAgent
()
{
delete
[]
_noise
;
delete
[]
_noise
;
if
(
!
_is_sampling_agent
)
if
(
!
_is_sampling_agent
)
delete
[]
_neg_gradients
;
delete
[]
_neg_gradients
;
}
}
ESAgent
::
ESAgent
(
ESAgent
::
ESAgent
(
const
std
::
string
&
model_dir
,
const
std
::
string
&
config_path
)
{
std
::
shared_ptr
<
PaddlePredictor
>
predictor
,
// 1. Create CxxConfig
std
::
string
config_path
)
{
_cxx_config
=
std
::
make_shared
<
CxxConfig
>
();
std
::
string
model_path
=
model_dir
+
"/model"
;
std
::
string
param_path
=
model_dir
+
"/param"
;
std
::
string
model_buffer
=
read_file
(
model_path
);
std
::
string
param_buffer
=
read_file
(
param_path
);
_cxx_config
->
set_model_buffer
(
model_buffer
.
c_str
(),
model_buffer
.
size
(),
param_buffer
.
c_str
(),
param_buffer
.
size
());
_cxx_config
->
set_valid_places
({
Place
{
TARGET
(
kX86
),
PRECISION
(
kFloat
)},
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}
});
_predictor
=
CreatePaddlePredictor
<
CxxConfig
>
(
*
_cxx_config
);
_is_sampling_agent
=
false
;
_is_sampling_agent
=
false
;
_predictor
=
predictor
;
// Original agent can't be used to sample, so keep it same with _predictor for evaluating.
// Original agent can't be used to sample, so keep it same with _predictor for evaluating.
_sampling_predictor
=
predictor
;
_sampling_predictor
=
_
predictor
;
_config
=
std
::
make_shared
<
DeepESConfig
>
();
_config
=
std
::
make_shared
<
DeepESConfig
>
();
load_proto_conf
(
config_path
,
*
_config
);
load_proto_conf
(
config_path
,
*
_config
);
...
@@ -55,16 +64,21 @@ ESAgent::ESAgent(
...
@@ -55,16 +64,21 @@ ESAgent::ESAgent(
_neg_gradients
=
new
float
[
_param_size
];
_neg_gradients
=
new
float
[
_param_size
];
}
}
std
::
shared_ptr
<
ESAgent
>
ESAgent
::
clone
()
{
ESAgent
::
ESAgent
(
const
CxxConfig
&
cxx_config
)
{
std
::
shared_ptr
<
PaddlePredictor
>
new_sampling_predictor
=
_predictor
->
Clone
();
_sampling_predictor
=
CreatePaddlePredictor
<
CxxConfig
>
(
cxx_config
);
}
std
::
shared_ptr
<
ESAgent
>
new_agent
=
std
::
make_shared
<
ESAgent
>
();
std
::
shared_ptr
<
ESAgent
>
ESAgent
::
clone
()
{
if
(
_is_sampling_agent
)
{
LOG
(
ERROR
)
<<
"[DeepES] only original ESAgent can call `clone` function."
;
return
nullptr
;
}
std
::
shared_ptr
<
ESAgent
>
new_agent
=
std
::
make_shared
<
ESAgent
>
(
*
_cxx_config
);
float
*
noise
=
new
float
[
_param_size
];
float
*
noise
=
new
float
[
_param_size
];
new_agent
->
_predictor
=
_predictor
;
new_agent
->
_predictor
=
_predictor
;
new_agent
->
_sampling_predictor
=
new_sampling_predictor
;
new_agent
->
_cxx_config
=
_cxx_config
;
new_agent
->
_is_sampling_agent
=
true
;
new_agent
->
_is_sampling_agent
=
true
;
new_agent
->
_sampling_method
=
_sampling_method
;
new_agent
->
_sampling_method
=
_sampling_method
;
new_agent
->
_param_names
=
_param_names
;
new_agent
->
_param_names
=
_param_names
;
...
...
deepes/src/utils.cc
浏览文件 @
a8caedbe
...
@@ -52,4 +52,16 @@ std::vector<std::string> list_all_model_dirs(std::string path) {
...
@@ -52,4 +52,16 @@ std::vector<std::string> list_all_model_dirs(std::string path) {
return
model_dirs
;
return
model_dirs
;
}
}
std
::
string
read_file
(
const
std
::
string
&
filename
)
{
std
::
ifstream
ifile
(
filename
.
c_str
());
if
(
!
ifile
.
is_open
())
{
LOG
(
FATAL
)
<<
"Open file: ["
<<
filename
<<
"] failed."
;
}
std
::
ostringstream
buf
;
char
ch
;
while
(
buf
&&
ifile
.
get
(
ch
))
buf
.
put
(
ch
);
ifile
.
close
();
return
buf
.
str
();
}
}
//namespace
}
//namespace
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录