提交 cb01a295 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Fix tpu_ops.all_to_all op output shape.

PiperOrigin-RevId: 262676072
上级 e6dc56ce
......@@ -40,6 +40,9 @@ REGISTER_OP("AllToAll")
}
int concat_dimension;
int split_dimension;
int split_count;
TF_RETURN_IF_ERROR(c->GetAttr("split_count", &split_count));
TF_RETURN_IF_ERROR(c->GetAttr("concat_dimension", &concat_dimension));
......@@ -58,14 +61,13 @@ REGISTER_OP("AllToAll")
dims.resize(rank);
for (int32 i = 0; i < rank; ++i) {
int64 in_idx = i;
dims[i] = c->Dim(input, i);
if (i == concat_dimension) {
in_idx = split_dimension;
} else if (i == split_dimension) {
in_idx = concat_dimension;
dims[i] = c->MakeDim(c->Value(dims[i]) * split_count);
}
if (i == split_dimension) {
dims[i] = c->MakeDim(c->Value(dims[i]) / split_count);
}
dims[i] = c->Dim(input, in_idx);
}
c->set_output(0, c->MakeShape(dims));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册