提交 9f75f8e6 编写于 作者: P Peter Hawkins 提交者: Michael Case

[TF:XLA] Implement ExtractImagePatches.

PiperOrigin-RevId: 184033616
上级 2639cda7
......@@ -255,6 +255,18 @@ tf_xla_py_test(
],
)
tf_xla_py_test(
name = "extract_image_patches_op_test",
size = "small",
srcs = ["extract_image_patches_op_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "fft_test",
size = "medium",
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functional tests for ExtractImagePatches op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class ExtractImagePatches(XLATestCase):
"""Functional tests for ExtractImagePatches op."""
def _VerifyValues(self, image, ksizes, strides, rates, padding, patches):
"""Tests input-output pairs for the ExtractImagePatches op.
Args:
image: Input tensor with shape: [batch, in_rows, in_cols, depth].
ksizes: Patch size specified as: [ksize_rows, ksize_cols].
strides: Output strides, specified as [stride_rows, stride_cols].
rates: Atrous rates, specified as [rate_rows, rate_cols].
padding: Padding type.
patches: Expected output.
"""
ksizes = [1] + ksizes + [1]
strides = [1] + strides + [1]
rates = [1] + rates + [1]
with self.test_session():
image_placeholder = array_ops.placeholder(dtypes.float32)
with self.test_scope():
out_tensor = array_ops.extract_image_patches(
image_placeholder,
ksizes=ksizes,
strides=strides,
rates=rates,
padding=padding,
name="im2col")
feed_dict = {image_placeholder: image}
self.assertAllClose(patches, out_tensor.eval(feed_dict=feed_dict))
def testKsize1x1Stride1x1Rate1x1(self):
"""Verifies that for 1x1 kernel the output equals the input."""
# [2, 3, 4, 5]
image = np.reshape(range(120), [2, 3, 4, 5])
# [2, 3, 4, 5]
patches = np.reshape(range(120), [2, 3, 4, 5])
for padding in ["VALID", "SAME"]:
self._VerifyValues(
image,
ksizes=[1, 1],
strides=[1, 1],
rates=[1, 1],
padding=padding,
patches=patches)
def testKsize1x1Stride2x3Rate1x1(self):
"""Test for 1x1 kernel and strides."""
# [2, 4, 5, 3]
image = np.reshape(range(120), [2, 4, 5, 3])
# [2, 2, 2, 3]
patches = image[:, ::2, ::3, :]
for padding in ["VALID", "SAME"]:
self._VerifyValues(
image,
ksizes=[1, 1],
strides=[2, 3],
rates=[1, 1],
padding=padding,
patches=patches)
def testKsize2x2Stride1x1Rate1x1Valid(self):
"""Test for 2x2 kernel with VALID padding."""
# [1, 2, 2, 1]
image = [[[[1], [2]], [[3], [4]]]]
# [1, 1, 1, 4]
patches = [[[[1, 2, 3, 4]]]]
self._VerifyValues(
image,
ksizes=[2, 2],
strides=[1, 1],
rates=[1, 1],
padding="VALID",
patches=patches)
def testKsize2x2Stride1x1Rate1x1Same(self):
"""Test for 2x2 kernel with SAME padding."""
# [1, 2, 2, 1]
image = [[[[1], [2]], [[3], [4]]]]
# [1, 2, 2, 4]
patches = [[[[1, 2, 3, 4], [2, 0, 4, 0]], [[3, 4, 0, 0], [4, 0, 0, 0]]]]
self._VerifyValues(
image,
ksizes=[2, 2],
strides=[1, 1],
rates=[1, 1],
padding="SAME",
patches=patches)
def testKsize2x2Stride1x1Rate2x2Valid(self):
"""Test for 2x2 kernel with 2x2 dilation."""
# [1, 2, 2, 1]
image = np.arange(16).reshape(1, 4, 4, 1).astype(np.float32)
# [1, 2, 2, 4]
patches = [[[[0, 2, 8, 10], [1, 3, 9, 11]],
[[4, 6, 12, 14], [5, 7, 13, 15]]]]
self._VerifyValues(
image,
ksizes=[2, 2],
strides=[1, 1],
rates=[2, 2],
padding="VALID",
patches=patches)
if __name__ == "__main__":
test.main()
......@@ -71,6 +71,7 @@ Operator | Type Constraint
`Exp` | `T={complex64,double,float}`
`ExpandDims` | `Tdim={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Expm1` | `T={complex64,double,float}`
`ExtractImagePatches` | `T={double,float,int32,int64,uint32,uint64}`
`FFT` |
`FFT2D` |
`FFT3D` |
......@@ -124,6 +125,8 @@ Operator | Type Constraint
`MaxPool3D` | `T={float}`
`MaxPool3DGrad` | `TInput={float}`<br>`T={float}`
`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}`
`MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}`
`MaxPoolV2` | `T={double,float,int32,int64}`
`Maximum` | `T={double,float,int32,int64}`
`Mean` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
`Min` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
......@@ -176,6 +179,7 @@ Operator | Type Constraint
`ResourceGather` | `Tindices={int32,int64}`<br>`dtype={complex64,double,float,int32,int64,uint32,uint64}`
`ResourceStridedSliceAssign` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Reverse` | `T={bool,complex64,double,float,int32,int64}`
`ReverseSequence` | `Tlen={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`<br>`Tidx={int32,int64}`
`RightShift` | `T={int32,int64,uint32,uint64}`
`Rint` | `T={double,float}`
......
......@@ -71,6 +71,7 @@ Operator | Type Constraint
`Exp` | `T={complex64,double,float}`
`ExpandDims` | `Tdim={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Expm1` | `T={complex64,double,float}`
`ExtractImagePatches` | `T={double,float,int32,int64,uint32,uint64}`
`FFT` |
`FFT2D` |
`FFT3D` |
......@@ -124,6 +125,8 @@ Operator | Type Constraint
`MaxPool3D` | `T={float}`
`MaxPool3DGrad` | `TInput={float}`<br>`T={float}`
`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}`
`MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}`
`MaxPoolV2` | `T={double,float,int32,int64}`
`Maximum` | `T={double,float,int32,int64}`
`Mean` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
`Min` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
......@@ -173,6 +176,7 @@ Operator | Type Constraint
`ResourceGather` | `Tindices={int32,int64}`<br>`dtype={complex64,double,float,int32,int64,uint32,uint64}`
`ResourceStridedSliceAssign` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Reverse` | `T={bool,complex64,double,float,int32,int64}`
`ReverseSequence` | `Tlen={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`<br>`Tidx={int32,int64}`
`RightShift` | `T={int32,int64,uint32,uint64}`
`Rint` | `T={double,float}`
......
......@@ -31,6 +31,7 @@ tf_kernel_library(
"diag_op.cc",
"dynamic_stitch_op.cc",
"elu_op.cc",
"extract_image_patches_op.cc",
"fft_ops.cc",
"fill_op.cc",
"function_ops.cc",
......
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
namespace {
class ExtractImagePatchesOp : public XlaOpKernel {
public:
explicit ExtractImagePatchesOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("ksizes", &ksizes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("rates", &dilations_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
}
void Compile(XlaOpKernelContext* ctx) override {
const TensorFormat data_format = FORMAT_NHWC;
const int num_dims = ksizes_.size();
OP_REQUIRES(
ctx, num_dims >= 3,
errors::InvalidArgument("Kernel size must have at least 3 dimensions"));
const int num_spatial_dims = num_dims - 2;
OP_REQUIRES(ctx, strides_.size() == num_dims,
errors::InvalidArgument("Sliding window strides field must "
"specify ",
num_dims, " dimensions"));
OP_REQUIRES(ctx, dilations_.size() == num_dims,
errors::InvalidArgument("Dilations field must "
"specify ",
num_dims, " dimensions"));
int batch_dim = GetTensorBatchDimIndex(num_dims, data_format);
int feature_dim = GetTensorFeatureDimIndex(num_dims, data_format);
OP_REQUIRES(
ctx, ksizes_[batch_dim] == 1 && ksizes_[feature_dim] == 1,
errors::Unimplemented("Current implementation does not yet support "
"kernel sizes > 1 in the batch and depth "
"dimensions."));
OP_REQUIRES(
ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
errors::Unimplemented("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES(
ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
errors::Unimplemented("Current implementation does not support "
"dilations in the batch and depth dimensions."));
for (int i = 0; i < num_spatial_dims; ++i) {
int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
OP_REQUIRES(
ctx, ksizes_[input_dim] >= 0,
errors::Unimplemented("Kernel size values must be non-negative; ", i,
"th spatial dimension had dilation ",
dilations_[input_dim]));
OP_REQUIRES(ctx, strides_[input_dim] >= 1,
errors::Unimplemented("Stride values must be positive; ", i,
"th spatial dimension had dilation ",
dilations_[input_dim]));
OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
errors::Unimplemented("Dilation values must be positive; ", i,
"th spatial dimension had dilation ",
dilations_[input_dim]));
}
xla::PrimitiveType type;
OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(0), &type));
const TensorShape input_shape = ctx->InputShape(0);
OP_REQUIRES(
ctx, input_shape.dims() == num_dims,
errors::InvalidArgument("input must be ", num_dims, "-dimensional",
input_shape.DebugString()));
const int64 depth = input_shape.dim_size(feature_dim);
xla::ComputationBuilder* builder = ctx->builder();
// The following code is equivalent to:
// eye = np.eye(kH * kW * D).reshape([kH, kW, D, kH * kW * kD])
int64 kernel_size = 1;
std::vector<int64> lhs_shape(num_dims, 1);
for (int i = 0; i < num_spatial_dims; ++i) {
int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
lhs_shape[i] = ksizes_[input_dim];
kernel_size *= ksizes_[input_dim];
}
lhs_shape[num_spatial_dims] = depth;
lhs_shape[num_spatial_dims + 1] = 1;
// Builds an identity matrix as a broadcast equality of iotas.
// iota = np.arange(np.prod(ksize), depth)
// filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32)
xla::ComputationDataHandle iota;
TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32,
kernel_size * depth, &iota));
auto lhs = builder->Reshape(iota, lhs_shape);
auto filter = builder->ConvertElementType(
builder->Eq(lhs, iota, {num_spatial_dims + 1}), type);
xla::ConvolutionDimensionNumbers dims;
std::vector<int64> window_strides(num_spatial_dims);
std::vector<int64> lhs_dilation(num_spatial_dims, 1);
std::vector<int64> rhs_dilation(num_spatial_dims);
std::vector<std::pair<int64, int64>> padding(num_spatial_dims);
dims.set_input_batch_dimension(batch_dim);
dims.set_output_batch_dimension(batch_dim);
dims.set_input_feature_dimension(feature_dim);
dims.set_output_feature_dimension(feature_dim);
dims.set_kernel_input_feature_dimension(num_spatial_dims);
dims.set_kernel_output_feature_dimension(num_spatial_dims + 1);
for (int i = 0; i < num_spatial_dims; ++i) {
const int64 dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
dims.add_input_spatial_dimensions(dim);
dims.add_kernel_spatial_dimensions(i);
dims.add_output_spatial_dimensions(dim);
window_strides[i] = strides_.at(dim);
rhs_dilation[i] = dilations_.at(dim);
int64 unused_output_size;
OP_REQUIRES_OK(
ctx, GetWindowedOutputSizeVerboseV2(
input_shape.dim_size(dim), ksizes_[dim], rhs_dilation[i],
window_strides[i], padding_, &unused_output_size,
&padding[i].first, &padding[i].second));
}
xla::ComputationDataHandle conv =
builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides,
padding, lhs_dilation, rhs_dilation, dims);
ctx->SetOutput(0, conv);
}
protected:
std::vector<int32> ksizes_;
std::vector<int32> dilations_;
std::vector<int32> strides_;
Padding padding_;
private:
TF_DISALLOW_COPY_AND_ASSIGN(ExtractImagePatchesOp);
};
REGISTER_XLA_OP(Name("ExtractImagePatches"), ExtractImagePatchesOp);
} // namespace
} // namespace tensorflow
......@@ -84,7 +84,7 @@ class ExtractImagePatches(test.TestCase):
patches=patches)
def testKsize2x2Stride1x1Rate1x1Valid(self):
"""Test for 1x1 kernel ."""
"""Test for 2x2 kernel with VALID padding."""
# [1, 2, 2, 1]
image = [[[[1], [2]], [[3], [4]]]]
# [1, 1, 1, 4]
......@@ -98,7 +98,7 @@ class ExtractImagePatches(test.TestCase):
patches=patches)
def testKsize2x2Stride1x1Rate1x1Same(self):
"""Test for 1x1 kernel ."""
"""Test for 2x2 kernel with SAME padding."""
# [1, 2, 2, 1]
image = [[[[1], [2]], [[3], [4]]]]
# [1, 2, 2, 4]
......@@ -111,6 +111,20 @@ class ExtractImagePatches(test.TestCase):
padding="SAME",
patches=patches)
def testKsize2x2Stride1x1Rate2x2Valid(self):
"""Test for 2x2 kernel with 2x2 dilation."""
# [1, 2, 2, 1]
image = np.arange(16).reshape(1, 4, 4, 1).astype(np.float32)
# [1, 2, 2, 4]
patches = [[[[0, 2, 8, 10], [1, 3, 9, 11]],
[[4, 6, 12, 14], [5, 7, 13, 15]]]]
self._VerifyValues(
image,
ksizes=[2, 2],
strides=[1, 1],
rates=[2, 2],
padding="VALID",
patches=patches)
if __name__ == "__main__":
test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册