提交 d8260b87 编写于 作者: R rensilin

execute

Change-Id: I316472bb3c9a2c9334876f3e4e6e9869aa4c3252
上级 29aec8e0
此差异已折叠。
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/inference/api/details/reset_tensor_array.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
namespace {
int ReadBinaryFile(const std::string& filename, std::string* contents) {
std::ifstream fin(filename, std::ios::in | std::ios::binary);
if (!fin) {
VLOG(4) << "Cannot open file " << filename;
return -1;
}
fin.seekg(0, std::ios::end);
contents->clear();
contents->resize(fin.tellg());
fin.seekg(0, std::ios::beg);
fin.read(&(contents->at(0)), contents->size());
fin.close();
return 0;
}
std::unique_ptr<paddle::framework::ProgramDesc> Load(
paddle::framework::Executor* /*executor*/, const std::string& model_filename) {
VLOG(3) << "loading model from " << model_filename;
std::string program_desc_str;
if (ReadBinaryFile(model_filename, &program_desc_str) != 0) {
return nullptr;
}
std::unique_ptr<paddle::framework::ProgramDesc> main_program(
new paddle::framework::ProgramDesc(program_desc_str));
return main_program;
}
}
struct SimpleExecute::Context {
Context(const ::paddle::platform::Place& place) : place(place), executor(place) {
}
const ::paddle::platform::Place& place;
::paddle::framework::Executor executor;
::std::unique_ptr<::paddle::framework::ProgramDesc> main_program;
::std::unique_ptr<framework::ExecutorPrepareContext> prepare_context;
details::TensorArrayBatchCleaner tensor_array_batch_cleaner;
};
SimpleExecute::SimpleExecute() {
}
SimpleExecute::~SimpleExecute() {
}
int SimpleExecute::initialize(YAML::Node& exe_config,
std::shared_ptr<TrainerContext> context_ptr) {
paddle::framework::InitDevices(false);
if (exe_config["num_threads"]) {
paddle::platform::SetNumThreads(exe_config["num_threads"].as<int>());
} else {
paddle::platform::SetNumThreads(1);
}
_context.reset(new SimpleExecute::Context(context_ptr->cpu_place));
auto startup_program = Load(&_context->executor, exe_config["startup_program"].as<std::string>());
if (startup_program == nullptr) {
VLOG(4) << "fail to load startup_program: " << exe_config["startup_program"].as<std::string>();
return -1;
}
_context->executor.Run(*startup_program, this->scope(), 0, false, true);
_context->main_program = Load(&_context->executor, exe_config["main_program"].as<std::string>());
if (_context->main_program == nullptr) {
VLOG(4) << "fail to load main_program: " << exe_config["main_program"].as<std::string>();
return -1;
}
_context->prepare_context = _context->executor.Prepare(*_context->main_program, 0);
_context->executor.CreateVariables(*_context->main_program, this->scope(), 0);
return 0;
}
int SimpleExecute::run() {
_context->executor.RunPreparedContext(_context->prepare_context.get(), this->scope(),
false, /* don't create local scope each time*/
false /* don't create variable each time */);
// For some other vector like containers not cleaned after each batch.
_context->tensor_array_batch_cleaner.CollectNoTensorVars(this->scope());
_context->tensor_array_batch_cleaner.ResetNoTensorVars();
return 0;
}
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
#pragma once #pragma once
#include <functional> #include <functional>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h" #include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h" #include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
...@@ -23,16 +23,16 @@ public: ...@@ -23,16 +23,16 @@ public:
} }
//直接取var //直接取var
template <class T> template <class T>
T* var(const std::string& name) { const T& var(const std::string& name) {
return _scope.Var(name).Get<T>(); return _scope.Var(name)->Get<T>();
} }
template <class T> template <class T>
T* mutable_var(const std::string& name) { T* mutable_var(const std::string& name) {
return _scope.Var(name)->GetMutable<T>(); return _scope.Var(name)->GetMutable<T>();
} }
//执行n轮训练,每轮回调(epoch_id, _scope) //执行训练
virtual int run(uint32_t epoch_num, std::function<void(uint32_t, ::paddle::framework::Scope*)>) = 0; virtual int run() = 0;
protected: protected:
::paddle::framework::Scope _scope; ::paddle::framework::Scope _scope;
...@@ -41,13 +41,14 @@ REGISTER_REGISTERER(Execute); ...@@ -41,13 +41,14 @@ REGISTER_REGISTERER(Execute);
class SimpleExecute : public Execute { class SimpleExecute : public Execute {
public: public:
SimpleExecute() {} SimpleExecute();
virtual ~SimpleExecute() {} virtual ~SimpleExecute();
virtual int initialize(YAML::Node& exe_config, virtual int initialize(YAML::Node& exe_config,
std::shared_ptr<TrainerContext> context_ptr); std::shared_ptr<TrainerContext> context_ptr);
virtual int run(uint32_t epoch_num, std::function<void(uint32_t, ::paddle::framework::Scope*)>) = 0; virtual int run();
protected: protected:
::paddle::framework::Executor _execute; struct Context;
std::unique_ptr<Context> _context;
}; };
} // namespace feed } // namespace feed
......
#include <gtest/gtest.h>
#include <gflags/gflags.h>
#include <glog/logging.h>
int32_t main(int32_t argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
::google::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging("paddle_trainer");
return RUN_ALL_TESTS();
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
TEST(testSimpleExecute, initialize) {
SimpleExecute execute;
auto context_ptr = std::make_shared<TrainerContext>();
auto config = YAML::Load("[1, 2, 3]");
ASSERT_NE(0, execute.initialize(config, context_ptr));
config = YAML::Load("{startup_program: ./data/startup_program, main_program: ./data/main_program}");
ASSERT_EQ(0, execute.initialize(config, context_ptr));
config = YAML::Load("{thread_num: 2, startup_program: ./data/startup_program, main_program: ./data/main_program}");
ASSERT_EQ(0, execute.initialize(config, context_ptr));
}
TEST(testSimpleExecute, run) {
SimpleExecute execute;
auto context_ptr = std::make_shared<TrainerContext>();
auto config = YAML::Load("{thread_num: 2, startup_program: ./data/startup_program, main_program: ./data/main_program}");
ASSERT_EQ(0, execute.initialize(config, context_ptr));
ASSERT_EQ(0, execute.run());
}
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册