提交 dd67ad08 编写于 作者: X xjqbest 提交者: dongdaxiang

modify c++ and python dataset related code & fix bug

上级 cc4def6b
...@@ -206,7 +206,7 @@ cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory. ...@@ -206,7 +206,7 @@ cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.
DEPS op_registry device_context scope framework_proto DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass data_feed_proto feed_fetch_method graph_to_program_pass data_feed_proto
variable_helper timer) variable_helper timer fs shell)
cc_test(data_feed_test SRCS data_feed_test.cc DEPS async_executor) cc_test(data_feed_test SRCS data_feed_test.cc DEPS async_executor)
......
...@@ -59,6 +59,12 @@ void AsyncExecutor::GatherServers(const std::vector<uint64_t>& host_sign_list, ...@@ -59,6 +59,12 @@ void AsyncExecutor::GatherServers(const std::vector<uint64_t>& host_sign_list,
fleet_ptr_->GatherServers(host_sign_list, node_num); fleet_ptr_->GatherServers(host_sign_list, node_num);
} }
// todo InitModel
void AsyncExecutor::InitModel() { }
// todo SaveModel
void AsyncExecutor::SaveModel(const std::string& path) { }
void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
const std::string& data_feed_desc_str, const std::string& data_feed_desc_str,
const std::vector<std::string>& filelist, const std::vector<std::string>& filelist,
...@@ -154,5 +160,11 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -154,5 +160,11 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
return; return;
} }
// todo RunFromDataset
void AsyncExecutor::RunFromDataset(const ProgramDesc& main_program,
Dataset* data_set,
const std::string& trainer_desc_str,
const bool debug) { }
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.h"
#include <stdio_ext.h> #include <stdio_ext.h>
#include <utility>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h" #include "google/protobuf/message.h"
...@@ -135,6 +136,7 @@ int PrivateQueueDataFeed<T>::Next() { ...@@ -135,6 +136,7 @@ int PrivateQueueDataFeed<T>::Next() {
return batch_size_; return batch_size_;
} }
// explicit instantiation
template class PrivateQueueDataFeed<std::vector<MultiSlotType>>; template class PrivateQueueDataFeed<std::vector<MultiSlotType>>;
template <typename T> template <typename T>
...@@ -220,8 +222,6 @@ void InMemoryDataFeed<T>::LocalShuffle() { ...@@ -220,8 +222,6 @@ void InMemoryDataFeed<T>::LocalShuffle() {
std::random_shuffle(memory_data_.begin(), memory_data_.end()); std::random_shuffle(memory_data_.begin(), memory_data_.end());
} }
template class InMemoryDataFeed<std::vector<MultiSlotType>>;
// todo global shuffle // todo global shuffle
/* /*
template <typename T> template <typename T>
...@@ -242,6 +242,9 @@ void InMemoryDataFeed<T>::GlobalShuffle(int trainer_num) { ...@@ -242,6 +242,9 @@ void InMemoryDataFeed<T>::GlobalShuffle(int trainer_num) {
} }
*/ */
// explicit instantiation
template class InMemoryDataFeed<std::vector<MultiSlotType>>;
void MultiSlotDataFeed::Init( void MultiSlotDataFeed::Init(
const paddle::framework::DataFeedDesc& data_feed_desc) { const paddle::framework::DataFeedDesc& data_feed_desc) {
finish_init_ = false; finish_init_ = false;
......
...@@ -12,6 +12,9 @@ ...@@ -12,6 +12,9 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/data_feed_factory.h" #include "paddle/fluid/framework/data_feed_factory.h"
...@@ -44,9 +47,9 @@ void Dataset::SetThreadNum(int thread_num) { ...@@ -44,9 +47,9 @@ void Dataset::SetThreadNum(int thread_num) {
void Dataset::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; } void Dataset::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; }
void Dataset::SetDataFeedDesc( void Dataset::SetDataFeedDesc(const std::string& data_feed_desc_str) {
const paddle::framework::DataFeedDesc& data_feed_desc) { google::protobuf::TextFormat::ParseFromString(
data_feed_desc_ = data_feed_desc; data_feed_desc_str, &data_feed_desc_);
} }
std::vector<std::shared_ptr<paddle::framework::DataFeed>> std::vector<std::shared_ptr<paddle::framework::DataFeed>>
......
...@@ -34,8 +34,7 @@ class Dataset { ...@@ -34,8 +34,7 @@ class Dataset {
virtual void SetFileList(const std::vector<std::string>& filelist); virtual void SetFileList(const std::vector<std::string>& filelist);
virtual void SetThreadNum(int thread_num); virtual void SetThreadNum(int thread_num);
virtual void SetTrainerNum(int trainer_num); virtual void SetTrainerNum(int trainer_num);
virtual void SetDataFeedDesc( virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
const paddle::framework::DataFeedDesc& data_feed_desc);
virtual const std::vector<std::string>& GetFileList() { return filelist_; } virtual const std::vector<std::string>& GetFileList() { return filelist_; }
virtual int GetThreadNum() { return thread_num_; } virtual int GetThreadNum() { return thread_num_; }
......
...@@ -22,7 +22,7 @@ namespace paddle { ...@@ -22,7 +22,7 @@ namespace paddle {
namespace framework { namespace framework {
void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc, void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
const Dataset& data_set) { Dataset* data_set) {
thread_num_ = trainer_desc.thread_num(); thread_num_ = trainer_desc.thread_num();
workers_.resize(thread_num_); workers_.resize(thread_num_);
readers_.resize(thread_num_); readers_.resize(thread_num_);
......
...@@ -14,11 +14,9 @@ limitations under the License. */ ...@@ -14,11 +14,9 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include <deque> #include <deque>
#include <memory>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <unordered_map>
#include <utility> #include <utility>
#include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h" #include "google/protobuf/message.h"
#include "google/protobuf/text_format.h" #include "google/protobuf/text_format.h"
...@@ -119,7 +117,7 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, ...@@ -119,7 +117,7 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
} }
void Executor::RunFromDataset(const ProgramDesc& main_program, void Executor::RunFromDataset(const ProgramDesc& main_program,
const Dataset& dataset, Dataset* dataset,
const std::string& trainer_desc_str, const std::string& trainer_desc_str,
const bool debug) { const bool debug) {
VLOG(3) << "Start to RunFromDataset in executor"; VLOG(3) << "Start to RunFromDataset in executor";
......
...@@ -19,6 +19,8 @@ limitations under the License. */ ...@@ -19,6 +19,8 @@ limitations under the License. */
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include <unordered_map>
#include <memory>
#include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
...@@ -112,7 +114,7 @@ class Executor { ...@@ -112,7 +114,7 @@ class Executor {
void EnableMKLDNN(const ProgramDesc& program); void EnableMKLDNN(const ProgramDesc& program);
void RunFromDataset(const ProgramDesc& main_program, const Dataset& dataset, void RunFromDataset(const ProgramDesc& main_program, Dataset* dataset,
const std::string& trainer_desc_str, const bool debug); const std::string& trainer_desc_str, const bool debug);
public: public:
......
...@@ -22,7 +22,7 @@ namespace paddle { ...@@ -22,7 +22,7 @@ namespace paddle {
namespace framework { namespace framework {
void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
const Dataset& dataset) { Dataset* dataset) {
thread_num_ = trainer_desc.thread_num(); thread_num_ = trainer_desc.thread_num();
// get filelist from trainer_desc here // get filelist from trainer_desc here
workers_.resize(thread_num_); workers_.resize(thread_num_);
......
...@@ -42,7 +42,7 @@ class TrainerBase { ...@@ -42,7 +42,7 @@ class TrainerBase {
void SetScope(Scope* root_scope); void SetScope(Scope* root_scope);
void SetDebug(const bool debug) { debug_ = debug; } void SetDebug(const bool debug) { debug_ = debug; }
virtual void Initialize(const TrainerDesc& trainer_desc, virtual void Initialize(const TrainerDesc& trainer_desc,
const Dataset& data_set) = 0; Dataset* data_set) = 0;
virtual void InitTrainerEnv(const ProgramDesc& main_program, virtual void InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place) = 0; const platform::Place& place) = 0;
virtual void InitOtherEnv(const ProgramDesc& main_program) = 0; virtual void InitOtherEnv(const ProgramDesc& main_program) = 0;
...@@ -62,7 +62,7 @@ class MultiTrainer : public TrainerBase { ...@@ -62,7 +62,7 @@ class MultiTrainer : public TrainerBase {
MultiTrainer() {} MultiTrainer() {}
virtual ~MultiTrainer() {} virtual ~MultiTrainer() {}
virtual void Initialize(const TrainerDesc& trainer_desc, virtual void Initialize(const TrainerDesc& trainer_desc,
const Dataset& data_set); Dataset* data_set);
virtual void InitTrainerEnv(const ProgramDesc& main_program, virtual void InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place); const platform::Place& place);
virtual void InitOtherEnv(const ProgramDesc& main_program) {} virtual void InitOtherEnv(const ProgramDesc& main_program) {}
...@@ -81,7 +81,7 @@ class DistMultiTrainer : public MultiTrainer { ...@@ -81,7 +81,7 @@ class DistMultiTrainer : public MultiTrainer {
DistMultiTrainer() {} DistMultiTrainer() {}
virtual ~DistMultiTrainer() {} virtual ~DistMultiTrainer() {}
virtual void Initialize(const TrainerDesc& trainer_desc, virtual void Initialize(const TrainerDesc& trainer_desc,
const Dataset& data_set); Dataset* data_set);
virtual void InitOtherEnv(const ProgramDesc& main_program); virtual void InitOtherEnv(const ProgramDesc& main_program);
virtual void Finalize(); virtual void Finalize();
......
...@@ -24,6 +24,9 @@ from .executor import * ...@@ -24,6 +24,9 @@ from .executor import *
from . import data_feed_desc from . import data_feed_desc
from .data_feed_desc import * from .data_feed_desc import *
from . import dataset
from .dataset import *
from . import async_executor from . import async_executor
from .async_executor import * from .async_executor import *
......
...@@ -139,10 +139,6 @@ class DataFeedDesc(object): ...@@ -139,10 +139,6 @@ class DataFeedDesc(object):
self.proto_desc.multi_slot_desc.slots[self.__name_to_index[ self.proto_desc.multi_slot_desc.slots[self.__name_to_index[
name]].is_used = True name]].is_used = True
def global_shuffle(self):
self.data.global_shuffle()
pass
def desc(self): def desc(self):
""" """
Returns a protobuf message for this DataFeedDesc Returns a protobuf message for this DataFeedDesc
......
...@@ -23,9 +23,9 @@ class DatasetFactory(object): ...@@ -23,9 +23,9 @@ class DatasetFactory(object):
pass pass
def create_dataset(self, datafeed_class): def create_dataset(self, datafeed_class):
datafeed_class = datafeed_class.capitalize()
try: try:
dataset = globals()[datafeed_class]() dataset = globals()[datafeed_class]()
return dataset
except: except:
raise ValueError("datafeed class %s does not exist" % raise ValueError("datafeed class %s does not exist" %
datafeed_class) datafeed_class)
...@@ -37,6 +37,7 @@ class DatasetBase(object): ...@@ -37,6 +37,7 @@ class DatasetBase(object):
# to decide whether we need create in memory instance # to decide whether we need create in memory instance
self.proto_desc = data_feed_pb2.DataFeedDesc() self.proto_desc = data_feed_pb2.DataFeedDesc()
self.proto_desc.pipe_command = "cat" self.proto_desc.pipe_command = "cat"
self.dataset = core.Dataset()
def set_pipe_command(self, pipe_command): def set_pipe_command(self, pipe_command):
""" """
...@@ -60,17 +61,23 @@ class DatasetBase(object): ...@@ -60,17 +61,23 @@ class DatasetBase(object):
""" """
self.proto_desc.batch_size = batch_size self.proto_desc.batch_size = batch_size
def set_thread(self, thread_num):
self.dataset.set_thread_num(thread_num)
def set_filelist(self, filelist):
self.dataset.set_filelist(filelist)
def set_use_var(self, var_list): def set_use_var(self, var_list):
multi_slot = self.proto_desc.multi_slot_desc() multi_slot = self.proto_desc.multi_slot_desc
for var in var_list: for var in var_list:
slot_var = multi_slot.add() slot_var = multi_slot.slots.add()
slot_var.is_used = True slot_var.is_used = True
slot_var.name = var.name slot_var.name = var.name
if var.lod_level == 0: if var.lod_level == 0:
slot_var.is_dense = True slot_var.is_dense = True
if var.dtype == core.VarType.FP32: if var.dtype == core.VarDesc.VarType.FP32:
slot_var.type = "float32" slot_var.type = "float32"
elif var.dtype == core.VarType.INT64: elif var.dtype == core.VarDesc.VarType.INT64:
slot_var.type = "uint64" slot_var.type = "uint64"
else: else:
raise ValueError( raise ValueError(
...@@ -93,17 +100,24 @@ class DatasetBase(object): ...@@ -93,17 +100,24 @@ class DatasetBase(object):
class InMemoryDataset(DatasetBase): class InMemoryDataset(DatasetBase):
def __init__(self): def __init__(self):
super(InMemoryDataset.__init__()) super(InMemoryDataset, self).__init__()
self.proto_desc.name = "InMemoryDataFeed" self.proto_desc.name = "MultiSlotInMemoryDataFeed"
def load_into_memory(self):
self.dataset.set_data_feed_desc(self.desc())
self.dataset.load_into_memory()
def local_shuffle(self): def local_shuffle(self):
pass self.dataset.local_shuffle()
def global_shuffle(self): def global_shuffle(self):
pass from .distributed import ps_instance
instance = ps_instance.PaddlePSInstance(1, 2)
self.dataset.set_trainer_num(instance.get_worker_num())
self.global_shuffle()
class QueueDataset(DatasetBase): class QueueDataset(DatasetBase):
def __init__(self): def __init__(self):
super(QueueDataset.__init__()) super(QueueDataset, self).__init__()
self.proto_desc.name = "MultiSlotDataFeed" self.proto_desc.name = "MultiSlotDataFeed"
...@@ -121,6 +121,18 @@ class PaddlePSInstance(object): ...@@ -121,6 +121,18 @@ class PaddlePSInstance(object):
""" """
return self._nodes return self._nodes
def get_worker_num(self):
"""
Return worker num
"""
return self._worker_num
def get_server_num(self):
"""
Return server num
"""
return self._server_num
def barrier_all(self): def barrier_all(self):
""" """
barrier workers and servers barrier workers and servers
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册