diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index b5f7e6c22405d6928f0e423458d6cd720f2d09a8..01fd8b59d13dafc61ddfbbf4ba95a747a378aab8 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -466,6 +466,17 @@ void MultiSlotDataFeed::Init( if (slot.is_used()) { use_slots_.push_back(all_slots_[i]); use_slots_is_dense_.push_back(slot.is_dense()); + std::vector local_shape; + if (slot.is_dense()) { + // for batch size holder if is_dense + if (slot.shape(0) > 0) { + local_shape.push_back(0); + } + } + for (size_t i = 0; i < slot.shape_size(); ++i) { + local_shape.push_back(slot.shape(i)); + } + use_slots_shape_.push_back(local_shape); } } feed_vec_.resize(use_slots_.size()); @@ -752,8 +763,8 @@ void MultiSlotDataFeed::PutToFeedVec( LoD data_lod{offset}; feed_vec_[i]->set_lod(data_lod); if (use_slots_is_dense_[i]) { - int dim = total_instance / batch_size_; - feed_vec_[i]->Resize({batch_size_, dim}); + use_slots_shape_[i][0] = batch_size_; + feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i])); } } #endif @@ -785,6 +796,16 @@ void MultiSlotInMemoryDataFeed::Init( if (slot.is_used()) { use_slots_.push_back(all_slots_[i]); use_slots_is_dense_.push_back(slot.is_dense()); + std::vector local_shape; + if (slot.is_dense()) { + if (slot.shape(0) > 0) { + local_shape.push_back(0); + } + } + for (size_t i = 0; i < slot.shape_size(); ++i) { + local_shape.push_back(slot.shape(i)); + } + use_slots_shape_.push_back(local_shape); } } feed_vec_.resize(use_slots_.size()); @@ -940,8 +961,8 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec( LoD data_lod{offset}; feed_vec_[i]->set_lod(data_lod); if (use_slots_is_dense_[i]) { - int dim = total_instance / batch_size_; - feed_vec_[i]->Resize({batch_size_, dim}); + use_slots_shape_[i][0] = batch_size_; + feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i])); } } #endif diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 648c874a0b8763b18118e18adf3b3e93acfd104b..d098c7858a98c644bd3cad78d3cf1e3b35ca026b 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -142,6 +142,7 @@ class DataFeed { // object) std::vector all_slots_; std::vector all_slots_type_; + std::vector> use_slots_shape_; std::vector use_slots_index_; // -1: not used; >=0: the index of use_slots_ diff --git a/paddle/fluid/framework/data_feed.proto b/paddle/fluid/framework/data_feed.proto index 77911306299b77748a2ad9437d49680748885003..03996e0e20a1729ee300a5ad37abc325876930b7 100644 --- a/paddle/fluid/framework/data_feed.proto +++ b/paddle/fluid/framework/data_feed.proto @@ -19,6 +19,7 @@ message Slot { required string type = 2; optional bool is_dense = 3 [ default = false ]; optional bool is_used = 4 [ default = false ]; + repeated int32 shape = 5; // we can define N-D Tensor } message MultiSlotDesc { repeated Slot slots = 1; } diff --git a/paddle/fluid/imperative/nccl_context.h b/paddle/fluid/imperative/nccl_context.h index b4f44e56405a51082e60afd69fb6f011dab44b86..1c03b368868d2aaccb9d14a57f4b859afb7589b8 100644 --- a/paddle/fluid/imperative/nccl_context.h +++ b/paddle/fluid/imperative/nccl_context.h @@ -16,6 +16,7 @@ // network header files #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #include +#include #include #include #include diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index e655fd4a976a8a6fa2811ddc43de3d1f231229d5..276f28302675ba19ccf50285099df4c4e6590ddf 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -136,6 +136,8 @@ class DatasetBase(object): slot_var.name = var.name if var.lod_level == 0: slot_var.is_dense = True + print(var.shape) + slot_var.shape.extend(var.shape) if var.dtype == core.VarDesc.VarType.FP32: slot_var.type = "float" elif var.dtype == core.VarDesc.VarType.INT64: diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index e15197037e1d901855883919b02a1574b7bc9a29..fa8b49a021294e8555e979459615b1956d9b2b55 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -712,10 +712,6 @@ class Executor(object): if dataset == None: raise RuntimeError("dataset is needed and should be initialized") - if not isinstance(self.place, core.CPUPlace): - raise RuntimeError("infer_from_dataset is verified on CPUPlace" - "We will open CUDAPlace in the future") - scope, trainer = self._prepare_trainer( program=program, dataset=dataset, @@ -796,10 +792,6 @@ class Executor(object): if dataset == None: raise RuntimeError("dataset is need and should be initialized") - if not isinstance(self.place, core.CPUPlace): - raise RuntimeError("train_from_dataset is verified on CPUPlace" - "We will open CUDAPlace in the future") - scope, trainer = self._prepare_trainer( program=program, dataset=dataset,