From 570cf1d394011ffc3eb8a0e26814d93e16d45bad Mon Sep 17 00:00:00 2001 From: Li Fuchen Date: Fri, 24 Apr 2020 18:50:23 +0800 Subject: [PATCH] modified the example of diag_embed english doc, test=develop (#24012) (#24134) --- python/paddle/nn/functional/extension.py | 65 +++++++++++++++++++----- 1 file changed, 52 insertions(+), 13 deletions(-) diff --git a/python/paddle/nn/functional/extension.py b/python/paddle/nn/functional/extension.py index 77687041083..112212bfab7 100644 --- a/python/paddle/nn/functional/extension.py +++ b/python/paddle/nn/functional/extension.py @@ -46,27 +46,68 @@ def diag_embed(input, offset=0, dim1=-2, dim2=-1): This OP creates a tensor whose diagonals of certain 2D planes (specified by dim1 and dim2) are filled by ``input``. By default, a 2D plane formed by the last two dimensions of the returned tensor will be selected. + The argument ``offset`` determines which diagonal is generated: + - If offset = 0, it is the main diagonal. - If offset > 0, it is above the main diagonal. - If offset < 0, it is below the main diagonal. + Args: input(Variable|numpy.ndarray): The input tensor. Must be at least 1-dimensional. The input data type should be float32, float64, int32, int64. offset(int, optional): Which diagonal to consider. Default: 0 (main diagonal). dim1(int, optional): The first dimension with respect to which to take diagonal. Default: -2. dim2(int, optional): The second dimension with respect to which to take diagonal. Default: -1. + Returns: Variable, the output data type is the same as input data type. + Examples: .. code-block:: python + import paddle.nn.functional as F import paddle.fluid.dygraph as dg import numpy as np diag_embed = np.random.randn(2, 3).astype('float32') + # [[ 0.7545889 , -0.25074545, 0.5929117 ], + # [-0.6097662 , -0.01753256, 0.619769 ]] with dg.guard(): data1 = F.diag_embed(diag_embed) - data2 = F.diag_embed(diag_embed, offset=1, dim1=0, dim2=2) + data1.numpy() + # [[[ 0.7545889 , 0. , 0. ], + # [ 0. , -0.25074545, 0. ], + # [ 0. , 0. , 0.5929117 ]], + + # [[-0.6097662 , 0. , 0. ], + # [ 0. , -0.01753256, 0. ], + # [ 0. , 0. , 0.619769 ]]] + + data2 = F.diag_embed(diag_embed, offset=-1, dim1=0, dim2=2) + data2.numpy() + # [[[ 0. , 0. , 0. , 0. ], + # [ 0.7545889 , 0. , 0. , 0. ], + # [ 0. , -0.25074545, 0. , 0. ], + # [ 0. , 0. , 0.5929117 , 0. ]], + # + # [[ 0. , 0. , 0. , 0. ], + # [-0.6097662 , 0. , 0. , 0. ], + # [ 0. , -0.01753256, 0. , 0. ], + # [ 0. , 0. , 0.619769 , 0. ]]] + + data3 = F.diag_embed(diag_embed, offset=1, dim1=0, dim2=2) + data3.numpy() + # [[[ 0. , 0.7545889 , 0. , 0. ], + # [ 0. , -0.6097662 , 0. , 0. ]], + # + # [[ 0. , 0. , -0.25074545, 0. ], + # [ 0. , 0. , -0.01753256, 0. ]], + # + # [[ 0. , 0. , 0. , 0.5929117 ], + # [ 0. , 0. , 0. , 0.619769 ]], + # + # [[ 0. , 0. , 0. , 0. ], + # [ 0. , 0. , 0. , 0. ]]] """ inputs = {'Input': [input]} attrs = {'offset': offset, 'dim1': dim1, 'dim2': dim2} @@ -80,26 +121,24 @@ def diag_embed(input, offset=0, dim1=-2, dim2=-1): 'diag_embed') input_shape = list(input.shape) - assert (len(input_shape) >= 1, \ + assert len(input_shape) >= 1, \ "Input must be at least 1-dimensional, " \ "But received Input's dimensional: %s.\n" % \ - len(input_shape)) + len(input_shape) - assert ( - np.abs(dim1) <= len(input_shape), - "Dim1 is out of range (expected to be in range of [%d, %d], but got %d).\n" - % (-(len(input_shape) + 1), len(input_shape), dim1)) + assert np.abs(dim1) <= len(input_shape), \ + "Dim1 is out of range (expected to be in range of [%d, %d], but got %d).\n" \ + % (-(len(input_shape) + 1), len(input_shape), dim1) - assert ( - np.abs(dim2) <= len(input_shape), - "Dim2 is out of range (expected to be in range of [%d, %d], but got %d).\n" - % (-(len(input_shape) + 1), len(input_shape), dim2)) + assert np.abs(dim2) <= len(input_shape), \ + "Dim2 is out of range (expected to be in range of [%d, %d], but got %d).\n" \ + % (-(len(input_shape) + 1), len(input_shape), dim2) dim1_ = dim1 if dim1 >= 0 else len(input_shape) + dim1 + 1 dim2_ = dim2 if dim2 >= 0 else len(input_shape) + dim2 + 1 - assert ( dim1_ != dim2_, + assert dim1_ != dim2_, \ "dim1 and dim2 cannot be the same dimension." \ - "But received dim1 = %d, dim2 = %d\n"%(dim1, dim2)) + "But received dim1 = %d, dim2 = %d\n"%(dim1, dim2) if not in_dygraph_mode(): __check_input(input, offset, dim1, dim2) -- GitLab