未验证 提交 c6757bd3 编写于 作者: J jiangfan06 提交者: GitHub

[XPU] Add xpu plugin for reduce ops (#56389)

上级 f8cba26d
......@@ -75,6 +75,35 @@ DLL_EXPORT int fast_layer_norm(Context* ctx,
float eps,
const float* scale,
const float* bias);
template <typename T>
DLL_EXPORT int fast_reduce_sum(Context* ctx,
const T* x,
T* y,
const std::vector<int>& xshape,
const std::vector<int>& rdims);
template <typename T>
DLL_EXPORT int fast_reduce_mean(Context* ctx,
const T* x,
T* y,
const std::vector<int>& xshape,
const std::vector<int>& rdims);
template <typename T>
DLL_EXPORT int fast_reduce_max(Context* ctx,
const T* x,
T* y,
const std::vector<int>& xshape,
const std::vector<int>& rdims);
template <typename T>
DLL_EXPORT int fast_reduce_min(Context* ctx,
const T* x,
T* y,
const std::vector<int>& xshape,
const std::vector<int>& rdims);
} // namespace plugin
} // namespace api
} // namespace xpu
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu2 {
namespace plugin {
__device__ float do_sum_align16(float* lmptr, int size) {
__simd__ float sum_buf[16];
float32x16_t vsum = vset_zero();
for (int i = 0; i < size; i += 16) {
float32x16_t v0 = vload_lm_float32x16(lmptr + i);
vsum = vvadd_float32x16(vsum, v0);
}
vstore_lm_float32x16(sum_buf, vsum);
mfence_lm();
float sum = 0.0f;
for (int i = 0; i < 16; i++) {
sum = sum + sum_buf[i];
}
return sum;
}
__device__ float do_sum(float* lmptr, int size) {
float sum = 0.0f;
for (int i = 0; i < size; i++) {
sum += lmptr[i];
}
return sum;
}
__device__ float do_max_align16(float* lmptr, int size) {
__simd__ float max_buf[16];
float32x16_t vmax = vload_lm_float32x16(lmptr);
for (int i = 16; i < size; i += 16) {
float32x16_t v0 = vload_lm_float32x16(lmptr + i);
vmax = vvmax_float32x16(vmax, v0);
}
vstore_lm_float32x16(max_buf, vmax);
mfence_lm();
float max_val = max_buf[0];
for (int i = 1; i < 16; i++) {
max_val = fmax(max_val, max_buf[i]);
}
return max_val;
}
__device__ float do_max(float* lmptr, int size) {
float max_val = lmptr[0];
for (int i = 1; i < size; i++) {
max_val = fmax(max_val, lmptr[i]);
}
return max_val;
}
__device__ float do_min_align16(float* lmptr, int size) {
__simd__ float min_buf[16];
float32x16_t vmin = vload_lm_float32x16(lmptr);
for (int i = 16; i < size; i += 16) {
float32x16_t v0 = vload_lm_float32x16(lmptr + i);
vmin = vvmin_float32x16(vmin, v0);
}
vstore_lm_float32x16(min_buf, vmin);
mfence_lm();
float min_val = min_buf[0];
for (int i = 1; i < 16; i++) {
min_val = fmin(min_val, min_buf[i]);
}
return min_val;
}
__device__ float do_min(float* lmptr, int size) {
float min_val = lmptr[0];
for (int i = 1; i < size; i++) {
min_val = fmin(min_val, lmptr[i]);
}
return min_val;
}
template <typename T>
__global__ void fast_reduce_sum_tiny(const T* x, T* y, int m, int t) {
int cid = core_id();
const int ncores = core_num();
int tid = cid * cluster_num() + cluster_id();
int nthreads = cluster_num() * ncores;
const int64_t max_tt = 832;
const int64_t buffer_len = max_tt * 4 / sizeof(float);
int mstart = 0;
int mend = 0;
__simd__ float xlm[buffer_len];
__simd__ float ylm[buffer_len];
int block_cnt = buffer_len / t;
partition(tid, nthreads, m, 1, &mstart, &mend);
for (int i = mstart; i < mend; i += block_cnt) {
int readlen = min((mend - i) * t, block_cnt * t);
GM2LM(x + i * t, (T*)xlm, readlen * sizeof(T));
if (t % 16 == 0 && t >= 32) {
for (int j = 0; j < readlen; j += t) {
ylm[j / t] = do_sum_align16(xlm + j, t);
}
LM2GM((T*)ylm, y + i, readlen / t * sizeof(T));
} else {
primitive_cast<T, float>((T*)xlm, xlm, readlen);
for (int j = 0; j < readlen; j += t) {
ylm[j / t] = do_sum(xlm + j, t);
}
primitive_cast<float, T>(ylm, (T*)ylm, readlen / t);
LM2GM((T*)ylm, y + i, readlen / t * sizeof(T));
}
}
return;
}
template <typename T>
__global__ void fast_reduce_mean_tiny(const T* x, T* y, int m, int t) {
int cid = core_id();
const int ncores = core_num();
int tid = cid * cluster_num() + cluster_id();
int nthreads = cluster_num() * ncores;
const int64_t max_tt = 832;
const int64_t buffer_len = max_tt * 4 / sizeof(float);
int mstart = 0;
int mend = 0;
__simd__ float xlm[buffer_len];
__simd__ float ylm[buffer_len];
int block_cnt = buffer_len / t;
partition(tid, nthreads, m, 1, &mstart, &mend);
for (int i = mstart; i < mend; i += block_cnt) {
int readlen = min((mend - i) * t, block_cnt * t);
GM2LM(x + i * t, (T*)xlm, readlen * sizeof(T));
if (t % 16 == 0 && t >= 32) {
for (int j = 0; j < readlen; j += t) {
ylm[j / t] = do_sum_align16(xlm + j, t) / t;
}
LM2GM((T*)ylm, y + i, readlen / t * sizeof(T));
} else {
primitive_cast<T, float>((T*)xlm, xlm, readlen);
for (int j = 0; j < readlen; j += t) {
ylm[j / t] = do_sum(xlm + j, t) / t;
}
primitive_cast<float, T>(ylm, (T*)ylm, readlen / t);
LM2GM((T*)ylm, y + i, readlen / t * sizeof(T));
}
}
return;
}
template <typename T>
__global__ void fast_reduce_max_tiny(const T* x, T* y, int m, int t) {
int cid = core_id();
const int ncores = core_num();
int tid = cid * cluster_num() + cluster_id();
int nthreads = cluster_num() * ncores;
const int64_t max_tt = 832;
const int64_t buffer_len = max_tt * 4 / sizeof(float);
int mstart = 0;
int mend = 0;
__simd__ float xlm[buffer_len];
__simd__ float ylm[buffer_len];
int block_cnt = buffer_len / t;
partition(tid, nthreads, m, 1, &mstart, &mend);
for (int i = mstart; i < mend; i += block_cnt) {
int readlen = min((mend - i) * t, block_cnt * t);
GM2LM(x + i * t, (T*)xlm, readlen * sizeof(T));
if (t % 16 == 0 && t >= 32) {
for (int j = 0; j < readlen; j += t) {
ylm[j / t] = do_max_align16(xlm + j, t);
}
LM2GM((T*)ylm, y + i, readlen / t * sizeof(T));
} else {
primitive_cast<T, float>((T*)xlm, xlm, readlen);
for (int j = 0; j < readlen; j += t) {
ylm[j / t] = do_max(xlm + j, t);
}
primitive_cast<float, T>(ylm, (T*)ylm, readlen / t);
LM2GM((T*)ylm, y + i, readlen / t * sizeof(T));
}
}
return;
}
template <typename T>
__global__ void fast_reduce_min_tiny(const T* x, T* y, int m, int t) {
int cid = core_id();
const int ncores = core_num();
int tid = cid * cluster_num() + cluster_id();
int nthreads = cluster_num() * ncores;
const int64_t max_tt = 832;
const int64_t buffer_len = max_tt * 4 / sizeof(float);
int mstart = 0;
int mend = 0;
__simd__ float xlm[buffer_len];
__simd__ float ylm[buffer_len];
int block_cnt = buffer_len / t;
partition(tid, nthreads, m, 1, &mstart, &mend);
for (int i = mstart; i < mend; i += block_cnt) {
int readlen = min((mend - i) * t, block_cnt * t);
GM2LM(x + i * t, (T*)xlm, readlen * sizeof(T));
if (t % 16 == 0 && t >= 32) {
for (int j = 0; j < readlen; j += t) {
ylm[j / t] = do_min_align16(xlm + j, t);
}
LM2GM((T*)ylm, y + i, readlen / t * sizeof(T));
} else {
primitive_cast<T, float>((T*)xlm, xlm, readlen);
for (int j = 0; j < readlen; j += t) {
ylm[j / t] = do_min(xlm + j, t);
}
primitive_cast<float, T>(ylm, (T*)ylm, readlen / t);
LM2GM((T*)ylm, y + i, readlen / t * sizeof(T));
}
}
return;
}
#define _XPU_DEF__FAST_REDUCE_SUM_TINY_(DTYPE) \
template __global__ void fast_reduce_sum_tiny<DTYPE>( \
const DTYPE* x, DTYPE* y, int m, int t);
_XPU_DEF__FAST_REDUCE_SUM_TINY_(float);
_XPU_DEF__FAST_REDUCE_SUM_TINY_(float16);
#define _XPU_DEF__FAST_REDUCE_MEAN_TINY_(DTYPE) \
template __global__ void fast_reduce_mean_tiny<DTYPE>( \
const DTYPE* x, DTYPE* y, int m, int t);
_XPU_DEF__FAST_REDUCE_MEAN_TINY_(float);
_XPU_DEF__FAST_REDUCE_MEAN_TINY_(float16);
#define _XPU_DEF__FAST_REDUCE_MAX_TINY_(DTYPE) \
template __global__ void fast_reduce_max_tiny<DTYPE>( \
const DTYPE* x, DTYPE* y, int m, int t);
_XPU_DEF__FAST_REDUCE_MAX_TINY_(float);
_XPU_DEF__FAST_REDUCE_MAX_TINY_(float16);
#define _XPU_DEF__FAST_REDUCE_MIN_TINY_(DTYPE) \
template __global__ void fast_reduce_min_tiny<DTYPE>( \
const DTYPE* x, DTYPE* y, int m, int t);
_XPU_DEF__FAST_REDUCE_MIN_TINY_(float);
_XPU_DEF__FAST_REDUCE_MIN_TINY_(float16);
} // namespace plugin
} // namespace xpu2
......@@ -26,8 +26,6 @@ template <typename T, typename TID>
__global__ void take_along_axis(const T* x,
const TID* indices,
T* y,
const int64_t* shape,
int64_t shape_size,
int64_t batch,
int64_t xlen,
int64_t ylen) {
......@@ -40,12 +38,6 @@ __global__ void take_along_axis(const T* x,
__simd__ char lm_y[sizeof(T)];
__simd__ char lm_idx[sizeof(TID)];
__shared__ int64_t sm_shape[512];
if (cid == 0) {
GM2SM(shape, sm_shape, shape_size * sizeof(int64_t));
}
sync_all();
for (int64_t i = tid; i < batch * ylen; i += nthreads) {
GM2LM(indices + i, lm_idx, sizeof(TID));
TID idx = ((TID*)lm_idx)[0];
......@@ -65,8 +57,6 @@ __global__ void take_along_axis(const T* x,
const DTYPE* x, \
const IDTYPE* indices, \
DTYPE* y, \
const int64_t* shape, \
int64_t shape_size, \
int64_t batch, \
int64_t xlen, \
int64_t ylen);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu2 {
namespace plugin {
template <typename T>
__attribute__((global)) void fast_reduce_sum_tiny(const T* x,
T* y,
int m,
int t);
template <typename T>
__attribute__((global)) void fast_reduce_mean_tiny(const T* x,
T* y,
int m,
int t);
template <typename T>
__attribute__((global)) void fast_reduce_max_tiny(const T* x,
T* y,
int m,
int t);
template <typename T>
__attribute__((global)) void fast_reduce_min_tiny(const T* x,
T* y,
int m,
int t);
} // namespace plugin
} // namespace xpu2
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
template <typename T>
static int xpu2_wrapper(Context* ctx,
const T* x,
T* y,
const std::vector<int>& xshape,
int op_type) {
std::vector<int> rdims = {static_cast<int>(xshape.size() - 1)};
switch (op_type) {
case 0:
return reduce_sum<T>(ctx, x, y, xshape, rdims);
case 2:
return reduce_max<T>(ctx, x, y, xshape, rdims);
case 3:
return reduce_min<T>(ctx, x, y, xshape, rdims);
default:
return NOT_IMPLEMENT;
}
return SUCCESS;
}
template <>
int xpu2_wrapper<int8_t>(Context* ctx,
const int8_t* x,
int8_t* y,
const std::vector<int>& xshape,
int op_type) {
std::vector<int> rdims = {static_cast<int>(xshape.size() - 1)};
if (op_type == 0) {
return reduce_sum<int8_t>(ctx, x, y, xshape, rdims);
} else {
return NOT_IMPLEMENT;
}
return SUCCESS;
}
template <>
int xpu2_wrapper<float>(Context* ctx,
const float* x,
float* y,
const std::vector<int>& xshape,
int op_type) {
int t = xshape[xshape.size() - 1];
int xlen = vector_prod(xshape);
int m = xlen / t;
switch (op_type) {
case 0:
xpu2::plugin::fast_reduce_sum_tiny<float>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, y, m, t);
break;
case 1:
xpu2::plugin::fast_reduce_mean_tiny<float>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, y, m, t);
break;
case 2:
xpu2::plugin::fast_reduce_max_tiny<float>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, y, m, t);
break;
case 3:
xpu2::plugin::fast_reduce_min_tiny<float>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, y, m, t);
break;
default:
return NOT_IMPLEMENT;
}
return SUCCESS;
}
template <>
int xpu2_wrapper<float16>(Context* ctx,
const float16* x,
float16* y,
const std::vector<int>& xshape,
int op_type) {
int t = xshape[xshape.size() - 1];
int xlen = vector_prod(xshape);
int m = xlen / t;
switch (op_type) {
case 0:
xpu2::plugin::fast_reduce_sum_tiny<float16>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, y, m, t);
break;
case 1:
xpu2::plugin::fast_reduce_mean_tiny<float16>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, y, m, t);
break;
case 2:
xpu2::plugin::fast_reduce_max_tiny<float16>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, y, m, t);
break;
case 3:
xpu2::plugin::fast_reduce_min_tiny<float16>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, y, m, t);
break;
default:
return NOT_IMPLEMENT;
}
return SUCCESS;
}
template <typename T>
int fast_reduce_tiny(Context* ctx,
const T* x,
T* y,
const std::vector<int>& xshape,
const std::vector<int>& rdims,
int op_type) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "fast_reduce_tiny", T);
WRAPPER_DUMP_PARAM5(ctx, x, y, xshape, rdims, op_type);
WRAPPER_DUMP(ctx);
std::vector<int> yshape = xshape;
yshape[xshape.size() - 1] = 1;
int64_t lenx = -1;
int64_t leny = -1;
WRAPPER_CHECK_SHAPE(ctx, &lenx, xshape);
WRAPPER_CHECK_SHAPE(ctx, &leny, yshape);
WRAPPER_CHECK_PTR(ctx, T, lenx, x);
WRAPPER_CHECK_PTR(ctx, T, leny, y);
if (ctx->dev().type() == api::kXPU2) {
return xpu2_wrapper<T>(ctx, x, y, xshape, op_type);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
template <typename T>
DLL_EXPORT int fast_reduce_sum(Context* ctx,
const T* x,
T* y,
const std::vector<int>& xshape,
const std::vector<int>& rdims) {
if (rdims.size() == 1 && rdims[0] == xshape.size() - 1 &&
xshape[xshape.size() - 1] <= 832) {
return fast_reduce_tiny<T>(ctx, x, y, xshape, rdims, 0);
} else {
return reduce_sum<T>(ctx, x, y, xshape, rdims);
}
}
template <typename T>
DLL_EXPORT int fast_reduce_mean(Context* ctx,
const T* x,
T* y,
const std::vector<int>& xshape,
const std::vector<int>& rdims) {
if (rdims.size() == 1 && rdims[0] == xshape.size() - 1 &&
xshape[xshape.size() - 1] <= 832) {
return fast_reduce_tiny<T>(ctx, x, y, xshape, rdims, 1);
} else {
return reduce_mean<T>(ctx, x, y, xshape, rdims);
}
}
template <typename T>
DLL_EXPORT int fast_reduce_max(Context* ctx,
const T* x,
T* y,
const std::vector<int>& xshape,
const std::vector<int>& rdims) {
if (rdims.size() == 1 && rdims[0] == xshape.size() - 1 &&
xshape[xshape.size() - 1] <= 832) {
return fast_reduce_tiny<T>(ctx, x, y, xshape, rdims, 2);
} else {
return reduce_max<T>(ctx, x, y, xshape, rdims);
}
}
template <typename T>
DLL_EXPORT int fast_reduce_min(Context* ctx,
const T* x,
T* y,
const std::vector<int>& xshape,
const std::vector<int>& rdims) {
if (rdims.size() == 1 && rdims[0] == xshape.size() - 1 &&
xshape[xshape.size() - 1] <= 832) {
return fast_reduce_tiny<T>(ctx, x, y, xshape, rdims, 3);
} else {
return reduce_min<T>(ctx, x, y, xshape, rdims);
}
}
template int fast_reduce_sum(Context*,
const float*,
float*,
const std::vector<int>&,
const std::vector<int>&);
template int fast_reduce_sum(Context*,
const float16*,
float16*,
const std::vector<int>&,
const std::vector<int>&);
template int fast_reduce_sum(Context*,
const int*,
int*,
const std::vector<int>&,
const std::vector<int>&);
template int fast_reduce_sum(Context*,
const int64_t*,
int64_t*,
const std::vector<int>&,
const std::vector<int>&);
template int fast_reduce_sum(Context*,
const int8_t*,
int8_t*,
const std::vector<int>&,
const std::vector<int>&);
template int fast_reduce_mean(Context*,
const float*,
float*,
const std::vector<int>&,
const std::vector<int>&);
template int fast_reduce_mean(Context*,
const float16*,
float16*,
const std::vector<int>&,
const std::vector<int>&);
template int fast_reduce_min(Context*,
const float*,
float*,
const std::vector<int>&,
const std::vector<int>&);
template int fast_reduce_max(Context*,
const float*,
float*,
const std::vector<int>&,
const std::vector<int>&);
template int fast_reduce_max(Context*,
const int*,
int*,
const std::vector<int>&,
const std::vector<int>&);
template int fast_reduce_max(Context*,
const int64_t*,
int64_t*,
const std::vector<int>&,
const std::vector<int>&);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
......@@ -25,8 +25,6 @@ template <typename T, typename TID>
__attribute__((global)) void take_along_axis(const T* x,
const TID* indices,
T* y,
const int64_t* shape,
int64_t shape_size,
int64_t batch,
int64_t xlen,
int64_t ylen);
......@@ -74,43 +72,21 @@ static int xpu2_wrapper(Context* ctx,
const std::vector<int64_t>& idxshape,
int64_t axis) {
int64_t m_idx = 1;
int64_t shape_new_size = idxshape.size() - 1;
std::vector<int64_t> shape_new = xshape;
for (int64_t i = 0; i < axis; i++) {
m_idx *= idxshape[i];
}
for (int64_t i = axis + 1; i < xshape.size(); i++) {
shape_new[i - 1] = xshape[i];
}
int64_t t_x = xshape[axis];
int64_t t_idx = idxshape[axis];
int64_t n_idx = vector_prod(idxshape) / m_idx / t_idx;
if (m_idx < 64 && n_idx == 1) {
api::ctx_guard RAII_GUARD(ctx);
int64_t* shape_xpu = RAII_GUARD.alloc_l3_or_gm<int64_t>(shape_new_size);
WRAPPER_ASSERT_WORKSPACE(ctx, shape_xpu);
int ret = do_host2device(
ctx, shape_new.data(), shape_xpu, (shape_new_size) * sizeof(int64_t));
WRAPPER_ASSERT_SUCCESS(ctx, ret);
using XPU_TID = typename XPUIndexType<TID>::type;
const XPU_TID* casted_index =
static_cast<const XPU_TID*>(static_cast<const void*>(index));
xpu2::plugin::take_along_axis<T, XPU_TID>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
x,
casted_index,
y,
reinterpret_cast<xpu2::int64_t*>(shape_xpu),
shape_new_size,
m_idx,
t_x,
t_idx);
x, casted_index, y, m_idx, t_x, t_idx);
} else {
return gather_element(ctx, x, index, y, xshape, idxshape, axis);
}
......
......@@ -25,6 +25,36 @@
namespace phi {
static void GetReduceDims(const DDim& xdims,
const std::vector<int64_t>& dims,
bool reduce_all,
std::vector<int>* reduce_dims) {
const auto& input_dim_size = xdims.size();
std::vector<int> true_dims;
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) {
true_dims.push_back(dims[i] + input_dim_size);
} else {
true_dims.push_back(dims[i]);
}
}
if (reduce_all) {
for (int i = 0; i < input_dim_size; ++i) {
reduce_dims->push_back(i);
}
} else {
std::set<int> dims_set(true_dims.begin(), true_dims.end());
for (auto i = 0; i < input_dim_size; i++) {
if (dims_set.find(i) != dims_set.end()) {
if (xdims[i] != 1) {
reduce_dims->push_back(i);
}
}
}
}
}
template <typename Context, typename T>
int XPUReduce(const Context& dev_ctx,
const DenseTensor& x,
......@@ -43,35 +73,15 @@ int XPUReduce(const Context& dev_ctx,
const auto* x_data = x.data<T>();
auto* y_data = out->data<T>();
const auto& input_dim_size = x.dims().size();
std::vector<int> true_dims;
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) {
true_dims.push_back(dims[i] + input_dim_size);
} else {
true_dims.push_back(dims[i]);
}
}
std::vector<int> reduce_dims;
std::vector<int> xdims((input_dim_size));
const auto& input_dim_size = x.dims().size();
std::vector<int> xdims(input_dim_size);
for (int i = 0; i < input_dim_size; ++i) {
xdims[i] = x.dims()[i];
}
if (reduce_all) {
for (int i = 0; i < input_dim_size; ++i) {
reduce_dims.push_back(i);
}
} else {
std::set<int> dims_set(true_dims.begin(), true_dims.end());
for (auto i = 0; i < input_dim_size; i++) {
if (dims_set.find(i) != dims_set.end()) {
if (x.dims()[i] != 1) {
reduce_dims.push_back(i);
}
}
}
}
std::vector<int> reduce_dims;
GetReduceDims(x.dims(), dims, reduce_all, &reduce_dims);
int r = xpu::SUCCESS;
if (reduce_dims.size() == 0) {
......@@ -119,33 +129,14 @@ void XPUReduce(const DeviceContext& dev_ctx,
reduce_all = recompute_reduce_all(x, dims, reduce_all);
const auto& input_dim_size = x.dims().size();
std::vector<int> true_dims;
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) {
true_dims.push_back(dims[i] + input_dim_size);
} else {
true_dims.push_back(dims[i]);
}
}
std::vector<int> reduce_dims;
std::vector<int> xdims((input_dim_size));
std::vector<int> xdims(input_dim_size);
for (int i = 0; i < input_dim_size; ++i) {
xdims[i] = x.dims()[i];
}
if (reduce_all) {
for (int i = 0; i < input_dim_size; ++i) {
reduce_dims.push_back(i);
}
} else {
std::set<int> dims_set(true_dims.begin(), true_dims.end());
for (auto i = 0; i < input_dim_size; i++) {
if (dims_set.find(i) != dims_set.end()) {
if (x.dims()[i] != 1) {
reduce_dims.push_back(i);
}
}
}
}
std::vector<int> reduce_dims;
GetReduceDims(x.dims(), dims, reduce_all, &reduce_dims);
// no need to cast dtype
if (out_dtype == phi::DataType::UNDEFINED || out_dtype == x.dtype()) {
// do reduce sum
......
......@@ -34,11 +34,20 @@ void MaxKernel(const Context& dev_ctx,
T* y,
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
#ifndef PADDLE_WITH_XPU_PLUGIN
return xpu::reduce_max<XPUType>(ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<XPUType*>(y),
xdims,
reduce_dims);
#else
return xpu::plugin::fast_reduce_max<XPUType>(
ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<XPUType*>(y),
xdims,
reduce_dims);
#endif
};
int r = XPUReduce<Context, T>(
......
......@@ -35,11 +35,20 @@ void MeanRawKernel(const Context& dev_ctx,
T* y,
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
#ifndef PADDLE_WITH_XPU_PLUGIN
return xpu::reduce_mean<XPUType>(ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<XPUType*>(y),
xdims,
reduce_dims);
#else
return xpu::plugin::fast_reduce_mean<XPUType>(
ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<XPUType*>(y),
xdims,
reduce_dims);
#endif
};
int r = XPUReduce<Context, T>(
......
......@@ -36,11 +36,20 @@ void MinRawKernel(const Context& dev_ctx,
T* y,
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
#ifndef PADDLE_WITH_XPU_PLUGIN
return xpu::reduce_min<XPUType>(ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<XPUType*>(y),
xdims,
reduce_dims);
#else
return xpu::plugin::fast_reduce_min<XPUType>(
ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<XPUType*>(y),
xdims,
reduce_dims);
#endif
};
int r = XPUReduce<Context, T>(
......
......@@ -28,12 +28,22 @@ struct SumFunctor {
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
using XPUType = typename XPUTypeTrait<X>::Type;
#ifndef PADDLE_WITH_XPU_PLUGIN
int r = xpu::reduce_sum<XPUType>(ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<XPUType*>(y),
xdims,
reduce_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum");
#else
int r = xpu::plugin::fast_reduce_sum<XPUType>(
ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<XPUType*>(y),
xdims,
reduce_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "fast_reduce_sum");
#endif
}
};
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册