diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index a590558492348001345c6485350aff15ce13f4ea..0f8a787e9f481f69e8db360249fbe834ec76a4a0 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -696,6 +696,17 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "depthwise_conv_op", + prefix = "depthwise_conv_op", + deps = [ + ":ops_util", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:nn_ops_op_lib", + ], +) + tf_kernel_libraries( name = "nn", prefixes = [ @@ -714,6 +725,7 @@ tf_kernel_libraries( deps = [ ":conv_2d", ":conv_ops", + ":depthwise_conv_op", ":ops_util", ":pooling_ops", "//tensorflow/core:framework", diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c3f90521b499ef48f1abdca8aab4e974c7d20666 --- /dev/null +++ b/tensorflow/core/kernels/depthwise_conv_op.cc @@ -0,0 +1,267 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include +#include + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/depthwise_conv_op.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/padding.h" + +#if GOOGLE_CUDA +#include "tensorflow/core/common_runtime/gpu_device_context.h" +#include "tensorflow/core/platform/stream_executor.h" +#endif // GOOGLE_CUDA + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +struct LaunchDepthwiseConvOp; + +template +struct LaunchDepthwiseConvOp { + static void launch(OpKernelContext* ctx, const DepthwiseArgs& args, + const T* input, const T* filter, T* output) { + // Naive for loop as a reference point without concerns about performance. + // Expected to be replaced later. + // TODO(andydavis): replace this with an optimized version + for (int b = 0; b < args.batch; ++b) { + for (int out_r = 0; out_r < args.out_rows; ++out_r) { + for (int out_c = 0; out_c < args.out_cols; ++out_c) { + for (int out_d = 0; out_d < args.out_depth; ++out_d) { + T sum = 0; + const int in_r_start = out_r * args.stride - args.pad_rows; + const int in_c_start = out_c * args.stride - args.pad_cols; + const int in_d = out_d / args.depth_multiplier; + const int filter_dm = out_d % args.depth_multiplier; + + for (int f_r = 0; f_r < args.filter_rows; ++f_r) { + for (int f_c = 0; f_c < args.filter_cols; ++f_c) { + int in_r = in_r_start + f_r; + int in_c = in_c_start + f_c; + + if (in_r >= 0 && in_r < args.in_rows && in_c >= 0 && + in_c < args.in_cols) { + int input_offset = + in_d + + args.in_depth * + (in_c + args.in_cols * (in_r + args.in_rows * b)); + int filter_offset = + filter_dm + + args.depth_multiplier * + (in_d + + args.in_depth * (f_c + args.filter_cols * f_r)); + sum += input[input_offset] * filter[filter_offset]; + } + } + } + + int output_offset = + out_d + + args.out_depth * + (out_c + args.out_cols * (out_r + args.out_rows * b)); + output[output_offset] = sum; + } + } + } + } + } +}; + +#if GOOGLE_CUDA + +template +struct DepthwiseConv2dGPULaunch { + void Run(const GPUDevice& d, const DepthwiseArgs args, const T* input, + const T* filter, T* output); +}; + +template +struct LaunchDepthwiseConvOp { + static void launch(OpKernelContext* ctx, const DepthwiseArgs args, + const T* input, const T* filter, T* output) { + const GPUDevice& d = ctx->eigen_device(); + DepthwiseConv2dGPULaunch().Run(d, args, input, filter, output); + auto stream = ctx->op_device_context()->stream(); + OP_REQUIRES(ctx, stream->ok(), + errors::Internal("Launch of gpu kernel for SplitOp failed")); + } +}; + +#endif + +template +class DepthwiseConv2dNativeOp : public BinaryOp { + public: + explicit DepthwiseConv2dNativeOp(OpKernelConstruction* context) + : BinaryOp(context) { + OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); + OP_REQUIRES(context, strides_.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + OP_REQUIRES(context, strides_[1] == strides_[2], + errors::InvalidArgument( + "Current implementation only supports equal length " + "strides in the row and column dimensions.")); + OP_REQUIRES( + context, (strides_[0] == 1 && strides_[3] == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + // Input tensor is of the following dimensions: + // [ batch, in_rows, in_cols, in_depth ] + const Tensor& input = context->input(0); + auto input_ptr = input.template flat().data(); + + // Input filter is of the following dimensions: + // [ filter_rows, filter_cols, in_depth, depth_multiplier] + const Tensor& filter = context->input(1); + auto filter_ptr = filter.template flat().data(); + + // For 2D convolution, there should be 4 dimensions. + OP_REQUIRES(context, input.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input.shape().DebugString())); + OP_REQUIRES(context, filter.dims() == 4, + errors::InvalidArgument("filter must be 4-dimensional: ", + filter.shape().DebugString())); + + // The last dimension for input is in_depth. It must be the same as the + // filter's in_depth. + const int32 in_depth = input.dim_size(3); + OP_REQUIRES( + context, in_depth == filter.dim_size(2), + errors::InvalidArgument("input and filter must have the same depth: ", + in_depth, " vs ", filter.dim_size(2))); + + // The last dimension for filter is depth multiplier. + const int32 depth_multiplier = filter.dim_size(3); + + // The output depth is input depth x depth multipler + const int32 out_depth = in_depth * depth_multiplier; + + // The second dimension for input is rows/height. + // The first dimension for filter is rows/height. + const int32 input_rows = input.dim_size(1); + const int32 filter_rows = filter.dim_size(0); + + // The third dimension for input is columns/width. + // The second dimension for filter is columns/width. + const int32 input_cols = input.dim_size(2); + const int32 filter_cols = filter.dim_size(1); + + // The first dimension for input is batch. + const int32 batch = input.dim_size(0); + + // For now we take the stride from the second dimension only (we + // assume row = col stride, and do not support striding on the + // batch or depth dimension). + const int32 stride = strides_[1]; + + int32 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; + OP_REQUIRES_OK(context, + Get2dOutputSize(input_rows, input_cols, filter_rows, + filter_cols, stride, stride, padding_, + &out_rows, &out_cols, &pad_rows, &pad_cols)); + TensorShape out_shape({batch, out_rows, out_cols, out_depth}); + OP_REQUIRES( + context, out_shape.num_elements() <= 2147483647, + errors::InvalidArgument("total number of outputs should be within the " + "range of int which is used in the GPU kernel", + in_depth, " vs ", filter.dim_size(2))); + + // Output tensor is of the following dimensions: + // [ in_batch, out_rows, out_cols, out_depth ] + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); + auto output_ptr = output->template flat().data(); + + DepthwiseArgs args; + args.batch = batch; + args.in_rows = input_rows; + args.in_cols = input_cols; + args.in_depth = in_depth; + args.filter_rows = filter_rows; + args.filter_cols = filter_cols; + args.depth_multiplier = depth_multiplier; + args.stride = stride; + args.pad_rows = pad_rows; + args.pad_cols = pad_cols; + args.out_rows = out_rows; + args.out_cols = out_cols; + args.out_depth = out_depth; + + VLOG(2) << "DepthwiseConv2dNative: " + << " Input: [" << batch << ", " << input_rows << ", " << input_cols + << ", " << in_depth << "]; Filter: [" << filter_rows << ", " + << filter_cols << ", " << in_depth << ", " << depth_multiplier + << "]; stride = " << stride << ", pad_rows = " << pad_rows + << ", pad_cols = " << pad_cols << ", output: [" << batch << ", " + << out_rows << ", " << out_cols << ", " << out_depth << "]" << endl; + + // If there is nothing to compute, return. + if (out_shape.num_elements() == 0) { + return; + } + LaunchDepthwiseConvOp::launch(context, args, input_ptr, + filter_ptr, output_ptr); + } + + private: + std::vector strides_; + Padding padding_; + + TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeOp); +}; + +REGISTER_KERNEL_BUILDER( + Name("DepthwiseConv2dNative").Device(DEVICE_CPU).TypeConstraint("T"), + DepthwiseConv2dNativeOp); + +REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + DepthwiseConv2dNativeOp); + +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER( + Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint("T"), + DepthwiseConv2dNativeOp); + +REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative") + .Device(DEVICE_GPU) + .TypeConstraint("T"), + DepthwiseConv2dNativeOp); +#endif + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/depthwise_conv_op.h b/tensorflow/core/kernels/depthwise_conv_op.h new file mode 100644 index 0000000000000000000000000000000000000000..9733c5914d2ee7804a9ba768e7efc68b28c8dde1 --- /dev/null +++ b/tensorflow/core/kernels/depthwise_conv_op.h @@ -0,0 +1,53 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_H_ + +struct DepthwiseArgs { + // Input layer dimensions + int batch; + int in_rows; + int in_cols; + int in_depth; + int filter_rows; + int filter_cols; + int depth_multiplier; + int stride; + int pad_rows; + int pad_cols; + + // Output layer dimensions + int out_rows; + int out_cols; + int out_depth; + + DepthwiseArgs() + : batch(0), + in_rows(0), + in_cols(0), + in_depth(0), + filter_rows(0), + filter_cols(0), + depth_multiplier(0), + stride(0), + pad_rows(0), + pad_cols(0), + out_rows(0), + out_cols(0), + out_depth(0) {} +}; + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_H_ diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..787e37c041dd3545ad188d3309b325d5072e1883 --- /dev/null +++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc @@ -0,0 +1,122 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/depthwise_conv_op.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +#define UNROLL _Pragma("unroll") + +namespace tensorflow { + +namespace { + +typedef Eigen::GpuDevice GPUDevice; + +// A Cuda kernel to compute the depthwise convolution. +template +__global__ void DepthwiseConv2dGPUKernel(const DepthwiseArgs args, + const T* input, const T* filter, + T* output, int num_outputs) { + const int in_rows = args.in_rows; + const int in_cols = args.in_cols; + const int in_depth = args.in_depth; + const int filter_rows = args.filter_rows; + const int filter_cols = args.filter_cols; + const int depth_multiplier = args.depth_multiplier; + const int stride = args.stride; + const int pad_rows = args.pad_rows; + const int pad_cols = args.pad_cols; + const int out_rows = args.out_rows; + const int out_cols = args.out_cols; + const int out_depth = args.out_depth; + + CUDA_1D_KERNEL_LOOP(thread_id, num_outputs) { + // Compute the indexes of this thread in the output. + const int OD = thread_id % out_depth; + const int OC = (thread_id / out_depth) % out_cols; + const int OR = (thread_id / out_depth / out_cols) % out_rows; + const int OB = thread_id / out_depth / out_cols / out_rows; + // Compute the input depth and the index of depth multiplier. + const int in_d = OD / depth_multiplier; + const int multiplier = OD % depth_multiplier; + + // Decide if all input is valid, if yes, we can skip the boundary checks for + // each input. + const int input_row_start = OR * stride - pad_rows; + const int input_col_start = OC * stride - pad_cols; + const int input_row_end = input_row_start + filter_rows; + const int input_col_end = input_col_start + filter_cols; + + float sum = 0; + if (input_row_start >= 0 && input_col_start >= 0 && + input_row_end < in_rows && input_col_end < in_cols) { + UNROLL for (int f_r = 0; f_r < filter_rows; ++f_r) { + UNROLL for (int f_c = 0; f_c < filter_cols; ++f_c) { + int in_r = input_row_start + f_r; + int in_c = input_col_start + f_c; + + sum += input[in_d + + in_depth * (in_c + in_cols * (in_r + in_rows * OB))] * + filter[multiplier + + depth_multiplier * + (in_d + in_depth * (f_c + filter_cols * f_r))]; + } + } + } else { + UNROLL for (int f_r = 0; f_r < filter_rows; ++f_r) { + UNROLL for (int f_c = 0; f_c < filter_cols; ++f_c) { + int in_r = input_row_start + f_r; + int in_c = input_col_start + f_c; + + if (in_r >= 0 && in_r < in_rows && in_c >= 0 && in_c < in_cols) { + sum += input[in_d + + in_depth * (in_c + in_cols * (in_r + in_rows * OB))] * + filter[multiplier + + depth_multiplier * + (in_d + in_depth * (f_c + filter_cols * f_r))]; + } + } + } + } + output[OD + out_depth * (OC + out_cols * (OR + out_rows * OB))] = sum; + } +} +} // namespace + +// A simple launch pad to launch the Cuda kernel for depthwise convolution. +template +struct DepthwiseConv2dGPULaunch { + void Run(const GPUDevice& d, const DepthwiseArgs args, const T* input, + const T* filter, T* output) { + const int num_outputs = + args.batch * args.out_rows * args.out_cols * args.out_depth; + CudaLaunchConfig config = GetCudaLaunchConfig(num_outputs, d); + + DepthwiseConv2dGPUKernel< + T><<>>( + args, input, filter, output, num_outputs); + } +}; + +template struct DepthwiseConv2dGPULaunch; +template struct DepthwiseConv2dGPULaunch; + +} // namespace tensorflow +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index d6a5f64e508313f51e8636f7e8fb451a54860a38..29e19af6f83d99d8c9eb9bbbb0d24b69ad84f4eb 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -269,6 +269,40 @@ data_format: Specify the data format of the input and output data. With the // -------------------------------------------------------------------------- +REGISTER_OP("DepthwiseConv2dNative") + .Input("input: T") + .Input("filter: T") + .Output("output: T") + .Attr("T: {float, double}") + .Attr("strides: list(int)") + .Attr(GetPaddingAttrString()) + .Doc(R"doc( +Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors. + +Given an input tensor of shape `[batch, in_height, in_width, in_channels]` +and a filter / kernel tensor of shape +`[filter_height, filter_width, in_channels, channel_multiplier]`, containing +`in_channels` convolutional filters of depth 1, `depthwise_conv2d` applies +a different filter to each input channel (expanding from 1 channel to +`channel_multiplier` channels for each), then concatenates the results +together. Thus, the output has `in_channels * channel_multiplier` channels. + +for k in 0..in_channels-1 + for q in 0..channel_multiplier-1 + output[b, i, j, k * channel_multiplier + q] = + sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, k] * + filter[di, dj, k, q] + +Must have `strides[0] = strides[3] = 1`. For the most common case of the same +horizontal and vertices strides, `strides = [1, stride, stride, 1]`. + +strides: 1-D of length 4. The stride of the sliding window for each dimension + of `input`. +padding: The type of padding algorithm to use. +)doc"); + +// -------------------------------------------------------------------------- + REGISTER_OP("L2Loss") .Input("t: T") .Output("output: T") diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 827831fbb7256c44d86c7fdd2ee5fb3de14e7500..8d1b9f5d2334edb96adab55d6921468a99673082 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -2353,6 +2353,49 @@ op { summary: "DepthToSpace for tensors of type T." description: "Rearranges data from depth into blocks of spatial data.\nThis is the reverse transformation of SpaceToDepth. More specifically,\nthis op outputs a copy of the input tensor where values from the `depth`\ndimension are moved in spatial blocks to the `height` and `width` dimensions.\nThe attr `block_size` indicates the input block size and how the data is moved.\n\n * Chunks of data of size `block_size * block_size` from depth are rearranged\n into non-overlapping blocks of size `block_size x block_size`\n * The width the output tensor is `input_depth * block_size`, whereas the\n height is `input_height * block_size`.\n * The depth of the input tensor must be divisible by\n `block_size * block_size`.\n\nThat is, assuming the input is in the shape:\n`[batch, height, width, depth]`,\nthe shape of the output will be:\n`[batch, height*block_size, width*block_size, depth/(block_size*block_size)]`\n\nThis operation requires that the input tensor be of rank 4, and that\n`block_size` be >=1 and that `block_size * block_size` be a divisor of the\ninput depth.\n\nThis operation is useful for resizing the activations between convolutions\n(but keeping all data), e.g. instead of pooling. It is also useful for training\npurely convolutional models.\n\nFor example, given this input of shape `[1, 1, 1, 4]`, and a block size of 2:\n\n```prettyprint\nx = [[[[1, 2, 3, 4]]]]\n\n```\n\nThis operation will output a tensor of shape `[1, 2, 2, 1]`:\n\n```prettyprint\n [[[[1], [2]],\n [[3], [4]]]]\n```\n\nHere, the input has a batch of 1 and each batch element has shape `[1, 1, 4]`,\nthe corresponding output will have 2x2 elements and will have a depth of\n1 channel (1 = `4 / (block_size * block_size)`).\nThe output element shape is `[2, 2, 1]`.\n\nFor an input tensor with larger depth, here of shape `[1, 1, 1, 12]`, e.g.\n\n```prettyprint\nx = [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]]\n```\n\nThis operation, for block size of 2, will return the following tensor of shape\n`[1, 2, 2, 3]`\n\n```prettyprint\n [[[[1, 2, 3], [4, 5, 6]],\n [[7, 8, 9], [10, 11, 12]]]]\n\n```\n\nSimilarly, for the following input of shape `[1 2 2 4]`, and a block size of 2:\n\n```prettyprint\nx = [[[[1, 2, 3, 4],\n [5, 6, 7, 8]],\n [[9, 10, 11, 12],\n [13, 14, 15, 16]]]]\n```\n\nthe operator will return the following tensor of shape `[1 4 4 1]`:\n\n```prettyprint\nx = [[ [1], [2], [5], [6]],\n [ [3], [4], [7], [8]],\n [ [9], [10], [13], [14]],\n [ [11], [12], [15], [16]]]\n\n```" } +op { + name: "DepthwiseConv2dNative" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "filter" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "strides" + type: "list(int)" + description: "1-D of length 4. The stride of the sliding window for each dimension\nof `input`." + } + attr { + name: "padding" + type: "string" + description: "The type of padding algorithm to use." + allowed_values { + list { + s: "SAME" + s: "VALID" + } + } + } + summary: "Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors." + description: "Given an input tensor of shape `[batch, in_height, in_width, in_channels]`\nand a filter / kernel tensor of shape\n`[filter_height, filter_width, in_channels, channel_multiplier]`, containing\n`in_channels` convolutional filters of depth 1, `depthwise_conv2d` applies\na different filter to each input channel (expanding from 1 channel to\n`channel_multiplier` channels for each), then concatenates the results\ntogether. Thus, the output has `in_channels * channel_multiplier` channels.\n\nfor k in 0..in_channels-1\n for q in 0..channel_multiplier-1\n output[b, i, j, k * channel_multiplier + q] =\n sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, k] *\n filter[di, dj, k, q]\n\nMust have `strides[0] = strides[3] = 1`. For the most common case of the same\nhorizontal and vertices strides, `strides = [1, stride, stride, 1]`." +} op { name: "DeserializeManySparse" input_arg { diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a4ba95a0198e039cedf6dd236fc158d8fc8785cb --- /dev/null +++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py @@ -0,0 +1,205 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for depthwise convolutional operations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + + +def ConfigsToTest(): + """Iterator for different convolution shapes, strides and paddings. + + Yields: + Tuple (input_size, filter_size, out_size, stride, padding), the depthwise + convolution parameters. + """ + input_sizes = [[4, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48], + [4, 35, 35, 192], [4, 147, 147, 24], [3, 299, 299, 3]] + filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [3, 1, 48, 4], [5, 5, 192, 1], + [3, 3, 24, 8], [1, 1, 3, 8]] + out_sizes = [[4, 5, 5, 96], [4, 8, 8, 84], [4, 17, 17, 192], [4, 11, 11, 192], + [4, 74, 74, 192], [3, 299, 299, 24]] + strides = [1, 1, 1, 3, 2, 1] + # pylint: disable=invalid-name + VALID = "VALID" + SAME = "SAME" + # pylint: enable=invalid-name + paddings = [SAME, SAME, SAME, VALID, SAME, SAME] + for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides, + paddings): + yield i, f, o, s, p + + +class DepthwiseConv2DTest(tf.test.TestCase): + + # This is testing against the output of the implementation using the + # combination of conv_2d and slicing ops. + def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, stride, padding, + use_gpu): + """Verifies the output values of the convolution function. + + Args: + tensor_in_sizes: Input tensor dimensions in + [batch, input_rows, input_cols, input_depth]. + filter_in_sizes: Filter tensor dimensions in + [filter_rows, filter_cols, input_depth, depth_multiplier]. + stride: Stride. + padding: Padding type. + use_gpu: Whether to use GPU. + """ + total_size_1 = 1 + total_size_2 = 1 + for s in tensor_in_sizes: + total_size_1 *= s + for s in filter_in_sizes: + total_size_2 *= s + # Initializes the input tensor with array containing incrementing + # numbers from 1. + x1 = [f * 1.0 for f in range(1, total_size_1 + 1)] + x2 = [f * 1.0 for f in range(1, total_size_2 + 1)] + with self.test_session(use_gpu=use_gpu) as sess: + t1 = tf.constant(x1, shape=tensor_in_sizes) + t1.set_shape(tensor_in_sizes) + t2 = tf.constant(x2, shape=filter_in_sizes) + conv_native = tf.nn.depthwise_conv2d_native( + t1, + t2, + strides=[1, stride, stride, 1], + padding=padding) + + conv_gold = tf.nn.depthwise_conv2d(t1, + t2, + strides=[1, stride, stride, 1], + padding=padding) + native_result = sess.run(conv_native) + gold_result = sess.run(conv_gold) + + self.assertArrayNear(np.ravel(native_result), np.ravel(gold_result), 1e-5) + self.assertShapeEqual(native_result, conv_native) + self.assertShapeEqual(native_result, conv_gold) + + def testDepthwiseConv2D(self): + for _, (input_size, filter_size, _, stride, + padding) in enumerate(ConfigsToTest()): + self._VerifyValues(input_size, filter_size, stride, padding, use_gpu=True) + + # This is testing against hand calculated results. + def _VerifyHandValues(self, tensor_in_sizes, filter_in_sizes, stride, padding, + expected, use_gpu): + """Verifies the output values of the depthwise convolution function. + + Args: + tensor_in_sizes: Input tensor dimensions in + [batch, input_rows, input_cols, input_depth]. + filter_in_sizes: Filter tensor dimensions in + [filter_rows, filter_cols, input_depth, depth_multiplier]. + stride: Stride. + padding: Padding type. + expected: An array containing the expected operation outputs. + use_gpu: Whether to use GPU. + """ + total_size_1 = 1 + total_size_2 = 1 + for s in tensor_in_sizes: + total_size_1 *= s + for s in filter_in_sizes: + total_size_2 *= s + # Initializes the input tensor with array containing incrementing + # numbers from 1. + x1 = [f * 1.0 for f in range(1, total_size_1 + 1)] + x2 = [f * 1.0 for f in range(1, total_size_2 + 1)] + with self.test_session(use_gpu=use_gpu) as sess: + t1 = tf.constant(x1, shape=tensor_in_sizes) + t1.set_shape(tensor_in_sizes) + t2 = tf.constant(x2, shape=filter_in_sizes) + conv = tf.nn.depthwise_conv2d_native(t1, + t2, + strides=[1, stride, stride, 1], + padding=padding) + value = sess.run(conv) + print("value = ", value) + self.assertArrayNear(expected, np.ravel(value), 1e-5) + self.assertShapeEqual(value, conv) + + def testConv2D2x2Filter(self): + # The inputs look like this (it's a 3 x 2 matrix, each of depth 2): + # + # [ (1.0, 2.0), (3.0, 4.0), ( 5.0, 6.0) ] + # [ (7.0, 8.0), (9.0, 10.0), (11.0, 12.0) ] + # We can view this as two inputs + # + # input depth 0: + # + # [ 1.0, 3.0, 5.0 ] + # [ 7.0, 9.0, 11.0 ] + # + # input depth 1: + # + # [ 2.0, 4.0, 6.0 ] + # [ 8.0, 10.0, 12.0 ] + # + # The filter looks like this (it has two 2 x 2 patches, each generating 2 + # depths): + # + # filter #0: + # + # [ (1.0, 3.0), ( 5.0, 7.0)] + # [ (9.0, 11.0), (13.0, 15.0)] + # + # filter #1: + # + # [ ( 2.0, 4.0), ( 6.0, 8.0)] + # [ (10.0, 12.0), (14.0, 16.0)] + # + # So the outputs are: + # + # (position 0, 0: in_depth 0, output_depth 0 -- using filter #0) + # 1.0 * 1.0 + 7.0 * 9.0 + 3.0 * 5.0 + 9.0 * 13.0 = 196 + # (position 0, 0: in_depth 0, output_depth 1 -- using filter #1) + # 1.0 * 2.0 + 7.0 * 10.0 + 3.0 * 6.0 + 9.0 * 14.0 = 216 + # (position 0, 0: in_depth 1, output_depth 2 -- using filter #0) + # 2.0 * 3.0 + 8.0 * 11.0 + 4.0 * 7.0 + 10.0 * 15.0 = 272 + # (position 0, 0: in_depth 1, output_depth 3 -- using filter #1) + # 2.0 * 4.0 + 8.0 * 12.0 + 4.0 * 8.0 + 10.0 * 16.0 = 296 + # + # (position 1, 0: in_depth 0, output_depth 0 -- using filter #0) + # 3.0 * 1.0 + 9.0 * 9.0 + 5.0 * 5.0 + 11.0 * 13.0 = 252 + # (position 1, 0: in_depth 0, output_depth 1 -- using filter #1) + # 3.0 * 2.0 + 9.0 * 10.0 + 5.0 * 6.0 + 11.0 * 14.0 = 280 + # (position 1, 0: in_depth 1, output_depth 2 -- using filter #0) + # 4.0 * 3.0 + 10.0 * 11.0 + 6.0 * 7.0 + 12.0 * 15.0 = 344 + # (position 1, 0: in_depth 1, output_depth 3 -- using filter #1) + # 4.0 * 4.0 + 10.0 * 12.0 + 6.0 * 8.0 + 12.0 * 16.0 = 376 + expected_output = [196, 216, 272, 296, 252, 280, 344, 376] + self._VerifyHandValues(tensor_in_sizes=[1, 2, 3, 2], + filter_in_sizes=[2, 2, 2, 2], + stride=1, + padding="VALID", + expected=expected_output, + use_gpu=False) + + self._VerifyHandValues(tensor_in_sizes=[1, 2, 3, 2], + filter_in_sizes=[2, 2, 2, 2], + stride=1, + padding="VALID", + expected=expected_output, + use_gpu=True) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/python/ops/common_shapes.py b/tensorflow/python/ops/common_shapes.py index 233e8a5edfebcd6f500d1f9019e0b3f09a461495..5e3d9ff059edb602ebeea0ad68d945e342269e5d 100644 --- a/tensorflow/python/ops/common_shapes.py +++ b/tensorflow/python/ops/common_shapes.py @@ -225,6 +225,61 @@ def conv2d_shape(op): return [tensor_shape.TensorShape(output_shape)] +def depthwise_conv2d_native_shape(op): + """Shape function for a DepthwiseConv2D op. + + This op has two inputs: + + * input, a 4D tensor with shape = [batch_size, rows, cols, depth_in] + * filter, a 4D tensor with shape = [filter_rows, filter_cols, + depth_in, depthwise_multiplier] + + The output is a 4D tensor with shape = [batch_size, out_rows, + out_cols, depth_in*depthwise_multiplier], where out_rows and out_cols depend + on the value of the op's "padding" and "strides" attrs. + + Args: + op: A DepthwiseConv2dNative Operation. + + Returns: + A list containing the Shape of the DepthwiseConv2DNative output. + + Raises: + ValueError: If the shapes of the input or filter are incompatible. + """ + input_shape = op.inputs[0].get_shape().with_rank(4) + filter_shape = op.inputs[1].get_shape().with_rank(4) + + batch_size = input_shape[0] + in_rows = input_shape[1] + in_cols = input_shape[2] + + filter_rows = filter_shape[0] + filter_cols = filter_shape[1] + depth_out = filter_shape[3] * filter_shape[2] + # Check that the input depths are compatible. + input_shape[3].assert_is_compatible_with(filter_shape[2]) + + stride_b, stride_r, stride_c, stride_d = op.get_attr("strides") + if stride_b != 1 or stride_d != 1: + raise ValueError("Current implementation does not yet support " + "strides in the batch and depth dimensions.") + if stride_r != stride_c: + # TODO(shlens): Add support for this. + raise ValueError("Current implementation only supports equal length " + "strides in the row and column dimensions.") + + # TODO(mrry,shlens): Raise an error if the stride would cause + # information in the input to be ignored. This will require a change + # in the kernel implementation. + stride = stride_r + padding = op.get_attr("padding") + out_rows, out_cols = get2d_conv_output_size( + in_rows, in_cols, filter_rows, filter_cols, stride, stride, padding) + + return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])] + + def separable_conv2d_shape(op): """Shape function for a SeparableConv2D op. diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 379408b83bb00438c352cc988d15360c8209a51c..efb1fcb0b3272b3cc35363aaa228c82e5c683a0c 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -398,6 +398,8 @@ def _BatchNormGradShape(op): ops.RegisterShape("Conv2D")(common_shapes.conv2d_shape) +ops.RegisterShape("DepthwiseConv2dNative")( + common_shapes.depthwise_conv2d_native_shape) ops.RegisterShape("AvgPool")(common_shapes.avg_pool_shape) ops.RegisterShape("MaxPool")(common_shapes.max_pool_shape)