diff --git a/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h b/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h index c7fe7e74b4193256b0116ef76fba9439dac5f288..c063ebf949b917f6504f9eac2c2815fc8b17e888 100644 --- a/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h +++ b/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h @@ -75,6 +75,35 @@ DLL_EXPORT int fast_layer_norm(Context* ctx, float eps, const float* scale, const float* bias); + +template +DLL_EXPORT int fast_reduce_sum(Context* ctx, + const T* x, + T* y, + const std::vector& xshape, + const std::vector& rdims); + +template +DLL_EXPORT int fast_reduce_mean(Context* ctx, + const T* x, + T* y, + const std::vector& xshape, + const std::vector& rdims); + +template +DLL_EXPORT int fast_reduce_max(Context* ctx, + const T* x, + T* y, + const std::vector& xshape, + const std::vector& rdims); + +template +DLL_EXPORT int fast_reduce_min(Context* ctx, + const T* x, + T* y, + const std::vector& xshape, + const std::vector& rdims); + } // namespace plugin } // namespace api } // namespace xpu diff --git a/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_reduce.xpu b/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_reduce.xpu new file mode 100644 index 0000000000000000000000000000000000000000..19878718e928b04961e237d9f269c10bd0f8467f --- /dev/null +++ b/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_reduce.xpu @@ -0,0 +1,262 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/* + * copyright (C) 2022 KUNLUNXIN, Inc + */ + +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu2 { +namespace plugin { + +__device__ float do_sum_align16(float* lmptr, int size) { + __simd__ float sum_buf[16]; + float32x16_t vsum = vset_zero(); + for (int i = 0; i < size; i += 16) { + float32x16_t v0 = vload_lm_float32x16(lmptr + i); + vsum = vvadd_float32x16(vsum, v0); + } + vstore_lm_float32x16(sum_buf, vsum); + mfence_lm(); + float sum = 0.0f; + for (int i = 0; i < 16; i++) { + sum = sum + sum_buf[i]; + } + return sum; +} + +__device__ float do_sum(float* lmptr, int size) { + float sum = 0.0f; + for (int i = 0; i < size; i++) { + sum += lmptr[i]; + } + return sum; +} + +__device__ float do_max_align16(float* lmptr, int size) { + __simd__ float max_buf[16]; + float32x16_t vmax = vload_lm_float32x16(lmptr); + for (int i = 16; i < size; i += 16) { + float32x16_t v0 = vload_lm_float32x16(lmptr + i); + vmax = vvmax_float32x16(vmax, v0); + } + vstore_lm_float32x16(max_buf, vmax); + mfence_lm(); + float max_val = max_buf[0]; + for (int i = 1; i < 16; i++) { + max_val = fmax(max_val, max_buf[i]); + } + return max_val; +} + +__device__ float do_max(float* lmptr, int size) { + float max_val = lmptr[0]; + for (int i = 1; i < size; i++) { + max_val = fmax(max_val, lmptr[i]); + } + return max_val; +} + +__device__ float do_min_align16(float* lmptr, int size) { + __simd__ float min_buf[16]; + float32x16_t vmin = vload_lm_float32x16(lmptr); + for (int i = 16; i < size; i += 16) { + float32x16_t v0 = vload_lm_float32x16(lmptr + i); + vmin = vvmin_float32x16(vmin, v0); + } + vstore_lm_float32x16(min_buf, vmin); + mfence_lm(); + float min_val = min_buf[0]; + for (int i = 1; i < 16; i++) { + min_val = fmin(min_val, min_buf[i]); + } + return min_val; +} + +__device__ float do_min(float* lmptr, int size) { + float min_val = lmptr[0]; + for (int i = 1; i < size; i++) { + min_val = fmin(min_val, lmptr[i]); + } + return min_val; +} + +template +__global__ void fast_reduce_sum_tiny(const T* x, T* y, int m, int t) { + int cid = core_id(); + const int ncores = core_num(); + int tid = cid * cluster_num() + cluster_id(); + int nthreads = cluster_num() * ncores; + + const int64_t max_tt = 832; + const int64_t buffer_len = max_tt * 4 / sizeof(float); + int mstart = 0; + int mend = 0; + __simd__ float xlm[buffer_len]; + __simd__ float ylm[buffer_len]; + int block_cnt = buffer_len / t; + partition(tid, nthreads, m, 1, &mstart, &mend); + for (int i = mstart; i < mend; i += block_cnt) { + int readlen = min((mend - i) * t, block_cnt * t); + GM2LM(x + i * t, (T*)xlm, readlen * sizeof(T)); + if (t % 16 == 0 && t >= 32) { + for (int j = 0; j < readlen; j += t) { + ylm[j / t] = do_sum_align16(xlm + j, t); + } + LM2GM((T*)ylm, y + i, readlen / t * sizeof(T)); + } else { + primitive_cast((T*)xlm, xlm, readlen); + for (int j = 0; j < readlen; j += t) { + ylm[j / t] = do_sum(xlm + j, t); + } + primitive_cast(ylm, (T*)ylm, readlen / t); + LM2GM((T*)ylm, y + i, readlen / t * sizeof(T)); + } + } + return; +} + +template +__global__ void fast_reduce_mean_tiny(const T* x, T* y, int m, int t) { + int cid = core_id(); + const int ncores = core_num(); + int tid = cid * cluster_num() + cluster_id(); + int nthreads = cluster_num() * ncores; + + const int64_t max_tt = 832; + const int64_t buffer_len = max_tt * 4 / sizeof(float); + int mstart = 0; + int mend = 0; + __simd__ float xlm[buffer_len]; + __simd__ float ylm[buffer_len]; + int block_cnt = buffer_len / t; + partition(tid, nthreads, m, 1, &mstart, &mend); + for (int i = mstart; i < mend; i += block_cnt) { + int readlen = min((mend - i) * t, block_cnt * t); + GM2LM(x + i * t, (T*)xlm, readlen * sizeof(T)); + if (t % 16 == 0 && t >= 32) { + for (int j = 0; j < readlen; j += t) { + ylm[j / t] = do_sum_align16(xlm + j, t) / t; + } + LM2GM((T*)ylm, y + i, readlen / t * sizeof(T)); + } else { + primitive_cast((T*)xlm, xlm, readlen); + for (int j = 0; j < readlen; j += t) { + ylm[j / t] = do_sum(xlm + j, t) / t; + } + primitive_cast(ylm, (T*)ylm, readlen / t); + LM2GM((T*)ylm, y + i, readlen / t * sizeof(T)); + } + } + return; +} + +template +__global__ void fast_reduce_max_tiny(const T* x, T* y, int m, int t) { + int cid = core_id(); + const int ncores = core_num(); + int tid = cid * cluster_num() + cluster_id(); + int nthreads = cluster_num() * ncores; + + const int64_t max_tt = 832; + const int64_t buffer_len = max_tt * 4 / sizeof(float); + int mstart = 0; + int mend = 0; + __simd__ float xlm[buffer_len]; + __simd__ float ylm[buffer_len]; + int block_cnt = buffer_len / t; + partition(tid, nthreads, m, 1, &mstart, &mend); + for (int i = mstart; i < mend; i += block_cnt) { + int readlen = min((mend - i) * t, block_cnt * t); + GM2LM(x + i * t, (T*)xlm, readlen * sizeof(T)); + if (t % 16 == 0 && t >= 32) { + for (int j = 0; j < readlen; j += t) { + ylm[j / t] = do_max_align16(xlm + j, t); + } + LM2GM((T*)ylm, y + i, readlen / t * sizeof(T)); + } else { + primitive_cast((T*)xlm, xlm, readlen); + for (int j = 0; j < readlen; j += t) { + ylm[j / t] = do_max(xlm + j, t); + } + primitive_cast(ylm, (T*)ylm, readlen / t); + LM2GM((T*)ylm, y + i, readlen / t * sizeof(T)); + } + } + return; +} + +template +__global__ void fast_reduce_min_tiny(const T* x, T* y, int m, int t) { + int cid = core_id(); + const int ncores = core_num(); + int tid = cid * cluster_num() + cluster_id(); + int nthreads = cluster_num() * ncores; + + const int64_t max_tt = 832; + const int64_t buffer_len = max_tt * 4 / sizeof(float); + int mstart = 0; + int mend = 0; + __simd__ float xlm[buffer_len]; + __simd__ float ylm[buffer_len]; + int block_cnt = buffer_len / t; + partition(tid, nthreads, m, 1, &mstart, &mend); + for (int i = mstart; i < mend; i += block_cnt) { + int readlen = min((mend - i) * t, block_cnt * t); + GM2LM(x + i * t, (T*)xlm, readlen * sizeof(T)); + if (t % 16 == 0 && t >= 32) { + for (int j = 0; j < readlen; j += t) { + ylm[j / t] = do_min_align16(xlm + j, t); + } + LM2GM((T*)ylm, y + i, readlen / t * sizeof(T)); + } else { + primitive_cast((T*)xlm, xlm, readlen); + for (int j = 0; j < readlen; j += t) { + ylm[j / t] = do_min(xlm + j, t); + } + primitive_cast(ylm, (T*)ylm, readlen / t); + LM2GM((T*)ylm, y + i, readlen / t * sizeof(T)); + } + } + return; +} + +#define _XPU_DEF__FAST_REDUCE_SUM_TINY_(DTYPE) \ + template __global__ void fast_reduce_sum_tiny( \ + const DTYPE* x, DTYPE* y, int m, int t); +_XPU_DEF__FAST_REDUCE_SUM_TINY_(float); +_XPU_DEF__FAST_REDUCE_SUM_TINY_(float16); + +#define _XPU_DEF__FAST_REDUCE_MEAN_TINY_(DTYPE) \ + template __global__ void fast_reduce_mean_tiny( \ + const DTYPE* x, DTYPE* y, int m, int t); +_XPU_DEF__FAST_REDUCE_MEAN_TINY_(float); +_XPU_DEF__FAST_REDUCE_MEAN_TINY_(float16); + +#define _XPU_DEF__FAST_REDUCE_MAX_TINY_(DTYPE) \ + template __global__ void fast_reduce_max_tiny( \ + const DTYPE* x, DTYPE* y, int m, int t); +_XPU_DEF__FAST_REDUCE_MAX_TINY_(float); +_XPU_DEF__FAST_REDUCE_MAX_TINY_(float16); + +#define _XPU_DEF__FAST_REDUCE_MIN_TINY_(DTYPE) \ + template __global__ void fast_reduce_min_tiny( \ + const DTYPE* x, DTYPE* y, int m, int t); +_XPU_DEF__FAST_REDUCE_MIN_TINY_(float); +_XPU_DEF__FAST_REDUCE_MIN_TINY_(float16); + +} // namespace plugin +} // namespace xpu2 diff --git a/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/take_along_axis.xpu b/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/take_along_axis.xpu index abaa38a0284cf7d69cad553cca3e5da21c318c97..f607b66ba0ac2335e9e10060e761599bb33b6be0 100644 --- a/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/take_along_axis.xpu +++ b/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/take_along_axis.xpu @@ -26,8 +26,6 @@ template __global__ void take_along_axis(const T* x, const TID* indices, T* y, - const int64_t* shape, - int64_t shape_size, int64_t batch, int64_t xlen, int64_t ylen) { @@ -40,12 +38,6 @@ __global__ void take_along_axis(const T* x, __simd__ char lm_y[sizeof(T)]; __simd__ char lm_idx[sizeof(TID)]; - __shared__ int64_t sm_shape[512]; - if (cid == 0) { - GM2SM(shape, sm_shape, shape_size * sizeof(int64_t)); - } - sync_all(); - for (int64_t i = tid; i < batch * ylen; i += nthreads) { GM2LM(indices + i, lm_idx, sizeof(TID)); TID idx = ((TID*)lm_idx)[0]; @@ -65,8 +57,6 @@ __global__ void take_along_axis(const T* x, const DTYPE* x, \ const IDTYPE* indices, \ DTYPE* y, \ - const int64_t* shape, \ - int64_t shape_size, \ int64_t batch, \ int64_t xlen, \ int64_t ylen); diff --git a/paddle/phi/kernels/xpu/plugin/src/wrapper/fast_reduce.cpp b/paddle/phi/kernels/xpu/plugin/src/wrapper/fast_reduce.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7149f654c499612760babd3cf3d575fc3ea00694 --- /dev/null +++ b/paddle/phi/kernels/xpu/plugin/src/wrapper/fast_reduce.cpp @@ -0,0 +1,291 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/* + * copyright (C) 2022 KUNLUNXIN, Inc + */ + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu2 { +namespace plugin { +template +__attribute__((global)) void fast_reduce_sum_tiny(const T* x, + T* y, + int m, + int t); +template +__attribute__((global)) void fast_reduce_mean_tiny(const T* x, + T* y, + int m, + int t); +template +__attribute__((global)) void fast_reduce_max_tiny(const T* x, + T* y, + int m, + int t); +template +__attribute__((global)) void fast_reduce_min_tiny(const T* x, + T* y, + int m, + int t); +} // namespace plugin +} // namespace xpu2 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +template +static int xpu2_wrapper(Context* ctx, + const T* x, + T* y, + const std::vector& xshape, + int op_type) { + std::vector rdims = {static_cast(xshape.size() - 1)}; + switch (op_type) { + case 0: + return reduce_sum(ctx, x, y, xshape, rdims); + case 2: + return reduce_max(ctx, x, y, xshape, rdims); + case 3: + return reduce_min(ctx, x, y, xshape, rdims); + default: + return NOT_IMPLEMENT; + } + return SUCCESS; +} + +template <> +int xpu2_wrapper(Context* ctx, + const int8_t* x, + int8_t* y, + const std::vector& xshape, + int op_type) { + std::vector rdims = {static_cast(xshape.size() - 1)}; + if (op_type == 0) { + return reduce_sum(ctx, x, y, xshape, rdims); + } else { + return NOT_IMPLEMENT; + } + return SUCCESS; +} + +template <> +int xpu2_wrapper(Context* ctx, + const float* x, + float* y, + const std::vector& xshape, + int op_type) { + int t = xshape[xshape.size() - 1]; + int xlen = vector_prod(xshape); + int m = xlen / t; + switch (op_type) { + case 0: + xpu2::plugin::fast_reduce_sum_tiny + <<ncluster(), 64, ctx->xpu_stream>>>(x, y, m, t); + break; + case 1: + xpu2::plugin::fast_reduce_mean_tiny + <<ncluster(), 64, ctx->xpu_stream>>>(x, y, m, t); + break; + case 2: + xpu2::plugin::fast_reduce_max_tiny + <<ncluster(), 64, ctx->xpu_stream>>>(x, y, m, t); + break; + case 3: + xpu2::plugin::fast_reduce_min_tiny + <<ncluster(), 64, ctx->xpu_stream>>>(x, y, m, t); + break; + default: + return NOT_IMPLEMENT; + } + return SUCCESS; +} + +template <> +int xpu2_wrapper(Context* ctx, + const float16* x, + float16* y, + const std::vector& xshape, + int op_type) { + int t = xshape[xshape.size() - 1]; + int xlen = vector_prod(xshape); + int m = xlen / t; + switch (op_type) { + case 0: + xpu2::plugin::fast_reduce_sum_tiny + <<ncluster(), 64, ctx->xpu_stream>>>(x, y, m, t); + break; + case 1: + xpu2::plugin::fast_reduce_mean_tiny + <<ncluster(), 64, ctx->xpu_stream>>>(x, y, m, t); + break; + case 2: + xpu2::plugin::fast_reduce_max_tiny + <<ncluster(), 64, ctx->xpu_stream>>>(x, y, m, t); + break; + case 3: + xpu2::plugin::fast_reduce_min_tiny + <<ncluster(), 64, ctx->xpu_stream>>>(x, y, m, t); + break; + default: + return NOT_IMPLEMENT; + } + return SUCCESS; +} + +template +int fast_reduce_tiny(Context* ctx, + const T* x, + T* y, + const std::vector& xshape, + const std::vector& rdims, + int op_type) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "fast_reduce_tiny", T); + WRAPPER_DUMP_PARAM5(ctx, x, y, xshape, rdims, op_type); + WRAPPER_DUMP(ctx); + std::vector yshape = xshape; + yshape[xshape.size() - 1] = 1; + int64_t lenx = -1; + int64_t leny = -1; + WRAPPER_CHECK_SHAPE(ctx, &lenx, xshape); + WRAPPER_CHECK_SHAPE(ctx, &leny, yshape); + WRAPPER_CHECK_PTR(ctx, T, lenx, x); + WRAPPER_CHECK_PTR(ctx, T, leny, y); + + if (ctx->dev().type() == api::kXPU2) { + return xpu2_wrapper(ctx, x, y, xshape, op_type); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +template +DLL_EXPORT int fast_reduce_sum(Context* ctx, + const T* x, + T* y, + const std::vector& xshape, + const std::vector& rdims) { + if (rdims.size() == 1 && rdims[0] == xshape.size() - 1 && + xshape[xshape.size() - 1] <= 832) { + return fast_reduce_tiny(ctx, x, y, xshape, rdims, 0); + } else { + return reduce_sum(ctx, x, y, xshape, rdims); + } +} + +template +DLL_EXPORT int fast_reduce_mean(Context* ctx, + const T* x, + T* y, + const std::vector& xshape, + const std::vector& rdims) { + if (rdims.size() == 1 && rdims[0] == xshape.size() - 1 && + xshape[xshape.size() - 1] <= 832) { + return fast_reduce_tiny(ctx, x, y, xshape, rdims, 1); + } else { + return reduce_mean(ctx, x, y, xshape, rdims); + } +} + +template +DLL_EXPORT int fast_reduce_max(Context* ctx, + const T* x, + T* y, + const std::vector& xshape, + const std::vector& rdims) { + if (rdims.size() == 1 && rdims[0] == xshape.size() - 1 && + xshape[xshape.size() - 1] <= 832) { + return fast_reduce_tiny(ctx, x, y, xshape, rdims, 2); + } else { + return reduce_max(ctx, x, y, xshape, rdims); + } +} + +template +DLL_EXPORT int fast_reduce_min(Context* ctx, + const T* x, + T* y, + const std::vector& xshape, + const std::vector& rdims) { + if (rdims.size() == 1 && rdims[0] == xshape.size() - 1 && + xshape[xshape.size() - 1] <= 832) { + return fast_reduce_tiny(ctx, x, y, xshape, rdims, 3); + } else { + return reduce_min(ctx, x, y, xshape, rdims); + } +} + +template int fast_reduce_sum(Context*, + const float*, + float*, + const std::vector&, + const std::vector&); +template int fast_reduce_sum(Context*, + const float16*, + float16*, + const std::vector&, + const std::vector&); +template int fast_reduce_sum(Context*, + const int*, + int*, + const std::vector&, + const std::vector&); +template int fast_reduce_sum(Context*, + const int64_t*, + int64_t*, + const std::vector&, + const std::vector&); +template int fast_reduce_sum(Context*, + const int8_t*, + int8_t*, + const std::vector&, + const std::vector&); +template int fast_reduce_mean(Context*, + const float*, + float*, + const std::vector&, + const std::vector&); +template int fast_reduce_mean(Context*, + const float16*, + float16*, + const std::vector&, + const std::vector&); +template int fast_reduce_min(Context*, + const float*, + float*, + const std::vector&, + const std::vector&); +template int fast_reduce_max(Context*, + const float*, + float*, + const std::vector&, + const std::vector&); +template int fast_reduce_max(Context*, + const int*, + int*, + const std::vector&, + const std::vector&); +template int fast_reduce_max(Context*, + const int64_t*, + int64_t*, + const std::vector&, + const std::vector&); + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/paddle/phi/kernels/xpu/plugin/src/wrapper/take_along_axis.cpp b/paddle/phi/kernels/xpu/plugin/src/wrapper/take_along_axis.cpp index 55df45d4131b403049da036951ef5d24724d3b85..f615cc0ec602b9653c3906a1dcf10b67ed747db8 100644 --- a/paddle/phi/kernels/xpu/plugin/src/wrapper/take_along_axis.cpp +++ b/paddle/phi/kernels/xpu/plugin/src/wrapper/take_along_axis.cpp @@ -25,8 +25,6 @@ template __attribute__((global)) void take_along_axis(const T* x, const TID* indices, T* y, - const int64_t* shape, - int64_t shape_size, int64_t batch, int64_t xlen, int64_t ylen); @@ -74,43 +72,21 @@ static int xpu2_wrapper(Context* ctx, const std::vector& idxshape, int64_t axis) { int64_t m_idx = 1; - int64_t shape_new_size = idxshape.size() - 1; - std::vector shape_new = xshape; - for (int64_t i = 0; i < axis; i++) { m_idx *= idxshape[i]; } - - for (int64_t i = axis + 1; i < xshape.size(); i++) { - shape_new[i - 1] = xshape[i]; - } - int64_t t_x = xshape[axis]; int64_t t_idx = idxshape[axis]; int64_t n_idx = vector_prod(idxshape) / m_idx / t_idx; if (m_idx < 64 && n_idx == 1) { - api::ctx_guard RAII_GUARD(ctx); - int64_t* shape_xpu = RAII_GUARD.alloc_l3_or_gm(shape_new_size); - WRAPPER_ASSERT_WORKSPACE(ctx, shape_xpu); - int ret = do_host2device( - ctx, shape_new.data(), shape_xpu, (shape_new_size) * sizeof(int64_t)); - WRAPPER_ASSERT_SUCCESS(ctx, ret); - using XPU_TID = typename XPUIndexType::type; const XPU_TID* casted_index = static_cast(static_cast(index)); xpu2::plugin::take_along_axis <<ncluster(), 64, ctx->xpu_stream>>>( - x, - casted_index, - y, - reinterpret_cast(shape_xpu), - shape_new_size, - m_idx, - t_x, - t_idx); + x, casted_index, y, m_idx, t_x, t_idx); } else { return gather_element(ctx, x, index, y, xshape, idxshape, axis); } diff --git a/paddle/phi/kernels/xpu/reduce.h b/paddle/phi/kernels/xpu/reduce.h index 1507024265e23d68dadd77ae52c78f429bad6b74..9cf57ea2cce717966749a2bea3f5772e49a66265 100644 --- a/paddle/phi/kernels/xpu/reduce.h +++ b/paddle/phi/kernels/xpu/reduce.h @@ -25,6 +25,36 @@ namespace phi { +static void GetReduceDims(const DDim& xdims, + const std::vector& dims, + bool reduce_all, + std::vector* reduce_dims) { + const auto& input_dim_size = xdims.size(); + std::vector true_dims; + for (size_t i = 0; i < dims.size(); ++i) { + if (dims[i] < 0) { + true_dims.push_back(dims[i] + input_dim_size); + } else { + true_dims.push_back(dims[i]); + } + } + + if (reduce_all) { + for (int i = 0; i < input_dim_size; ++i) { + reduce_dims->push_back(i); + } + } else { + std::set dims_set(true_dims.begin(), true_dims.end()); + for (auto i = 0; i < input_dim_size; i++) { + if (dims_set.find(i) != dims_set.end()) { + if (xdims[i] != 1) { + reduce_dims->push_back(i); + } + } + } + } +} + template int XPUReduce(const Context& dev_ctx, const DenseTensor& x, @@ -43,35 +73,15 @@ int XPUReduce(const Context& dev_ctx, const auto* x_data = x.data(); auto* y_data = out->data(); - const auto& input_dim_size = x.dims().size(); - std::vector true_dims; - for (size_t i = 0; i < dims.size(); ++i) { - if (dims[i] < 0) { - true_dims.push_back(dims[i] + input_dim_size); - } else { - true_dims.push_back(dims[i]); - } - } - std::vector reduce_dims; - std::vector xdims((input_dim_size)); + const auto& input_dim_size = x.dims().size(); + std::vector xdims(input_dim_size); for (int i = 0; i < input_dim_size; ++i) { xdims[i] = x.dims()[i]; } - if (reduce_all) { - for (int i = 0; i < input_dim_size; ++i) { - reduce_dims.push_back(i); - } - } else { - std::set dims_set(true_dims.begin(), true_dims.end()); - for (auto i = 0; i < input_dim_size; i++) { - if (dims_set.find(i) != dims_set.end()) { - if (x.dims()[i] != 1) { - reduce_dims.push_back(i); - } - } - } - } + + std::vector reduce_dims; + GetReduceDims(x.dims(), dims, reduce_all, &reduce_dims); int r = xpu::SUCCESS; if (reduce_dims.size() == 0) { @@ -119,33 +129,14 @@ void XPUReduce(const DeviceContext& dev_ctx, reduce_all = recompute_reduce_all(x, dims, reduce_all); const auto& input_dim_size = x.dims().size(); - std::vector true_dims; - for (size_t i = 0; i < dims.size(); ++i) { - if (dims[i] < 0) { - true_dims.push_back(dims[i] + input_dim_size); - } else { - true_dims.push_back(dims[i]); - } - } - std::vector reduce_dims; - std::vector xdims((input_dim_size)); + std::vector xdims(input_dim_size); for (int i = 0; i < input_dim_size; ++i) { xdims[i] = x.dims()[i]; } - if (reduce_all) { - for (int i = 0; i < input_dim_size; ++i) { - reduce_dims.push_back(i); - } - } else { - std::set dims_set(true_dims.begin(), true_dims.end()); - for (auto i = 0; i < input_dim_size; i++) { - if (dims_set.find(i) != dims_set.end()) { - if (x.dims()[i] != 1) { - reduce_dims.push_back(i); - } - } - } - } + + std::vector reduce_dims; + GetReduceDims(x.dims(), dims, reduce_all, &reduce_dims); + // no need to cast dtype if (out_dtype == phi::DataType::UNDEFINED || out_dtype == x.dtype()) { // do reduce sum diff --git a/paddle/phi/kernels/xpu/reduce_max_kernel.cc b/paddle/phi/kernels/xpu/reduce_max_kernel.cc index 1bc56a3990ea2ccd9efde471b83db6195ed22d52..8842f86b0c9fb35b577866fe060a40190ca24d13 100644 --- a/paddle/phi/kernels/xpu/reduce_max_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_max_kernel.cc @@ -34,11 +34,20 @@ void MaxKernel(const Context& dev_ctx, T* y, const std::vector& xdims, const std::vector& reduce_dims) { +#ifndef PADDLE_WITH_XPU_PLUGIN return xpu::reduce_max(ctx, reinterpret_cast(x), reinterpret_cast(y), xdims, reduce_dims); +#else + return xpu::plugin::fast_reduce_max( + ctx, + reinterpret_cast(x), + reinterpret_cast(y), + xdims, + reduce_dims); +#endif }; int r = XPUReduce( diff --git a/paddle/phi/kernels/xpu/reduce_mean_kernel.cc b/paddle/phi/kernels/xpu/reduce_mean_kernel.cc index cb0bfb6218a889cb15232a8f54171058a64c123a..7a340d45d1fe4765bd81b4d01db0bdd0c705bc7f 100644 --- a/paddle/phi/kernels/xpu/reduce_mean_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_mean_kernel.cc @@ -35,11 +35,20 @@ void MeanRawKernel(const Context& dev_ctx, T* y, const std::vector& xdims, const std::vector& reduce_dims) { +#ifndef PADDLE_WITH_XPU_PLUGIN return xpu::reduce_mean(ctx, reinterpret_cast(x), reinterpret_cast(y), xdims, reduce_dims); +#else + return xpu::plugin::fast_reduce_mean( + ctx, + reinterpret_cast(x), + reinterpret_cast(y), + xdims, + reduce_dims); +#endif }; int r = XPUReduce( diff --git a/paddle/phi/kernels/xpu/reduce_min_kernel.cc b/paddle/phi/kernels/xpu/reduce_min_kernel.cc index 6c63615f8ec16df584d717726bc3c63f22f03691..721a579945ceb3b0870faf64cc7c48f60cc5ccec 100644 --- a/paddle/phi/kernels/xpu/reduce_min_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_min_kernel.cc @@ -36,11 +36,20 @@ void MinRawKernel(const Context& dev_ctx, T* y, const std::vector& xdims, const std::vector& reduce_dims) { +#ifndef PADDLE_WITH_XPU_PLUGIN return xpu::reduce_min(ctx, reinterpret_cast(x), reinterpret_cast(y), xdims, reduce_dims); +#else + return xpu::plugin::fast_reduce_min( + ctx, + reinterpret_cast(x), + reinterpret_cast(y), + xdims, + reduce_dims); +#endif }; int r = XPUReduce( diff --git a/paddle/phi/kernels/xpu/reduce_util.h b/paddle/phi/kernels/xpu/reduce_util.h index cd624cc1ef1f0da56c1386d0a870ee08fe138bb7..3c0d35d641fd9416430b31d47ec07a8eb50d6e65 100644 --- a/paddle/phi/kernels/xpu/reduce_util.h +++ b/paddle/phi/kernels/xpu/reduce_util.h @@ -28,12 +28,22 @@ struct SumFunctor { const std::vector& xdims, const std::vector& reduce_dims) { using XPUType = typename XPUTypeTrait::Type; +#ifndef PADDLE_WITH_XPU_PLUGIN int r = xpu::reduce_sum(ctx, reinterpret_cast(x), reinterpret_cast(y), xdims, reduce_dims); PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum"); +#else + int r = xpu::plugin::fast_reduce_sum( + ctx, + reinterpret_cast(x), + reinterpret_cast(y), + xdims, + reduce_dims); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "fast_reduce_sum"); +#endif } }; } // namespace phi