未验证 提交 0a59825e 编写于 作者: W wz1qqx 提交者: GitHub

[XPU] Optimize fp16 xpu models (#53523)

上级 186f5e0f
......@@ -34,6 +34,25 @@ class Scope;
} // namespace framework
} // namespace paddle
namespace {
template <typename T1, typename T2>
void ConvertTensorType(phi::DenseTensor* tensor) {
phi::DenseTensor tmp_tensor;
tmp_tensor.set_type(phi::CppTypeToDataType<T2>::Type());
tmp_tensor.Resize(tensor->dims());
auto* tmp_data = tmp_tensor.mutable_data<T2>(paddle::platform::CPUPlace());
auto* data = tensor->mutable_data<T1>(paddle::platform::CPUPlace());
for (int i = 0; i < tensor->numel(); i++) {
tmp_data[i] = static_cast<T2>(data[i]);
}
tensor->clear();
paddle::framework::TensorCopySync(
tmp_tensor, paddle::platform::CPUPlace(), tensor);
}
} // namespace
namespace paddle {
namespace framework {
namespace ir {
......@@ -157,15 +176,23 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern,
if (with_bn_) {
ew_bias_add_out->assert_is_op_input("batch_norm", "X");
bn_bias = pattern->NewNode(bn_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Bias")
->assert_has_n_outputs(1);
bn_mean = pattern->NewNode(bn_mean_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Mean")
->assert_has_n_outputs(1);
bn_scale = pattern->NewNode(bn_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Scale")
->assert_has_n_outputs(1);
bn_var = pattern->NewNode(bn_var_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Variance")
->assert_has_n_outputs(1);
bn = pattern->NewNode(bn_repr())->assert_is_op("batch_norm");
......@@ -420,13 +447,17 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph,
// recompute bias and weight for conv2d_xpu op
auto* filter_t =
scope->FindVar(conv_filter->Name())->GetMutable<phi::DenseTensor>();
// conv_filter fp16 --> fp32
auto tensor_type = filter_t->dtype();
if (tensor_type == phi::DataType::FLOAT16) {
ConvertTensorType<float16, float>(filter_t);
}
auto filter_dims = filter_t->dims();
bool has_bias = with_bn || with_conv_bias;
bool has_branch = with_branch_x || with_branch_y;
// Create conv_fusion_bias (conv bias) variable
Node* fusion_bias_node = nullptr;
if (has_bias) {
if (ew_bias_add != nullptr) {
if (with_conv_bias) {
auto* ew_bias_add_y_t = scope->FindVar(ew_bias_add_y->Name())
->GetMutable<phi::DenseTensor>();
auto ew_bias_add_y_dims = ew_bias_add_y_t->dims();
......@@ -439,7 +470,7 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph,
filter_dims[0]));
PrepareBias(graph, scope, block, ew_bias_add_y, &fusion_bias_node);
}
if (bn != nullptr) {
if (with_bn) {
auto bn_bias_t =
scope->Var(bn_bias->Name())->GetMutable<phi::DenseTensor>();
PADDLE_ENFORCE_EQ(filter_dims[0],
......@@ -469,7 +500,7 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph,
auto filter_len = filter_t->numel();
auto filter_stride = filter_len / mean_len;
float epsilon = PADDLE_GET_CONST(float, bn->Op()->GetAttr("epsilon"));
if (fusion_bias_node == nullptr) { // prev node is conv
if (!with_conv_bias) { // prev node is conv
PrepareBias(graph, scope, block, bn_bias, &fusion_bias_node);
}
auto fusion_bias_t = scope->Var(fusion_bias_node->Name())
......@@ -477,10 +508,10 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph,
float* fusion_bias_ptr =
fusion_bias_t->mutable_data<float>(paddle::platform::CPUPlace());
// recompute bias and weights
if (ew_bias_add == nullptr) {
if (!with_conv_bias) { // prev node is conv
for (int i = 0; i < mean_len; ++i) {
bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon);
fusion_bias_ptr[i] += (0.f - bn_mean_ptr[i]) * bn_scale_ptr[i];
fusion_bias_ptr[i] += (0.0f - bn_mean_ptr[i]) * bn_scale_ptr[i];
for (int j = 0; j < filter_stride; j++) {
filter_ptr[i * filter_stride + j] *= bn_scale_ptr[i];
}
......@@ -488,21 +519,25 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph,
} else {
for (int i = 0; i < mean_len; ++i) {
bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon);
bn_bias_ptr[i] +=
fusion_bias_ptr[i] =
bn_bias_ptr[i] +
(fusion_bias_ptr[i] - bn_mean_ptr[i]) * bn_scale_ptr[i];
for (int j = 0; j < filter_stride; j++) {
filter_ptr[i * filter_stride + j] *= bn_scale_ptr[i];
}
}
memcpy(fusion_bias_ptr, bn_bias_ptr, mean_len * sizeof(float));
}
}
}
if (tensor_type == phi::DataType::FLOAT16) {
ConvertTensorType<float, float16>(filter_t);
}
// filter max
Node* filter_int16 = nullptr;
Node* filter_max = nullptr;
PrepareWeight<int16_t>(
graph, scope, block, conv_filter, &filter_int16, &filter_max, false);
bool has_branch = with_branch_x || with_branch_y;
// output && output max
std::string conv2d_xpu_out_name;
if (!act_type.empty()) {
......
......@@ -5,14 +5,14 @@
# otherwise the operator only could be used in static mode.
- op : conv2d_xpu
args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, bool has_bias, bool has_branch, int act_type, float act_param)
args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, Tensor branch_max, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, bool has_bias, bool has_branch, int act_type, float act_param)
output : Tensor(out), Tensor(out_max)
infer_meta :
func : Conv2dXPUInferMeta
kernel :
func : conv2d_xpu
data_type : x
optional : bias, branch, x_max
optional : bias, branch, branch_max ,x_max
- op : embedding_with_eltwise_add_xpu
args : (Tensor[] ids, Tensor[] tables, int64_t padding_idx)
......
......@@ -58,7 +58,8 @@ XPUOpMap& get_kl2_ops() {
{"atan_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"batch_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"batch_norm", XPUKernelSet({phi::DataType::FLOAT32})},
{"batch_norm",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"bmm", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"bmm_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
......@@ -401,7 +402,8 @@ XPUOpMap& get_kl2_ops() {
{"grid_sampler_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"grid_sampler", XPUKernelSet({phi::DataType::FLOAT32})},
{"hard_sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"hard_sigmoid", XPUKernelSet({phi::DataType::FLOAT32})},
{"hard_sigmoid",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"hard_swish_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"hard_swish", XPUKernelSet({phi::DataType::FLOAT32})},
......@@ -438,7 +440,8 @@ XPUOpMap& get_kl2_ops() {
{"layer_norm",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"leaky_relu_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"leaky_relu", XPUKernelSet({phi::DataType::FLOAT32})},
{"leaky_relu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"less_equal",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
......@@ -554,7 +557,8 @@ XPUOpMap& get_kl2_ops() {
{"reduce_max_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_max", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_mean_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_mean", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_mean",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"reduce_min_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_min", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_prod", XPUKernelSet({phi::DataType::FLOAT32})},
......@@ -646,7 +650,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT16})},
{"sigmoid", XPUKernelSet({phi::DataType::FLOAT32})},
{"sigmoid",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"sign", XPUKernelSet({phi::DataType::FLOAT32})},
{"slice_grad",
......@@ -676,7 +681,7 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT32})},
{"sqrt", XPUKernelSet({phi::DataType::FLOAT32})},
{"sqrt", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"sqrt_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"square_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
......@@ -733,7 +738,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT16,
phi::DataType::INT32})},
{"sum", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"swish", XPUKernelSet({phi::DataType::FLOAT32})},
{"swish", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"swish_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"tanh_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
......
......@@ -41,6 +41,7 @@ void Conv2dXPUInferMeta(const MetaTensor& x,
const MetaTensor& filter_max,
const MetaTensor& bias,
const MetaTensor& branch,
const MetaTensor& branch_max,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
......
......@@ -28,6 +28,7 @@ void Conv2dXPUInferMeta(const MetaTensor& x,
const MetaTensor& filter_max,
const MetaTensor& bias,
const MetaTensor& branch,
const MetaTensor& branch_max,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
......
......@@ -63,7 +63,8 @@ PD_REGISTER_KERNEL(swish,
#if defined PADDLE_WITH_XPU
PD_REGISTER_KERNEL(relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float) {}
PD_REGISTER_KERNEL(swish, XPU, ALL_LAYOUT, phi::SwishKernel, float) {}
PD_REGISTER_KERNEL(
swish, XPU, ALL_LAYOUT, phi::SwishKernel, float, phi::dtype::float16) {}
#endif
#ifdef PADDLE_WITH_MKLDNN
......
......@@ -106,6 +106,10 @@ PD_REGISTER_KERNEL(batch_norm_infer,
phi::dtype::float16) {}
#endif
#ifdef PADDLE_WITH_XPU
PD_REGISTER_KERNEL(
batch_norm_infer, XPU, ALL_LAYOUT, phi::BatchNormInferKernel, float) {}
PD_REGISTER_KERNEL(batch_norm_infer,
XPU,
ALL_LAYOUT,
phi::BatchNormInferKernel,
float,
phi::dtype::float16) {}
#endif
......@@ -27,6 +27,7 @@ void Conv2dXPUKernel(const Context& ctx,
const DenseTensor& filter_max,
const paddle::optional<DenseTensor>& bias,
const paddle::optional<DenseTensor>& branch,
const paddle::optional<DenseTensor>& branch_max,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
......@@ -69,10 +70,12 @@ void Conv2dXPUKernel(const Context& ctx,
branch.get_ptr() == nullptr
? nullptr
: reinterpret_cast<const XPUType*>(branch.get_ptr()->data<T>());
const float* branch_max_data = branch_max.get_ptr() == nullptr
? nullptr
: branch_max.get_ptr()->data<float>();
const float* bias_data =
bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data<float>();
auto* out_data = reinterpret_cast<XPUType*>(ctx.template Alloc<T>(out));
xpu::Activation_t act(static_cast<xpu::Activation_t::act_enum>(act_type));
if (act_type == xpu::Activation_t::LEAKY_RELU) {
act.leaky_alpha = act_param;
......@@ -102,7 +105,7 @@ void Conv2dXPUKernel(const Context& ctx,
/* const float* bias */ bias_data,
/* const TY* branch */ branch_data,
/* const baidu::xpu::api::Activation_t& act */ act,
/* const float* branch_maxptr */ nullptr,
/* const float* branch_maxptr */ branch_max_data,
/* const float* scale */ nullptr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_xpu");
}
......
......@@ -195,6 +195,13 @@ void PowKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& factor,
DenseTensor* out) {
// using XPUType = typename XPUTypeTrait<T>::Type;
// // dev_ctx.template Alloc<T>(out);
// auto pow_factor = factor.to<T>();
// const auto* x_data = reinterpret_cast<const XPUType*>(x.data<T>());
// auto* y_data = reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(out));
// // const T* x_data = x.data<T>();
// // T* y_data = out->data<T>();
dev_ctx.template Alloc<T>(out);
float pow_factor = factor.to<float>();
const T* x_data = x.data<T>();
......@@ -534,9 +541,28 @@ PD_REGISTER_KERNEL(
relu, XPU, ALL_LAYOUT, phi::ReluKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(
silu, XPU, ALL_LAYOUT, phi::SiluKernel, float, phi::dtype::float16) {}
#define PD_REGISTER_ACTIVATION_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {}
PD_REGISTER_KERNEL(
sigmoid, XPU, ALL_LAYOUT, phi::SigmoidKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(swish_raw,
XPU,
ALL_LAYOUT,
phi::SwishRawKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(hard_sigmoid,
XPU,
ALL_LAYOUT,
phi::HardSigmoidKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(leaky_relu,
XPU,
ALL_LAYOUT,
phi::LeakyReluKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(
sqrt, XPU, ALL_LAYOUT, phi::SqrtKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(
tanh, XPU, ALL_LAYOUT, phi::TanhKernel, float, phi::dtype::float16) {}
......@@ -547,18 +573,21 @@ PD_REGISTER_KERNEL(
PD_REGISTER_KERNEL(
log, XPU, ALL_LAYOUT, phi::LogKernel, float, phi::dtype::float16) {}
#define PD_REGISTER_ACTIVATION_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {}
PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) // no grad
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel)
// PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
// PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL(pow, PowKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel)
PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel)
// PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
// PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
// PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel)
PD_REGISTER_ACTIVATION_KERNEL(sin, SinKernel)
PD_REGISTER_ACTIVATION_KERNEL(cos, CosKernel)
......@@ -39,6 +39,7 @@ void BatchNormKernel(const Context& dev_ctx,
DenseTensor* saved_mean,
DenseTensor* saved_variance,
DenseTensor* reserve_space) {
using XPUType = typename XPUTypeTrait<T>::Type;
bool test_mode = is_test && (!trainable_statistics);
bool global_stats = test_mode || use_global_stats;
const auto data_layout = phi::StringToDataLayout(data_layout_str);
......@@ -68,12 +69,12 @@ void BatchNormKernel(const Context& dev_ctx,
W = W * D;
const auto* x_data = x.data<T>();
const auto* x_data = reinterpret_cast<const XPUType*>(x.data<T>());
const auto* scale_data = scale.data<float>();
const auto* bias_data = bias.data<float>();
// alloc memory
auto* y_data = dev_ctx.template Alloc<T>(y);
auto* y_data = reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(y));
dev_ctx.template Alloc<float>(mean_out);
dev_ctx.template Alloc<float>(variance_out);
dev_ctx.template Alloc<float>(saved_mean);
......@@ -95,43 +96,48 @@ void BatchNormKernel(const Context& dev_ctx,
auto* saved_mean_data = saved_mean->data<float>();
auto* saved_variance_data = saved_variance->data<float>();
int r = xpu::batch_norm<T>(dev_ctx.x_context(),
x_data,
y_data,
N,
C,
H,
W,
epsilon,
momentum,
scale_data,
bias_data,
saved_mean_data,
saved_variance_data,
mean_out_data,
variance_out_data,
is_nchw);
int r = xpu::batch_norm<XPUType>(dev_ctx.x_context(),
x_data,
y_data,
N,
C,
H,
W,
epsilon,
momentum,
scale_data,
bias_data,
saved_mean_data,
saved_variance_data,
mean_out_data,
variance_out_data,
is_nchw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm");
} else {
const auto* mean_data = mean.data<float>();
const auto* variance_data = variance.data<float>();
int r = xpu::batch_norm_infer(dev_ctx.x_context(),
x_data,
y_data,
N,
C,
H,
W,
epsilon,
scale_data,
bias_data,
mean_data,
variance_data,
is_nchw);
int r = xpu::batch_norm_infer<XPUType>(dev_ctx.x_context(),
x_data,
y_data,
N,
C,
H,
W,
epsilon,
scale_data,
bias_data,
mean_data,
variance_data,
is_nchw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_infer");
}
}
} // namespace phi
PD_REGISTER_KERNEL(batch_norm, XPU, ALL_LAYOUT, phi::BatchNormKernel, float) {}
PD_REGISTER_KERNEL(batch_norm,
XPU,
ALL_LAYOUT,
phi::BatchNormKernel,
float,
phi::dtype::float16) {}
......@@ -50,4 +50,6 @@ void MeanRawKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(mean_raw, XPU, ALL_LAYOUT, phi::MeanRawKernel, float) {}
PD_REGISTER_KERNEL(
mean_raw, XPU, ALL_LAYOUT, phi::MeanRawKernel, float, phi::dtype::float16) {
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册