提交 b415ec27 编写于 作者: D dongdaxiang

make Dataset* as an argument

上级 dd67ad08
......@@ -30,7 +30,7 @@ add_subdirectory(io)
proto_library(framework_proto SRCS framework.proto)
proto_library(data_feed_proto SRCS data_feed.proto)
proto_library(async_executor_proto SRCS data_feed.proto)
proto_library(trainer_desc_proto SRCS trainer_desc.proto)
proto_library(trainer_desc_proto SRCS trainer_desc.proto data_feed.proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
......
......@@ -52,7 +52,7 @@ void Dataset::SetDataFeedDesc(const std::string& data_feed_desc_str) {
data_feed_desc_str, &data_feed_desc_);
}
std::vector<std::shared_ptr<paddle::framework::DataFeed>>
const std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
Dataset::GetReaders() {
return readers_;
}
......
......@@ -43,7 +43,7 @@ class Dataset {
return data_feed_desc_;
}
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>
virtual const std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
GetReaders();
virtual void LoadIntoMemory();
virtual void LocalShuffle();
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
......@@ -25,26 +26,18 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* data_set) {
thread_num_ = trainer_desc.thread_num();
workers_.resize(thread_num_);
readers_.resize(thread_num_);
const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers =
data_set->GetReaders();
for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
readers_[i] =
DataFeedFactory::CreateDataFeed(trainer_desc.data_desc().name());
workers_[i]->SetDeviceIndex(i);
readers_[i]->Init(trainer_desc.data_desc());
workers_[i]->SetDataFeed(readers_[i]);
workers_[i]->SetDataFeed(readers[i]);
workers_[i]->Initialize(trainer_desc);
}
std::vector<std::string> filelist_vec;
for (unsigned i = 0; i < trainer_desc.filelist_size(); ++i) {
filelist_vec.push_back(trainer_desc.filelist(i));
}
readers_[0]->SetFileList(filelist_vec);
fleet_ptr_ = FleetWrapper::GetInstance();
pull_dense_worker_ = PullDenseWorker::GetInstance();
pull_dense_worker_->Initialize(trainer_desc);
......
......@@ -116,10 +116,9 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
}
}
void Executor::RunFromDataset(const ProgramDesc& main_program,
void Executor::RunFromDataset(const ProgramDesc& main_program, Scope* scope,
Dataset* dataset,
const std::string& trainer_desc_str,
const bool debug) {
const std::string& trainer_desc_str) {
VLOG(3) << "Start to RunFromDataset in executor";
TrainerDesc trainer_desc;
google::protobuf::TextFormat::ParseFromString(trainer_desc_str,
......@@ -132,9 +131,7 @@ void Executor::RunFromDataset(const ProgramDesc& main_program,
VLOG(3) << "Going to initialize trainer";
trainer->Initialize(trainer_desc, dataset);
VLOG(3) << "Set root scope here";
trainer->SetScope(root_scope_);
VLOG(3) << "Going to set debug";
trainer->SetDebug(debug);
trainer->SetScope(scope);
// prepare training environment and helper environment
VLOG(3) << "Try to init train environment";
trainer->InitTrainerEnv(main_program, place_);
......@@ -146,7 +143,7 @@ void Executor::RunFromDataset(const ProgramDesc& main_program,
VLOG(3) << "Trainer going to finalize";
trainer->Finalize();
VLOG(3) << "Drop current scope kids";
root_scope_->DropKids();
scope->DropKids();
return;
}
......
......@@ -114,16 +114,11 @@ class Executor {
void EnableMKLDNN(const ProgramDesc& program);
void RunFromDataset(const ProgramDesc& main_program, Dataset* dataset,
const std::string& trainer_desc_str, const bool debug);
public:
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
Scope* root_scope_;
void RunFromDataset(const ProgramDesc& main_program, Scope* scope,
Dataset* dataset, const std::string& trainer_desc_str);
private:
const platform::Place place_;
int actual_thread_num_;
};
} // namespace framework
......
......@@ -26,31 +26,16 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
thread_num_ = trainer_desc.thread_num();
// get filelist from trainer_desc here
workers_.resize(thread_num_);
/*
if (NULL == dataset) {
readers_.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
readers_[i] =
DataFeedFactory::CreateDataFeed(trainer_desc.data_desc().name());
readers_[i]->Init(trainer_desc.data_desc());
}
std::vector<std::string> filelist_vec;
for (unsigned i = 0; i < trainer_desc.filelist_size(); ++i) {
filelist_vec.push_back(trainer_desc.filelist(i));
}
readers_[0]->SetFileList(filelist_vec);
} else {
// readers_ = dataset.get_readers(); ?
}
*/
const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers =
dataset->GetReaders();
for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
workers_[i]->SetDeviceIndex(i);
workers_[i]->SetDataFeed(readers_[i]);
workers_[i]->SetDataFeed(readers[i]);
}
// set debug here
}
// call only after all resources are set in current trainer
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from .. import core
__all__ = ['Fleet']
class Fleet(object):
"""
"""
def __init__(self):
self.instance_ = ps_instance.PaddlePSInstance()
self.fleet_ = core.FleetWrapper()
def stop(self):
self.instance_.barrier_worker()
if self.instance.is_first_worker():
self.fleet_.stop_server()
self.instance_.barrier_worker()
self.instance_.barrier_all()
self.instance.finalize()
def init_pserver(self, dist_desc):
self.dist_desc_str_ = text_format.MessageToString(dist_desc)
self.dist_desc = dist_desc
self.fleet_.init_server(self.dist_desc_str_)
ip = self.fleet_.start_server()
self.instance_.set_ip(ip)
self.instance.barrier_all()
ips = self.instance.gather_ips()
self.fleet.gather_servers(ips, self.instance_.get_node_cnt())
self.instance_.barrier_all()
def init_worker(self, dist_desc):
self.dist_desc_str_ = text_format.MessageToString(dist_desc)
self.dist_desc_ = dist_desc
self.instance_.barrier_all()
ips = self.instance.gather_ips()
self.fleet_.init_worker(self.dist_desc_str_, ips,
self.instance_.get_node_cnt(),
self.instance._rankid)
self.instance.barrier_worker()
def init_pserver_model(self):
if self.instance_.is_first_worker():
self.fleet_.init_model()
self.instance_.barrier_worker()
def save_pserver_model(self, save_path):
self.fleet_.save_model(save_path)
......@@ -610,3 +610,23 @@ class Executor(object):
def _run_inference(self, exe, feed):
return exe.run(feed)
def run_from_dataset(self,
program=None,
dataset=None,
fetch_list=None,
scope=None,
opt_info=None):
if scope is None:
scope = global_scope()
if fetch_list is None:
fetch_list = []
compiled = isinstance(program, compiler.CompiledProgram)
if not compiled:
trainer = TrainerFactory().create_trainer(opt_info)
self._default_executor.run_from_dataset(program_desc,
trainer._desc())
else:
# For compiled program, more runtime should be implemented
print("run_from_dataset current does not support compiled program"
", we will support this later", sys.stderr)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,5 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: Trainer is moved into fluid.contrib.trainer.
__all__ = []
__all__ = ["TrainerFactory"]
class TrainerFactory(object):
def __init__(self):
pass
def create_trainer(self, opt_info=None):
if opt_info == None:
return MultiTrainer()
else:
if opt_info["optimizer"] == "DownpourSGD":
trainer = DistMultiTrainer()
trainer.gen_trainer_desc(
fleet_desc=opt_info["fleet"], worker="downpour")
return trainer
else:
print("Currently only support DownpourSGD")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册