diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 9c67df7bdfb2c4e5d1c9fe60676c412ab11b4fa5..a84b3bccefa3f646dc3310f22711bdd76695b20f 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -301,6 +301,7 @@ op_library(fusion_lstm_op DEPS cpu_lstm_compute) if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) op_library(layer_norm_op DEPS cub) + op_library(reduce_mean_op DEPS cub) else() op_library(conv_op DEPS vol2col im2col) endif() diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index b3140116dfe6a17a400bb88219ff43b249ecb32a..ef76106f17218a03d24ebc0eca43dbb0ae935093 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -380,7 +380,8 @@ class DepthwiseConvKernel : public framework::OpKernel { math::DepthwiseConvFunctor depthwiseConv; auto& dev_ctx = context.template device_context(); - depthwiseConv(dev_ctx, *input, filter, strides, paddings, output); + depthwiseConv(dev_ctx, *input, filter, strides, paddings, dilations, + output); } }; @@ -415,14 +416,14 @@ class DepthwiseConvGradKernel : public framework::OpKernel { input_grad->mutable_data(context.GetPlace()); set_zero(dev_ctx, input_grad, static_cast(0)); depthwiseConvInputGrad(dev_ctx, *input, filter, *output_grad, strides, - paddings, input_grad); + paddings, dilations, input_grad); } if (filter_grad) { filter_grad->mutable_data(context.GetPlace()); set_zero(dev_ctx, filter_grad, static_cast(0)); depthwiseConvFilterGrad(dev_ctx, *input, *output_grad, strides, paddings, - filter_grad); + dilations, filter_grad); } } }; diff --git a/paddle/fluid/operators/conv_transpose_op.h b/paddle/fluid/operators/conv_transpose_op.h index 0d9c6a62fec1ea24bee5c24b4a7b792781f14d9e..88c578b1410558b9adcd55f1cd6b53fb9cb124e2 100644 --- a/paddle/fluid/operators/conv_transpose_op.h +++ b/paddle/fluid/operators/conv_transpose_op.h @@ -345,7 +345,7 @@ class DepthwiseConvTransposeKernel : public framework::OpKernel { math::DepthwiseConvInputGradFunctor depthwiseConvInputGrad; depthwiseConvInputGrad(dev_ctx, *output, filter, *input, strides, paddings, - output); + dilations, output); } }; @@ -367,10 +367,11 @@ class DepthwiseConvTransposeGradKernel : public framework::OpKernel { auto& dev_ctx = context.template device_context(); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); + std::vector dilations = context.Attr>("dilations"); if (input_grad) { math::DepthwiseConvFunctor depthwiseConv; - depthwiseConv(dev_ctx, *output_grad, filter, strides, paddings, + depthwiseConv(dev_ctx, *output_grad, filter, strides, paddings, dilations, input_grad); } @@ -382,7 +383,7 @@ class DepthwiseConvTransposeGradKernel : public framework::OpKernel { math::DepthwiseConvFilterGradFunctor depthwiseConvFilterGrad; depthwiseConvFilterGrad(dev_ctx, *output_grad, *input, strides, paddings, - filter_grad); + dilations, filter_grad); } } }; diff --git a/paddle/fluid/operators/cub_reduce.h b/paddle/fluid/operators/cub_reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..16fdad775f7befaac04b1ac59a601f04e0ab2bdc --- /dev/null +++ b/paddle/fluid/operators/cub_reduce.h @@ -0,0 +1,322 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include // NOLINT +#include "paddle/fluid/framework/tensor.h" + +namespace paddle { +namespace operators { + +namespace detail { +template +struct Array { + public: + HOSTDEVICE inline Array() {} + + HOSTDEVICE inline T& operator[](size_t index) { return data_[index]; } + + HOSTDEVICE inline const T& operator[](size_t index) const { + return data_[index]; + } + + HOSTDEVICE constexpr inline size_t size() const { return ElementCount; } + + template + static inline Array From(const VectorLikeType& vec) { + PADDLE_ENFORCE_EQ(vec.size(), ElementCount, "size not match"); + size_t n = static_cast(vec.size()); + Array ret; + for (size_t i = 0; i < n; ++i) ret[i] = vec[i]; + return ret; + } + + private: + T data_[ElementCount]; +}; + +// reduce the last axis of 2d array +template +__global__ void ReduceKernel2D(const Tx* x, Ty* y, ReduceOp reducer, + TransformOp transformer, Ty init, + int reduce_num) { + __shared__ typename cub::BlockReduce::TempStorage temp_storage; + int idx_x = blockIdx.x * reduce_num; + int idx_y = threadIdx.x; + Ty reduce_var = init; + for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += BlockDim) + reduce_var = reducer(reduce_var, transformer(x[idx_x + idx_y])); + + reduce_var = + cub::BlockReduce(temp_storage).Reduce(reduce_var, reducer); + + if (threadIdx.x == 0) { + y[blockIdx.x] = reduce_var; + } +} + +template +__global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer, + TransformOp transformer, Ty init, int reduce_num, + Array x_strides, + Array reduce_dim, + Array reduce_strides, + Array left_dim, + Array left_strides) { + __shared__ typename cub::BlockReduce::TempStorage temp_storage; + Array sub_index; + int left_idx = blockIdx.x; + for (int i = 0; i < Rank - ReduceRank; ++i) { + sub_index[left_dim[i]] = left_idx / left_strides[i]; + left_idx %= left_strides[i]; + } + + int reduce_idx = threadIdx.x; + for (int j = 0; j < ReduceRank; ++j) { + sub_index[reduce_dim[j]] = reduce_idx / reduce_strides[j]; + reduce_idx %= reduce_strides[j]; + } + + int idx_x = 0; + for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]); + Ty reduce_var = static_cast(transformer(x[idx_x])); + + for (int i = threadIdx.x + BlockDim; i < reduce_num; i += BlockDim) { + int reduce_idx = i; + for (int j = 0; j < ReduceRank; ++j) { + sub_index[reduce_dim[j]] = reduce_idx / reduce_strides[j]; + reduce_idx %= reduce_strides[j]; + } + + int idx_x = 0; + for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]); + reduce_var = static_cast(reducer(reduce_var, transformer(x[idx_x]))); + } + + reduce_var = + cub::BlockReduce(temp_storage).Reduce(reduce_var, reducer); + + if (threadIdx.x == 0) { + y[blockIdx.x] = reduce_var; + } +} + +static inline std::vector GetStrides(const std::vector& dims) { + int n = static_cast(dims.size()); + if (n == 0) return std::vector(); + std::vector strides(n); + strides.back() = 1; + for (int i = n - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * dims[i + 1]; + } + return strides; +} + +static inline std::vector GetStrides(const std::vector& dims, + const std::vector& idx) { + int n = static_cast(idx.size()); + if (n == 0) return std::vector(); + std::vector strides(n); + strides.back() = 1; + for (int i = n - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * dims[idx[i + 1]]; + } + return strides; +} + +constexpr int kMaxBlockDim = 512; + +static inline int GetDesiredBlockDim(int block_dim) { + return block_dim >= kMaxBlockDim + ? kMaxBlockDim + : (1 << static_cast(std::log2(block_dim))); +} + +template +static void TensorReduceImpl( + const Tx* x_data, Ty* y_data, const platform::Place& place, + const ReduceOp& reducer, const TransformOp& transformer, const Ty& init, + int left_num, int reduce_num, const std::vector& x_strides, + const std::vector& reduce_dim, const std::vector& reduce_strides, + const std::vector& left_dim, const std::vector& left_strides, + cudaStream_t stream) { +#define CUB_RANK_CASE(i, ...) \ + case i: { \ + constexpr auto kRank = i; \ + switch (reduce_rank) { __VA_ARGS__; } \ + } break + +#define CUB_REDUCE_RANK_CASE(i, ...) \ + case i: { \ + constexpr auto kReduceRank = i; \ + ReduceKernel<<>>( \ + x_data, y_data, reducer, transformer, init, reduce_num, \ + Array::From(x_strides), \ + Array::From(reduce_dim), \ + Array::From(reduce_strides), \ + Array::From(left_dim), \ + Array::From(left_strides)); \ + } break + + int rank = x_strides.size(); + int reduce_rank = reduce_strides.size(); + if (rank == reduce_rank) { + cub::TransformInputIterator trans_x( + x_data, transformer); + size_t temp_storage_bytes = 0; + cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data, + reduce_num, reducer, init, stream); + framework::Tensor tmp; + auto* temp_storage = tmp.mutable_data( + framework::make_ddim({static_cast(temp_storage_bytes)}), + place); + cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data, + reduce_num, reducer, init, stream); + return; + } + if (rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) { + ReduceKernel2D<<>>( + x_data, y_data, reducer, transformer, init, reduce_num); + return; + } + /* + if (rank == 3 && reduce_rank == 1 && reduce_dim[0] == 1) { + // TODO(liangdun): we can optimize 3d case which the 2nd axis is reduced. + // Currently, it is handled by code below, but inefficient + return; + } + */ + + switch (rank) { + CUB_RANK_CASE(2, CUB_REDUCE_RANK_CASE(1);); + + CUB_RANK_CASE(3, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2);); + + CUB_RANK_CASE(4, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2); + CUB_REDUCE_RANK_CASE(3);); + + CUB_RANK_CASE(5, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2); + CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4);); + + CUB_RANK_CASE(6, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2); + CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4); + CUB_REDUCE_RANK_CASE(5);); + + CUB_RANK_CASE(7, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2); + CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4); + CUB_REDUCE_RANK_CASE(5); CUB_REDUCE_RANK_CASE(6);); + + CUB_RANK_CASE(8, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2); + CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4); + CUB_REDUCE_RANK_CASE(5); CUB_REDUCE_RANK_CASE(6);); + + CUB_RANK_CASE(9, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2); + CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4); + CUB_REDUCE_RANK_CASE(5); CUB_REDUCE_RANK_CASE(6); + CUB_REDUCE_RANK_CASE(7); CUB_REDUCE_RANK_CASE(8);); + } + +#undef CUB_REDUCE_RANK_CASE +#undef CUB_RANK_CASE +} + +} // namespace detail + +template +void TensorReduce(const framework::Tensor& x, framework::Tensor* y, + std::vector origin_reduce_dims, const Ty& init, + const ReduceOp& reducer, const TransformOp& transformer, + cudaStream_t stream) { + auto x_dim = framework::vectorize2int(x.dims()); + std::vector new_x_dim, new_reduce_dims; + int is_reduced = 0; + for (auto e : origin_reduce_dims) { + auto pos = e >= 0 ? e : e + x_dim.size(); + is_reduced |= 1 << e; + } + for (int i = 0; i < x_dim.size(); i++) { + if ((i == 0) || (((is_reduced >> i) ^ (is_reduced >> (i - 1))) & 1)) { + new_x_dim.push_back(x_dim[i]); + if ((is_reduced >> i) & 1) + new_reduce_dims.push_back(new_x_dim.size() - 1); + } else { + new_x_dim[new_x_dim.size() - 1] *= x_dim[i]; + } + } + x_dim = new_x_dim; + origin_reduce_dims = new_reduce_dims; + int x_rank = static_cast(x_dim.size()); + std::set left_set, reduce_set; + for (int i = 0; i < x_rank; ++i) left_set.insert(i); + + for (auto e : origin_reduce_dims) { + left_set.erase(e); + reduce_set.insert(e); + } + + std::vector reduce_dim(reduce_set.begin(), reduce_set.end()); + std::vector left_dim(left_set.begin(), left_set.end()); + + std::vector x_strides = detail::GetStrides(x_dim); + std::vector reduce_strides = detail::GetStrides(x_dim, reduce_dim); + std::vector left_strides = detail::GetStrides(x_dim, left_dim); + int reduce_num = reduce_strides[0] * x_dim[reduce_dim[0]]; + int left_num = 1; + if (left_dim.size()) left_num = left_strides[0] * x_dim[left_dim[0]]; + + std::vector y_dim(left_dim.size()); + for (int i = 0; i < left_dim.size(); ++i) { + y_dim[i] = x_dim[left_dim[i]]; + } + auto x_data = x.data(); + auto y_data = y->mutable_data(x.place()); + if (reduce_num == 1) return; + +#define CUB_BLOCK_DIM_CASE(block_dim) \ + case block_dim: { \ + constexpr auto kBlockDim = block_dim; \ + detail::TensorReduceImpl( \ + x_data, y_data, x.place(), reducer, transformer, init, left_num, \ + reduce_num, x_strides, reduce_dim, reduce_strides, left_dim, \ + left_strides, stream); \ + } break + + switch (detail::GetDesiredBlockDim(reduce_num)) { + CUB_BLOCK_DIM_CASE(512); + CUB_BLOCK_DIM_CASE(256); + CUB_BLOCK_DIM_CASE(128); + CUB_BLOCK_DIM_CASE(64); + CUB_BLOCK_DIM_CASE(32); + CUB_BLOCK_DIM_CASE(16); + CUB_BLOCK_DIM_CASE(8); + CUB_BLOCK_DIM_CASE(4); + CUB_BLOCK_DIM_CASE(2); + } +#undef CUB_BLOCK_DIM_CASE +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/depthwise_conv.cu b/paddle/fluid/operators/math/depthwise_conv.cu index 027e2de48d229761f12f974dc73625c8ea1b3567..3be389912307f7aac6dda6d1018943eb8f08696d 100644 --- a/paddle/fluid/operators/math/depthwise_conv.cu +++ b/paddle/fluid/operators/math/depthwise_conv.cu @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include "paddle/fluid/operators/math/depthwise_conv.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -20,149 +21,268 @@ namespace paddle { namespace operators { namespace math { +template +__inline__ __device__ T warpReduceSum(T val) { +#if CUDA_VERSION < 9000 + for (int offset = 16; offset > 0; offset /= 2) + val += __shfl_down(val, offset); + return val; +#else +#define FULL_MASK 0xffffffff + for (int offset = 16; offset > 0; offset /= 2) + val += __shfl_down_sync(FULL_MASK, val, offset); + return val; +#endif +} +__forceinline__ __device__ unsigned lane_id() { + unsigned ret; + asm volatile("mov.u32 %0, %laneid;" : "=r"(ret)); + return ret; +} + +__forceinline__ __device__ unsigned warp_id() { + unsigned ret; + asm volatile("mov.u32 %0, %warpid;" : "=r"(ret)); + return ret; +} + // A Cuda kernel to compute the depthwise convolution forward pass // in NCHW format. template -__global__ void KernelDepthwiseConv( - const int nthreads, const T* const input_data, const T* const filter_data, - const int batch_size, const int output_channels, const int output_height, - const int output_width, const int input_channels, const int input_height, - const int input_width, const int filter_multiplier, const int filter_height, +__device__ __inline__ void KernelDepthwiseConv( + const T* const input_data, const T* const filter_data, const int batch_size, + const int output_channels, const int output_height, const int output_width, + const int input_channels, const int input_height, const int input_width, + const int filter_multiplier, const int filter_height, const int filter_width, const int stride_height, const int stride_width, - const int padding_height, const int padding_width, T* const output_data) { - int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; - - if (index < nthreads) { - const int batch = index / output_channels / output_height / output_width; - const int c_out = (index / output_height / output_width) % output_channels; - const int h_out = (index / output_width) % output_height; - const int w_out = index % output_width; - - const int c_in = c_out / filter_multiplier; - const T* weight = filter_data + c_out * filter_height * filter_width; - T value = 0; - const int h_in_start = -padding_height + h_out * stride_height; - const int w_in_start = -padding_width + w_out * stride_width; - const int h_in_end = h_in_start + filter_height; - const int w_in_end = w_in_start + filter_width; - - const int in_offset = - ((batch * input_channels + c_in) * input_height) * input_width; - - const int h_end = h_in_end < input_height ? h_in_end : input_height; - const int w_end = w_in_end < input_width ? w_in_end : input_width; - const int h_start = h_in_start > 0 ? h_in_start : 0; - const int w_start = w_in_start > 0 ? w_in_start : 0; - - for (int h_in = h_start; h_in < h_end; h_in++) { - for (int w_in = w_start; w_in < w_end; w_in++) { - const int offset = in_offset + h_in * input_width + w_in; - value += - weight[(h_in - h_in_start) * filter_width + (w_in - w_in_start)] * - input_data[offset]; + const int padding_height, const int padding_width, const int dilate_height, + const int dilate_width, T* const output_data) { + for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) { + for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) { + const int batch = blockIdx.y; + const int c_out = blockIdx.x; + + const int c_in = c_out / filter_multiplier; + const T* weight = filter_data + c_out * filter_height * filter_width; + T value = 0; + const int h_in_start = -padding_height + h_out * stride_height; + const int w_in_start = -padding_width + w_out * stride_width; + const int h_in_end = h_in_start + filter_height * dilate_height; + const int w_in_end = w_in_start + filter_width * dilate_width; + + const int in_offset = + ((batch * input_channels + c_in) * input_height) * input_width; + + const int h_end = h_in_end < input_height ? h_in_end : input_height; + const int w_end = w_in_end < input_width ? w_in_end : input_width; + const int h_start = h_in_start > 0 ? h_in_start : 0; + const int w_start = w_in_start > 0 ? w_in_start : 0; + int weight_offset = 0; + + for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) { + for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) { + if (h_in >= h_start && h_in < h_end && w_in >= w_start && + w_in < w_end) { + const int offset = in_offset + h_in * input_width + w_in; + value += weight[weight_offset] * input_data[offset]; + } + weight_offset++; + } } + int index = + ((batch * gridDim.x + c_out) * output_height + h_out) * output_width + + w_out; + output_data[index] = value; } - output_data[index] = value; } } +template +__global__ void KernelDepthwiseConvSp( + const T* const input_data, const T* const filter_data, const int batch_size, + const int output_channels, const int output_height, const int output_width, + const int input_channels, const int input_height, const int input_width, + const int filter_multiplier, const int filter_height, + const int filter_width, const int stride_height, const int stride_width, + const int padding_height, const int padding_width, const int dilate_height, + const int dilate_width, T* const output_data) { + if (c_filter_multiplier == 0) + KernelDepthwiseConv(input_data, filter_data, batch_size, output_channels, + output_height, output_width, input_channels, + input_height, input_width, filter_multiplier, + filter_height, filter_width, stride_height, + stride_width, padding_height, padding_width, + dilate_height, dilate_width, output_data); + + else + KernelDepthwiseConv(input_data, filter_data, batch_size, output_channels, + output_height, output_width, input_channels, + input_height, input_width, c_filter_multiplier, + filter_height, filter_height, c_stride, c_stride, + padding_height, padding_width, dilate_height, + dilate_width, output_data); +} + // CUDA kernel to compute the depthwise convolution backprop w.r.t input. template -__global__ void KernelDepthwiseConvInputGrad( - const int nthreads, const T* const output_grad_data, - const T* const filter_data, const int batch_size, const int output_channels, - const int output_height, const int output_width, const int input_channels, - const int input_height, const int input_width, const int filter_multiplier, - const int filter_height, const int filter_width, const int stride_height, - const int stride_width, const int padding_height, const int padding_width, - T* const input_grad_data) { - int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; - if (index < nthreads) { - const int batch = index / input_channels / input_height / input_width; - const int c_in = (index / input_height / input_width) % input_channels; - const int h_in = (index / input_width) % input_height; - const int w_in = index % input_width; - - const int c_out_start = c_in * filter_multiplier; - - int h_out_start = - (h_in - filter_height + padding_height + stride_height) / stride_height; - h_out_start = 0 > h_out_start ? 0 : h_out_start; - - int h_out_end = (h_in + padding_height) / stride_height; - h_out_end = output_height - 1 < h_out_end ? output_height - 1 : h_out_end; - - int w_out_start = - (w_in - filter_width + padding_width + stride_width) / stride_width; - w_out_start = 0 > w_out_start ? 0 : w_out_start; - - int w_out_end = (w_in + padding_width) / stride_width; - w_out_end = output_width - 1 < w_out_end ? output_width - 1 : w_out_end; - - T value = 0; - - for (int c_out = c_out_start; c_out < c_out_start + filter_multiplier; - c_out++) { - for (int h_out = h_out_start; h_out <= h_out_end; ++h_out) { - const int filter_h = h_in + padding_height - h_out * stride_height; - for (int w_out = w_out_start; w_out <= w_out_end; ++w_out) { - const int filter_w = w_in + padding_width - w_out * stride_width; - const int filter_offset = c_out * filter_height * filter_width + - filter_h * filter_width + filter_w; - const int output_grad_offset = - ((batch * output_channels + c_out) * output_height + h_out) * - output_width + - w_out; - value += - output_grad_data[output_grad_offset] * filter_data[filter_offset]; +__device__ __inline__ void KernelDepthwiseConvInputGrad( + const T* const output_grad_data, const T* const filter_data, + const int batch_size, const int output_channels, const int output_height, + const int output_width, const int input_channels, const int input_height, + const int input_width, const int filter_multiplier, const int filter_height, + const int filter_width, const int stride_height, const int stride_width, + const int padding_height, const int padding_width, const int dilate_height, + const int dilate_width, T* const input_grad_data) { + for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) { + for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) { + const int batch = blockIdx.y; + const int c_in = blockIdx.x; + + const int c_out_start = c_in * filter_multiplier; + + int h_out_start = + h_in - (filter_height - 1) * dilate_height + padding_height; + + int h_out_end = h_in + padding_height; + + int w_out_start = + w_in - (filter_width - 1) * dilate_width + padding_width; + + int w_out_end = w_in + padding_width; + + T value = 0; + + for (int c_out = c_out_start; c_out < c_out_start + filter_multiplier; + c_out++) { + int filter_offset = (c_out + 1) * filter_height * filter_width; + for (int h_out = h_out_start; h_out <= h_out_end; + h_out += dilate_height) { + for (int w_out = w_out_start; w_out <= w_out_end; + w_out += dilate_width) { + filter_offset--; + int s_h_out = h_out / stride_height; + int s_w_out = w_out / stride_width; + if (h_out % stride_height == 0 && w_out % stride_width == 0 && + s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 && + s_w_out < output_width) { + const int output_grad_offset = + ((batch * output_channels + c_out) * output_height + + s_h_out) * + output_width + + s_w_out; + value += output_grad_data[output_grad_offset] * + filter_data[filter_offset]; + } + } } } + int index = + ((batch * gridDim.x + c_in) * input_height + h_in) * input_width + + w_in; + input_grad_data[index] = value; } - input_grad_data[index] += value; } } +template +__global__ void KernelDepthwiseConvInputGradSp( + const T* const output_grad_data, const T* const filter_data, + const int batch_size, const int output_channels, const int output_height, + const int output_width, const int input_channels, const int input_height, + const int input_width, const int filter_multiplier, const int filter_height, + const int filter_width, const int stride_height, const int stride_width, + const int padding_height, const int padding_width, const int dilate_height, + const int dilate_width, T* const input_grad_data) { + if (c_filter_multiplier == 0) + KernelDepthwiseConvInputGrad( + output_grad_data, filter_data, batch_size, output_channels, + output_height, output_width, input_channels, input_height, input_width, + filter_multiplier, filter_height, filter_width, stride_height, + stride_width, padding_height, padding_width, dilate_height, + dilate_width, input_grad_data); + else + KernelDepthwiseConvInputGrad( + output_grad_data, filter_data, batch_size, output_channels, + output_height, output_width, input_channels, input_height, input_width, + c_filter_multiplier, filter_height, filter_width, c_stride, c_stride, + padding_height, padding_width, dilate_height, dilate_width, + input_grad_data); +} + // Cuda kernel to compute the depthwise convolution backprop w.r.t. filter. template -__global__ void KernelDepthwiseConvFilterGrad( - const int nthreads, const T* const output_grad_data, - const T* const input_data, const int num, const int output_channels, - const int output_height, const int output_width, const int input_channels, - const int input_height, const int input_width, const int filter_multiplier, - const int filter_height, const int filter_width, const int stride_height, - const int stride_width, const int padding_height, const int padding_width, - T* const filter_grad_data) { - int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; - if (index < nthreads) { - const int w_out = index % output_width; - const int h_out = (index / output_width) % output_height; - const int c_out = (index / output_width / output_height) % output_channels; - const int batch = (index / output_width / output_height / output_channels); - const int c_in = c_out / filter_multiplier; - const int h_in_start = -padding_height + h_out * stride_height; - const int w_in_start = -padding_width + w_out * stride_width; - const int h_in_end = - -padding_height + h_out * stride_height + filter_height; - const int w_in_end = -padding_width + w_out * stride_width + filter_width; - const int in_offset = - (batch * input_channels + c_in) * input_height * input_width; - - T* addr_offset = filter_grad_data + c_out * filter_height * filter_width; - const int h_end = h_in_end < input_height ? h_in_end : input_height; - const int w_end = w_in_end < input_width ? w_in_end : input_width; - const int h_start = h_in_start > 0 ? h_in_start : 0; - const int w_start = w_in_start > 0 ? w_in_start : 0; - - for (int h_in = h_start; h_in < h_end; h_in++) { - for (int w_in = w_start; w_in < w_end; w_in++) { - const int offset = in_offset + h_in * input_width + w_in; - const T diff_temp = output_grad_data[index] * input_data[offset]; - T* addr = addr_offset + (h_in - h_in_start) * filter_width + - (w_in - w_in_start); - paddle::platform::CudaAtomicAdd(addr, diff_temp); +__device__ __inline__ void KernelDepthwiseConvFilterGrad( + const T* output_grad_data, const T* input_data, const int num, + const int output_channels, const int output_height, const int output_width, + const int input_channels, const int input_height, const int input_width, + const int filter_multiplier, const int filter_height, + const int filter_width, const int stride_height, const int stride_width, + const int padding_height, const int padding_width, const int dilate_height, + const int dilate_width, T* filter_grad_data) { + T s = 0; + + int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x; + int lid = lane_id(); + + for (int image_w = threadIdx.x; image_w < output_width; + image_w += blockDim.x) { + for (int bid = 0; bid < num; bid++) { + for (int image_h = threadIdx.y; image_h < output_height; + image_h += blockDim.y) { + int kernel_id = blockIdx.z; + int kernel_h = blockIdx.y * dilate_height - padding_height; + int kernel_w = blockIdx.x * dilate_width - padding_width; + + int image_hk = image_h * stride_height + kernel_h; + int image_wk = image_w * stride_width + kernel_w; + if (image_hk < 0 || image_hk >= input_height) continue; + if (image_wk < 0 || image_wk >= input_width) continue; +#define gaid(N, C, H, W) \ + ((((N)*gridDim.z + (C)) * output_height + (H)) * output_width + (W)) + + s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] * + input_data[((bid * (gridDim.z / filter_multiplier) + + kernel_id / filter_multiplier) * + input_height + + image_hk) * + input_width + + image_wk]; + +#undef gaid } } } +#if __CUDA_ARCH__ >= 530 + s = warpReduceSum(s); + if (lid == 0) paddle::platform::CudaAtomicAdd(&filter_grad_data[gbid], s); +#else + paddle::platform::CudaAtomicAdd(&filter_grad_data[gbid], s); +#endif +} + +template +__global__ void KernelDepthwiseConvFilterGradSp( + const T* output_grad_data, const T* input_data, const int num, + const int output_channels, const int output_height, const int output_width, + const int input_channels, const int input_height, const int input_width, + const int filter_multiplier, const int filter_height, + const int filter_width, const int stride_height, const int stride_width, + const int padding_height, const int padding_width, const int dilate_height, + const int dilate_width, T* filter_grad_data) { + if (c_filter_multiplier == 0) + KernelDepthwiseConvFilterGrad( + output_grad_data, input_data, num, output_channels, output_height, + output_width, input_channels, input_height, input_width, + filter_multiplier, filter_height, filter_width, stride_height, + stride_width, padding_height, padding_width, dilate_height, + dilate_width, filter_grad_data); + else + KernelDepthwiseConvFilterGrad( + output_grad_data, input_data, num, output_channels, output_height, + output_width, input_channels, input_height, input_width, + c_filter_multiplier, filter_height, filter_width, stride_height, + stride_width, padding_height, padding_width, dilate_height, + dilate_width, filter_grad_data); } /* @@ -177,7 +297,9 @@ class DepthwiseConvFunctor { const framework::Tensor& input, const framework::Tensor& filter, const std::vector& strides, - const std::vector& paddings, framework::Tensor* output) { + const std::vector& paddings, + const std::vector& dilations, + framework::Tensor* output) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -191,22 +313,37 @@ class DepthwiseConvFunctor { const int stride_width = strides[1]; const int padding_height = paddings[0]; const int padding_width = paddings[1]; + const int dilate_height = dilations[0]; + const int dilate_width = dilations[1]; const T* input_data = input.data(); const T* filter_data = filter.data(); T* output_data = output->mutable_data(context.GetPlace()); - int nthreads = batch_size * output_channels * output_height * output_width; - int blocks = (nthreads + 1024 - 1) / 1024; - dim3 threads(1024, 1); - dim3 grid(blocks, 1); - - KernelDepthwiseConv<<>>( - nthreads, input_data, filter_data, batch_size, output_channels, - output_height, output_width, input_channels, input_height, input_width, - output_channels / input_channels, ksize_height, ksize_width, - stride_height, stride_width, padding_height, padding_width, - output_data); + int thread = 512; + int blocks = std::min(std::max(thread / output_width, 1), output_height); + dim3 threads(std::min(output_width, thread), blocks, 1); + dim3 grid(output_channels, batch_size, 1); + int filter_multiplier = output_channels / input_channels; +#define check_case(c_filter_multiplier, c_stride) \ + if (c_filter_multiplier == 0 || \ + filter_multiplier == c_filter_multiplier && \ + stride_height == stride_width && stride_height == c_stride) { \ + KernelDepthwiseConvSp<<>>( \ + input_data, filter_data, batch_size, output_channels, output_height, \ + output_width, input_channels, input_height, input_width, \ + filter_multiplier, ksize_height, ksize_width, stride_height, \ + stride_width, padding_height, padding_width, dilate_height, \ + dilate_width, output_data); \ + return; \ + } + check_case(1, 1); + check_case(1, 2); + // NOTE(liangdun): 0,0 for other case + // add other case if needed, e.g. check_case(2^n,1) + check_case(0, 0); +#undef check_case } }; @@ -219,6 +356,7 @@ class DepthwiseConvInputGradFunctor { const framework::Tensor& output_grad, const std::vector& strides, const std::vector& paddings, + const std::vector& dilations, framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; @@ -233,22 +371,39 @@ class DepthwiseConvInputGradFunctor { const int stride_width = strides[1]; const int padding_height = paddings[0]; const int padding_width = paddings[1]; + const int dilate_height = dilations[0]; + const int dilate_width = dilations[1]; const T* filter_data = filter.data(); const T* output_grad_data = output_grad.data(); T* input_grad_data = input_grad->mutable_data(context.GetPlace()); - int nthreads = batch_size * input_channels * input_height * input_width; - int blocks = (nthreads + 1024 - 1) / 1024; - dim3 threads(1024, 1); - dim3 grid(blocks, 1); - - KernelDepthwiseConvInputGrad<<>>( - nthreads, output_grad_data, filter_data, batch_size, output_channels, - output_height, output_width, input_channels, input_height, input_width, - output_channels / input_channels, ksize_height, ksize_width, - stride_height, stride_width, padding_height, padding_width, - input_grad_data); + int thread = 512; + int blocks = std::min(std::max(thread / input_width, 1), input_height); + dim3 threads(std::min(input_width, thread), blocks, 1); + dim3 grid(input_channels, batch_size, 1); + int filter_multiplier = output_channels / input_channels; + +#define check_case(c_filter_multiplier, c_stride) \ + if (c_filter_multiplier == 0 || \ + filter_multiplier == c_filter_multiplier && \ + stride_height == stride_width && stride_height == c_stride) { \ + KernelDepthwiseConvInputGradSp< \ + T, c_filter_multiplier, \ + c_stride><<>>( \ + output_grad_data, filter_data, batch_size, output_channels, \ + output_height, output_width, input_channels, input_height, \ + input_width, filter_multiplier, ksize_height, ksize_width, \ + stride_height, stride_width, padding_height, padding_width, \ + dilate_height, dilate_width, input_grad_data); \ + return; \ + } + check_case(1, 1); + check_case(1, 2); + // NOTE(liangdun): 0,0 for other case + // add other case if needed, e.g. check_case(2^n,1) + check_case(0, 0); +#undef check_case } }; @@ -260,6 +415,7 @@ class DepthwiseConvFilterGradFunctor { const framework::Tensor& output_grad, const std::vector& strides, const std::vector& paddings, + const std::vector& dilations, framework::Tensor* filter_grad) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; @@ -274,23 +430,34 @@ class DepthwiseConvFilterGradFunctor { const int stride_width = strides[1]; const int padding_height = paddings[0]; const int padding_width = paddings[1]; + const int dilate_height = dilations[0]; + const int dilate_width = dilations[1]; const T* input_data = input.data(); const T* output_grad_data = output_grad.data(); T* filter_grad_data = filter_grad->mutable_data(context.GetPlace()); - int nthreads = batch_size * output_channels * output_height * output_width; - - int blocks = (nthreads + 1024 - 1) / 1024; - dim3 threads(1024, 1); - dim3 grid(blocks, 1); - - KernelDepthwiseConvFilterGrad<<>>( - nthreads, output_grad_data, input_data, batch_size, output_channels, - output_height, output_width, input_channels, input_height, input_width, - output_channels / input_channels, ksize_height, ksize_width, - stride_height, stride_width, padding_height, padding_width, - filter_grad_data); + int block_size = 512; + int crop_output_height = + std::min(std::max(block_size / output_width, 1), output_height); + dim3 grid(ksize_width, ksize_height, output_channels); + dim3 threads(std::min(output_width, block_size), crop_output_height, 1); + int filter_multiplier = output_channels / input_channels; + +#define check_case(c_filter_multiplier) \ + if (c_filter_multiplier == 0 || c_filter_multiplier == filter_multiplier) { \ + KernelDepthwiseConvFilterGradSp< \ + T, c_filter_multiplier><<>>( \ + output_grad_data, input_data, batch_size, output_channels, \ + output_height, output_width, input_channels, input_height, \ + input_width, filter_multiplier, ksize_height, ksize_width, \ + stride_height, stride_width, padding_height, padding_width, \ + dilate_height, dilate_width, filter_grad_data); \ + return; \ + } + check_case(1); + check_case(0); +#undef check_case } }; diff --git a/paddle/fluid/operators/math/depthwise_conv.h b/paddle/fluid/operators/math/depthwise_conv.h index 97aec401889a56d3fc9ac08e766d931bb3725b01..71f6fcb23df1942d6dcf7177165f2ec1022a9b35 100644 --- a/paddle/fluid/operators/math/depthwise_conv.h +++ b/paddle/fluid/operators/math/depthwise_conv.h @@ -32,7 +32,8 @@ class DepthwiseConvFunctor { void operator()(const DeviceContext& context, const framework::Tensor& input, const framework::Tensor& filter, const std::vector& strides, - const std::vector& paddings, framework::Tensor* output); + const std::vector& paddings, + const std::vector& dilations, framework::Tensor* output); }; template @@ -43,6 +44,7 @@ class DepthwiseConvInputGradFunctor { const framework::Tensor& output_grad, const std::vector& strides, const std::vector& paddings, + const std::vector& dilations, framework::Tensor* input_grad); }; @@ -53,6 +55,7 @@ class DepthwiseConvFilterGradFunctor { const framework::Tensor& output_grad, const std::vector& strides, const std::vector& paddings, + const std::vector& dilations, framework::Tensor* filter_grad); }; diff --git a/paddle/fluid/operators/reduce_mean_op.cu b/paddle/fluid/operators/reduce_mean_op.cu index 960cb3235be7f4cc98b97d3b088ceaeb3d4a4209..59b30244839849d79e3e531953134633503c4090 100644 --- a/paddle/fluid/operators/reduce_mean_op.cu +++ b/paddle/fluid/operators/reduce_mean_op.cu @@ -12,17 +12,64 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include "paddle/fluid/operators/cub_reduce.h" #include "paddle/fluid/operators/reduce_mean_op.h" -REGISTER_OP_CUDA_KERNEL(reduce_mean, - ops::ReduceKernel, - ops::ReduceKernel, - ops::ReduceKernel, - ops::ReduceKernel); +namespace paddle { +namespace operators { + +template +struct DivideFunctor { + HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {} + + HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; } + + private: + T n_inv; +}; + +template +class ReduceMeanKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + bool reduce_all = context.Attr("reduce_all"); + auto* input = context.Input("X"); + auto* output = context.Output("Out"); + + auto dims = context.Attr>("dim"); + bool keep_dim = context.Attr("keep_dim"); + + std::vector reduce_dims; + if (reduce_all) { + reduce_dims.resize(input->dims().size()); + for (int i = 0; i < reduce_dims.size(); ++i) reduce_dims[i] = i; + } else { + for (auto e : dims) { + reduce_dims.push_back(e >= 0 ? e : e + input->dims().size()); + } + } + + int reduce_num = 1; + for (int i = 0; i < reduce_dims.size(); ++i) { + reduce_num *= input->dims()[reduce_dims[i]]; + } + + auto stream = context.cuda_device_context().stream(); + TensorReduce>( + *input, output, reduce_dims, static_cast(0), cub::Sum(), + DivideFunctor(reduce_num), stream); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL(reduce_mean, ops::ReduceMeanKernel, + ops::ReduceMeanKernel, + ops::ReduceMeanKernel, + ops::ReduceMeanKernel); + REGISTER_OP_CUDA_KERNEL( reduce_mean_grad, ops::ReduceGradKernel, diff --git a/paddle/fluid/operators/reduce_sum_op.cu b/paddle/fluid/operators/reduce_sum_op.cu index f2e16955a50dc6a7feda9fbaf968c929ef3d8a4f..53cd9e9419dd9aecee730917ae21d7a4ab332ffc 100644 --- a/paddle/fluid/operators/reduce_sum_op.cu +++ b/paddle/fluid/operators/reduce_sum_op.cu @@ -12,17 +12,59 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/operators/cub_reduce.h" #include "paddle/fluid/operators/reduce_sum_op.h" -REGISTER_OP_CUDA_KERNEL(reduce_sum, - ops::ReduceKernel, - ops::ReduceKernel, - ops::ReduceKernel, - ops::ReduceKernel); +namespace paddle { +namespace operators { + +template +struct IdentityFunctor { + HOSTDEVICE explicit inline IdentityFunctor() {} + + HOSTDEVICE inline T operator()(const T& x) const { return x; } +}; + +template +class ReduceSumKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + bool reduce_all = context.Attr("reduce_all"); + auto* input = context.Input("X"); + auto* output = context.Output("Out"); + + auto dims = context.Attr>("dim"); + bool keep_dim = context.Attr("keep_dim"); + + std::vector reduce_dims; + if (reduce_all) { + reduce_dims.resize(input->dims().size()); + for (int i = 0; i < reduce_dims.size(); ++i) reduce_dims[i] = i; + } else { + for (auto e : dims) { + reduce_dims.push_back(e >= 0 ? e : e + input->dims().size()); + } + } + + int reduce_num = 1; + for (int i = 0; i < reduce_dims.size(); ++i) { + reduce_num *= input->dims()[reduce_dims[i]]; + } + + auto stream = context.cuda_device_context().stream(); + TensorReduce>( + *input, output, reduce_dims, static_cast(0), cub::Sum(), + IdentityFunctor(), stream); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL(reduce_sum, ops::ReduceSumKernel, + ops::ReduceSumKernel, ops::ReduceSumKernel, + ops::ReduceSumKernel); + REGISTER_OP_CUDA_KERNEL( reduce_sum_grad, ops::ReduceGradKernel, diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_op.py index 6a2732e9399aa5a93f4c47eb73bfd23dba608c3d..2ecc2504a8c9c5ecfc32cee96df9e368ff219cbb 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_op.py @@ -67,6 +67,7 @@ class TestConv2dOp(OpTest): def setUp(self): self.op_type = "conv2d" self.use_cudnn = False + self.use_cuda = False self.use_mkldnn = False self.data_format = "AnyLayout" self.dtype = np.float32 @@ -101,24 +102,25 @@ class TestConv2dOp(OpTest): } self.outputs = {'Output': output} - def testcudnn(self): - return core.is_compiled_with_cuda() and self.use_cudnn + def testcuda(self): + return core.is_compiled_with_cuda() and (self.use_cudnn or + self.use_cuda) def test_check_output(self): - place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace() + place = core.CUDAPlace(0) if self.testcuda() else core.CPUPlace() self.check_output_with_place(place, atol=1e-5) def test_check_grad(self): if self.dtype == np.float16: return - place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace() + place = core.CUDAPlace(0) if self.testcuda() else core.CPUPlace() self.check_grad_with_place( place, set(['Input', 'Filter']), 'Output', max_relative_error=0.02) def test_check_grad_no_filter(self): if self.dtype == np.float16: return - place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace() + place = core.CUDAPlace(0) if self.testcuda() else core.CPUPlace() self.check_grad_with_place( place, ['Input'], 'Output', @@ -128,7 +130,7 @@ class TestConv2dOp(OpTest): def test_check_grad_no_input(self): if self.dtype == np.float16: return - place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace() + place = core.CUDAPlace(0) if self.testcuda() else core.CPUPlace() self.check_grad_with_place( place, ['Filter'], 'Output', @@ -325,18 +327,33 @@ class TestFP16CUDNNWithInput1x1Filter1x1(TestWithInput1x1Filter1x1): class TestDepthwiseConv(TestConv2dOp): def init_test_case(self): + self.use_cuda = True self.pad = [1, 1] self.stride = [2, 2] self.input_size = [2, 3, 5, 5] # NCHW self.groups = 3 assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] // self.groups - self.filter_size = [6, f_c, 3, 3] + self.filter_size = [3, f_c, 3, 3] self.op_type = "depthwise_conv2d" class TestDepthwiseConv2(TestConv2dOp): def init_test_case(self): + self.use_cuda = True + self.pad = [1, 1] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [3, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + +class TestDepthwiseConv3(TestConv2dOp): + def init_test_case(self): + self.use_cuda = True self.pad = [1, 1] self.stride = [1, 1] self.input_size = [2, 3, 5, 5] # NCHW @@ -347,6 +364,34 @@ class TestDepthwiseConv2(TestConv2dOp): self.op_type = "depthwise_conv2d" +class TestDepthwiseConvWithDilation(TestConv2dOp): + def init_test_case(self): + self.use_cuda = True + self.pad = [1, 1] + self.stride = [2, 2] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + self.dilations = [2, 2] + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + +class TestDepthwiseConvWithDilation2(TestConv2dOp): + def init_test_case(self): + self.use_cuda = True + self.pad = [1, 1] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + self.dilations = [2, 2] + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + # Please Don't remove the following code. # Currently, CI use cudnn V5.0 which not support dilation conv. # class TestCUDNNWithDilation(TestWithDilation):