未验证 提交 f8823c1a 编写于 作者: C co63oc 提交者: GitHub

Add unfold tests (#52963)

上级 183a74db
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/im2col.h"
namespace phi {
......@@ -71,7 +72,7 @@ __global__ void im2col(const T* data_im,
}
*data_col =
(rIdx >= im_height || rIdx < 0 || cIdx >= im_width || cIdx < 0)
? 0
? T(0)
: data_im[im_idx];
data_col += col_height * col_width;
}
......@@ -173,7 +174,7 @@ __global__ void col2im(int n,
int input_channels = n / im_height / im_width;
if (index < n) {
T val = 0;
T val = static_cast<T>(0);
int w = (data_layout != DataLayout::kNHWC
? index % im_width + padding_width
: (index / input_channels) % im_width + padding_width);
......@@ -309,12 +310,24 @@ template class Im2ColFunctor<phi::funcs::ColFormat::kCFO,
template class Im2ColFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
double>;
template class Im2ColFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
phi::dtype::float16>;
template class Im2ColFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
phi::dtype::bfloat16>;
template class Col2ImFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
float>;
template class Col2ImFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
double>;
template class Col2ImFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
phi::dtype::float16>;
template class Col2ImFunctor<phi::funcs::ColFormat::kCFO,
phi::GPUContext,
phi::dtype::bfloat16>;
template <class T>
__global__ void im2colOCF(const T* im_data,
......@@ -560,13 +573,24 @@ template class Im2ColFunctor<phi::funcs::ColFormat::kOCF,
template class Im2ColFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
double>;
template class Im2ColFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
phi::dtype::float16>;
template class Im2ColFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
phi::dtype::bfloat16>;
template class Col2ImFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
float>;
template class Col2ImFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
double>;
template class Col2ImFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
phi::dtype::float16>;
template class Col2ImFunctor<phi::funcs::ColFormat::kOCF,
phi::GPUContext,
phi::dtype::bfloat16>;
} // namespace funcs
} // namespace phi
......@@ -18,5 +18,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/unfold_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
unfold_grad, GPU, ALL_LAYOUT, phi::UnfoldGradKernel, float, double) {}
PD_REGISTER_KERNEL(unfold_grad,
GPU,
ALL_LAYOUT,
phi::UnfoldGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -15,7 +15,15 @@
#include "paddle/phi/kernels/unfold_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/unfold_kernel_impl.h"
PD_REGISTER_KERNEL(unfold, GPU, ALL_LAYOUT, phi::UnfoldKernel, float, double) {}
PD_REGISTER_KERNEL(unfold,
GPU,
ALL_LAYOUT,
phi::UnfoldKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle import fluid
......@@ -42,7 +42,11 @@ class TestUnfoldOp(OpTest):
self.input_height,
self.input_width,
]
self.x = np.random.rand(*input_shape).astype(np.float64)
if self.dtype == np.uint16:
as_type = self.np_dtype
else:
as_type = self.dtype
self.x = np.random.rand(*input_shape).astype(as_type)
def calc_unfold(self):
output_shape = [0] * 3
......@@ -77,7 +81,11 @@ class TestUnfoldOp(OpTest):
+ 1
)
output_shape[2] = out_height * out_width
output = np.zeros(output_shape).astype(np.float64)
if self.dtype == np.uint16:
as_type = self.np_dtype
else:
as_type = self.dtype
output = np.zeros(output_shape).astype(as_type)
# ------------ calculate output -------------- #
for i in range(output_shape[0]):
for j in range(output_shape[1]):
......@@ -123,9 +131,13 @@ class TestUnfoldOp(OpTest):
def setUp(self):
self.op_type = 'unfold'
self.init_dtype()
self.python_api = paddle.nn.functional.unfold
self.set_data()
def init_dtype(self):
self.dtype = np.float64
def test_check_output(self):
self.check_output()
......@@ -133,6 +145,55 @@ class TestUnfoldOp(OpTest):
self.check_grad(['X'], 'Y')
class TestUnfoldFP16Op(TestUnfoldOp):
def init_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestUnfoldBF16Op(TestUnfoldOp):
# Notice: The test is time consuming, may cause timeout, modify the parameters to reduce the time
def init_data(self):
self.batch_size = 3
self.input_channels = 3
self.input_height = 5
self.input_width = 5
self.kernel_sizes = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 1, 1, 1]
self.dilations = [1, 1]
input_shape = [
self.batch_size,
self.input_channels,
self.input_height,
self.input_width,
]
self.x = np.random.rand(*input_shape).astype(self.np_dtype)
def init_dtype(self):
self.dtype = np.uint16
self.np_dtype = np.float32
def setUp(self):
self.op_type = 'unfold'
self.init_dtype()
self.python_api = paddle.nn.functional.unfold
self.set_data()
self.inputs['X'] = convert_float_to_uint16(self.inputs['X'])
self.outputs['Y'] = convert_float_to_uint16(self.outputs['Y'])
self.place = core.CUDAPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Y')
class TestUnfoldAPI(TestUnfoldOp):
"""
This is for test on paddle.nn.Unfold
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册