未验证 提交 0e10f247 编写于 作者: Z zhiboniu 提交者: GitHub

fluid code transfer in nn.functional (#42808)

上级 77bae9a4
......@@ -336,28 +336,7 @@ def square_error_cost(input, label):
# [0.01, 0.01]
"""
if _non_static_mode():
minus_out = _C_ops.elementwise_sub(input, label)
square_out = _C_ops.square(minus_out)
return square_out
check_variable_and_dtype(input, "input", ['float32', 'float64'],
'square_error_cost')
check_variable_and_dtype(label, "label", ['float32', 'float64'],
'square_error_cost')
helper = LayerHelper('square_error_cost', **locals())
minus_out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='elementwise_sub',
inputs={'X': [input],
'Y': [label]},
outputs={'Out': [minus_out]})
square_out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='square', inputs={'X': [minus_out]},
outputs={'Out': [square_out]})
return square_out
return paddle.nn.functional.square_error_cost(input, label)
def edit_distance(input,
......@@ -433,45 +412,8 @@ def edit_distance(input,
# [4]
"""
check_variable_and_dtype(input, 'input', ['int64'], 'edit_distance')
check_variable_and_dtype(label, 'label', ['int64'], 'edit_distance')
helper = LayerHelper("edit_distance", **locals())
# remove some tokens from input and labels
if ignored_tokens is not None and len(ignored_tokens) > 0:
erased_input = helper.create_variable_for_type_inference(dtype="int64")
erased_label = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op(
type="sequence_erase",
inputs={"X": [input]},
outputs={"Out": [erased_input]},
attrs={"tokens": ignored_tokens})
input = erased_input
helper.append_op(
type="sequence_erase",
inputs={"X": [label]},
outputs={"Out": [erased_label]},
attrs={"tokens": ignored_tokens})
label = erased_label
this_inputs = {"Hyps": [input], "Refs": [label]}
if input_length is not None and label_length is not None:
this_inputs['HypsLength'] = [input_length]
this_inputs['RefsLength'] = [label_length]
# edit distance op
edit_distance_out = helper.create_variable_for_type_inference(dtype="int64")
sequence_num = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op(
type="edit_distance",
inputs=this_inputs,
outputs={"Out": [edit_distance_out],
"SequenceNum": [sequence_num]},
attrs={"normalized": normalized})
return edit_distance_out, sequence_num
return paddle.nn.functional.loss.edit_distance(
input, label, normalized, ignored_tokens, input_length, label_length)
def warpctc(input,
......@@ -1279,52 +1221,9 @@ def softmax_with_cross_entropy(logits,
out = paddle.nn.functional.softmax_with_cross_entropy(logits=x, label=label)
print(out)
"""
if _non_static_mode():
if core.is_compiled_with_npu():
softmax, backprop, loss = _C_ops.softmax_with_cross_entropy(
logits, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', numeric_stable_mode,
'axis', axis)
else:
if in_dygraph_mode():
softmax, loss = _C_ops.final_state_cross_entropy_with_softmax(
logits, label, soft_label, True, numeric_stable_mode,
ignore_index, axis)
if _in_legacy_dygraph():
softmax, loss = _C_ops.softmax_with_cross_entropy(
logits, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', numeric_stable_mode,
'axis', axis)
if not return_softmax:
return loss
else:
return loss, softmax
attrs = {
'soft_label': soft_label,
'ignore_index': ignore_index,
'numeric_stable_mode': numeric_stable_mode,
'axis': axis
}
helper = LayerHelper('softmax_with_cross_entropy', **locals())
softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
outputs = {'Softmax': softmax, 'Loss': loss}
if core.is_compiled_with_npu() or core.is_compiled_with_mlu():
backprop = helper.create_variable_for_type_inference(dtype=logits.dtype)
outputs['Backprop'] = backprop
helper.append_op(
type='softmax_with_cross_entropy',
inputs={'Logits': logits,
'Label': label},
outputs=outputs,
attrs=attrs)
if return_softmax:
return loss, softmax
return loss
return paddle.nn.functional.loss.fluid_softmax_with_cross_entropy(
logits, label, soft_label, ignore_index, numeric_stable_mode,
return_softmax, axis)
def rank_loss(label, left, right, name=None):
......@@ -1733,33 +1632,7 @@ def npair_loss(anchor, positive, labels, l2_reg=0.002):
print(npair_loss)
"""
check_variable_and_dtype(anchor, 'anchor', ['float32', 'float64'],
'npair_loss')
check_variable_and_dtype(positive, 'positive', ['float32', 'float64'],
'positive')
check_variable_and_dtype(labels, 'labels', ['float32', 'float64', 'int64'],
'labels')
Beta = 0.25
batch_size = labels.shape[0]
labels = nn.reshape(labels, shape=[batch_size, 1])
labels = paddle.tile(labels, repeat_times=[1, batch_size])
labels = equal(labels, nn.transpose(labels, perm=[1, 0])).astype('float32')
labels = labels / nn.reduce_sum(labels, dim=1, keep_dim=True)
l2loss = nn.reduce_mean(nn.reduce_sum(square(anchor), 1)) \
+ nn.reduce_mean(nn.reduce_sum(square(positive), 1))
l2loss = l2loss * Beta * l2_reg
similarity_matrix = paddle.matmul(
anchor, positive, transpose_x=False, transpose_y=True)
softmax_ce = softmax_with_cross_entropy(
logits=similarity_matrix, label=labels, soft_label=True)
cross_entropy = nn.reduce_sum(labels * softmax_ce, 0)
celoss = nn.reduce_mean(cross_entropy)
return l2loss + celoss
return paddle.nn.functional.npair_loss(anchor, positive, labels, l2_reg)
def mse_loss(input, label):
......
......@@ -7394,30 +7394,8 @@ def dice_loss(input, label, epsilon=0.00001, name=None):
predictions = F.softmax(x)
loss = F.dice_loss(input=predictions, label=label)
"""
assert input.dtype in (paddle.float32, paddle.float64)
assert label.dtype in (paddle.int32, paddle.int64)
assert len(input.shape) >= 2, \
"The rank of input should be greater than or equal to 2."
assert len(input.shape) == len(label.shape), (
"The rank of input and label should be equal, "
"but received input: %d, label: %d." %
(len(input.shape), len(label.shape)))
assert label.shape[-1] == 1, ("The last dimension of label should be 1, "
"but received %d." % label.shape[-1])
assert input.shape[:-1] == label.shape[:-1], (
"All dimensions should be equal except the last one.")
assert input.numel() > 0 and label.numel() > 0, \
"Any dimension of input and label cannot be equal to 0."
label = squeeze(label, [-1])
label = paddle.nn.functional.one_hot(label, input.shape[-1])
reduce_dim = list(range(1, len(input.shape)))
inse = reduce_sum(input * label, dim=reduce_dim)
dice_denominator = reduce_sum(
input, dim=reduce_dim) + reduce_sum(
label, dim=reduce_dim)
dice_score = 1 - inse * 2 / (dice_denominator + epsilon)
return reduce_mean(dice_score)
return paddle.nn.functional.dice_loss(
input, label, epsilon=epsilon, name=name)
def image_resize(input,
......@@ -13603,22 +13581,7 @@ def log_loss(input, label, epsilon=1e-4, name=None):
prob = paddle.randn((10,1))
cost = F.log_loss(input=prob, label=label)
"""
if in_dygraph_mode():
return _C_ops.final_state_log_loss(input, label, epsilon)
helper = LayerHelper('log_loss', **locals())
check_variable_and_dtype(input, 'input', ['float32'], 'log_loss')
check_variable_and_dtype(label, 'label', ['float32'], 'log_loss')
loss = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='log_loss',
inputs={'Predicted': [input],
'Labels': [label]},
outputs={'Loss': [loss]},
attrs={'epsilon': epsilon})
return loss
return paddle.nn.functional.log_loss(input, label, epsilon, name)
def add_position_encoding(input, alpha, beta, name=None):
......@@ -13922,33 +13885,8 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None, data_format="NCHW"):
input = paddle.randn([6, 4, 2, 2])
out = F.temporal_shift(x=input, seg_num=2, shift_ratio=0.2)
"""
if data_format not in ["NCHW", "NHWC"]:
raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'. "
"Received Attr(data_format): {}.".format(data_format))
if _non_static_mode():
return _C_ops.temporal_shift(x, 'seg_num', seg_num, 'shift_ratio',
shift_ratio, 'data_format', data_format)
helper = LayerHelper("temporal_shift", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'temporal_shift')
check_type(seg_num, 'seg_num', int, 'temporal_shift')
check_type(shift_ratio, 'shift_ratio', float, 'temporal_shift')
out = helper.create_variable_for_type_inference(dtype=x.dtype)
if not isinstance(seg_num, int):
raise TypeError("seg_num must be int type.")
helper.append_op(
type="temporal_shift",
inputs={"X": x},
outputs={"Out": out},
attrs={
"seg_num": seg_num,
"shift_ratio": shift_ratio,
"data_format": data_format
})
return out
return paddle.nn.functional.temporal_shift(x, seg_num, shift_ratio, name,
data_format)
class PyFuncRegistry(object):
......@@ -15076,63 +15014,8 @@ def unfold(x, kernel_sizes, strides=1, paddings=0, dilations=1, name=None):
y = F.unfold(x, [3, 3], 1, 1, 1)
"""
helper = LayerHelper("unfold", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'unfold')
assert len(x.shape) == 4, \
"input should be the format of [N, C, H, W]"
if isinstance(kernel_sizes, int):
kernel_sizes = [kernel_sizes, kernel_sizes]
else:
assert isinstance(kernel_sizes, list) and (len(kernel_sizes) == 2), \
"kernel_sizes should either be an integer or a list of two integers"
if isinstance(strides, int):
strides = [strides, strides]
else:
assert isinstance(strides, list) and (len(strides) == 2), \
"strides should either be an integer or a list of two integers"
if isinstance(dilations, int):
dilations = [dilations, dilations]
else:
assert isinstance(dilations, list) and (len(dilations) == 2), \
"dilations should either be an integer or a list of two integers"
if isinstance(paddings, int):
paddings = [paddings] * 4
elif isinstance(paddings, list):
if len(paddings) == 2:
paddings = paddings * 2
elif len(paddings) == 4:
pass
else:
raise ValueError(
"paddings should either be an integer or a list of 2 or 4 integers"
)
else:
raise ValueError(
"Unexpected type of paddings, it should be either an integer or a list"
"of 2 or 4 integers")
if in_dygraph_mode():
return _C_ops.final_state_unfold(x, kernel_sizes, strides, paddings,
dilations)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="unfold",
inputs={"X": x},
outputs={"Y": out},
attrs={
"kernel_sizes": kernel_sizes,
"strides": strides,
"paddings": paddings,
"dilations": dilations
})
return out
return paddle.nn.functional.unfold(x, kernel_sizes, strides, paddings,
dilations, name)
def deformable_roi_pooling(input,
......@@ -15584,26 +15467,7 @@ def gather_tree(ids, parents):
# [[[2, 2], [1, 6]], [[3, 3], [6, 1]], [[0, 1], [9, 0]]]
"""
if in_dygraph_mode():
return _C_ops.final_state_gather_tree(ids, parents)
else:
if _in_legacy_dygraph():
return _C_ops.gather_tree(ids, parents)
else:
helper = LayerHelper('gather_tree', **locals())
check_variable_and_dtype(ids, 'ids', ['int32', 'int64'],
'gather_tree')
check_variable_and_dtype(parents, 'parents', ['int32', 'int64'],
'gather_tree')
out = helper.create_variable_for_type_inference(dtype=ids.dtype)
helper.append_op(
type="gather_tree",
inputs={"Ids": ids,
"Parents": parents},
outputs={"Out": out})
return out
return paddle.nn.functional.gather_tree(ids, parents)
@deprecated(since="2.0.0", update_to="paddle.uniform")
......
......@@ -14,6 +14,7 @@
from __future__ import print_function
import paddle
from .layer_function_generator import templatedoc
from ..framework import core, Variable, _non_static_mode, in_dygraph_mode, _in_legacy_dygraph, convert_np_dtype_to_dtype_
from ..layer_helper import LayerHelper
......@@ -1382,35 +1383,7 @@ def sequence_mask(x, maxlen=None, dtype='int64', name=None):
"""
if in_dygraph_mode():
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
if maxlen is not None:
if isinstance(maxlen, core.eager.Tensor):
attrs = ('out_dtype', dtype)
out = _C_ops.sequence_mask(x, maxlen, *attrs)
else:
attrs = ('out_dtype', dtype, 'maxlen', maxlen)
out = _C_ops.sequence_mask(x, None, *attrs)
out.stop_gradient = True
return out
helper = LayerHelper('sequence_mask', **locals())
out = helper.create_variable_for_type_inference(dtype=dtype)
inputs = {'X': [x]}
attrs = {'out_dtype': out.dtype}
if maxlen is not None:
if isinstance(maxlen, Variable):
inputs['MaxLenTensor'] = maxlen
else:
attrs['maxlen'] = maxlen
helper.append_op(
type='sequence_mask', inputs=inputs, outputs={'Y': out}, attrs=attrs)
out.stop_gradient = True
return out
return paddle.nn.functional.sequence_mask(x, maxlen, dtype, name)
@templatedoc()
......
......@@ -55,5 +55,6 @@ from ..fluid.framework import _dygraph_tracer # noqa: F401
from ..fluid.layer_helper import LayerHelper # noqa: F401
from ..fluid.framework import in_dygraph_mode # noqa: F401
from ..fluid.framework import _in_legacy_dygraph # noqa: F401
__all__ = []
......@@ -119,8 +119,8 @@ from .vision import pixel_unshuffle # noqa: F401
from .vision import channel_shuffle # noqa: F401
from .input import one_hot # noqa: F401
from .input import embedding # noqa: F401
from ...fluid.layers import gather_tree # noqa: F401
from ...fluid.layers import temporal_shift # noqa: F401
from .extension import gather_tree # noqa: F401
from .extension import temporal_shift # noqa: F401
from .sparse_attention import sparse_attention
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ...fluid.layers import sigmoid # noqa: F401
from ...tensor.ops import sigmoid # noqa: F401
from ...tensor.math import tanh # noqa: F401
from ...tensor.math import tanh_ # noqa: F401
......
......@@ -21,7 +21,6 @@ from ...tensor.creation import zeros
from paddle.static import Variable
from ...fluid import dygraph_utils
# TODO: define the common functions to build a neural network
from ...fluid.layers import unfold # noqa: F401
from ...tensor.manipulation import squeeze
from ...tensor.manipulation import unsqueeze
from ...tensor import clip
......@@ -31,8 +30,6 @@ from ...fluid.data_feeder import check_variable_and_dtype, check_dtype
from ...fluid.framework import _varbase_creator, _in_legacy_dygraph, in_dygraph_mode, _non_static_mode
from ...fluid import dygraph_utils
from ...fluid import layers
from ...fluid.data_feeder import check_variable_and_dtype
from paddle import _C_ops
from paddle.framework import in_dynamic_mode
......@@ -44,6 +41,135 @@ from paddle.static import default_main_program
__all__ = []
def unfold(x, kernel_sizes, strides=1, paddings=0, dilations=1, name=None):
r"""
This op returns a col buffer of sliding local blocks of input x, also known
as im2col for batched 2D image tensors. For each block under the convolution filter,
all element will be rearranged as a column. While the convolution filter sliding over
the input feature map, a series of such columns will be formed.
For each input :math:`x` with shape [N, C, H, W], the output shape [N, Cout, Lout]
can be calculated as following.
.. math::
dkernel[0] &= dilations[0] \times (kernel\_sizes[0] - 1) + 1
dkernel[1] &= dilations[1] \times (kernel\_sizes[1] - 1) + 1
hout &= \frac{H + paddings[0] + paddings[2] - dkernel[0]}{strides[0]} + 1
wout &= \frac{W + paddings[1] + paddings[3] - dkernel[1]}{strides[1]} + 1
Cout &= C \times kernel\_sizes[0] \times kernel\_sizes[1]
Lout &= hout \times wout
Parameters:
x(Tensor): 4-D Tensor, input tensor of format [N, C, H, W],
data type can be float32 or float64
kernel_sizes(int|list): The size of convolution kernel, should be [k_h, k_w]
or an integer k treated as [k, k].
strides(int|list): The strides, should be [stride_h, stride_w]
or an integer stride treated as [sride, stride].
For default, strides will be [1, 1].
paddings(int|list): The paddings of each dimension, should be
[padding_top, padding_left, padding_bottom, padding_right]
or [padding_h, padding_w] or an integer padding.
If [padding_h, padding_w] was given, it will expanded to
[padding_h, padding_w, padding_h, padding_w]. If an integer
padding was given, [padding, padding, padding, padding] will
be used. For default, paddings will be [0, 0, 0, 0]
dilations(int|list): the dilations of convolution kernel, should be
[dilation_h, dilation_w], or an integer dilation treated as
[dilation, dilation]. For default, it will be [1, 1].
name(str, optional): The default value is None.
Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`
Returns:
The tensor corresponding to the sliding local blocks.
The output shape is [N, Cout, Lout] as decriabled above.
Cout is the total number of values within each block,
and Lout is the total number of such blocks.
The data type of output is the same as the input :math:`x`
Return Type:
Tensor
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
x = paddle.randn((100,3,224,224))
y = F.unfold(x, [3, 3], 1, 1, 1)
"""
helper = LayerHelper("unfold", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'unfold')
assert len(x.shape) == 4, \
"input should be the format of [N, C, H, W]"
if isinstance(kernel_sizes, int):
kernel_sizes = [kernel_sizes, kernel_sizes]
else:
assert isinstance(kernel_sizes, list) and (len(kernel_sizes) == 2), \
"kernel_sizes should either be an integer or a list of two integers"
if isinstance(strides, int):
strides = [strides, strides]
else:
assert isinstance(strides, list) and (len(strides) == 2), \
"strides should either be an integer or a list of two integers"
if isinstance(dilations, int):
dilations = [dilations, dilations]
else:
assert isinstance(dilations, list) and (len(dilations) == 2), \
"dilations should either be an integer or a list of two integers"
if isinstance(paddings, int):
paddings = [paddings] * 4
elif isinstance(paddings, list):
if len(paddings) == 2:
paddings = paddings * 2
elif len(paddings) == 4:
pass
else:
raise ValueError(
"paddings should either be an integer or a list of 2 or 4 integers"
)
else:
raise ValueError(
"Unexpected type of paddings, it should be either an integer or a list"
"of 2 or 4 integers")
if in_dygraph_mode():
return _C_ops.final_state_unfold(x, kernel_sizes, strides, paddings,
dilations)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="unfold",
inputs={"X": x},
outputs={"Y": out},
attrs={
"kernel_sizes": kernel_sizes,
"strides": strides,
"paddings": paddings,
"dilations": dilations
})
return out
def interpolate(x,
size=None,
scale_factor=None,
......@@ -1295,7 +1421,23 @@ def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None):
if mode == "constant" and isinstance(pad, (
list, tuple)) and len(pad) == x_dim * 2:
return layers.pad(x, pad, pad_value=value)
paddings = pad
pad_value = value
check_variable_and_dtype(x, 'x', [
'float16', 'float32', 'float64', 'int32', 'int64', 'complex64',
'complex128'
], "pad")
helper = LayerHelper('pad', **locals())
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='pad',
inputs={'X': x},
outputs={'Out': out},
attrs={'paddings': paddings,
'pad_value': float(pad_value)})
return out
assert x_dim in [
3, 4, 5
......
......@@ -21,8 +21,12 @@ from ...static import Variable
from ...tensor.creation import assign
from ...fluid import dygraph_utils
from ...tensor.layer_function_generator import templatedoc
from ...fluid.layers.sequence_lod import sequence_mask #noqa: F401
from paddle import in_dynamic_mode
from paddle import _C_ops
from ...fluid.framework import _non_static_mode, _in_legacy_dygraph, in_dygraph_mode
from ...fluid.data_feeder import check_variable_and_dtype, check_type
from ...framework import core
from ...common_ops_import import convert_np_dtype_to_dtype_
__all__ = []
......@@ -140,3 +144,240 @@ def diag_embed(input, offset=0, dim1=-2, dim2=-1):
outputs={'Out': [out]})
out.stop_gradient = True
return out
def sequence_mask(x, maxlen=None, dtype='int64', name=None):
r"""
**SequenceMask Layer**
This layer outputs a mask according to the input :code:`x` and
:code:`maxlen` with data type of :code:`dtype`.
Supposing :code:`x` is a Tensor with shape [d_1, d_2, ..., d_n], the
:code:`y` is a mask with shape [d_1, d_2, ..., d_n, maxlen], where:
.. math::
y(i_1, i_2,..., i_n, j) = (j < x(i_1, i_2,..., i_n))
.. code-block:: text
Case:
Consider input:
x = [3, 1, 1, 0] max_len = 4
then we get out:
mask = [[1, 1, 1, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[0, 0, 0, 0]]
Args:
x (Variable): Input tensor of sequence_mask layer, \
whose elements are integers less than :code:`maxlen`. \
Tensor or LodTensor with shape [d_1, d_2, ..., d_n].
maxlen (int, optional): Maximum length of the sequence. If :code:`maxlen` \
is None, it would be replace with :math:`max(x)`.
dtype (np.dtype|paddle.dtype|str, optional): Data type of the output, \
``int64`` by default.
name(str, optional): For detailed information, please refer \
to :ref:`api_guide_Name`. Usually name is no need to set and \
None by default.
Returns: The output sequence mask. Tensor with shape [d_1, d_2, ..., d_n, maxlen] \
and data type of :code:`dtype`. The data type should be bool, float32, float64, int8, \
int32 or int64.
Return Type: Tensor
Examples:
.. code-block:: python
import paddle
lengths = paddle.to_tensor([10, 9, 8])
mask = paddle.nn.functional.sequence_mask(lengths)
print(mask.numpy())
# [[1 1 1 1 1 1 1 1 1 1]
# [1 1 1 1 1 1 1 1 1 0]
# [1 1 1 1 1 1 1 1 0 0]]
"""
if in_dygraph_mode():
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
if maxlen is not None:
if isinstance(maxlen, core.eager.Tensor):
attrs = ('out_dtype', dtype)
out = _C_ops.sequence_mask(x, maxlen, *attrs)
else:
attrs = ('out_dtype', dtype, 'maxlen', maxlen)
out = _C_ops.sequence_mask(x, None, *attrs)
out.stop_gradient = True
return out
helper = LayerHelper('sequence_mask', **locals())
out = helper.create_variable_for_type_inference(dtype=dtype)
inputs = {'X': [x]}
attrs = {'out_dtype': out.dtype}
if maxlen is not None:
if isinstance(maxlen, Variable):
inputs['MaxLenTensor'] = maxlen
else:
attrs['maxlen'] = maxlen
helper.append_op(
type='sequence_mask', inputs=inputs, outputs={'Y': out}, attrs=attrs)
out.stop_gradient = True
return out
def gather_tree(ids, parents):
r"""
To be used after beam search. After beam search, we get selected ids at
each time step and the corresponding parents in the search tree. Both ids
and parents have the layout :attr:`[max_time, batch_size, beam_size]`. Then
:attr:`gather_tree` is used to backtrace from the last time step and
generate the full sequences by collecting selected ids.
Here is an example:
.. code-block:: text
Given:
ids = [[[2 2]
[6 1]]
[[3 9]
[6 1]]
[[0 1]
[9 0]]]
parents = [[[0 0]
[1 1]]
[[1 0]
[1 0]]
[[0 0]
[0 1]]]
Then:
gather_tree(ids, parents)
= [[[2 2]
[1 6]]
[[3 3]
[6 1]]
[[0 1]
[9 0]]]
Args:
ids(Tensor): A Tensor with shape :attr:`[length, batch_size, beam_size]`
and data type :attr:`int32` or :attr:`int64`. It contains the selected
ids of all time steps.
parents(Tensor): A Tensor with the same shape and data type as :attr:`ids`,
It contains the parents corresponding to selected ids when searching
among beams.
Returns:
A Tensor with the same shape and data type as :attr:`ids`. \
It contains the full sequences. The sequences are collected from \
:attr:`ids` by backtracing according to :attr:`parents`.
Examples:
.. code-block:: python
import paddle
ids = paddle.to_tensor([[[2, 2], [6, 1]], [[3, 9], [6, 1]], [[0, 1], [9, 0]]])
parents = paddle.to_tensor([[[0, 0], [1, 1]], [[1, 0], [1, 0]], [[0, 0], [0, 1]]])
final_sequences = paddle.nn.functional.gather_tree(ids, parents)
# [[[2, 2], [1, 6]], [[3, 3], [6, 1]], [[0, 1], [9, 0]]]
"""
if in_dygraph_mode():
return _C_ops.final_state_gather_tree(ids, parents)
else:
if _in_legacy_dygraph():
return _C_ops.gather_tree(ids, parents)
else:
helper = LayerHelper('gather_tree', **locals())
check_variable_and_dtype(ids, 'ids', ['int32', 'int64'],
'gather_tree')
check_variable_and_dtype(parents, 'parents', ['int32', 'int64'],
'gather_tree')
out = helper.create_variable_for_type_inference(dtype=ids.dtype)
helper.append_op(
type="gather_tree",
inputs={"Ids": ids,
"Parents": parents},
outputs={"Out": out})
return out
@templatedoc()
def temporal_shift(x, seg_num, shift_ratio=0.25, name=None, data_format="NCHW"):
"""
**Temporal Shift Operator**
${comment}
Args:
x(Tensor): ${x_comment}
seg_num(int): ${seg_num_comment}
shift_ratio(float): ${shift_ratio_comment}
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
data_format(str, optional): Data format that specifies the layout of input.
It can be "NCHW" or "NHWC". Default: "NCHW".
Returns:
out(Tensor): The temporal shifting result is a tensor with the
same shape and same data type as the input.
Raises:
TypeError: seg_num must be int type.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
input = paddle.randn([6, 4, 2, 2])
out = F.temporal_shift(x=input, seg_num=2, shift_ratio=0.2)
"""
if data_format not in ["NCHW", "NHWC"]:
raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'. "
"Received Attr(data_format): {}.".format(data_format))
if _non_static_mode():
return _C_ops.temporal_shift(x, 'seg_num', seg_num, 'shift_ratio',
shift_ratio, 'data_format', data_format)
helper = LayerHelper("temporal_shift", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'temporal_shift')
check_type(seg_num, 'seg_num', int, 'temporal_shift')
check_type(shift_ratio, 'shift_ratio', float, 'temporal_shift')
out = helper.create_variable_for_type_inference(dtype=x.dtype)
if not isinstance(seg_num, int):
raise TypeError("seg_num must be int type.")
helper.append_op(
type="temporal_shift",
inputs={"X": x},
outputs={"Out": out},
attrs={
"seg_num": seg_num,
"shift_ratio": shift_ratio,
"data_format": data_format
})
return out
......@@ -21,15 +21,7 @@ import numpy as np
import paddle
import paddle.fluid as fluid
from ...fluid.layers.nn import _elementwise_op_in_dygraph
from ...fluid.layers import dice_loss # noqa: F401
from ...fluid.layers import log_loss # noqa: F401
from ...fluid.layers import npair_loss # noqa: F401
from ...tensor.manipulation import reshape
from ...fluid.layers import softmax_with_cross_entropy as fluid_softmax_with_cross_entropy
from ...fluid.layers import square_error_cost # noqa: F401
from ...fluid.layers import edit_distance # noqa: F401
from ...fluid.layers import huber_loss
from ...fluid.layer_helper import LayerHelper
from ...fluid.framework import _varbase_creator
from ...static import Variable
......@@ -41,6 +33,518 @@ from ...fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _non_static_
__all__ = []
def dice_loss(input, label, epsilon=0.00001, name=None):
r"""
Dice loss for comparing the similarity between the input predictions and the label.
This implementation is for binary classification, where the input is sigmoid
predictions of each pixel, usually used for segmentation task. The dice loss can
be defined as the following equation:
.. math::
dice\_loss &= 1 - \frac{2 * intersection\_area}{total\_area} \\
&= \frac{(total\_area - intersection\_area) - intersection\_area}{total\_area} \\
&= \frac{(union\_area - intersection\_area)}{total\_area}
Parameters:
input (Tensor): Tensor, rank>=2, shape is :math:`[N_1, N_2, ..., N_k, D]`, where :math:`N_1` is
the batch_size, :math:`D` is the number of categories. It is usually the output
predictions of sigmoid activation. The data type can be float32 or float64.
label (Tensor): Tensor, the groud truth with the same rank as input, shape is :math:`[N_1, N_2, ..., N_k, 1]`.
where :math:`N_1` is the batch_size. The data type can be int32 or int64.
epsilon (float): The epsilon will be added to the numerator and denominator.
If both input and label are empty, it makes sure dice is 1.
Default: 0.00001
name(str, optional): The default value is None.
Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`
Returns:
Tensor, which shape is [1], data type is the same as `input` .
Example:
.. code-block:: python
import paddle
import paddle.nn.functional as F
x = paddle.randn((3,224,224,2))
label = paddle.randint(high=2, shape=(3,224,224,1))
predictions = F.softmax(x)
loss = F.dice_loss(input=predictions, label=label)
"""
assert input.dtype in (paddle.float32, paddle.float64)
assert label.dtype in (paddle.int32, paddle.int64)
assert len(input.shape) >= 2, \
"The rank of input should be greater than or equal to 2."
assert len(input.shape) == len(label.shape), (
"The rank of input and label should be equal, "
"but received input: %d, label: %d." %
(len(input.shape), len(label.shape)))
assert label.shape[-1] == 1, ("The last dimension of label should be 1, "
"but received %d." % label.shape[-1])
assert input.shape[:-1] == label.shape[:-1], (
"All dimensions should be equal except the last one.")
assert input.numel() > 0 and label.numel() > 0, \
"Any dimension of input and label cannot be equal to 0."
label = paddle.squeeze(label, [-1])
label = paddle.nn.functional.one_hot(label, input.shape[-1])
reduce_dim = list(range(1, len(input.shape)))
inse = paddle.sum(input * label, axis=reduce_dim)
dice_denominator = paddle.sum(input, axis=reduce_dim) + paddle.sum(
label, axis=reduce_dim)
dice_score = 1 - inse * 2 / (dice_denominator + epsilon)
return paddle.mean(dice_score)
def log_loss(input, label, epsilon=1e-4, name=None):
r"""
**Negative Log Loss Layer**
This layer accepts input predictions and target label and returns the
negative log loss.
.. math::
Out = -label * \log{(input + \epsilon)}
- (1 - label) * \log{(1 - input + \epsilon)}
Args:
input (Tensor|list): A 2-D tensor with shape [N x 1], where N is the
batch size. This input is a probability computed
by the previous operator. Data type float32.
label (Tensor|list): The ground truth which is a 2-D tensor with
shape [N x 1], where N is the batch size.
Data type float32.
epsilon (float, optional): A small number for numerical stability. Default 1e-4.
name(str|None): For detailed information, please refer to
:ref:`api_guide_Name` . Usually name is no need to set and None by default.
Returns:
Tensor, which shape is [N x 1], data type is float32.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
label = paddle.randn((10,1))
prob = paddle.randn((10,1))
cost = F.log_loss(input=prob, label=label)
"""
if in_dygraph_mode():
return _C_ops.final_state_log_loss(input, label, epsilon)
helper = LayerHelper('log_loss', **locals())
check_variable_and_dtype(input, 'input', ['float32'], 'log_loss')
check_variable_and_dtype(label, 'label', ['float32'], 'log_loss')
loss = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='log_loss',
inputs={'Predicted': [input],
'Labels': [label]},
outputs={'Loss': [loss]},
attrs={'epsilon': epsilon})
return loss
def fluid_softmax_with_cross_entropy(logits,
label,
soft_label=False,
ignore_index=-100,
numeric_stable_mode=True,
return_softmax=False,
axis=-1):
r"""
This operator implements the cross entropy loss function with softmax. This function
combines the calculation of the softmax operation and the cross entropy loss function
to provide a more numerically stable gradient.
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
When the attribute :attr:`soft_label` is set :attr:`False`, this operators
expects mutually exclusive hard labels, each sample in a batch is in exactly
one class with a probability of 1.0. Each sample in the batch will have a
single label.
The equation is as follows:
1) Hard label (one-hot label, so every sample has exactly one class)
.. math::
loss_j = -\\text{logits}_{label_j} +
\\log\\left(\\sum_{i=0}^{K}\\exp(\\text{logits}_i)\\right), j = 1,..., K
2) Soft label (each sample can have a distribution over all classes)
.. math::
loss_j = -\\sum_{i=0}^{K}\\text{label}_i
\\left(\\text{logits}_i - \\log\\left(\\sum_{i=0}^{K}
\\exp(\\text{logits}_i)\\right)\\right), j = 1,...,K
3) If :attr:`numeric_stable_mode` is :attr:`True`, softmax is calculated first by:
.. math::
max_j &= \\max_{i=0}^{K}{\\text{logits}_i}
log\\_max\\_sum_j &= \\log\\sum_{i=0}^{K}\\exp(logits_i - max_j)
softmax_j &= \\exp(logits_j - max_j - {log\\_max\\_sum}_j)
and then cross entropy loss is calculated by softmax and label.
Args:
logits (Tensor): A multi-dimension ``Tensor`` , and the data type is float32 or float64. The input tensor of unscaled log probabilities.
label (Tensor): The ground truth ``Tensor`` , data type is the same
as the ``logits`` . If :attr:`soft_label` is set to :attr:`True`,
Label is a ``Tensor`` in the same shape with :attr:`logits`.
If :attr:`soft_label` is set to :attr:`True`, Label is a ``Tensor``
in the same shape with :attr:`logits` expect shape in dimension :attr:`axis` as 1.
soft_label (bool, optional): A flag to indicate whether to interpretant the given
labels as soft labels. Default False.
ignore_index (int, optional): Specifies a target value that is ignored and does
not contribute to the input gradient. Only valid
if :attr:`soft_label` is set to :attr:`False`.
Default: kIgnoreIndex(-100).
numeric_stable_mode (bool, optional): A flag to indicate whether to use a more
numerically stable algorithm. Only valid
when :attr:`soft_label` is :attr:`False`
and GPU is used. When :attr:`soft_label`
is :attr:`True` or CPU is used, the
algorithm is always numerically stable.
Note that the speed may be slower when use
stable algorithm. Default: True.
return_softmax (bool, optional): A flag indicating whether to return the softmax
along with the cross entropy loss. Default: False.
axis (int, optional): The index of dimension to perform softmax calculations. It
should be in range :math:`[-1, rank - 1]`, while :math:`rank`
is the rank of input :attr:`logits`. Default: -1.
Returns:
``Tensor`` or Tuple of two ``Tensor`` : Return the cross entropy loss if \
`return_softmax` is False, otherwise the tuple \
(loss, softmax), softmax is in the same shape \
with input logits and cross entropy loss is in \
the same shape with input logits except shape \
in dimension :attr:`axis` as 1.
Examples:
.. code-block:: python
import paddle
import numpy as np
data = np.random.rand(128).astype("float32")
label = np.random.rand(1).astype("int64")
data = paddle.to_tensor(data)
label = paddle.to_tensor(label)
linear = paddle.nn.Linear(128, 100)
x = linear(data)
out = paddle.nn.functional.softmax_with_cross_entropy(logits=x, label=label)
print(out)
"""
if _non_static_mode():
if core.is_compiled_with_npu():
softmax, backprop, loss = _C_ops.softmax_with_cross_entropy(
logits, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', numeric_stable_mode,
'axis', axis)
else:
if in_dygraph_mode():
softmax, loss = _C_ops.final_state_cross_entropy_with_softmax(
logits, label, soft_label, True, numeric_stable_mode,
ignore_index, axis)
if _in_legacy_dygraph():
softmax, loss = _C_ops.softmax_with_cross_entropy(
logits, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', numeric_stable_mode,
'axis', axis)
if not return_softmax:
return loss
else:
return loss, softmax
attrs = {
'soft_label': soft_label,
'ignore_index': ignore_index,
'numeric_stable_mode': numeric_stable_mode,
'axis': axis
}
helper = LayerHelper('softmax_with_cross_entropy', **locals())
softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
outputs = {'Softmax': softmax, 'Loss': loss}
if core.is_compiled_with_npu() or core.is_compiled_with_mlu():
backprop = helper.create_variable_for_type_inference(dtype=logits.dtype)
outputs['Backprop'] = backprop
helper.append_op(
type='softmax_with_cross_entropy',
inputs={'Logits': logits,
'Label': label},
outputs=outputs,
attrs=attrs)
if return_softmax:
return loss, softmax
return loss
def npair_loss(anchor, positive, labels, l2_reg=0.002):
"""
Npair loss requires paired data. Npair loss has two parts: the first part is L2
regularizer on the embedding vector; the second part is cross entropy loss which
takes the similarity matrix of anchor and positive as logits.
For more information, please refer to:
`Improved Deep Metric Learning with Multi class N pair Loss Objective <http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf>`_
Args:
anchor(Tensor): embedding vector for the anchor image. shape=[batch_size, embedding_dims],
the data type is float32 or float64.
positive(Tensor): embedding vector for the positive image. shape=[batch_size, embedding_dims],
the data type is float32 or float64.
labels(Tensor): 1-D tensor. shape=[batch_size], the data type is float32 or float64 or int64.
l2_reg(float32): L2 regularization term on embedding vector, default: 0.002.
Returns:
A Tensor representing the npair loss, the data type is the same as anchor, the shape is [1].
Examples:
.. code-block:: python
import paddle
DATATYPE = "float32"
anchor = paddle.rand(shape=(18, 6), dtype=DATATYPE)
positive = paddle.rand(shape=(18, 6), dtype=DATATYPE)
labels = paddle.rand(shape=(18,), dtype=DATATYPE)
npair_loss = paddle.nn.functional.npair_loss(anchor, positive, labels, l2_reg = 0.002)
print(npair_loss)
"""
check_variable_and_dtype(anchor, 'anchor', ['float32', 'float64'],
'npair_loss')
check_variable_and_dtype(positive, 'positive', ['float32', 'float64'],
'positive')
check_variable_and_dtype(labels, 'labels', ['float32', 'float64', 'int64'],
'labels')
Beta = 0.25
batch_size = labels.shape[0]
labels = paddle.reshape(labels, shape=[batch_size, 1])
labels = paddle.tile(labels, repeat_times=[1, batch_size])
labels = paddle.equal(
labels, paddle.transpose(
labels, perm=[1, 0])).astype('float32')
labels = labels / paddle.sum(labels, axis=1, keepdim=True)
l2loss = paddle.mean(paddle.sum(paddle.square(anchor), 1)) \
+ paddle.mean(paddle.sum(paddle.square(positive), 1))
l2loss = l2loss * Beta * l2_reg
similarity_matrix = paddle.matmul(
anchor, positive, transpose_x=False, transpose_y=True)
softmax_ce = fluid_softmax_with_cross_entropy(
logits=similarity_matrix, label=labels, soft_label=True)
cross_entropy = paddle.sum(labels * softmax_ce, 0)
celoss = paddle.mean(cross_entropy)
return l2loss + celoss
def square_error_cost(input, label):
r"""
This op accepts input predictions and target label and returns the
squared error cost.
For predictions label, and target label, the equation is:
.. math::
Out = (input - label)^2
Parameters:
input (Tensor): Input tensor, the data type should be float32.
label (Tensor): Label tensor, the data type should be float32.
Returns:
The tensor storing the element-wise squared error \
difference between input and label.
Return type: Tensor.
Examples:
.. code-block:: python
import paddle
input = paddle.to_tensor([1.1, 1.9])
label = paddle.to_tensor([1.0, 2.0])
output = paddle.nn.functional.square_error_cost(input, label)
print(output)
# [0.01, 0.01]
"""
if _non_static_mode():
minus_out = _C_ops.elementwise_sub(input, label)
square_out = _C_ops.square(minus_out)
return square_out
check_variable_and_dtype(input, "input", ['float32', 'float64'],
'square_error_cost')
check_variable_and_dtype(label, "label", ['float32', 'float64'],
'square_error_cost')
helper = LayerHelper('square_error_cost', **locals())
minus_out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='elementwise_sub',
inputs={'X': [input],
'Y': [label]},
outputs={'Out': [minus_out]})
square_out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='square', inputs={'X': [minus_out]},
outputs={'Out': [square_out]})
return square_out
def edit_distance(input,
label,
normalized=True,
ignored_tokens=None,
input_length=None,
label_length=None):
"""
This op computes the edit distances, also called Levenshtein distance, between a batch of
hypothesis strings and their references. It measures how dissimilar two strings are by counting
the minimum number of operations to transform one string into another.
The operations include insertion, deletion, and substitution.
For example, given hypothesis string A = "kitten" and reference
B = "sitting", A will be transformed into B
at least after two substitutions and one insertion:
"kitten" -> "sitten" -> "sittin" -> "sitting"
So the edit distance between A and B is 3.
The input is a Tensor, the input_length and label_length should be supported.
The `batch_size` of labels should be same as `input`.
The output include the edit distance value between every pair of input and related label, and the number of sequence.
If Attr(normalized) is true,
the edit distance value will be divided by the length of label.
Parameters:
input(Tensor): The input tensor, its rank should be equal to 2 and its data type should be int64.
label(Tensor): The label tensor, its rank should be equal to 2 and its data type should be int64.
normalized(bool, default True): Indicated whether to normalize the edit distance.
ignored_tokens(list<int>, default None): Tokens that will be removed before
calculating edit distance.
input_length(Tensor): The length for each sequence in `input` if it's of Tensor type, it should have shape `(batch_size, )` and its data type should be int64.
label_length(Tensor): The length for each sequence in `label` if it's of Tensor type, it should have shape `(batch_size, )` and its data type should be int64.
NOTE: To be avoid unexpected result, the value of every elements in input_length and label_length should be equal to the value of the second dimension of input and label. For example, The input: [[1,2,3,4],[5,6,7,8],[9,10,11,12]], the shape of input is [3,4] and the input_length should be [4,4,4]
NOTE: This Api is different from fluid.metrics.EditDistance
Returns:
Tuple:
distance(Tensor): edit distance result, its data type is float32, and its shape is (batch_size, 1).
sequence_num(Tensor): sequence number, its data type is float32, and its shape is (1,).
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
input = paddle.to_tensor([[1,2,3],[4,5,6],[4,4,4],[1,1,1]], dtype='int64')
label = paddle.to_tensor([[1,3,4,1],[4,5,8,1],[7,7,7,1],[1,1,1,1]], dtype='int64')
input_len = paddle.to_tensor([3,3,3,3], dtype='int64')
label_len = paddle.to_tensor([4,4,4,4], dtype='int64')
distance, sequence_num = F.loss.edit_distance(input=input, label=label, input_length=input_len, label_length=label_len, normalized=False)
# print(distance)
# [[3.]
# [2.]
# [4.]
# [1.]]
# if set normalized to True
# [[0.75]
# [0.5 ]
# [1. ]
# [0.25]
#
# print(sequence_num)
# [4]
"""
check_variable_and_dtype(input, 'input', ['int64'], 'edit_distance')
check_variable_and_dtype(label, 'label', ['int64'], 'edit_distance')
helper = LayerHelper("edit_distance", **locals())
# remove some tokens from input and labels
if ignored_tokens is not None and len(ignored_tokens) > 0:
erased_input = helper.create_variable_for_type_inference(dtype="int64")
erased_label = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op(
type="sequence_erase",
inputs={"X": [input]},
outputs={"Out": [erased_input]},
attrs={"tokens": ignored_tokens})
input = erased_input
helper.append_op(
type="sequence_erase",
inputs={"X": [label]},
outputs={"Out": [erased_label]},
attrs={"tokens": ignored_tokens})
label = erased_label
this_inputs = {"Hyps": [input], "Refs": [label]}
if input_length is not None and label_length is not None:
this_inputs['HypsLength'] = [input_length]
this_inputs['RefsLength'] = [label_length]
# edit distance op
edit_distance_out = helper.create_variable_for_type_inference(dtype="int64")
sequence_num = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op(
type="edit_distance",
inputs=this_inputs,
outputs={"Out": [edit_distance_out],
"SequenceNum": [sequence_num]},
attrs={"normalized": normalized})
return edit_distance_out, sequence_num
def binary_cross_entropy(input, label, weight=None, reduction='mean',
name=None):
"""
......@@ -138,10 +642,10 @@ def binary_cross_entropy(input, label, weight=None, reduction='mean',
else:
return out
else:
fluid.data_feeder.check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'binary_cross_entropy')
fluid.data_feeder.check_variable_and_dtype(
label, 'label', ['float32', 'float64'], 'binary_cross_entropy')
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'binary_cross_entropy')
check_variable_and_dtype(label, 'label', ['float32', 'float64'],
'binary_cross_entropy')
sub_name = name if weight is None and reduction == 'none' else None
helper = LayerHelper("binary_cross_entropy", name=sub_name)
......@@ -288,12 +792,10 @@ def binary_cross_entropy_with_logits(logit,
else:
return out
fluid.data_feeder.check_variable_and_dtype(
logit, 'logit', ['float32', 'float64'],
'binary_cross_entropy_with_logits')
fluid.data_feeder.check_variable_and_dtype(
label, 'label', ['float32', 'float64'],
'binary_cross_entropy_with_logits')
check_variable_and_dtype(logit, 'logit', ['float32', 'float64'],
'binary_cross_entropy_with_logits')
check_variable_and_dtype(label, 'label', ['float32', 'float64'],
'binary_cross_entropy_with_logits')
sigmoid_name = None
if reduction == 'none' and pos_weight is None and weight is None:
sigmoid_name = name
......@@ -303,18 +805,17 @@ def binary_cross_entropy_with_logits(logit,
one = paddle.full(shape=[1], fill_value=1.0, dtype=logit.dtype)
if pos_weight is not None:
fluid.data_feeder.check_variable_and_dtype(
pos_weight, 'pos_weight', ['float32', 'float64'],
'binary_cross_entropy_with_logits')
check_variable_and_dtype(pos_weight, 'pos_weight',
['float32', 'float64'],
'binary_cross_entropy_with_logits')
log_weight = paddle.add(
paddle.multiply(label, paddle.subtract(pos_weight, one)), one)
pos_weight_name = name if reduction == 'none' and weight is None else None
out = paddle.multiply(out, log_weight, name=pos_weight_name)
if weight is not None:
fluid.data_feeder.check_variable_and_dtype(
weight, 'weight', ['float32', 'float64'],
'binary_cross_entropy_with_logits')
check_variable_and_dtype(weight, 'weight', ['float32', 'float64'],
'binary_cross_entropy_with_logits')
weight_name = name if reduction == 'none' else None
out = paddle.multiply(out, weight, name=weight_name)
......@@ -519,12 +1020,26 @@ def smooth_l1_loss(input, label, reduction='mean', delta=1.0, name=None):
output = paddle.nn.functional.smooth_l1_loss(input, label)
print(output)
"""
fluid.data_feeder.check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'smooth_l1_loss')
fluid.data_feeder.check_variable_and_dtype(
label, 'label', ['float32', 'float64'], 'smooth_l1_loss')
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'smooth_l1_loss')
check_variable_and_dtype(label, 'label', ['float32', 'float64'],
'smooth_l1_loss')
out = huber_loss(input=input, label=label, delta=delta)
if in_dygraph_mode():
out, residual = _C_ops.final_state_huber_loss(input, label, delta)
else:
helper = LayerHelper('huber_loss', **locals())
residual = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
helper.append_op(
type='huber_loss',
inputs={'X': input,
'Y': label},
outputs={'Out': out,
'Residual': residual},
attrs={'delta': delta})
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
......@@ -615,12 +1130,12 @@ def margin_ranking_loss(input,
return out
helper = LayerHelper("margin_ranking_loss", **locals())
fluid.data_feeder.check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'margin_rank_loss')
fluid.data_feeder.check_variable_and_dtype(
other, 'other', ['float32', 'float64'], 'margin_rank_loss')
fluid.data_feeder.check_variable_and_dtype(
label, 'label', ['float32', 'float64'], 'margin_rank_loss')
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'margin_rank_loss')
check_variable_and_dtype(other, 'other', ['float32', 'float64'],
'margin_rank_loss')
check_variable_and_dtype(label, 'label', ['float32', 'float64'],
'margin_rank_loss')
out = paddle.subtract(other, input)
out = paddle.multiply(out, label)
......@@ -738,9 +1253,9 @@ def l1_loss(input, label, reduction='mean', name=None):
else:
return unreduced
fluid.data_feeder.check_variable_and_dtype(
check_variable_and_dtype(
input, 'input', ['float32', 'float64', 'int32', 'int64'], 'l1_loss')
fluid.data_feeder.check_variable_and_dtype(
check_variable_and_dtype(
label, 'label', ['float32', 'float64', 'int32', 'int64'], 'l1_loss')
if reduction == 'sum':
......@@ -847,10 +1362,8 @@ def nll_loss(input,
label = reshape(label, shape=[n, 1, -1])
out_shape = [n] + input_shape[2:]
fluid.data_feeder.check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'nll_loss')
fluid.data_feeder.check_variable_and_dtype(label, 'label', ['int64'],
'nll_loss')
check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'nll_loss')
check_variable_and_dtype(label, 'label', ['int64'], 'nll_loss')
inputs = {'X': input, 'Label': label}
attrs = {'reduction': reduction, 'ignore_index': ignore_index}
if weight is not None:
......@@ -971,10 +1484,8 @@ def kl_div(input, label, reduction='mean', name=None):
helper = LayerHelper('kl_div', **locals())
fluid.data_feeder.check_variable_and_dtype(input, 'input',
['float32', 'float64'], 'kl_div')
fluid.data_feeder.check_variable_and_dtype(label, 'label',
['float32', 'float64'], 'kl_div')
check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'kl_div')
check_variable_and_dtype(label, 'label', ['float32', 'float64'], 'kl_div')
fluid.data_feeder.check_type(reduction, 'reduction', str, 'kl_div')
loss = helper.create_variable_for_type_inference(dtype=input.dtype)
......@@ -1051,10 +1562,10 @@ def mse_loss(input, label, reduction='mean', name=None):
"but received {}.".format(reduction))
if not in_dynamic_mode():
paddle.fluid.data_feeder.check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'mse_loss')
paddle.fluid.data_feeder.check_variable_and_dtype(
label, 'label', ['float32', 'float64'], 'mse_loss')
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'mse_loss')
check_variable_and_dtype(label, 'label', ['float32', 'float64'],
'mse_loss')
if reduction == 'none':
return paddle.square(paddle.subtract(input, label), name=name)
......@@ -1858,9 +2369,9 @@ def cross_entropy(input,
out = paddle.squeeze(out, axis=axis)
return out
fluid.data_feeder.check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'softmax_cross_entropy')
fluid.data_feeder.check_variable_and_dtype(
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'softmax_cross_entropy')
check_variable_and_dtype(
label, 'label',
['uint8', 'int8', 'int16', 'int32', 'int64', 'float32', 'float64'],
'softmax_cross_entropy')
......@@ -1887,8 +2398,8 @@ def cross_entropy(input,
attrs=attrs)
if weight is not None:
fluid.data_feeder.check_variable_and_dtype(
weight, 'weight', ['float32', 'float64'], 'softmax_cross_entropy')
check_variable_and_dtype(weight, 'weight', ['float32', 'float64'],
'softmax_cross_entropy')
weight_name = name if reduction == 'none' else None
if soft_label == True:
# chajchaj:
......@@ -2050,9 +2561,8 @@ def sigmoid_focal_loss(logit,
% reduction)
if normalizer is not None:
fluid.data_feeder.check_variable_and_dtype(normalizer, 'normalizer',
['float32', 'float64'],
'sigmoid_focal_loss')
check_variable_and_dtype(normalizer, 'normalizer',
['float32', 'float64'], 'sigmoid_focal_loss')
normalizer_shape = list(normalizer.shape)
normalizer_dims = len(normalizer_shape)
if normalizer_dims > 1:
......@@ -2102,10 +2612,10 @@ def sigmoid_focal_loss(logit,
return loss
fluid.data_feeder.check_variable_and_dtype(
logit, 'logit', ['float32', 'float64'], 'sigmoid_focal_loss')
fluid.data_feeder.check_variable_and_dtype(
label, 'label', ['float32', 'float64'], 'sigmoid_focal_loss')
check_variable_and_dtype(logit, 'logit', ['float32', 'float64'],
'sigmoid_focal_loss')
check_variable_and_dtype(label, 'label', ['float32', 'float64'],
'sigmoid_focal_loss')
bce_name = None
if reduction == 'none' and normalizer is None:
......
......@@ -21,7 +21,7 @@ import string
from six.moves import cStringIO
from ..static import Variable
from ..fluid.proto import framework_pb2
from ..framework import OpProtoHolder, core, convert_np_dtype_to_dtype_, _non_static_mode, in_dygraph_mode
from ..framework import OpProtoHolder, core, convert_np_dtype_to_dtype_, _non_static_mode, in_dygraph_mode, _in_legacy_dygraph
from ..framework import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype
import paddle
......@@ -271,9 +271,10 @@ def generate_activation_fn(op_type):
op_type)
else:
# abs exp square ops support dtype(int32, int64, float16, float32, float64)
check_variable_and_dtype(
x, 'x', ['int32', 'int64', 'float16', 'float32', 'float64'],
op_type)
check_variable_and_dtype(x, 'x', [
'int32', 'int64', 'float16', 'float32', 'float64', 'complex64',
'complex128'
], op_type)
helper = LayerHelper(op_type, **locals())
......@@ -302,7 +303,7 @@ def generate_inplace_fn(inplace_op_type):
origin_op_type = inplace_op_type[:-1]
def func(x, name=None):
if paddle.in_dynamic_mode():
if _non_static_mode():
op = getattr(_C_ops, inplace_op_type)
return op(x)
warnings.warn(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册