Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PARL
提交
2f6d1e10
P
PARL
项目概览
PaddlePaddle
/
PARL
通知
67
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
18
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PARL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
18
Issue
18
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
2f6d1e10
编写于
3月 31, 2020
作者:
Z
zhoubo01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix comments
上级
752974cb
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
19 addition
and
19 deletion
+19
-19
deepes/benchmark/cartpole_config.prototxt
deepes/benchmark/cartpole_config.prototxt
+1
-1
deepes/include/paddle/async_es_agent.h
deepes/include/paddle/async_es_agent.h
+6
-3
deepes/include/paddle/es_agent.h
deepes/include/paddle/es_agent.h
+1
-3
deepes/include/utils.h
deepes/include/utils.h
+4
-3
deepes/scripts/build.sh
deepes/scripts/build.sh
+0
-2
deepes/src/paddle/async_es_agent.cc
deepes/src/paddle/async_es_agent.cc
+7
-4
deepes/src/paddle/es_agent.cc
deepes/src/paddle/es_agent.cc
+0
-3
未找到文件。
deepes/benchmark/cartpole_config.prototxt
浏览文件 @
2f6d1e10
...
...
@@ -11,5 +11,5 @@ optimizer {
epsilon: 1e-08
}
async_es {
model_iter_id:
0
model_iter_id:
99
}
deepes/include/paddle/async_es_agent.h
浏览文件 @
2f6d1e10
...
...
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef
_
ASYNC_ES_AGENT_H
#define
_
ASYNC_ES_AGENT_H
#ifndef ASYNC_ES_AGENT_H
#define ASYNC_ES_AGENT_H
#include "es_agent.h"
#include <map>
...
...
@@ -49,7 +49,10 @@ class AsyncESAgent: public ESAgent {
std
::
shared_ptr
<
AsyncESAgent
>
clone
();
/**
* @brief: Clone an agent for sampling.
* @brief: update parameters given data collected during evaluation.
* @args:
* noisy_info: sampling information returned by add_noise function.
* noisy_reward: evaluation rewards.
*/
bool
update
(
std
::
vector
<
SamplingInfo
>&
noisy_info
,
...
...
deepes/include/paddle/es_agent.h
浏览文件 @
2f6d1e10
...
...
@@ -21,15 +21,13 @@
#include "gaussian_sampling.h"
#include "deepes.pb.h"
#include <vector>
using
namespace
paddle
::
lite_api
;
using
namespace
paddle
::
lite_api
;
namespace
DeepES
{
int64_t
ShapeProduction
(
const
shape_t
&
shape
);
typedef
paddle
::
lite_api
::
PaddlePredictor
PaddlePredictor
;
/**
* @brief DeepES agent with PaddleLite as backend.
* Users mainly focus on the following functions:
...
...
deepes/include/utils.h
浏览文件 @
2f6d1e10
...
...
@@ -40,7 +40,7 @@ bool load_proto_conf(const std::string& config_file, T& proto_config) {
bool
success
=
true
;
std
::
ifstream
fin
(
config_file
);
if
(
!
fin
||
fin
.
fail
())
{
LOG
(
FATAL
)
<<
"open prototxt config failed: "
<<
config_file
;
LOG
(
ERROR
)
<<
"open prototxt config failed: "
<<
config_file
;
success
=
false
;
}
else
{
fin
.
seekg
(
0
,
std
::
ios
::
end
);
...
...
@@ -52,7 +52,7 @@ bool load_proto_conf(const std::string& config_file, T& proto_config) {
std
::
string
proto_str
(
file_content_buffer
,
file_size
);
if
(
!
google
::
protobuf
::
TextFormat
::
ParseFromString
(
proto_str
,
&
proto_config
))
{
LOG
(
FATAL
)
<<
"Failed to load config: "
<<
config_file
;
LOG
(
ERROR
)
<<
"Failed to load config: "
<<
config_file
;
success
=
false
;
}
delete
[]
file_content_buffer
;
...
...
@@ -66,7 +66,7 @@ bool save_proto_conf(const std::string& config_file, T&proto_config) {
bool
success
=
true
;
std
::
ofstream
ofs
(
config_file
,
std
::
ofstream
::
out
);
if
(
!
ofs
||
ofs
.
fail
())
{
LOG
(
FATAL
)
<<
"open prototxt config failed: "
<<
config_file
;
LOG
(
ERROR
)
<<
"open prototxt config failed: "
<<
config_file
;
success
=
false
;
}
else
{
std
::
string
config_str
;
...
...
@@ -76,6 +76,7 @@ bool save_proto_conf(const std::string& config_file, T&proto_config) {
}
ofs
<<
config_str
;
}
return
success
;
}
std
::
vector
<
std
::
string
>
list_all_model_dirs
(
std
::
string
path
);
...
...
deepes/scripts/build.sh
浏览文件 @
2f6d1e10
...
...
@@ -32,8 +32,6 @@ else
exit
0
fi
#export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
#----------------protobuf-------------#
cp
./src/proto/deepes.proto ./
protoc deepes.proto
--cpp_out
./
...
...
deepes/src/paddle/async_es_agent.cc
浏览文件 @
2f6d1e10
...
...
@@ -30,7 +30,7 @@ AsyncESAgent::~AsyncESAgent() {
bool
AsyncESAgent
::
_save
()
{
bool
success
=
true
;
if
(
_is_sampling_agent
)
{
LOG
(
ERROR
)
<<
"[DeepES] Original AsyncESAgent cannot call
add_noise function, p
lease use cloned AsyncESAgent."
;
LOG
(
ERROR
)
<<
"[DeepES] Original AsyncESAgent cannot call
`save`.P
lease use cloned AsyncESAgent."
;
success
=
false
;
return
success
;
}
...
...
@@ -49,7 +49,7 @@ bool AsyncESAgent::_save() {
model_name
=
"model_iter_id-"
+
std
::
to_string
(
model_iter_id
);
std
::
string
model_path
=
_config
->
async_es
().
model_warehouse
()
+
"/"
+
model_name
;
LOG
(
INFO
)
<<
"[save]model_path: "
<<
model_path
;
_predictor
->
SaveOptimizedModel
(
model_path
,
LiteModelType
::
kProtobuf
);
_predictor
->
SaveOptimizedModel
(
model_path
,
paddle
::
lite_api
::
LiteModelType
::
kProtobuf
);
// save config
auto
async_es
=
_config
->
mutable_async_es
();
async_es
->
set_model_iter_id
(
model_iter_id
);
...
...
@@ -93,15 +93,17 @@ bool AsyncESAgent::_compute_model_diff() {
std
::
shared_ptr
<
PaddlePredictor
>
old_predictor
=
kv
.
second
;
float
*
diff
=
new
float
[
_param_size
];
memset
(
diff
,
0
,
_param_size
*
sizeof
(
float
));
for
(
std
::
string
param_name
:
_param_names
)
{
int
offset
=
0
;
for
(
const
std
::
string
&
param_name
:
_param_names
)
{
auto
des_tensor
=
old_predictor
->
GetTensor
(
param_name
);
auto
src_tensor
=
_predictor
->
GetTensor
(
param_name
);
const
float
*
des_data
=
des_tensor
->
data
<
float
>
();
const
float
*
src_data
=
src_tensor
->
data
<
float
>
();
int64_t
tensor_size
=
ShapeProduction
(
src_tensor
->
shape
());
for
(
int
i
=
0
;
i
<
tensor_size
;
++
i
)
{
diff
[
i
]
=
des_data
[
i
]
-
src_data
[
i
];
diff
[
i
+
offset
]
=
des_data
[
i
]
-
src_data
[
i
];
}
offset
+=
tensor_size
;
}
_param_delta
[
model_iter_id
]
=
diff
;
}
...
...
@@ -206,6 +208,7 @@ bool AsyncESAgent::update(
float
reward
=
noisy_rewards
[
i
];
int
model_iter_id
=
noisy_info
[
i
].
model_iter_id
();
bool
success
=
_sampling_method
->
resampling
(
key
,
_noise
,
_param_size
);
CHECK
(
success
)
<<
"[DeepES] resampling error occurs at sample: "
<<
i
;
float
*
delta
=
_param_delta
[
model_iter_id
];
// compute neg_gradients
if
(
model_iter_id
==
current_model_iter_id
)
{
...
...
deepes/src/paddle/es_agent.cc
浏览文件 @
2f6d1e10
...
...
@@ -17,9 +17,6 @@
namespace
DeepES
{
typedef
paddle
::
lite_api
::
Tensor
Tensor
;
typedef
paddle
::
lite_api
::
shape_t
shape_t
;
int64_t
ShapeProduction
(
const
shape_t
&
shape
)
{
int64_t
res
=
1
;
for
(
auto
i
:
shape
)
res
*=
i
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录