未验证 提交 ac75a9a6 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Paddle Inference] fix mixed precision diff (#49475)

上级 04aa80e6
...@@ -78,11 +78,11 @@ inline bool VarNodeHasDtype(Node* var_node) { ...@@ -78,11 +78,11 @@ inline bool VarNodeHasDtype(Node* var_node) {
(type == VarType::VOCAB); (type == VarType::VOCAB);
} }
inline bool IsFloatType(VarType::Type type) { inline bool IsFP32AndFP64(VarType::Type type) {
return (type == VarType::FP64) || (type == VarType::FP32); return (type == VarType::FP64) || (type == VarType::FP32);
} }
inline bool IsHalfType(VarType::Type type) { inline bool IsFP16AndBFP16(VarType::Type type) {
return (type == VarType::FP16) || (type == VarType::BF16); return (type == VarType::FP16) || (type == VarType::BF16);
} }
...@@ -159,23 +159,14 @@ bool OpSupportPrecision(const std::string& op_type, ...@@ -159,23 +159,14 @@ bool OpSupportPrecision(const std::string& op_type,
// The set of ops that support fp16 calculation and are considered // The set of ops that support fp16 calculation and are considered
// numerically-dangerous, slower and whose effects may also be observed in // numerically-dangerous, slower and whose effects may also be observed in
// downstream ops. // downstream ops.
// ref to python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
void AutoMixedPrecisionPass::SetDefaultBlacklist() const { void AutoMixedPrecisionPass::SetDefaultBlacklist() const {
black_list_.insert({ black_list_.insert({
// numerically-dangerous // numerically-dangerous
"acos",
"asin",
"cosh",
"tan",
"exp", "exp",
"expm1",
"square", "square",
"log", "log",
"log2",
"log10",
"log1p",
"logsumexp",
"mean", "mean",
"rsqrt",
"sum", "sum",
"cos_sim", "cos_sim",
"softmax_with_cross_entropy", "softmax_with_cross_entropy",
...@@ -271,6 +262,9 @@ void AutoMixedPrecisionPass::ApplyImpl(Graph* graph) const { ...@@ -271,6 +262,9 @@ void AutoMixedPrecisionPass::ApplyImpl(Graph* graph) const {
VLOG(4) << "InsertCastOp done"; VLOG(4) << "InsertCastOp done";
RestoreOpOriginType(); RestoreOpOriginType();
VLOG(4) << "RestoreOpOriginType done"; VLOG(4) << "RestoreOpOriginType done";
LOG(INFO) << "The number of ops run at low precision ["
<< op_run_low_precision_.size() << "/" << op_original_type_.size()
<< "]";
} }
void AutoMixedPrecisionPass::SetOpUniqueType() const { void AutoMixedPrecisionPass::SetOpUniqueType() const {
...@@ -314,11 +308,26 @@ void AutoMixedPrecisionPass::ProcessOpWithDtypeAttr() const { ...@@ -314,11 +308,26 @@ void AutoMixedPrecisionPass::ProcessOpWithDtypeAttr() const {
for (const auto& nodes : all_op_nodes_) { for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) { for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type(); auto op_type = op_node->Op()->Type();
if (op_node->Op()->HasAttr("in_dtype")) {
auto* var_node = op_node->inputs[0];
auto* real_var_node = real_vars_[var_node->Var()->Name()];
if (IsFP16AndBFP16(real_var_node->Var()->GetDataType())) {
op_node->Op()->SetAttr(
"in_dtype",
static_cast<int>(framework::TransToProtoVarType(low_precision_)));
op_node->Op()->Flush();
VLOG(4) << "process op with in_dtype attr: " << op_type << " ( "
<< static_cast<int>(real_var_node->Var()->GetDataType())
<< " --->" << static_cast<int>(low_precision_) << " )";
}
}
if (op_run_low_precision_.count(op_type) == 0) continue; if (op_run_low_precision_.count(op_type) == 0) continue;
if (op_node->Op()->HasAttr("dtype")) { if (op_node->Op()->HasAttr("dtype")) {
auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype"); auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
if (IsFloatType(static_cast<VarType::Type>(dtype))) { if (IsFP32AndFP64(static_cast<VarType::Type>(dtype))) {
op_node->Op()->SetAttr( op_node->Op()->SetAttr(
"dtype", "dtype",
static_cast<int>(framework::TransToProtoVarType(low_precision_))); static_cast<int>(framework::TransToProtoVarType(low_precision_)));
...@@ -326,10 +335,9 @@ void AutoMixedPrecisionPass::ProcessOpWithDtypeAttr() const { ...@@ -326,10 +335,9 @@ void AutoMixedPrecisionPass::ProcessOpWithDtypeAttr() const {
VLOG(4) << "process op with dtype attr: " << op_type << " ( " << dtype VLOG(4) << "process op with dtype attr: " << op_type << " ( " << dtype
<< " --->" << static_cast<int>(low_precision_) << " )"; << " --->" << static_cast<int>(low_precision_) << " )";
} }
} } else if (op_node->Op()->HasAttr("out_dtype")) {
if (op_node->Op()->HasAttr("out_dtype")) {
auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype"); auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
if (IsFloatType(static_cast<VarType::Type>(out_dtype))) { if (IsFP32AndFP64(static_cast<VarType::Type>(out_dtype))) {
op_node->Op()->SetAttr( op_node->Op()->SetAttr(
"out_dtype", "out_dtype",
static_cast<int>(framework::TransToProtoVarType(low_precision_))); static_cast<int>(framework::TransToProtoVarType(low_precision_)));
...@@ -358,37 +366,55 @@ void AutoMixedPrecisionPass::GetOpPrecision() const { ...@@ -358,37 +366,55 @@ void AutoMixedPrecisionPass::GetOpPrecision() const {
if (op_node->Op()->HasAttr("dtype")) { if (op_node->Op()->HasAttr("dtype")) {
auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype"); auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
support_low_precision = support_low_precision && support_low_precision =
IsFloatType(static_cast<VarType::Type>(dtype)); support_low_precision &&
IsFP32AndFP64(static_cast<VarType::Type>(dtype));
} else if (op_node->Op()->HasAttr("out_dtype")) { } else if (op_node->Op()->HasAttr("out_dtype")) {
auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype"); auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
support_low_precision = support_low_precision =
support_low_precision && support_low_precision &&
IsFloatType(static_cast<VarType::Type>(out_dtype)); IsFP32AndFP64(static_cast<VarType::Type>(out_dtype));
} else { }
// if op's input var and output var is not dense tensor, the op should
// not run at low precision.
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
if (real_in_var_node->Var()->Persistable()) continue;
// If scale op's "scale" and "bias" attr value exceed the range of fp16
// and bf16, it cannot run at low precision.
if (GetOpOriginalType(op_node->Op()->Type()) == "scale") {
auto scale = op_node->Op()->GetAttrIfExists<float>("scale");
auto bias = op_node->Op()->GetAttrIfExists<float>("bias");
if (low_precision_ == phi::DataType::FLOAT16) {
support_low_precision = support_low_precision =
support_low_precision && support_low_precision &&
(real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR); phi::dtype::isfinite(static_cast<phi::dtype::float16>(scale)) &&
} phi::dtype::isfinite(static_cast<phi::dtype::float16>(bias));
} else if (low_precision_ == phi::DataType::BFLOAT16) {
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
if (real_out_var_node->Var()->Persistable()) continue;
support_low_precision = support_low_precision =
support_low_precision && support_low_precision &&
(real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR); phi::dtype::isfinite(static_cast<phi::dtype::bfloat16>(scale)) &&
phi::dtype::isfinite(static_cast<phi::dtype::bfloat16>(bias));
} }
} }
// if op's input var and output var is not dense tensor, the op should
// not run at low precision.
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
if (real_in_var_node->Var()->Persistable()) continue;
support_low_precision =
support_low_precision &&
(real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR);
}
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
if (real_out_var_node->Var()->Persistable()) continue;
support_low_precision =
support_low_precision &&
(real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR);
}
if (support_low_precision) { if (support_low_precision) {
op_run_low_precision_.insert(op_type); op_run_low_precision_.insert(op_type);
VLOG(4) << "support precision: " << op_type << " run at low precision"; VLOG(4) << "support precision: " << op_type << " run at low precision";
...@@ -438,7 +464,7 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const { ...@@ -438,7 +464,7 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
} }
// when op_1 only support cpu kernel. if op_2's intput var is op_1's // when op_1 only support cpu kernel. if op_2's intput var is op_1's
// output var, then op_2 should not run half. // output var, then op_2 should not run at low precision.
if (GetOpOriginalType(op_type) != "feed" && if (GetOpOriginalType(op_type) != "feed" &&
!GpuKernelSupportPrecision(GetOpOriginalType(op_type), !GpuKernelSupportPrecision(GetOpOriginalType(op_type),
phi::DataType::FLOAT32)) { phi::DataType::FLOAT32)) {
...@@ -596,7 +622,7 @@ void AutoMixedPrecisionPass::SetVarPrecision() const { ...@@ -596,7 +622,7 @@ void AutoMixedPrecisionPass::SetVarPrecision() const {
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()]; auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
auto in_var_name = real_in_var_node->Var()->Name(); auto in_var_name = real_in_var_node->Var()->Name();
if (!IsFloatType(real_in_var_node->Var()->GetDataType())) continue; if (!IsFP32AndFP64(real_in_var_node->Var()->GetDataType())) continue;
if (!VarNodeHasDtype(real_in_var_node)) continue; if (!VarNodeHasDtype(real_in_var_node)) continue;
if (InputVarsNotConvert(op_node, in_var_name)) continue; if (InputVarsNotConvert(op_node, in_var_name)) continue;
...@@ -615,7 +641,7 @@ void AutoMixedPrecisionPass::SetVarPrecision() const { ...@@ -615,7 +641,7 @@ void AutoMixedPrecisionPass::SetVarPrecision() const {
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()]; auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
auto out_var_name = real_out_var_node->Var()->Name(); auto out_var_name = real_out_var_node->Var()->Name();
if (!IsFloatType(real_out_var_node->Var()->GetDataType())) continue; if (!IsFP32AndFP64(real_out_var_node->Var()->GetDataType())) continue;
if (!VarNodeHasDtype(real_out_var_node)) continue; if (!VarNodeHasDtype(real_out_var_node)) continue;
if (OutputVarsNotConvert(op_node, out_var_name)) continue; if (OutputVarsNotConvert(op_node, out_var_name)) continue;
...@@ -655,7 +681,7 @@ void AutoMixedPrecisionPass::ConvertWeightsData() const { ...@@ -655,7 +681,7 @@ void AutoMixedPrecisionPass::ConvertWeightsData() const {
auto var_names = scope->LocalVarNames(); auto var_names = scope->LocalVarNames();
for (const auto& var_name : var_names) { for (const auto& var_name : var_names) {
if (vars_convert_to_low_precision_.count(var_name)) { if (vars_convert_to_low_precision_.count(var_name)) {
VLOG(4) << var_name << "'s data type was convert to half"; VLOG(4) << var_name << "'s data type was convert to low precision";
auto* var = scope->FindLocalVar(var_name); auto* var = scope->FindLocalVar(var_name);
CHECK_EQ(var->IsType<phi::DenseTensor>(), true); CHECK_EQ(var->IsType<phi::DenseTensor>(), true);
...@@ -682,16 +708,18 @@ void AutoMixedPrecisionPass::ConvertWeightsData() const { ...@@ -682,16 +708,18 @@ void AutoMixedPrecisionPass::ConvertWeightsData() const {
} }
} }
} else if (low_precision_ == phi::DataType::BFLOAT16) { } else if (low_precision_ == phi::DataType::BFLOAT16) {
auto* half_data = auto* low_precision_data =
low_precision_tensor.mutable_data<phi::dtype::bfloat16>( low_precision_tensor.mutable_data<phi::dtype::bfloat16>(
phi::CPUPlace{}); phi::CPUPlace{});
for (int64_t i = 0; i < origin_tensor->numel(); i++) { for (int64_t i = 0; i < origin_tensor->numel(); i++) {
if (origin_tensor->dtype() == phi::DataType::FLOAT64) { if (origin_tensor->dtype() == phi::DataType::FLOAT64) {
auto* origin_data = origin_tensor->data<double>(); auto* origin_data = origin_tensor->data<double>();
half_data[i] = static_cast<phi::dtype::bfloat16>(origin_data[i]); low_precision_data[i] =
static_cast<phi::dtype::bfloat16>(origin_data[i]);
} else if (origin_tensor->dtype() == phi::DataType::FLOAT32) { } else if (origin_tensor->dtype() == phi::DataType::FLOAT32) {
auto* origin_data = origin_tensor->data<float>(); auto* origin_data = origin_tensor->data<float>();
half_data[i] = static_cast<phi::dtype::bfloat16>(origin_data[i]); low_precision_data[i] =
static_cast<phi::dtype::bfloat16>(origin_data[i]);
} }
} }
} }
...@@ -731,25 +759,44 @@ void AutoMixedPrecisionPass::InsertCastOp() const { ...@@ -731,25 +759,44 @@ void AutoMixedPrecisionPass::InsertCastOp() const {
VLOG(4) << "process var: " << real_in_var_node->Var()->Name() VLOG(4) << "process var: " << real_in_var_node->Var()->Name()
<< " with type " << in_var_type; << " with type " << in_var_type;
if (IsFloatType(in_var_type) && op_run_low_precision_.count(op_type)) { if (IsFP32AndFP64(in_var_type) &&
DoInsertCastOp(subgraphes_[i], op_run_low_precision_.count(op_type)) {
in_var_node, auto to_type = framework::TransToProtoVarType(low_precision_);
op_node, auto* prev_op =
in_var_type, in_var_node->inputs.empty() ? nullptr : in_var_node->inputs[0];
framework::TransToProtoVarType(low_precision_), if (prev_op && GetOpOriginalType(prev_op->Op()->Type()) == "cast") {
block_desc, in_var_node->Var()->SetDataType(to_type);
&suffix, prev_op->Op()->SetAttr("out_dtype", static_cast<int>(to_type));
&cache); prev_op->Op()->Flush();
} else if (IsHalfType(in_var_type) && } else {
DoInsertCastOp(subgraphes_[i],
in_var_node,
op_node,
in_var_type,
to_type,
block_desc,
&suffix,
&cache);
}
} else if (IsFP16AndBFP16(in_var_type) &&
op_run_low_precision_.count(op_type) == 0) { op_run_low_precision_.count(op_type) == 0) {
DoInsertCastOp(subgraphes_[i], auto to_type = VarType::FP32;
in_var_node, auto* prev_op =
op_node, in_var_node->inputs.empty() ? nullptr : in_var_node->inputs[0];
in_var_type, if (prev_op && GetOpOriginalType(prev_op->Op()->Type()) == "cast") {
VarType::FP32, in_var_node->Var()->SetDataType(to_type);
block_desc, prev_op->Op()->SetAttr("out_dtype", static_cast<int>(to_type));
&suffix, prev_op->Op()->Flush();
&cache); } else {
DoInsertCastOp(subgraphes_[i],
in_var_node,
op_node,
in_var_type,
to_type,
block_desc,
&suffix,
&cache);
}
} }
} }
......
...@@ -164,7 +164,7 @@ TEST(Ernie_gpu_fp16_no_ir, compare_results) { ...@@ -164,7 +164,7 @@ TEST(Ernie_gpu_fp16_no_ir, compare_results) {
} }
float *result = reinterpret_cast<float *>(output.data.data()); float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) { for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 5e-2); EXPECT_NEAR(ref[i * outputs_size + j], result[j], 8e-3);
} }
} }
} }
...@@ -175,8 +175,6 @@ TEST(Ernie_gpu_fp16_with_ir, compare_results) { ...@@ -175,8 +175,6 @@ TEST(Ernie_gpu_fp16_with_ir, compare_results) {
config.SetModel(FLAGS_infer_model); config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kHalf); config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kHalf);
config.SwitchIrOptim(true); config.SwitchIrOptim(true);
// The fc_fuse_pass has diff, which will be repaired later.
config.pass_builder()->DeletePass("fc_fuse_pass");
// There is a problem with the model itself, which has nothing to do with // There is a problem with the model itself, which has nothing to do with
// constant_folding_pass. // constant_folding_pass.
config.pass_builder()->DeletePass("constant_folding_pass"); config.pass_builder()->DeletePass("constant_folding_pass");
...@@ -206,7 +204,7 @@ TEST(Ernie_gpu_fp16_with_ir, compare_results) { ...@@ -206,7 +204,7 @@ TEST(Ernie_gpu_fp16_with_ir, compare_results) {
} }
float *result = reinterpret_cast<float *>(output.data.data()); float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) { for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 5e-2); EXPECT_NEAR(ref[i * outputs_size + j], result[j], 2e-2);
} }
} }
} }
...@@ -243,7 +241,7 @@ TEST(Ernie_gpu_bf16_no_ir, compare_results) { ...@@ -243,7 +241,7 @@ TEST(Ernie_gpu_bf16_no_ir, compare_results) {
} }
float *result = reinterpret_cast<float *>(output.data.data()); float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) { for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 7e-2); EXPECT_NEAR(ref[i * outputs_size + j], result[j], 1e-2);
} }
} }
} }
...@@ -254,8 +252,6 @@ TEST(Ernie_gpu_bf16_with_ir, compare_results) { ...@@ -254,8 +252,6 @@ TEST(Ernie_gpu_bf16_with_ir, compare_results) {
config.SetModel(FLAGS_infer_model); config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kBf16); config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kBf16);
config.SwitchIrOptim(true); config.SwitchIrOptim(true);
// The fc_fuse_pass has diff, which will be repaired later.
config.pass_builder()->DeletePass("fc_fuse_pass");
// There is a problem with the model itself, which has nothing to do with // There is a problem with the model itself, which has nothing to do with
// constant_folding_pass. // constant_folding_pass.
config.pass_builder()->DeletePass("constant_folding_pass"); config.pass_builder()->DeletePass("constant_folding_pass");
...@@ -285,7 +281,7 @@ TEST(Ernie_gpu_bf16_with_ir, compare_results) { ...@@ -285,7 +281,7 @@ TEST(Ernie_gpu_bf16_with_ir, compare_results) {
} }
float *result = reinterpret_cast<float *>(output.data.data()); float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) { for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 7e-2); EXPECT_NEAR(ref[i * outputs_size + j], result[j], 5e-3);
} }
} }
} }
......
...@@ -223,13 +223,7 @@ __global__ void InplaceAddReluAddLayerNormKernel(const float16* y_data, ...@@ -223,13 +223,7 @@ __global__ void InplaceAddReluAddLayerNormKernel(const float16* y_data,
// For layer_norm, reduce to calculate mean and std // For layer_norm, reduce to calculate mean and std
sum_i += static_cast<float>(tmp_3); sum_i += static_cast<float>(tmp_3);
#if defined(PADDLE_WITH_CUDA) && __CUDA_ARCH__ >= 530
square_sum_i += static_cast<float>(__hmul(tmp_3, tmp_3));
#elif defined(PADDLE_WITH_CUDA)
square_sum_i += static_cast<float>(tmp_3) * static_cast<float>(tmp_3); square_sum_i += static_cast<float>(tmp_3) * static_cast<float>(tmp_3);
#else
square_sum_i += static_cast<float>(tmp_3 * tmp_3);
#endif
} }
auto pair = BlockReduce(temp_storage) auto pair = BlockReduce(temp_storage)
.Reduce(PairForLayerNorm<float>(sum_i, square_sum_i), .Reduce(PairForLayerNorm<float>(sum_i, square_sum_i),
...@@ -282,9 +276,9 @@ __global__ void InplaceAddReluAddLayerNormKernel(const float16* y_data, ...@@ -282,9 +276,9 @@ __global__ void InplaceAddReluAddLayerNormKernel(const float16* y_data,
half tmp_0 = __hdiv(__hsub(save_ptr[save_index], mean_i), std_i); half tmp_0 = __hdiv(__hsub(save_ptr[save_index], mean_i), std_i);
half tmp_1 = scale ? __hmul(scale[j], tmp_0) : tmp_0; half tmp_1 = scale ? __hmul(scale[j], tmp_0) : tmp_0;
#else #else
half tmp_0 = static_cast<float>(static_cast<float>(save_ptr[save_index]) + half tmp_0 = static_cast<half>(static_cast<float>(save_ptr[save_index]) -
static_cast<float>(mean_i) / static_cast<float>(mean_i) /
static_cast<float>(std_i)); static_cast<float>(std_i));
half tmp_1 = scale ? static_cast<half>(static_cast<float>(scale[j]) * half tmp_1 = scale ? static_cast<half>(static_cast<float>(scale[j]) *
static_cast<float>(tmp_0)) static_cast<float>(tmp_0))
: tmp_0; : tmp_0;
......
...@@ -164,7 +164,7 @@ __global__ void bias_relu_v4_half2(const int num, ...@@ -164,7 +164,7 @@ __global__ void bias_relu_v4_half2(const int num,
data_vec[unroll_idx] = __hmax2(__half2(0, 0), data_vec[unroll_idx]); data_vec[unroll_idx] = __hmax2(__half2(0, 0), data_vec[unroll_idx]);
#elif __CUDA_ARCH__ >= 530 #elif __CUDA_ARCH__ >= 530
data_vec[unroll_idx] = __hmul2( data_vec[unroll_idx] = __hmul2(
__hgt2(__half2(0, 0), data_vec[unroll_idx]), data_vec[unroll_idx]); __hgt2(data_vec[unroll_idx], __half2(0, 0)), data_vec[unroll_idx]);
#else #else
data_vec[unroll_idx].x = data_vec[unroll_idx].x =
static_cast<int>(static_cast<float>(data_vec[unroll_idx].x) > 0) * static_cast<int>(static_cast<float>(data_vec[unroll_idx].x) > 0) *
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册