未验证 提交 07e788f1 编写于 作者: H hong19860320 提交者: GitHub

[XPU] Add fast_where fusion op and XPU micro kernel (#55628)

上级 744e1eaf
......@@ -280,6 +280,7 @@ if(WITH_XPU)
pass_library(matmul_weight_trans_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(reshape2_matmul_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(fast_where_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
endif()
cc_library(
......@@ -599,4 +600,8 @@ if(WITH_XPU)
test_reshape2_matmul_xpu_fuse_pass
SRCS xpu/reshape2_matmul_xpu_fuse_pass_test.cc
DEPS reshape2_matmul_xpu_fuse_pass)
cc_test(
test_fast_where_xpu_fuse_pass
SRCS xpu/fast_where_xpu_fuse_pass_test.cc
DEPS fast_where_xpu_fuse_pass)
endif()
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
#define APPLY_PASS \
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); \
auto pass = PassRegistry::Instance().Get("fast_where_xpu_fuse_pass"); \
pass->Apply(graph.get());
#define VERIFY_GRAPH(x, y) \
auto num_op_nodes = GetNumOpNodes(graph); \
PADDLE_ENFORCE_EQ( \
num_op_nodes, \
1, \
platform::errors::PreconditionNotMet( \
"The graph contains only one op node, but %d op nodes found.", \
num_op_nodes)); \
auto fast_where_xpu_op_nodes = GetOpNodes(graph, "fast_where_xpu"); \
PADDLE_ENFORCE_EQ(fast_where_xpu_op_nodes.size(), \
1, \
platform::errors::PreconditionNotMet( \
"The graph contains only a fast_where_xpu op node, " \
"but %d op nodes found.", \
fast_where_xpu_op_nodes.size())); \
const auto& x_name = fast_where_xpu_op_nodes[0]->Op()->Input("x")[0]; \
PADDLE_ENFORCE_EQ(x_name, \
#x, \
platform::errors::PreconditionNotMet( \
"The input 'x' of fast_where_xpu op should be '%s', " \
"but receive '%s'.", \
#x, \
x_name)); \
const auto& y_name = fast_where_xpu_op_nodes[0]->Op()->Input("y")[0]; \
PADDLE_ENFORCE_EQ(y_name, \
#y, \
platform::errors::PreconditionNotMet( \
"The input 'y' of fast_where_xpu op should be '%s', " \
"but receive '%s'.", \
#y, \
y_name));
TEST(FastWhereXPUFusePass, one_case0) {
Layers layers;
auto* condition =
layers.data("condition", {20, 1}, false, proto::VarType::BOOL);
auto* x = layers.data("x", {20, 7});
auto* y = layers.data("y", {20, 7});
auto* cast_out = layers.cast(condition, 0, 5);
cast_out->SetShape({20, 1});
auto* scale_out = layers.scale(cast_out, -1.0f, 1.0f, true);
scale_out->SetShape({20, 1});
auto* mul0_out = layers.elementwise_mul(x, scale_out);
mul0_out->SetShape({20, 7});
auto* mul1_out = layers.elementwise_mul(y, cast_out);
mul1_out->SetShape({20, 7});
auto* add_out = layers.elementwise_add(mul0_out, mul1_out);
add_out->SetShape({20, 7});
APPLY_PASS
VERIFY_GRAPH(y, x)
}
TEST(FastWhereXPUFusePass, one_case1) {
Layers layers;
auto* condition =
layers.data("condition", {20, 1}, false, proto::VarType::BOOL);
auto* x = layers.data("x", {20, 7});
auto* y = layers.data("y", {20, 7});
auto* cast_out = layers.cast(condition, 0, 5);
cast_out->SetShape({20, 1});
auto* mul0_out = layers.elementwise_mul(x, cast_out);
mul0_out->SetShape({20, 7});
auto* scale_out = layers.scale(cast_out, -1.0f, 1.0f, true);
scale_out->SetShape({20, 1});
auto* mul1_out = layers.elementwise_mul(y, scale_out);
mul1_out->SetShape({20, 7});
auto* add_out = layers.elementwise_add(mul0_out, mul1_out);
add_out->SetShape({20, 7});
APPLY_PASS
VERIFY_GRAPH(x, y)
}
TEST(FastWhereXPUFusePass, one_case2) {
Layers layers;
auto* condition =
layers.data("condition", {20, 1}, false, proto::VarType::BOOL);
auto* x = layers.data("x", {20, 7});
auto* y = layers.data("y", {20, 7});
auto* cast_out = layers.cast(condition, 0, 5);
cast_out->SetShape({20, 1});
auto* scale_out = layers.scale(cast_out, -1.0f, 1.0f, true);
scale_out->SetShape({20, 1});
auto* mul0_out = layers.elementwise_mul(scale_out, x);
mul0_out->SetShape({20, 7});
auto* mul1_out = layers.elementwise_mul(cast_out, y);
mul1_out->SetShape({20, 7});
auto* add_out = layers.elementwise_add(mul0_out, mul1_out);
add_out->SetShape({20, 7});
APPLY_PASS
VERIFY_GRAPH(y, x)
}
TEST(FastWhereXPUFusePass, one_case3) {
Layers layers;
auto* condition =
layers.data("condition", {20, 1}, false, proto::VarType::BOOL);
auto* x = layers.data("x", {20, 7});
auto* y = layers.data("y", {20, 7});
auto* cast_out = layers.cast(condition, 0, 5);
cast_out->SetShape({20, 1});
auto* mul0_out = layers.elementwise_mul(cast_out, x);
mul0_out->SetShape({20, 7});
auto* scale_out = layers.scale(cast_out, -1.0f, 1.0f, true);
scale_out->SetShape({20, 1});
auto* mul1_out = layers.elementwise_mul(scale_out, y);
mul1_out->SetShape({20, 7});
auto* add_out = layers.elementwise_add(mul0_out, mul1_out);
add_out->SetShape({20, 7});
APPLY_PASS
VERIFY_GRAPH(x, y)
}
TEST(FastWhereXPUFusePass, one_case4) {
Layers layers;
auto* condition =
layers.data("condition", {20, 1}, false, proto::VarType::BOOL);
auto* x = layers.data("x", {20, 7});
auto* y = layers.data("y", {20, 7});
auto* cast_out = layers.cast(condition, 0, 5);
cast_out->SetShape({20, 1});
auto* scale_out = layers.scale(cast_out, -1.0f, 1.0f, true);
scale_out->SetShape({20, 1});
auto* mul0_out = layers.elementwise_mul(scale_out, x);
mul0_out->SetShape({20, 7});
auto* mul1_out = layers.elementwise_mul(y, cast_out);
mul1_out->SetShape({20, 7});
auto* add_out = layers.elementwise_add(mul0_out, mul1_out);
add_out->SetShape({20, 7});
APPLY_PASS
VERIFY_GRAPH(y, x)
}
TEST(FastWhereXPUFusePass, one_case5) {
Layers layers;
auto* condition =
layers.data("condition", {20, 1}, false, proto::VarType::BOOL);
auto* x = layers.data("x", {20, 7});
auto* y = layers.data("y", {20, 7});
auto* cast_out = layers.cast(condition, 0, 5);
cast_out->SetShape({20, 1});
auto* mul0_out = layers.elementwise_mul(cast_out, x);
mul0_out->SetShape({20, 7});
auto* scale_out = layers.scale(cast_out, -1.0f, 1.0f, true);
scale_out->SetShape({20, 1});
auto* mul1_out = layers.elementwise_mul(y, scale_out);
mul1_out->SetShape({20, 7});
auto* add_out = layers.elementwise_add(mul0_out, mul1_out);
add_out->SetShape({20, 7});
APPLY_PASS
VERIFY_GRAPH(x, y)
}
#undef VERIFY_GRAPH
#define VERIFY_GRAPH(logical_op, x, y) \
auto num_op_nodes = GetNumOpNodes(graph); \
PADDLE_ENFORCE_EQ( \
num_op_nodes, \
2, \
platform::errors::PreconditionNotMet( \
"The graph contains only two op nodes, but %d op nodes found.", \
num_op_nodes)); \
auto logical_op_nodes = GetOpNodes(graph, #logical_op); \
PADDLE_ENFORCE_EQ( \
logical_op_nodes.size(), \
1, \
platform::errors::PreconditionNotMet( \
"The graph contains only a '%s' op node, but %d op nodes found.", \
#logical_op, \
logical_op_nodes.size())); \
auto fast_where_xpu_op_nodes = GetOpNodes(graph, "fast_where_xpu"); \
PADDLE_ENFORCE_EQ(fast_where_xpu_op_nodes.size(), \
1, \
platform::errors::PreconditionNotMet( \
"The graph contains only a fast_where_xpu op node, " \
"but %d op nodes found.", \
fast_where_xpu_op_nodes.size())); \
const auto& x_name = fast_where_xpu_op_nodes[0]->Op()->Input("x")[0]; \
PADDLE_ENFORCE_EQ(x_name, \
#x, \
platform::errors::PreconditionNotMet( \
"The input 'x' of fast_where_xpu op should be '%s', " \
"but receive '%s'.", \
#x, \
x_name)); \
const auto& y_name = fast_where_xpu_op_nodes[0]->Op()->Input("y")[0]; \
PADDLE_ENFORCE_EQ(y_name, \
#y, \
platform::errors::PreconditionNotMet( \
"The input 'y' of fast_where_xpu op should be '%s', " \
"but receive '%s'.", \
#y, \
y_name));
TEST(FastWhereXPUFusePass, cascade_case0) {
Layers layers;
auto* condition0 =
layers.data("condition0", {20, 1}, false, proto::VarType::BOOL);
auto* condition1 =
layers.data("condition1", {20, 1}, false, proto::VarType::BOOL);
auto* x = layers.data("x", {20, 7});
auto* y = layers.data("y", {20, 7});
// fast_where_xpu0
auto* cast0_out = layers.cast(condition0, 0, 5);
cast0_out->SetShape({20, 1});
auto* mul0_out = layers.elementwise_mul(cast0_out, x);
mul0_out->SetShape({20, 7});
auto* scale0_out = layers.scale(cast0_out, -1.0f, 1.0f, true);
scale0_out->SetShape({20, 1});
auto* mul1_out = layers.elementwise_mul(scale0_out, y);
mul1_out->SetShape({20, 7});
auto* add0_out = layers.elementwise_add(mul0_out, mul1_out);
add0_out->SetShape({20, 7});
// fast_where_xpu1
auto* cast1_out = layers.cast(condition1, 0, 5);
cast1_out->SetShape({20, 1});
auto* mul2_out = layers.elementwise_mul(cast1_out, x);
mul2_out->SetShape({20, 7});
auto* scale1_out = layers.scale(cast1_out, -1.0f, 1.0f, true);
scale1_out->SetShape({20, 1});
auto* mul3_out = layers.elementwise_mul(scale1_out, add0_out);
mul3_out->SetShape({20, 7});
auto* add1_out = layers.elementwise_add(mul2_out, mul3_out);
add1_out->SetShape({20, 7});
APPLY_PASS
VERIFY_GRAPH(logical_or, x, y)
}
TEST(FastWhereXPUFusePass, cascade_case1) {
Layers layers;
auto* condition0 =
layers.data("condition0", {20, 1}, false, proto::VarType::BOOL);
auto* condition1 =
layers.data("condition1", {20, 1}, false, proto::VarType::BOOL);
auto* x = layers.data("x", {20, 7});
auto* y = layers.data("y", {20, 7});
// fast_where_xpu0
auto* cast0_out = layers.cast(condition0, 0, 5);
cast0_out->SetShape({20, 1});
auto* mul0_out = layers.elementwise_mul(cast0_out, x);
mul0_out->SetShape({20, 7});
auto* scale0_out = layers.scale(cast0_out, -1.0f, 1.0f, true);
scale0_out->SetShape({20, 1});
auto* mul1_out = layers.elementwise_mul(scale0_out, y);
mul1_out->SetShape({20, 7});
auto* add0_out = layers.elementwise_add(mul0_out, mul1_out);
add0_out->SetShape({20, 7});
// fast_where_xpu1
auto* cast1_out = layers.cast(condition1, 0, 5);
cast1_out->SetShape({20, 1});
auto* mul2_out = layers.elementwise_mul(cast1_out, add0_out);
mul2_out->SetShape({20, 7});
auto* scale1_out = layers.scale(cast1_out, -1.0f, 1.0f, true);
scale1_out->SetShape({20, 1});
auto* mul3_out = layers.elementwise_mul(scale1_out, y);
mul3_out->SetShape({20, 7});
auto* add1_out = layers.elementwise_add(mul2_out, mul3_out);
add1_out->SetShape({20, 7});
APPLY_PASS
VERIFY_GRAPH(logical_and, x, y)
}
#undef APPLY_PASS
#undef VERIFY_GRAPH
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(fast_where_xpu_fuse_pass);
......@@ -545,6 +545,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"add_activation_xpu_fuse_pass",
"add_layernorm_xpu_fuse_pass",
"yolo_box_xpu_fuse_pass",
"fast_where_xpu_fuse_pass",
"link_xpu_op_max_pass",
"delete_isolated_node_pass",
// "auto_mixed_precision_pass",
......
......@@ -53,6 +53,15 @@
data_type: tables
optional : mask, seq_lod, max_seq_len
- op : fast_where_xpu
args : (Tensor condition, Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : FastWhereXPUInferMeta
kernel :
func : fast_where_xpu
data_type : x
- op : fc_xpu
args : (Tensor x, Tensor x_max, Tensor w, Tensor w_max, Tensor bias, int in_num_col_dims, bool transpose_x, float alpha, float beta, int act_type, float act_alpha, DataType out_dtype)
output : Tensor(out), Tensor(out_max)
......
......@@ -295,6 +295,10 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::BOOL,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"fast_where_xpu",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
{"fc_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"fill",
......
......@@ -721,4 +721,12 @@ void Conv2dTransposeXPUInferMeta(const MetaTensor& x,
out_max);
}
void FastWhereXPUInferMeta(const MetaTensor& condition,
const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(x.dtype());
}
} // namespace phi
......@@ -175,4 +175,10 @@ void Conv2dTransposeXPUInferMeta(const MetaTensor& x,
const std::string& act_type,
MetaTensor* out,
MetaTensor* out_max);
void FastWhereXPUInferMeta(const MetaTensor& condition,
const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "glog/logging.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void FastWhereXPUKernel(const Context& ctx,
const DenseTensor& condition,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto* condition_data = condition.data<bool>();
auto* x_data = reinterpret_cast<const XPUType*>(x.data<T>());
auto* y_data = reinterpret_cast<const XPUType*>(y.data<T>());
auto* out_data = reinterpret_cast<XPUType*>(ctx.template Alloc<T>(out));
auto condition_dims = phi::vectorize<int>(condition.dims());
auto x_dims = phi::vectorize<int>(x.dims());
auto y_dims = phi::vectorize<int>(y.dims());
PADDLE_ENFORCE_EQ(
x_dims,
y_dims,
errors::PreconditionNotMet(
"The dimensions of inputs should be equal, but x_dims=[",
x.dims(),
"] and y_dims=[",
y.dims(),
"]"));
#ifndef PADDLE_WITH_XPU_PLUGIN
LOG(WARNING)
<< "Add -DWITH_XPU_PLUGIN=ON to build xpu::plugin::fast_where(), or use "
"xpu::select() instead, which leads low performance.";
int r = xpu::select<XPUType>(ctx.x_context(),
condition_data,
x_data,
y_data,
out_data,
condition_dims,
x_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "select");
#else
xpu::ctx_guard RAII_GUARD(ctx.x_context());
if (condition_dims != x_dims) {
bool* temp_data = RAII_GUARD.alloc_l3_or_gm<bool>(x.numel());
int r = xpu::broadcast<bool>(
ctx.x_context(), condition_data, temp_data, condition_dims, x_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast");
condition_data = temp_data;
}
int r = xpu::plugin::fast_where<XPUType>(
ctx.x_context(), condition_data, x_data, y_data, out_data, x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "fast_where");
#endif
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fast_where_xpu,
XPU,
ALL_LAYOUT,
phi::fusion::FastWhereXPUKernel,
float,
phi::dtype::float16,
int) {}
......@@ -154,7 +154,7 @@ macro(
${kernel_path} -D ${xpu_n_macro} --target=${TARGET_ARCH} ${HOST_XPU_FLAGS}
--basename ${kernel_name} -fno-builtin --xpu-arch=${xpu_n} -fPIC
-Wno-int-to-void-pointer-cast -Wno-int-to-pointer-cast -Werror -mllvm
--xpu-inline-cost -mllvm --xpu-inline-hot-call
--xpu-inline-cost -mllvm --xpu-inline-hot-call -I${XDNN_INC_DIR}
-I${CMAKE_CURRENT_SOURCE_DIR}/include -I${CMAKE_CURRENT_SOURCE_DIR}/src
-I${CMAKE_CURRENT_SOURCE_DIR}/src/kernel
-I${CMAKE_CURRENT_SOURCE_DIR}/src/kernel/include ${arg_rule}
......
......@@ -24,6 +24,13 @@ namespace api {
namespace plugin {
DLL_EXPORT int add2(Context* ctx, const float* x, float* y, int len);
template <typename T>
DLL_EXPORT int fast_where(Context* ctx,
const bool* condition,
const T* x,
const T* y,
T* out,
int64_t len);
} // namespace plugin
} // namespace api
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu2 {
namespace plugin {
#define CALC_MASK(offset) \
mask |= static_cast<int>(condition[i + offset]) << offset;
static __device__ inline void do_select_16(const int8_t* condition,
const int16_t* x,
int16_t* y,
int len) {
int len_rounddown32 = rounddown32(len);
for (int i = 0; i < len_rounddown32; i += 32) {
int mask = condition[i];
CALC_MASK(1)
CALC_MASK(2)
CALC_MASK(3)
CALC_MASK(4)
CALC_MASK(5)
CALC_MASK(6)
CALC_MASK(7)
CALC_MASK(8)
CALC_MASK(9)
CALC_MASK(10)
CALC_MASK(11)
CALC_MASK(12)
CALC_MASK(13)
CALC_MASK(14)
CALC_MASK(15)
CALC_MASK(16)
CALC_MASK(17)
CALC_MASK(18)
CALC_MASK(19)
CALC_MASK(20)
CALC_MASK(21)
CALC_MASK(22)
CALC_MASK(23)
CALC_MASK(24)
CALC_MASK(25)
CALC_MASK(26)
CALC_MASK(27)
CALC_MASK(28)
CALC_MASK(29)
CALC_MASK(30)
CALC_MASK(31)
vstore_lm_int16x32_mh(y + i, vload_lm_int16x32(x + i), mask);
}
for (int i = len_rounddown32; i < len; i++) {
y[i] = condition[i] ? x[i] : y[i];
}
mfence_lm();
}
static __device__ inline void do_select_32(const int8_t* condition,
const int32_t* x,
int32_t* y,
int len) {
int len_rounddown16 = rounddown16(len);
for (int i = 0; i < len_rounddown16; i += 16) {
int mask = condition[i];
CALC_MASK(1)
CALC_MASK(2)
CALC_MASK(3)
CALC_MASK(4)
CALC_MASK(5)
CALC_MASK(6)
CALC_MASK(7)
CALC_MASK(8)
CALC_MASK(9)
CALC_MASK(10)
CALC_MASK(11)
CALC_MASK(12)
CALC_MASK(13)
CALC_MASK(14)
CALC_MASK(15)
vstore_lm_int32x16_mh(y + i, vload_lm_int32x16(x + i), mask);
}
for (int i = len_rounddown16; i < len; i++) {
y[i] = condition[i] ? x[i] : y[i];
}
mfence_lm();
}
template <typename T>
static __device__ void do_select(const int8_t* condition,
const T* x,
T* y,
int len) {}
template <>
__device__ void do_select<float16>(const int8_t* condition,
const float16* x,
float16* y,
int len) {
do_select_16(condition,
reinterpret_cast<const int16_t*>(x),
reinterpret_cast<int16_t*>(y),
len);
}
template <>
__device__ void do_select<float>(const int8_t* condition,
const float* x,
float* y,
int len) {
do_select_32(condition,
reinterpret_cast<const int32_t*>(x),
reinterpret_cast<int32_t*>(y),
len);
}
template <>
__device__ void do_select<int16_t>(const int8_t* condition,
const int16_t* x,
int16_t* y,
int len) {
do_select_16(condition, x, y, len);
}
template <>
__device__ void do_select<int32_t>(const int8_t* condition,
const int32_t* x,
int32_t* y,
int len) {
do_select_32(condition, x, y, len);
}
template <typename T>
__global__ void fast_where(
const int8_t* condition, const T* x, const T* y, T* z, int64_t len) {
int tid = core_id() * cluster_num() + cluster_id();
int nthreads = core_num() * cluster_num();
#ifdef __XPU3__
const int buf_len = 1536 / sizeof(T);
#else
const int buf_len = 512 / sizeof(T);
#endif
__simd__ int8_t local_condition[buf_len];
__simd__ T local_x[buf_len];
__simd__ T local_y[buf_len];
int loop = 0;
for (int64_t i = tid * buf_len; i < len; i += nthreads * buf_len) {
int read_len = min(static_cast<int64_t>(buf_len), len - i);
GM2LM_ASYNC(condition + i, local_condition, read_len * sizeof(int8_t));
GM2LM_ASYNC(x + i, local_x, read_len * sizeof(T));
GM2LM(y + i, local_y, read_len * sizeof(T));
do_select<T>(local_condition, local_x, local_y, read_len);
LM2GM_ASYNC(local_y, z + i, read_len * sizeof(T));
mfence();
#ifndef __XPU3__
loop++;
if ((loop & 0xF) == 0) {
sync_all();
}
#endif
}
}
#define _XPU_DEF__FAST_WHERE_(DTYPE) \
template __global__ void fast_where<DTYPE>(const int8_t* condition, \
const DTYPE* x, \
const DTYPE* y, \
DTYPE* z, \
int64_t len);
_XPU_DEF__FAST_WHERE_(float16);
_XPU_DEF__FAST_WHERE_(float);
_XPU_DEF__FAST_WHERE_(int16_t);
_XPU_DEF__FAST_WHERE_(int32_t);
} // namespace plugin
} // namespace xpu2
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu2 {
namespace plugin {
template <typename T>
__attribute__((global)) void fast_where(
const int8_t* condition, const T* x, const T* y, T* z, int64_t len);
}
} // namespace xpu2
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
template <typename T>
static int cpu_wrapper(Context* ctx,
const bool* condition,
const T* x,
const T* y,
T* z,
int64_t len) {
for (int64_t i = 0; i < len; i++) {
z[i] = condition[i] ? x[i] : y[i];
}
return SUCCESS;
}
template <>
int cpu_wrapper<float16>(Context* ctx,
const bool* condition,
const float16* x,
const float16* y,
float16* z,
int64_t len) {
std::vector<float> x_fp32(len);
std::vector<float> y_fp32(len);
std::vector<float> z_fp32(len);
int ret = cast<float16, float>(ctx, x, x_fp32.data(), len);
ret = cast<float16, float>(ctx, y, y_fp32.data(), len);
ret = cpu_wrapper<float>(
ctx, condition, x_fp32.data(), y_fp32.data(), z_fp32.data(), len);
ret = cast<float, float16>(ctx, z_fp32.data(), z, len);
WRAPPER_ASSERT_SUCCESS(ctx, ret);
return ret;
}
template <typename T>
static int xpu2_wrapper(Context* ctx,
const bool* condition,
const T* x,
const T* y,
T* z,
int64_t len) {
xpu2::plugin::fast_where<T><<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
reinterpret_cast<const int8_t*>(condition), x, y, z, len);
return SUCCESS;
}
template <typename T>
int fast_where(Context* ctx,
const bool* condition,
const T* x,
const T* y,
T* z,
int64_t len) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "fast_where", float);
WRAPPER_DUMP_PARAM5(ctx, condition, x, y, z, len);
WRAPPER_DUMP(ctx);
WRAPPER_ASSERT_GT(ctx, len, 0);
WRAPPER_CHECK_2PTRS(ctx, T, len, x, y);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<T>(ctx, condition, x, y, z, len);
}
if (ctx->dev().type() == api::kXPU2) {
return xpu2_wrapper<T>(ctx, condition, x, y, z, len);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
template int fast_where(Context*,
const bool* condition,
const float*,
const float*,
float*,
int64_t);
template int fast_where(Context*,
const bool* condition,
const float16*,
const float16*,
float16*,
int64_t);
template int fast_where(Context*,
const bool* condition,
const int16_t*,
const int16_t*,
int16_t*,
int64_t);
template int fast_where(Context*,
const bool* condition,
const int32_t*,
const int32_t*,
int32_t*,
int64_t);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册