未验证 提交 ba9a22db 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add bfloat16 support for several operators and apis. (#52696)

* Cherry-pick the register of bfloat16 for amp_kernel, pull request #45541.

* Cherry-pick the master_grad support of adamw, pull request #51141.

* add bf16 for some ops in static mode (#51582)

* Add bfloat16 support for some api in static mode.

* Fix codestyle.

* Revert the change of layer_function_generator.py.

---------
Co-authored-by: shaojie_wang's avatarShaojie WANG <wsjmessi@163.com>
上级 95c3d613
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
#include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/for_range.h"
namespace phi { namespace phi {
template <typename T, typename MT> template <typename T, typename TG, typename MT>
__global__ void AdamWKernelREG(MT beta1, __global__ void AdamWKernelREG(MT beta1,
MT beta2, MT beta2,
MT epsilon, MT epsilon,
...@@ -42,7 +42,7 @@ __global__ void AdamWKernelREG(MT beta1, ...@@ -42,7 +42,7 @@ __global__ void AdamWKernelREG(MT beta1,
const MT* moment2, const MT* moment2,
MT* moment2_out, MT* moment2_out,
const MT* lr_, const MT* lr_,
const T* grad, const TG* grad,
const T* param, const T* param,
T* param_out, T* param_out,
const MT* master_param, const MT* master_param,
...@@ -78,7 +78,7 @@ __global__ void AdamWKernelREG(MT beta1, ...@@ -78,7 +78,7 @@ __global__ void AdamWKernelREG(MT beta1,
} }
} }
template <typename T, typename MT> template <typename T, typename TG, typename MT>
__global__ void AdamWKernelMEM(MT beta1, __global__ void AdamWKernelMEM(MT beta1,
MT beta2, MT beta2,
MT epsilon, MT epsilon,
...@@ -91,7 +91,7 @@ __global__ void AdamWKernelMEM(MT beta1, ...@@ -91,7 +91,7 @@ __global__ void AdamWKernelMEM(MT beta1,
const MT* moment2, const MT* moment2,
MT* moment2_out, MT* moment2_out,
const MT* lr_, const MT* lr_,
const T* grad, const TG* grad,
const T* param, const T* param,
T* param_out, T* param_out,
const MT* master_param, const MT* master_param,
...@@ -167,6 +167,8 @@ void AdamwDenseKernel(const Context& dev_ctx, ...@@ -167,6 +167,8 @@ void AdamwDenseKernel(const Context& dev_ctx,
DenseTensor* master_param_outs) { DenseTensor* master_param_outs) {
using MPDType = typename phi::dtype::MPTypeTrait<T>::Type; using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
const auto grad_type = grad.dtype();
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
MPDType coeff_ = static_cast<MPDType>(coeff); MPDType coeff_ = static_cast<MPDType>(coeff);
...@@ -191,8 +193,10 @@ void AdamwDenseKernel(const Context& dev_ctx, ...@@ -191,8 +193,10 @@ void AdamwDenseKernel(const Context& dev_ctx,
phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out);
phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out);
phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out);
phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); if (!use_global_beta_pow) {
phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out);
phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out);
}
return; return;
} }
...@@ -233,25 +237,49 @@ void AdamwDenseKernel(const Context& dev_ctx, ...@@ -233,25 +237,49 @@ void AdamwDenseKernel(const Context& dev_ctx,
if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) { if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) {
// Compute with betapow in REG // Compute with betapow in REG
AdamWKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>( if (grad_type == phi::DataType::FLOAT32)
beta1_, AdamWKernelREG<T, float, MPDType>
beta2_, <<<blocks, threads, 0, dev_ctx.stream()>>>(
epsilon_, beta1_,
coeff_, beta2_,
lr_ratio_, epsilon_,
*beta1_pow.data<MPDType>(), coeff_,
*beta2_pow.data<MPDType>(), lr_ratio_,
moment1.data<MPDType>(), *beta1_pow.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out), *beta2_pow.data<MPDType>(),
moment2.data<MPDType>(), moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out), dev_ctx.template Alloc<MPDType>(moment1_out),
learning_rate.data<MPDType>(), moment2.data<MPDType>(),
grad.data<T>(), dev_ctx.template Alloc<MPDType>(moment2_out),
param.data<T>(), learning_rate.data<MPDType>(),
dev_ctx.template Alloc<T>(param_out), grad.data<float>(),
master_in_data, param.data<T>(),
master_out_data, dev_ctx.template Alloc<T>(param_out),
param.numel()); master_in_data,
master_out_data,
param.numel());
else
AdamWKernelREG<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
*beta1_pow.data<MPDType>(),
*beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<T>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
if (!use_global_beta_pow) { if (!use_global_beta_pow) {
// Cpu update // Cpu update
dev_ctx.template HostAlloc<MPDType>(beta1_pow_out)[0] = dev_ctx.template HostAlloc<MPDType>(beta1_pow_out)[0] =
...@@ -260,28 +288,50 @@ void AdamwDenseKernel(const Context& dev_ctx, ...@@ -260,28 +288,50 @@ void AdamwDenseKernel(const Context& dev_ctx,
beta2_ * beta2_pow.data<MPDType>()[0]; beta2_ * beta2_pow.data<MPDType>()[0];
} }
} else { } else {
AdamWKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>( if (grad_type == phi::DataType::FLOAT32)
beta1_, AdamWKernelMEM<T, float, MPDType>
beta2_, <<<blocks, threads, 0, dev_ctx.stream()>>>(
epsilon_, beta1_,
coeff_, beta2_,
lr_ratio_, epsilon_,
beta1_pow.data<MPDType>(), coeff_,
beta2_pow.data<MPDType>(), lr_ratio_,
moment1.data<MPDType>(), beta1_pow.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out), beta2_pow.data<MPDType>(),
moment2.data<MPDType>(), moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out), dev_ctx.template Alloc<MPDType>(moment1_out),
learning_rate.data<MPDType>(), moment2.data<MPDType>(),
grad.data<T>(), dev_ctx.template Alloc<MPDType>(moment2_out),
param.data<T>(), learning_rate.data<MPDType>(),
dev_ctx.template Alloc<T>(param_out), grad.data<float>(),
master_in_data, param.data<T>(),
master_out_data, dev_ctx.template Alloc<T>(param_out),
param.numel()); master_in_data,
master_out_data,
param.numel());
else
AdamWKernelMEM<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
beta1_pow.data<MPDType>(),
beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<T>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
if (!use_global_beta_pow) { if (!use_global_beta_pow) {
// Update with gpu // Update with gpu
UpdateAdamWBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>( UpdateAdamWBetaPow<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
beta1_, beta1_,
beta2_, beta2_,
beta1_pow.data<MPDType>(), beta1_pow.data<MPDType>(),
...@@ -300,9 +350,21 @@ PD_REGISTER_KERNEL(adamw, ...@@ -300,9 +350,21 @@ PD_REGISTER_KERNEL(adamw,
phi::AdamwDenseKernel, phi::AdamwDenseKernel,
float, float,
double, double,
phi::dtype::float16) { phi::dtype::float16,
phi::dtype::bfloat16) {
// Skip beta1_pow, beta2_pow, skip_update data transform // Skip beta1_pow, beta2_pow, skip_update data transform
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND);
if (kernel_key.dtype() == phi::DataType::FLOAT16 ||
kernel_key.dtype() == phi::DataType::BFLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32);
}
kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED);
kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED);
} }
...@@ -357,7 +357,8 @@ PD_REGISTER_KERNEL(check_finite_and_unscale, ...@@ -357,7 +357,8 @@ PD_REGISTER_KERNEL(check_finite_and_unscale,
phi::CheckFiniteAndUnscaleKernel, phi::CheckFiniteAndUnscaleKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(update_loss_scaling, PD_REGISTER_KERNEL(update_loss_scaling,
GPU, GPU,
...@@ -365,6 +366,7 @@ PD_REGISTER_KERNEL(update_loss_scaling, ...@@ -365,6 +366,7 @@ PD_REGISTER_KERNEL(update_loss_scaling,
phi::UpdateLossScalingKernel, phi::UpdateLossScalingKernel,
float, float,
double, double,
phi::dtype::float16) { phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
} }
...@@ -55,6 +55,7 @@ PD_REGISTER_KERNEL(matmul_with_flatten_grad, ...@@ -55,6 +55,7 @@ PD_REGISTER_KERNEL(matmul_with_flatten_grad,
phi::MatmulWithFlattenGradKernel, phi::MatmulWithFlattenGradKernel,
float, float,
double, double,
phi::dtype::bfloat16,
phi::dtype::float16) {} phi::dtype::float16) {}
PD_REGISTER_KERNEL(matmul_with_flatten_double_grad, PD_REGISTER_KERNEL(matmul_with_flatten_double_grad,
......
...@@ -36,4 +36,5 @@ PD_REGISTER_KERNEL(matmul_with_flatten, ...@@ -36,4 +36,5 @@ PD_REGISTER_KERNEL(matmul_with_flatten,
phi::MatmulWithFlattenKernel, phi::MatmulWithFlattenKernel,
float, float,
double, double,
phi::dtype::bfloat16,
phi::dtype::float16) {} phi::dtype::float16) {}
...@@ -20,15 +20,26 @@ import string ...@@ -20,15 +20,26 @@ import string
from six.moves import cStringIO from six.moves import cStringIO
from ..proto import framework_pb2 from ..proto import framework_pb2
from ..framework import OpProtoHolder, Variable, core, convert_np_dtype_to_dtype_, _non_static_mode, in_dygraph_mode, _in_legacy_dygraph from ..framework import (
OpProtoHolder,
Variable,
core,
convert_np_dtype_to_dtype_,
_non_static_mode,
in_dygraph_mode,
_in_legacy_dygraph,
)
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..data_feeder import check_variable_and_dtype from ..data_feeder import check_variable_and_dtype
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
from paddle import _C_ops, _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
__all__ = [ __all__ = [
'generate_layer_fn', 'generate_activation_fn', 'generate_inplace_fn', 'generate_layer_fn',
'autodoc', 'templatedoc' 'generate_activation_fn',
'generate_inplace_fn',
'autodoc',
'templatedoc',
] ]
...@@ -58,16 +69,16 @@ _two_bang_pattern_ = re.compile(r"!!([^!]+)!!") ...@@ -58,16 +69,16 @@ _two_bang_pattern_ = re.compile(r"!!([^!]+)!!")
def escape_math(text): def escape_math(text):
#return _two_bang_pattern_.sub( # return _two_bang_pattern_.sub(
# r'$$\1$$', # r'$$\1$$',
# _single_dollar_pattern_.sub(r':math:\n`\1`', # _single_dollar_pattern_.sub(r':math:\n`\1`',
# _two_dollar_pattern_.sub(r"!!\1!!", text))) # _two_dollar_pattern_.sub(r"!!\1!!", text)))
return _two_dollar_pattern_.sub(r':math:`\1`', text) return _two_dollar_pattern_.sub(r':math:`\1`', text)
def _generate_doc_string_(op_proto, def _generate_doc_string_(
additional_args_lines=None, op_proto, additional_args_lines=None, skip_attrs_set=None
skip_attrs_set=None): ):
""" """
Generate docstring by OpProto Generate docstring by OpProto
...@@ -147,23 +158,30 @@ def generate_layer_fn(op_type): ...@@ -147,23 +158,30 @@ def generate_layer_fn(op_type):
""" """
op_proto = OpProtoHolder.instance().get_op_proto(op_type) op_proto = OpProtoHolder.instance().get_op_proto(op_type)
not_intermediate_outputs = \ not_intermediate_outputs = [
[output for output in op_proto.outputs if not output.intermediate] output for output in op_proto.outputs if not output.intermediate
intermediate_outputs = \ ]
[output for output in op_proto.outputs if output.intermediate] intermediate_outputs = [
output for output in op_proto.outputs if output.intermediate
]
if len(not_intermediate_outputs) != 1: if len(not_intermediate_outputs) != 1:
raise ValueError("Only one non intermediate output operator can be", raise ValueError(
"automatically generated. {0}".format(op_type)) "Only one non intermediate output operator can be",
"automatically generated. {0}".format(op_type),
)
if not_intermediate_outputs[0].duplicable: if not_intermediate_outputs[0].duplicable:
raise ValueError( raise ValueError(
"Only non duplicable op can be automatically generated.") "Only non duplicable op can be automatically generated."
)
for output in intermediate_outputs: for output in intermediate_outputs:
if output.duplicable: if output.duplicable:
raise ValueError("The op can be automatically generated only when ", raise ValueError(
"all intermediate ops are not duplicable.") "The op can be automatically generated only when ",
"all intermediate ops are not duplicable.",
)
o_name = not_intermediate_outputs[0].name o_name = not_intermediate_outputs[0].name
intermediate_output_names = [output.name for output in intermediate_outputs] intermediate_output_names = [output.name for output in intermediate_outputs]
...@@ -188,14 +206,17 @@ def generate_layer_fn(op_type): ...@@ -188,14 +206,17 @@ def generate_layer_fn(op_type):
for each in val: for each in val:
if not isinstance(each, Variable): if not isinstance(each, Variable):
raise ValueError( raise ValueError(
"input of {0} must be variable".format(op_type)) "input of {0} must be variable".format(op_type)
)
if dtype is None: if dtype is None:
dtype = each.dtype dtype = each.dtype
elif dtype != each.dtype: elif dtype != each.dtype:
raise ValueError( raise ValueError(
"operator {0} must input same dtype. {1} vs {2}".format( "operator {0} must input same dtype. {1} vs {2}".format(
op_type, dtype, each.dtype)) op_type, dtype, each.dtype
)
)
if dtype is None: if dtype is None:
arg_dtype = kwargs.get("dtype") arg_dtype = kwargs.get("dtype")
...@@ -227,8 +248,11 @@ def generate_layer_fn(op_type): ...@@ -227,8 +248,11 @@ def generate_layer_fn(op_type):
outputs = dict() outputs = dict()
out = kwargs.pop(_convert_(o_name), []) out = kwargs.pop(_convert_(o_name), [])
if out: if out:
out_var = out[0] if (isinstance(out, list) out_var = (
or isinstance(out, tuple)) else out out[0]
if (isinstance(out, list) or isinstance(out, tuple))
else out
)
else: else:
out_var = helper.create_variable_for_type_inference(dtype=dtype) out_var = helper.create_variable_for_type_inference(dtype=dtype)
outputs[o_name] = [out_var] outputs[o_name] = [out_var]
...@@ -236,10 +260,9 @@ def generate_layer_fn(op_type): ...@@ -236,10 +260,9 @@ def generate_layer_fn(op_type):
outputs[name] = [ outputs[name] = [
helper.create_variable_for_type_inference(dtype=dtype) helper.create_variable_for_type_inference(dtype=dtype)
] ]
helper.append_op(type=op_type, helper.append_op(
inputs=inputs, type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs
outputs=outputs, )
attrs=kwargs)
return helper.append_activation(out_var) return helper.append_activation(out_var)
func.__name__ = op_type func.__name__ = op_type
...@@ -270,14 +293,26 @@ def generate_activation_fn(op_type): ...@@ -270,14 +293,26 @@ def generate_activation_fn(op_type):
return op(x) return op(x)
if op_type not in ["abs", "exp", "square"]: if op_type not in ["abs", "exp", "square"]:
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], check_variable_and_dtype(
op_type) x, 'x', ['float16', 'float32', 'float64', 'uint16'], op_type
)
else: else:
# abs exp square ops support dtype(int32, int64, float16, float32, float64) # abs exp square ops support dtype(int32, int64, float16, float32, float64)
check_variable_and_dtype(x, 'x', [ check_variable_and_dtype(
'int32', 'int64', 'float16', 'float32', 'float64', 'complex64', x,
'complex128' 'x',
], op_type) [
'int32',
'int64',
'float16',
'float32',
'float64',
'complex64',
'complex128',
'uint16',
],
op_type,
)
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
...@@ -290,7 +325,8 @@ def generate_activation_fn(op_type): ...@@ -290,7 +325,8 @@ def generate_activation_fn(op_type):
op_proto, op_proto,
additional_args_lines=[ additional_args_lines=[
"name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`." "name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`."
]) ],
)
return func return func
...@@ -310,24 +346,31 @@ def generate_inplace_fn(inplace_op_type): ...@@ -310,24 +346,31 @@ def generate_inplace_fn(inplace_op_type):
op = getattr(_legacy_C_ops, inplace_op_type) op = getattr(_legacy_C_ops, inplace_op_type)
return op(x) return op(x)
warnings.warn( warnings.warn(
"In static mode, {}() is the same as {}() and does not perform inplace operation." "In static mode, {}() is the same as {}() and does not perform inplace operation.".format(
.format(inplace_op_type, origin_op_type)) inplace_op_type, origin_op_type
)
)
return generate_activation_fn(origin_op_type)(x, name) return generate_activation_fn(origin_op_type)(x, name)
func.__name__ = inplace_op_type func.__name__ = inplace_op_type
func.__doc__ = """ func.__doc__ = """
Inplace version of ``{0}`` API, the output Tensor will be inplaced with input ``x``. Inplace version of ``{0}`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_fluid_layers_{1}`. Please refer to :ref:`api_fluid_layers_{1}`.
""".format(origin_op_type, origin_op_type) """.format(
origin_op_type, origin_op_type
)
return func return func
def autodoc(comment=""): def autodoc(comment=""):
def __impl__(func): def __impl__(func):
func.__doc__ = _generate_doc_string_( func.__doc__ = (
OpProtoHolder.instance().get_op_proto(func.__name__)) + comment _generate_doc_string_(
OpProtoHolder.instance().get_op_proto(func.__name__)
)
+ comment
)
return func return func
return __impl__ return __impl__
...@@ -372,18 +415,21 @@ def templatedoc(op_type=None): ...@@ -372,18 +415,21 @@ def templatedoc(op_type=None):
for each_input in op_proto.inputs: for each_input in op_proto.inputs:
input_name = _convert_(each_input.name) input_name = _convert_(each_input.name)
args["{0}_comment".format(input_name)] = trim_ending_dot( args["{0}_comment".format(input_name)] = trim_ending_dot(
each_input.comment) each_input.comment
)
args["{0}_type".format(input_name)] = "Variable" args["{0}_type".format(input_name)] = "Variable"
for each_attr in op_proto.attrs: for each_attr in op_proto.attrs:
input_name = _convert_(each_attr.name) input_name = _convert_(each_attr.name)
args["{0}_comment".format(input_name)] = trim_ending_dot( args["{0}_comment".format(input_name)] = trim_ending_dot(
each_attr.comment) each_attr.comment
)
args["{0}_type".format(input_name)] = _type_to_str_(each_attr.type) args["{0}_type".format(input_name)] = _type_to_str_(each_attr.type)
for each_opt in op_proto.outputs: for each_opt in op_proto.outputs:
output_name = _convert_(each_opt.name) output_name = _convert_(each_opt.name)
args["{0}_comment".format(output_name)] = trim_ending_dot( args["{0}_comment".format(output_name)] = trim_ending_dot(
each_opt.comment) each_opt.comment
)
args["{0}_type".format(output_name)] = "Variable" args["{0}_type".format(output_name)] = "Variable"
func.__doc__ = tmpl.substitute(args) func.__doc__ = tmpl.substitute(args)
return func return func
...@@ -393,7 +439,7 @@ def templatedoc(op_type=None): ...@@ -393,7 +439,7 @@ def templatedoc(op_type=None):
def add_sample_code(func, sample_code): def add_sample_code(func, sample_code):
""" """
Append sample code for dynamically generated functions. Append sample code for dynamically generated functions.
Args: Args:
func: The function of the function to be append sample code to. func: The function of the function to be append sample code to.
......
...@@ -5106,7 +5106,7 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None): ...@@ -5106,7 +5106,7 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None):
check_variable_and_dtype( check_variable_and_dtype(
input, input,
'input', 'input',
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'reduce_sum', 'reduce_sum',
) )
helper = LayerHelper('reduce_sum', **locals()) helper = LayerHelper('reduce_sum', **locals())
......
...@@ -411,7 +411,7 @@ def layer_norm( ...@@ -411,7 +411,7 @@ def layer_norm(
return dygraph_utils._append_activation_in_dygraph(pre_act, act=None) return dygraph_utils._append_activation_in_dygraph(pre_act, act=None)
check_variable_and_dtype( check_variable_and_dtype(
x, 'input', ['float16', 'float32', 'float64'], 'LayerNorm' x, 'input', ['float16', 'float32', 'float64', 'uint16'], 'LayerNorm'
) )
inputs = dict() inputs = dict()
......
...@@ -482,10 +482,10 @@ def _elementwise_op(helper): ...@@ -482,10 +482,10 @@ def _elementwise_op(helper):
assert x is not None, 'x cannot be None in {}'.format(original_op_type) assert x is not None, 'x cannot be None in {}'.format(original_op_type)
assert y is not None, 'y cannot be None in {}'.format(original_op_type) assert y is not None, 'y cannot be None in {}'.format(original_op_type)
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'], x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool', 'uint16'],
original_op_type) original_op_type)
check_variable_and_dtype( check_variable_and_dtype(
y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'], y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool', 'uint16'],
original_op_type) original_op_type)
axis = helper.kwargs.get('axis', -1) axis = helper.kwargs.get('axis', -1)
...@@ -1537,10 +1537,10 @@ def add_n(inputs, name=None): ...@@ -1537,10 +1537,10 @@ def add_n(inputs, name=None):
if len(inputs) > 0: if len(inputs) > 0:
for input in inputs: for input in inputs:
check_variable_and_dtype(input, "inputs", \ check_variable_and_dtype(input, "inputs", \
['float16', 'float32', 'float64', 'int32', 'int64'], 'add_n') ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'add_n')
else: else:
check_variable_and_dtype(inputs, "inputs", \ check_variable_and_dtype(inputs, "inputs", \
['float16', 'float32', 'float64', 'int32', 'int64'], 'add_n') ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'add_n')
out = helper.create_variable_for_type_inference( out = helper.create_variable_for_type_inference(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册