未验证 提交 fa8abeec 编写于 作者: C csy0225 提交者: GitHub

[XPU] Fix instance_norm、conv2d_xpu、inplace optimizer bugs. (#52627)

上级 710b664d
...@@ -620,6 +620,15 @@ bool AutoMixedPrecisionPass::InputVarsNotConvert( ...@@ -620,6 +620,15 @@ bool AutoMixedPrecisionPass::InputVarsNotConvert(
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true; return true;
} }
} else if (GetOpOriginalType(op_desc->Type()) == "instance_norm") {
auto vecs = op_desc->Input("Bias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Scale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
} }
} }
......
...@@ -36,6 +36,12 @@ bool InplaceOpVarPass::IsValidInplaceOp( ...@@ -36,6 +36,12 @@ bool InplaceOpVarPass::IsValidInplaceOp(
if (var_node->Name() != x_name) continue; if (var_node->Name() != x_name) continue;
if (var_node->Var()->Persistable() || var_node->outputs.size() != 1) if (var_node->Var()->Persistable() || var_node->outputs.size() != 1)
return false; return false;
// The op type in front of in_var_node should not be feed.
for (auto* pre_op : var_node->inputs) {
if (pre_op->Op()->Type() == "feed") {
return false;
}
}
} }
// in/out_var_node should be not used in multi graphs. // in/out_var_node should be not used in multi graphs.
......
...@@ -416,7 +416,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -416,7 +416,8 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32, XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::INT64})}, phi::DataType::INT64})},
{"instance_norm", XPUKernelSet({phi::DataType::FLOAT32})}, {"instance_norm",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"instance_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"instance_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"iou_similarity", XPUKernelSet({phi::DataType::FLOAT32})}, {"iou_similarity", XPUKernelSet({phi::DataType::FLOAT32})},
{"label_smooth", XPUKernelSet({phi::DataType::FLOAT32})}, {"label_smooth", XPUKernelSet({phi::DataType::FLOAT32})},
......
...@@ -186,6 +186,5 @@ PD_REGISTER_KERNEL(assign_value, ...@@ -186,6 +186,5 @@ PD_REGISTER_KERNEL(assign_value,
int, int,
float, float,
double, double,
int64_t, int64_t) {}
phi::dtype::float16) {}
#endif #endif
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/instance_norm_kernel.h" #include "paddle/phi/kernels/instance_norm_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi { namespace phi {
...@@ -37,9 +38,31 @@ void InstanceNormKernel(const Context& dev_ctx, ...@@ -37,9 +38,31 @@ void InstanceNormKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(y); dev_ctx.template Alloc<T>(y);
dev_ctx.template Alloc<float>(saved_mean); dev_ctx.template Alloc<float>(saved_mean);
dev_ctx.template Alloc<float>(saved_var); dev_ctx.template Alloc<float>(saved_var);
// scale
const auto scale_ptr = scale.get_ptr(); const auto scale_ptr = scale.get_ptr();
const auto bias_ptr = bias.get_ptr(); const float* scale_data_fp32 = nullptr;
DenseTensor scale_data;
if (scale_ptr == nullptr) {
scale_data.Resize({c});
dev_ctx.template Alloc<float>(&scale_data);
phi::funcs::set_constant(dev_ctx, &scale_data, static_cast<float>(1));
scale_data_fp32 = scale_data.data<float>();
} else {
// no need to cast
scale_data_fp32 = scale_ptr->data<float>();
}
// bias
const float* bias_data_fp32 = nullptr;
const auto* bias_ptr = bias.get_ptr();
DenseTensor bias_data;
if (bias_ptr == nullptr) {
bias_data.Resize({c});
dev_ctx.template Alloc<float>(&bias_data);
phi::funcs::set_constant(dev_ctx, &bias_data, static_cast<float>(0));
bias_data_fp32 = bias_data.data<float>();
} else {
bias_data_fp32 = bias_ptr->data<float>();
}
int r = xpu::instance_norm(dev_ctx.x_context(), int r = xpu::instance_norm(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()), reinterpret_cast<const XPUType*>(x.data<T>()),
...@@ -49,8 +72,8 @@ void InstanceNormKernel(const Context& dev_ctx, ...@@ -49,8 +72,8 @@ void InstanceNormKernel(const Context& dev_ctx,
h, h,
w, w,
epsilon, epsilon,
scale_ptr->data<float>(), scale_data_fp32,
bias_ptr->data<float>(), bias_data_fp32,
saved_mean->data<float>(), saved_mean->data<float>(),
saved_var->data<float>(), saved_var->data<float>(),
true); true);
...@@ -60,5 +83,9 @@ void InstanceNormKernel(const Context& dev_ctx, ...@@ -60,5 +83,9 @@ void InstanceNormKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(instance_norm,
instance_norm, XPU, ALL_LAYOUT, phi::InstanceNormKernel, float) {} XPU,
ALL_LAYOUT,
phi::InstanceNormKernel,
float,
phi::dtype::float16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册