diff --git a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc index 061e2432eed1a3597a9e567e439f978b37dbf01b..fa570394d80f3bda8e3ceb4c4f1a71ec764851a0 100644 --- a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc +++ b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc @@ -620,6 +620,15 @@ bool AutoMixedPrecisionPass::InputVarsNotConvert( if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { return true; } + } else if (GetOpOriginalType(op_desc->Type()) == "instance_norm") { + auto vecs = op_desc->Input("Bias"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Input("Scale"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } } } diff --git a/paddle/fluid/framework/ir/inplace_op_var_pass.cc b/paddle/fluid/framework/ir/inplace_op_var_pass.cc index 0ccac637be36c741f5b5729b71794635cf0e83ac..5bbe980daaba7ef7d525b83e60d10c91fcea0829 100644 --- a/paddle/fluid/framework/ir/inplace_op_var_pass.cc +++ b/paddle/fluid/framework/ir/inplace_op_var_pass.cc @@ -36,6 +36,12 @@ bool InplaceOpVarPass::IsValidInplaceOp( if (var_node->Name() != x_name) continue; if (var_node->Var()->Persistable() || var_node->outputs.size() != 1) return false; + // The op type in front of in_var_node should not be feed. + for (auto* pre_op : var_node->inputs) { + if (pre_op->Op()->Type() == "feed") { + return false; + } + } } // in/out_var_node should be not used in multi graphs. diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 419bba4da7417acf7c80fae7ec3fab09aed783b7..ceab85cf551d6bfe597480e7a92e65a557b04b9a 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -416,7 +416,8 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32, phi::DataType::INT64})}, - {"instance_norm", XPUKernelSet({phi::DataType::FLOAT32})}, + {"instance_norm", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"instance_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"iou_similarity", XPUKernelSet({phi::DataType::FLOAT32})}, {"label_smooth", XPUKernelSet({phi::DataType::FLOAT32})}, diff --git a/paddle/phi/kernels/assign_kernel.cc b/paddle/phi/kernels/assign_kernel.cc index 09046ef45565e3d1744805e13838be89a5be59eb..9aba3bcb78faf1639add0cd8bc301ce5a602862e 100644 --- a/paddle/phi/kernels/assign_kernel.cc +++ b/paddle/phi/kernels/assign_kernel.cc @@ -186,6 +186,5 @@ PD_REGISTER_KERNEL(assign_value, int, float, double, - int64_t, - phi::dtype::float16) {} + int64_t) {} #endif diff --git a/paddle/phi/kernels/xpu/instance_norm_kernel.cc b/paddle/phi/kernels/xpu/instance_norm_kernel.cc index 293397f66ee37def5f5953e3faccc39b9f0dc9b8..1631d0ccbeed8de53272cfd319a34c001d3ac934 100644 --- a/paddle/phi/kernels/xpu/instance_norm_kernel.cc +++ b/paddle/phi/kernels/xpu/instance_norm_kernel.cc @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/instance_norm_kernel.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -37,9 +38,31 @@ void InstanceNormKernel(const Context& dev_ctx, dev_ctx.template Alloc(y); dev_ctx.template Alloc(saved_mean); dev_ctx.template Alloc(saved_var); - + // scale const auto scale_ptr = scale.get_ptr(); - const auto bias_ptr = bias.get_ptr(); + const float* scale_data_fp32 = nullptr; + DenseTensor scale_data; + if (scale_ptr == nullptr) { + scale_data.Resize({c}); + dev_ctx.template Alloc(&scale_data); + phi::funcs::set_constant(dev_ctx, &scale_data, static_cast(1)); + scale_data_fp32 = scale_data.data(); + } else { + // no need to cast + scale_data_fp32 = scale_ptr->data(); + } + // bias + const float* bias_data_fp32 = nullptr; + const auto* bias_ptr = bias.get_ptr(); + DenseTensor bias_data; + if (bias_ptr == nullptr) { + bias_data.Resize({c}); + dev_ctx.template Alloc(&bias_data); + phi::funcs::set_constant(dev_ctx, &bias_data, static_cast(0)); + bias_data_fp32 = bias_data.data(); + } else { + bias_data_fp32 = bias_ptr->data(); + } int r = xpu::instance_norm(dev_ctx.x_context(), reinterpret_cast(x.data()), @@ -49,8 +72,8 @@ void InstanceNormKernel(const Context& dev_ctx, h, w, epsilon, - scale_ptr->data(), - bias_ptr->data(), + scale_data_fp32, + bias_data_fp32, saved_mean->data(), saved_var->data(), true); @@ -60,5 +83,9 @@ void InstanceNormKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - instance_norm, XPU, ALL_LAYOUT, phi::InstanceNormKernel, float) {} +PD_REGISTER_KERNEL(instance_norm, + XPU, + ALL_LAYOUT, + phi::InstanceNormKernel, + float, + phi::dtype::float16) {}