Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
e512aa9a
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e512aa9a
编写于
11月 02, 2021
作者:
Q
QingshuChen
提交者:
GitHub
11月 02, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support different precision in kunlun (#36836)
* support different precision in kunlun * minor * minor * minor
上级
5c4c55f9
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
126 addition
and
16 deletion
+126
-16
cmake/external/xpu.cmake
cmake/external/xpu.cmake
+2
-1
paddle/fluid/operators/matmul_op_xpu.cc
paddle/fluid/operators/matmul_op_xpu.cc
+12
-9
paddle/fluid/operators/matmul_v2_op_xpu.cc
paddle/fluid/operators/matmul_v2_op_xpu.cc
+15
-5
paddle/fluid/operators/xpu_api_wrapper.h
paddle/fluid/operators/xpu_api_wrapper.h
+53
-0
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+5
-1
paddle/fluid/platform/xpu/xpu2_op_list.h
paddle/fluid/platform/xpu/xpu2_op_list.h
+39
-0
未找到文件。
cmake/external/xpu.cmake
浏览文件 @
e512aa9a
...
...
@@ -35,7 +35,8 @@ ELSE ()
ENDIF
()
SET
(
XPU_BASE_URL_WITHOUT_DATE
"https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev"
)
SET
(
XPU_BASE_URL
"
${
XPU_BASE_URL_WITHOUT_DATE
}
/20211020"
)
SET
(
XPU_BASE_URL
"
${
XPU_BASE_URL_WITHOUT_DATE
}
/20211029"
)
#SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20211020")
SET
(
XPU_XRE_URL
"
${
XPU_BASE_URL
}
/
${
XPU_XRE_DIR_NAME
}
.tar.gz"
CACHE STRING
""
FORCE
)
SET
(
XPU_XDNN_URL
"
${
XPU_BASE_URL
}
/
${
XPU_XDNN_DIR_NAME
}
.tar.gz"
CACHE STRING
""
FORCE
)
SET
(
XPU_XCCL_URL
"
${
XPU_BASE_URL_WITHOUT_DATE
}
/20210623/
${
XPU_XCCL_DIR_NAME
}
.tar.gz"
CACHE STRING
""
FORCE
)
...
...
paddle/fluid/operators/matmul_op_xpu.cc
浏览文件 @
e512aa9a
...
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/xpu_api_wrapper.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -151,28 +152,26 @@ static void MatMulXPUFunction(const Tensor *x, const Tensor *y, Tensor *out,
x_dims
.
to_str
().
c_str
(),
y_dims
.
to_str
().
c_str
()));
float
alpha
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"alpha"
));
T
*
data_c
=
out
->
data
<
T
>
();
int
m
=
mat_dim_a
.
height_
;
int
n
=
mat_dim_b
.
width_
;
int
k
=
mat_dim_a
.
width_
;
int
batch_size
=
mat_dim_a
.
batch_size_
;
int
ldx
=
mat_dim_a
.
trans_
?
m
:
k
;
int
ldy
=
mat_dim_b
.
trans_
?
k
:
n
;
int
ldout
=
n
;
if
(
batch_size
<=
1
)
{
int
r
=
0
;
r
=
xpu
::
fc_fusion
<
XPUType
,
XPUType
,
XPUType
,
FCT
>
(
r
=
xpu
_fc_wrapper
<
XPUType
,
FCT
>
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
x
->
data
<
T
>
()),
reinterpret_cast
<
const
XPUType
*>
(
y
->
data
<
T
>
()),
reinterpret_cast
<
XPUType
*>
(
data_c
),
m
,
n
,
k
,
mat_dim_a
.
trans_
,
mat_dim_b
.
trans_
,
nullptr
,
nullptr
,
nullptr
,
ldx
,
ldy
,
ldout
,
alpha
,
0
,
nullptr
,
xpu
::
Activation_t
::
LINEAR
);
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU fc_fusion
kernel return wrong value[%d %s]"
,
r
,
XPUAPIErrorMsg
[
r
]));
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU fc
kernel return wrong value[%d %s]"
,
r
,
XPUAPIErrorMsg
[
r
]));
}
else
{
// batch matmul
int
r
=
xpu
::
fc_batched
<
XPUType
,
XPUType
,
XPUType
,
FCT
>
(
...
...
@@ -216,8 +215,10 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
if
(
std
::
is_same
<
paddle
::
platform
::
float16
,
T
>::
value
)
{
MatMulXPUFunction
<
T
,
int16_t
>
(
x
,
y
,
out
,
trans_x
,
trans_y
,
context
);
}
else
{
if
(
std
::
getenv
(
"XPU_PADDLE_
MAT_MUL_FC
INT32"
)
!=
nullptr
)
{
if
(
std
::
getenv
(
"XPU_PADDLE_
FC_
INT32"
)
!=
nullptr
)
{
MatMulXPUFunction
<
T
,
int32_t
>
(
x
,
y
,
out
,
trans_x
,
trans_y
,
context
);
}
else
if
(
std
::
getenv
(
"XPU_PADDLE_FC_LOCAL_INT16"
)
!=
nullptr
)
{
MatMulXPUFunction
<
T
,
float
>
(
x
,
y
,
out
,
trans_x
,
trans_y
,
context
);
}
else
{
MatMulXPUFunction
<
T
,
int16_t
>
(
x
,
y
,
out
,
trans_x
,
trans_y
,
context
);
}
...
...
@@ -292,8 +293,10 @@ class MatMulGradXPUKernel : public framework::OpKernel<T> {
if
(
std
::
is_same
<
paddle
::
platform
::
float16
,
T
>::
value
)
{
MatMulXPUFunction
<
T
,
int16_t
>
(
&
a
,
&
b
,
out
,
trans_a
,
trans_b
,
context
);
}
else
{
if
(
std
::
getenv
(
"XPU_PADDLE_
MAT_MUL_GRAD_FC
INT32"
)
!=
nullptr
)
{
if
(
std
::
getenv
(
"XPU_PADDLE_
FC_
INT32"
)
!=
nullptr
)
{
MatMulXPUFunction
<
T
,
int32_t
>
(
&
a
,
&
b
,
out
,
trans_a
,
trans_b
,
context
);
}
else
if
(
std
::
getenv
(
"XPU_PADDLE_FC_LOCAL_INT16"
)
!=
nullptr
)
{
MatMulXPUFunction
<
T
,
float
>
(
&
a
,
&
b
,
out
,
trans_a
,
trans_b
,
context
);
}
else
{
MatMulXPUFunction
<
T
,
int16_t
>
(
&
a
,
&
b
,
out
,
trans_a
,
trans_b
,
context
);
}
...
...
paddle/fluid/operators/matmul_v2_op_xpu.cc
浏览文件 @
e512aa9a
...
...
@@ -18,6 +18,8 @@
#include <string>
#include <vector>
#include "paddle/fluid/operators/xpu_api_wrapper.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -74,17 +76,21 @@ static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out,
int
n
=
mat_dim_b
.
width_
;
int
k
=
mat_dim_a
.
width_
;
int
batch_size
=
mat_dim_a
.
batch_size_
;
int
ldx
=
mat_dim_a
.
trans_
?
m
:
k
;
int
ldy
=
mat_dim_b
.
trans_
?
k
:
n
;
int
ldout
=
n
;
if
(
batch_size
<=
1
)
{
int
r
=
0
;
r
=
xpu
::
fc
<
XPUType
,
XPUType
,
XPUType
,
FCT
>
(
r
=
xpu
_fc_wrapper
<
XPUType
,
FCT
>
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
x
->
data
<
T
>
()),
reinterpret_cast
<
const
XPUType
*>
(
y
->
data
<
T
>
()),
reinterpret_cast
<
XPUType
*>
(
data_c
),
m
,
n
,
k
,
mat_dim_a
.
trans_
,
mat_dim_b
.
trans_
,
nullptr
,
nullptr
,
nullptr
);
mat_dim_b
.
trans_
,
nullptr
,
nullptr
,
nullptr
,
ldx
,
ldy
,
ldout
,
1.0
,
0
,
nullptr
,
xpu
::
Activation_t
::
LINEAR
);
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU fc
_fusion
kernel return wrong value[%d %s] , m = %d, n = "
"XPU fc kernel return wrong value[%d %s] , m = %d, n = "
"%d, "
"k = %d, a_tr = %d, b_tr = %d"
,
r
,
XPUAPIErrorMsg
[
r
],
m
,
n
,
k
,
mat_dim_a
.
trans_
,
mat_dim_b
.
trans_
));
...
...
@@ -129,8 +135,10 @@ class MatMulV2XPUKernel : public framework::OpKernel<T> {
if
(
std
::
is_same
<
paddle
::
platform
::
float16
,
T
>::
value
)
{
MatMulXPUFunction
<
T
,
int16_t
>
(
x
,
y
,
out
,
trans_x
,
trans_y
,
ctx
);
}
else
{
if
(
std
::
getenv
(
"XPU_PADDLE_
MAT_MUL_V2_FC
INT32"
)
!=
nullptr
)
{
if
(
std
::
getenv
(
"XPU_PADDLE_
FC_
INT32"
)
!=
nullptr
)
{
MatMulXPUFunction
<
T
,
int32_t
>
(
x
,
y
,
out
,
trans_x
,
trans_y
,
ctx
);
}
else
if
(
std
::
getenv
(
"XPU_PADDLE_FC_LOCAL_INT16"
)
!=
nullptr
)
{
MatMulXPUFunction
<
T
,
float
>
(
x
,
y
,
out
,
trans_x
,
trans_y
,
ctx
);
}
else
{
MatMulXPUFunction
<
T
,
int16_t
>
(
x
,
y
,
out
,
trans_x
,
trans_y
,
ctx
);
}
...
...
@@ -178,8 +186,10 @@ class MatMulV2XPUGradKernel : public framework::OpKernel<T> {
if
(
std
::
is_same
<
paddle
::
platform
::
float16
,
T
>::
value
)
{
MatMulXPUFunction
<
T
,
int16_t
>
(
&
a
,
&
b
,
out
,
trans_a
,
trans_b
,
ctx
);
}
else
{
if
(
std
::
getenv
(
"XPU_PADDLE_
MAT_MUL_GRAD_V2_FC
INT32"
)
!=
nullptr
)
{
if
(
std
::
getenv
(
"XPU_PADDLE_
FC_
INT32"
)
!=
nullptr
)
{
MatMulXPUFunction
<
T
,
int32_t
>
(
&
a
,
&
b
,
out
,
trans_a
,
trans_b
,
ctx
);
}
else
if
(
std
::
getenv
(
"XPU_PADDLE_FC_LOCAL_INT16"
)
!=
nullptr
)
{
MatMulXPUFunction
<
T
,
float
>
(
&
a
,
&
b
,
out
,
trans_a
,
trans_b
,
ctx
);
}
else
{
MatMulXPUFunction
<
T
,
int16_t
>
(
&
a
,
&
b
,
out
,
trans_a
,
trans_b
,
ctx
);
}
...
...
paddle/fluid/operators/xpu_api_wrapper.h
0 → 100644
浏览文件 @
e512aa9a
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_XPU
#include <vector>
namespace
paddle
{
namespace
operators
{
template
<
typename
XPUType
,
typename
FCT
>
int
xpu_fc_wrapper
(
xpu
::
Context
*
ctx
,
const
XPUType
*
x
,
const
XPUType
*
w
,
XPUType
*
y
,
int
m
,
int
n
,
int
k
,
bool
x_trans
,
bool
w_trans
,
const
float
*
x_maxptr
,
const
float
*
w_maxptr
,
float
*
y_maxptr
,
int
ldx
,
int
ldw
,
int
ldy
,
float
alpha
,
float
beta
,
const
float
*
bias
,
const
xpu
::
Activation_t
&
act
)
{
int
r
=
0
;
if
(
x_trans
&&
std
::
getenv
(
"XPU_PADDLE_FC_TRANS_A"
)
!=
nullptr
&&
std
::
is_same
<
float
,
XPUType
>::
value
)
{
XPUType
*
l3_addr
=
nullptr
;
xpu
::
ctx_guard
RAII_GUARD
(
ctx
);
l3_addr
=
RAII_GUARD
.
alloc_l3_or_gm
<
XPUType
>
(
m
*
k
);
if
(
l3_addr
==
nullptr
)
return
XPUERR_NOMEM
;
std
::
vector
<
int
>
shape
=
{
k
,
m
};
std
::
vector
<
int
>
axis
=
{
1
,
0
};
r
=
xpu
::
transpose
<
XPUType
>
(
ctx
,
x
,
l3_addr
,
shape
,
axis
);
if
(
r
!=
XPU_SUCCESS
)
return
r
;
r
=
xpu
::
fc_fusion
<
XPUType
,
XPUType
,
XPUType
,
FCT
>
(
ctx
,
l3_addr
,
w
,
y
,
m
,
n
,
k
,
false
,
w_trans
,
x_maxptr
,
w_maxptr
,
y_maxptr
,
k
,
ldw
,
ldy
,
alpha
,
beta
,
bias
,
act
);
if
(
r
!=
XPU_SUCCESS
)
return
r
;
}
else
{
r
=
xpu
::
fc_fusion
<
XPUType
,
XPUType
,
XPUType
,
FCT
>
(
ctx
,
x
,
w
,
y
,
m
,
n
,
k
,
x_trans
,
w_trans
,
x_maxptr
,
w_maxptr
,
y_maxptr
,
ldx
,
ldw
,
ldy
,
alpha
,
beta
,
bias
,
act
);
}
return
r
;
}
}
// namespace operators
}
// namespace paddle
#endif
paddle/fluid/platform/device_context.cc
浏览文件 @
e512aa9a
...
...
@@ -222,9 +222,13 @@ XPUDeviceContext::XPUDeviceContext(XPUPlace place) : place_(place) {
context_
=
xpu
::
create_context
();
const
int
MAX_XPU_NUM
=
16
;
const
int
l3_size
=
13.5
*
1024
*
1024
;
static
void
*
l3ptrs
[
MAX_XPU_NUM
]
=
{
nullptr
};
int
l3_size
=
13.5
*
1024
*
1024
;
if
(
std
::
getenv
(
"XPU_PADDLE_L3_SIZE"
)
!=
nullptr
)
{
l3_size
=
atoi
(
std
::
getenv
(
"XPU_PADDLE_L3_SIZE"
));
}
auto
selected_xpus
=
GetXPUSelectedDevices
();
for
(
unsigned
int
i
=
0
;
i
<
selected_xpus
.
size
();
i
++
)
{
if
(
place
.
device
==
selected_xpus
[
i
])
{
...
...
paddle/fluid/platform/xpu/xpu2_op_list.h
浏览文件 @
e512aa9a
...
...
@@ -90,6 +90,12 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType
(
vartype
::
FP16
,
XPUPlace
())})},
{
"adam"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"adamw"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"reduce_sum"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"reduce_sum_grad"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"softmax"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"softmax_grad"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"softmax_with_cross_entropy"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"softmax_with_cross_entropy_grad"
,
...
...
@@ -171,6 +177,39 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType
(
vartype
::
INT32
,
XPUPlace
()),
pOpKernelType
(
vartype
::
INT8
,
XPUPlace
()),
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"matmul_v2"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"matmul_v2_grad"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"matmul"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"matmul_grad"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"relu"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"relu_grad"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"assign_value"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"dropout"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"dropout_grad"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"elementwise_div"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"elementwise_div_grad"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"range"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
()),
pOpKernelType
(
vartype
::
INT64
,
XPUPlace
())})},
{
"reshape2"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"reshape2_grad"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"shape"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
()),
pOpKernelType
(
vartype
::
INT64
,
XPUPlace
())})},
{
"one_hot_v2"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
()),
pOpKernelType
(
vartype
::
INT64
,
XPUPlace
())})},
{
"layer_norm"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"layer_norm_grad"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"lookup_table_v2"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"lookup_table_v2_grad"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"scale"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"flatten_contiguous_range"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
INT64
,
XPUPlace
()),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录