提交 1fcde8e9 编写于 作者: R rensilin

hide SimpleExecutor

Change-Id: I08245abbb5c3fdba91ef1bb0a24871d41594c1ce
上级 2f60e4a7
#pragma once
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
namespace paddle {
......
......@@ -106,7 +106,7 @@ BaseClassMap& global_factory_map_cpp();
void register_factory_##name() __attribute__((constructor));
#define CREATE_CLASS(base_class, name) \
base_class##Registerer::CreateInstanceByName(name);
base_class##Registerer::CreateInstanceByName(name)
}//namespace feed
}//namespace custom_trainer
......
......@@ -43,86 +43,84 @@ std::unique_ptr<paddle::framework::ProgramDesc> Load(
}
struct SimpleExecutor::Context {
Context(const ::paddle::platform::Place& place) : place(place), executor(place) {
}
const ::paddle::platform::Place& place;
::paddle::framework::Executor executor;
::std::unique_ptr<::paddle::framework::ProgramDesc> main_program;
::std::unique_ptr<framework::ExecutorPrepareContext> prepare_context;
details::TensorArrayBatchCleaner tensor_array_batch_cleaner;
};
SimpleExecutor::SimpleExecutor() {
}
SimpleExecutor::~SimpleExecutor() {
}
int SimpleExecutor::initialize(YAML::Node exe_config,
class SimpleExecutor : public Executor {
public:
SimpleExecutor() {};
virtual ~SimpleExecutor() {};
virtual int initialize(YAML::Node exe_config,
std::shared_ptr<TrainerContext> context_ptr) {
paddle::framework::InitDevices(false);
if (exe_config["num_threads"]) {
paddle::platform::SetNumThreads(exe_config["num_threads"].as<int>());
} else {
paddle::platform::SetNumThreads(1);
}
if (!exe_config["startup_program"] ||
!exe_config["main_program"]) {
VLOG(2) << "fail to load config";
return -1;
}
paddle::framework::InitDevices(false);
if (exe_config["num_threads"]) {
paddle::platform::SetNumThreads(exe_config["num_threads"].as<int>());
} else {
paddle::platform::SetNumThreads(1);
}
try {
_context.reset(new SimpleExecutor::Context(context_ptr->cpu_place));
auto startup_program = Load(&_context->executor, exe_config["startup_program"].as<std::string>());
if (startup_program == nullptr) {
VLOG(2) << "fail to load startup_program: " << exe_config["startup_program"].as<std::string>();
if (!exe_config["startup_program"] ||
!exe_config["main_program"]) {
VLOG(2) << "fail to load config";
return -1;
}
_context->executor.Run(*startup_program, this->scope(), 0, false, true);
_context->main_program = Load(&_context->executor, exe_config["main_program"].as<std::string>());
if (_context->main_program == nullptr) {
VLOG(2) << "fail to load main_program: " << exe_config["main_program"].as<std::string>();
try {
_context.reset(new SimpleExecutor::Context(context_ptr->cpu_place));
auto startup_program = Load(&_context->executor, exe_config["startup_program"].as<std::string>());
if (startup_program == nullptr) {
VLOG(2) << "fail to load startup_program: " << exe_config["startup_program"].as<std::string>();
return -1;
}
_context->executor.Run(*startup_program, this->scope(), 0, false, true);
_context->main_program = Load(&_context->executor, exe_config["main_program"].as<std::string>());
if (_context->main_program == nullptr) {
VLOG(2) << "fail to load main_program: " << exe_config["main_program"].as<std::string>();
return -1;
}
_context->prepare_context = _context->executor.Prepare(*_context->main_program, 0);
_context->executor.CreateVariables(*_context->main_program, this->scope(), 0);
} catch (::paddle::platform::EnforceNotMet& err) {
VLOG(2) << err.what();
_context.reset(nullptr);
return -1;
}
_context->prepare_context = _context->executor.Prepare(*_context->main_program, 0);
_context->executor.CreateVariables(*_context->main_program, this->scope(), 0);
} catch (::paddle::platform::EnforceNotMet& err) {
VLOG(2) << err.what();
_context.reset(nullptr);
return -1;
}
return 0;
}
int SimpleExecutor::run() {
if (_context == nullptr) {
VLOG(2) << "need initialize before run";
return -1;
return 0;
}
try {
_context->executor.RunPreparedContext(_context->prepare_context.get(), this->scope(),
false, /* don't create local scope each time*/
false /* don't create variable each time */);
// For some other vector like containers not cleaned after each batch.
_context->tensor_array_batch_cleaner.CollectNoTensorVars(this->scope());
_context->tensor_array_batch_cleaner.ResetNoTensorVars();
} catch (::paddle::platform::EnforceNotMet& err) {
VLOG(2) << err.what();
return -1;
virtual int run() {
if (_context == nullptr) {
VLOG(2) << "need initialize before run";
return -1;
}
try {
_context->executor.RunPreparedContext(_context->prepare_context.get(), this->scope(),
false, /* don't create local scope each time*/
false /* don't create variable each time */);
// For some other vector like containers not cleaned after each batch.
_context->tensor_array_batch_cleaner.CollectNoTensorVars(this->scope());
_context->tensor_array_batch_cleaner.ResetNoTensorVars();
} catch (::paddle::platform::EnforceNotMet& err) {
VLOG(2) << err.what();
return -1;
}
return 0;
}
return 0;
}
protected:
struct Context {
Context(const ::paddle::platform::Place& place) : place(place), executor(place) {
}
const ::paddle::platform::Place& place;
::paddle::framework::Executor executor;
::std::unique_ptr<::paddle::framework::ProgramDesc> main_program;
::std::unique_ptr<framework::ExecutorPrepareContext> prepare_context;
details::TensorArrayBatchCleaner tensor_array_batch_cleaner;
};
std::unique_ptr<Context> _context;
};
REGISTER_CLASS(Executor, SimpleExecutor);
} // namespace feed
......
......@@ -42,18 +42,6 @@ protected:
};
REGISTER_REGISTERER(Executor);
class SimpleExecutor : public Executor {
public:
SimpleExecutor();
virtual ~SimpleExecutor();
virtual int initialize(YAML::Node exe_config,
std::shared_ptr<TrainerContext> context_ptr);
virtual int run();
protected:
struct Context;
std::unique_ptr<Context> _context;
};
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
......@@ -81,22 +81,23 @@ public:
};
TEST_F(SimpleExecutorTest, initialize) {
SimpleExecutor executor;
std::unique_ptr<Executor> executor(CREATE_CLASS(Executor, "SimpleExecutor"));
YAML::Node config = YAML::Load("[1, 2, 3]");
ASSERT_NE(0, executor.initialize(config, context_ptr));
ASSERT_NE(0, executor->initialize(config, context_ptr));
config = YAML::Load(std::string() + "{startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}");
ASSERT_EQ(0, executor.initialize(config, context_ptr));
ASSERT_EQ(0, executor->initialize(config, context_ptr));
config = YAML::Load(std::string() + "{thread_num: 2, startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}");
ASSERT_EQ(0, executor.initialize(config, context_ptr));
ASSERT_EQ(0, executor->initialize(config, context_ptr));
}
TEST_F(SimpleExecutorTest, run) {
SimpleExecutor executor;
std::unique_ptr<Executor> executor(CREATE_CLASS(Executor, "SimpleExecutor"));
auto config = YAML::Load(std::string() + "{thread_num: 2, startup_program: " + startup_program_path + ", main_program: " + main_program_path + "}");
ASSERT_EQ(0, executor.initialize(config, context_ptr));
ASSERT_EQ(0, executor->initialize(config, context_ptr));
auto x_var = executor.mutable_var<::paddle::framework::LoDTensor>("x");
executor.mutable_var<::paddle::framework::LoDTensor>("mean");
auto x_var = executor->mutable_var<::paddle::framework::LoDTensor>("x");
executor->mutable_var<::paddle::framework::LoDTensor>("mean");
ASSERT_NE(nullptr, x_var);
int x_len = 10;
......@@ -109,9 +110,9 @@ TEST_F(SimpleExecutorTest, run) {
}
std::cout << std::endl;
ASSERT_EQ(0, executor.run());
ASSERT_EQ(0, executor->run());
auto mean_var = executor.var<::paddle::framework::LoDTensor>("mean");
auto mean_var = executor->var<::paddle::framework::LoDTensor>("mean");
auto mean = mean_var.data<float>()[0];
std::cout << "mean: " << mean << std::endl;
ASSERT_NEAR(4.5, mean, 1e-9);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册