未验证 提交 6ed8221a 编写于 作者: C cyber-pioneer 提交者: GitHub

New feature: add register composite rule of ops (#49605)

上级 a48b8e2c
......@@ -1288,6 +1288,8 @@
- op : softmax
backward : softmax_grad
inputs :
x : X
extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false]
......
......@@ -53,6 +53,35 @@ set(FLUID_CORE_DEPS ${FLUID_CORE})
add_custom_target(copy_libpaddle ALL DEPENDS ${FLUID_CORE_DEPS})
# Standard op(phi op) description is defined in ops.yaml and legacy_ops.yaml.
# When users define composite rules of some nonbasic op, as for defination of args,
# they are supposed to refer to standard op description. However, there exists
# some gap of description between current op and standard ones. So special dictionary
# is needed to record such gap for execution of composite rules.
# Todo: this custom_target will be moved to other place.
set(ops_yaml_path "${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/ops.yaml")
set(ops_legacy_yaml_path
"${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/legacy_ops.yaml")
set(ops_compat_yaml_path
"${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml")
set(phi_ops_map_path
"${PADDLE_SOURCE_DIR}/python/paddle/incubate/autograd/phi_ops_map.py")
add_custom_target(
op_map_codegen ALL
COMMAND
"${PYTHON_EXECUTABLE}"
"${PADDLE_SOURCE_DIR}/python/paddle/incubate/autograd/generate_op_map.py"
"--ops_yaml_path=${ops_yaml_path}"
"--ops_legacy_yaml_path=${ops_legacy_yaml_path}"
"--ops_compat_yaml_path=${ops_compat_yaml_path}"
"--phi_ops_map_path=${phi_ops_map_path}"
VERBATIM)
# NOTE(zhiqiu): WHY?
# In `setup.py.in`, some dynamic libraries (eg, libxpuapi.so) are modified using
# patchelf. In rare cases, if the a linker is linking that dynamic library for
......
......@@ -837,6 +837,7 @@ add_subdirectory(sequence)
add_subdirectory(dygraph_to_static)
add_subdirectory(rnn)
add_subdirectory(autograd)
add_subdirectory(composite_ops)
add_subdirectory(distribution)
add_subdirectory(prim)
......
file(
GLOB TEST_OPS
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"test_*.py")
file(
GLOB TEST_OPS_GRAD
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"test_*_grad.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
string(REPLACE ".py" "" TEST_OPS_GRAD "${TEST_OPS_GRAD}")
if(WIN32 OR APPLE)
# TODO: Fix these unittests failed on Windows and MAC.
list(REMOVE_ITEM TEST_OPS ${TEST_OPS_GRAD})
endif()
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from utils import TOLERANCE
import paddle
import paddle.nn.functional as F
def generate_data(shape, dtype="float32"):
np_data = np.random.random(shape).astype(dtype)
return np_data
class Attr:
def __init__(self) -> None:
self.dtype = None
self.axis = -1
self.shape = None
def set_dtype(self, dtype) -> None:
self.dtype = dtype
return
def set_axis(self, axis) -> None:
self.axis = axis
return
def set_shape(self, shape) -> None:
self.shape = shape
return
def get_rtol(self, flag):
rtol = TOLERANCE[self.dtype][flag].get("rtol")
return rtol
def get_atol(self, flag):
atol = TOLERANCE[self.dtype][flag].get("atol")
return atol
attrs = Attr()
def fn(x):
y = paddle.tan(x)
return F.softmax(y, axis=attrs.axis, dtype=attrs.dtype)
def expect_forward(inputs):
return fn(inputs)
class TestCompositeSoftmax(unittest.TestCase):
def setUp(self):
self.dtypes = ["float32", "float64"]
self.shapes = [[2, 3, 4], [2, 3]]
self.axes = [-1, 0, 1]
def cal_composite(self, inputs):
paddle.enable_static()
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
'x', shape=inputs.shape, dtype=str(inputs.dtype)
)
y = fn(x)
blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks)
exe = paddle.static.Executor()
exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[y])
paddle.disable_static()
return res
def compare_forward(self):
np_data = generate_data(attrs.shape)
tensor_data = paddle.to_tensor(np_data)
expect = expect_forward(tensor_data).numpy()
actual = self.cal_composite(np_data)[0]
assert expect.dtype == actual.dtype
assert np.allclose(
expect,
actual,
rtol=attrs.get_rtol("forward"),
atol=attrs.get_atol("forward"),
)
def test_forward(self):
for i in self.axes:
for j in self.dtypes:
for t in self.shapes:
attrs.set_axis(i)
attrs.set_dtype(j)
attrs.set_shape(t)
self.compare_forward()
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from utils import TOLERANCE
import paddle
import paddle.nn.functional as F
def generate_data(shape, dtype="float32"):
np_data = np.random.random(shape).astype(dtype)
return np_data
class Attr:
def __init__(self) -> None:
self.dtype = None
self.axis = -1
self.shape = None
def set_dtype(self, dtype) -> None:
self.dtype = dtype
return
def set_axis(self, axis) -> None:
self.axis = axis
return
def set_shape(self, shape) -> None:
self.shape = shape
return
def get_rtol(self, flag):
rtol = TOLERANCE[self.dtype][flag].get("rtol")
return rtol
def get_atol(self, flag):
atol = TOLERANCE[self.dtype][flag].get("atol")
return atol
attrs = Attr()
def fn(x):
y = paddle.tan(x)
return F.softmax(y, axis=attrs.axis, dtype=attrs.dtype)
def expect_grad(inputs):
inputs.stop_gradient = False
res = fn(inputs)
gradients = paddle.grad(res, inputs)
return gradients
class TestCompositeSoftmax(unittest.TestCase):
def setUp(self):
self.dtypes = ["float32", "float64"]
self.shapes = [[2, 3, 4], [2, 3]]
self.axes = [-1, 0, 1]
def cal_composite_grad(self, inputs):
paddle.enable_static()
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
'x', shape=inputs.shape, dtype=str(inputs.dtype)
)
x.stop_gradient = False
y = fn(x)
blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks)
z = paddle.static.gradients([y], x)
exe = paddle.static.Executor()
exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z])
paddle.disable_static()
return res
def compare_backward(self):
np_data = generate_data(attrs.shape)
tensor_data = paddle.to_tensor(np_data)
expect = expect_grad(tensor_data)[0].numpy()
actual = self.cal_composite_grad(np_data)[0]
assert expect.dtype == actual.dtype
assert np.allclose(
expect,
actual,
rtol=attrs.get_rtol("backward"),
atol=attrs.get_atol("backward"),
)
def test_backward(self):
for i in self.axes:
for j in self.dtypes:
for t in self.shapes:
attrs.set_axis(i)
attrs.set_dtype(j)
attrs.set_shape(t)
self.compare_backward()
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
TOLERANCE = {
"float32": {
"forward": {"rtol": 1e-6, "atol": 1e-6},
"backward": {"rtol": 1e-6, "atol": 1e-6},
},
"float64": {
"forward": {"rtol": 1e-7, "atol": 1e-7},
"backward": {"rtol": 1e-7, "atol": 1e-7},
},
}
# this file is generated during build system generation
phi_ops_map.py
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .functional import Hessian, Jacobian, jvp, vjp
from .primapi import forward_grad, grad
from .primapi import forward_grad, grad, to_prim
from .primx import prim2orig
from .utils import disable_prim, enable_prim, prim_enabled
......@@ -25,4 +25,5 @@ __all__ = [ # noqa
'disable_prim',
'forward_grad',
'grad',
'to_prim',
]
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file contains composite rules of nonbasic operations. There are some notes:
# 1. When define composite rule of some op, you can only use primitive ops defined in primitives.py.
# 2. The name and args of target op must be corresponding with standard description of op in
# ops.yaml or legacy_ops.yaml.
from .primitives import * # noqa: F403
from .primreg import REGISTER_COMPOSITE, lookup_composite
def _composite(op, *args):
_lowerrule = lookup_composite(op.type)
return _lowerrule(op, *args)
@REGISTER_COMPOSITE('softmax')
def softmax_composite(x, axis):
"""define composite rule of op softmax"""
molecular = exp(x)
denominator = broadcast_to(sum(molecular, axis=axis, keepdim=True), x.shape)
res = divide(molecular, denominator)
return res
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Notice: This file will be automatically executed during building of whole paddle project.
# You can also run this file separately if you want to preview generated file without building.
import argparse
import json
import re
import yaml
def ParseArguments():
parser = argparse.ArgumentParser(
description='prim ops Code Generator Args Parser'
)
parser.add_argument('--ops_yaml_path', type=str, help="path to ops.yaml")
parser.add_argument(
'--ops_legacy_yaml_path', type=str, help="path to legacy_ops.yaml"
)
parser.add_argument(
'--ops_compat_yaml_path', type=str, help="path to op_compat.yaml"
)
parser.add_argument(
'--phi_ops_map_path',
type=str,
default="./phi_ops_map.py",
help='path to target phi_ops_map.py',
)
args = parser.parse_args()
return args
def _trans_value_type(item):
for key in item.keys():
for subkey in item[key]:
value = str(item[key][subkey])
item[key][subkey] = value
def generate_code(
ops_yaml_path, ops_legacy_yaml_path, ops_compat_yaml_path, phi_ops_map_path
):
"""
Generate dictiorary and save to file phi_ops_map.py. The target file records gap
of description between current op and standard ones.
"""
for op_path in [ops_yaml_path, ops_legacy_yaml_path]:
pattern = re.compile(r'[(](.*)[)]', re.S)
with open(op_path, "rt") as f:
ops = yaml.safe_load(f)
dct = {}
for item in ops:
key = item['op']
if key in dct:
raise ValueError(f"There already exists op {key}")
dct[key] = {
"args": re.findall(pattern, item["args"])[0],
"output": item["output"],
}
with open(ops_compat_yaml_path, "rt") as f:
ops_compat = yaml.safe_load(f)
map_dct = {}
for item in ops_compat:
key = item['op']
if key.endswith(")"):
tmp = re.match("(.*)\\((.*)\\)", key.replace(" ", ""))
phi_name, op_name = tmp.group(1), tmp.group(2)
map_dct[op_name] = {"phi_name": phi_name}
else:
op_name = key
map_dct[op_name] = {"phi_name": op_name}
for element in ["inputs", "attrs"]:
if element in item.keys():
map_dct[op_name][element] = item[element]
for element in ["scalar", "int_array"]:
if element in item.keys():
_trans_value_type(item[element])
map_dct[op_name][element] = item[element]
with open(phi_ops_map_path, "w") as f:
f.write("op_map = ")
json.dump(map_dct, f, indent=4)
f.write('\n')
f.write("op_info = ")
json.dump(dct, f, indent=4)
f.write('\n')
if __name__ == "__main__":
args = ParseArguments()
ops_yaml_path = args.ops_yaml_path
ops_legacy_yaml_path = args.ops_legacy_yaml_path
ops_compat_yaml_path = args.ops_compat_yaml_path
phi_ops_map_path = args.phi_ops_map_path
generate_code(
ops_yaml_path,
ops_legacy_yaml_path,
ops_compat_yaml_path,
phi_ops_map_path,
)
......@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import typing
import paddle
from paddle.fluid import backward, framework
from paddle.incubate.autograd import primx, utils
......@@ -211,3 +213,18 @@ def grad(outputs, inputs, grad_outputs=None):
ad.erase_dots(xs_dot)
return xs_bar[0] if isinstance(inputs, framework.Variable) else xs_bar
@framework.static_only
def to_prim(blocks):
"""Search nonbasic ops which have be registered composite rules and replace them with primitive ops."""
if isinstance(blocks, paddle.fluid.framework.Block):
logging.info("Atomize composite op to primitive ops begin.")
primx._lower_composite(blocks)
return
elif isinstance(blocks, typing.Sequence):
for item in blocks:
to_prim(item)
return
else:
raise TypeError
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.layers.tensor import cast # noqa: F401
from paddle.tensor import abs # noqa: F401
from paddle.tensor import acos # noqa: F401
from paddle.tensor import acosh # noqa: F401
from paddle.tensor import add # noqa: F401
from paddle.tensor import asin # noqa: F401
from paddle.tensor import asinh # noqa: F401
from paddle.tensor import atan # noqa: F401
from paddle.tensor import atanh # noqa: F401
from paddle.tensor import broadcast_shape # noqa: F401
from paddle.tensor import broadcast_to # noqa: F401
from paddle.tensor import cos # noqa: F401
from paddle.tensor import cosh # noqa: F401
from paddle.tensor import cumprod # noqa: F401
from paddle.tensor import cumsum # noqa: F401
from paddle.tensor import digamma # noqa: F401
from paddle.tensor import divide # noqa: F401
from paddle.tensor import erf # noqa: F401
from paddle.tensor import erfinv # noqa: F401
from paddle.tensor import exp # noqa: F401
from paddle.tensor import expm1 # noqa: F401
from paddle.tensor import lgamma # noqa: F401
from paddle.tensor import log # noqa: F401
from paddle.tensor import log1p # noqa: F401
from paddle.tensor import logcumsumexp # noqa: F401
from paddle.tensor import logit # noqa: F401
from paddle.tensor import logsumexp # noqa: F401
from paddle.tensor import multiply # noqa: F401
from paddle.tensor import pow # noqa: F401
from paddle.tensor import prod # noqa: F401
from paddle.tensor import sign # noqa: F401
from paddle.tensor import sin # noqa: F401
from paddle.tensor import sinh # noqa: F401
from paddle.tensor import subtract # noqa: F401
from paddle.tensor import sum # noqa: F401
from paddle.tensor import tan # noqa: F401
from paddle.tensor import tanh # noqa: F401
math_op = [
'add',
'subtract',
'multiply',
'divide',
'abs',
'pow',
'sign',
'sum',
'prod',
'cumsum',
'cumprod',
'digamma',
'lgamma',
'erf',
'erfinv',
'exp',
'expm1',
'log',
'log1p',
'logsumexp',
'logcumsumexp',
'logit',
]
trigonometric_op = [
'sin',
'cos',
'tan',
'sinh',
'cosh',
'tanh',
'asin',
'acos',
'atan',
'asinh',
'acosh',
'atanh',
]
others = [
'cast',
'broadcast_to',
]
__all__ = []
__all__.extend(math_op)
__all__.extend(trigonometric_op)
__all__.extend(others)
__all__.sort()
......@@ -38,6 +38,7 @@ _prim2orig = Registry('prim2orig')
_primop_jvp = Registry('primop_jvp')
_primop_transpose = Registry('primop_transpose')
_primop_position_argnames = Registry('primop_position_argnames')
_composite_ops = Registry('composite')
def lookup_fn(optype):
......@@ -60,6 +61,10 @@ def lookup_transpose(optype):
return _primop_transpose.lookup(optype)
def lookup_composite(optype):
return _composite_ops.lookup(optype)
def op_position_inputs(op):
"""
Returns the position inputs of `op` as registered with REGISTER_FN.
......@@ -200,6 +205,41 @@ def REGISTER_ORIG2PRIM(op_type):
return wrapper
def REGISTER_COMPOSITE(op_type):
"""
Decorator for registering the lower function for an original op into sequence of primitive ops.
Args:
op_type(str): The op name
Returns:
wrapper: Inner wrapper function
Examples:
.. code-block:: python
@REGISTER_COMPOSITE('softmax')
def softmax_composite(x, axis):
molecular = exp(x)
denominator = broadcast_to(sum(molecular, axis=axis, keepdim=True), x.shape)
res = divide(molecular, denominator)
return res
"""
if not isinstance(op_type, str):
raise TypeError(f'op_type must be str, but got {type(op_type)}.')
def wrapper(f):
def _lower(op, *args, **kwargs):
assert (
op.type == op_type
), f'op.type should be equal to op_type, but op.type is {op.type} and op_type is {op_type}'
return f(*args, **kwargs)
_composite_ops.register(op_type, _lower)
return wrapper
def REGISTER_PRIM2ORIG(op_type):
"""
Decorator for registering the lower function for an primitive op into sequence of original ops.
......
......@@ -19,8 +19,10 @@ from paddle.fluid import framework as framework
from paddle.fluid.framework import Operator, default_main_program
from paddle.incubate.autograd.utils import as_tensors
from .composite_rules import _composite
from .primops import add, fill_const
from .primreg import (
lookup_composite,
lookup_orig2prim,
lookup_prim2orig,
op_position_inputs,
......@@ -32,6 +34,7 @@ from .utils import (
flatten_and_remove_none,
get_input_var_list,
get_output_var_list,
prepare_python_api_arguments,
)
......@@ -543,6 +546,121 @@ def _lower(block, reverse, blacklist):
block._sync_with_cpp()
def _lower_composite(block, blacklist=[]):
# Some functions which are only used in _lower.
def bind(args, to_bind, value_table):
for i in range(len(args)):
if isinstance(args[i], list):
bind(args[i], to_bind, value_table)
if not isinstance(args[i], paddle.fluid.framework.Variable):
continue
elif args[i] is not None and args[i].name in to_bind:
args[i] = value_table[to_bind[args[i].name]]
def bind_name(names, to_bind):
return_list = []
for name in names:
if isinstance(name, list):
return_list.append(bind_name(name, to_bind))
else:
return_list.append(to_bind[name] if name in to_bind else name)
return return_list
def expand_nested_list(xs):
return_list = []
for x in xs:
if isinstance(x, list):
return_list = return_list + expand_nested_list(x)
else:
return_list.append(x)
return return_list
# Step1: Do some preparatory work for lower
lower_fn = _composite
lookup_fn = lookup_composite
value_table = {}
to_bind = {}
to_bind_rev = {}
for var in block.desc.all_vars():
value_table[var.name()] = block.var(var.name())
ops_to_remove = []
vars_to_remove = set()
# Step2: Process all ops in the target block
for op_idx in range(len(block.ops)):
op = block.ops[op_idx]
ops_to_remove.append(op_idx)
if lookup_fn(op.type) is not None and op.type not in blacklist:
input_args = prepare_python_api_arguments(op)
bind(input_args, to_bind, value_table)
for orig_out, new_out in zip(
expand_nested_list(get_output_var_list(op)),
expand_nested_list(as_tensors(lower_fn(op, *input_args))),
):
assert not (orig_out is None) ^ (
new_out is None
), "orig_out and new_out should match."
vars_to_remove.add(new_out.name)
value_table[new_out.name] = new_out
to_bind[orig_out.name] = new_out.name
to_bind_rev[new_out.name] = orig_out.name
else:
inputs = {}
for i in range(len(op.input_names)):
inputs[op.input_names[i]] = bind_name(
op.input(op.input_names[i]), to_bind
)
outputs = {}
for i in range(len(op.output_names)):
outputs[op.output_names[i]] = op.output(op.output_names[i])
attrs = {}
for name in sorted(op.attr_names):
attrs[name] = op.attr(name)
from paddle.fluid.dygraph.base import param_guard
new_op_desc = block.desc.append_op()
with param_guard(inputs), param_guard(outputs):
op = Operator(
block=block,
desc=new_op_desc,
type=op.type,
inputs=inputs,
outputs=outputs,
attrs=attrs,
)
block.ops.append(op)
# Step3: Do some post-processing work
for op_idx in reversed(ops_to_remove):
block.desc._remove_op(op_idx, op_idx + 1)
del block.ops[op_idx]
block._sync_with_cpp()
for op_idx in range(len(block.ops)):
op = block.ops[op_idx]
for in_name in op.input_arg_names:
if in_name in to_bind_rev:
op._rename_input(in_name, to_bind_rev[in_name])
for out_name in op.output_arg_names:
if out_name in to_bind_rev:
op._rename_output(out_name, to_bind_rev[out_name])
for var_name in sorted(vars_to_remove):
assert (
var_name in to_bind_rev
), 'var_name "{}" is not in to_bind_rev.'.format(var_name)
if var_name != to_bind_rev[var_name]:
block.desc._remove_var(var_name.encode())
del block.vars[var_name]
block._sync_with_cpp()
@framework.static_only
def orig2prim(block=None):
"""
......
......@@ -16,6 +16,8 @@ import typing
import paddle
from paddle.fluid import framework as framework
from .phi_ops_map import op_info, op_map
class PrimOption:
def __init__(self):
......@@ -148,6 +150,64 @@ def get_input_var_list(op):
]
def _solve_arg(item):
if "=" not in item:
res = item
else:
res = item.split('=')[0]
[arg_type, arg_name] = res.strip().split()
return arg_type.strip(), arg_name.strip()
def _get_args_values(op, phi_name):
"get attrs' values for api args' values"
args = op_info[phi_name]
args_list = args["args"].split(",")
inputs = []
attrs = []
for item in args_list:
arg_type, arg_name = _solve_arg(item)
op_content = op_map[op.type]
if arg_type in ("Tensor", "Tensor[]"):
if (
"inputs" in op_content.keys()
and arg_name in op_content["inputs"].keys()
):
inputs.append(op_content["inputs"][arg_name])
else:
inputs.append(arg_name)
else:
op_content = op_map[op.type]
if (
"attrs" in op_content.keys()
and arg_name in op_content["attrs"].keys()
):
attrs.append(op.attr(op_content["attrs"][arg_name]))
attrs.append(op.attr(arg_name))
return inputs, attrs
def prepare_python_api_arguments(op):
"""
Generate all args inputs of composite op. Because inputs of composite op is
the same as phi op desribed in ops.yaml. So we need to map origin op to phi op
and then push input data and attrs of origin op to correspondng phi op.
"""
if op.input_names is None:
return []
else:
if op.type in op_map:
phi_name = op_map[op.type]["phi_name"]
else:
phi_name = op.type
inputs, attrs = _get_args_values(op, phi_name)
res = [get_var_block(op.block, op.input(n)) for n in inputs]
if attrs:
res.extend(attrs)
return res
def get_output_var_list(op):
if op.output_names is None:
return []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册