From 8b7e1ed1e669515d513b6f7cc1dff64c60cbfc70 Mon Sep 17 00:00:00 2001 From: xiexionghang Date: Thu, 1 Aug 2019 16:22:52 +0800 Subject: [PATCH] for runnable trainer --- .../fluid/train/custom_trainer/feed/conf/gflags.conf | 1 + .../train/custom_trainer/feed/conf/trainer.yaml | 6 ++++++ .../train/custom_trainer/feed/executor/executor.cc | 12 ++++++------ .../train/custom_trainer/feed/executor/executor.h | 4 ++-- .../custom_trainer/feed/unit_test/test_executor.cc | 8 ++++---- 5 files changed, 19 insertions(+), 12 deletions(-) create mode 100644 paddle/fluid/train/custom_trainer/feed/conf/gflags.conf create mode 100644 paddle/fluid/train/custom_trainer/feed/conf/trainer.yaml diff --git a/paddle/fluid/train/custom_trainer/feed/conf/gflags.conf b/paddle/fluid/train/custom_trainer/feed/conf/gflags.conf new file mode 100644 index 00000000..f7ac9283 --- /dev/null +++ b/paddle/fluid/train/custom_trainer/feed/conf/gflags.conf @@ -0,0 +1 @@ +-v=10 diff --git a/paddle/fluid/train/custom_trainer/feed/conf/trainer.yaml b/paddle/fluid/train/custom_trainer/feed/conf/trainer.yaml new file mode 100644 index 00000000..85236102 --- /dev/null +++ b/paddle/fluid/train/custom_trainer/feed/conf/trainer.yaml @@ -0,0 +1,6 @@ +train_thread_num : 10 + +environment : + environment_class : MPIRuntimeEnvironment +epoch: + epoch_class : HourlyEpochAccessor diff --git a/paddle/fluid/train/custom_trainer/feed/executor/executor.cc b/paddle/fluid/train/custom_trainer/feed/executor/executor.cc index 569b6d82..782ec620 100644 --- a/paddle/fluid/train/custom_trainer/feed/executor/executor.cc +++ b/paddle/fluid/train/custom_trainer/feed/executor/executor.cc @@ -43,7 +43,7 @@ std::unique_ptr Load( } -struct SimpleExecute::Context { +struct SimpleExecutor::Context { Context(const ::paddle::platform::Place& place) : place(place), executor(place) { } const ::paddle::platform::Place& place; @@ -54,15 +54,15 @@ struct SimpleExecute::Context { }; -SimpleExecute::SimpleExecute() { +SimpleExecutor::SimpleExecutor() { } -SimpleExecute::~SimpleExecute() { +SimpleExecutor::~SimpleExecutor() { } -int SimpleExecute::initialize(YAML::Node exe_config, +int SimpleExecutor::initialize(YAML::Node exe_config, std::shared_ptr context_ptr) { paddle::framework::InitDevices(false); @@ -79,7 +79,7 @@ int SimpleExecute::initialize(YAML::Node exe_config, } try { - _context.reset(new SimpleExecute::Context(context_ptr->cpu_place)); + _context.reset(new SimpleExecutor::Context(context_ptr->cpu_place)); auto startup_program = Load(&_context->executor, exe_config["startup_program"].as()); if (startup_program == nullptr) { VLOG(2) << "fail to load startup_program: " << exe_config["startup_program"].as(); @@ -104,7 +104,7 @@ int SimpleExecute::initialize(YAML::Node exe_config, return 0; } -int SimpleExecute::run() { +int SimpleExecutor::run() { if (_context == nullptr) { VLOG(2) << "need initialize before run"; return -1; diff --git a/paddle/fluid/train/custom_trainer/feed/executor/executor.h b/paddle/fluid/train/custom_trainer/feed/executor/executor.h index 64eb7a76..ec20401f 100644 --- a/paddle/fluid/train/custom_trainer/feed/executor/executor.h +++ b/paddle/fluid/train/custom_trainer/feed/executor/executor.h @@ -44,8 +44,8 @@ REGISTER_REGISTERER(Executor); class SimpleExecutor : public Executor { public: - SimpleExecute(); - virtual ~SimpleExecute(); + SimpleExecutor(); + virtual ~SimpleExecutor(); virtual int initialize(YAML::Node exe_config, std::shared_ptr context_ptr); virtual int run(); diff --git a/paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc b/paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc index ffecb3e8..479bbada 100644 --- a/paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc +++ b/paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc @@ -22,8 +22,8 @@ namespace paddle { namespace custom_trainer { namespace feed { -TEST(testSimpleExecute, initialize) { - SimpleExecute execute; +TEST(testSimpleExecutor, initialize) { + SimpleExecutor execute; auto context_ptr = std::make_shared(); YAML::Node config = YAML::Load("[1, 2, 3]"); ASSERT_NE(0, execute.initialize(config, context_ptr)); @@ -54,8 +54,8 @@ void next_batch(int batch_size, const paddle::platform::Place& place, paddle::fr } } -TEST(testSimpleExecute, run) { - SimpleExecute execute; +TEST(testSimpleExecutor, run) { + SimpleExecutor execute; auto context_ptr = std::make_shared(); auto config = YAML::Load("{thread_num: 2, startup_program: ./data/startup_program, main_program: ./data/main_program}"); ASSERT_EQ(0, execute.initialize(config, context_ptr)); -- GitLab