提交 da767e10 编写于 作者: B Benoit Steiner 提交者: TensorFlower Gardener

Sped up the computation of the gradients of the convolution layer wrt the

inputs by 30% on GPU.
Change: 118221540
上级 90333773
......@@ -38,13 +38,44 @@ namespace Eigen {
*/
template <typename OutputBackward, typename Kernel>
EIGEN_ALWAYS_INLINE
static const typename internal::conditional<
internal::traits<OutputBackward>::Layout == ColMajor,
TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, internal::traits<OutputBackward>::NumDimensions>, const TensorContractionOp<const array<IndexPair<typename internal::traits<OutputBackward>::Index>, 2>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 3>, const TensorReverseOp<const array<bool, 4>, const Kernel> >, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 3>, const TensorImagePatchOp<Dynamic, Dynamic, const OutputBackward> > > >,
TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, internal::traits<OutputBackward>::NumDimensions>, const TensorContractionOp<const array<IndexPair<typename internal::traits<OutputBackward>::Index>, 2>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 3>, const TensorImagePatchOp<Dynamic, Dynamic, const OutputBackward> >, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 3>, const TensorReverseOp<const array<bool, 4>, const Kernel> > > > >::type
SpatialConvolutionBackwardInput(const Kernel& kernel, const OutputBackward& output_backward, typename internal::traits<OutputBackward>::Index inputRows, typename internal::traits<OutputBackward>::Index inputCols, const DenseIndex stride = 1, const DenseIndex in_stride = 1) {
EIGEN_ALWAYS_INLINE static const typename internal::conditional<
internal::traits<OutputBackward>::Layout == ColMajor,
TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
internal::traits<OutputBackward>::NumDimensions>,
const TensorContractionOp<
const array<
IndexPair<typename internal::traits<OutputBackward>::Index>, 2>,
const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
3>,
const TensorReverseOp<const array<bool, 4>, const Kernel> > >,
const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
3>,
const TensorImagePatchOp<Dynamic, Dynamic,
const OutputBackward> > > >,
TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
internal::traits<OutputBackward>::NumDimensions>,
const TensorContractionOp<
const array<
IndexPair<typename internal::traits<OutputBackward>::Index>, 2>,
const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
3>,
const TensorImagePatchOp<Dynamic, Dynamic,
const OutputBackward> >,
const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
3>,
const TensorReverseOp<const array<bool, 4>,
const Kernel> > > > > >::type
SpatialConvolutionBackwardInput(
const Kernel& kernel, const OutputBackward& output_backward,
typename internal::traits<OutputBackward>::Index inputRows,
typename internal::traits<OutputBackward>::Index inputCols,
const DenseIndex stride = 1, const DenseIndex in_stride = 1) {
typedef typename internal::traits<OutputBackward>::Index TensorIndex;
typedef typename internal::traits<OutputBackward>::Scalar OutScalar;
TensorRef<Tensor<typename internal::traits<Kernel>::Scalar, internal::traits<Kernel>::NumDimensions, internal::traits<Kernel>::Layout, TensorIndex> > kern(kernel);
......@@ -167,9 +198,28 @@ SpatialConvolutionBackwardInput(const Kernel& kernel, const OutputBackward& outp
}
}
return choose(Cond<internal::traits<OutputBackward>::Layout == ColMajor>(),
kernel.reverse(kernel_reverse).reshape(kernel_dims).contract(output_backward.extract_image_patches(kernelRows, kernelCols, 1, 1, in_stride, in_stride, stride, stride, padding_top, padding_bottom, padding_left, padding_right, OutScalar(0)).reshape(pre_contract_dims), contract_dims).reshape(post_contract_dims),
output_backward.extract_image_patches(kernelRows, kernelCols, 1, 1, in_stride, in_stride, stride, stride, padding_top, padding_bottom, padding_left, padding_right, OutScalar(0)).reshape(pre_contract_dims).contract(kernel.reverse(kernel_reverse).reshape(kernel_dims), contract_dims).reshape(post_contract_dims));
return choose(
Cond<internal::traits<OutputBackward>::Layout == ColMajor>(),
kernel.reverse(kernel_reverse)
.reshape(kernel_dims)
.eval()
.contract(output_backward
.extract_image_patches(
kernelRows, kernelCols, 1, 1, in_stride, in_stride,
stride, stride, padding_top, padding_bottom,
padding_left, padding_right, OutScalar(0))
.reshape(pre_contract_dims),
contract_dims)
.reshape(post_contract_dims),
output_backward
.extract_image_patches(kernelRows, kernelCols, 1, 1, in_stride,
in_stride, stride, stride, padding_top,
padding_bottom, padding_left, padding_right,
OutScalar(0))
.reshape(pre_contract_dims)
.contract(kernel.reverse(kernel_reverse).reshape(kernel_dims).eval(),
contract_dims)
.reshape(post_contract_dims));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册