From 43c9561e9a383d08534a84d01808ff61d1e261e2 Mon Sep 17 00:00:00 2001 From: guru4elephant <35550832+guru4elephant@users.noreply.github.com> Date: Thu, 16 May 2019 15:36:36 +0800 Subject: [PATCH] add inductive shape index (#17435) add inductive shape index --- paddle/fluid/framework/data_feed.cc | 30 +++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index bdefb1df8ce..4f40786a959 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -477,17 +477,17 @@ void MultiSlotDataFeed::Init( use_slots_is_dense_.push_back(slot.is_dense()); std::vector local_shape; if (slot.is_dense()) { - for (size_t i = 0; i < slot.shape_size(); ++i) { - if (slot.shape(i) > 0) { - total_dims_without_inductive_[i] *= slot.shape(i); + for (size_t j = 0; j < slot.shape_size(); ++j) { + if (slot.shape(j) > 0) { + total_dims_without_inductive_[i] *= slot.shape(j); } - if (slot.shape(i) == -1) { - inductive_shape_index_[i] = i; + if (slot.shape(j) == -1) { + inductive_shape_index_[i] = j; } } } - for (size_t i = 0; i < slot.shape_size(); ++i) { - local_shape.push_back(slot.shape(i)); + for (size_t j = 0; j < slot.shape_size(); ++j) { + local_shape.push_back(slot.shape(j)); } use_slots_shape_.push_back(local_shape); } @@ -811,22 +811,24 @@ void MultiSlotInMemoryDataFeed::Init( all_slots_[i] = slot.name(); all_slots_type_[i] = slot.type(); use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1; + total_dims_without_inductive_[i] = 1; + inductive_shape_index_[i] = -1; if (slot.is_used()) { use_slots_.push_back(all_slots_[i]); use_slots_is_dense_.push_back(slot.is_dense()); std::vector local_shape; if (slot.is_dense()) { - for (size_t i = 0; i < slot.shape_size(); ++i) { - if (slot.shape(i) > 0) { - total_dims_without_inductive_[i] *= slot.shape(i); + for (size_t j = 0; j < slot.shape_size(); ++j) { + if (slot.shape(j) > 0) { + total_dims_without_inductive_[i] *= slot.shape(j); } - if (slot.shape(i) == -1) { - inductive_shape_index_[i] = i; + if (slot.shape(j) == -1) { + inductive_shape_index_[i] = j; } } } - for (size_t i = 0; i < slot.shape_size(); ++i) { - local_shape.push_back(slot.shape(i)); + for (size_t j = 0; j < slot.shape_size(); ++j) { + local_shape.push_back(slot.shape(j)); } use_slots_shape_.push_back(local_shape); } -- GitLab