diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec old mode 100644 new mode 100755 index 5746c7d0885eec18a27d86264b164f97e9dd6267..dfa480b6a365b115f794cad941b811004348cf0a --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -257,7 +257,7 @@ paddle.fluid.layers.uniform_random_batch_size_like (ArgSpec(args=['input', 'shap paddle.fluid.layers.gaussian_random (ArgSpec(args=['shape', 'mean', 'std', 'seed', 'dtype'], varargs=None, keywords=None, defaults=(0.0, 1.0, 0, 'float32')), ('document', 'dd4ddb66c78a2564e5d1e0e345d8286f')) paddle.fluid.layers.sampling_id (ArgSpec(args=['x', 'min', 'max', 'seed', 'dtype'], varargs=None, keywords=None, defaults=(0.0, 1.0, 0, 'float32')), ('document', '2490492db3b41af9144bb1539e4e9116')) paddle.fluid.layers.gaussian_random_batch_size_like (ArgSpec(args=['input', 'shape', 'input_dim_idx', 'output_dim_idx', 'mean', 'std', 'seed', 'dtype'], varargs=None, keywords=None, defaults=(0, 0, 0.0, 1.0, 0, 'float32')), ('document', '2aed0f546f220364fb1da724a3176f74')) -paddle.fluid.layers.sum (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', 'f4b60847cb0f1ae00823ba6fb1b11310')) +paddle.fluid.layers.sum (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', '42c43fc74347bfe9528850aa7f59b2b2')) paddle.fluid.layers.slice (ArgSpec(args=['input', 'axes', 'starts', 'ends'], varargs=None, keywords=None, defaults=None), ('document', '8c622791994a0d657d8c6c9cefa5bf34')) paddle.fluid.layers.strided_slice (ArgSpec(args=['input', 'axes', 'starts', 'ends', 'strides'], varargs=None, keywords=None, defaults=None), ('document', '33b8dfd6708443ae93f1a0016ff6a5ef')) paddle.fluid.layers.shape (ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None), ('document', '39534cccdb8e727e287316c7c42e6663')) @@ -288,7 +288,7 @@ paddle.fluid.layers.get_tensor_from_selected_rows (ArgSpec(args=['x', 'name'], v paddle.fluid.layers.lstm (ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1)), ('document', '5193cf1113f9d8d8f682ee5a5fc8b391')) paddle.fluid.layers.shuffle_channel (ArgSpec(args=['x', 'group', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '276a1213dd431228cefa33c3146df34a')) paddle.fluid.layers.temporal_shift (ArgSpec(args=['x', 'seg_num', 'shift_ratio', 'name'], varargs=None, keywords=None, defaults=(0.25, None)), ('document', 'd5945431cdcae3cda21914db5bbf383e')) -paddle.fluid.layers.py_func (ArgSpec(args=['func', 'x', 'out', 'backward_func', 'skip_vars_in_backward_input'], varargs=None, keywords=None, defaults=(None, None)), ('document', '8404e472ac12b4a30a505d3d3a3e5fdb')) +paddle.fluid.layers.py_func (ArgSpec(args=['func', 'x', 'out', 'backward_func', 'skip_vars_in_backward_input'], varargs=None, keywords=None, defaults=(None, None)), ('document', '231f91231430f5dae2b757df22317c67')) paddle.fluid.layers.psroi_pool (ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '9bf0cc6b0717010b8ceec5dc2541d566')) paddle.fluid.layers.prroi_pool (ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(1.0, 1, 1, None)), ('document', '454c7ea8c73313dd41513929d7526303')) paddle.fluid.layers.teacher_student_sigmoid_loss (ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0)), ('document', 'b0e07aa41caae04b07a8e8217cc96020')) @@ -347,7 +347,7 @@ paddle.fluid.layers.Switch.__init__ (ArgSpec(args=['self', 'name'], varargs=None paddle.fluid.layers.Switch.case (ArgSpec(args=['self', 'condition'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.layers.Switch.default (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.layers.increment (ArgSpec(args=['x', 'value', 'in_place'], varargs=None, keywords=None, defaults=(1.0, True)), ('document', 'f88b5787bb80ae6b8bf513a70dabbdc1')) -paddle.fluid.layers.array_write (ArgSpec(args=['x', 'i', 'array'], varargs=None, keywords=None, defaults=(None,)), ('document', '3f913b5069ad40bd85d89b33e4aa5939')) +paddle.fluid.layers.array_write (ArgSpec(args=['x', 'i', 'array'], varargs=None, keywords=None, defaults=(None,)), ('document', 'd357f71a280bf06aab4c79de9bd4facf')) paddle.fluid.layers.create_array (ArgSpec(args=['dtype'], varargs=None, keywords=None, defaults=None), ('document', '556de793fdf24d515f3fc91260e2c048')) paddle.fluid.layers.less_than (ArgSpec(args=['x', 'y', 'force_cpu', 'cond'], varargs=None, keywords=None, defaults=(None, None)), ('document', '329bdde01cba69463b08b8c13015560a')) paddle.fluid.layers.less_equal (ArgSpec(args=['x', 'y', 'cond'], varargs=None, keywords=None, defaults=(None,)), ('document', '04e5623dd39b4437b9b08e0ce11071ca')) @@ -355,8 +355,8 @@ paddle.fluid.layers.greater_than (ArgSpec(args=['x', 'y', 'cond'], varargs=None, paddle.fluid.layers.greater_equal (ArgSpec(args=['x', 'y', 'cond'], varargs=None, keywords=None, defaults=(None,)), ('document', '44bdacd11299d72c0a52d2181e7ae6ca')) paddle.fluid.layers.equal (ArgSpec(args=['x', 'y', 'cond'], varargs=None, keywords=None, defaults=(None,)), ('document', '781eac1f980916c68623659f639e2b8c')) paddle.fluid.layers.not_equal (ArgSpec(args=['x', 'y', 'cond'], varargs=None, keywords=None, defaults=(None,)), ('document', '8b76aaac4ba7cf9111750b9c2c9418cb')) -paddle.fluid.layers.array_read (ArgSpec(args=['array', 'i'], varargs=None, keywords=None, defaults=None), ('document', 'caf0d94349cdc28e1bda3b8a19411ac0')) -paddle.fluid.layers.array_length (ArgSpec(args=['array'], varargs=None, keywords=None, defaults=None), ('document', '6f24a9b872027634ad758ea2826c9727')) +paddle.fluid.layers.array_read (ArgSpec(args=['array', 'i'], varargs=None, keywords=None, defaults=None), ('document', 'b75c821cc1d22355c3c17e7bdf509510')) +paddle.fluid.layers.array_length (ArgSpec(args=['array'], varargs=None, keywords=None, defaults=None), ('document', 'c90d305395eb44e6dc772fab24ff2ef5')) paddle.fluid.layers.IfElse ('paddle.fluid.layers.control_flow.IfElse', ('document', '720054043e55273224682fdb6b9ad13b')) paddle.fluid.layers.IfElse.__init__ (ArgSpec(args=['self', 'cond', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.layers.IfElse.false_block (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc old mode 100644 new mode 100755 index 54281b1927d158bdf9876035c238dffacc053a40..27b66250d7397fba496683dc11221fca614788c1 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -175,19 +175,21 @@ class SumOp : public framework::OperatorWithKernel { class SumOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("X", "(vector) The input tensors of sum operator.") + AddInput("X", + "A Varaible list. The shape and data type of the list elements" + "should be consistent. Variable can be multi-dimensional Tensor" + "or LoDTensor, and data types can be: float32, float64, int32, " + "int64.") .AsDuplicable(); - AddOutput("Out", "(Tensor) The output tensor of sum operator."); + AddOutput("Out", + "the sum of input :code:`x`. its shape and data types are " + "consistent with :code:`x`."); AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); - AddComment(R"DOC( -Sum operator. - -This operators sums the input tensors. All the inputs can carry the -LoD (Level of Details) information. However, the output only shares -the LoD information with the first input. -)DOC"); + AddComment(R"DOC(This OP is used to sum one or more Tensor or LoDTensor + of the input. If the input is LoDTensor, the output only + shares LoD information with the first input.)DOC"); } }; diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py old mode 100644 new mode 100755 index e635b802aaaf1b01ae9b6f9dd555f2ddafbfa309..fca5a55bffb82612d01ca31e35cc16d9a9951409 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -935,31 +935,53 @@ def increment(x, value=1.0, in_place=True): def array_write(x, i, array=None): """ - This function writes the given input variable to the specified position - indicating by the arrary index to an output LOD_TENSOR_ARRAY. If the - output LOD_TENSOR_ARRAY is not given(None), a new one will be created and - returned. + This OP writes the input ``x`` into the i-th position of the ``array`` + :ref:`api_fluid_LoDTensorArray` and returns the modified array. + If ``array`` is none, a new LoDTensorArray will be created and returned. + This OP is often used together with :ref:`api_fluid_layers_array_read` OP. Args: - x (Variable|list): The input tensor from which the data will be read. - i (Variable|list): The index of the output LOD_TENSOR_ARRAY, pointing to - the position to which the input tensor will be - written. - array (Variable|list): The output LOD_TENSOR_ARRAY to which the input - tensor will be written. If this parameter is - NONE, a new LOD_TENSOR_ARRAY will be created and - returned. + x (Variable): The input data to be written into array. It's multi-dimensional + Tensor or LoDTensor. Data type: float32, float64, int32, int64. + i (Variable): 1-D Tensor with shape [1], which represents the position into which + ``x`` is written. Data type: int64. + array (LoDTensorArray, optional): The LoDTensorArray into which ``x`` is written. + The default value is None, when a new LoDTensorArray will be created and returned + as a result. Returns: - Variable: The output LOD_TENSOR_ARRAY where the input tensor is written. + Variable: The input ``array`` after ``x`` is written into. Examples: .. code-block:: python - import paddle.fluid as fluid - tmp = fluid.layers.zeros(shape=[10], dtype='int32') - i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=10) - arr = fluid.layers.array_write(tmp, i=i) + import paddle.fluid as fluid + tmp = fluid.layers.fill_constant(shape=[3, 2], dtype='int64', value=5) + i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=10) + # Write tmp into the position of arr with subscript 10 and return arr. + arr = fluid.layers.array_write(tmp, i=i) + + # Now, arr is a LoDTensorArray with length 11. We can use array_read OP to read + # the data at subscript 10 and print it out. + item = fluid.layers.array_read(arr, i=i) + input = fluid.layers.Print(item, message="The content of i-th LoDTensor:") + main_program = fluid.default_main_program() + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(main_program) + + # The printed result is: + # 1570533133 The content of i-th LoDTensor: The place is:CPUPlace + # Tensor[array_read_0.tmp_0] + # shape: [3,2,] + # dtype: l + # data: 5,5,5,5,5,5, + + # the output is 2-D Tensor with shape [3,2], which is tmp above. + # dtype is the corresponding C++ data type, which may vary in different environments. + # Eg: if the data type of tensor is int64, then the corresponding C++ data type is int64_t, + # so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux, + # and '__int64' on Windows. They both represent 64-bit integer variables. + """ helper = LayerHelper('array_write', **locals()) if array is None: @@ -1265,38 +1287,66 @@ def not_equal(x, y, cond=None): def array_read(array, i): """ - This function performs the operation to read the data in as an - LOD_TENSOR_ARRAY. - - .. code-block:: text - - Given: - - array = [0.6, 0.1, 0.3, 0.1] - - And: - - i = 2 - - Then: - - output = 0.3 + This OP is used to read data at the specified position from the input array + :ref:`api_fluid_LoDTensorArray` . ``array`` is the input array and ``i`` + is the specified read position. This OP is often used together with + :ref:`api_fluid_layers_array_write` OP. + + Case 1: + :: + Input: + The shape of first three tensors are [1], and that of the last one is [1,2]: + array = ([0.6], [0.1], [0.3], [0.4, 0.2]) + And: + i = [3] + + Output: + output = [0.4, 0.2] Args: - array (Variable|list): The input tensor that store data to be read. - i (Variable|list): The index of the data to be read from input array. + array (LoDTensorArray): The input LoDTensorArray. + i (Variable): 1-D Tensor, whose shape is [1] and dtype is int64. It represents the + specified read position of ``array``. Returns: - Variable: The tensor type variable that has the data written to it. + Variable: The LoDTensor or Tensor that is read at the specified position of ``array``. Examples: .. code-block:: python - import paddle.fluid as fluid - array = fluid.layers.create_array(dtype='float32') - i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=10) - item = fluid.layers.array_read(array, i) + # First we're going to create a LoDTensorArray, then we're going to write the Tensor into + # the specified position, and finally we're going to read the Tensor at that position. + import paddle.fluid as fluid + arr = fluid.layers.create_array(dtype='float32') + tmp = fluid.layers.fill_constant(shape=[3, 2], dtype='int64', value=5) + i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=10) + # tmp is the Tensor with shape [3,2], and if we write it into the position with subscript 10 + # of the empty-array: arr, then the length of arr becomes 11. + arr = fluid.layers.array_write(tmp, i, array=arr) + # Read the data of the position with subscript 10. + item = fluid.layers.array_read(arr, i) + + # You can print out the data via executor. + input = fluid.layers.Print(item, message="The LoDTensor of the i-th position:") + main_program = fluid.default_main_program() + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(main_program) + + # The printed result is: + + # 1569588169 The LoDTensor of the i-th position: The place is:CPUPlace + # Tensor[array_read_0.tmp_0] + # shape: [3,2,] + # dtype: l + # data: 5,5,5,5,5,5, + + # the output is 2-D Tensor with shape [3,2]. + # dtype is the corresponding C++ data type, which may vary in different environments. + # Eg: if the data type of tensor is int64, then the corresponding C++ data type is int64_t, + # so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux, + # and '__int64' on Windows. They both represent 64-bit integer variables. """ + helper = LayerHelper('array_read', **locals()) if not isinstance( array, @@ -1350,29 +1400,48 @@ def shrink_memory(x, i, table): def array_length(array): """ - **Get the Length of Input LoDTensorArray** - - This function performs the operation to find the length of the input - LOD_TENSOR_ARRAY. - - Related API: array_read, array_write, While. + This OP is used to get the length of the input array :ref:`api_fluid_LoDTensorArray` . + It can be used together with :ref:`api_fluid_layers_array_read` , :ref:`api_fluid_layers_array_write` , + :ref:`api_fluid_layers_While` OP to traverse, read and wirte LoDTensorArray. Args: - array (LOD_TENSOR_ARRAY): The input array that will be used - to compute the length. + array (LoDTensorArray): The input array that will be used to compute the length. Returns: - Variable: The length of the input LoDTensorArray. + Variable: 1-D Tensor with shape [1], which is the length of array. Datatype: int64. Examples: .. code-block:: python - import paddle.fluid as fluid - tmp = fluid.layers.zeros(shape=[10], dtype='int32') - i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=10) - arr = fluid.layers.array_write(tmp, i=i) - arr_len = fluid.layers.array_length(arr) + import paddle.fluid as fluid + tmp = fluid.layers.zeros(shape=[10], dtype='int32') + i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=10) + # tmp is 1-D Tensor with shape [10]. We write tmp into arr on subscript 10, + # then the length of arr becomes 11. + arr = fluid.layers.array_write(tmp, i=i) + # return the length of arr + arr_len = fluid.layers.array_length(arr) + + # You can use executor to print out the length of LoDTensorArray. + input = fluid.layers.Print(arr_len, message="The length of LoDTensorArray:") + main_program = fluid.default_main_program() + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(main_program) + + # The printed result is: + # 1569576542 The length of LoDTensorArray: The place is:CPUPlace + # Tensor[array_length_0.tmp_0] + # shape: [1,] + # dtype: l + # data: 11, + + # 1-D Tensor with shape [1], whose value is 11. It means that the length of LoDTensorArray + # is 11. + # dtype is the corresponding C++ data type, which may vary in different environments. + # Eg: if the data type of tensor is int64, then the corresponding C++ data type is int64_t, + # so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux, + # and '__int64' on Windows. They both represent 64-bit integer variables. """ helper = LayerHelper('array_length', **locals()) tmp = helper.create_variable_for_type_inference(dtype='int64') diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 1a21f9b51245d339e9f7b85357162730f0968245..7cfbffcefb45baec96cc614a81140f450c1b4c00 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -12488,21 +12488,69 @@ def gaussian_random_batch_size_like(input, def sum(x): """ ${comment} + + Case 1: + :: + Input: + Input. Shape = [2, 3] + Input = [[1, 2, 3], + [4, 5, 6]] + + Output: + The output. Shape = [2, 3] + Output = [[1, 2, 3], + [4, 5, 6]] + + Case 2: + :: + Input: + First input: + Input1. Shape = [2, 3] + Input1 = [[1, 2, 3], + [4, 5, 6]] + + The second input: + Input2. Shape = [2, 3] + Input2 = [[7, 8, 9], + [10, 11, 12]] + + Output: + The output. Shape = [2, 3] + Output = [[8, 10, 12], + [14, 16, 18]] Args: - x (Variable): ${x_comment} + x (Variable|list(Variable)): ${x_comment} Returns: - out (Variable): ${out_comment} + Variable: ${out_comment} Examples: .. code-block:: python import paddle.fluid as fluid - import paddle.fluid.layers as layers - input0 = layers.data(name="input0", shape=[13, 11], dtype='float32') - input1 = layers.data(name="input1", shape=[13, 11], dtype='float32') - out = layers.sum([input0,input1]) + + input0 = fluid.layers.fill_constant(shape=[2, 3], dtype='int64', value=5) + input1 = fluid.layers.fill_constant(shape=[2, 3], dtype='int64', value=3) + sum = fluid.layers.sum([input0, input1]) + + # You can print out 'sum' via executor. + out = fluid.layers.Print(sum, message="the sum of input0 and input1: ") + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(fluid.default_main_program()) + + # The printed result is: + # 1570701754 the sum of input0 and input1: The place is:CPUPlace + # Tensor[sum_0.tmp_0] + # shape: [2,3,] + # dtype: l + # data: 8,8,8,8,8,8, + + # the sum of input0 and input1 is 2-D Tensor with shape [2,3]. + # dtype is the corresponding C++ data type, which may vary in different environments. + # Eg: if the data type of tensor is int64, then the corresponding C++ data type is int64_t, + # so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux, + # and '__int64' on Windows. They both represent 64-bit integer variables. """ helper = LayerHelper('sum', **locals()) @@ -15095,85 +15143,90 @@ class PyFuncRegistry(object): @templatedoc() def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): """ - PyFunc Operator. - - User can use :code:`py_func` to register operators in Python side. - The inputs of :code:`func` is :code:`LoDTensor` and outputs can be - numpy array or :code:`LoDTensor`. Paddle would call the registered - :code:`func` in forward part, and call :code:`backward_func` in - backward part (if :code:`backward_func` is not None). + This API is used to register customized OP to Fluid. The forward function + of the registered OP is ``func`` and the backward function of that is + ``backward_func``. Paddle will call ``func`` at forward runtime and call + ``backward_func`` at backward runtime(if ``backward_func`` is not None). + ``x`` is the input of ``func``, whose type must be LoDTensor; ``out`` is + the output of ``func``, whose type can be either LoDTensor or NumPy array. - User should set the right data type and shape of :code:`out` before - calling this function. However, data types and shapes of gradients of - :code:`out` and :code:`x` would be inferred automatically. + The input of the backward function ``backward_func`` is ``x``, ``out`` and + the gradient of ``out``. If some variables of ``out`` have no gradient, the + relevant input variable of ``backward_func`` is None. If some variables of + ``x`` do not have a gradient, the user should return None in ``backward_func``. - Input orders of :code:`backward_func` would be: forward inputs - :code:`x`, forward outputs :code:`out` and backward input gradients of - :code:`out`. If some variables of :code:`out` have no gradient, the input - tensor would be None in Python side. If some variables of :code:`in` have - no gradient, users should return None. + The data type and shape of ``out`` should also be set correctly before this + API is called, and the data type and shape of the gradient of ``out`` and + ``x`` will be inferred automatically. - This function can also be used to debug the running network. User can - add a :code:`py_func` operator without output, and print input - :code:`x` inside :code:`func`. + This API can also be used to debug the neural network by setting the ``func`` + as a function that only print variables. Args: - func (callable): forward Python function. - x (Variable|list(Variable)|tuple(Variable)): inputs of :code:`func`. - out (Variable|list(Variable)|tuple(Variable)): outputs of :code:`func`. - Paddle cannot infer shapes and data types of :code:`out`. Users - should create :code:`out` beforehand. - backward_func (callable|None): backward Python function. - None means no backward. Default None. - skip_vars_in_backward_input (Variable|list(Variable)|tuple(Variable)): - Variables that are not needed in :code:`backward_func` inputs. - These variables must be any of :code:`x` and :code:`out`. - If set, these vars would not be inputs of :code:`backward_func`, - Only useful when :code:`backward_func` is not None. Default None. - - Returns: - out (Variable|list(Variable)|tuple(Variable)): input :code:`out` + func (callable): The forward function of the registered OP. When the network + is running, the forward output ``out`` will be calculated according to this + function and the forward input ``x``. + x (Variable): The input of the forward function ``func``, its type can be + Variable | tuple[Variable] | list[Variale], in which Variable is LoDTensor. + out (Variable): The output of the forward function ``func``, its type can be + Variable | tuple[Variable] | list[Variale], in which Variable can be either + LoDTensor or NumPy array. Since Paddle cannot automatically infer the shape + and data type of ``out``, ``out`` must be created in advance. + backward_func (callable, optional): The backward function of the registered OP. + Its default value is None, which means there is no reverse calculation. If + it is not None, ``backward_func`` is called to calculate the gradient of + ``x`` when the network is at backward runtime. + skip_vars_in_backward_input (Variable, optional): It's used to limit the input + variable list of ``backward_func``, and it can be single Variable, tuple[Variable] + or list[Variable]. It must belong to either ``x`` or ``out``. The default + value is None, which means that no variables need to be removed from ``x`` + and ``out``. If it is not None, these variables will not be the input of + ``backward_func``. This parameter is only useful when ``backward_func`` is + not None. + + Returns: + Variable: The output ``out`` of the forward function ``func``. Examples: + .. code-block:: python - >>> import paddle.fluid as fluid - >>> import six - >>> - >>> def create_tmp_var(name, dtype, shape): - >>> return fluid.default_main_program().current_block().create_var( - >>> name=name, dtype=dtype, shape=shape) - >>> - >>> # tanh activation has been provided by Paddle C++ op - >>> # Here, we only use tanh to be an example to show the usage - >>> # of py_func - >>> def tanh(x): - >>> return np.tanh(x) - >>> - >>> # forward input x is skipped - >>> def tanh_grad(y, dy): - >>> return np.array(dy) * (1 - np.square(np.array(y))) - >>> - >>> def debug_func(x): - >>> print(x) - >>> - >>> def simple_net(img, label): - >>> hidden = img - >>> for idx in six.moves.range(4): - >>> hidden = fluid.layers.fc(hidden, size=200) - >>> new_hidden = create_tmp_var(name='hidden_{}'.format(idx), - >>> dtype=hidden.dtype, shape=hidden.shape) - >>> - >>> # user-defined layers with forward and backward - >>> hidden = fluid.layers.py_func(func=tanh, x=hidden, - >>> out=new_hidden, backward_func=tanh_grad, - >>> skip_vars_in_backward_input=hidden) - >>> - >>> # user-defined debug layers to print variables - >>> fluid.layers.py_func(func=debug_func, x=hidden, out=None) - >>> - >>> prediction = fluid.layers.fc(hidden, size=10, act='softmax') - >>> loss = fluid.layers.cross_entropy(input=prediction, label=label) - >>> return fluid.layers.mean(loss) + import paddle.fluid as fluid + import six + + def create_tmp_var(name, dtype, shape): + return fluid.default_main_program().current_block().create_var( + name=name, dtype=dtype, shape=shape) + + # Tanh activation function provided by Paddle C++ op + # Here, tanh is used as an example to show how to use py_func + def tanh(x): + return np.tanh(x) + + # Skip forward input x + def tanh_grad(y, dy): + return np.array(dy) * (1 - np.square(np.array(y))) + + def debug_func(x): + print(x) + + def simple_net(img, label): + hidden = img + for idx in six.moves.range(4): + hidden = fluid.layers.fc(hidden, size=200) + new_hidden = create_tmp_var(name='hidden_{}'.format(idx), + dtype=hidden.dtype, shape=hidden.shape) + + # User-defined forward and backward + hidden = fluid.layers.py_func(func=tanh, x=hidden, + out=new_hidden, backward_func=tanh_grad, + skip_vars_in_backward_input=hidden) + + # User-defined debugging layer, which can print out variable details + fluid.layers.py_func(func=debug_func, x=hidden, out=None) + + prediction = fluid.layers.fc(hidden, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + return fluid.layers.mean(loss) """ helper = LayerHelper('py_func', **locals()) if x is None: