未验证 提交 43c9561e 编写于 作者: G guru4elephant 提交者: GitHub

add inductive shape index (#17435)

add inductive shape index
上级 712bfb17
......@@ -477,17 +477,17 @@ void MultiSlotDataFeed::Init(
use_slots_is_dense_.push_back(slot.is_dense());
std::vector<int> local_shape;
if (slot.is_dense()) {
for (size_t i = 0; i < slot.shape_size(); ++i) {
if (slot.shape(i) > 0) {
total_dims_without_inductive_[i] *= slot.shape(i);
for (size_t j = 0; j < slot.shape_size(); ++j) {
if (slot.shape(j) > 0) {
total_dims_without_inductive_[i] *= slot.shape(j);
}
if (slot.shape(i) == -1) {
inductive_shape_index_[i] = i;
if (slot.shape(j) == -1) {
inductive_shape_index_[i] = j;
}
}
}
for (size_t i = 0; i < slot.shape_size(); ++i) {
local_shape.push_back(slot.shape(i));
for (size_t j = 0; j < slot.shape_size(); ++j) {
local_shape.push_back(slot.shape(j));
}
use_slots_shape_.push_back(local_shape);
}
......@@ -811,22 +811,24 @@ void MultiSlotInMemoryDataFeed::Init(
all_slots_[i] = slot.name();
all_slots_type_[i] = slot.type();
use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1;
total_dims_without_inductive_[i] = 1;
inductive_shape_index_[i] = -1;
if (slot.is_used()) {
use_slots_.push_back(all_slots_[i]);
use_slots_is_dense_.push_back(slot.is_dense());
std::vector<int> local_shape;
if (slot.is_dense()) {
for (size_t i = 0; i < slot.shape_size(); ++i) {
if (slot.shape(i) > 0) {
total_dims_without_inductive_[i] *= slot.shape(i);
for (size_t j = 0; j < slot.shape_size(); ++j) {
if (slot.shape(j) > 0) {
total_dims_without_inductive_[i] *= slot.shape(j);
}
if (slot.shape(i) == -1) {
inductive_shape_index_[i] = i;
if (slot.shape(j) == -1) {
inductive_shape_index_[i] = j;
}
}
}
for (size_t i = 0; i < slot.shape_size(); ++i) {
local_shape.push_back(slot.shape(i));
for (size_t j = 0; j < slot.shape_size(); ++j) {
local_shape.push_back(slot.shape(j));
}
use_slots_shape_.push_back(local_shape);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册