未验证 提交 ba9a22db 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add bfloat16 support for several operators and apis. (#52696)

* Cherry-pick the register of bfloat16 for amp_kernel, pull request #45541.

* Cherry-pick the master_grad support of adamw, pull request #51141.

* add bf16 for some ops in static mode (#51582)

* Add bfloat16 support for some api in static mode.

* Fix codestyle.

* Revert the change of layer_function_generator.py.

---------
Co-authored-by: shaojie_wang's avatarShaojie WANG <wsjmessi@163.com>
上级 95c3d613
......@@ -29,7 +29,7 @@
#include "paddle/phi/kernels/funcs/for_range.h"
namespace phi {
template <typename T, typename MT>
template <typename T, typename TG, typename MT>
__global__ void AdamWKernelREG(MT beta1,
MT beta2,
MT epsilon,
......@@ -42,7 +42,7 @@ __global__ void AdamWKernelREG(MT beta1,
const MT* moment2,
MT* moment2_out,
const MT* lr_,
const T* grad,
const TG* grad,
const T* param,
T* param_out,
const MT* master_param,
......@@ -78,7 +78,7 @@ __global__ void AdamWKernelREG(MT beta1,
}
}
template <typename T, typename MT>
template <typename T, typename TG, typename MT>
__global__ void AdamWKernelMEM(MT beta1,
MT beta2,
MT epsilon,
......@@ -91,7 +91,7 @@ __global__ void AdamWKernelMEM(MT beta1,
const MT* moment2,
MT* moment2_out,
const MT* lr_,
const T* grad,
const TG* grad,
const T* param,
T* param_out,
const MT* master_param,
......@@ -167,6 +167,8 @@ void AdamwDenseKernel(const Context& dev_ctx,
DenseTensor* master_param_outs) {
using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
const auto grad_type = grad.dtype();
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
MPDType coeff_ = static_cast<MPDType>(coeff);
......@@ -191,8 +193,10 @@ void AdamwDenseKernel(const Context& dev_ctx,
phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out);
phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out);
phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out);
phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out);
phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out);
if (!use_global_beta_pow) {
phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out);
phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out);
}
return;
}
......@@ -233,25 +237,49 @@ void AdamwDenseKernel(const Context& dev_ctx,
if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) {
// Compute with betapow in REG
AdamWKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
*beta1_pow.data<MPDType>(),
*beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<T>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
if (grad_type == phi::DataType::FLOAT32)
AdamWKernelREG<T, float, MPDType>
<<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
*beta1_pow.data<MPDType>(),
*beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<float>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
else
AdamWKernelREG<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
*beta1_pow.data<MPDType>(),
*beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<T>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
if (!use_global_beta_pow) {
// Cpu update
dev_ctx.template HostAlloc<MPDType>(beta1_pow_out)[0] =
......@@ -260,28 +288,50 @@ void AdamwDenseKernel(const Context& dev_ctx,
beta2_ * beta2_pow.data<MPDType>()[0];
}
} else {
AdamWKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
beta1_pow.data<MPDType>(),
beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<T>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
if (grad_type == phi::DataType::FLOAT32)
AdamWKernelMEM<T, float, MPDType>
<<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
beta1_pow.data<MPDType>(),
beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<float>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
else
AdamWKernelMEM<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
beta1_pow.data<MPDType>(),
beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<T>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
if (!use_global_beta_pow) {
// Update with gpu
UpdateAdamWBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>(
UpdateAdamWBetaPow<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
beta1_pow.data<MPDType>(),
......@@ -300,9 +350,21 @@ PD_REGISTER_KERNEL(adamw,
phi::AdamwDenseKernel,
float,
double,
phi::dtype::float16) {
phi::dtype::float16,
phi::dtype::bfloat16) {
// Skip beta1_pow, beta2_pow, skip_update data transform
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND);
if (kernel_key.dtype() == phi::DataType::FLOAT16 ||
kernel_key.dtype() == phi::DataType::BFLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32);
}
kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED);
kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED);
}
......@@ -357,7 +357,8 @@ PD_REGISTER_KERNEL(check_finite_and_unscale,
phi::CheckFiniteAndUnscaleKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(update_loss_scaling,
GPU,
......@@ -365,6 +366,7 @@ PD_REGISTER_KERNEL(update_loss_scaling,
phi::UpdateLossScalingKernel,
float,
double,
phi::dtype::float16) {
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
}
......@@ -55,6 +55,7 @@ PD_REGISTER_KERNEL(matmul_with_flatten_grad,
phi::MatmulWithFlattenGradKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(matmul_with_flatten_double_grad,
......
......@@ -36,4 +36,5 @@ PD_REGISTER_KERNEL(matmul_with_flatten,
phi::MatmulWithFlattenKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
......@@ -20,15 +20,26 @@ import string
from six.moves import cStringIO
from ..proto import framework_pb2
from ..framework import OpProtoHolder, Variable, core, convert_np_dtype_to_dtype_, _non_static_mode, in_dygraph_mode, _in_legacy_dygraph
from ..framework import (
OpProtoHolder,
Variable,
core,
convert_np_dtype_to_dtype_,
_non_static_mode,
in_dygraph_mode,
_in_legacy_dygraph,
)
from ..layer_helper import LayerHelper
from ..data_feeder import check_variable_and_dtype
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
from paddle import _C_ops, _legacy_C_ops
__all__ = [
'generate_layer_fn', 'generate_activation_fn', 'generate_inplace_fn',
'autodoc', 'templatedoc'
'generate_layer_fn',
'generate_activation_fn',
'generate_inplace_fn',
'autodoc',
'templatedoc',
]
......@@ -58,16 +69,16 @@ _two_bang_pattern_ = re.compile(r"!!([^!]+)!!")
def escape_math(text):
#return _two_bang_pattern_.sub(
# return _two_bang_pattern_.sub(
# r'$$\1$$',
# _single_dollar_pattern_.sub(r':math:\n`\1`',
# _two_dollar_pattern_.sub(r"!!\1!!", text)))
return _two_dollar_pattern_.sub(r':math:`\1`', text)
def _generate_doc_string_(op_proto,
additional_args_lines=None,
skip_attrs_set=None):
def _generate_doc_string_(
op_proto, additional_args_lines=None, skip_attrs_set=None
):
"""
Generate docstring by OpProto
......@@ -147,23 +158,30 @@ def generate_layer_fn(op_type):
"""
op_proto = OpProtoHolder.instance().get_op_proto(op_type)
not_intermediate_outputs = \
[output for output in op_proto.outputs if not output.intermediate]
intermediate_outputs = \
[output for output in op_proto.outputs if output.intermediate]
not_intermediate_outputs = [
output for output in op_proto.outputs if not output.intermediate
]
intermediate_outputs = [
output for output in op_proto.outputs if output.intermediate
]
if len(not_intermediate_outputs) != 1:
raise ValueError("Only one non intermediate output operator can be",
"automatically generated. {0}".format(op_type))
raise ValueError(
"Only one non intermediate output operator can be",
"automatically generated. {0}".format(op_type),
)
if not_intermediate_outputs[0].duplicable:
raise ValueError(
"Only non duplicable op can be automatically generated.")
"Only non duplicable op can be automatically generated."
)
for output in intermediate_outputs:
if output.duplicable:
raise ValueError("The op can be automatically generated only when ",
"all intermediate ops are not duplicable.")
raise ValueError(
"The op can be automatically generated only when ",
"all intermediate ops are not duplicable.",
)
o_name = not_intermediate_outputs[0].name
intermediate_output_names = [output.name for output in intermediate_outputs]
......@@ -188,14 +206,17 @@ def generate_layer_fn(op_type):
for each in val:
if not isinstance(each, Variable):
raise ValueError(
"input of {0} must be variable".format(op_type))
"input of {0} must be variable".format(op_type)
)
if dtype is None:
dtype = each.dtype
elif dtype != each.dtype:
raise ValueError(
"operator {0} must input same dtype. {1} vs {2}".format(
op_type, dtype, each.dtype))
op_type, dtype, each.dtype
)
)
if dtype is None:
arg_dtype = kwargs.get("dtype")
......@@ -227,8 +248,11 @@ def generate_layer_fn(op_type):
outputs = dict()
out = kwargs.pop(_convert_(o_name), [])
if out:
out_var = out[0] if (isinstance(out, list)
or isinstance(out, tuple)) else out
out_var = (
out[0]
if (isinstance(out, list) or isinstance(out, tuple))
else out
)
else:
out_var = helper.create_variable_for_type_inference(dtype=dtype)
outputs[o_name] = [out_var]
......@@ -236,10 +260,9 @@ def generate_layer_fn(op_type):
outputs[name] = [
helper.create_variable_for_type_inference(dtype=dtype)
]
helper.append_op(type=op_type,
inputs=inputs,
outputs=outputs,
attrs=kwargs)
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs
)
return helper.append_activation(out_var)
func.__name__ = op_type
......@@ -270,14 +293,26 @@ def generate_activation_fn(op_type):
return op(x)
if op_type not in ["abs", "exp", "square"]:
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
op_type)
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'uint16'], op_type
)
else:
# abs exp square ops support dtype(int32, int64, float16, float32, float64)
check_variable_and_dtype(x, 'x', [
'int32', 'int64', 'float16', 'float32', 'float64', 'complex64',
'complex128'
], op_type)
check_variable_and_dtype(
x,
'x',
[
'int32',
'int64',
'float16',
'float32',
'float64',
'complex64',
'complex128',
'uint16',
],
op_type,
)
helper = LayerHelper(op_type, **locals())
......@@ -290,7 +325,8 @@ def generate_activation_fn(op_type):
op_proto,
additional_args_lines=[
"name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`."
])
],
)
return func
......@@ -310,24 +346,31 @@ def generate_inplace_fn(inplace_op_type):
op = getattr(_legacy_C_ops, inplace_op_type)
return op(x)
warnings.warn(
"In static mode, {}() is the same as {}() and does not perform inplace operation."
.format(inplace_op_type, origin_op_type))
"In static mode, {}() is the same as {}() and does not perform inplace operation.".format(
inplace_op_type, origin_op_type
)
)
return generate_activation_fn(origin_op_type)(x, name)
func.__name__ = inplace_op_type
func.__doc__ = """
Inplace version of ``{0}`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_fluid_layers_{1}`.
""".format(origin_op_type, origin_op_type)
""".format(
origin_op_type, origin_op_type
)
return func
def autodoc(comment=""):
def __impl__(func):
func.__doc__ = _generate_doc_string_(
OpProtoHolder.instance().get_op_proto(func.__name__)) + comment
func.__doc__ = (
_generate_doc_string_(
OpProtoHolder.instance().get_op_proto(func.__name__)
)
+ comment
)
return func
return __impl__
......@@ -372,18 +415,21 @@ def templatedoc(op_type=None):
for each_input in op_proto.inputs:
input_name = _convert_(each_input.name)
args["{0}_comment".format(input_name)] = trim_ending_dot(
each_input.comment)
each_input.comment
)
args["{0}_type".format(input_name)] = "Variable"
for each_attr in op_proto.attrs:
input_name = _convert_(each_attr.name)
args["{0}_comment".format(input_name)] = trim_ending_dot(
each_attr.comment)
each_attr.comment
)
args["{0}_type".format(input_name)] = _type_to_str_(each_attr.type)
for each_opt in op_proto.outputs:
output_name = _convert_(each_opt.name)
args["{0}_comment".format(output_name)] = trim_ending_dot(
each_opt.comment)
each_opt.comment
)
args["{0}_type".format(output_name)] = "Variable"
func.__doc__ = tmpl.substitute(args)
return func
......@@ -393,7 +439,7 @@ def templatedoc(op_type=None):
def add_sample_code(func, sample_code):
"""
Append sample code for dynamically generated functions.
Append sample code for dynamically generated functions.
Args:
func: The function of the function to be append sample code to.
......
......@@ -5106,7 +5106,7 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None):
check_variable_and_dtype(
input,
'input',
['float16', 'float32', 'float64', 'int32', 'int64'],
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'reduce_sum',
)
helper = LayerHelper('reduce_sum', **locals())
......
......@@ -21,14 +21,28 @@ import warnings
from ..layer_helper import LayerHelper
from ..param_attr import ParamAttr
from ..initializer import Initializer
from ..framework import _current_expected_place, convert_np_dtype_to_dtype_, _non_static_mode, _varbase_creator, device_guard, _in_legacy_dygraph, in_dygraph_mode, _get_paddle_place
from ..framework import (
_current_expected_place,
convert_np_dtype_to_dtype_,
_non_static_mode,
_varbase_creator,
device_guard,
_in_legacy_dygraph,
in_dygraph_mode,
_get_paddle_place,
)
from ..framework import Variable
from ..initializer import Constant
from ..core import VarDesc
from .. import core
from .layer_function_generator import templatedoc
from . import utils
from ..data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype
from ..data_feeder import (
check_variable_and_dtype,
check_type,
check_dtype,
convert_dtype,
)
from paddle.utils import deprecated
from .utils import check_shape
......@@ -71,7 +85,7 @@ def create_tensor(dtype, name=None, persistable=False):
Args:
dtype(string|numpy.dtype): the data type of Tensor to be created, the
data type is bool, float16, float32, float64, int8, int16, int32 and int64.
name(string, optional): The default value is None. Normally there is no need for
name(string, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`
persistable(bool): Set the persistable flag of the create tensor.
default value is False.
......@@ -85,24 +99,32 @@ def create_tensor(dtype, name=None, persistable=False):
import paddle.fluid as fluid
tensor = fluid.layers.create_tensor(dtype='float32')
"""
check_dtype(dtype, 'dtype', [
'bool', 'float16', 'float32', 'float64', 'int8', 'int32', 'int32',
'int64'
], 'create_tensor')
check_dtype(
dtype,
'dtype',
[
'bool',
'float16',
'float32',
'float64',
'int8',
'int32',
'int32',
'int64',
],
'create_tensor',
)
helper = LayerHelper("create_tensor", **locals())
return helper.create_variable(name=helper.name,
dtype=dtype,
persistable=persistable)
return helper.create_variable(
name=helper.name, dtype=dtype, persistable=persistable
)
def create_parameter(shape,
dtype,
name=None,
attr=None,
is_bias=False,
default_initializer=None):
def create_parameter(
shape, dtype, name=None, attr=None, is_bias=False, default_initializer=None
):
"""
:api_attr: Static Graph
:api_attr: Static Graph
This function creates a parameter. The parameter is a learnable variable, which can have
gradient, and can be optimized.
......@@ -134,31 +156,55 @@ def create_parameter(shape,
"""
check_type(shape, 'shape', (list, tuple, numpy.ndarray), 'create_parameter')
for item in shape:
check_type(item, 'item of shape',
(int, numpy.uint8, numpy.int8, numpy.int16, numpy.int32,
numpy.int64), 'create_parameter')
check_dtype(dtype, 'dtype', [
'bool', 'float16', 'float32', 'float64', 'int8', 'int16', 'int32',
'int64', 'uint8'
], 'create_parameter')
check_type(
item,
'item of shape',
(
int,
numpy.uint8,
numpy.int8,
numpy.int16,
numpy.int32,
numpy.int64,
),
'create_parameter',
)
check_dtype(
dtype,
'dtype',
[
'bool',
'float16',
'float32',
'float64',
'int8',
'int16',
'int32',
'int64',
'uint8',
],
'create_parameter',
)
check_type(attr, 'attr', (type(None), ParamAttr), 'create_parameter')
check_type(default_initializer, 'default_initializer',
(type(None), Initializer), 'create_parameter')
check_type(
default_initializer,
'default_initializer',
(type(None), Initializer),
'create_parameter',
)
helper = LayerHelper("create_parameter", **locals())
if attr is None:
attr = ParamAttr(name=name)
return helper.create_parameter(attr, shape, convert_dtype(dtype), is_bias,
default_initializer)
return helper.create_parameter(
attr, shape, convert_dtype(dtype), is_bias, default_initializer
)
def create_global_var(shape,
value,
dtype,
persistable=False,
force_cpu=False,
name=None):
def create_global_var(
shape, value, dtype, persistable=False, force_cpu=False, name=None
):
"""
This function creates a new tensor variable with value in the global block(block 0).
......@@ -185,35 +231,53 @@ def create_global_var(shape,
var = paddle.static.create_global_var(shape=[2,3], value=1.0, dtype='float32',
persistable=True, force_cpu=True, name='new_var')
"""
check_type(shape, 'shape', (list, tuple, numpy.ndarray),
'create_global_var')
check_type(
shape, 'shape', (list, tuple, numpy.ndarray), 'create_global_var'
)
for item in shape:
check_type(item, 'item of shape',
(int, numpy.uint8, numpy.int8, numpy.int16, numpy.int32,
numpy.int64), 'create_global_var')
check_dtype(dtype, 'dtype', [
'bool',
'float16',
'float32',
'float64',
'int8',
'int16',
'int32',
'int64',
'uint8',
'uint16',
], 'create_global_var')
check_type(
item,
'item of shape',
(
int,
numpy.uint8,
numpy.int8,
numpy.int16,
numpy.int32,
numpy.int64,
),
'create_global_var',
)
check_dtype(
dtype,
'dtype',
[
'bool',
'float16',
'float32',
'float64',
'int8',
'int16',
'int32',
'int64',
'uint8',
'uint16',
],
'create_global_var',
)
helper = LayerHelper("global_var", **locals())
var = helper.create_global_variable(dtype=dtype,
shape=shape,
persistable=persistable,
name=name,
stop_gradient=True)
helper.set_variable_initializer(var,
initializer=Constant(value=float(value),
force_cpu=force_cpu))
var = helper.create_global_variable(
dtype=dtype,
shape=shape,
persistable=persistable,
name=name,
stop_gradient=True,
)
helper.set_variable_initializer(
var, initializer=Constant(value=float(value), force_cpu=force_cpu)
)
return var
......@@ -253,25 +317,50 @@ def cast(x, dtype):
out = _legacy_C_ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype)
return out
check_variable_and_dtype(x, 'x', [
'bool', 'float16', 'float32', 'float64', 'int16', 'int32', 'int64',
'uint8', 'uint16'
], 'cast')
check_dtype(dtype, 'dtype', [
'bool', 'float16', 'float32', 'float64', 'int8', 'int16', 'int32',
'int64', 'uint8', 'uint16'
], 'cast')
check_variable_and_dtype(
x,
'x',
[
'bool',
'float16',
'float32',
'float64',
'int16',
'int32',
'int64',
'uint8',
'uint16',
],
'cast',
)
check_dtype(
dtype,
'dtype',
[
'bool',
'float16',
'float32',
'float64',
'int8',
'int16',
'int32',
'int64',
'uint8',
'uint16',
],
'cast',
)
helper = LayerHelper('cast', **locals())
out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=x.stop_gradient)
helper.append_op(type='cast',
inputs={'X': [x]},
outputs={'Out': [out]},
attrs={
'in_dtype': x.dtype,
'out_dtype': out.dtype
})
dtype=dtype, stop_gradient=x.stop_gradient
)
helper.append_op(
type='cast',
inputs={'X': [x]},
outputs={'Out': [out]},
attrs={'in_dtype': x.dtype, 'out_dtype': out.dtype},
)
return out
......@@ -281,7 +370,7 @@ def concat(input, axis=0, name=None):
Args:
input(list|tuple|Tensor): ``input`` can be Tensor, Tensor list or Tensor tuple which is with data type
bool, float16, float32, float64, int32, int64. All the Tensors in ``input`` must have the same data type.
bool, float16, float32, float64, int32, int64. All the Tensors in ``input`` must have the same data type.
axis(int|Tensor, optional): Specify the axis to operate on the input Tensors.
It's a scalar with data type int or a Tensor with shape [1] and data type int32 or int64.
The effective range is [-R, R), where R is Rank(x). When ``axis < 0``, it works the same way
......@@ -346,9 +435,11 @@ def concat(input, axis=0, name=None):
if not isinstance(input, Variable):
for id, x in enumerate(input):
check_variable_and_dtype(
x, 'input[' + str(id) + ']',
x,
'input[' + str(id) + ']',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'concat')
'concat',
)
if x.dtype != input[0].dtype:
raise TypeError(
"All the Tensors in the input must have the same data type."
......@@ -359,8 +450,11 @@ def concat(input, axis=0, name=None):
if isinstance(axis, Variable):
check_dtype(
axis.dtype, 'axis', ['int32', 'int64'], 'concat',
"The data type of axis must be int32 or int64 when axis is a Tensor"
axis.dtype,
'axis',
['int32', 'int64'],
'concat',
"The data type of axis must be int32 or int64 when axis is a Tensor",
)
helper = LayerHelper('concat', **locals())
......@@ -371,19 +465,17 @@ def concat(input, axis=0, name=None):
# This feature is supported for Dynamic-to-Static, because after transformed, the type of inputs[0]
# is LOD_TENSOR_ARRAY in some scenarios. And this feature can be used in static mode.
assert len(input) == 1, "If the elements of 'input' in concat are Variable(LoDTensorArray), " \
"number of the elements must be 1, but received %s." % len(input)
assert len(input) == 1, (
"If the elements of 'input' in concat are Variable(LoDTensorArray), "
"number of the elements must be 1, but received %s." % len(input)
)
out_index = helper.create_variable_for_type_inference(dtype="int32")
helper.append_op(type='tensor_array_to_tensor',
inputs={'X': input[0]},
outputs={
'Out': [out],
'OutIndex': [out_index]
},
attrs={
'axis': axis,
'use_stack': False
})
helper.append_op(
type='tensor_array_to_tensor',
inputs={'X': input[0]},
outputs={'Out': [out], 'OutIndex': [out_index]},
attrs={'axis': axis, 'use_stack': False},
)
else:
inputs = {'X': input}
attrs = {}
......@@ -391,10 +483,9 @@ def concat(input, axis=0, name=None):
axis.stop_gradient = True
attrs['axis'] = axis
helper.append_op(type='concat',
inputs=inputs,
outputs={'Out': [out]},
attrs=attrs)
helper.append_op(
type='concat', inputs=inputs, outputs={'Out': [out]}, attrs=attrs
)
return out
......@@ -480,33 +571,36 @@ def tensor_array_to_tensor(input, axis=1, name=None, use_stack=False):
"""
if _non_static_mode():
assert isinstance(
input, list), "The 'input' in tensor_array_to_tensor must be list"
input, list
), "The 'input' in tensor_array_to_tensor must be list"
from .nn import stack, concat
from ..dygraph import to_variable
op = stack if use_stack else concat
res = op(input, axis=axis)
sizes = to_variable(
numpy.array(list(map(lambda x: int(x.shape[axis]), input))))
numpy.array(list(map(lambda x: int(x.shape[axis]), input)))
)
return res, sizes
check_type(input, 'input', (list, Variable), 'tensor_array_to_tensor')
if isinstance(input, list):
for i, input_x in enumerate(input):
check_type(input_x, 'input[' + str(i) + ']', Variable,
'tensor_array_to_tensor')
check_type(
input_x,
'input[' + str(i) + ']',
Variable,
'tensor_array_to_tensor',
)
helper = LayerHelper('tensor_array_to_tensor', **locals())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
out_index = helper.create_variable_for_type_inference(dtype="int32")
helper.append_op(type='tensor_array_to_tensor',
inputs={'X': input},
outputs={
'Out': [out],
'OutIndex': [out_index]
},
attrs={
'axis': axis,
'use_stack': use_stack
})
helper.append_op(
type='tensor_array_to_tensor',
inputs={'X': input},
outputs={'Out': [out], 'OutIndex': [out_index]},
attrs={'axis': axis, 'use_stack': use_stack},
)
return out, out_index
......@@ -563,25 +657,36 @@ def sums(input, out=None):
check_type(input, 'input', (Variable, tuple, list), 'sums')
if isinstance(input, list) or isinstance(input, tuple):
for input_section in input:
check_variable_and_dtype(input_section, "input", \
['float16', 'float32', 'float64', 'int32', 'int64'], 'sums')
check_variable_and_dtype(
input_section,
"input",
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'sums',
)
else:
check_variable_and_dtype(input, "input", \
['float16', 'float32', 'float64', 'int32', 'int64'], 'sums')
check_variable_and_dtype(
input,
"input",
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'sums',
)
helper = LayerHelper('sum', **locals())
if out is None:
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
dtype=helper.input_dtype()
)
else:
check_variable_and_dtype(out, "out",
['float32', 'float64', 'int32', 'int64'],
'sums')
helper.append_op(type='sum',
inputs={'X': input},
outputs={'Out': out},
attrs={'use_mkldnn': False})
check_variable_and_dtype(
out, "out", ['float32', 'float64', 'int32', 'int64'], 'sums'
)
helper.append_op(
type='sum',
inputs={'X': input},
outputs={'Out': out},
attrs={'use_mkldnn': False},
)
return out
......@@ -616,9 +721,12 @@ def assign(input, output=None):
result3 = paddle.assign(np.array([[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]], dtype='float32')) # result3 = [[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]]
"""
helper = LayerHelper('assign', **locals())
check_type(input, 'input',
(Variable, numpy.ndarray, list, tuple, float, int, bool),
'assign')
check_type(
input,
'input',
(Variable, numpy.ndarray, list, tuple, float, int, bool),
'assign',
)
is_inplace = True if output is not None else False
if numpy.isscalar(input) and not isinstance(input, str):
......@@ -644,16 +752,29 @@ def assign(input, output=None):
output = core.eager.Tensor()
_legacy_C_ops.assign(input, output)
else:
check_dtype(input.dtype, 'input', [
'float16', 'uint16', 'float32', 'float64', 'int32', 'int64',
'uint8', 'bool'
], 'assign', '(When the type of input in assign is Variable.)')
check_dtype(
input.dtype,
'input',
[
'float16',
'uint16',
'float32',
'float64',
'int32',
'int64',
'uint8',
'bool',
],
'assign',
'(When the type of input in assign is Variable.)',
)
if output is None:
output = helper.create_variable_for_type_inference(
dtype=input.dtype)
helper.append_op(type='assign',
inputs={'X': [input]},
outputs={'Out': [output]})
dtype=input.dtype
)
helper.append_op(
type='assign', inputs={'X': [input]}, outputs={'Out': [output]}
)
elif isinstance(input, numpy.ndarray):
# Not support [var, var, ...] currently.
if len(input.shape) > 0 and any(isinstance(x, Variable) for x in input):
......@@ -667,7 +788,8 @@ def assign(input, output=None):
warnings.warn(
"paddle.assign doesn't support float64 input now due "
"to current platform protobuf data limitation, we convert "
"it to float32")
"it to float32"
)
dtype = VarDesc.VarType.FP32
if dtype == VarDesc.VarType.BOOL:
value_name = "bool_values"
......@@ -685,31 +807,49 @@ def assign(input, output=None):
raise TypeError(
"When the type of 'input' in assign is numpy.ndarray, "
"the data type of 'input' must be bool, float32, int32 or int64, but "
"received %s." % convert_dtype(dtype))
"received %s." % convert_dtype(dtype)
)
if input.size > 1024 * 1024:
raise ValueError("The size of input is too big. Please consider "
"saving it to file and 'load_op' to load it")
raise ValueError(
"The size of input is too big. Please consider "
"saving it to file and 'load_op' to load it"
)
if in_dygraph_mode():
if output is None:
output = zeros(list(input.shape), dtype)
_C_ops.assign_value_(output, list(input.shape), dtype, values,
_current_expected_place())
_C_ops.assign_value_(
output,
list(input.shape),
dtype,
values,
_current_expected_place(),
)
elif _in_legacy_dygraph():
if output is None:
output = core.VarBase()
_legacy_C_ops.assign_value(output, 'shape', list(input.shape),
'dtype', dtype, value_name, values)
_legacy_C_ops.assign_value(
output,
'shape',
list(input.shape),
'dtype',
dtype,
value_name,
values,
)
else:
if output is None:
output = helper.create_variable_for_type_inference(
dtype=input.dtype)
helper.append_op(type='assign_value',
outputs={'Out': [output]},
attrs={
'dtype': dtype,
'shape': list(input.shape),
value_name: values
})
dtype=input.dtype
)
helper.append_op(
type='assign_value',
outputs={'Out': [output]},
attrs={
'dtype': dtype,
'shape': list(input.shape),
value_name: values,
},
)
if is_inplace and _non_static_mode():
output._bump_inplace_version()
......@@ -731,10 +871,10 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
If ``shape`` is an Tensor, it should be an 1-D Tensor with date type int32 or int64.
dtype(np.dtype|str): Data type of the output Tensor which can
be float16, float32, float64, uint8, int16, int32, int64.
value(bool|float|int|Tensor): The constant value used to initialize
value(bool|float|int|Tensor): The constant value used to initialize
the Tensor to be created. If ``value`` is an Tensor, it should be an 1-D Tensor.
force_cpu(bool, optional): data should be on CPU if it's true, default value is False.
out(Tensor, optional): Optional output which can be any created
out(Tensor, optional): Optional output which can be any created
Tensor that meets the requirements to store the result of operation.
if ``out`` is None, a new Tensor will be create to store the result.
name(str, optional): The default value is None. Normally there is no need for user to set this
......@@ -759,7 +899,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
# attr shape is a Tensor.
shape = fluid.layers.fill_constant([2], "int32", 2) # shape=[2,2]
data4 = fluid.layers.fill_constant(shape=shape, dtype='bool', value=True) # data4=[[True,True],[True,True]]
# attr value is a Tensor.
val = fluid.layers.fill_constant([1], "float32", 2.0) # val=[2.0]
data5 = fluid.layers.fill_constant(shape=[2,1], value=val, dtype='float32') #data5=[[2.0],[2.0]]
......@@ -785,7 +925,11 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
shape = list(
map(
lambda x: x.numpy().flat[0]
if isinstance(x, Variable) else x, shape))
if isinstance(x, Variable)
else x,
shape,
)
)
break
if not isinstance(dtype, core.VarDesc.VarType):
......@@ -813,9 +957,19 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
else:
attrs['str_value'] = str(float(value.numpy().item(0)))
_legacy_C_ops.fill_constant(out, 'value', float(value), 'force_cpu',
force_cpu, 'dtype', out.dtype, 'str_value',
attrs['str_value'], 'shape', shape)
_legacy_C_ops.fill_constant(
out,
'value',
float(value),
'force_cpu',
force_cpu,
'dtype',
out.dtype,
'str_value',
attrs['str_value'],
'shape',
shape,
)
out.stop_gradient = True
return out
......@@ -827,43 +981,61 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
inputs['ValueTensor'] = value
check_shape(shape)
check_dtype(dtype, 'dtype', [
'bool', 'float16', 'float32', 'float64', 'uint8', 'int16', 'int32',
'int64', 'complex64', 'complex128'
], 'fill_constant')
check_dtype(
dtype,
'dtype',
[
'bool',
'float16',
'float32',
'float64',
'uint8',
'int16',
'int32',
'int64',
'complex64',
'complex128',
'uint16',
],
'fill_constant',
)
check_type(shape, 'shape', (Variable, list, tuple), 'fill_constant')
if out is not None:
check_variable_and_dtype(out, 'out', [convert_dtype(dtype)],
'fill_constant')
check_variable_and_dtype(
out, 'out', [convert_dtype(dtype)], 'fill_constant'
)
helper = LayerHelper("fill_constant", **locals())
utils.get_shape_tensor_inputs(inputs=inputs,
attrs=attrs,
shape=shape,
op_type='fill_constant')
utils.get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=shape, op_type='fill_constant'
)
if out is None:
out = helper.create_variable_for_type_inference(dtype=dtype)
attrs['dtype'] = out.dtype
helper.append_op(type='fill_constant',
inputs=inputs,
outputs={'Out': [out]},
attrs=attrs,
stop_gradient=True)
helper.append_op(
type='fill_constant',
inputs=inputs,
outputs={'Out': [out]},
attrs=attrs,
stop_gradient=True,
)
out.stop_gradient = True
return out
@deprecated(since='1.8.0', update_to="paddle.fluid.layers.fill_constant")
@templatedoc()
def fill_constant_batch_size_like(input,
shape,
dtype,
value,
input_dim_idx=0,
output_dim_idx=0,
force_cpu=False):
def fill_constant_batch_size_like(
input,
shape,
dtype,
value,
input_dim_idx=0,
output_dim_idx=0,
force_cpu=False,
):
"""
This OP creates a Tesnor according the shape and dtype, and initializes the
Tensor with the constants provided in ``value``. When the input is LoDTensor
......@@ -877,7 +1049,7 @@ def fill_constant_batch_size_like(input,
according the input.
dtype(np.dtype|core.VarDesc.VarType|str): The data type of created Tensor which
can be float32, float64, int32, int64.
value(float|int): The constant value used to initialize the Tensor to be created.
value(float|int): The constant value used to initialize the Tensor to be created.
input_dim_idx(int): When the value is 0 and the input is LoDTensor, the output_dim_idx
dimension of the created Tensor is set to the batch_size value of input.
The default value is 0.
......@@ -905,8 +1077,9 @@ def fill_constant_batch_size_like(input,
place = _current_expected_place()
if force_cpu:
place = core.CPUPlace()
out = _C_ops.full_batch_size_like(input, shape, dtype, value,
input_dim_idx, output_dim_idx, place)
out = _C_ops.full_batch_size_like(
input, shape, dtype, value, input_dim_idx, output_dim_idx, place
)
out.stop_gradient = True
return out
......@@ -918,25 +1091,27 @@ def fill_constant_batch_size_like(input,
'value': float(value),
'input_dim_idx': input_dim_idx,
'output_dim_idx': output_dim_idx,
'force_cpu': force_cpu
'force_cpu': force_cpu,
}
if convert_dtype(dtype) in ['int64', 'int32']:
attrs['str_value'] = str(int(value))
else:
attrs['str_value'] = str(float(value))
helper.append_op(type='fill_constant_batch_size_like',
inputs={'Input': input},
outputs={'Out': [out]},
attrs=attrs)
helper.append_op(
type='fill_constant_batch_size_like',
inputs={'Input': input},
outputs={'Out': [out]},
attrs=attrs,
)
out.stop_gradient = True
return out
def argmin(x, axis=0):
"""
:alias_main: paddle.argmin
:alias: paddle.argmin,paddle.tensor.argmin,paddle.tensor.search.argmin
:old_api: paddle.fluid.layers.argmin
:alias_main: paddle.argmin
:alias: paddle.argmin,paddle.tensor.argmin,paddle.tensor.search.argmin
:old_api: paddle.fluid.layers.argmin
**argmin**
......@@ -986,14 +1161,19 @@ def argmin(x, axis=0):
# [1 0 2]]
"""
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'uint8', 'int16', 'int32', 'int64'],
'argmin')
x,
'x',
['float32', 'float64', 'uint8', 'int16', 'int32', 'int64'],
'argmin',
)
helper = LayerHelper("arg_min", **locals())
out = helper.create_variable_for_type_inference(VarDesc.VarType.INT64)
helper.append_op(type='arg_min',
inputs={'X': x},
outputs={'Out': [out]},
attrs={'axis': axis})
helper.append_op(
type='arg_min',
inputs={'X': x},
outputs={'Out': [out]},
attrs={'axis': axis},
)
out.stop_gradient = True
return out
......@@ -1048,23 +1228,28 @@ def argmax(x, axis=0):
# [0 3 1]]
"""
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'uint8', 'int16', 'int32', 'int64'],
'argmax')
x,
'x',
['float32', 'float64', 'uint8', 'int16', 'int32', 'int64'],
'argmax',
)
helper = LayerHelper("arg_max", **locals())
out = helper.create_variable_for_type_inference(VarDesc.VarType.INT64)
helper.append_op(type='arg_max',
inputs={'X': x},
outputs={'Out': [out]},
attrs={'axis': axis})
helper.append_op(
type='arg_max',
inputs={'X': x},
outputs={'Out': [out]},
attrs={'axis': axis},
)
out.stop_gradient = True
return out
def argsort(input, axis=-1, descending=False, name=None):
"""
:alias_main: paddle.argsort
:alias: paddle.argsort,paddle.tensor.argsort,paddle.tensor.search.argsort
:old_api: paddle.fluid.layers.argsort
:alias_main: paddle.argsort
:alias: paddle.argsort,paddle.tensor.argsort,paddle.tensor.search.argsort
:old_api: paddle.fluid.layers.argsort
This OP sorts the input along the given axis, and returns sorted output
data Varibale and its corresponding index Variable with the same shape as
......@@ -1135,23 +1320,24 @@ def argsort(input, axis=-1, descending=False, name=None):
# [5. 7. 7. 9.]]]
"""
check_variable_and_dtype(
input, 'input',
['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'], 'argsort')
input,
'input',
['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'],
'argsort',
)
helper = LayerHelper("argsort", **locals())
out = helper.create_variable_for_type_inference(dtype=input.dtype,
stop_gradient=True)
ids = helper.create_variable_for_type_inference(VarDesc.VarType.INT64,
stop_gradient=True)
helper.append_op(type='argsort',
inputs={'X': input},
outputs={
'Out': out,
'Indices': ids
},
attrs={
'axis': axis,
'descending': descending
})
out = helper.create_variable_for_type_inference(
dtype=input.dtype, stop_gradient=True
)
ids = helper.create_variable_for_type_inference(
VarDesc.VarType.INT64, stop_gradient=True
)
helper.append_op(
type='argsort',
inputs={'X': input},
outputs={'Out': out, 'Indices': ids},
attrs={'axis': axis, 'descending': descending},
)
return out, ids
......@@ -1176,7 +1362,7 @@ def ones(shape, dtype, force_cpu=False):
import paddle.fluid as fluid
data0 = fluid.layers.ones(shape=[2, 4], dtype='float32') # [[1., 1., 1., 1.], [1., 1., 1., 1.]]
# shape is a Tensor
shape = fluid.layers.fill_constant(shape=[2], dtype='int32', value=2)
data1 = fluid.layers.ones(shape=shape, dtype='int32') #[[1, 1], [1, 1]]
......@@ -1207,7 +1393,7 @@ def zeros(shape, dtype, force_cpu=False, name=None):
import paddle.fluid as fluid
data = fluid.layers.zeros(shape=[3, 2], dtype='float32') # [[0., 0.], [0., 0.], [0., 0.]]
# shape is a Tensor
shape = fluid.layers.fill_constant(shape=[2], dtype='int32', value=2)
data1 = fluid.layers.zeros(shape=shape, dtype='int32') #[[0, 0], [0, 0]]
......@@ -1217,9 +1403,9 @@ def zeros(shape, dtype, force_cpu=False, name=None):
def reverse(x, axis):
"""
:alias_main: paddle.reverse
:alias: paddle.reverse,paddle.tensor.reverse,paddle.tensor.manipulation.reverse
:old_api: paddle.fluid.layers.reverse
:alias_main: paddle.reverse
:alias: paddle.reverse,paddle.tensor.reverse,paddle.tensor.manipulation.reverse
:old_api: paddle.fluid.layers.reverse
The OP reverses the tensor :attr:`x` along the given :attr:`axis`.
......@@ -1277,9 +1463,9 @@ def reverse(x, axis):
reversed_tensor_array = fluid.layers.reverse(tensor_array, 0) # {[[3, 4, 5]], [[0, 1, 2]]}
"""
check_variable_and_dtype(x, 'x',
('float32', 'float64', 'int32', 'int64', 'uint8'),
'reverse')
check_variable_and_dtype(
x, 'x', ('float32', 'float64', 'int32', 'int64', 'uint8'), 'reverse'
)
check_type(axis, 'axis', (int, tuple, list, Variable), 'reverse')
if isinstance(axis, int):
axis = [axis]
......@@ -1287,10 +1473,12 @@ def reverse(x, axis):
return _C_ops.reverse(x, axis)
helper = LayerHelper("reverse", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='reverse',
inputs={'X': x},
outputs={'Out': [out]},
attrs={'axis': axis})
helper.append_op(
type='reverse',
inputs={'X': x},
outputs={'Out': [out]},
attrs={'axis': axis},
)
return out
......@@ -1306,13 +1494,12 @@ def save(x, file_path, overwrite=True):
error will be thrown.
"""
helper = LayerHelper("save", **locals())
helper.append_op(type="save",
inputs={"input": x},
outputs={},
args={
"file_path": file_path,
"overwrite": overwrite
})
helper.append_op(
type="save",
inputs={"input": x},
outputs={},
args={"file_path": file_path, "overwrite": overwrite},
)
def save_combine(x, file_path, overwrite=True):
......@@ -1344,13 +1531,12 @@ def save_combine(x, file_path, overwrite=True):
normed = fluid.layers.save_combine([v1, v2], file_path="output")
"""
helper = LayerHelper("save_combine", **locals())
helper.append_op(type="save_combine",
inputs={"input": x},
outputs={},
args={
"file_path": file_path,
"overwrite": overwrite
})
helper.append_op(
type="save_combine",
inputs={"input": x},
outputs={},
args={"file_path": file_path, "overwrite": overwrite},
)
def load_combine(out, file_path):
......@@ -1362,10 +1548,12 @@ def load_combine(out, file_path):
file_path(str): The path of the disk file.
"""
helper = LayerHelper("load_combine", **locals())
helper.append_op(type="load_combine",
inputs={},
output={"Out": out},
args={"file_path": file_path})
helper.append_op(
type="load_combine",
inputs={},
output={"Out": out},
args={"file_path": file_path},
)
def has_inf(x):
......@@ -1377,10 +1565,10 @@ def has_inf(x):
Returns:
Tensor: The tensor storing the output, only a bool value, indicating that whether there is infinity number in x or not.
Examples:
.. code-block:: python
import paddle
data = paddle.randn(shape=[4, 32, 32], dtype="float32")
res = paddle.fluid.layers.has_inf(data)
......@@ -1406,10 +1594,10 @@ def has_nan(x):
Returns:
Tensor: The tensor variable storing the output, only a bool value, indicating that whether there is NAN in x or not.
Examples:
.. code-block:: python
import paddle
data = paddle.randn(shape=[2,3], dtype="float32")
res = paddle.fluid.layers.has_nan(data)
......@@ -1449,8 +1637,9 @@ def isfinite(x):
print(y)
"""
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"isfinite")
check_variable_and_dtype(
x, "x", ["float32", "float64", "int32", "int64"], "isfinite"
)
helper = LayerHelper("isfinite", **locals())
out = helper.create_variable_for_type_inference(dtype='bool')
......@@ -1485,7 +1674,7 @@ def range(start, end, step, dtype, name=None):
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
Returns:
Tensor: A 1-D Tensor with values from the interval [``start``, ``end``)
taken with common difference ``step`` beginning from ``start``. Its
data type is set by ``dtype``.
......@@ -1508,8 +1697,11 @@ def range(start, end, step, dtype, name=None):
"""
out_shape = None
if not isinstance(start, Variable) and not isinstance(
end, Variable) and not isinstance(step, Variable):
if (
not isinstance(start, Variable)
and not isinstance(end, Variable)
and not isinstance(step, Variable)
):
out_shape = [int(math.ceil((end - start) / step))]
if not isinstance(dtype, core.VarDesc.VarType):
......@@ -1541,17 +1733,16 @@ def range(start, end, step, dtype, name=None):
out.stop_gradient = True
return out
check_dtype(dtype, 'dtype', ['float32', 'float64', 'int32', 'int64'],
'range/arange')
check_dtype(
dtype, 'dtype', ['float32', 'float64', 'int32', 'int64'], 'range/arange'
)
helper = LayerHelper('range', **locals())
out = helper.create_variable_for_type_inference(dtype, shape=out_shape)
helper.append_op(type='range',
inputs={
'Start': start,
'End': end,
'Step': step
},
outputs={'Out': out})
helper.append_op(
type='range',
inputs={'Start': start, 'End': end, 'Step': step},
outputs={'Out': out},
)
out.stop_gradient = True
if out_shape is not None:
out.desc.set_shape(out_shape)
......@@ -1606,58 +1797,76 @@ def linspace(start, stop, num, dtype=None, name=None):
with device_guard("cpu"):
tensor_num = fill_constant([1], 'int32', num)
if in_dygraph_mode():
return _C_ops.linspace(tensor_start, tensor_stop, tensor_num, dtype,
_current_expected_place())
return _C_ops.linspace(
tensor_start,
tensor_stop,
tensor_num,
dtype,
_current_expected_place(),
)
if _in_legacy_dygraph():
return _legacy_C_ops.linspace(tensor_start, tensor_stop, tensor_num,
'dtype', dtype)
return _legacy_C_ops.linspace(
tensor_start, tensor_stop, tensor_num, 'dtype', dtype
)
helper = LayerHelper("linspace", **locals())
start_dtype = convert_dtype(tensor_start.dtype)
stop_dtype = convert_dtype(tensor_stop.dtype)
out_dtype = convert_dtype(dtype)
if isinstance(start, Variable):
check_dtype(start.dtype, 'start',
['float32', 'float64', 'int32', 'int64'], 'linspace')
check_dtype(
start.dtype,
'start',
['float32', 'float64', 'int32', 'int64'],
'linspace',
)
else:
check_type(start, 'start', (int, float), 'linspace')
if isinstance(stop, Variable):
check_dtype(stop.dtype, 'stop',
['float32', 'float64', 'int32', 'int64'], 'linspace')
check_dtype(
stop.dtype,
'stop',
['float32', 'float64', 'int32', 'int64'],
'linspace',
)
else:
check_type(stop, 'stop', (int, float), 'linspace')
if isinstance(num, Variable):
check_dtype(num.dtype, 'num', ['int32'], 'linspace')
check_dtype(dtype, 'dtype', ['int32', 'int64', 'float32', 'float64'],
'linspace')
if ((stop_dtype == "float64" or start_dtype == "float64")
and out_dtype in ["float32", "int32"]) or (
(stop_dtype == "int64" or start_dtype == "int64")
and out_dtype == "int32"):
check_dtype(
dtype, 'dtype', ['int32', 'int64', 'float32', 'float64'], 'linspace'
)
if (
(stop_dtype == "float64" or start_dtype == "float64")
and out_dtype in ["float32", "int32"]
) or (
(stop_dtype == "int64" or start_dtype == "int64")
and out_dtype == "int32"
):
raise ValueError(
"The dtype of start/stop is {}/{} but the attr(dtype) of linspace is {}, "
"which may cause data type overflows. Please reset attr(dtype) of linspace."
.format(start_dtype, stop_dtype, dtype))
"which may cause data type overflows. Please reset attr(dtype) of linspace.".format(
start_dtype, stop_dtype, dtype
)
)
out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op(type='linspace',
inputs={
'Start': tensor_start,
'Stop': tensor_stop,
'Num': tensor_num
},
attrs={'dtype': dtype},
outputs={'Out': [out]})
helper.append_op(
type='linspace',
inputs={'Start': tensor_start, 'Stop': tensor_stop, 'Num': tensor_num},
attrs={'dtype': dtype},
outputs={'Out': [out]},
)
if isinstance(num, int):
out.desc.set_shape((num, ))
out.desc.set_shape((num,))
return out
def zeros_like(x, out=None):
"""
This OP creates a zeros tensor which has identical shape and dtype
This OP creates a zeros tensor which has identical shape and dtype
with `x`.
Args:
......@@ -1681,23 +1890,25 @@ def zeros_like(x, out=None):
data = fluid.layers.zeros_like(x) # [0.0, 0.0, 0.0]
"""
check_variable_and_dtype(x, "x",
['bool', 'float32', 'float64', 'int32', 'int64'],
'zeros_like')
check_variable_and_dtype(
x, "x", ['bool', 'float32', 'float64', 'int32', 'int64'], 'zeros_like'
)
helper = LayerHelper("zeros_like", **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
check_variable_and_dtype(
out, "out", ['bool', 'float32', 'float64', 'int32', 'int64'],
'zeros_like')
helper.append_op(type='fill_any_like',
inputs={'X': [x]},
attrs={
'value': 0,
"dtype": x.dtype
},
outputs={'Out': [out]})
out,
"out",
['bool', 'float32', 'float64', 'int32', 'int64'],
'zeros_like',
)
helper.append_op(
type='fill_any_like',
inputs={'X': [x]},
attrs={'value': 0, "dtype": x.dtype},
outputs={'Out': [out]},
)
out.stop_gradient = True
return out
......@@ -1734,8 +1945,12 @@ def diag(diagonal):
"""
check_type(diagonal, 'diagonal', (Variable, numpy.ndarray), 'diag')
check_dtype(diagonal.dtype, 'diagonal',
['float32', 'float64', 'int32', 'int64'], 'diag')
check_dtype(
diagonal.dtype,
'diagonal',
['float32', 'float64', 'int32', 'int64'],
'diag',
)
helper = LayerHelper("diag", **locals())
if not isinstance(diagonal, Variable):
......@@ -1743,21 +1958,19 @@ def diag(diagonal):
out = helper.create_variable_for_type_inference(dtype=diagonal.dtype)
helper.append_op(type='diag',
inputs={'Diagonal': [diagonal]},
outputs={'Out': [out]})
helper.append_op(
type='diag', inputs={'Diagonal': [diagonal]}, outputs={'Out': [out]}
)
out.stop_gradient = True
return out
def eye(num_rows,
num_columns=None,
batch_shape=None,
dtype='float32',
name=None):
def eye(
num_rows, num_columns=None, batch_shape=None, dtype='float32', name=None
):
"""
This function constructs a or a batch of 2-D tensor with ones on the diagonal and zeros elsewhere.
This function constructs a or a batch of 2-D tensor with ones on the diagonal and zeros elsewhere.
Args:
num_rows(int): the number of rows in each batch tensor.
......@@ -1808,25 +2021,33 @@ def eye(num_rows,
num_columns = num_rows
if in_dygraph_mode():
out = _C_ops.eye(num_rows, num_columns, dtype,
_current_expected_place())
out = _C_ops.eye(
num_rows, num_columns, dtype, _current_expected_place()
)
elif _in_legacy_dygraph():
out = _legacy_C_ops.eye('dtype', dtype, 'num_rows', num_rows,
'num_columns', num_columns)
out = _legacy_C_ops.eye(
'dtype', dtype, 'num_rows', num_rows, 'num_columns', num_columns
)
else:
helper = LayerHelper("eye", **locals())
check_dtype(dtype, 'dtype',
['float16', 'float32', 'float64', 'int32', 'int64'], 'eye')
check_dtype(
dtype,
'dtype',
['float16', 'float32', 'float64', 'int32', 'int64'],
'eye',
)
out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op(type='eye',
inputs={},
outputs={'Out': [out]},
attrs={
'num_rows': num_rows,
'num_columns': num_columns,
'dtype': dtype
},
stop_gradient=True)
helper.append_op(
type='eye',
inputs={},
outputs={'Out': [out]},
attrs={
'num_rows': num_rows,
'num_columns': num_columns,
'dtype': dtype,
},
stop_gradient=True,
)
if batch_shape is not None:
re_shape = [1] * len(batch_shape)
......@@ -1838,11 +2059,12 @@ def eye(num_rows,
if not isinstance(batch_shape, list):
raise TypeError("batch_shape should be a list")
for batch_val in (batch_shape):
for batch_val in batch_shape:
if batch_val <= 0:
raise TypeError("batch_shape should be a positive int list")
from .nn import reshape, expand
out = reshape(x=out, shape=re_shape)
out = expand(x=out, expand_times=expand_times)
......@@ -1854,7 +2076,7 @@ def ones_like(x, out=None):
"""
**ones_like**
This function creates a ones tensor which has identical shape and dtype
This function creates a ones tensor which has identical shape and dtype
with `x`.
Args:
......@@ -1873,25 +2095,31 @@ def ones_like(x, out=None):
data = fluid.layers.ones_like(x) # [1.0, 1.0, 1.0]
"""
check_variable_and_dtype(x, "x",
['bool', 'float32', 'float64', 'int32', 'int64'],
'ones_like')
check_variable_and_dtype(
x, "x", ['bool', 'float32', 'float64', 'int32', 'int64'], 'ones_like'
)
helper = LayerHelper("ones_like", **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
check_variable_and_dtype(
out, "out", ['bool', 'float32', 'float64', 'int32', 'int64'],
'ones_like')
helper.append_op(type='fill_any_like',
inputs={'X': [x]},
attrs={'value': 1.0},
outputs={'Out': [out]})
out,
"out",
['bool', 'float32', 'float64', 'int32', 'int64'],
'ones_like',
)
helper.append_op(
type='fill_any_like',
inputs={'X': [x]},
attrs={'value': 1.0},
outputs={'Out': [out]},
)
return out
@deprecated(since="2.0.0", update_to="paddle.triu")
def triu(input, diagonal=0, name=None):
import paddle
return paddle.tensor.triu(x=input, diagonal=diagonal, name=name)
......@@ -411,7 +411,7 @@ def layer_norm(
return dygraph_utils._append_activation_in_dygraph(pre_act, act=None)
check_variable_and_dtype(
x, 'input', ['float16', 'float32', 'float64'], 'LayerNorm'
x, 'input', ['float16', 'float32', 'float64', 'uint16'], 'LayerNorm'
)
inputs = dict()
......
......@@ -482,10 +482,10 @@ def _elementwise_op(helper):
assert x is not None, 'x cannot be None in {}'.format(original_op_type)
assert y is not None, 'y cannot be None in {}'.format(original_op_type)
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'],
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool', 'uint16'],
original_op_type)
check_variable_and_dtype(
y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'],
y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool', 'uint16'],
original_op_type)
axis = helper.kwargs.get('axis', -1)
......@@ -1537,10 +1537,10 @@ def add_n(inputs, name=None):
if len(inputs) > 0:
for input in inputs:
check_variable_and_dtype(input, "inputs", \
['float16', 'float32', 'float64', 'int32', 'int64'], 'add_n')
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'add_n')
else:
check_variable_and_dtype(inputs, "inputs", \
['float16', 'float32', 'float64', 'int32', 'int64'], 'add_n')
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'add_n')
out = helper.create_variable_for_type_inference(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册