未验证 提交 d3c9db75 编写于 作者: Z zhongpu 提交者: GitHub

copy api from paddle to paddle.fluid (#24164)

* copy api from paddle to paddle.fluid, test=develop

* fix optest, test=develop
上级 a4519a5d
......@@ -17,6 +17,8 @@ from __future__ import print_function
from six.moves import reduce
from .. import core
from ..layers import utils
from ..layers import square
from ..layers import cross_entropy
from ..layers import nn as F
from .. import dygraph_utils
from . import layers
......@@ -35,7 +37,8 @@ __all__ = [
'Conv2D', 'Conv3D', 'Pool2D', 'Linear', 'BatchNorm', 'Dropout', 'Embedding',
'GRUUnit', 'InstanceNorm', 'LayerNorm', 'NCE', 'PRelu',
'BilinearTensorProduct', 'Conv2DTranspose', 'Conv3DTranspose', 'GroupNorm',
'SpectralNorm', 'TreeConv'
'SpectralNorm', 'TreeConv', 'CrossEntropyLoss', 'MSELoss', 'L1Loss',
'NLLLoss', 'BCELoss'
]
......@@ -3122,3 +3125,560 @@ class TreeConv(layers.Layer):
else:
pre_activation = out
return self._helper.append_activation(pre_activation, act=self._act)
class CrossEntropyLoss(layers.Layer):
"""
This operator implements the cross entropy loss function. This OP combines `softmax`,
`cross_entropy`, and `reduce_sum`/`reduce_mean` together.
It is useful when training a classification problem with `C` classes.
If provided, the optional argument `weight` should be a 1D Variable assigning
weight to each of the classes.
For predictions label, and target label, the loss is calculated as follows.
.. math::
loss_j = -\\text{input[class]} +
\\log\\left(\\sum_{i=0}^{K}\\exp(\\text{input}_i)\\right), j = 1,..., K
If weight is not `None`:
.. math::
loss_j = \\text{weight[class]}(-\\text{input[class]} +
\\log\\left(\\sum_{i=0}^{K}\\exp(\\text{input}_i)\\right)), j = 1,..., K
Parameters:
input (Variable): Input tensor, the data type is float32,
float64, int32, int64.
label (Variable): Label tensor, the data type is float32,
float64, int32, int64.
weight (Variable, optional): Weight tensor, a manual rescaling weight given
to each class. It has the same dimensions as class number and the data type
is float32, float64, int32, int64. Default is ``'None'``.
reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
Default is ``'mean'``.
Returns:
The tensor variable storing the cross_entropy_loss of input and label.
Return type: Variable.
Examples:
.. code-block:: python
# declarative mode
import paddle.fluid as fluid
import numpy as np
input = fluid.layers.data(name='input', shape=[5, 100], dtype='float32')
label = fluid.layers.data(name='label', shape=[5, 1], dtype='int64')
weight = fluid.layers.data(name='weight', shape=[100], dtype='float32')
ce_loss = fluid.dygraph.CrossEntropyLoss(weight=weight, reduction='mean')
output = ce_loss(input,label)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
input_data = np.random.random([5, 100]).astype("float32")
label_data = np.array([[1], [9], [40], [50], [90]]).astype("int64")
weight_data = np.random.random([100]).astype("float32")
output = exe.run(fluid.default_main_program(),
feed={"input": input_data, "label": label_data,"weight": weight_data},
fetch_list=[output],
return_numpy=True)
print(output)
# imperative mode
import paddle.fluid.dygraph as dg
with dg.guard(place) as g:
input = dg.to_variable(input_data)
label = dg.to_variable(label_data)
weight = dg.to_variable(weight_data)
ce_loss = fluid.dygraph.CrossEntropyLoss(weight=weight, reduction='mean')
output = ce_loss(input, label)
print(output.numpy())
"""
def __init__(self, weight=None, reduction='mean'):
super(CrossEntropyLoss, self).__init__()
self.weight = weight
self.reduction = reduction
def forward(self, input, label):
check_variable_and_dtype(input, 'input',
['float32', 'float64', 'int32', 'int64'],
'cross_entropy_loss')
check_variable_and_dtype(label, 'label',
['float32', 'float64', 'int32', 'int64'],
'cross_entropy_loss')
if self.reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in cross_entropy_loss should be 'sum', 'mean' or 'none',"
" but received %s, which is not allowed." % self.reduction)
softmax_out = F.softmax(input)
if self.weight is not None:
if isinstance(self.weight, Variable):
softmax_out = F.elementwise_pow(
softmax_out, self.weight, axis=-1)
else:
raise ValueError(
"The weight' is not a Variable, please convert to Variable.")
out = cross_entropy(softmax_out, label)
if self.reduction == 'sum':
return F.reduce_sum(out)
elif self.reduction == 'mean':
return F.reduce_mean(out)
else:
return out
class MSELoss(layers.Layer):
"""
**Mean Square Error Loss**
Computes the mean square error (squared L2 norm) of given input and label.
If :attr:`reduction` is set to ``'none'``, loss is calculated as:
.. math::
Out = (input - label)^2
If :attr:`reduction` is set to ``'mean'``, loss is calculated as:
.. math::
Out = \operatorname{mean}((input - label)^2)
If :attr:`reduction` is set to ``'sum'``, loss is calculated as:
.. math::
Out = \operatorname{sum}((input - label)^2)
where `input` and `label` are `float32` tensors of same shape.
Parameters:
input (Variable): Input tensor, the data type is float32,
label (Variable): Label tensor, the data type is float32,
reduction (string, optional): The reduction method for the output,
could be 'none' | 'mean' | 'sum'.
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned.
If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned.
Default is ``'mean'``.
Returns:
The tensor variable storing the MSE loss of input and label.
Return type:
Variable.
Examples:
.. code-block:: python
import numpy as np
from paddle import fluid
import paddle.fluid.dygraph as dg
mse_loss = fluid.dygraph.MSELoss()
input = fluid.data(name="input", shape=[1])
label = fluid.data(name="label", shape=[1])
place = fluid.CPUPlace()
input_data = np.array([1.5]).astype("float32")
label_data = np.array([1.7]).astype("float32")
# declarative mode
output = mse_loss(input,label)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
output_data = exe.run(
fluid.default_main_program(),
feed={"input":input_data, "label":label_data},
fetch_list=[output],
return_numpy=True)
print(output_data)
# [array([0.04000002], dtype=float32)]
# imperative mode
with dg.guard(place) as g:
input = dg.to_variable(input_data)
label = dg.to_variable(label_data)
output = mse_loss(input, label)
print(output.numpy())
# [0.04000002]
"""
def __init__(self, reduction='mean'):
super(MSELoss, self).__init__()
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"'reduction' in 'MSELoss' should be 'sum', 'mean' or 'none', "
"but received {}.".format(reduction))
self.reduction = reduction
def forward(self, input, label):
if not in_dygraph_mode():
check_variable_and_dtype(input, 'input', ['float32'], 'MSELoss')
check_variable_and_dtype(label, 'label', ['float32'], 'MSELoss')
square_out = square(F.elementwise_sub(input, label))
if self.reduction == 'none':
return square_out
reduce_op = 'reduce_mean'
if self.reduction == 'sum':
reduce_op = 'reduce_sum'
return getattr(F, reduce_op)(square_out)
class L1Loss(layers.Layer):
"""
This interface is used to construct a callable object of the ``L1Loss`` class.
The L1Loss layer calculates the L1 Loss of input predictions and target
labels as follows.
If :attr:`reduction` set to ``'none'``, the unreduced loss is:
.. math::
Out = |input - label|
If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:
.. math::
Out = MEAN(|input - label|)
If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:
.. math::
Out = SUM(|input - label|)
The shape of input predictions and target labels are [N, *], where N is batch_size and `*`
means any number of additional dimensions.
If :attr:`reduction` is ``'none'``, the shape of output loss is [N, *], the same as input.
If :attr:`reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1], which means the output is a scalar.
Parameters:
reduction (str, optional): Indicate the reduction to apply to the loss,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned.
If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned.
Default is ``'mean'``.
Returns:
A callable object of L1Loss.
Examples:
.. code-block:: python
# declarative mode
import paddle.fluid as fluid
import numpy as np
input = fluid.data(name="input", shape=[1])
label = fluid.data(name="label", shape=[1])
l1_loss = fluid.dygraph.L1Loss(reduction='mean')
output = l1_loss(input,label)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
input_data = np.array([1.5]).astype("float32")
label_data = np.array([1.7]).astype("float32")
output_data = exe.run(fluid.default_main_program(),
feed={"input":input_data, "label":label_data},
fetch_list=[output],
return_numpy=True)
print(output_data) # [array([0.2], dtype=float32)]
# imperative mode
import paddle.fluid.dygraph as dg
with dg.guard(place) as g:
input = dg.to_variable(input_data)
label = dg.to_variable(label_data)
l1_loss = fluid.dygraph.L1Loss(reduction='mean')
output = l1_loss(input,label)
print(output.numpy()) # [0.2]
"""
def __init__(self, reduction='mean'):
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in L1Loss should be 'sum', 'mean' or 'none', but "
"received %s, which is not allowed." % reduction)
super(L1Loss, self).__init__()
self.reduction = reduction
def forward(self, input, label):
check_variable_and_dtype(
input, 'input', ['float32', 'float64', 'int32', 'int64'], 'l1_loss')
check_variable_and_dtype(
label, 'label', ['float32', 'float64', 'int32', 'int64'], 'l1_loss')
unreduced = F.elementwise_sub(input, label, act='abs')
if self.reduction == 'sum':
return F.reduce_sum(unreduced)
elif self.reduction == 'mean':
return F.reduce_mean(unreduced)
else:
return unreduced
class BCELoss(layers.Layer):
"""
This interface is used to construct a callable object of the ``BCELoss`` class.
The BCELoss layer measures the binary_cross_entropy loss between input predictions
and target labels. The binary_cross_entropy loss can be described as:
If :attr:`weight` is set, the loss is:
.. math::
Out = -1 * weight * (label * log(input) + (1 - label) * log(1 - input))
If :attr:`weight` is None, the loss is:
.. math::
Out = -1 * (label * log(input) + (1 - label) * log(1 - input))
If :attr:`reduction` set to ``'none'``, the unreduced loss is:
.. math::
Out = Out
If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:
.. math::
Out = MEAN(Out)
If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:
.. math::
Out = SUM(Out)
Note that the input predictions always be the output of sigmoid, and the target labels
should be numbers between 0 and 1.
The shape of input predictions and target labels are [N, *], where N is batch_size and `*`
means any number of additional dimensions. If ``reduction`` is ``'none'``, the shape of
output is scalar, else the shape of output is same as input.
Parameters:
weight (Variable, optional): A manual rescaling weight given to the loss of each
batch element. If given, has to be a Variable of size nbatch and the data type
is float32, float64. Default is ``'None'``.
reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default is ``'mean'``.
Returns:
A callable object of BCELoss.
Examples:
.. code-block:: python
# declarative mode
import paddle.fluid as fluid
import numpy as np
input = fluid.data(name="input", shape=[3, 1], dtype='float32')
label = fluid.data(name="label", shape=[3, 1], dtype='float32')
bce_loss = fluid.dygraph.BCELoss()
output = bce_loss(input, label)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
input_data = np.array([0.5, 0.6, 0.7]).astype("float32")
label_data = np.array([1.0, 0.0, 1.0]).astype("float32")
output_data = exe.run(fluid.default_main_program(),
feed={"input":input_data, "label":label_data},
fetch_list=[output],
return_numpy=True)
print(output_data) # [array([0.65537095], dtype=float32)]
# imperative mode
import paddle.fluid.dygraph as dg
with dg.guard(place) as g:
input = dg.to_variable(input_data)
label = dg.to_variable(label_data)
output = bce_loss(input, label)
print(output.numpy()) # [0.65537095]
"""
def __init__(self, weight=None, reduction='mean'):
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in bce_loss should be 'sum', 'mean' or 'none', but "
"received %s, which is not allowed." % reduction)
super(BCELoss, self).__init__()
self.weight = weight
self.reduction = reduction
def forward(self, input, label):
dtype = self._helper.input_dtype(input)
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'bce_loss')
check_variable_and_dtype(label, 'label', ['float32', 'float64'],
'bce_loss')
out = self._helper.create_variable_for_type_inference(dtype=input.dtype)
self._helper.append_op(
type='bce_loss',
inputs={
'X': [input],
'Label': [label],
},
outputs={'Out': [out]})
if self.weight is not None:
if isinstance(self.weight, Variable):
w = self.weight
out = F.elementwise_mul(out, w, axis=-1)
else:
raise ValueError(
"The weight is not a Variable, please convert to Variable.")
if self.reduction == 'sum':
return F.reduce_sum(out)
elif self.reduction == 'mean':
return F.reduce_mean(out)
else:
return out
class NLLLoss(layers.Layer):
"""
This op accepts input and target label and returns negative log likelihood
cross error. It is useful to train a classification problem with C classes.
The input for the loss is epected to contain log-probabilities of
each classes. It hs to be a Tensor of size either (batch_size, C) or
(batch_size, C, d1, d2, ..., dK) with K >= 1 for the K-dimensional case.
The label for the loss should be a class index in the range [0, C-1]
where C is the number of classes. If ignore_index is specified, the
specified target value does not contribute to the input gradient.
If the optional argument `weight` is provided, it should be a 1D Tensor
assigning weight to each of the classed. This is particularly useful
when you have an unbalanced training set.
The loss is calculated as follows.
The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
.. math::
\ell(x, y) = L = \{l_1,\dots,l_N\}^\\top, \quad
l_n = - w_{y_n} x_{n,y_n}, \quad
w_{c} = \\text{weight}[c] \cdot \mathbb{1}\{c \\not= \\text{ignore\\_index}\},
where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
(default ``'mean'``), then
.. math::
\ell(x, y) = \\begin{cases}
\\sum_{n=1}^N \\frac{1}{\\sum_{n=1}^N w_{y_n}} l_n, &
\\text{if reduction} = \\text{'mean';}\\\\
\\sum_{n=1}^N l_n, &
\\text{if reduction} = \\text{'sum'.}
\\end{cases}
Parameters:
input (Variable): Input tensor, the data type is float32, float64.
label (Variable): Label tensor, the data type is int64_t.
weight (Variable, optional): Weight tensor, a manual rescaling weight given
to each class. If given, it has to be a Tensor of size `C`. Otherwise,
it treated as if having all ones. the data type is
float32, float64, Default is ``'None'``.
reduction (str, optional): Indicate how to average the loss,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
Default is ``'mean'``.
ignore_index (int64, optional): Specifies a target value that is ignored
and does not contribute to the input gradient.
Returns:
The tensor variable storing the nll_loss.
Return type: Variable.
Examples:
.. code-block:: python
# declarative mode
import paddle.fluid as fluid
import numpy as np
input_np = np.random.random(size=(10, 10)).astype(np.float32)
label_np = np.random.randint(0, 10, size=(10,)).astype(np.int64)
prog = fluid.Program()
startup_prog = fluid.Program()
place = fluid.CPUPlace()
with fluid.program_guard(prog, startup_prog):
input = fluid.data(name='input', shape=[10, 10], dtype='float32')
label = fluid.data(name='label', shape=[10], dtype='int64')
nll_loss = fluid.dygraph.NLLLoss()
res = nll_loss(input, label)
exe = fluid.Executor(place)
static_result = exe.run(
prog,
feed={"input": input_np,
"label": label_np},
fetch_list=[res])
print(static_result)
# imperative mode
import paddle.fluid.dygraph as dg
with dg.guard(place) as g:
input = dg.to_variable(input_np)
label = dg.to_variable(label_np)
output = nll_loss(input, label)
print(output.numpy())
"""
def __init__(self, weight=None, reduction='mean', ignore_index=-100):
super(NLLLoss, self).__init__()
self.weight = weight
self.reduction = reduction
self.ignore_index = ignore_index
def forward(self, input, label):
dtype = self._helper.input_dtype(input)
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'nll_loss')
check_variable_and_dtype(label, 'label', ['int64'], 'nll_loss')
if self.reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in nll_loss should be 'sum', 'mean' or 'none', but "
"received %s, which is not allowed." % self.reduction)
x_shape = list(input.shape)
n = x_shape[0]
c = x_shape[1]
x_dims = len(x_shape)
if x_dims < 2:
raise ValueError('Expected 2 or more dimensions (got {})'.format(
x_dims))
if x_dims != 2 and x_dims != 4:
input = F.reshape(input, shape=[n, c, 1, -1])
label = F.reshape(label, shape=[n, 1, -1])
out_shape = [n] + x_shape[2:]
inputs = {'X': input, 'Label': label}
attrs = {'reduction': self.reduction, 'ignore_index': self.ignore_index}
if self.weight is not None:
if isinstance(self.weight, Variable):
inputs['Weight'] = self.weight
out = self._helper.create_variable_for_type_inference(dtype=input.dtype)
total_weight = self._helper.create_variable_for_type_inference(
dtype=input.dtype)
outputs = {'Out': out, 'Total_weight': total_weight}
self._helper.append_op(
type='nll_loss', inputs=inputs, outputs=outputs, attrs=attrs)
if x_dims != 2 and x_dims != 4 and self.reduction == 'none':
out = F.reshape(out, shape=out_shape)
return out
......@@ -26,11 +26,11 @@ import six
import paddle
from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant, NumpyArrayInitializer
from ..framework import Variable, OpProtoHolder, in_dygraph_mode, dygraph_only, _dygraph_tracer, default_main_program
from ..framework import Variable, OpProtoHolder, in_dygraph_mode, dygraph_only, _dygraph_tracer, default_main_program, device_guard, _varbase_creator
from .. import dygraph_utils
from ..param_attr import ParamAttr
from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_
from .tensor import concat, assign, fill_constant, zeros, tensor_array_to_tensor
from .tensor import concat, assign, fill_constant, zeros, tensor_array_to_tensor, cast
from . import utils
from .. import unique_name
from functools import reduce
......@@ -39,153 +39,38 @@ from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, c
import paddle
__all__ = [
'fc',
'embedding',
'linear_chain_crf',
'crf_decoding',
'cos_sim',
'chunk_eval',
'conv2d',
'conv3d',
'softmax',
'pool2d',
'pool3d',
'adaptive_pool2d',
'adaptive_pool3d',
'batch_norm',
'inplace_abn',
'instance_norm',
'data_norm',
'conv2d_transpose',
'conv3d_transpose',
'reduce_sum',
'reduce_mean',
'reduce_max',
'reduce_min',
'reduce_prod',
'reduce_all',
'reduce_any',
'dropout',
'split',
'ctc_greedy_decoder',
'l2_normalize',
'matmul',
'topk',
'transpose',
'im2sequence',
'row_conv',
'multiplex',
'layer_norm',
'group_norm',
'spectral_norm',
'smooth_l1',
'one_hot',
'autoincreased_step_counter',
'reshape',
'squeeze',
'unsqueeze',
'lod_reset',
'lod_append',
'lrn',
'pad',
'pad_constant_like',
'label_smooth',
'roi_pool',
'roi_align',
'dice_loss',
'image_resize',
'image_resize_short',
'resize_bilinear',
'resize_trilinear',
'resize_nearest',
'gather',
'gather_nd',
'scatter',
'scatter_nd_add',
'scatter_nd',
'random_crop',
'mean_iou',
'relu',
'selu',
'log',
'crop',
'crop_tensor',
'elu',
'relu6',
'pow',
'stanh',
'hard_sigmoid',
'swish',
'prelu',
'brelu',
'leaky_relu',
'soft_relu',
'flatten',
'stack',
'pad2d',
'unstack',
'unique',
'unique_with_counts',
'expand',
'expand_as',
'scale',
'elementwise_add',
'elementwise_div',
'elementwise_sub',
'elementwise_mul',
'elementwise_max',
'elementwise_min',
'elementwise_pow',
'elementwise_mod',
'elementwise_floordiv',
'uniform_random_batch_size_like',
'gaussian_random',
'sampling_id',
'gaussian_random_batch_size_like',
'sum',
'slice',
'strided_slice',
'shape',
'rank',
'size',
'logical_and',
'logical_or',
'logical_xor',
'logical_not',
'clip',
'clip_by_norm',
'mean',
'mul',
'maxout',
'space_to_depth',
'affine_grid',
'affine_channel',
'similarity_focus',
'hash',
'grid_sampler',
'log_loss',
'add_position_encoding',
'bilinear_tensor_product',
'merge_selected_rows',
'get_tensor_from_selected_rows',
'shuffle_channel',
'temporal_shift',
'py_func',
'psroi_pool',
'prroi_pool',
'pixel_shuffle',
'fsp_matrix',
'continuous_value_model',
'where',
'sign',
'deformable_conv',
'unfold',
'deformable_roi_pooling',
'filter_by_instag',
'shard_index',
'hard_swish',
'gather_tree',
'uniform_random',
'fc', 'embedding', 'linear_chain_crf', 'crf_decoding', 'cos_sim',
'chunk_eval', 'conv2d', 'conv3d', 'softmax', 'pool2d', 'pool3d',
'adaptive_pool2d', 'adaptive_pool3d', 'batch_norm', 'inplace_abn',
'instance_norm', 'data_norm', 'conv2d_transpose', 'conv3d_transpose',
'reduce_sum', 'reduce_mean', 'reduce_max', 'reduce_min', 'reduce_prod',
'reduce_all', 'reduce_any', 'dropout', 'split', 'ctc_greedy_decoder',
'l2_normalize', 'matmul', 'topk', 'transpose', 'im2sequence', 'row_conv',
'multiplex', 'layer_norm', 'group_norm', 'spectral_norm', 'smooth_l1',
'one_hot', 'autoincreased_step_counter', 'reshape', 'squeeze', 'unsqueeze',
'lod_reset', 'lod_append', 'lrn', 'pad', 'pad_constant_like',
'label_smooth', 'roi_pool', 'roi_align', 'dice_loss', 'image_resize',
'image_resize_short', 'resize_bilinear', 'resize_trilinear',
'resize_nearest', 'gather', 'gather_nd', 'scatter', 'scatter_nd_add',
'scatter_nd', 'random_crop', 'mean_iou', 'relu', 'selu', 'log', 'crop',
'crop_tensor', 'elu', 'relu6', 'pow', 'stanh', 'hard_sigmoid', 'swish',
'prelu', 'brelu', 'leaky_relu', 'soft_relu', 'flatten', 'stack', 'pad2d',
'unstack', 'unique', 'unique_with_counts', 'expand', 'expand_as', 'scale',
'elementwise_add', 'elementwise_div', 'elementwise_sub', 'elementwise_mul',
'elementwise_max', 'elementwise_min', 'elementwise_pow', 'elementwise_mod',
'elementwise_floordiv', 'uniform_random_batch_size_like', 'gaussian_random',
'sampling_id', 'gaussian_random_batch_size_like', 'sum', 'slice',
'strided_slice', 'shape', 'rank', 'size', 'logical_and', 'logical_or',
'logical_xor', 'logical_not', 'clip', 'clip_by_norm', 'mean', 'mul',
'maxout', 'space_to_depth', 'affine_grid', 'affine_channel',
'similarity_focus', 'hash', 'grid_sampler', 'log_loss',
'add_position_encoding', 'bilinear_tensor_product', 'merge_selected_rows',
'get_tensor_from_selected_rows', 'shuffle_channel', 'temporal_shift',
'py_func', 'psroi_pool', 'prroi_pool', 'pixel_shuffle', 'fsp_matrix',
'continuous_value_model', 'where', 'sign', 'deformable_conv', 'unfold',
'deformable_roi_pooling', 'filter_by_instag', 'shard_index', 'hard_swish',
'gather_tree', 'uniform_random', 'randint', 'randn', 'randperm', 'allclose',
'elementwise_equal', 'flip', 'roll', 'log_softmax'
]
......@@ -14312,3 +14197,681 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0):
outputs={"Out": out})
return helper.append_activation(out)
def randint(low,
high=None,
shape=None,
out=None,
dtype=None,
device=None,
stop_gradient=False,
seed=0,
name=None):
"""
This function returns a Tensor filled with random integers from the "discrete uniform" distribution of the
specified data type in the interval [low, high). If high is None (the default), then results are from [0, low).
Args:
low (int): The lower bound on the range of random values to generate, the low is included in the range.
(unless high=None, in which case this parameter is one above the highest such integer).
high (int, optional): The upper bound on the range of random values to generate, the high is excluded
in the range. Default None(see above for behavior if high=None).
shape (list|tuple|Variable, optional): The shape of the output Tensor, if the shape is a list or tuple,
its elements can be an integer
or a Tensor with the shape [1], and the type of the Tensor must be int32 or int64.
If the shape is a Variable, it is a 1-D Tensor, and the type of the Tensor must be
int32 or int64. Default is None, in which case the shape is [1].
out(Variable, optional): Optional output which can be any created
Variable that meets the requirements to store the result of operation.
if out is None, a new Varibale will be create to store the result.
dtype(np.dtype|core.VarDesc.VarType|str, optional): Data type of the output Tensor
which can be int32, int64, if dytpe is `None`, the data
type of created Tensor is `int64`
device(str, optional): This parameter specifies that the Tensor is created
on the GPU or CPU.
stop_gradient(bool, optional): Indicating if we stop gradient from current(out) Variable,
default value is False.
seed (int, optional): Random seed used for permute samples. If seed is
equal to 0, it means use a seed generated by the system. Note that
if seed is not 0, this operator will always generate the same random
permutation every time. Default: 0.
name(str, optional): The default value is None. Normally there is no need for user to set this
property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Variable: A Tensor of the specified shape filled with random integers.
Raises:
TypeError: Randint's low must less then high.
Examples:
.. code-block:: python
import paddle.fluid as fluid
# example 1:
# attr shape is a list which doesn't contain tensor Variable.
result_1 = fluid.layers.randint(low=-5, high=5, shape=[3, 4], dtype="int64")
# example 2:
# attr shape is a list which contains tensor Variable.
dim_1 = fluid.layers.fill_constant([1],"int64",3)
dim_2 = fluid.layers.fill_constant([1],"int32",5)
result_2 = fluid.layers.randint(low=-5, high=5, shape=[dim_1, dim_2], dtype="int32")
# example 3:
# attr shape is a Variable, the data type must be int64 or int32.
var_shape = fluid.data(name='var_shape', shape=[2], dtype="int64")
result_3 = fluid.layers.randint(low=-5, high=5, shape=var_shape, dtype="int32")
var_shape_int32 = fluid.data(name='var_shape_int32', shape=[2], dtype="int32")
result_4 = fluid.layers.randint(low=-5, high=5, shape=var_shape_int32, dtype="int64")
# example 4:
# Input only one parameter
# low=0, high=10, shape=[1], dtype='int64'
result_4 = fluid.layers.randint(10)
"""
def get_new_shape_tensor(list_shape):
new_shape_tensor = []
for dim in list_shape:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_shape_tensor.append(dim)
else:
assert isinstance(dim, int) or isinstance(dim, long)
temp_out = helper.create_variable_for_type_inference('int64')
fill_constant([1], 'int64', dim, force_cpu=True, out=temp_out)
new_shape_tensor.append(temp_out)
return new_shape_tensor
def get_attr_shape(list_shape):
unk_dim_idx = -1
attrs_shape = []
for dim_idx, dim_size in enumerate(list_shape):
if isinstance(dim_size, Variable):
attrs_shape.append(-1)
else:
attrs_shape.append(dim_size)
assert dim_size > 0, (
"Each dimension size given in shape must not be negative "
"except one unknown dimension.")
return attrs_shape
if dtype is None:
dtype = 'int64'
check_dtype(dtype, 'dtype', ['int32', 'int64'], 'randint')
inputs = dict()
attrs = dict()
if shape is None:
shape = [1]
assert len(shape) > 0, ("The size of argument(shape) can't be zero.")
helper = LayerHelper("randint", **locals())
if in_dygraph_mode():
attrs['shape'] = shape
else:
if isinstance(shape, Variable):
shape.stop_gradient = True
inputs["ShapeTensor"] = shape
elif isinstance(shape, (list, tuple)):
assert len(shape) > 0, (
"The size of argument(shape) can't be zero.")
if utils._contain_var(shape):
inputs['ShapeTensorList'] = get_new_shape_tensor(shape)
else:
attrs["shape"] = get_attr_shape(shape)
check_type(shape, 'shape', (list, tuple, Variable), 'randint')
if high is None:
high = low
low = 0
attrs['low'] = low
attrs['high'] = high
attrs['seed'] = seed
if (low >= high):
raise ValueError(
"randint's low must less then high, but received low = {0}, "
"high = {1}".format(low, high))
if out is None:
if name is None:
out = helper.create_variable_for_type_inference(dtype=dtype)
else:
out = helper.create_variable(
name=name, dtype=dtype, persistable=False)
else:
check_dtype(dtype, 'dtype',
convert_dtype(out.dtype), 'randint',
"(The dtype in randint must be the same with out's dtype.)")
attrs['dtype'] = out.dtype
out.stop_gradient = stop_gradient
if device is None:
helper.append_op(
type='randint', inputs=inputs, outputs={'Out': out}, attrs=attrs)
else:
with device_guard(device):
helper.append_op(
type='randint',
inputs=inputs,
outputs={'Out': out},
attrs=attrs)
return out
def randn(shape,
out=None,
dtype=None,
device=None,
stop_gradient=True,
name=None):
"""
This function returns a tensor filled with random numbers from a normal
distribution with mean 0 and variance 1 (also called the standard normal
distribution).
Args:
shape(list|tuple): Shape of the generated random tensor.
out(Variable, optional): Optional output which can be any created Variable
that meets the requirements to store the result of operation. If the
out is `None`, a new Variable wiil be returned to store the result.
Default is None.
dtype(np.dtype|core.VarDesc.VarType|str, optional): Data type of the output
tensor, which can be float32, float64. if dtype is `None` , the data
type of output tensor is `float32` .
Default is None.
device(str, optional): Specific the output variable to be saved in cpu
or gpu memory. Supported None, 'cpu', 'gpu'. If it is None, the output
variable will be automatically assigned devices.
Default: None.
stop_gradient(bool, optional): Indicating if we stop gradient from current(out)
Variable. Default is True.
name(str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Default is None.
Returns:
Random tensor whose data is drawn from a Gaussian distribution,
dtype: flaot32 or float64 as specified.
Return type:
Variable
Raises:
TypeError: If the type of `shape` is not list or tuple.
TypeError: If the data type of `dtype` is not float32 or float64.
ValueError: If the length of `shape` is not bigger than 0.
Examples:
.. code-block:: python
# declarative mode
import paddle.fluid as fluid
data = fluid.layers.randn([2, 4])
place = fluid.CPUPlace()
exe = fluid.Executor(place)
res, = exe.run(fluid.default_main_program(), feed={}, fetch_list=[data])
print(res)
# [[-1.4187592 0.7368311 -0.53748125 -0.0146909 ]
# [-0.66294265 -1.3090698 0.1898754 -0.14065823]]
.. code-block:: python
# imperative mode
import paddle.fluid as fluid
import paddle.fluid.dygraph as dg
place = fluid.CPUPlace()
with dg.guard(place) as g:
x = fluid.layers.randn([2, 4])
x_np = x.numpy()
print(x_np)
# [[ 1.5149173 -0.26234224 -0.592486 1.4523455 ]
# [ 0.04581212 -0.85345626 1.1687907 -0.02512913]]
"""
helper = LayerHelper("randn", **locals())
check_type(shape, 'shape', (list, tuple), 'randn')
assert len(shape) > 0, ("The size of argument(shape) can't be zero.")
if dtype is None:
dtype = 'float32'
check_dtype(dtype, 'create data type', ['float32', 'float64'], 'randn')
if out is None:
out = helper.create_variable_for_type_inference(dtype=dtype)
else:
check_variable_and_dtype(out, 'out', [dtype], 'randn')
out.stop_gradient = stop_gradient
dtype = convert_np_dtype_to_dtype_(dtype)
seed = np.random.randint(0, 100)
with device_guard(device):
helper.append_op(
type='gaussian_random',
outputs={'Out': out},
attrs={
'shape': shape,
'mean': 0.0,
'std': 1.0,
'seed': seed,
'dtype': dtype,
'use_mkldnn': False
})
return out
@templatedoc()
def randperm(n,
out=None,
dtype="int64",
device=None,
stop_gradient=True,
seed=0):
"""
${comment}
Args:
n (int): The upper bound (exclusive), and it should be greater than 0.
out (Variable, optional): Optional output which can be any created
Variable that meets the requirements to store the result of operation.
If out is None, a new Varibale will be create to store the result.
Default: None.
dtype (np.dtype|core.VarDesc.VarType|str, optional): The type of the
output Tensor. Supported data types: int64, int32. Default: int32.
device (str, optional): Specific the output variable to be saved in cpu
or gpu memory. Supported None, 'cpu', 'gpu'. If it is None, the output
variable will be automatically assigned devices.
Default: None.
stop_gradient (bool, optional): Whether grad should record operations
on the returned tensor. Default: True.
seed (int, optional): Random seed used for permute samples. If seed is
equal to 0, it means use a seed generated by the system. Note that
if seed is not 0, this operator will always generate the same random
permutation every time. Default: 0.
Returns:
${out_comment}.
Return Type:
${out_type}
Examples:
.. code-block:: python
import paddle.fluid as fluid
num = 6
is_use_gpu = False
data_1 = fluid.layers.randperm(num)
fluid.layers.Print(data_1)
data_2 = fluid.layers.randperm(num, dtype="int32", seed=1)
fluid.layers.Print(data_2)
data_3 = fluid.layers.randperm(num, stop_gradient=False, device="cpu")
fluid.layers.Print(data_3)
fluid.layers.randperm(num, out=data_3)
fluid.layers.Print(data_3)
place = fluid.CUDAPlace(0) if is_use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
exe.run()
"""
if n < 1:
raise ValueError("The input n should be greater than 0 in randperm op.")
check_dtype(dtype, 'dtype', ['int64', 'int32'], 'randperm')
dtype = convert_dtype(dtype)
if device not in [None, 'cpu', 'gpu']:
raise ValueError("The input device should in [None, 'cpu', 'gpu'].")
check_type(stop_gradient, 'stop_gradient', bool, 'randperm')
helper = LayerHelper("randperm", **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype=dtype)
else:
check_variable_and_dtype(out, 'out', [dtype], 'randperm')
if stop_gradient:
out.stop_gradient = True
inputs = dict()
outputs = {'Out': [out]}
attrs = {'n': n, 'dtype': out.dtype, 'seed': seed}
with device_guard(device):
helper.append_op(
type='randperm', inputs=inputs, outputs=outputs, attrs=attrs)
return out
@templatedoc()
def allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
"""
${comment}
Args:
input(inputtype):{input_comment}.
other(othertype):{other_comment}.
rtol(rtoltype,optional):{rtol_comment}.
atol(atoltype,optional):{atol_comment}.
equal_nan(equalnantype,optional):{equal_nan_comment}.
name(STR, optional): The default value is None.
Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.
Returns:
${out_comment}.
Return Type:
${out_type}
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
use_cuda = fluid.core.is_compiled_with_cuda()
a = fluid.data(name="a", shape=[2], dtype='float32')
b = fluid.data(name="b", shape=[2], dtype='float32')
result = fluid.layers.allclose(a, b, rtol=1e-05, atol=1e-08,
equal_nan=False, name="ignore_nan")
result_nan = fluid.layers.allclose(a, b, rtol=1e-05, atol=1e-08,
equal_nan=True, name="equal_nan")
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
x = np.array([10000., 1e-07]).astype("float32")
y = np.array([10000.1, 1e-08]).astype("float32")
result_v, result_nan_v = exe.run(
feed={'a': x, 'b': y},
fetch_list=[result, result_nan])
print(result_v, result_nan_v)
# Output: (array([False]), array([False]))
x = np.array([10000., 1e-08]).astype("float32")
y = np.array([10000.1, 1e-09]).astype("float32")
result_v, result_nan_v = exe.run(
feed={'a': x, 'b': y},
fetch_list=[result, result_nan])
print(result_v, result_nan_v)
# Output: (array([ True]), array([ True]))
x = np.array([1.0, float('nan')]).astype("float32")
y = np.array([1.0, float('nan')]).astype("float32")
result_v, result_nan_v = exe.run(
feed={'a': x, 'b': y},
fetch_list=[result, result_nan])
print(result_v, result_nan_v)
# Output: (array([False]), array([ True]))
"""
check_type(rtol, 'rtol', float, 'allclose')
check_type(atol, 'atol', float, 'allclose')
check_type(equal_nan, 'equal_nan', bool, 'allclose')
helper = LayerHelper("allclose", **locals())
out = helper.create_variable_for_type_inference(dtype='bool')
inputs = {'Input': input, 'Other': other}
outputs = {'Out': out}
attrs = {'rtol': rtol, 'atol': atol, 'equal_nan': equal_nan}
helper.append_op(
type='allclose', inputs=inputs, outputs=outputs, attrs=attrs)
return out
def elementwise_equal(x, y, name=None):
"""
This layer returns the truth value of :math:`x == y` elementwise.
Args:
x(Variable): Tensor, data type is float32, float64, int32, int64.
y(Variable): Tensor, data type is float32, float64, int32, int64.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Variable: output Tensor, it's shape is the same as the input's Tensor,
and the data type is bool. The result of this op is stop_gradient.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
label = fluid.layers.assign(np.array([3, 3], dtype="int32"))
limit = fluid.layers.assign(np.array([3, 2], dtype="int32"))
out1 = fluid.layers.elementwise_equal(x=label, y=limit) #out1=[True, False]
"""
helper = LayerHelper("elementwise_equal", **locals())
out = helper.create_variable_for_type_inference(dtype='bool')
out.stop_gradient = True
helper.append_op(
type='equal',
inputs={'X': [x],
'Y': [y]},
outputs={'Out': [out]},
attrs={'force_cpu': False})
return out
def flip(input, dims, name=None):
"""
Reverse the order of a n-D tensor along given axis in dims.
Args:
input (Variable): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor
should be float32, float64, int32, int64, bool.
dims (list): The axis to flip on.
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
Variable: Tensor or LoDTensor calculated by flip layer. The data type is same with input.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
input = fluid.data(name="x", shape=[-1, 2, 2], dtype='float32')
output = fluid.layers.flip(input, dims=[0, 1])
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
img = np.arange(12).reshape((3,2,2)).astype(np.float32)
res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output])
print(res) # [[[10,11][8, 9]],[[6, 7],[4, 5]] [[2, 3],[0, 1]]]
"""
helper = LayerHelper("flip", **locals())
check_type(input, 'X', (Variable), 'flip')
dtype = helper.input_dtype()
check_dtype(dtype, 'X',
['float16', 'float32', 'float64', 'int32', 'int64', 'bool'],
'flip')
check_type(dims, 'dims', (list, tuple), 'flip')
assert len(dims) > 0, 'len(dims) must be greater than 0.'
if name is None:
out = helper.create_variable_for_type_inference(dtype)
else:
out = helper.create_variable(name=name, dtype=dtype, persistable=False)
helper.append_op(
type="flip",
inputs={"X": input},
outputs={"Out": out},
attrs={"dims": dims})
return out
def roll(input, shifts, dims=None):
"""
Roll the `input` tensor along the given dimension(s). Elements that are shifted beyond
the last position are re-introduced at the first position. If a dimension is not specified,
the tensor will be flattened before rolling and then restored to the original shape.
Args:
input (Variable): The input tensor variable.
shifts (int|list|tuple): The number of places by which the elements
of the `input` tensor are shifted.
dims (int|list|tuple|None): Dimentions along which to roll.
Returns:
Variable: A Tensor with same data type as `input`.
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
data = np.array([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]])
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(data)
out_z1 = fluid.layers.roll(x, shifts=1)
print(out_z1.numpy())
#[[9. 1. 2.]
# [3. 4. 5.]
# [6. 7. 8.]]
out_z2 = fluid.layers.roll(x, shifts=1, dims=0)
print(out_z2.numpy())
#[[7. 8. 9.]
# [1. 2. 3.]
# [4. 5. 6.]]
"""
helper = LayerHelper("roll", **locals())
origin_shape = input.shape
if type(shifts) == int:
shifts = [shifts]
if type(dims) == int:
dims = [dims]
if dims:
check_type(dims, 'dims', (list, tuple), 'roll')
check_type(shifts, 'shifts', (list, tuple), 'roll')
if in_dygraph_mode():
if dims is None:
input = core.ops.reshape(input, 'shape', [-1, 1])
dims = [0]
out = core.ops.roll(input, 'dims', dims, 'shifts', shifts)
return core.ops.reshape(out, 'shape', origin_shape)
out = helper.create_variable_for_type_inference(input.dtype)
if dims is None:
input = reshape(input, shape=[-1, 1])
dims = [0]
helper.append_op(
type='roll',
inputs={'X': input},
outputs={'Out': out},
attrs={'dims': dims,
'shifts': shifts})
out = reshape(out, shape=origin_shape, inplace=True)
return out
def log_softmax(input, axis=None, dtype=None, name=None):
"""
This operator implements the log_softmax layer. The calculation process is as follows:
.. math::
Out[i, j] = log(softmax(x))
= log(\\frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])})
Parameters:
input (Variable): The input variable. A multi-dimension Tensor with type float32, or float64.
axis (int, optional): The index of dimension to perform softmax calculations, it should be in
range :math:`[-1, rank-1]`, while :math:`rank` is the rank of input variable. Default: None.
None and -1 means the last dimension.
dtype (np.dtype|core.VarDesc.VarType|str): The desired data type of returned tensor. If specified,
the input tensor is casted to dtype before the operation is performed. This is useful for
preventing data type overflows. Default: None. Supported dtype: float32 or float64
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
Variable: ``Tensor`` indicates the output of softmax. The data type and shape are the same as ``input``.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
data = np.array([[[-2.0, 3.0, -4.0, 5.0],
[3.0, -4.0, 5.0, -6.0],
[-7.0, -8.0, 8.0, 9.0]],
[[1.0, -2.0, -3.0, 4.0],
[-5.0, 6.0, 7.0, -8.0],
[6.0, 7.0, 8.0, 9.0]]]).astype('float32')
with fluid.dygraph.guard():
data = fluid.dygraph.to_variable(data)
res = fluid.layers.log_softmax(data, -1)
# [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948]
# [ -2.1270514 -9.127051 -0.12705144 -11.127051 ]
# [-16.313261 -17.313261 -1.3132617 -0.31326184]]
# [[ -3.0518122 -6.051812 -7.051812 -0.051812 ]
# [-12.313267 -1.3132664 -0.3132665 -15.313267 ]
# [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]]
"""
axis = -1 if axis is None else axis
dtype = convert_np_dtype_to_dtype_(dtype) if dtype is not None else dtype
if in_dygraph_mode():
outs_cast = input if dtype is None \
else core.ops.cast(input, 'in_dtype', input.dtype, 'out_dtype', dtype)
outs_softmax = core.ops.softmax(outs_cast, 'axis', axis, 'use_cudnn',
False)
return core.ops.log(outs_softmax)
if dtype is None:
check_variable_and_dtype(
input, 'input', ['float16', 'float32', 'float64'], 'log_softmax')
helper = LayerHelper("log_softmax", **locals())
outs_cast = input
if dtype is not None:
outs_cast = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='cast',
inputs={'X': input},
outputs={'Out': outs_cast},
attrs={'in_dtype': input.dtype,
'out_dtype': dtype})
outs_softmax = helper.create_variable_for_type_inference(outs_cast.dtype)
helper.append_op(
type='softmax',
inputs={'X': outs_cast},
outputs={'Out': outs_softmax},
attrs={'axis': axis,
'use_cudnn': False})
outs_log = helper.create_variable_for_type_inference(outs_softmax.dtype)
helper.append_op(
type='log', inputs={'X': outs_softmax}, outputs={'Out': outs_log})
return outs_log
......@@ -18,8 +18,8 @@ from six.moves import reduce
from ..layer_helper import LayerHelper
from ..param_attr import ParamAttr
from ..initializer import Initializer
from ..framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator
from ..framework import Variable
from ..framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator, device_guard, OpProtoHolder
from ..framework import Variable, in_dygraph_mode
from ..initializer import Constant
from ..core import VarDesc
from .. import core
......@@ -30,32 +30,12 @@ import numpy
import warnings
__all__ = [
'create_tensor',
'create_parameter',
'create_global_var',
'cast',
'tensor_array_to_tensor',
'concat',
'sums',
'assign',
'fill_constant_batch_size_like',
'fill_constant',
'argmin',
'argmax',
'argsort',
'ones',
'zeros',
'reverse',
'has_inf',
'has_nan',
'isfinite',
'range',
'linspace',
'zeros_like',
'ones_like',
'diag',
'eye',
'kron',
'create_tensor', 'create_parameter', 'create_global_var', 'cast',
'tensor_array_to_tensor', 'concat', 'sums', 'assign',
'fill_constant_batch_size_like', 'fill_constant', 'argmin', 'argmax',
'argsort', 'ones', 'zeros', 'reverse', 'has_inf', 'has_nan', 'isfinite',
'range', 'linspace', 'zeros_like', 'ones_like', 'diag', 'eye', 'kron',
'full_like', 'arange', 'full', 'tril', 'triu'
]
......@@ -1587,6 +1567,412 @@ def ones_like(x, out=None):
return out
def full_like(input,
fill_value,
out=None,
dtype=None,
device=None,
stop_gradient=True,
name=None):
"""
**full_like**
This function creates a tensor filled with `fill_value` which has identical shape and dtype
with `input`.
Args:
input(Variable): The input tensor which specifies shape and dtype.
fill_value: The value to fill the tensor with. Data type can be bool, float32, float64, int32, int64. Default value is 0.
out(Variable): The output tensor.
Returns:
out(Variable): The tensor variable storing the output.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
input = fluid.data(name='input', dtype='float32', shape=[2, 3])
output = fluid.layers.full_like(input, 2.0)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
img=np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32)
res = exe.run(fluid.default_main_program(), feed={'input':img}, fetch_list=[output])
print(res) # [array([[2., 2., 2.], [2., 2., 2.]], dtype=float32)]
"""
helper = LayerHelper("full_like", **locals())
if dtype is None:
dtype = 'float32'
check_dtype(dtype, 'dtype',
['bool', 'float16', 'float32', 'int32', 'int64'], 'full_like')
if out is None:
out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op(
type='fill_any_like',
inputs={'X': [input]},
attrs={'value': fill_value},
outputs={'Out': [out]})
out.stop_gradient = stop_gradient
return out
def arange(start, end, step=1, dtype=None, name=None):
"""
Return evenly spaced values within a given interval.
Values are generated within the half-open interval [start, stop) (in other words,
the interval including start but excluding stop).
Parameters:
start(float32 | float64 | int32 | int64 | Variable): Start of interval. The interval includes this value.
when start is Variable, it is a 1-D Tensor with shape [1].
end(float32 | float64 | int32 | int64 | Variable): End of interval. The interval does not include this
value, except in some cases where step is not an integer
and floating point round-off affects the length of out. When end is Variable,
it is a 1-D Tensor with shape [1].
step(float32 | float64 | int32 | int64 | Variable): Spacing between values. For any output out, this is the
distance between two adjacent values, out[i+1] - out[i].
dtype(str|core.VarDesc.VarType): the data type of the output tensor, can be float32, float64, int32, int64.
Returns: a 1-D Tensor which is evenly spaced values within a given interval. Its data type is set by dtype.
Return type: Variable
examples:
.. code-block:: python
import paddle.fluid as fluid
# expected out put: [0, 2, 4, 6, 8]
data = fluid.layers.arange(0, 10, 2, 'int32')
#dygraph mode
import paddle.fluid as fluid
with fluid.dygraph.guard():
x = fluid.layers.arange(0, 6, 2)
# x: [0, 2, 4]
# x dtype: float32
"""
helper = LayerHelper("range", **locals())
if dtype is None:
dtype = 'float32'
check_dtype(dtype, 'create data type',
['float32', 'float64', 'int32', 'int64'], 'range')
dtype = convert_dtype(dtype)
if not isinstance(start, Variable):
start = fill_constant([1], dtype, start)
if not isinstance(end, Variable):
end = fill_constant([1], dtype, end)
if not isinstance(step, Variable):
step = fill_constant([1], dtype, step)
out = helper.create_variable_for_type_inference(dtype=start.dtype)
helper.append_op(
type='range',
inputs={'Start': start,
'End': end,
'Step': step},
outputs={'Out': [out]})
out.stop_gradient = True
return out
def full(shape,
fill_value,
out=None,
dtype=None,
device=None,
stop_gradient=True,
name=None):
"""
This Op return a Tensor with the `fill_value` which size is same as `shape`
Args:
shape(list|tuple|Variable): Shape of the Tensor to be created.
The data type is ``int32`` or ``int64`` . If ``shape`` is a list or tuple,
the elements of it should be integers or Tensors with shape [1].
If ``shape`` is an Variable, it should be an 1-D Tensor .
fill_value(bool|float16|float32|float64|int32|int64|Variable): The constant value
used to initialize the Tensor to be created. If fill_value is an Variable, it must be an 1-D Tensor.
out(Variable, optional): Optional output which can be any created
Variable that meets the requirements to store the result of operation.
if out is None, a new Varibale will be create to store the result.
dtype(np.dtype|core.VarDesc.VarType|str, optional): Data type of the output tensor
which can be float16, float32, float64, int32, int64, if dytpe is `None`, the data
type of created tensor is `float32`
device(str, optional): On which device to run this Op. The :attr:`device` must be
None, 'cpu' or 'gpu'. If :attr:`device` is None, the device that the user set in
the paddle program will be chosen. Default value is None.
stop_gradient(bool, optional): Indicating if we stop gradient from current(out) Variable,
default value is True.
name(str, optional): The default value is None. Normally there is no need for user to set this
property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Variable: Tensor which is created according to shape and dtype.
Raises:
TypeError: The `dtype` must be one of None, bool, float16, float32, float64, int32 and int64.
TypeError: The `out` must be a Variable.
TypeError: The `shape` must be one of Variable, list tuple.
Examples:
.. code-block:: python
import paddle.fluid as fluid
data1 = fluid.layers.full(shape=[2,1], fill_value=0, dtype='int64') # data1=[[0],[0]]
data2 = fluid.layers.full(shape=[2,1], fill_value=5, dtype='int64', device='gpu') # data2=[[5],[5]]
# attr shape is a list which contains Variable Tensor.
positive_2 = fluid.layers.fill_constant([1], "int32", 2)
data3 = fluid.layers.full(shape=[1, positive_2], dtype='float32', fill_value=1.5) # data3=[1.5, 1.5]
# attr shape is an Variable Tensor.
shape = fluid.layers.fill_constant([1,2], "int32", 2) # shape=[2,2]
data4 = fluid.layers.full(shape=shape, dtype='bool', fill_value=True) # data4=[[True,True],[True,True]]
# attr value is an Variable Tensor.
val = fluid.layers.fill_constant([1], "float32", 2.0) # val=[2.0]
data5 = fluid.layers.full(shape=[2,1], fill_value=val, dtype='float32') #data5=[[2.0],[2.0]]
"""
helper = LayerHelper("full", **locals())
if dtype is None:
dtype = 'float32'
check_dtype(dtype, 'create data type',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'full')
check_type(shape, 'shape', (Variable, list, tuple), 'full')
if out is not None:
check_type(shape, 'out', (Variable), 'full')
if out is None:
out = helper.create_variable_for_type_inference(dtype=dtype)
out.stop_gradient = stop_gradient
with device_guard(device):
out = fill_constant(shape=shape, dtype=dtype, value=fill_value, out=out)
return out
def _tril_triu_op(helper):
"""Base op of tril_op and triu_op
"""
op_type = helper.layer_type
x = helper.kwargs.get('input', None)
assert x is not None, 'x cannot be None in {}'.format(op_type)
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
op_type)
if len(x.shape) < 2:
raise ValueError("input shape in {} must be at least 2-D".format(
op_type))
diagonal = helper.kwargs.get('diagonal', 0)
if not isinstance(diagonal, (int, )):
raise TypeError("diagonal in {} must be a python Int".format(op_type))
name = helper.kwargs.get('name', None)
if name is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False)
helper.append_op(
type="tril_triu",
inputs={"X": x},
attrs={
"diagonal": diagonal,
"lower": True if op_type == 'tril' else False,
},
outputs={"Out": out}, )
return out
def tril(input, diagonal=0, name=None):
"""
This op returns the lower triangular part of a matrix (2-D tensor) or batch
of matrices :attr:`input`, the other elements of the result tensor are set
to 0. The lower triangular part of the matrix is defined as the elements
on and below the diagonal.
Args:
input (Variable): The input variable which is a Tensor.
Support data types: ``float64``, ``float32``, ``int32``, ``int64``.
diagonal (int, optional): The diagonal to consider, default value is 0.
If :attr:`diagonal` = 0, all elements on and below the main diagonal are
retained. A positive value includes just as many diagonals above the main
diagonal, and similarly a negative value excludes just as many diagonals below
the main diagonal. The main diagonal are the set of indices
:math:`\{(i, i)\}` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` where
:math:`d_{1}, d_{2}` are the dimensions of the matrix.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Variable: Tensor, results of lower triangular operation by the specified diagonal of input tensor,
it's data type is the same as input's Tensor.
Raises:
TypeError: diagonal is not a int type.
ValueError: dimension of :attr:`input` is less than 2.
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
data = np.arange(1, 13, dtype="int64").reshape(3,-1)
# array([[ 1, 2, 3, 4],
# [ 5, 6, 7, 8],
# [ 9, 10, 11, 12]])
x = fluid.data(shape=(-1, 4), dtype='int64', name='x')
exe = fluid.Executor(fluid.CPUPlace())
# example 1, default diagonal
tril = fluid.layers.tril(x)
tril_out, = exe.run(fluid.default_main_program(), feed={"x": data},
fetch_list=[tril], return_numpy=True)
# array([[ 1, 0, 0, 0],
# [ 5, 6, 0, 0],
# [ 9, 10, 11, 0]])
.. code-block:: python
# example 2, positive diagonal value
import paddle.fluid as fluid
import numpy as np
data = np.arange(1, 13, dtype="int64").reshape(3,-1)
x = fluid.data(shape=(-1, 4), dtype='int64', name='x')
exe = fluid.Executor(fluid.CPUPlace())
tril = fluid.layers.tril(x, diagonal=2)
tril_out, = exe.run(fluid.default_main_program(), feed={"x": data},
fetch_list=[tril], return_numpy=True)
# array([[ 1, 2, 3, 0],
# [ 5, 6, 7, 8],
# [ 9, 10, 11, 12]])
.. code-block:: python
# example 3, negative diagonal value
import paddle.fluid as fluid
import numpy as np
data = np.arange(1, 13, dtype="int64").reshape(3,-1)
x = fluid.data(shape=(-1, 4), dtype='int64', name='x')
exe = fluid.Executor(fluid.CPUPlace())
tril = fluid.layers.tril(x, diagonal=-1)
tril_out, = exe.run(fluid.default_main_program(), feed={"x": data},
fetch_list=[tril], return_numpy=True)
# array([[ 0, 0, 0, 0],
# [ 5, 0, 0, 0],
# [ 9, 10, 0, 0]])
"""
return _tril_triu_op(LayerHelper('tril', **locals()))
def triu(input, diagonal=0, name=None):
"""
This op returns the upper triangular part of a matrix (2-D tensor) or batch of matrices
:attr:`input`, the other elements of the result tensor are set to 0.
The upper triangular part of the matrix is defined as the elements on and
above the diagonal.
Args:
input (Variable): The input variable which is a Tensor.
Support data types: ``float64``, ``float32``, ``int32``, ``int64``.
diagonal (int, optional): The diagonal to consider, default value is 0.
If :attr:`diagonal` = 0, all elements on and above the main diagonal are
retained. A positive value excludes just as many diagonals above the main
diagonal, and similarly a negative value includes just as many diagonals below
the main diagonal. The main diagonal are the set of indices
:math:`\{(i, i)\}` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` where
:math:`d_{1}, d_{2}` are the dimensions of the matrix.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Variable: Tensor, results of upper triangular operation by the specified diagonal of input tensor,
it's data type is the same as input's Tensor.
Raises:
TypeError: diagonal is not a int type.
ValueError: dimension of :attr:`input` is less than 2.
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
data = np.arange(1, 13, dtype="int64").reshape(3,-1)
# array([[ 1, 2, 3, 4],
# [ 5, 6, 7, 8],
# [ 9, 10, 11, 12]])
x = fluid.data(shape=(-1, 4), dtype='int64', name='x')
exe = fluid.Executor(fluid.CPUPlace())
# example 1, default diagonal
import paddle.fluid as fluid
triu = fluid.layers.triu(x)
triu_out, = exe.run(fluid.default_main_program(), feed={"x": data},
fetch_list=[triu], return_numpy=True)
# array([[ 1, 2, 3, 4],
# [ 0, 6, 7, 8],
# [ 0, 0, 11, 12]])
.. code-block:: python
# example 2, positive diagonal value
import paddle.fluid as fluid
import numpy as np
data = np.arange(1, 13, dtype="int64").reshape(3,-1)
x = fluid.data(shape=(-1, 4), dtype='int64', name='x')
exe = fluid.Executor(fluid.CPUPlace())
triu = fluid.layers.triu(x, diagonal=2)
triu_out, = exe.run(fluid.default_main_program(), feed={"x": data},
fetch_list=[triu], return_numpy=True)
# array([[0, 0, 3, 4],
# [0, 0, 0, 8],
# [0, 0, 0, 0]])
.. code-block:: python
# example 3, negative diagonal value
import paddle.fluid as fluid
import numpy as np
data = np.arange(1, 13, dtype="int64").reshape(3,-1)
x = fluid.data(shape=(-1, 4), dtype='int64', name='x')
exe = fluid.Executor(fluid.CPUPlace())
triu = fluid.layers.triu(x, diagonal=-1)
triu_out, = exe.run(fluid.default_main_program(), feed={"x": data},
fetch_list=[triu], return_numpy=True)
# array([[ 1, 2, 3, 4],
# [ 5, 6, 7, 8],
# [ 0, 10, 11, 12]])
"""
return _tril_triu_op(LayerHelper('triu', **locals()))
@templatedoc(op_type="kron")
def kron(x, y, out=None, name=None):
"""${comment}
......
......@@ -23,9 +23,9 @@ class TestAllcloseLayer(unittest.TestCase):
a = fluid.data(name="a", shape=[2], dtype='float32')
b = fluid.data(name="b", shape=[2], dtype='float32')
result = paddle.allclose(
result = fluid.layers.allclose(
a, b, rtol=1e-05, atol=1e-08, equal_nan=False, name="ignore_nan")
result_nan = paddle.allclose(
result_nan = fluid.layers.allclose(
a, b, rtol=1e-05, atol=1e-08, equal_nan=True, name="equal_nan")
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
......@@ -82,7 +82,7 @@ class TestAllcloseLayer(unittest.TestCase):
with fluid.dygraph.guard():
x_v_1 = fluid.dygraph.to_variable(x_1)
y_v_1 = fluid.dygraph.to_variable(y_1)
ret_1 = paddle.allclose(
ret_1 = fluid.layers.allclose(
x_v_1,
y_v_1,
rtol=1e-05,
......@@ -90,7 +90,7 @@ class TestAllcloseLayer(unittest.TestCase):
equal_nan=False,
name='test_1')
self.assertEqual(ret_1.numpy()[0], False)
ret_1 = paddle.allclose(
ret_1 = fluid.layers.allclose(
x_v_1,
y_v_1,
rtol=1e-05,
......@@ -100,7 +100,7 @@ class TestAllcloseLayer(unittest.TestCase):
self.assertEqual(ret_1.numpy()[0], False)
x_v_2 = fluid.dygraph.to_variable(x_2)
y_v_2 = fluid.dygraph.to_variable(y_2)
ret_2 = paddle.allclose(
ret_2 = fluid.layers.allclose(
x_v_2,
y_v_2,
rtol=1e-05,
......@@ -108,7 +108,7 @@ class TestAllcloseLayer(unittest.TestCase):
equal_nan=False,
name='test_3')
self.assertEqual(ret_2.numpy()[0], True)
ret_2 = paddle.allclose(
ret_2 = fluid.layers.allclose(
x_v_2,
y_v_2,
rtol=1e-05,
......@@ -118,7 +118,7 @@ class TestAllcloseLayer(unittest.TestCase):
self.assertEqual(ret_2.numpy()[0], True)
x_v_3 = fluid.dygraph.to_variable(x_3)
y_v_3 = fluid.dygraph.to_variable(y_3)
ret_3 = paddle.allclose(
ret_3 = fluid.layers.allclose(
x_v_3,
y_v_3,
rtol=1e-05,
......@@ -126,7 +126,7 @@ class TestAllcloseLayer(unittest.TestCase):
equal_nan=False,
name='test_5')
self.assertEqual(ret_3.numpy()[0], False)
ret_3 = paddle.allclose(
ret_3 = fluid.layers.allclose(
x_v_3,
y_v_3,
rtol=1e-05,
......
......@@ -14,7 +14,6 @@
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import unittest
import numpy as np
......@@ -71,7 +70,7 @@ class TestInt32ArangeOpCase2(TestArangeOp):
class TestArangeAPI(unittest.TestCase):
def test_out(self):
with fluid.program_guard(fluid.Program()):
data = paddle.arange(0, 5, 1)
data = fluid.layers.arange(0, 5, 1)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result, = exe.run(fetch_list=[data])
......@@ -79,7 +78,7 @@ class TestArangeAPI(unittest.TestCase):
self.assertEqual((result == expected_data).all(), True)
with fluid.program_guard(fluid.Program()):
data = paddle.arange(0.0, 5.0, 1.0, 'int32')
data = fluid.layers.arange(0.0, 5.0, 1.0, 'int32')
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result, = exe.run(fetch_list=[data])
......
......@@ -36,7 +36,7 @@ class TestBCELoss(unittest.TestCase):
name='input', shape=[None, 30], dtype='float64')
label = fluid.data(
name='label', shape=[None, 30], dtype='float64')
bce_loss = paddle.nn.loss.BCELoss(reduction=red)
bce_loss = fluid.dygraph.BCELoss(reduction=red)
res = bce_loss(input, label)
exe = fluid.Executor(place)
......@@ -47,7 +47,7 @@ class TestBCELoss(unittest.TestCase):
fetch_list=[res])
with fluid.dygraph.guard():
bce_loss = paddle.nn.loss.BCELoss(reduction=red)
bce_loss = fluid.dygraph.BCELoss(reduction=red)
dy_res = bce_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
......@@ -80,7 +80,7 @@ class TestBCELoss(unittest.TestCase):
name='label', shape=[None, 3, 4, 10], dtype='float64')
weight = fluid.data(
name='weight', shape=[3, 4, 10], dtype='float64')
bce_loss = paddle.nn.loss.BCELoss(weight=weight)
bce_loss = fluid.dygraph.BCELoss(weight=weight)
res = bce_loss(input, label)
exe = fluid.Executor(place)
......@@ -93,7 +93,7 @@ class TestBCELoss(unittest.TestCase):
fetch_list=[res])
with fluid.dygraph.guard():
bce_loss = paddle.nn.loss.BCELoss(
bce_loss = fluid.dygraph.BCELoss(
weight=fluid.dygraph.to_variable(weight_np))
dy_res = bce_loss(
fluid.dygraph.to_variable(input_np),
......
......@@ -82,7 +82,7 @@ class API_TestElementwise_Equal(unittest.TestCase):
with fluid.program_guard(fluid.Program(), fluid.Program()):
label = fluid.layers.assign(np.array([3, 3], dtype="int32"))
limit = fluid.layers.assign(np.array([3, 2], dtype="int32"))
out = paddle.elementwise_equal(x=label, y=limit)
out = fluid.layers.elementwise_equal(x=label, y=limit)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
res, = exe.run(fetch_list=[out])
......@@ -91,7 +91,7 @@ class API_TestElementwise_Equal(unittest.TestCase):
with fluid.program_guard(fluid.Program(), fluid.Program()):
label = fluid.layers.assign(np.array([3, 3], dtype="int32"))
limit = fluid.layers.assign(np.array([3, 3], dtype="int32"))
out = paddle.elementwise_equal(x=label, y=limit)
out = fluid.layers.elementwise_equal(x=label, y=limit)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
res, = exe.run(fetch_list=[out])
......
......@@ -35,7 +35,7 @@ class CrossEntropyLoss(unittest.TestCase):
label = fluid.layers.data(name='label', shape=[5, 1], dtype='int64')
weight = fluid.layers.data(
name='weight', shape=[100], dtype='float32')
cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(weight=weight)
cross_entropy_loss = fluid.dygraph.CrossEntropyLoss(weight=weight)
ret = cross_entropy_loss(input, label)
exe = fluid.Executor(place)
......@@ -48,7 +48,7 @@ class CrossEntropyLoss(unittest.TestCase):
fetch_list=[ret])
self.assertIsNotNone(static_ret)
with fluid.dygraph.guard():
cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
cross_entropy_loss = fluid.dygraph.CrossEntropyLoss(
weight=fluid.dygraph.to_variable(weight_np))
dy_ret = cross_entropy_loss(
fluid.dygraph.to_variable(input_np),
......@@ -71,7 +71,7 @@ class CrossEntropyLoss(unittest.TestCase):
label = fluid.layers.data(name='label', shape=[5, 1], dtype='int64')
weight = fluid.layers.data(
name='weight', shape=[100], dtype='float32')
cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
cross_entropy_loss = fluid.dygraph.CrossEntropyLoss(
weight=weight, reduction='sum')
ret = cross_entropy_loss(input, label)
......@@ -85,7 +85,7 @@ class CrossEntropyLoss(unittest.TestCase):
fetch_list=[ret])
self.assertIsNotNone(static_ret)
with fluid.dygraph.guard():
cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
cross_entropy_loss = fluid.dygraph.CrossEntropyLoss(
weight=fluid.dygraph.to_variable(weight_np), reduction='sum')
dy_ret = cross_entropy_loss(
fluid.dygraph.to_variable(input_np),
......@@ -108,7 +108,7 @@ class CrossEntropyLoss(unittest.TestCase):
label = fluid.layers.data(name='label', shape=[5, 1], dtype='int64')
weight = fluid.layers.data(
name='weight', shape=[100], dtype='float32')
cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
cross_entropy_loss = fluid.dygraph.CrossEntropyLoss(
weight=weight, reduction='none')
ret = cross_entropy_loss(input, label)
......@@ -122,7 +122,7 @@ class CrossEntropyLoss(unittest.TestCase):
fetch_list=[ret])
self.assertIsNotNone(static_ret)
with fluid.dygraph.guard():
cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
cross_entropy_loss = fluid.dygraph.CrossEntropyLoss(
weight=fluid.dygraph.to_variable(weight_np), reduction='none')
dy_ret = cross_entropy_loss(
fluid.dygraph.to_variable(input_np),
......
......@@ -106,7 +106,7 @@ class TestFillAnyLikeOp_attr_out(unittest.TestCase):
with fluid.program_guard(train_program, startup_program):
fill_value = 2.0
input = fluid.data(name='input', dtype='float32', shape=[2, 3])
output = paddle.full_like(input, fill_value)
output = fluid.layers.full_like(input, fill_value)
place = fluid.CPUPlace()
if fluid.core.is_compiled_with_cuda():
......@@ -132,20 +132,20 @@ class TestFillAnyLikeOpError(unittest.TestCase):
#for ci coverage
input_data = fluid.data(name='input', dtype='float32', shape=[2, 3])
output = paddle.full_like(input_data, 2.0)
output = fluid.layers.full_like(input_data, 2.0)
def test_input_dtype():
paddle.full_like
fluid.layers.full_like
self.assertRaises(
ValueError,
paddle.full_like,
fluid.layers.full_like,
input=input_data,
fill_value=2,
dtype='uint4')
self.assertRaises(
TypeError,
paddle.full_like,
fluid.layers.full_like,
input=input_data,
fill_value=2,
dtype='int16')
......
......@@ -16,7 +16,6 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
......@@ -32,7 +31,7 @@ class TestFlipOp_API(unittest.TestCase):
with fluid.program_guard(train_program, startup_program):
dims = [0]
input = fluid.data(name='input', dtype='float32', shape=[2, 3])
output = paddle.flip(input, dims)
output = fluid.layers.flip(input, dims)
place = fluid.CPUPlace()
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
......@@ -52,7 +51,7 @@ class TestFlipOp_API(unittest.TestCase):
img = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32)
with fluid.dygraph.guard():
inputs = fluid.dygraph.to_variable(img)
ret = paddle.flip(inputs, [0])
ret = fluid.layers.flip(inputs, [0])
out_ref = np.array([[4, 5, 6], [1, 2, 3]]).astype(np.float32)
self.assertTrue(
(ret.numpy() == out_ref).all(),
......
......@@ -21,7 +21,6 @@ from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
import paddle
from paddle.fluid import compiler, Program, program_guard
......@@ -37,39 +36,39 @@ class TestFullAPI(unittest.TestCase):
shape_tensor_int64 = fluid.data(
name="shape_tensor_int64", shape=[2], dtype="int64")
out_1 = paddle.full(
out_1 = fluid.layers.full(
shape=[1, 2], dtype="float32", fill_value=1.1, device='gpu')
out_2 = paddle.full(
out_2 = fluid.layers.full(
shape=[1, positive_2_int32],
dtype="float32",
fill_value=1.1,
device='cpu')
out_3 = paddle.full(
out_3 = fluid.layers.full(
shape=[1, positive_2_int64],
dtype="float32",
fill_value=1.1,
device='gpu')
out_4 = paddle.full(
out_4 = fluid.layers.full(
shape=shape_tensor_int32,
dtype="float32",
fill_value=1.2,
out=out_3)
out_5 = paddle.full(
out_5 = fluid.layers.full(
shape=shape_tensor_int64,
dtype="float32",
fill_value=1.1,
device='gpu',
stop_gradient=False)
out_6 = paddle.full(
out_6 = fluid.layers.full(
shape=shape_tensor_int64, dtype=np.float32, fill_value=1.1)
val = fluid.layers.fill_constant(shape=[1], dtype=np.float32, value=1.1)
out_7 = paddle.full(
out_7 = fluid.layers.full(
shape=shape_tensor_int64, dtype=np.float32, fill_value=val)
exe = fluid.Executor(place=fluid.CPUPlace())
......@@ -97,17 +96,21 @@ class TestFullOpError(unittest.TestCase):
x1 = fluid.layers.data(name='x1', shape=[1], dtype="int16")
x2 = np.random.randn(1, 2).astype('int32')
self.assertRaises(
ValueError, paddle.full, shape=[1], fill_value=5, dtype='uint4')
ValueError,
fluid.layers.full,
shape=[1],
fill_value=5,
dtype='uint4')
self.assertRaises(
TypeError,
paddle.full,
fluid.layers.full,
shape=[1],
fill_value=5,
dtype='int32',
out=x2)
self.assertRaises(
TypeError,
paddle.full,
fluid.layers.full,
shape=[1],
fill_value=5,
dtype='int16',
......@@ -118,17 +121,21 @@ class TestFullOpError(unittest.TestCase):
x2 = fluid.layers.data(name='x2', shape=[1], dtype="int32")
self.assertRaises(
TypeError, paddle.full, shape=[1], fill_value=5, dtype='uint8')
TypeError,
fluid.layers.full,
shape=[1],
fill_value=5,
dtype='uint8')
# The argument shape's type of full_op must be list, tuple or Variable.
def test_shape_type():
paddle.full(shape=1, dtype="float32", fill_value=1)
fluid.layers.full(shape=1, dtype="float32", fill_value=1)
self.assertRaises(TypeError, test_shape_type)
# The argument shape's size of full_op must not be 0.
def test_shape_size():
paddle.full(shape=[], dtype="float32", fill_value=1)
fluid.layers.full(shape=[], dtype="float32", fill_value=1)
self.assertRaises(AssertionError, test_shape_size)
......@@ -136,14 +143,15 @@ class TestFullOpError(unittest.TestCase):
def test_shape_tensor_dtype():
shape = fluid.data(
name="shape_tensor", shape=[2], dtype="float32")
paddle.full(shape=shape, dtype="float32", fill_value=1)
fluid.layers.full(shape=shape, dtype="float32", fill_value=1)
self.assertRaises(TypeError, test_shape_tensor_dtype)
def test_shape_tensor_list_dtype():
shape = fluid.data(
name="shape_tensor_list", shape=[1], dtype="bool")
paddle.full(shape=[shape, 2], dtype="float32", fill_value=1)
fluid.layers.full(
shape=[shape, 2], dtype="float32", fill_value=1)
self.assertRaises(TypeError, test_shape_tensor_list_dtype)
......
......@@ -33,7 +33,7 @@ class TestL1Loss(unittest.TestCase):
name='input', shape=[10, 1], dtype='float32')
label = fluid.layers.data(
name='label', shape=[10, 1], dtype='float32')
l1_loss = paddle.nn.loss.L1Loss()
l1_loss = fluid.dygraph.L1Loss()
ret = l1_loss(input, label)
exe = fluid.Executor(place)
......@@ -44,7 +44,7 @@ class TestL1Loss(unittest.TestCase):
fetch_list=[ret])
with fluid.dygraph.guard():
l1_loss = paddle.nn.loss.L1Loss()
l1_loss = fluid.dygraph.L1Loss()
dy_ret = l1_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
......@@ -68,7 +68,7 @@ class TestL1Loss(unittest.TestCase):
name='input', shape=[10, 10, 5], dtype='float32')
label = fluid.layers.data(
name='label', shape=[10, 10, 5], dtype='float32')
l1_loss = paddle.nn.loss.L1Loss(reduction='sum')
l1_loss = fluid.dygraph.L1Loss(reduction='sum')
ret = l1_loss(input, label)
exe = fluid.Executor(place)
......@@ -79,7 +79,7 @@ class TestL1Loss(unittest.TestCase):
fetch_list=[ret])
with fluid.dygraph.guard():
l1_loss = paddle.nn.loss.L1Loss(reduction='sum')
l1_loss = fluid.dygraph.L1Loss(reduction='sum')
dy_ret = l1_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
......@@ -103,7 +103,7 @@ class TestL1Loss(unittest.TestCase):
name='input', shape=[10, 5], dtype='float32')
label = fluid.layers.data(
name='label', shape=[10, 5], dtype='float32')
l1_loss = paddle.nn.loss.L1Loss(reduction='none')
l1_loss = fluid.dygraph.L1Loss(reduction='none')
ret = l1_loss(input, label)
exe = fluid.Executor(place)
......@@ -114,7 +114,7 @@ class TestL1Loss(unittest.TestCase):
fetch_list=[ret])
with fluid.dygraph.guard():
l1_loss = paddle.nn.loss.L1Loss(reduction='none')
l1_loss = fluid.dygraph.L1Loss(reduction='none')
dy_ret = l1_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
......
......@@ -17,7 +17,6 @@ import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.nn as nn
import paddle.nn.functional as functional
def stable_softmax(x):
......@@ -84,14 +83,14 @@ class TestNNFunctionalLogSoftmaxAPI(unittest.TestCase):
mylogsoftmax = nn.LogSoftmax(axis)
with fluid.program_guard(main_program):
x = fluid.data(name='x', shape=self.x_shape)
y = functional.log_softmax(x, axis, dtype)
y = fluid.layers.log_softmax(x, axis, dtype)
exe = fluid.Executor(place)
out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y])
self.assertTrue(np.allclose(out[0], ref_out))
with fluid.dygraph.guard(place):
x = fluid.dygraph.to_variable(self.x)
y = functional.log_softmax(x, axis, dtype)
y = fluid.layers.log_softmax(x, axis, dtype)
self.assertTrue(np.allclose(y.numpy(), ref_out))
def test_check_api(self):
......
......@@ -18,8 +18,8 @@ import unittest
import numpy as np
from op_test import OpTest, skip_check_grad_ci
import paddle.fluid as fluid
import paddle
from paddle.fluid import compiler, Program, program_guard, core
import paddle
class TestMeshgridOp(OpTest):
......@@ -79,7 +79,7 @@ class TestMeshgridOp3(unittest.TestCase):
out_2 = np.broadcast_to(out_2, [100, 200])
exe = fluid.Executor(place=fluid.CPUPlace())
grid_x, grid_y = paddle.tensor.meshgrid([x, y])
grid_x, grid_y = paddle.meshgrid([x, y])
res_1, res_2 = exe.run(fluid.default_main_program(),
feed={'x': input_1,
'y': input_2},
......@@ -95,7 +95,7 @@ class TestMeshgridOp4(unittest.TestCase):
def test_input_type():
x = fluid.data(shape=[200], dtype='float32', name='x2')
paddle.tensor.meshgrid(x)
paddle.meshgrid(x)
self.assertRaises(TypeError, test_input_type)
......@@ -108,7 +108,7 @@ class TestMeshgridOp5(unittest.TestCase):
with fluid.dygraph.guard():
tensor_3 = fluid.dygraph.to_variable(input_3)
tensor_4 = fluid.dygraph.to_variable(input_4)
res_3, res_4 = paddle.tensor.meshgrid([tensor_3, tensor_4])
res_3, res_4 = paddle.meshgrid([tensor_3, tensor_4])
assert np.array_equal(res_3.shape, [100, 200])
assert np.array_equal(res_4.shape, [100, 200])
......
......@@ -78,7 +78,7 @@ class TestNNMseLoss(unittest.TestCase):
name='input', shape=dim, dtype='float32')
label = fluid.layers.data(
name='label', shape=dim, dtype='float32')
mse_loss = paddle.nn.loss.MSELoss()
mse_loss = fluid.dygraph.MSELoss()
ret = mse_loss(input, label)
exe = fluid.Executor(place)
......@@ -89,7 +89,7 @@ class TestNNMseLoss(unittest.TestCase):
fetch_list=[ret])
with fluid.dygraph.guard():
mse_loss = paddle.nn.loss.MSELoss()
mse_loss = fluid.dygraph.MSELoss()
dy_ret = mse_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
......@@ -115,7 +115,7 @@ class TestNNMseLoss(unittest.TestCase):
name='input', shape=dim, dtype='float32')
label = fluid.layers.data(
name='label', shape=dim, dtype='float32')
mse_loss = paddle.nn.loss.MSELoss(reduction='sum')
mse_loss = fluid.dygraph.MSELoss(reduction='sum')
ret = mse_loss(input, label)
exe = fluid.Executor(place)
......@@ -126,7 +126,7 @@ class TestNNMseLoss(unittest.TestCase):
fetch_list=[ret])
with fluid.dygraph.guard():
mse_loss = paddle.nn.loss.MSELoss(reduction='sum')
mse_loss = fluid.dygraph.MSELoss(reduction='sum')
dy_ret = mse_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
......@@ -152,7 +152,7 @@ class TestNNMseLoss(unittest.TestCase):
name='input', shape=dim, dtype='float32')
label = fluid.layers.data(
name='label', shape=dim, dtype='float32')
mse_loss = paddle.nn.loss.MSELoss(reduction='none')
mse_loss = fluid.dygraph.MSELoss(reduction='none')
ret = mse_loss(input, label)
exe = fluid.Executor(place)
......@@ -163,7 +163,7 @@ class TestNNMseLoss(unittest.TestCase):
fetch_list=[ret])
with fluid.dygraph.guard():
mse_loss = paddle.nn.loss.MSELoss(reduction='none')
mse_loss = fluid.dygraph.MSELoss(reduction='none')
dy_ret = mse_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
......
......@@ -82,7 +82,7 @@ class TestNLLLoss(unittest.TestCase):
with fluid.program_guard(prog, startup_prog):
input = fluid.data(name='input', shape=[10, 10], dtype='float64')
label = fluid.data(name='label', shape=[10], dtype='int64')
nll_loss = paddle.nn.loss.NLLLoss()
nll_loss = fluid.dygraph.NLLLoss()
res = nll_loss(input, label)
exe = fluid.Executor(place)
......@@ -93,7 +93,7 @@ class TestNLLLoss(unittest.TestCase):
fetch_list=[res])
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss()
nll_loss = fluid.dygraph.NLLLoss()
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
......@@ -115,7 +115,7 @@ class TestNLLLoss(unittest.TestCase):
with fluid.program_guard(prog, startup_prog):
input = fluid.data(name='input', shape=[10, 10], dtype='float64')
label = fluid.data(name='label', shape=[10], dtype='int64')
nll_loss = paddle.nn.loss.NLLLoss(reduction='sum')
nll_loss = fluid.dygraph.NLLLoss(reduction='sum')
res = nll_loss(input, label)
exe = fluid.Executor(place)
......@@ -126,7 +126,7 @@ class TestNLLLoss(unittest.TestCase):
fetch_list=[res])
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(reduction='sum')
nll_loss = fluid.dygraph.NLLLoss(reduction='sum')
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
......@@ -150,7 +150,7 @@ class TestNLLLoss(unittest.TestCase):
input = fluid.data(name='input', shape=[10, 10], dtype='float64')
label = fluid.data(name='label', shape=[10], dtype='int64')
weight = fluid.data(name='weight', shape=[10], dtype='float64')
nll_loss = paddle.nn.loss.NLLLoss(weight=weight)
nll_loss = fluid.dygraph.NLLLoss(weight=weight)
res = nll_loss(input, label)
exe = fluid.Executor(place)
......@@ -163,7 +163,7 @@ class TestNLLLoss(unittest.TestCase):
fetch_list=[res])
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(
nll_loss = fluid.dygraph.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np))
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
......@@ -188,7 +188,7 @@ class TestNLLLoss(unittest.TestCase):
input = fluid.data(name='input', shape=[10, 10], dtype='float64')
label = fluid.data(name='label', shape=[10], dtype='int64')
weight = fluid.data(name='weight', shape=[10], dtype='float64')
nll_loss = paddle.nn.loss.NLLLoss(weight=weight, reduction='sum')
nll_loss = fluid.dygraph.NLLLoss(weight=weight, reduction='sum')
res = nll_loss(input, label)
exe = fluid.Executor(place)
......@@ -201,7 +201,7 @@ class TestNLLLoss(unittest.TestCase):
fetch_list=[res])
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(
nll_loss = fluid.dygraph.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np), reduction='sum')
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
......@@ -225,7 +225,7 @@ class TestNLLLoss(unittest.TestCase):
input = fluid.data(name='input', shape=[10, 10], dtype='float64')
label = fluid.data(name='label', shape=[10], dtype='int64')
weight = fluid.data(name='weight', shape=[10], dtype='float64')
nll_loss = paddle.nn.loss.NLLLoss(weight=weight)
nll_loss = fluid.dygraph.NLLLoss(weight=weight)
res = nll_loss(input, label)
exe = fluid.Executor(place)
......@@ -238,7 +238,7 @@ class TestNLLLoss(unittest.TestCase):
fetch_list=[res])
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(
nll_loss = fluid.dygraph.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np))
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
......@@ -261,7 +261,7 @@ class TestNLLLoss(unittest.TestCase):
input = fluid.data(name='input', shape=[10, 10], dtype='float64')
label = fluid.data(name='label', shape=[10], dtype='int64')
weight = fluid.data(name='weight', shape=[10], dtype='float64')
nll_loss = paddle.nn.loss.NLLLoss(weight=weight, reduction='none')
nll_loss = fluid.dygraph.NLLLoss(weight=weight, reduction='none')
res = nll_loss(input, label)
exe = fluid.Executor(place)
......@@ -274,7 +274,7 @@ class TestNLLLoss(unittest.TestCase):
fetch_list=[res])
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(
nll_loss = fluid.dygraph.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np), reduction='none')
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
......@@ -299,7 +299,7 @@ class TestNLLLoss(unittest.TestCase):
input = fluid.data(
name='input', shape=[5, 3, 5, 5], dtype='float64')
label = fluid.data(name='label', shape=[5, 5, 5], dtype='int64')
nll_loss = paddle.nn.loss.NLLLoss()
nll_loss = fluid.dygraph.NLLLoss()
res = nll_loss(input, label)
exe = fluid.Executor(place)
......@@ -310,7 +310,7 @@ class TestNLLLoss(unittest.TestCase):
fetch_list=[res])
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss()
nll_loss = fluid.dygraph.NLLLoss()
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
......@@ -334,7 +334,7 @@ class TestNLLLoss(unittest.TestCase):
input = fluid.data(
name='input', shape=[5, 3, 5, 5], dtype='float64')
label = fluid.data(name='label', shape=[5, 5, 5], dtype='int64')
nll_loss = paddle.nn.loss.NLLLoss(reduction='sum')
nll_loss = fluid.dygraph.NLLLoss(reduction='sum')
res = nll_loss(input, label)
exe = fluid.Executor(place)
......@@ -345,7 +345,7 @@ class TestNLLLoss(unittest.TestCase):
fetch_list=[res])
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(reduction='sum')
nll_loss = fluid.dygraph.NLLLoss(reduction='sum')
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
......@@ -372,7 +372,7 @@ class TestNLLLoss(unittest.TestCase):
label = fluid.data(name='label', shape=[5, 5, 5], dtype='int64')
weight = fluid.data(name='weight', shape=[3], dtype='float64')
nll_loss = paddle.nn.loss.NLLLoss(weight=weight)
nll_loss = fluid.dygraph.NLLLoss(weight=weight)
res = nll_loss(input, label)
exe = fluid.Executor(place)
......@@ -385,7 +385,7 @@ class TestNLLLoss(unittest.TestCase):
fetch_list=[res])
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(
nll_loss = fluid.dygraph.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np))
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
......@@ -411,7 +411,7 @@ class TestNLLLoss(unittest.TestCase):
label = fluid.data(name='label', shape=[5, 5, 5], dtype='int64')
weight = fluid.data(name='weight', shape=[3], dtype='float64')
nll_loss = paddle.nn.loss.NLLLoss(weight=weight)
nll_loss = fluid.dygraph.NLLLoss(weight=weight)
res = nll_loss(input, label)
exe = fluid.Executor(place)
......@@ -424,7 +424,7 @@ class TestNLLLoss(unittest.TestCase):
fetch_list=[res])
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(
nll_loss = fluid.dygraph.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np))
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
......@@ -452,7 +452,7 @@ class TestNLLLoss(unittest.TestCase):
label = fluid.data(name='label', shape=[5, 5, 5], dtype='int64')
weight = fluid.data(name='weight', shape=[3], dtype='float64')
nll_loss = paddle.nn.loss.NLLLoss(weight=weight, reduction='sum')
nll_loss = fluid.dygraph.NLLLoss(weight=weight, reduction='sum')
res = nll_loss(input, label)
exe = fluid.Executor(place)
......@@ -465,7 +465,7 @@ class TestNLLLoss(unittest.TestCase):
fetch_list=[res])
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(
nll_loss = fluid.dygraph.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np), reduction='sum')
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
......@@ -491,7 +491,7 @@ class TestNLLLoss(unittest.TestCase):
input = fluid.data(
name='input', shape=[5, 3, 5, 5, 5], dtype='float64')
label = fluid.data(name='label', shape=[5, 5, 5, 5], dtype='int64')
nll_loss = paddle.nn.loss.NLLLoss()
nll_loss = fluid.dygraph.NLLLoss()
res = nll_loss(input, label)
exe = fluid.Executor(place)
......@@ -502,7 +502,7 @@ class TestNLLLoss(unittest.TestCase):
fetch_list=[res])
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss()
nll_loss = fluid.dygraph.NLLLoss()
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
......@@ -533,7 +533,7 @@ class TestNLLLoss(unittest.TestCase):
name='input', shape=[5, 3, 5, 5, 5], dtype='float64')
label = fluid.data(name='label', shape=[5, 5, 5, 5], dtype='int64')
weight = fluid.data(name='weight', shape=[3], dtype='float64')
nll_loss = paddle.nn.loss.NLLLoss(weight=weight)
nll_loss = fluid.dygraph.NLLLoss(weight=weight)
res = nll_loss(input, label)
exe = fluid.Executor(place)
......@@ -546,7 +546,7 @@ class TestNLLLoss(unittest.TestCase):
fetch_list=[res])
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(
nll_loss = fluid.dygraph.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np))
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
......@@ -579,7 +579,7 @@ class TestNLLLoss(unittest.TestCase):
name='input', shape=[5, 3, 5, 5, 5], dtype='float64')
label = fluid.data(name='label', shape=[5, 5, 5, 5], dtype='int64')
weight = fluid.data(name='weight', shape=[3], dtype='float64')
nll_loss = paddle.nn.loss.NLLLoss(weight=weight, reduction='sum')
nll_loss = fluid.dygraph.NLLLoss(weight=weight, reduction='sum')
res = nll_loss(input, label)
exe = fluid.Executor(place)
......@@ -592,7 +592,7 @@ class TestNLLLoss(unittest.TestCase):
fetch_list=[res])
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(
nll_loss = fluid.dygraph.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np), reduction='sum')
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
......@@ -628,7 +628,7 @@ class TestNLLLoss(unittest.TestCase):
name='input', shape=[5, 3, 5, 5, 5], dtype='float64')
label = fluid.data(name='label', shape=[5, 5, 5, 5], dtype='int64')
weight = fluid.data(name='weight', shape=[3], dtype='float64')
nll_loss = paddle.nn.loss.NLLLoss(weight=weight, reduction='none')
nll_loss = fluid.dygraph.NLLLoss(weight=weight, reduction='none')
res = nll_loss(input, label)
exe = fluid.Executor(place)
......@@ -641,7 +641,7 @@ class TestNLLLoss(unittest.TestCase):
fetch_list=[res])
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(
nll_loss = fluid.dygraph.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np), reduction='none')
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
......@@ -676,7 +676,7 @@ class TestNLLLoss(unittest.TestCase):
name='input', shape=[5, 3, 5, 5, 5], dtype='float64')
label = fluid.data(name='label', shape=[5, 5, 5, 5], dtype='int64')
weight = fluid.data(name='weight', shape=[3], dtype='float64')
nll_loss = paddle.nn.loss.NLLLoss(weight=weight, reduction='none')
nll_loss = fluid.dygraph.NLLLoss(weight=weight, reduction='none')
res = nll_loss(input, label)
exe = fluid.Executor(place)
......@@ -689,7 +689,7 @@ class TestNLLLoss(unittest.TestCase):
fetch_list=[res])
with fluid.dygraph.guard():
nll_loss = paddle.nn.loss.NLLLoss(
nll_loss = fluid.dygraph.NLLLoss(
weight=fluid.dygraph.to_variable(weight_np), reduction='none')
dy_res = nll_loss(
fluid.dygraph.to_variable(input_np),
......
......@@ -22,7 +22,6 @@ import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
import paddle
def output_hist(out):
......@@ -62,17 +61,18 @@ class TestRandintOpError(unittest.TestCase):
def test_shape():
shape = np.array([2, 3])
paddle.randint(5, shape=shape, dtype='int32')
fluid.layers.randint(5, shape=shape, dtype='int32')
self.assertRaises(TypeError, test_shape)
def test_dtype():
paddle.randint(5, shape=[32, 32], dtype='float32')
fluid.layers.randint(5, shape=[32, 32], dtype='float32')
self.assertRaises(TypeError, test_dtype)
def test_low_high():
paddle.randint(low=5, high=5, shape=[32, 32], dtype='int32')
fluid.layers.randint(
low=5, high=5, shape=[32, 32], dtype='int32')
self.assertRaises(ValueError, test_low_high)
......@@ -131,21 +131,21 @@ class TestRandintAPI(unittest.TestCase):
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
# results are from [0, 5).
output1 = paddle.randint(5)
output1 = fluid.layers.randint(5)
# shape is a list and dtype is 'int32'
output2 = paddle.randint(
output2 = fluid.layers.randint(
low=-100, high=100, shape=[64, 64], dtype='int32')
# shape is a tuple and dtype is 'int64'
output3 = paddle.randint(
output3 = fluid.layers.randint(
low=-100, high=100, shape=(32, 32, 3), dtype='int64')
# shape is a tensorlist and dtype is 'float32'
dim_1 = fluid.layers.fill_constant([1], "int64", 32)
dim_2 = fluid.layers.fill_constant([1], "int32", 50)
output4 = paddle.randint(
output4 = fluid.layers.randint(
low=-100, high=100, shape=[dim_1, 5], dtype='int32')
# shape is a tensor and dtype is 'float64'
var_shape = fluid.data(name='var_shape', shape=[2], dtype="int64")
output5 = paddle.randint(
output5 = fluid.layers.randint(
low=1, high=1000, shape=var_shape, dtype='int64')
place = fluid.CPUPlace()
......@@ -163,7 +163,7 @@ class TestRandintAPI(unittest.TestCase):
class TestRandintDygraphMode(unittest.TestCase):
def test_check_output(self):
with fluid.dygraph.guard():
x = paddle.randint(10, shape=[10], dtype="int32")
x = fluid.layers.randint(10, shape=[10], dtype="int32")
x_np = x.numpy()
for i in range(10):
self.assertTrue((x_np[i] >= 0 and x_np[i] < 10))
......
......@@ -16,7 +16,6 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
......@@ -24,14 +23,16 @@ from paddle.fluid import Program, program_guard
class TestRandnOp(unittest.TestCase):
def test_api(self):
x1 = paddle.randn(shape=[1000, 784], dtype='float32')
x2 = paddle.randn(shape=[1000, 784], dtype='float64')
x1 = fluid.layers.randn(shape=[1000, 784], dtype='float32')
x2 = fluid.layers.randn(shape=[1000, 784], dtype='float64')
x3 = fluid.layers.fill_constant(
shape=[1000, 784], dtype='float32', value=0)
paddle.randn(shape=[1000, 784], out=x3, dtype='float32')
x4 = paddle.randn(shape=[1000, 784], dtype='float32', device='cpu')
x5 = paddle.randn(shape=[1000, 784], dtype='float32', device='gpu')
x6 = paddle.randn(
fluid.layers.randn(shape=[1000, 784], out=x3, dtype='float32')
x4 = fluid.layers.randn(
shape=[1000, 784], dtype='float32', device='cpu')
x5 = fluid.layers.randn(
shape=[1000, 784], dtype='float32', device='gpu')
x6 = fluid.layers.randn(
shape=[1000, 784],
dtype='float32',
device='gpu',
......@@ -64,43 +65,43 @@ class TestRandnOpError(unittest.TestCase):
# The argument shape's size of randn_op should not be 0.
def test_shape_size():
out = paddle.randn(shape=[])
out = fluid.layers.randn(shape=[])
self.assertRaises(AssertionError, test_shape_size)
# The argument shape's type of randn_op should be list or tuple.
def test_shape_type():
out = paddle.randn(shape=1)
out = fluid.layers.randn(shape=1)
self.assertRaises(TypeError, test_shape_type)
# The argument dtype of randn_op should be float32 or float64.
def test_dtype_float16():
out = paddle.randn(shape=[1, 2], dtype='float16')
out = fluid.layers.randn(shape=[1, 2], dtype='float16')
self.assertRaises(TypeError, test_dtype_float16)
# The argument dtype of randn_op should be float32 or float64.
def test_dtype_int32():
out = paddle.randn(shape=[1, 2], dtype='int32')
out = fluid.layers.randn(shape=[1, 2], dtype='int32')
self.assertRaises(TypeError, test_dtype_int32)
# The argument dtype of randn_op should be float32 or float64.
def test_dtype_int64():
out = paddle.randn(shape=[1, 2], dtype='int64')
out = fluid.layers.randn(shape=[1, 2], dtype='int64')
self.assertRaises(TypeError, test_dtype_int64)
# The argument dtype of randn_op should be float32 or float64.
def test_dtype_uint8():
out = paddle.randn(shape=[1, 2], dtype='uint8')
out = fluid.layers.randn(shape=[1, 2], dtype='uint8')
self.assertRaises(TypeError, test_dtype_uint8)
# The argument dtype of randn_op should be float32 or float64.
def test_dtype_bool():
out = paddle.randn(shape=[1, 2], dtype='bool')
out = fluid.layers.randn(shape=[1, 2], dtype='bool')
self.assertRaises(TypeError, test_dtype_bool)
......
......@@ -15,7 +15,6 @@
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.op import Operator
......@@ -120,12 +119,12 @@ class TestRandpermOpError(unittest.TestCase):
def test_Variable():
out = np.arange(10)
paddle.randperm(n=10, out=out)
fluid.layers.randperm(n=10, out=out)
self.assertRaises(TypeError, test_Variable)
def test_value():
paddle.randperm(n=-3)
fluid.layers.randperm(n=-3)
self.assertRaises(ValueError, test_value)
......@@ -139,9 +138,9 @@ class TestRandpermOp_attr_out(unittest.TestCase):
with fluid.program_guard(train_program, startup_program):
n = 10
data_1 = fluid.layers.fill_constant([n], "int64", 3)
paddle.randperm(n=n, out=data_1)
fluid.layers.randperm(n=n, out=data_1)
data_2 = paddle.randperm(n=n, dtype="int32", device="cpu")
data_2 = fluid.layers.randperm(n=n, dtype="int32", device="cpu")
place = fluid.CPUPlace()
if fluid.core.is_compiled_with_cuda():
......@@ -160,12 +159,12 @@ class TestRandpermDygraphMode(unittest.TestCase):
def test_check_output(self):
with fluid.dygraph.guard():
n = 10
data_1 = paddle.randperm(n, dtype="int64")
data_1 = fluid.layers.randperm(n, dtype="int64")
data_1_np = data_1.numpy()
self.assertTrue(
check_randperm_out(n, data_1_np), msg=error_msg(data_1_np))
data_2 = paddle.randperm(n, dtype="int32", device="cpu")
data_2 = fluid.layers.randperm(n, dtype="int32", device="cpu")
data_2_np = data_2.numpy()
self.assertTrue(
check_randperm_out(n, data_2_np), msg=error_msg(data_2_np))
......
......@@ -15,7 +15,6 @@
from __future__ import print_function
import unittest
import paddle
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
......@@ -66,7 +65,7 @@ class TestRollAPI(unittest.TestCase):
# case 1:
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 3])
z = paddle.roll(x, shifts=1)
z = fluid.layers.roll(x, shifts=1)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': self.data_x},
fetch_list=[z.name],
......@@ -78,7 +77,7 @@ class TestRollAPI(unittest.TestCase):
# case 2:
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 3])
z = paddle.roll(x, shifts=1, dims=0)
z = fluid.layers.roll(x, shifts=1, dims=0)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': self.data_x},
fetch_list=[z.name],
......@@ -92,7 +91,7 @@ class TestRollAPI(unittest.TestCase):
# case 1:
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(self.data_x)
z = paddle.roll(x, shifts=1)
z = fluid.layers.roll(x, shifts=1)
np_z = z.numpy()
expect_out = np.array([[9.0, 1.0, 2.0], [3.0, 4.0, 5.0],
[6.0, 7.0, 8.0]])
......@@ -101,7 +100,7 @@ class TestRollAPI(unittest.TestCase):
# case 2:
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(self.data_x)
z = paddle.roll(x, shifts=1, dims=0)
z = fluid.layers.roll(x, shifts=1, dims=0)
np_z = z.numpy()
expect_out = np.array([[7.0, 8.0, 9.0], [1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]])
......
......@@ -17,7 +17,6 @@ import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
import paddle.tensor as tensor
class TrilTriuOpDefaultTest(OpTest):
......@@ -71,7 +70,7 @@ def case_generator(op_type, Xshape, diagonal, expected):
data = fluid.data(shape=Xshape, dtype='float64', name=cls_name)
with self.assertRaisesRegexp(
eval(expected.split(':')[-1]), errmsg[expected]):
getattr(tensor, op_type)(input=data, diagonal=diagonal)
getattr(fluid.layers, op_type)(input=data, diagonal=diagonal)
class SuccessCase(TrilTriuOpDefaultTest):
def initTestCase(self):
......@@ -122,7 +121,7 @@ class TestTrilTriuOpAPI(unittest.TestCase):
def test_api(self):
data = np.random.random([1, 9, 9, 4]).astype('float32')
x = fluid.data(shape=[1, 9, -1, 4], dtype='float32', name='x')
tril_out, triu_out = tensor.tril(x), tensor.triu(x)
tril_out, triu_out = fluid.layers.tril(x), fluid.layers.triu(x)
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册