From ec23155affb294af2713ebb51d7fff2b7c7fe27a Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Thu, 29 Oct 2020 19:55:02 -0700 Subject: [PATCH] [XLA] Grouped dims do not need to be modified when being swapped as they will correspond one-to-one on both operands. PiperOrigin-RevId: 339796841 Change-Id: Ia71bb1bf74cb728a036b393c6ed16b2721137c7b --- .../xla/service/algebraic_simplifier.cc | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 27c308319f7..046701c564f 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -5220,13 +5220,31 @@ StatusOr AlgebraicSimplifierVisitor::SwapConvOperands( for (int64 spatial_dim = 0; spatial_dim < dnums.input_spatial_dimensions_size(); ++spatial_dim) { const int64 kernel_size = window_dims[spatial_dim].size(); - const int64 dilated_kernel_size = - 1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation(); - + const bool can_be_group_or_contraction = + !window_dims[spatial_dim].window_reversal() && + window_dims[spatial_dim].padding_low() == 0 && + window_dims[spatial_dim].padding_high() == 0 && + window_dims[spatial_dim].window_dilation() == 1; + const bool is_group_dim = + can_be_group_or_contraction && + window_dims[spatial_dim].base_dilation() == kernel_size && + window_dims[spatial_dim].stride() == kernel_size - 1; const int64 input_size = input->shape().dimensions(dnums.input_spatial_dimensions(spatial_dim)); + const bool is_pure_contraction_dim = + kernel_size == input_size && can_be_group_or_contraction && + window_dims[spatial_dim].base_dilation() == 1 && + window_dims[spatial_dim].stride() == 1; + if (is_group_dim || is_pure_contraction_dim) { + *(swapped_window.add_dimensions()) = window_dims[spatial_dim]; + continue; + } + + const int64 dilated_kernel_size = + 1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation(); const int64 dilated_input_size = 1 + (input_size - 1) * window_dims[spatial_dim].base_dilation(); + // Don't decide to swap if the input size is one, since many convolution // implementations can easily hand that special case efficiently. kernel_product *= kernel_size; -- GitLab