Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
225efa67
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
225efa67
编写于
3月 12, 2018
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove dims in base class
上级
2ea4a5d9
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
12 addition
and
59 deletion
+12
-59
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+2
-18
paddle/fluid/framework/reader.cc
paddle/fluid/framework/reader.cc
+1
-9
paddle/fluid/framework/reader.h
paddle/fluid/framework/reader.h
+2
-24
paddle/fluid/operators/reader/create_random_data_generator_op.cc
...fluid/operators/reader/create_random_data_generator_op.cc
+3
-2
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
.../fluid/operators/reader/create_recordio_file_reader_op.cc
+4
-6
未找到文件。
paddle/fluid/framework/operator.cc
浏览文件 @
225efa67
...
...
@@ -442,15 +442,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
std
::
vector
<
DDim
>
GetRepeatedDims
(
const
std
::
string
&
name
)
const
override
{
Variable
*
var
=
scope_
.
FindVar
(
name
);
if
(
var
->
IsType
<
ReaderHolder
>
())
{
return
var
->
Get
<
ReaderHolder
>
().
shapes
();
}
else
{
PADDLE_THROW
(
"Only ReaderHolder support 'GetRepeatedDims', but Variable %s's "
"type_id is %s."
,
name
,
var
->
Type
().
name
());
}
PADDLE_THROW
(
"Only compile time support this method"
);
}
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
...
...
@@ -467,15 +459,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
void
SetRepeatedDims
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>&
dims
)
override
{
Variable
*
var
=
scope_
.
FindVar
(
name
);
if
(
var
->
IsType
<
ReaderHolder
>
())
{
var
->
GetMutable
<
ReaderHolder
>
()
->
set_shapes
(
dims
);
}
else
{
PADDLE_THROW
(
"Only ReaderHolder support 'SetRepeatedDims', but Variable %s's "
"type_id is %s."
,
name
,
var
->
Type
().
name
());
}
PADDLE_THROW
(
"Only compile time support this method"
);
}
proto
::
VarType
::
Type
GetVarType
(
const
std
::
string
&
name
)
const
override
{
...
...
paddle/fluid/framework/reader.cc
浏览文件 @
225efa67
...
...
@@ -16,14 +16,6 @@
namespace
paddle
{
namespace
framework
{
DDim
ReaderBase
::
shape
(
size_t
idx
)
const
{
PADDLE_ENFORCE_LT
(
idx
,
shapes_
.
size
(),
"Cannot get the %d'th shape, 'shapes_' only has %d elements."
,
idx
,
shapes_
.
size
());
return
shapes_
[
idx
];
}
ReaderBase
::~
ReaderBase
()
{}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/reader.h
浏览文件 @
225efa67
...
...
@@ -22,34 +22,18 @@ namespace framework {
class
ReaderBase
{
public:
explicit
ReaderBase
(
const
std
::
vector
<
DDim
>&
shapes
)
:
shapes_
(
shapes
)
{
PADDLE_ENFORCE
(
!
shapes_
.
empty
());
}
virtual
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
=
0
;
virtual
void
ReInit
()
=
0
;
DDim
shape
(
size_t
idx
)
const
;
std
::
vector
<
DDim
>
shapes
()
const
{
return
shapes_
;
}
void
set_shapes
(
const
std
::
vector
<
DDim
>&
shapes
)
{
shapes_
=
shapes
;
}
virtual
bool
HasNext
()
const
=
0
;
virtual
~
ReaderBase
()
{}
protected:
std
::
vector
<
DDim
>
shapes_
;
};
class
FileReader
:
public
ReaderBase
{
public:
explicit
FileReader
(
const
std
::
vector
<
DDim
>&
shapes
)
:
ReaderBase
(
shapes
)
{}
virtual
~
ReaderBase
();
};
class
DecoratedReader
:
public
ReaderBase
{
public:
explicit
DecoratedReader
(
ReaderBase
*
reader
)
:
ReaderBase
(
reader
->
shapes
()),
reader_
(
reader
)
{
explicit
DecoratedReader
(
ReaderBase
*
reader
)
:
ReaderBase
(),
reader_
(
reader
)
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
}
...
...
@@ -72,12 +56,6 @@ class ReaderHolder {
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
{
reader_
->
ReadNext
(
out
);
}
void
ReInit
()
{
reader_
->
ReInit
();
}
DDim
shape
(
size_t
idx
)
const
{
return
reader_
->
shape
(
idx
);
}
std
::
vector
<
DDim
>
shapes
()
const
{
return
reader_
->
shapes
();
}
void
set_shapes
(
const
std
::
vector
<
DDim
>&
shapes
)
{
reader_
->
set_shapes
(
shapes
);
}
bool
HasNext
()
const
{
return
reader_
->
HasNext
();
}
private:
...
...
paddle/fluid/operators/reader/create_random_data_generator_op.cc
浏览文件 @
225efa67
...
...
@@ -19,11 +19,11 @@ namespace operators {
namespace
reader
{
template
<
typename
T
>
class
RandomDataGenerator
:
public
framework
::
FileReader
{
class
RandomDataGenerator
:
public
framework
::
ReaderBase
{
public:
RandomDataGenerator
(
const
std
::
vector
<
framework
::
DDim
>&
shapes
,
float
min
,
float
max
)
:
FileReader
(
shapes
),
min_
(
min
),
max_
(
max
)
{
:
framework
::
ReaderBase
(),
min_
(
min
),
max_
(
max
),
shapes_
(
shapes
)
{
PADDLE_ENFORCE_LE
(
min
,
max
,
"'min' shouldn't be greater than 'max'.(%f vs %f)"
,
min
,
max
);
unsigned
int
seed
=
std
::
random_device
()();
...
...
@@ -59,6 +59,7 @@ class RandomDataGenerator : public framework::FileReader {
float
max_
;
std
::
minstd_rand
engine_
;
std
::
uniform_real_distribution
<
float
>
dist_
;
std
::
vector
<
framework
::
DDim
>
shapes_
;
};
template
<
typename
T
>
...
...
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
浏览文件 @
225efa67
...
...
@@ -18,11 +18,10 @@
namespace
paddle
{
namespace
operators
{
namespace
reader
{
class
RecordIOFileReader
:
public
framework
::
FileReader
{
class
RecordIOFileReader
:
public
framework
::
ReaderBase
{
public:
RecordIOFileReader
(
const
std
::
string
&
filename
,
const
std
::
vector
<
framework
::
DDim
>&
shapes
)
:
FileReader
(
shapes
),
explicit
RecordIOFileReader
(
const
std
::
string
&
filename
)
:
ReaderBase
(),
scanner_
(
filename
),
dev_ctx_
(
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()))
{}
...
...
@@ -54,12 +53,11 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
int
(
shape_concat
.
size
()),
"The accumulate of all ranks should be equal to the "
"shape concat's length."
);
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
std
::
string
filename
=
Attr
<
std
::
string
>
(
"filename"
);
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
RecordIOFileReader
(
filename
,
shapes
));
out
->
Reset
(
new
RecordIOFileReader
(
filename
));
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录