未验证 提交 c446ab7b 编写于 作者: Z zhangyikun02 提交者: GitHub

bugfix for conv_op_xpu in NHWC data_formate and update xpu.cmake, test=kunlun (#44296)

上级 b7287d2b
......@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so")
if(NOT DEFINED XPU_BASE_URL)
set(XPU_BASE_URL_WITHOUT_DATE
"https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220708")
set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220712")
else()
set(XPU_BASE_URL "${XPU_BASE_URL}")
endif()
......@@ -19,7 +19,7 @@ endif()
if(NOT DEFINED XPU_XDNN_BASE_URL)
set(XPU_XDNN_BASE_URL_WITHOUT_DATE
"https://klx-sdk-release-public.su.bcebos.com/xdnn/dev")
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220708")
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220712")
else()
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}")
endif()
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#ifdef PADDLE_WITH_XPU
namespace paddle {
namespace operators {
......@@ -71,9 +72,26 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
XPUT *output_data = reinterpret_cast<XPUT *>(output->data<T>());
auto &dev_ctx = context.template device_context<DeviceContext>();
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUT *filter_data_tmp;
const XPUT *filter_data_ptr = filter_data;
if (data_format == "NHWC") {
filter_data_tmp = RAII_GUARD.alloc<XPUT>(filter.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(filter_data_tmp);
std::vector<int> filter_shape = phi::vectorize<int>(filter.dims());
int r = xpu::transpose<XPUT>(dev_ctx.x_context(),
filter_data,
filter_data_tmp,
filter_shape,
{0, 2, 3, 1});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
filter_data_ptr = reinterpret_cast<const XPUT *>(filter_data_tmp);
}
int r = xpu::conv2d<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
input_data,
filter_data,
filter_data_ptr,
output_data,
batch_size,
img_c,
......@@ -89,11 +107,7 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
nullptr,
nullptr,
is_nchw);
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
platform::errors::External(
"XPU conv kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r]));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d");
}
};
......@@ -134,6 +148,7 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
framework::DDim filter_data_dims =
phi::slice_ddim(filter.dims(), 2, filter.dims().size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
std::vector<int> filter_shape = phi::vectorize<int>(filter.dims());
UpdatePaddingAndDilation(
&paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize);
......@@ -165,12 +180,35 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
filter_grad_data = reinterpret_cast<XPUT *>(filter_grad->data<T>());
}
auto &dev_ctx = context.template device_context<DeviceContext>();
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUT *filter_data_tmp;
XPUT *filter_grad_data_tmp;
const XPUT *filter_data_ptr = filter_data;
XPUT *filter_grad_data_ptr = filter_grad_data;
if (data_format == "NHWC") {
filter_data_tmp = RAII_GUARD.alloc<XPUT>(filter.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(filter_data_tmp);
int r = xpu::transpose<XPUT>(dev_ctx.x_context(),
filter_data,
filter_data_tmp,
filter_shape,
{0, 2, 3, 1});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
filter_data_ptr = reinterpret_cast<const XPUT *>(filter_data_tmp);
if (filter_grad_data != nullptr) {
filter_grad_data_tmp = RAII_GUARD.alloc<XPUT>(filter.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(filter_grad_data_tmp);
filter_grad_data_ptr = filter_grad_data_tmp;
}
}
int r = xpu::conv2d_grad<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
input_data,
filter_data,
filter_data_ptr,
output_grad_data,
input_grad_data,
filter_grad_data,
filter_grad_data_ptr,
batch_size,
img_c,
img_h,
......@@ -187,11 +225,18 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
nullptr,
nullptr,
is_nchw);
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
platform::errors::External(
"XPU conv kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r]));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad");
if ((filter_grad_data_ptr != nullptr) && (data_format == "NHWC")) {
std::vector<int> filter_shape_fhwc = {
filter_shape[0], filter_shape[2], filter_shape[3], filter_shape[1]};
int r = xpu::transpose<XPUT>(dev_ctx.x_context(),
filter_grad_data_ptr,
filter_grad_data,
filter_shape_fhwc,
{0, 3, 1, 2});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
}
}
};
} // namespace operators
......
......@@ -498,10 +498,41 @@ class XPUTestConv2DOp_v2(XPUOpTestWrapper):
self.padding_algorithm = "EXPLICIT"
class XPUTestConv2DOp_NHWC(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'conv2d'
self.use_dynamic_create_class = False
class TestConv2DOp_AsyPadding_NHWC(
XPUTestConv2DOp_v2.TestConv2DOp_AsyPadding):
def init_data_format(self):
self.data_format = "NHWC"
def init_test_case_2(self):
N, C, H, W = self.input_size
self.input_size = [N, H, W, C]
class TestWithPad_AsyPadding_NHWC(XPUTestConv2DOp_v2.TestWithPad_AsyPadding
):
def init_data_format(self):
self.data_format = "NHWC"
def init_test_case_2(self):
N, C, H, W = self.input_size
self.input_size = [N, H, W, C]
support_types = get_xpu_op_support_types('conv2d')
for stype in ['float32']:
create_test_class(globals(), XPUTestConv2DOp, stype)
create_test_class(globals(), XPUTestConv2DOp_v2, stype)
create_test_class(globals(),
XPUTestConv2DOp_NHWC,
stype,
ignore_deivce_version=[core.XPUVersion.XPU1])
#---------- test SAME VALID -----------
#create_test_padding_SAME_class(TestConv2DOp_AsyPadding)
......@@ -512,9 +543,5 @@ for stype in ['float32']:
#create_test_padding_VALID_class(TestWithPad_AsyPadding)
#create_test_padding_VALID_class(TestWithStride_AsyPadding)
# ------------ test channel last ---------
#create_test_channel_last_class(TestConv2DOp_AsyPadding)
#create_test_channel_last_class(TestWithPad_AsyPadding)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册