提交 5a130818 编写于 作者: B Benoit Steiner 提交者: TensorFlower Gardener

Leverage index list to further speedup the computation of the convolution

gradients by 1 to 10% depending on the size of the convolution kernel.
Change: 118505660
上级 89aab41d
......@@ -44,6 +44,16 @@ namespace Eigen {
* that the same order is used in the input, the kernel, and the output.
*
*/
#ifdef EIGEN_HAS_INDEX_LIST
typedef IndexList<type2index<0>, type2index<0>, type2index<1>, type2index<1> >
ReverseColMajor;
typedef IndexList<type2index<1>, type2index<1>, type2index<0>, type2index<0> >
ReverseRowMajor;
#else
typedef array<bool, 4> ReverseColMajor;
typedef array<bool, 4> ReverseRowMajor;
#endif
template <typename OutputBackward, typename Kernel>
EIGEN_ALWAYS_INLINE static const typename internal::conditional<
internal::traits<OutputBackward>::Layout == ColMajor,
......@@ -59,7 +69,7 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
const TensorShufflingOp<
const array<
typename internal::traits<OutputBackward>::Index, 4>,
const TensorReverseOp<const array<bool, 4>,
const TensorReverseOp<const ReverseColMajor,
const Kernel> > > >,
const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
......@@ -83,7 +93,7 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
const TensorShufflingOp<
const array<
typename internal::traits<OutputBackward>::Index, 4>,
const TensorReverseOp<const array<bool, 4>,
const TensorReverseOp<const ReverseRowMajor,
const Kernel> > > > > > >::type
SpatialConvolutionBackwardInput(
const Kernel& kernel, const OutputBackward& output_backward,
......@@ -161,7 +171,11 @@ SpatialConvolutionBackwardInput(
// TODO(yangke): we can make things slightly faster by collapsing the
// dimensions
// where we don't reverse. Try that once we have a faster compiler.
array<bool, 4> kernel_reverse;
typedef typename internal::conditional<isColMajor, ReverseColMajor,
ReverseRowMajor>::type Reverse;
Reverse kernel_reverse;
#ifndef EIGEN_HAS_INDEX_LIST
if (isColMajor) {
kernel_reverse[0] = false;
kernel_reverse[1] = false;
......@@ -173,6 +187,7 @@ SpatialConvolutionBackwardInput(
kernel_reverse[2] = false;
kernel_reverse[3] = false;
}
#endif
// Reorder the dimensions to filters X patch_rows X patch_cols X channels
array<TensorIndex, 4> kernel_shuffle;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册