未验证 提交 ba9a22db 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add bfloat16 support for several operators and apis. (#52696)

* Cherry-pick the register of bfloat16 for amp_kernel, pull request #45541.

* Cherry-pick the master_grad support of adamw, pull request #51141.

* add bf16 for some ops in static mode (#51582)

* Add bfloat16 support for some api in static mode.

* Fix codestyle.

* Revert the change of layer_function_generator.py.

---------
Co-authored-by: shaojie_wang's avatarShaojie WANG <wsjmessi@163.com>
上级 95c3d613
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
#include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/for_range.h"
namespace phi { namespace phi {
template <typename T, typename MT> template <typename T, typename TG, typename MT>
__global__ void AdamWKernelREG(MT beta1, __global__ void AdamWKernelREG(MT beta1,
MT beta2, MT beta2,
MT epsilon, MT epsilon,
...@@ -42,7 +42,7 @@ __global__ void AdamWKernelREG(MT beta1, ...@@ -42,7 +42,7 @@ __global__ void AdamWKernelREG(MT beta1,
const MT* moment2, const MT* moment2,
MT* moment2_out, MT* moment2_out,
const MT* lr_, const MT* lr_,
const T* grad, const TG* grad,
const T* param, const T* param,
T* param_out, T* param_out,
const MT* master_param, const MT* master_param,
...@@ -78,7 +78,7 @@ __global__ void AdamWKernelREG(MT beta1, ...@@ -78,7 +78,7 @@ __global__ void AdamWKernelREG(MT beta1,
} }
} }
template <typename T, typename MT> template <typename T, typename TG, typename MT>
__global__ void AdamWKernelMEM(MT beta1, __global__ void AdamWKernelMEM(MT beta1,
MT beta2, MT beta2,
MT epsilon, MT epsilon,
...@@ -91,7 +91,7 @@ __global__ void AdamWKernelMEM(MT beta1, ...@@ -91,7 +91,7 @@ __global__ void AdamWKernelMEM(MT beta1,
const MT* moment2, const MT* moment2,
MT* moment2_out, MT* moment2_out,
const MT* lr_, const MT* lr_,
const T* grad, const TG* grad,
const T* param, const T* param,
T* param_out, T* param_out,
const MT* master_param, const MT* master_param,
...@@ -167,6 +167,8 @@ void AdamwDenseKernel(const Context& dev_ctx, ...@@ -167,6 +167,8 @@ void AdamwDenseKernel(const Context& dev_ctx,
DenseTensor* master_param_outs) { DenseTensor* master_param_outs) {
using MPDType = typename phi::dtype::MPTypeTrait<T>::Type; using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
const auto grad_type = grad.dtype();
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
MPDType coeff_ = static_cast<MPDType>(coeff); MPDType coeff_ = static_cast<MPDType>(coeff);
...@@ -191,8 +193,10 @@ void AdamwDenseKernel(const Context& dev_ctx, ...@@ -191,8 +193,10 @@ void AdamwDenseKernel(const Context& dev_ctx,
phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out);
phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out);
phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out);
phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); if (!use_global_beta_pow) {
phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out);
phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out);
}
return; return;
} }
...@@ -233,25 +237,49 @@ void AdamwDenseKernel(const Context& dev_ctx, ...@@ -233,25 +237,49 @@ void AdamwDenseKernel(const Context& dev_ctx,
if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) { if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) {
// Compute with betapow in REG // Compute with betapow in REG
AdamWKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>( if (grad_type == phi::DataType::FLOAT32)
beta1_, AdamWKernelREG<T, float, MPDType>
beta2_, <<<blocks, threads, 0, dev_ctx.stream()>>>(
epsilon_, beta1_,
coeff_, beta2_,
lr_ratio_, epsilon_,
*beta1_pow.data<MPDType>(), coeff_,
*beta2_pow.data<MPDType>(), lr_ratio_,
moment1.data<MPDType>(), *beta1_pow.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out), *beta2_pow.data<MPDType>(),
moment2.data<MPDType>(), moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out), dev_ctx.template Alloc<MPDType>(moment1_out),
learning_rate.data<MPDType>(), moment2.data<MPDType>(),
grad.data<T>(), dev_ctx.template Alloc<MPDType>(moment2_out),
param.data<T>(), learning_rate.data<MPDType>(),
dev_ctx.template Alloc<T>(param_out), grad.data<float>(),
master_in_data, param.data<T>(),
master_out_data, dev_ctx.template Alloc<T>(param_out),
param.numel()); master_in_data,
master_out_data,
param.numel());
else
AdamWKernelREG<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
*beta1_pow.data<MPDType>(),
*beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<T>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
if (!use_global_beta_pow) { if (!use_global_beta_pow) {
// Cpu update // Cpu update
dev_ctx.template HostAlloc<MPDType>(beta1_pow_out)[0] = dev_ctx.template HostAlloc<MPDType>(beta1_pow_out)[0] =
...@@ -260,28 +288,50 @@ void AdamwDenseKernel(const Context& dev_ctx, ...@@ -260,28 +288,50 @@ void AdamwDenseKernel(const Context& dev_ctx,
beta2_ * beta2_pow.data<MPDType>()[0]; beta2_ * beta2_pow.data<MPDType>()[0];
} }
} else { } else {
AdamWKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>( if (grad_type == phi::DataType::FLOAT32)
beta1_, AdamWKernelMEM<T, float, MPDType>
beta2_, <<<blocks, threads, 0, dev_ctx.stream()>>>(
epsilon_, beta1_,
coeff_, beta2_,
lr_ratio_, epsilon_,
beta1_pow.data<MPDType>(), coeff_,
beta2_pow.data<MPDType>(), lr_ratio_,
moment1.data<MPDType>(), beta1_pow.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out), beta2_pow.data<MPDType>(),
moment2.data<MPDType>(), moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out), dev_ctx.template Alloc<MPDType>(moment1_out),
learning_rate.data<MPDType>(), moment2.data<MPDType>(),
grad.data<T>(), dev_ctx.template Alloc<MPDType>(moment2_out),
param.data<T>(), learning_rate.data<MPDType>(),
dev_ctx.template Alloc<T>(param_out), grad.data<float>(),
master_in_data, param.data<T>(),
master_out_data, dev_ctx.template Alloc<T>(param_out),
param.numel()); master_in_data,
master_out_data,
param.numel());
else
AdamWKernelMEM<T, T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
epsilon_,
coeff_,
lr_ratio_,
beta1_pow.data<MPDType>(),
beta2_pow.data<MPDType>(),
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<T>(),
param.data<T>(),
dev_ctx.template Alloc<T>(param_out),
master_in_data,
master_out_data,
param.numel());
if (!use_global_beta_pow) { if (!use_global_beta_pow) {
// Update with gpu // Update with gpu
UpdateAdamWBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>( UpdateAdamWBetaPow<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
beta1_, beta1_,
beta2_, beta2_,
beta1_pow.data<MPDType>(), beta1_pow.data<MPDType>(),
...@@ -300,9 +350,21 @@ PD_REGISTER_KERNEL(adamw, ...@@ -300,9 +350,21 @@ PD_REGISTER_KERNEL(adamw,
phi::AdamwDenseKernel, phi::AdamwDenseKernel,
float, float,
double, double,
phi::dtype::float16) { phi::dtype::float16,
phi::dtype::bfloat16) {
// Skip beta1_pow, beta2_pow, skip_update data transform // Skip beta1_pow, beta2_pow, skip_update data transform
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND);
if (kernel_key.dtype() == phi::DataType::FLOAT16 ||
kernel_key.dtype() == phi::DataType::BFLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32);
}
kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED);
kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED);
} }
...@@ -357,7 +357,8 @@ PD_REGISTER_KERNEL(check_finite_and_unscale, ...@@ -357,7 +357,8 @@ PD_REGISTER_KERNEL(check_finite_and_unscale,
phi::CheckFiniteAndUnscaleKernel, phi::CheckFiniteAndUnscaleKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(update_loss_scaling, PD_REGISTER_KERNEL(update_loss_scaling,
GPU, GPU,
...@@ -365,6 +366,7 @@ PD_REGISTER_KERNEL(update_loss_scaling, ...@@ -365,6 +366,7 @@ PD_REGISTER_KERNEL(update_loss_scaling,
phi::UpdateLossScalingKernel, phi::UpdateLossScalingKernel,
float, float,
double, double,
phi::dtype::float16) { phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
} }
...@@ -55,6 +55,7 @@ PD_REGISTER_KERNEL(matmul_with_flatten_grad, ...@@ -55,6 +55,7 @@ PD_REGISTER_KERNEL(matmul_with_flatten_grad,
phi::MatmulWithFlattenGradKernel, phi::MatmulWithFlattenGradKernel,
float, float,
double, double,
phi::dtype::bfloat16,
phi::dtype::float16) {} phi::dtype::float16) {}
PD_REGISTER_KERNEL(matmul_with_flatten_double_grad, PD_REGISTER_KERNEL(matmul_with_flatten_double_grad,
......
...@@ -36,4 +36,5 @@ PD_REGISTER_KERNEL(matmul_with_flatten, ...@@ -36,4 +36,5 @@ PD_REGISTER_KERNEL(matmul_with_flatten,
phi::MatmulWithFlattenKernel, phi::MatmulWithFlattenKernel,
float, float,
double, double,
phi::dtype::bfloat16,
phi::dtype::float16) {} phi::dtype::float16) {}
...@@ -20,15 +20,26 @@ import string ...@@ -20,15 +20,26 @@ import string
from six.moves import cStringIO from six.moves import cStringIO
from ..proto import framework_pb2 from ..proto import framework_pb2
from ..framework import OpProtoHolder, Variable, core, convert_np_dtype_to_dtype_, _non_static_mode, in_dygraph_mode, _in_legacy_dygraph from ..framework import (
OpProtoHolder,
Variable,
core,
convert_np_dtype_to_dtype_,
_non_static_mode,
in_dygraph_mode,
_in_legacy_dygraph,
)
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..data_feeder import check_variable_and_dtype from ..data_feeder import check_variable_and_dtype
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
from paddle import _C_ops, _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
__all__ = [ __all__ = [
'generate_layer_fn', 'generate_activation_fn', 'generate_inplace_fn', 'generate_layer_fn',
'autodoc', 'templatedoc' 'generate_activation_fn',
'generate_inplace_fn',
'autodoc',
'templatedoc',
] ]
...@@ -58,16 +69,16 @@ _two_bang_pattern_ = re.compile(r"!!([^!]+)!!") ...@@ -58,16 +69,16 @@ _two_bang_pattern_ = re.compile(r"!!([^!]+)!!")
def escape_math(text): def escape_math(text):
#return _two_bang_pattern_.sub( # return _two_bang_pattern_.sub(
# r'$$\1$$', # r'$$\1$$',
# _single_dollar_pattern_.sub(r':math:\n`\1`', # _single_dollar_pattern_.sub(r':math:\n`\1`',
# _two_dollar_pattern_.sub(r"!!\1!!", text))) # _two_dollar_pattern_.sub(r"!!\1!!", text)))
return _two_dollar_pattern_.sub(r':math:`\1`', text) return _two_dollar_pattern_.sub(r':math:`\1`', text)
def _generate_doc_string_(op_proto, def _generate_doc_string_(
additional_args_lines=None, op_proto, additional_args_lines=None, skip_attrs_set=None
skip_attrs_set=None): ):
""" """
Generate docstring by OpProto Generate docstring by OpProto
...@@ -147,23 +158,30 @@ def generate_layer_fn(op_type): ...@@ -147,23 +158,30 @@ def generate_layer_fn(op_type):
""" """
op_proto = OpProtoHolder.instance().get_op_proto(op_type) op_proto = OpProtoHolder.instance().get_op_proto(op_type)
not_intermediate_outputs = \ not_intermediate_outputs = [
[output for output in op_proto.outputs if not output.intermediate] output for output in op_proto.outputs if not output.intermediate
intermediate_outputs = \ ]
[output for output in op_proto.outputs if output.intermediate] intermediate_outputs = [
output for output in op_proto.outputs if output.intermediate
]
if len(not_intermediate_outputs) != 1: if len(not_intermediate_outputs) != 1:
raise ValueError("Only one non intermediate output operator can be", raise ValueError(
"automatically generated. {0}".format(op_type)) "Only one non intermediate output operator can be",
"automatically generated. {0}".format(op_type),
)
if not_intermediate_outputs[0].duplicable: if not_intermediate_outputs[0].duplicable:
raise ValueError( raise ValueError(
"Only non duplicable op can be automatically generated.") "Only non duplicable op can be automatically generated."
)
for output in intermediate_outputs: for output in intermediate_outputs:
if output.duplicable: if output.duplicable:
raise ValueError("The op can be automatically generated only when ", raise ValueError(
"all intermediate ops are not duplicable.") "The op can be automatically generated only when ",
"all intermediate ops are not duplicable.",
)
o_name = not_intermediate_outputs[0].name o_name = not_intermediate_outputs[0].name
intermediate_output_names = [output.name for output in intermediate_outputs] intermediate_output_names = [output.name for output in intermediate_outputs]
...@@ -188,14 +206,17 @@ def generate_layer_fn(op_type): ...@@ -188,14 +206,17 @@ def generate_layer_fn(op_type):
for each in val: for each in val:
if not isinstance(each, Variable): if not isinstance(each, Variable):
raise ValueError( raise ValueError(
"input of {0} must be variable".format(op_type)) "input of {0} must be variable".format(op_type)
)
if dtype is None: if dtype is None:
dtype = each.dtype dtype = each.dtype
elif dtype != each.dtype: elif dtype != each.dtype:
raise ValueError( raise ValueError(
"operator {0} must input same dtype. {1} vs {2}".format( "operator {0} must input same dtype. {1} vs {2}".format(
op_type, dtype, each.dtype)) op_type, dtype, each.dtype
)
)
if dtype is None: if dtype is None:
arg_dtype = kwargs.get("dtype") arg_dtype = kwargs.get("dtype")
...@@ -227,8 +248,11 @@ def generate_layer_fn(op_type): ...@@ -227,8 +248,11 @@ def generate_layer_fn(op_type):
outputs = dict() outputs = dict()
out = kwargs.pop(_convert_(o_name), []) out = kwargs.pop(_convert_(o_name), [])
if out: if out:
out_var = out[0] if (isinstance(out, list) out_var = (
or isinstance(out, tuple)) else out out[0]
if (isinstance(out, list) or isinstance(out, tuple))
else out
)
else: else:
out_var = helper.create_variable_for_type_inference(dtype=dtype) out_var = helper.create_variable_for_type_inference(dtype=dtype)
outputs[o_name] = [out_var] outputs[o_name] = [out_var]
...@@ -236,10 +260,9 @@ def generate_layer_fn(op_type): ...@@ -236,10 +260,9 @@ def generate_layer_fn(op_type):
outputs[name] = [ outputs[name] = [
helper.create_variable_for_type_inference(dtype=dtype) helper.create_variable_for_type_inference(dtype=dtype)
] ]
helper.append_op(type=op_type, helper.append_op(
inputs=inputs, type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs
outputs=outputs, )
attrs=kwargs)
return helper.append_activation(out_var) return helper.append_activation(out_var)
func.__name__ = op_type func.__name__ = op_type
...@@ -270,14 +293,26 @@ def generate_activation_fn(op_type): ...@@ -270,14 +293,26 @@ def generate_activation_fn(op_type):
return op(x) return op(x)
if op_type not in ["abs", "exp", "square"]: if op_type not in ["abs", "exp", "square"]:
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], check_variable_and_dtype(
op_type) x, 'x', ['float16', 'float32', 'float64', 'uint16'], op_type
)
else: else:
# abs exp square ops support dtype(int32, int64, float16, float32, float64) # abs exp square ops support dtype(int32, int64, float16, float32, float64)
check_variable_and_dtype(x, 'x', [ check_variable_and_dtype(
'int32', 'int64', 'float16', 'float32', 'float64', 'complex64', x,
'complex128' 'x',
], op_type) [
'int32',
'int64',
'float16',
'float32',
'float64',
'complex64',
'complex128',
'uint16',
],
op_type,
)
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
...@@ -290,7 +325,8 @@ def generate_activation_fn(op_type): ...@@ -290,7 +325,8 @@ def generate_activation_fn(op_type):
op_proto, op_proto,
additional_args_lines=[ additional_args_lines=[
"name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`." "name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`."
]) ],
)
return func return func
...@@ -310,24 +346,31 @@ def generate_inplace_fn(inplace_op_type): ...@@ -310,24 +346,31 @@ def generate_inplace_fn(inplace_op_type):
op = getattr(_legacy_C_ops, inplace_op_type) op = getattr(_legacy_C_ops, inplace_op_type)
return op(x) return op(x)
warnings.warn( warnings.warn(
"In static mode, {}() is the same as {}() and does not perform inplace operation." "In static mode, {}() is the same as {}() and does not perform inplace operation.".format(
.format(inplace_op_type, origin_op_type)) inplace_op_type, origin_op_type
)
)
return generate_activation_fn(origin_op_type)(x, name) return generate_activation_fn(origin_op_type)(x, name)
func.__name__ = inplace_op_type func.__name__ = inplace_op_type
func.__doc__ = """ func.__doc__ = """
Inplace version of ``{0}`` API, the output Tensor will be inplaced with input ``x``. Inplace version of ``{0}`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_fluid_layers_{1}`. Please refer to :ref:`api_fluid_layers_{1}`.
""".format(origin_op_type, origin_op_type) """.format(
origin_op_type, origin_op_type
)
return func return func
def autodoc(comment=""): def autodoc(comment=""):
def __impl__(func): def __impl__(func):
func.__doc__ = _generate_doc_string_( func.__doc__ = (
OpProtoHolder.instance().get_op_proto(func.__name__)) + comment _generate_doc_string_(
OpProtoHolder.instance().get_op_proto(func.__name__)
)
+ comment
)
return func return func
return __impl__ return __impl__
...@@ -372,18 +415,21 @@ def templatedoc(op_type=None): ...@@ -372,18 +415,21 @@ def templatedoc(op_type=None):
for each_input in op_proto.inputs: for each_input in op_proto.inputs:
input_name = _convert_(each_input.name) input_name = _convert_(each_input.name)
args["{0}_comment".format(input_name)] = trim_ending_dot( args["{0}_comment".format(input_name)] = trim_ending_dot(
each_input.comment) each_input.comment
)
args["{0}_type".format(input_name)] = "Variable" args["{0}_type".format(input_name)] = "Variable"
for each_attr in op_proto.attrs: for each_attr in op_proto.attrs:
input_name = _convert_(each_attr.name) input_name = _convert_(each_attr.name)
args["{0}_comment".format(input_name)] = trim_ending_dot( args["{0}_comment".format(input_name)] = trim_ending_dot(
each_attr.comment) each_attr.comment
)
args["{0}_type".format(input_name)] = _type_to_str_(each_attr.type) args["{0}_type".format(input_name)] = _type_to_str_(each_attr.type)
for each_opt in op_proto.outputs: for each_opt in op_proto.outputs:
output_name = _convert_(each_opt.name) output_name = _convert_(each_opt.name)
args["{0}_comment".format(output_name)] = trim_ending_dot( args["{0}_comment".format(output_name)] = trim_ending_dot(
each_opt.comment) each_opt.comment
)
args["{0}_type".format(output_name)] = "Variable" args["{0}_type".format(output_name)] = "Variable"
func.__doc__ = tmpl.substitute(args) func.__doc__ = tmpl.substitute(args)
return func return func
...@@ -393,7 +439,7 @@ def templatedoc(op_type=None): ...@@ -393,7 +439,7 @@ def templatedoc(op_type=None):
def add_sample_code(func, sample_code): def add_sample_code(func, sample_code):
""" """
Append sample code for dynamically generated functions. Append sample code for dynamically generated functions.
Args: Args:
func: The function of the function to be append sample code to. func: The function of the function to be append sample code to.
......
...@@ -5106,7 +5106,7 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None): ...@@ -5106,7 +5106,7 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None):
check_variable_and_dtype( check_variable_and_dtype(
input, input,
'input', 'input',
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'reduce_sum', 'reduce_sum',
) )
helper = LayerHelper('reduce_sum', **locals()) helper = LayerHelper('reduce_sum', **locals())
......
...@@ -21,14 +21,28 @@ import warnings ...@@ -21,14 +21,28 @@ import warnings
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from ..initializer import Initializer from ..initializer import Initializer
from ..framework import _current_expected_place, convert_np_dtype_to_dtype_, _non_static_mode, _varbase_creator, device_guard, _in_legacy_dygraph, in_dygraph_mode, _get_paddle_place from ..framework import (
_current_expected_place,
convert_np_dtype_to_dtype_,
_non_static_mode,
_varbase_creator,
device_guard,
_in_legacy_dygraph,
in_dygraph_mode,
_get_paddle_place,
)
from ..framework import Variable from ..framework import Variable
from ..initializer import Constant from ..initializer import Constant
from ..core import VarDesc from ..core import VarDesc
from .. import core from .. import core
from .layer_function_generator import templatedoc from .layer_function_generator import templatedoc
from . import utils from . import utils
from ..data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype from ..data_feeder import (
check_variable_and_dtype,
check_type,
check_dtype,
convert_dtype,
)
from paddle.utils import deprecated from paddle.utils import deprecated
from .utils import check_shape from .utils import check_shape
...@@ -71,7 +85,7 @@ def create_tensor(dtype, name=None, persistable=False): ...@@ -71,7 +85,7 @@ def create_tensor(dtype, name=None, persistable=False):
Args: Args:
dtype(string|numpy.dtype): the data type of Tensor to be created, the dtype(string|numpy.dtype): the data type of Tensor to be created, the
data type is bool, float16, float32, float64, int8, int16, int32 and int64. data type is bool, float16, float32, float64, int8, int16, int32 and int64.
name(string, optional): The default value is None. Normally there is no need for name(string, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name` user to set this property. For more information, please refer to :ref:`api_guide_Name`
persistable(bool): Set the persistable flag of the create tensor. persistable(bool): Set the persistable flag of the create tensor.
default value is False. default value is False.
...@@ -85,24 +99,32 @@ def create_tensor(dtype, name=None, persistable=False): ...@@ -85,24 +99,32 @@ def create_tensor(dtype, name=None, persistable=False):
import paddle.fluid as fluid import paddle.fluid as fluid
tensor = fluid.layers.create_tensor(dtype='float32') tensor = fluid.layers.create_tensor(dtype='float32')
""" """
check_dtype(dtype, 'dtype', [ check_dtype(
'bool', 'float16', 'float32', 'float64', 'int8', 'int32', 'int32', dtype,
'int64' 'dtype',
], 'create_tensor') [
'bool',
'float16',
'float32',
'float64',
'int8',
'int32',
'int32',
'int64',
],
'create_tensor',
)
helper = LayerHelper("create_tensor", **locals()) helper = LayerHelper("create_tensor", **locals())
return helper.create_variable(name=helper.name, return helper.create_variable(
dtype=dtype, name=helper.name, dtype=dtype, persistable=persistable
persistable=persistable) )
def create_parameter(shape, def create_parameter(
dtype, shape, dtype, name=None, attr=None, is_bias=False, default_initializer=None
name=None, ):
attr=None,
is_bias=False,
default_initializer=None):
""" """
:api_attr: Static Graph :api_attr: Static Graph
This function creates a parameter. The parameter is a learnable variable, which can have This function creates a parameter. The parameter is a learnable variable, which can have
gradient, and can be optimized. gradient, and can be optimized.
...@@ -134,31 +156,55 @@ def create_parameter(shape, ...@@ -134,31 +156,55 @@ def create_parameter(shape,
""" """
check_type(shape, 'shape', (list, tuple, numpy.ndarray), 'create_parameter') check_type(shape, 'shape', (list, tuple, numpy.ndarray), 'create_parameter')
for item in shape: for item in shape:
check_type(item, 'item of shape', check_type(
(int, numpy.uint8, numpy.int8, numpy.int16, numpy.int32, item,
numpy.int64), 'create_parameter') 'item of shape',
(
check_dtype(dtype, 'dtype', [ int,
'bool', 'float16', 'float32', 'float64', 'int8', 'int16', 'int32', numpy.uint8,
'int64', 'uint8' numpy.int8,
], 'create_parameter') numpy.int16,
numpy.int32,
numpy.int64,
),
'create_parameter',
)
check_dtype(
dtype,
'dtype',
[
'bool',
'float16',
'float32',
'float64',
'int8',
'int16',
'int32',
'int64',
'uint8',
],
'create_parameter',
)
check_type(attr, 'attr', (type(None), ParamAttr), 'create_parameter') check_type(attr, 'attr', (type(None), ParamAttr), 'create_parameter')
check_type(default_initializer, 'default_initializer', check_type(
(type(None), Initializer), 'create_parameter') default_initializer,
'default_initializer',
(type(None), Initializer),
'create_parameter',
)
helper = LayerHelper("create_parameter", **locals()) helper = LayerHelper("create_parameter", **locals())
if attr is None: if attr is None:
attr = ParamAttr(name=name) attr = ParamAttr(name=name)
return helper.create_parameter(attr, shape, convert_dtype(dtype), is_bias, return helper.create_parameter(
default_initializer) attr, shape, convert_dtype(dtype), is_bias, default_initializer
)
def create_global_var(shape, def create_global_var(
value, shape, value, dtype, persistable=False, force_cpu=False, name=None
dtype, ):
persistable=False,
force_cpu=False,
name=None):
""" """
This function creates a new tensor variable with value in the global block(block 0). This function creates a new tensor variable with value in the global block(block 0).
...@@ -185,35 +231,53 @@ def create_global_var(shape, ...@@ -185,35 +231,53 @@ def create_global_var(shape,
var = paddle.static.create_global_var(shape=[2,3], value=1.0, dtype='float32', var = paddle.static.create_global_var(shape=[2,3], value=1.0, dtype='float32',
persistable=True, force_cpu=True, name='new_var') persistable=True, force_cpu=True, name='new_var')
""" """
check_type(shape, 'shape', (list, tuple, numpy.ndarray), check_type(
'create_global_var') shape, 'shape', (list, tuple, numpy.ndarray), 'create_global_var'
)
for item in shape: for item in shape:
check_type(item, 'item of shape', check_type(
(int, numpy.uint8, numpy.int8, numpy.int16, numpy.int32, item,
numpy.int64), 'create_global_var') 'item of shape',
(
check_dtype(dtype, 'dtype', [ int,
'bool', numpy.uint8,
'float16', numpy.int8,
'float32', numpy.int16,
'float64', numpy.int32,
'int8', numpy.int64,
'int16', ),
'int32', 'create_global_var',
'int64', )
'uint8',
'uint16', check_dtype(
], 'create_global_var') dtype,
'dtype',
[
'bool',
'float16',
'float32',
'float64',
'int8',
'int16',
'int32',
'int64',
'uint8',
'uint16',
],
'create_global_var',
)
helper = LayerHelper("global_var", **locals()) helper = LayerHelper("global_var", **locals())
var = helper.create_global_variable(dtype=dtype, var = helper.create_global_variable(
shape=shape, dtype=dtype,
persistable=persistable, shape=shape,
name=name, persistable=persistable,
stop_gradient=True) name=name,
helper.set_variable_initializer(var, stop_gradient=True,
initializer=Constant(value=float(value), )
force_cpu=force_cpu)) helper.set_variable_initializer(
var, initializer=Constant(value=float(value), force_cpu=force_cpu)
)
return var return var
...@@ -253,25 +317,50 @@ def cast(x, dtype): ...@@ -253,25 +317,50 @@ def cast(x, dtype):
out = _legacy_C_ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype) out = _legacy_C_ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype)
return out return out
check_variable_and_dtype(x, 'x', [ check_variable_and_dtype(
'bool', 'float16', 'float32', 'float64', 'int16', 'int32', 'int64', x,
'uint8', 'uint16' 'x',
], 'cast') [
check_dtype(dtype, 'dtype', [ 'bool',
'bool', 'float16', 'float32', 'float64', 'int8', 'int16', 'int32', 'float16',
'int64', 'uint8', 'uint16' 'float32',
], 'cast') 'float64',
'int16',
'int32',
'int64',
'uint8',
'uint16',
],
'cast',
)
check_dtype(
dtype,
'dtype',
[
'bool',
'float16',
'float32',
'float64',
'int8',
'int16',
'int32',
'int64',
'uint8',
'uint16',
],
'cast',
)
helper = LayerHelper('cast', **locals()) helper = LayerHelper('cast', **locals())
out = helper.create_variable_for_type_inference( out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=x.stop_gradient) dtype=dtype, stop_gradient=x.stop_gradient
helper.append_op(type='cast', )
inputs={'X': [x]}, helper.append_op(
outputs={'Out': [out]}, type='cast',
attrs={ inputs={'X': [x]},
'in_dtype': x.dtype, outputs={'Out': [out]},
'out_dtype': out.dtype attrs={'in_dtype': x.dtype, 'out_dtype': out.dtype},
}) )
return out return out
...@@ -281,7 +370,7 @@ def concat(input, axis=0, name=None): ...@@ -281,7 +370,7 @@ def concat(input, axis=0, name=None):
Args: Args:
input(list|tuple|Tensor): ``input`` can be Tensor, Tensor list or Tensor tuple which is with data type input(list|tuple|Tensor): ``input`` can be Tensor, Tensor list or Tensor tuple which is with data type
bool, float16, float32, float64, int32, int64. All the Tensors in ``input`` must have the same data type. bool, float16, float32, float64, int32, int64. All the Tensors in ``input`` must have the same data type.
axis(int|Tensor, optional): Specify the axis to operate on the input Tensors. axis(int|Tensor, optional): Specify the axis to operate on the input Tensors.
It's a scalar with data type int or a Tensor with shape [1] and data type int32 or int64. It's a scalar with data type int or a Tensor with shape [1] and data type int32 or int64.
The effective range is [-R, R), where R is Rank(x). When ``axis < 0``, it works the same way The effective range is [-R, R), where R is Rank(x). When ``axis < 0``, it works the same way
...@@ -346,9 +435,11 @@ def concat(input, axis=0, name=None): ...@@ -346,9 +435,11 @@ def concat(input, axis=0, name=None):
if not isinstance(input, Variable): if not isinstance(input, Variable):
for id, x in enumerate(input): for id, x in enumerate(input):
check_variable_and_dtype( check_variable_and_dtype(
x, 'input[' + str(id) + ']', x,
'input[' + str(id) + ']',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'concat') 'concat',
)
if x.dtype != input[0].dtype: if x.dtype != input[0].dtype:
raise TypeError( raise TypeError(
"All the Tensors in the input must have the same data type." "All the Tensors in the input must have the same data type."
...@@ -359,8 +450,11 @@ def concat(input, axis=0, name=None): ...@@ -359,8 +450,11 @@ def concat(input, axis=0, name=None):
if isinstance(axis, Variable): if isinstance(axis, Variable):
check_dtype( check_dtype(
axis.dtype, 'axis', ['int32', 'int64'], 'concat', axis.dtype,
"The data type of axis must be int32 or int64 when axis is a Tensor" 'axis',
['int32', 'int64'],
'concat',
"The data type of axis must be int32 or int64 when axis is a Tensor",
) )
helper = LayerHelper('concat', **locals()) helper = LayerHelper('concat', **locals())
...@@ -371,19 +465,17 @@ def concat(input, axis=0, name=None): ...@@ -371,19 +465,17 @@ def concat(input, axis=0, name=None):
# This feature is supported for Dynamic-to-Static, because after transformed, the type of inputs[0] # This feature is supported for Dynamic-to-Static, because after transformed, the type of inputs[0]
# is LOD_TENSOR_ARRAY in some scenarios. And this feature can be used in static mode. # is LOD_TENSOR_ARRAY in some scenarios. And this feature can be used in static mode.
assert len(input) == 1, "If the elements of 'input' in concat are Variable(LoDTensorArray), " \ assert len(input) == 1, (
"number of the elements must be 1, but received %s." % len(input) "If the elements of 'input' in concat are Variable(LoDTensorArray), "
"number of the elements must be 1, but received %s." % len(input)
)
out_index = helper.create_variable_for_type_inference(dtype="int32") out_index = helper.create_variable_for_type_inference(dtype="int32")
helper.append_op(type='tensor_array_to_tensor', helper.append_op(
inputs={'X': input[0]}, type='tensor_array_to_tensor',
outputs={ inputs={'X': input[0]},
'Out': [out], outputs={'Out': [out], 'OutIndex': [out_index]},
'OutIndex': [out_index] attrs={'axis': axis, 'use_stack': False},
}, )
attrs={
'axis': axis,
'use_stack': False
})
else: else:
inputs = {'X': input} inputs = {'X': input}
attrs = {} attrs = {}
...@@ -391,10 +483,9 @@ def concat(input, axis=0, name=None): ...@@ -391,10 +483,9 @@ def concat(input, axis=0, name=None):
axis.stop_gradient = True axis.stop_gradient = True
attrs['axis'] = axis attrs['axis'] = axis
helper.append_op(type='concat', helper.append_op(
inputs=inputs, type='concat', inputs=inputs, outputs={'Out': [out]}, attrs=attrs
outputs={'Out': [out]}, )
attrs=attrs)
return out return out
...@@ -480,33 +571,36 @@ def tensor_array_to_tensor(input, axis=1, name=None, use_stack=False): ...@@ -480,33 +571,36 @@ def tensor_array_to_tensor(input, axis=1, name=None, use_stack=False):
""" """
if _non_static_mode(): if _non_static_mode():
assert isinstance( assert isinstance(
input, list), "The 'input' in tensor_array_to_tensor must be list" input, list
), "The 'input' in tensor_array_to_tensor must be list"
from .nn import stack, concat from .nn import stack, concat
from ..dygraph import to_variable from ..dygraph import to_variable
op = stack if use_stack else concat op = stack if use_stack else concat
res = op(input, axis=axis) res = op(input, axis=axis)
sizes = to_variable( sizes = to_variable(
numpy.array(list(map(lambda x: int(x.shape[axis]), input)))) numpy.array(list(map(lambda x: int(x.shape[axis]), input)))
)
return res, sizes return res, sizes
check_type(input, 'input', (list, Variable), 'tensor_array_to_tensor') check_type(input, 'input', (list, Variable), 'tensor_array_to_tensor')
if isinstance(input, list): if isinstance(input, list):
for i, input_x in enumerate(input): for i, input_x in enumerate(input):
check_type(input_x, 'input[' + str(i) + ']', Variable, check_type(
'tensor_array_to_tensor') input_x,
'input[' + str(i) + ']',
Variable,
'tensor_array_to_tensor',
)
helper = LayerHelper('tensor_array_to_tensor', **locals()) helper = LayerHelper('tensor_array_to_tensor', **locals())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
out_index = helper.create_variable_for_type_inference(dtype="int32") out_index = helper.create_variable_for_type_inference(dtype="int32")
helper.append_op(type='tensor_array_to_tensor', helper.append_op(
inputs={'X': input}, type='tensor_array_to_tensor',
outputs={ inputs={'X': input},
'Out': [out], outputs={'Out': [out], 'OutIndex': [out_index]},
'OutIndex': [out_index] attrs={'axis': axis, 'use_stack': use_stack},
}, )
attrs={
'axis': axis,
'use_stack': use_stack
})
return out, out_index return out, out_index
...@@ -563,25 +657,36 @@ def sums(input, out=None): ...@@ -563,25 +657,36 @@ def sums(input, out=None):
check_type(input, 'input', (Variable, tuple, list), 'sums') check_type(input, 'input', (Variable, tuple, list), 'sums')
if isinstance(input, list) or isinstance(input, tuple): if isinstance(input, list) or isinstance(input, tuple):
for input_section in input: for input_section in input:
check_variable_and_dtype(input_section, "input", \ check_variable_and_dtype(
['float16', 'float32', 'float64', 'int32', 'int64'], 'sums') input_section,
"input",
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'sums',
)
else: else:
check_variable_and_dtype(input, "input", \ check_variable_and_dtype(
['float16', 'float32', 'float64', 'int32', 'int64'], 'sums') input,
"input",
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'sums',
)
helper = LayerHelper('sum', **locals()) helper = LayerHelper('sum', **locals())
if out is None: if out is None:
out = helper.create_variable_for_type_inference( out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype()) dtype=helper.input_dtype()
)
else: else:
check_variable_and_dtype(out, "out", check_variable_and_dtype(
['float32', 'float64', 'int32', 'int64'], out, "out", ['float32', 'float64', 'int32', 'int64'], 'sums'
'sums') )
helper.append_op(type='sum', helper.append_op(
inputs={'X': input}, type='sum',
outputs={'Out': out}, inputs={'X': input},
attrs={'use_mkldnn': False}) outputs={'Out': out},
attrs={'use_mkldnn': False},
)
return out return out
...@@ -616,9 +721,12 @@ def assign(input, output=None): ...@@ -616,9 +721,12 @@ def assign(input, output=None):
result3 = paddle.assign(np.array([[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]], dtype='float32')) # result3 = [[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]] result3 = paddle.assign(np.array([[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]], dtype='float32')) # result3 = [[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]]
""" """
helper = LayerHelper('assign', **locals()) helper = LayerHelper('assign', **locals())
check_type(input, 'input', check_type(
(Variable, numpy.ndarray, list, tuple, float, int, bool), input,
'assign') 'input',
(Variable, numpy.ndarray, list, tuple, float, int, bool),
'assign',
)
is_inplace = True if output is not None else False is_inplace = True if output is not None else False
if numpy.isscalar(input) and not isinstance(input, str): if numpy.isscalar(input) and not isinstance(input, str):
...@@ -644,16 +752,29 @@ def assign(input, output=None): ...@@ -644,16 +752,29 @@ def assign(input, output=None):
output = core.eager.Tensor() output = core.eager.Tensor()
_legacy_C_ops.assign(input, output) _legacy_C_ops.assign(input, output)
else: else:
check_dtype(input.dtype, 'input', [ check_dtype(
'float16', 'uint16', 'float32', 'float64', 'int32', 'int64', input.dtype,
'uint8', 'bool' 'input',
], 'assign', '(When the type of input in assign is Variable.)') [
'float16',
'uint16',
'float32',
'float64',
'int32',
'int64',
'uint8',
'bool',
],
'assign',
'(When the type of input in assign is Variable.)',
)
if output is None: if output is None:
output = helper.create_variable_for_type_inference( output = helper.create_variable_for_type_inference(
dtype=input.dtype) dtype=input.dtype
helper.append_op(type='assign', )
inputs={'X': [input]}, helper.append_op(
outputs={'Out': [output]}) type='assign', inputs={'X': [input]}, outputs={'Out': [output]}
)
elif isinstance(input, numpy.ndarray): elif isinstance(input, numpy.ndarray):
# Not support [var, var, ...] currently. # Not support [var, var, ...] currently.
if len(input.shape) > 0 and any(isinstance(x, Variable) for x in input): if len(input.shape) > 0 and any(isinstance(x, Variable) for x in input):
...@@ -667,7 +788,8 @@ def assign(input, output=None): ...@@ -667,7 +788,8 @@ def assign(input, output=None):
warnings.warn( warnings.warn(
"paddle.assign doesn't support float64 input now due " "paddle.assign doesn't support float64 input now due "
"to current platform protobuf data limitation, we convert " "to current platform protobuf data limitation, we convert "
"it to float32") "it to float32"
)
dtype = VarDesc.VarType.FP32 dtype = VarDesc.VarType.FP32
if dtype == VarDesc.VarType.BOOL: if dtype == VarDesc.VarType.BOOL:
value_name = "bool_values" value_name = "bool_values"
...@@ -685,31 +807,49 @@ def assign(input, output=None): ...@@ -685,31 +807,49 @@ def assign(input, output=None):
raise TypeError( raise TypeError(
"When the type of 'input' in assign is numpy.ndarray, " "When the type of 'input' in assign is numpy.ndarray, "
"the data type of 'input' must be bool, float32, int32 or int64, but " "the data type of 'input' must be bool, float32, int32 or int64, but "
"received %s." % convert_dtype(dtype)) "received %s." % convert_dtype(dtype)
)
if input.size > 1024 * 1024: if input.size > 1024 * 1024:
raise ValueError("The size of input is too big. Please consider " raise ValueError(
"saving it to file and 'load_op' to load it") "The size of input is too big. Please consider "
"saving it to file and 'load_op' to load it"
)
if in_dygraph_mode(): if in_dygraph_mode():
if output is None: if output is None:
output = zeros(list(input.shape), dtype) output = zeros(list(input.shape), dtype)
_C_ops.assign_value_(output, list(input.shape), dtype, values, _C_ops.assign_value_(
_current_expected_place()) output,
list(input.shape),
dtype,
values,
_current_expected_place(),
)
elif _in_legacy_dygraph(): elif _in_legacy_dygraph():
if output is None: if output is None:
output = core.VarBase() output = core.VarBase()
_legacy_C_ops.assign_value(output, 'shape', list(input.shape), _legacy_C_ops.assign_value(
'dtype', dtype, value_name, values) output,
'shape',
list(input.shape),
'dtype',
dtype,
value_name,
values,
)
else: else:
if output is None: if output is None:
output = helper.create_variable_for_type_inference( output = helper.create_variable_for_type_inference(
dtype=input.dtype) dtype=input.dtype
helper.append_op(type='assign_value', )
outputs={'Out': [output]}, helper.append_op(
attrs={ type='assign_value',
'dtype': dtype, outputs={'Out': [output]},
'shape': list(input.shape), attrs={
value_name: values 'dtype': dtype,
}) 'shape': list(input.shape),
value_name: values,
},
)
if is_inplace and _non_static_mode(): if is_inplace and _non_static_mode():
output._bump_inplace_version() output._bump_inplace_version()
...@@ -731,10 +871,10 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -731,10 +871,10 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
If ``shape`` is an Tensor, it should be an 1-D Tensor with date type int32 or int64. If ``shape`` is an Tensor, it should be an 1-D Tensor with date type int32 or int64.
dtype(np.dtype|str): Data type of the output Tensor which can dtype(np.dtype|str): Data type of the output Tensor which can
be float16, float32, float64, uint8, int16, int32, int64. be float16, float32, float64, uint8, int16, int32, int64.
value(bool|float|int|Tensor): The constant value used to initialize value(bool|float|int|Tensor): The constant value used to initialize
the Tensor to be created. If ``value`` is an Tensor, it should be an 1-D Tensor. the Tensor to be created. If ``value`` is an Tensor, it should be an 1-D Tensor.
force_cpu(bool, optional): data should be on CPU if it's true, default value is False. force_cpu(bool, optional): data should be on CPU if it's true, default value is False.
out(Tensor, optional): Optional output which can be any created out(Tensor, optional): Optional output which can be any created
Tensor that meets the requirements to store the result of operation. Tensor that meets the requirements to store the result of operation.
if ``out`` is None, a new Tensor will be create to store the result. if ``out`` is None, a new Tensor will be create to store the result.
name(str, optional): The default value is None. Normally there is no need for user to set this name(str, optional): The default value is None. Normally there is no need for user to set this
...@@ -759,7 +899,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -759,7 +899,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
# attr shape is a Tensor. # attr shape is a Tensor.
shape = fluid.layers.fill_constant([2], "int32", 2) # shape=[2,2] shape = fluid.layers.fill_constant([2], "int32", 2) # shape=[2,2]
data4 = fluid.layers.fill_constant(shape=shape, dtype='bool', value=True) # data4=[[True,True],[True,True]] data4 = fluid.layers.fill_constant(shape=shape, dtype='bool', value=True) # data4=[[True,True],[True,True]]
# attr value is a Tensor. # attr value is a Tensor.
val = fluid.layers.fill_constant([1], "float32", 2.0) # val=[2.0] val = fluid.layers.fill_constant([1], "float32", 2.0) # val=[2.0]
data5 = fluid.layers.fill_constant(shape=[2,1], value=val, dtype='float32') #data5=[[2.0],[2.0]] data5 = fluid.layers.fill_constant(shape=[2,1], value=val, dtype='float32') #data5=[[2.0],[2.0]]
...@@ -785,7 +925,11 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -785,7 +925,11 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
shape = list( shape = list(
map( map(
lambda x: x.numpy().flat[0] lambda x: x.numpy().flat[0]
if isinstance(x, Variable) else x, shape)) if isinstance(x, Variable)
else x,
shape,
)
)
break break
if not isinstance(dtype, core.VarDesc.VarType): if not isinstance(dtype, core.VarDesc.VarType):
...@@ -813,9 +957,19 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -813,9 +957,19 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
else: else:
attrs['str_value'] = str(float(value.numpy().item(0))) attrs['str_value'] = str(float(value.numpy().item(0)))
_legacy_C_ops.fill_constant(out, 'value', float(value), 'force_cpu', _legacy_C_ops.fill_constant(
force_cpu, 'dtype', out.dtype, 'str_value', out,
attrs['str_value'], 'shape', shape) 'value',
float(value),
'force_cpu',
force_cpu,
'dtype',
out.dtype,
'str_value',
attrs['str_value'],
'shape',
shape,
)
out.stop_gradient = True out.stop_gradient = True
return out return out
...@@ -827,43 +981,61 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -827,43 +981,61 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
inputs['ValueTensor'] = value inputs['ValueTensor'] = value
check_shape(shape) check_shape(shape)
check_dtype(dtype, 'dtype', [ check_dtype(
'bool', 'float16', 'float32', 'float64', 'uint8', 'int16', 'int32', dtype,
'int64', 'complex64', 'complex128' 'dtype',
], 'fill_constant') [
'bool',
'float16',
'float32',
'float64',
'uint8',
'int16',
'int32',
'int64',
'complex64',
'complex128',
'uint16',
],
'fill_constant',
)
check_type(shape, 'shape', (Variable, list, tuple), 'fill_constant') check_type(shape, 'shape', (Variable, list, tuple), 'fill_constant')
if out is not None: if out is not None:
check_variable_and_dtype(out, 'out', [convert_dtype(dtype)], check_variable_and_dtype(
'fill_constant') out, 'out', [convert_dtype(dtype)], 'fill_constant'
)
helper = LayerHelper("fill_constant", **locals()) helper = LayerHelper("fill_constant", **locals())
utils.get_shape_tensor_inputs(inputs=inputs, utils.get_shape_tensor_inputs(
attrs=attrs, inputs=inputs, attrs=attrs, shape=shape, op_type='fill_constant'
shape=shape, )
op_type='fill_constant')
if out is None: if out is None:
out = helper.create_variable_for_type_inference(dtype=dtype) out = helper.create_variable_for_type_inference(dtype=dtype)
attrs['dtype'] = out.dtype attrs['dtype'] = out.dtype
helper.append_op(type='fill_constant', helper.append_op(
inputs=inputs, type='fill_constant',
outputs={'Out': [out]}, inputs=inputs,
attrs=attrs, outputs={'Out': [out]},
stop_gradient=True) attrs=attrs,
stop_gradient=True,
)
out.stop_gradient = True out.stop_gradient = True
return out return out
@deprecated(since='1.8.0', update_to="paddle.fluid.layers.fill_constant") @deprecated(since='1.8.0', update_to="paddle.fluid.layers.fill_constant")
@templatedoc() @templatedoc()
def fill_constant_batch_size_like(input, def fill_constant_batch_size_like(
shape, input,
dtype, shape,
value, dtype,
input_dim_idx=0, value,
output_dim_idx=0, input_dim_idx=0,
force_cpu=False): output_dim_idx=0,
force_cpu=False,
):
""" """
This OP creates a Tesnor according the shape and dtype, and initializes the This OP creates a Tesnor according the shape and dtype, and initializes the
Tensor with the constants provided in ``value``. When the input is LoDTensor Tensor with the constants provided in ``value``. When the input is LoDTensor
...@@ -877,7 +1049,7 @@ def fill_constant_batch_size_like(input, ...@@ -877,7 +1049,7 @@ def fill_constant_batch_size_like(input,
according the input. according the input.
dtype(np.dtype|core.VarDesc.VarType|str): The data type of created Tensor which dtype(np.dtype|core.VarDesc.VarType|str): The data type of created Tensor which
can be float32, float64, int32, int64. can be float32, float64, int32, int64.
value(float|int): The constant value used to initialize the Tensor to be created. value(float|int): The constant value used to initialize the Tensor to be created.
input_dim_idx(int): When the value is 0 and the input is LoDTensor, the output_dim_idx input_dim_idx(int): When the value is 0 and the input is LoDTensor, the output_dim_idx
dimension of the created Tensor is set to the batch_size value of input. dimension of the created Tensor is set to the batch_size value of input.
The default value is 0. The default value is 0.
...@@ -905,8 +1077,9 @@ def fill_constant_batch_size_like(input, ...@@ -905,8 +1077,9 @@ def fill_constant_batch_size_like(input,
place = _current_expected_place() place = _current_expected_place()
if force_cpu: if force_cpu:
place = core.CPUPlace() place = core.CPUPlace()
out = _C_ops.full_batch_size_like(input, shape, dtype, value, out = _C_ops.full_batch_size_like(
input_dim_idx, output_dim_idx, place) input, shape, dtype, value, input_dim_idx, output_dim_idx, place
)
out.stop_gradient = True out.stop_gradient = True
return out return out
...@@ -918,25 +1091,27 @@ def fill_constant_batch_size_like(input, ...@@ -918,25 +1091,27 @@ def fill_constant_batch_size_like(input,
'value': float(value), 'value': float(value),
'input_dim_idx': input_dim_idx, 'input_dim_idx': input_dim_idx,
'output_dim_idx': output_dim_idx, 'output_dim_idx': output_dim_idx,
'force_cpu': force_cpu 'force_cpu': force_cpu,
} }
if convert_dtype(dtype) in ['int64', 'int32']: if convert_dtype(dtype) in ['int64', 'int32']:
attrs['str_value'] = str(int(value)) attrs['str_value'] = str(int(value))
else: else:
attrs['str_value'] = str(float(value)) attrs['str_value'] = str(float(value))
helper.append_op(type='fill_constant_batch_size_like', helper.append_op(
inputs={'Input': input}, type='fill_constant_batch_size_like',
outputs={'Out': [out]}, inputs={'Input': input},
attrs=attrs) outputs={'Out': [out]},
attrs=attrs,
)
out.stop_gradient = True out.stop_gradient = True
return out return out
def argmin(x, axis=0): def argmin(x, axis=0):
""" """
:alias_main: paddle.argmin :alias_main: paddle.argmin
:alias: paddle.argmin,paddle.tensor.argmin,paddle.tensor.search.argmin :alias: paddle.argmin,paddle.tensor.argmin,paddle.tensor.search.argmin
:old_api: paddle.fluid.layers.argmin :old_api: paddle.fluid.layers.argmin
**argmin** **argmin**
...@@ -986,14 +1161,19 @@ def argmin(x, axis=0): ...@@ -986,14 +1161,19 @@ def argmin(x, axis=0):
# [1 0 2]] # [1 0 2]]
""" """
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'uint8', 'int16', 'int32', 'int64'], x,
'argmin') 'x',
['float32', 'float64', 'uint8', 'int16', 'int32', 'int64'],
'argmin',
)
helper = LayerHelper("arg_min", **locals()) helper = LayerHelper("arg_min", **locals())
out = helper.create_variable_for_type_inference(VarDesc.VarType.INT64) out = helper.create_variable_for_type_inference(VarDesc.VarType.INT64)
helper.append_op(type='arg_min', helper.append_op(
inputs={'X': x}, type='arg_min',
outputs={'Out': [out]}, inputs={'X': x},
attrs={'axis': axis}) outputs={'Out': [out]},
attrs={'axis': axis},
)
out.stop_gradient = True out.stop_gradient = True
return out return out
...@@ -1048,23 +1228,28 @@ def argmax(x, axis=0): ...@@ -1048,23 +1228,28 @@ def argmax(x, axis=0):
# [0 3 1]] # [0 3 1]]
""" """
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'uint8', 'int16', 'int32', 'int64'], x,
'argmax') 'x',
['float32', 'float64', 'uint8', 'int16', 'int32', 'int64'],
'argmax',
)
helper = LayerHelper("arg_max", **locals()) helper = LayerHelper("arg_max", **locals())
out = helper.create_variable_for_type_inference(VarDesc.VarType.INT64) out = helper.create_variable_for_type_inference(VarDesc.VarType.INT64)
helper.append_op(type='arg_max', helper.append_op(
inputs={'X': x}, type='arg_max',
outputs={'Out': [out]}, inputs={'X': x},
attrs={'axis': axis}) outputs={'Out': [out]},
attrs={'axis': axis},
)
out.stop_gradient = True out.stop_gradient = True
return out return out
def argsort(input, axis=-1, descending=False, name=None): def argsort(input, axis=-1, descending=False, name=None):
""" """
:alias_main: paddle.argsort :alias_main: paddle.argsort
:alias: paddle.argsort,paddle.tensor.argsort,paddle.tensor.search.argsort :alias: paddle.argsort,paddle.tensor.argsort,paddle.tensor.search.argsort
:old_api: paddle.fluid.layers.argsort :old_api: paddle.fluid.layers.argsort
This OP sorts the input along the given axis, and returns sorted output This OP sorts the input along the given axis, and returns sorted output
data Varibale and its corresponding index Variable with the same shape as data Varibale and its corresponding index Variable with the same shape as
...@@ -1135,23 +1320,24 @@ def argsort(input, axis=-1, descending=False, name=None): ...@@ -1135,23 +1320,24 @@ def argsort(input, axis=-1, descending=False, name=None):
# [5. 7. 7. 9.]]] # [5. 7. 7. 9.]]]
""" """
check_variable_and_dtype( check_variable_and_dtype(
input, 'input', input,
['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'], 'argsort') 'input',
['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'],
'argsort',
)
helper = LayerHelper("argsort", **locals()) helper = LayerHelper("argsort", **locals())
out = helper.create_variable_for_type_inference(dtype=input.dtype, out = helper.create_variable_for_type_inference(
stop_gradient=True) dtype=input.dtype, stop_gradient=True
ids = helper.create_variable_for_type_inference(VarDesc.VarType.INT64, )
stop_gradient=True) ids = helper.create_variable_for_type_inference(
helper.append_op(type='argsort', VarDesc.VarType.INT64, stop_gradient=True
inputs={'X': input}, )
outputs={ helper.append_op(
'Out': out, type='argsort',
'Indices': ids inputs={'X': input},
}, outputs={'Out': out, 'Indices': ids},
attrs={ attrs={'axis': axis, 'descending': descending},
'axis': axis, )
'descending': descending
})
return out, ids return out, ids
...@@ -1176,7 +1362,7 @@ def ones(shape, dtype, force_cpu=False): ...@@ -1176,7 +1362,7 @@ def ones(shape, dtype, force_cpu=False):
import paddle.fluid as fluid import paddle.fluid as fluid
data0 = fluid.layers.ones(shape=[2, 4], dtype='float32') # [[1., 1., 1., 1.], [1., 1., 1., 1.]] data0 = fluid.layers.ones(shape=[2, 4], dtype='float32') # [[1., 1., 1., 1.], [1., 1., 1., 1.]]
# shape is a Tensor # shape is a Tensor
shape = fluid.layers.fill_constant(shape=[2], dtype='int32', value=2) shape = fluid.layers.fill_constant(shape=[2], dtype='int32', value=2)
data1 = fluid.layers.ones(shape=shape, dtype='int32') #[[1, 1], [1, 1]] data1 = fluid.layers.ones(shape=shape, dtype='int32') #[[1, 1], [1, 1]]
...@@ -1207,7 +1393,7 @@ def zeros(shape, dtype, force_cpu=False, name=None): ...@@ -1207,7 +1393,7 @@ def zeros(shape, dtype, force_cpu=False, name=None):
import paddle.fluid as fluid import paddle.fluid as fluid
data = fluid.layers.zeros(shape=[3, 2], dtype='float32') # [[0., 0.], [0., 0.], [0., 0.]] data = fluid.layers.zeros(shape=[3, 2], dtype='float32') # [[0., 0.], [0., 0.], [0., 0.]]
# shape is a Tensor # shape is a Tensor
shape = fluid.layers.fill_constant(shape=[2], dtype='int32', value=2) shape = fluid.layers.fill_constant(shape=[2], dtype='int32', value=2)
data1 = fluid.layers.zeros(shape=shape, dtype='int32') #[[0, 0], [0, 0]] data1 = fluid.layers.zeros(shape=shape, dtype='int32') #[[0, 0], [0, 0]]
...@@ -1217,9 +1403,9 @@ def zeros(shape, dtype, force_cpu=False, name=None): ...@@ -1217,9 +1403,9 @@ def zeros(shape, dtype, force_cpu=False, name=None):
def reverse(x, axis): def reverse(x, axis):
""" """
:alias_main: paddle.reverse :alias_main: paddle.reverse
:alias: paddle.reverse,paddle.tensor.reverse,paddle.tensor.manipulation.reverse :alias: paddle.reverse,paddle.tensor.reverse,paddle.tensor.manipulation.reverse
:old_api: paddle.fluid.layers.reverse :old_api: paddle.fluid.layers.reverse
The OP reverses the tensor :attr:`x` along the given :attr:`axis`. The OP reverses the tensor :attr:`x` along the given :attr:`axis`.
...@@ -1277,9 +1463,9 @@ def reverse(x, axis): ...@@ -1277,9 +1463,9 @@ def reverse(x, axis):
reversed_tensor_array = fluid.layers.reverse(tensor_array, 0) # {[[3, 4, 5]], [[0, 1, 2]]} reversed_tensor_array = fluid.layers.reverse(tensor_array, 0) # {[[3, 4, 5]], [[0, 1, 2]]}
""" """
check_variable_and_dtype(x, 'x', check_variable_and_dtype(
('float32', 'float64', 'int32', 'int64', 'uint8'), x, 'x', ('float32', 'float64', 'int32', 'int64', 'uint8'), 'reverse'
'reverse') )
check_type(axis, 'axis', (int, tuple, list, Variable), 'reverse') check_type(axis, 'axis', (int, tuple, list, Variable), 'reverse')
if isinstance(axis, int): if isinstance(axis, int):
axis = [axis] axis = [axis]
...@@ -1287,10 +1473,12 @@ def reverse(x, axis): ...@@ -1287,10 +1473,12 @@ def reverse(x, axis):
return _C_ops.reverse(x, axis) return _C_ops.reverse(x, axis)
helper = LayerHelper("reverse", **locals()) helper = LayerHelper("reverse", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='reverse', helper.append_op(
inputs={'X': x}, type='reverse',
outputs={'Out': [out]}, inputs={'X': x},
attrs={'axis': axis}) outputs={'Out': [out]},
attrs={'axis': axis},
)
return out return out
...@@ -1306,13 +1494,12 @@ def save(x, file_path, overwrite=True): ...@@ -1306,13 +1494,12 @@ def save(x, file_path, overwrite=True):
error will be thrown. error will be thrown.
""" """
helper = LayerHelper("save", **locals()) helper = LayerHelper("save", **locals())
helper.append_op(type="save", helper.append_op(
inputs={"input": x}, type="save",
outputs={}, inputs={"input": x},
args={ outputs={},
"file_path": file_path, args={"file_path": file_path, "overwrite": overwrite},
"overwrite": overwrite )
})
def save_combine(x, file_path, overwrite=True): def save_combine(x, file_path, overwrite=True):
...@@ -1344,13 +1531,12 @@ def save_combine(x, file_path, overwrite=True): ...@@ -1344,13 +1531,12 @@ def save_combine(x, file_path, overwrite=True):
normed = fluid.layers.save_combine([v1, v2], file_path="output") normed = fluid.layers.save_combine([v1, v2], file_path="output")
""" """
helper = LayerHelper("save_combine", **locals()) helper = LayerHelper("save_combine", **locals())
helper.append_op(type="save_combine", helper.append_op(
inputs={"input": x}, type="save_combine",
outputs={}, inputs={"input": x},
args={ outputs={},
"file_path": file_path, args={"file_path": file_path, "overwrite": overwrite},
"overwrite": overwrite )
})
def load_combine(out, file_path): def load_combine(out, file_path):
...@@ -1362,10 +1548,12 @@ def load_combine(out, file_path): ...@@ -1362,10 +1548,12 @@ def load_combine(out, file_path):
file_path(str): The path of the disk file. file_path(str): The path of the disk file.
""" """
helper = LayerHelper("load_combine", **locals()) helper = LayerHelper("load_combine", **locals())
helper.append_op(type="load_combine", helper.append_op(
inputs={}, type="load_combine",
output={"Out": out}, inputs={},
args={"file_path": file_path}) output={"Out": out},
args={"file_path": file_path},
)
def has_inf(x): def has_inf(x):
...@@ -1377,10 +1565,10 @@ def has_inf(x): ...@@ -1377,10 +1565,10 @@ def has_inf(x):
Returns: Returns:
Tensor: The tensor storing the output, only a bool value, indicating that whether there is infinity number in x or not. Tensor: The tensor storing the output, only a bool value, indicating that whether there is infinity number in x or not.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
data = paddle.randn(shape=[4, 32, 32], dtype="float32") data = paddle.randn(shape=[4, 32, 32], dtype="float32")
res = paddle.fluid.layers.has_inf(data) res = paddle.fluid.layers.has_inf(data)
...@@ -1406,10 +1594,10 @@ def has_nan(x): ...@@ -1406,10 +1594,10 @@ def has_nan(x):
Returns: Returns:
Tensor: The tensor variable storing the output, only a bool value, indicating that whether there is NAN in x or not. Tensor: The tensor variable storing the output, only a bool value, indicating that whether there is NAN in x or not.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
data = paddle.randn(shape=[2,3], dtype="float32") data = paddle.randn(shape=[2,3], dtype="float32")
res = paddle.fluid.layers.has_nan(data) res = paddle.fluid.layers.has_nan(data)
...@@ -1449,8 +1637,9 @@ def isfinite(x): ...@@ -1449,8 +1637,9 @@ def isfinite(x):
print(y) print(y)
""" """
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"], check_variable_and_dtype(
"isfinite") x, "x", ["float32", "float64", "int32", "int64"], "isfinite"
)
helper = LayerHelper("isfinite", **locals()) helper = LayerHelper("isfinite", **locals())
out = helper.create_variable_for_type_inference(dtype='bool') out = helper.create_variable_for_type_inference(dtype='bool')
...@@ -1485,7 +1674,7 @@ def range(start, end, step, dtype, name=None): ...@@ -1485,7 +1674,7 @@ def range(start, end, step, dtype, name=None):
need for user to set this property. For more information, please need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`. refer to :ref:`api_guide_Name`.
Returns: Returns:
Tensor: A 1-D Tensor with values from the interval [``start``, ``end``) Tensor: A 1-D Tensor with values from the interval [``start``, ``end``)
taken with common difference ``step`` beginning from ``start``. Its taken with common difference ``step`` beginning from ``start``. Its
data type is set by ``dtype``. data type is set by ``dtype``.
...@@ -1508,8 +1697,11 @@ def range(start, end, step, dtype, name=None): ...@@ -1508,8 +1697,11 @@ def range(start, end, step, dtype, name=None):
""" """
out_shape = None out_shape = None
if not isinstance(start, Variable) and not isinstance( if (
end, Variable) and not isinstance(step, Variable): not isinstance(start, Variable)
and not isinstance(end, Variable)
and not isinstance(step, Variable)
):
out_shape = [int(math.ceil((end - start) / step))] out_shape = [int(math.ceil((end - start) / step))]
if not isinstance(dtype, core.VarDesc.VarType): if not isinstance(dtype, core.VarDesc.VarType):
...@@ -1541,17 +1733,16 @@ def range(start, end, step, dtype, name=None): ...@@ -1541,17 +1733,16 @@ def range(start, end, step, dtype, name=None):
out.stop_gradient = True out.stop_gradient = True
return out return out
check_dtype(dtype, 'dtype', ['float32', 'float64', 'int32', 'int64'], check_dtype(
'range/arange') dtype, 'dtype', ['float32', 'float64', 'int32', 'int64'], 'range/arange'
)
helper = LayerHelper('range', **locals()) helper = LayerHelper('range', **locals())
out = helper.create_variable_for_type_inference(dtype, shape=out_shape) out = helper.create_variable_for_type_inference(dtype, shape=out_shape)
helper.append_op(type='range', helper.append_op(
inputs={ type='range',
'Start': start, inputs={'Start': start, 'End': end, 'Step': step},
'End': end, outputs={'Out': out},
'Step': step )
},
outputs={'Out': out})
out.stop_gradient = True out.stop_gradient = True
if out_shape is not None: if out_shape is not None:
out.desc.set_shape(out_shape) out.desc.set_shape(out_shape)
...@@ -1606,58 +1797,76 @@ def linspace(start, stop, num, dtype=None, name=None): ...@@ -1606,58 +1797,76 @@ def linspace(start, stop, num, dtype=None, name=None):
with device_guard("cpu"): with device_guard("cpu"):
tensor_num = fill_constant([1], 'int32', num) tensor_num = fill_constant([1], 'int32', num)
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.linspace(tensor_start, tensor_stop, tensor_num, dtype, return _C_ops.linspace(
_current_expected_place()) tensor_start,
tensor_stop,
tensor_num,
dtype,
_current_expected_place(),
)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _legacy_C_ops.linspace(tensor_start, tensor_stop, tensor_num, return _legacy_C_ops.linspace(
'dtype', dtype) tensor_start, tensor_stop, tensor_num, 'dtype', dtype
)
helper = LayerHelper("linspace", **locals()) helper = LayerHelper("linspace", **locals())
start_dtype = convert_dtype(tensor_start.dtype) start_dtype = convert_dtype(tensor_start.dtype)
stop_dtype = convert_dtype(tensor_stop.dtype) stop_dtype = convert_dtype(tensor_stop.dtype)
out_dtype = convert_dtype(dtype) out_dtype = convert_dtype(dtype)
if isinstance(start, Variable): if isinstance(start, Variable):
check_dtype(start.dtype, 'start', check_dtype(
['float32', 'float64', 'int32', 'int64'], 'linspace') start.dtype,
'start',
['float32', 'float64', 'int32', 'int64'],
'linspace',
)
else: else:
check_type(start, 'start', (int, float), 'linspace') check_type(start, 'start', (int, float), 'linspace')
if isinstance(stop, Variable): if isinstance(stop, Variable):
check_dtype(stop.dtype, 'stop', check_dtype(
['float32', 'float64', 'int32', 'int64'], 'linspace') stop.dtype,
'stop',
['float32', 'float64', 'int32', 'int64'],
'linspace',
)
else: else:
check_type(stop, 'stop', (int, float), 'linspace') check_type(stop, 'stop', (int, float), 'linspace')
if isinstance(num, Variable): if isinstance(num, Variable):
check_dtype(num.dtype, 'num', ['int32'], 'linspace') check_dtype(num.dtype, 'num', ['int32'], 'linspace')
check_dtype(dtype, 'dtype', ['int32', 'int64', 'float32', 'float64'], check_dtype(
'linspace') dtype, 'dtype', ['int32', 'int64', 'float32', 'float64'], 'linspace'
if ((stop_dtype == "float64" or start_dtype == "float64") )
and out_dtype in ["float32", "int32"]) or ( if (
(stop_dtype == "int64" or start_dtype == "int64") (stop_dtype == "float64" or start_dtype == "float64")
and out_dtype == "int32"): and out_dtype in ["float32", "int32"]
) or (
(stop_dtype == "int64" or start_dtype == "int64")
and out_dtype == "int32"
):
raise ValueError( raise ValueError(
"The dtype of start/stop is {}/{} but the attr(dtype) of linspace is {}, " "The dtype of start/stop is {}/{} but the attr(dtype) of linspace is {}, "
"which may cause data type overflows. Please reset attr(dtype) of linspace." "which may cause data type overflows. Please reset attr(dtype) of linspace.".format(
.format(start_dtype, stop_dtype, dtype)) start_dtype, stop_dtype, dtype
)
)
out = helper.create_variable_for_type_inference(dtype=dtype) out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op(type='linspace', helper.append_op(
inputs={ type='linspace',
'Start': tensor_start, inputs={'Start': tensor_start, 'Stop': tensor_stop, 'Num': tensor_num},
'Stop': tensor_stop, attrs={'dtype': dtype},
'Num': tensor_num outputs={'Out': [out]},
}, )
attrs={'dtype': dtype},
outputs={'Out': [out]})
if isinstance(num, int): if isinstance(num, int):
out.desc.set_shape((num, )) out.desc.set_shape((num,))
return out return out
def zeros_like(x, out=None): def zeros_like(x, out=None):
""" """
This OP creates a zeros tensor which has identical shape and dtype This OP creates a zeros tensor which has identical shape and dtype
with `x`. with `x`.
Args: Args:
...@@ -1681,23 +1890,25 @@ def zeros_like(x, out=None): ...@@ -1681,23 +1890,25 @@ def zeros_like(x, out=None):
data = fluid.layers.zeros_like(x) # [0.0, 0.0, 0.0] data = fluid.layers.zeros_like(x) # [0.0, 0.0, 0.0]
""" """
check_variable_and_dtype(x, "x", check_variable_and_dtype(
['bool', 'float32', 'float64', 'int32', 'int64'], x, "x", ['bool', 'float32', 'float64', 'int32', 'int64'], 'zeros_like'
'zeros_like') )
helper = LayerHelper("zeros_like", **locals()) helper = LayerHelper("zeros_like", **locals())
if out is None: if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
else: else:
check_variable_and_dtype( check_variable_and_dtype(
out, "out", ['bool', 'float32', 'float64', 'int32', 'int64'], out,
'zeros_like') "out",
helper.append_op(type='fill_any_like', ['bool', 'float32', 'float64', 'int32', 'int64'],
inputs={'X': [x]}, 'zeros_like',
attrs={ )
'value': 0, helper.append_op(
"dtype": x.dtype type='fill_any_like',
}, inputs={'X': [x]},
outputs={'Out': [out]}) attrs={'value': 0, "dtype": x.dtype},
outputs={'Out': [out]},
)
out.stop_gradient = True out.stop_gradient = True
return out return out
...@@ -1734,8 +1945,12 @@ def diag(diagonal): ...@@ -1734,8 +1945,12 @@ def diag(diagonal):
""" """
check_type(diagonal, 'diagonal', (Variable, numpy.ndarray), 'diag') check_type(diagonal, 'diagonal', (Variable, numpy.ndarray), 'diag')
check_dtype(diagonal.dtype, 'diagonal', check_dtype(
['float32', 'float64', 'int32', 'int64'], 'diag') diagonal.dtype,
'diagonal',
['float32', 'float64', 'int32', 'int64'],
'diag',
)
helper = LayerHelper("diag", **locals()) helper = LayerHelper("diag", **locals())
if not isinstance(diagonal, Variable): if not isinstance(diagonal, Variable):
...@@ -1743,21 +1958,19 @@ def diag(diagonal): ...@@ -1743,21 +1958,19 @@ def diag(diagonal):
out = helper.create_variable_for_type_inference(dtype=diagonal.dtype) out = helper.create_variable_for_type_inference(dtype=diagonal.dtype)
helper.append_op(type='diag', helper.append_op(
inputs={'Diagonal': [diagonal]}, type='diag', inputs={'Diagonal': [diagonal]}, outputs={'Out': [out]}
outputs={'Out': [out]}) )
out.stop_gradient = True out.stop_gradient = True
return out return out
def eye(num_rows, def eye(
num_columns=None, num_rows, num_columns=None, batch_shape=None, dtype='float32', name=None
batch_shape=None, ):
dtype='float32',
name=None):
""" """
This function constructs a or a batch of 2-D tensor with ones on the diagonal and zeros elsewhere. This function constructs a or a batch of 2-D tensor with ones on the diagonal and zeros elsewhere.
Args: Args:
num_rows(int): the number of rows in each batch tensor. num_rows(int): the number of rows in each batch tensor.
...@@ -1808,25 +2021,33 @@ def eye(num_rows, ...@@ -1808,25 +2021,33 @@ def eye(num_rows,
num_columns = num_rows num_columns = num_rows
if in_dygraph_mode(): if in_dygraph_mode():
out = _C_ops.eye(num_rows, num_columns, dtype, out = _C_ops.eye(
_current_expected_place()) num_rows, num_columns, dtype, _current_expected_place()
)
elif _in_legacy_dygraph(): elif _in_legacy_dygraph():
out = _legacy_C_ops.eye('dtype', dtype, 'num_rows', num_rows, out = _legacy_C_ops.eye(
'num_columns', num_columns) 'dtype', dtype, 'num_rows', num_rows, 'num_columns', num_columns
)
else: else:
helper = LayerHelper("eye", **locals()) helper = LayerHelper("eye", **locals())
check_dtype(dtype, 'dtype', check_dtype(
['float16', 'float32', 'float64', 'int32', 'int64'], 'eye') dtype,
'dtype',
['float16', 'float32', 'float64', 'int32', 'int64'],
'eye',
)
out = helper.create_variable_for_type_inference(dtype=dtype) out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op(type='eye', helper.append_op(
inputs={}, type='eye',
outputs={'Out': [out]}, inputs={},
attrs={ outputs={'Out': [out]},
'num_rows': num_rows, attrs={
'num_columns': num_columns, 'num_rows': num_rows,
'dtype': dtype 'num_columns': num_columns,
}, 'dtype': dtype,
stop_gradient=True) },
stop_gradient=True,
)
if batch_shape is not None: if batch_shape is not None:
re_shape = [1] * len(batch_shape) re_shape = [1] * len(batch_shape)
...@@ -1838,11 +2059,12 @@ def eye(num_rows, ...@@ -1838,11 +2059,12 @@ def eye(num_rows,
if not isinstance(batch_shape, list): if not isinstance(batch_shape, list):
raise TypeError("batch_shape should be a list") raise TypeError("batch_shape should be a list")
for batch_val in (batch_shape): for batch_val in batch_shape:
if batch_val <= 0: if batch_val <= 0:
raise TypeError("batch_shape should be a positive int list") raise TypeError("batch_shape should be a positive int list")
from .nn import reshape, expand from .nn import reshape, expand
out = reshape(x=out, shape=re_shape) out = reshape(x=out, shape=re_shape)
out = expand(x=out, expand_times=expand_times) out = expand(x=out, expand_times=expand_times)
...@@ -1854,7 +2076,7 @@ def ones_like(x, out=None): ...@@ -1854,7 +2076,7 @@ def ones_like(x, out=None):
""" """
**ones_like** **ones_like**
This function creates a ones tensor which has identical shape and dtype This function creates a ones tensor which has identical shape and dtype
with `x`. with `x`.
Args: Args:
...@@ -1873,25 +2095,31 @@ def ones_like(x, out=None): ...@@ -1873,25 +2095,31 @@ def ones_like(x, out=None):
data = fluid.layers.ones_like(x) # [1.0, 1.0, 1.0] data = fluid.layers.ones_like(x) # [1.0, 1.0, 1.0]
""" """
check_variable_and_dtype(x, "x", check_variable_and_dtype(
['bool', 'float32', 'float64', 'int32', 'int64'], x, "x", ['bool', 'float32', 'float64', 'int32', 'int64'], 'ones_like'
'ones_like') )
helper = LayerHelper("ones_like", **locals()) helper = LayerHelper("ones_like", **locals())
if out is None: if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
else: else:
check_variable_and_dtype( check_variable_and_dtype(
out, "out", ['bool', 'float32', 'float64', 'int32', 'int64'], out,
'ones_like') "out",
helper.append_op(type='fill_any_like', ['bool', 'float32', 'float64', 'int32', 'int64'],
inputs={'X': [x]}, 'ones_like',
attrs={'value': 1.0}, )
outputs={'Out': [out]}) helper.append_op(
type='fill_any_like',
inputs={'X': [x]},
attrs={'value': 1.0},
outputs={'Out': [out]},
)
return out return out
@deprecated(since="2.0.0", update_to="paddle.triu") @deprecated(since="2.0.0", update_to="paddle.triu")
def triu(input, diagonal=0, name=None): def triu(input, diagonal=0, name=None):
import paddle import paddle
return paddle.tensor.triu(x=input, diagonal=diagonal, name=name) return paddle.tensor.triu(x=input, diagonal=diagonal, name=name)
...@@ -411,7 +411,7 @@ def layer_norm( ...@@ -411,7 +411,7 @@ def layer_norm(
return dygraph_utils._append_activation_in_dygraph(pre_act, act=None) return dygraph_utils._append_activation_in_dygraph(pre_act, act=None)
check_variable_and_dtype( check_variable_and_dtype(
x, 'input', ['float16', 'float32', 'float64'], 'LayerNorm' x, 'input', ['float16', 'float32', 'float64', 'uint16'], 'LayerNorm'
) )
inputs = dict() inputs = dict()
......
...@@ -482,10 +482,10 @@ def _elementwise_op(helper): ...@@ -482,10 +482,10 @@ def _elementwise_op(helper):
assert x is not None, 'x cannot be None in {}'.format(original_op_type) assert x is not None, 'x cannot be None in {}'.format(original_op_type)
assert y is not None, 'y cannot be None in {}'.format(original_op_type) assert y is not None, 'y cannot be None in {}'.format(original_op_type)
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'], x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool', 'uint16'],
original_op_type) original_op_type)
check_variable_and_dtype( check_variable_and_dtype(
y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'], y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool', 'uint16'],
original_op_type) original_op_type)
axis = helper.kwargs.get('axis', -1) axis = helper.kwargs.get('axis', -1)
...@@ -1537,10 +1537,10 @@ def add_n(inputs, name=None): ...@@ -1537,10 +1537,10 @@ def add_n(inputs, name=None):
if len(inputs) > 0: if len(inputs) > 0:
for input in inputs: for input in inputs:
check_variable_and_dtype(input, "inputs", \ check_variable_and_dtype(input, "inputs", \
['float16', 'float32', 'float64', 'int32', 'int64'], 'add_n') ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'add_n')
else: else:
check_variable_and_dtype(inputs, "inputs", \ check_variable_and_dtype(inputs, "inputs", \
['float16', 'float32', 'float64', 'int32', 'int64'], 'add_n') ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'add_n')
out = helper.create_variable_for_type_inference( out = helper.create_variable_for_type_inference(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册