提交 8b7e1ed1 编写于 作者: X xiexionghang

for runnable trainer

上级 b1a8a46e
train_thread_num : 10
environment :
environment_class : MPIRuntimeEnvironment
epoch:
epoch_class : HourlyEpochAccessor
...@@ -43,7 +43,7 @@ std::unique_ptr<paddle::framework::ProgramDesc> Load( ...@@ -43,7 +43,7 @@ std::unique_ptr<paddle::framework::ProgramDesc> Load(
} }
struct SimpleExecute::Context { struct SimpleExecutor::Context {
Context(const ::paddle::platform::Place& place) : place(place), executor(place) { Context(const ::paddle::platform::Place& place) : place(place), executor(place) {
} }
const ::paddle::platform::Place& place; const ::paddle::platform::Place& place;
...@@ -54,15 +54,15 @@ struct SimpleExecute::Context { ...@@ -54,15 +54,15 @@ struct SimpleExecute::Context {
}; };
SimpleExecute::SimpleExecute() { SimpleExecutor::SimpleExecutor() {
} }
SimpleExecute::~SimpleExecute() { SimpleExecutor::~SimpleExecutor() {
} }
int SimpleExecute::initialize(YAML::Node exe_config, int SimpleExecutor::initialize(YAML::Node exe_config,
std::shared_ptr<TrainerContext> context_ptr) { std::shared_ptr<TrainerContext> context_ptr) {
paddle::framework::InitDevices(false); paddle::framework::InitDevices(false);
...@@ -79,7 +79,7 @@ int SimpleExecute::initialize(YAML::Node exe_config, ...@@ -79,7 +79,7 @@ int SimpleExecute::initialize(YAML::Node exe_config,
} }
try { try {
_context.reset(new SimpleExecute::Context(context_ptr->cpu_place)); _context.reset(new SimpleExecutor::Context(context_ptr->cpu_place));
auto startup_program = Load(&_context->executor, exe_config["startup_program"].as<std::string>()); auto startup_program = Load(&_context->executor, exe_config["startup_program"].as<std::string>());
if (startup_program == nullptr) { if (startup_program == nullptr) {
VLOG(2) << "fail to load startup_program: " << exe_config["startup_program"].as<std::string>(); VLOG(2) << "fail to load startup_program: " << exe_config["startup_program"].as<std::string>();
...@@ -104,7 +104,7 @@ int SimpleExecute::initialize(YAML::Node exe_config, ...@@ -104,7 +104,7 @@ int SimpleExecute::initialize(YAML::Node exe_config,
return 0; return 0;
} }
int SimpleExecute::run() { int SimpleExecutor::run() {
if (_context == nullptr) { if (_context == nullptr) {
VLOG(2) << "need initialize before run"; VLOG(2) << "need initialize before run";
return -1; return -1;
......
...@@ -44,8 +44,8 @@ REGISTER_REGISTERER(Executor); ...@@ -44,8 +44,8 @@ REGISTER_REGISTERER(Executor);
class SimpleExecutor : public Executor { class SimpleExecutor : public Executor {
public: public:
SimpleExecute(); SimpleExecutor();
virtual ~SimpleExecute(); virtual ~SimpleExecutor();
virtual int initialize(YAML::Node exe_config, virtual int initialize(YAML::Node exe_config,
std::shared_ptr<TrainerContext> context_ptr); std::shared_ptr<TrainerContext> context_ptr);
virtual int run(); virtual int run();
......
...@@ -22,8 +22,8 @@ namespace paddle { ...@@ -22,8 +22,8 @@ namespace paddle {
namespace custom_trainer { namespace custom_trainer {
namespace feed { namespace feed {
TEST(testSimpleExecute, initialize) { TEST(testSimpleExecutor, initialize) {
SimpleExecute execute; SimpleExecutor execute;
auto context_ptr = std::make_shared<TrainerContext>(); auto context_ptr = std::make_shared<TrainerContext>();
YAML::Node config = YAML::Load("[1, 2, 3]"); YAML::Node config = YAML::Load("[1, 2, 3]");
ASSERT_NE(0, execute.initialize(config, context_ptr)); ASSERT_NE(0, execute.initialize(config, context_ptr));
...@@ -54,8 +54,8 @@ void next_batch(int batch_size, const paddle::platform::Place& place, paddle::fr ...@@ -54,8 +54,8 @@ void next_batch(int batch_size, const paddle::platform::Place& place, paddle::fr
} }
} }
TEST(testSimpleExecute, run) { TEST(testSimpleExecutor, run) {
SimpleExecute execute; SimpleExecutor execute;
auto context_ptr = std::make_shared<TrainerContext>(); auto context_ptr = std::make_shared<TrainerContext>();
auto config = YAML::Load("{thread_num: 2, startup_program: ./data/startup_program, main_program: ./data/main_program}"); auto config = YAML::Load("{thread_num: 2, startup_program: ./data/startup_program, main_program: ./data/main_program}");
ASSERT_EQ(0, execute.initialize(config, context_ptr)); ASSERT_EQ(0, execute.initialize(config, context_ptr));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册