Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
1fcde8e9
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1fcde8e9
编写于
8月 01, 2019
作者:
R
rensilin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
hide SimpleExecutor
Change-Id: I08245abbb5c3fdba91ef1bb0a24871d41594c1ce
上级
2f60e4a7
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
79 addition
and
93 deletion
+79
-93
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
...luid/train/custom_trainer/feed/accessor/epoch_accessor.cc
+0
-1
paddle/fluid/train/custom_trainer/feed/common/registerer.h
paddle/fluid/train/custom_trainer/feed/common/registerer.h
+1
-1
paddle/fluid/train/custom_trainer/feed/executor/executor.cc
paddle/fluid/train/custom_trainer/feed/executor/executor.cc
+67
-69
paddle/fluid/train/custom_trainer/feed/executor/executor.h
paddle/fluid/train/custom_trainer/feed/executor/executor.h
+0
-12
paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc
...luid/train/custom_trainer/feed/unit_test/test_executor.cc
+11
-10
未找到文件。
paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc
浏览文件 @
1fcde8e9
#pragma once
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
namespace
paddle
{
namespace
paddle
{
...
...
paddle/fluid/train/custom_trainer/feed/common/registerer.h
浏览文件 @
1fcde8e9
...
@@ -106,7 +106,7 @@ BaseClassMap& global_factory_map_cpp();
...
@@ -106,7 +106,7 @@ BaseClassMap& global_factory_map_cpp();
void register_factory_##name() __attribute__((constructor));
void register_factory_##name() __attribute__((constructor));
#define CREATE_CLASS(base_class, name) \
#define CREATE_CLASS(base_class, name) \
base_class##Registerer::CreateInstanceByName(name)
;
base_class##Registerer::CreateInstanceByName(name)
}
//namespace feed
}
//namespace feed
}
//namespace custom_trainer
}
//namespace custom_trainer
...
...
paddle/fluid/train/custom_trainer/feed/executor/executor.cc
浏览文件 @
1fcde8e9
...
@@ -43,26 +43,12 @@ std::unique_ptr<paddle::framework::ProgramDesc> Load(
...
@@ -43,26 +43,12 @@ std::unique_ptr<paddle::framework::ProgramDesc> Load(
}
}
struct
SimpleExecutor
::
Context
{
Context
(
const
::
paddle
::
platform
::
Place
&
place
)
:
place
(
place
),
executor
(
place
)
{
}
const
::
paddle
::
platform
::
Place
&
place
;
::
paddle
::
framework
::
Executor
executor
;
::
std
::
unique_ptr
<::
paddle
::
framework
::
ProgramDesc
>
main_program
;
::
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
prepare_context
;
details
::
TensorArrayBatchCleaner
tensor_array_batch_cleaner
;
};
SimpleExecutor
::
SimpleExecutor
()
{
}
SimpleExecutor
::~
SimpleExecutor
()
{
}
int
SimpleExecutor
::
initialize
(
YAML
::
Node
exe_config
,
class
SimpleExecutor
:
public
Executor
{
public:
SimpleExecutor
()
{};
virtual
~
SimpleExecutor
()
{};
virtual
int
initialize
(
YAML
::
Node
exe_config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
paddle
::
framework
::
InitDevices
(
false
);
paddle
::
framework
::
InitDevices
(
false
);
...
@@ -102,9 +88,8 @@ int SimpleExecutor::initialize(YAML::Node exe_config,
...
@@ -102,9 +88,8 @@ int SimpleExecutor::initialize(YAML::Node exe_config,
}
}
return
0
;
return
0
;
}
}
virtual
int
run
()
{
int
SimpleExecutor
::
run
()
{
if
(
_context
==
nullptr
)
{
if
(
_context
==
nullptr
)
{
VLOG
(
2
)
<<
"need initialize before run"
;
VLOG
(
2
)
<<
"need initialize before run"
;
return
-
1
;
return
-
1
;
...
@@ -122,7 +107,20 @@ int SimpleExecutor::run() {
...
@@ -122,7 +107,20 @@ int SimpleExecutor::run() {
return
-
1
;
return
-
1
;
}
}
return
0
;
return
0
;
}
}
protected:
struct
Context
{
Context
(
const
::
paddle
::
platform
::
Place
&
place
)
:
place
(
place
),
executor
(
place
)
{
}
const
::
paddle
::
platform
::
Place
&
place
;
::
paddle
::
framework
::
Executor
executor
;
::
std
::
unique_ptr
<::
paddle
::
framework
::
ProgramDesc
>
main_program
;
::
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
prepare_context
;
details
::
TensorArrayBatchCleaner
tensor_array_batch_cleaner
;
};
std
::
unique_ptr
<
Context
>
_context
;
};
REGISTER_CLASS
(
Executor
,
SimpleExecutor
);
REGISTER_CLASS
(
Executor
,
SimpleExecutor
);
}
// namespace feed
}
// namespace feed
...
...
paddle/fluid/train/custom_trainer/feed/executor/executor.h
浏览文件 @
1fcde8e9
...
@@ -42,18 +42,6 @@ protected:
...
@@ -42,18 +42,6 @@ protected:
};
};
REGISTER_REGISTERER
(
Executor
);
REGISTER_REGISTERER
(
Executor
);
class
SimpleExecutor
:
public
Executor
{
public:
SimpleExecutor
();
virtual
~
SimpleExecutor
();
virtual
int
initialize
(
YAML
::
Node
exe_config
,
std
::
shared_ptr
<
TrainerContext
>
context_ptr
);
virtual
int
run
();
protected:
struct
Context
;
std
::
unique_ptr
<
Context
>
_context
;
};
}
// namespace feed
}
// namespace feed
}
// namespace custom_trainer
}
// namespace custom_trainer
}
// namespace paddle
}
// namespace paddle
paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc
浏览文件 @
1fcde8e9
...
@@ -81,22 +81,23 @@ public:
...
@@ -81,22 +81,23 @@ public:
};
};
TEST_F
(
SimpleExecutorTest
,
initialize
)
{
TEST_F
(
SimpleExecutorTest
,
initialize
)
{
SimpleExecutor
executor
;
std
::
unique_ptr
<
Executor
>
executor
(
CREATE_CLASS
(
Executor
,
"SimpleExecutor"
))
;
YAML
::
Node
config
=
YAML
::
Load
(
"[1, 2, 3]"
);
YAML
::
Node
config
=
YAML
::
Load
(
"[1, 2, 3]"
);
ASSERT_NE
(
0
,
executor
.
initialize
(
config
,
context_ptr
));
ASSERT_NE
(
0
,
executor
->
initialize
(
config
,
context_ptr
));
config
=
YAML
::
Load
(
std
::
string
()
+
"{startup_program: "
+
startup_program_path
+
", main_program: "
+
main_program_path
+
"}"
);
config
=
YAML
::
Load
(
std
::
string
()
+
"{startup_program: "
+
startup_program_path
+
", main_program: "
+
main_program_path
+
"}"
);
ASSERT_EQ
(
0
,
executor
.
initialize
(
config
,
context_ptr
));
ASSERT_EQ
(
0
,
executor
->
initialize
(
config
,
context_ptr
));
config
=
YAML
::
Load
(
std
::
string
()
+
"{thread_num: 2, startup_program: "
+
startup_program_path
+
", main_program: "
+
main_program_path
+
"}"
);
config
=
YAML
::
Load
(
std
::
string
()
+
"{thread_num: 2, startup_program: "
+
startup_program_path
+
", main_program: "
+
main_program_path
+
"}"
);
ASSERT_EQ
(
0
,
executor
.
initialize
(
config
,
context_ptr
));
ASSERT_EQ
(
0
,
executor
->
initialize
(
config
,
context_ptr
));
}
}
TEST_F
(
SimpleExecutorTest
,
run
)
{
TEST_F
(
SimpleExecutorTest
,
run
)
{
SimpleExecutor
executor
;
std
::
unique_ptr
<
Executor
>
executor
(
CREATE_CLASS
(
Executor
,
"SimpleExecutor"
));
auto
config
=
YAML
::
Load
(
std
::
string
()
+
"{thread_num: 2, startup_program: "
+
startup_program_path
+
", main_program: "
+
main_program_path
+
"}"
);
auto
config
=
YAML
::
Load
(
std
::
string
()
+
"{thread_num: 2, startup_program: "
+
startup_program_path
+
", main_program: "
+
main_program_path
+
"}"
);
ASSERT_EQ
(
0
,
executor
.
initialize
(
config
,
context_ptr
));
ASSERT_EQ
(
0
,
executor
->
initialize
(
config
,
context_ptr
));
auto
x_var
=
executor
.
mutable_var
<::
paddle
::
framework
::
LoDTensor
>
(
"x"
);
auto
x_var
=
executor
->
mutable_var
<::
paddle
::
framework
::
LoDTensor
>
(
"x"
);
executor
.
mutable_var
<::
paddle
::
framework
::
LoDTensor
>
(
"mean"
);
executor
->
mutable_var
<::
paddle
::
framework
::
LoDTensor
>
(
"mean"
);
ASSERT_NE
(
nullptr
,
x_var
);
ASSERT_NE
(
nullptr
,
x_var
);
int
x_len
=
10
;
int
x_len
=
10
;
...
@@ -109,9 +110,9 @@ TEST_F(SimpleExecutorTest, run) {
...
@@ -109,9 +110,9 @@ TEST_F(SimpleExecutorTest, run) {
}
}
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
ASSERT_EQ
(
0
,
executor
.
run
());
ASSERT_EQ
(
0
,
executor
->
run
());
auto
mean_var
=
executor
.
var
<::
paddle
::
framework
::
LoDTensor
>
(
"mean"
);
auto
mean_var
=
executor
->
var
<::
paddle
::
framework
::
LoDTensor
>
(
"mean"
);
auto
mean
=
mean_var
.
data
<
float
>
()[
0
];
auto
mean
=
mean_var
.
data
<
float
>
()[
0
];
std
::
cout
<<
"mean: "
<<
mean
<<
std
::
endl
;
std
::
cout
<<
"mean: "
<<
mean
<<
std
::
endl
;
ASSERT_NEAR
(
4.5
,
mean
,
1e-9
);
ASSERT_NEAR
(
4.5
,
mean
,
1e-9
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录