未验证 提交 43c9561e 编写于 作者: G guru4elephant 提交者: GitHub

add inductive shape index (#17435)

add inductive shape index
上级 712bfb17
...@@ -477,17 +477,17 @@ void MultiSlotDataFeed::Init( ...@@ -477,17 +477,17 @@ void MultiSlotDataFeed::Init(
use_slots_is_dense_.push_back(slot.is_dense()); use_slots_is_dense_.push_back(slot.is_dense());
std::vector<int> local_shape; std::vector<int> local_shape;
if (slot.is_dense()) { if (slot.is_dense()) {
for (size_t i = 0; i < slot.shape_size(); ++i) { for (size_t j = 0; j < slot.shape_size(); ++j) {
if (slot.shape(i) > 0) { if (slot.shape(j) > 0) {
total_dims_without_inductive_[i] *= slot.shape(i); total_dims_without_inductive_[i] *= slot.shape(j);
} }
if (slot.shape(i) == -1) { if (slot.shape(j) == -1) {
inductive_shape_index_[i] = i; inductive_shape_index_[i] = j;
} }
} }
} }
for (size_t i = 0; i < slot.shape_size(); ++i) { for (size_t j = 0; j < slot.shape_size(); ++j) {
local_shape.push_back(slot.shape(i)); local_shape.push_back(slot.shape(j));
} }
use_slots_shape_.push_back(local_shape); use_slots_shape_.push_back(local_shape);
} }
...@@ -811,22 +811,24 @@ void MultiSlotInMemoryDataFeed::Init( ...@@ -811,22 +811,24 @@ void MultiSlotInMemoryDataFeed::Init(
all_slots_[i] = slot.name(); all_slots_[i] = slot.name();
all_slots_type_[i] = slot.type(); all_slots_type_[i] = slot.type();
use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1; use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1;
total_dims_without_inductive_[i] = 1;
inductive_shape_index_[i] = -1;
if (slot.is_used()) { if (slot.is_used()) {
use_slots_.push_back(all_slots_[i]); use_slots_.push_back(all_slots_[i]);
use_slots_is_dense_.push_back(slot.is_dense()); use_slots_is_dense_.push_back(slot.is_dense());
std::vector<int> local_shape; std::vector<int> local_shape;
if (slot.is_dense()) { if (slot.is_dense()) {
for (size_t i = 0; i < slot.shape_size(); ++i) { for (size_t j = 0; j < slot.shape_size(); ++j) {
if (slot.shape(i) > 0) { if (slot.shape(j) > 0) {
total_dims_without_inductive_[i] *= slot.shape(i); total_dims_without_inductive_[i] *= slot.shape(j);
} }
if (slot.shape(i) == -1) { if (slot.shape(j) == -1) {
inductive_shape_index_[i] = i; inductive_shape_index_[i] = j;
} }
} }
} }
for (size_t i = 0; i < slot.shape_size(); ++i) { for (size_t j = 0; j < slot.shape_size(); ++j) {
local_shape.push_back(slot.shape(i)); local_shape.push_back(slot.shape(j));
} }
use_slots_shape_.push_back(local_shape); use_slots_shape_.push_back(local_shape);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册