提交 b415ec27 编写于 作者: D dongdaxiang

make Dataset* as an argument

上级 dd67ad08
...@@ -30,7 +30,7 @@ add_subdirectory(io) ...@@ -30,7 +30,7 @@ add_subdirectory(io)
proto_library(framework_proto SRCS framework.proto) proto_library(framework_proto SRCS framework.proto)
proto_library(data_feed_proto SRCS data_feed.proto) proto_library(data_feed_proto SRCS data_feed.proto)
proto_library(async_executor_proto SRCS data_feed.proto) proto_library(async_executor_proto SRCS data_feed.proto)
proto_library(trainer_desc_proto SRCS trainer_desc.proto) proto_library(trainer_desc_proto SRCS trainer_desc.proto data_feed.proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
......
...@@ -52,7 +52,7 @@ void Dataset::SetDataFeedDesc(const std::string& data_feed_desc_str) { ...@@ -52,7 +52,7 @@ void Dataset::SetDataFeedDesc(const std::string& data_feed_desc_str) {
data_feed_desc_str, &data_feed_desc_); data_feed_desc_str, &data_feed_desc_);
} }
std::vector<std::shared_ptr<paddle::framework::DataFeed>> const std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
Dataset::GetReaders() { Dataset::GetReaders() {
return readers_; return readers_;
} }
......
...@@ -43,7 +43,7 @@ class Dataset { ...@@ -43,7 +43,7 @@ class Dataset {
return data_feed_desc_; return data_feed_desc_;
} }
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>> virtual const std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
GetReaders(); GetReaders();
virtual void LoadIntoMemory(); virtual void LoadIntoMemory();
virtual void LocalShuffle(); virtual void LocalShuffle();
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_feed_factory.h" #include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h" #include "paddle/fluid/framework/trainer.h"
...@@ -25,26 +26,18 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -25,26 +26,18 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* data_set) { Dataset* data_set) {
thread_num_ = trainer_desc.thread_num(); thread_num_ = trainer_desc.thread_num();
workers_.resize(thread_num_); workers_.resize(thread_num_);
readers_.resize(thread_num_);
const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers =
data_set->GetReaders();
for (int i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name()); trainer_desc.device_worker_name());
readers_[i] =
DataFeedFactory::CreateDataFeed(trainer_desc.data_desc().name());
workers_[i]->SetDeviceIndex(i); workers_[i]->SetDeviceIndex(i);
readers_[i]->Init(trainer_desc.data_desc()); workers_[i]->SetDataFeed(readers[i]);
workers_[i]->SetDataFeed(readers_[i]);
workers_[i]->Initialize(trainer_desc); workers_[i]->Initialize(trainer_desc);
} }
std::vector<std::string> filelist_vec;
for (unsigned i = 0; i < trainer_desc.filelist_size(); ++i) {
filelist_vec.push_back(trainer_desc.filelist(i));
}
readers_[0]->SetFileList(filelist_vec);
fleet_ptr_ = FleetWrapper::GetInstance(); fleet_ptr_ = FleetWrapper::GetInstance();
pull_dense_worker_ = PullDenseWorker::GetInstance(); pull_dense_worker_ = PullDenseWorker::GetInstance();
pull_dense_worker_->Initialize(trainer_desc); pull_dense_worker_->Initialize(trainer_desc);
......
...@@ -116,10 +116,9 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, ...@@ -116,10 +116,9 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
} }
} }
void Executor::RunFromDataset(const ProgramDesc& main_program, void Executor::RunFromDataset(const ProgramDesc& main_program, Scope* scope,
Dataset* dataset, Dataset* dataset,
const std::string& trainer_desc_str, const std::string& trainer_desc_str) {
const bool debug) {
VLOG(3) << "Start to RunFromDataset in executor"; VLOG(3) << "Start to RunFromDataset in executor";
TrainerDesc trainer_desc; TrainerDesc trainer_desc;
google::protobuf::TextFormat::ParseFromString(trainer_desc_str, google::protobuf::TextFormat::ParseFromString(trainer_desc_str,
...@@ -132,9 +131,7 @@ void Executor::RunFromDataset(const ProgramDesc& main_program, ...@@ -132,9 +131,7 @@ void Executor::RunFromDataset(const ProgramDesc& main_program,
VLOG(3) << "Going to initialize trainer"; VLOG(3) << "Going to initialize trainer";
trainer->Initialize(trainer_desc, dataset); trainer->Initialize(trainer_desc, dataset);
VLOG(3) << "Set root scope here"; VLOG(3) << "Set root scope here";
trainer->SetScope(root_scope_); trainer->SetScope(scope);
VLOG(3) << "Going to set debug";
trainer->SetDebug(debug);
// prepare training environment and helper environment // prepare training environment and helper environment
VLOG(3) << "Try to init train environment"; VLOG(3) << "Try to init train environment";
trainer->InitTrainerEnv(main_program, place_); trainer->InitTrainerEnv(main_program, place_);
...@@ -146,7 +143,7 @@ void Executor::RunFromDataset(const ProgramDesc& main_program, ...@@ -146,7 +143,7 @@ void Executor::RunFromDataset(const ProgramDesc& main_program,
VLOG(3) << "Trainer going to finalize"; VLOG(3) << "Trainer going to finalize";
trainer->Finalize(); trainer->Finalize();
VLOG(3) << "Drop current scope kids"; VLOG(3) << "Drop current scope kids";
root_scope_->DropKids(); scope->DropKids();
return; return;
} }
......
...@@ -114,16 +114,11 @@ class Executor { ...@@ -114,16 +114,11 @@ class Executor {
void EnableMKLDNN(const ProgramDesc& program); void EnableMKLDNN(const ProgramDesc& program);
void RunFromDataset(const ProgramDesc& main_program, Dataset* dataset, void RunFromDataset(const ProgramDesc& main_program, Scope* scope,
const std::string& trainer_desc_str, const bool debug); Dataset* dataset, const std::string& trainer_desc_str);
public:
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
Scope* root_scope_;
private: private:
const platform::Place place_; const platform::Place place_;
int actual_thread_num_;
}; };
} // namespace framework } // namespace framework
......
...@@ -26,31 +26,16 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -26,31 +26,16 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
thread_num_ = trainer_desc.thread_num(); thread_num_ = trainer_desc.thread_num();
// get filelist from trainer_desc here // get filelist from trainer_desc here
workers_.resize(thread_num_); workers_.resize(thread_num_);
const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers =
/* dataset->GetReaders();
if (NULL == dataset) {
readers_.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
readers_[i] =
DataFeedFactory::CreateDataFeed(trainer_desc.data_desc().name());
readers_[i]->Init(trainer_desc.data_desc());
}
std::vector<std::string> filelist_vec;
for (unsigned i = 0; i < trainer_desc.filelist_size(); ++i) {
filelist_vec.push_back(trainer_desc.filelist(i));
}
readers_[0]->SetFileList(filelist_vec);
} else {
// readers_ = dataset.get_readers(); ?
}
*/
for (int i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name()); trainer_desc.device_worker_name());
workers_[i]->SetDeviceIndex(i); workers_[i]->SetDeviceIndex(i);
workers_[i]->SetDataFeed(readers_[i]); workers_[i]->SetDataFeed(readers[i]);
} }
// set debug here
} }
// call only after all resources are set in current trainer // call only after all resources are set in current trainer
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from .. import core
__all__ = ['Fleet']
class Fleet(object):
"""
"""
def __init__(self):
self.instance_ = ps_instance.PaddlePSInstance()
self.fleet_ = core.FleetWrapper()
def stop(self):
self.instance_.barrier_worker()
if self.instance.is_first_worker():
self.fleet_.stop_server()
self.instance_.barrier_worker()
self.instance_.barrier_all()
self.instance.finalize()
def init_pserver(self, dist_desc):
self.dist_desc_str_ = text_format.MessageToString(dist_desc)
self.dist_desc = dist_desc
self.fleet_.init_server(self.dist_desc_str_)
ip = self.fleet_.start_server()
self.instance_.set_ip(ip)
self.instance.barrier_all()
ips = self.instance.gather_ips()
self.fleet.gather_servers(ips, self.instance_.get_node_cnt())
self.instance_.barrier_all()
def init_worker(self, dist_desc):
self.dist_desc_str_ = text_format.MessageToString(dist_desc)
self.dist_desc_ = dist_desc
self.instance_.barrier_all()
ips = self.instance.gather_ips()
self.fleet_.init_worker(self.dist_desc_str_, ips,
self.instance_.get_node_cnt(),
self.instance._rankid)
self.instance.barrier_worker()
def init_pserver_model(self):
if self.instance_.is_first_worker():
self.fleet_.init_model()
self.instance_.barrier_worker()
def save_pserver_model(self, save_path):
self.fleet_.save_model(save_path)
...@@ -610,3 +610,23 @@ class Executor(object): ...@@ -610,3 +610,23 @@ class Executor(object):
def _run_inference(self, exe, feed): def _run_inference(self, exe, feed):
return exe.run(feed) return exe.run(feed)
def run_from_dataset(self,
program=None,
dataset=None,
fetch_list=None,
scope=None,
opt_info=None):
if scope is None:
scope = global_scope()
if fetch_list is None:
fetch_list = []
compiled = isinstance(program, compiler.CompiledProgram)
if not compiled:
trainer = TrainerFactory().create_trainer(opt_info)
self._default_executor.run_from_dataset(program_desc,
trainer._desc())
else:
# For compiled program, more runtime should be implemented
print("run_from_dataset current does not support compiled program"
", we will support this later", sys.stderr)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,5 +12,21 @@ ...@@ -12,5 +12,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# NOTE: Trainer is moved into fluid.contrib.trainer. __all__ = ["TrainerFactory"]
__all__ = []
class TrainerFactory(object):
def __init__(self):
pass
def create_trainer(self, opt_info=None):
if opt_info == None:
return MultiTrainer()
else:
if opt_info["optimizer"] == "DownpourSGD":
trainer = DistMultiTrainer()
trainer.gen_trainer_desc(
fleet_desc=opt_info["fleet"], worker="downpour")
return trainer
else:
print("Currently only support DownpourSGD")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册