Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1323e5e7
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1323e5e7
编写于
1月 19, 2021
作者:
T
taixiurong
提交者:
GitHub
1月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Pd2.0 (#30532)
* support transformer v2.0 * fix range op crash in dygraph xpu place
上级
4875b972
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
838 addition
and
544 deletion
+838
-544
cmake/external/xpu.cmake
cmake/external/xpu.cmake
+1
-1
paddle/fluid/operators/layer_norm_op_xpu.cc
paddle/fluid/operators/layer_norm_op_xpu.cc
+13
-16
paddle/fluid/operators/matmul_op_xpu.cc
paddle/fluid/operators/matmul_op_xpu.cc
+90
-139
paddle/fluid/operators/matmul_v2_op_xpu.cc
paddle/fluid/operators/matmul_v2_op_xpu.cc
+159
-274
paddle/fluid/operators/one_hot_op_xpu.cc
paddle/fluid/operators/one_hot_op_xpu.cc
+1
-1
paddle/fluid/operators/one_hot_v2_op_xpu.cc
paddle/fluid/operators/one_hot_v2_op_xpu.cc
+70
-0
paddle/fluid/operators/range_op_xpu.cc
paddle/fluid/operators/range_op_xpu.cc
+69
-0
paddle/fluid/operators/scale_op_xpu.cc
paddle/fluid/operators/scale_op_xpu.cc
+7
-4
paddle/fluid/operators/softmax_op_xpu.cc
paddle/fluid/operators/softmax_op_xpu.cc
+15
-2
python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py
...paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py
+141
-107
python/paddle/fluid/tests/unittests/xpu/test_one_hot_v2_op_xpu.py
...addle/fluid/tests/unittests/xpu/test_one_hot_v2_op_xpu.py
+196
-0
python/paddle/fluid/tests/unittests/xpu/test_range_xpu.py
python/paddle/fluid/tests/unittests/xpu/test_range_xpu.py
+76
-0
未找到文件。
cmake/external/xpu.cmake
浏览文件 @
1323e5e7
...
@@ -10,7 +10,7 @@ if (WITH_AARCH64)
...
@@ -10,7 +10,7 @@ if (WITH_AARCH64)
elseif
(
WITH_SUNWAY
)
elseif
(
WITH_SUNWAY
)
SET
(
XPU_URL
"https://baidu-kunlun-public.su.bcebos.com/paddle_depence/sunway/xpu_2020_1227.tar.gz"
CACHE STRING
""
FORCE
)
SET
(
XPU_URL
"https://baidu-kunlun-public.su.bcebos.com/paddle_depence/sunway/xpu_2020_1227.tar.gz"
CACHE STRING
""
FORCE
)
else
()
else
()
SET
(
XPU_URL
"https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_01
05
.tar.gz"
CACHE STRING
""
FORCE
)
SET
(
XPU_URL
"https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_01
_13
.tar.gz"
CACHE STRING
""
FORCE
)
endif
()
endif
()
SET
(
XPU_SOURCE_DIR
"
${
THIRD_PARTY_PATH
}
/xpu"
)
SET
(
XPU_SOURCE_DIR
"
${
THIRD_PARTY_PATH
}
/xpu"
)
...
...
paddle/fluid/operators/layer_norm_op_xpu.cc
浏览文件 @
1323e5e7
...
@@ -45,15 +45,13 @@ class LayerNormXPUKernel : public framework::OpKernel<T> {
...
@@ -45,15 +45,13 @@ class LayerNormXPUKernel : public framework::OpKernel<T> {
auto
*
mean_data
=
mean
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
mean_data
=
mean
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
variance_data
=
variance
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
variance_data
=
variance
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
int
r
=
xpu
::
layer_norm
(
dev_ctx
.
x_context
(),
left
,
right
,
x_data
,
y_data
,
int
r
=
xpu
::
layer_norm
(
dev_ctx
.
x_context
(),
x_data
,
y_data
,
left
,
right
,
scale_data
,
bias_data
,
epsilon
,
mean_data
,
epsilon
,
scale_data
,
bias_data
,
mean_data
,
variance_data
,
false
);
variance_data
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
platform
::
errors
::
External
(
"XPU API(layer_norm) return wrong "
"XPU layer_norm kernel return wrong value[%d %s]"
,
r
,
"value[%d], please check whether Baidu "
XPUAPIErrorMsg
[
r
]));
"Kunlun Card is properly installed."
,
r
));
}
}
};
};
...
@@ -87,15 +85,14 @@ class LayerNormGradXPUKernel : public framework::OpKernel<T> {
...
@@ -87,15 +85,14 @@ class LayerNormGradXPUKernel : public framework::OpKernel<T> {
auto
*
dx_data
=
auto
*
dx_data
=
(
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
(
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
int
r
=
xpu
::
layer_norm_
backward
(
int
r
=
xpu
::
layer_norm_
grad
(
dev_ctx
.
x_context
(),
x_data
,
dy_data
,
dx_data
,
dev_ctx
.
x_context
(),
left
,
right
,
x_data
,
scale_data
,
variance
_data
,
left
,
right
,
epsilon
,
scale_data
,
mean
_data
,
mean_data
,
dy_data
,
dx_data
,
dscale_data
,
dbias_data
,
epsilon
);
variance_data
,
dscale_data
,
dbias_data
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU API(layer_norm_backward) return wrong "
platform
::
errors
::
External
(
"value[%d], please check whether Baidu "
"XPU layer_norm_grad kernel return wrong value[%d %s]"
,
r
,
"Kunlun Card is properly installed."
,
XPUAPIErrorMsg
[
r
]));
r
));
}
}
};
};
...
...
paddle/fluid/operators/matmul_op_xpu.cc
浏览文件 @
1323e5e7
...
@@ -24,6 +24,8 @@ limitations under the License. */
...
@@ -24,6 +24,8 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
framework
::
Tensor
;
static
framework
::
DDim
RowMatrixFromVector
(
const
framework
::
DDim
&
x_dim
)
{
static
framework
::
DDim
RowMatrixFromVector
(
const
framework
::
DDim
&
x_dim
)
{
if
(
x_dim
.
size
()
>
1
)
{
if
(
x_dim
.
size
()
>
1
)
{
return
x_dim
;
return
x_dim
;
...
@@ -97,26 +99,23 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor *x,
...
@@ -97,26 +99,23 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor *x,
ReshapeTensorIntoMatrixSequence
(
y
,
mat_dim_y
);
ReshapeTensorIntoMatrixSequence
(
y
,
mat_dim_y
);
}
}
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
T
,
typename
FC
T
>
class
MatMulXPUKernel
:
public
framework
::
OpKernel
<
T
>
{
static
void
MatMulXPUFunction
(
const
Tensor
*
x
,
const
Tensor
*
y
,
Tensor
*
out
,
public:
bool
trans_x
,
bool
trans_y
,
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
{
auto
*
x
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
const
auto
&
x_dims
=
x
->
dims
(
);
auto
*
y
=
context
.
Input
<
framework
::
Tensor
>
(
"Y"
);
const
auto
&
y_dims
=
y
->
dims
(
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
&
dev_ctx
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
()
);
ctx
.
template
device_context
<
paddle
::
platform
::
XPUDeviceContext
>(
);
auto
mat_dim_a
=
math
::
CreateMatrixDescriptor
(
auto
mat_dim_a
=
RowMatrixFromVector
(
x
->
dims
()),
0
,
context
.
Attr
<
bool
>
(
"transpose_X"
)
);
math
::
CreateMatrixDescriptor
(
RowMatrixFromVector
(
x_dims
),
0
,
trans_x
);
auto
mat_dim_b
=
auto
mat_dim_b
=
math
::
CreateMatrixDescriptor
(
ColumnMatrixFromVector
(
y
->
dims
()),
0
,
math
::
CreateMatrixDescriptor
(
ColumnMatrixFromVector
(
y_dims
),
0
,
trans_y
);
context
.
Attr
<
bool
>
(
"transpose_Y"
));
const
auto
&
x_dims
=
x
->
dims
();
const
auto
&
y_dims
=
y
->
dims
();
if
(
x_dims
.
size
()
==
3
&&
y_dims
.
size
()
<=
2
)
{
if
(
x_dims
.
size
()
==
3
&&
y_dims
.
size
()
<=
2
)
{
// if transpose_X is true, the transpose cost much time
// if transpose_X is true, the transpose cost much time
if
(
!
context
.
Attr
<
bool
>
(
"transpose_X"
)
)
{
if
(
!
trans_x
)
{
mat_dim_a
.
height_
*=
mat_dim_a
.
batch_size_
;
mat_dim_a
.
height_
*=
mat_dim_a
.
batch_size_
;
mat_dim_a
.
batch_size_
=
0
;
mat_dim_a
.
batch_size_
=
0
;
}
else
{
}
else
{
...
@@ -124,7 +123,6 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
...
@@ -124,7 +123,6 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
mat_dim_b
.
height_
=
mat_dim_b
.
height_
/
mat_dim_b
.
batch_size_
;
mat_dim_b
.
height_
=
mat_dim_b
.
height_
/
mat_dim_b
.
batch_size_
;
}
}
}
}
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
mat_dim_a
.
width_
,
mat_dim_b
.
height_
,
mat_dim_a
.
width_
,
mat_dim_b
.
height_
,
platform
::
errors
::
InvalidArgument
(
"Shape mistake in matmul_op, the "
platform
::
errors
::
InvalidArgument
(
"Shape mistake in matmul_op, the "
...
@@ -139,9 +137,9 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
...
@@ -139,9 +137,9 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
"tensor batch_size:%d, second "
"tensor batch_size:%d, second "
"tensor batch_size:%d"
,
"tensor batch_size:%d"
,
mat_dim_a
.
batch_size_
,
mat_dim_b
.
batch_size_
));
mat_dim_a
.
batch_size_
,
mat_dim_b
.
batch_size_
));
T
alpha
=
static_cast
<
T
>
(
context
.
Attr
<
float
>
(
"alpha"
));
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
T
alpha
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"alpha"
));
float
*
data_c
=
out
->
data
<
T
>
();
float
*
data_c
=
out
->
data
<
T
>
();
int
m
=
mat_dim_a
.
height_
;
int
m
=
mat_dim_a
.
height_
;
int
n
=
mat_dim_b
.
width_
;
int
n
=
mat_dim_b
.
width_
;
...
@@ -150,11 +148,12 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
...
@@ -150,11 +148,12 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
int
ldy
=
mat_dim_b
.
trans_
?
k
:
n
;
int
ldy
=
mat_dim_b
.
trans_
?
k
:
n
;
int
ldout
=
n
;
int
ldout
=
n
;
int
batch_size
=
mat_dim_a
.
batch_size_
;
int
batch_size
=
mat_dim_a
.
batch_size_
;
if
(
batch_size
==
0
||
batch_size
==
1
)
{
int
r
=
xpu
::
fc_fusion
<
float
,
float
,
float
,
int16_t
>
(
if
(
batch_size
==
0
)
{
int
r
=
xpu
::
fc_fusion
<
float
,
float
,
float
,
FCT
>
(
dev_ctx
.
x_context
(),
x
->
data
<
T
>
(),
y
->
data
<
T
>
(),
data_c
,
m
,
n
,
k
,
dev_ctx
.
x_context
(),
x
->
data
<
T
>
(),
y
->
data
<
T
>
(),
data_c
,
m
,
n
,
k
,
mat_dim_a
.
trans_
,
mat_dim_b
.
trans_
,
nullptr
,
nullptr
,
nullptr
,
ldx
,
mat_dim_a
.
trans_
,
mat_dim_b
.
trans_
,
nullptr
,
nullptr
,
nullptr
,
ldx
,
ldy
,
ldy
,
ldout
,
alpha
,
0
,
nullptr
,
xpu
::
Activation_t
::
LINEAR
);
ldout
,
alpha
,
0
,
nullptr
,
xpu
::
Activation_t
::
LINEAR
);
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
platform
::
errors
::
External
(
"XPU fc_fusion kernel return wrong value[%d %s]"
,
r
,
"XPU fc_fusion kernel return wrong value[%d %s]"
,
r
,
...
@@ -168,16 +167,33 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
...
@@ -168,16 +167,33 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
const
float
*
x_data
=
x
->
data
<
T
>
()
+
x_stride
*
i
;
const
float
*
x_data
=
x
->
data
<
T
>
()
+
x_stride
*
i
;
const
float
*
y_data
=
y
->
data
<
T
>
()
+
y_stride
*
i
;
const
float
*
y_data
=
y
->
data
<
T
>
()
+
y_stride
*
i
;
float
*
out_data
=
data_c
+
out_stride
*
i
;
float
*
out_data
=
data_c
+
out_stride
*
i
;
int
r
=
xpu
::
fc_fusion
<
float
,
float
,
float
,
int16_t
>
(
int
r
=
xpu
::
fc_fusion
<
float
,
float
,
float
,
FCT
>
(
dev_ctx
.
x_context
(),
x_data
,
y_data
,
out_data
,
m
,
n
,
k
,
dev_ctx
.
x_context
(),
x_data
,
y_data
,
out_data
,
m
,
n
,
k
,
mat_dim_a
.
trans_
,
mat_dim_b
.
trans_
,
nullptr
,
nullptr
,
nullptr
,
ldx
,
mat_dim_a
.
trans_
,
mat_dim_b
.
trans_
,
nullptr
,
nullptr
,
nullptr
,
ldx
,
ldy
,
ldout
,
alpha
,
0
,
nullptr
,
xpu
::
Activation_t
::
LINEAR
);
ldy
,
ldout
,
alpha
,
0
,
nullptr
,
xpu
::
Activation_t
::
LINEAR
);
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
platform
::
errors
::
External
(
"XPU fc_fusion kernel return wrong value[%d %s]"
,
"XPU fc_fusion kernel return wrong value[%d %s]"
,
r
,
r
,
XPUAPIErrorMsg
[
r
]));
XPUAPIErrorMsg
[
r
]));
}
}
}
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
MatMulXPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
x
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
y
=
context
.
Input
<
framework
::
Tensor
>
(
"Y"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
bool
trans_x
=
context
.
Attr
<
bool
>
(
"transpose_X"
);
bool
trans_y
=
context
.
Attr
<
bool
>
(
"transpose_Y"
);
if
(
std
::
getenv
(
"XPU_PADDLE_MAT_MUL_FCINT32"
)
!=
nullptr
)
{
MatMulXPUFunction
<
T
,
int32_t
>
(
x
,
y
,
out
,
trans_x
,
trans_y
,
context
);
}
else
{
MatMulXPUFunction
<
T
,
int16_t
>
(
x
,
y
,
out
,
trans_x
,
trans_y
,
context
);
}
}
}
};
};
...
@@ -244,75 +260,10 @@ class MatMulGradXPUKernel : public framework::OpKernel<T> {
...
@@ -244,75 +260,10 @@ class MatMulGradXPUKernel : public framework::OpKernel<T> {
const
framework
::
Tensor
&
b
,
bool
trans_b
,
const
framework
::
Tensor
&
b
,
bool
trans_b
,
framework
::
Tensor
*
out
)
const
{
framework
::
Tensor
*
out
)
const
{
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
mat_dim_a
=
math
::
CreateMatrixDescriptor
(
a
.
dims
(),
0
,
trans_a
);
if
(
std
::
getenv
(
"XPU_PADDLE_MAT_MUL_GRAD_FCINT32"
)
!=
nullptr
)
{
auto
mat_dim_b
=
math
::
CreateMatrixDescriptor
(
b
.
dims
(),
0
,
trans_b
);
MatMulXPUFunction
<
T
,
int32_t
>
(
&
a
,
&
b
,
out
,
trans_a
,
trans_b
,
context
);
const
auto
&
a_dims
=
a
.
dims
();
const
auto
&
b_dims
=
b
.
dims
();
if
(
a_dims
.
size
()
==
3
&&
b_dims
.
size
()
<=
2
)
{
// if transpose_X is true, the transpose cost much time
if
(
!
context
.
Attr
<
bool
>
(
"transpose_X"
))
{
mat_dim_a
.
height_
*=
mat_dim_a
.
batch_size_
;
mat_dim_a
.
batch_size_
=
0
;
}
else
{
}
else
{
mat_dim_b
.
batch_size_
=
mat_dim_a
.
batch_size_
;
MatMulXPUFunction
<
T
,
int16_t
>
(
&
a
,
&
b
,
out
,
trans_a
,
trans_b
,
context
);
mat_dim_b
.
height_
=
mat_dim_b
.
height_
/
mat_dim_b
.
batch_size_
;
}
}
PADDLE_ENFORCE_EQ
(
mat_dim_a
.
width_
,
mat_dim_b
.
height_
,
platform
::
errors
::
InvalidArgument
(
"Shape mistake in matmul_grad_op, the "
"first tensor width must be same as second tensor "
"height, but received "
"width:%d, height:%d"
,
mat_dim_a
.
width_
,
mat_dim_b
.
height_
));
PADDLE_ENFORCE_EQ
(
mat_dim_a
.
batch_size_
,
mat_dim_b
.
batch_size_
,
platform
::
errors
::
InvalidArgument
(
"Shape mistake in matmul_grad_op, the two input"
"tensor batch_size must be same, but received first "
"tensor batch_size:%d, second "
"tensor batch_size:%d"
,
mat_dim_a
.
batch_size_
,
mat_dim_b
.
batch_size_
));
T
alpha
=
static_cast
<
T
>
(
context
.
Attr
<
float
>
(
"alpha"
));
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
float
*
data_c
=
out
->
data
<
T
>
();
int
m
=
mat_dim_a
.
height_
;
int
n
=
mat_dim_b
.
width_
;
int
k
=
mat_dim_a
.
width_
;
int
ldx
=
mat_dim_a
.
trans_
?
m
:
k
;
int
ldy
=
mat_dim_b
.
trans_
?
k
:
n
;
int
ldout
=
n
;
int
batch_size
=
mat_dim_a
.
batch_size_
;
if
(
batch_size
==
0
||
batch_size
==
1
)
{
int
r
=
xpu
::
fc_fusion
<
float
,
float
,
float
,
int16_t
>
(
dev_ctx
.
x_context
(),
a
.
data
<
T
>
(),
b
.
data
<
T
>
(),
data_c
,
m
,
n
,
k
,
mat_dim_a
.
trans_
,
mat_dim_b
.
trans_
,
nullptr
,
nullptr
,
nullptr
,
ldx
,
ldy
,
ldout
,
alpha
,
0
,
nullptr
,
xpu
::
Activation_t
::
LINEAR
);
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU fc_fusion kernel return wrong value[%d %s]"
,
r
,
XPUAPIErrorMsg
[
r
]));
}
else
{
// batch matmul
int
x_stride
=
mat_dim_a
.
stride_
;
int
y_stride
=
mat_dim_b
.
stride_
;
int
out_stride
=
m
*
n
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
const
float
*
x_data
=
a
.
data
<
T
>
()
+
x_stride
*
i
;
const
float
*
y_data
=
b
.
data
<
T
>
()
+
y_stride
*
i
;
float
*
out_data
=
data_c
+
out_stride
*
i
;
int
r
=
xpu
::
fc_fusion
<
float
,
float
,
float
,
int16_t
>
(
dev_ctx
.
x_context
(),
x_data
,
y_data
,
out_data
,
m
,
n
,
k
,
mat_dim_a
.
trans_
,
mat_dim_b
.
trans_
,
nullptr
,
nullptr
,
nullptr
,
ldx
,
ldy
,
ldout
,
alpha
,
0
,
nullptr
,
xpu
::
Activation_t
::
LINEAR
);
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU fc_fusion kernel return wrong value[%d %s]"
,
r
,
XPUAPIErrorMsg
[
r
]));
}
}
}
}
}
...
...
paddle/fluid/operators/matmul_v2_op_xpu.cc
浏览文件 @
1323e5e7
...
@@ -21,211 +21,141 @@
...
@@ -21,211 +21,141 @@
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
T
>
template
<
typename
T
,
typename
FCT
>
void
MatMulXPUFunction
(
const
Tensor
*
X
,
const
Tensor
*
Y
,
static
void
MatMulXPUFunction
(
const
Tensor
*
x
,
const
Tensor
*
y
,
Tensor
*
out
,
const
std
::
vector
<
std
::
int64_t
>&
x_dims
,
const
std
::
vector
<
std
::
int64_t
>&
y_dims
,
Tensor
*
Out
,
bool
trans_x
,
bool
trans_y
,
bool
trans_x
,
bool
trans_y
,
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
{
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
{
const
int
x_ndim
=
x_dims
.
size
();
const
auto
&
x_dims
=
x
->
dims
();
const
int
y_ndim
=
y_dims
.
size
();
const
auto
&
y_dims
=
y
->
dims
();
auto
&
dev_ctx
=
auto
&
dev_ctx
=
ctx
.
template
device_context
<
paddle
::
platform
::
XPUDeviceContext
>();
ctx
.
template
device_context
<
paddle
::
platform
::
XPUDeviceContext
>();
// currently only support x_ndim == y_dim and non-broadcast case
auto
mat_dim_a
=
PADDLE_ENFORCE_EQ
(
x_ndim
,
y_ndim
,
platform
::
errors
::
InvalidArgument
(
math
::
CreateMatrixDescriptor
(
RowMatrixFromVector
(
x_dims
),
0
,
trans_x
);
"Shape mistake in matmul_v2_op"
));
auto
mat_dim_b
=
for
(
int
i
=
0
;
i
<
x_ndim
-
2
;
i
++
)
{
math
::
CreateMatrixDescriptor
(
ColumnMatrixFromVector
(
y_dims
),
0
,
trans_y
);
PADDLE_ENFORCE_EQ
(
x_dims
.
data
()[
i
],
y_dims
.
data
()[
i
],
platform
::
errors
::
InvalidArgument
(
"Shape mistake in matmul_v2_op"
));
}
int
ret
=
0
;
if
(
x_ndim
==
1
&&
y_ndim
==
1
)
{
PADDLE_ENFORCE_EQ
(
X
->
numel
(),
Y
->
numel
(),
platform
::
errors
::
InvalidArgument
(
"X's numbers is not equal to Y's numbers,"
"when X/Y's dims =1"
));
VLOG
(
3
)
<<
"MatMul's case 1"
;
Out
->
Resize
({
1
});
Out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ret
=
baidu
::
xpu
::
api
::
fc_int16
(
dev_ctx
.
x_context
(),
false
,
false
,
1
,
1
,
X
->
numel
(),
1.0
f
,
X
->
data
<
T
>
(),
Y
->
data
<
T
>
(),
0.0
f
,
Out
->
data
<
T
>
());
PADDLE_ENFORCE_EQ
(
ret
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU API return wrong value[%d] in matmul_v2, please check whether "
"Baidu Kunlun Card is properly installed."
,
ret
));
return
;
}
if
(
x_ndim
==
1
)
{
const
int
N
=
X
->
numel
();
if
(
trans_y
)
{
PADDLE_ENFORCE_EQ
(
y_dims
[
y_ndim
-
1
],
N
,
platform
::
errors
::
InvalidArgument
(
"Input(Y) has error dim."
));
}
else
{
PADDLE_ENFORCE_EQ
(
y_dims
[
y_ndim
-
2
],
N
,
platform
::
errors
::
InvalidArgument
(
"Input(Y) has error dim."
));
}
std
::
vector
<
std
::
int64_t
>
out_dims
(
y_ndim
-
1
);
if
(
trans_y
)
{
std
::
copy_n
(
y_dims
.
cbegin
(),
y_ndim
-
1
,
out_dims
.
begin
());
}
else
{
std
::
copy_n
(
y_dims
.
cbegin
(),
y_ndim
-
2
,
out_dims
.
begin
());
out_dims
.
back
()
=
y_dims
.
back
();
}
Out
->
Resize
(
framework
::
make_ddim
(
out_dims
));
Out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
trans_y
)
{
const
int
M
=
Y
->
numel
()
/
N
;
VLOG
(
3
)
<<
"MatMul's case 2"
;
ret
=
baidu
::
xpu
::
api
::
fc_int16
(
dev_ctx
.
x_context
(),
false
,
true
,
1
,
M
,
N
,
1.0
f
,
X
->
data
<
T
>
(),
Y
->
data
<
T
>
(),
0.0
f
,
Out
->
data
<
T
>
());
PADDLE_ENFORCE_EQ
(
ret
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU API return wrong value[%d] in "
"matmul_v2, please check whether "
"Baidu Kunlun Card is properly installed."
,
ret
));
}
else
{
const
int
M
=
y_dims
[
y_ndim
-
1
];
const
int
batch_size
=
Y
->
numel
()
/
(
M
*
N
);
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
ret
=
baidu
::
xpu
::
api
::
fc_int16
(
dev_ctx
.
x_context
(),
false
,
false
,
1
,
M
,
N
,
1.0
f
,
X
->
data
<
T
>
(),
Y
->
data
<
T
>
()
+
i
*
M
*
N
,
0.0
f
,
Out
->
data
<
T
>
()
+
i
*
M
);
PADDLE_ENFORCE_EQ
(
ret
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU API return wrong value[%d] in matmul_v2, "
"please check whether "
"Baidu Kunlun Card is properly installed."
,
ret
));
}
}
return
;
}
if
(
y_ndim
==
1
)
{
if
(
x_dims
.
size
()
==
3
&&
y_dims
.
size
()
<=
2
)
{
const
int
N
=
Y
->
numel
();
// if transpose_X is true, the transpose cost much time
if
(
trans_x
)
{
if
(
!
trans_x
)
{
PADDLE_ENFORCE_EQ
(
mat_dim_a
.
height_
*=
mat_dim_a
.
batch_size_
;
x_dims
[
x_ndim
-
2
],
N
,
mat_dim_a
.
batch_size_
=
0
;
platform
::
errors
::
InvalidArgument
(
"Input(X) has error dim."
));
}
else
{
}
else
{
PADDLE_ENFORCE_EQ
(
mat_dim_b
.
batch_size_
=
mat_dim_a
.
batch_size_
;
x_dims
[
x_ndim
-
1
],
N
,
mat_dim_b
.
height_
=
mat_dim_b
.
height_
/
mat_dim_b
.
batch_size_
;
platform
::
errors
::
InvalidArgument
(
"Input(X) has error dim."
));
}
}
std
::
vector
<
std
::
int64_t
>
out_dims
(
x_ndim
-
1
);
if
(
trans_x
)
{
std
::
copy_n
(
x_dims
.
cbegin
(),
x_ndim
-
2
,
out_dims
.
begin
());
out_dims
.
back
()
=
x_dims
.
back
();
}
else
{
std
::
copy_n
(
x_dims
.
cbegin
(),
x_ndim
-
1
,
out_dims
.
begin
());
}
}
Out
->
Resize
(
framework
::
make_ddim
(
out_dims
));
Out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
trans_x
)
{
if
(
mat_dim_a
.
width_
==
mat_dim_b
.
height_
)
{
const
int
M
=
x_dims
[
x_ndim
-
1
];
if
(
mat_dim_a
.
batch_size_
==
0
&&
mat_dim_b
.
batch_size_
==
1
)
{
const
int
batch_size
=
X
->
numel
()
/
(
M
*
N
);
mat_dim_a
.
batch_size_
=
mat_dim_b
.
batch_size_
=
0
;
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
ret
=
baidu
::
xpu
::
api
::
fc_int16
(
dev_ctx
.
x_context
(),
true
,
false
,
M
,
1
,
N
,
1.0
f
,
X
->
data
<
T
>
()
+
i
*
M
*
N
,
Y
->
data
<
T
>
(),
0.0
f
,
Out
->
data
<
T
>
()
+
i
*
M
);
PADDLE_ENFORCE_EQ
(
ret
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU API return wrong value[%d] in matmul_v2, "
"please check whether "
"Baidu Kunlun Card is properly installed."
,
ret
));
}
}
}
else
{
if
(
mat_dim_a
.
batch_size_
==
1
&&
mat_dim_b
.
batch_size_
==
0
)
{
const
int
M
=
X
->
numel
()
/
N
;
mat_dim_a
.
batch_size_
=
mat_dim_b
.
batch_size_
=
0
;
VLOG
(
3
)
<<
"MatMul's case 7"
;
ret
=
baidu
::
xpu
::
api
::
fc_int16
(
dev_ctx
.
x_context
(),
false
,
false
,
M
,
1
,
N
,
1.0
f
,
X
->
data
<
T
>
(),
Y
->
data
<
T
>
(),
0.0
f
,
Out
->
data
<
T
>
());
PADDLE_ENFORCE_EQ
(
ret
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU API return wrong value[%d] in "
"matmul_v2, please check whether "
"Baidu Kunlun Card is properly installed."
,
ret
));
}
}
return
;
}
}
const
int
M
=
trans_x
?
x_dims
[
x_ndim
-
1
]
:
x_dims
[
x_ndim
-
2
];
PADDLE_ENFORCE_EQ
(
mat_dim_a
.
width_
,
mat_dim_b
.
height_
,
const
int
K
=
trans_x
?
x_dims
[
x_ndim
-
2
]
:
x_dims
[
x_ndim
-
1
];
platform
::
errors
::
InvalidArgument
(
if
(
trans_y
)
{
"Shape mistake in matmul_v2_op xdims = %s ydims = %s"
,
PADDLE_ENFORCE_EQ
(
y_dims
[
y_ndim
-
1
],
K
,
platform
::
errors
::
InvalidArgument
(
x_dims
.
to_str
(),
y_dims
.
to_str
()));
"Input(X) has error dim."
));
PADDLE_ENFORCE_EQ
(
mat_dim_a
.
batch_size_
,
mat_dim_b
.
batch_size_
,
platform
::
errors
::
InvalidArgument
(
"Shape mistake in matmul_v2_op xdims = %s ydims = %s"
,
x_dims
.
to_str
(),
y_dims
.
to_str
()));
float
*
data_c
=
out
->
data
<
T
>
();
int
m
=
mat_dim_a
.
height_
;
int
n
=
mat_dim_b
.
width_
;
int
k
=
mat_dim_a
.
width_
;
int
batch_size
=
mat_dim_a
.
batch_size_
;
if
(
batch_size
==
0
)
{
int
r
=
xpu
::
fc
<
float
,
float
,
float
,
FCT
>
(
dev_ctx
.
x_context
(),
x
->
data
<
T
>
(),
y
->
data
<
T
>
(),
data_c
,
m
,
n
,
k
,
mat_dim_a
.
trans_
,
mat_dim_b
.
trans_
,
nullptr
,
nullptr
,
nullptr
);
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU fc_fusion kernel return wrong value[%d %s]"
,
r
,
XPUAPIErrorMsg
[
r
]));
}
else
{
}
else
{
PADDLE_ENFORCE_EQ
(
y_dims
[
y_ndim
-
2
],
K
,
platform
::
errors
::
InvalidArgument
(
// batch matmul
"Input(X) has error dim."
));
int
x_stride
=
mat_dim_a
.
stride_
;
int
y_stride
=
mat_dim_b
.
stride_
;
int
out_stride
=
m
*
n
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
const
float
*
x_data
=
x
->
data
<
T
>
()
+
x_stride
*
i
;
const
float
*
y_data
=
y
->
data
<
T
>
()
+
y_stride
*
i
;
float
*
out_data
=
data_c
+
out_stride
*
i
;
int
r
=
xpu
::
fc
<
float
,
float
,
float
,
FCT
>
(
dev_ctx
.
x_context
(),
x_data
,
y_data
,
out_data
,
m
,
n
,
k
,
mat_dim_a
.
trans_
,
mat_dim_b
.
trans_
,
nullptr
,
nullptr
,
nullptr
);
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU fc_fusion kernel return wrong value[%d %s]"
,
r
,
XPUAPIErrorMsg
[
r
]));
}
}
const
int
N
=
trans_y
?
y_dims
[
y_ndim
-
2
]
:
y_dims
[
y_ndim
-
1
];
const
int
ndim
=
(
std
::
max
)(
x_ndim
,
y_ndim
);
std
::
vector
<
std
::
int64_t
>
out_broadcast_dims
(
ndim
);
int
batch_size
=
1
;
for
(
int
i
=
0
;
i
<
ndim
-
2
;
i
++
)
{
PADDLE_ENFORCE_EQ
(
x_dims
.
data
()[
i
],
y_dims
.
data
()[
i
],
platform
::
errors
::
InvalidArgument
(
"Shape mistake in matmul_v2_op"
));
out_broadcast_dims
[
i
]
=
x_dims
.
data
()[
i
];
batch_size
*=
x_dims
.
data
()[
i
];
}
}
out_broadcast_dims
[
ndim
-
2
]
=
M
;
out_broadcast_dims
[
ndim
-
1
]
=
N
;
Out
->
Resize
(
framework
::
make_ddim
(
out_broadcast_dims
));
Out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ret
=
baidu
::
xpu
::
api
::
batched_gemm_int16
(
dev_ctx
.
x_context
(),
trans_x
,
trans_y
,
batch_size
,
M
,
N
,
K
,
1.0
f
,
X
->
data
<
T
>
(),
Y
->
data
<
T
>
(),
Out
->
data
<
T
>
(),
nullptr
,
nullptr
);
PADDLE_ENFORCE_EQ
(
ret
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU API return wrong value[%d] in matmul_v2, please check whether "
"Baidu Kunlun Card is properly installed."
,
ret
));
}
}
template
<
typename
T
>
template
<
typename
T
>
class
MatMulV2XPUKernel
:
public
framework
::
OpKernel
<
T
>
{
class
MatMulV2XPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
Y
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
auto
*
y
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
auto
*
O
ut
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
*
o
ut
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
bool
trans_x
=
ctx
.
Attr
<
bool
>
(
"trans_x"
);
bool
trans_x
=
ctx
.
Attr
<
bool
>
(
"trans_x"
);
bool
trans_y
=
ctx
.
Attr
<
bool
>
(
"trans_y"
);
bool
trans_y
=
ctx
.
Attr
<
bool
>
(
"trans_y"
);
MatMulXPUFunction
<
T
>
(
X
,
Y
,
vectorize
(
X
->
dims
()),
vectorize
(
Y
->
dims
()),
Out
,
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
trans_x
,
trans_y
,
ctx
);
if
(
std
::
getenv
(
"XPU_PADDLE_MAT_MUL_V2_FCINT32"
)
!=
nullptr
)
{
MatMulXPUFunction
<
T
,
int32_t
>
(
x
,
y
,
out
,
trans_x
,
trans_y
,
ctx
);
}
else
{
MatMulXPUFunction
<
T
,
int16_t
>
(
x
,
y
,
out
,
trans_x
,
trans_y
,
ctx
);
}
}
}
};
};
template
<
typename
DeviceContext
,
typename
T
>
static
framework
::
Tensor
XPUFoldHeadAndLastDims
(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
)
{
auto
in_dims
=
input
.
dims
();
if
(
in_dims
.
size
()
!=
3
)
{
return
input
;
}
framework
::
Tensor
output
;
output
.
Resize
({
in_dims
[
1
],
in_dims
[
0
],
in_dims
[
2
]});
output
.
mutable_data
<
T
>
(
context
.
GetPlace
());
std
::
vector
<
int
>
in_shape_host
=
{
static_cast
<
int
>
(
in_dims
[
0
]),
static_cast
<
int
>
(
in_dims
[
1
]),
static_cast
<
int
>
(
in_dims
[
2
])};
std
::
vector
<
int
>
axis_host
=
{
1
,
0
,
2
};
int
r
=
xpu
::
transpose
(
context
.
x_context
(),
input
.
data
<
T
>
(),
output
.
data
<
T
>
(),
in_shape_host
.
data
(),
axis_host
.
data
(),
/*ndims=*/
3
);
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU transpose kernel return wrong value[%d %s]"
,
r
,
XPUAPIErrorMsg
[
r
]));
output
.
Resize
({
in_dims
[
1
],
in_dims
[
0
]
*
in_dims
[
2
]});
return
output
;
}
template
<
typename
T
>
template
<
typename
T
>
class
MatMulV2XPUGradKernel
:
public
framework
::
OpKernel
<
T
>
{
class
MatMulV2XPUGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
MatMul
(
const
framework
::
ExecutionContext
&
c
ontext
,
void
MatMul
(
const
framework
::
ExecutionContext
&
c
tx
,
const
framework
::
Tensor
&
a
,
bool
trans_a
,
const
framework
::
Tensor
&
a
,
bool
trans_a
,
const
framework
::
Tensor
&
b
,
bool
trans_b
,
const
framework
::
Tensor
&
b
,
bool
trans_b
,
framework
::
Tensor
*
out
)
const
{
framework
::
Tensor
*
out
)
const
{
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
MatMulXPUFunction
<
T
>
(
&
a
,
&
b
,
vectorize
(
a
.
dims
()),
vectorize
(
b
.
dims
()),
out
,
if
(
std
::
getenv
(
"XPU_PADDLE_MAT_MUL_GRAD_V2_FCINT32"
)
!=
nullptr
)
{
trans_a
,
trans_b
,
context
);
MatMulXPUFunction
<
T
,
int32_t
>
(
&
a
,
&
b
,
out
,
trans_a
,
trans_b
,
ctx
);
}
else
{
MatMulXPUFunction
<
T
,
int16_t
>
(
&
a
,
&
b
,
out
,
trans_a
,
trans_b
,
ctx
);
}
}
}
void
CalcInputGrad
(
const
framework
::
ExecutionContext
&
context
,
void
CalcInputGrad
(
const
framework
::
ExecutionContext
&
context
,
...
@@ -239,79 +169,33 @@ class MatMulV2XPUGradKernel : public framework::OpKernel<T> {
...
@@ -239,79 +169,33 @@ class MatMulV2XPUGradKernel : public framework::OpKernel<T> {
if
(
!
need_combine
)
{
if
(
!
need_combine
)
{
MatMul
(
context
,
a
,
trans_a
,
b
,
trans_b
,
out
);
MatMul
(
context
,
a
,
trans_a
,
b
,
trans_b
,
out
);
}
else
{
}
else
{
// currently not support this case
}
}
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
bool
transpose_x
=
ctx
.
Attr
<
bool
>
(
"trans_x"
);
bool
transpose_y
=
ctx
.
Attr
<
bool
>
(
"trans_y"
);
auto
x
=
*
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
y
=
*
ctx
.
Input
<
framework
::
Tensor
>
(
"Y"
);
auto
dout
=
*
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
// get dims
std
::
vector
<
std
::
int64_t
>
x_dims
=
vectorize
(
x
.
dims
());
std
::
vector
<
std
::
int64_t
>
y_dims
=
vectorize
(
y
.
dims
());
std
::
vector
<
std
::
int64_t
>
dout_dims
=
vectorize
(
dout
.
dims
());
int
x_ndim
=
x_dims
.
size
();
int
y_ndim
=
y_dims
.
size
();
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dy
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
&
dev_ctx
=
auto
&
dev_ctx
=
ctx
.
template
device_context
<
paddle
::
platform
::
XPUDeviceContext
>();
context
.
template
device_context
<
paddle
::
platform
::
XPUDeviceContext
>();
// Case1 : x's or y's dim = 1
MatMul
(
int
ret
=
0
;
context
,
if
(
x_ndim
==
1
&&
y_ndim
==
1
)
{
is_fold_init_dims_a
if
(
dx
)
{
?
FoldInitDims
(
a
)
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
:
XPUFoldHeadAndLastDims
<
paddle
::
platform
::
XPUDeviceContext
,
T
>
(
ret
=
baidu
::
xpu
::
api
::
fc_int16
(
dev_ctx
.
x_context
(),
false
,
false
,
dev_ctx
,
a
),
dx
->
numel
(),
1
,
1
,
1.0
f
,
y
.
data
<
T
>
(),
trans_a
,
dout
.
data
<
T
>
(),
0.0
f
,
dx
->
data
<
T
>
());
is_fold_init_dims_b
PADDLE_ENFORCE_EQ
(
ret
,
XPU_SUCCESS
,
?
FoldInitDims
(
b
)
platform
::
errors
::
External
(
:
XPUFoldHeadAndLastDims
<
paddle
::
platform
::
XPUDeviceContext
,
T
>
(
"XPU API return wrong value[%d] in "
dev_ctx
,
b
),
"matmul_v2_grad, please check whether "
trans_b
,
out
);
"Baidu Kunlun Card is properly installed."
,
}
ret
));
}
}
if
(
dy
)
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
bool
transpose_x
=
context
.
Attr
<
bool
>
(
"trans_x"
);
ret
=
baidu
::
xpu
::
api
::
fc_int16
(
dev_ctx
.
x_context
(),
false
,
false
,
bool
transpose_y
=
context
.
Attr
<
bool
>
(
"trans_y"
);
dy
->
numel
(),
1
,
1
,
1.0
f
,
x
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
0.0
f
,
dy
->
data
<
T
>
());
auto
x
=
*
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
PADDLE_ENFORCE_EQ
(
ret
,
XPU_SUCCESS
,
auto
y
=
*
context
.
Input
<
framework
::
Tensor
>
(
"Y"
);
platform
::
errors
::
External
(
auto
dout
=
"XPU API return wrong value[%d] in "
*
context
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
"matmul_v2_grad, please check whether "
auto
*
dx
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
"Baidu Kunlun Card is properly installed."
,
auto
*
dy
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
ret
));
}
return
;
}
bool
is_broadcast
=
true
;
if
(
x_ndim
<=
2
||
y_ndim
<=
2
)
{
is_broadcast
=
false
;
}
else
if
(
x_ndim
!=
y_ndim
)
{
is_broadcast
=
true
;
}
else
{
is_broadcast
=
!
std
::
equal
(
x_dims
.
cbegin
(),
x_dims
.
cbegin
()
+
x_ndim
-
2
,
y_dims
.
cbegin
());
}
// currently only support non-broadcast case
PADDLE_ENFORCE_EQ
(
is_broadcast
,
false
,
platform
::
errors
::
InvalidArgument
(
"Shape mistake in matmul_v2_op"
));
// Case2: no broadcast or no batch size, it aims to speed and it is same as
// matmul in old version.
if
(
!
is_broadcast
)
{
ReshapeXYOutIntoMatrixSequence
(
&
x
,
&
y
,
&
dout
,
transpose_x
,
transpose_y
);
ReshapeXYOutIntoMatrixSequence
(
&
x
,
&
y
,
&
dout
,
transpose_x
,
transpose_y
);
framework
::
DDim
dx_dims
;
framework
::
DDim
dx_dims
;
if
(
dx
)
{
if
(
dx
)
{
...
@@ -328,18 +212,19 @@ class MatMulV2XPUGradKernel : public framework::OpKernel<T> {
...
@@ -328,18 +212,19 @@ class MatMulV2XPUGradKernel : public framework::OpKernel<T> {
dy
->
Resize
(
y
.
dims
());
dy
->
Resize
(
y
.
dims
());
}
}
}
}
if
(
transpose_x
&&
transpose_y
)
{
if
(
transpose_x
&&
transpose_y
)
{
CalcInputGrad
(
ctx
,
y
,
true
,
true
,
dout
,
true
,
false
,
dx
);
CalcInputGrad
(
context
,
y
,
true
,
true
,
dout
,
true
,
false
,
dx
);
CalcInputGrad
(
ctx
,
dout
,
true
,
true
,
x
,
true
,
false
,
dy
);
CalcInputGrad
(
context
,
dout
,
true
,
true
,
x
,
true
,
false
,
dy
);
}
else
if
(
transpose_x
)
{
}
else
if
(
transpose_x
)
{
CalcInputGrad
(
ctx
,
y
,
false
,
false
,
dout
,
true
,
false
,
dx
);
CalcInputGrad
(
context
,
y
,
false
,
false
,
dout
,
true
,
false
,
dx
);
CalcInputGrad
(
ctx
,
x
,
false
,
false
,
dout
,
false
,
true
,
dy
);
CalcInputGrad
(
context
,
x
,
false
,
false
,
dout
,
false
,
true
,
dy
);
}
else
if
(
transpose_y
)
{
}
else
if
(
transpose_y
)
{
CalcInputGrad
(
ctx
,
dout
,
false
,
false
,
y
,
false
,
true
,
dx
);
CalcInputGrad
(
context
,
dout
,
false
,
false
,
y
,
false
,
true
,
dx
);
CalcInputGrad
(
ctx
,
dout
,
true
,
true
,
x
,
false
,
true
,
dy
);
CalcInputGrad
(
context
,
dout
,
true
,
true
,
x
,
false
,
true
,
dy
);
}
else
{
}
else
{
CalcInputGrad
(
ctx
,
dout
,
false
,
false
,
y
,
true
,
false
,
dx
);
CalcInputGrad
(
context
,
dout
,
false
,
false
,
y
,
true
,
false
,
dx
);
CalcInputGrad
(
ctx
,
x
,
true
,
true
,
dout
,
false
,
true
,
dy
);
CalcInputGrad
(
context
,
x
,
true
,
true
,
dout
,
false
,
true
,
dy
);
}
}
if
(
dx
)
{
if
(
dx
)
{
...
@@ -347,13 +232,13 @@ class MatMulV2XPUGradKernel : public framework::OpKernel<T> {
...
@@ -347,13 +232,13 @@ class MatMulV2XPUGradKernel : public framework::OpKernel<T> {
dx
->
Resize
(
dx_dims
);
dx
->
Resize
(
dx_dims
);
}
}
}
}
if
(
dy
)
{
if
(
dy
)
{
if
(
dy_dims
!=
y
.
dims
())
{
if
(
dy_dims
!=
y
.
dims
())
{
dy
->
Resize
(
dy_dims
);
dy
->
Resize
(
dy_dims
);
}
}
}
}
}
}
}
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/one_hot_op_xpu.cc
浏览文件 @
1323e5e7
...
@@ -35,7 +35,7 @@ class OneHotXPUKernel : public framework::OpKernel<T> {
...
@@ -35,7 +35,7 @@ class OneHotXPUKernel : public framework::OpKernel<T> {
if
(
context
.
HasInput
(
"depth_tensor"
))
{
if
(
context
.
HasInput
(
"depth_tensor"
))
{
auto
*
depth_tensor
=
context
.
Input
<
Tensor
>
(
"depth_tensor"
);
auto
*
depth_tensor
=
context
.
Input
<
Tensor
>
(
"depth_tensor"
);
auto
*
depth_data
=
depth_tensor
->
data
<
int32_t
>
();
auto
*
depth_data
=
depth_tensor
->
data
<
int32_t
>
();
if
(
depth_tensor
->
place
()
==
platform
::
XPUPlace
(
))
{
if
(
platform
::
is_xpu_place
(
depth_tensor
->
place
()
))
{
xpu_memcpy
(
static_cast
<
void
*>
(
&
depth
),
xpu_memcpy
(
static_cast
<
void
*>
(
&
depth
),
static_cast
<
const
void
*>
(
depth_data
),
sizeof
(
int32_t
),
static_cast
<
const
void
*>
(
depth_data
),
sizeof
(
int32_t
),
XPU_DEVICE_TO_HOST
);
XPU_DEVICE_TO_HOST
);
...
...
paddle/fluid/operators/one_hot_v2_op_xpu.cc
0 → 100644
浏览文件 @
1323e5e7
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_XPU
#include <string>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/operators/one_hot_op.h"
namespace
paddle
{
namespace
operators
{
using
LoDTensor
=
framework
::
LoDTensor
;
using
Tensor
=
framework
::
Tensor
;
template
<
typename
DeviceContext
,
typename
T
>
class
OneHotV2XPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
in
=
context
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
int
depth
=
context
.
Attr
<
int
>
(
"depth"
);
if
(
context
.
HasInput
(
"depth_tensor"
))
{
auto
*
depth_tensor
=
context
.
Input
<
Tensor
>
(
"depth_tensor"
);
auto
*
depth_data
=
depth_tensor
->
data
<
int32_t
>
();
if
(
platform
::
is_xpu_place
(
depth_tensor
->
place
()))
{
xpu_memcpy
(
static_cast
<
void
*>
(
&
depth
),
static_cast
<
const
void
*>
(
depth_data
),
sizeof
(
int32_t
),
XPU_DEVICE_TO_HOST
);
}
else
{
depth
=
depth_data
[
0
];
}
auto
out_dims
=
out
->
dims
();
out_dims
[
out_dims
.
size
()
-
1
]
=
depth
;
out
->
Resize
(
out_dims
);
}
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
int
len
=
in
->
numel
();
int
ret
=
xpu
::
one_hot
<
T
>
(
dev_ctx
.
x_context
(),
in
->
data
<
T
>
(),
out
->
mutable_data
<
float
>
(
context
.
GetPlace
()),
len
,
depth
,
1.0
,
0.0
);
PADDLE_ENFORCE_EQ
(
ret
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU one_hot kernel return wrong value[%d %s]"
,
ret
,
XPUAPIErrorMsg
[
ret
]));
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_XPU_KERNEL
(
one_hot_v2
,
ops
::
OneHotV2XPUKernel
<
paddle
::
platform
::
XPUDeviceContext
,
int
>
,
ops
::
OneHotV2XPUKernel
<
paddle
::
platform
::
XPUDeviceContext
,
int64_t
>
);
#endif
paddle/fluid/operators/range_op_xpu.cc
0 → 100644
浏览文件 @
1323e5e7
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/range_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
XPURangeKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
start_t
=
context
.
Input
<
framework
::
Tensor
>
(
"Start"
);
auto
*
end_t
=
context
.
Input
<
framework
::
Tensor
>
(
"End"
);
auto
*
step_t
=
context
.
Input
<
framework
::
Tensor
>
(
"Step"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
framework
::
Tensor
n
;
framework
::
TensorCopy
(
*
start_t
,
platform
::
CPUPlace
(),
&
n
);
T
start
=
n
.
data
<
T
>
()[
0
];
framework
::
TensorCopy
(
*
end_t
,
platform
::
CPUPlace
(),
&
n
);
T
end
=
n
.
data
<
T
>
()[
0
];
framework
::
TensorCopy
(
*
step_t
,
platform
::
CPUPlace
(),
&
n
);
T
step
=
n
.
data
<
T
>
()[
0
];
int64_t
size
=
0
;
GetSize
(
start
,
end
,
step
,
&
size
);
out
->
Resize
(
framework
::
make_ddim
({
size
}));
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
framework
::
Tensor
out_cpu
;
T
*
out_cpu_data_ptr
=
out_cpu
.
mutable_data
<
T
>
(
platform
::
CPUPlace
(),
out
->
numel
()
*
sizeof
(
T
));
T
value
=
start
;
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
out_cpu_data_ptr
[
i
]
=
value
;
value
+=
step
;
}
int
ret
=
xpu_memcpy
(
out_data
,
out_cpu_data_ptr
,
out
->
numel
()
*
sizeof
(
T
),
XPUMemcpyKind
::
XPU_HOST_TO_DEVICE
);
PADDLE_ENFORCE_EQ
(
ret
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU xpu_memcpy return wrong "
"value[%d %s]"
,
ret
,
XPUAPIErrorMsg
[
ret
]));
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_XPU_KERNEL
(
range
,
ops
::
XPURangeKernel
<
int
>
,
ops
::
XPURangeKernel
<
int64_t
>
,
ops
::
XPURangeKernel
<
float
>
,
ops
::
XPURangeKernel
<
double
>
);
#endif // PADDLE_WITH_XPU
paddle/fluid/operators/scale_op_xpu.cc
浏览文件 @
1323e5e7
...
@@ -46,10 +46,13 @@ class ScaleXPUKernel : public framework::OpKernel<T> {
...
@@ -46,10 +46,13 @@ class ScaleXPUKernel : public framework::OpKernel<T> {
in
->
dims
().
to_str
().
c_str
(),
in
->
dims
().
to_str
().
c_str
(),
out
->
dims
().
to_str
().
c_str
()));
out
->
dims
().
to_str
().
c_str
()));
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
int
r
=
xpu
::
scale
(
dev_ctx
.
x_context
(),
in
->
numel
(),
scale
,
bias
,
int
r
=
bias_after_scale
,
in
->
data
<
float
>
(),
out
->
data
<
float
>
());
xpu
::
scale
(
dev_ctx
.
x_context
(),
in
->
data
<
float
>
(),
out
->
data
<
float
>
(),
PADDLE_ENFORCE_EQ
(
r
,
xpu
::
Error_t
::
SUCCESS
,
in
->
numel
(),
bias_after_scale
,
scale
,
bias
);
platform
::
errors
::
Fatal
(
"XPU scale kernel error!"
));
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU scale kernel return wrong value[%d %s]"
,
r
,
XPUAPIErrorMsg
[
r
]));
}
}
};
};
...
...
paddle/fluid/operators/softmax_op_xpu.cc
浏览文件 @
1323e5e7
...
@@ -41,8 +41,21 @@ class SoftmaxXPUKernel : public framework::OpKernel<T> {
...
@@ -41,8 +41,21 @@ class SoftmaxXPUKernel : public framework::OpKernel<T> {
}
}
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
int
r
=
xpu
::
softmax
<
T
>
(
dev_ctx
.
x_context
(),
x
->
data
<
float
>
(),
out
->
data
<
float
>
(),
x_dims
,
axis
);
int
r
=
XPU_SUCCESS
;
Tensor
clip_x
;
int
len
=
x
->
numel
();
T
*
clip_x_data
=
clip_x
.
mutable_data
<
T
>
(
platform
::
XPUPlace
(),
len
*
sizeof
(
T
));
r
=
xpu
::
clip
(
dev_ctx
.
x_context
(),
x
->
data
<
float
>
(),
clip_x_data
,
len
,
-
1e30
,
1e30
);
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU API(clip) return wrong "
"value[%d %s]"
,
r
,
XPUAPIErrorMsg
[
r
]));
r
=
xpu
::
softmax
<
T
>
(
dev_ctx
.
x_context
(),
clip_x_data
,
out
->
data
<
float
>
(),
x_dims
,
axis
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU API(softmax2d_forward) return wrong "
platform
::
errors
::
External
(
"XPU API(softmax2d_forward) return wrong "
...
...
python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py
浏览文件 @
1323e5e7
...
@@ -13,12 +13,11 @@
...
@@ -13,12 +13,11 @@
# limitations under the License.
# limitations under the License.
from
__future__
import
print_function
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
sys
import
sys
sys
.
path
.
append
(
".."
)
sys
.
path
.
append
(
".."
)
from
op_test
import
OpTest
import
unittest
import
numpy
as
np
from
op_test_xpu
import
XPUOpTest
import
paddle.fluid.core
as
core
import
paddle.fluid.core
as
core
import
paddle
import
paddle
...
@@ -57,9 +56,7 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
...
@@ -57,9 +56,7 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
return
Out
return
Out
@
unittest
.
skipIf
(
not
paddle
.
is_compiled_with_xpu
(),
class
TestMatMulV2Op
(
XPUOpTest
):
"core is not compiled with XPU"
)
class
TestMatMulV2Op
(
OpTest
):
"""
"""
case 1
case 1
"""
"""
...
@@ -74,10 +71,10 @@ class TestMatMulV2Op(OpTest):
...
@@ -74,10 +71,10 @@ class TestMatMulV2Op(OpTest):
self
.
dtype
=
"float32"
self
.
dtype
=
"float32"
def
setUp
(
self
):
def
setUp
(
self
):
self
.
use_xpu
=
True
self
.
init_kernel_type
()
self
.
init_kernel_type
()
self
.
config
()
self
.
config
()
self
.
op_type
=
"matmul_v2"
self
.
op_type
=
"matmul_v2"
self
.
use_xpu
=
True
x
=
np
.
random
.
random
(
self
.
x_shape
).
astype
(
self
.
dtype
)
x
=
np
.
random
.
random
(
self
.
x_shape
).
astype
(
self
.
dtype
)
y
=
np
.
random
.
random
(
self
.
y_shape
).
astype
(
self
.
dtype
)
y
=
np
.
random
.
random
(
self
.
y_shape
).
astype
(
self
.
dtype
)
# -0.1 ~ 0.1
# -0.1 ~ 0.1
...
@@ -94,31 +91,25 @@ class TestMatMulV2Op(OpTest):
...
@@ -94,31 +91,25 @@ class TestMatMulV2Op(OpTest):
def
test_check_output
(
self
):
def
test_check_output
(
self
):
place
=
paddle
.
XPUPlace
(
0
)
place
=
paddle
.
XPUPlace
(
0
)
self
.
check_output_with_place
(
place
,
atol
=
0.01
)
self
.
check_output_with_place
(
place
)
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
place
=
paddle
.
XPUPlace
(
0
)
place
=
paddle
.
XPUPlace
(
0
)
self
.
check_grad_with_place
(
self
.
check_grad_with_place
(
place
,
[
'X'
,
'Y'
],
'Out'
)
place
,
[
'X'
,
'Y'
],
'Out'
,
max_relative_error
=
0.1
)
'''
# class TestMatMuklOp2(TestMatMulV2Op):
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
# """
"core is not compiled with XPU")
# case 2
class TestMatMuklOp2(TestMatMulV2Op):
# """
"""
case 2
"""
def config(self):
#
def config(self):
self.x_shape = (100, )
#
self.x_shape = (100, )
self.y_shape = (1, 3, 2, 100)
#
self.y_shape = (1, 3, 2, 100)
self.trans_x = False
#
self.trans_x = False
self.trans_y = True
#
self.trans_y = True
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class
TestMatMuklOp3
(
TestMatMulV2Op
):
class
TestMatMuklOp3
(
TestMatMulV2Op
):
"""
"""
case 3
case 3
...
@@ -131,21 +122,18 @@ class TestMatMuklOp3(TestMatMulV2Op):
...
@@ -131,21 +122,18 @@ class TestMatMuklOp3(TestMatMulV2Op):
self
.
trans_y
=
False
self
.
trans_y
=
False
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
# class TestMatMuklOp4(TestMatMulV2Op):
"core is not compiled with XPU")
# """
class TestMatMuklOp4(TestMatMulV2Op):
# case 4
"""
# """
case 4
"""
# def config(self):
# self.x_shape = (100, )
# self.y_shape = (1, 2, 100, 2)
# self.trans_x = False
# self.trans_y = False
def config(self):
self.x_shape = (100, )
self.y_shape = (1, 2, 100, 2)
self.trans_x = False
self.trans_y = False
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class
TestMatMuklOp5
(
TestMatMulV2Op
):
class
TestMatMuklOp5
(
TestMatMulV2Op
):
"""
"""
case 5
case 5
...
@@ -158,37 +146,29 @@ class TestMatMuklOp5(TestMatMulV2Op):
...
@@ -158,37 +146,29 @@ class TestMatMuklOp5(TestMatMulV2Op):
self
.
trans_y
=
False
self
.
trans_y
=
False
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
# class TestMatMuklOp6(TestMatMulV2Op):
"core is not compiled with XPU")
# """
class TestMatMuklOp6(TestMatMulV2Op):
# case 6
"""
# """
case 6
"""
def config(self):
self.x_shape = (1, 2, 100, 1)
self.y_shape = (100, )
self.trans_x = True
self.trans_y = False
# def config(self):
# self.x_shape = (1, 2, 102, 1)
# self.y_shape = (102, )
# self.trans_x = True
# self.trans_y = False
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
# class TestMatMuklOp7(TestMatMulV2Op):
"core is not compiled with XPU")
# """
class TestMatMuklOp7(TestMatMulV2Op):
# case 7
"""
# """
case 7
"""
def config(self):
# def config(self):
self.x_shape = (1, 2, 1, 100)
# self.x_shape = (1, 2, 1, 100)
self.y_shape = (100, )
# self.y_shape = (100, )
self.trans_x = False
# self.trans_x = False
self.trans_y = False
# self.trans_y = False
'''
@
unittest
.
skipIf
(
not
paddle
.
is_compiled_with_xpu
(),
"core is not compiled with XPU"
)
class
TestMatMuklOp8
(
TestMatMulV2Op
):
class
TestMatMuklOp8
(
TestMatMulV2Op
):
"""
"""
case 8
case 8
...
@@ -201,37 +181,97 @@ class TestMatMuklOp8(TestMatMulV2Op):
...
@@ -201,37 +181,97 @@ class TestMatMuklOp8(TestMatMulV2Op):
self
.
trans_y
=
False
self
.
trans_y
=
False
@
unittest
.
skipIf
(
not
paddle
.
is_compiled_with_xpu
(),
# class TestMatMuklOp9(TestMatMulV2Op):
"core is not compiled with XPU"
)
# """
# case 9
# """
# def config(self):
# self.x_shape = (1, 1, 1, 100)
# self.y_shape = (2, 1, 2, 100)
# self.trans_x = False
# self.trans_y = True
# class TestMatMuklOp10(TestMatMulV2Op):
# """
# case 10
# """
# def config(self):
# self.x_shape = (1, 1, 25, 4)
# self.y_shape = (1, 2, 4, 25)
# self.trans_x = False
# self.trans_y = False
# class TestMatMuklOp11(TestMatMulV2Op):
# """
# case 11
# """
# def config(self):
# self.x_shape = (2, 1, 2, 100)
# self.y_shape = (1, 1, 100, 2)
# self.trans_x = False
# self.trans_y = False
# class TestMatMuklOp12(TestMatMulV2Op):
# """
# case 12
# """
# def config(self):
# self.x_shape = (2, 1, 4, 25)
# self.y_shape = (1, 1, 4, 25)
# self.trans_x = True
# self.trans_y = False
class
TestMatMuklOp13
(
TestMatMulV2Op
):
class
TestMatMuklOp13
(
TestMatMulV2Op
):
"""
"""
case 13
case 13
"""
"""
def
config
(
self
):
def
config
(
self
):
self
.
x_shape
=
(
2
,
2
,
2
,
5
0
)
self
.
x_shape
=
(
2
,
2
,
10
,
1
0
)
self
.
y_shape
=
(
2
,
2
,
2
,
5
0
)
self
.
y_shape
=
(
2
,
2
,
10
,
1
0
)
self
.
trans_x
=
True
self
.
trans_x
=
True
self
.
trans_y
=
False
self
.
trans_y
=
False
'''
# class TestMatMuklOp14(TestMatMulV2Op):
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
# """
"core is not compiled with XPU")
# case 14_1
class TestMatMuklOp16(TestMatMulV2Op):
# """
"""
case 16 : to check the gradient for special case
"""
def config(self):
# def config(self):
self.x_shape = (100)
# self.x_shape = (3, 1, 6, 6)
self.y_shape = (1, 2, 2, 100, 2)
# self.y_shape = (1, 2, 6, 9)
self.trans_x = False
# self.trans_x = True
self.trans_y = False
# self.trans_y = False
# class TestMatMuklOp15(TestMatMulV2Op):
# """
# case 14_2
# """
# def config(self):
# self.x_shape = (3, 1, 6, 6)
# self.y_shape = (1, 2, 6, 9)
# self.trans_x = False
# self.trans_y = False
# class TestMatMuklOp16(TestMatMulV2Op):
# """
# case 16 : to check the gradient for special case
# """
# def config(self):
# self.x_shape = (100)
# self.y_shape = (1, 2, 2, 100, 2)
# self.trans_x = False
# self.trans_y = False
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class
TestMatMuklOp17
(
TestMatMulV2Op
):
class
TestMatMuklOp17
(
TestMatMulV2Op
):
"""
"""
case 17 : to check the gradient for special case
case 17 : to check the gradient for special case
...
@@ -242,36 +282,30 @@ class TestMatMuklOp17(TestMatMulV2Op):
...
@@ -242,36 +282,30 @@ class TestMatMuklOp17(TestMatMulV2Op):
self
.
y_shape
=
(
100
)
self
.
y_shape
=
(
100
)
self
.
trans_x
=
False
self
.
trans_x
=
False
self
.
trans_y
=
False
self
.
trans_y
=
False
'''
@
unittest
.
skipIf
(
not
paddle
.
is_compiled_with_xpu
(),
"core is not compiled with XPU"
)
class
TestMatMulV2API
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
places
=
[
fluid
.
CPUPlace
()]
self
.
places
.
append
(
fluid
.
XPUPlace
(
0
))
def
check_static_result
(
self
,
place
):
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
input_x
=
fluid
.
data
(
name
=
"input_x"
,
shape
=
[
4
,
3
],
dtype
=
"float32"
)
input_y
=
fluid
.
data
(
name
=
"input_y"
,
shape
=
[
3
,
4
],
dtype
=
"float32"
)
result
=
paddle
.
matmul
(
input_x
,
input_y
)
x_np
=
np
.
random
.
random
([
4
,
3
]).
astype
(
"float32"
)
# class TestMatMuklOpBroadcast1(TestMatMulV2Op):
y_np
=
np
.
random
.
random
([
3
,
4
]).
astype
(
"float32"
)
# """
# case 14_3
# """
exe
=
fluid
.
Executor
(
place
)
# def config(self):
fetches
=
exe
.
run
(
fluid
.
default_main_program
(),
# self.x_shape = (3, 1, 10, 10)
feed
=
{
"input_x"
:
x_np
,
# self.y_shape = (1, 2, 10, 10)
"input_y"
:
y_np
},
# self.trans_x = True
fetch_list
=
[
result
])
# self.trans_y = True
def
test_static
(
self
):
# class TestMatMuklOpBroadcast2(TestMatMulV2Op):
for
place
in
self
.
places
:
# """
self
.
check_static_result
(
place
=
place
)
# case 14_4
# """
# def config(self):
# self.x_shape = (3, 1, 10, 10)
# self.y_shape = (1, 2, 10, 10)
# self.trans_x = False
# self.trans_y = True
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
paddle
.
enable_static
()
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/xpu/test_one_hot_v2_op_xpu.py
0 → 100644
浏览文件 @
1323e5e7
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
paddle
import
paddle.fluid.core
as
core
import
sys
sys
.
path
.
append
(
".."
)
from
op_test_xpu
import
XPUOpTest
import
paddle.fluid
as
fluid
from
paddle.fluid
import
Program
,
program_guard
import
time
paddle
.
enable_static
()
class
TestOneHotOp
(
XPUOpTest
):
def
setUp
(
self
):
self
.
use_xpu
=
True
self
.
op_type
=
'one_hot_v2'
depth
=
10
depth_np
=
np
.
array
(
10
).
astype
(
'int32'
)
# dimension = 12
x_lod
=
[[
4
,
1
,
3
,
3
]]
x
=
[
np
.
random
.
randint
(
0
,
depth
-
1
)
for
i
in
range
(
sum
(
x_lod
[
0
]))]
x
=
np
.
array
(
x
).
astype
(
'int32'
).
reshape
([
sum
(
x_lod
[
0
])])
out
=
np
.
zeros
(
shape
=
(
np
.
product
(
x
.
shape
),
depth
)).
astype
(
'float32'
)
for
i
in
range
(
np
.
product
(
x
.
shape
)):
out
[
i
,
x
[
i
]]
=
1.0
self
.
inputs
=
{
'X'
:
(
x
,
x_lod
),
'depth_tensor'
:
depth_np
}
self
.
attrs
=
{
'dtype'
:
int
(
core
.
VarDesc
.
VarType
.
FP32
)}
self
.
outputs
=
{
'Out'
:
(
out
,
x_lod
)}
def
test_check_output
(
self
):
place
=
paddle
.
XPUPlace
(
0
)
self
.
check_output_with_place
(
place
,
check_dygraph
=
False
)
class
TestOneHotOp_attr
(
XPUOpTest
):
def
setUp
(
self
):
self
.
op_type
=
'one_hot_v2'
depth
=
10
dimension
=
12
x_lod
=
[[
4
,
1
,
3
,
3
]]
x
=
[
np
.
random
.
randint
(
0
,
depth
-
1
)
for
i
in
range
(
sum
(
x_lod
[
0
]))]
x
=
np
.
array
(
x
).
astype
(
'int32'
).
reshape
([
sum
(
x_lod
[
0
]),
1
])
out
=
np
.
zeros
(
shape
=
(
np
.
product
(
x
.
shape
[:
-
1
]),
1
,
depth
)).
astype
(
'float32'
)
for
i
in
range
(
np
.
product
(
x
.
shape
)):
out
[
i
,
0
,
x
[
i
]]
=
1.0
self
.
inputs
=
{
'X'
:
(
x
,
x_lod
)}
self
.
attrs
=
{
'dtype'
:
int
(
core
.
VarDesc
.
VarType
.
FP32
),
'depth'
:
depth
}
self
.
outputs
=
{
'Out'
:
(
out
,
x_lod
)}
def
test_check_output
(
self
):
place
=
paddle
.
XPUPlace
(
0
)
self
.
check_output_with_place
(
place
,
check_dygraph
=
False
)
class
TestOneHotOp_default_dtype
(
XPUOpTest
):
def
setUp
(
self
):
self
.
op_type
=
'one_hot_v2'
depth
=
10
depth_np
=
np
.
array
(
10
).
astype
(
'int32'
)
dimension
=
12
x_lod
=
[[
4
,
1
,
3
,
3
]]
x
=
[
np
.
random
.
randint
(
0
,
depth
-
1
)
for
i
in
range
(
sum
(
x_lod
[
0
]))]
x
=
np
.
array
(
x
).
astype
(
'int32'
).
reshape
([
sum
(
x_lod
[
0
])])
out
=
np
.
zeros
(
shape
=
(
np
.
product
(
x
.
shape
),
depth
)).
astype
(
'float32'
)
for
i
in
range
(
np
.
product
(
x
.
shape
)):
out
[
i
,
x
[
i
]]
=
1.0
self
.
inputs
=
{
'X'
:
(
x
,
x_lod
),
'depth_tensor'
:
depth_np
}
self
.
attrs
=
{}
self
.
outputs
=
{
'Out'
:
(
out
,
x_lod
)}
def
test_check_output
(
self
):
place
=
paddle
.
XPUPlace
(
0
)
self
.
check_output_with_place
(
place
,
check_dygraph
=
False
)
class
TestOneHotOp_default_dtype_attr
(
XPUOpTest
):
def
setUp
(
self
):
self
.
op_type
=
'one_hot_v2'
depth
=
10
dimension
=
12
x_lod
=
[[
4
,
1
,
3
,
3
]]
x
=
[
np
.
random
.
randint
(
0
,
depth
-
1
)
for
i
in
range
(
sum
(
x_lod
[
0
]))]
x
=
np
.
array
(
x
).
astype
(
'int32'
).
reshape
([
sum
(
x_lod
[
0
]),
1
])
out
=
np
.
zeros
(
shape
=
(
np
.
product
(
x
.
shape
[:
-
1
]),
1
,
depth
)).
astype
(
'float32'
)
for
i
in
range
(
np
.
product
(
x
.
shape
)):
out
[
i
,
0
,
x
[
i
]]
=
1.0
self
.
inputs
=
{
'X'
:
(
x
,
x_lod
)}
self
.
attrs
=
{
'depth'
:
depth
}
self
.
outputs
=
{
'Out'
:
(
out
,
x_lod
)}
def
test_check_output
(
self
):
place
=
paddle
.
XPUPlace
(
0
)
self
.
check_output_with_place
(
place
,
check_dygraph
=
False
)
class
TestOneHotOp_out_of_range
(
XPUOpTest
):
def
setUp
(
self
):
self
.
op_type
=
'one_hot_v2'
depth
=
10
x_lod
=
[[
4
,
1
,
3
,
3
]]
x
=
[
np
.
random
.
choice
([
-
1
,
depth
])
for
i
in
range
(
sum
(
x_lod
[
0
]))]
x
=
np
.
array
(
x
).
astype
(
'int32'
).
reshape
([
sum
(
x_lod
[
0
])])
out
=
np
.
zeros
(
shape
=
(
np
.
product
(
x
.
shape
),
depth
)).
astype
(
'float32'
)
self
.
inputs
=
{
'X'
:
(
x
,
x_lod
)}
self
.
attrs
=
{
'depth'
:
depth
,
'allow_out_of_range'
:
True
}
self
.
outputs
=
{
'Out'
:
(
out
,
x_lod
)}
def
test_check_output
(
self
):
place
=
paddle
.
XPUPlace
(
0
)
self
.
check_output_with_place
(
place
,
check_dygraph
=
False
)
class
TestOneHotOpApi
(
unittest
.
TestCase
):
def
test_api
(
self
):
depth
=
10
self
.
_run
(
depth
)
def
test_api_with_depthTensor
(
self
):
depth
=
fluid
.
layers
.
assign
(
input
=
np
.
array
([
10
],
dtype
=
np
.
int32
))
self
.
_run
(
depth
)
def
test_api_with_dygraph
(
self
):
depth
=
10
label
=
np
.
array
([
np
.
random
.
randint
(
0
,
depth
-
1
)
for
i
in
range
(
6
)]).
reshape
([
6
,
1
])
with
fluid
.
dygraph
.
guard
():
one_hot_label
=
fluid
.
one_hot
(
input
=
fluid
.
dygraph
.
to_variable
(
label
),
depth
=
depth
)
def
_run
(
self
,
depth
):
label
=
fluid
.
layers
.
data
(
name
=
"label"
,
shape
=
[
1
],
dtype
=
"int64"
)
one_hot_label
=
fluid
.
one_hot
(
input
=
label
,
depth
=
depth
)
place
=
fluid
.
XPUPlace
(
0
)
label_data
=
np
.
array
([
np
.
random
.
randint
(
0
,
10
-
1
)
for
i
in
range
(
6
)]).
reshape
([
6
,
1
])
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
ret
=
exe
.
run
(
feed
=
{
'label'
:
label_data
,
},
fetch_list
=
[
one_hot_label
],
return_numpy
=
False
)
class
BadInputTestOnehotV2
(
unittest
.
TestCase
):
def
test_error
(
self
):
with
fluid
.
program_guard
(
fluid
.
Program
()):
def
test_bad_x
():
label
=
fluid
.
layers
.
data
(
name
=
"label"
,
shape
=
[
4
],
append_batch_size
=
False
,
dtype
=
"float32"
)
one_hot_label
=
fluid
.
one_hot
(
input
=
label
,
depth
=
4
)
self
.
assertRaises
(
TypeError
,
test_bad_x
)
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/xpu/test_range_xpu.py
0 → 100644
浏览文件 @
1323e5e7
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
paddle
import
numpy
as
np
import
sys
sys
.
path
.
append
(
".."
)
from
op_test_xpu
import
XPUOpTest
paddle
.
enable_static
()
class
TestRangeOp
(
XPUOpTest
):
def
setUp
(
self
):
self
.
op_type
=
"range"
self
.
init_config
()
self
.
inputs
=
{
'Start'
:
np
.
array
([
self
.
case
[
0
]]).
astype
(
self
.
dtype
),
'End'
:
np
.
array
([
self
.
case
[
1
]]).
astype
(
self
.
dtype
),
'Step'
:
np
.
array
([
self
.
case
[
2
]]).
astype
(
self
.
dtype
)
}
self
.
outputs
=
{
'Out'
:
np
.
arange
(
self
.
case
[
0
],
self
.
case
[
1
],
self
.
case
[
2
]).
astype
(
self
.
dtype
)
}
def
init_config
(
self
):
self
.
dtype
=
np
.
float32
self
.
case
=
(
0
,
1
,
0.2
)
def
test_check_output
(
self
):
place
=
paddle
.
XPUPlace
(
0
)
self
.
check_output_with_place
(
place
,
check_dygraph
=
False
)
class
TestFloatRangeOpCase0
(
TestRangeOp
):
def
init_config
(
self
):
self
.
dtype
=
np
.
float32
self
.
case
=
(
0
,
5
,
1
)
class
TestInt32RangeOpCase0
(
TestRangeOp
):
def
init_config
(
self
):
self
.
dtype
=
np
.
int32
self
.
case
=
(
0
,
5
,
2
)
class
TestInt32RangeOpCase1
(
TestRangeOp
):
def
init_config
(
self
):
self
.
dtype
=
np
.
int32
self
.
case
=
(
10
,
1
,
-
2
)
class
TestInt32RangeOpCase2
(
TestRangeOp
):
def
init_config
(
self
):
self
.
dtype
=
np
.
int32
self
.
case
=
(
-
1
,
-
10
,
-
2
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录