未验证 提交 ba9a22db 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add bfloat16 support for several operators and apis. (#52696)

* Cherry-pick the register of bfloat16 for amp_kernel, pull request #45541.

* Cherry-pick the master_grad support of adamw, pull request #51141.

* add bf16 for some ops in static mode (#51582)

* Add bfloat16 support for some api in static mode.

* Fix codestyle.

* Revert the change of layer_function_generator.py.

---------
Co-authored-by: shaojie_wang's avatarShaojie WANG <wsjmessi@163.com>
上级 95c3d613
......@@ -29,7 +29,7 @@
#include "paddle/phi/kernels/funcs/for_range.h"
namespace phi {
template <typename T, typename MT>
template <typename T, typename TG, typename MT>
__global__ void AdamWKernelREG(MT beta1,
MT beta2,
MT epsilon,
......@@ -42,7 +42,7 @@ __global__ void AdamWKernelREG(MT beta1,
const MT* moment2,
MT* moment2_out,
const MT* lr_,
const T* grad,
const TG* grad,
const T* param,
T* param_out,
const MT* master_param,
......@@ -78,7 +78,7 @@ __global__ void AdamWKernelREG(MT beta1,
}
}
template <typename T, typename MT>
template <typename T, typename TG, typename MT>
__global__ void AdamWKernelMEM(MT beta1,
MT beta2,
MT epsilon,
......@@ -91,7 +91,7 @@ __global__ void AdamWKernelMEM(MT beta1,
const MT* moment2,
MT* moment2_out,
const MT* lr_,
const T* grad,
const TG* grad,
const T* param,
T* param_out,
const MT* master_param,
......@@ -167,6 +167,8 @@ void AdamwDenseKernel(const Context& dev_ctx,
DenseTensor* master_param_outs) {
using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
const auto grad_type = grad.dtype();
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
MPDType coeff_ = static_cast<MPDType>(coeff);
......@@ -191,8 +193,10 @@ void AdamwDenseKernel(const Context& dev_ctx,
phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out);
phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out);
phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out);
phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out);
phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out);
if (!use_global_beta_pow) {
phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out);
phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out);
}
return;
}
......@@ -233,25 +237,49 @@ void AdamwDenseKernel(const Context& dev_ctx,
if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) {
// Compute with betapow in REG
AdamWKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
*beta1_pow.data<MPDType>(),
*beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<T>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
if (grad_type == phi::DataType::FLOAT32)
AdamWKernelREG<T, float, MPDType>
<<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
*beta1_pow.data<MPDType>(),
*beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<float>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
else
AdamWKernelREG<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
*beta1_pow.data<MPDType>(),
*beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<T>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
if (!use_global_beta_pow) {
// Cpu update
dev_ctx.template HostAlloc<MPDType>(beta1_pow_out)[0] =
......@@ -260,28 +288,50 @@ void AdamwDenseKernel(const Context& dev_ctx,
beta2_ * beta2_pow.data<MPDType>()[0];
}
} else {
AdamWKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
beta1_pow.data<MPDType>(),
beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<T>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
if (grad_type == phi::DataType::FLOAT32)
AdamWKernelMEM<T, float, MPDType>
<<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
beta1_pow.data<MPDType>(),
beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<float>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
else
AdamWKernelMEM<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
beta1_pow.data<MPDType>(),
beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<T>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
if (!use_global_beta_pow) {
// Update with gpu
UpdateAdamWBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>(
UpdateAdamWBetaPow<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
beta1_pow.data<MPDType>(),
......@@ -300,9 +350,21 @@ PD_REGISTER_KERNEL(adamw,
phi::AdamwDenseKernel,
float,
double,
phi::dtype::float16) {
phi::dtype::float16,
phi::dtype::bfloat16) {
// Skip beta1_pow, beta2_pow, skip_update data transform
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND);
if (kernel_key.dtype() == phi::DataType::FLOAT16 ||
kernel_key.dtype() == phi::DataType::BFLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32);
}
kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED);
kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED);
}
......@@ -357,7 +357,8 @@ PD_REGISTER_KERNEL(check_finite_and_unscale,
phi::CheckFiniteAndUnscaleKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(update_loss_scaling,
GPU,
......@@ -365,6 +366,7 @@ PD_REGISTER_KERNEL(update_loss_scaling,
phi::UpdateLossScalingKernel,
float,
double,
phi::dtype::float16) {
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
}
......@@ -55,6 +55,7 @@ PD_REGISTER_KERNEL(matmul_with_flatten_grad,
phi::MatmulWithFlattenGradKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(matmul_with_flatten_double_grad,
......
......@@ -36,4 +36,5 @@ PD_REGISTER_KERNEL(matmul_with_flatten,
phi::MatmulWithFlattenKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
......@@ -20,15 +20,26 @@ import string
from six.moves import cStringIO
from ..proto import framework_pb2
from ..framework import OpProtoHolder, Variable, core, convert_np_dtype_to_dtype_, _non_static_mode, in_dygraph_mode, _in_legacy_dygraph
from ..framework import (
OpProtoHolder,
Variable,
core,
convert_np_dtype_to_dtype_,
_non_static_mode,
in_dygraph_mode,
_in_legacy_dygraph,
)
from ..layer_helper import LayerHelper
from ..data_feeder import check_variable_and_dtype
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
from paddle import _C_ops, _legacy_C_ops
__all__ = [
'generate_layer_fn', 'generate_activation_fn', 'generate_inplace_fn',
'autodoc', 'templatedoc'
'generate_layer_fn',
'generate_activation_fn',
'generate_inplace_fn',
'autodoc',
'templatedoc',
]
......@@ -58,16 +69,16 @@ _two_bang_pattern_ = re.compile(r"!!([^!]+)!!")
def escape_math(text):
#return _two_bang_pattern_.sub(
# return _two_bang_pattern_.sub(
# r'$$\1$$',
# _single_dollar_pattern_.sub(r':math:\n`\1`',
# _two_dollar_pattern_.sub(r"!!\1!!", text)))
return _two_dollar_pattern_.sub(r':math:`\1`', text)
def _generate_doc_string_(op_proto,
additional_args_lines=None,
skip_attrs_set=None):
def _generate_doc_string_(
op_proto, additional_args_lines=None, skip_attrs_set=None
):
"""
Generate docstring by OpProto
......@@ -147,23 +158,30 @@ def generate_layer_fn(op_type):
"""
op_proto = OpProtoHolder.instance().get_op_proto(op_type)
not_intermediate_outputs = \
[output for output in op_proto.outputs if not output.intermediate]
intermediate_outputs = \
[output for output in op_proto.outputs if output.intermediate]
not_intermediate_outputs = [
output for output in op_proto.outputs if not output.intermediate
]
intermediate_outputs = [
output for output in op_proto.outputs if output.intermediate
]
if len(not_intermediate_outputs) != 1:
raise ValueError("Only one non intermediate output operator can be",
"automatically generated. {0}".format(op_type))
raise ValueError(
"Only one non intermediate output operator can be",
"automatically generated. {0}".format(op_type),
)
if not_intermediate_outputs[0].duplicable:
raise ValueError(
"Only non duplicable op can be automatically generated.")
"Only non duplicable op can be automatically generated."
)
for output in intermediate_outputs:
if output.duplicable:
raise ValueError("The op can be automatically generated only when ",
"all intermediate ops are not duplicable.")
raise ValueError(
"The op can be automatically generated only when ",
"all intermediate ops are not duplicable.",
)
o_name = not_intermediate_outputs[0].name
intermediate_output_names = [output.name for output in intermediate_outputs]
......@@ -188,14 +206,17 @@ def generate_layer_fn(op_type):
for each in val:
if not isinstance(each, Variable):
raise ValueError(
"input of {0} must be variable".format(op_type))
"input of {0} must be variable".format(op_type)
)
if dtype is None:
dtype = each.dtype
elif dtype != each.dtype:
raise ValueError(
"operator {0} must input same dtype. {1} vs {2}".format(
op_type, dtype, each.dtype))
op_type, dtype, each.dtype
)
)
if dtype is None:
arg_dtype = kwargs.get("dtype")
......@@ -227,8 +248,11 @@ def generate_layer_fn(op_type):
outputs = dict()
out = kwargs.pop(_convert_(o_name), [])
if out:
out_var = out[0] if (isinstance(out, list)
or isinstance(out, tuple)) else out
out_var = (
out[0]
if (isinstance(out, list) or isinstance(out, tuple))
else out
)
else:
out_var = helper.create_variable_for_type_inference(dtype=dtype)
outputs[o_name] = [out_var]
......@@ -236,10 +260,9 @@ def generate_layer_fn(op_type):
outputs[name] = [
helper.create_variable_for_type_inference(dtype=dtype)
]
helper.append_op(type=op_type,
inputs=inputs,
outputs=outputs,
attrs=kwargs)
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs
)
return helper.append_activation(out_var)
func.__name__ = op_type
......@@ -270,14 +293,26 @@ def generate_activation_fn(op_type):
return op(x)
if op_type not in ["abs", "exp", "square"]:
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
op_type)
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'uint16'], op_type
)
else:
# abs exp square ops support dtype(int32, int64, float16, float32, float64)
check_variable_and_dtype(x, 'x', [
'int32', 'int64', 'float16', 'float32', 'float64', 'complex64',
'complex128'
], op_type)
check_variable_and_dtype(
x,
'x',
[
'int32',
'int64',
'float16',
'float32',
'float64',
'complex64',
'complex128',
'uint16',
],
op_type,
)
helper = LayerHelper(op_type, **locals())
......@@ -290,7 +325,8 @@ def generate_activation_fn(op_type):
op_proto,
additional_args_lines=[
"name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`."
])
],
)
return func
......@@ -310,24 +346,31 @@ def generate_inplace_fn(inplace_op_type):
op = getattr(_legacy_C_ops, inplace_op_type)
return op(x)
warnings.warn(
"In static mode, {}() is the same as {}() and does not perform inplace operation."
.format(inplace_op_type, origin_op_type))
"In static mode, {}() is the same as {}() and does not perform inplace operation.".format(
inplace_op_type, origin_op_type
)
)
return generate_activation_fn(origin_op_type)(x, name)
func.__name__ = inplace_op_type
func.__doc__ = """
Inplace version of ``{0}`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_fluid_layers_{1}`.
""".format(origin_op_type, origin_op_type)
""".format(
origin_op_type, origin_op_type
)
return func
def autodoc(comment=""):
def __impl__(func):
func.__doc__ = _generate_doc_string_(
OpProtoHolder.instance().get_op_proto(func.__name__)) + comment
func.__doc__ = (
_generate_doc_string_(
OpProtoHolder.instance().get_op_proto(func.__name__)
)
+ comment
)
return func
return __impl__
......@@ -372,18 +415,21 @@ def templatedoc(op_type=None):
for each_input in op_proto.inputs:
input_name = _convert_(each_input.name)
args["{0}_comment".format(input_name)] = trim_ending_dot(
each_input.comment)
each_input.comment
)
args["{0}_type".format(input_name)] = "Variable"
for each_attr in op_proto.attrs:
input_name = _convert_(each_attr.name)
args["{0}_comment".format(input_name)] = trim_ending_dot(
each_attr.comment)
each_attr.comment
)
args["{0}_type".format(input_name)] = _type_to_str_(each_attr.type)
for each_opt in op_proto.outputs:
output_name = _convert_(each_opt.name)
args["{0}_comment".format(output_name)] = trim_ending_dot(
each_opt.comment)
each_opt.comment
)
args["{0}_type".format(output_name)] = "Variable"
func.__doc__ = tmpl.substitute(args)
return func
......@@ -393,7 +439,7 @@ def templatedoc(op_type=None):
def add_sample_code(func, sample_code):
"""
Append sample code for dynamically generated functions.
Append sample code for dynamically generated functions.
Args:
func: The function of the function to be append sample code to.
......
......@@ -5106,7 +5106,7 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None):
check_variable_and_dtype(
input,
'input',
['float16', 'float32', 'float64', 'int32', 'int64'],
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'reduce_sum',
)
helper = LayerHelper('reduce_sum', **locals())
......
......@@ -411,7 +411,7 @@ def layer_norm(
return dygraph_utils._append_activation_in_dygraph(pre_act, act=None)
check_variable_and_dtype(
x, 'input', ['float16', 'float32', 'float64'], 'LayerNorm'
x, 'input', ['float16', 'float32', 'float64', 'uint16'], 'LayerNorm'
)
inputs = dict()
......
......@@ -482,10 +482,10 @@ def _elementwise_op(helper):
assert x is not None, 'x cannot be None in {}'.format(original_op_type)
assert y is not None, 'y cannot be None in {}'.format(original_op_type)
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'],
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool', 'uint16'],
original_op_type)
check_variable_and_dtype(
y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'],
y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool', 'uint16'],
original_op_type)
axis = helper.kwargs.get('axis', -1)
......@@ -1537,10 +1537,10 @@ def add_n(inputs, name=None):
if len(inputs) > 0:
for input in inputs:
check_variable_and_dtype(input, "inputs", \
['float16', 'float32', 'float64', 'int32', 'int64'], 'add_n')
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'add_n')
else:
check_variable_and_dtype(inputs, "inputs", \
['float16', 'float32', 'float64', 'int32', 'int64'], 'add_n')
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'add_n')
out = helper.create_variable_for_type_inference(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册