未验证 提交 c446ab7b 编写于 作者: Z zhangyikun02 提交者: GitHub

bugfix for conv_op_xpu in NHWC data_formate and update xpu.cmake, test=kunlun (#44296)

上级 b7287d2b
...@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so") ...@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so")
if(NOT DEFINED XPU_BASE_URL) if(NOT DEFINED XPU_BASE_URL)
set(XPU_BASE_URL_WITHOUT_DATE set(XPU_BASE_URL_WITHOUT_DATE
"https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220708") set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220712")
else() else()
set(XPU_BASE_URL "${XPU_BASE_URL}") set(XPU_BASE_URL "${XPU_BASE_URL}")
endif() endif()
...@@ -19,7 +19,7 @@ endif() ...@@ -19,7 +19,7 @@ endif()
if(NOT DEFINED XPU_XDNN_BASE_URL) if(NOT DEFINED XPU_XDNN_BASE_URL)
set(XPU_XDNN_BASE_URL_WITHOUT_DATE set(XPU_XDNN_BASE_URL_WITHOUT_DATE
"https://klx-sdk-release-public.su.bcebos.com/xdnn/dev") "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev")
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220708") set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220712")
else() else()
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}") set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}")
endif() endif()
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/platform/cudnn_workspace_helper.h" #include "paddle/fluid/platform/cudnn_workspace_helper.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -71,9 +72,26 @@ class GemmConvXPUKernel : public framework::OpKernel<T> { ...@@ -71,9 +72,26 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
XPUT *output_data = reinterpret_cast<XPUT *>(output->data<T>()); XPUT *output_data = reinterpret_cast<XPUT *>(output->data<T>());
auto &dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUT *filter_data_tmp;
const XPUT *filter_data_ptr = filter_data;
if (data_format == "NHWC") {
filter_data_tmp = RAII_GUARD.alloc<XPUT>(filter.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(filter_data_tmp);
std::vector<int> filter_shape = phi::vectorize<int>(filter.dims());
int r = xpu::transpose<XPUT>(dev_ctx.x_context(),
filter_data,
filter_data_tmp,
filter_shape,
{0, 2, 3, 1});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
filter_data_ptr = reinterpret_cast<const XPUT *>(filter_data_tmp);
}
int r = xpu::conv2d<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(), int r = xpu::conv2d<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
input_data, input_data,
filter_data, filter_data_ptr,
output_data, output_data,
batch_size, batch_size,
img_c, img_c,
...@@ -89,11 +107,7 @@ class GemmConvXPUKernel : public framework::OpKernel<T> { ...@@ -89,11 +107,7 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
nullptr, nullptr,
nullptr, nullptr,
is_nchw); is_nchw);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d");
r,
XPU_SUCCESS,
platform::errors::External(
"XPU conv kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r]));
} }
}; };
...@@ -134,6 +148,7 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> { ...@@ -134,6 +148,7 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
framework::DDim filter_data_dims = framework::DDim filter_data_dims =
phi::slice_ddim(filter.dims(), 2, filter.dims().size()); phi::slice_ddim(filter.dims(), 2, filter.dims().size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims); std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
std::vector<int> filter_shape = phi::vectorize<int>(filter.dims());
UpdatePaddingAndDilation( UpdatePaddingAndDilation(
&paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize); &paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize);
...@@ -165,12 +180,35 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> { ...@@ -165,12 +180,35 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
filter_grad_data = reinterpret_cast<XPUT *>(filter_grad->data<T>()); filter_grad_data = reinterpret_cast<XPUT *>(filter_grad->data<T>());
} }
auto &dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUT *filter_data_tmp;
XPUT *filter_grad_data_tmp;
const XPUT *filter_data_ptr = filter_data;
XPUT *filter_grad_data_ptr = filter_grad_data;
if (data_format == "NHWC") {
filter_data_tmp = RAII_GUARD.alloc<XPUT>(filter.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(filter_data_tmp);
int r = xpu::transpose<XPUT>(dev_ctx.x_context(),
filter_data,
filter_data_tmp,
filter_shape,
{0, 2, 3, 1});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
filter_data_ptr = reinterpret_cast<const XPUT *>(filter_data_tmp);
if (filter_grad_data != nullptr) {
filter_grad_data_tmp = RAII_GUARD.alloc<XPUT>(filter.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(filter_grad_data_tmp);
filter_grad_data_ptr = filter_grad_data_tmp;
}
}
int r = xpu::conv2d_grad<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(), int r = xpu::conv2d_grad<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
input_data, input_data,
filter_data, filter_data_ptr,
output_grad_data, output_grad_data,
input_grad_data, input_grad_data,
filter_grad_data, filter_grad_data_ptr,
batch_size, batch_size,
img_c, img_c,
img_h, img_h,
...@@ -187,11 +225,18 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> { ...@@ -187,11 +225,18 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
nullptr, nullptr,
nullptr, nullptr,
is_nchw); is_nchw);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad");
r,
XPU_SUCCESS, if ((filter_grad_data_ptr != nullptr) && (data_format == "NHWC")) {
platform::errors::External( std::vector<int> filter_shape_fhwc = {
"XPU conv kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r])); filter_shape[0], filter_shape[2], filter_shape[3], filter_shape[1]};
int r = xpu::transpose<XPUT>(dev_ctx.x_context(),
filter_grad_data_ptr,
filter_grad_data,
filter_shape_fhwc,
{0, 3, 1, 2});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
}
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -498,10 +498,41 @@ class XPUTestConv2DOp_v2(XPUOpTestWrapper): ...@@ -498,10 +498,41 @@ class XPUTestConv2DOp_v2(XPUOpTestWrapper):
self.padding_algorithm = "EXPLICIT" self.padding_algorithm = "EXPLICIT"
class XPUTestConv2DOp_NHWC(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'conv2d'
self.use_dynamic_create_class = False
class TestConv2DOp_AsyPadding_NHWC(
XPUTestConv2DOp_v2.TestConv2DOp_AsyPadding):
def init_data_format(self):
self.data_format = "NHWC"
def init_test_case_2(self):
N, C, H, W = self.input_size
self.input_size = [N, H, W, C]
class TestWithPad_AsyPadding_NHWC(XPUTestConv2DOp_v2.TestWithPad_AsyPadding
):
def init_data_format(self):
self.data_format = "NHWC"
def init_test_case_2(self):
N, C, H, W = self.input_size
self.input_size = [N, H, W, C]
support_types = get_xpu_op_support_types('conv2d') support_types = get_xpu_op_support_types('conv2d')
for stype in ['float32']: for stype in ['float32']:
create_test_class(globals(), XPUTestConv2DOp, stype) create_test_class(globals(), XPUTestConv2DOp, stype)
create_test_class(globals(), XPUTestConv2DOp_v2, stype) create_test_class(globals(), XPUTestConv2DOp_v2, stype)
create_test_class(globals(),
XPUTestConv2DOp_NHWC,
stype,
ignore_deivce_version=[core.XPUVersion.XPU1])
#---------- test SAME VALID ----------- #---------- test SAME VALID -----------
#create_test_padding_SAME_class(TestConv2DOp_AsyPadding) #create_test_padding_SAME_class(TestConv2DOp_AsyPadding)
...@@ -512,9 +543,5 @@ for stype in ['float32']: ...@@ -512,9 +543,5 @@ for stype in ['float32']:
#create_test_padding_VALID_class(TestWithPad_AsyPadding) #create_test_padding_VALID_class(TestWithPad_AsyPadding)
#create_test_padding_VALID_class(TestWithStride_AsyPadding) #create_test_padding_VALID_class(TestWithStride_AsyPadding)
# ------------ test channel last ---------
#create_test_channel_last_class(TestConv2DOp_AsyPadding)
#create_test_channel_last_class(TestWithPad_AsyPadding)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册