未验证 提交 fa8abeec 编写于 作者: C csy0225 提交者: GitHub

[XPU] Fix instance_norm、conv2d_xpu、inplace optimizer bugs. (#52627)

上级 710b664d
......@@ -620,6 +620,15 @@ bool AutoMixedPrecisionPass::InputVarsNotConvert(
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
} else if (GetOpOriginalType(op_desc->Type()) == "instance_norm") {
auto vecs = op_desc->Input("Bias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Scale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
}
}
......
......@@ -36,6 +36,12 @@ bool InplaceOpVarPass::IsValidInplaceOp(
if (var_node->Name() != x_name) continue;
if (var_node->Var()->Persistable() || var_node->outputs.size() != 1)
return false;
// The op type in front of in_var_node should not be feed.
for (auto* pre_op : var_node->inputs) {
if (pre_op->Op()->Type() == "feed") {
return false;
}
}
}
// in/out_var_node should be not used in multi graphs.
......
......@@ -416,7 +416,8 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64})},
{"instance_norm", XPUKernelSet({phi::DataType::FLOAT32})},
{"instance_norm",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"instance_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"iou_similarity", XPUKernelSet({phi::DataType::FLOAT32})},
{"label_smooth", XPUKernelSet({phi::DataType::FLOAT32})},
......
......@@ -186,6 +186,5 @@ PD_REGISTER_KERNEL(assign_value,
int,
float,
double,
int64_t,
phi::dtype::float16) {}
int64_t) {}
#endif
......@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/instance_norm_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
......@@ -37,9 +38,31 @@ void InstanceNormKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(y);
dev_ctx.template Alloc<float>(saved_mean);
dev_ctx.template Alloc<float>(saved_var);
// scale
const auto scale_ptr = scale.get_ptr();
const auto bias_ptr = bias.get_ptr();
const float* scale_data_fp32 = nullptr;
DenseTensor scale_data;
if (scale_ptr == nullptr) {
scale_data.Resize({c});
dev_ctx.template Alloc<float>(&scale_data);
phi::funcs::set_constant(dev_ctx, &scale_data, static_cast<float>(1));
scale_data_fp32 = scale_data.data<float>();
} else {
// no need to cast
scale_data_fp32 = scale_ptr->data<float>();
}
// bias
const float* bias_data_fp32 = nullptr;
const auto* bias_ptr = bias.get_ptr();
DenseTensor bias_data;
if (bias_ptr == nullptr) {
bias_data.Resize({c});
dev_ctx.template Alloc<float>(&bias_data);
phi::funcs::set_constant(dev_ctx, &bias_data, static_cast<float>(0));
bias_data_fp32 = bias_data.data<float>();
} else {
bias_data_fp32 = bias_ptr->data<float>();
}
int r = xpu::instance_norm(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
......@@ -49,8 +72,8 @@ void InstanceNormKernel(const Context& dev_ctx,
h,
w,
epsilon,
scale_ptr->data<float>(),
bias_ptr->data<float>(),
scale_data_fp32,
bias_data_fp32,
saved_mean->data<float>(),
saved_var->data<float>(),
true);
......@@ -60,5 +83,9 @@ void InstanceNormKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
instance_norm, XPU, ALL_LAYOUT, phi::InstanceNormKernel, float) {}
PD_REGISTER_KERNEL(instance_norm,
XPU,
ALL_LAYOUT,
phi::InstanceNormKernel,
float,
phi::dtype::float16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册