Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
c446ab7b
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c446ab7b
编写于
7月 14, 2022
作者:
Z
zhangyikun02
提交者:
GitHub
7月 14, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
bugfix for conv_op_xpu in NHWC data_formate and update xpu.cmake, test=kunlun (#44296)
上级
b7287d2b
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
91 addition
and
19 deletion
+91
-19
cmake/external/xpu.cmake
cmake/external/xpu.cmake
+2
-2
paddle/fluid/operators/conv_op_xpu.cc
paddle/fluid/operators/conv_op_xpu.cc
+58
-13
python/paddle/fluid/tests/unittests/xpu/test_conv2d_op_xpu.py
...on/paddle/fluid/tests/unittests/xpu/test_conv2d_op_xpu.py
+31
-4
未找到文件。
cmake/external/xpu.cmake
浏览文件 @
c446ab7b
...
@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so")
...
@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so")
if
(
NOT DEFINED XPU_BASE_URL
)
if
(
NOT DEFINED XPU_BASE_URL
)
set
(
XPU_BASE_URL_WITHOUT_DATE
set
(
XPU_BASE_URL_WITHOUT_DATE
"https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev"
)
"https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev"
)
set
(
XPU_BASE_URL
"
${
XPU_BASE_URL_WITHOUT_DATE
}
/202207
08
"
)
set
(
XPU_BASE_URL
"
${
XPU_BASE_URL_WITHOUT_DATE
}
/202207
12
"
)
else
()
else
()
set
(
XPU_BASE_URL
"
${
XPU_BASE_URL
}
"
)
set
(
XPU_BASE_URL
"
${
XPU_BASE_URL
}
"
)
endif
()
endif
()
...
@@ -19,7 +19,7 @@ endif()
...
@@ -19,7 +19,7 @@ endif()
if
(
NOT DEFINED XPU_XDNN_BASE_URL
)
if
(
NOT DEFINED XPU_XDNN_BASE_URL
)
set
(
XPU_XDNN_BASE_URL_WITHOUT_DATE
set
(
XPU_XDNN_BASE_URL_WITHOUT_DATE
"https://klx-sdk-release-public.su.bcebos.com/xdnn/dev"
)
"https://klx-sdk-release-public.su.bcebos.com/xdnn/dev"
)
set
(
XPU_XDNN_BASE_URL
"
${
XPU_XDNN_BASE_URL_WITHOUT_DATE
}
/202207
08
"
)
set
(
XPU_XDNN_BASE_URL
"
${
XPU_XDNN_BASE_URL_WITHOUT_DATE
}
/202207
12
"
)
else
()
else
()
set
(
XPU_XDNN_BASE_URL
"
${
XPU_XDNN_BASE_URL
}
"
)
set
(
XPU_XDNN_BASE_URL
"
${
XPU_XDNN_BASE_URL
}
"
)
endif
()
endif
()
...
...
paddle/fluid/operators/conv_op_xpu.cc
浏览文件 @
c446ab7b
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#ifdef PADDLE_WITH_XPU
#ifdef PADDLE_WITH_XPU
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -71,9 +72,26 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
...
@@ -71,9 +72,26 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
XPUT
*
output_data
=
reinterpret_cast
<
XPUT
*>
(
output
->
data
<
T
>
());
XPUT
*
output_data
=
reinterpret_cast
<
XPUT
*>
(
output
->
data
<
T
>
());
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
xpu
::
ctx_guard
RAII_GUARD
(
dev_ctx
.
x_context
());
XPUT
*
filter_data_tmp
;
const
XPUT
*
filter_data_ptr
=
filter_data
;
if
(
data_format
==
"NHWC"
)
{
filter_data_tmp
=
RAII_GUARD
.
alloc
<
XPUT
>
(
filter
.
numel
());
PADDLE_ENFORCE_XDNN_NOT_NULL
(
filter_data_tmp
);
std
::
vector
<
int
>
filter_shape
=
phi
::
vectorize
<
int
>
(
filter
.
dims
());
int
r
=
xpu
::
transpose
<
XPUT
>
(
dev_ctx
.
x_context
(),
filter_data
,
filter_data_tmp
,
filter_shape
,
{
0
,
2
,
3
,
1
});
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"transpose"
);
filter_data_ptr
=
reinterpret_cast
<
const
XPUT
*>
(
filter_data_tmp
);
}
int
r
=
xpu
::
conv2d
<
XPUT
,
XPUT
,
XPUT
,
int16_t
>
(
dev_ctx
.
x_context
(),
int
r
=
xpu
::
conv2d
<
XPUT
,
XPUT
,
XPUT
,
int16_t
>
(
dev_ctx
.
x_context
(),
input_data
,
input_data
,
filter_data
,
filter_data
_ptr
,
output_data
,
output_data
,
batch_size
,
batch_size
,
img_c
,
img_c
,
...
@@ -89,11 +107,7 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
...
@@ -89,11 +107,7 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
nullptr
,
nullptr
,
nullptr
,
nullptr
,
is_nchw
);
is_nchw
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"conv2d"
);
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU conv kernel return wrong value[%d %s]"
,
r
,
XPUAPIErrorMsg
[
r
]));
}
}
};
};
...
@@ -134,6 +148,7 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
...
@@ -134,6 +148,7 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
framework
::
DDim
filter_data_dims
=
framework
::
DDim
filter_data_dims
=
phi
::
slice_ddim
(
filter
.
dims
(),
2
,
filter
.
dims
().
size
());
phi
::
slice_ddim
(
filter
.
dims
(),
2
,
filter
.
dims
().
size
());
std
::
vector
<
int
>
ksize
=
phi
::
vectorize
<
int
>
(
filter_data_dims
);
std
::
vector
<
int
>
ksize
=
phi
::
vectorize
<
int
>
(
filter_data_dims
);
std
::
vector
<
int
>
filter_shape
=
phi
::
vectorize
<
int
>
(
filter
.
dims
());
UpdatePaddingAndDilation
(
UpdatePaddingAndDilation
(
&
paddings
,
&
dilations
,
padding_algorithm
,
in_data_dims
,
strides
,
ksize
);
&
paddings
,
&
dilations
,
padding_algorithm
,
in_data_dims
,
strides
,
ksize
);
...
@@ -165,12 +180,35 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
...
@@ -165,12 +180,35 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
filter_grad_data
=
reinterpret_cast
<
XPUT
*>
(
filter_grad
->
data
<
T
>
());
filter_grad_data
=
reinterpret_cast
<
XPUT
*>
(
filter_grad
->
data
<
T
>
());
}
}
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
xpu
::
ctx_guard
RAII_GUARD
(
dev_ctx
.
x_context
());
XPUT
*
filter_data_tmp
;
XPUT
*
filter_grad_data_tmp
;
const
XPUT
*
filter_data_ptr
=
filter_data
;
XPUT
*
filter_grad_data_ptr
=
filter_grad_data
;
if
(
data_format
==
"NHWC"
)
{
filter_data_tmp
=
RAII_GUARD
.
alloc
<
XPUT
>
(
filter
.
numel
());
PADDLE_ENFORCE_XDNN_NOT_NULL
(
filter_data_tmp
);
int
r
=
xpu
::
transpose
<
XPUT
>
(
dev_ctx
.
x_context
(),
filter_data
,
filter_data_tmp
,
filter_shape
,
{
0
,
2
,
3
,
1
});
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"transpose"
);
filter_data_ptr
=
reinterpret_cast
<
const
XPUT
*>
(
filter_data_tmp
);
if
(
filter_grad_data
!=
nullptr
)
{
filter_grad_data_tmp
=
RAII_GUARD
.
alloc
<
XPUT
>
(
filter
.
numel
());
PADDLE_ENFORCE_XDNN_NOT_NULL
(
filter_grad_data_tmp
);
filter_grad_data_ptr
=
filter_grad_data_tmp
;
}
}
int
r
=
xpu
::
conv2d_grad
<
XPUT
,
XPUT
,
XPUT
,
int16_t
>
(
dev_ctx
.
x_context
(),
int
r
=
xpu
::
conv2d_grad
<
XPUT
,
XPUT
,
XPUT
,
int16_t
>
(
dev_ctx
.
x_context
(),
input_data
,
input_data
,
filter_data
,
filter_data
_ptr
,
output_grad_data
,
output_grad_data
,
input_grad_data
,
input_grad_data
,
filter_grad_data
,
filter_grad_data
_ptr
,
batch_size
,
batch_size
,
img_c
,
img_c
,
img_h
,
img_h
,
...
@@ -187,11 +225,18 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
...
@@ -187,11 +225,18 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
nullptr
,
nullptr
,
nullptr
,
nullptr
,
is_nchw
);
is_nchw
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"conv2d_grad"
);
r
,
XPU_SUCCESS
,
if
((
filter_grad_data_ptr
!=
nullptr
)
&&
(
data_format
==
"NHWC"
))
{
platform
::
errors
::
External
(
std
::
vector
<
int
>
filter_shape_fhwc
=
{
"XPU conv kernel return wrong value[%d %s]"
,
r
,
XPUAPIErrorMsg
[
r
]));
filter_shape
[
0
],
filter_shape
[
2
],
filter_shape
[
3
],
filter_shape
[
1
]};
int
r
=
xpu
::
transpose
<
XPUT
>
(
dev_ctx
.
x_context
(),
filter_grad_data_ptr
,
filter_grad_data
,
filter_shape_fhwc
,
{
0
,
3
,
1
,
2
});
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"transpose"
);
}
}
}
};
};
}
// namespace operators
}
// namespace operators
...
...
python/paddle/fluid/tests/unittests/xpu/test_conv2d_op_xpu.py
浏览文件 @
c446ab7b
...
@@ -498,10 +498,41 @@ class XPUTestConv2DOp_v2(XPUOpTestWrapper):
...
@@ -498,10 +498,41 @@ class XPUTestConv2DOp_v2(XPUOpTestWrapper):
self
.
padding_algorithm
=
"EXPLICIT"
self
.
padding_algorithm
=
"EXPLICIT"
class
XPUTestConv2DOp_NHWC
(
XPUOpTestWrapper
):
def
__init__
(
self
):
self
.
op_name
=
'conv2d'
self
.
use_dynamic_create_class
=
False
class
TestConv2DOp_AsyPadding_NHWC
(
XPUTestConv2DOp_v2
.
TestConv2DOp_AsyPadding
):
def
init_data_format
(
self
):
self
.
data_format
=
"NHWC"
def
init_test_case_2
(
self
):
N
,
C
,
H
,
W
=
self
.
input_size
self
.
input_size
=
[
N
,
H
,
W
,
C
]
class
TestWithPad_AsyPadding_NHWC
(
XPUTestConv2DOp_v2
.
TestWithPad_AsyPadding
):
def
init_data_format
(
self
):
self
.
data_format
=
"NHWC"
def
init_test_case_2
(
self
):
N
,
C
,
H
,
W
=
self
.
input_size
self
.
input_size
=
[
N
,
H
,
W
,
C
]
support_types
=
get_xpu_op_support_types
(
'conv2d'
)
support_types
=
get_xpu_op_support_types
(
'conv2d'
)
for
stype
in
[
'float32'
]:
for
stype
in
[
'float32'
]:
create_test_class
(
globals
(),
XPUTestConv2DOp
,
stype
)
create_test_class
(
globals
(),
XPUTestConv2DOp
,
stype
)
create_test_class
(
globals
(),
XPUTestConv2DOp_v2
,
stype
)
create_test_class
(
globals
(),
XPUTestConv2DOp_v2
,
stype
)
create_test_class
(
globals
(),
XPUTestConv2DOp_NHWC
,
stype
,
ignore_deivce_version
=
[
core
.
XPUVersion
.
XPU1
])
#---------- test SAME VALID -----------
#---------- test SAME VALID -----------
#create_test_padding_SAME_class(TestConv2DOp_AsyPadding)
#create_test_padding_SAME_class(TestConv2DOp_AsyPadding)
...
@@ -512,9 +543,5 @@ for stype in ['float32']:
...
@@ -512,9 +543,5 @@ for stype in ['float32']:
#create_test_padding_VALID_class(TestWithPad_AsyPadding)
#create_test_padding_VALID_class(TestWithPad_AsyPadding)
#create_test_padding_VALID_class(TestWithStride_AsyPadding)
#create_test_padding_VALID_class(TestWithStride_AsyPadding)
# ------------ test channel last ---------
#create_test_channel_last_class(TestConv2DOp_AsyPadding)
#create_test_channel_last_class(TestWithPad_AsyPadding)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录