提交 5238c9fb 编写于 作者: C chengduoZH

input type should be different

上级 ba868854
...@@ -498,8 +498,8 @@ template class Pool3dGradFunctor< ...@@ -498,8 +498,8 @@ template class Pool3dGradFunctor<
* Ksize, strides, paddings are two elements. These two elements represent * Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively. * height and width, respectively.
*/ */
template <typename T> template <typename T1, typename T2>
class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> { class MaxPool2dWithIndexFunctor<platform::CPUPlace, T1, T2> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, std::vector<int>& ksize, const framework::Tensor& input, std::vector<int>& ksize,
...@@ -520,9 +520,9 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> { ...@@ -520,9 +520,9 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
const int input_stride = input_height * input_width; const int input_stride = input_height * input_width;
const int output_stride = output_height * output_width; const int output_stride = output_height * output_width;
const T* input_data = input.data<T>(); const T1* input_data = input.data<T1>();
T* output_data = output->mutable_data<T>(context.GetPlace()); T1* output_data = output->mutable_data<T1>(context.GetPlace());
T* mask_data = mask->mutable_data<T>(context.GetPlace()); T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
...@@ -535,7 +535,7 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> { ...@@ -535,7 +535,7 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
int wend = std::min(wstart + ksize_width, input_width); int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0); wstart = std::max(wstart, 0);
T ele = static_cast<T>(-FLT_MAX); T1 ele = static_cast<T1>(-FLT_MAX);
int index = -1; int index = -1;
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
...@@ -563,8 +563,8 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> { ...@@ -563,8 +563,8 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
* Ksize, strides, paddings are two elements. These two elements represent * Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively. * height and width, respectively.
*/ */
template <typename T> template <typename T1, typename T2>
class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> { class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T1, T2> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& output_grad, const framework::Tensor& output_grad,
...@@ -580,9 +580,9 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> { ...@@ -580,9 +580,9 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
const int input_stride = input_height * input_width; const int input_stride = input_height * input_width;
const int output_stride = output_height * output_width; const int output_stride = output_height * output_width;
const T* mask_data = mask.data<T>(); const T2* mask_data = mask.data<T2>();
const T* output_grad_data = output_grad.data<T>(); const T1* output_grad_data = output_grad.data<T1>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
for (int n = 0; n < batch_size; ++n) { for (int n = 0; n < batch_size; ++n) {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
...@@ -602,18 +602,18 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> { ...@@ -602,18 +602,18 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
} }
}; };
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, float>; template class MaxPool2dWithIndexFunctor<platform::CPUPlace, float, int>;
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, float>; template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, float, int>;
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, double>; template class MaxPool2dWithIndexFunctor<platform::CPUPlace, double, int>;
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, double>; template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, double, int>;
/* /*
* All tensors are in NCDHW format. * All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent * Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively. * depth, height and width, respectively.
*/ */
template <typename T> template <typename T1, typename T2>
class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> { class MaxPool3dWithIndexFunctor<platform::CPUPlace, T1, T2> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, std::vector<int>& ksize, const framework::Tensor& input, std::vector<int>& ksize,
...@@ -639,9 +639,9 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> { ...@@ -639,9 +639,9 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
const int input_stride = input_depth * input_height * input_width; const int input_stride = input_depth * input_height * input_width;
const int output_stride = output_depth * output_height * output_width; const int output_stride = output_depth * output_height * output_width;
const T* input_data = input.data<T>(); const T1* input_data = input.data<T1>();
T* output_data = output->mutable_data<T>(context.GetPlace()); T1* output_data = output->mutable_data<T1>(context.GetPlace());
T* mask_data = mask->mutable_data<T>(context.GetPlace()); T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
...@@ -659,7 +659,7 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> { ...@@ -659,7 +659,7 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
wstart = std::max(wstart, 0); wstart = std::max(wstart, 0);
int output_idx = (pd * output_height + ph) * output_width + pw; int output_idx = (pd * output_height + ph) * output_width + pw;
T ele = static_cast<T>(-FLT_MAX); T1 ele = static_cast<T1>(-FLT_MAX);
int index = -1; int index = -1;
for (int d = dstart; d < dend; ++d) { for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
...@@ -691,8 +691,8 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> { ...@@ -691,8 +691,8 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
* Ksize, strides, paddings are three elements. These three elements represent * Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively. * depth, height and width, respectively.
*/ */
template <typename T> template <typename T1, typename T2>
class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> { class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T1, T2> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& output_grad, const framework::Tensor& output_grad,
...@@ -710,9 +710,9 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> { ...@@ -710,9 +710,9 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
const int input_stride = input_depth * input_height * input_width; const int input_stride = input_depth * input_height * input_width;
const int output_stride = output_depth * output_height * output_width; const int output_stride = output_depth * output_height * output_width;
const T* mask_data = mask.data<T>(); const T2* mask_data = mask.data<T2>();
const T* output_grad_data = output_grad.data<T>(); const T1* output_grad_data = output_grad.data<T1>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
for (int n = 0; n < batch_size; ++n) { for (int n = 0; n < batch_size; ++n) {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
...@@ -735,10 +735,10 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> { ...@@ -735,10 +735,10 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
} }
}; };
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, float>; template class MaxPool3dWithIndexFunctor<platform::CPUPlace, float, int>;
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, float>; template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, float, int>;
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, double>; template class MaxPool3dWithIndexFunctor<platform::CPUPlace, double, int>;
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, double>; template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, double, int>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -658,13 +658,13 @@ template class Pool3dGradFunctor< ...@@ -658,13 +658,13 @@ template class Pool3dGradFunctor<
template class Pool3dGradFunctor< template class Pool3dGradFunctor<
platform::GPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>; platform::GPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>;
template <typename T> template <typename T1, typename T2>
__global__ void KernelMaxPool2dWithIdx( __global__ void KernelMaxPool2dWithIdx(
const int nthreads, const T* input_data, const int channels, const int nthreads, const T1* input_data, const int channels,
const int input_height, const int input_width, const int output_height, const int input_height, const int input_width, const int output_height,
const int output_width, const int ksize_height, const int ksize_width, const int output_width, const int ksize_height, const int ksize_width,
const int stride_height, const int stride_width, const int padding_height, const int stride_height, const int stride_width, const int padding_height,
const int padding_width, T* output_data, T* mask_data) { const int padding_width, T1* output_data, T2* mask_data) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) { index += blockDim.x * gridDim.x) {
int pw = index % output_width; int pw = index % output_width;
...@@ -681,7 +681,7 @@ __global__ void KernelMaxPool2dWithIdx( ...@@ -681,7 +681,7 @@ __global__ void KernelMaxPool2dWithIdx(
wstart = max(wstart, 0); wstart = max(wstart, 0);
input_data += (batch_idx * channels + c) * input_height * input_width; input_data += (batch_idx * channels + c) * input_height * input_width;
T ele = -FLT_MAX; T1 ele = -FLT_MAX;
int max_index = -1; int max_index = -1;
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
...@@ -697,13 +697,13 @@ __global__ void KernelMaxPool2dWithIdx( ...@@ -697,13 +697,13 @@ __global__ void KernelMaxPool2dWithIdx(
} }
} }
template <typename T> template <typename T1, typename T2>
__global__ void KernelMaxPool2DWithIdxGrad( __global__ void KernelMaxPool2DWithIdxGrad(
const int nthreads, const T* output_grad, const T* mask_data, const int nthreads, const T1* output_grad, const T2* mask_data,
const int channels, const int input_height, const int input_width, const int channels, const int input_height, const int input_width,
const int output_height, const int output_width, const int ksize_height, const int output_height, const int output_width, const int ksize_height,
const int ksize_width, const int stride_height, const int stride_width, const int ksize_width, const int stride_height, const int stride_width,
const int padding_height, const int padding_width, T* input_grad) { const int padding_height, const int padding_width, T1* input_grad) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) { index += blockDim.x * gridDim.x) {
int w_offset = index % input_width; int w_offset = index % input_width;
...@@ -724,7 +724,7 @@ __global__ void KernelMaxPool2DWithIdxGrad( ...@@ -724,7 +724,7 @@ __global__ void KernelMaxPool2DWithIdxGrad(
int pw_end = int pw_end =
min((w_offset + padding_width) / stride_width + 1, output_width); min((w_offset + padding_width) / stride_width + 1, output_width);
T gradient = 0; T1 gradient = 0;
int input_current_featuremap_idx = h_offset * input_width + w_offset; int input_current_featuremap_idx = h_offset * input_width + w_offset;
int output_idx = int output_idx =
(batch_idx * channels + c_offset) * output_height * output_width; (batch_idx * channels + c_offset) * output_height * output_width;
...@@ -746,8 +746,8 @@ __global__ void KernelMaxPool2DWithIdxGrad( ...@@ -746,8 +746,8 @@ __global__ void KernelMaxPool2DWithIdxGrad(
* Ksize, strides, paddings are two elements. These two elements represent * Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively. * height and width, respectively.
*/ */
template <typename T> template <typename T1, typename T2>
class MaxPool2dWithIndexFunctor<platform::GPUPlace, T> { class MaxPool2dWithIndexFunctor<platform::GPUPlace, T1, T2> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, std::vector<int>& ksize, const framework::Tensor& input, std::vector<int>& ksize,
...@@ -767,9 +767,9 @@ class MaxPool2dWithIndexFunctor<platform::GPUPlace, T> { ...@@ -767,9 +767,9 @@ class MaxPool2dWithIndexFunctor<platform::GPUPlace, T> {
const int padding_height = paddings[0]; const int padding_height = paddings[0];
const int padding_width = paddings[1]; const int padding_width = paddings[1];
const T* input_data = input.data<T>(); const T1* input_data = input.data<T1>();
T* output_data = output->mutable_data<T>(context.GetPlace()); T1* output_data = output->mutable_data<T1>(context.GetPlace());
T* mask_data = mask->mutable_data<T>(context.GetPlace()); T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
int nthreads = batch_size * output_channels * output_height * output_width; int nthreads = batch_size * output_channels * output_height * output_width;
int blocks = (nthreads + 1024 - 1) / 1024; int blocks = (nthreads + 1024 - 1) / 1024;
...@@ -777,9 +777,9 @@ class MaxPool2dWithIndexFunctor<platform::GPUPlace, T> { ...@@ -777,9 +777,9 @@ class MaxPool2dWithIndexFunctor<platform::GPUPlace, T> {
dim3 grid(blocks, 1); dim3 grid(blocks, 1);
KernelMaxPool2dWithIdx< KernelMaxPool2dWithIdx<
T><<<grid, threads, 0, T1, T2><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>( .stream()>>>(
nthreads, input_data, input_channels, input_height, input_width, nthreads, input_data, input_channels, input_height, input_width,
output_height, output_width, ksize_height, ksize_width, stride_height, output_height, output_width, ksize_height, ksize_width, stride_height,
stride_width, padding_height, padding_width, output_data, mask_data); stride_width, padding_height, padding_width, output_data, mask_data);
...@@ -791,8 +791,8 @@ class MaxPool2dWithIndexFunctor<platform::GPUPlace, T> { ...@@ -791,8 +791,8 @@ class MaxPool2dWithIndexFunctor<platform::GPUPlace, T> {
* Ksize, strides, paddings are two elements. These two elements represent * Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively. * height and width, respectively.
*/ */
template <typename T> template <typename T1, typename T2>
class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, T> { class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, T1, T2> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& output_grad, const framework::Tensor& output_grad,
...@@ -812,9 +812,9 @@ class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, T> { ...@@ -812,9 +812,9 @@ class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, T> {
const int padding_height = paddings[0]; const int padding_height = paddings[0];
const int padding_width = paddings[1]; const int padding_width = paddings[1];
const T* mask_data = mask.data<T>(); const T2* mask_data = mask.data<T2>();
const T* output_grad_data = output_grad.data<T>(); const T1* output_grad_data = output_grad.data<T1>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
int nthreads = batch_size * input_channels * input_height * input_width; int nthreads = batch_size * input_channels * input_height * input_width;
int blocks = (nthreads + 1024 - 1) / 1024; int blocks = (nthreads + 1024 - 1) / 1024;
...@@ -822,30 +822,30 @@ class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, T> { ...@@ -822,30 +822,30 @@ class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, T> {
dim3 grid(blocks, 1); dim3 grid(blocks, 1);
KernelMaxPool2DWithIdxGrad< KernelMaxPool2DWithIdxGrad<
T><<<grid, threads, 0, T1, T2><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(nthreads, output_grad_data, mask_data, .stream()>>>(
input_channels, input_height, input_width, nthreads, output_grad_data, mask_data, input_channels, input_height,
output_height, output_width, ksize_height, input_width, output_height, output_width, ksize_height, ksize_width,
ksize_width, stride_height, stride_width, stride_height, stride_width, padding_height, padding_width,
padding_height, padding_width, input_grad_data); input_grad_data);
} }
}; };
template class MaxPool2dWithIndexFunctor<platform::GPUPlace, float>; template class MaxPool2dWithIndexFunctor<platform::GPUPlace, float, int>;
template class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, float>; template class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, float, int>;
template class MaxPool2dWithIndexFunctor<platform::GPUPlace, double>; template class MaxPool2dWithIndexFunctor<platform::GPUPlace, double, int>;
template class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, double>; template class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, double, int>;
template <typename T> template <typename T1, typename T2>
__global__ void KernelMaxPool3DWithIdx( __global__ void KernelMaxPool3DWithIdx(
const int nthreads, const T* input_data, const int channels, const int nthreads, const T1* input_data, const int channels,
const int input_depth, const int input_height, const int input_width, const int input_depth, const int input_height, const int input_width,
const int output_depth, const int output_height, const int output_width, const int output_depth, const int output_height, const int output_width,
const int ksize_depth, const int ksize_height, const int ksize_width, const int ksize_depth, const int ksize_height, const int ksize_width,
const int stride_depth, const int stride_height, const int stride_width, const int stride_depth, const int stride_height, const int stride_width,
const int padding_depth, const int padding_height, const int padding_width, const int padding_depth, const int padding_height, const int padding_width,
T* output_data, T* mask_data) { T1* output_data, T2* mask_data) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) { index += blockDim.x * gridDim.x) {
int pw = index % output_width; int pw = index % output_width;
...@@ -865,7 +865,7 @@ __global__ void KernelMaxPool3DWithIdx( ...@@ -865,7 +865,7 @@ __global__ void KernelMaxPool3DWithIdx(
hstart = max(hstart, 0); hstart = max(hstart, 0);
wstart = max(wstart, 0); wstart = max(wstart, 0);
T ele = -FLT_MAX; T1 ele = -FLT_MAX;
int max_index = -1; int max_index = -1;
input_data += input_data +=
(batch_idx * channels + c) * input_depth * input_height * input_width; (batch_idx * channels + c) * input_depth * input_height * input_width;
...@@ -885,15 +885,15 @@ __global__ void KernelMaxPool3DWithIdx( ...@@ -885,15 +885,15 @@ __global__ void KernelMaxPool3DWithIdx(
} }
} }
template <typename T> template <typename T1, typename T2>
__global__ void KernelMaxPool3DWithIdxGrad( __global__ void KernelMaxPool3DWithIdxGrad(
const int nthreads, const T* output_grad, const T* mask, const int channels, const int nthreads, const T1* output_grad, const T2* mask,
const int input_depth, const int input_height, const int input_width, const int channels, const int input_depth, const int input_height,
const int output_depth, const int output_height, const int output_width, const int input_width, const int output_depth, const int output_height,
const int ksize_depth, const int ksize_height, const int ksize_width, const int output_width, const int ksize_depth, const int ksize_height,
const int stride_depth, const int stride_height, const int stride_width, const int ksize_width, const int stride_depth, const int stride_height,
const int padding_depth, const int padding_height, const int padding_width, const int stride_width, const int padding_depth, const int padding_height,
T* input_grad) { const int padding_width, T1* input_grad) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) { index += blockDim.x * gridDim.x) {
int w_offset = index % input_width; int w_offset = index % input_width;
...@@ -922,7 +922,7 @@ __global__ void KernelMaxPool3DWithIdxGrad( ...@@ -922,7 +922,7 @@ __global__ void KernelMaxPool3DWithIdxGrad(
int pw_end = int pw_end =
min((w_offset + padding_width) / stride_width + 1, output_width); min((w_offset + padding_width) / stride_width + 1, output_width);
T gradient = 0; T1 gradient = 0;
int input_current_feature_map_idx = int input_current_feature_map_idx =
(d_offset * input_height + h_offset) * input_width + w_offset; (d_offset * input_height + h_offset) * input_width + w_offset;
int output_idx = (batch_idx * channels + c_offset) * output_depth * int output_idx = (batch_idx * channels + c_offset) * output_depth *
...@@ -949,8 +949,8 @@ __global__ void KernelMaxPool3DWithIdxGrad( ...@@ -949,8 +949,8 @@ __global__ void KernelMaxPool3DWithIdxGrad(
* Ksize, strides, paddings are three elements. These three elements represent * Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively. * depth, height and width, respectively.
*/ */
template <typename T> template <typename T1, typename T2>
class MaxPool3dWithIndexFunctor<platform::GPUPlace, T> { class MaxPool3dWithIndexFunctor<platform::GPUPlace, T1, T2> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, std::vector<int>& ksize, const framework::Tensor& input, std::vector<int>& ksize,
...@@ -975,9 +975,9 @@ class MaxPool3dWithIndexFunctor<platform::GPUPlace, T> { ...@@ -975,9 +975,9 @@ class MaxPool3dWithIndexFunctor<platform::GPUPlace, T> {
const int padding_height = paddings[1]; const int padding_height = paddings[1];
const int padding_width = paddings[2]; const int padding_width = paddings[2];
const T* input_data = input.data<T>(); const T1* input_data = input.data<T1>();
T* output_data = output->mutable_data<T>(context.GetPlace()); T1* output_data = output->mutable_data<T1>(context.GetPlace());
T* mask_data = mask->mutable_data<T>(context.GetPlace()); T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
int nthreads = batch_size * output_channels * output_depth * output_height * int nthreads = batch_size * output_channels * output_depth * output_height *
output_width; output_width;
...@@ -986,9 +986,9 @@ class MaxPool3dWithIndexFunctor<platform::GPUPlace, T> { ...@@ -986,9 +986,9 @@ class MaxPool3dWithIndexFunctor<platform::GPUPlace, T> {
dim3 grid(blocks, 1); dim3 grid(blocks, 1);
KernelMaxPool3DWithIdx< KernelMaxPool3DWithIdx<
T><<<grid, threads, 0, T1, T2><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>( .stream()>>>(
nthreads, input_data, input_channels, input_depth, input_height, nthreads, input_data, input_channels, input_depth, input_height,
input_width, output_depth, output_height, output_width, ksize_depth, input_width, output_depth, output_height, output_width, ksize_depth,
ksize_height, ksize_width, stride_depth, stride_height, stride_width, ksize_height, ksize_width, stride_depth, stride_height, stride_width,
...@@ -1001,8 +1001,8 @@ class MaxPool3dWithIndexFunctor<platform::GPUPlace, T> { ...@@ -1001,8 +1001,8 @@ class MaxPool3dWithIndexFunctor<platform::GPUPlace, T> {
* Ksize, strides, paddings are three elements. These three elements represent * Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively. * depth, height and width, respectively.
*/ */
template <typename T> template <typename T1, typename T2>
class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, T> { class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, T1, T2> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& output_grad, const framework::Tensor& output_grad,
...@@ -1027,9 +1027,9 @@ class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, T> { ...@@ -1027,9 +1027,9 @@ class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, T> {
const int padding_height = paddings[1]; const int padding_height = paddings[1];
const int padding_width = paddings[2]; const int padding_width = paddings[2];
const T* output_grad_data = output_grad.data<T>(); const T1* output_grad_data = output_grad.data<T1>();
const T* mask_data = mask.data<T>(); const T2* mask_data = mask.data<T2>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
int nthreads = int nthreads =
batch_size * input_channels * input_depth * input_height * input_width; batch_size * input_channels * input_depth * input_height * input_width;
...@@ -1038,9 +1038,9 @@ class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, T> { ...@@ -1038,9 +1038,9 @@ class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, T> {
dim3 grid(blocks, 1); dim3 grid(blocks, 1);
KernelMaxPool3DWithIdxGrad< KernelMaxPool3DWithIdxGrad<
T><<<grid, threads, 0, T1, T2><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>( .stream()>>>(
nthreads, output_grad_data, mask_data, input_channels, input_depth, nthreads, output_grad_data, mask_data, input_channels, input_depth,
input_height, input_width, output_depth, output_height, output_width, input_height, input_width, output_depth, output_height, output_width,
ksize_depth, ksize_height, ksize_width, stride_depth, stride_height, ksize_depth, ksize_height, ksize_width, stride_depth, stride_height,
...@@ -1049,10 +1049,10 @@ class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, T> { ...@@ -1049,10 +1049,10 @@ class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, T> {
} }
}; };
template class MaxPool3dWithIndexFunctor<platform::GPUPlace, float>; template class MaxPool3dWithIndexFunctor<platform::GPUPlace, float, int>;
template class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, float>; template class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, float, int>;
template class MaxPool3dWithIndexFunctor<platform::GPUPlace, double>; template class MaxPool3dWithIndexFunctor<platform::GPUPlace, double, int>;
template class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, double>; template class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, double, int>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -153,7 +153,7 @@ class MaxPool3dGradFunctor { ...@@ -153,7 +153,7 @@ class MaxPool3dGradFunctor {
* In pool2d, all tensors are in NCHW format. In pool3d, all tensors are in * In pool2d, all tensors are in NCHW format. In pool3d, all tensors are in
* NCDHW format. * NCDHW format.
*/ */
template <typename Place, typename T> template <typename Place, typename T1, typename T2>
class MaxPool2dWithIndexFunctor { class MaxPool2dWithIndexFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
...@@ -162,7 +162,7 @@ class MaxPool2dWithIndexFunctor { ...@@ -162,7 +162,7 @@ class MaxPool2dWithIndexFunctor {
framework::Tensor* output, framework::Tensor* mask); framework::Tensor* output, framework::Tensor* mask);
}; };
template <typename Place, typename T> template <typename Place, typename T1, typename T2>
class MaxPool2dWithIndexGradFunctor { class MaxPool2dWithIndexGradFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
...@@ -172,7 +172,7 @@ class MaxPool2dWithIndexGradFunctor { ...@@ -172,7 +172,7 @@ class MaxPool2dWithIndexGradFunctor {
framework::Tensor* input_grad); framework::Tensor* input_grad);
}; };
template <typename Place, typename T> template <typename Place, typename T1, typename T2>
class MaxPool3dWithIndexFunctor { class MaxPool3dWithIndexFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
...@@ -181,7 +181,7 @@ class MaxPool3dWithIndexFunctor { ...@@ -181,7 +181,7 @@ class MaxPool3dWithIndexFunctor {
framework::Tensor* output, framework::Tensor* mask); framework::Tensor* output, framework::Tensor* mask);
}; };
template <typename Place, typename T> template <typename Place, typename T1, typename T2>
class MaxPool3dWithIndexGradFunctor { class MaxPool3dWithIndexGradFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
......
...@@ -29,11 +29,11 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { ...@@ -29,11 +29,11 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"X(Input) of Pooling should not be null."); "Input(X) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Out(Output) of Pooling should not be null."); "Output(Out) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Mask"), PADDLE_ENFORCE(ctx->HasOutput("Mask"),
"Mask(Output) of Pooling should not be null."); "Output(Mask) of Pooling should not be null.");
auto in_x_dims = ctx->GetInputDim("X"); auto in_x_dims = ctx->GetInputDim("X");
...@@ -67,6 +67,14 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { ...@@ -67,6 +67,14 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->SetOutputDim("Mask", framework::make_ddim(output_shape)); ctx->SetOutputDim("Mask", framework::make_ddim(output_shape));
} }
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
}
}; };
class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
...@@ -80,6 +88,14 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { ...@@ -80,6 +88,14 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
"Input(X@GRAD) should not be null."); "Input(X@GRAD) should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
}
}; };
class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -116,7 +132,7 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -116,7 +132,7 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<bool>( AddAttr<bool>(
"global_pooling", "global_pooling",
"(bool, default false) Whether to use the global pooling. " "(bool, default:false) Whether to use the global pooling. "
"If global_pooling = true, ksize and paddings will be ignored.") "If global_pooling = true, ksize and paddings will be ignored.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
...@@ -126,7 +142,7 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -126,7 +142,7 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"paddings", "paddings",
"(vector<int>, defalut {0, 0}), paddings(height, width) of pooling " "(vector<int>, defalut:{0, 0}), paddings(height, width) of pooling "
"operator. " "operator. "
"If global_pooling = true, paddings and will be ignored.") "If global_pooling = true, paddings and will be ignored.")
.SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently, .SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently,
...@@ -250,10 +266,10 @@ REGISTER_OP(max_pool2d_with_index, ops::MaxPoolWithIndexOp, ...@@ -250,10 +266,10 @@ REGISTER_OP(max_pool2d_with_index, ops::MaxPoolWithIndexOp,
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
max_pool2d_with_index, max_pool2d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float>); ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float, int>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
max_pool2d_with_index_grad, max_pool2d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float>) ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float, int>)
REGISTER_OP(max_pool3d_with_index, ops::MaxPoolWithIndexOp, REGISTER_OP(max_pool3d_with_index, ops::MaxPoolWithIndexOp,
ops::MaxPool3dWithIndexOpMaker, max_pool3d_with_index_grad, ops::MaxPool3dWithIndexOpMaker, max_pool3d_with_index_grad,
...@@ -261,7 +277,7 @@ REGISTER_OP(max_pool3d_with_index, ops::MaxPoolWithIndexOp, ...@@ -261,7 +277,7 @@ REGISTER_OP(max_pool3d_with_index, ops::MaxPoolWithIndexOp,
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
max_pool3d_with_index, max_pool3d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float>); ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float, int>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
max_pool3d_with_index_grad, max_pool3d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float>) ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float, int>)
...@@ -18,14 +18,14 @@ namespace ops = paddle::operators; ...@@ -18,14 +18,14 @@ namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
max_pool2d_with_index, max_pool2d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float>); ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float, int>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
max_pool2d_with_index_grad, max_pool2d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float>) ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float, int>)
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
max_pool3d_with_index, max_pool3d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float>); ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float, int>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
max_pool3d_with_index_grad, max_pool3d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float>) ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float, int>)
...@@ -24,8 +24,8 @@ namespace operators { ...@@ -24,8 +24,8 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename Place, typename T> template <typename Place, typename T1, typename T2>
class MaxPoolWithIndexKernel : public framework::OpKernel<T> { class MaxPoolWithIndexKernel : public framework::OpKernel<T1> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X"); const Tensor* in_x = context.Input<Tensor>("X");
...@@ -44,13 +44,13 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T> { ...@@ -44,13 +44,13 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T> {
switch (ksize.size()) { switch (ksize.size()) {
case 2: { case 2: {
paddle::operators::math::MaxPool2dWithIndexFunctor<Place, T> paddle::operators::math::MaxPool2dWithIndexFunctor<Place, T1, T2>
pool2d_forward; pool2d_forward;
pool2d_forward(context.device_context(), *in_x, ksize, strides, pool2d_forward(context.device_context(), *in_x, ksize, strides,
paddings, out, mask); paddings, out, mask);
} break; } break;
case 3: { case 3: {
paddle::operators::math::MaxPool3dWithIndexFunctor<Place, T> paddle::operators::math::MaxPool3dWithIndexFunctor<Place, T1, T2>
pool3d_forward; pool3d_forward;
pool3d_forward(context.device_context(), *in_x, ksize, strides, pool3d_forward(context.device_context(), *in_x, ksize, strides,
paddings, out, mask); paddings, out, mask);
...@@ -60,8 +60,8 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T> { ...@@ -60,8 +60,8 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T> {
} }
}; };
template <typename Place, typename T> template <typename Place, typename T1, typename T2>
class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> { class MaxPoolWithIndexGradKernel : public framework::OpKernel<T1> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* mask = context.Input<Tensor>("Mask"); const Tensor* mask = context.Input<Tensor>("Mask");
...@@ -80,19 +80,19 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> { ...@@ -80,19 +80,19 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> {
} }
if (in_x_grad) { if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace()); in_x_grad->mutable_data<T1>(context.GetPlace());
auto& device_ctx = context.device_context(); auto& device_ctx = context.device_context();
math::set_constant(device_ctx, in_x_grad, 0); math::set_constant(device_ctx, in_x_grad, 0);
switch (ksize.size()) { switch (ksize.size()) {
case 2: { case 2: {
paddle::operators::math::MaxPool2dWithIndexGradFunctor<Place, T> paddle::operators::math::MaxPool2dWithIndexGradFunctor<Place, T1, T2>
pool2d_backward; pool2d_backward;
pool2d_backward(device_ctx, *out_grad, *mask, ksize, strides, pool2d_backward(device_ctx, *out_grad, *mask, ksize, strides,
paddings, in_x_grad); paddings, in_x_grad);
} break; } break;
case 3: { case 3: {
paddle::operators::math::MaxPool3dWithIndexGradFunctor<Place, T> paddle::operators::math::MaxPool3dWithIndexGradFunctor<Place, T1, T2>
pool3d_backward; pool3d_backward;
pool3d_backward(device_ctx, *out_grad, *mask, ksize, strides, pool3d_backward(device_ctx, *out_grad, *mask, ksize, strides,
paddings, in_x_grad); paddings, in_x_grad);
......
...@@ -3,11 +3,13 @@ import numpy as np ...@@ -3,11 +3,13 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
def max_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=0): def max_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=False):
N, C, D, H, W = x.shape N, C, D, H, W = x.shape
if global_pool == 1: if global_pool:
ksize = [D, H, W] ksize = [D, H, W]
paddings = [0, 0, 0]
D_out = (D - ksize[0] + 2 * paddings[0]) / strides[0] + 1 D_out = (D - ksize[0] + 2 * paddings[0]) / strides[0] + 1
H_out = (H - ksize[1] + 2 * paddings[1]) / strides[1] + 1 H_out = (H - ksize[1] + 2 * paddings[1]) / strides[1] + 1
W_out = (W - ksize[2] + 2 * paddings[2]) / strides[2] + 1 W_out = (W - ksize[2] + 2 * paddings[2]) / strides[2] + 1
...@@ -40,11 +42,13 @@ def max_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=0): ...@@ -40,11 +42,13 @@ def max_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=0):
return out, mask return out, mask
def max_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=0): def max_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=False):
N, C, H, W = x.shape N, C, H, W = x.shape
if global_pool == 1: if global_pool:
ksize = [H, W] ksize = [H, W]
paddings = [0, 0]
H_out = (H - ksize[0] + 2 * paddings[0]) / strides[0] + 1 H_out = (H - ksize[0] + 2 * paddings[0]) / strides[0] + 1
W_out = (W - ksize[1] + 2 * paddings[1]) / strides[1] + 1 W_out = (W - ksize[1] + 2 * paddings[1]) / strides[1] + 1
out = np.zeros((N, C, H_out, W_out)) out = np.zeros((N, C, H_out, W_out))
...@@ -74,13 +78,13 @@ def max_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=0): ...@@ -74,13 +78,13 @@ def max_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=0):
class TestMaxPoolWithIndex_Op(OpTest): class TestMaxPoolWithIndex_Op(OpTest):
def setUp(self): def setUp(self):
self.init_test_case() self.init_test_case()
if self.global_pool: self.init_global()
self.paddings = [0 for _ in range(len(self.paddings))]
input = np.random.random(self.shape).astype("float32") input = np.random.random(self.shape).astype("float32")
output, mask = self.pool_forward_naive(input, self.ksize, self.strides, output, mask = self.pool_forward_naive(input, self.ksize, self.strides,
self.paddings, self.global_pool) self.paddings, self.global_pool)
output = output.astype("float32") output = output.astype("float32")
mask = mask.astype("float32") mask = mask.astype("int32")
self.attrs = { self.attrs = {
'strides': self.strides, 'strides': self.strides,
...@@ -99,41 +103,24 @@ class TestMaxPoolWithIndex_Op(OpTest): ...@@ -99,41 +103,24 @@ class TestMaxPoolWithIndex_Op(OpTest):
# self.check_grad(set(['X']), ['Out'], max_relative_error=0.07) # self.check_grad(set(['X']), ['Out'], max_relative_error=0.07)
def init_test_case(self): def init_test_case(self):
self.global_pool = True self.op_type = "max_pool3d_with_index"
self.index = "max_pool3d_with_index"
self.op_type = "%s" % self.index
self.pool_forward_naive = max_pool3D_forward_naive self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 5, 5, 5] self.shape = [2, 3, 5, 5, 5]
self.ksize = [3, 3, 3] self.ksize = [3, 3, 3]
self.strides = [1, 1, 1] self.strides = [1, 1, 1]
self.paddings = [1, 1, 1] self.paddings = [1, 1, 1]
def init_global(self):
self.global_pool = False
class TestCase1(TestMaxPoolWithIndex_Op): class TestCase1(TestMaxPoolWithIndex_Op):
def init_test_case(self): def init_global(self):
self.global_pool = True self.global_pool = True
self.op_type = "max_pool3d_with_index"
self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 5, 5, 5]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [1, 1, 1]
class TestCase2(TestMaxPoolWithIndex_Op): class TestCase2(TestMaxPoolWithIndex_Op):
def init_test_case(self): def init_test_case(self):
self.global_pool = False
self.op_type = "max_pool3d_with_index"
self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 7, 7, 7]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [1, 1, 1]
class TestCase3(TestMaxPoolWithIndex_Op):
def init_test_case(self):
self.global_pool = False
self.op_type = "max_pool3d_with_index" self.op_type = "max_pool3d_with_index"
self.pool_forward_naive = max_pool3D_forward_naive self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 7, 7, 7] self.shape = [2, 3, 7, 7, 7]
...@@ -141,32 +128,18 @@ class TestCase3(TestMaxPoolWithIndex_Op): ...@@ -141,32 +128,18 @@ class TestCase3(TestMaxPoolWithIndex_Op):
self.strides = [2, 2, 2] self.strides = [2, 2, 2]
self.paddings = [0, 0, 0] self.paddings = [0, 0, 0]
def init_global(self):
class TestCase4(TestMaxPoolWithIndex_Op):
def init_test_case(self):
self.global_pool = True self.global_pool = True
self.op_type = "max_pool3d_with_index"
self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 5, 5, 5]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [1, 1, 1]
class TestCase5(TestMaxPoolWithIndex_Op): class TestCase3(TestCase2):
def init_test_case(self): def init_global(self):
self.global_pool = True self.global_pool = False
self.op_type = "max_pool3d_with_index"
self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 5, 5, 5]
self.ksize = [3, 3, 3]
self.strides = [2, 2, 2]
self.paddings = [0, 0, 0]
class TestCase6(TestMaxPoolWithIndex_Op): #----------------max_pool2d_with_index----------------
class TestCase4(TestMaxPoolWithIndex_Op):
def init_test_case(self): def init_test_case(self):
self.global_pool = False
self.op_type = "max_pool2d_with_index" self.op_type = "max_pool2d_with_index"
self.pool_forward_naive = max_pool2D_forward_naive self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7] self.shape = [2, 3, 7, 7]
...@@ -174,10 +147,17 @@ class TestCase6(TestMaxPoolWithIndex_Op): ...@@ -174,10 +147,17 @@ class TestCase6(TestMaxPoolWithIndex_Op):
self.strides = [1, 1] self.strides = [1, 1]
self.paddings = [1, 1] self.paddings = [1, 1]
def init_global(self):
self.global_pool = True
class TestCase7(TestMaxPoolWithIndex_Op):
def init_test_case(self): class TestCase5(TestMaxPoolWithIndex_Op):
def init_global(self):
self.global_pool = False self.global_pool = False
class TestCase6(TestMaxPoolWithIndex_Op):
def init_test_case(self):
self.op_type = "max_pool2d_with_index" self.op_type = "max_pool2d_with_index"
self.pool_forward_naive = max_pool2D_forward_naive self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7] self.shape = [2, 3, 7, 7]
...@@ -185,27 +165,13 @@ class TestCase7(TestMaxPoolWithIndex_Op): ...@@ -185,27 +165,13 @@ class TestCase7(TestMaxPoolWithIndex_Op):
self.strides = [2, 2] self.strides = [2, 2]
self.paddings = [0, 0] self.paddings = [0, 0]
def init_global(self):
class TestCase8(TestMaxPoolWithIndex_Op):
def init_test_case(self):
self.global_pool = True self.global_pool = True
self.op_type = "max_pool2d_with_index"
self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 5, 5]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 1]
class TestCase9(TestMaxPoolWithIndex_Op): class TestCase7(TestCase6):
def init_test_case(self): def init_global(self):
self.global_pool = True self.global_pool = False
self.op_type = "max_pool2d_with_index"
self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 5, 5]
self.ksize = [3, 3]
self.strides = [2, 2]
self.paddings = [0, 0]
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册