Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
4d4fb660
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4d4fb660
编写于
6月 29, 2021
作者:
T
taixiurong
提交者:
GitHub
6月 29, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
xpu support amp (#33809)
上级
0d3de8d0
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
143 addition
and
99 deletion
+143
-99
cmake/external/xpu.cmake
cmake/external/xpu.cmake
+5
-7
paddle/fluid/imperative/amp_auto_cast.cc
paddle/fluid/imperative/amp_auto_cast.cc
+4
-2
paddle/fluid/operators/cast_op_xpu.cc
paddle/fluid/operators/cast_op_xpu.cc
+1
-14
paddle/fluid/operators/matmul_op_xpu.cc
paddle/fluid/operators/matmul_op_xpu.cc
+49
-32
paddle/fluid/operators/matmul_v2_op_xpu.cc
paddle/fluid/operators/matmul_v2_op_xpu.cc
+47
-32
paddle/fluid/operators/softmax_op_xpu.cc
paddle/fluid/operators/softmax_op_xpu.cc
+2
-2
paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc
paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc
+3
-2
paddle/fluid/platform/xpu_header.h
paddle/fluid/platform/xpu_header.h
+14
-1
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+3
-1
python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
+9
-2
python/paddle/fluid/dygraph/amp/auto_cast.py
python/paddle/fluid/dygraph/amp/auto_cast.py
+3
-2
python/paddle/fluid/dygraph/amp/loss_scaler.py
python/paddle/fluid/dygraph/amp/loss_scaler.py
+3
-2
未找到文件。
cmake/external/xpu.cmake
浏览文件 @
4d4fb660
...
@@ -27,19 +27,17 @@ ELSEIF(WITH_CENTOS)
...
@@ -27,19 +27,17 @@ ELSEIF(WITH_CENTOS)
SET
(
XPU_XRE_DIR_NAME
"xre-centos7_x86_64"
)
SET
(
XPU_XRE_DIR_NAME
"xre-centos7_x86_64"
)
SET
(
XPU_XDNN_DIR_NAME
"xdnn-centos7_x86_64"
)
SET
(
XPU_XDNN_DIR_NAME
"xdnn-centos7_x86_64"
)
SET
(
XPU_XCCL_DIR_NAME
"xccl-bdcentos_x86_64"
)
SET
(
XPU_XCCL_DIR_NAME
"xccl-bdcentos_x86_64"
)
ELSE
()
ELSE
()
SET
(
XPU_XRE_DIR_NAME
"xre-ubuntu_x86_64"
)
SET
(
XPU_XRE_DIR_NAME
"xre-ubuntu_x86_64"
)
SET
(
XPU_XDNN_DIR_NAME
"xdnn-ubuntu_x86_64"
)
SET
(
XPU_XDNN_DIR_NAME
"xdnn-ubuntu_x86_64"
)
SET
(
XPU_XCCL_DIR_NAME
"xccl-bdcentos_x86_64"
)
SET
(
XPU_XCCL_DIR_NAME
"xccl-bdcentos_x86_64"
)
ENDIF
()
ENDIF
()
IF
(
NOT XPU_BASE_URL
)
SET
(
XPU_BASE_URL
"https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev"
)
SET
(
XPU_BASE_URL
"https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev/20210527"
)
SET
(
XPU_XRE_URL
"
${
XPU_BASE_URL
}
/20210625/
${
XPU_XRE_DIR_NAME
}
.tar.gz"
CACHE STRING
""
FORCE
)
ENDIF
()
SET
(
XPU_XDNN_URL
"
${
XPU_BASE_URL
}
/20210625/
${
XPU_XDNN_DIR_NAME
}
.tar.gz"
CACHE STRING
""
FORCE
)
SET
(
XPU_XCCL_URL
"
${
XPU_BASE_URL
}
/20210623/
${
XPU_XCCL_DIR_NAME
}
.tar.gz"
CACHE STRING
""
FORCE
)
SET
(
XPU_XRE_URL
"
${
XPU_BASE_URL
}
/
${
XPU_XRE_DIR_NAME
}
.tar.gz"
CACHE STRING
""
FORCE
)
SET
(
XPU_XDNN_URL
"
${
XPU_BASE_URL
}
/
${
XPU_XDNN_DIR_NAME
}
.tar.gz"
CACHE STRING
""
FORCE
)
SET
(
XPU_XCCL_URL
"
${
XPU_BASE_URL
}
/
${
XPU_XCCL_DIR_NAME
}
.tar.gz"
CACHE STRING
""
FORCE
)
SET
(
XPU_PACK_DEPENCE_URL
"https://baidu-kunlun-public.su.bcebos.com/paddle_depence/pack_paddle_depence.sh"
CACHE STRING
""
FORCE
)
SET
(
XPU_PACK_DEPENCE_URL
"https://baidu-kunlun-public.su.bcebos.com/paddle_depence/pack_paddle_depence.sh"
CACHE STRING
""
FORCE
)
SET
(
XPU_SOURCE_DIR
"
${
THIRD_PARTY_PATH
}
/xpu"
)
SET
(
XPU_SOURCE_DIR
"
${
THIRD_PARTY_PATH
}
/xpu"
)
...
...
paddle/fluid/imperative/amp_auto_cast.cc
浏览文件 @
4d4fb660
...
@@ -33,7 +33,8 @@ AmpOperators::AmpOperators()
...
@@ -33,7 +33,8 @@ AmpOperators::AmpOperators()
for
(
auto
it
=
all_kernels
.
begin
();
it
!=
all_kernels
.
end
();
it
++
)
{
for
(
auto
it
=
all_kernels
.
begin
();
it
!=
all_kernels
.
end
();
it
++
)
{
bool
supported
=
false
;
bool
supported
=
false
;
for
(
auto
&
kernel_type
:
it
->
second
)
{
for
(
auto
&
kernel_type
:
it
->
second
)
{
if
(
platform
::
is_gpu_place
(
kernel_type
.
first
.
place_
)
&&
if
((
platform
::
is_gpu_place
(
kernel_type
.
first
.
place_
)
||
platform
::
is_xpu_place
(
kernel_type
.
first
.
place_
))
&&
kernel_type
.
first
.
data_type_
==
fp16_dtype
)
{
kernel_type
.
first
.
data_type_
==
fp16_dtype
)
{
supported
=
true
;
supported
=
true
;
}
}
...
@@ -91,7 +92,8 @@ inline std::string GetDtypeStr(
...
@@ -91,7 +92,8 @@ inline std::string GetDtypeStr(
inline
bool
NeedCast
(
const
std
::
shared_ptr
<
VarBase
>&
var
)
{
inline
bool
NeedCast
(
const
std
::
shared_ptr
<
VarBase
>&
var
)
{
if
(
platform
::
is_gpu_place
(
var
->
Place
())
||
if
(
platform
::
is_gpu_place
(
var
->
Place
())
||
platform
::
is_cuda_pinned_place
(
var
->
Place
()))
{
platform
::
is_cuda_pinned_place
(
var
->
Place
())
||
platform
::
is_xpu_place
(
var
->
Place
()))
{
// CudaPinndePlace is added for varbase created by dataloader
// CudaPinndePlace is added for varbase created by dataloader
if
(
var
->
DataType
()
==
framework
::
proto
::
VarType
::
FP32
||
if
(
var
->
DataType
()
==
framework
::
proto
::
VarType
::
FP32
||
var
->
DataType
()
==
framework
::
proto
::
VarType
::
FP16
)
{
var
->
DataType
()
==
framework
::
proto
::
VarType
::
FP16
)
{
...
...
paddle/fluid/operators/cast_op_xpu.cc
浏览文件 @
4d4fb660
...
@@ -23,21 +23,9 @@ limitations under the License. */
...
@@ -23,21 +23,9 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
T
>
class
XPUFPTypeTrait
{
public:
using
Type
=
T
;
};
template
<
>
class
XPUFPTypeTrait
<
platform
::
float16
>
{
public:
using
Type
=
float16
;
};
template
<
typename
DeviceContext
,
typename
InT
>
template
<
typename
DeviceContext
,
typename
InT
>
class
CastXPUKernel
:
public
framework
::
OpKernel
<
InT
>
{
class
CastXPUKernel
:
public
framework
::
OpKernel
<
InT
>
{
using
XPUInTDType
=
typename
XPU
FP
TypeTrait
<
InT
>::
Type
;
using
XPUInTDType
=
typename
XPUTypeTrait
<
InT
>::
Type
;
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
...
@@ -49,7 +37,6 @@ class CastXPUKernel : public framework::OpKernel<InT> {
...
@@ -49,7 +37,6 @@ class CastXPUKernel : public framework::OpKernel<InT> {
context
.
Attr
<
int
>
(
"out_dtype"
));
context
.
Attr
<
int
>
(
"out_dtype"
));
auto
*
in_data
=
in
->
data
<
InT
>
();
auto
*
in_data
=
in
->
data
<
InT
>
();
// using XPUOutTDType = typename XPUFPTypeTrait<InT>::Type;
auto
numel
=
in
->
numel
();
auto
numel
=
in
->
numel
();
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
int
r
=
-
1
;
int
r
=
-
1
;
...
...
paddle/fluid/operators/matmul_op_xpu.cc
浏览文件 @
4d4fb660
...
@@ -102,6 +102,7 @@ template <typename T, typename FCT>
...
@@ -102,6 +102,7 @@ template <typename T, typename FCT>
static
void
MatMulXPUFunction
(
const
Tensor
*
x
,
const
Tensor
*
y
,
Tensor
*
out
,
static
void
MatMulXPUFunction
(
const
Tensor
*
x
,
const
Tensor
*
y
,
Tensor
*
out
,
bool
trans_x
,
bool
trans_y
,
bool
trans_x
,
bool
trans_y
,
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
{
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
{
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
const
auto
&
x_dims
=
x
->
dims
();
const
auto
&
x_dims
=
x
->
dims
();
const
auto
&
y_dims
=
y
->
dims
();
const
auto
&
y_dims
=
y
->
dims
();
auto
&
dev_ctx
=
auto
&
dev_ctx
=
...
@@ -162,34 +163,36 @@ static void MatMulXPUFunction(const Tensor *x, const Tensor *y, Tensor *out,
...
@@ -162,34 +163,36 @@ static void MatMulXPUFunction(const Tensor *x, const Tensor *y, Tensor *out,
int
ldout
=
n
;
int
ldout
=
n
;
if
(
batch_size
<=
1
)
{
if
(
batch_size
<=
1
)
{
int
r
=
0
;
int
r
=
0
;
r
=
xpu
::
fc_fusion
<
T
,
T
,
T
,
FCT
>
(
r
=
xpu
::
fc_fusion
<
XPUType
,
XPUType
,
XPUType
,
FCT
>
(
dev_ctx
.
x_context
(),
x
->
data
<
T
>
(),
y
->
data
<
T
>
(),
data_c
,
m
,
n
,
k
,
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
x
->
data
<
T
>
()),
mat_dim_a
.
trans_
,
mat_dim_b
.
trans_
,
nullptr
,
nullptr
,
nullptr
,
ldx
,
ldy
,
reinterpret_cast
<
const
XPUType
*>
(
y
->
data
<
T
>
()),
ldout
,
alpha
,
0
,
nullptr
,
xpu
::
Activation_t
::
LINEAR
);
reinterpret_cast
<
XPUType
*>
(
data_c
),
m
,
n
,
k
,
mat_dim_a
.
trans_
,
mat_dim_b
.
trans_
,
nullptr
,
nullptr
,
nullptr
,
ldx
,
ldy
,
ldout
,
alpha
,
0
,
nullptr
,
xpu
::
Activation_t
::
LINEAR
);
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
platform
::
errors
::
External
(
"XPU fc_fusion kernel return wrong value[%d %s]"
,
r
,
"XPU fc_fusion kernel return wrong value[%d %s]"
,
r
,
XPUAPIErrorMsg
[
r
]));
XPUAPIErrorMsg
[
r
]));
}
else
{
}
else
{
// batch matmul
// batch matmul
int
r
=
xpu
::
fc_batched
<
T
,
T
,
T
,
FCT
>
(
int
r
=
xpu
::
fc_batched
<
XPUType
,
XPUType
,
XPUType
,
FCT
>
(
dev_ctx
.
x_context
(),
// Context* ctx,
dev_ctx
.
x_context
(),
// Context* ctx,
batch_size
,
// int batch_size,
batch_size
,
// int batch_size,
mat_dim_a
.
trans_
,
// bool x_trans,
mat_dim_a
.
trans_
,
// bool x_trans,
mat_dim_b
.
trans_
,
// bool w_trans,
mat_dim_b
.
trans_
,
// bool w_trans,
m
,
// int m,
m
,
// int m,
n
,
// int n,
n
,
// int n,
k
,
// int k,
k
,
// int k,
alpha
,
// float alpha,
alpha
,
// float alpha,
reinterpret_cast
<
const
T
*>
(
x
->
data
<
T
>
()),
// const TX* x,
reinterpret_cast
<
const
XPUType
*>
(
x
->
data
<
T
>
()),
// const TX* x,
mat_dim_a
.
stride_
,
// int stride_a,
mat_dim_a
.
stride_
,
// int stride_a,
reinterpret_cast
<
const
T
*>
(
y
->
data
<
T
>
()),
// const TW* w,
reinterpret_cast
<
const
XPUType
*>
(
y
->
data
<
T
>
()),
// const TW* w,
mat_dim_b
.
stride_
,
// int stride_b,
mat_dim_b
.
stride_
,
// int stride_b,
0.0
,
// float beta,
0.0
,
// float beta,
reinterpret_cast
<
T
*>
(
data_c
),
// TY* y,
reinterpret_cast
<
XPUType
*>
(
data_c
),
// TY* y,
m
*
n
,
// int stride_c,
m
*
n
,
// int stride_c,
nullptr
,
// const float* x_maxptr,
nullptr
,
// const float* x_maxptr,
nullptr
);
// const float* w_maxptr
nullptr
);
// const float* w_maxptr
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
platform
::
errors
::
External
(
...
@@ -210,10 +213,14 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
...
@@ -210,10 +213,14 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
bool
trans_x
=
context
.
Attr
<
bool
>
(
"transpose_X"
);
bool
trans_x
=
context
.
Attr
<
bool
>
(
"transpose_X"
);
bool
trans_y
=
context
.
Attr
<
bool
>
(
"transpose_Y"
);
bool
trans_y
=
context
.
Attr
<
bool
>
(
"transpose_Y"
);
if
(
std
::
getenv
(
"XPU_PADDLE_MAT_MUL_FCINT32"
)
!=
nullptr
)
{
if
(
std
::
is_same
<
paddle
::
platform
::
float16
,
T
>::
value
)
{
MatMulXPUFunction
<
T
,
int32_t
>
(
x
,
y
,
out
,
trans_x
,
trans_y
,
context
);
}
else
{
MatMulXPUFunction
<
T
,
int16_t
>
(
x
,
y
,
out
,
trans_x
,
trans_y
,
context
);
MatMulXPUFunction
<
T
,
int16_t
>
(
x
,
y
,
out
,
trans_x
,
trans_y
,
context
);
}
else
{
if
(
std
::
getenv
(
"XPU_PADDLE_MAT_MUL_FCINT32"
)
!=
nullptr
)
{
MatMulXPUFunction
<
T
,
int32_t
>
(
x
,
y
,
out
,
trans_x
,
trans_y
,
context
);
}
else
{
MatMulXPUFunction
<
T
,
int16_t
>
(
x
,
y
,
out
,
trans_x
,
trans_y
,
context
);
}
}
}
}
}
};
};
...
@@ -224,6 +231,7 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
...
@@ -224,6 +231,7 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
static
framework
::
Tensor
XPUFoldHeadAndLastDims
(
static
framework
::
Tensor
XPUFoldHeadAndLastDims
(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
)
{
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
)
{
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
auto
in_dims
=
input
.
dims
();
auto
in_dims
=
input
.
dims
();
if
(
in_dims
.
size
()
!=
3
)
{
if
(
in_dims
.
size
()
!=
3
)
{
return
input
;
return
input
;
...
@@ -236,8 +244,9 @@ static framework::Tensor XPUFoldHeadAndLastDims(
...
@@ -236,8 +244,9 @@ static framework::Tensor XPUFoldHeadAndLastDims(
static_cast
<
int
>
(
in_dims
[
1
]),
static_cast
<
int
>
(
in_dims
[
1
]),
static_cast
<
int
>
(
in_dims
[
2
])};
static_cast
<
int
>
(
in_dims
[
2
])};
std
::
vector
<
int
>
axis_host
=
{
1
,
0
,
2
};
std
::
vector
<
int
>
axis_host
=
{
1
,
0
,
2
};
int
r
=
xpu
::
transpose
(
context
.
x_context
(),
input
.
data
<
T
>
(),
output
.
data
<
T
>
(),
int
r
=
xpu
::
transpose
(
in_shape_host
,
axis_host
);
context
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
input
.
data
<
T
>
()),
reinterpret_cast
<
XPUType
*>
(
output
.
data
<
T
>
()),
in_shape_host
,
axis_host
);
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
platform
::
errors
::
External
(
"XPU transpose kernel return wrong value[%d %s]"
,
r
,
"XPU transpose kernel return wrong value[%d %s]"
,
r
,
...
@@ -280,10 +289,14 @@ class MatMulGradXPUKernel : public framework::OpKernel<T> {
...
@@ -280,10 +289,14 @@ class MatMulGradXPUKernel : public framework::OpKernel<T> {
const
framework
::
Tensor
&
b
,
bool
trans_b
,
const
framework
::
Tensor
&
b
,
bool
trans_b
,
framework
::
Tensor
*
out
)
const
{
framework
::
Tensor
*
out
)
const
{
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
if
(
std
::
getenv
(
"XPU_PADDLE_MAT_MUL_GRAD_FCINT32"
)
!=
nullptr
)
{
if
(
std
::
is_same
<
paddle
::
platform
::
float16
,
T
>::
value
)
{
MatMulXPUFunction
<
T
,
int32_t
>
(
&
a
,
&
b
,
out
,
trans_a
,
trans_b
,
context
);
}
else
{
MatMulXPUFunction
<
T
,
int16_t
>
(
&
a
,
&
b
,
out
,
trans_a
,
trans_b
,
context
);
MatMulXPUFunction
<
T
,
int16_t
>
(
&
a
,
&
b
,
out
,
trans_a
,
trans_b
,
context
);
}
else
{
if
(
std
::
getenv
(
"XPU_PADDLE_MAT_MUL_GRAD_FCINT32"
)
!=
nullptr
)
{
MatMulXPUFunction
<
T
,
int32_t
>
(
&
a
,
&
b
,
out
,
trans_a
,
trans_b
,
context
);
}
else
{
MatMulXPUFunction
<
T
,
int16_t
>
(
&
a
,
&
b
,
out
,
trans_a
,
trans_b
,
context
);
}
}
}
}
}
...
@@ -370,10 +383,14 @@ class MatMulGradXPUKernel : public framework::OpKernel<T> {
...
@@ -370,10 +383,14 @@ class MatMulGradXPUKernel : public framework::OpKernel<T> {
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_XPU_KERNEL
(
REGISTER_OP_XPU_KERNEL
(
matmul
,
ops
::
MatMulXPUKernel
<
paddle
::
platform
::
XPUDeviceContext
,
float
>
);
matmul
,
ops
::
MatMulXPUKernel
<
paddle
::
platform
::
XPUDeviceContext
,
float
>
,
ops
::
MatMulXPUKernel
<
paddle
::
platform
::
XPUDeviceContext
,
plat
::
float16
>
);
REGISTER_OP_XPU_KERNEL
(
REGISTER_OP_XPU_KERNEL
(
matmul_grad
,
matmul_grad
,
ops
::
MatMulGradXPUKernel
<
paddle
::
platform
::
XPUDeviceContext
,
float
>
);
ops
::
MatMulGradXPUKernel
<
paddle
::
platform
::
XPUDeviceContext
,
float
>
,
ops
::
MatMulGradXPUKernel
<
paddle
::
platform
::
XPUDeviceContext
,
plat
::
float16
>
);
#endif
#endif
paddle/fluid/operators/matmul_v2_op_xpu.cc
浏览文件 @
4d4fb660
...
@@ -25,6 +25,7 @@ template <typename T, typename FCT>
...
@@ -25,6 +25,7 @@ template <typename T, typename FCT>
static
void
MatMulXPUFunction
(
const
Tensor
*
x
,
const
Tensor
*
y
,
Tensor
*
out
,
static
void
MatMulXPUFunction
(
const
Tensor
*
x
,
const
Tensor
*
y
,
Tensor
*
out
,
bool
trans_x
,
bool
trans_y
,
bool
trans_x
,
bool
trans_y
,
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
{
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
{
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
const
auto
&
x_dims
=
x
->
dims
();
const
auto
&
x_dims
=
x
->
dims
();
const
auto
&
y_dims
=
y
->
dims
();
const
auto
&
y_dims
=
y
->
dims
();
auto
&
dev_ctx
=
auto
&
dev_ctx
=
...
@@ -75,9 +76,11 @@ static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out,
...
@@ -75,9 +76,11 @@ static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out,
int
batch_size
=
mat_dim_a
.
batch_size_
;
int
batch_size
=
mat_dim_a
.
batch_size_
;
if
(
batch_size
<=
1
)
{
if
(
batch_size
<=
1
)
{
int
r
=
0
;
int
r
=
0
;
r
=
xpu
::
fc
<
T
,
T
,
T
,
FCT
>
(
dev_ctx
.
x_context
(),
x
->
data
<
T
>
(),
y
->
data
<
T
>
(),
r
=
xpu
::
fc
<
XPUType
,
XPUType
,
XPUType
,
FCT
>
(
data_c
,
m
,
n
,
k
,
mat_dim_a
.
trans_
,
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
x
->
data
<
T
>
()),
mat_dim_b
.
trans_
,
nullptr
,
nullptr
,
nullptr
);
reinterpret_cast
<
const
XPUType
*>
(
y
->
data
<
T
>
()),
reinterpret_cast
<
XPUType
*>
(
data_c
),
m
,
n
,
k
,
mat_dim_a
.
trans_
,
mat_dim_b
.
trans_
,
nullptr
,
nullptr
,
nullptr
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
platform
::
errors
::
External
(
...
@@ -87,24 +90,24 @@ static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out,
...
@@ -87,24 +90,24 @@ static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out,
r
,
XPUAPIErrorMsg
[
r
],
m
,
n
,
k
,
mat_dim_a
.
trans_
,
mat_dim_b
.
trans_
));
r
,
XPUAPIErrorMsg
[
r
],
m
,
n
,
k
,
mat_dim_a
.
trans_
,
mat_dim_b
.
trans_
));
}
else
{
}
else
{
// batch matmul
// batch matmul
int
r
=
xpu
::
fc_batched
<
T
,
T
,
T
,
FCT
>
(
int
r
=
xpu
::
fc_batched
<
XPUType
,
XPUType
,
XPUType
,
FCT
>
(
dev_ctx
.
x_context
(),
// Context* ctx,
dev_ctx
.
x_context
(),
// Context* ctx,
batch_size
,
// int batch_size,
batch_size
,
// int batch_size,
mat_dim_a
.
trans_
,
// bool x_trans,
mat_dim_a
.
trans_
,
// bool x_trans,
mat_dim_b
.
trans_
,
// bool w_trans,
mat_dim_b
.
trans_
,
// bool w_trans,
m
,
// int m,
m
,
// int m,
n
,
// int n,
n
,
// int n,
k
,
// int k,
k
,
// int k,
1.0
,
// float alpha,
1.0
,
// float alpha,
reinterpret_cast
<
const
T
*>
(
x
->
data
<
T
>
()),
// const TX* x,
reinterpret_cast
<
const
XPUType
*>
(
x
->
data
<
T
>
()),
// const TX* x,
mat_dim_a
.
stride_
,
// int stride_a,
mat_dim_a
.
stride_
,
// int stride_a,
reinterpret_cast
<
const
T
*>
(
y
->
data
<
T
>
()),
// const TW* w,
reinterpret_cast
<
const
XPUType
*>
(
y
->
data
<
T
>
()),
// const TW* w,
mat_dim_b
.
stride_
,
// int stride_b,
mat_dim_b
.
stride_
,
// int stride_b,
0.0
,
// float beta,
0.0
,
// float beta,
reinterpret_cast
<
T
*>
(
data_c
),
// TY* y,
reinterpret_cast
<
XPUType
*>
(
data_c
),
// TY* y,
m
*
n
,
// int stride_c,
m
*
n
,
// int stride_c,
nullptr
,
// const float* x_maxptr,
nullptr
,
// const float* x_maxptr,
nullptr
);
// const float* w_maxptr
nullptr
);
// const float* w_maxptr
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
platform
::
errors
::
External
(
...
@@ -123,10 +126,14 @@ class MatMulV2XPUKernel : public framework::OpKernel<T> {
...
@@ -123,10 +126,14 @@ class MatMulV2XPUKernel : public framework::OpKernel<T> {
bool
trans_x
=
ctx
.
Attr
<
bool
>
(
"trans_x"
);
bool
trans_x
=
ctx
.
Attr
<
bool
>
(
"trans_x"
);
bool
trans_y
=
ctx
.
Attr
<
bool
>
(
"trans_y"
);
bool
trans_y
=
ctx
.
Attr
<
bool
>
(
"trans_y"
);
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
std
::
getenv
(
"XPU_PADDLE_MAT_MUL_V2_FCINT32"
)
!=
nullptr
)
{
if
(
std
::
is_same
<
paddle
::
platform
::
float16
,
T
>::
value
)
{
MatMulXPUFunction
<
T
,
int32_t
>
(
x
,
y
,
out
,
trans_x
,
trans_y
,
ctx
);
}
else
{
MatMulXPUFunction
<
T
,
int16_t
>
(
x
,
y
,
out
,
trans_x
,
trans_y
,
ctx
);
MatMulXPUFunction
<
T
,
int16_t
>
(
x
,
y
,
out
,
trans_x
,
trans_y
,
ctx
);
}
else
{
if
(
std
::
getenv
(
"XPU_PADDLE_MAT_MUL_V2_FCINT32"
)
!=
nullptr
)
{
MatMulXPUFunction
<
T
,
int32_t
>
(
x
,
y
,
out
,
trans_x
,
trans_y
,
ctx
);
}
else
{
MatMulXPUFunction
<
T
,
int16_t
>
(
x
,
y
,
out
,
trans_x
,
trans_y
,
ctx
);
}
}
}
}
}
};
};
...
@@ -134,6 +141,7 @@ class MatMulV2XPUKernel : public framework::OpKernel<T> {
...
@@ -134,6 +141,7 @@ class MatMulV2XPUKernel : public framework::OpKernel<T> {
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
static
framework
::
Tensor
XPUFoldHeadAndLastDims
(
static
framework
::
Tensor
XPUFoldHeadAndLastDims
(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
)
{
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
)
{
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
auto
in_dims
=
input
.
dims
();
auto
in_dims
=
input
.
dims
();
if
(
in_dims
.
size
()
!=
3
)
{
if
(
in_dims
.
size
()
!=
3
)
{
return
input
;
return
input
;
...
@@ -147,8 +155,9 @@ static framework::Tensor XPUFoldHeadAndLastDims(
...
@@ -147,8 +155,9 @@ static framework::Tensor XPUFoldHeadAndLastDims(
static_cast
<
int
>
(
in_dims
[
2
])};
static_cast
<
int
>
(
in_dims
[
2
])};
std
::
vector
<
int
>
axis_host
=
{
1
,
0
,
2
};
std
::
vector
<
int
>
axis_host
=
{
1
,
0
,
2
};
int
r
=
xpu
::
transpose
(
context
.
x_context
(),
input
.
data
<
T
>
(),
output
.
data
<
T
>
(),
int
r
=
xpu
::
transpose
(
in_shape_host
,
axis_host
);
context
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
input
.
data
<
T
>
()),
reinterpret_cast
<
XPUType
*>
(
output
.
data
<
T
>
()),
in_shape_host
,
axis_host
);
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
platform
::
errors
::
External
(
"XPU transpose kernel return wrong value[%d %s]"
,
r
,
"XPU transpose kernel return wrong value[%d %s]"
,
r
,
...
@@ -166,10 +175,14 @@ class MatMulV2XPUGradKernel : public framework::OpKernel<T> {
...
@@ -166,10 +175,14 @@ class MatMulV2XPUGradKernel : public framework::OpKernel<T> {
const
framework
::
Tensor
&
b
,
bool
trans_b
,
const
framework
::
Tensor
&
b
,
bool
trans_b
,
framework
::
Tensor
*
out
)
const
{
framework
::
Tensor
*
out
)
const
{
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
std
::
getenv
(
"XPU_PADDLE_MAT_MUL_GRAD_V2_FCINT32"
)
!=
nullptr
)
{
if
(
std
::
is_same
<
paddle
::
platform
::
float16
,
T
>::
value
)
{
MatMulXPUFunction
<
T
,
int32_t
>
(
&
a
,
&
b
,
out
,
trans_a
,
trans_b
,
ctx
);
}
else
{
MatMulXPUFunction
<
T
,
int16_t
>
(
&
a
,
&
b
,
out
,
trans_a
,
trans_b
,
ctx
);
MatMulXPUFunction
<
T
,
int16_t
>
(
&
a
,
&
b
,
out
,
trans_a
,
trans_b
,
ctx
);
}
else
{
if
(
std
::
getenv
(
"XPU_PADDLE_MAT_MUL_GRAD_V2_FCINT32"
)
!=
nullptr
)
{
MatMulXPUFunction
<
T
,
int32_t
>
(
&
a
,
&
b
,
out
,
trans_a
,
trans_b
,
ctx
);
}
else
{
MatMulXPUFunction
<
T
,
int16_t
>
(
&
a
,
&
b
,
out
,
trans_a
,
trans_b
,
ctx
);
}
}
}
}
}
...
@@ -261,8 +274,10 @@ class MatMulV2XPUGradKernel : public framework::OpKernel<T> {
...
@@ -261,8 +274,10 @@ class MatMulV2XPUGradKernel : public framework::OpKernel<T> {
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_XPU_KERNEL
(
matmul_v2
,
ops
::
MatMulV2XPUKernel
<
float
>
);
REGISTER_OP_XPU_KERNEL
(
matmul_v2
,
ops
::
MatMulV2XPUKernel
<
float
>
,
REGISTER_OP_XPU_KERNEL
(
matmul_v2_grad
,
ops
::
MatMulV2XPUGradKernel
<
float
>
);
ops
::
MatMulV2XPUKernel
<
plat
::
float16
>
);
REGISTER_OP_XPU_KERNEL
(
matmul_v2_grad
,
ops
::
MatMulV2XPUGradKernel
<
float
>
,
ops
::
MatMulV2XPUGradKernel
<
plat
::
float16
>
);
#endif
#endif
paddle/fluid/operators/softmax_op_xpu.cc
浏览文件 @
4d4fb660
...
@@ -47,8 +47,8 @@ class SoftmaxXPUKernel : public framework::OpKernel<T> {
...
@@ -47,8 +47,8 @@ class SoftmaxXPUKernel : public framework::OpKernel<T> {
int
len
=
x
->
numel
();
int
len
=
x
->
numel
();
T
*
clip_x_data
=
T
*
clip_x_data
=
clip_x
.
mutable_data
<
T
>
(
context
.
GetPlace
(),
len
*
sizeof
(
T
));
clip_x
.
mutable_data
<
T
>
(
context
.
GetPlace
(),
len
*
sizeof
(
T
));
r
=
xpu
::
clip
(
dev_ctx
.
x_context
(),
x
->
data
<
float
>
(),
clip_x_data
,
len
,
r
=
xpu
::
clip
_v2
(
dev_ctx
.
x_context
(),
x
->
data
<
float
>
(),
clip_x_data
,
len
,
-
1e30
,
1e30
);
static_cast
<
float
>
(
-
1e20
),
static_cast
<
float
>
(
1e20
)
);
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU API(clip) return wrong "
platform
::
errors
::
External
(
"XPU API(clip) return wrong "
"value[%d %s]"
,
"value[%d %s]"
,
...
...
paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc
浏览文件 @
4d4fb660
...
@@ -54,8 +54,9 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
...
@@ -54,8 +54,9 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
int
len
=
logits
->
numel
();
int
len
=
logits
->
numel
();
T
*
clip_logits_data
=
T
*
clip_logits_data
=
clip_logits
.
mutable_data
<
T
>
(
context
.
GetPlace
(),
len
*
sizeof
(
T
));
clip_logits
.
mutable_data
<
T
>
(
context
.
GetPlace
(),
len
*
sizeof
(
T
));
r
=
xpu
::
clip
(
dev_ctx
.
x_context
(),
logits
->
data
<
float
>
(),
clip_logits_data
,
r
=
xpu
::
clip_v2
(
dev_ctx
.
x_context
(),
logits
->
data
<
float
>
(),
len
,
-
1e30
,
1e30
);
clip_logits_data
,
len
,
static_cast
<
float
>
(
-
1e20
),
static_cast
<
float
>
(
1e20
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
r
,
xpu
::
Error_t
::
SUCCESS
,
r
,
xpu
::
Error_t
::
SUCCESS
,
platform
::
errors
::
External
(
"XPU kernel error. clip "
platform
::
errors
::
External
(
"XPU kernel error. clip "
...
...
paddle/fluid/platform/xpu_header.h
浏览文件 @
4d4fb660
// Copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// you may not use this file except in compliance with the License.
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include <unordered_map>
#include <unordered_map>
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/float16.h"
#include "xpu/api.h"
#include "xpu/api.h"
#include "xpu/refactor/fusion.h"
#include "xpu/refactor/fusion.h"
#include "xpu/refactor/math.h"
#include "xpu/refactor/math.h"
...
@@ -58,4 +59,16 @@ static std::map<int, std::string> XPUAPIErrorMsg = {
...
@@ -58,4 +59,16 @@ static std::map<int, std::string> XPUAPIErrorMsg = {
{
xpu
::
Error_t
::
RUNTIME_ERROR
,
"xpu api runtime error"
},
{
xpu
::
Error_t
::
RUNTIME_ERROR
,
"xpu api runtime error"
},
{
xpu
::
Error_t
::
NO_ENOUGH_WORKSPACE
,
"xpu api no enough workspace"
}};
{
xpu
::
Error_t
::
NO_ENOUGH_WORKSPACE
,
"xpu api no enough workspace"
}};
template
<
typename
T
>
class
XPUTypeTrait
{
public:
using
Type
=
T
;
};
template
<
>
class
XPUTypeTrait
<
paddle
::
platform
::
float16
>
{
public:
using
Type
=
float16
;
};
#endif
#endif
paddle/fluid/pybind/pybind.cc
浏览文件 @
4d4fb660
...
@@ -225,7 +225,9 @@ OpSupportedInfos(const std::string &place,
...
@@ -225,7 +225,9 @@ OpSupportedInfos(const std::string &place,
[](
unsigned
char
c
)
{
return
std
::
toupper
(
c
);
});
[](
unsigned
char
c
)
{
return
std
::
toupper
(
c
);
});
using
fn_type
=
std
::
add_pointer
<
bool
(
const
platform
::
Place
&
)
>::
type
;
using
fn_type
=
std
::
add_pointer
<
bool
(
const
platform
::
Place
&
)
>::
type
;
std
::
unordered_map
<
std
::
string
,
fn_type
>
is_target_place
{
std
::
unordered_map
<
std
::
string
,
fn_type
>
is_target_place
{
{
"GPU"
,
&
platform
::
is_gpu_place
},
{
"CPU"
,
&
platform
::
is_cpu_place
},
{
"GPU"
,
&
platform
::
is_gpu_place
},
{
"CPU"
,
&
platform
::
is_cpu_place
},
{
"XPU"
,
&
platform
::
is_xpu_place
},
};
};
PADDLE_ENFORCE_NE
(
PADDLE_ENFORCE_NE
(
is_target_place
.
count
(
query_place
),
0
,
is_target_place
.
count
(
query_place
),
0
,
...
...
python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
浏览文件 @
4d4fb660
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
import
copy
import
copy
from
...
import
core
from
...
import
core
import
paddle.fluid
as
fluid
__all__
=
[
"CustomOpLists"
,
"AutoMixedPrecisionLists"
]
__all__
=
[
"CustomOpLists"
,
"AutoMixedPrecisionLists"
]
...
@@ -152,8 +153,14 @@ gray_list = {
...
@@ -152,8 +153,14 @@ gray_list = {
# The set of ops that don't support fp16 calculation
# The set of ops that don't support fp16 calculation
# lookup_table fp16 is slower than fp32, though fp16 is supported.
# lookup_table fp16 is slower than fp32, though fp16 is supported.
_
,
_
,
_sys_unsupported_fp16_list
=
core
.
op_supported_infos
(
_sys_unsupported_fp16_list
=
[]
'GPU'
,
core
.
VarDesc
.
VarType
.
FP16
)
if
fluid
.
is_compiled_with_xpu
():
_
,
_
,
_sys_unsupported_fp16_list
=
core
.
op_supported_infos
(
'XPU'
,
core
.
VarDesc
.
VarType
.
FP16
)
else
:
_
,
_
,
_sys_unsupported_fp16_list
=
core
.
op_supported_infos
(
'GPU'
,
core
.
VarDesc
.
VarType
.
FP16
)
unsupported_fp16_list
=
{
'lookup_table'
,
unsupported_fp16_list
=
{
'lookup_table'
,
'lookup_table_v2'
}
|
_sys_unsupported_fp16_list
'lookup_table_v2'
}
|
_sys_unsupported_fp16_list
...
...
python/paddle/fluid/dygraph/amp/auto_cast.py
浏览文件 @
4d4fb660
...
@@ -130,9 +130,10 @@ def amp_guard(enable=True, custom_white_list=None, custom_black_list=None):
...
@@ -130,9 +130,10 @@ def amp_guard(enable=True, custom_white_list=None, custom_black_list=None):
raise
ValueError
(
raise
ValueError
(
"current_tracer is None, maybe it is not in imperative mode."
)
"current_tracer is None, maybe it is not in imperative mode."
)
if
enable
and
not
tracer
.
_expected_place
.
is_gpu_place
():
if
enable
and
not
(
tracer
.
_expected_place
.
is_gpu_place
()
or
tracer
.
_expected_place
.
is_xpu_place
()):
warnings
.
warn
(
warnings
.
warn
(
'amp_guard can only be enabled on CUDAPlace, current place is %s, so it makes no effect.'
'amp_guard can only be enabled on CUDAPlace
and XPUPlace
, current place is %s, so it makes no effect.'
%
tracer
.
_expected_place
)
%
tracer
.
_expected_place
)
enable
=
False
enable
=
False
...
...
python/paddle/fluid/dygraph/amp/loss_scaler.py
浏览文件 @
4d4fb660
...
@@ -90,9 +90,10 @@ class AmpScaler(object):
...
@@ -90,9 +90,10 @@ class AmpScaler(object):
raise
ValueError
(
raise
ValueError
(
"current_tracer is None, maybe it is not in imperative mode."
)
"current_tracer is None, maybe it is not in imperative mode."
)
if
enable
and
not
tracer
.
_expected_place
.
is_gpu_place
():
if
enable
and
not
(
tracer
.
_expected_place
.
is_gpu_place
()
or
tracer
.
_expected_place
.
is_xpu_place
()):
warnings
.
warn
(
warnings
.
warn
(
'AmpScaler can only be enabled on CUDAPlace, current place is %s, so it makes no effect.'
'AmpScaler can only be enabled on CUDAPlace
and XPUPlace
, current place is %s, so it makes no effect.'
%
tracer
.
_expected_place
)
%
tracer
.
_expected_place
)
enable
=
False
enable
=
False
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录