Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
1fcde8e9
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1fcde8e9
编写于
8月 01, 2019
作者:
R
rensilin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
hide SimpleExecutor
Change-Id: I08245abbb5c3fdba91ef1bb0a24871d41594c1ce
上级
2f60e4a7
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
79 addition
and
93 deletion
+79
-93
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
...luid/train/custom_trainer/feed/accessor/epoch_accessor.cc
+0
-1
paddle/fluid/train/custom_trainer/feed/common/registerer.h
paddle/fluid/train/custom_trainer/feed/common/registerer.h
+1
-1
paddle/fluid/train/custom_trainer/feed/executor/executor.cc
paddle/fluid/train/custom_trainer/feed/executor/executor.cc
+67
-69
paddle/fluid/train/custom_trainer/feed/executor/executor.h
paddle/fluid/train/custom_trainer/feed/executor/executor.h
+0
-12
paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc
...luid/train/custom_trainer/feed/unit_test/test_executor.cc
+11
-10
未找到文件。
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
浏览文件 @
1fcde8e9
#pragma once
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
namespace
paddle
{
...
...
paddle/fluid/train/custom_trainer/feed/common/registerer.h
浏览文件 @
1fcde8e9
...
...
@@ -106,7 +106,7 @@ BaseClassMap& global_factory_map_cpp();
void register_factory_##name() __attribute__((constructor));
#define CREATE_CLASS(base_class, name) \
base_class##Registerer::CreateInstanceByName(name)
;
base_class##Registerer::CreateInstanceByName(name)
}
//namespace feed
}
//namespace custom_trainer
...
...
paddle/fluid/train/custom_trainer/feed/executor/executor.cc
浏览文件 @
1fcde8e9
...
...
@@ -43,86 +43,84 @@ std::unique_ptr<paddle::framework::ProgramDesc> Load(
}
struct
SimpleExecutor
::
Context
{
Context
(
const
::
paddle
::
platform
::
Place
&
place
)
:
place
(
place
),
executor
(
place
)
{
}
const
::
paddle
::
platform
::
Place
&
place
;
::
paddle
::
framework
::
Executor
executor
;
::
std
::
unique_ptr
<::
paddle
::
framework
::
ProgramDesc
>
main_program
;
::
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
prepare_context
;
details
::
TensorArrayBatchCleaner
tensor_array_batch_cleaner
;
};
SimpleExecutor
::
SimpleExecutor
()
{
}
SimpleExecutor
::~
SimpleExecutor
()
{
}
int
SimpleExecutor
::
initialize
(
YAML
::
Node
exe_config
,
class
SimpleExecutor
:
public
Executor
{
public:
SimpleExecutor
()
{};
virtual
~
SimpleExecutor
()
{};
virtual
int
initialize
(
YAML
::
Node
exe_config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
paddle
::
framework
::
InitDevices
(
false
);
if
(
exe_config
[
"num_threads"
])
{
paddle
::
platform
::
SetNumThreads
(
exe_config
[
"num_threads"
].
as
<
int
>
());
}
else
{
paddle
::
platform
::
SetNumThreads
(
1
);
}
if
(
!
exe_config
[
"startup_program"
]
||
!
exe_config
[
"main_program"
])
{
VLOG
(
2
)
<<
"fail to load config"
;
return
-
1
;
}
paddle
::
framework
::
InitDevices
(
false
);
if
(
exe_config
[
"num_threads"
])
{
paddle
::
platform
::
SetNumThreads
(
exe_config
[
"num_threads"
].
as
<
int
>
());
}
else
{
paddle
::
platform
::
SetNumThreads
(
1
);
}
try
{
_context
.
reset
(
new
SimpleExecutor
::
Context
(
context_ptr
->
cpu_place
));
auto
startup_program
=
Load
(
&
_context
->
executor
,
exe_config
[
"startup_program"
].
as
<
std
::
string
>
());
if
(
startup_program
==
nullptr
)
{
VLOG
(
2
)
<<
"fail to load startup_program: "
<<
exe_config
[
"startup_program"
].
as
<
std
::
string
>
();
if
(
!
exe_config
[
"startup_program"
]
||
!
exe_config
[
"main_program"
])
{
VLOG
(
2
)
<<
"fail to load config"
;
return
-
1
;
}
_context
->
executor
.
Run
(
*
startup_program
,
this
->
scope
(),
0
,
false
,
true
);
_context
->
main_program
=
Load
(
&
_context
->
executor
,
exe_config
[
"main_program"
].
as
<
std
::
string
>
());
if
(
_context
->
main_program
==
nullptr
)
{
VLOG
(
2
)
<<
"fail to load main_program: "
<<
exe_config
[
"main_program"
].
as
<
std
::
string
>
();
try
{
_context
.
reset
(
new
SimpleExecutor
::
Context
(
context_ptr
->
cpu_place
));
auto
startup_program
=
Load
(
&
_context
->
executor
,
exe_config
[
"startup_program"
].
as
<
std
::
string
>
());
if
(
startup_program
==
nullptr
)
{
VLOG
(
2
)
<<
"fail to load startup_program: "
<<
exe_config
[
"startup_program"
].
as
<
std
::
string
>
();
return
-
1
;
}
_context
->
executor
.
Run
(
*
startup_program
,
this
->
scope
(),
0
,
false
,
true
);
_context
->
main_program
=
Load
(
&
_context
->
executor
,
exe_config
[
"main_program"
].
as
<
std
::
string
>
());
if
(
_context
->
main_program
==
nullptr
)
{
VLOG
(
2
)
<<
"fail to load main_program: "
<<
exe_config
[
"main_program"
].
as
<
std
::
string
>
();
return
-
1
;
}
_context
->
prepare_context
=
_context
->
executor
.
Prepare
(
*
_context
->
main_program
,
0
);
_context
->
executor
.
CreateVariables
(
*
_context
->
main_program
,
this
->
scope
(),
0
);
}
catch
(
::
paddle
::
platform
::
EnforceNotMet
&
err
)
{
VLOG
(
2
)
<<
err
.
what
();
_context
.
reset
(
nullptr
);
return
-
1
;
}
_context
->
prepare_context
=
_context
->
executor
.
Prepare
(
*
_context
->
main_program
,
0
);
_context
->
executor
.
CreateVariables
(
*
_context
->
main_program
,
this
->
scope
(),
0
);
}
catch
(
::
paddle
::
platform
::
EnforceNotMet
&
err
)
{
VLOG
(
2
)
<<
err
.
what
();
_context
.
reset
(
nullptr
);
return
-
1
;
}
return
0
;
}
int
SimpleExecutor
::
run
()
{
if
(
_context
==
nullptr
)
{
VLOG
(
2
)
<<
"need initialize before run"
;
return
-
1
;
return
0
;
}
try
{
_context
->
executor
.
RunPreparedContext
(
_context
->
prepare_context
.
get
(),
this
->
scope
(),
false
,
/* don't create local scope each time*/
false
/* don't create variable each time */
);
// For some other vector like containers not cleaned after each batch.
_context
->
tensor_array_batch_cleaner
.
CollectNoTensorVars
(
this
->
scope
());
_context
->
tensor_array_batch_cleaner
.
ResetNoTensorVars
();
}
catch
(
::
paddle
::
platform
::
EnforceNotMet
&
err
)
{
VLOG
(
2
)
<<
err
.
what
();
return
-
1
;
virtual
int
run
()
{
if
(
_context
==
nullptr
)
{
VLOG
(
2
)
<<
"need initialize before run"
;
return
-
1
;
}
try
{
_context
->
executor
.
RunPreparedContext
(
_context
->
prepare_context
.
get
(),
this
->
scope
(),
false
,
/* don't create local scope each time*/
false
/* don't create variable each time */
);
// For some other vector like containers not cleaned after each batch.
_context
->
tensor_array_batch_cleaner
.
CollectNoTensorVars
(
this
->
scope
());
_context
->
tensor_array_batch_cleaner
.
ResetNoTensorVars
();
}
catch
(
::
paddle
::
platform
::
EnforceNotMet
&
err
)
{
VLOG
(
2
)
<<
err
.
what
();
return
-
1
;
}
return
0
;
}
return
0
;
}
protected:
struct
Context
{
Context
(
const
::
paddle
::
platform
::
Place
&
place
)
:
place
(
place
),
executor
(
place
)
{
}
const
::
paddle
::
platform
::
Place
&
place
;
::
paddle
::
framework
::
Executor
executor
;
::
std
::
unique_ptr
<::
paddle
::
framework
::
ProgramDesc
>
main_program
;
::
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
prepare_context
;
details
::
TensorArrayBatchCleaner
tensor_array_batch_cleaner
;
};
std
::
unique_ptr
<
Context
>
_context
;
};
REGISTER_CLASS
(
Executor
,
SimpleExecutor
);
}
// namespace feed
...
...
paddle/fluid/train/custom_trainer/feed/executor/executor.h
浏览文件 @
1fcde8e9
...
...
@@ -42,18 +42,6 @@ protected:
};
REGISTER_REGISTERER
(
Executor
);
class
SimpleExecutor
:
public
Executor
{
public:
SimpleExecutor
();
virtual
~
SimpleExecutor
();
virtual
int
initialize
(
YAML
::
Node
exe_config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
);
virtual
int
run
();
protected:
struct
Context
;
std
::
unique_ptr
<
Context
>
_context
;
};
}
// namespace feed
}
// namespace custom_trainer
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc
浏览文件 @
1fcde8e9
...
...
@@ -81,22 +81,23 @@ public:
};
TEST_F
(
SimpleExecutorTest
,
initialize
)
{
SimpleExecutor
executor
;
std
::
unique_ptr
<
Executor
>
executor
(
CREATE_CLASS
(
Executor
,
"SimpleExecutor"
))
;
YAML
::
Node
config
=
YAML
::
Load
(
"[1, 2, 3]"
);
ASSERT_NE
(
0
,
executor
.
initialize
(
config
,
context_ptr
));
ASSERT_NE
(
0
,
executor
->
initialize
(
config
,
context_ptr
));
config
=
YAML
::
Load
(
std
::
string
()
+
"{startup_program: "
+
startup_program_path
+
", main_program: "
+
main_program_path
+
"}"
);
ASSERT_EQ
(
0
,
executor
.
initialize
(
config
,
context_ptr
));
ASSERT_EQ
(
0
,
executor
->
initialize
(
config
,
context_ptr
));
config
=
YAML
::
Load
(
std
::
string
()
+
"{thread_num: 2, startup_program: "
+
startup_program_path
+
", main_program: "
+
main_program_path
+
"}"
);
ASSERT_EQ
(
0
,
executor
.
initialize
(
config
,
context_ptr
));
ASSERT_EQ
(
0
,
executor
->
initialize
(
config
,
context_ptr
));
}
TEST_F
(
SimpleExecutorTest
,
run
)
{
SimpleExecutor
executor
;
std
::
unique_ptr
<
Executor
>
executor
(
CREATE_CLASS
(
Executor
,
"SimpleExecutor"
));
auto
config
=
YAML
::
Load
(
std
::
string
()
+
"{thread_num: 2, startup_program: "
+
startup_program_path
+
", main_program: "
+
main_program_path
+
"}"
);
ASSERT_EQ
(
0
,
executor
.
initialize
(
config
,
context_ptr
));
ASSERT_EQ
(
0
,
executor
->
initialize
(
config
,
context_ptr
));
auto
x_var
=
executor
.
mutable_var
<::
paddle
::
framework
::
LoDTensor
>
(
"x"
);
executor
.
mutable_var
<::
paddle
::
framework
::
LoDTensor
>
(
"mean"
);
auto
x_var
=
executor
->
mutable_var
<::
paddle
::
framework
::
LoDTensor
>
(
"x"
);
executor
->
mutable_var
<::
paddle
::
framework
::
LoDTensor
>
(
"mean"
);
ASSERT_NE
(
nullptr
,
x_var
);
int
x_len
=
10
;
...
...
@@ -109,9 +110,9 @@ TEST_F(SimpleExecutorTest, run) {
}
std
::
cout
<<
std
::
endl
;
ASSERT_EQ
(
0
,
executor
.
run
());
ASSERT_EQ
(
0
,
executor
->
run
());
auto
mean_var
=
executor
.
var
<::
paddle
::
framework
::
LoDTensor
>
(
"mean"
);
auto
mean_var
=
executor
->
var
<::
paddle
::
framework
::
LoDTensor
>
(
"mean"
);
auto
mean
=
mean_var
.
data
<
float
>
()[
0
];
std
::
cout
<<
"mean: "
<<
mean
<<
std
::
endl
;
ASSERT_NEAR
(
4.5
,
mean
,
1e-9
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录