提交 8abe7db5 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Enable fp16 for most of the pooling ops (MaxPool, AvgPool, associated

gradients, some variants etc.).
Change: 123967787
上级 9bedadce
package(default_visibility = ["//visibility:public"])
archive_dir = "eigen-eigen-0c0b79ecd74c"
archive_dir = "eigen-eigen-d02e6a705c30"
cc_library(
name = "eigen",
......
......@@ -7,7 +7,7 @@
include (ExternalProject)
set(eigen_archive_hash "0c0b79ecd74c")
set(eigen_archive_hash "d02e6a705c30")
set(eigen_INCLUDE_DIRS
${CMAKE_CURRENT_BINARY_DIR}
......@@ -16,7 +16,7 @@ set(eigen_INCLUDE_DIRS
${tensorflow_source_dir}/third_party/eigen3
)
set(eigen_URL https://bitbucket.org/eigen/eigen/get/${eigen_archive_hash}.tar.gz)
set(eigen_HASH SHA256=b4b5884b03bd4bae114d02b36e2435ad1504ed8e51431d16c876b6f6a365882b)
set(eigen_HASH SHA256=532956172daa8aba87c750791ff89a5c38cdb07e2525afe17ecb4bef812d67cf)
set(eigen_BUILD ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen)
set(eigen_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/eigen/install)
......
......@@ -99,12 +99,10 @@ class AvgPoolingOp : public UnaryOp<T> {
TensorFormat data_format_;
};
REGISTER_KERNEL_BUILDER(
Name("AvgPool").Device(DEVICE_CPU).TypeConstraint<float>("T"),
AvgPoolingOp<CPUDevice, float>);
REGISTER_KERNEL_BUILDER(
Name("AvgPool").Device(DEVICE_CPU).TypeConstraint<Eigen::half>("T"),
AvgPoolingOp<CPUDevice, Eigen::half>);
REGISTER_KERNEL_BUILDER(Name("AvgPool")
.Device(DEVICE_CPU)
.TypeConstraint<float>("T"),
AvgPoolingOp<CPUDevice, float>);
#if GOOGLE_CUDA
template <typename T>
......@@ -183,17 +181,14 @@ namespace functor {
const Eigen::PaddingType& padding); \
extern template struct SpatialAvgPooling<GPUDevice, T>;
DECLARE_GPU_SPEC(Eigen::half);
DECLARE_GPU_SPEC(float);
#undef DECLARE_GPU_SPEC
} // namespace functor
REGISTER_KERNEL_BUILDER(
Name("AvgPool").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
AvgPoolingOp<GPUDevice, Eigen::half>);
REGISTER_KERNEL_BUILDER(
Name("AvgPool").Device(DEVICE_GPU).TypeConstraint<float>("T"),
AvgPoolingOp<GPUDevice, float>);
REGISTER_KERNEL_BUILDER(Name("AvgPool")
.Device(DEVICE_GPU)
.TypeConstraint<float>("T"),
AvgPoolingOp<GPUDevice, float>);
#endif // GOOGLE_CUDA
// The operation to compute AvgPool gradients.
......@@ -305,7 +300,7 @@ class AvgPoolingGradOp : public OpKernel {
GetBroadcastSize(c, in_cols, window_cols, col_stride,
pad_cols, &cindex, &csize));
T divide_coeff(1.0 / (rsize * csize));
T divide_coeff = 1.0 / (rsize * csize);
int64 output_index =
(b * out_backprop_rows + r) * out_backprop_cols + c;
for (int64 r_dst = rindex; r_dst < rindex + rsize; ++r_dst) {
......@@ -342,11 +337,6 @@ class AvgPoolingGradOp : public OpKernel {
TensorFormat data_format_;
};
REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
.Device(DEVICE_CPU)
.TypeConstraint<Eigen::half>("T")
.HostMemory("orig_input_shape"),
AvgPoolingGradOp<CPUDevice, Eigen::half>);
REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
.Device(DEVICE_CPU)
.TypeConstraint<float>("T")
......@@ -426,12 +416,6 @@ REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
.HostMemory("orig_input_shape")
.Label("cudnn"),
AvgPoolingGradOp<GPUDevice, float>);
REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
.Device(DEVICE_GPU)
.TypeConstraint<Eigen::half>("T")
.HostMemory("orig_input_shape")
.Label("cudnn"),
AvgPoolingGradOp<GPUDevice, Eigen::half>);
// A custom GPU kernel based AvgPoolingGrad implementation. It includes the
// padding as the candidates for the pooling operation.
......@@ -548,11 +532,6 @@ REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
.TypeConstraint<float>("T")
.HostMemory("orig_input_shape"),
AvgPoolingGradOpCustomGPUKernel<float>);
REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
.Device(DEVICE_GPU)
.TypeConstraint<Eigen::half>("T")
.HostMemory("orig_input_shape"),
AvgPoolingGradOpCustomGPUKernel<Eigen::half>);
#endif // GOOGLE_CUDA
......
......@@ -33,7 +33,6 @@ typedef Eigen::GpuDevice GPUDevice;
#define DEFINE_GPU_KERNELS(T) \
template struct functor::SpatialAvgPooling<GPUDevice, T>;
DEFINE_GPU_KERNELS(Eigen::half)
DEFINE_GPU_KERNELS(float)
#undef DEFINE_GPU_KERNELS
......@@ -58,7 +57,7 @@ __global__ void AvePoolBackwardNHWC(const int nthreads,
const int phend = min(h / stride_h + 1, pooled_height);
const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
const int pwend = min(w / stride_w + 1, pooled_width);
dtype gradient(0);
dtype gradient = 0;
const dtype* const top_diff_slice =
top_diff + n * pooled_height * pooled_width * channels + c;
for (int ph = phstart; ph < phend; ++ph) {
......@@ -105,12 +104,6 @@ template bool RunAvePoolBackwardNHWC(
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
float* const bottom_diff, const GPUDevice& d);
template bool RunAvePoolBackwardNHWC(
const Eigen::half* const top_diff, const int num, const int height,
const int width, const int channels, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
Eigen::half* const bottom_diff, const GPUDevice& d);
} // end namespace tensorflow
......
......@@ -309,7 +309,7 @@ struct AvgPoolMeanReducer {
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE AvgPoolMeanReducer() : scalarCount_(0) {
typedef typename packet_traits<T>::type Packet;
packetCount_ = pset1<Packet>(T(0.0));
packetCount_ = pset1<Packet>(0.0);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) {
......
......@@ -160,7 +160,7 @@ static void SpatialMaxPoolWithArgMaxHelper(
const int in_end = limit * in_size;
EigenMatrixMap in_shard(input_backprop_flat.data() + in_start, 1,
in_end - in_start);
in_shard.setConstant(T(0));
in_shard.setConstant(0);
// Backpropagate.
const int out_size = out_height * out_width * depth;
......@@ -187,12 +187,8 @@ static void SpatialMaxPoolWithArgMaxHelper(
params.tensor_in_batch, shard_cost, shard);
}
REGISTER_KERNEL_BUILDER(
Name("MaxPool").Device(DEVICE_CPU).TypeConstraint<float>("T"),
MaxPoolingOp<CPUDevice, float>);
REGISTER_KERNEL_BUILDER(
Name("MaxPool").Device(DEVICE_CPU).TypeConstraint<Eigen::half>("T"),
MaxPoolingOp<CPUDevice, Eigen::half>);
REGISTER_KERNEL_BUILDER(Name("MaxPool").Device(DEVICE_CPU),
MaxPoolingOp<CPUDevice, float>);
#if GOOGLE_CUDA
// Forward declarations for the functor specializations for GPU.
......@@ -216,7 +212,6 @@ DECLARE_GPU_SPEC(float);
// kernel_label_map.
REGISTER_KERNEL_BUILDER(Name("MaxPool")
.Device(DEVICE_GPU)
.TypeConstraint<float>("T")
.Label("eigen_tensor"),
MaxPoolingOp<Eigen::GpuDevice, float>);
#endif // GOOGLE_CUDA
......@@ -302,16 +297,11 @@ class MaxPoolingGradOp : public OpKernel {
TensorFormat data_format_;
};
REGISTER_KERNEL_BUILDER(
Name("MaxPoolGrad").Device(DEVICE_CPU).TypeConstraint<float>("T"),
MaxPoolingGradOp<CPUDevice, float>);
REGISTER_KERNEL_BUILDER(
Name("MaxPoolGrad").Device(DEVICE_CPU).TypeConstraint<Eigen::half>("T"),
MaxPoolingGradOp<CPUDevice, Eigen::half>);
REGISTER_KERNEL_BUILDER(Name("MaxPoolGrad").Device(DEVICE_CPU),
MaxPoolingGradOp<CPUDevice, float>);
#ifdef GOOGLE_CUDA
template <typename T>
static void MaxPoolingBackwardCustomKernel(
OpKernelContext* context, const std::vector<int32>& size,
const std::vector<int32>& stride, Padding padding, const Tensor* tensor_in,
......@@ -328,12 +318,12 @@ static void MaxPoolingBackwardCustomKernel(
}
MaxPoolBackwardNoMask(
tensor_in->flat<T>().data(), params.tensor_in_batch,
tensor_in->flat<float>().data(), params.tensor_in_batch,
params.tensor_in_rows, params.tensor_in_cols, params.depth,
params.out_height, params.out_width, params.window_rows,
params.window_cols, params.row_stride, params.col_stride, params.pad_rows,
params.pad_cols, out_backprop.flat<T>().data(),
output->flat<T>().data(), context->eigen_device<Eigen::GpuDevice>());
params.pad_cols, out_backprop.flat<float>().data(),
output->flat<float>().data(), context->eigen_device<Eigen::GpuDevice>());
}
template <class T>
......@@ -388,8 +378,8 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
} else {
CHECK(data_format_ == FORMAT_NHWC)
<< "Non-Cudnn MaxPoolGrad only supports NHWC format";
MaxPoolingBackwardCustomKernel<T>(context, ksize_, stride_, padding_,
&tensor_in, out_backprop, output_shape);
MaxPoolingBackwardCustomKernel(context, ksize_, stride_, padding_,
&tensor_in, out_backprop, output_shape);
}
}
......@@ -401,12 +391,8 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
bool use_dnn_;
};
REGISTER_KERNEL_BUILDER(
Name("MaxPoolGrad").Device(DEVICE_GPU).TypeConstraint<float>("T"),
MaxPoolingGradOp<Eigen::GpuDevice, float>);
REGISTER_KERNEL_BUILDER(
Name("MaxPoolGrad").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
MaxPoolingGradOp<Eigen::GpuDevice, Eigen::half>);
REGISTER_KERNEL_BUILDER(Name("MaxPoolGrad").Device(DEVICE_GPU),
MaxPoolingGradOp<Eigen::GpuDevice, float>);
#endif // GOOGLE_CUDA
......@@ -639,12 +625,8 @@ struct LaunchMaxPoolingNoMask<Eigen::GpuDevice, T> {
}
};
REGISTER_KERNEL_BUILDER(
Name("MaxPool").Device(DEVICE_GPU).TypeConstraint<float>("T"),
MaxPoolingNoMaskOp<Eigen::GpuDevice, float>);
REGISTER_KERNEL_BUILDER(
Name("MaxPool").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
MaxPoolingNoMaskOp<Eigen::GpuDevice, Eigen::half>);
REGISTER_KERNEL_BUILDER(Name("MaxPool").Device(DEVICE_GPU),
MaxPoolingNoMaskOp<Eigen::GpuDevice, float>);
template <typename T>
struct LaunchMaxPoolingWithArgmax<Eigen::GpuDevice, T> {
......@@ -667,14 +649,8 @@ struct LaunchMaxPoolingWithArgmax<Eigen::GpuDevice, T> {
REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax")
.Device(DEVICE_GPU)
.TypeConstraint<int64>("Targmax")
.TypeConstraint<float>("T"),
.TypeConstraint<int64>("Targmax"),
MaxPoolingWithArgmaxOp<Eigen::GpuDevice, float>);
REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax")
.Device(DEVICE_GPU)
.TypeConstraint<int64>("Targmax")
.TypeConstraint<Eigen::half>("T"),
MaxPoolingWithArgmaxOp<Eigen::GpuDevice, Eigen::half>);
template <typename T>
struct LaunchMaxPoolingGradWithArgmax<Eigen::GpuDevice, T> {
......@@ -699,18 +675,10 @@ struct LaunchMaxPoolingGradWithArgmax<Eigen::GpuDevice, T> {
}
};
REGISTER_KERNEL_BUILDER(
Name("MaxPoolGradWithArgmax")
.Device(DEVICE_GPU)
.TypeConstraint<float>("T")
.TypeConstraint<int64>("Targmax"),
MaxPoolingGradWithArgmaxOp<Eigen::GpuDevice, float>);
REGISTER_KERNEL_BUILDER(
Name("MaxPoolGradWithArgmax")
.Device(DEVICE_GPU)
.TypeConstraint<Eigen::half>("T")
.TypeConstraint<int64>("Targmax"),
MaxPoolingGradWithArgmaxOp<Eigen::GpuDevice, Eigen::half>);
REGISTER_KERNEL_BUILDER(Name("MaxPoolGradWithArgmax")
.Device(DEVICE_GPU)
.TypeConstraint<int64>("Targmax"),
MaxPoolingGradWithArgmaxOp<Eigen::GpuDevice, float>);
#endif // GOOGLE_CUDA
......
......@@ -110,7 +110,7 @@ __global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data,
int wend = min(wstart + kernel_w, width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
dtype maxval = Eigen::NumTraits<dtype>::lowest();
dtype maxval = -FLT_MAX;
int maxidx = -1;
const dtype* bottom_data_n = bottom_data + n * height * width * channels;
for (int h = hstart; h < hend; ++h) {
......@@ -149,7 +149,7 @@ __global__ void MaxPoolBackwardNoMaskNHWC(
int wend = min(wstart + kernel_w, width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
dtype maxval = Eigen::NumTraits<dtype>::lowest();
dtype maxval = -FLT_MAX;
int maxidx = -1;
const dtype* bottom_data_n = bottom_data + n * height * width * channels;
for (int h = hstart; h < hend; ++h) {
......@@ -165,8 +165,8 @@ __global__ void MaxPoolBackwardNoMaskNHWC(
// Atomically accumulate the bottom diff. The index could still be
// uninitialized, if all the bottom_data are NaN.
if (maxidx != -1) {
CudaAtomicAdd(bottom_diff + n * height * width * channels + maxidx,
top_diff[index]);
atomicAdd(bottom_diff + n * height * width * channels + maxidx,
top_diff[index]);
}
}
}
......@@ -185,8 +185,8 @@ __global__ void MaxPoolBackwardNoMaskNHWC(
// bottom_offset: the pre-computed per-image offset of the maxpool input.
// This is equal to H*W*C.
// bottom_diff: the gradient with respect to the input.
// This function relies on CudaAtomicAdd to avoid race conditions. Also, before
// the kernel is run, you will need to make sure that bottom_diff is filled with
// This function relies on atomicAdd to avoid race conditions. Also, before the
// kernel is run, you will need to make sure that bottom_diff is filled with
// zero first.
template <typename dtype>
__global__ void MaxPoolBackward(const int nthreads, const dtype* top_diff,
......@@ -194,8 +194,8 @@ __global__ void MaxPoolBackward(const int nthreads, const dtype* top_diff,
const int bottom_offset, dtype* bottom_diff) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int image_id = (index / top_offset);
CudaAtomicAdd(bottom_diff + image_id * bottom_offset + mask[index],
top_diff[index]);
atomicAdd(bottom_diff + image_id * bottom_offset + mask[index],
top_diff[index]);
}
}
......@@ -219,23 +219,6 @@ bool MaxPoolForwardWithOptionalArgmax(
return d.ok();
}
bool MaxPoolForwardWithOptionalArgmax(
const Eigen::half* bottom_data, const int batch, const int height,
const int width, const int channels, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
Eigen::half* top_data, int64* mask, const Eigen::GpuDevice& d) {
const int kThreadsPerBlock = 1024;
const int output_size = batch * channels * pooled_height * pooled_width;
MaxPoolForwardNHWC<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
kThreadsPerBlock, 0, d.stream()>>>(
output_size, bottom_data, height, width, channels, pooled_height,
pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
top_data, mask);
return d.ok();
}
bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch,
const int height, const int width,
const int channels, const int pooled_height,
......@@ -260,30 +243,6 @@ bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch,
return d.ok();
}
bool MaxPoolBackwardNoMask(const Eigen::half* bottom_data, const int batch,
const int height, const int width,
const int channels, const int pooled_height,
const int pooled_width, const int kernel_h,
const int kernel_w, const int stride_h,
const int stride_w, const int pad_t, const int pad_l,
const Eigen::half* top_diff, Eigen::half* bottom_diff,
const Eigen::GpuDevice& d) {
const int kThreadsPerBlock = 1024;
const int bottom_size = batch * channels * height * width;
const int top_size = batch * channels * pooled_height * pooled_width;
SetZero<<<(bottom_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
kThreadsPerBlock, 0, d.stream()>>>(bottom_size, bottom_diff);
MaxPoolBackwardNoMaskNHWC<<<(top_size + kThreadsPerBlock - 1) /
kThreadsPerBlock,
kThreadsPerBlock, 0, d.stream()>>>(
top_size, bottom_data, height, width, channels, pooled_height,
pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
top_diff, bottom_diff);
return d.ok();
}
bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size,
const float* top_diff, const int64* mask,
const int top_offset, const int bottom_offset,
......@@ -297,27 +256,12 @@ bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size,
return d.ok();
}
bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size,
const Eigen::half* top_diff, const int64* mask,
const int top_offset, const int bottom_offset,
Eigen::half* bottom_diff,
const Eigen::GpuDevice& d) {
const int kThreadsPerBlock = 1024;
SetZero<<<(input_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
kThreadsPerBlock, 0, d.stream()>>>(input_size, bottom_diff);
MaxPoolBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
kThreadsPerBlock, 0, d.stream()>>>(
output_size, top_diff, mask, top_offset, bottom_offset, bottom_diff);
return d.ok();
}
typedef Eigen::GpuDevice GPUDevice;
#define DEFINE_GPU_KERNELS(T) \
template struct functor::SpatialMaxPooling<GPUDevice, T>;
DEFINE_GPU_KERNELS(float)
DEFINE_GPU_KERNELS(Eigen::half)
#undef DEFINE_GPU_KERNELS
......
......@@ -37,24 +37,11 @@ bool MaxPoolForwardWithOptionalArgmax(
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
float* top_data, int64* mask, const Eigen::GpuDevice& d);
bool MaxPoolForwardWithOptionalArgmax(
const Eigen::half* bottom_data, const int batch, const int height,
const int width, const int channels, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
Eigen::half* top_data, int64* mask, const Eigen::GpuDevice& d);
bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size,
const float* top_diff, const int64* mask,
const int top_offset, const int bottom_offset,
float* bottom_diff, const Eigen::GpuDevice& d);
bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size,
const Eigen::half* top_diff, const int64* mask,
const int top_offset, const int bottom_offset,
Eigen::half* bottom_diff,
const Eigen::GpuDevice& d);
bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch,
const int height, const int width,
const int channels, const int pooled_height,
......@@ -64,15 +51,6 @@ bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch,
const float* top_diff, float* bottom_diff,
const Eigen::GpuDevice& d);
bool MaxPoolBackwardNoMask(const Eigen::half* bottom_data, const int batch,
const int height, const int width,
const int channels, const int pooled_height,
const int pooled_width, const int kernel_h,
const int kernel_w, const int stride_h,
const int stride_w, const int pad_t, const int pad_l,
const Eigen::half* top_diff, Eigen::half* bottom_diff,
const Eigen::GpuDevice& d);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_MAXPOOLING_OP_GPU_H_
......@@ -124,7 +124,6 @@ namespace functor {
extern template struct TransformDepth<GPUDevice, T, Eigen::DenseIndex>;
DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(Eigen::half);
#undef DECLARE_GPU_SPEC
} // namespace functor
......@@ -369,9 +368,7 @@ void DnnPoolingGradOp<T>::Compute(
}
}
template class DnnPoolingOp<Eigen::half>;
template class DnnPoolingOp<float>;
template class DnnPoolingGradOp<Eigen::half>;
template class DnnPoolingGradOp<float>;
#endif // GOOGLE_CUDA
......
......@@ -311,7 +311,7 @@ void SpatialAvgPool(OpKernelContext* context, Tensor* output,
}
}
}
DCHECK_GT(out_count.minCoeff(), T(0));
DCHECK_GT(out_count.minCoeff(), 0);
out_mat.array().rowwise() /= out_count.transpose().array();
}
......
......@@ -2933,63 +2933,6 @@ op {
}
}
}
op {
name: "AvgPool"
input_arg {
name: "value"
type_attr: "T"
}
output_arg {
name: "output"
type_attr: "T"
}
attr {
name: "ksize"
type: "list(int)"
has_minimum: true
minimum: 4
}
attr {
name: "strides"
type: "list(int)"
has_minimum: true
minimum: 4
}
attr {
name: "padding"
type: "string"
allowed_values {
list {
s: "SAME"
s: "VALID"
}
}
}
attr {
name: "data_format"
type: "string"
default_value {
s: "NHWC"
}
allowed_values {
list {
s: "NHWC"
s: "NCHW"
}
}
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_FLOAT
type: DT_HALF
type: DT_DOUBLE
}
}
}
}
op {
name: "AvgPool3D"
input_arg {
......@@ -3211,67 +3154,6 @@ op {
}
}
}
op {
name: "AvgPoolGrad"
input_arg {
name: "orig_input_shape"
type: DT_INT32
}
input_arg {
name: "grad"
type_attr: "T"
}
output_arg {
name: "output"
type_attr: "T"
}
attr {
name: "ksize"
type: "list(int)"
has_minimum: true
minimum: 4
}
attr {
name: "strides"
type: "list(int)"
has_minimum: true
minimum: 4
}
attr {
name: "padding"
type: "string"
allowed_values {
list {
s: "SAME"
s: "VALID"
}
}
}
attr {
name: "data_format"
type: "string"
default_value {
s: "NHWC"
}
allowed_values {
list {
s: "NHWC"
s: "NCHW"
}
}
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_FLOAT
type: DT_HALF
type: DT_DOUBLE
}
}
}
}
op {
name: "BatchCholesky"
input_arg {
......@@ -11785,124 +11667,6 @@ op {
}
}
}
op {
name: "MaxPool"
input_arg {
name: "input"
type_attr: "T"
}
output_arg {
name: "output"
type_attr: "T"
}
attr {
name: "ksize"
type: "list(int)"
has_minimum: true
minimum: 4
}
attr {
name: "strides"
type: "list(int)"
has_minimum: true
minimum: 4
}
attr {
name: "padding"
type: "string"
allowed_values {
list {
s: "SAME"
s: "VALID"
}
}
}
attr {
name: "data_format"
type: "string"
default_value {
s: "NHWC"
}
allowed_values {
list {
s: "NHWC"
s: "NCHW"
}
}
}
attr {
name: "T"
type: "type"
default_value {
type: DT_FLOAT
}
allowed_values {
list {
type: DT_FLOAT
type: DT_HALF
}
}
}
}
op {
name: "MaxPool"
input_arg {
name: "input"
type_attr: "T"
}
output_arg {
name: "output"
type_attr: "T"
}
attr {
name: "T"
type: "type"
default_value {
type: DT_FLOAT
}
allowed_values {
list {
type: DT_FLOAT
type: DT_HALF
}
}
}
attr {
name: "ksize"
type: "list(int)"
has_minimum: true
minimum: 4
}
attr {
name: "strides"
type: "list(int)"
has_minimum: true
minimum: 4
}
attr {
name: "padding"
type: "string"
allowed_values {
list {
s: "SAME"
s: "VALID"
}
}
}
attr {
name: "data_format"
type: "string"
default_value {
s: "NHWC"
}
allowed_values {
list {
s: "NHWC"
s: "NCHW"
}
}
}
}
op {
name: "MaxPool3D"
input_arg {
......@@ -12116,73 +11880,6 @@ op {
}
}
}
op {
name: "MaxPoolGrad"
input_arg {
name: "orig_input"
type_attr: "T"
}
input_arg {
name: "orig_output"
type_attr: "T"
}
input_arg {
name: "grad"
type_attr: "T"
}
output_arg {
name: "output"
type_attr: "T"
}
attr {
name: "ksize"
type: "list(int)"
has_minimum: true
minimum: 4
}
attr {
name: "strides"
type: "list(int)"
has_minimum: true
minimum: 4
}
attr {
name: "padding"
type: "string"
allowed_values {
list {
s: "SAME"
s: "VALID"
}
}
}
attr {
name: "data_format"
type: "string"
default_value {
s: "NHWC"
}
allowed_values {
list {
s: "NHWC"
s: "NCHW"
}
}
}
attr {
name: "T"
type: "type"
default_value {
type: DT_FLOAT
}
allowed_values {
list {
type: DT_FLOAT
type: DT_HALF
}
}
}
}
op {
name: "MaxPoolGradWithArgmax"
input_arg {
......@@ -12234,70 +11931,6 @@ op {
}
}
}
op {
name: "MaxPoolGradWithArgmax"
input_arg {
name: "input"
type_attr: "T"
}
input_arg {
name: "grad"
type_attr: "T"
}
input_arg {
name: "argmax"
type_attr: "Targmax"
}
output_arg {
name: "output"
type_attr: "T"
}
attr {
name: "ksize"
type: "list(int)"
has_minimum: true
minimum: 4
}
attr {
name: "strides"
type: "list(int)"
has_minimum: true
minimum: 4
}
attr {
name: "padding"
type: "string"
allowed_values {
list {
s: "SAME"
s: "VALID"
}
}
}
attr {
name: "Targmax"
type: "type"
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "T"
type: "type"
default_value {
type: DT_FLOAT
}
allowed_values {
list {
type: DT_FLOAT
type: DT_HALF
}
}
}
}
op {
name: "MaxPoolWithArgmax"
input_arg {
......@@ -12348,69 +11981,6 @@ op {
}
}
}
op {
name: "MaxPoolWithArgmax"
input_arg {
name: "input"
type_attr: "T"
}
output_arg {
name: "output"
type_attr: "T"
}
output_arg {
name: "argmax"
type_attr: "Targmax"
}
attr {
name: "ksize"
type: "list(int)"
has_minimum: true
minimum: 4
}
attr {
name: "strides"
type: "list(int)"
has_minimum: true
minimum: 4
}
attr {
name: "Targmax"
type: "type"
default_value {
type: DT_INT64
}
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "padding"
type: "string"
allowed_values {
list {
s: "SAME"
s: "VALID"
}
}
}
attr {
name: "T"
type: "type"
default_value {
type: DT_FLOAT
}
allowed_values {
list {
type: DT_FLOAT
type: DT_HALF
}
}
}
}
op {
name: "Maximum"
input_arg {
......
......@@ -154,25 +154,22 @@ Status MaxPoolGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
*g = FDH::Define(
// Arg defs
{"input: T", "grad: T"},
{"input: float", "grad: float"},
// Ret val defs
{"output: T"},
{"output: float"},
// Attr defs
{"T: {float, half} = DT_FLOAT",
"ksize: list(int) >= 4",
{"ksize: list(int) >= 4",
"strides: list(int) >= 4",
GetPaddingAttrString()},
// Nodes
{
// Invoke MaxPool again to recompute the outputs (removed by CSE?).
{{"maxpool"}, "MaxPool", {"input"},
/*Attrs=*/{{"T", "$T"},
{"ksize", "$ksize"},
/*Attrs=*/{{"ksize", "$ksize"},
{"strides", "$strides"},
{"padding", "$padding"}}},
{{"output"}, "MaxPoolGrad", {"input", "maxpool", "grad"},
/*Attrs=*/{{"T", "$T"},
{"ksize", "$ksize"},
/*Attrs=*/{{"ksize", "$ksize"},
{"strides", "$strides"},
{"padding", "$padding"}}}
});
......
......@@ -28,7 +28,7 @@ REGISTER_OP("AvgPool")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.Attr("T: {float, half, double}")
.Attr("T: {float, double}")
.Doc(R"doc(
Performs average pooling on the input.
......@@ -55,7 +55,7 @@ REGISTER_OP("AvgPoolGrad")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.Attr("T: {float, half, double}")
.Attr("T: {float, double}")
.Doc(R"doc(
Computes gradients of the average pooling function.
......@@ -642,13 +642,12 @@ output: The gradients for LRN.
// --------------------------------------------------------------------------
REGISTER_OP("MaxPool")
.Attr("T: {float, half} = DT_FLOAT")
.Attr("ksize: list(int) >= 4")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.Input("input: T")
.Output("output: T")
.Input("input: float")
.Output("output: float")
.Doc(R"doc(
Performs max pooling on the input.
......@@ -670,11 +669,10 @@ REGISTER_OP("MaxPoolGrad")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.Input("orig_input: T")
.Input("orig_output: T")
.Input("grad: T")
.Output("output: T")
.Attr("T: {float, half} = DT_FLOAT")
.Input("orig_input: float")
.Input("orig_output: float")
.Input("grad: float")
.Output("output: float")
.Doc(R"doc(
Computes gradients of the maxpooling function.
......@@ -698,10 +696,9 @@ REGISTER_OP("MaxPoolWithArgmax")
.Attr("strides: list(int) >= 4")
.Attr("Targmax: {int32, int64} = DT_INT64")
.Attr(GetPaddingAttrString())
.Input("input: T")
.Output("output: T")
.Input("input: float")
.Output("output: float")
.Output("argmax: Targmax")
.Attr("T: {float, half} = DT_FLOAT")
.Doc(R"doc(
Performs max pooling on the input and outputs both max values and indices.
......@@ -723,11 +720,10 @@ REGISTER_OP("MaxPoolGradWithArgmax")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr("Targmax: {int32, int64}")
.Input("input: T")
.Input("grad: T")
.Input("input: float")
.Input("grad: float")
.Input("argmax: Targmax")
.Output("output: T")
.Attr("T: {float, half} = DT_FLOAT")
.Output("output: float")
.Doc(R"doc(
Computes gradients of the maxpooling function.
......
......@@ -1170,7 +1170,6 @@ op {
allowed_values {
list {
type: DT_FLOAT
type: DT_HALF
type: DT_DOUBLE
}
}
......@@ -1367,7 +1366,6 @@ op {
allowed_values {
list {
type: DT_FLOAT
type: DT_HALF
type: DT_DOUBLE
}
}
......@@ -6472,25 +6470,12 @@ op {
input_arg {
name: "input"
description: "4-D input to pool over."
type_attr: "T"
type: DT_FLOAT
}
output_arg {
name: "output"
description: "The max pooled output tensor."
type_attr: "T"
}
attr {
name: "T"
type: "type"
default_value {
type: DT_FLOAT
}
allowed_values {
list {
type: DT_FLOAT
type: DT_HALF
}
}
type: DT_FLOAT
}
attr {
name: "ksize"
......@@ -6669,22 +6654,22 @@ op {
input_arg {
name: "orig_input"
description: "The original input tensor."
type_attr: "T"
type: DT_FLOAT
}
input_arg {
name: "orig_output"
description: "The original output tensor."
type_attr: "T"
type: DT_FLOAT
}
input_arg {
name: "grad"
description: "4-D. Gradients w.r.t. the output of `max_pool`."
type_attr: "T"
type: DT_FLOAT
}
output_arg {
name: "output"
description: "Gradients w.r.t. the input to `max_pool`."
type_attr: "T"
type: DT_FLOAT
}
attr {
name: "ksize"
......@@ -6725,19 +6710,6 @@ op {
}
}
}
attr {
name: "T"
type: "type"
default_value {
type: DT_FLOAT
}
allowed_values {
list {
type: DT_FLOAT
type: DT_HALF
}
}
}
summary: "Computes gradients of the maxpooling function."
}
op {
......@@ -6745,12 +6717,12 @@ op {
input_arg {
name: "input"
description: "The original input."
type_attr: "T"
type: DT_FLOAT
}
input_arg {
name: "grad"
description: "4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the\noutput of `max_pool`."
type_attr: "T"
type: DT_FLOAT
}
input_arg {
name: "argmax"
......@@ -6760,7 +6732,7 @@ op {
output_arg {
name: "output"
description: "Gradients w.r.t. the input of `max_pool`."
type_attr: "T"
type: DT_FLOAT
}
attr {
name: "ksize"
......@@ -6797,19 +6769,6 @@ op {
}
}
}
attr {
name: "T"
type: "type"
default_value {
type: DT_FLOAT
}
allowed_values {
list {
type: DT_FLOAT
type: DT_HALF
}
}
}
summary: "Computes gradients of the maxpooling function."
}
op {
......@@ -6817,12 +6776,12 @@ op {
input_arg {
name: "input"
description: "4-D with shape `[batch, height, width, channels]`. Input to pool over."
type_attr: "T"
type: DT_FLOAT
}
output_arg {
name: "output"
description: "The max pooled output tensor."
type_attr: "T"
type: DT_FLOAT
}
output_arg {
name: "argmax"
......@@ -6867,19 +6826,6 @@ op {
}
}
}
attr {
name: "T"
type: "type"
default_value {
type: DT_FLOAT
}
allowed_values {
list {
type: DT_FLOAT
type: DT_HALF
}
}
}
summary: "Performs max pooling on the input and outputs both max values and indices."
description: "The indices in `argmax` are flattened, so that a maximum value at position\n`[b, y, x, c]` becomes flattened index\n`((b * height + y) * width + x) * channels + c`."
}
......
......@@ -99,8 +99,8 @@ def GetShrunkInceptionMaxPoolShapes(shrink=30):
class PoolingTest(tf.test.TestCase):
def _VerifyOneType(self, pool_func, input_sizes, ksize, strides, padding,
data_format, data_type, expected, use_gpu):
def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding,
data_format, expected, use_gpu):
"""Verifies the output values of the pooling function.
Args:
......@@ -111,7 +111,6 @@ class PoolingTest(tf.test.TestCase):
strides: The stride dimensions
padding: Padding type.
data_format: The data format we use to run the pooling operation.
data_type: The data type to use to run the pooling operation.
expected: An array containing the expected operation outputs.
use_gpu: Whether we are running on GPU.
"""
......@@ -122,7 +121,7 @@ class PoolingTest(tf.test.TestCase):
# numbers from 1.
x = [f * 1.0 for f in range(1, total_size + 1)]
with self.test_session(use_gpu=use_gpu) as sess:
t = tf.constant(x, shape=input_sizes, dtype=data_type)
t = tf.constant(x, shape=input_sizes)
if data_format == "NCHW":
t = NHWCToNCHW(t)
ksize = NHWCToNCHW(ksize)
......@@ -132,31 +131,9 @@ class PoolingTest(tf.test.TestCase):
if data_format == "NCHW":
t = NCHWToNHWC(t)
actual = t.eval()
self.assertAllCloseAccordingToType(expected, actual.flatten())
self.assertAllClose(expected, actual.flatten())
self.assertShapeEqual(actual, t)
def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding,
data_format, expected, use_gpu):
"""Verifies the output values of the pooling function.
Args:
pool_func: Function to be called, co.MaxPool, co.AvgPool,
or the Lua version.
input_sizes: Input tensor dimensions.
ksize: The kernel size dimensions
strides: The stride dimensions
padding: Padding type.
data_format: The data format we use to run the pooling operation.
expected: An array containing the expected operation outputs.
use_gpu: Whether we are running on GPU.
"""
self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding,
data_format, tf.float32, expected, use_gpu)
if not use_gpu or test_util.CudaSupportsHalfMatMulAndConv():
self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding,
data_format, tf.float16, expected, use_gpu)
def _VerifyValues(self, pool_func, input_sizes, ksize, strides, padding,
expected, use_gpu):
"""Verifies the output values of the pooling function.
......@@ -395,41 +372,33 @@ class PoolingTest(tf.test.TestCase):
def testKernelSmallerThanStrideValid(self):
for use_gpu in [True, False]:
self._VerifyValues(tf.nn.max_pool,
input_sizes=[1, 7, 7, 1],
ksize=[1, 2, 2, 1],
strides=[1, 3, 3, 1],
padding="VALID",
expected=[9, 12, 30, 33],
use_gpu=use_gpu)
self._VerifyValues(tf.nn.avg_pool,
input_sizes=[1, 7, 7, 1],
ksize=[1, 2, 2, 1],
strides=[1, 3, 3, 1],
padding="VALID",
expected=[5, 8, 26, 29],
use_gpu=use_gpu)
def testKernelSmallerThanStrideSame(self):
for use_gpu in [True, False]:
for pool_func in [tf.nn.max_pool, tf.nn.avg_pool]:
self._VerifyValues(pool_func,
input_sizes=[1, 3, 3, 1],
ksize=[1, 1, 1, 1],
strides=[1, 2, 2, 1],
padding="SAME",
expected=[1, 3, 7, 9],
self._VerifyValues(tf.nn.max_pool, input_sizes=[1, 7, 7, 1],
ksize=[1, 2, 2, 1], strides=[1, 3, 3, 1],
padding="VALID",
expected=[9, 12, 30, 33],
use_gpu=use_gpu)
self._VerifyValues(pool_func,
input_sizes=[1, 4, 4, 1],
ksize=[1, 1, 1, 1],
strides=[1, 2, 2, 1],
padding="SAME",
expected=[1, 3, 9, 11],
self._VerifyValues(tf.nn.avg_pool, input_sizes=[1, 7, 7, 1],
ksize=[1, 2, 2, 1], strides=[1, 3, 3, 1],
padding="VALID",
expected=[5, 8, 26, 29],
use_gpu=use_gpu)
def testKernelSmallerThanStrideSame(self):
for use_gpu in [True, False]:
for pool_func in [tf.nn.max_pool, tf.nn.avg_pool]:
self._VerifyValues(pool_func, input_sizes=[1, 3, 3, 1],
ksize=[1, 1, 1, 1], strides=[1, 2, 2, 1],
padding="SAME",
expected=[1, 3, 7, 9],
use_gpu=use_gpu)
self._VerifyValues(pool_func, input_sizes=[1, 4, 4, 1],
ksize=[1, 1, 1, 1], strides=[1, 2, 2, 1],
padding="SAME",
expected=[1, 3, 9, 11],
use_gpu=use_gpu)
def _testDepthwiseMaxPoolInvalidConfig(self, in_size, ksize, strides,
error_msg, use_gpu=False):
t = tf.constant(1.0, shape=in_size)
......@@ -456,50 +425,43 @@ class PoolingTest(tf.test.TestCase):
# The following are tests that verify that the CPU and GPU implementations
# produce the same resuts.
def _CompareMaxPoolingFwd(self, input_shape, ksize, strides, padding):
for dtype in np.float32, np.float16:
tensor_input = np.random.rand(*input_shape).astype(dtype)
with self.test_session(use_gpu=True):
t = tf.constant(tensor_input, shape=input_shape)
out_op, _ = tf.nn.max_pool_with_argmax(t, ksize, strides, padding)
gpu_val = out_op.eval()
with self.test_session(use_gpu=False):
t = tf.constant(tensor_input, shape=input_shape)
out_op = tf.nn.max_pool(t, ksize, strides, padding)
cpu_val = out_op.eval()
self.assertAllCloseAccordingToType(cpu_val, gpu_val)
tensor_input = np.random.rand(*input_shape).astype(np.float32)
with self.test_session(use_gpu=True):
t = tf.constant(tensor_input, shape=input_shape)
out_op, _ = tf.nn.max_pool_with_argmax(t, ksize, strides, padding)
gpu_val = out_op.eval()
with self.test_session(use_gpu=False):
t = tf.constant(tensor_input, shape=input_shape)
out_op = tf.nn.max_pool(t, ksize, strides, padding)
cpu_val = out_op.eval()
self.assertAllClose(cpu_val, gpu_val, rtol=1e-5, atol=1e-5)
def _CompareMaxPoolingBk(self, input_shape, output_shape, ksize, strides,
padding):
for dtype in np.float32, np.float16:
# Generate numbers in a narrow range, so that there are many duplicates
# in the input.
tensor_input = np.random.random_integers(0, 3, input_shape).astype(dtype)
tensor_output = np.random.rand(*output_shape).astype(dtype)
with self.test_session(use_gpu=True):
t = tf.constant(tensor_input, shape=input_shape)
_, argmax_op = tf.nn.max_pool_with_argmax(t, ksize, strides, padding)
argmax = argmax_op.eval()
grad_in = tf.constant(tensor_output, shape=output_shape)
out_op = gen_nn_ops._max_pool_grad_with_argmax(t, grad_in, argmax,
ksize, strides, padding)
gpu_val = out_op.eval()
self.assertShapeEqual(gpu_val, out_op)
with self.test_session(use_gpu=False):
t = tf.constant(tensor_input, shape=input_shape)
out_op = tf.nn.max_pool(t, ksize, strides, padding)
orig_out = out_op.eval()
grad_in = tf.constant(tensor_output, shape=output_shape)
out_op = gen_nn_ops._max_pool_grad(t, orig_out, grad_in, ksize, strides,
padding)
cpu_val = out_op.eval()
self.assertShapeEqual(cpu_val, out_op)
if dtype == np.float16:
# The CPU version accumulates its gradient on fp16, so it's less
# accurate than the GPU version that does the accumulation on fp32
self.assertAllClose(cpu_val, gpu_val, rtol=0.01, atol=0.01)
else:
self.assertAllClose(cpu_val, gpu_val)
# Generate numbers in a narrow range, so that there are many duplicates
# in the input.
tensor_input = np.random.random_integers(0, 3,
input_shape).astype(np.float32)
tensor_output = np.random.rand(*output_shape).astype(np.float32)
with self.test_session(use_gpu=True):
t = tf.constant(tensor_input, shape=input_shape)
_, argmax_op = tf.nn.max_pool_with_argmax(t, ksize, strides, padding)
argmax = argmax_op.eval()
grad_in = tf.constant(tensor_output, shape=output_shape)
out_op = gen_nn_ops._max_pool_grad_with_argmax(t, grad_in, argmax,
ksize, strides, padding)
gpu_val = out_op.eval()
self.assertShapeEqual(gpu_val, out_op)
with self.test_session(use_gpu=False):
t = tf.constant(tensor_input, shape=input_shape)
out_op = tf.nn.max_pool(t, ksize, strides, padding)
orig_out = out_op.eval()
grad_in = tf.constant(tensor_output, shape=output_shape)
out_op = gen_nn_ops._max_pool_grad(t, orig_out, grad_in, ksize,
strides, padding)
cpu_val = out_op.eval()
self.assertShapeEqual(cpu_val, out_op)
self.assertAllClose(cpu_val, gpu_val, rtol=1e-5, atol=1e-5)
def testMaxPoolingWithArgmax(self):
# MaxPoolWithArgMax is implemented only on GPU.
......
......@@ -1874,40 +1874,6 @@ bool CudnnSupport::DoPoolForward(
return true;
}
bool CudnnSupport::DoPoolForward(
Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
const dnn::BatchDescriptor& input_dimensions,
const DeviceMemory<Eigen::half>& input_data,
const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<Eigen::half>* output_data) {
mutex_lock lock{dnn_handle_mutex_};
auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
return false;
}
// Alpha is the scaling factor for input.
float alpha = 1.0;
// Beta is the scaling factor for output.
float beta = 0.0;
ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_HALF};
ScopedTensorDescriptor dest_desc{parent_, output_dimensions, CUDNN_DATA_HALF};
ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
status = dynload::cudnnPoolingForward(
parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
src_desc.handle(), input_data.opaque(), &beta, dest_desc.handle(),
output_data->opaque());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "failed to enqueue forward pooling on stream: "
<< ToString(status);
return false;
}
return true;
}
bool CudnnSupport::DoPoolBackward(
Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
const dnn::BatchDescriptor& input_dimensions,
......@@ -1946,43 +1912,6 @@ bool CudnnSupport::DoPoolBackward(
return true;
}
bool CudnnSupport::DoPoolBackward(
Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
const dnn::BatchDescriptor& input_dimensions,
const DeviceMemory<Eigen::half>& input_data,
const dnn::BatchDescriptor& output_dimensions,
const DeviceMemory<Eigen::half>& output_data,
const DeviceMemory<Eigen::half>& input_diff_data,
DeviceMemory<Eigen::half>* output_diff_data) {
mutex_lock lock{dnn_handle_mutex_};
auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
return false;
}
// Alpha is the scaling factor for input.
float alpha = 1.0;
// Beta is the scaling factor for output.
float beta = 0.0;
ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_HALF};
ScopedTensorDescriptor dest_desc{parent_, output_dimensions, CUDNN_DATA_HALF};
ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
status = dynload::cudnnPoolingBackward(
parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
dest_desc.handle(), output_data.opaque(), dest_desc.handle(),
input_diff_data.opaque(), src_desc.handle(), input_data.opaque(), &beta,
src_desc.handle(), output_diff_data->opaque());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "failed to enqueue backward pooling on stream: "
<< ToString(status);
return false;
}
return true;
}
bool CudnnSupport::DoNormalize(
Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
......
......@@ -201,13 +201,6 @@ class CudnnSupport : public dnn::DnnSupport {
const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<float>* output_data) override;
bool DoPoolForward(Stream* stream,
const dnn::PoolingDescriptor& pooling_dimensions,
const dnn::BatchDescriptor& input_dimensions,
const DeviceMemory<Eigen::half>& input_data,
const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<Eigen::half>* output_data) override;
bool DoPoolBackward(Stream* stream,
const dnn::PoolingDescriptor& pooling_dimensions,
const dnn::BatchDescriptor& input_dimensions,
......@@ -217,15 +210,6 @@ class CudnnSupport : public dnn::DnnSupport {
const DeviceMemory<float>& input_diff_data,
DeviceMemory<float>* output_diff_data) override;
bool DoPoolBackward(Stream* stream,
const dnn::PoolingDescriptor& pooling_dimensions,
const dnn::BatchDescriptor& input_dimensions,
const DeviceMemory<Eigen::half>& input_data,
const dnn::BatchDescriptor& output_dimensions,
const DeviceMemory<Eigen::half>& output_data,
const DeviceMemory<Eigen::half>& input_diff_data,
DeviceMemory<Eigen::half>* output_diff_data) override;
bool DoNormalize(Stream* stream,
const dnn::NormalizeDescriptor& normalize_descriptor,
const DeviceMemory<float>& input_data,
......
......@@ -1011,13 +1011,6 @@ class DnnSupport {
const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<float>* output_data) = 0;
virtual bool DoPoolForward(Stream* stream,
const dnn::PoolingDescriptor& pooling_dimensions,
const dnn::BatchDescriptor& input_dimensions,
const DeviceMemory<Eigen::half>& input_data,
const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<Eigen::half>* output_data) = 0;
// Performs differentiation of the pooling operation.
virtual bool DoPoolBackward(Stream* stream,
const dnn::PoolingDescriptor& pooling_dimensions,
......@@ -1028,15 +1021,6 @@ class DnnSupport {
const DeviceMemory<float>& input_diff_data,
DeviceMemory<float>* output_diff_data) = 0;
virtual bool DoPoolBackward(Stream* stream,
const dnn::PoolingDescriptor& pooling_dimensions,
const dnn::BatchDescriptor& input_dimensions,
const DeviceMemory<Eigen::half>& input_data,
const dnn::BatchDescriptor& output_dimensions,
const DeviceMemory<Eigen::half>& output_data,
const DeviceMemory<Eigen::half>& input_diff_data,
DeviceMemory<Eigen::half>* output_diff_data) = 0;
// Applies local response normalization to the values from
// input_data and writes the result to output_data. See comments on
// NormalizeDescriptor for a description of local response
......
......@@ -909,30 +909,6 @@ Stream &Stream::ThenPoolForward(
return *this;
}
Stream &Stream::ThenPoolForward(
const dnn::PoolingDescriptor &pooling_dimensions,
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<Eigen::half> &input_data,
const dnn::BatchDescriptor &output_dimensions,
DeviceMemory<Eigen::half> *output_data) {
VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
PARAM(input_data), PARAM(output_dimensions), PARAM(output_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
input_data, output_dimensions,
output_data));
} else {
SetError();
LOG(WARNING)
<< "attempting to perform DNN operation using StreamExecutor "
"without DNN support";
}
}
return *this;
}
Stream &Stream::ThenPoolBackward(
const dnn::PoolingDescriptor &pooling_dimensions,
const dnn::BatchDescriptor &input_dimensions,
......@@ -960,33 +936,6 @@ Stream &Stream::ThenPoolBackward(
return *this;
}
Stream &Stream::ThenPoolBackward(
const dnn::PoolingDescriptor &pooling_dimensions,
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<Eigen::half> &input_data,
const dnn::BatchDescriptor &output_dimensions,
const DeviceMemory<Eigen::half> &output_data,
const DeviceMemory<Eigen::half> &input_diff_data,
DeviceMemory<Eigen::half> *output_diff_data) {
VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
PARAM(input_diff_data), PARAM(output_diff_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
input_data, output_dimensions, output_data,
input_diff_data, output_diff_data));
} else {
SetError();
LOG(WARNING)
<< "attempting to perform DNN operation using StreamExecutor "
"without DNN support";
}
}
return *this;
}
Stream &Stream::ThenNormalize(
const dnn::NormalizeDescriptor &normalize_descriptor,
const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) {
......
......@@ -418,12 +418,6 @@ class Stream {
const dnn::BatchDescriptor &output_dimensions,
DeviceMemory<float> *output_data);
Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<Eigen::half> &input_data,
const dnn::BatchDescriptor &output_dimensions,
DeviceMemory<Eigen::half> *output_data);
Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<float> &input_data,
......@@ -432,14 +426,6 @@ class Stream {
const DeviceMemory<float> &input_diff_data,
DeviceMemory<float> *output_diff_data);
Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<Eigen::half> &input_data,
const dnn::BatchDescriptor &output_dimensions,
const DeviceMemory<Eigen::half> &output_data,
const DeviceMemory<Eigen::half> &input_diff_data,
DeviceMemory<Eigen::half> *output_diff_data);
Stream &ThenNormalize(const dnn::NormalizeDescriptor &normalize_descriptor,
const DeviceMemory<float> &input_data,
DeviceMemory<float> *output_data);
......
......@@ -6,8 +6,8 @@
def tf_workspace(path_prefix = "", tf_repo_name = ""):
native.new_http_archive(
name = "eigen_archive",
url = "https://bitbucket.org/eigen/eigen/get/0c0b79ecd74c.tar.gz",
sha256 = "b4b5884b03bd4bae114d02b36e2435ad1504ed8e51431d16c876b6f6a365882b",
url = "https://bitbucket.org/eigen/eigen/get/d02e6a705c30.tar.gz",
sha256 = "532956172daa8aba87c750791ff89a5c38cdb07e2525afe17ecb4bef812d67cf",
build_file = path_prefix + "eigen.BUILD",
)
......@@ -178,3 +178,4 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
sha256 = "36658cb768a54c1d4dec43c3116c27ed893e88b02ecfcb44f2166f9c0b7f2a0d",
build_file = path_prefix + "zlib.BUILD",
)
#include "eigen-eigen-0c0b79ecd74c/Eigen/Cholesky"
#include "eigen-eigen-d02e6a705c30/Eigen/Cholesky"
#include "eigen-eigen-0c0b79ecd74c/Eigen/Core"
#include "eigen-eigen-d02e6a705c30/Eigen/Core"
#include "eigen-eigen-0c0b79ecd74c/Eigen/Eigenvalues"
#include "eigen-eigen-d02e6a705c30/Eigen/Eigenvalues"
#include "eigen-eigen-0c0b79ecd74c/Eigen/LU"
#include "eigen-eigen-d02e6a705c30/Eigen/LU"
#include "eigen-eigen-0c0b79ecd74c/Eigen/QR"
#include "eigen-eigen-d02e6a705c30/Eigen/QR"
#include "eigen-eigen-0c0b79ecd74c/unsupported/Eigen/CXX11/Tensor"
#include "eigen-eigen-d02e6a705c30/unsupported/Eigen/CXX11/Tensor"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册