未验证 提交 4d198acb 编写于 作者: Z zhangyikun02 提交者: GitHub

pool2d support fp16 on xpu and update pool2d unittest, test=kunlun (#40841)

上级 d1c1d731
......@@ -37,6 +37,8 @@ xpu::Pooling_t XPUPoolingType(const std::string& pooltype, bool exclusive,
template <typename DeviceContext, typename T>
class PoolXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X");
......@@ -68,17 +70,19 @@ class PoolXPUKernel : public framework::OpKernel<T> {
const int c = in_x->dims()[1];
const int in_h = in_x->dims()[2];
const int in_w = in_x->dims()[3];
const float* input = in_x->data<float>();
auto input = reinterpret_cast<const XPUType*>(in_x->data<T>());
out->mutable_data<T>(context.GetPlace());
float* output = out->data<float>();
auto output = reinterpret_cast<XPUType*>(out->data<T>());
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::Error_t::SUCCESS;
if (pooling_type == "max") {
r = xpu::max_pool2d(dev_ctx.x_context(), input, output, index_data, n, c,
in_h, in_w, ksize, strides, paddings, true);
r = xpu::max_pool2d<XPUType>(dev_ctx.x_context(), input, output,
index_data, n, c, in_h, in_w, ksize, strides,
paddings, true);
} else if (pooling_type == "avg") {
r = xpu::avg_pool2d(dev_ctx.x_context(), input, output, n, c, in_h, in_w,
ksize, strides, paddings, !exclusive, true);
r = xpu::avg_pool2d<XPUType>(dev_ctx.x_context(), input, output, n, c,
in_h, in_w, ksize, strides, paddings,
!exclusive, true);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported pooling type for kunlun ", pooling_type));
......@@ -92,6 +96,8 @@ class PoolXPUKernel : public framework::OpKernel<T> {
template <typename DeviceContext, typename T>
class PoolGradXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X");
......@@ -130,21 +136,21 @@ class PoolGradXPUKernel : public framework::OpKernel<T> {
const int c = in_x->dims()[1];
const int in_h = in_x->dims()[2];
const int in_w = in_x->dims()[3];
const float* input = in_x->data<float>();
const float* output = out->data<float>();
const float* output_grad = out_grad->data<float>();
auto input = reinterpret_cast<const XPUType*>(in_x->data<T>());
auto output = reinterpret_cast<const XPUType*>(out->data<T>());
auto output_grad = reinterpret_cast<const XPUType*>(out_grad->data<T>());
in_x_grad->mutable_data<T>(context.GetPlace());
float* input_grad = in_x_grad->data<float>();
auto input_grad = reinterpret_cast<XPUType*>(in_x_grad->data<T>());
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::Error_t::SUCCESS;
if (pooling_type == "max") {
r = xpu::max_pool2d_grad(dev_ctx.x_context(), input, output, index_data,
output_grad, input_grad, n, c, in_h, in_w, ksize,
strides, paddings, true);
r = xpu::max_pool2d_grad<XPUType>(
dev_ctx.x_context(), input, output, index_data, output_grad,
input_grad, n, c, in_h, in_w, ksize, strides, paddings, true);
} else if (pooling_type == "avg") {
r = xpu::avg_pool2d_grad(dev_ctx.x_context(), input, output, output_grad,
input_grad, n, c, in_h, in_w, ksize, strides,
paddings, !exclusive, true);
r = xpu::avg_pool2d_grad<XPUType>(
dev_ctx.x_context(), input, output, output_grad, input_grad, n, c,
in_h, in_w, ksize, strides, paddings, !exclusive, true);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported pooling type for kunlun ", pooling_type));
......@@ -161,9 +167,13 @@ class PoolGradXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
pool2d, ops::PoolXPUKernel<paddle::platform::XPUDeviceContext, float>);
pool2d, ops::PoolXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::PoolXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
pool2d_grad,
ops::PoolGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
ops::PoolGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::PoolGradXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
#endif
......@@ -25,6 +25,7 @@ from op_test_xpu import XPUOpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
from test_pool2d_op import adaptive_start_index, adaptive_end_index
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
import paddle
paddle.enable_static()
......@@ -246,13 +247,19 @@ def pool2D_forward_naive(x,
return out
class TestPool2D_Op(XPUOpTest):
class XPUTestPool2D_Op(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'pool2d'
self.use_dynamic_create_class = False
class TestPool2D_Op(XPUOpTest):
def setUp(self):
self.op_type = "pool2d"
self.dtype = self.in_type
self.place = paddle.XPUPlace(0)
self.use_cudnn = False
self.init_kernel_type()
self.use_mkldnn = False
self.init_data_type()
self.init_test_case()
self.padding_algorithm = "EXPLICIT"
self.init_paddings()
......@@ -267,9 +274,10 @@ class TestPool2D_Op(XPUOpTest):
input = np.random.random(self.shape).astype(self.dtype)
output = pool2D_forward_naive(
input, self.ksize, self.strides, self.paddings, self.global_pool,
self.ceil_mode, self.exclusive, self.adaptive, self.data_format,
self.pool_type, self.padding_algorithm).astype(self.dtype)
input, self.ksize, self.strides, self.paddings,
self.global_pool, self.ceil_mode, self.exclusive, self.adaptive,
self.data_format, self.pool_type,
self.padding_algorithm).astype(self.dtype)
self.inputs = {'X': XPUOpTest.np_dtype_to_fluid_dtype(input)}
self.attrs = {
......@@ -289,20 +297,11 @@ class TestPool2D_Op(XPUOpTest):
self.outputs = {'Out': output}
def has_xpu(self):
return core.is_compiled_with_xpu()
def test_check_output(self):
if self.has_xpu():
place = core.XPUPlace(0)
self.check_output_with_place(place)
return
self.check_output_with_place(self.place)
def test_check_grad(self):
if self.has_xpu():
place = core.XPUPlace(0)
self.check_grad_with_place(place, set(['X']), 'Out')
return
self.check_grad_with_place(self.place, set(['X']), 'Out')
def init_data_format(self):
self.data_format = "NCHW"
......@@ -321,9 +320,6 @@ class TestPool2D_Op(XPUOpTest):
def init_kernel_type(self):
self.use_cudnn = False
def init_data_type(self):
self.dtype = np.float32
def init_pool_type(self):
self.pool_type = "avg"
self.pool2D_forward_naive = avg_pool2D_forward_naive
......@@ -340,8 +336,7 @@ class TestPool2D_Op(XPUOpTest):
def init_adaptive(self):
self.adaptive = False
class TestCase1(TestPool2D_Op):
class TestCase1(TestPool2D_Op):
def init_test_case(self):
self.ksize = [3, 3]
self.strides = [1, 1]
......@@ -359,8 +354,7 @@ class TestCase1(TestPool2D_Op):
def init_shape(self):
self.shape = [2, 3, 7, 7]
class TestCase2(TestPool2D_Op):
class TestCase2(TestPool2D_Op):
def init_test_case(self):
self.ksize = [3, 3]
self.strides = [1, 1]
......@@ -378,26 +372,22 @@ class TestCase2(TestPool2D_Op):
def init_shape(self):
self.shape = [2, 3, 7, 7]
class TestCase3(TestPool2D_Op):
class TestCase3(TestPool2D_Op):
def init_pool_type(self):
self.pool_type = "max"
self.pool2D_forward_naive = max_pool2D_forward_naive
class TestCase4(TestCase1):
class TestCase4(TestCase1):
def init_pool_type(self):
self.pool_type = "max"
self.pool2D_forward_naive = max_pool2D_forward_naive
class TestCase5(TestCase2):
class TestCase5(TestCase2):
def init_pool_type(self):
self.pool_type = "max"
self.pool2D_forward_naive = max_pool2D_forward_naive
class TestPool2D_AsyPadding(TestPool2D_Op):
class TestPool2D_AsyPadding(TestPool2D_Op):
def init_test_case(self):
self.ksize = [3, 3]
self.strides = [1, 1]
......@@ -406,8 +396,7 @@ class TestPool2D_AsyPadding(TestPool2D_Op):
def init_shape(self):
self.shape = [2, 3, 5, 5]
class TestCase1_AsyPadding(TestCase1):
class TestCase1_AsyPadding(TestCase1):
def init_test_case(self):
self.ksize = [3, 3]
self.strides = [1, 1]
......@@ -416,8 +405,7 @@ class TestCase1_AsyPadding(TestCase1):
def init_shape(self):
self.shape = [2, 3, 7, 7]
class TestCase2_AsyPadding(TestCase2):
class TestCase2_AsyPadding(TestCase2):
def init_test_case(self):
self.ksize = [3, 3]
self.strides = [1, 1]
......@@ -426,8 +414,7 @@ class TestCase2_AsyPadding(TestCase2):
def init_shape(self):
self.shape = [2, 3, 7, 7]
class TestCase3_AsyPadding(TestCase3):
class TestCase3_AsyPadding(TestCase3):
def init_test_case(self):
self.ksize = [3, 3]
self.strides = [1, 1]
......@@ -436,8 +423,7 @@ class TestCase3_AsyPadding(TestCase3):
def init_shape(self):
self.shape = [2, 3, 5, 5]
class TestCase4_AsyPadding(TestCase4):
class TestCase4_AsyPadding(TestCase4):
def init_test_case(self):
self.ksize = [3, 3]
self.strides = [1, 1]
......@@ -446,8 +432,7 @@ class TestCase4_AsyPadding(TestCase4):
def init_shape(self):
self.shape = [2, 3, 7, 7]
class TestCase5_AsyPadding((TestCase5)):
class TestCase5_AsyPadding(TestCase5):
def init_test_case(self):
self.ksize = [3, 3]
self.strides = [1, 1]
......@@ -456,8 +441,7 @@ class TestCase5_AsyPadding((TestCase5)):
def init_shape(self):
self.shape = [2, 3, 7, 7]
class TestAvgInclude_AsyPadding(TestCase2):
class TestAvgInclude_AsyPadding(TestCase2):
def init_exclusive(self):
self.exclusive = False
......@@ -470,5 +454,9 @@ class TestAvgInclude_AsyPadding(TestCase2):
self.shape = [2, 3, 7, 7]
support_types = get_xpu_op_support_types('pool2d')
for stype in support_types:
create_test_class(globals(), XPUTestPool2D_Op, stype)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册