提交 dd67ad08 编写于 作者: X xjqbest 提交者: dongdaxiang

modify c++ and python dataset related code & fix bug

上级 cc4def6b
......@@ -206,7 +206,7 @@ cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass data_feed_proto
variable_helper timer)
variable_helper timer fs shell)
cc_test(data_feed_test SRCS data_feed_test.cc DEPS async_executor)
......
......@@ -59,6 +59,12 @@ void AsyncExecutor::GatherServers(const std::vector<uint64_t>& host_sign_list,
fleet_ptr_->GatherServers(host_sign_list, node_num);
}
// todo InitModel
void AsyncExecutor::InitModel() { }
// todo SaveModel
void AsyncExecutor::SaveModel(const std::string& path) { }
void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
const std::string& data_feed_desc_str,
const std::vector<std::string>& filelist,
......@@ -154,5 +160,11 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
return;
}
// todo RunFromDataset
void AsyncExecutor::RunFromDataset(const ProgramDesc& main_program,
Dataset* data_set,
const std::string& trainer_desc_str,
const bool debug) { }
} // end namespace framework
} // end namespace paddle
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed.h"
#include <stdio_ext.h>
#include <utility>
#include "gflags/gflags.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
......@@ -135,6 +136,7 @@ int PrivateQueueDataFeed<T>::Next() {
return batch_size_;
}
// explicit instantiation
template class PrivateQueueDataFeed<std::vector<MultiSlotType>>;
template <typename T>
......@@ -220,8 +222,6 @@ void InMemoryDataFeed<T>::LocalShuffle() {
std::random_shuffle(memory_data_.begin(), memory_data_.end());
}
template class InMemoryDataFeed<std::vector<MultiSlotType>>;
// todo global shuffle
/*
template <typename T>
......@@ -242,6 +242,9 @@ void InMemoryDataFeed<T>::GlobalShuffle(int trainer_num) {
}
*/
// explicit instantiation
template class InMemoryDataFeed<std::vector<MultiSlotType>>;
void MultiSlotDataFeed::Init(
const paddle::framework::DataFeedDesc& data_feed_desc) {
finish_init_ = false;
......
......@@ -12,6 +12,9 @@
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/data_feed_factory.h"
......@@ -44,9 +47,9 @@ void Dataset::SetThreadNum(int thread_num) {
void Dataset::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; }
void Dataset::SetDataFeedDesc(
const paddle::framework::DataFeedDesc& data_feed_desc) {
data_feed_desc_ = data_feed_desc;
void Dataset::SetDataFeedDesc(const std::string& data_feed_desc_str) {
google::protobuf::TextFormat::ParseFromString(
data_feed_desc_str, &data_feed_desc_);
}
std::vector<std::shared_ptr<paddle::framework::DataFeed>>
......
......@@ -34,8 +34,7 @@ class Dataset {
virtual void SetFileList(const std::vector<std::string>& filelist);
virtual void SetThreadNum(int thread_num);
virtual void SetTrainerNum(int trainer_num);
virtual void SetDataFeedDesc(
const paddle::framework::DataFeedDesc& data_feed_desc);
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
virtual const std::vector<std::string>& GetFileList() { return filelist_; }
virtual int GetThreadNum() { return thread_num_; }
......
......@@ -22,7 +22,7 @@ namespace paddle {
namespace framework {
void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
const Dataset& data_set) {
Dataset* data_set) {
thread_num_ = trainer_desc.thread_num();
workers_.resize(thread_num_);
readers_.resize(thread_num_);
......
......@@ -14,11 +14,9 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h"
#include <deque>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <unordered_map>
#include <utility>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
......@@ -119,7 +117,7 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
}
void Executor::RunFromDataset(const ProgramDesc& main_program,
const Dataset& dataset,
Dataset* dataset,
const std::string& trainer_desc_str,
const bool debug) {
VLOG(3) << "Start to RunFromDataset in executor";
......
......@@ -19,6 +19,8 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include <vector>
#include <unordered_map>
#include <memory>
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/op_info.h"
......@@ -112,7 +114,7 @@ class Executor {
void EnableMKLDNN(const ProgramDesc& program);
void RunFromDataset(const ProgramDesc& main_program, const Dataset& dataset,
void RunFromDataset(const ProgramDesc& main_program, Dataset* dataset,
const std::string& trainer_desc_str, const bool debug);
public:
......
......@@ -22,7 +22,7 @@ namespace paddle {
namespace framework {
void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
const Dataset& dataset) {
Dataset* dataset) {
thread_num_ = trainer_desc.thread_num();
// get filelist from trainer_desc here
workers_.resize(thread_num_);
......
......@@ -42,7 +42,7 @@ class TrainerBase {
void SetScope(Scope* root_scope);
void SetDebug(const bool debug) { debug_ = debug; }
virtual void Initialize(const TrainerDesc& trainer_desc,
const Dataset& data_set) = 0;
Dataset* data_set) = 0;
virtual void InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place) = 0;
virtual void InitOtherEnv(const ProgramDesc& main_program) = 0;
......@@ -62,7 +62,7 @@ class MultiTrainer : public TrainerBase {
MultiTrainer() {}
virtual ~MultiTrainer() {}
virtual void Initialize(const TrainerDesc& trainer_desc,
const Dataset& data_set);
Dataset* data_set);
virtual void InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place);
virtual void InitOtherEnv(const ProgramDesc& main_program) {}
......@@ -81,7 +81,7 @@ class DistMultiTrainer : public MultiTrainer {
DistMultiTrainer() {}
virtual ~DistMultiTrainer() {}
virtual void Initialize(const TrainerDesc& trainer_desc,
const Dataset& data_set);
Dataset* data_set);
virtual void InitOtherEnv(const ProgramDesc& main_program);
virtual void Finalize();
......
......@@ -24,6 +24,9 @@ from .executor import *
from . import data_feed_desc
from .data_feed_desc import *
from . import dataset
from .dataset import *
from . import async_executor
from .async_executor import *
......
......@@ -139,10 +139,6 @@ class DataFeedDesc(object):
self.proto_desc.multi_slot_desc.slots[self.__name_to_index[
name]].is_used = True
def global_shuffle(self):
self.data.global_shuffle()
pass
def desc(self):
"""
Returns a protobuf message for this DataFeedDesc
......
......@@ -23,9 +23,9 @@ class DatasetFactory(object):
pass
def create_dataset(self, datafeed_class):
datafeed_class = datafeed_class.capitalize()
try:
dataset = globals()[datafeed_class]()
return dataset
except:
raise ValueError("datafeed class %s does not exist" %
datafeed_class)
......@@ -37,6 +37,7 @@ class DatasetBase(object):
# to decide whether we need create in memory instance
self.proto_desc = data_feed_pb2.DataFeedDesc()
self.proto_desc.pipe_command = "cat"
self.dataset = core.Dataset()
def set_pipe_command(self, pipe_command):
"""
......@@ -60,17 +61,23 @@ class DatasetBase(object):
"""
self.proto_desc.batch_size = batch_size
def set_thread(self, thread_num):
self.dataset.set_thread_num(thread_num)
def set_filelist(self, filelist):
self.dataset.set_filelist(filelist)
def set_use_var(self, var_list):
multi_slot = self.proto_desc.multi_slot_desc()
multi_slot = self.proto_desc.multi_slot_desc
for var in var_list:
slot_var = multi_slot.add()
slot_var = multi_slot.slots.add()
slot_var.is_used = True
slot_var.name = var.name
if var.lod_level == 0:
slot_var.is_dense = True
if var.dtype == core.VarType.FP32:
if var.dtype == core.VarDesc.VarType.FP32:
slot_var.type = "float32"
elif var.dtype == core.VarType.INT64:
elif var.dtype == core.VarDesc.VarType.INT64:
slot_var.type = "uint64"
else:
raise ValueError(
......@@ -93,17 +100,24 @@ class DatasetBase(object):
class InMemoryDataset(DatasetBase):
def __init__(self):
super(InMemoryDataset.__init__())
self.proto_desc.name = "InMemoryDataFeed"
super(InMemoryDataset, self).__init__()
self.proto_desc.name = "MultiSlotInMemoryDataFeed"
def load_into_memory(self):
self.dataset.set_data_feed_desc(self.desc())
self.dataset.load_into_memory()
def local_shuffle(self):
pass
self.dataset.local_shuffle()
def global_shuffle(self):
pass
from .distributed import ps_instance
instance = ps_instance.PaddlePSInstance(1, 2)
self.dataset.set_trainer_num(instance.get_worker_num())
self.global_shuffle()
class QueueDataset(DatasetBase):
def __init__(self):
super(QueueDataset.__init__())
super(QueueDataset, self).__init__()
self.proto_desc.name = "MultiSlotDataFeed"
......@@ -121,6 +121,18 @@ class PaddlePSInstance(object):
"""
return self._nodes
def get_worker_num(self):
"""
Return worker num
"""
return self._worker_num
def get_server_num(self):
"""
Return server num
"""
return self._server_num
def barrier_all(self):
"""
barrier workers and servers
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册