From d3c9db75a823173336401bb3de0db56cee3327fe Mon Sep 17 00:00:00 2001 From: zhongpu <2013000149@qq.com> Date: Mon, 27 Apr 2020 13:49:44 +0800 Subject: [PATCH] copy api from paddle to paddle.fluid (#24164) * copy api from paddle to paddle.fluid, test=develop * fix optest, test=develop --- python/paddle/fluid/dygraph/nn.py | 562 +++++++++++- python/paddle/fluid/layers/nn.py | 861 +++++++++++++++--- python/paddle/fluid/layers/tensor.py | 442 ++++++++- .../tests/unittests/test_allclose_layer.py | 16 +- .../fluid/tests/unittests/test_arange.py | 5 +- .../fluid/tests/unittests/test_bce_loss.py | 8 +- .../fluid/tests/unittests/test_compare_op.py | 4 +- .../unittests/test_cross_entropy_loss.py | 12 +- .../tests/unittests/test_fill_any_like_op.py | 10 +- .../paddle/fluid/tests/unittests/test_flip.py | 5 +- .../fluid/tests/unittests/test_full_op.py | 40 +- .../fluid/tests/unittests/test_l1_loss.py | 12 +- .../fluid/tests/unittests/test_log_softmax.py | 5 +- .../fluid/tests/unittests/test_meshgrid_op.py | 8 +- .../fluid/tests/unittests/test_mse_loss.py | 12 +- .../fluid/tests/unittests/test_nll_loss.py | 64 +- .../fluid/tests/unittests/test_randint_op.py | 20 +- .../fluid/tests/unittests/test_randn_op.py | 29 +- .../fluid/tests/unittests/test_randperm_op.py | 13 +- .../fluid/tests/unittests/test_roll_op.py | 9 +- .../tests/unittests/test_tril_triu_op.py | 5 +- 21 files changed, 1827 insertions(+), 315 deletions(-) diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index e9139156a14..a644eaf368d 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -17,6 +17,8 @@ from __future__ import print_function from six.moves import reduce from .. import core from ..layers import utils +from ..layers import square +from ..layers import cross_entropy from ..layers import nn as F from .. import dygraph_utils from . import layers @@ -35,7 +37,8 @@ __all__ = [ 'Conv2D', 'Conv3D', 'Pool2D', 'Linear', 'BatchNorm', 'Dropout', 'Embedding', 'GRUUnit', 'InstanceNorm', 'LayerNorm', 'NCE', 'PRelu', 'BilinearTensorProduct', 'Conv2DTranspose', 'Conv3DTranspose', 'GroupNorm', - 'SpectralNorm', 'TreeConv' + 'SpectralNorm', 'TreeConv', 'CrossEntropyLoss', 'MSELoss', 'L1Loss', + 'NLLLoss', 'BCELoss' ] @@ -3122,3 +3125,560 @@ class TreeConv(layers.Layer): else: pre_activation = out return self._helper.append_activation(pre_activation, act=self._act) + + +class CrossEntropyLoss(layers.Layer): + """ + This operator implements the cross entropy loss function. This OP combines `softmax`, + `cross_entropy`, and `reduce_sum`/`reduce_mean` together. + + It is useful when training a classification problem with `C` classes. + If provided, the optional argument `weight` should be a 1D Variable assigning + weight to each of the classes. + + For predictions label, and target label, the loss is calculated as follows. + .. math:: + + loss_j = -\\text{input[class]} + + \\log\\left(\\sum_{i=0}^{K}\\exp(\\text{input}_i)\\right), j = 1,..., K + + If weight is not `None`: + .. math:: + + loss_j = \\text{weight[class]}(-\\text{input[class]} + + \\log\\left(\\sum_{i=0}^{K}\\exp(\\text{input}_i)\\right)), j = 1,..., K + + Parameters: + input (Variable): Input tensor, the data type is float32, + float64, int32, int64. + label (Variable): Label tensor, the data type is float32, + float64, int32, int64. + weight (Variable, optional): Weight tensor, a manual rescaling weight given + to each class. It has the same dimensions as class number and the data type + is float32, float64, int32, int64. Default is ``'None'``. + reduction (str, optional): Indicate how to average the loss by batch_size, + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; + If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned. + If :attr:`reduction` is ``'none'``, the unreduced loss is returned. + Default is ``'mean'``. + Returns: + The tensor variable storing the cross_entropy_loss of input and label. + Return type: Variable. + Examples: + .. code-block:: python + + # declarative mode + import paddle.fluid as fluid + import numpy as np + + input = fluid.layers.data(name='input', shape=[5, 100], dtype='float32') + label = fluid.layers.data(name='label', shape=[5, 1], dtype='int64') + weight = fluid.layers.data(name='weight', shape=[100], dtype='float32') + ce_loss = fluid.dygraph.CrossEntropyLoss(weight=weight, reduction='mean') + output = ce_loss(input,label) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + input_data = np.random.random([5, 100]).astype("float32") + label_data = np.array([[1], [9], [40], [50], [90]]).astype("int64") + weight_data = np.random.random([100]).astype("float32") + output = exe.run(fluid.default_main_program(), + feed={"input": input_data, "label": label_data,"weight": weight_data}, + fetch_list=[output], + return_numpy=True) + print(output) + + # imperative mode + import paddle.fluid.dygraph as dg + with dg.guard(place) as g: + input = dg.to_variable(input_data) + label = dg.to_variable(label_data) + weight = dg.to_variable(weight_data) + ce_loss = fluid.dygraph.CrossEntropyLoss(weight=weight, reduction='mean') + output = ce_loss(input, label) + print(output.numpy()) + """ + + def __init__(self, weight=None, reduction='mean'): + super(CrossEntropyLoss, self).__init__() + self.weight = weight + self.reduction = reduction + + def forward(self, input, label): + check_variable_and_dtype(input, 'input', + ['float32', 'float64', 'int32', 'int64'], + 'cross_entropy_loss') + check_variable_and_dtype(label, 'label', + ['float32', 'float64', 'int32', 'int64'], + 'cross_entropy_loss') + + if self.reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "The value of 'reduction' in cross_entropy_loss should be 'sum', 'mean' or 'none'," + " but received %s, which is not allowed." % self.reduction) + + softmax_out = F.softmax(input) + if self.weight is not None: + if isinstance(self.weight, Variable): + softmax_out = F.elementwise_pow( + softmax_out, self.weight, axis=-1) + else: + raise ValueError( + "The weight' is not a Variable, please convert to Variable.") + + out = cross_entropy(softmax_out, label) + + if self.reduction == 'sum': + return F.reduce_sum(out) + elif self.reduction == 'mean': + return F.reduce_mean(out) + else: + return out + + +class MSELoss(layers.Layer): + """ + **Mean Square Error Loss** + Computes the mean square error (squared L2 norm) of given input and label. + + If :attr:`reduction` is set to ``'none'``, loss is calculated as: + + .. math:: + Out = (input - label)^2 + + If :attr:`reduction` is set to ``'mean'``, loss is calculated as: + + .. math:: + Out = \operatorname{mean}((input - label)^2) + + If :attr:`reduction` is set to ``'sum'``, loss is calculated as: + + .. math:: + Out = \operatorname{sum}((input - label)^2) + + where `input` and `label` are `float32` tensors of same shape. + + Parameters: + input (Variable): Input tensor, the data type is float32, + label (Variable): Label tensor, the data type is float32, + reduction (string, optional): The reduction method for the output, + could be 'none' | 'mean' | 'sum'. + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned. + If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned. + If :attr:`reduction` is ``'none'``, the unreduced loss is returned. + Default is ``'mean'``. + + Returns: + The tensor variable storing the MSE loss of input and label. + + Return type: + Variable. + + Examples: + .. code-block:: python + + import numpy as np + from paddle import fluid + import paddle.fluid.dygraph as dg + + mse_loss = fluid.dygraph.MSELoss() + input = fluid.data(name="input", shape=[1]) + label = fluid.data(name="label", shape=[1]) + place = fluid.CPUPlace() + input_data = np.array([1.5]).astype("float32") + label_data = np.array([1.7]).astype("float32") + + # declarative mode + output = mse_loss(input,label) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + output_data = exe.run( + fluid.default_main_program(), + feed={"input":input_data, "label":label_data}, + fetch_list=[output], + return_numpy=True) + print(output_data) + # [array([0.04000002], dtype=float32)] + + # imperative mode + with dg.guard(place) as g: + input = dg.to_variable(input_data) + label = dg.to_variable(label_data) + output = mse_loss(input, label) + print(output.numpy()) + # [0.04000002] + """ + + def __init__(self, reduction='mean'): + super(MSELoss, self).__init__() + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "'reduction' in 'MSELoss' should be 'sum', 'mean' or 'none', " + "but received {}.".format(reduction)) + self.reduction = reduction + + def forward(self, input, label): + if not in_dygraph_mode(): + check_variable_and_dtype(input, 'input', ['float32'], 'MSELoss') + check_variable_and_dtype(label, 'label', ['float32'], 'MSELoss') + + square_out = square(F.elementwise_sub(input, label)) + if self.reduction == 'none': + return square_out + + reduce_op = 'reduce_mean' + if self.reduction == 'sum': + reduce_op = 'reduce_sum' + + return getattr(F, reduce_op)(square_out) + + +class L1Loss(layers.Layer): + """ + This interface is used to construct a callable object of the ``L1Loss`` class. + The L1Loss layer calculates the L1 Loss of input predictions and target + labels as follows. + + If :attr:`reduction` set to ``'none'``, the unreduced loss is: + .. math:: + Out = |input - label| + If :attr:`reduction` set to ``'mean'``, the reduced mean loss is: + .. math:: + Out = MEAN(|input - label|) + If :attr:`reduction` set to ``'sum'``, the reduced sum loss is: + .. math:: + Out = SUM(|input - label|) + + The shape of input predictions and target labels are [N, *], where N is batch_size and `*` + means any number of additional dimensions. + If :attr:`reduction` is ``'none'``, the shape of output loss is [N, *], the same as input. + If :attr:`reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1], which means the output is a scalar. + + Parameters: + reduction (str, optional): Indicate the reduction to apply to the loss, + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'none'``, the unreduced loss is returned; + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned. + If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned. + Default is ``'mean'``. + Returns: + A callable object of L1Loss. + Examples: + .. code-block:: python + # declarative mode + import paddle.fluid as fluid + import numpy as np + input = fluid.data(name="input", shape=[1]) + label = fluid.data(name="label", shape=[1]) + l1_loss = fluid.dygraph.L1Loss(reduction='mean') + output = l1_loss(input,label) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + input_data = np.array([1.5]).astype("float32") + label_data = np.array([1.7]).astype("float32") + output_data = exe.run(fluid.default_main_program(), + feed={"input":input_data, "label":label_data}, + fetch_list=[output], + return_numpy=True) + + print(output_data) # [array([0.2], dtype=float32)] + + # imperative mode + import paddle.fluid.dygraph as dg + with dg.guard(place) as g: + input = dg.to_variable(input_data) + label = dg.to_variable(label_data) + l1_loss = fluid.dygraph.L1Loss(reduction='mean') + output = l1_loss(input,label) + print(output.numpy()) # [0.2] + """ + + def __init__(self, reduction='mean'): + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "The value of 'reduction' in L1Loss should be 'sum', 'mean' or 'none', but " + "received %s, which is not allowed." % reduction) + super(L1Loss, self).__init__() + self.reduction = reduction + + def forward(self, input, label): + check_variable_and_dtype( + input, 'input', ['float32', 'float64', 'int32', 'int64'], 'l1_loss') + check_variable_and_dtype( + label, 'label', ['float32', 'float64', 'int32', 'int64'], 'l1_loss') + + unreduced = F.elementwise_sub(input, label, act='abs') + + if self.reduction == 'sum': + return F.reduce_sum(unreduced) + elif self.reduction == 'mean': + return F.reduce_mean(unreduced) + else: + return unreduced + + +class BCELoss(layers.Layer): + """ + This interface is used to construct a callable object of the ``BCELoss`` class. + The BCELoss layer measures the binary_cross_entropy loss between input predictions + and target labels. The binary_cross_entropy loss can be described as: + + If :attr:`weight` is set, the loss is: + + .. math:: + Out = -1 * weight * (label * log(input) + (1 - label) * log(1 - input)) + If :attr:`weight` is None, the loss is: + + .. math:: + Out = -1 * (label * log(input) + (1 - label) * log(1 - input)) + + If :attr:`reduction` set to ``'none'``, the unreduced loss is: + + .. math:: + Out = Out + If :attr:`reduction` set to ``'mean'``, the reduced mean loss is: + + .. math:: + Out = MEAN(Out) + If :attr:`reduction` set to ``'sum'``, the reduced sum loss is: + + .. math:: + Out = SUM(Out) + + Note that the input predictions always be the output of sigmoid, and the target labels + should be numbers between 0 and 1. + + The shape of input predictions and target labels are [N, *], where N is batch_size and `*` + means any number of additional dimensions. If ``reduction`` is ``'none'``, the shape of + output is scalar, else the shape of output is same as input. + + Parameters: + weight (Variable, optional): A manual rescaling weight given to the loss of each + batch element. If given, has to be a Variable of size nbatch and the data type + is float32, float64. Default is ``'None'``. + reduction (str, optional): Indicate how to average the loss by batch_size, + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'none'``, the unreduced loss is returned; + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; + If :attr:`reduction` is ``'sum'``, the summed loss is returned. + Default is ``'mean'``. + + Returns: + A callable object of BCELoss. + + Examples: + .. code-block:: python + + # declarative mode + import paddle.fluid as fluid + import numpy as np + input = fluid.data(name="input", shape=[3, 1], dtype='float32') + label = fluid.data(name="label", shape=[3, 1], dtype='float32') + bce_loss = fluid.dygraph.BCELoss() + output = bce_loss(input, label) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + input_data = np.array([0.5, 0.6, 0.7]).astype("float32") + label_data = np.array([1.0, 0.0, 1.0]).astype("float32") + output_data = exe.run(fluid.default_main_program(), + feed={"input":input_data, "label":label_data}, + fetch_list=[output], + return_numpy=True) + + print(output_data) # [array([0.65537095], dtype=float32)] + + # imperative mode + import paddle.fluid.dygraph as dg + with dg.guard(place) as g: + input = dg.to_variable(input_data) + label = dg.to_variable(label_data) + output = bce_loss(input, label) + print(output.numpy()) # [0.65537095] + """ + + def __init__(self, weight=None, reduction='mean'): + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "The value of 'reduction' in bce_loss should be 'sum', 'mean' or 'none', but " + "received %s, which is not allowed." % reduction) + + super(BCELoss, self).__init__() + self.weight = weight + self.reduction = reduction + + def forward(self, input, label): + dtype = self._helper.input_dtype(input) + + check_variable_and_dtype(input, 'input', ['float32', 'float64'], + 'bce_loss') + check_variable_and_dtype(label, 'label', ['float32', 'float64'], + 'bce_loss') + + out = self._helper.create_variable_for_type_inference(dtype=input.dtype) + self._helper.append_op( + type='bce_loss', + inputs={ + 'X': [input], + 'Label': [label], + }, + outputs={'Out': [out]}) + + if self.weight is not None: + if isinstance(self.weight, Variable): + w = self.weight + out = F.elementwise_mul(out, w, axis=-1) + else: + raise ValueError( + "The weight is not a Variable, please convert to Variable.") + + if self.reduction == 'sum': + return F.reduce_sum(out) + elif self.reduction == 'mean': + return F.reduce_mean(out) + else: + return out + + +class NLLLoss(layers.Layer): + """ + This op accepts input and target label and returns negative log likelihood + cross error. It is useful to train a classification problem with C classes. + + The input for the loss is epected to contain log-probabilities of + each classes. It hs to be a Tensor of size either (batch_size, C) or + (batch_size, C, d1, d2, ..., dK) with K >= 1 for the K-dimensional case. + The label for the loss should be a class index in the range [0, C-1] + where C is the number of classes. If ignore_index is specified, the + specified target value does not contribute to the input gradient. + + If the optional argument `weight` is provided, it should be a 1D Tensor + assigning weight to each of the classed. This is particularly useful + when you have an unbalanced training set. + + The loss is calculated as follows. + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\\top, \quad + l_n = - w_{y_n} x_{n,y_n}, \quad + w_{c} = \\text{weight}[c] \cdot \mathbb{1}\{c \\not= \\text{ignore\\_index}\}, + + where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then + + .. math:: + \ell(x, y) = \\begin{cases} + \\sum_{n=1}^N \\frac{1}{\\sum_{n=1}^N w_{y_n}} l_n, & + \\text{if reduction} = \\text{'mean';}\\\\ + \\sum_{n=1}^N l_n, & + \\text{if reduction} = \\text{'sum'.} + \\end{cases} + + Parameters: + input (Variable): Input tensor, the data type is float32, float64. + label (Variable): Label tensor, the data type is int64_t. + weight (Variable, optional): Weight tensor, a manual rescaling weight given + to each class. If given, it has to be a Tensor of size `C`. Otherwise, + it treated as if having all ones. the data type is + float32, float64, Default is ``'None'``. + reduction (str, optional): Indicate how to average the loss, + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; + Default is ``'mean'``. + ignore_index (int64, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. + + Returns: + The tensor variable storing the nll_loss. + + Return type: Variable. + + Examples: + + .. code-block:: python + + # declarative mode + import paddle.fluid as fluid + import numpy as np + + input_np = np.random.random(size=(10, 10)).astype(np.float32) + label_np = np.random.randint(0, 10, size=(10,)).astype(np.int64) + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data(name='input', shape=[10, 10], dtype='float32') + label = fluid.data(name='label', shape=[10], dtype='int64') + nll_loss = fluid.dygraph.NLLLoss() + res = nll_loss(input, label) + + exe = fluid.Executor(place) + static_result = exe.run( + prog, + feed={"input": input_np, + "label": label_np}, + fetch_list=[res]) + print(static_result) + + # imperative mode + import paddle.fluid.dygraph as dg + with dg.guard(place) as g: + input = dg.to_variable(input_np) + label = dg.to_variable(label_np) + output = nll_loss(input, label) + print(output.numpy()) + """ + + def __init__(self, weight=None, reduction='mean', ignore_index=-100): + super(NLLLoss, self).__init__() + self.weight = weight + self.reduction = reduction + self.ignore_index = ignore_index + + def forward(self, input, label): + dtype = self._helper.input_dtype(input) + + check_variable_and_dtype(input, 'input', ['float32', 'float64'], + 'nll_loss') + check_variable_and_dtype(label, 'label', ['int64'], 'nll_loss') + + if self.reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "The value of 'reduction' in nll_loss should be 'sum', 'mean' or 'none', but " + "received %s, which is not allowed." % self.reduction) + + x_shape = list(input.shape) + n = x_shape[0] + c = x_shape[1] + x_dims = len(x_shape) + if x_dims < 2: + raise ValueError('Expected 2 or more dimensions (got {})'.format( + x_dims)) + if x_dims != 2 and x_dims != 4: + input = F.reshape(input, shape=[n, c, 1, -1]) + label = F.reshape(label, shape=[n, 1, -1]) + out_shape = [n] + x_shape[2:] + + inputs = {'X': input, 'Label': label} + attrs = {'reduction': self.reduction, 'ignore_index': self.ignore_index} + + if self.weight is not None: + if isinstance(self.weight, Variable): + inputs['Weight'] = self.weight + + out = self._helper.create_variable_for_type_inference(dtype=input.dtype) + total_weight = self._helper.create_variable_for_type_inference( + dtype=input.dtype) + outputs = {'Out': out, 'Total_weight': total_weight} + + self._helper.append_op( + type='nll_loss', inputs=inputs, outputs=outputs, attrs=attrs) + if x_dims != 2 and x_dims != 4 and self.reduction == 'none': + out = F.reshape(out, shape=out_shape) + + return out diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 45a3de64bc0..5b28f141d3d 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -26,11 +26,11 @@ import six import paddle from ..layer_helper import LayerHelper from ..initializer import Normal, Constant, NumpyArrayInitializer -from ..framework import Variable, OpProtoHolder, in_dygraph_mode, dygraph_only, _dygraph_tracer, default_main_program +from ..framework import Variable, OpProtoHolder, in_dygraph_mode, dygraph_only, _dygraph_tracer, default_main_program, device_guard, _varbase_creator from .. import dygraph_utils from ..param_attr import ParamAttr from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_ -from .tensor import concat, assign, fill_constant, zeros, tensor_array_to_tensor +from .tensor import concat, assign, fill_constant, zeros, tensor_array_to_tensor, cast from . import utils from .. import unique_name from functools import reduce @@ -39,153 +39,38 @@ from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, c import paddle __all__ = [ - 'fc', - 'embedding', - 'linear_chain_crf', - 'crf_decoding', - 'cos_sim', - 'chunk_eval', - 'conv2d', - 'conv3d', - 'softmax', - 'pool2d', - 'pool3d', - 'adaptive_pool2d', - 'adaptive_pool3d', - 'batch_norm', - 'inplace_abn', - 'instance_norm', - 'data_norm', - 'conv2d_transpose', - 'conv3d_transpose', - 'reduce_sum', - 'reduce_mean', - 'reduce_max', - 'reduce_min', - 'reduce_prod', - 'reduce_all', - 'reduce_any', - 'dropout', - 'split', - 'ctc_greedy_decoder', - 'l2_normalize', - 'matmul', - 'topk', - 'transpose', - 'im2sequence', - 'row_conv', - 'multiplex', - 'layer_norm', - 'group_norm', - 'spectral_norm', - 'smooth_l1', - 'one_hot', - 'autoincreased_step_counter', - 'reshape', - 'squeeze', - 'unsqueeze', - 'lod_reset', - 'lod_append', - 'lrn', - 'pad', - 'pad_constant_like', - 'label_smooth', - 'roi_pool', - 'roi_align', - 'dice_loss', - 'image_resize', - 'image_resize_short', - 'resize_bilinear', - 'resize_trilinear', - 'resize_nearest', - 'gather', - 'gather_nd', - 'scatter', - 'scatter_nd_add', - 'scatter_nd', - 'random_crop', - 'mean_iou', - 'relu', - 'selu', - 'log', - 'crop', - 'crop_tensor', - 'elu', - 'relu6', - 'pow', - 'stanh', - 'hard_sigmoid', - 'swish', - 'prelu', - 'brelu', - 'leaky_relu', - 'soft_relu', - 'flatten', - 'stack', - 'pad2d', - 'unstack', - 'unique', - 'unique_with_counts', - 'expand', - 'expand_as', - 'scale', - 'elementwise_add', - 'elementwise_div', - 'elementwise_sub', - 'elementwise_mul', - 'elementwise_max', - 'elementwise_min', - 'elementwise_pow', - 'elementwise_mod', - 'elementwise_floordiv', - 'uniform_random_batch_size_like', - 'gaussian_random', - 'sampling_id', - 'gaussian_random_batch_size_like', - 'sum', - 'slice', - 'strided_slice', - 'shape', - 'rank', - 'size', - 'logical_and', - 'logical_or', - 'logical_xor', - 'logical_not', - 'clip', - 'clip_by_norm', - 'mean', - 'mul', - 'maxout', - 'space_to_depth', - 'affine_grid', - 'affine_channel', - 'similarity_focus', - 'hash', - 'grid_sampler', - 'log_loss', - 'add_position_encoding', - 'bilinear_tensor_product', - 'merge_selected_rows', - 'get_tensor_from_selected_rows', - 'shuffle_channel', - 'temporal_shift', - 'py_func', - 'psroi_pool', - 'prroi_pool', - 'pixel_shuffle', - 'fsp_matrix', - 'continuous_value_model', - 'where', - 'sign', - 'deformable_conv', - 'unfold', - 'deformable_roi_pooling', - 'filter_by_instag', - 'shard_index', - 'hard_swish', - 'gather_tree', - 'uniform_random', + 'fc', 'embedding', 'linear_chain_crf', 'crf_decoding', 'cos_sim', + 'chunk_eval', 'conv2d', 'conv3d', 'softmax', 'pool2d', 'pool3d', + 'adaptive_pool2d', 'adaptive_pool3d', 'batch_norm', 'inplace_abn', + 'instance_norm', 'data_norm', 'conv2d_transpose', 'conv3d_transpose', + 'reduce_sum', 'reduce_mean', 'reduce_max', 'reduce_min', 'reduce_prod', + 'reduce_all', 'reduce_any', 'dropout', 'split', 'ctc_greedy_decoder', + 'l2_normalize', 'matmul', 'topk', 'transpose', 'im2sequence', 'row_conv', + 'multiplex', 'layer_norm', 'group_norm', 'spectral_norm', 'smooth_l1', + 'one_hot', 'autoincreased_step_counter', 'reshape', 'squeeze', 'unsqueeze', + 'lod_reset', 'lod_append', 'lrn', 'pad', 'pad_constant_like', + 'label_smooth', 'roi_pool', 'roi_align', 'dice_loss', 'image_resize', + 'image_resize_short', 'resize_bilinear', 'resize_trilinear', + 'resize_nearest', 'gather', 'gather_nd', 'scatter', 'scatter_nd_add', + 'scatter_nd', 'random_crop', 'mean_iou', 'relu', 'selu', 'log', 'crop', + 'crop_tensor', 'elu', 'relu6', 'pow', 'stanh', 'hard_sigmoid', 'swish', + 'prelu', 'brelu', 'leaky_relu', 'soft_relu', 'flatten', 'stack', 'pad2d', + 'unstack', 'unique', 'unique_with_counts', 'expand', 'expand_as', 'scale', + 'elementwise_add', 'elementwise_div', 'elementwise_sub', 'elementwise_mul', + 'elementwise_max', 'elementwise_min', 'elementwise_pow', 'elementwise_mod', + 'elementwise_floordiv', 'uniform_random_batch_size_like', 'gaussian_random', + 'sampling_id', 'gaussian_random_batch_size_like', 'sum', 'slice', + 'strided_slice', 'shape', 'rank', 'size', 'logical_and', 'logical_or', + 'logical_xor', 'logical_not', 'clip', 'clip_by_norm', 'mean', 'mul', + 'maxout', 'space_to_depth', 'affine_grid', 'affine_channel', + 'similarity_focus', 'hash', 'grid_sampler', 'log_loss', + 'add_position_encoding', 'bilinear_tensor_product', 'merge_selected_rows', + 'get_tensor_from_selected_rows', 'shuffle_channel', 'temporal_shift', + 'py_func', 'psroi_pool', 'prroi_pool', 'pixel_shuffle', 'fsp_matrix', + 'continuous_value_model', 'where', 'sign', 'deformable_conv', 'unfold', + 'deformable_roi_pooling', 'filter_by_instag', 'shard_index', 'hard_swish', + 'gather_tree', 'uniform_random', 'randint', 'randn', 'randperm', 'allclose', + 'elementwise_equal', 'flip', 'roll', 'log_softmax' ] @@ -14312,3 +14197,681 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0): outputs={"Out": out}) return helper.append_activation(out) + + +def randint(low, + high=None, + shape=None, + out=None, + dtype=None, + device=None, + stop_gradient=False, + seed=0, + name=None): + """ + This function returns a Tensor filled with random integers from the "discrete uniform" distribution of the + specified data type in the interval [low, high). If high is None (the default), then results are from [0, low). + + Args: + low (int): The lower bound on the range of random values to generate, the low is included in the range. + (unless high=None, in which case this parameter is one above the highest such integer). + high (int, optional): The upper bound on the range of random values to generate, the high is excluded + in the range. Default None(see above for behavior if high=None). + shape (list|tuple|Variable, optional): The shape of the output Tensor, if the shape is a list or tuple, + its elements can be an integer + or a Tensor with the shape [1], and the type of the Tensor must be int32 or int64. + If the shape is a Variable, it is a 1-D Tensor, and the type of the Tensor must be + int32 or int64. Default is None, in which case the shape is [1]. + out(Variable, optional): Optional output which can be any created + Variable that meets the requirements to store the result of operation. + if out is None, a new Varibale will be create to store the result. + dtype(np.dtype|core.VarDesc.VarType|str, optional): Data type of the output Tensor + which can be int32, int64, if dytpe is `None`, the data + type of created Tensor is `int64` + device(str, optional): This parameter specifies that the Tensor is created + on the GPU or CPU. + stop_gradient(bool, optional): Indicating if we stop gradient from current(out) Variable, + default value is False. + seed (int, optional): Random seed used for permute samples. If seed is + equal to 0, it means use a seed generated by the system. Note that + if seed is not 0, this operator will always generate the same random + permutation every time. Default: 0. + name(str, optional): The default value is None. Normally there is no need for user to set this + property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Variable: A Tensor of the specified shape filled with random integers. + + Raises: + TypeError: Randint's low must less then high. + + Examples: + .. code-block:: python + import paddle.fluid as fluid + + # example 1: + # attr shape is a list which doesn't contain tensor Variable. + result_1 = fluid.layers.randint(low=-5, high=5, shape=[3, 4], dtype="int64") + + # example 2: + # attr shape is a list which contains tensor Variable. + dim_1 = fluid.layers.fill_constant([1],"int64",3) + dim_2 = fluid.layers.fill_constant([1],"int32",5) + result_2 = fluid.layers.randint(low=-5, high=5, shape=[dim_1, dim_2], dtype="int32") + + # example 3: + # attr shape is a Variable, the data type must be int64 or int32. + var_shape = fluid.data(name='var_shape', shape=[2], dtype="int64") + result_3 = fluid.layers.randint(low=-5, high=5, shape=var_shape, dtype="int32") + var_shape_int32 = fluid.data(name='var_shape_int32', shape=[2], dtype="int32") + result_4 = fluid.layers.randint(low=-5, high=5, shape=var_shape_int32, dtype="int64") + + # example 4: + # Input only one parameter + # low=0, high=10, shape=[1], dtype='int64' + result_4 = fluid.layers.randint(10) + """ + + def get_new_shape_tensor(list_shape): + new_shape_tensor = [] + for dim in list_shape: + if isinstance(dim, Variable): + dim.stop_gradient = True + new_shape_tensor.append(dim) + else: + assert isinstance(dim, int) or isinstance(dim, long) + temp_out = helper.create_variable_for_type_inference('int64') + fill_constant([1], 'int64', dim, force_cpu=True, out=temp_out) + new_shape_tensor.append(temp_out) + return new_shape_tensor + + def get_attr_shape(list_shape): + unk_dim_idx = -1 + attrs_shape = [] + for dim_idx, dim_size in enumerate(list_shape): + if isinstance(dim_size, Variable): + attrs_shape.append(-1) + else: + attrs_shape.append(dim_size) + assert dim_size > 0, ( + "Each dimension size given in shape must not be negative " + "except one unknown dimension.") + return attrs_shape + + if dtype is None: + dtype = 'int64' + check_dtype(dtype, 'dtype', ['int32', 'int64'], 'randint') + + inputs = dict() + attrs = dict() + + if shape is None: + shape = [1] + assert len(shape) > 0, ("The size of argument(shape) can't be zero.") + + helper = LayerHelper("randint", **locals()) + + if in_dygraph_mode(): + attrs['shape'] = shape + else: + if isinstance(shape, Variable): + shape.stop_gradient = True + inputs["ShapeTensor"] = shape + elif isinstance(shape, (list, tuple)): + assert len(shape) > 0, ( + "The size of argument(shape) can't be zero.") + if utils._contain_var(shape): + inputs['ShapeTensorList'] = get_new_shape_tensor(shape) + else: + attrs["shape"] = get_attr_shape(shape) + check_type(shape, 'shape', (list, tuple, Variable), 'randint') + + if high is None: + high = low + low = 0 + attrs['low'] = low + attrs['high'] = high + attrs['seed'] = seed + if (low >= high): + raise ValueError( + "randint's low must less then high, but received low = {0}, " + "high = {1}".format(low, high)) + + if out is None: + if name is None: + out = helper.create_variable_for_type_inference(dtype=dtype) + else: + out = helper.create_variable( + name=name, dtype=dtype, persistable=False) + else: + check_dtype(dtype, 'dtype', + convert_dtype(out.dtype), 'randint', + "(The dtype in randint must be the same with out's dtype.)") + attrs['dtype'] = out.dtype + out.stop_gradient = stop_gradient + + if device is None: + helper.append_op( + type='randint', inputs=inputs, outputs={'Out': out}, attrs=attrs) + else: + with device_guard(device): + helper.append_op( + type='randint', + inputs=inputs, + outputs={'Out': out}, + attrs=attrs) + return out + + +def randn(shape, + out=None, + dtype=None, + device=None, + stop_gradient=True, + name=None): + """ + This function returns a tensor filled with random numbers from a normal + distribution with mean 0 and variance 1 (also called the standard normal + distribution). + + Args: + shape(list|tuple): Shape of the generated random tensor. + out(Variable, optional): Optional output which can be any created Variable + that meets the requirements to store the result of operation. If the + out is `None`, a new Variable wiil be returned to store the result. + Default is None. + dtype(np.dtype|core.VarDesc.VarType|str, optional): Data type of the output + tensor, which can be float32, float64. if dtype is `None` , the data + type of output tensor is `float32` . + Default is None. + device(str, optional): Specific the output variable to be saved in cpu + or gpu memory. Supported None, 'cpu', 'gpu'. If it is None, the output + variable will be automatically assigned devices. + Default: None. + stop_gradient(bool, optional): Indicating if we stop gradient from current(out) + Variable. Default is True. + name(str, optional): Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . + Default is None. + + Returns: + Random tensor whose data is drawn from a Gaussian distribution, + dtype: flaot32 or float64 as specified. + + Return type: + Variable + + Raises: + TypeError: If the type of `shape` is not list or tuple. + TypeError: If the data type of `dtype` is not float32 or float64. + ValueError: If the length of `shape` is not bigger than 0. + + Examples: + .. code-block:: python + + # declarative mode + import paddle.fluid as fluid + + data = fluid.layers.randn([2, 4]) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + res, = exe.run(fluid.default_main_program(), feed={}, fetch_list=[data]) + print(res) + # [[-1.4187592 0.7368311 -0.53748125 -0.0146909 ] + # [-0.66294265 -1.3090698 0.1898754 -0.14065823]] + + .. code-block:: python + + # imperative mode + import paddle.fluid as fluid + import paddle.fluid.dygraph as dg + + place = fluid.CPUPlace() + with dg.guard(place) as g: + x = fluid.layers.randn([2, 4]) + x_np = x.numpy() + print(x_np) + # [[ 1.5149173 -0.26234224 -0.592486 1.4523455 ] + # [ 0.04581212 -0.85345626 1.1687907 -0.02512913]] + """ + helper = LayerHelper("randn", **locals()) + check_type(shape, 'shape', (list, tuple), 'randn') + assert len(shape) > 0, ("The size of argument(shape) can't be zero.") + + if dtype is None: + dtype = 'float32' + + check_dtype(dtype, 'create data type', ['float32', 'float64'], 'randn') + + if out is None: + out = helper.create_variable_for_type_inference(dtype=dtype) + else: + check_variable_and_dtype(out, 'out', [dtype], 'randn') + + out.stop_gradient = stop_gradient + + dtype = convert_np_dtype_to_dtype_(dtype) + seed = np.random.randint(0, 100) + + with device_guard(device): + helper.append_op( + type='gaussian_random', + outputs={'Out': out}, + attrs={ + 'shape': shape, + 'mean': 0.0, + 'std': 1.0, + 'seed': seed, + 'dtype': dtype, + 'use_mkldnn': False + }) + return out + + +@templatedoc() +def randperm(n, + out=None, + dtype="int64", + device=None, + stop_gradient=True, + seed=0): + """ + ${comment} + + Args: + n (int): The upper bound (exclusive), and it should be greater than 0. + out (Variable, optional): Optional output which can be any created + Variable that meets the requirements to store the result of operation. + If out is None, a new Varibale will be create to store the result. + Default: None. + dtype (np.dtype|core.VarDesc.VarType|str, optional): The type of the + output Tensor. Supported data types: int64, int32. Default: int32. + device (str, optional): Specific the output variable to be saved in cpu + or gpu memory. Supported None, 'cpu', 'gpu'. If it is None, the output + variable will be automatically assigned devices. + Default: None. + stop_gradient (bool, optional): Whether grad should record operations + on the returned tensor. Default: True. + seed (int, optional): Random seed used for permute samples. If seed is + equal to 0, it means use a seed generated by the system. Note that + if seed is not 0, this operator will always generate the same random + permutation every time. Default: 0. + + Returns: + ${out_comment}. + + Return Type: + ${out_type} + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + num = 6 + is_use_gpu = False + + data_1 = fluid.layers.randperm(num) + fluid.layers.Print(data_1) + + data_2 = fluid.layers.randperm(num, dtype="int32", seed=1) + fluid.layers.Print(data_2) + + data_3 = fluid.layers.randperm(num, stop_gradient=False, device="cpu") + fluid.layers.Print(data_3) + + fluid.layers.randperm(num, out=data_3) + fluid.layers.Print(data_3) + + place = fluid.CUDAPlace(0) if is_use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + exe.run() + + """ + + if n < 1: + raise ValueError("The input n should be greater than 0 in randperm op.") + check_dtype(dtype, 'dtype', ['int64', 'int32'], 'randperm') + dtype = convert_dtype(dtype) + if device not in [None, 'cpu', 'gpu']: + raise ValueError("The input device should in [None, 'cpu', 'gpu'].") + check_type(stop_gradient, 'stop_gradient', bool, 'randperm') + + helper = LayerHelper("randperm", **locals()) + if out is None: + out = helper.create_variable_for_type_inference(dtype=dtype) + else: + check_variable_and_dtype(out, 'out', [dtype], 'randperm') + if stop_gradient: + out.stop_gradient = True + inputs = dict() + outputs = {'Out': [out]} + attrs = {'n': n, 'dtype': out.dtype, 'seed': seed} + with device_guard(device): + helper.append_op( + type='randperm', inputs=inputs, outputs=outputs, attrs=attrs) + return out + + +@templatedoc() +def allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): + """ + ${comment} + + Args: + input(inputtype):{input_comment}. + other(othertype):{other_comment}. + rtol(rtoltype,optional):{rtol_comment}. + atol(atoltype,optional):{atol_comment}. + equal_nan(equalnantype,optional):{equal_nan_comment}. + name(STR, optional): The default value is None. + Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + ${out_comment}. + + Return Type: + ${out_type} + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import numpy as np + + use_cuda = fluid.core.is_compiled_with_cuda() + + a = fluid.data(name="a", shape=[2], dtype='float32') + b = fluid.data(name="b", shape=[2], dtype='float32') + + result = fluid.layers.allclose(a, b, rtol=1e-05, atol=1e-08, + equal_nan=False, name="ignore_nan") + result_nan = fluid.layers.allclose(a, b, rtol=1e-05, atol=1e-08, + equal_nan=True, name="equal_nan") + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + x = np.array([10000., 1e-07]).astype("float32") + y = np.array([10000.1, 1e-08]).astype("float32") + result_v, result_nan_v = exe.run( + feed={'a': x, 'b': y}, + fetch_list=[result, result_nan]) + print(result_v, result_nan_v) + # Output: (array([False]), array([False])) + + x = np.array([10000., 1e-08]).astype("float32") + y = np.array([10000.1, 1e-09]).astype("float32") + result_v, result_nan_v = exe.run( + feed={'a': x, 'b': y}, + fetch_list=[result, result_nan]) + print(result_v, result_nan_v) + # Output: (array([ True]), array([ True])) + + x = np.array([1.0, float('nan')]).astype("float32") + y = np.array([1.0, float('nan')]).astype("float32") + result_v, result_nan_v = exe.run( + feed={'a': x, 'b': y}, + fetch_list=[result, result_nan]) + print(result_v, result_nan_v) + # Output: (array([False]), array([ True])) + """ + + check_type(rtol, 'rtol', float, 'allclose') + check_type(atol, 'atol', float, 'allclose') + check_type(equal_nan, 'equal_nan', bool, 'allclose') + + helper = LayerHelper("allclose", **locals()) + out = helper.create_variable_for_type_inference(dtype='bool') + + inputs = {'Input': input, 'Other': other} + outputs = {'Out': out} + attrs = {'rtol': rtol, 'atol': atol, 'equal_nan': equal_nan} + helper.append_op( + type='allclose', inputs=inputs, outputs=outputs, attrs=attrs) + + return out + + +def elementwise_equal(x, y, name=None): + """ + This layer returns the truth value of :math:`x == y` elementwise. + + Args: + x(Variable): Tensor, data type is float32, float64, int32, int64. + y(Variable): Tensor, data type is float32, float64, int32, int64. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Variable: output Tensor, it's shape is the same as the input's Tensor, + and the data type is bool. The result of this op is stop_gradient. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import numpy as np + label = fluid.layers.assign(np.array([3, 3], dtype="int32")) + limit = fluid.layers.assign(np.array([3, 2], dtype="int32")) + out1 = fluid.layers.elementwise_equal(x=label, y=limit) #out1=[True, False] + """ + helper = LayerHelper("elementwise_equal", **locals()) + out = helper.create_variable_for_type_inference(dtype='bool') + out.stop_gradient = True + + helper.append_op( + type='equal', + inputs={'X': [x], + 'Y': [y]}, + outputs={'Out': [out]}, + attrs={'force_cpu': False}) + return out + + +def flip(input, dims, name=None): + """ + + Reverse the order of a n-D tensor along given axis in dims. + + Args: + input (Variable): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor + should be float32, float64, int32, int64, bool. + dims (list): The axis to flip on. + name (str, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Variable: Tensor or LoDTensor calculated by flip layer. The data type is same with input. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import numpy as np + input = fluid.data(name="x", shape=[-1, 2, 2], dtype='float32') + output = fluid.layers.flip(input, dims=[0, 1]) + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(fluid.default_startup_program()) + img = np.arange(12).reshape((3,2,2)).astype(np.float32) + res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output]) + print(res) # [[[10,11][8, 9]],[[6, 7],[4, 5]] [[2, 3],[0, 1]]] + """ + helper = LayerHelper("flip", **locals()) + check_type(input, 'X', (Variable), 'flip') + dtype = helper.input_dtype() + check_dtype(dtype, 'X', + ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'], + 'flip') + check_type(dims, 'dims', (list, tuple), 'flip') + assert len(dims) > 0, 'len(dims) must be greater than 0.' + if name is None: + out = helper.create_variable_for_type_inference(dtype) + else: + out = helper.create_variable(name=name, dtype=dtype, persistable=False) + + helper.append_op( + type="flip", + inputs={"X": input}, + outputs={"Out": out}, + attrs={"dims": dims}) + return out + + +def roll(input, shifts, dims=None): + """ + Roll the `input` tensor along the given dimension(s). Elements that are shifted beyond + the last position are re-introduced at the first position. If a dimension is not specified, + the tensor will be flattened before rolling and then restored to the original shape. + + Args: + input (Variable): The input tensor variable. + shifts (int|list|tuple): The number of places by which the elements + of the `input` tensor are shifted. + dims (int|list|tuple|None): Dimentions along which to roll. + + Returns: + Variable: A Tensor with same data type as `input`. + + Examples: + .. code-block:: python + import numpy as np + import paddle.fluid as fluid + + data = np.array([[1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0]]) + with fluid.dygraph.guard(): + x = fluid.dygraph.to_variable(data) + out_z1 = fluid.layers.roll(x, shifts=1) + print(out_z1.numpy()) + #[[9. 1. 2.] + # [3. 4. 5.] + # [6. 7. 8.]] + out_z2 = fluid.layers.roll(x, shifts=1, dims=0) + print(out_z2.numpy()) + #[[7. 8. 9.] + # [1. 2. 3.] + # [4. 5. 6.]] + """ + helper = LayerHelper("roll", **locals()) + origin_shape = input.shape + if type(shifts) == int: + shifts = [shifts] + if type(dims) == int: + dims = [dims] + + if dims: + check_type(dims, 'dims', (list, tuple), 'roll') + check_type(shifts, 'shifts', (list, tuple), 'roll') + + if in_dygraph_mode(): + if dims is None: + input = core.ops.reshape(input, 'shape', [-1, 1]) + dims = [0] + out = core.ops.roll(input, 'dims', dims, 'shifts', shifts) + return core.ops.reshape(out, 'shape', origin_shape) + + out = helper.create_variable_for_type_inference(input.dtype) + + if dims is None: + input = reshape(input, shape=[-1, 1]) + dims = [0] + + helper.append_op( + type='roll', + inputs={'X': input}, + outputs={'Out': out}, + attrs={'dims': dims, + 'shifts': shifts}) + out = reshape(out, shape=origin_shape, inplace=True) + return out + + +def log_softmax(input, axis=None, dtype=None, name=None): + """ + This operator implements the log_softmax layer. The calculation process is as follows: + + .. math:: + + Out[i, j] = log(softmax(x)) + = log(\\frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}) + + Parameters: + input (Variable): The input variable. A multi-dimension Tensor with type float32, or float64. + axis (int, optional): The index of dimension to perform softmax calculations, it should be in + range :math:`[-1, rank-1]`, while :math:`rank` is the rank of input variable. Default: None. + None and -1 means the last dimension. + dtype (np.dtype|core.VarDesc.VarType|str): The desired data type of returned tensor. If specified, + the input tensor is casted to dtype before the operation is performed. This is useful for + preventing data type overflows. Default: None. Supported dtype: float32 or float64 + name (str, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Variable: ``Tensor`` indicates the output of softmax. The data type and shape are the same as ``input``. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import numpy as np + + data = np.array([[[-2.0, 3.0, -4.0, 5.0], + [3.0, -4.0, 5.0, -6.0], + [-7.0, -8.0, 8.0, 9.0]], + [[1.0, -2.0, -3.0, 4.0], + [-5.0, 6.0, 7.0, -8.0], + [6.0, 7.0, 8.0, 9.0]]]).astype('float32') + with fluid.dygraph.guard(): + data = fluid.dygraph.to_variable(data) + res = fluid.layers.log_softmax(data, -1) + # [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948] + # [ -2.1270514 -9.127051 -0.12705144 -11.127051 ] + # [-16.313261 -17.313261 -1.3132617 -0.31326184]] + # [[ -3.0518122 -6.051812 -7.051812 -0.051812 ] + # [-12.313267 -1.3132664 -0.3132665 -15.313267 ] + # [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]] + """ + + axis = -1 if axis is None else axis + dtype = convert_np_dtype_to_dtype_(dtype) if dtype is not None else dtype + + if in_dygraph_mode(): + outs_cast = input if dtype is None \ + else core.ops.cast(input, 'in_dtype', input.dtype, 'out_dtype', dtype) + outs_softmax = core.ops.softmax(outs_cast, 'axis', axis, 'use_cudnn', + False) + return core.ops.log(outs_softmax) + + if dtype is None: + check_variable_and_dtype( + input, 'input', ['float16', 'float32', 'float64'], 'log_softmax') + + helper = LayerHelper("log_softmax", **locals()) + outs_cast = input + if dtype is not None: + outs_cast = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type='cast', + inputs={'X': input}, + outputs={'Out': outs_cast}, + attrs={'in_dtype': input.dtype, + 'out_dtype': dtype}) + + outs_softmax = helper.create_variable_for_type_inference(outs_cast.dtype) + helper.append_op( + type='softmax', + inputs={'X': outs_cast}, + outputs={'Out': outs_softmax}, + attrs={'axis': axis, + 'use_cudnn': False}) + + outs_log = helper.create_variable_for_type_inference(outs_softmax.dtype) + helper.append_op( + type='log', inputs={'X': outs_softmax}, outputs={'Out': outs_log}) + + return outs_log diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index f1d47ecc4b9..6ddd4dc04a8 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -18,8 +18,8 @@ from six.moves import reduce from ..layer_helper import LayerHelper from ..param_attr import ParamAttr from ..initializer import Initializer -from ..framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator -from ..framework import Variable +from ..framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator, device_guard, OpProtoHolder +from ..framework import Variable, in_dygraph_mode from ..initializer import Constant from ..core import VarDesc from .. import core @@ -30,32 +30,12 @@ import numpy import warnings __all__ = [ - 'create_tensor', - 'create_parameter', - 'create_global_var', - 'cast', - 'tensor_array_to_tensor', - 'concat', - 'sums', - 'assign', - 'fill_constant_batch_size_like', - 'fill_constant', - 'argmin', - 'argmax', - 'argsort', - 'ones', - 'zeros', - 'reverse', - 'has_inf', - 'has_nan', - 'isfinite', - 'range', - 'linspace', - 'zeros_like', - 'ones_like', - 'diag', - 'eye', - 'kron', + 'create_tensor', 'create_parameter', 'create_global_var', 'cast', + 'tensor_array_to_tensor', 'concat', 'sums', 'assign', + 'fill_constant_batch_size_like', 'fill_constant', 'argmin', 'argmax', + 'argsort', 'ones', 'zeros', 'reverse', 'has_inf', 'has_nan', 'isfinite', + 'range', 'linspace', 'zeros_like', 'ones_like', 'diag', 'eye', 'kron', + 'full_like', 'arange', 'full', 'tril', 'triu' ] @@ -1587,6 +1567,412 @@ def ones_like(x, out=None): return out +def full_like(input, + fill_value, + out=None, + dtype=None, + device=None, + stop_gradient=True, + name=None): + """ + **full_like** + This function creates a tensor filled with `fill_value` which has identical shape and dtype + with `input`. + Args: + input(Variable): The input tensor which specifies shape and dtype. + fill_value: The value to fill the tensor with. Data type can be bool, float32, float64, int32, int64. Default value is 0. + out(Variable): The output tensor. + Returns: + out(Variable): The tensor variable storing the output. + Examples: + .. code-block:: python + import paddle.fluid as fluid + import numpy as np + + input = fluid.data(name='input', dtype='float32', shape=[2, 3]) + output = fluid.layers.full_like(input, 2.0) + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(fluid.default_startup_program()) + img=np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32) + res = exe.run(fluid.default_main_program(), feed={'input':img}, fetch_list=[output]) + print(res) # [array([[2., 2., 2.], [2., 2., 2.]], dtype=float32)] + """ + helper = LayerHelper("full_like", **locals()) + + if dtype is None: + dtype = 'float32' + + check_dtype(dtype, 'dtype', + ['bool', 'float16', 'float32', 'int32', 'int64'], 'full_like') + + if out is None: + out = helper.create_variable_for_type_inference(dtype=dtype) + helper.append_op( + type='fill_any_like', + inputs={'X': [input]}, + attrs={'value': fill_value}, + outputs={'Out': [out]}) + out.stop_gradient = stop_gradient + + return out + + +def arange(start, end, step=1, dtype=None, name=None): + """ + Return evenly spaced values within a given interval. + Values are generated within the half-open interval [start, stop) (in other words, + the interval including start but excluding stop). + Parameters: + start(float32 | float64 | int32 | int64 | Variable): Start of interval. The interval includes this value. + when start is Variable, it is a 1-D Tensor with shape [1]. + end(float32 | float64 | int32 | int64 | Variable): End of interval. The interval does not include this + value, except in some cases where step is not an integer + and floating point round-off affects the length of out. When end is Variable, + it is a 1-D Tensor with shape [1]. + step(float32 | float64 | int32 | int64 | Variable): Spacing between values. For any output out, this is the + distance between two adjacent values, out[i+1] - out[i]. + dtype(str|core.VarDesc.VarType): the data type of the output tensor, can be float32, float64, int32, int64. + Returns: a 1-D Tensor which is evenly spaced values within a given interval. Its data type is set by dtype. + + Return type: Variable + examples: + .. code-block:: python + import paddle.fluid as fluid + # expected out put: [0, 2, 4, 6, 8] + data = fluid.layers.arange(0, 10, 2, 'int32') + #dygraph mode + import paddle.fluid as fluid + with fluid.dygraph.guard(): + x = fluid.layers.arange(0, 6, 2) + # x: [0, 2, 4] + # x dtype: float32 + + """ + helper = LayerHelper("range", **locals()) + + if dtype is None: + dtype = 'float32' + + check_dtype(dtype, 'create data type', + ['float32', 'float64', 'int32', 'int64'], 'range') + + dtype = convert_dtype(dtype) + if not isinstance(start, Variable): + start = fill_constant([1], dtype, start) + + if not isinstance(end, Variable): + end = fill_constant([1], dtype, end) + + if not isinstance(step, Variable): + step = fill_constant([1], dtype, step) + + out = helper.create_variable_for_type_inference(dtype=start.dtype) + + helper.append_op( + type='range', + inputs={'Start': start, + 'End': end, + 'Step': step}, + outputs={'Out': [out]}) + out.stop_gradient = True + return out + + +def full(shape, + fill_value, + out=None, + dtype=None, + device=None, + stop_gradient=True, + name=None): + """ + This Op return a Tensor with the `fill_value` which size is same as `shape` + + Args: + shape(list|tuple|Variable): Shape of the Tensor to be created. + The data type is ``int32`` or ``int64`` . If ``shape`` is a list or tuple, + the elements of it should be integers or Tensors with shape [1]. + If ``shape`` is an Variable, it should be an 1-D Tensor . + fill_value(bool|float16|float32|float64|int32|int64|Variable): The constant value + used to initialize the Tensor to be created. If fill_value is an Variable, it must be an 1-D Tensor. + out(Variable, optional): Optional output which can be any created + Variable that meets the requirements to store the result of operation. + if out is None, a new Varibale will be create to store the result. + dtype(np.dtype|core.VarDesc.VarType|str, optional): Data type of the output tensor + which can be float16, float32, float64, int32, int64, if dytpe is `None`, the data + type of created tensor is `float32` + device(str, optional): On which device to run this Op. The :attr:`device` must be + None, 'cpu' or 'gpu'. If :attr:`device` is None, the device that the user set in + the paddle program will be chosen. Default value is None. + stop_gradient(bool, optional): Indicating if we stop gradient from current(out) Variable, + default value is True. + name(str, optional): The default value is None. Normally there is no need for user to set this + property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Variable: Tensor which is created according to shape and dtype. + + Raises: + TypeError: The `dtype` must be one of None, bool, float16, float32, float64, int32 and int64. + TypeError: The `out` must be a Variable. + TypeError: The `shape` must be one of Variable, list tuple. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + data1 = fluid.layers.full(shape=[2,1], fill_value=0, dtype='int64') # data1=[[0],[0]] + data2 = fluid.layers.full(shape=[2,1], fill_value=5, dtype='int64', device='gpu') # data2=[[5],[5]] + + # attr shape is a list which contains Variable Tensor. + positive_2 = fluid.layers.fill_constant([1], "int32", 2) + data3 = fluid.layers.full(shape=[1, positive_2], dtype='float32', fill_value=1.5) # data3=[1.5, 1.5] + + # attr shape is an Variable Tensor. + shape = fluid.layers.fill_constant([1,2], "int32", 2) # shape=[2,2] + data4 = fluid.layers.full(shape=shape, dtype='bool', fill_value=True) # data4=[[True,True],[True,True]] + + # attr value is an Variable Tensor. + val = fluid.layers.fill_constant([1], "float32", 2.0) # val=[2.0] + data5 = fluid.layers.full(shape=[2,1], fill_value=val, dtype='float32') #data5=[[2.0],[2.0]] + """ + + helper = LayerHelper("full", **locals()) + + if dtype is None: + dtype = 'float32' + + check_dtype(dtype, 'create data type', + ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], + 'full') + check_type(shape, 'shape', (Variable, list, tuple), 'full') + if out is not None: + check_type(shape, 'out', (Variable), 'full') + + if out is None: + out = helper.create_variable_for_type_inference(dtype=dtype) + + out.stop_gradient = stop_gradient + + with device_guard(device): + out = fill_constant(shape=shape, dtype=dtype, value=fill_value, out=out) + + return out + + +def _tril_triu_op(helper): + """Base op of tril_op and triu_op + """ + op_type = helper.layer_type + x = helper.kwargs.get('input', None) + + assert x is not None, 'x cannot be None in {}'.format(op_type) + check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'], + op_type) + if len(x.shape) < 2: + raise ValueError("input shape in {} must be at least 2-D".format( + op_type)) + diagonal = helper.kwargs.get('diagonal', 0) + if not isinstance(diagonal, (int, )): + raise TypeError("diagonal in {} must be a python Int".format(op_type)) + name = helper.kwargs.get('name', None) + + if name is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + else: + out = helper.create_variable( + name=name, dtype=x.dtype, persistable=False) + + helper.append_op( + type="tril_triu", + inputs={"X": x}, + attrs={ + "diagonal": diagonal, + "lower": True if op_type == 'tril' else False, + }, + outputs={"Out": out}, ) + + return out + + +def tril(input, diagonal=0, name=None): + """ + This op returns the lower triangular part of a matrix (2-D tensor) or batch + of matrices :attr:`input`, the other elements of the result tensor are set + to 0. The lower triangular part of the matrix is defined as the elements + on and below the diagonal. + + Args: + input (Variable): The input variable which is a Tensor. + Support data types: ``float64``, ``float32``, ``int32``, ``int64``. + diagonal (int, optional): The diagonal to consider, default value is 0. + If :attr:`diagonal` = 0, all elements on and below the main diagonal are + retained. A positive value includes just as many diagonals above the main + diagonal, and similarly a negative value excludes just as many diagonals below + the main diagonal. The main diagonal are the set of indices + :math:`\{(i, i)\}` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` where + :math:`d_{1}, d_{2}` are the dimensions of the matrix. + name (str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Variable: Tensor, results of lower triangular operation by the specified diagonal of input tensor, + it's data type is the same as input's Tensor. + + Raises: + TypeError: diagonal is not a int type. + ValueError: dimension of :attr:`input` is less than 2. + + Examples: + .. code-block:: python + + import numpy as np + import paddle.fluid as fluid + + data = np.arange(1, 13, dtype="int64").reshape(3,-1) + # array([[ 1, 2, 3, 4], + # [ 5, 6, 7, 8], + # [ 9, 10, 11, 12]]) + x = fluid.data(shape=(-1, 4), dtype='int64', name='x') + exe = fluid.Executor(fluid.CPUPlace()) + + # example 1, default diagonal + tril = fluid.layers.tril(x) + tril_out, = exe.run(fluid.default_main_program(), feed={"x": data}, + fetch_list=[tril], return_numpy=True) + # array([[ 1, 0, 0, 0], + # [ 5, 6, 0, 0], + # [ 9, 10, 11, 0]]) + + .. code-block:: python + + # example 2, positive diagonal value + import paddle.fluid as fluid + import numpy as np + + data = np.arange(1, 13, dtype="int64").reshape(3,-1) + x = fluid.data(shape=(-1, 4), dtype='int64', name='x') + exe = fluid.Executor(fluid.CPUPlace()) + + tril = fluid.layers.tril(x, diagonal=2) + tril_out, = exe.run(fluid.default_main_program(), feed={"x": data}, + fetch_list=[tril], return_numpy=True) + # array([[ 1, 2, 3, 0], + # [ 5, 6, 7, 8], + # [ 9, 10, 11, 12]]) + + .. code-block:: python + + # example 3, negative diagonal value + import paddle.fluid as fluid + import numpy as np + + data = np.arange(1, 13, dtype="int64").reshape(3,-1) + x = fluid.data(shape=(-1, 4), dtype='int64', name='x') + exe = fluid.Executor(fluid.CPUPlace()) + + tril = fluid.layers.tril(x, diagonal=-1) + tril_out, = exe.run(fluid.default_main_program(), feed={"x": data}, + fetch_list=[tril], return_numpy=True) + # array([[ 0, 0, 0, 0], + # [ 5, 0, 0, 0], + # [ 9, 10, 0, 0]]) + + """ + + return _tril_triu_op(LayerHelper('tril', **locals())) + + +def triu(input, diagonal=0, name=None): + """ + This op returns the upper triangular part of a matrix (2-D tensor) or batch of matrices + :attr:`input`, the other elements of the result tensor are set to 0. + The upper triangular part of the matrix is defined as the elements on and + above the diagonal. + + Args: + input (Variable): The input variable which is a Tensor. + Support data types: ``float64``, ``float32``, ``int32``, ``int64``. + diagonal (int, optional): The diagonal to consider, default value is 0. + If :attr:`diagonal` = 0, all elements on and above the main diagonal are + retained. A positive value excludes just as many diagonals above the main + diagonal, and similarly a negative value includes just as many diagonals below + the main diagonal. The main diagonal are the set of indices + :math:`\{(i, i)\}` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]` where + :math:`d_{1}, d_{2}` are the dimensions of the matrix. + name (str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Variable: Tensor, results of upper triangular operation by the specified diagonal of input tensor, + it's data type is the same as input's Tensor. + + Raises: + TypeError: diagonal is not a int type. + ValueError: dimension of :attr:`input` is less than 2. + + Examples: + .. code-block:: python + + import numpy as np + import paddle.fluid as fluid + + data = np.arange(1, 13, dtype="int64").reshape(3,-1) + # array([[ 1, 2, 3, 4], + # [ 5, 6, 7, 8], + # [ 9, 10, 11, 12]]) + x = fluid.data(shape=(-1, 4), dtype='int64', name='x') + exe = fluid.Executor(fluid.CPUPlace()) + + # example 1, default diagonal + import paddle.fluid as fluid + triu = fluid.layers.triu(x) + triu_out, = exe.run(fluid.default_main_program(), feed={"x": data}, + fetch_list=[triu], return_numpy=True) + # array([[ 1, 2, 3, 4], + # [ 0, 6, 7, 8], + # [ 0, 0, 11, 12]]) + + .. code-block:: python + + # example 2, positive diagonal value + import paddle.fluid as fluid + import numpy as np + + data = np.arange(1, 13, dtype="int64").reshape(3,-1) + x = fluid.data(shape=(-1, 4), dtype='int64', name='x') + exe = fluid.Executor(fluid.CPUPlace()) + + triu = fluid.layers.triu(x, diagonal=2) + triu_out, = exe.run(fluid.default_main_program(), feed={"x": data}, + fetch_list=[triu], return_numpy=True) + # array([[0, 0, 3, 4], + # [0, 0, 0, 8], + # [0, 0, 0, 0]]) + + .. code-block:: python + + # example 3, negative diagonal value + import paddle.fluid as fluid + import numpy as np + + data = np.arange(1, 13, dtype="int64").reshape(3,-1) + x = fluid.data(shape=(-1, 4), dtype='int64', name='x') + exe = fluid.Executor(fluid.CPUPlace()) + + triu = fluid.layers.triu(x, diagonal=-1) + triu_out, = exe.run(fluid.default_main_program(), feed={"x": data}, + fetch_list=[triu], return_numpy=True) + # array([[ 1, 2, 3, 4], + # [ 5, 6, 7, 8], + # [ 0, 10, 11, 12]]) + + """ + + return _tril_triu_op(LayerHelper('triu', **locals())) + + @templatedoc(op_type="kron") def kron(x, y, out=None, name=None): """${comment} diff --git a/python/paddle/fluid/tests/unittests/test_allclose_layer.py b/python/paddle/fluid/tests/unittests/test_allclose_layer.py index 60fd157d2e7..c97a2492ab0 100644 --- a/python/paddle/fluid/tests/unittests/test_allclose_layer.py +++ b/python/paddle/fluid/tests/unittests/test_allclose_layer.py @@ -23,9 +23,9 @@ class TestAllcloseLayer(unittest.TestCase): a = fluid.data(name="a", shape=[2], dtype='float32') b = fluid.data(name="b", shape=[2], dtype='float32') - result = paddle.allclose( + result = fluid.layers.allclose( a, b, rtol=1e-05, atol=1e-08, equal_nan=False, name="ignore_nan") - result_nan = paddle.allclose( + result_nan = fluid.layers.allclose( a, b, rtol=1e-05, atol=1e-08, equal_nan=True, name="equal_nan") place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() @@ -82,7 +82,7 @@ class TestAllcloseLayer(unittest.TestCase): with fluid.dygraph.guard(): x_v_1 = fluid.dygraph.to_variable(x_1) y_v_1 = fluid.dygraph.to_variable(y_1) - ret_1 = paddle.allclose( + ret_1 = fluid.layers.allclose( x_v_1, y_v_1, rtol=1e-05, @@ -90,7 +90,7 @@ class TestAllcloseLayer(unittest.TestCase): equal_nan=False, name='test_1') self.assertEqual(ret_1.numpy()[0], False) - ret_1 = paddle.allclose( + ret_1 = fluid.layers.allclose( x_v_1, y_v_1, rtol=1e-05, @@ -100,7 +100,7 @@ class TestAllcloseLayer(unittest.TestCase): self.assertEqual(ret_1.numpy()[0], False) x_v_2 = fluid.dygraph.to_variable(x_2) y_v_2 = fluid.dygraph.to_variable(y_2) - ret_2 = paddle.allclose( + ret_2 = fluid.layers.allclose( x_v_2, y_v_2, rtol=1e-05, @@ -108,7 +108,7 @@ class TestAllcloseLayer(unittest.TestCase): equal_nan=False, name='test_3') self.assertEqual(ret_2.numpy()[0], True) - ret_2 = paddle.allclose( + ret_2 = fluid.layers.allclose( x_v_2, y_v_2, rtol=1e-05, @@ -118,7 +118,7 @@ class TestAllcloseLayer(unittest.TestCase): self.assertEqual(ret_2.numpy()[0], True) x_v_3 = fluid.dygraph.to_variable(x_3) y_v_3 = fluid.dygraph.to_variable(y_3) - ret_3 = paddle.allclose( + ret_3 = fluid.layers.allclose( x_v_3, y_v_3, rtol=1e-05, @@ -126,7 +126,7 @@ class TestAllcloseLayer(unittest.TestCase): equal_nan=False, name='test_5') self.assertEqual(ret_3.numpy()[0], False) - ret_3 = paddle.allclose( + ret_3 = fluid.layers.allclose( x_v_3, y_v_3, rtol=1e-05, diff --git a/python/paddle/fluid/tests/unittests/test_arange.py b/python/paddle/fluid/tests/unittests/test_arange.py index d715744b02a..6c0d1a4b252 100644 --- a/python/paddle/fluid/tests/unittests/test_arange.py +++ b/python/paddle/fluid/tests/unittests/test_arange.py @@ -14,7 +14,6 @@ from __future__ import print_function -import paddle import paddle.fluid as fluid import unittest import numpy as np @@ -71,7 +70,7 @@ class TestInt32ArangeOpCase2(TestArangeOp): class TestArangeAPI(unittest.TestCase): def test_out(self): with fluid.program_guard(fluid.Program()): - data = paddle.arange(0, 5, 1) + data = fluid.layers.arange(0, 5, 1) place = fluid.CPUPlace() exe = fluid.Executor(place) result, = exe.run(fetch_list=[data]) @@ -79,7 +78,7 @@ class TestArangeAPI(unittest.TestCase): self.assertEqual((result == expected_data).all(), True) with fluid.program_guard(fluid.Program()): - data = paddle.arange(0.0, 5.0, 1.0, 'int32') + data = fluid.layers.arange(0.0, 5.0, 1.0, 'int32') place = fluid.CPUPlace() exe = fluid.Executor(place) result, = exe.run(fetch_list=[data]) diff --git a/python/paddle/fluid/tests/unittests/test_bce_loss.py b/python/paddle/fluid/tests/unittests/test_bce_loss.py index 21571e09810..b0961af780b 100644 --- a/python/paddle/fluid/tests/unittests/test_bce_loss.py +++ b/python/paddle/fluid/tests/unittests/test_bce_loss.py @@ -36,7 +36,7 @@ class TestBCELoss(unittest.TestCase): name='input', shape=[None, 30], dtype='float64') label = fluid.data( name='label', shape=[None, 30], dtype='float64') - bce_loss = paddle.nn.loss.BCELoss(reduction=red) + bce_loss = fluid.dygraph.BCELoss(reduction=red) res = bce_loss(input, label) exe = fluid.Executor(place) @@ -47,7 +47,7 @@ class TestBCELoss(unittest.TestCase): fetch_list=[res]) with fluid.dygraph.guard(): - bce_loss = paddle.nn.loss.BCELoss(reduction=red) + bce_loss = fluid.dygraph.BCELoss(reduction=red) dy_res = bce_loss( fluid.dygraph.to_variable(input_np), fluid.dygraph.to_variable(label_np)) @@ -80,7 +80,7 @@ class TestBCELoss(unittest.TestCase): name='label', shape=[None, 3, 4, 10], dtype='float64') weight = fluid.data( name='weight', shape=[3, 4, 10], dtype='float64') - bce_loss = paddle.nn.loss.BCELoss(weight=weight) + bce_loss = fluid.dygraph.BCELoss(weight=weight) res = bce_loss(input, label) exe = fluid.Executor(place) @@ -93,7 +93,7 @@ class TestBCELoss(unittest.TestCase): fetch_list=[res]) with fluid.dygraph.guard(): - bce_loss = paddle.nn.loss.BCELoss( + bce_loss = fluid.dygraph.BCELoss( weight=fluid.dygraph.to_variable(weight_np)) dy_res = bce_loss( fluid.dygraph.to_variable(input_np), diff --git a/python/paddle/fluid/tests/unittests/test_compare_op.py b/python/paddle/fluid/tests/unittests/test_compare_op.py index 9d4a9082b54..a75eb73d6fe 100644 --- a/python/paddle/fluid/tests/unittests/test_compare_op.py +++ b/python/paddle/fluid/tests/unittests/test_compare_op.py @@ -82,7 +82,7 @@ class API_TestElementwise_Equal(unittest.TestCase): with fluid.program_guard(fluid.Program(), fluid.Program()): label = fluid.layers.assign(np.array([3, 3], dtype="int32")) limit = fluid.layers.assign(np.array([3, 2], dtype="int32")) - out = paddle.elementwise_equal(x=label, y=limit) + out = fluid.layers.elementwise_equal(x=label, y=limit) place = fluid.CPUPlace() exe = fluid.Executor(place) res, = exe.run(fetch_list=[out]) @@ -91,7 +91,7 @@ class API_TestElementwise_Equal(unittest.TestCase): with fluid.program_guard(fluid.Program(), fluid.Program()): label = fluid.layers.assign(np.array([3, 3], dtype="int32")) limit = fluid.layers.assign(np.array([3, 3], dtype="int32")) - out = paddle.elementwise_equal(x=label, y=limit) + out = fluid.layers.elementwise_equal(x=label, y=limit) place = fluid.CPUPlace() exe = fluid.Executor(place) res, = exe.run(fetch_list=[out]) diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py index eeed59f5a6c..5b02fdb4d24 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py @@ -35,7 +35,7 @@ class CrossEntropyLoss(unittest.TestCase): label = fluid.layers.data(name='label', shape=[5, 1], dtype='int64') weight = fluid.layers.data( name='weight', shape=[100], dtype='float32') - cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(weight=weight) + cross_entropy_loss = fluid.dygraph.CrossEntropyLoss(weight=weight) ret = cross_entropy_loss(input, label) exe = fluid.Executor(place) @@ -48,7 +48,7 @@ class CrossEntropyLoss(unittest.TestCase): fetch_list=[ret]) self.assertIsNotNone(static_ret) with fluid.dygraph.guard(): - cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + cross_entropy_loss = fluid.dygraph.CrossEntropyLoss( weight=fluid.dygraph.to_variable(weight_np)) dy_ret = cross_entropy_loss( fluid.dygraph.to_variable(input_np), @@ -71,7 +71,7 @@ class CrossEntropyLoss(unittest.TestCase): label = fluid.layers.data(name='label', shape=[5, 1], dtype='int64') weight = fluid.layers.data( name='weight', shape=[100], dtype='float32') - cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + cross_entropy_loss = fluid.dygraph.CrossEntropyLoss( weight=weight, reduction='sum') ret = cross_entropy_loss(input, label) @@ -85,7 +85,7 @@ class CrossEntropyLoss(unittest.TestCase): fetch_list=[ret]) self.assertIsNotNone(static_ret) with fluid.dygraph.guard(): - cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + cross_entropy_loss = fluid.dygraph.CrossEntropyLoss( weight=fluid.dygraph.to_variable(weight_np), reduction='sum') dy_ret = cross_entropy_loss( fluid.dygraph.to_variable(input_np), @@ -108,7 +108,7 @@ class CrossEntropyLoss(unittest.TestCase): label = fluid.layers.data(name='label', shape=[5, 1], dtype='int64') weight = fluid.layers.data( name='weight', shape=[100], dtype='float32') - cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + cross_entropy_loss = fluid.dygraph.CrossEntropyLoss( weight=weight, reduction='none') ret = cross_entropy_loss(input, label) @@ -122,7 +122,7 @@ class CrossEntropyLoss(unittest.TestCase): fetch_list=[ret]) self.assertIsNotNone(static_ret) with fluid.dygraph.guard(): - cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + cross_entropy_loss = fluid.dygraph.CrossEntropyLoss( weight=fluid.dygraph.to_variable(weight_np), reduction='none') dy_ret = cross_entropy_loss( fluid.dygraph.to_variable(input_np), diff --git a/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py b/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py index cd902b7e00e..777d38cb25e 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py @@ -106,7 +106,7 @@ class TestFillAnyLikeOp_attr_out(unittest.TestCase): with fluid.program_guard(train_program, startup_program): fill_value = 2.0 input = fluid.data(name='input', dtype='float32', shape=[2, 3]) - output = paddle.full_like(input, fill_value) + output = fluid.layers.full_like(input, fill_value) place = fluid.CPUPlace() if fluid.core.is_compiled_with_cuda(): @@ -132,20 +132,20 @@ class TestFillAnyLikeOpError(unittest.TestCase): #for ci coverage input_data = fluid.data(name='input', dtype='float32', shape=[2, 3]) - output = paddle.full_like(input_data, 2.0) + output = fluid.layers.full_like(input_data, 2.0) def test_input_dtype(): - paddle.full_like + fluid.layers.full_like self.assertRaises( ValueError, - paddle.full_like, + fluid.layers.full_like, input=input_data, fill_value=2, dtype='uint4') self.assertRaises( TypeError, - paddle.full_like, + fluid.layers.full_like, input=input_data, fill_value=2, dtype='int16') diff --git a/python/paddle/fluid/tests/unittests/test_flip.py b/python/paddle/fluid/tests/unittests/test_flip.py index 77e416e5e6a..20d5b59d75a 100644 --- a/python/paddle/fluid/tests/unittests/test_flip.py +++ b/python/paddle/fluid/tests/unittests/test_flip.py @@ -16,7 +16,6 @@ from __future__ import print_function import unittest import numpy as np -import paddle import paddle.fluid as fluid import paddle.fluid.core as core from paddle.fluid import Program, program_guard @@ -32,7 +31,7 @@ class TestFlipOp_API(unittest.TestCase): with fluid.program_guard(train_program, startup_program): dims = [0] input = fluid.data(name='input', dtype='float32', shape=[2, 3]) - output = paddle.flip(input, dims) + output = fluid.layers.flip(input, dims) place = fluid.CPUPlace() if fluid.core.is_compiled_with_cuda(): place = fluid.CUDAPlace(0) @@ -52,7 +51,7 @@ class TestFlipOp_API(unittest.TestCase): img = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32) with fluid.dygraph.guard(): inputs = fluid.dygraph.to_variable(img) - ret = paddle.flip(inputs, [0]) + ret = fluid.layers.flip(inputs, [0]) out_ref = np.array([[4, 5, 6], [1, 2, 3]]).astype(np.float32) self.assertTrue( (ret.numpy() == out_ref).all(), diff --git a/python/paddle/fluid/tests/unittests/test_full_op.py b/python/paddle/fluid/tests/unittests/test_full_op.py index 29d5be1ea42..7ea8a958d45 100644 --- a/python/paddle/fluid/tests/unittests/test_full_op.py +++ b/python/paddle/fluid/tests/unittests/test_full_op.py @@ -21,7 +21,6 @@ from op_test import OpTest import paddle.fluid.core as core from paddle.fluid.op import Operator import paddle.fluid as fluid -import paddle from paddle.fluid import compiler, Program, program_guard @@ -37,39 +36,39 @@ class TestFullAPI(unittest.TestCase): shape_tensor_int64 = fluid.data( name="shape_tensor_int64", shape=[2], dtype="int64") - out_1 = paddle.full( + out_1 = fluid.layers.full( shape=[1, 2], dtype="float32", fill_value=1.1, device='gpu') - out_2 = paddle.full( + out_2 = fluid.layers.full( shape=[1, positive_2_int32], dtype="float32", fill_value=1.1, device='cpu') - out_3 = paddle.full( + out_3 = fluid.layers.full( shape=[1, positive_2_int64], dtype="float32", fill_value=1.1, device='gpu') - out_4 = paddle.full( + out_4 = fluid.layers.full( shape=shape_tensor_int32, dtype="float32", fill_value=1.2, out=out_3) - out_5 = paddle.full( + out_5 = fluid.layers.full( shape=shape_tensor_int64, dtype="float32", fill_value=1.1, device='gpu', stop_gradient=False) - out_6 = paddle.full( + out_6 = fluid.layers.full( shape=shape_tensor_int64, dtype=np.float32, fill_value=1.1) val = fluid.layers.fill_constant(shape=[1], dtype=np.float32, value=1.1) - out_7 = paddle.full( + out_7 = fluid.layers.full( shape=shape_tensor_int64, dtype=np.float32, fill_value=val) exe = fluid.Executor(place=fluid.CPUPlace()) @@ -97,17 +96,21 @@ class TestFullOpError(unittest.TestCase): x1 = fluid.layers.data(name='x1', shape=[1], dtype="int16") x2 = np.random.randn(1, 2).astype('int32') self.assertRaises( - ValueError, paddle.full, shape=[1], fill_value=5, dtype='uint4') + ValueError, + fluid.layers.full, + shape=[1], + fill_value=5, + dtype='uint4') self.assertRaises( TypeError, - paddle.full, + fluid.layers.full, shape=[1], fill_value=5, dtype='int32', out=x2) self.assertRaises( TypeError, - paddle.full, + fluid.layers.full, shape=[1], fill_value=5, dtype='int16', @@ -118,17 +121,21 @@ class TestFullOpError(unittest.TestCase): x2 = fluid.layers.data(name='x2', shape=[1], dtype="int32") self.assertRaises( - TypeError, paddle.full, shape=[1], fill_value=5, dtype='uint8') + TypeError, + fluid.layers.full, + shape=[1], + fill_value=5, + dtype='uint8') # The argument shape's type of full_op must be list, tuple or Variable. def test_shape_type(): - paddle.full(shape=1, dtype="float32", fill_value=1) + fluid.layers.full(shape=1, dtype="float32", fill_value=1) self.assertRaises(TypeError, test_shape_type) # The argument shape's size of full_op must not be 0. def test_shape_size(): - paddle.full(shape=[], dtype="float32", fill_value=1) + fluid.layers.full(shape=[], dtype="float32", fill_value=1) self.assertRaises(AssertionError, test_shape_size) @@ -136,14 +143,15 @@ class TestFullOpError(unittest.TestCase): def test_shape_tensor_dtype(): shape = fluid.data( name="shape_tensor", shape=[2], dtype="float32") - paddle.full(shape=shape, dtype="float32", fill_value=1) + fluid.layers.full(shape=shape, dtype="float32", fill_value=1) self.assertRaises(TypeError, test_shape_tensor_dtype) def test_shape_tensor_list_dtype(): shape = fluid.data( name="shape_tensor_list", shape=[1], dtype="bool") - paddle.full(shape=[shape, 2], dtype="float32", fill_value=1) + fluid.layers.full( + shape=[shape, 2], dtype="float32", fill_value=1) self.assertRaises(TypeError, test_shape_tensor_list_dtype) diff --git a/python/paddle/fluid/tests/unittests/test_l1_loss.py b/python/paddle/fluid/tests/unittests/test_l1_loss.py index d7e801a666f..5c9eeb21c7b 100644 --- a/python/paddle/fluid/tests/unittests/test_l1_loss.py +++ b/python/paddle/fluid/tests/unittests/test_l1_loss.py @@ -33,7 +33,7 @@ class TestL1Loss(unittest.TestCase): name='input', shape=[10, 1], dtype='float32') label = fluid.layers.data( name='label', shape=[10, 1], dtype='float32') - l1_loss = paddle.nn.loss.L1Loss() + l1_loss = fluid.dygraph.L1Loss() ret = l1_loss(input, label) exe = fluid.Executor(place) @@ -44,7 +44,7 @@ class TestL1Loss(unittest.TestCase): fetch_list=[ret]) with fluid.dygraph.guard(): - l1_loss = paddle.nn.loss.L1Loss() + l1_loss = fluid.dygraph.L1Loss() dy_ret = l1_loss( fluid.dygraph.to_variable(input_np), fluid.dygraph.to_variable(label_np)) @@ -68,7 +68,7 @@ class TestL1Loss(unittest.TestCase): name='input', shape=[10, 10, 5], dtype='float32') label = fluid.layers.data( name='label', shape=[10, 10, 5], dtype='float32') - l1_loss = paddle.nn.loss.L1Loss(reduction='sum') + l1_loss = fluid.dygraph.L1Loss(reduction='sum') ret = l1_loss(input, label) exe = fluid.Executor(place) @@ -79,7 +79,7 @@ class TestL1Loss(unittest.TestCase): fetch_list=[ret]) with fluid.dygraph.guard(): - l1_loss = paddle.nn.loss.L1Loss(reduction='sum') + l1_loss = fluid.dygraph.L1Loss(reduction='sum') dy_ret = l1_loss( fluid.dygraph.to_variable(input_np), fluid.dygraph.to_variable(label_np)) @@ -103,7 +103,7 @@ class TestL1Loss(unittest.TestCase): name='input', shape=[10, 5], dtype='float32') label = fluid.layers.data( name='label', shape=[10, 5], dtype='float32') - l1_loss = paddle.nn.loss.L1Loss(reduction='none') + l1_loss = fluid.dygraph.L1Loss(reduction='none') ret = l1_loss(input, label) exe = fluid.Executor(place) @@ -114,7 +114,7 @@ class TestL1Loss(unittest.TestCase): fetch_list=[ret]) with fluid.dygraph.guard(): - l1_loss = paddle.nn.loss.L1Loss(reduction='none') + l1_loss = fluid.dygraph.L1Loss(reduction='none') dy_ret = l1_loss( fluid.dygraph.to_variable(input_np), fluid.dygraph.to_variable(label_np)) diff --git a/python/paddle/fluid/tests/unittests/test_log_softmax.py b/python/paddle/fluid/tests/unittests/test_log_softmax.py index 2b77624734d..5117208f225 100644 --- a/python/paddle/fluid/tests/unittests/test_log_softmax.py +++ b/python/paddle/fluid/tests/unittests/test_log_softmax.py @@ -17,7 +17,6 @@ import numpy as np import paddle.fluid as fluid import paddle.fluid.core as core import paddle.nn as nn -import paddle.nn.functional as functional def stable_softmax(x): @@ -84,14 +83,14 @@ class TestNNFunctionalLogSoftmaxAPI(unittest.TestCase): mylogsoftmax = nn.LogSoftmax(axis) with fluid.program_guard(main_program): x = fluid.data(name='x', shape=self.x_shape) - y = functional.log_softmax(x, axis, dtype) + y = fluid.layers.log_softmax(x, axis, dtype) exe = fluid.Executor(place) out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y]) self.assertTrue(np.allclose(out[0], ref_out)) with fluid.dygraph.guard(place): x = fluid.dygraph.to_variable(self.x) - y = functional.log_softmax(x, axis, dtype) + y = fluid.layers.log_softmax(x, axis, dtype) self.assertTrue(np.allclose(y.numpy(), ref_out)) def test_check_api(self): diff --git a/python/paddle/fluid/tests/unittests/test_meshgrid_op.py b/python/paddle/fluid/tests/unittests/test_meshgrid_op.py index eea1ca3282c..17d9109d83a 100644 --- a/python/paddle/fluid/tests/unittests/test_meshgrid_op.py +++ b/python/paddle/fluid/tests/unittests/test_meshgrid_op.py @@ -18,8 +18,8 @@ import unittest import numpy as np from op_test import OpTest, skip_check_grad_ci import paddle.fluid as fluid -import paddle from paddle.fluid import compiler, Program, program_guard, core +import paddle class TestMeshgridOp(OpTest): @@ -79,7 +79,7 @@ class TestMeshgridOp3(unittest.TestCase): out_2 = np.broadcast_to(out_2, [100, 200]) exe = fluid.Executor(place=fluid.CPUPlace()) - grid_x, grid_y = paddle.tensor.meshgrid([x, y]) + grid_x, grid_y = paddle.meshgrid([x, y]) res_1, res_2 = exe.run(fluid.default_main_program(), feed={'x': input_1, 'y': input_2}, @@ -95,7 +95,7 @@ class TestMeshgridOp4(unittest.TestCase): def test_input_type(): x = fluid.data(shape=[200], dtype='float32', name='x2') - paddle.tensor.meshgrid(x) + paddle.meshgrid(x) self.assertRaises(TypeError, test_input_type) @@ -108,7 +108,7 @@ class TestMeshgridOp5(unittest.TestCase): with fluid.dygraph.guard(): tensor_3 = fluid.dygraph.to_variable(input_3) tensor_4 = fluid.dygraph.to_variable(input_4) - res_3, res_4 = paddle.tensor.meshgrid([tensor_3, tensor_4]) + res_3, res_4 = paddle.meshgrid([tensor_3, tensor_4]) assert np.array_equal(res_3.shape, [100, 200]) assert np.array_equal(res_4.shape, [100, 200]) diff --git a/python/paddle/fluid/tests/unittests/test_mse_loss.py b/python/paddle/fluid/tests/unittests/test_mse_loss.py index 89052396cf9..1f93bd2c688 100644 --- a/python/paddle/fluid/tests/unittests/test_mse_loss.py +++ b/python/paddle/fluid/tests/unittests/test_mse_loss.py @@ -78,7 +78,7 @@ class TestNNMseLoss(unittest.TestCase): name='input', shape=dim, dtype='float32') label = fluid.layers.data( name='label', shape=dim, dtype='float32') - mse_loss = paddle.nn.loss.MSELoss() + mse_loss = fluid.dygraph.MSELoss() ret = mse_loss(input, label) exe = fluid.Executor(place) @@ -89,7 +89,7 @@ class TestNNMseLoss(unittest.TestCase): fetch_list=[ret]) with fluid.dygraph.guard(): - mse_loss = paddle.nn.loss.MSELoss() + mse_loss = fluid.dygraph.MSELoss() dy_ret = mse_loss( fluid.dygraph.to_variable(input_np), fluid.dygraph.to_variable(label_np)) @@ -115,7 +115,7 @@ class TestNNMseLoss(unittest.TestCase): name='input', shape=dim, dtype='float32') label = fluid.layers.data( name='label', shape=dim, dtype='float32') - mse_loss = paddle.nn.loss.MSELoss(reduction='sum') + mse_loss = fluid.dygraph.MSELoss(reduction='sum') ret = mse_loss(input, label) exe = fluid.Executor(place) @@ -126,7 +126,7 @@ class TestNNMseLoss(unittest.TestCase): fetch_list=[ret]) with fluid.dygraph.guard(): - mse_loss = paddle.nn.loss.MSELoss(reduction='sum') + mse_loss = fluid.dygraph.MSELoss(reduction='sum') dy_ret = mse_loss( fluid.dygraph.to_variable(input_np), fluid.dygraph.to_variable(label_np)) @@ -152,7 +152,7 @@ class TestNNMseLoss(unittest.TestCase): name='input', shape=dim, dtype='float32') label = fluid.layers.data( name='label', shape=dim, dtype='float32') - mse_loss = paddle.nn.loss.MSELoss(reduction='none') + mse_loss = fluid.dygraph.MSELoss(reduction='none') ret = mse_loss(input, label) exe = fluid.Executor(place) @@ -163,7 +163,7 @@ class TestNNMseLoss(unittest.TestCase): fetch_list=[ret]) with fluid.dygraph.guard(): - mse_loss = paddle.nn.loss.MSELoss(reduction='none') + mse_loss = fluid.dygraph.MSELoss(reduction='none') dy_ret = mse_loss( fluid.dygraph.to_variable(input_np), fluid.dygraph.to_variable(label_np)) diff --git a/python/paddle/fluid/tests/unittests/test_nll_loss.py b/python/paddle/fluid/tests/unittests/test_nll_loss.py index b14e3a15d97..36c137e3100 100644 --- a/python/paddle/fluid/tests/unittests/test_nll_loss.py +++ b/python/paddle/fluid/tests/unittests/test_nll_loss.py @@ -82,7 +82,7 @@ class TestNLLLoss(unittest.TestCase): with fluid.program_guard(prog, startup_prog): input = fluid.data(name='input', shape=[10, 10], dtype='float64') label = fluid.data(name='label', shape=[10], dtype='int64') - nll_loss = paddle.nn.loss.NLLLoss() + nll_loss = fluid.dygraph.NLLLoss() res = nll_loss(input, label) exe = fluid.Executor(place) @@ -93,7 +93,7 @@ class TestNLLLoss(unittest.TestCase): fetch_list=[res]) with fluid.dygraph.guard(): - nll_loss = paddle.nn.loss.NLLLoss() + nll_loss = fluid.dygraph.NLLLoss() dy_res = nll_loss( fluid.dygraph.to_variable(input_np), fluid.dygraph.to_variable(label_np)) @@ -115,7 +115,7 @@ class TestNLLLoss(unittest.TestCase): with fluid.program_guard(prog, startup_prog): input = fluid.data(name='input', shape=[10, 10], dtype='float64') label = fluid.data(name='label', shape=[10], dtype='int64') - nll_loss = paddle.nn.loss.NLLLoss(reduction='sum') + nll_loss = fluid.dygraph.NLLLoss(reduction='sum') res = nll_loss(input, label) exe = fluid.Executor(place) @@ -126,7 +126,7 @@ class TestNLLLoss(unittest.TestCase): fetch_list=[res]) with fluid.dygraph.guard(): - nll_loss = paddle.nn.loss.NLLLoss(reduction='sum') + nll_loss = fluid.dygraph.NLLLoss(reduction='sum') dy_res = nll_loss( fluid.dygraph.to_variable(input_np), fluid.dygraph.to_variable(label_np)) @@ -150,7 +150,7 @@ class TestNLLLoss(unittest.TestCase): input = fluid.data(name='input', shape=[10, 10], dtype='float64') label = fluid.data(name='label', shape=[10], dtype='int64') weight = fluid.data(name='weight', shape=[10], dtype='float64') - nll_loss = paddle.nn.loss.NLLLoss(weight=weight) + nll_loss = fluid.dygraph.NLLLoss(weight=weight) res = nll_loss(input, label) exe = fluid.Executor(place) @@ -163,7 +163,7 @@ class TestNLLLoss(unittest.TestCase): fetch_list=[res]) with fluid.dygraph.guard(): - nll_loss = paddle.nn.loss.NLLLoss( + nll_loss = fluid.dygraph.NLLLoss( weight=fluid.dygraph.to_variable(weight_np)) dy_res = nll_loss( fluid.dygraph.to_variable(input_np), @@ -188,7 +188,7 @@ class TestNLLLoss(unittest.TestCase): input = fluid.data(name='input', shape=[10, 10], dtype='float64') label = fluid.data(name='label', shape=[10], dtype='int64') weight = fluid.data(name='weight', shape=[10], dtype='float64') - nll_loss = paddle.nn.loss.NLLLoss(weight=weight, reduction='sum') + nll_loss = fluid.dygraph.NLLLoss(weight=weight, reduction='sum') res = nll_loss(input, label) exe = fluid.Executor(place) @@ -201,7 +201,7 @@ class TestNLLLoss(unittest.TestCase): fetch_list=[res]) with fluid.dygraph.guard(): - nll_loss = paddle.nn.loss.NLLLoss( + nll_loss = fluid.dygraph.NLLLoss( weight=fluid.dygraph.to_variable(weight_np), reduction='sum') dy_res = nll_loss( fluid.dygraph.to_variable(input_np), @@ -225,7 +225,7 @@ class TestNLLLoss(unittest.TestCase): input = fluid.data(name='input', shape=[10, 10], dtype='float64') label = fluid.data(name='label', shape=[10], dtype='int64') weight = fluid.data(name='weight', shape=[10], dtype='float64') - nll_loss = paddle.nn.loss.NLLLoss(weight=weight) + nll_loss = fluid.dygraph.NLLLoss(weight=weight) res = nll_loss(input, label) exe = fluid.Executor(place) @@ -238,7 +238,7 @@ class TestNLLLoss(unittest.TestCase): fetch_list=[res]) with fluid.dygraph.guard(): - nll_loss = paddle.nn.loss.NLLLoss( + nll_loss = fluid.dygraph.NLLLoss( weight=fluid.dygraph.to_variable(weight_np)) dy_res = nll_loss( fluid.dygraph.to_variable(input_np), @@ -261,7 +261,7 @@ class TestNLLLoss(unittest.TestCase): input = fluid.data(name='input', shape=[10, 10], dtype='float64') label = fluid.data(name='label', shape=[10], dtype='int64') weight = fluid.data(name='weight', shape=[10], dtype='float64') - nll_loss = paddle.nn.loss.NLLLoss(weight=weight, reduction='none') + nll_loss = fluid.dygraph.NLLLoss(weight=weight, reduction='none') res = nll_loss(input, label) exe = fluid.Executor(place) @@ -274,7 +274,7 @@ class TestNLLLoss(unittest.TestCase): fetch_list=[res]) with fluid.dygraph.guard(): - nll_loss = paddle.nn.loss.NLLLoss( + nll_loss = fluid.dygraph.NLLLoss( weight=fluid.dygraph.to_variable(weight_np), reduction='none') dy_res = nll_loss( fluid.dygraph.to_variable(input_np), @@ -299,7 +299,7 @@ class TestNLLLoss(unittest.TestCase): input = fluid.data( name='input', shape=[5, 3, 5, 5], dtype='float64') label = fluid.data(name='label', shape=[5, 5, 5], dtype='int64') - nll_loss = paddle.nn.loss.NLLLoss() + nll_loss = fluid.dygraph.NLLLoss() res = nll_loss(input, label) exe = fluid.Executor(place) @@ -310,7 +310,7 @@ class TestNLLLoss(unittest.TestCase): fetch_list=[res]) with fluid.dygraph.guard(): - nll_loss = paddle.nn.loss.NLLLoss() + nll_loss = fluid.dygraph.NLLLoss() dy_res = nll_loss( fluid.dygraph.to_variable(input_np), fluid.dygraph.to_variable(label_np)) @@ -334,7 +334,7 @@ class TestNLLLoss(unittest.TestCase): input = fluid.data( name='input', shape=[5, 3, 5, 5], dtype='float64') label = fluid.data(name='label', shape=[5, 5, 5], dtype='int64') - nll_loss = paddle.nn.loss.NLLLoss(reduction='sum') + nll_loss = fluid.dygraph.NLLLoss(reduction='sum') res = nll_loss(input, label) exe = fluid.Executor(place) @@ -345,7 +345,7 @@ class TestNLLLoss(unittest.TestCase): fetch_list=[res]) with fluid.dygraph.guard(): - nll_loss = paddle.nn.loss.NLLLoss(reduction='sum') + nll_loss = fluid.dygraph.NLLLoss(reduction='sum') dy_res = nll_loss( fluid.dygraph.to_variable(input_np), fluid.dygraph.to_variable(label_np)) @@ -372,7 +372,7 @@ class TestNLLLoss(unittest.TestCase): label = fluid.data(name='label', shape=[5, 5, 5], dtype='int64') weight = fluid.data(name='weight', shape=[3], dtype='float64') - nll_loss = paddle.nn.loss.NLLLoss(weight=weight) + nll_loss = fluid.dygraph.NLLLoss(weight=weight) res = nll_loss(input, label) exe = fluid.Executor(place) @@ -385,7 +385,7 @@ class TestNLLLoss(unittest.TestCase): fetch_list=[res]) with fluid.dygraph.guard(): - nll_loss = paddle.nn.loss.NLLLoss( + nll_loss = fluid.dygraph.NLLLoss( weight=fluid.dygraph.to_variable(weight_np)) dy_res = nll_loss( fluid.dygraph.to_variable(input_np), @@ -411,7 +411,7 @@ class TestNLLLoss(unittest.TestCase): label = fluid.data(name='label', shape=[5, 5, 5], dtype='int64') weight = fluid.data(name='weight', shape=[3], dtype='float64') - nll_loss = paddle.nn.loss.NLLLoss(weight=weight) + nll_loss = fluid.dygraph.NLLLoss(weight=weight) res = nll_loss(input, label) exe = fluid.Executor(place) @@ -424,7 +424,7 @@ class TestNLLLoss(unittest.TestCase): fetch_list=[res]) with fluid.dygraph.guard(): - nll_loss = paddle.nn.loss.NLLLoss( + nll_loss = fluid.dygraph.NLLLoss( weight=fluid.dygraph.to_variable(weight_np)) dy_res = nll_loss( fluid.dygraph.to_variable(input_np), @@ -452,7 +452,7 @@ class TestNLLLoss(unittest.TestCase): label = fluid.data(name='label', shape=[5, 5, 5], dtype='int64') weight = fluid.data(name='weight', shape=[3], dtype='float64') - nll_loss = paddle.nn.loss.NLLLoss(weight=weight, reduction='sum') + nll_loss = fluid.dygraph.NLLLoss(weight=weight, reduction='sum') res = nll_loss(input, label) exe = fluid.Executor(place) @@ -465,7 +465,7 @@ class TestNLLLoss(unittest.TestCase): fetch_list=[res]) with fluid.dygraph.guard(): - nll_loss = paddle.nn.loss.NLLLoss( + nll_loss = fluid.dygraph.NLLLoss( weight=fluid.dygraph.to_variable(weight_np), reduction='sum') dy_res = nll_loss( fluid.dygraph.to_variable(input_np), @@ -491,7 +491,7 @@ class TestNLLLoss(unittest.TestCase): input = fluid.data( name='input', shape=[5, 3, 5, 5, 5], dtype='float64') label = fluid.data(name='label', shape=[5, 5, 5, 5], dtype='int64') - nll_loss = paddle.nn.loss.NLLLoss() + nll_loss = fluid.dygraph.NLLLoss() res = nll_loss(input, label) exe = fluid.Executor(place) @@ -502,7 +502,7 @@ class TestNLLLoss(unittest.TestCase): fetch_list=[res]) with fluid.dygraph.guard(): - nll_loss = paddle.nn.loss.NLLLoss() + nll_loss = fluid.dygraph.NLLLoss() dy_res = nll_loss( fluid.dygraph.to_variable(input_np), fluid.dygraph.to_variable(label_np)) @@ -533,7 +533,7 @@ class TestNLLLoss(unittest.TestCase): name='input', shape=[5, 3, 5, 5, 5], dtype='float64') label = fluid.data(name='label', shape=[5, 5, 5, 5], dtype='int64') weight = fluid.data(name='weight', shape=[3], dtype='float64') - nll_loss = paddle.nn.loss.NLLLoss(weight=weight) + nll_loss = fluid.dygraph.NLLLoss(weight=weight) res = nll_loss(input, label) exe = fluid.Executor(place) @@ -546,7 +546,7 @@ class TestNLLLoss(unittest.TestCase): fetch_list=[res]) with fluid.dygraph.guard(): - nll_loss = paddle.nn.loss.NLLLoss( + nll_loss = fluid.dygraph.NLLLoss( weight=fluid.dygraph.to_variable(weight_np)) dy_res = nll_loss( fluid.dygraph.to_variable(input_np), @@ -579,7 +579,7 @@ class TestNLLLoss(unittest.TestCase): name='input', shape=[5, 3, 5, 5, 5], dtype='float64') label = fluid.data(name='label', shape=[5, 5, 5, 5], dtype='int64') weight = fluid.data(name='weight', shape=[3], dtype='float64') - nll_loss = paddle.nn.loss.NLLLoss(weight=weight, reduction='sum') + nll_loss = fluid.dygraph.NLLLoss(weight=weight, reduction='sum') res = nll_loss(input, label) exe = fluid.Executor(place) @@ -592,7 +592,7 @@ class TestNLLLoss(unittest.TestCase): fetch_list=[res]) with fluid.dygraph.guard(): - nll_loss = paddle.nn.loss.NLLLoss( + nll_loss = fluid.dygraph.NLLLoss( weight=fluid.dygraph.to_variable(weight_np), reduction='sum') dy_res = nll_loss( fluid.dygraph.to_variable(input_np), @@ -628,7 +628,7 @@ class TestNLLLoss(unittest.TestCase): name='input', shape=[5, 3, 5, 5, 5], dtype='float64') label = fluid.data(name='label', shape=[5, 5, 5, 5], dtype='int64') weight = fluid.data(name='weight', shape=[3], dtype='float64') - nll_loss = paddle.nn.loss.NLLLoss(weight=weight, reduction='none') + nll_loss = fluid.dygraph.NLLLoss(weight=weight, reduction='none') res = nll_loss(input, label) exe = fluid.Executor(place) @@ -641,7 +641,7 @@ class TestNLLLoss(unittest.TestCase): fetch_list=[res]) with fluid.dygraph.guard(): - nll_loss = paddle.nn.loss.NLLLoss( + nll_loss = fluid.dygraph.NLLLoss( weight=fluid.dygraph.to_variable(weight_np), reduction='none') dy_res = nll_loss( fluid.dygraph.to_variable(input_np), @@ -676,7 +676,7 @@ class TestNLLLoss(unittest.TestCase): name='input', shape=[5, 3, 5, 5, 5], dtype='float64') label = fluid.data(name='label', shape=[5, 5, 5, 5], dtype='int64') weight = fluid.data(name='weight', shape=[3], dtype='float64') - nll_loss = paddle.nn.loss.NLLLoss(weight=weight, reduction='none') + nll_loss = fluid.dygraph.NLLLoss(weight=weight, reduction='none') res = nll_loss(input, label) exe = fluid.Executor(place) @@ -689,7 +689,7 @@ class TestNLLLoss(unittest.TestCase): fetch_list=[res]) with fluid.dygraph.guard(): - nll_loss = paddle.nn.loss.NLLLoss( + nll_loss = fluid.dygraph.NLLLoss( weight=fluid.dygraph.to_variable(weight_np), reduction='none') dy_res = nll_loss( fluid.dygraph.to_variable(input_np), diff --git a/python/paddle/fluid/tests/unittests/test_randint_op.py b/python/paddle/fluid/tests/unittests/test_randint_op.py index 40c9480a2c9..b896baa8bd1 100644 --- a/python/paddle/fluid/tests/unittests/test_randint_op.py +++ b/python/paddle/fluid/tests/unittests/test_randint_op.py @@ -22,7 +22,6 @@ import paddle.fluid.core as core from paddle.fluid.op import Operator import paddle.fluid as fluid from paddle.fluid import Program, program_guard -import paddle def output_hist(out): @@ -62,17 +61,18 @@ class TestRandintOpError(unittest.TestCase): def test_shape(): shape = np.array([2, 3]) - paddle.randint(5, shape=shape, dtype='int32') + fluid.layers.randint(5, shape=shape, dtype='int32') self.assertRaises(TypeError, test_shape) def test_dtype(): - paddle.randint(5, shape=[32, 32], dtype='float32') + fluid.layers.randint(5, shape=[32, 32], dtype='float32') self.assertRaises(TypeError, test_dtype) def test_low_high(): - paddle.randint(low=5, high=5, shape=[32, 32], dtype='int32') + fluid.layers.randint( + low=5, high=5, shape=[32, 32], dtype='int32') self.assertRaises(ValueError, test_low_high) @@ -131,21 +131,21 @@ class TestRandintAPI(unittest.TestCase): train_program = fluid.Program() with fluid.program_guard(train_program, startup_program): # results are from [0, 5). - output1 = paddle.randint(5) + output1 = fluid.layers.randint(5) # shape is a list and dtype is 'int32' - output2 = paddle.randint( + output2 = fluid.layers.randint( low=-100, high=100, shape=[64, 64], dtype='int32') # shape is a tuple and dtype is 'int64' - output3 = paddle.randint( + output3 = fluid.layers.randint( low=-100, high=100, shape=(32, 32, 3), dtype='int64') # shape is a tensorlist and dtype is 'float32' dim_1 = fluid.layers.fill_constant([1], "int64", 32) dim_2 = fluid.layers.fill_constant([1], "int32", 50) - output4 = paddle.randint( + output4 = fluid.layers.randint( low=-100, high=100, shape=[dim_1, 5], dtype='int32') # shape is a tensor and dtype is 'float64' var_shape = fluid.data(name='var_shape', shape=[2], dtype="int64") - output5 = paddle.randint( + output5 = fluid.layers.randint( low=1, high=1000, shape=var_shape, dtype='int64') place = fluid.CPUPlace() @@ -163,7 +163,7 @@ class TestRandintAPI(unittest.TestCase): class TestRandintDygraphMode(unittest.TestCase): def test_check_output(self): with fluid.dygraph.guard(): - x = paddle.randint(10, shape=[10], dtype="int32") + x = fluid.layers.randint(10, shape=[10], dtype="int32") x_np = x.numpy() for i in range(10): self.assertTrue((x_np[i] >= 0 and x_np[i] < 10)) diff --git a/python/paddle/fluid/tests/unittests/test_randn_op.py b/python/paddle/fluid/tests/unittests/test_randn_op.py index 808e5a08fd6..3f552b93f1e 100644 --- a/python/paddle/fluid/tests/unittests/test_randn_op.py +++ b/python/paddle/fluid/tests/unittests/test_randn_op.py @@ -16,7 +16,6 @@ from __future__ import print_function import unittest import numpy as np -import paddle import paddle.fluid as fluid import paddle.fluid.core as core from paddle.fluid import Program, program_guard @@ -24,14 +23,16 @@ from paddle.fluid import Program, program_guard class TestRandnOp(unittest.TestCase): def test_api(self): - x1 = paddle.randn(shape=[1000, 784], dtype='float32') - x2 = paddle.randn(shape=[1000, 784], dtype='float64') + x1 = fluid.layers.randn(shape=[1000, 784], dtype='float32') + x2 = fluid.layers.randn(shape=[1000, 784], dtype='float64') x3 = fluid.layers.fill_constant( shape=[1000, 784], dtype='float32', value=0) - paddle.randn(shape=[1000, 784], out=x3, dtype='float32') - x4 = paddle.randn(shape=[1000, 784], dtype='float32', device='cpu') - x5 = paddle.randn(shape=[1000, 784], dtype='float32', device='gpu') - x6 = paddle.randn( + fluid.layers.randn(shape=[1000, 784], out=x3, dtype='float32') + x4 = fluid.layers.randn( + shape=[1000, 784], dtype='float32', device='cpu') + x5 = fluid.layers.randn( + shape=[1000, 784], dtype='float32', device='gpu') + x6 = fluid.layers.randn( shape=[1000, 784], dtype='float32', device='gpu', @@ -64,43 +65,43 @@ class TestRandnOpError(unittest.TestCase): # The argument shape's size of randn_op should not be 0. def test_shape_size(): - out = paddle.randn(shape=[]) + out = fluid.layers.randn(shape=[]) self.assertRaises(AssertionError, test_shape_size) # The argument shape's type of randn_op should be list or tuple. def test_shape_type(): - out = paddle.randn(shape=1) + out = fluid.layers.randn(shape=1) self.assertRaises(TypeError, test_shape_type) # The argument dtype of randn_op should be float32 or float64. def test_dtype_float16(): - out = paddle.randn(shape=[1, 2], dtype='float16') + out = fluid.layers.randn(shape=[1, 2], dtype='float16') self.assertRaises(TypeError, test_dtype_float16) # The argument dtype of randn_op should be float32 or float64. def test_dtype_int32(): - out = paddle.randn(shape=[1, 2], dtype='int32') + out = fluid.layers.randn(shape=[1, 2], dtype='int32') self.assertRaises(TypeError, test_dtype_int32) # The argument dtype of randn_op should be float32 or float64. def test_dtype_int64(): - out = paddle.randn(shape=[1, 2], dtype='int64') + out = fluid.layers.randn(shape=[1, 2], dtype='int64') self.assertRaises(TypeError, test_dtype_int64) # The argument dtype of randn_op should be float32 or float64. def test_dtype_uint8(): - out = paddle.randn(shape=[1, 2], dtype='uint8') + out = fluid.layers.randn(shape=[1, 2], dtype='uint8') self.assertRaises(TypeError, test_dtype_uint8) # The argument dtype of randn_op should be float32 or float64. def test_dtype_bool(): - out = paddle.randn(shape=[1, 2], dtype='bool') + out = fluid.layers.randn(shape=[1, 2], dtype='bool') self.assertRaises(TypeError, test_dtype_bool) diff --git a/python/paddle/fluid/tests/unittests/test_randperm_op.py b/python/paddle/fluid/tests/unittests/test_randperm_op.py index 2fbdc83f3ab..23756cf3c65 100644 --- a/python/paddle/fluid/tests/unittests/test_randperm_op.py +++ b/python/paddle/fluid/tests/unittests/test_randperm_op.py @@ -15,7 +15,6 @@ import unittest import numpy as np from op_test import OpTest -import paddle import paddle.fluid as fluid import paddle.fluid.core as core from paddle.fluid.op import Operator @@ -120,12 +119,12 @@ class TestRandpermOpError(unittest.TestCase): def test_Variable(): out = np.arange(10) - paddle.randperm(n=10, out=out) + fluid.layers.randperm(n=10, out=out) self.assertRaises(TypeError, test_Variable) def test_value(): - paddle.randperm(n=-3) + fluid.layers.randperm(n=-3) self.assertRaises(ValueError, test_value) @@ -139,9 +138,9 @@ class TestRandpermOp_attr_out(unittest.TestCase): with fluid.program_guard(train_program, startup_program): n = 10 data_1 = fluid.layers.fill_constant([n], "int64", 3) - paddle.randperm(n=n, out=data_1) + fluid.layers.randperm(n=n, out=data_1) - data_2 = paddle.randperm(n=n, dtype="int32", device="cpu") + data_2 = fluid.layers.randperm(n=n, dtype="int32", device="cpu") place = fluid.CPUPlace() if fluid.core.is_compiled_with_cuda(): @@ -160,12 +159,12 @@ class TestRandpermDygraphMode(unittest.TestCase): def test_check_output(self): with fluid.dygraph.guard(): n = 10 - data_1 = paddle.randperm(n, dtype="int64") + data_1 = fluid.layers.randperm(n, dtype="int64") data_1_np = data_1.numpy() self.assertTrue( check_randperm_out(n, data_1_np), msg=error_msg(data_1_np)) - data_2 = paddle.randperm(n, dtype="int32", device="cpu") + data_2 = fluid.layers.randperm(n, dtype="int32", device="cpu") data_2_np = data_2.numpy() self.assertTrue( check_randperm_out(n, data_2_np), msg=error_msg(data_2_np)) diff --git a/python/paddle/fluid/tests/unittests/test_roll_op.py b/python/paddle/fluid/tests/unittests/test_roll_op.py index d05fc45928f..169e1057b14 100644 --- a/python/paddle/fluid/tests/unittests/test_roll_op.py +++ b/python/paddle/fluid/tests/unittests/test_roll_op.py @@ -15,7 +15,6 @@ from __future__ import print_function import unittest -import paddle import numpy as np import paddle.fluid.core as core from op_test import OpTest @@ -66,7 +65,7 @@ class TestRollAPI(unittest.TestCase): # case 1: with program_guard(Program(), Program()): x = fluid.layers.data(name='x', shape=[-1, 3]) - z = paddle.roll(x, shifts=1) + z = fluid.layers.roll(x, shifts=1) exe = fluid.Executor(fluid.CPUPlace()) res, = exe.run(feed={'x': self.data_x}, fetch_list=[z.name], @@ -78,7 +77,7 @@ class TestRollAPI(unittest.TestCase): # case 2: with program_guard(Program(), Program()): x = fluid.layers.data(name='x', shape=[-1, 3]) - z = paddle.roll(x, shifts=1, dims=0) + z = fluid.layers.roll(x, shifts=1, dims=0) exe = fluid.Executor(fluid.CPUPlace()) res, = exe.run(feed={'x': self.data_x}, fetch_list=[z.name], @@ -92,7 +91,7 @@ class TestRollAPI(unittest.TestCase): # case 1: with fluid.dygraph.guard(): x = fluid.dygraph.to_variable(self.data_x) - z = paddle.roll(x, shifts=1) + z = fluid.layers.roll(x, shifts=1) np_z = z.numpy() expect_out = np.array([[9.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]) @@ -101,7 +100,7 @@ class TestRollAPI(unittest.TestCase): # case 2: with fluid.dygraph.guard(): x = fluid.dygraph.to_variable(self.data_x) - z = paddle.roll(x, shifts=1, dims=0) + z = fluid.layers.roll(x, shifts=1, dims=0) np_z = z.numpy() expect_out = np.array([[7.0, 8.0, 9.0], [1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) diff --git a/python/paddle/fluid/tests/unittests/test_tril_triu_op.py b/python/paddle/fluid/tests/unittests/test_tril_triu_op.py index 84dc7bd8a96..a9c3d0422cd 100644 --- a/python/paddle/fluid/tests/unittests/test_tril_triu_op.py +++ b/python/paddle/fluid/tests/unittests/test_tril_triu_op.py @@ -17,7 +17,6 @@ import unittest import numpy as np from op_test import OpTest import paddle.fluid as fluid -import paddle.tensor as tensor class TrilTriuOpDefaultTest(OpTest): @@ -71,7 +70,7 @@ def case_generator(op_type, Xshape, diagonal, expected): data = fluid.data(shape=Xshape, dtype='float64', name=cls_name) with self.assertRaisesRegexp( eval(expected.split(':')[-1]), errmsg[expected]): - getattr(tensor, op_type)(input=data, diagonal=diagonal) + getattr(fluid.layers, op_type)(input=data, diagonal=diagonal) class SuccessCase(TrilTriuOpDefaultTest): def initTestCase(self): @@ -122,7 +121,7 @@ class TestTrilTriuOpAPI(unittest.TestCase): def test_api(self): data = np.random.random([1, 9, 9, 4]).astype('float32') x = fluid.data(shape=[1, 9, -1, 4], dtype='float32', name='x') - tril_out, triu_out = tensor.tril(x), tensor.triu(x) + tril_out, triu_out = fluid.layers.tril(x), fluid.layers.triu(x) place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( ) else fluid.CPUPlace() -- GitLab