From 6ed8221a44229b11e5751425e143a27841672c96 Mon Sep 17 00:00:00 2001 From: cyber-pioneer <116002591+cyber-pioneer@users.noreply.github.com> Date: Fri, 13 Jan 2023 19:18:01 +0800 Subject: [PATCH] New feature: add register composite rule of ops (#49605) --- paddle/phi/api/yaml/op_compat.yaml | 2 + python/CMakeLists.txt | 29 ++++ .../fluid/tests/unittests/CMakeLists.txt | 1 + .../unittests/composite_ops/CMakeLists.txt | 20 +++ .../composite_ops/test_composite_softmax.py | 118 +++++++++++++++++ .../test_composite_softmax_grad.py | 124 ++++++++++++++++++ .../tests/unittests/composite_ops/utils.py | 25 ++++ python/paddle/incubate/autograd/.gitignore | 2 + python/paddle/incubate/autograd/__init__.py | 3 +- .../incubate/autograd/composite_rules.py | 35 +++++ .../incubate/autograd/generate_op_map.py | 115 ++++++++++++++++ python/paddle/incubate/autograd/primapi.py | 17 +++ python/paddle/incubate/autograd/primitives.py | 103 +++++++++++++++ python/paddle/incubate/autograd/primreg.py | 40 ++++++ python/paddle/incubate/autograd/primx.py | 118 +++++++++++++++++ python/paddle/incubate/autograd/utils.py | 60 +++++++++ 16 files changed, 811 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/composite_ops/CMakeLists.txt create mode 100644 python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax.py create mode 100644 python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax_grad.py create mode 100644 python/paddle/fluid/tests/unittests/composite_ops/utils.py create mode 100644 python/paddle/incubate/autograd/.gitignore create mode 100644 python/paddle/incubate/autograd/composite_rules.py create mode 100644 python/paddle/incubate/autograd/generate_op_map.py create mode 100644 python/paddle/incubate/autograd/primitives.py diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 47dc8a0fed..c2b22ba7af 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1288,6 +1288,8 @@ - op : softmax backward : softmax_grad + inputs : + x : X extra : attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false] diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 9523228eaf..76b99d3eea 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -53,6 +53,35 @@ set(FLUID_CORE_DEPS ${FLUID_CORE}) add_custom_target(copy_libpaddle ALL DEPENDS ${FLUID_CORE_DEPS}) +# Standard op(phi op) description is defined in ops.yaml and legacy_ops.yaml. +# When users define composite rules of some nonbasic op, as for defination of args, +# they are supposed to refer to standard op description. However, there exists +# some gap of description between current op and standard ones. So special dictionary +# is needed to record such gap for execution of composite rules. +# Todo: this custom_target will be moved to other place. + +set(ops_yaml_path "${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/ops.yaml") + +set(ops_legacy_yaml_path + "${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/legacy_ops.yaml") + +set(ops_compat_yaml_path + "${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml") + +set(phi_ops_map_path + "${PADDLE_SOURCE_DIR}/python/paddle/incubate/autograd/phi_ops_map.py") + +add_custom_target( + op_map_codegen ALL + COMMAND + "${PYTHON_EXECUTABLE}" + "${PADDLE_SOURCE_DIR}/python/paddle/incubate/autograd/generate_op_map.py" + "--ops_yaml_path=${ops_yaml_path}" + "--ops_legacy_yaml_path=${ops_legacy_yaml_path}" + "--ops_compat_yaml_path=${ops_compat_yaml_path}" + "--phi_ops_map_path=${phi_ops_map_path}" + VERBATIM) + # NOTE(zhiqiu): WHY? # In `setup.py.in`, some dynamic libraries (eg, libxpuapi.so) are modified using # patchelf. In rare cases, if the a linker is linking that dynamic library for diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 1dd97cdccb..6d99deb2bf 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -837,6 +837,7 @@ add_subdirectory(sequence) add_subdirectory(dygraph_to_static) add_subdirectory(rnn) add_subdirectory(autograd) +add_subdirectory(composite_ops) add_subdirectory(distribution) add_subdirectory(prim) diff --git a/python/paddle/fluid/tests/unittests/composite_ops/CMakeLists.txt b/python/paddle/fluid/tests/unittests/composite_ops/CMakeLists.txt new file mode 100644 index 0000000000..2cc4413bb0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/composite_ops/CMakeLists.txt @@ -0,0 +1,20 @@ +file( + GLOB TEST_OPS + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "test_*.py") + +file( + GLOB TEST_OPS_GRAD + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "test_*_grad.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") +string(REPLACE ".py" "" TEST_OPS_GRAD "${TEST_OPS_GRAD}") + +if(WIN32 OR APPLE) + # TODO: Fix these unittests failed on Windows and MAC. + list(REMOVE_ITEM TEST_OPS ${TEST_OPS_GRAD}) +endif() + +foreach(TEST_OP ${TEST_OPS}) + py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) +endforeach() diff --git a/python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax.py b/python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax.py new file mode 100644 index 0000000000..2ac671962c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax.py @@ -0,0 +1,118 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from utils import TOLERANCE + +import paddle +import paddle.nn.functional as F + + +def generate_data(shape, dtype="float32"): + np_data = np.random.random(shape).astype(dtype) + return np_data + + +class Attr: + def __init__(self) -> None: + self.dtype = None + self.axis = -1 + self.shape = None + + def set_dtype(self, dtype) -> None: + self.dtype = dtype + return + + def set_axis(self, axis) -> None: + self.axis = axis + return + + def set_shape(self, shape) -> None: + self.shape = shape + return + + def get_rtol(self, flag): + rtol = TOLERANCE[self.dtype][flag].get("rtol") + return rtol + + def get_atol(self, flag): + atol = TOLERANCE[self.dtype][flag].get("atol") + return atol + + +attrs = Attr() + + +def fn(x): + y = paddle.tan(x) + return F.softmax(y, axis=attrs.axis, dtype=attrs.dtype) + + +def expect_forward(inputs): + return fn(inputs) + + +class TestCompositeSoftmax(unittest.TestCase): + def setUp(self): + self.dtypes = ["float32", "float64"] + self.shapes = [[2, 3, 4], [2, 3]] + self.axes = [-1, 0, 1] + + def cal_composite(self, inputs): + paddle.enable_static() + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + 'x', shape=inputs.shape, dtype=str(inputs.dtype) + ) + y = fn(x) + blocks = main_program.blocks + paddle.incubate.autograd.to_prim(blocks) + + exe = paddle.static.Executor() + exe.run(startup_program) + res = exe.run(main_program, feed={'x': inputs}, fetch_list=[y]) + paddle.disable_static() + return res + + def compare_forward(self): + np_data = generate_data(attrs.shape) + tensor_data = paddle.to_tensor(np_data) + + expect = expect_forward(tensor_data).numpy() + actual = self.cal_composite(np_data)[0] + + assert expect.dtype == actual.dtype + assert np.allclose( + expect, + actual, + rtol=attrs.get_rtol("forward"), + atol=attrs.get_atol("forward"), + ) + + def test_forward(self): + for i in self.axes: + for j in self.dtypes: + for t in self.shapes: + attrs.set_axis(i) + attrs.set_dtype(j) + attrs.set_shape(t) + self.compare_forward() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax_grad.py b/python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax_grad.py new file mode 100644 index 0000000000..c47399ba5a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax_grad.py @@ -0,0 +1,124 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from utils import TOLERANCE + +import paddle +import paddle.nn.functional as F + + +def generate_data(shape, dtype="float32"): + np_data = np.random.random(shape).astype(dtype) + return np_data + + +class Attr: + def __init__(self) -> None: + self.dtype = None + self.axis = -1 + self.shape = None + + def set_dtype(self, dtype) -> None: + self.dtype = dtype + return + + def set_axis(self, axis) -> None: + self.axis = axis + return + + def set_shape(self, shape) -> None: + self.shape = shape + return + + def get_rtol(self, flag): + rtol = TOLERANCE[self.dtype][flag].get("rtol") + return rtol + + def get_atol(self, flag): + atol = TOLERANCE[self.dtype][flag].get("atol") + return atol + + +attrs = Attr() + + +def fn(x): + y = paddle.tan(x) + return F.softmax(y, axis=attrs.axis, dtype=attrs.dtype) + + +def expect_grad(inputs): + inputs.stop_gradient = False + res = fn(inputs) + + gradients = paddle.grad(res, inputs) + return gradients + + +class TestCompositeSoftmax(unittest.TestCase): + def setUp(self): + self.dtypes = ["float32", "float64"] + self.shapes = [[2, 3, 4], [2, 3]] + self.axes = [-1, 0, 1] + + def cal_composite_grad(self, inputs): + paddle.enable_static() + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + 'x', shape=inputs.shape, dtype=str(inputs.dtype) + ) + x.stop_gradient = False + y = fn(x) + blocks = main_program.blocks + paddle.incubate.autograd.to_prim(blocks) + z = paddle.static.gradients([y], x) + + exe = paddle.static.Executor() + exe.run(startup_program) + res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z]) + paddle.disable_static() + return res + + def compare_backward(self): + np_data = generate_data(attrs.shape) + tensor_data = paddle.to_tensor(np_data) + + expect = expect_grad(tensor_data)[0].numpy() + actual = self.cal_composite_grad(np_data)[0] + + assert expect.dtype == actual.dtype + assert np.allclose( + expect, + actual, + rtol=attrs.get_rtol("backward"), + atol=attrs.get_atol("backward"), + ) + + def test_backward(self): + for i in self.axes: + for j in self.dtypes: + for t in self.shapes: + attrs.set_axis(i) + attrs.set_dtype(j) + attrs.set_shape(t) + self.compare_backward() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/composite_ops/utils.py b/python/paddle/fluid/tests/unittests/composite_ops/utils.py new file mode 100644 index 0000000000..c43f79d1c0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/composite_ops/utils.py @@ -0,0 +1,25 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +TOLERANCE = { + "float32": { + "forward": {"rtol": 1e-6, "atol": 1e-6}, + "backward": {"rtol": 1e-6, "atol": 1e-6}, + }, + "float64": { + "forward": {"rtol": 1e-7, "atol": 1e-7}, + "backward": {"rtol": 1e-7, "atol": 1e-7}, + }, +} diff --git a/python/paddle/incubate/autograd/.gitignore b/python/paddle/incubate/autograd/.gitignore new file mode 100644 index 0000000000..27e033cf8f --- /dev/null +++ b/python/paddle/incubate/autograd/.gitignore @@ -0,0 +1,2 @@ +# this file is generated during build system generation +phi_ops_map.py diff --git a/python/paddle/incubate/autograd/__init__.py b/python/paddle/incubate/autograd/__init__.py index d9b9e41781..3e73ff571e 100644 --- a/python/paddle/incubate/autograd/__init__.py +++ b/python/paddle/incubate/autograd/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from .functional import Hessian, Jacobian, jvp, vjp -from .primapi import forward_grad, grad +from .primapi import forward_grad, grad, to_prim from .primx import prim2orig from .utils import disable_prim, enable_prim, prim_enabled @@ -25,4 +25,5 @@ __all__ = [ # noqa 'disable_prim', 'forward_grad', 'grad', + 'to_prim', ] diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py new file mode 100644 index 0000000000..456ac20db2 --- /dev/null +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -0,0 +1,35 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file contains composite rules of nonbasic operations. There are some notes: +# 1. When define composite rule of some op, you can only use primitive ops defined in primitives.py. +# 2. The name and args of target op must be corresponding with standard description of op in +# ops.yaml or legacy_ops.yaml. + +from .primitives import * # noqa: F403 +from .primreg import REGISTER_COMPOSITE, lookup_composite + + +def _composite(op, *args): + _lowerrule = lookup_composite(op.type) + return _lowerrule(op, *args) + + +@REGISTER_COMPOSITE('softmax') +def softmax_composite(x, axis): + """define composite rule of op softmax""" + molecular = exp(x) + denominator = broadcast_to(sum(molecular, axis=axis, keepdim=True), x.shape) + res = divide(molecular, denominator) + return res diff --git a/python/paddle/incubate/autograd/generate_op_map.py b/python/paddle/incubate/autograd/generate_op_map.py new file mode 100644 index 0000000000..45784ad950 --- /dev/null +++ b/python/paddle/incubate/autograd/generate_op_map.py @@ -0,0 +1,115 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Notice: This file will be automatically executed during building of whole paddle project. +# You can also run this file separately if you want to preview generated file without building. + +import argparse +import json +import re + +import yaml + + +def ParseArguments(): + parser = argparse.ArgumentParser( + description='prim ops Code Generator Args Parser' + ) + parser.add_argument('--ops_yaml_path', type=str, help="path to ops.yaml") + parser.add_argument( + '--ops_legacy_yaml_path', type=str, help="path to legacy_ops.yaml" + ) + parser.add_argument( + '--ops_compat_yaml_path', type=str, help="path to op_compat.yaml" + ) + parser.add_argument( + '--phi_ops_map_path', + type=str, + default="./phi_ops_map.py", + help='path to target phi_ops_map.py', + ) + + args = parser.parse_args() + return args + + +def _trans_value_type(item): + for key in item.keys(): + for subkey in item[key]: + value = str(item[key][subkey]) + item[key][subkey] = value + + +def generate_code( + ops_yaml_path, ops_legacy_yaml_path, ops_compat_yaml_path, phi_ops_map_path +): + """ + Generate dictiorary and save to file phi_ops_map.py. The target file records gap + of description between current op and standard ones. + """ + for op_path in [ops_yaml_path, ops_legacy_yaml_path]: + pattern = re.compile(r'[(](.*)[)]', re.S) + with open(op_path, "rt") as f: + ops = yaml.safe_load(f) + dct = {} + for item in ops: + key = item['op'] + if key in dct: + raise ValueError(f"There already exists op {key}") + dct[key] = { + "args": re.findall(pattern, item["args"])[0], + "output": item["output"], + } + + with open(ops_compat_yaml_path, "rt") as f: + ops_compat = yaml.safe_load(f) + map_dct = {} + for item in ops_compat: + key = item['op'] + if key.endswith(")"): + tmp = re.match("(.*)\\((.*)\\)", key.replace(" ", "")) + phi_name, op_name = tmp.group(1), tmp.group(2) + map_dct[op_name] = {"phi_name": phi_name} + else: + op_name = key + map_dct[op_name] = {"phi_name": op_name} + for element in ["inputs", "attrs"]: + if element in item.keys(): + map_dct[op_name][element] = item[element] + for element in ["scalar", "int_array"]: + if element in item.keys(): + _trans_value_type(item[element]) + map_dct[op_name][element] = item[element] + + with open(phi_ops_map_path, "w") as f: + f.write("op_map = ") + json.dump(map_dct, f, indent=4) + f.write('\n') + f.write("op_info = ") + json.dump(dct, f, indent=4) + f.write('\n') + + +if __name__ == "__main__": + args = ParseArguments() + ops_yaml_path = args.ops_yaml_path + ops_legacy_yaml_path = args.ops_legacy_yaml_path + ops_compat_yaml_path = args.ops_compat_yaml_path + phi_ops_map_path = args.phi_ops_map_path + generate_code( + ops_yaml_path, + ops_legacy_yaml_path, + ops_compat_yaml_path, + phi_ops_map_path, + ) diff --git a/python/paddle/incubate/autograd/primapi.py b/python/paddle/incubate/autograd/primapi.py index 0cd6898380..7cfabdd9e5 100644 --- a/python/paddle/incubate/autograd/primapi.py +++ b/python/paddle/incubate/autograd/primapi.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import typing +import paddle from paddle.fluid import backward, framework from paddle.incubate.autograd import primx, utils @@ -211,3 +213,18 @@ def grad(outputs, inputs, grad_outputs=None): ad.erase_dots(xs_dot) return xs_bar[0] if isinstance(inputs, framework.Variable) else xs_bar + + +@framework.static_only +def to_prim(blocks): + """Search nonbasic ops which have be registered composite rules and replace them with primitive ops.""" + if isinstance(blocks, paddle.fluid.framework.Block): + logging.info("Atomize composite op to primitive ops begin.") + primx._lower_composite(blocks) + return + elif isinstance(blocks, typing.Sequence): + for item in blocks: + to_prim(item) + return + else: + raise TypeError diff --git a/python/paddle/incubate/autograd/primitives.py b/python/paddle/incubate/autograd/primitives.py new file mode 100644 index 0000000000..371746bf34 --- /dev/null +++ b/python/paddle/incubate/autograd/primitives.py @@ -0,0 +1,103 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid.layers.tensor import cast # noqa: F401 +from paddle.tensor import abs # noqa: F401 +from paddle.tensor import acos # noqa: F401 +from paddle.tensor import acosh # noqa: F401 +from paddle.tensor import add # noqa: F401 +from paddle.tensor import asin # noqa: F401 +from paddle.tensor import asinh # noqa: F401 +from paddle.tensor import atan # noqa: F401 +from paddle.tensor import atanh # noqa: F401 +from paddle.tensor import broadcast_shape # noqa: F401 +from paddle.tensor import broadcast_to # noqa: F401 +from paddle.tensor import cos # noqa: F401 +from paddle.tensor import cosh # noqa: F401 +from paddle.tensor import cumprod # noqa: F401 +from paddle.tensor import cumsum # noqa: F401 +from paddle.tensor import digamma # noqa: F401 +from paddle.tensor import divide # noqa: F401 +from paddle.tensor import erf # noqa: F401 +from paddle.tensor import erfinv # noqa: F401 +from paddle.tensor import exp # noqa: F401 +from paddle.tensor import expm1 # noqa: F401 +from paddle.tensor import lgamma # noqa: F401 +from paddle.tensor import log # noqa: F401 +from paddle.tensor import log1p # noqa: F401 +from paddle.tensor import logcumsumexp # noqa: F401 +from paddle.tensor import logit # noqa: F401 +from paddle.tensor import logsumexp # noqa: F401 +from paddle.tensor import multiply # noqa: F401 +from paddle.tensor import pow # noqa: F401 +from paddle.tensor import prod # noqa: F401 +from paddle.tensor import sign # noqa: F401 +from paddle.tensor import sin # noqa: F401 +from paddle.tensor import sinh # noqa: F401 +from paddle.tensor import subtract # noqa: F401 +from paddle.tensor import sum # noqa: F401 +from paddle.tensor import tan # noqa: F401 +from paddle.tensor import tanh # noqa: F401 + +math_op = [ + 'add', + 'subtract', + 'multiply', + 'divide', + 'abs', + 'pow', + 'sign', + 'sum', + 'prod', + 'cumsum', + 'cumprod', + 'digamma', + 'lgamma', + 'erf', + 'erfinv', + 'exp', + 'expm1', + 'log', + 'log1p', + 'logsumexp', + 'logcumsumexp', + 'logit', +] + +trigonometric_op = [ + 'sin', + 'cos', + 'tan', + 'sinh', + 'cosh', + 'tanh', + 'asin', + 'acos', + 'atan', + 'asinh', + 'acosh', + 'atanh', +] + +others = [ + 'cast', + 'broadcast_to', +] + +__all__ = [] +__all__.extend(math_op) +__all__.extend(trigonometric_op) +__all__.extend(others) + +__all__.sort() diff --git a/python/paddle/incubate/autograd/primreg.py b/python/paddle/incubate/autograd/primreg.py index cce8c49eb4..05b7ea7812 100644 --- a/python/paddle/incubate/autograd/primreg.py +++ b/python/paddle/incubate/autograd/primreg.py @@ -38,6 +38,7 @@ _prim2orig = Registry('prim2orig') _primop_jvp = Registry('primop_jvp') _primop_transpose = Registry('primop_transpose') _primop_position_argnames = Registry('primop_position_argnames') +_composite_ops = Registry('composite') def lookup_fn(optype): @@ -60,6 +61,10 @@ def lookup_transpose(optype): return _primop_transpose.lookup(optype) +def lookup_composite(optype): + return _composite_ops.lookup(optype) + + def op_position_inputs(op): """ Returns the position inputs of `op` as registered with REGISTER_FN. @@ -200,6 +205,41 @@ def REGISTER_ORIG2PRIM(op_type): return wrapper +def REGISTER_COMPOSITE(op_type): + """ + Decorator for registering the lower function for an original op into sequence of primitive ops. + + Args: + op_type(str): The op name + + Returns: + wrapper: Inner wrapper function + + Examples: + .. code-block:: python + @REGISTER_COMPOSITE('softmax') + def softmax_composite(x, axis): + molecular = exp(x) + denominator = broadcast_to(sum(molecular, axis=axis, keepdim=True), x.shape) + res = divide(molecular, denominator) + return res + + """ + if not isinstance(op_type, str): + raise TypeError(f'op_type must be str, but got {type(op_type)}.') + + def wrapper(f): + def _lower(op, *args, **kwargs): + assert ( + op.type == op_type + ), f'op.type should be equal to op_type, but op.type is {op.type} and op_type is {op_type}' + return f(*args, **kwargs) + + _composite_ops.register(op_type, _lower) + + return wrapper + + def REGISTER_PRIM2ORIG(op_type): """ Decorator for registering the lower function for an primitive op into sequence of original ops. diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index 08489068de..6f2d4d9d52 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -19,8 +19,10 @@ from paddle.fluid import framework as framework from paddle.fluid.framework import Operator, default_main_program from paddle.incubate.autograd.utils import as_tensors +from .composite_rules import _composite from .primops import add, fill_const from .primreg import ( + lookup_composite, lookup_orig2prim, lookup_prim2orig, op_position_inputs, @@ -32,6 +34,7 @@ from .utils import ( flatten_and_remove_none, get_input_var_list, get_output_var_list, + prepare_python_api_arguments, ) @@ -543,6 +546,121 @@ def _lower(block, reverse, blacklist): block._sync_with_cpp() +def _lower_composite(block, blacklist=[]): + # Some functions which are only used in _lower. + def bind(args, to_bind, value_table): + for i in range(len(args)): + if isinstance(args[i], list): + bind(args[i], to_bind, value_table) + if not isinstance(args[i], paddle.fluid.framework.Variable): + continue + elif args[i] is not None and args[i].name in to_bind: + args[i] = value_table[to_bind[args[i].name]] + + def bind_name(names, to_bind): + return_list = [] + for name in names: + if isinstance(name, list): + return_list.append(bind_name(name, to_bind)) + else: + return_list.append(to_bind[name] if name in to_bind else name) + return return_list + + def expand_nested_list(xs): + return_list = [] + for x in xs: + if isinstance(x, list): + return_list = return_list + expand_nested_list(x) + else: + return_list.append(x) + return return_list + + # Step1: Do some preparatory work for lower + lower_fn = _composite + lookup_fn = lookup_composite + + value_table = {} + to_bind = {} + to_bind_rev = {} + for var in block.desc.all_vars(): + value_table[var.name()] = block.var(var.name()) + + ops_to_remove = [] + vars_to_remove = set() + + # Step2: Process all ops in the target block + for op_idx in range(len(block.ops)): + op = block.ops[op_idx] + ops_to_remove.append(op_idx) + if lookup_fn(op.type) is not None and op.type not in blacklist: + input_args = prepare_python_api_arguments(op) + bind(input_args, to_bind, value_table) + + for orig_out, new_out in zip( + expand_nested_list(get_output_var_list(op)), + expand_nested_list(as_tensors(lower_fn(op, *input_args))), + ): + assert not (orig_out is None) ^ ( + new_out is None + ), "orig_out and new_out should match." + vars_to_remove.add(new_out.name) + value_table[new_out.name] = new_out + to_bind[orig_out.name] = new_out.name + to_bind_rev[new_out.name] = orig_out.name + else: + inputs = {} + for i in range(len(op.input_names)): + inputs[op.input_names[i]] = bind_name( + op.input(op.input_names[i]), to_bind + ) + + outputs = {} + for i in range(len(op.output_names)): + outputs[op.output_names[i]] = op.output(op.output_names[i]) + + attrs = {} + for name in sorted(op.attr_names): + attrs[name] = op.attr(name) + from paddle.fluid.dygraph.base import param_guard + + new_op_desc = block.desc.append_op() + with param_guard(inputs), param_guard(outputs): + op = Operator( + block=block, + desc=new_op_desc, + type=op.type, + inputs=inputs, + outputs=outputs, + attrs=attrs, + ) + block.ops.append(op) + + # Step3: Do some post-processing work + for op_idx in reversed(ops_to_remove): + block.desc._remove_op(op_idx, op_idx + 1) + del block.ops[op_idx] + block._sync_with_cpp() + + for op_idx in range(len(block.ops)): + op = block.ops[op_idx] + for in_name in op.input_arg_names: + if in_name in to_bind_rev: + op._rename_input(in_name, to_bind_rev[in_name]) + + for out_name in op.output_arg_names: + if out_name in to_bind_rev: + op._rename_output(out_name, to_bind_rev[out_name]) + + for var_name in sorted(vars_to_remove): + assert ( + var_name in to_bind_rev + ), 'var_name "{}" is not in to_bind_rev.'.format(var_name) + if var_name != to_bind_rev[var_name]: + block.desc._remove_var(var_name.encode()) + del block.vars[var_name] + block._sync_with_cpp() + + @framework.static_only def orig2prim(block=None): """ diff --git a/python/paddle/incubate/autograd/utils.py b/python/paddle/incubate/autograd/utils.py index b5f93ebe97..211851160b 100644 --- a/python/paddle/incubate/autograd/utils.py +++ b/python/paddle/incubate/autograd/utils.py @@ -16,6 +16,8 @@ import typing import paddle from paddle.fluid import framework as framework +from .phi_ops_map import op_info, op_map + class PrimOption: def __init__(self): @@ -148,6 +150,64 @@ def get_input_var_list(op): ] +def _solve_arg(item): + if "=" not in item: + res = item + else: + res = item.split('=')[0] + [arg_type, arg_name] = res.strip().split() + return arg_type.strip(), arg_name.strip() + + +def _get_args_values(op, phi_name): + "get attrs' values for api args' values" + args = op_info[phi_name] + args_list = args["args"].split(",") + inputs = [] + attrs = [] + for item in args_list: + arg_type, arg_name = _solve_arg(item) + op_content = op_map[op.type] + if arg_type in ("Tensor", "Tensor[]"): + if ( + "inputs" in op_content.keys() + and arg_name in op_content["inputs"].keys() + ): + inputs.append(op_content["inputs"][arg_name]) + else: + inputs.append(arg_name) + else: + op_content = op_map[op.type] + if ( + "attrs" in op_content.keys() + and arg_name in op_content["attrs"].keys() + ): + attrs.append(op.attr(op_content["attrs"][arg_name])) + attrs.append(op.attr(arg_name)) + + return inputs, attrs + + +def prepare_python_api_arguments(op): + """ + Generate all args inputs of composite op. Because inputs of composite op is + the same as phi op desribed in ops.yaml. So we need to map origin op to phi op + and then push input data and attrs of origin op to correspondng phi op. + """ + if op.input_names is None: + return [] + else: + if op.type in op_map: + phi_name = op_map[op.type]["phi_name"] + else: + phi_name = op.type + inputs, attrs = _get_args_values(op, phi_name) + res = [get_var_block(op.block, op.input(n)) for n in inputs] + if attrs: + res.extend(attrs) + return res + + def get_output_var_list(op): if op.output_names is None: return [] -- GitLab