From 518f9d81411f16fe1cfd84ed8b3c2d4ea5c5232b Mon Sep 17 00:00:00 2001 From: cyber-pioneer <116002591+cyber-pioneer@users.noreply.github.com> Date: Fri, 2 Dec 2022 11:03:05 +0800 Subject: [PATCH] move fluid.layer.py_func to paddle.static.nn.common.py_func (#48482) --- python/paddle/fluid/layers/nn.py | 322 ----------------- .../fluid/tests/unittests/test_py_func_op.py | 12 +- .../tests/unittests/test_rnn_decode_api.py | 2 +- python/paddle/static/__init__.py | 5 +- python/paddle/static/nn/__init__.py | 2 +- python/paddle/static/nn/common.py | 324 ++++++++++++++++++ 6 files changed, 335 insertions(+), 332 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index c2599454c1..7ff74cd37c 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -113,7 +113,6 @@ __all__ = [ 'merge_selected_rows', 'get_tensor_from_selected_rows', 'temporal_shift', - 'py_func', 'continuous_value_model', 'unfold', 'deformable_roi_pooling', @@ -6635,327 +6634,6 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None, data_format="NCHW"): ) -class PyFuncRegistry: - _register_funcs = [] - - def __init__(self, func): - if func is None or not callable(func): - raise TypeError('func must be a Python function') - - self._func = func - # find named args using reflection - args = inspect.getfullargspec(self._func) - if len(args[0]) == 0 and args[1] is None and args[2] is None: - # Function with no inputs - self._named_args = None - else: - self._named_args = args[0] - self._id = core._append_python_callable_object_and_return_id(self) - ''' - Why record self here? - - 1. For debug usage. Users can call - :code:`py_func.registered_func(idx)` method - to find the registered function corresponding - to :code:`idx`. - - 2. For increasing reference count of self. - It seems that to release Python object - whose reference count is 1 would cause - segmentation fault error in C++ side. - May be lack of Python GC in C++ side? - ''' - PyFuncRegistry._register_funcs.append(self) - - @classmethod - def registered_func(cls, idx): - return cls._register_funcs[idx]._func - - @classmethod - def registered_func_num(cls): - return len(cls._register_funcs) - - @property - def id(self): - return self._id - - def __call__(self, *args): - if self._named_args is None: - func_ret = self._func() - else: - kwargs = dict() - idx = 0 - for arg in self._named_args: - kwargs[arg] = args[idx] - idx += 1 - func_ret = self._func(*args[idx:], **kwargs) - - if not isinstance(func_ret, (list, tuple)): - func_ret = (func_ret,) - - ret = [] - for each_ret in func_ret: - if each_ret is None or isinstance(each_ret, core.LoDTensor): - ret.append(each_ret) - continue - - if not isinstance(each_ret, np.ndarray): - each_ret = np.array(each_ret) - - tensor = core.LoDTensor() - tensor.set(each_ret, core.CPUPlace()) - ret.append(tensor) - - return tuple(ret) - - -@static_only -@templatedoc() -def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): - """ - :api_attr: Static Graph - - This OP is used to register customized Python OP to Paddle. The design - principe of py_func is that Tensor and numpy array can be converted to each - other easily. So you can use Python and numpy API to register a python OP. - - The forward function of the registered OP is ``func`` and the backward function - of that is ``backward_func``. Paddle will call ``func`` at forward runtime and - call ``backward_func`` at backward runtime(if ``backward_func`` is not None). - ``x`` is the input of ``func``, whose type must be Tensor; ``out`` is - the output of ``func``, whose type can be either Tensor or numpy array. - - The input of the backward function ``backward_func`` is ``x``, ``out`` and - the gradient of ``out``. If ``out`` have no gradient, the relevant input of - ``backward_func`` is None. If ``x`` do not have a gradient, the user should - return None in ``backward_func``. - - The data type and shape of ``out`` should also be set correctly before this - API is called, and the data type and shape of the gradient of ``out`` and - ``x`` will be inferred automatically. - - This API can also be used to debug the neural network by setting the ``func`` - as a function that only print variables. - - Args: - func (callable): The forward function of the registered OP. When the network - is running, the forward output ``out`` will be calculated according to this - function and the forward input ``x``. In ``func`` , it's suggested that we - actively convert Tensor into a numpy array, so that we can use Python and - numpy API arbitrarily. If not, some operations of numpy may not be compatible. - x (Tensor|tuple(Tensor)|list[Tensor]): The input of the forward function ``func``. - It can be Tensor|tuple(Tensor)|list[Tensor]. In addition, Multiple Tensor - should be passed in the form of tuple(Tensor) or list[Tensor]. - out (T|tuple(T)|list[T]): The output of the forward function ``func``, it can be - T|tuple(T)|list[T], where T can be either Tensor or numpy array. Since Paddle - cannot automatically infer the shape and type of ``out``, you must create - ``out`` in advance. - backward_func (callable, optional): The backward function of the registered OP. - Its default value is None, which means there is no reverse calculation. If - it is not None, ``backward_func`` is called to calculate the gradient of - ``x`` when the network is at backward runtime. - skip_vars_in_backward_input (Tensor, optional): It's used to limit the input - list of ``backward_func``, and it can be Tensor|tuple(Tensor)|list[Tensor]. - It must belong to either ``x`` or ``out``. The default value is None, which means - that no tensors need to be removed from ``x`` and ``out``. If it is not None, - these tensors will not be the input of ``backward_func``. This parameter is only - useful when ``backward_func`` is not None. - - Returns: - Tensor|tuple(Tensor)|list[Tensor]: The output ``out`` of the forward function ``func``. - - Examples: - .. code-block:: python - - # example 1: - import paddle - import numpy as np - - paddle.enable_static() - - # Creates a forward function, Tensor can be input directly without - # being converted into numpy array. - def tanh(x): - return np.tanh(x) - - # Skip x in backward function and return the gradient of x - # Tensor must be actively converted to numpy array, otherwise, - # operations such as +/- can't be used. - def tanh_grad(y, dy): - return np.array(dy) * (1 - np.square(np.array(y))) - - # Creates a forward function for debugging running networks(print value) - def debug_func(x): - print(x) - - def create_tmp_var(name, dtype, shape): - return paddle.static.default_main_program().current_block().create_var( - name=name, dtype=dtype, shape=shape) - - def simple_net(img, label): - hidden = img - for idx in range(4): - hidden = paddle.static.nn.fc(hidden, size=200) - new_hidden = create_tmp_var(name='hidden_{}'.format(idx), - dtype=hidden.dtype, shape=hidden.shape) - - # User-defined forward and backward - hidden = paddle.static.py_func(func=tanh, x=hidden, - out=new_hidden, backward_func=tanh_grad, - skip_vars_in_backward_input=hidden) - - # User-defined debug functions that print out the input Tensor - paddle.static.py_func(func=debug_func, x=hidden, out=None) - - prediction = paddle.static.nn.fc(hidden, size=10, activation='softmax') - ce_loss = paddle.nn.loss.CrossEntropyLoss() - return ce_loss(prediction, label) - - x = paddle.static.data(name='x', shape=[1,4], dtype='float32') - y = paddle.static.data(name='y', shape=[1], dtype='int64') - res = simple_net(x, y) - - exe = paddle.static.Executor(paddle.CPUPlace()) - exe.run(paddle.static.default_startup_program()) - input1 = np.random.random(size=[1,4]).astype('float32') - input2 = np.random.randint(1, 10, size=[1], dtype='int64') - out = exe.run(paddle.static.default_main_program(), - feed={'x':input1, 'y':input2}, - fetch_list=[res.name]) - print(out) - - .. code-block:: python - - # example 2: - # This example shows how to turn Tensor into numpy array and - # use numpy API to register an Python OP - import paddle - import numpy as np - - paddle.enable_static() - - def element_wise_add(x, y): - # Tensor must be actively converted to numpy array, otherwise, - # numpy.shape can't be used. - x = np.array(x) - y = np.array(y) - - if x.shape != y.shape: - raise AssertionError("the shape of inputs must be the same!") - - result = np.zeros(x.shape, dtype='int32') - for i in range(len(x)): - for j in range(len(x[0])): - result[i][j] = x[i][j] + y[i][j] - - return result - - def create_tmp_var(name, dtype, shape): - return paddle.static.default_main_program().current_block().create_var( - name=name, dtype=dtype, shape=shape) - - def py_func_demo(): - start_program = paddle.static.default_startup_program() - main_program = paddle.static.default_main_program() - - # Input of the forward function - x = paddle.static.data(name='x', shape=[2,3], dtype='int32') - y = paddle.static.data(name='y', shape=[2,3], dtype='int32') - - # Output of the forward function, name/dtype/shape must be specified - output = create_tmp_var('output','int32', [3,1]) - - # Multiple Variable should be passed in the form of tuple(Variale) or list[Variale] - paddle.static.py_func(func=element_wise_add, x=[x,y], out=output) - - exe=paddle.static.Executor(paddle.CPUPlace()) - exe.run(start_program) - - # Feed numpy array to main_program - input1 = np.random.randint(1, 10, size=[2,3], dtype='int32') - input2 = np.random.randint(1, 10, size=[2,3], dtype='int32') - out = exe.run(main_program, - feed={'x':input1, 'y':input2}, - fetch_list=[output.name]) - print("{0} + {1} = {2}".format(input1, input2, out)) - - py_func_demo() - - # Reference output: - # [[5, 9, 9] + [[7, 8, 4] = [array([[12, 17, 13] - # [7, 5, 2]] [1, 3, 3]] [8, 8, 5]], dtype=int32)] - """ - helper = LayerHelper('py_func', **locals()) - check_type(x, 'X', (list, tuple, Variable, type(None)), 'py_func') - if x is None: - x = [] - elif isinstance(x, Variable): - x = [x] - elif isinstance(x, tuple): - x = list(x) - elif not isinstance(x, (list, tuple, Variable)): - raise TypeError('Input must be Variable/list(Variable)/tuple(Variable)') - check_type(out, 'Out', (list, tuple, Variable, type(None)), 'py_func') - if out is None: - out_list = [] - elif isinstance(out, Variable): - out_list = [out] - elif isinstance(out, tuple): - out_list = list(out) - elif isinstance(out, list): - out_list = out - else: - raise TypeError( - 'Output must be Variable/list(Variable)/tuple(Variable)' - ) - - fwd_func_id = PyFuncRegistry(func).id - bwd_func_id = ( - PyFuncRegistry(backward_func).id if backward_func is not None else -1 - ) - - for each_out in out_list: - if len(each_out.shape) == 0: - raise ValueError( - 'Output shapes of py_func op should be provided by users manually' - ) - - backward_skip_vars = set() - if backward_func is not None and skip_vars_in_backward_input is not None: - if isinstance(skip_vars_in_backward_input, Variable): - skip_vars_in_backward_input = [skip_vars_in_backward_input] - - fwd_in_out = [v.name for v in x] - fwd_in_out.extend([v.name for v in out_list]) - fwd_in_out = set(fwd_in_out) - backward_skip_vars = set() - for v in skip_vars_in_backward_input: - if not v.name in fwd_in_out: - raise ValueError( - 'Variable {} is not found in forward inputs and outputs'.format( - v.name - ) - ) - backward_skip_vars.add(v.name) - - helper.append_op( - type='py_func', - inputs={'X': x}, - outputs={'Out': out_list}, - attrs={ - 'forward_callable_id': fwd_func_id, - 'backward_callable_id': bwd_func_id, - 'backward_skip_vars': list(backward_skip_vars), - }, - ) - return out - - -# For debug usage -py_func.registered_func = PyFuncRegistry.registered_func -py_func.registered_func_num = PyFuncRegistry.registered_func_num - - def continuous_value_model(input, cvm, use_cvm=True): r""" diff --git a/python/paddle/fluid/tests/unittests/test_py_func_op.py b/python/paddle/fluid/tests/unittests/test_py_func_op.py index 51d7af4993..0f2f9ea1e3 100644 --- a/python/paddle/fluid/tests/unittests/test_py_func_op.py +++ b/python/paddle/fluid/tests/unittests/test_py_func_op.py @@ -94,7 +94,7 @@ def simple_fc_net(img, label, use_py_func_op): shape=hidden.shape, ) ) - hidden = fluid.layers.py_func( + hidden = paddle.static.py_func( func=tanh, x=hidden, out=new_hidden, @@ -111,7 +111,7 @@ def simple_fc_net(img, label, use_py_func_op): .current_block() .create_var(name='loss', dtype='float32', shape=[-1, 1]) ) - loss = fluid.layers.py_func( + loss = paddle.static.py_func( func=cross_entropy, x=[prediction, label], out=loss, @@ -124,11 +124,11 @@ def simple_fc_net(img, label, use_py_func_op): .current_block() .create_var(name='test_tmp_var', dtype='float32', shape=[1]) ) - fluid.layers.py_func( + paddle.static.py_func( func=dummy_func_with_no_input, x=None, out=dummy_var ) loss += dummy_var - fluid.layers.py_func(func=dummy_func_with_no_output, x=loss, out=None) + paddle.static.py_func(func=dummy_func_with_no_output, x=loss, out=None) loss_out = ( fluid.default_main_program() @@ -140,7 +140,7 @@ def simple_fc_net(img, label, use_py_func_op): .current_block() .create_var(dtype='float32', shape=[1]) ) - fluid.layers.py_func( + paddle.static.py_func( func=dummy_func_with_multi_input_output, x=(loss, dummy_var), out=(loss_out, dummy_var_out), @@ -149,7 +149,7 @@ def simple_fc_net(img, label, use_py_func_op): loss == loss_out and dummy_var == dummy_var_out ), "py_func failed with multi input and output" - fluid.layers.py_func( + paddle.static.py_func( func=dummy_func_with_multi_input_output, x=[loss, dummy_var], out=[loss_out, dummy_var_out], diff --git a/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py b/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py index 410708a105..a557fb9df0 100644 --- a/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py +++ b/python/paddle/fluid/tests/unittests/test_rnn_decode_api.py @@ -309,7 +309,7 @@ class PolicyGradient: """ update policy model self.model with policy gradient algorithm """ - self.reward = fluid.layers.py_func( + self.reward = paddle.static.py_func( func=reward_func, x=[action, length], out=reward ) neg_log_prob = layers.cross_entropy(act_prob, action) diff --git a/python/paddle/static/__init__.py b/python/paddle/static/__init__.py index f527b5a1c3..118fe0b58b 100644 --- a/python/paddle/static/__init__.py +++ b/python/paddle/static/__init__.py @@ -16,6 +16,9 @@ from . import amp # noqa: F401 from . import sparsity # noqa: F401 from . import nn # noqa: F401 + +from .nn.common import py_func # noqa: F401 + from .io import save_inference_model # noqa: F401 from .io import load_inference_model # noqa: F401 from .io import deserialize_persistables # noqa: F401 @@ -53,7 +56,6 @@ from ..fluid.framework import Variable # noqa: F401 from ..fluid.framework import ipu_shard_guard # noqa: F401 from ..fluid.framework import set_ipu_shard # noqa: F401 from ..fluid.layers.control_flow import Print # noqa: F401 -from ..fluid.layers.nn import py_func # noqa: F401 from ..fluid.parallel_executor import ParallelExecutor # noqa: F401 from ..fluid.param_attr import WeightNormParamAttr # noqa: F401 from ..fluid.optimizer import ExponentialMovingAverage # noqa: F401 @@ -61,7 +63,6 @@ from ..fluid.io import save # noqa: F401 from ..fluid.io import load # noqa: F401 from ..fluid.io import load_program_state # noqa: F401 from ..fluid.io import set_program_state # noqa: F401 - from ..fluid.io import load_vars # noqa: F401 from ..fluid.io import save_vars # noqa: F401 from ..fluid.io import batch # noqa: F401 diff --git a/python/paddle/static/nn/__init__.py b/python/paddle/static/nn/__init__.py index 449cd478a2..8e3048b21c 100755 --- a/python/paddle/static/nn/__init__.py +++ b/python/paddle/static/nn/__init__.py @@ -20,6 +20,7 @@ from .common import deform_conv2d # noqa: F401 from .common import conv3d # noqa: F401 from .common import conv2d_transpose # noqa: F401 from .common import conv3d_transpose # noqa: F401 +from .common import py_func # noqa: F401 from ...fluid.layers import batch_norm # noqa: F401 from ...fluid.layers import bilinear_tensor_product # noqa: F401 @@ -32,7 +33,6 @@ from ...fluid.layers import layer_norm # noqa: F401 from ...fluid.layers import multi_box_head # noqa: F401 from .loss import nce # noqa: F401 from .common import prelu # noqa: F401 -from ...fluid.layers import py_func # noqa: F401 from ...fluid.layers import row_conv # noqa: F401 from ...fluid.layers import spectral_norm # noqa: F401 from ...fluid.layers import switch_case # noqa: F401 diff --git a/python/paddle/static/nn/common.py b/python/paddle/static/nn/common.py index 9fc2bbd975..a8dec018ff 100755 --- a/python/paddle/static/nn/common.py +++ b/python/paddle/static/nn/common.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect + +import numpy as np + import paddle from paddle.common_ops_import import ( LayerHelper, @@ -19,6 +23,7 @@ from paddle.common_ops_import import ( check_variable_and_dtype, utils, ) +from paddle.fluid import core from paddle.fluid.data_feeder import check_dtype from paddle.fluid.framework import Variable, _non_static_mode, static_only from paddle.fluid.initializer import Constant, Normal @@ -2083,6 +2088,325 @@ def deform_conv2d( ) +class PyFuncRegistry: + _register_funcs = [] + + def __init__(self, func): + if func is None or not callable(func): + raise TypeError('func must be a Python function') + + self._func = func + # find named args using reflection + args = inspect.getfullargspec(self._func) + if len(args[0]) == 0 and args[1] is None and args[2] is None: + # Function with no inputs + self._named_args = None + else: + self._named_args = args[0] + self._id = core._append_python_callable_object_and_return_id(self) + ''' + Why record self here? + + 1. For debug usage. Users can call + :code:`py_func.registered_func(idx)` method + to find the registered function corresponding + to :code:`idx`. + + 2. For increasing reference count of self. + It seems that to release Python object + whose reference count is 1 would cause + segmentation fault error in C++ side. + May be lack of Python GC in C++ side? + ''' + PyFuncRegistry._register_funcs.append(self) + + @classmethod + def registered_func(cls, idx): + return cls._register_funcs[idx]._func + + @classmethod + def registered_func_num(cls): + return len(cls._register_funcs) + + @property + def id(self): + return self._id + + def __call__(self, *args): + if self._named_args is None: + func_ret = self._func() + else: + kwargs = dict() + idx = 0 + for arg in self._named_args: + kwargs[arg] = args[idx] + idx += 1 + func_ret = self._func(*args[idx:], **kwargs) + + if not isinstance(func_ret, (list, tuple)): + func_ret = (func_ret,) + + ret = [] + for each_ret in func_ret: + if each_ret is None or isinstance(each_ret, core.LoDTensor): + ret.append(each_ret) + continue + + if not isinstance(each_ret, np.ndarray): + each_ret = np.array(each_ret) + + tensor = core.LoDTensor() + tensor.set(each_ret, core.CPUPlace()) + ret.append(tensor) + + return tuple(ret) + + +@static_only +@templatedoc() +def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): + """ + This is used to register customized Python OP to Paddle. The design + principe of py_func is that Tensor and numpy array can be converted to each + other easily. So you can use Python and numpy API to register a python OP. + + The forward function of the registered OP is ``func`` and the backward function + of that is ``backward_func``. Paddle will call ``func`` at forward runtime and + call ``backward_func`` at backward runtime(if ``backward_func`` is not None). + ``x`` is the input of ``func``, whose type must be Tensor; ``out`` is + the output of ``func``, whose type can be either Tensor or numpy array. + + The input of the backward function ``backward_func`` is ``x``, ``out`` and + the gradient of ``out``. If ``out`` have no gradient, the relevant input of + ``backward_func`` is None. If ``x`` do not have a gradient, the user should + return None in ``backward_func``. + + The data type and shape of ``out`` should also be set correctly before this + API is called, and the data type and shape of the gradient of ``out`` and + ``x`` will be inferred automatically. + + This API can also be used to debug the neural network by setting the ``func`` + as a function that only print variables. + + Args: + func (callable): The forward function of the registered OP. When the network + is running, the forward output ``out`` will be calculated according to this + function and the forward input ``x``. In ``func`` , it's suggested that we + actively convert Tensor into a numpy array, so that we can use Python and + numpy API arbitrarily. If not, some operations of numpy may not be compatible. + x (Tensor|tuple(Tensor)|list[Tensor]): The input of the forward function ``func``. + It can be Tensor|tuple(Tensor)|list[Tensor]. In addition, Multiple Tensor + should be passed in the form of tuple(Tensor) or list[Tensor]. + out (T|tuple(T)|list[T]): The output of the forward function ``func``, it can be + T|tuple(T)|list[T], where T can be either Tensor or numpy array. Since Paddle + cannot automatically infer the shape and type of ``out``, you must create + ``out`` in advance. + backward_func (callable, optional): The backward function of the registered OP. + Its default value is None, which means there is no reverse calculation. If + it is not None, ``backward_func`` is called to calculate the gradient of + ``x`` when the network is at backward runtime. + skip_vars_in_backward_input (Tensor, optional): It's used to limit the input + list of ``backward_func``, and it can be Tensor|tuple(Tensor)|list[Tensor]. + It must belong to either ``x`` or ``out``. The default value is None, which means + that no tensors need to be removed from ``x`` and ``out``. If it is not None, + these tensors will not be the input of ``backward_func``. This parameter is only + useful when ``backward_func`` is not None. + + Returns: + Tensor|tuple(Tensor)|list[Tensor]: The output ``out`` of the forward function ``func``. + + Examples: + .. code-block:: python + + # example 1: + import paddle + import numpy as np + + paddle.enable_static() + + # Creates a forward function, Tensor can be input directly without + # being converted into numpy array. + def tanh(x): + return np.tanh(x) + + # Skip x in backward function and return the gradient of x + # Tensor must be actively converted to numpy array, otherwise, + # operations such as +/- can't be used. + def tanh_grad(y, dy): + return np.array(dy) * (1 - np.square(np.array(y))) + + # Creates a forward function for debugging running networks(print value) + def debug_func(x): + print(x) + + def create_tmp_var(name, dtype, shape): + return paddle.static.default_main_program().current_block().create_var( + name=name, dtype=dtype, shape=shape) + + def simple_net(img, label): + hidden = img + for idx in range(4): + hidden = paddle.static.nn.fc(hidden, size=200) + new_hidden = create_tmp_var(name='hidden_{}'.format(idx), + dtype=hidden.dtype, shape=hidden.shape) + + # User-defined forward and backward + hidden = paddle.static.py_func(func=tanh, x=hidden, + out=new_hidden, backward_func=tanh_grad, + skip_vars_in_backward_input=hidden) + + # User-defined debug functions that print out the input Tensor + paddle.static.py_func(func=debug_func, x=hidden, out=None) + + prediction = paddle.static.nn.fc(hidden, size=10, activation='softmax') + ce_loss = paddle.nn.loss.CrossEntropyLoss() + return ce_loss(prediction, label) + + x = paddle.static.data(name='x', shape=[1,4], dtype='float32') + y = paddle.static.data(name='y', shape=[1], dtype='int64') + res = simple_net(x, y) + + exe = paddle.static.Executor(paddle.CPUPlace()) + exe.run(paddle.static.default_startup_program()) + input1 = np.random.random(size=[1,4]).astype('float32') + input2 = np.random.randint(1, 10, size=[1], dtype='int64') + out = exe.run(paddle.static.default_main_program(), + feed={'x':input1, 'y':input2}, + fetch_list=[res.name]) + print(out) + + .. code-block:: python + + # example 2: + # This example shows how to turn Tensor into numpy array and + # use numpy API to register an Python OP + import paddle + import numpy as np + + paddle.enable_static() + + def element_wise_add(x, y): + # Tensor must be actively converted to numpy array, otherwise, + # numpy.shape can't be used. + x = np.array(x) + y = np.array(y) + + if x.shape != y.shape: + raise AssertionError("the shape of inputs must be the same!") + + result = np.zeros(x.shape, dtype='int32') + for i in range(len(x)): + for j in range(len(x[0])): + result[i][j] = x[i][j] + y[i][j] + + return result + + def create_tmp_var(name, dtype, shape): + return paddle.static.default_main_program().current_block().create_var( + name=name, dtype=dtype, shape=shape) + + def py_func_demo(): + start_program = paddle.static.default_startup_program() + main_program = paddle.static.default_main_program() + + # Input of the forward function + x = paddle.static.data(name='x', shape=[2,3], dtype='int32') + y = paddle.static.data(name='y', shape=[2,3], dtype='int32') + + # Output of the forward function, name/dtype/shape must be specified + output = create_tmp_var('output','int32', [3,1]) + + # Multiple Variable should be passed in the form of tuple(Variale) or list[Variale] + paddle.static.py_func(func=element_wise_add, x=[x,y], out=output) + + exe=paddle.static.Executor(paddle.CPUPlace()) + exe.run(start_program) + + # Feed numpy array to main_program + input1 = np.random.randint(1, 10, size=[2,3], dtype='int32') + input2 = np.random.randint(1, 10, size=[2,3], dtype='int32') + out = exe.run(main_program, + feed={'x':input1, 'y':input2}, + fetch_list=[output.name]) + print("{0} + {1} = {2}".format(input1, input2, out)) + + py_func_demo() + + # Reference output: + # [[5, 9, 9] + [[7, 8, 4] = [array([[12, 17, 13] + # [7, 5, 2]] [1, 3, 3]] [8, 8, 5]], dtype=int32)] + """ + helper = LayerHelper('py_func', **locals()) + check_type(x, 'X', (list, tuple, Variable, type(None)), 'py_func') + if x is None: + x = [] + elif isinstance(x, Variable): + x = [x] + elif isinstance(x, tuple): + x = list(x) + elif not isinstance(x, (list, tuple, Variable)): + raise TypeError('Input must be Variable/list(Variable)/tuple(Variable)') + check_type(out, 'Out', (list, tuple, Variable, type(None)), 'py_func') + if out is None: + out_list = [] + elif isinstance(out, Variable): + out_list = [out] + elif isinstance(out, tuple): + out_list = list(out) + elif isinstance(out, list): + out_list = out + else: + raise TypeError( + 'Output must be Variable/list(Variable)/tuple(Variable)' + ) + + fwd_func_id = PyFuncRegistry(func).id + bwd_func_id = ( + PyFuncRegistry(backward_func).id if backward_func is not None else -1 + ) + + for each_out in out_list: + if len(each_out.shape) == 0: + raise ValueError( + 'Output shapes of py_func should be provided by users manually' + ) + + backward_skip_vars = set() + if backward_func is not None and skip_vars_in_backward_input is not None: + if isinstance(skip_vars_in_backward_input, Variable): + skip_vars_in_backward_input = [skip_vars_in_backward_input] + + fwd_in_out = [v.name for v in x] + fwd_in_out.extend([v.name for v in out_list]) + fwd_in_out = set(fwd_in_out) + backward_skip_vars = set() + for v in skip_vars_in_backward_input: + if v.name not in fwd_in_out: + raise ValueError( + 'Variable {} is not found in forward inputs and outputs'.format( + v.name + ) + ) + backward_skip_vars.add(v.name) + + helper.append_op( + type='py_func', + inputs={'X': x}, + outputs={'Out': out_list}, + attrs={ + 'forward_callable_id': fwd_func_id, + 'backward_callable_id': bwd_func_id, + 'backward_skip_vars': list(backward_skip_vars), + }, + ) + return out + + +# For debug usage +py_func.registered_func = PyFuncRegistry.registered_func +py_func.registered_func_num = PyFuncRegistry.registered_func_num + + @static_only def prelu(x, mode, param_attr=None, data_format="NCHW", name=None): r""" -- GitLab