diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 5a04c11636e84fad60cb68f849b85c2a13e46d80..e736d8cd6ff584361902450736a65742b30d7f7d 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1223,6 +1223,9 @@ tf_kernel_library( "conv_grad_ops.cc", "conv_grad_ops_3d.cc", ], + hdrs = [ + "conv_grad_ops.h", + ], prefix = "conv_ops", deps = [ ":bounds_check", @@ -1800,6 +1803,7 @@ filegroup( "control_flow_ops.cc", "conv_2d.h", "conv_grad_ops.cc", + "conv_grad_ops.h", "conv_ops.cc", "cwise_op_add.cc", "cwise_op_div.cc", diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc index d68112245d38cf1d8cc10c1e21cd93a9dc12f15b..508ffc04029b0de7cd1e4dd969049cffc83b0f49 100644 --- a/tensorflow/core/kernels/conv_grad_ops.cc +++ b/tensorflow/core/kernels/conv_grad_ops.cc @@ -18,8 +18,11 @@ limitations under the License. #define USE_EIGEN_TENSOR #define EIGEN_USE_THREADS +#include "tensorflow/core/kernels/conv_grad_ops.h" + #include #include + #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -263,82 +266,81 @@ typedef Eigen::GpuDevice GPUDevice; // The case for SAME padding is in fact very similar to VALID -- we just // need to pad the input tensor a bit when computing the filter_backprop. -// Common code between the two kernels: verifies that the dimensions all match -// and extract the padded rows and columns. -#define EXTRACT_AND_VERIFY_DIMENSIONS(label) \ - const Tensor& out_backprop = context->input(2); \ - OP_REQUIRES( \ - context, input_shape.dims() == 4, \ - errors::InvalidArgument(label, ": input must be 4-dimensional")); \ - OP_REQUIRES( \ - context, filter_shape.dims() == 4, \ - errors::InvalidArgument(label, ": filter must be 4-dimensional")); \ - OP_REQUIRES( \ - context, out_backprop.dims() == 4, \ - errors::InvalidArgument(label, ": out_backprop must be 4-dimensional")); \ - const int64 batch = GetTensorDim(input_shape, data_format_, 'N'); \ - OP_REQUIRES( \ - context, batch == GetTensorDim(out_backprop, data_format_, 'N'), \ - errors::InvalidArgument( \ - label, ": input and out_backprop must have the same batch size")); \ - const int64 input_rows = GetTensorDim(input_shape, data_format_, 'H'); \ - const int64 input_cols = GetTensorDim(input_shape, data_format_, 'W'); \ - const int64 filter_rows = filter_shape.dim_size(0); \ - const int64 filter_cols = filter_shape.dim_size(1); \ - const int64 output_rows = GetTensorDim(out_backprop, data_format_, 'H'); \ - const int64 output_cols = GetTensorDim(out_backprop, data_format_, 'W'); \ - const int64 in_depth = GetTensorDim(input_shape, data_format_, 'C'); \ - OP_REQUIRES(context, in_depth == filter_shape.dim_size(2), \ - errors::InvalidArgument( \ - label, ": input and filter must have the same depth")); \ - const int64 out_depth = filter_shape.dim_size(3); \ - OP_REQUIRES( \ - context, out_depth == GetTensorDim(out_backprop, data_format_, 'C'), \ - errors::InvalidArgument( \ - label, ": filter and out_backprop must have the same out_depth")); \ - const auto stride_rows = GetTensorDim(strides_, data_format_, 'H'); \ - const auto stride_cols = GetTensorDim(strides_, data_format_, 'W'); \ - int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; \ - OP_REQUIRES_OK(context, \ - GetWindowedOutputSize(input_rows, filter_rows, stride_rows, \ - padding_, &out_rows, &pad_rows)); \ - OP_REQUIRES_OK(context, \ - GetWindowedOutputSize(input_cols, filter_cols, stride_cols, \ - padding_, &out_cols, &pad_cols)); \ - OP_REQUIRES( \ - context, output_rows == out_rows, \ - errors::InvalidArgument( \ - label, ": Number of rows of out_backprop doesn't match computed: ", \ - "actual = ", output_rows, ", computed = ", out_rows)); \ - OP_REQUIRES( \ - context, output_cols == out_cols, \ - errors::InvalidArgument( \ - label, ": Number of cols of out_backprop doesn't match computed: ", \ - "actual = ", output_cols, ", computed = ", out_cols)); \ - const auto expanded_out_rows = (output_rows - 1) * stride_rows + 1; \ - const auto expanded_out_cols = (output_cols - 1) * stride_cols + 1; \ - const auto padded_out_rows = input_rows + filter_rows - 1; \ - const auto padded_out_cols = input_cols + filter_cols - 1; \ - const int top_pad_rows = filter_rows - 1 - pad_rows; \ - const int left_pad_cols = filter_cols - 1 - pad_cols; \ - const int bottom_pad_rows = \ - padded_out_rows - expanded_out_rows - top_pad_rows; \ - const int right_pad_cols = \ - padded_out_cols - expanded_out_cols - left_pad_cols; \ - Eigen::DSizes strides{1, stride_rows, stride_cols, 1}; \ - VLOG(2) << "Conv2d: " << label \ - << ": expanded_out_rows = " << expanded_out_rows \ - << ", expanded_out_cols = " << expanded_out_cols \ - << ", filter_rows = " << filter_rows \ - << ", filter_cols = " << filter_cols \ - << ", padded_out_rows = " << padded_out_rows \ - << ", padded_out_cols = " << padded_out_cols \ - << ", top_pad_rows = " << top_pad_rows \ - << ", left_pad_cols = " << left_pad_cols \ - << ", bottom_pad_rows = " << bottom_pad_rows \ - << ", right_pad_cols = " << right_pad_cols \ - << ", strides_rows = " << strides[1] \ - << ", strides_cols = " << strides[2] +static Status ConvBackpropExtractAndVerifyDimension( + StringPiece label, const TensorShape& input_shape, + const TensorShape& filter_shape, const TensorShape& output_shape, + const std::vector& strides, Padding padding, int spatial_dim, + int filter_spatial_dim, ConvBackpropSpatialDimension* dim) { + dim->input_size = input_shape.dim_size(spatial_dim); + dim->filter_size = filter_shape.dim_size(filter_spatial_dim); + dim->output_size = output_shape.dim_size(spatial_dim); + dim->stride = strides[spatial_dim]; + int64 out_size = 0, pad_size = 0; + TF_RETURN_IF_ERROR(GetWindowedOutputSize(dim->input_size, dim->filter_size, + dim->stride, padding, &out_size, + &pad_size)); + if (dim->output_size != out_size) { + return errors::InvalidArgument( + label, ": Size of out_backprop doesn't match computed: ", "actual = ", + dim->output_size, ", computed = ", out_size); + } + + dim->expanded_output_size = (dim->output_size - 1) * dim->stride + 1; + const auto padded_out_size = dim->input_size + dim->filter_size - 1; + dim->pad_before = dim->filter_size - 1 - pad_size; + dim->pad_after = + padded_out_size - dim->expanded_output_size - dim->pad_before; + VLOG(2) << label << ": expanded_out = " << dim->expanded_output_size + << ", filter = " << dim->filter_size + << ", padded_out = " << padded_out_size + << ", pad_before = " << dim->pad_before + << ", pad_after = " << dim->pad_after + << ", strides = " << dim->stride; + return Status::OK(); +} + +Status Conv2DBackpropComputeDimensions( + StringPiece label, const TensorShape& input_shape, + const TensorShape& filter_shape, const TensorShape& out_backprop_shape, + const std::vector& strides, Padding padding, + TensorFormat data_format, Conv2DBackpropDimensions* dims) { + if (input_shape.dims() != 4) { + return errors::InvalidArgument(label, ": input must be 4-dimensional"); + } + if (filter_shape.dims() != 4) { + return errors::InvalidArgument(label, ": filter must be 4-dimensional"); + } + if (out_backprop_shape.dims() != 4) { + errors::InvalidArgument(label, ": out_backprop must be 4-dimensional"); + } + dims->batch_size = GetTensorDim(input_shape, data_format, 'N'); + if (dims->batch_size != GetTensorDim(out_backprop_shape, data_format, 'N')) { + return errors::InvalidArgument( + label, ": input and out_backprop must have the same batch size"); + } + + dims->in_depth = GetTensorDim(input_shape, data_format, 'C'); + if (dims->in_depth != filter_shape.dim_size(2)) { + return errors::InvalidArgument( + label, ": input and filter must have the same depth"); + } + dims->out_depth = filter_shape.dim_size(3); + if (dims->out_depth != GetTensorDim(out_backprop_shape, data_format, 'C')) { + return errors::InvalidArgument( + label, ": filter and out_backprop must have the same out_depth"); + } + + const int row_dim = GetTensorDimIndex(data_format, 'H'); + const int col_dim = GetTensorDimIndex(data_format, 'W'); + const int filter_row_dim = 0, filter_col_dim = 1; + TF_RETURN_IF_ERROR(ConvBackpropExtractAndVerifyDimension( + label, input_shape, filter_shape, out_backprop_shape, strides, padding, + row_dim, filter_row_dim, &dims->rows)); + TF_RETURN_IF_ERROR(ConvBackpropExtractAndVerifyDimension( + label, input_shape, filter_shape, out_backprop_shape, strides, padding, + col_dim, filter_col_dim, &dims->cols)); + return Status::OK(); +} // The fast versions using eigen computations directly. They are only enabled // for CPU for now since nvcc times out when trying to compile them. @@ -370,6 +372,7 @@ class Conv2DFastBackpropInputOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input_sizes = context->input(0); const Tensor& filter = context->input(1); + const Tensor& out_backprop = context->input(2); OP_REQUIRES( context, TensorShapeUtils::IsVector(input_sizes.shape()), errors::InvalidArgument( @@ -378,16 +381,21 @@ class Conv2DFastBackpropInputOp : public OpKernel { TensorShape input_shape; OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( input_sizes.vec(), &input_shape)); - const TensorShape& filter_shape = filter.shape(); - EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DBackpropInput"); + Conv2DBackpropDimensions dims; + OP_REQUIRES_OK(context, Conv2DBackpropComputeDimensions( + "Conv2DFastBackpropInput", input_shape, + filter.shape(), out_backprop.shape(), strides_, + padding_, data_format_, &dims)); + Tensor* in_backprop = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &in_backprop)); functor::SpatialConvolutionBackwardInput()( context->eigen_device(), in_backprop->tensor(), - filter.tensor(), out_backprop.tensor(), input_rows, - input_cols, stride_rows, stride_cols); + filter.tensor(), out_backprop.tensor(), + dims.rows.input_size, dims.cols.input_size, dims.rows.stride, + dims.cols.stride); } private: @@ -425,6 +433,7 @@ class Conv2DCustomBackpropInputOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input_sizes = context->input(0); const Tensor& filter = context->input(1); + const Tensor& out_backprop = context->input(2); OP_REQUIRES( context, TensorShapeUtils::IsVector(input_sizes.shape()), errors::InvalidArgument( @@ -433,30 +442,35 @@ class Conv2DCustomBackpropInputOp : public OpKernel { TensorShape input_shape; OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( input_sizes.vec(), &input_shape)); - const TensorShape& filter_shape = filter.shape(); - EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DBackpropInput"); + Conv2DBackpropDimensions dims; + OP_REQUIRES_OK(context, Conv2DBackpropComputeDimensions( + "Conv2DCustomBackpropInput", input_shape, + filter.shape(), out_backprop.shape(), strides_, + padding_, data_format_, &dims)); + Tensor* in_backprop = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &in_backprop)); // TODO(andydavis) Consider moving code shared with // Conv2DCustomBackpropFilterOp into a shared helper function. - int64 pad_top; - int64 pad_bottom; - int64 pad_left; - int64 pad_right; + int64 pad_top, pad_bottom; + int64 pad_left, pad_right; OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - input_rows, filter_rows, stride_rows, padding_, - &out_rows, &pad_top, &pad_bottom)); + dims.rows.input_size, dims.rows.filter_size, + dims.rows.stride, padding_, + &dims.rows.output_size, &pad_top, &pad_bottom)); OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - input_cols, filter_cols, stride_cols, padding_, - &out_cols, &pad_left, &pad_right)); + dims.cols.input_size, dims.cols.filter_size, + dims.cols.stride, padding_, + &dims.cols.output_size, &pad_left, &pad_right)); // The total dimension size of each kernel. - const int filter_total_size = filter_rows * filter_cols * in_depth; + const int filter_total_size = + dims.rows.filter_size * dims.cols.filter_size * dims.in_depth; // The output image size is the spatial size of the output. - const int output_image_size = out_rows * out_cols; + const int output_image_size = dims.rows.output_size * dims.cols.output_size; // TODO(andydavis) Get L2/L3 cache sizes from device. const size_t l2_cache_size = 256LL << 10; @@ -466,9 +480,9 @@ class Conv2DCustomBackpropInputOp : public OpKernel { const size_t target_working_set_size = l3_cache_size / sizeof(T); // Calculate size of matrices involved in MatMul: C = A x B. - const size_t size_A = output_image_size * out_depth; + const size_t size_A = output_image_size * dims.out_depth; - const size_t size_B = filter_total_size * out_depth; + const size_t size_B = filter_total_size * dims.out_depth; const size_t size_C = output_image_size * filter_total_size; @@ -490,7 +504,8 @@ class Conv2DCustomBackpropInputOp : public OpKernel { // TODO(andydavis) Explore alternatives to branching the code in this way // (i.e. run multiple, parallel tensor contractions in another thread pool). const bool use_parallel_contraction = - batch == 1 || thread_work_unit_size >= min_thread_work_unit_size; + dims.batch_size == 1 || + thread_work_unit_size >= min_thread_work_unit_size; const size_t shard_size = use_parallel_contraction @@ -507,9 +522,11 @@ class Conv2DCustomBackpropInputOp : public OpKernel { &col_buffer)); // The input offset corresponding to a single input image. - const int input_offset = input_rows * input_cols * in_depth; + const int input_offset = + dims.rows.input_size * dims.cols.input_size * dims.in_depth; // The output offset corresponding to a single output image. - const int output_offset = out_rows * out_cols * out_depth; + const int output_offset = + dims.rows.output_size * dims.cols.output_size * dims.out_depth; const T* filter_data = filter.template flat().data(); T* col_buffer_data = col_buffer.template flat().data(); @@ -533,19 +550,21 @@ class Conv2DCustomBackpropInputOp : public OpKernel { contract_dims[0].first = 1; contract_dims[0].second = 1; - for (int image_id = 0; image_id < batch; ++image_id) { + for (int image_id = 0; image_id < dims.batch_size; ++image_id) { // Compute gradient into col_buffer. TensorMap C(col_buffer_data, output_image_size, filter_total_size); ConstTensorMap A(out_backprop_data + output_offset * image_id, - output_image_size, out_depth); - ConstTensorMap B(filter_data, filter_total_size, out_depth); + output_image_size, dims.out_depth); + ConstTensorMap B(filter_data, filter_total_size, dims.out_depth); C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims); - Col2im(col_buffer_data, in_depth, input_rows, input_cols, - filter_rows, filter_cols, pad_top, pad_left, pad_bottom, - pad_right, stride_rows, stride_cols, input_backprop_data); + Col2im(col_buffer_data, dims.in_depth, dims.rows.input_size, + dims.cols.input_size, dims.rows.filter_size, + dims.cols.filter_size, pad_top, pad_left, pad_bottom, + pad_right, dims.rows.stride, dims.cols.stride, + input_backprop_data); input_backprop_data += input_offset; } @@ -557,14 +576,14 @@ class Conv2DCustomBackpropInputOp : public OpKernel { Eigen::RowMajor>> ConstMatrixMap; - for (int image_id = 0; image_id < batch; image_id += shard_size) { - const int shard_limit = std::min(static_cast(shard_size), - static_cast(batch) - image_id); + for (int image_id = 0; image_id < dims.batch_size; + image_id += shard_size) { + const int shard_limit = + std::min(static_cast(shard_size), + static_cast(dims.batch_size) - image_id); - auto shard = [&in_depth, &input_rows, &input_cols, &filter_rows, - &filter_cols, &pad_top, &pad_left, &pad_bottom, - &pad_right, &stride_rows, &stride_cols, - &output_image_size, &filter_total_size, &out_depth, + auto shard = [&dims, &pad_top, &pad_left, &pad_bottom, &pad_right, + &output_image_size, &filter_total_size, &input_backprop_data, &col_buffer_data, &out_backprop_data, &filter_data, &input_offset, &output_offset, &size_C](int64 start, int64 limit) { @@ -576,14 +595,16 @@ class Conv2DCustomBackpropInputOp : public OpKernel { // Compute gradient into 'im2col_buf'. MatrixMap C(im2col_buf, output_image_size, filter_total_size); - ConstMatrixMap A(out_data, output_image_size, out_depth); - ConstMatrixMap B(filter_data, filter_total_size, out_depth); + ConstMatrixMap A(out_data, output_image_size, dims.out_depth); + ConstMatrixMap B(filter_data, filter_total_size, dims.out_depth); C.noalias() = A * B.transpose(); - Col2im(im2col_buf, in_depth, input_rows, input_cols, filter_rows, - filter_cols, pad_top, pad_left, pad_bottom, pad_right, - stride_rows, stride_cols, input_data); + Col2im(im2col_buf, dims.in_depth, dims.rows.input_size, + dims.cols.input_size, dims.rows.filter_size, + dims.cols.filter_size, pad_top, pad_left, pad_bottom, + pad_right, dims.rows.stride, dims.cols.stride, + input_data); } }; Shard(worker_threads.num_threads, worker_threads.workers, shard_limit, @@ -648,25 +669,31 @@ class Conv2DFastBackpropFilterOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); const Tensor& filter_sizes = context->input(1); + const Tensor& out_backprop = context->input(2); OP_REQUIRES( context, TensorShapeUtils::IsVector(filter_sizes.shape()), errors::InvalidArgument( "Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ", filter_sizes.dims())); - const TensorShape& input_shape = input.shape(); TensorShape filter_shape; OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( filter_sizes.vec(), &filter_shape)); - EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DBackpropFilter"); + Conv2DBackpropDimensions dims; + OP_REQUIRES_OK(context, Conv2DBackpropComputeDimensions( + "Conv2DFastBackpropFilter", input.shape(), + filter_shape, out_backprop.shape(), strides_, + padding_, data_format_, &dims)); + Tensor* filter_backprop = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, filter_shape, &filter_backprop)); functor::SpatialConvolutionBackwardKernel()( context->eigen_device(), filter_backprop->tensor(), - input.tensor(), out_backprop.tensor(), filter_rows, - filter_cols, stride_rows, stride_cols); + input.tensor(), out_backprop.tensor(), + dims.rows.filter_size, dims.cols.filter_size, dims.rows.stride, + dims.cols.stride); } private: @@ -704,37 +731,43 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); const Tensor& filter_sizes = context->input(1); + const Tensor& out_backprop = context->input(2); OP_REQUIRES( context, TensorShapeUtils::IsVector(filter_sizes.shape()), errors::InvalidArgument( "Conv2DCustomBackpropFilter: filter_sizes input must be 1-dim, " "not ", filter_sizes.dims())); - const TensorShape& input_shape = input.shape(); TensorShape filter_shape; OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( filter_sizes.vec(), &filter_shape)); - EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DCustomBackpropFilter"); + Conv2DBackpropDimensions dims; + OP_REQUIRES_OK(context, Conv2DBackpropComputeDimensions( + "Conv2DCustomBackpropFilter", input.shape(), + filter_shape, out_backprop.shape(), strides_, + padding_, data_format_, &dims)); + Tensor* filter_backprop; OP_REQUIRES_OK(context, context->allocate_output(0, filter_shape, &filter_backprop)); - int64 pad_top; - int64 pad_bottom; - int64 pad_left; - int64 pad_right; + int64 pad_top, pad_bottom; + int64 pad_left, pad_right; OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - input_rows, filter_rows, stride_rows, padding_, - &out_rows, &pad_top, &pad_bottom)); + dims.rows.input_size, dims.rows.filter_size, + dims.rows.stride, padding_, + &dims.rows.output_size, &pad_top, &pad_bottom)); OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( - input_cols, filter_cols, stride_cols, padding_, - &out_cols, &pad_left, &pad_right)); + dims.cols.input_size, dims.cols.filter_size, + dims.cols.stride, padding_, + &dims.cols.output_size, &pad_left, &pad_right)); // The total dimension size of each kernel. - const int filter_total_size = filter_rows * filter_cols * in_depth; + const int filter_total_size = + dims.rows.filter_size * dims.cols.filter_size * dims.in_depth; // The output image size is the spatial size of the output. - const int output_image_size = out_rows * out_cols; + const int output_image_size = dims.rows.output_size * dims.cols.output_size; // Shard 'batch' images into 'shard_size' groups of images to be fed // into the parallel matmul. Calculate 'shard_size' by dividing the L3 cache @@ -749,9 +782,9 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { const size_t size_A = output_image_size * filter_total_size; - const size_t size_B = output_image_size * out_depth; + const size_t size_B = output_image_size * dims.out_depth; - const size_t size_C = filter_total_size * out_depth; + const size_t size_C = filter_total_size * dims.out_depth; const size_t work_unit_size = size_A + size_B + size_C; @@ -768,9 +801,11 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { &col_buffer)); // The input offset corresponding to a single input image. - const int input_offset = input_rows * input_cols * in_depth; + const int input_offset = + dims.rows.input_size * dims.cols.input_size * dims.in_depth; // The output offset corresponding to a single output image. - const int output_offset = out_rows * out_cols * out_depth; + const int output_offset = + dims.rows.output_size * dims.cols.output_size * dims.out_depth; const T* input_data = input.template flat().data(); T* col_buffer_data = col_buffer.template flat().data(); @@ -784,7 +819,7 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { Eigen::Unaligned> ConstTensorMap; - TensorMap C(filter_backprop_data, filter_total_size, out_depth); + TensorMap C(filter_backprop_data, filter_total_size, dims.out_depth); C.setZero(); // Initialize contraction dims (we need to transpose 'A' below). @@ -794,14 +829,13 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); - for (int image_id = 0; image_id < batch; image_id += shard_size) { - const int shard_limit = std::min(static_cast(shard_size), - static_cast(batch) - image_id); + for (int image_id = 0; image_id < dims.batch_size; image_id += shard_size) { + const int shard_limit = + std::min(static_cast(shard_size), + static_cast(dims.batch_size) - image_id); - auto shard = [&input_data, &col_buffer_data, &in_depth, &input_rows, - &input_cols, &filter_rows, &filter_cols, &pad_top, - &pad_left, &pad_bottom, &pad_right, &stride_rows, - &stride_cols, &input_offset, + auto shard = [&input_data, &col_buffer_data, &dims, &pad_top, &pad_left, + &pad_bottom, &pad_right, &input_offset, &size_A](int64 start, int64 limit) { for (int shard_id = start; shard_id < limit; ++shard_id) { const T* input_data_shard = input_data + shard_id * input_offset; @@ -809,9 +843,11 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { // When we compute the gradient with respect to the filters, we need // to do im2col to allow gemm-type computation. - Im2col(input_data_shard, in_depth, input_rows, input_cols, - filter_rows, filter_cols, pad_top, pad_left, pad_bottom, - pad_right, stride_rows, stride_cols, col_data_shard); + Im2col(input_data_shard, dims.in_depth, dims.rows.input_size, + dims.cols.input_size, dims.rows.filter_size, + dims.cols.filter_size, pad_top, pad_left, pad_bottom, + pad_right, dims.rows.stride, dims.cols.stride, + col_data_shard); } }; Shard(worker_threads.num_threads, worker_threads.workers, shard_limit, @@ -820,7 +856,7 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { ConstTensorMap A(col_buffer_data, output_image_size * shard_limit, filter_total_size); ConstTensorMap B(out_backprop_data, output_image_size * shard_limit, - out_depth); + dims.out_depth); // Gradient with respect to filter. C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims); @@ -894,6 +930,7 @@ class Conv2DSlowBackpropInputOp : public OpKernel { using perftools::gputools::dnn::kDefaultAlgorithm; const Tensor& input_sizes = context->input(0); const Tensor& filter = context->input(1); + const Tensor& out_backprop = context->input(2); OP_REQUIRES( context, TensorShapeUtils::IsVector(input_sizes.shape()), errors::InvalidArgument( @@ -904,19 +941,28 @@ class Conv2DSlowBackpropInputOp : public OpKernel { input_sizes.vec(), &input_shape)); const TensorShape& filter_shape = filter.shape(); - EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DBackpropInput"); + Conv2DBackpropDimensions dims; + OP_REQUIRES_OK(context, Conv2DBackpropComputeDimensions( + "Conv2DSlowBackpropInput", input_shape, + filter_shape, out_backprop.shape(), strides_, + padding_, data_format_, &dims)); + Tensor* in_backprop = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &in_backprop)); const int padding_rows = - (padding_ == VALID) ? 0 - : std::max(0, (output_rows - 1) * stride_rows + - filter_rows - input_rows); + (padding_ == VALID) + ? 0 + : std::max(0, (dims.rows.output_size - 1) * dims.rows.stride + + dims.rows.filter_size - + dims.rows.input_size); const int padding_cols = - (padding_ == VALID) ? 0 - : std::max(0, (output_cols - 1) * stride_cols + - filter_cols - input_cols); + (padding_ == VALID) + ? 0 + : std::max(0, (dims.cols.output_size - 1) * dims.cols.stride + + dims.cols.filter_size - + dims.cols.input_size); // TODO(keveman): cuDNN only supports equal padding on both sides, so only // calling it when that is true. Remove this check when (if?) cuDNN starts @@ -934,12 +980,14 @@ class Conv2DSlowBackpropInputOp : public OpKernel { return; } - if (filter_rows == 1 && filter_cols == 1 && stride_rows == 1 && - stride_cols == 1 && data_format_ == FORMAT_NHWC) { + if (dims.rows.filter_size == 1 && dims.cols.filter_size == 1 && + dims.rows.stride == 1 && dims.cols.stride == 1 && + data_format_ == FORMAT_NHWC) { // 1x1 filter, so call cublas directly. - const uint64 m = batch * input_rows * input_cols; - const uint64 k = out_depth; - const uint64 n = in_depth; + const uint64 m = + dims.batch_size * dims.rows.input_size * dims.cols.input_size; + const uint64 k = dims.out_depth; + const uint64 n = dims.in_depth; auto a_ptr = AsDeviceMemory(out_backprop.template flat().data(), out_backprop.template flat().size()); @@ -969,33 +1017,33 @@ class Conv2DSlowBackpropInputOp : public OpKernel { // If a padding dimension is odd, we have one more element on the right // side or the bottom side. This is unsupported in cudnn. Therefore, // we pad that extra element and make it compatible. - compatible_input_shape = - ShapeFromFormat(data_format_, batch, input_rows + rows_odd, - input_cols + cols_odd, in_depth); + compatible_input_shape = ShapeFromFormat( + data_format_, dims.batch_size, dims.rows.input_size + rows_odd, + dims.cols.input_size + cols_odd, dims.in_depth); } else { compatible_input_shape = input_shape; } perftools::gputools::dnn::BatchDescriptor input_desc; - input_desc.set_count(batch) + input_desc.set_count(dims.batch_size) .set_height(GetTensorDim(compatible_input_shape, data_format_, 'H')) .set_width(GetTensorDim(compatible_input_shape, data_format_, 'W')) - .set_feature_map_count(in_depth) + .set_feature_map_count(dims.in_depth) .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); perftools::gputools::dnn::BatchDescriptor output_desc; - output_desc.set_count(batch) - .set_height(output_rows) - .set_width(output_cols) - .set_feature_map_count(out_depth) + output_desc.set_count(dims.batch_size) + .set_height(dims.rows.output_size) + .set_width(dims.cols.output_size) + .set_feature_map_count(dims.out_depth) .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); perftools::gputools::dnn::FilterDescriptor filter_desc; - filter_desc.set_input_filter_height(filter_rows) - .set_input_filter_width(filter_cols) - .set_input_feature_map_count(in_depth) - .set_output_feature_map_count(out_depth); + filter_desc.set_input_filter_height(dims.rows.filter_size) + .set_input_filter_width(dims.cols.filter_size) + .set_input_feature_map_count(dims.in_depth) + .set_output_feature_map_count(dims.out_depth); perftools::gputools::dnn::ConvolutionDescriptor conv_desc; - conv_desc.set_vertical_filter_stride(stride_rows) - .set_horizontal_filter_stride(stride_cols) + conv_desc.set_vertical_filter_stride(dims.rows.stride) + .set_horizontal_filter_stride(dims.cols.stride) .set_zero_padding_height(padding_rows / 2) .set_zero_padding_width(padding_cols / 2); @@ -1013,11 +1061,12 @@ class Conv2DSlowBackpropInputOp : public OpKernel { // the second TransformDepth performs // (B x D x R x C) => (B x R x C x D). Tensor transformed_filter; - OP_REQUIRES_OK( - context, context->allocate_temp(DataTypeToEnum::value, - TensorShape({out_depth, in_depth, - filter_rows, filter_cols}), - &transformed_filter)); + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum::value, + TensorShape({dims.out_depth, dims.in_depth, + dims.rows.filter_size, + dims.cols.filter_size}), + &transformed_filter)); functor::TransformFilter()( context->eigen_device(), To32Bit(filter.tensor()), @@ -1028,8 +1077,9 @@ class Conv2DSlowBackpropInputOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_temp( DataTypeToEnum::value, - ShapeFromFormat(FORMAT_NCHW, batch, output_rows, - output_cols, out_depth), + ShapeFromFormat(FORMAT_NCHW, dims.batch_size, + dims.rows.output_size, + dims.cols.output_size, dims.out_depth), &transformed_out_backprop)); functor::NHWCToNCHW()( @@ -1069,18 +1119,18 @@ class Conv2DSlowBackpropInputOp : public OpKernel { context); int device_id = stream->parent()->device_ordinal(); ConvParameters conv_parameters = { - batch, // batch - in_depth, // in_depths - input_desc.height(), // in_rows - input_desc.width(), // in_cols - out_depth, // out_depths - filter_rows, // filter_rows - filter_cols, // filter_cols - stride_rows, // stride_rows - stride_cols, // stride_cols - padding_rows, // padding_rows - padding_cols, // padding_cols - device_id, // device_id + dims.batch_size, // batch + dims.in_depth, // in_depths + input_desc.height(), // in_rows + input_desc.width(), // in_cols + dims.out_depth, // out_depths + dims.rows.filter_size, // filter_rows + dims.cols.filter_size, // filter_cols + dims.rows.stride, // stride_rows + dims.cols.stride, // stride_cols + padding_rows, // padding_rows + padding_cols, // padding_cols + device_id, // device_id }; AlgorithmConfig algorithm_config; if (cudnn_use_autotune_ && @@ -1221,6 +1271,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { using perftools::gputools::dnn::kDefaultAlgorithm; const Tensor& input = context->input(0); const Tensor& filter_sizes = context->input(1); + const Tensor& out_backprop = context->input(2); OP_REQUIRES( context, TensorShapeUtils::IsVector(filter_sizes.shape()), errors::InvalidArgument( @@ -1231,19 +1282,28 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( filter_sizes.vec(), &filter_shape)); - EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DBackpropFilter"); + Conv2DBackpropDimensions dims; + OP_REQUIRES_OK(context, Conv2DBackpropComputeDimensions( + "Conv2DSlowBackpropFilter", input.shape(), + filter_shape, out_backprop.shape(), strides_, + padding_, data_format_, &dims)); + Tensor* filter_backprop = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, filter_shape, &filter_backprop)); const int padding_rows = - (padding_ == VALID) ? 0 - : std::max(0, (output_rows - 1) * stride_rows + - filter_rows - input_rows); + (padding_ == VALID) + ? 0 + : std::max(0, (dims.rows.output_size - 1) * dims.rows.stride + + dims.rows.filter_size - + dims.rows.input_size); const int padding_cols = - (padding_ == VALID) ? 0 - : std::max(0, (output_cols - 1) * stride_cols + - filter_cols - input_cols); + (padding_ == VALID) + ? 0 + : std::max(0, (dims.cols.output_size - 1) * dims.cols.stride + + dims.cols.filter_size - + dims.cols.input_size); // TODO(zhengxq): cuDNN only supports equal padding on both sides, so only // calling it when that is true. Remove this check when (if?) cuDNN starts @@ -1261,11 +1321,13 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { return; } - if (filter_rows == 1 && filter_cols == 1 && stride_rows == 1 && - stride_cols == 1 && data_format_ == FORMAT_NHWC) { - const uint64 m = in_depth; - const uint64 k = batch * input_rows * input_cols; - const uint64 n = out_depth; + if (dims.rows.filter_size == 1 && dims.cols.filter_size == 1 && + dims.rows.stride == 1 && dims.cols.stride == 1 && + data_format_ == FORMAT_NHWC) { + const uint64 m = dims.in_depth; + const uint64 k = + dims.batch_size * dims.rows.input_size * dims.cols.input_size; + const uint64 n = dims.out_depth; // The shape of output backprop is // [batch, out_rows, out_cols, out_depth] @@ -1307,8 +1369,9 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { context, context->allocate_temp( DataTypeToEnum::value, - ShapeFromFormat(data_format_, batch, input_rows + rows_odd, - input_cols + cols_odd, in_depth), + ShapeFromFormat(data_format_, dims.batch_size, + dims.rows.input_size + rows_odd, + dims.cols.input_size + cols_odd, dims.in_depth), &compatible_input)); functor::PadInput()( @@ -1320,25 +1383,25 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { } perftools::gputools::dnn::BatchDescriptor input_desc; - input_desc.set_count(batch) + input_desc.set_count(dims.batch_size) .set_height(GetTensorDim(compatible_input, data_format_, 'H')) .set_width(GetTensorDim(compatible_input, data_format_, 'W')) - .set_feature_map_count(in_depth) + .set_feature_map_count(dims.in_depth) .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); perftools::gputools::dnn::BatchDescriptor output_desc; - output_desc.set_count(batch) - .set_height(output_rows) - .set_width(output_cols) - .set_feature_map_count(out_depth) + output_desc.set_count(dims.batch_size) + .set_height(dims.rows.output_size) + .set_width(dims.cols.output_size) + .set_feature_map_count(dims.out_depth) .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); perftools::gputools::dnn::FilterDescriptor filter_desc; - filter_desc.set_input_filter_height(filter_rows) - .set_input_filter_width(filter_cols) - .set_input_feature_map_count(in_depth) - .set_output_feature_map_count(out_depth); + filter_desc.set_input_filter_height(dims.rows.filter_size) + .set_input_filter_width(dims.cols.filter_size) + .set_input_feature_map_count(dims.in_depth) + .set_output_feature_map_count(dims.out_depth); perftools::gputools::dnn::ConvolutionDescriptor conv_desc; - conv_desc.set_vertical_filter_stride(stride_rows) - .set_horizontal_filter_stride(stride_cols) + conv_desc.set_vertical_filter_stride(dims.rows.stride) + .set_horizontal_filter_stride(dims.cols.stride) .set_zero_padding_height(padding_rows / 2) .set_zero_padding_width(padding_cols / 2); @@ -1357,19 +1420,21 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { // (B x D x R x C) => (B x R x C x D). Tensor pre_transformed_filter_backprop; - OP_REQUIRES_OK( - context, context->allocate_temp(DataTypeToEnum::value, - TensorShape({out_depth, in_depth, - filter_rows, filter_cols}), - &pre_transformed_filter_backprop)); + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum::value, + TensorShape({dims.out_depth, dims.in_depth, + dims.rows.filter_size, + dims.cols.filter_size}), + &pre_transformed_filter_backprop)); Tensor transformed_out_backprop; if (data_format_ == FORMAT_NHWC) { OP_REQUIRES_OK(context, context->allocate_temp( DataTypeToEnum::value, - ShapeFromFormat(FORMAT_NCHW, batch, output_rows, - output_cols, out_depth), + ShapeFromFormat(FORMAT_NCHW, dims.batch_size, + dims.rows.output_size, + dims.cols.output_size, dims.out_depth), &transformed_out_backprop)); functor::NHWCToNCHW()( context->eigen_device(), out_backprop.tensor(), @@ -1413,18 +1478,18 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { ); int device_id = stream->parent()->device_ordinal(); ConvParameters conv_parameters = { - batch, // batch - in_depth, // in_depths - input_desc.height(), // in_rows - input_desc.width(), // in_cols - out_depth, // out_depths - filter_rows, // filter_rows - filter_cols, // filter_cols - stride_rows, // stride_rows - stride_cols, // stride_cols - padding_rows, // padding_rows - padding_cols, // padding_cols - device_id, // device_id + dims.batch_size, // batch + dims.in_depth, // in_depths + input_desc.height(), // in_rows + input_desc.width(), // in_cols + dims.out_depth, // out_depths + dims.rows.filter_size, // filter_rows + dims.cols.filter_size, // filter_cols + dims.rows.stride, // stride_rows + dims.cols.stride, // stride_cols + padding_rows, // padding_rows + padding_cols, // padding_cols + device_id, // device_id }; AlgorithmConfig algorithm_config; if (cudnn_use_autotune_ && diff --git a/tensorflow/core/kernels/conv_grad_ops.h b/tensorflow/core/kernels/conv_grad_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..f0ec8c2a7504bf32351aff8177211b6c684285e7 --- /dev/null +++ b/tensorflow/core/kernels/conv_grad_ops.h @@ -0,0 +1,65 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_CONV_GRAD_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_CONV_GRAD_OPS_H_ + +#include + +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +// Information about a single spatial dimension for a convolution +// backpropagation. +struct ConvBackpropSpatialDimension { + int64 input_size; + int64 filter_size; + int64 output_size; + int64 stride; + int64 expanded_output_size; + + // Number of padding elements to be added before/after this dimension of + // the input when computing Conv2DBackpropInput. + int64 pad_before, pad_after; +}; + +// Computed dimensions for a Conv2D backpropagation. +struct Conv2DBackpropDimensions { + // Information about each spatial dimension. + ConvBackpropSpatialDimension rows, cols; + + // Batch size. + int64 batch_size; + + // Input and output feature depth. + int64 in_depth, out_depth; +}; + +// Common code between implementations of Conv2DBackpropInput and +// Conv2DBackpropFilter. Verifies that the dimensions all match, and computes +// sizes/padding for rows and columns. +Status Conv2DBackpropComputeDimensions( + StringPiece label, const TensorShape& input_shape, + const TensorShape& filter_shape, const TensorShape& out_backprop_shape, + const std::vector& strides, Padding padding, + TensorFormat data_format, Conv2DBackpropDimensions* dims); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CONV_GRAD_OPS_H_