From fa8abeec1645a45eefb7681bb1cf1dc7f2332b30 Mon Sep 17 00:00:00 2001 From: csy0225 <78470701+csy0225@users.noreply.github.com> Date: Thu, 13 Apr 2023 10:15:55 +0800 Subject: [PATCH] =?UTF-8?q?[XPU]=20Fix=20instance=5Fnorm=E3=80=81conv2d=5F?= =?UTF-8?q?xpu=E3=80=81inplace=20optimizer=20bugs.=20(#52627)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../framework/ir/auto_mixed_precision_pass.cc | 9 +++++ .../fluid/framework/ir/inplace_op_var_pass.cc | 6 +++ paddle/phi/backends/xpu/xpu2_op_list.cc | 3 +- paddle/phi/kernels/assign_kernel.cc | 3 +- .../phi/kernels/xpu/instance_norm_kernel.cc | 39 ++++++++++++++++--- 5 files changed, 51 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc index 061e2432eed..fa570394d80 100644 --- a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc +++ b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc @@ -620,6 +620,15 @@ bool AutoMixedPrecisionPass::InputVarsNotConvert( if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { return true; } + } else if (GetOpOriginalType(op_desc->Type()) == "instance_norm") { + auto vecs = op_desc->Input("Bias"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Input("Scale"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } } } diff --git a/paddle/fluid/framework/ir/inplace_op_var_pass.cc b/paddle/fluid/framework/ir/inplace_op_var_pass.cc index 0ccac637be3..5bbe980daab 100644 --- a/paddle/fluid/framework/ir/inplace_op_var_pass.cc +++ b/paddle/fluid/framework/ir/inplace_op_var_pass.cc @@ -36,6 +36,12 @@ bool InplaceOpVarPass::IsValidInplaceOp( if (var_node->Name() != x_name) continue; if (var_node->Var()->Persistable() || var_node->outputs.size() != 1) return false; + // The op type in front of in_var_node should not be feed. + for (auto* pre_op : var_node->inputs) { + if (pre_op->Op()->Type() == "feed") { + return false; + } + } } // in/out_var_node should be not used in multi graphs. diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 419bba4da74..ceab85cf551 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -416,7 +416,8 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32, phi::DataType::INT64})}, - {"instance_norm", XPUKernelSet({phi::DataType::FLOAT32})}, + {"instance_norm", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"instance_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"iou_similarity", XPUKernelSet({phi::DataType::FLOAT32})}, {"label_smooth", XPUKernelSet({phi::DataType::FLOAT32})}, diff --git a/paddle/phi/kernels/assign_kernel.cc b/paddle/phi/kernels/assign_kernel.cc index 09046ef4556..9aba3bcb78f 100644 --- a/paddle/phi/kernels/assign_kernel.cc +++ b/paddle/phi/kernels/assign_kernel.cc @@ -186,6 +186,5 @@ PD_REGISTER_KERNEL(assign_value, int, float, double, - int64_t, - phi::dtype::float16) {} + int64_t) {} #endif diff --git a/paddle/phi/kernels/xpu/instance_norm_kernel.cc b/paddle/phi/kernels/xpu/instance_norm_kernel.cc index 293397f66ee..1631d0ccbee 100644 --- a/paddle/phi/kernels/xpu/instance_norm_kernel.cc +++ b/paddle/phi/kernels/xpu/instance_norm_kernel.cc @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/instance_norm_kernel.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -37,9 +38,31 @@ void InstanceNormKernel(const Context& dev_ctx, dev_ctx.template Alloc(y); dev_ctx.template Alloc(saved_mean); dev_ctx.template Alloc(saved_var); - + // scale const auto scale_ptr = scale.get_ptr(); - const auto bias_ptr = bias.get_ptr(); + const float* scale_data_fp32 = nullptr; + DenseTensor scale_data; + if (scale_ptr == nullptr) { + scale_data.Resize({c}); + dev_ctx.template Alloc(&scale_data); + phi::funcs::set_constant(dev_ctx, &scale_data, static_cast(1)); + scale_data_fp32 = scale_data.data(); + } else { + // no need to cast + scale_data_fp32 = scale_ptr->data(); + } + // bias + const float* bias_data_fp32 = nullptr; + const auto* bias_ptr = bias.get_ptr(); + DenseTensor bias_data; + if (bias_ptr == nullptr) { + bias_data.Resize({c}); + dev_ctx.template Alloc(&bias_data); + phi::funcs::set_constant(dev_ctx, &bias_data, static_cast(0)); + bias_data_fp32 = bias_data.data(); + } else { + bias_data_fp32 = bias_ptr->data(); + } int r = xpu::instance_norm(dev_ctx.x_context(), reinterpret_cast(x.data()), @@ -49,8 +72,8 @@ void InstanceNormKernel(const Context& dev_ctx, h, w, epsilon, - scale_ptr->data(), - bias_ptr->data(), + scale_data_fp32, + bias_data_fp32, saved_mean->data(), saved_var->data(), true); @@ -60,5 +83,9 @@ void InstanceNormKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - instance_norm, XPU, ALL_LAYOUT, phi::InstanceNormKernel, float) {} +PD_REGISTER_KERNEL(instance_norm, + XPU, + ALL_LAYOUT, + phi::InstanceNormKernel, + float, + phi::dtype::float16) {} -- GitLab