提交 1fcde8e9 编写于 作者: R rensilin

hide SimpleExecutor

Change-Id: I08245abbb5c3fdba91ef1bb0a24871d41594c1ce
上级 2f60e4a7
#pragma once
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h" #include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
namespace paddle { namespace paddle {
......
...@@ -106,7 +106,7 @@ BaseClassMap& global_factory_map_cpp(); ...@@ -106,7 +106,7 @@ BaseClassMap& global_factory_map_cpp();
void register_factory_##name() __attribute__((constructor)); void register_factory_##name() __attribute__((constructor));
#define CREATE_CLASS(base_class, name) \ #define CREATE_CLASS(base_class, name) \
base_class##Registerer::CreateInstanceByName(name); base_class##Registerer::CreateInstanceByName(name)
}//namespace feed }//namespace feed
}//namespace custom_trainer }//namespace custom_trainer
......
...@@ -43,86 +43,84 @@ std::unique_ptr<paddle::framework::ProgramDesc> Load( ...@@ -43,86 +43,84 @@ std::unique_ptr<paddle::framework::ProgramDesc> Load(
} }
struct SimpleExecutor::Context {
Context(const ::paddle::platform::Place& place) : place(place), executor(place) {
}
const ::paddle::platform::Place& place;
::paddle::framework::Executor executor;
::std::unique_ptr<::paddle::framework::ProgramDesc> main_program;
::std::unique_ptr<framework::ExecutorPrepareContext> prepare_context;
details::TensorArrayBatchCleaner tensor_array_batch_cleaner;
};
SimpleExecutor::SimpleExecutor() {
}
SimpleExecutor::~SimpleExecutor() {
}
int SimpleExecutor::initialize(YAML::Node exe_config, class SimpleExecutor : public Executor {
public:
SimpleExecutor() {};
virtual ~SimpleExecutor() {};
virtual int initialize(YAML::Node exe_config,
std::shared_ptr<TrainerContext> context_ptr) { std::shared_ptr<TrainerContext> context_ptr) {
paddle::framework::InitDevices(false); paddle::framework::InitDevices(false);
if (exe_config["num_threads"]) { if (exe_config["num_threads"]) {
paddle::platform::SetNumThreads(exe_config["num_threads"].as<int>()); paddle::platform::SetNumThreads(exe_config["num_threads"].as<int>());
} else { } else {
paddle::platform::SetNumThreads(1); paddle::platform::SetNumThreads(1);
} }
if (!exe_config["startup_program"] ||
!exe_config["main_program"]) {
VLOG(2) << "fail to load config";
return -1;
}
try { if (!exe_config["startup_program"] ||
_context.reset(new SimpleExecutor::Context(context_ptr->cpu_place)); !exe_config["main_program"]) {
auto startup_program = Load(&_context->executor, exe_config["startup_program"].as<std::string>()); VLOG(2) << "fail to load config";
if (startup_program == nullptr) {
VLOG(2) << "fail to load startup_program: " << exe_config["startup_program"].as<std::string>();
return -1; return -1;
} }
_context->executor.Run(*startup_program, this->scope(), 0, false, true);
_context->main_program = Load(&_context->executor, exe_config["main_program"].as<std::string>()); try {
if (_context->main_program == nullptr) { _context.reset(new SimpleExecutor::Context(context_ptr->cpu_place));
VLOG(2) << "fail to load main_program: " << exe_config["main_program"].as<std::string>(); auto startup_program = Load(&_context->executor, exe_config["startup_program"].as<std::string>());
if (startup_program == nullptr) {
VLOG(2) << "fail to load startup_program: " << exe_config["startup_program"].as<std::string>();
return -1;
}
_context->executor.Run(*startup_program, this->scope(), 0, false, true);
_context->main_program = Load(&_context->executor, exe_config["main_program"].as<std::string>());
if (_context->main_program == nullptr) {
VLOG(2) << "fail to load main_program: " << exe_config["main_program"].as<std::string>();
return -1;
}
_context->prepare_context = _context->executor.Prepare(*_context->main_program, 0);
_context->executor.CreateVariables(*_context->main_program, this->scope(), 0);
} catch (::paddle::platform::EnforceNotMet& err) {
VLOG(2) << err.what();
_context.reset(nullptr);
return -1; return -1;
} }
_context->prepare_context = _context->executor.Prepare(*_context->main_program, 0);
_context->executor.CreateVariables(*_context->main_program, this->scope(), 0);
} catch (::paddle::platform::EnforceNotMet& err) {
VLOG(2) << err.what();
_context.reset(nullptr);
return -1;
}
return 0;
}
int SimpleExecutor::run() { return 0;
if (_context == nullptr) {
VLOG(2) << "need initialize before run";
return -1;
} }
try { virtual int run() {
_context->executor.RunPreparedContext(_context->prepare_context.get(), this->scope(), if (_context == nullptr) {
false, /* don't create local scope each time*/ VLOG(2) << "need initialize before run";
false /* don't create variable each time */); return -1;
}
// For some other vector like containers not cleaned after each batch. try {
_context->tensor_array_batch_cleaner.CollectNoTensorVars(this->scope()); _context->executor.RunPreparedContext(_context->prepare_context.get(), this->scope(),
_context->tensor_array_batch_cleaner.ResetNoTensorVars(); false, /* don't create local scope each time*/
} catch (::paddle::platform::EnforceNotMet& err) { false /* don't create variable each time */);
VLOG(2) << err.what();
return -1; // For some other vector like containers not cleaned after each batch.
_context->tensor_array_batch_cleaner.CollectNoTensorVars(this->scope());
_context->tensor_array_batch_cleaner.ResetNoTensorVars();
} catch (::paddle::platform::EnforceNotMet& err) {
VLOG(2) << err.what();
return -1;
}
return 0;
} }
return 0; protected:
} struct Context {
Context(const ::paddle::platform::Place& place) : place(place), executor(place) {
}
const ::paddle::platform::Place& place;
::paddle::framework::Executor executor;
::std::unique_ptr<::paddle::framework::ProgramDesc> main_program;
::std::unique_ptr<framework::ExecutorPrepareContext> prepare_context;
details::TensorArrayBatchCleaner tensor_array_batch_cleaner;
};
std::unique_ptr<Context> _context;
};
REGISTER_CLASS(Executor, SimpleExecutor); REGISTER_CLASS(Executor, SimpleExecutor);
} // namespace feed } // namespace feed
......
...@@ -42,18 +42,6 @@ protected: ...@@ -42,18 +42,6 @@ protected:
}; };
REGISTER_REGISTERER(Executor); REGISTER_REGISTERER(Executor);
class SimpleExecutor : public Executor {
public:
SimpleExecutor();
virtual ~SimpleExecutor();
virtual int initialize(YAML::Node exe_config,
std::shared_ptr<TrainerContext> context_ptr);
virtual int run();
protected:
struct Context;
std::unique_ptr<Context> _context;
};
} // namespace feed } // namespace feed
} // namespace custom_trainer } // namespace custom_trainer
} // namespace paddle } // namespace paddle
...@@ -81,22 +81,23 @@ public: ...@@ -81,22 +81,23 @@ public:
}; };
TEST_F(SimpleExecutorTest, initialize) { TEST_F(SimpleExecutorTest, initialize) {
SimpleExecutor executor; std::unique_ptr<Executor> executor(CREATE_CLASS(Executor, "SimpleExecutor"));
YAML::Node config = YAML::Load("[1, 2, 3]"); YAML::Node config = YAML::Load("[1, 2, 3]");
ASSERT_NE(0, executor.initialize(config, context_ptr)); ASSERT_NE(0, executor->initialize(config, context_ptr));
config = YAML::Load(std::string() + "{startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}"); config = YAML::Load(std::string() + "{startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}");
ASSERT_EQ(0, executor.initialize(config, context_ptr)); ASSERT_EQ(0, executor->initialize(config, context_ptr));
config = YAML::Load(std::string() + "{thread_num: 2, startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}"); config = YAML::Load(std::string() + "{thread_num: 2, startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}");
ASSERT_EQ(0, executor.initialize(config, context_ptr)); ASSERT_EQ(0, executor->initialize(config, context_ptr));
} }
TEST_F(SimpleExecutorTest, run) { TEST_F(SimpleExecutorTest, run) {
SimpleExecutor executor; std::unique_ptr<Executor> executor(CREATE_CLASS(Executor, "SimpleExecutor"));
auto config = YAML::Load(std::string() + "{thread_num: 2, startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}"); auto config = YAML::Load(std::string() + "{thread_num: 2, startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}");
ASSERT_EQ(0, executor.initialize(config, context_ptr)); ASSERT_EQ(0, executor->initialize(config, context_ptr));
auto x_var = executor.mutable_var<::paddle::framework::LoDTensor>("x"); auto x_var = executor->mutable_var<::paddle::framework::LoDTensor>("x");
executor.mutable_var<::paddle::framework::LoDTensor>("mean"); executor->mutable_var<::paddle::framework::LoDTensor>("mean");
ASSERT_NE(nullptr, x_var); ASSERT_NE(nullptr, x_var);
int x_len = 10; int x_len = 10;
...@@ -109,9 +110,9 @@ TEST_F(SimpleExecutorTest, run) { ...@@ -109,9 +110,9 @@ TEST_F(SimpleExecutorTest, run) {
} }
std::cout << std::endl; std::cout << std::endl;
ASSERT_EQ(0, executor.run()); ASSERT_EQ(0, executor->run());
auto mean_var = executor.var<::paddle::framework::LoDTensor>("mean"); auto mean_var = executor->var<::paddle::framework::LoDTensor>("mean");
auto mean = mean_var.data<float>()[0]; auto mean = mean_var.data<float>()[0];
std::cout << "mean: " << mean << std::endl; std::cout << "mean: " << mean << std::endl;
ASSERT_NEAR(4.5, mean, 1e-9); ASSERT_NEAR(4.5, mean, 1e-9);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册