diff --git a/paddle/fluid/operators/batch_norm_op_xpu.cc b/paddle/fluid/operators/batch_norm_op_xpu.cc index ff6bb22d3957ccb087505a38c292571845812628..526fc7364cdd8407a2e25e966adc0043b9d5206f 100644 --- a/paddle/fluid/operators/batch_norm_op_xpu.cc +++ b/paddle/fluid/operators/batch_norm_op_xpu.cc @@ -139,16 +139,14 @@ class BatchNormGradXPUKernel : public framework::OpKernel { auto* dscale_data = dscale->mutable_data(ctx.GetPlace()); auto* dbias_data = dbias->mutable_data(ctx.GetPlace()); auto& dev_ctx = ctx.template device_context(); - int r = xpu::batch_norm_backward(dev_ctx.x_context(), N, C, H, W, x_data, - dy_data, scale_data, saved_mean_data, - saved_inv_variance_data, dx_data, - dscale_data, dbias_data); - PADDLE_ENFORCE_EQ( - r, XPU_SUCCESS, - platform::errors::External("XPU API(batch_norm_infer_forward) return " - "wrong value[%d], please check whether " - "Baidu Kunlun Card is properly installed.", - r)); + int r = xpu::batch_norm_grad(dev_ctx.x_context(), x_data, dy_data, + dx_data, N, C, H, W, scale_data, + saved_mean_data, saved_inv_variance_data, + dscale_data, dbias_data, true); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( + "XPU API(batch_norm_grad) return " + "wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); } }; diff --git a/paddle/fluid/operators/pool_op_xpu.cc b/paddle/fluid/operators/pool_op_xpu.cc index 096a81db9bd66a3de02549e9e1e0cd6e022fb249..402dd6c10803947f73e593d215d28246a81c6706 100644 --- a/paddle/fluid/operators/pool_op_xpu.cc +++ b/paddle/fluid/operators/pool_op_xpu.cc @@ -30,6 +30,7 @@ xpu::Pooling_t XPUPoolingType(const std::string& pooltype, bool exclusive, "Pool op only supports 2D and 3D input.")); } } + template class PoolXPUKernel : public framework::OpKernel { public: @@ -41,7 +42,6 @@ class PoolXPUKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); bool exclusive = context.Attr("exclusive"); - bool is_test = context.Attr("is_test"); bool adaptive = context.Attr("adaptive"); PADDLE_ENFORCE_EQ( ksize.size(), 2, @@ -60,36 +60,32 @@ class PoolXPUKernel : public framework::OpKernel { ksize[i] = static_cast(in_x->dims()[i + 2]); } } - const int c = in_x->dims()[0] * in_x->dims()[1]; + const int n = in_x->dims()[0]; + const int c = in_x->dims()[1]; const int in_h = in_x->dims()[2]; const int in_w = in_x->dims()[3]; - const int out_h = out->dims()[2]; - const int out_w = out->dims()[3]; - const int win_h = ksize[0]; - const int win_w = ksize[1]; - const int stride_h = strides[0]; - const int stride_w = strides[1]; - const int pad_up = paddings[0]; - const int pad_down = paddings[0]; - const int pad_left = paddings[1]; - const int pad_right = paddings[1]; const float* input = in_x->data(); out->mutable_data(context.GetPlace()); float* output = out->data(); - xpu::Pooling_t pool_type = XPUPoolingType(pooling_type, exclusive, is_test); auto& dev_ctx = context.template device_context(); - int r = xpu::pooling_forward( - dev_ctx.x_context(), input, output, index_data, pool_type, c, in_h, - in_w, pad_left, pad_right, pad_up, pad_down, win_h, win_w, stride_h, - stride_w, out_h, out_w); - PADDLE_ENFORCE_EQ( - r, xpu::Error_t::SUCCESS, - platform::errors::External( - "The pool2d XPU API return wrong value[%d], please check " - "where Baidu Kunlun Card is properly installed.", - r)); + int r = xpu::Error_t::SUCCESS; + if (pooling_type == "max") { + r = xpu::max_pool2d(dev_ctx.x_context(), input, output, index_data, n, c, + in_h, in_w, ksize, strides, paddings, true); + } else if (pooling_type == "avg") { + r = xpu::avg_pool2d(dev_ctx.x_context(), input, output, n, c, in_h, in_w, + ksize, strides, paddings, !exclusive, true); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported pooling type for kunlun ", pooling_type)); + } + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The pool2d XPU API return wrong value[%d %s]", r, + XPUAPIErrorMsg[r])); } }; + template class PoolGradXPUKernel : public framework::OpKernel { public: @@ -126,47 +122,33 @@ class PoolGradXPUKernel : public framework::OpKernel { if (!in_x_grad) { return; } - const int c = in_x->dims()[0] * in_x->dims()[1]; + const int n = in_x->dims()[0]; + const int c = in_x->dims()[1]; const int in_h = in_x->dims()[2]; const int in_w = in_x->dims()[3]; - const int out_h = out->dims()[2]; - const int out_w = out->dims()[3]; - const int win_h = ksize[0]; - const int win_w = ksize[1]; - const int stride_h = strides[0]; - const int stride_w = strides[1]; - const int pad_up = paddings[0]; - const int pad_down = paddings[0]; - const int pad_left = paddings[1]; - const int pad_right = paddings[1]; const float* input = in_x->data(); const float* output = out->data(); const float* output_grad = out_grad->data(); in_x_grad->mutable_data(context.GetPlace()); float* input_grad = in_x_grad->data(); - xpu::Pooling_t pool_type = XPUPoolingType(pooling_type, exclusive, false); auto& dev_ctx = context.template device_context(); - // Need to init memory in the first place - const int zero = 0; - int r = - xpu::memset(dev_ctx.x_context(), reinterpret_cast(input_grad), - zero, in_x_grad->numel() * sizeof(float)); - PADDLE_ENFORCE_EQ( - r, xpu::Error_t::SUCCESS, - platform::errors::External( - "The Pool2d XPU OP return wrong value[%d], please check " - "where Baidu Kunlun Card is properly installed.", - r)); - r = xpu::pooling_backward(dev_ctx.x_context(), input, output, index_data, - output_grad, input_grad, pool_type, c, in_h, in_w, - pad_left, pad_right, pad_up, pad_down, win_h, - win_w, stride_h, stride_w, out_h, out_w); - PADDLE_ENFORCE_EQ( - r, xpu::Error_t::SUCCESS, - platform::errors::External( - "The Pool2d XPU OP return wrong value[%d], please check " - "where Baidu Kunlun Card is properly installed.", - r)); + int r = xpu::Error_t::SUCCESS; + if (pooling_type == "max") { + r = xpu::max_pool2d_grad(dev_ctx.x_context(), input, output, index_data, + output_grad, input_grad, n, c, in_h, in_w, ksize, + strides, paddings, true); + } else if (pooling_type == "avg") { + r = xpu::avg_pool2d_grad(dev_ctx.x_context(), input, output, output_grad, + input_grad, n, c, in_h, in_w, ksize, strides, + paddings, !exclusive, true); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported pooling type for kunlun ", pooling_type)); + } + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The Pool2dGrad XPU OP return wrong value[%d %s]", r, + XPUAPIErrorMsg[r])); } }; diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index fb94768984fcfb4b886e4805f8328fe76a7b3625..d9e9443e75292d8062a324501618c48754136db4 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -172,16 +172,7 @@ Place CPUDeviceContext::GetPlace() const { return place_; } #ifdef PADDLE_WITH_XPU XPUDeviceContext::XPUDeviceContext() { context_ = xpu::create_context(); } -XPUDeviceContext::~XPUDeviceContext() { - xpu::destroy_context(context_); - void* l3ptr = nullptr; - int l3_size = 13.5 * 1024 * 1024; - xpu_malloc(static_cast(&l3ptr), l3_size, XPU_MEM_L3); - if (l3ptr != nullptr) { - context_->_l3_mgr.set(l3ptr, l3_size); - std::cout << "set l3 size " << l3_size << std::endl; - } -} +XPUDeviceContext::~XPUDeviceContext() {} XPUDeviceContext::XPUDeviceContext(XPUPlace place) : place_(place) { int dev_id = -1; diff --git a/python/paddle/fluid/tests/unittests/xpu/test_pool2d_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_pool2d_op_xpu.py index 7f20c83aacb1f51feb256a1de01d22382079bedc..bebb5c762649145cab666633ac91371ab679f551 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_pool2d_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_pool2d_op_xpu.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,16 +13,20 @@ # limitations under the License. from __future__ import print_function +from __future__ import division import sys sys.path.append("..") -import paddle.fluid.core as core import unittest import numpy as np -from op_test import OpTest -import paddle + +import paddle.fluid.core as core +from op_test_xpu import XPUOpTest import paddle.fluid as fluid from paddle.fluid import Program, program_guard +import paddle + +paddle.enable_static() def max_pool2D_forward_naive(x, @@ -241,7 +245,7 @@ def pool2D_forward_naive(x, return out -class TestPool2D_Op(OpTest): +class TestPool2D_Op(XPUOpTest): def setUp(self): self.op_type = "pool2d" self.use_cudnn = False @@ -265,7 +269,7 @@ class TestPool2D_Op(OpTest): input, self.ksize, self.strides, self.paddings, self.global_pool, self.ceil_mode, self.exclusive, self.adaptive, self.data_format, self.pool_type, self.padding_algorithm).astype(self.dtype) - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)} + self.inputs = {'X': XPUOpTest.np_dtype_to_fluid_dtype(input)} self.attrs = { 'strides': self.strides, @@ -284,18 +288,20 @@ class TestPool2D_Op(OpTest): self.outputs = {'Out': output} + def has_xpu(self): + return core.is_compiled_with_xpu() + def test_check_output(self): - if paddle.is_compiled_with_xpu(): - paddle.enable_static() - place = paddle.XPUPlace(0) + if self.has_xpu(): + place = core.XPUPlace(0) self.check_output_with_place(place) + return def test_check_grad(self): - if paddle.is_compiled_with_xpu(): - paddle.enable_static() - place = paddle.XPUPlace(0) - self.check_grad_with_place( - place, set(['X']), 'Out', max_relative_error=0.07) + if self.has_xpu(): + place = core.XPUPlace(0) + self.check_grad_with_place(place, set(['X']), 'Out') + return def init_data_format(self): self.data_format = "NCHW" @@ -315,7 +321,7 @@ class TestPool2D_Op(OpTest): self.use_cudnn = False def init_data_type(self): - self.dtype = np.float64 + self.dtype = np.float32 def init_pool_type(self): self.pool_type = "avg" @@ -334,5 +340,134 @@ class TestPool2D_Op(OpTest): self.adaptive = False +class TestCase1(TestPool2D_Op): + def init_test_case(self): + self.ksize = [3, 3] + self.strides = [1, 1] + + def init_paddings(self): + self.paddings = [0, 0] + + def init_pool_type(self): + self.pool_type = "avg" + self.pool2D_forward_naive = avg_pool2D_forward_naive + + def init_global_pool(self): + self.global_pool = False + + def init_shape(self): + self.shape = [2, 3, 7, 7] + + +class TestCase2(TestPool2D_Op): + def init_test_case(self): + self.ksize = [3, 3] + self.strides = [1, 1] + + def init_paddings(self): + self.paddings = [1, 1] + + def init_pool_type(self): + self.pool_type = "avg" + self.pool2D_forward_naive = avg_pool2D_forward_naive + + def init_global_pool(self): + self.global_pool = False + + def init_shape(self): + self.shape = [2, 3, 7, 7] + + +class TestCase3(TestPool2D_Op): + def init_pool_type(self): + self.pool_type = "max" + self.pool2D_forward_naive = max_pool2D_forward_naive + + +class TestCase4(TestCase1): + def init_pool_type(self): + self.pool_type = "max" + self.pool2D_forward_naive = max_pool2D_forward_naive + + +class TestCase5(TestCase2): + def init_pool_type(self): + self.pool_type = "max" + self.pool2D_forward_naive = max_pool2D_forward_naive + + +class TestPool2D_AsyPadding(TestPool2D_Op): + def init_test_case(self): + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 0, 1, 2] + + def init_shape(self): + self.shape = [2, 3, 5, 5] + + +class TestCase1_AsyPadding(TestCase1): + def init_test_case(self): + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 0, 1, 0] + + def init_shape(self): + self.shape = [2, 3, 7, 7] + + +class TestCase2_AsyPadding(TestCase2): + def init_test_case(self): + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 2, 1, 2] + + def init_shape(self): + self.shape = [2, 3, 7, 7] + + +class TestCase3_AsyPadding(TestCase3): + def init_test_case(self): + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 0, 1, 2] + + def init_shape(self): + self.shape = [2, 3, 5, 5] + + +class TestCase4_AsyPadding(TestCase4): + def init_test_case(self): + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 0, 1, 0] + + def init_shape(self): + self.shape = [2, 3, 7, 7] + + +class TestCase5_AsyPadding((TestCase5)): + def init_test_case(self): + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [2, 2, 1, 2] + + def init_shape(self): + self.shape = [2, 3, 7, 7] + + +class TestAvgInclude_AsyPadding(TestCase2): + def init_exclusive(self): + self.exclusive = False + + def init_test_case(self): + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 2, 1, 2] + + def init_shape(self): + self.shape = [2, 3, 7, 7] + + if __name__ == '__main__': unittest.main()