Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
147fe8cb
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
147fe8cb
编写于
8月 01, 2019
作者:
X
xiexionghang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add dataset
上级
8b7e1ed1
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
79 addition
and
10 deletion
+79
-10
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
+75
-0
paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h
...uid/train/custom_trainer/feed/dataset/dataset_container.h
+3
-10
paddle/fluid/train/custom_trainer/feed/executor/executor.cc
paddle/fluid/train/custom_trainer/feed/executor/executor.cc
+1
-0
未找到文件。
paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h
0 → 100644
浏览文件 @
147fe8cb
/* DataReader
* 对指定数据的读取
*/
#pragma once
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/channel.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
class
TrainerContext
;
struct
FeatureItem
{
uint64_t
feature_sign
;
uint16_t
slot_id
;
};
struct
SampleInstance
{
std
::
string
id
;
std
::
vector
<
float
>
lables
;
std
::
vector
<
FeatureItem
>
features
;
std
::
vector
<
float
>
embedx
;
};
class
DataItem
{
public:
DataItem
()
{}
virtual
~
DataItem
()
{}
std
::
string
id
;
//样本id标识,可用于shuffle
std
::
string
data
;
//样本数据, maybe压缩格式
};
class
DataParser
{
public:
DataParser
()
{}
virtual
~
DataParser
()
{}
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
=
0
;
virtual
int
parse
(
const
std
::
string
&
str
,
DataItem
&
data
)
const
{
return
parse
(
str
.
c_str
(),
str
.
size
(),
data
);
}
virtual
int
parse
(
const
char
*
str
,
size_t
len
,
DataItem
&
data
)
const
=
0
;
virtual
int
parse_to_sample
(
const
DataItem
&
data
,
SampleInstance
&
instance
)
const
=
0
;
};
REGISTER_REGISTERER
(
DataParser
);
class
DataReader
{
public:
DataReader
()
{}
virtual
~
DataReader
()
{}
virtual
int
initialize
(
const
YAML
::
Node
&
config
,
std
::
shared_ptr
<
TrainerContext
>
context
)
=
0
;
//判断样本数据是否已就绪,就绪表明可以开始download
virtual
bool
is_data_ready
(
const
std
::
string
&
data_dir
)
=
0
;
//读取数据样本流中
virtual
int
read_all
(
const
std
::
string
&
data_dir
,
::
paddle
::
framework
::
Channel
<
DataItem
>
data_channel
)
=
0
;
virtual
const
DataParser
*
get_parser
()
{
return
_parser
.
get
();
}
private:
std
::
shared_ptr
<
DataParser
>
_parser
;
std
::
string
_data_converter_shell
;
};
REGISTER_REGISTERER
(
DataReader
);
//TODO
//HDFS/DISK Reader
}
//namespace feed
}
//namespace custom_trainer
}
//namespace paddle
paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h
浏览文件 @
147fe8cb
...
...
@@ -9,27 +9,20 @@
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
#include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h"
namespace
paddle
{
namespace
custom_trainer
{
namespace
feed
{
//单条样本的原始数据
class
DataItem
{
public:
DataItem
()
{}
virtual
~
DataItem
()
{}
std
::
string
id
;
//样本id标识,可用于shuffle
std
::
string
data
;
//样本完整数据
};
class
DatasetContainer
{
public:
DatasetContainer
()
{}
virtual
~
DatasetContainer
()
{}
virtual
int
initialize
(
const
YAML
::
Node
&
config
)
{
_dataset_config
=
config
;
//预取n轮样本数据
_prefetch_num
=
config
[
"prefetch_num"
].
as
<
int
>
();
_data_root_path
=
config
[
"root_path"
].
as
<
std
::
string
>
();
_data_path_generater
=
config
[
"_data_path_generater"
].
as
<
std
::
string
>
();
...
...
@@ -54,7 +47,7 @@ protected:
uint32_t
_current_dataset_idx
;
//当前样本数据idx
int
_current_epoch_id
=
-
1
;
int
_ready_epoch_id
=
-
1
;
//已下载完成的epoch_id
std
::
vector
<
std
::
shared_ptr
<::
paddle
::
framework
::
Dataset
>>
_dataset_list
;
std
::
vector
<
std
::
shared_ptr
<::
paddle
::
framework
::
Dataset
>>
_dataset_list
;
//预取的数据列表
};
}
//namespace feed
...
...
paddle/fluid/train/custom_trainer/feed/executor/executor.cc
浏览文件 @
147fe8cb
...
...
@@ -123,6 +123,7 @@ int SimpleExecutor::run() {
}
return
0
;
}
REGISTER_CLASS
(
Executor
,
SimpleExecutor
);
}
// namespace feed
}
// namespace custom_trainer
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录