Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
6f0dfd89
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6f0dfd89
编写于
3月 16, 2018
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Single GPU ParallelExecutor complete
上级
d84ddcf1
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
173 addition
and
34 deletion
+173
-34
CMakeLists.txt
CMakeLists.txt
+1
-0
cmake/external/threadpool.cmake
cmake/external/threadpool.cmake
+30
-0
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+133
-32
paddle/fluid/framework/parallel_executor.h
paddle/fluid/framework/parallel_executor.h
+4
-0
paddle/fluid/operators/read_op.cc
paddle/fluid/operators/read_op.cc
+4
-1
未找到文件。
CMakeLists.txt
浏览文件 @
6f0dfd89
...
@@ -146,6 +146,7 @@ include(external/cares)
...
@@ -146,6 +146,7 @@ include(external/cares)
include
(
external/grpc
)
include
(
external/grpc
)
include
(
external/snappy
)
# download snappy
include
(
external/snappy
)
# download snappy
include
(
external/snappystream
)
include
(
external/snappystream
)
include
(
external/threadpool
)
include
(
cudnn
)
# set cudnn libraries, must before configure
include
(
cudnn
)
# set cudnn libraries, must before configure
include
(
cupti
)
include
(
cupti
)
...
...
cmake/external/threadpool.cmake
0 → 100644
浏览文件 @
6f0dfd89
INCLUDE
(
ExternalProject
)
SET
(
THREADPOOL_SOURCE_DIR
${
THIRD_PARTY_PATH
}
/threadpool
)
SET
(
THREADPOOL_INCLUDE_DIR
${
THREADPOOL_SOURCE_DIR
}
/src/extern_threadpool
)
INCLUDE_DIRECTORIES
(
${
THREADPOOL_INCLUDE_DIR
}
)
ExternalProject_Add
(
extern_threadpool
${
EXTERNAL_PROJECT_LOG_ARGS
}
GIT_REPOSITORY
"https://github.com/progschj/ThreadPool.git"
GIT_TAG 9a42ec1329f259a5f4881a291db1dcb8f2ad9040
PREFIX
${
THREADPOOL_SOURCE_DIR
}
UPDATE_COMMAND
""
CONFIGURE_COMMAND
""
BUILD_COMMAND
""
INSTALL_COMMAND
""
TEST_COMMAND
""
)
if
(
${
CMAKE_VERSION
}
VERSION_LESS
"3.3.0"
)
set
(
dummyfile
${
CMAKE_CURRENT_BINARY_DIR
}
/threadpool_dummy.c
)
file
(
WRITE
${
dummyfile
}
"const char *dummy_threadpool =
\"
${
dummyfile
}
\"
;"
)
add_library
(
simple_threadpool STATIC
${
dummyfile
}
)
else
()
add_library
(
simple_threadpool INTERFACE
)
endif
()
add_dependencies
(
simple_threadpool extern_threadpool
)
LIST
(
APPEND external_project_dependencies simple_threadpool
)
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
6f0dfd89
...
@@ -87,7 +87,7 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo
...
@@ -87,7 +87,7 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table feed_fetch_method
)
framework_proto backward glog lod_rank_table feed_fetch_method
)
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table feed_fetch_method executor
)
framework_proto backward glog lod_rank_table feed_fetch_method executor
simple_threadpool
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
6f0dfd89
...
@@ -13,9 +13,10 @@ See the License for the specific language governing permissions and
...
@@ -13,9 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/framework/parallel_executor.h"
#include "ThreadPool.h"
#include "executor.h"
#include "lod_tensor.h"
#include "lod_tensor.h"
#include "op_registry.h"
#include "op_registry.h"
#include "threadpool.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -49,7 +50,7 @@ struct VarHandle : public VarHandleBase {
...
@@ -49,7 +50,7 @@ struct VarHandle : public VarHandleBase {
};
};
struct
DependencyVarHandle
:
public
VarHandleBase
{
struct
DependencyVarHandle
:
public
VarHandleBase
{
std
::
string
DebugString
()
const
override
{
return
"Dep
s var
"
;
}
std
::
string
DebugString
()
const
override
{
return
"Dep
endency Variable
"
;
}
};
};
struct
OpHandle
{
struct
OpHandle
{
...
@@ -75,7 +76,7 @@ struct OpHandle {
...
@@ -75,7 +76,7 @@ struct OpHandle {
virtual
~
OpHandle
()
{}
virtual
~
OpHandle
()
{}
virtual
void
Run
()
{}
virtual
void
Run
()
{
PADDLE_THROW
(
"Not implemented"
);
}
virtual
void
Wait
()
{}
virtual
void
Wait
()
{}
};
};
...
@@ -84,14 +85,15 @@ struct ComputationOpHandle : public OpHandle {
...
@@ -84,14 +85,15 @@ struct ComputationOpHandle : public OpHandle {
Scope
*
scope_
;
Scope
*
scope_
;
platform
::
Place
place_
;
platform
::
Place
place_
;
explicit
ComputationOpHandle
(
const
OpDesc
&
op_desc
,
platform
::
Place
place
)
explicit
ComputationOpHandle
(
const
OpDesc
&
op_desc
,
Scope
*
scope
,
platform
::
Place
place
)
:
op_
(
framework
::
OpRegistry
::
CreateOp
(
op_desc
)),
:
op_
(
framework
::
OpRegistry
::
CreateOp
(
op_desc
)),
scope_
(
nullptr
),
scope_
(
scope
),
place_
(
place
)
{}
place_
(
place
)
{}
void
Run
()
override
{
void
Run
()
override
{
// Wait other op if necessary
// Wait other op if necessary
LOG
(
INFO
)
<<
DebugString
();
LOG
(
INFO
)
<<
"Run "
<<
this
<<
" "
<<
DebugString
();
auto
*
cur_ctx
=
dev_ctx_
[
place_
];
auto
*
cur_ctx
=
dev_ctx_
[
place_
];
for
(
auto
*
in
:
inputs_
)
{
for
(
auto
*
in
:
inputs_
)
{
if
(
in
->
generated_op_
&&
in
->
generated_op_
->
dev_ctx_
[
place_
]
!=
cur_ctx
)
{
if
(
in
->
generated_op_
&&
in
->
generated_op_
->
dev_ctx_
[
place_
]
!=
cur_ctx
)
{
...
@@ -100,12 +102,49 @@ struct ComputationOpHandle : public OpHandle {
...
@@ -100,12 +102,49 @@ struct ComputationOpHandle : public OpHandle {
}
}
op_
->
Run
(
*
scope_
,
place_
);
op_
->
Run
(
*
scope_
,
place_
);
LOG
(
INFO
)
<<
"Done "
<<
this
;
}
}
};
};
struct
ScaleLossGradOpHandle
:
public
OpHandle
{};
struct
ScaleLossGradOpHandle
:
public
OpHandle
{
float
coeff_
;
Scope
*
scope_
;
platform
::
Place
place_
;
explicit
ScaleLossGradOpHandle
(
size_t
num_dev
,
Scope
*
scope
,
platform
::
Place
place
)
:
coeff_
(
static_cast
<
float
>
(
1.0
/
num_dev
)),
scope_
(
scope
),
place_
(
place
)
{}
void
Run
()
override
{
LOG
(
INFO
)
<<
"Run Scale Loss Grad"
;
std
::
string
var_name
=
static_cast
<
VarHandle
*>
(
this
->
outputs_
[
0
])
->
name_
;
struct
NCCLAllReduceOpHandle
:
public
OpHandle
{};
float
*
tmp
=
scope_
->
FindVar
(
var_name
)
->
GetMutable
<
framework
::
LoDTensor
>
()
->
mutable_data
<
float
>
(
make_ddim
({
1
}),
place_
);
if
(
platform
::
is_cpu_place
(
place_
))
{
*
tmp
=
coeff_
;
}
else
{
memory
::
Copy
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place_
),
tmp
,
platform
::
CPUPlace
(),
&
coeff_
,
sizeof
(
float
),
static_cast
<
platform
::
CUDADeviceContext
*>
(
this
->
dev_ctx_
[
place_
])
->
stream
());
}
}
};
struct
NCCLAllReduceOpHandle
:
public
OpHandle
{
void
Run
()
override
{
if
(
this
->
inputs_
.
size
()
==
1
)
{
return
;
// No need to all reduce when GPU count = 1;
}
}
};
class
ParallelExecutorPrivate
{
class
ParallelExecutorPrivate
{
public:
public:
...
@@ -182,7 +221,10 @@ class ParallelExecutorPrivate {
...
@@ -182,7 +221,10 @@ class ParallelExecutorPrivate {
std
::
vector
<
std
::
unique_ptr
<
OpHandle
>>
ops_
;
std
::
vector
<
std
::
unique_ptr
<
OpHandle
>>
ops_
;
// Use a simpler thread pool, might be faster.
ThreadPool
pool_
;
ThreadPool
pool_
;
std
::
unique_ptr
<
platform
::
EnforceNotMet
>
exception_
;
};
};
// TODO(yy): Move this function somewhere
// TODO(yy): Move this function somewhere
...
@@ -217,6 +259,19 @@ ParallelExecutor::ParallelExecutor(
...
@@ -217,6 +259,19 @@ ParallelExecutor::ParallelExecutor(
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
// ncclOp
ConstructDependencyGraph
(
params
,
main_program
,
loss_var_name
);
ConstructDependencyGraph
(
params
,
main_program
,
loss_var_name
);
// Step 3. Create vars in each scope;
for
(
auto
&
pair
:
member_
->
local_scopes_
)
{
auto
*
scope
=
pair
.
second
;
for
(
auto
*
var
:
main_program
.
Block
(
0
).
AllVars
())
{
if
(
scope
->
FindVar
(
var
->
Name
())
!=
nullptr
)
{
continue
;
}
InitializeVariable
(
scope
->
Var
(
var
->
Name
()),
var
->
GetType
());
}
}
}
}
void
ParallelExecutor
::
ConstructDependencyGraph
(
void
ParallelExecutor
::
ConstructDependencyGraph
(
...
@@ -240,7 +295,8 @@ void ParallelExecutor::ConstructDependencyGraph(
...
@@ -240,7 +295,8 @@ void ParallelExecutor::ConstructDependencyGraph(
}
}
for
(
auto
&
pair
:
member_
->
local_scopes_
)
{
for
(
auto
&
pair
:
member_
->
local_scopes_
)
{
member_
->
ops_
.
emplace_back
(
new
ComputationOpHandle
(
*
op
,
pair
.
first
));
member_
->
ops_
.
emplace_back
(
new
ComputationOpHandle
(
*
op
,
pair
.
second
,
pair
.
first
));
auto
*
op_handle
=
member_
->
ops_
.
back
().
get
();
auto
*
op_handle
=
member_
->
ops_
.
back
().
get
();
op_handle
->
dev_ctx_
[
pair
.
first
]
=
const_cast
<
platform
::
DeviceContext
*>
(
op_handle
->
dev_ctx_
[
pair
.
first
]
=
const_cast
<
platform
::
DeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
pair
.
first
));
platform
::
DeviceContextPool
::
Instance
().
Get
(
pair
.
first
));
...
@@ -263,16 +319,20 @@ void ParallelExecutor::ConstructDependencyGraph(
...
@@ -263,16 +319,20 @@ void ParallelExecutor::ConstructDependencyGraph(
if
(
is_forwarding
)
{
if
(
is_forwarding
)
{
if
(
var_names
.
size
()
==
1
&&
var_names
[
0
]
==
loss_var_name
)
{
if
(
var_names
.
size
()
==
1
&&
var_names
[
0
]
==
loss_var_name
)
{
// Insert ScaleCost OpHandle
// Insert ScaleCost OpHandle
member_
->
ops_
.
emplace_back
(
new
ScaleLossGradOpHandle
());
member_
->
ops_
.
emplace_back
(
new
ScaleLossGradOpHandle
(
this
->
member_
->
local_scopes_
.
size
(),
pair
.
second
,
pair
.
first
));
op_handle
=
member_
->
ops_
.
back
().
get
();
op_handle
=
member_
->
ops_
.
back
().
get
();
op_handle
->
dev_ctx_
[
pair
.
first
]
=
op_handle
->
dev_ctx_
[
pair
.
first
]
=
member_
->
CommunicationDevCtx
(
pair
.
first
);
member_
->
CommunicationDevCtx
(
pair
.
first
);
auto
&
place
=
pair
.
first
;
auto
&
place
=
pair
.
first
;
VarHandle
*
loss
=
GetVarHandle
(
loss_var_name
,
place
);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
loss
->
pending_ops_
.
emplace_back
(
op_handle
);
// factor. So it does not depend on any other operators.
op_handle
->
inputs_
.
emplace_back
(
loss
);
// VarHandle *loss = GetVarHandle(loss_var_name, place);
// loss->pending_ops_.emplace_back(op_handle);
// op_handle->inputs_.emplace_back(loss);
GenerateVar
(
op_handle
,
loss_var_name
+
"@GRAD"
,
place
);
GenerateVar
(
op_handle
,
loss_var_name
+
"@GRAD"
,
place
);
change_forward
=
true
;
change_forward
=
true
;
LOG
(
INFO
)
<<
"Scale Loss "
<<
op_handle
->
DebugString
();
LOG
(
INFO
)
<<
"Scale Loss "
<<
op_handle
->
DebugString
();
...
@@ -341,11 +401,25 @@ void ParallelExecutor::ConstructDependencyGraph(
...
@@ -341,11 +401,25 @@ void ParallelExecutor::ConstructDependencyGraph(
for
(;
it_old
!=
name_pair
.
second
.
rend
();
it_new
=
it_old
,
++
it_old
)
{
for
(;
it_old
!=
name_pair
.
second
.
rend
();
it_new
=
it_old
,
++
it_old
)
{
auto
*
write_op
=
it_new
->
second
.
generated_op_
;
auto
*
write_op
=
it_new
->
second
.
generated_op_
;
auto
&
read_ops
=
it_old
->
second
.
pending_ops_
;
auto
&
read_ops
=
it_old
->
second
.
pending_ops_
;
auto
*
ex_write_op
=
it_old
->
second
.
generated_op_
;
if
(
ex_write_op
==
nullptr
)
{
// Nobody write this var.
continue
;
}
LOG
(
INFO
)
<<
"Link "
<<
it_new
->
second
.
DebugString
()
<<
" From "
<<
it_old
->
second
.
version_
<<
" To "
<<
it_new
->
second
.
version_
;
for
(
auto
*
read_op
:
read_ops
)
{
for
(
auto
*
read_op
:
read_ops
)
{
// Manually add a dependency var from read_op to write_op;
// Manually add a dependency var from read_op to write_op;
if
(
read_op
==
write_op
)
{
// Read Write is the same op.
continue
;
}
auto
*
dep_var
=
new
DependencyVarHandle
();
auto
*
dep_var
=
new
DependencyVarHandle
();
dep_var
->
generated_op_
=
read_op
;
dep_var
->
generated_op_
=
read_op
;
read_op
->
outputs_
.
emplace_back
(
dep_var
);
read_op
->
outputs_
.
emplace_back
(
dep_var
);
...
@@ -448,7 +522,7 @@ void ParallelExecutor::BuildNCCLCommunicator() const {
...
@@ -448,7 +522,7 @@ void ParallelExecutor::BuildNCCLCommunicator() const {
std
::
vector
<
LoDTensor
>
ParallelExecutor
::
Run
(
std
::
vector
<
LoDTensor
>
ParallelExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
{
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
{
// Version --> VarHandle
// Version --> VarHandle
member_
->
exception_
.
reset
();
std
::
unordered_map
<
VarHandleBase
*
,
bool
>
pending_vars
;
std
::
unordered_map
<
VarHandleBase
*
,
bool
>
pending_vars
;
std
::
unordered_map
<
OpHandle
*
,
size_t
>
pending_ops
;
std
::
unordered_map
<
OpHandle
*
,
size_t
>
pending_ops
;
...
@@ -465,8 +539,18 @@ std::vector<LoDTensor> ParallelExecutor::Run(
...
@@ -465,8 +539,18 @@ std::vector<LoDTensor> ParallelExecutor::Run(
pending_vars
[
var
.
get
()]
=
var
->
generated_op_
==
nullptr
;
pending_vars
[
var
.
get
()]
=
var
->
generated_op_
==
nullptr
;
}
}
std
::
vector
<
OpHandle
*>
to_run
;
for
(
auto
&
op
:
member_
->
ops_
)
{
for
(
auto
&
op
:
member_
->
ops_
)
{
pending_ops
.
insert
({
op
.
get
(),
op
->
inputs_
.
size
()});
if
(
op
->
inputs_
.
empty
())
{
// Special case, Op has no input.
to_run
.
emplace_back
(
op
.
get
());
}
else
{
pending_ops
.
insert
({
op
.
get
(),
op
->
inputs_
.
size
()});
}
}
for
(
auto
*
op
:
to_run
)
{
RunOp
(
pending_vars
,
op
);
}
}
while
(
!
pending_ops
.
empty
())
{
while
(
!
pending_ops
.
empty
())
{
...
@@ -478,13 +562,19 @@ std::vector<LoDTensor> ParallelExecutor::Run(
...
@@ -478,13 +562,19 @@ std::vector<LoDTensor> ParallelExecutor::Run(
}
}
if
(
ready_var
==
nullptr
)
{
if
(
ready_var
==
nullptr
)
{
member_
->
pool_
.
Wait
();
// Wait thread pool;
// FIXME use conditional var instead of busy wait.
if
(
member_
->
exception_
)
{
throw
*
member_
->
exception_
;
}
std
::
this_thread
::
yield
();
continue
;
continue
;
}
}
pending_vars
.
erase
(
ready_var
);
pending_vars
.
erase
(
ready_var
);
std
::
vector
<
OpHandle
*>
to_run
;
to_run
.
clear
()
;
for
(
auto
*
op
:
ready_var
->
pending_ops_
)
{
for
(
auto
*
op
:
ready_var
->
pending_ops_
)
{
auto
&
deps
=
pending_ops
[
op
];
auto
&
deps
=
pending_ops
[
op
];
...
@@ -496,24 +586,35 @@ std::vector<LoDTensor> ParallelExecutor::Run(
...
@@ -496,24 +586,35 @@ std::vector<LoDTensor> ParallelExecutor::Run(
for
(
auto
*
op
:
to_run
)
{
for
(
auto
*
op
:
to_run
)
{
pending_ops
.
erase
(
op
);
pending_ops
.
erase
(
op
);
RunOp
(
pending_vars
,
op
);
std
::
vector
<
bool
*>
ready_buffer
;
for
(
auto
*
var
:
op
->
outputs_
)
{
ready_buffer
.
emplace_back
(
&
pending_vars
[
var
]);
}
auto
op_run
=
[
ready_buffer
,
op
]
{
// TODO(yy) Check Previous Op has same dev ctx.
op
->
Run
();
for
(
auto
*
ready
:
ready_buffer
)
{
*
ready
=
true
;
}
};
member_
->
pool_
.
Run
(
op_run
);
}
}
}
}
return
std
::
vector
<
LoDTensor
>
();
return
std
::
vector
<
LoDTensor
>
();
}
}
void
ParallelExecutor
::
RunOp
(
std
::
unordered_map
<
VarHandleBase
*
,
bool
>
&
pending_vars
,
OpHandle
*
op
)
const
{
std
::
vector
<
bool
*>
ready_buffer
;
for
(
auto
*
var
:
op
->
outputs_
)
{
ready_buffer
.
emplace_back
(
&
pending_vars
[
var
]);
}
auto
op_run
=
[
ready_buffer
,
op
,
this
]
{
try
{
// TODO(yy) Check Previous Op has same dev ctx.
op
->
Run
();
for
(
auto
*
ready
:
ready_buffer
)
{
*
ready
=
true
;
}
}
catch
(
platform
::
EnforceNotMet
ex
)
{
member_
->
exception_
.
reset
(
new
platform
::
EnforceNotMet
(
ex
));
}
catch
(...)
{
LOG
(
FATAL
)
<<
"Unknown exception catched"
;
}
};
member_
->
pool_
.
enqueue
(
op_run
);
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/parallel_executor.h
浏览文件 @
6f0dfd89
...
@@ -31,6 +31,7 @@ namespace framework {
...
@@ -31,6 +31,7 @@ namespace framework {
class
ParallelExecutorPrivate
;
class
ParallelExecutorPrivate
;
class
VarHandle
;
class
VarHandle
;
class
OpHandle
;
class
OpHandle
;
class
VarHandleBase
;
class
ParallelExecutor
{
class
ParallelExecutor
{
public:
public:
explicit
ParallelExecutor
(
const
std
::
vector
<
platform
::
Place
>&
places
,
explicit
ParallelExecutor
(
const
std
::
vector
<
platform
::
Place
>&
places
,
...
@@ -57,6 +58,9 @@ class ParallelExecutor {
...
@@ -57,6 +58,9 @@ class ParallelExecutor {
const
std
::
string
&
loss_var_name
)
const
;
const
std
::
string
&
loss_var_name
)
const
;
void
BuildNCCLCommunicator
()
const
;
void
BuildNCCLCommunicator
()
const
;
void
RunOp
(
std
::
unordered_map
<
VarHandleBase
*
,
bool
>&
pending_vars
,
OpHandle
*
op
)
const
;
};
};
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/operators/read_op.cc
浏览文件 @
6f0dfd89
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -59,7 +60,9 @@ class ReadOp : public framework::OperatorBase {
...
@@ -59,7 +60,9 @@ class ReadOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
framework
::
ReaderHolder
*
reader
=
framework
::
ReaderHolder
*
reader
=
scope
.
FindVar
(
Input
(
"Reader"
))
->
GetMutable
<
framework
::
ReaderHolder
>
();
detail
::
Ref
(
scope
.
FindVar
(
Input
(
"Reader"
)),
"Cannot find reader variable %s"
,
Input
(
"Reader"
))
.
GetMutable
<
framework
::
ReaderHolder
>
();
std
::
vector
<
std
::
string
>
out_arg_names
=
Outputs
(
"Out"
);
std
::
vector
<
std
::
string
>
out_arg_names
=
Outputs
(
"Out"
);
std
::
vector
<
framework
::
LoDTensor
>
ins
;
std
::
vector
<
framework
::
LoDTensor
>
ins
;
reader
->
ReadNext
(
&
ins
);
reader
->
ReadNext
(
&
ins
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录