Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PARL
提交
0698534b
P
PARL
项目概览
PaddlePaddle
/
PARL
通知
68
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
18
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PARL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
18
Issue
18
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0698534b
编写于
4月 27, 2020
作者:
X
xiaoyao4573
提交者:
GitHub
4月 27, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update docs (#253)
* update docs * update docs * update * update * update * update * update * update
上级
0a068653
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
167 addition
and
88 deletion
+167
-88
docs/EvoKit/minimal_example.rst
docs/EvoKit/minimal_example.rst
+89
-66
docs/EvoKit/online_example.rst
docs/EvoKit/online_example.rst
+78
-22
未找到文件。
docs/EvoKit/minimal_example.rst
浏览文件 @
0698534b
...
@@ -37,7 +37,7 @@ step1: 生成预测网络
...
@@ -37,7 +37,7 @@ step1: 生成预测网络
exe = fluid.Executor(fluid.CPUPlace())
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
exe.run(fluid.default_startup_program())
fluid.io.save_inference_model(
fluid.io.save_inference_model(
dirname='
cartpole_
init_model',
dirname='init_model',
feeded_var_names=['obs'],
feeded_var_names=['obs'],
target_vars=[prob],
target_vars=[prob],
params_filename='params',
params_filename='params',
...
@@ -46,51 +46,70 @@ step1: 生成预测网络
...
@@ -46,51 +46,70 @@ step1: 生成预测网络
step2: 构造ESAgent
step2: 构造ESAgent
###################
###################
- 根据配置文件构造一个ESAgent
- 调用 ``load_inference_model`` 函数加载模型参数
- 调用 ``load_config`` 加载配置文件。
- 调用 ``load_inference_model`` 函数加载模型参数。
- 调用 ``init_solver`` 初始化solver。
配置文件主要是用于指定进化算法类型(比如Gaussian或者CMA),使用的optimizer类型(Adam或者SGD)。
配置文件主要是用于指定进化算法类型(比如Gaussian或者CMA),使用的optimizer类型(Adam或者SGD)。
.. code-block:: c++
.. code-block:: c++
ESAgent agent = ESAgent(config_path);
ESAgent agent = ESAgent();
agent->load_inference_model(model_dir);
agent.load_config(config);
agent.load_inference_model(model_dir);
agent.init_solver();
// 附:EvoKit配置项示范
solver {
type: BASIC_ES
optimizer { // 线下Adam更新
type: ADAM
base_lr: 0.05
adam {
beta1: 0.9
beta2: 0.999
epsilon: 1e-08
}
}
sampling { // 线上高斯采样
type: GAUSSIAN_SAMPLING
gaussian_sampling {
std: 0.5
cached: true
seed: 1024
cache_size : 100000
}
}
}
//附:DeepES配置项示范
seed: 1024 //随机种子,用于复现
gaussian_sampling { //高斯采样相关参数
std: 0.5
}
optimizer { //离线更新所用的optimizer 类型以及相关超级参数
type: "Adam"
base_lr: 0.05
}
step3: 生成用于采样的Agent
step3: 生成用于采样的Agent
###################
###################
主要关注三个接口:
主要关注三个接口:
-
clone():
生成一个用于sampling的agent。
-
调用 ``clone``
生成一个用于sampling的agent。
-
add_noise():
给这个agent的参数空间增加噪声,同时返回该噪声对应的唯一信息,这个信息得记录在log中,用于线下更新。
-
调用 ``add_noise``
给这个agent的参数空间增加噪声,同时返回该噪声对应的唯一信息,这个信息得记录在log中,用于线下更新。
-
predict():
提供预测接口。
-
调用 ``predict``
提供预测接口。
.. code-block:: c++
.. code-block:: c++
auto sampling_agent = agent.clone();
auto sampling_agent = agent.clone();
auto sampling_info = sampling_agent.add_noise();
auto sampling_info = sampling_agent.add_noise();
sampling_agent.predict(feature);
sampling_agent.predict(feature);
step4: 用采样的数据更新模型参数
step4: 用采样的数据更新模型参数
###################
###################
用户提供两组数据:
用户提供两组数据:
- 采样参数过程中用于线下复现采样噪声的key
- 采样参数过程中用于线下复现采样噪声的sampling_info
- 扰动参数后,新参数的评估结果
- 扰动参数后,新参数的评估结果
.. code-block:: c++
.. code-block:: c++
agent.update(info
, rewards);
agent.update(sampling_infos
, rewards);
主代码以及注释
主代码以及注释
#################
#################
...
@@ -99,55 +118,59 @@ step4: 用采样的数据更新模型参数
...
@@ -99,55 +118,59 @@ step4: 用采样的数据更新模型参数
.. code-block:: c++
.. code-block:: c++
int main(int argc, char* argv[]) {
int main(int argc, char* argv[]) {
std::vector<CartPole> envs;
std::vector<CartPole> envs;
// 构造10个环境,用于多线程训练
// 构造10个环境,用于多线程训练
for (int i = 0; i < ITER; ++i) {
for (int i = 0; i < ITER; ++i) {
envs.push_back(CartPole());
envs.push_back(CartPole());
}
}
// 初始化ESAgent
// 初始化ESAgent
std::string model_dir = "./demo/cartpole_init_model";
std::string model_dir = "./demo/cartpole/init_model";
std::string config_path = "./demo/cartpole_config.prototxt";
std::string config_path = "./demo/cartpole/config.prototxt";
std::shared_ptr<ESAgent> agent = std::make_shared<ESAgent>(model_dir, config_path);
std::shared_ptr<ESAgent> agent = std::make_shared<ESAgent>();
agent->load_config(config_path); // 加载配置
// 生成10个agent用于同时采样
std::vector< std::shared_ptr<ESAgent> > sampling_agents;
agent->load_inference_model(FLAGS_model_dir); // 加载初始预测模型
for (int i = 0; i < ITER; ++i) {
agent->init_solver(); // 初始化solver,注意要在load_inference_model后执行
sampling_agents.push_back(agent->clone());
}
// 生成10个agent用于同时采样
std::vector<std::shared_ptr<ESAgent>> sampling_agents;
std::vector<SamplingInfo> noisy_keys;
for (int i = 0; i < ITER; ++i) {
std::vector<float> noisy_rewards(ITER, 0.0f);
sampling_agents.push_back(agent->clone());
noisy_keys.resize(ITER);
}
omp_set_num_threads(10);
std::vector<SamplingInfo> sampling_infos;
// 共迭代100轮
std::vector<float> rewards(ITER, 0.0f);
for (int epoch = 0; epoch < 100; ++epoch) {
sampling_infos.resize(ITER);
#pragma omp parallel for schedule(dynamic, 1)
omp_set_num_threads(10);
for (int i = 0; i < ITER; ++i) {
std::shared_ptr<ESAgent> sampling_agent = sampling_agents[i];
// 共迭代100轮
SamplingInfo key;
for (int epoch = 0; epoch < 100; ++epoch) {
bool success = sampling_agent->add_noise(key);
#pragma omp parallel for schedule(dynamic, 1)
float reward = evaluate(envs[i], sampling_agent);
for (int i = 0; i < ITER; ++i) {
// 保存采样的key以及对应的评估结果
std::shared_ptr<ESAgent> sampling_agent = sampling_agents[i];
noisy_keys[i] = key;
SamplingInfo sampling_info;
noisy_rewards[i] = reward;
sampling_agent->add_noise(sampling_info);
}
float reward = evaluate(envs[i], sampling_agent);
// 更新模型参数,注意:参数更新后会自动同步到sampling_agent中
// 保存采样的sampling_info以及对应的评估结果reward
bool success = agent->update(noisy_keys, noisy_rewards);
sampling_infos[i] = sampling_info;
rewards[i] = reward;
int reward = evaluate(envs[0], agent);
}
LOG(INFO) << "Epoch:" << epoch << " Reward: " << reward;
// 更新模型参数,注意:参数更新后会自动同步到sampling_agent中
}
agent->update(sampling_infos, rewards);
}
int reward = evaluate(envs[0], agent);
LOG(INFO) << "Epoch:" << epoch << " Reward: " << reward; // 打印每一轮reward
}
}
如何运行demo
如何运行demo
#################
#################
- 下载代码
- 下载代码
在icode上clone代码,我们的仓库路径是: ``baidu/nlp/deep-es``
在icode上clone代码,我们的仓库路径是: ``baidu/nlp/deep-es``
``TO DO: 修改库路径``
- 编译demo
- 编译demo
...
@@ -159,7 +182,7 @@ step4: 用采样的数据更新模型参数
...
@@ -159,7 +182,7 @@ step4: 用采样的数据更新模型参数
``export LD_LIBRARY_PATH=./output/so/:$LD_LIBRARY_PATH``
``export LD_LIBRARY_PATH=./output/so/:$LD_LIBRARY_PATH``
运行demo: ``./output/bin/
evokit_demo
``
运行demo: ``./output/bin/
cartpole/train
``
问题解决
问题解决
####################
####################
...
...
docs/EvoKit/online_example.rst
浏览文件 @
0698534b
...
@@ -3,66 +3,122 @@ Example for Online Products
...
@@ -3,66 +3,122 @@ Example for Online Products
``本教程的目标: 演示通过EvoKit库上线后,如何迭代算法,更新模型参数。``
``本教程的目标: 演示通过EvoKit库上线后,如何迭代算法,更新模型参数。``
在
实际的
产品线中,线上无法实时拿到用户日志,经常是通过保存用户点击/时长日志,在线下根据用户数据更新模型,然后再推送到线上,完成算法的更新。
在产品线中,线上无法实时拿到用户日志,经常是通过保存用户点击/时长日志,在线下根据用户数据更新模型,然后再推送到线上,完成算法的更新。
本教程继续围绕经典的CartPole环境,展示如何通过在线采样/离线更新的方式,来更新迭代ES算法。
本教程继续围绕经典的CartPole环境,展示如何通过在线采样/离线更新的方式,来更新迭代ES算法。
demo的完整代码示例放在demp/online_example文件夹中。
demo的完整代码示例放在demp/online_example文件夹中。
``TO DO: 文件夹``
线上采样
初始化solver
---------------------
---------------------
构造solver,对它初始化,并保存到文件。初始化solver仅需在开始时调用一次。
.. code-block:: c++
这部分的逻辑与上一个demo极度相似,主要的区别是采样返回的key以及评估的reward通过二进制的方式记录到log文件中。
std::shared_ptr<ESAgent> agent = std::make_shared<ESAgent>();
agent->load_config(FLAGS_config_path);
agent->load_inference_model(FLAGS_model_dir);
agent->init_solver();
agent->save_solver(FLAGS_model_dir);
线上采样
---------------------
加载模型和solver,记录线上采样返回的sampling_info以及评估的reward,并通过二进制的方式记录到log文件中。
.. code-block:: c++
.. code-block:: c++
std::shared_ptr<ESAgent> agent = std::make_shared<ESAgent>();
agent->load_config(FLAGS_config_path);
agent->load_inference_model(FLAGS_model_dir);
agent->load_solver(FLAGS_model_dir);
#pragma omp parallel for schedule(dynamic, 1)
#pragma omp parallel for schedule(dynamic, 1)
for (int i = 0; i < ITER; ++i) {
for (int i = 0; i < ITER; ++i) {
std::shared_ptr<ESAgent> sampling_agent = sampling_agents[i];
std::shared_ptr<ESAgent> sampling_agent = sampling_agents[i];
SamplingInfo
key
;
SamplingInfo
sampling_info
;
bool success = sampling_agent->add_noise(key
);
sampling_agent->add_noise(sampling_info
);
float reward = evaluate(envs[i], sampling_agent);
float reward = evaluate(envs[i], sampling_agent);
noisy_keys[i] = key
;
sampling_infos[i] = sampling_info
;
noisy_
rewards[i] = reward;
rewards[i] = reward;
}
}
// save sampling information and log in binary fomrat
// save sampling information and log in binary fomrat
std::ofstream log_stream(FLAGS_log_path, std::ios::binary);
std::ofstream log_stream(FLAGS_log_path, std::ios::binary);
for (int i = 0; i < ITER; ++i) {
for (int i = 0; i < ITER; ++i) {
std::string data;
std::string data;
noisy_key
s[i].SerializeToString(&data);
sampling_info
s[i].SerializeToString(&data);
int size = data.size();
int size = data.size();
log_stream.write((char*) &noisy_
rewards[i], sizeof(float));
log_stream.write((char*) &
rewards[i], sizeof(float));
log_stream.write((char*) &size, sizeof(int));
log_stream.write((char*) &size, sizeof(int));
log_stream.write(data.c_str(), size);
log_stream.write(data.c_str(), size);
}
}
log_stream.close();
log_stream.close();
线下更新
线下更新
-----------------------
-----------------------
在加载好之前记录的log之后,调用 ``update`` 函数进行更新,然后通过 ``save_inference_model`` 和 ``save_solver`` 函数保存更新后的参数到本地,推送到线上。
在加载好之前记录的log之后,调用 ``update`` 函数进行更新,然后通过 ``save_inference_model`` 函数保存更新后的参数到本地,推送到线上。
.. code-block:: c++
.. code-block:: c++
std::shared_ptr<ESAgent> agent = std::make_shared<ESAgent>();
agent->load_config(FLAGS_config_path);
agent->load_inference_model(FLAGS_model_dir);
agent->load_solver(FLAGS_model_dir);
// load training data
// load training data
std::vector<SamplingInfo>
noisy_key
s;
std::vector<SamplingInfo>
sampling_info
s;
std::vector<float>
noisy_
rewards(ITER, 0.0f);
std::vector<float> rewards(ITER, 0.0f);
noisy_key
s.resize(ITER);
sampling_info
s.resize(ITER);
std::ifstream log_stream(FLAGS_log_path);
std::ifstream log_stream(FLAGS_log_path);
CHECK(log_stream.good()) << "[EvoKit] cannot open log: " << FLAGS_log_path;
CHECK(log_stream.good()) << "[EvoKit] cannot open log: " << FLAGS_log_path;
char buffer[1000];
char buffer[1000];
for (int i = 0; i < ITER; ++i) {
for (int i = 0; i < ITER; ++i) {
int size;
int size;
log_stream.read((char*) &
noisy_
rewards[i], sizeof(float));
log_stream.read((char*) &rewards[i], sizeof(float));
log_stream.read((char*) &size, sizeof(int));
log_stream.read((char*) &size, sizeof(int));
log_stream.read(buffer, size);
log_stream.read(buffer, size);
buffer[size] = 0;
buffer[size] = 0;
std::string data(buffer);
std::string data(buffer);
noisy_key
s[i].ParseFromString(data);
sampling_info
s[i].ParseFromString(data);
}
}
// update model and save parameter
// update model and save parameter
agent->update(
noisy_keys, noisy_
rewards);
agent->update(
sampling_infos,
rewards);
agent->save_inference_model(FLAGS_updated_model_dir);
agent->save_inference_model(FLAGS_updated_model_dir);
agent->save_solver(FLAGS_updated_model_dir);
主代码
-----------------------
将以上代码分别编译成可执行文件。
- 初始化solver: ``init_solver`` 。
- 线上采样: ``online_sampling`` 。
- 线下更新: ``offline update`` 。
.. code-block:: shell
#------------------------init solver------------------------
./init_solver \
--model_dir="./model_warehouse/model_dir_0" \
--config_path="config.prototxt"
for ((epoch=0;epoch<200;++epoch));do
#------------------------online sampling------------------------
./online_sampling \
--log_path="./sampling_log" \
--model_dir="./model_warehouse/model_dir_$epoch" \
--config_path="./config.prototxt"
#------------------------offline update------------------------
next_epoch=$((epoch+1))
./offline_update \
--log_path='./sampling_log' \
--model_dir="./model_warehouse/model_dir_$epoch" \
--updated_model_dir="./model_warehouse/model_dir_${next_epoch}" \
--config_path="./config.prototxt"
done
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录