提交 ec23155a 编写于 作者: B Blake Hechtman 提交者: TensorFlower Gardener

[XLA] Grouped dims do not need to be modified when being swapped as they will

correspond one-to-one on both operands.

PiperOrigin-RevId: 339796841
Change-Id: Ia71bb1bf74cb728a036b393c6ed16b2721137c7b
上级 550434b7
......@@ -5220,13 +5220,31 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
for (int64 spatial_dim = 0;
spatial_dim < dnums.input_spatial_dimensions_size(); ++spatial_dim) {
const int64 kernel_size = window_dims[spatial_dim].size();
const int64 dilated_kernel_size =
1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation();
const bool can_be_group_or_contraction =
!window_dims[spatial_dim].window_reversal() &&
window_dims[spatial_dim].padding_low() == 0 &&
window_dims[spatial_dim].padding_high() == 0 &&
window_dims[spatial_dim].window_dilation() == 1;
const bool is_group_dim =
can_be_group_or_contraction &&
window_dims[spatial_dim].base_dilation() == kernel_size &&
window_dims[spatial_dim].stride() == kernel_size - 1;
const int64 input_size =
input->shape().dimensions(dnums.input_spatial_dimensions(spatial_dim));
const bool is_pure_contraction_dim =
kernel_size == input_size && can_be_group_or_contraction &&
window_dims[spatial_dim].base_dilation() == 1 &&
window_dims[spatial_dim].stride() == 1;
if (is_group_dim || is_pure_contraction_dim) {
*(swapped_window.add_dimensions()) = window_dims[spatial_dim];
continue;
}
const int64 dilated_kernel_size =
1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation();
const int64 dilated_input_size =
1 + (input_size - 1) * window_dims[spatial_dim].base_dilation();
// Don't decide to swap if the input size is one, since many convolution
// implementations can easily hand that special case efficiently.
kernel_product *= kernel_size;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册