未验证 提交 8489d4f7 编写于 作者: Q QingshuChen 提交者: GitHub

optimize batch_norm & pool op for kunlun (#30490)

上级 bd971922
......@@ -139,16 +139,14 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> {
auto* dscale_data = dscale->mutable_data<T>(ctx.GetPlace());
auto* dbias_data = dbias->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::batch_norm_backward(dev_ctx.x_context(), N, C, H, W, x_data,
dy_data, scale_data, saved_mean_data,
saved_inv_variance_data, dx_data,
dscale_data, dbias_data);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(batch_norm_infer_forward) return "
"wrong value[%d], please check whether "
"Baidu Kunlun Card is properly installed.",
r));
int r = xpu::batch_norm_grad<T>(dev_ctx.x_context(), x_data, dy_data,
dx_data, N, C, H, W, scale_data,
saved_mean_data, saved_inv_variance_data,
dscale_data, dbias_data, true);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(batch_norm_grad) return "
"wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
}
};
......
......@@ -30,6 +30,7 @@ xpu::Pooling_t XPUPoolingType(const std::string& pooltype, bool exclusive,
"Pool op only supports 2D and 3D input."));
}
}
template <typename DeviceContext, typename T>
class PoolXPUKernel : public framework::OpKernel<T> {
public:
......@@ -41,7 +42,6 @@ class PoolXPUKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
bool exclusive = context.Attr<bool>("exclusive");
bool is_test = context.Attr<bool>("is_test");
bool adaptive = context.Attr<bool>("adaptive");
PADDLE_ENFORCE_EQ(
ksize.size(), 2,
......@@ -60,36 +60,32 @@ class PoolXPUKernel : public framework::OpKernel<T> {
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
}
const int c = in_x->dims()[0] * in_x->dims()[1];
const int n = in_x->dims()[0];
const int c = in_x->dims()[1];
const int in_h = in_x->dims()[2];
const int in_w = in_x->dims()[3];
const int out_h = out->dims()[2];
const int out_w = out->dims()[3];
const int win_h = ksize[0];
const int win_w = ksize[1];
const int stride_h = strides[0];
const int stride_w = strides[1];
const int pad_up = paddings[0];
const int pad_down = paddings[0];
const int pad_left = paddings[1];
const int pad_right = paddings[1];
const float* input = in_x->data<float>();
out->mutable_data<T>(context.GetPlace());
float* output = out->data<float>();
xpu::Pooling_t pool_type = XPUPoolingType(pooling_type, exclusive, is_test);
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::pooling_forward<float, float>(
dev_ctx.x_context(), input, output, index_data, pool_type, c, in_h,
in_w, pad_left, pad_right, pad_up, pad_down, win_h, win_w, stride_h,
stride_w, out_h, out_w);
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::External(
"The pool2d XPU API return wrong value[%d], please check "
"where Baidu Kunlun Card is properly installed.",
r));
int r = xpu::Error_t::SUCCESS;
if (pooling_type == "max") {
r = xpu::max_pool2d(dev_ctx.x_context(), input, output, index_data, n, c,
in_h, in_w, ksize, strides, paddings, true);
} else if (pooling_type == "avg") {
r = xpu::avg_pool2d(dev_ctx.x_context(), input, output, n, c, in_h, in_w,
ksize, strides, paddings, !exclusive, true);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported pooling type for kunlun ", pooling_type));
}
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External(
"The pool2d XPU API return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
}
};
template <typename DeviceContext, typename T>
class PoolGradXPUKernel : public framework::OpKernel<T> {
public:
......@@ -126,47 +122,33 @@ class PoolGradXPUKernel : public framework::OpKernel<T> {
if (!in_x_grad) {
return;
}
const int c = in_x->dims()[0] * in_x->dims()[1];
const int n = in_x->dims()[0];
const int c = in_x->dims()[1];
const int in_h = in_x->dims()[2];
const int in_w = in_x->dims()[3];
const int out_h = out->dims()[2];
const int out_w = out->dims()[3];
const int win_h = ksize[0];
const int win_w = ksize[1];
const int stride_h = strides[0];
const int stride_w = strides[1];
const int pad_up = paddings[0];
const int pad_down = paddings[0];
const int pad_left = paddings[1];
const int pad_right = paddings[1];
const float* input = in_x->data<float>();
const float* output = out->data<float>();
const float* output_grad = out_grad->data<float>();
in_x_grad->mutable_data<T>(context.GetPlace());
float* input_grad = in_x_grad->data<float>();
xpu::Pooling_t pool_type = XPUPoolingType(pooling_type, exclusive, false);
auto& dev_ctx = context.template device_context<DeviceContext>();
// Need to init memory in the first place
const int zero = 0;
int r =
xpu::memset(dev_ctx.x_context(), reinterpret_cast<void**>(input_grad),
zero, in_x_grad->numel() * sizeof(float));
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::External(
"The Pool2d XPU OP return wrong value[%d], please check "
"where Baidu Kunlun Card is properly installed.",
r));
r = xpu::pooling_backward(dev_ctx.x_context(), input, output, index_data,
output_grad, input_grad, pool_type, c, in_h, in_w,
pad_left, pad_right, pad_up, pad_down, win_h,
win_w, stride_h, stride_w, out_h, out_w);
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::External(
"The Pool2d XPU OP return wrong value[%d], please check "
"where Baidu Kunlun Card is properly installed.",
r));
int r = xpu::Error_t::SUCCESS;
if (pooling_type == "max") {
r = xpu::max_pool2d_grad(dev_ctx.x_context(), input, output, index_data,
output_grad, input_grad, n, c, in_h, in_w, ksize,
strides, paddings, true);
} else if (pooling_type == "avg") {
r = xpu::avg_pool2d_grad(dev_ctx.x_context(), input, output, output_grad,
input_grad, n, c, in_h, in_w, ksize, strides,
paddings, !exclusive, true);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported pooling type for kunlun ", pooling_type));
}
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External(
"The Pool2dGrad XPU OP return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
}
};
......
......@@ -172,16 +172,7 @@ Place CPUDeviceContext::GetPlace() const { return place_; }
#ifdef PADDLE_WITH_XPU
XPUDeviceContext::XPUDeviceContext() { context_ = xpu::create_context(); }
XPUDeviceContext::~XPUDeviceContext() {
xpu::destroy_context(context_);
void* l3ptr = nullptr;
int l3_size = 13.5 * 1024 * 1024;
xpu_malloc(static_cast<void**>(&l3ptr), l3_size, XPU_MEM_L3);
if (l3ptr != nullptr) {
context_->_l3_mgr.set(l3ptr, l3_size);
std::cout << "set l3 size " << l3_size << std::endl;
}
}
XPUDeviceContext::~XPUDeviceContext() {}
XPUDeviceContext::XPUDeviceContext(XPUPlace place) : place_(place) {
int dev_id = -1;
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,16 +13,20 @@
# limitations under the License.
from __future__ import print_function
from __future__ import division
import sys
sys.path.append("..")
import paddle.fluid.core as core
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid.core as core
from op_test_xpu import XPUOpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
import paddle
paddle.enable_static()
def max_pool2D_forward_naive(x,
......@@ -241,7 +245,7 @@ def pool2D_forward_naive(x,
return out
class TestPool2D_Op(OpTest):
class TestPool2D_Op(XPUOpTest):
def setUp(self):
self.op_type = "pool2d"
self.use_cudnn = False
......@@ -265,7 +269,7 @@ class TestPool2D_Op(OpTest):
input, self.ksize, self.strides, self.paddings, self.global_pool,
self.ceil_mode, self.exclusive, self.adaptive, self.data_format,
self.pool_type, self.padding_algorithm).astype(self.dtype)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)}
self.inputs = {'X': XPUOpTest.np_dtype_to_fluid_dtype(input)}
self.attrs = {
'strides': self.strides,
......@@ -284,18 +288,20 @@ class TestPool2D_Op(OpTest):
self.outputs = {'Out': output}
def has_xpu(self):
return core.is_compiled_with_xpu()
def test_check_output(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
if self.has_xpu():
place = core.XPUPlace(0)
self.check_output_with_place(place)
return
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, set(['X']), 'Out', max_relative_error=0.07)
if self.has_xpu():
place = core.XPUPlace(0)
self.check_grad_with_place(place, set(['X']), 'Out')
return
def init_data_format(self):
self.data_format = "NCHW"
......@@ -315,7 +321,7 @@ class TestPool2D_Op(OpTest):
self.use_cudnn = False
def init_data_type(self):
self.dtype = np.float64
self.dtype = np.float32
def init_pool_type(self):
self.pool_type = "avg"
......@@ -334,5 +340,134 @@ class TestPool2D_Op(OpTest):
self.adaptive = False
class TestCase1(TestPool2D_Op):
def init_test_case(self):
self.ksize = [3, 3]
self.strides = [1, 1]
def init_paddings(self):
self.paddings = [0, 0]
def init_pool_type(self):
self.pool_type = "avg"
self.pool2D_forward_naive = avg_pool2D_forward_naive
def init_global_pool(self):
self.global_pool = False
def init_shape(self):
self.shape = [2, 3, 7, 7]
class TestCase2(TestPool2D_Op):
def init_test_case(self):
self.ksize = [3, 3]
self.strides = [1, 1]
def init_paddings(self):
self.paddings = [1, 1]
def init_pool_type(self):
self.pool_type = "avg"
self.pool2D_forward_naive = avg_pool2D_forward_naive
def init_global_pool(self):
self.global_pool = False
def init_shape(self):
self.shape = [2, 3, 7, 7]
class TestCase3(TestPool2D_Op):
def init_pool_type(self):
self.pool_type = "max"
self.pool2D_forward_naive = max_pool2D_forward_naive
class TestCase4(TestCase1):
def init_pool_type(self):
self.pool_type = "max"
self.pool2D_forward_naive = max_pool2D_forward_naive
class TestCase5(TestCase2):
def init_pool_type(self):
self.pool_type = "max"
self.pool2D_forward_naive = max_pool2D_forward_naive
class TestPool2D_AsyPadding(TestPool2D_Op):
def init_test_case(self):
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 0, 1, 2]
def init_shape(self):
self.shape = [2, 3, 5, 5]
class TestCase1_AsyPadding(TestCase1):
def init_test_case(self):
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 0, 1, 0]
def init_shape(self):
self.shape = [2, 3, 7, 7]
class TestCase2_AsyPadding(TestCase2):
def init_test_case(self):
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 2, 1, 2]
def init_shape(self):
self.shape = [2, 3, 7, 7]
class TestCase3_AsyPadding(TestCase3):
def init_test_case(self):
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 0, 1, 2]
def init_shape(self):
self.shape = [2, 3, 5, 5]
class TestCase4_AsyPadding(TestCase4):
def init_test_case(self):
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 0, 1, 0]
def init_shape(self):
self.shape = [2, 3, 7, 7]
class TestCase5_AsyPadding((TestCase5)):
def init_test_case(self):
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [2, 2, 1, 2]
def init_shape(self):
self.shape = [2, 3, 7, 7]
class TestAvgInclude_AsyPadding(TestCase2):
def init_exclusive(self):
self.exclusive = False
def init_test_case(self):
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 2, 1, 2]
def init_shape(self):
self.shape = [2, 3, 7, 7]
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册