Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c6757bd3
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c6757bd3
编写于
8月 21, 2023
作者:
J
jiangfan06
提交者:
GitHub
8月 21, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XPU] Add xpu plugin for reduce ops (#56389)
上级
f8cba26d
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
660 addition
and
84 deletion
+660
-84
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
+29
-0
paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_reduce.xpu
.../kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_reduce.xpu
+262
-0
paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/take_along_axis.xpu
...nels/xpu/plugin/src/kernel/kunlun2cpp/take_along_axis.xpu
+0
-10
paddle/phi/kernels/xpu/plugin/src/wrapper/fast_reduce.cpp
paddle/phi/kernels/xpu/plugin/src/wrapper/fast_reduce.cpp
+291
-0
paddle/phi/kernels/xpu/plugin/src/wrapper/take_along_axis.cpp
...le/phi/kernels/xpu/plugin/src/wrapper/take_along_axis.cpp
+1
-25
paddle/phi/kernels/xpu/reduce.h
paddle/phi/kernels/xpu/reduce.h
+40
-49
paddle/phi/kernels/xpu/reduce_max_kernel.cc
paddle/phi/kernels/xpu/reduce_max_kernel.cc
+9
-0
paddle/phi/kernels/xpu/reduce_mean_kernel.cc
paddle/phi/kernels/xpu/reduce_mean_kernel.cc
+9
-0
paddle/phi/kernels/xpu/reduce_min_kernel.cc
paddle/phi/kernels/xpu/reduce_min_kernel.cc
+9
-0
paddle/phi/kernels/xpu/reduce_util.h
paddle/phi/kernels/xpu/reduce_util.h
+10
-0
未找到文件。
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
浏览文件 @
c6757bd3
...
@@ -75,6 +75,35 @@ DLL_EXPORT int fast_layer_norm(Context* ctx,
...
@@ -75,6 +75,35 @@ DLL_EXPORT int fast_layer_norm(Context* ctx,
float
eps
,
float
eps
,
const
float
*
scale
,
const
float
*
scale
,
const
float
*
bias
);
const
float
*
bias
);
template
<
typename
T
>
DLL_EXPORT
int
fast_reduce_sum
(
Context
*
ctx
,
const
T
*
x
,
T
*
y
,
const
std
::
vector
<
int
>&
xshape
,
const
std
::
vector
<
int
>&
rdims
);
template
<
typename
T
>
DLL_EXPORT
int
fast_reduce_mean
(
Context
*
ctx
,
const
T
*
x
,
T
*
y
,
const
std
::
vector
<
int
>&
xshape
,
const
std
::
vector
<
int
>&
rdims
);
template
<
typename
T
>
DLL_EXPORT
int
fast_reduce_max
(
Context
*
ctx
,
const
T
*
x
,
T
*
y
,
const
std
::
vector
<
int
>&
xshape
,
const
std
::
vector
<
int
>&
rdims
);
template
<
typename
T
>
DLL_EXPORT
int
fast_reduce_min
(
Context
*
ctx
,
const
T
*
x
,
T
*
y
,
const
std
::
vector
<
int
>&
xshape
,
const
std
::
vector
<
int
>&
rdims
);
}
// namespace plugin
}
// namespace plugin
}
// namespace api
}
// namespace api
}
// namespace xpu
}
// namespace xpu
...
...
paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_reduce.xpu
0 → 100644
浏览文件 @
c6757bd3
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu2 {
namespace plugin {
__device__ float do_sum_align16(float* lmptr, int size) {
__simd__ float sum_buf[16];
float32x16_t vsum = vset_zero();
for (int i = 0; i < size; i += 16) {
float32x16_t v0 = vload_lm_float32x16(lmptr + i);
vsum = vvadd_float32x16(vsum, v0);
}
vstore_lm_float32x16(sum_buf, vsum);
mfence_lm();
float sum = 0.0f;
for (int i = 0; i < 16; i++) {
sum = sum + sum_buf[i];
}
return sum;
}
__device__ float do_sum(float* lmptr, int size) {
float sum = 0.0f;
for (int i = 0; i < size; i++) {
sum += lmptr[i];
}
return sum;
}
__device__ float do_max_align16(float* lmptr, int size) {
__simd__ float max_buf[16];
float32x16_t vmax = vload_lm_float32x16(lmptr);
for (int i = 16; i < size; i += 16) {
float32x16_t v0 = vload_lm_float32x16(lmptr + i);
vmax = vvmax_float32x16(vmax, v0);
}
vstore_lm_float32x16(max_buf, vmax);
mfence_lm();
float max_val = max_buf[0];
for (int i = 1; i < 16; i++) {
max_val = fmax(max_val, max_buf[i]);
}
return max_val;
}
__device__ float do_max(float* lmptr, int size) {
float max_val = lmptr[0];
for (int i = 1; i < size; i++) {
max_val = fmax(max_val, lmptr[i]);
}
return max_val;
}
__device__ float do_min_align16(float* lmptr, int size) {
__simd__ float min_buf[16];
float32x16_t vmin = vload_lm_float32x16(lmptr);
for (int i = 16; i < size; i += 16) {
float32x16_t v0 = vload_lm_float32x16(lmptr + i);
vmin = vvmin_float32x16(vmin, v0);
}
vstore_lm_float32x16(min_buf, vmin);
mfence_lm();
float min_val = min_buf[0];
for (int i = 1; i < 16; i++) {
min_val = fmin(min_val, min_buf[i]);
}
return min_val;
}
__device__ float do_min(float* lmptr, int size) {
float min_val = lmptr[0];
for (int i = 1; i < size; i++) {
min_val = fmin(min_val, lmptr[i]);
}
return min_val;
}
template <typename T>
__global__ void fast_reduce_sum_tiny(const T* x, T* y, int m, int t) {
int cid = core_id();
const int ncores = core_num();
int tid = cid * cluster_num() + cluster_id();
int nthreads = cluster_num() * ncores;
const int64_t max_tt = 832;
const int64_t buffer_len = max_tt * 4 / sizeof(float);
int mstart = 0;
int mend = 0;
__simd__ float xlm[buffer_len];
__simd__ float ylm[buffer_len];
int block_cnt = buffer_len / t;
partition(tid, nthreads, m, 1, &mstart, &mend);
for (int i = mstart; i < mend; i += block_cnt) {
int readlen = min((mend - i) * t, block_cnt * t);
GM2LM(x + i * t, (T*)xlm, readlen * sizeof(T));
if (t % 16 == 0 && t >= 32) {
for (int j = 0; j < readlen; j += t) {
ylm[j / t] = do_sum_align16(xlm + j, t);
}
LM2GM((T*)ylm, y + i, readlen / t * sizeof(T));
} else {
primitive_cast<T, float>((T*)xlm, xlm, readlen);
for (int j = 0; j < readlen; j += t) {
ylm[j / t] = do_sum(xlm + j, t);
}
primitive_cast<float, T>(ylm, (T*)ylm, readlen / t);
LM2GM((T*)ylm, y + i, readlen / t * sizeof(T));
}
}
return;
}
template <typename T>
__global__ void fast_reduce_mean_tiny(const T* x, T* y, int m, int t) {
int cid = core_id();
const int ncores = core_num();
int tid = cid * cluster_num() + cluster_id();
int nthreads = cluster_num() * ncores;
const int64_t max_tt = 832;
const int64_t buffer_len = max_tt * 4 / sizeof(float);
int mstart = 0;
int mend = 0;
__simd__ float xlm[buffer_len];
__simd__ float ylm[buffer_len];
int block_cnt = buffer_len / t;
partition(tid, nthreads, m, 1, &mstart, &mend);
for (int i = mstart; i < mend; i += block_cnt) {
int readlen = min((mend - i) * t, block_cnt * t);
GM2LM(x + i * t, (T*)xlm, readlen * sizeof(T));
if (t % 16 == 0 && t >= 32) {
for (int j = 0; j < readlen; j += t) {
ylm[j / t] = do_sum_align16(xlm + j, t) / t;
}
LM2GM((T*)ylm, y + i, readlen / t * sizeof(T));
} else {
primitive_cast<T, float>((T*)xlm, xlm, readlen);
for (int j = 0; j < readlen; j += t) {
ylm[j / t] = do_sum(xlm + j, t) / t;
}
primitive_cast<float, T>(ylm, (T*)ylm, readlen / t);
LM2GM((T*)ylm, y + i, readlen / t * sizeof(T));
}
}
return;
}
template <typename T>
__global__ void fast_reduce_max_tiny(const T* x, T* y, int m, int t) {
int cid = core_id();
const int ncores = core_num();
int tid = cid * cluster_num() + cluster_id();
int nthreads = cluster_num() * ncores;
const int64_t max_tt = 832;
const int64_t buffer_len = max_tt * 4 / sizeof(float);
int mstart = 0;
int mend = 0;
__simd__ float xlm[buffer_len];
__simd__ float ylm[buffer_len];
int block_cnt = buffer_len / t;
partition(tid, nthreads, m, 1, &mstart, &mend);
for (int i = mstart; i < mend; i += block_cnt) {
int readlen = min((mend - i) * t, block_cnt * t);
GM2LM(x + i * t, (T*)xlm, readlen * sizeof(T));
if (t % 16 == 0 && t >= 32) {
for (int j = 0; j < readlen; j += t) {
ylm[j / t] = do_max_align16(xlm + j, t);
}
LM2GM((T*)ylm, y + i, readlen / t * sizeof(T));
} else {
primitive_cast<T, float>((T*)xlm, xlm, readlen);
for (int j = 0; j < readlen; j += t) {
ylm[j / t] = do_max(xlm + j, t);
}
primitive_cast<float, T>(ylm, (T*)ylm, readlen / t);
LM2GM((T*)ylm, y + i, readlen / t * sizeof(T));
}
}
return;
}
template <typename T>
__global__ void fast_reduce_min_tiny(const T* x, T* y, int m, int t) {
int cid = core_id();
const int ncores = core_num();
int tid = cid * cluster_num() + cluster_id();
int nthreads = cluster_num() * ncores;
const int64_t max_tt = 832;
const int64_t buffer_len = max_tt * 4 / sizeof(float);
int mstart = 0;
int mend = 0;
__simd__ float xlm[buffer_len];
__simd__ float ylm[buffer_len];
int block_cnt = buffer_len / t;
partition(tid, nthreads, m, 1, &mstart, &mend);
for (int i = mstart; i < mend; i += block_cnt) {
int readlen = min((mend - i) * t, block_cnt * t);
GM2LM(x + i * t, (T*)xlm, readlen * sizeof(T));
if (t % 16 == 0 && t >= 32) {
for (int j = 0; j < readlen; j += t) {
ylm[j / t] = do_min_align16(xlm + j, t);
}
LM2GM((T*)ylm, y + i, readlen / t * sizeof(T));
} else {
primitive_cast<T, float>((T*)xlm, xlm, readlen);
for (int j = 0; j < readlen; j += t) {
ylm[j / t] = do_min(xlm + j, t);
}
primitive_cast<float, T>(ylm, (T*)ylm, readlen / t);
LM2GM((T*)ylm, y + i, readlen / t * sizeof(T));
}
}
return;
}
#define _XPU_DEF__FAST_REDUCE_SUM_TINY_(DTYPE) \
template __global__ void fast_reduce_sum_tiny<DTYPE>( \
const DTYPE* x, DTYPE* y, int m, int t);
_XPU_DEF__FAST_REDUCE_SUM_TINY_(float);
_XPU_DEF__FAST_REDUCE_SUM_TINY_(float16);
#define _XPU_DEF__FAST_REDUCE_MEAN_TINY_(DTYPE) \
template __global__ void fast_reduce_mean_tiny<DTYPE>( \
const DTYPE* x, DTYPE* y, int m, int t);
_XPU_DEF__FAST_REDUCE_MEAN_TINY_(float);
_XPU_DEF__FAST_REDUCE_MEAN_TINY_(float16);
#define _XPU_DEF__FAST_REDUCE_MAX_TINY_(DTYPE) \
template __global__ void fast_reduce_max_tiny<DTYPE>( \
const DTYPE* x, DTYPE* y, int m, int t);
_XPU_DEF__FAST_REDUCE_MAX_TINY_(float);
_XPU_DEF__FAST_REDUCE_MAX_TINY_(float16);
#define _XPU_DEF__FAST_REDUCE_MIN_TINY_(DTYPE) \
template __global__ void fast_reduce_min_tiny<DTYPE>( \
const DTYPE* x, DTYPE* y, int m, int t);
_XPU_DEF__FAST_REDUCE_MIN_TINY_(float);
_XPU_DEF__FAST_REDUCE_MIN_TINY_(float16);
} // namespace plugin
} // namespace xpu2
paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/take_along_axis.xpu
浏览文件 @
c6757bd3
...
@@ -26,8 +26,6 @@ template <typename T, typename TID>
...
@@ -26,8 +26,6 @@ template <typename T, typename TID>
__global__ void take_along_axis(const T* x,
__global__ void take_along_axis(const T* x,
const TID* indices,
const TID* indices,
T* y,
T* y,
const int64_t* shape,
int64_t shape_size,
int64_t batch,
int64_t batch,
int64_t xlen,
int64_t xlen,
int64_t ylen) {
int64_t ylen) {
...
@@ -40,12 +38,6 @@ __global__ void take_along_axis(const T* x,
...
@@ -40,12 +38,6 @@ __global__ void take_along_axis(const T* x,
__simd__ char lm_y[sizeof(T)];
__simd__ char lm_y[sizeof(T)];
__simd__ char lm_idx[sizeof(TID)];
__simd__ char lm_idx[sizeof(TID)];
__shared__ int64_t sm_shape[512];
if (cid == 0) {
GM2SM(shape, sm_shape, shape_size * sizeof(int64_t));
}
sync_all();
for (int64_t i = tid; i < batch * ylen; i += nthreads) {
for (int64_t i = tid; i < batch * ylen; i += nthreads) {
GM2LM(indices + i, lm_idx, sizeof(TID));
GM2LM(indices + i, lm_idx, sizeof(TID));
TID idx = ((TID*)lm_idx)[0];
TID idx = ((TID*)lm_idx)[0];
...
@@ -65,8 +57,6 @@ __global__ void take_along_axis(const T* x,
...
@@ -65,8 +57,6 @@ __global__ void take_along_axis(const T* x,
const DTYPE* x, \
const DTYPE* x, \
const IDTYPE* indices, \
const IDTYPE* indices, \
DTYPE* y, \
DTYPE* y, \
const int64_t* shape, \
int64_t shape_size, \
int64_t batch, \
int64_t batch, \
int64_t xlen, \
int64_t xlen, \
int64_t ylen);
int64_t ylen);
...
...
paddle/phi/kernels/xpu/plugin/src/wrapper/fast_reduce.cpp
0 → 100644
浏览文件 @
c6757bd3
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace
xpu2
{
namespace
plugin
{
template
<
typename
T
>
__attribute__
((
global
))
void
fast_reduce_sum_tiny
(
const
T
*
x
,
T
*
y
,
int
m
,
int
t
);
template
<
typename
T
>
__attribute__
((
global
))
void
fast_reduce_mean_tiny
(
const
T
*
x
,
T
*
y
,
int
m
,
int
t
);
template
<
typename
T
>
__attribute__
((
global
))
void
fast_reduce_max_tiny
(
const
T
*
x
,
T
*
y
,
int
m
,
int
t
);
template
<
typename
T
>
__attribute__
((
global
))
void
fast_reduce_min_tiny
(
const
T
*
x
,
T
*
y
,
int
m
,
int
t
);
}
// namespace plugin
}
// namespace xpu2
namespace
baidu
{
namespace
xpu
{
namespace
api
{
namespace
plugin
{
template
<
typename
T
>
static
int
xpu2_wrapper
(
Context
*
ctx
,
const
T
*
x
,
T
*
y
,
const
std
::
vector
<
int
>&
xshape
,
int
op_type
)
{
std
::
vector
<
int
>
rdims
=
{
static_cast
<
int
>
(
xshape
.
size
()
-
1
)};
switch
(
op_type
)
{
case
0
:
return
reduce_sum
<
T
>
(
ctx
,
x
,
y
,
xshape
,
rdims
);
case
2
:
return
reduce_max
<
T
>
(
ctx
,
x
,
y
,
xshape
,
rdims
);
case
3
:
return
reduce_min
<
T
>
(
ctx
,
x
,
y
,
xshape
,
rdims
);
default:
return
NOT_IMPLEMENT
;
}
return
SUCCESS
;
}
template
<
>
int
xpu2_wrapper
<
int8_t
>
(
Context
*
ctx
,
const
int8_t
*
x
,
int8_t
*
y
,
const
std
::
vector
<
int
>&
xshape
,
int
op_type
)
{
std
::
vector
<
int
>
rdims
=
{
static_cast
<
int
>
(
xshape
.
size
()
-
1
)};
if
(
op_type
==
0
)
{
return
reduce_sum
<
int8_t
>
(
ctx
,
x
,
y
,
xshape
,
rdims
);
}
else
{
return
NOT_IMPLEMENT
;
}
return
SUCCESS
;
}
template
<
>
int
xpu2_wrapper
<
float
>
(
Context
*
ctx
,
const
float
*
x
,
float
*
y
,
const
std
::
vector
<
int
>&
xshape
,
int
op_type
)
{
int
t
=
xshape
[
xshape
.
size
()
-
1
];
int
xlen
=
vector_prod
(
xshape
);
int
m
=
xlen
/
t
;
switch
(
op_type
)
{
case
0
:
xpu2
::
plugin
::
fast_reduce_sum_tiny
<
float
>
<<<
ctx
->
ncluster
(),
64
,
ctx
->
xpu_stream
>>>
(
x
,
y
,
m
,
t
);
break
;
case
1
:
xpu2
::
plugin
::
fast_reduce_mean_tiny
<
float
>
<<<
ctx
->
ncluster
(),
64
,
ctx
->
xpu_stream
>>>
(
x
,
y
,
m
,
t
);
break
;
case
2
:
xpu2
::
plugin
::
fast_reduce_max_tiny
<
float
>
<<<
ctx
->
ncluster
(),
64
,
ctx
->
xpu_stream
>>>
(
x
,
y
,
m
,
t
);
break
;
case
3
:
xpu2
::
plugin
::
fast_reduce_min_tiny
<
float
>
<<<
ctx
->
ncluster
(),
64
,
ctx
->
xpu_stream
>>>
(
x
,
y
,
m
,
t
);
break
;
default:
return
NOT_IMPLEMENT
;
}
return
SUCCESS
;
}
template
<
>
int
xpu2_wrapper
<
float16
>
(
Context
*
ctx
,
const
float16
*
x
,
float16
*
y
,
const
std
::
vector
<
int
>&
xshape
,
int
op_type
)
{
int
t
=
xshape
[
xshape
.
size
()
-
1
];
int
xlen
=
vector_prod
(
xshape
);
int
m
=
xlen
/
t
;
switch
(
op_type
)
{
case
0
:
xpu2
::
plugin
::
fast_reduce_sum_tiny
<
float16
>
<<<
ctx
->
ncluster
(),
64
,
ctx
->
xpu_stream
>>>
(
x
,
y
,
m
,
t
);
break
;
case
1
:
xpu2
::
plugin
::
fast_reduce_mean_tiny
<
float16
>
<<<
ctx
->
ncluster
(),
64
,
ctx
->
xpu_stream
>>>
(
x
,
y
,
m
,
t
);
break
;
case
2
:
xpu2
::
plugin
::
fast_reduce_max_tiny
<
float16
>
<<<
ctx
->
ncluster
(),
64
,
ctx
->
xpu_stream
>>>
(
x
,
y
,
m
,
t
);
break
;
case
3
:
xpu2
::
plugin
::
fast_reduce_min_tiny
<
float16
>
<<<
ctx
->
ncluster
(),
64
,
ctx
->
xpu_stream
>>>
(
x
,
y
,
m
,
t
);
break
;
default:
return
NOT_IMPLEMENT
;
}
return
SUCCESS
;
}
template
<
typename
T
>
int
fast_reduce_tiny
(
Context
*
ctx
,
const
T
*
x
,
T
*
y
,
const
std
::
vector
<
int
>&
xshape
,
const
std
::
vector
<
int
>&
rdims
,
int
op_type
)
{
WRAPPER_CHECK_CTX
(
ctx
);
WRAPPER_DUMP_FUNCTION_T1
(
ctx
,
"fast_reduce_tiny"
,
T
);
WRAPPER_DUMP_PARAM5
(
ctx
,
x
,
y
,
xshape
,
rdims
,
op_type
);
WRAPPER_DUMP
(
ctx
);
std
::
vector
<
int
>
yshape
=
xshape
;
yshape
[
xshape
.
size
()
-
1
]
=
1
;
int64_t
lenx
=
-
1
;
int64_t
leny
=
-
1
;
WRAPPER_CHECK_SHAPE
(
ctx
,
&
lenx
,
xshape
);
WRAPPER_CHECK_SHAPE
(
ctx
,
&
leny
,
yshape
);
WRAPPER_CHECK_PTR
(
ctx
,
T
,
lenx
,
x
);
WRAPPER_CHECK_PTR
(
ctx
,
T
,
leny
,
y
);
if
(
ctx
->
dev
().
type
()
==
api
::
kXPU2
)
{
return
xpu2_wrapper
<
T
>
(
ctx
,
x
,
y
,
xshape
,
op_type
);
}
WRAPPER_UNIMPLEMENTED
(
ctx
);
}
template
<
typename
T
>
DLL_EXPORT
int
fast_reduce_sum
(
Context
*
ctx
,
const
T
*
x
,
T
*
y
,
const
std
::
vector
<
int
>&
xshape
,
const
std
::
vector
<
int
>&
rdims
)
{
if
(
rdims
.
size
()
==
1
&&
rdims
[
0
]
==
xshape
.
size
()
-
1
&&
xshape
[
xshape
.
size
()
-
1
]
<=
832
)
{
return
fast_reduce_tiny
<
T
>
(
ctx
,
x
,
y
,
xshape
,
rdims
,
0
);
}
else
{
return
reduce_sum
<
T
>
(
ctx
,
x
,
y
,
xshape
,
rdims
);
}
}
template
<
typename
T
>
DLL_EXPORT
int
fast_reduce_mean
(
Context
*
ctx
,
const
T
*
x
,
T
*
y
,
const
std
::
vector
<
int
>&
xshape
,
const
std
::
vector
<
int
>&
rdims
)
{
if
(
rdims
.
size
()
==
1
&&
rdims
[
0
]
==
xshape
.
size
()
-
1
&&
xshape
[
xshape
.
size
()
-
1
]
<=
832
)
{
return
fast_reduce_tiny
<
T
>
(
ctx
,
x
,
y
,
xshape
,
rdims
,
1
);
}
else
{
return
reduce_mean
<
T
>
(
ctx
,
x
,
y
,
xshape
,
rdims
);
}
}
template
<
typename
T
>
DLL_EXPORT
int
fast_reduce_max
(
Context
*
ctx
,
const
T
*
x
,
T
*
y
,
const
std
::
vector
<
int
>&
xshape
,
const
std
::
vector
<
int
>&
rdims
)
{
if
(
rdims
.
size
()
==
1
&&
rdims
[
0
]
==
xshape
.
size
()
-
1
&&
xshape
[
xshape
.
size
()
-
1
]
<=
832
)
{
return
fast_reduce_tiny
<
T
>
(
ctx
,
x
,
y
,
xshape
,
rdims
,
2
);
}
else
{
return
reduce_max
<
T
>
(
ctx
,
x
,
y
,
xshape
,
rdims
);
}
}
template
<
typename
T
>
DLL_EXPORT
int
fast_reduce_min
(
Context
*
ctx
,
const
T
*
x
,
T
*
y
,
const
std
::
vector
<
int
>&
xshape
,
const
std
::
vector
<
int
>&
rdims
)
{
if
(
rdims
.
size
()
==
1
&&
rdims
[
0
]
==
xshape
.
size
()
-
1
&&
xshape
[
xshape
.
size
()
-
1
]
<=
832
)
{
return
fast_reduce_tiny
<
T
>
(
ctx
,
x
,
y
,
xshape
,
rdims
,
3
);
}
else
{
return
reduce_min
<
T
>
(
ctx
,
x
,
y
,
xshape
,
rdims
);
}
}
template
int
fast_reduce_sum
(
Context
*
,
const
float
*
,
float
*
,
const
std
::
vector
<
int
>
&
,
const
std
::
vector
<
int
>&
);
template
int
fast_reduce_sum
(
Context
*
,
const
float16
*
,
float16
*
,
const
std
::
vector
<
int
>
&
,
const
std
::
vector
<
int
>&
);
template
int
fast_reduce_sum
(
Context
*
,
const
int
*
,
int
*
,
const
std
::
vector
<
int
>
&
,
const
std
::
vector
<
int
>&
);
template
int
fast_reduce_sum
(
Context
*
,
const
int64_t
*
,
int64_t
*
,
const
std
::
vector
<
int
>
&
,
const
std
::
vector
<
int
>&
);
template
int
fast_reduce_sum
(
Context
*
,
const
int8_t
*
,
int8_t
*
,
const
std
::
vector
<
int
>
&
,
const
std
::
vector
<
int
>&
);
template
int
fast_reduce_mean
(
Context
*
,
const
float
*
,
float
*
,
const
std
::
vector
<
int
>
&
,
const
std
::
vector
<
int
>&
);
template
int
fast_reduce_mean
(
Context
*
,
const
float16
*
,
float16
*
,
const
std
::
vector
<
int
>
&
,
const
std
::
vector
<
int
>&
);
template
int
fast_reduce_min
(
Context
*
,
const
float
*
,
float
*
,
const
std
::
vector
<
int
>
&
,
const
std
::
vector
<
int
>&
);
template
int
fast_reduce_max
(
Context
*
,
const
float
*
,
float
*
,
const
std
::
vector
<
int
>
&
,
const
std
::
vector
<
int
>&
);
template
int
fast_reduce_max
(
Context
*
,
const
int
*
,
int
*
,
const
std
::
vector
<
int
>
&
,
const
std
::
vector
<
int
>&
);
template
int
fast_reduce_max
(
Context
*
,
const
int64_t
*
,
int64_t
*
,
const
std
::
vector
<
int
>
&
,
const
std
::
vector
<
int
>&
);
}
// namespace plugin
}
// namespace api
}
// namespace xpu
}
// namespace baidu
paddle/phi/kernels/xpu/plugin/src/wrapper/take_along_axis.cpp
浏览文件 @
c6757bd3
...
@@ -25,8 +25,6 @@ template <typename T, typename TID>
...
@@ -25,8 +25,6 @@ template <typename T, typename TID>
__attribute__
((
global
))
void
take_along_axis
(
const
T
*
x
,
__attribute__
((
global
))
void
take_along_axis
(
const
T
*
x
,
const
TID
*
indices
,
const
TID
*
indices
,
T
*
y
,
T
*
y
,
const
int64_t
*
shape
,
int64_t
shape_size
,
int64_t
batch
,
int64_t
batch
,
int64_t
xlen
,
int64_t
xlen
,
int64_t
ylen
);
int64_t
ylen
);
...
@@ -74,43 +72,21 @@ static int xpu2_wrapper(Context* ctx,
...
@@ -74,43 +72,21 @@ static int xpu2_wrapper(Context* ctx,
const
std
::
vector
<
int64_t
>&
idxshape
,
const
std
::
vector
<
int64_t
>&
idxshape
,
int64_t
axis
)
{
int64_t
axis
)
{
int64_t
m_idx
=
1
;
int64_t
m_idx
=
1
;
int64_t
shape_new_size
=
idxshape
.
size
()
-
1
;
std
::
vector
<
int64_t
>
shape_new
=
xshape
;
for
(
int64_t
i
=
0
;
i
<
axis
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
axis
;
i
++
)
{
m_idx
*=
idxshape
[
i
];
m_idx
*=
idxshape
[
i
];
}
}
for
(
int64_t
i
=
axis
+
1
;
i
<
xshape
.
size
();
i
++
)
{
shape_new
[
i
-
1
]
=
xshape
[
i
];
}
int64_t
t_x
=
xshape
[
axis
];
int64_t
t_x
=
xshape
[
axis
];
int64_t
t_idx
=
idxshape
[
axis
];
int64_t
t_idx
=
idxshape
[
axis
];
int64_t
n_idx
=
vector_prod
(
idxshape
)
/
m_idx
/
t_idx
;
int64_t
n_idx
=
vector_prod
(
idxshape
)
/
m_idx
/
t_idx
;
if
(
m_idx
<
64
&&
n_idx
==
1
)
{
if
(
m_idx
<
64
&&
n_idx
==
1
)
{
api
::
ctx_guard
RAII_GUARD
(
ctx
);
int64_t
*
shape_xpu
=
RAII_GUARD
.
alloc_l3_or_gm
<
int64_t
>
(
shape_new_size
);
WRAPPER_ASSERT_WORKSPACE
(
ctx
,
shape_xpu
);
int
ret
=
do_host2device
(
ctx
,
shape_new
.
data
(),
shape_xpu
,
(
shape_new_size
)
*
sizeof
(
int64_t
));
WRAPPER_ASSERT_SUCCESS
(
ctx
,
ret
);
using
XPU_TID
=
typename
XPUIndexType
<
TID
>::
type
;
using
XPU_TID
=
typename
XPUIndexType
<
TID
>::
type
;
const
XPU_TID
*
casted_index
=
const
XPU_TID
*
casted_index
=
static_cast
<
const
XPU_TID
*>
(
static_cast
<
const
void
*>
(
index
));
static_cast
<
const
XPU_TID
*>
(
static_cast
<
const
void
*>
(
index
));
xpu2
::
plugin
::
take_along_axis
<
T
,
XPU_TID
>
xpu2
::
plugin
::
take_along_axis
<
T
,
XPU_TID
>
<<<
ctx
->
ncluster
(),
64
,
ctx
->
xpu_stream
>>>
(
<<<
ctx
->
ncluster
(),
64
,
ctx
->
xpu_stream
>>>
(
x
,
x
,
casted_index
,
y
,
m_idx
,
t_x
,
t_idx
);
casted_index
,
y
,
reinterpret_cast
<
xpu2
::
int64_t
*>
(
shape_xpu
),
shape_new_size
,
m_idx
,
t_x
,
t_idx
);
}
else
{
}
else
{
return
gather_element
(
ctx
,
x
,
index
,
y
,
xshape
,
idxshape
,
axis
);
return
gather_element
(
ctx
,
x
,
index
,
y
,
xshape
,
idxshape
,
axis
);
}
}
...
...
paddle/phi/kernels/xpu/reduce.h
浏览文件 @
c6757bd3
...
@@ -25,6 +25,36 @@
...
@@ -25,6 +25,36 @@
namespace
phi
{
namespace
phi
{
static
void
GetReduceDims
(
const
DDim
&
xdims
,
const
std
::
vector
<
int64_t
>&
dims
,
bool
reduce_all
,
std
::
vector
<
int
>*
reduce_dims
)
{
const
auto
&
input_dim_size
=
xdims
.
size
();
std
::
vector
<
int
>
true_dims
;
for
(
size_t
i
=
0
;
i
<
dims
.
size
();
++
i
)
{
if
(
dims
[
i
]
<
0
)
{
true_dims
.
push_back
(
dims
[
i
]
+
input_dim_size
);
}
else
{
true_dims
.
push_back
(
dims
[
i
]);
}
}
if
(
reduce_all
)
{
for
(
int
i
=
0
;
i
<
input_dim_size
;
++
i
)
{
reduce_dims
->
push_back
(
i
);
}
}
else
{
std
::
set
<
int
>
dims_set
(
true_dims
.
begin
(),
true_dims
.
end
());
for
(
auto
i
=
0
;
i
<
input_dim_size
;
i
++
)
{
if
(
dims_set
.
find
(
i
)
!=
dims_set
.
end
())
{
if
(
xdims
[
i
]
!=
1
)
{
reduce_dims
->
push_back
(
i
);
}
}
}
}
}
template
<
typename
Context
,
typename
T
>
template
<
typename
Context
,
typename
T
>
int
XPUReduce
(
const
Context
&
dev_ctx
,
int
XPUReduce
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
...
@@ -43,35 +73,15 @@ int XPUReduce(const Context& dev_ctx,
...
@@ -43,35 +73,15 @@ int XPUReduce(const Context& dev_ctx,
const
auto
*
x_data
=
x
.
data
<
T
>
();
const
auto
*
x_data
=
x
.
data
<
T
>
();
auto
*
y_data
=
out
->
data
<
T
>
();
auto
*
y_data
=
out
->
data
<
T
>
();
const
auto
&
input_dim_size
=
x
.
dims
().
size
();
std
::
vector
<
int
>
true_dims
;
for
(
size_t
i
=
0
;
i
<
dims
.
size
();
++
i
)
{
if
(
dims
[
i
]
<
0
)
{
true_dims
.
push_back
(
dims
[
i
]
+
input_dim_size
);
}
else
{
true_dims
.
push_back
(
dims
[
i
]);
}
}
std
::
vector
<
int
>
reduce_dims
;
const
auto
&
input_dim_size
=
x
.
dims
().
size
()
;
std
::
vector
<
int
>
xdims
(
(
input_dim_size
)
);
std
::
vector
<
int
>
xdims
(
input_dim_size
);
for
(
int
i
=
0
;
i
<
input_dim_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
input_dim_size
;
++
i
)
{
xdims
[
i
]
=
x
.
dims
()[
i
];
xdims
[
i
]
=
x
.
dims
()[
i
];
}
}
if
(
reduce_all
)
{
for
(
int
i
=
0
;
i
<
input_dim_size
;
++
i
)
{
std
::
vector
<
int
>
reduce_dims
;
reduce_dims
.
push_back
(
i
);
GetReduceDims
(
x
.
dims
(),
dims
,
reduce_all
,
&
reduce_dims
);
}
}
else
{
std
::
set
<
int
>
dims_set
(
true_dims
.
begin
(),
true_dims
.
end
());
for
(
auto
i
=
0
;
i
<
input_dim_size
;
i
++
)
{
if
(
dims_set
.
find
(
i
)
!=
dims_set
.
end
())
{
if
(
x
.
dims
()[
i
]
!=
1
)
{
reduce_dims
.
push_back
(
i
);
}
}
}
}
int
r
=
xpu
::
SUCCESS
;
int
r
=
xpu
::
SUCCESS
;
if
(
reduce_dims
.
size
()
==
0
)
{
if
(
reduce_dims
.
size
()
==
0
)
{
...
@@ -119,33 +129,14 @@ void XPUReduce(const DeviceContext& dev_ctx,
...
@@ -119,33 +129,14 @@ void XPUReduce(const DeviceContext& dev_ctx,
reduce_all
=
recompute_reduce_all
(
x
,
dims
,
reduce_all
);
reduce_all
=
recompute_reduce_all
(
x
,
dims
,
reduce_all
);
const
auto
&
input_dim_size
=
x
.
dims
().
size
();
const
auto
&
input_dim_size
=
x
.
dims
().
size
();
std
::
vector
<
int
>
true_dims
;
std
::
vector
<
int
>
xdims
(
input_dim_size
);
for
(
size_t
i
=
0
;
i
<
dims
.
size
();
++
i
)
{
if
(
dims
[
i
]
<
0
)
{
true_dims
.
push_back
(
dims
[
i
]
+
input_dim_size
);
}
else
{
true_dims
.
push_back
(
dims
[
i
]);
}
}
std
::
vector
<
int
>
reduce_dims
;
std
::
vector
<
int
>
xdims
((
input_dim_size
));
for
(
int
i
=
0
;
i
<
input_dim_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
input_dim_size
;
++
i
)
{
xdims
[
i
]
=
x
.
dims
()[
i
];
xdims
[
i
]
=
x
.
dims
()[
i
];
}
}
if
(
reduce_all
)
{
for
(
int
i
=
0
;
i
<
input_dim_size
;
++
i
)
{
std
::
vector
<
int
>
reduce_dims
;
reduce_dims
.
push_back
(
i
);
GetReduceDims
(
x
.
dims
(),
dims
,
reduce_all
,
&
reduce_dims
);
}
}
else
{
std
::
set
<
int
>
dims_set
(
true_dims
.
begin
(),
true_dims
.
end
());
for
(
auto
i
=
0
;
i
<
input_dim_size
;
i
++
)
{
if
(
dims_set
.
find
(
i
)
!=
dims_set
.
end
())
{
if
(
x
.
dims
()[
i
]
!=
1
)
{
reduce_dims
.
push_back
(
i
);
}
}
}
}
// no need to cast dtype
// no need to cast dtype
if
(
out_dtype
==
phi
::
DataType
::
UNDEFINED
||
out_dtype
==
x
.
dtype
())
{
if
(
out_dtype
==
phi
::
DataType
::
UNDEFINED
||
out_dtype
==
x
.
dtype
())
{
// do reduce sum
// do reduce sum
...
...
paddle/phi/kernels/xpu/reduce_max_kernel.cc
浏览文件 @
c6757bd3
...
@@ -34,11 +34,20 @@ void MaxKernel(const Context& dev_ctx,
...
@@ -34,11 +34,20 @@ void MaxKernel(const Context& dev_ctx,
T
*
y
,
T
*
y
,
const
std
::
vector
<
int
>&
xdims
,
const
std
::
vector
<
int
>&
xdims
,
const
std
::
vector
<
int
>&
reduce_dims
)
{
const
std
::
vector
<
int
>&
reduce_dims
)
{
#ifndef PADDLE_WITH_XPU_PLUGIN
return
xpu
::
reduce_max
<
XPUType
>
(
ctx
,
return
xpu
::
reduce_max
<
XPUType
>
(
ctx
,
reinterpret_cast
<
const
XPUType
*>
(
x
),
reinterpret_cast
<
const
XPUType
*>
(
x
),
reinterpret_cast
<
XPUType
*>
(
y
),
reinterpret_cast
<
XPUType
*>
(
y
),
xdims
,
xdims
,
reduce_dims
);
reduce_dims
);
#else
return
xpu
::
plugin
::
fast_reduce_max
<
XPUType
>
(
ctx
,
reinterpret_cast
<
const
XPUType
*>
(
x
),
reinterpret_cast
<
XPUType
*>
(
y
),
xdims
,
reduce_dims
);
#endif
};
};
int
r
=
XPUReduce
<
Context
,
T
>
(
int
r
=
XPUReduce
<
Context
,
T
>
(
...
...
paddle/phi/kernels/xpu/reduce_mean_kernel.cc
浏览文件 @
c6757bd3
...
@@ -35,11 +35,20 @@ void MeanRawKernel(const Context& dev_ctx,
...
@@ -35,11 +35,20 @@ void MeanRawKernel(const Context& dev_ctx,
T
*
y
,
T
*
y
,
const
std
::
vector
<
int
>&
xdims
,
const
std
::
vector
<
int
>&
xdims
,
const
std
::
vector
<
int
>&
reduce_dims
)
{
const
std
::
vector
<
int
>&
reduce_dims
)
{
#ifndef PADDLE_WITH_XPU_PLUGIN
return
xpu
::
reduce_mean
<
XPUType
>
(
ctx
,
return
xpu
::
reduce_mean
<
XPUType
>
(
ctx
,
reinterpret_cast
<
const
XPUType
*>
(
x
),
reinterpret_cast
<
const
XPUType
*>
(
x
),
reinterpret_cast
<
XPUType
*>
(
y
),
reinterpret_cast
<
XPUType
*>
(
y
),
xdims
,
xdims
,
reduce_dims
);
reduce_dims
);
#else
return
xpu
::
plugin
::
fast_reduce_mean
<
XPUType
>
(
ctx
,
reinterpret_cast
<
const
XPUType
*>
(
x
),
reinterpret_cast
<
XPUType
*>
(
y
),
xdims
,
reduce_dims
);
#endif
};
};
int
r
=
XPUReduce
<
Context
,
T
>
(
int
r
=
XPUReduce
<
Context
,
T
>
(
...
...
paddle/phi/kernels/xpu/reduce_min_kernel.cc
浏览文件 @
c6757bd3
...
@@ -36,11 +36,20 @@ void MinRawKernel(const Context& dev_ctx,
...
@@ -36,11 +36,20 @@ void MinRawKernel(const Context& dev_ctx,
T
*
y
,
T
*
y
,
const
std
::
vector
<
int
>&
xdims
,
const
std
::
vector
<
int
>&
xdims
,
const
std
::
vector
<
int
>&
reduce_dims
)
{
const
std
::
vector
<
int
>&
reduce_dims
)
{
#ifndef PADDLE_WITH_XPU_PLUGIN
return
xpu
::
reduce_min
<
XPUType
>
(
ctx
,
return
xpu
::
reduce_min
<
XPUType
>
(
ctx
,
reinterpret_cast
<
const
XPUType
*>
(
x
),
reinterpret_cast
<
const
XPUType
*>
(
x
),
reinterpret_cast
<
XPUType
*>
(
y
),
reinterpret_cast
<
XPUType
*>
(
y
),
xdims
,
xdims
,
reduce_dims
);
reduce_dims
);
#else
return
xpu
::
plugin
::
fast_reduce_min
<
XPUType
>
(
ctx
,
reinterpret_cast
<
const
XPUType
*>
(
x
),
reinterpret_cast
<
XPUType
*>
(
y
),
xdims
,
reduce_dims
);
#endif
};
};
int
r
=
XPUReduce
<
Context
,
T
>
(
int
r
=
XPUReduce
<
Context
,
T
>
(
...
...
paddle/phi/kernels/xpu/reduce_util.h
浏览文件 @
c6757bd3
...
@@ -28,12 +28,22 @@ struct SumFunctor {
...
@@ -28,12 +28,22 @@ struct SumFunctor {
const
std
::
vector
<
int
>&
xdims
,
const
std
::
vector
<
int
>&
xdims
,
const
std
::
vector
<
int
>&
reduce_dims
)
{
const
std
::
vector
<
int
>&
reduce_dims
)
{
using
XPUType
=
typename
XPUTypeTrait
<
X
>::
Type
;
using
XPUType
=
typename
XPUTypeTrait
<
X
>::
Type
;
#ifndef PADDLE_WITH_XPU_PLUGIN
int
r
=
xpu
::
reduce_sum
<
XPUType
>
(
ctx
,
int
r
=
xpu
::
reduce_sum
<
XPUType
>
(
ctx
,
reinterpret_cast
<
const
XPUType
*>
(
x
),
reinterpret_cast
<
const
XPUType
*>
(
x
),
reinterpret_cast
<
XPUType
*>
(
y
),
reinterpret_cast
<
XPUType
*>
(
y
),
xdims
,
xdims
,
reduce_dims
);
reduce_dims
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"reduce_sum"
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"reduce_sum"
);
#else
int
r
=
xpu
::
plugin
::
fast_reduce_sum
<
XPUType
>
(
ctx
,
reinterpret_cast
<
const
XPUType
*>
(
x
),
reinterpret_cast
<
XPUType
*>
(
y
),
xdims
,
reduce_dims
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"fast_reduce_sum"
);
#endif
}
}
};
};
}
// namespace phi
}
// namespace phi
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录