未验证 提交 4d5265b8 编写于 作者: C Charles-hit 提交者: GitHub

[static code gen]Add phi and fluid info in static code gen (#49763)

* polish static grad op maker gen

* fix some bugs

* fix static code gen

* solve conflict

* modify composite grad maker name
上级 70378584
......@@ -63,7 +63,7 @@ using OpRegistryClasses = std::tuple< // NOLINT
TypePair<OpProtoAndCheckerMaker, kOpProtoAndCheckerMaker>, // NOLINT
TypePair<GradOpDescMakerBase, kGradOpDescMaker>, // NOLINT
TypePair<imperative::GradOpBaseMakerBase, kGradOpBaseMaker>, // NOLINT
TypePair<prim::GradCompositeOpMakerBase, kGradCompOpDescMaker>, // NOLINT
TypePair<prim::CompositeGradOpMakerBase, kGradCompOpDescMaker>, // NOLINT
TypePair<VarTypeInference, kVarTypeInference>, // NOLINT
TypePair<InferShapeBase, kShapeInference>, // NOLINT
TypePair<InplaceOpInference, kInplaceOpInference>, // NOLINT
......@@ -262,7 +262,7 @@ struct OpInfoFiller<T, kGradCompOpDescMaker> {
info->grad_comp_op_maker_,
nullptr,
platform::errors::AlreadyExists(
"GradCompositeOpMakerBase of %s has been registered", op_type));
"CompositeGradOpMakerBase of %s has been registered", op_type));
info->grad_comp_op_maker_ =
[](const OpDesc& fwd_op,
......
......@@ -52,8 +52,8 @@ class ElementwiseAddOpMaker : public ElementwiseOpMaker {
};
class ElementwiseAddGradCompositeOpMaker
: public prim::GradCompositeOpMakerBase {
using prim::GradCompositeOpMakerBase::GradCompositeOpMakerBase;
: public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
......
......@@ -68,8 +68,8 @@ class ElementwiseDivGradOpMaker : public framework::SingleGradOpMaker<T> {
};
class ElementwiseDivGradCompositeOpMaker
: public prim::GradCompositeOpMakerBase {
using prim::GradCompositeOpMakerBase::GradCompositeOpMakerBase;
: public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
......
......@@ -55,8 +55,8 @@ class ElementwiseSubOpMaker : public ElementwiseOpMaker {
};
class ElementwiseSubGradCompositeOpMaker
: public prim::GradCompositeOpMakerBase {
using prim::GradCompositeOpMakerBase::GradCompositeOpMakerBase;
: public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
......
......@@ -14,6 +14,7 @@
import itertools
import re
from typing import Dict, List
from type_mapping import (
attr_types_map,
......@@ -137,17 +138,23 @@ def to_composite_grad_opmaker_name(backward_op_name):
for i in range(len(words)):
words[i] = words[i].strip()
words[i] = words[i].capitalize()
composite_grad_opmaker_name = words[0] + "Composite"
composite_grad_opmaker_name += "".join(word for word in words[1:])
composite_grad_opmaker_name += "OpMaker"
composite_grad_opmaker_name = "".join(word for word in words)
composite_grad_opmaker_name += "CompositeGradOpMaker"
return composite_grad_opmaker_name
def to_variable_names(dict_list: List[Dict], key: str) -> List[str]:
names = []
for var in dict_list:
names.append(var[key])
return names
def cartesian_prod_attrs(attrs):
items = []
for attr in attrs:
type_name = attr["typename"]
name = attr["name"]
name = attr["fluid_name"]
if type_name == "Scalar":
items.append((name, to_scalar_tensor_name(attr)))
elif type_name == "IntArray":
......@@ -176,11 +183,15 @@ def cartesian_prod_attrs(attrs):
def cartesian_prod_mapping(op):
kernels = op["kernel"]["func"]
inputs = [
x["name"] for x in op["inputs"] if x["name"] in op["kernel"]["param"]
x["fluid_name"]
for x in op["inputs"]
if x["fluid_name"] in op["kernel"]["param"]
]
inputs = [to_opmaker_name_cstr(input) for input in inputs]
attrs = cartesian_prod_attrs(op["attrs"])
outputs = [to_opmaker_name_cstr(output["name"]) for output in op["outputs"]]
outputs = [
to_opmaker_name_cstr(output["fluid_name"]) for output in op["outputs"]
]
def vec(items):
return "{" + ', '.join(items) + "}"
......
......@@ -28,6 +28,7 @@ from filters import (
to_opmaker_name_cstr,
to_pascal_case,
to_scalar_tensor_name,
to_variable_names,
)
from jinja2 import Environment, FileSystemLoader, StrictUndefined
from parse_utils import to_named_dict
......@@ -60,6 +61,7 @@ env.filters["to_input_name"] = to_input_name
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
env.filters["to_composite_grad_opmaker_name"] = to_composite_grad_opmaker_name
env.filters["to_variable_names"] = to_variable_names
env.tests["base_op"] = is_base_op
env.tests["composite_op"] = is_composite_op
env.tests["vec"] = is_vec
......@@ -157,29 +159,26 @@ def process_int_array(op_item, int_array_configs):
]
def parse_composite_info(ops, backward_ops, backward_op_dict):
for op in ops:
if "backward" in op:
op["phi_backward"] = op["backward"]
for backward_op in backward_ops:
if "backward" in backward_op:
backward_op["phi_backward"] = backward_op["backward"]
for backward_op_name, op_dict in backward_op_dict.items():
if "composite" not in op_dict:
continue
op_dict["composite"]["phi_inputs"] = []
op_dict["composite"]["phi_attrs"] = []
op_dict["composite"]["phi_outputs"] = []
for input in op_dict["inputs"]:
op_dict["composite"]["phi_inputs"].append(input['name'])
for attr in op_dict["attrs"]:
op_dict["composite"]["phi_attrs"].append(attr['name'])
for output in op_dict["outputs"]:
op_dict["composite"]["phi_outputs"].append(output['name'])
# replace name of op and params for OpMaker
def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
def add_composite_info(ops, backward_ops, backward_op_dict):
# add backward composite name in forward
for op in ops + backward_ops:
if (
op["backward"] in backward_op_dict
and "composite" in backward_op_dict[op["backward"]]
):
op["backward_composite"] = op["backward"]
else:
op["backward_composite"] = None
# add fluid name in ops and backward ops info
def add_fluid_name(dict_list):
for item in dict_list:
item["fluid_name"] = item["name"]
# add fluid name of op and params for OpMaker
def add_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
def get_phi_and_fluid_op_name(op_item):
names = op_item.split('(')
if len(names) == 1:
......@@ -187,12 +186,14 @@ def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
else:
return names[0].strip(), names[1].split(')')[0].strip()
def update_op_param_name(op_args, args_alias_map):
def add_op_param_name(op_args, args_alias_map):
for item in op_args:
if item['name'] in args_alias_map:
item['name'] = args_alias_map[item['name']]
item['fluid_name'] = args_alias_map[item['name']]
else:
item['fluid_name'] = item['name']
def update_grad_args_name(op_args, args_alias_map):
def add_grad_args_name(op_args, args_alias_map):
for item in op_args:
if (
item['name'].endswith('_grad')
......@@ -201,38 +202,12 @@ def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
args_alias_map[item['name']] = (
args_alias_map[item['name'][:-5]] + '_grad'
)
item['name'] = args_alias_map[item['name'][:-5]] + '_grad'
def add_fluid_info_in_composite(composite_map, args_alias_map):
fluid_input_list = []
fluid_attr_list = []
fluid_output_list = []
# add fluid op inputs
for input in composite_map["phi_inputs"]:
if input in args_alias_map:
fluid_input_list.append(args_alias_map[input])
else:
fluid_input_list.append(input)
# add fluid op attrs
for attr in composite_map["phi_attrs"]:
if attr in args_alias_map:
fluid_attr_list.append(args_alias_map[attr])
else:
fluid_attr_list.append(attr)
# add fluid op outputs
for output in composite_map["phi_outputs"]:
if output in args_alias_map:
fluid_output_list.append(args_alias_map[output])
else:
fluid_output_list.append(output)
composite_map.update(
{
"fluid_inputs": fluid_input_list,
"fluid_attrs": fluid_attr_list,
"fluid_outputs": fluid_output_list,
}
)
item['fluid_name'] = args_alias_map[item['name'][:-5]] + '_grad'
elif (
item['name'].endswith('_grad')
and item['name'][:-5] not in args_alias_map
):
item['fluid_name'] = item['name']
def get_param_list_alias(param_list, args_map):
return [
......@@ -287,15 +262,15 @@ def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
op_item['kernel']['layout']['candidates'], args_name_map
)
def update_grad_op_compat_name(grad_op_item, args_name_map):
update_op_param_name(grad_op_item['inputs'], args_name_map)
update_op_param_name(grad_op_item['outputs'], args_name_map)
update_op_param_name(grad_op_item['attrs'], args_name_map)
update_op_param_name(grad_op_item['forward']['inputs'], args_name_map)
update_op_param_name(grad_op_item['forward']['outputs'], args_name_map)
update_op_param_name(grad_op_item['forward']['attrs'], args_name_map)
update_grad_args_name(grad_op_item['inputs'], args_map)
update_grad_args_name(grad_op_item['outputs'], args_map)
def add_grad_op_compat_name(grad_op_item, args_name_map):
add_op_param_name(grad_op_item['inputs'], args_name_map)
add_op_param_name(grad_op_item['outputs'], args_name_map)
add_op_param_name(grad_op_item['attrs'], args_name_map)
add_op_param_name(grad_op_item['forward']['inputs'], args_name_map)
add_op_param_name(grad_op_item['forward']['outputs'], args_name_map)
add_op_param_name(grad_op_item['forward']['attrs'], args_name_map)
add_grad_args_name(grad_op_item['inputs'], args_map)
add_grad_args_name(grad_op_item['outputs'], args_map)
for op_args in op_fluid_map_list:
new_op_name, op_name = get_phi_and_fluid_op_name(op_args['op'])
......@@ -340,39 +315,32 @@ def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
int_array_configs[
op_args[key][args_item['name']]
] = int_array_configs[args_item['name']]
args_item['name'] = op_args[key][args_item['name']]
if has_backward:
for args_item in backward_op_item['forward'][key]:
if args_item['name'] in op_args[key]:
args_item['name'] = op_args[key][args_item['name']]
forward_op_item["attr_dict"] = to_named_dict(forward_op_item["attrs"])
args_item['fluid_name'] = op_args[key][
args_item['name']
]
update_common_params_name(
forward_op_item, args_map, scalar_configs, int_array_configs
)
if has_backward:
update_grad_op_compat_name(backward_op_item, args_map)
# update fluid info in backward
add_grad_op_compat_name(backward_op_item, args_map)
update_common_params_name(
backward_op_item, args_map, scalar_configs, int_array_configs
)
backward_op_item["attr_dict"] = to_named_dict(
backward_op_item["attrs"]
)
if 'backward' not in op_args:
continue
backward_op_list = op_args['backward'].split(',')
# add fluid args name in composite map
for backward_op in backward_op_list:
if (
"composite"
in backward_op_dict[backward_op.split('(')[0].strip()]
):
add_fluid_info_in_composite(
backward_op_dict[backward_op]["composite"], args_map
)
_, bw_op_name = get_phi_and_fluid_op_name(backward_op_list[0])
phi_bw_op_name, bw_op_name = get_phi_and_fluid_op_name(
backward_op_list[0]
)
if (
forward_op_item["backward_composite"] is not None
and phi_bw_op_name != bw_op_name
):
forward_op_item["backward_composite"] = bw_op_name
forward_op_item['backward'] = bw_op_name
backward_op_item['op_name'] = bw_op_name
......@@ -383,18 +351,20 @@ def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
double_grad_op_name,
) = get_phi_and_fluid_op_name(backward_op_list[1])
double_grad_item = backward_op_dict[phi_double_grad_op_name]
if (
backward_op_item["backward_composite"] is not None
and phi_double_grad_op_name != double_grad_op_name
):
backward_op_item["backward_composite"] = double_grad_op_name
backward_op_item['backward'] = double_grad_op_name
double_grad_item['op_name'] = double_grad_op_name
update_grad_op_compat_name(double_grad_item, args_map)
add_grad_op_compat_name(double_grad_item, args_map)
update_common_params_name(
double_grad_item,
args_map,
scalar_configs,
int_array_configs,
)
double_grad_item["attr_dict"] = to_named_dict(
double_grad_item["attrs"]
)
# for triple grad
if len(backward_op_list) > 2:
......@@ -403,18 +373,22 @@ def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
triple_grad_op_name,
) = get_phi_and_fluid_op_name(backward_op_list[2])
triple_grad_item = backward_op_dict[phi_triple_grad_op_name]
if (
double_grad_item["backward_composite"] is not None
and phi_triple_grad_op_name != triple_grad_op_name
):
double_grad_item[
"backward_composite"
] = triple_grad_op_name
double_grad_item['backward'] = triple_grad_op_name
triple_grad_item['op_name'] = triple_grad_op_name
update_grad_op_compat_name(triple_grad_item, args_map)
add_grad_op_compat_name(triple_grad_item, args_map)
update_common_params_name(
triple_grad_item,
args_map,
scalar_configs,
int_array_configs,
)
triple_grad_item["attr_dict"] = to_named_dict(
triple_grad_item["attrs"]
)
def process_invoke_op(forward_op_dict, backward_op_dict):
......@@ -432,20 +406,28 @@ def process_invoke_op(forward_op_dict, backward_op_dict):
for input_item in reuse_op['inputs']:
bw_op['invoke']['inputs'].append(
{
'fluid_name': input_item['fluid_name'],
'name': input_item['name'],
'value': args_list[args_index],
}
)
args_index = args_index + 1
bw_fluid_attrs_set = [
item['fluid_name'] for item in bw_op['attrs']
]
for attr in reuse_op['attrs']:
if args_index < len(args_list):
attr_value = (
f"this->GetAttr(\"{args_list[args_index]}\")"
if args_list[args_index] in bw_op['attr_dict']
if args_list[args_index] in bw_fluid_attrs_set
else args_list[args_index]
)
bw_op['invoke']['attrs'].append(
{'name': attr['name'], 'value': attr_value}
{
'name': attr['name'],
'fluid_name': attr['fluid_name'],
'value': attr_value,
}
)
args_index = args_index + 1
else:
......@@ -454,7 +436,8 @@ def process_invoke_op(forward_op_dict, backward_op_dict):
bw_op['invoke']['outputs'].append(
{
'name': output_item['name'],
'value': bw_op['outputs'][idx]['name'],
'fluid_name': output_item['fluid_name'],
'value': bw_op['outputs'][idx]['fluid_name'],
}
)
......@@ -507,17 +490,26 @@ def main(
for op in ops:
op['op_name'] = op['name']
add_fluid_name(op['inputs'])
add_fluid_name(op['attrs'])
add_fluid_name(op['outputs'])
for bw_op in backward_ops:
bw_op['op_name'] = bw_op['name']
add_fluid_name(bw_op['inputs'])
add_fluid_name(bw_op['attrs'])
add_fluid_name(bw_op['outputs'])
add_fluid_name(bw_op['forward']['inputs'])
add_fluid_name(bw_op['forward']['attrs'])
add_fluid_name(bw_op['forward']['outputs'])
for bw_output in bw_op['outputs']:
bw_output['drop_empty_grad'] = True
# deal the drop_empty_grad of bw_op by op_compat.yaml
parse_drop_empty_grad(op_fluid_map_list, backward_op_dict)
parse_composite_info(ops, backward_ops, backward_op_dict)
add_composite_info(ops, backward_ops, backward_op_dict)
replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict)
add_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict)
# prepare for invoke case
process_invoke_op(forward_op_dict, backward_op_dict)
......@@ -545,7 +537,6 @@ def main(
ops=ops,
backward_ops=backward_ops,
op_dict=op_dict,
composite_gen_flag=True,
)
f.write(msg)
ks_template = env.get_template('ks.c.j2')
......
......@@ -28,12 +28,14 @@ from filters import (
to_opmaker_name_cstr,
to_pascal_case,
to_scalar_tensor_name,
to_variable_names,
)
from generate_op import process_invoke_op
from generate_op import add_fluid_name, process_invoke_op
from jinja2 import Environment, FileSystemLoader, StrictUndefined
from parse_utils import to_named_dict
from tests import (
is_base_op,
is_composite_op,
is_initializer_list,
is_scalar,
is_vec,
......@@ -60,7 +62,9 @@ env.filters["to_input_name"] = to_input_name
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
env.filters["to_composite_grad_opmaker_name"] = to_composite_grad_opmaker_name
env.filters["to_variable_names"] = to_variable_names
env.tests["base_op"] = is_base_op
env.tests["composite_op"] = is_composite_op
env.tests["vec"] = is_vec
env.tests["scalar"] = is_scalar
env.tests["initializer_list"] = is_initializer_list
......@@ -96,9 +100,18 @@ def main(op_yaml_path, backward_yaml_path, output_op_path, output_arg_map_path):
op['name'] = op['op_name']
if op["backward"] is not None:
op["backward"] = SPARSE_OP_PREFIX + op["backward"]
add_fluid_name(op["inputs"])
add_fluid_name(op["attrs"])
add_fluid_name(op["outputs"])
for bw_op in backward_ops:
bw_op['op_name'] = SPARSE_OP_PREFIX + bw_op['name']
bw_op['name'] = bw_op['op_name']
add_fluid_name(bw_op["inputs"])
add_fluid_name(bw_op["attrs"])
add_fluid_name(bw_op["outputs"])
add_fluid_name(bw_op["forward"]["inputs"])
add_fluid_name(bw_op["forward"]["attrs"])
add_fluid_name(bw_op["forward"]["outputs"])
if 'invoke' in bw_op:
bw_op['invoke']['args'] = [
param.strip() for param in bw_op['invoke']['args'].split(',')
......@@ -139,7 +152,6 @@ def main(op_yaml_path, backward_yaml_path, output_op_path, output_arg_map_path):
ops=ops,
backward_ops=backward_ops,
op_dict=op_dict,
composite_gen_flag=False,
)
f.write(msg)
......
......@@ -28,12 +28,14 @@ from filters import (
to_opmaker_name_cstr,
to_pascal_case,
to_scalar_tensor_name,
to_variable_names,
)
from generate_op import replace_compat_name
from generate_op import add_compat_name, add_fluid_name
from jinja2 import Environment, FileSystemLoader, StrictUndefined
from parse_utils import to_named_dict
from tests import (
is_base_op,
is_composite_op,
is_initializer_list,
is_scalar,
is_vec,
......@@ -60,7 +62,9 @@ env.filters["to_input_name"] = to_input_name
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
env.filters["to_composite_grad_opmaker_name"] = to_composite_grad_opmaker_name
env.filters["to_variable_names"] = to_variable_names
env.tests["base_op"] = is_base_op
env.tests["composite_op"] = is_composite_op
env.tests["vec"] = is_vec
env.tests["scalar"] = is_scalar
env.tests["initializer_list"] = is_initializer_list
......@@ -100,8 +104,11 @@ def main(
for op in ops:
op['op_name'] = op['name']
add_fluid_name(op["inputs"])
add_fluid_name(op["attrs"])
add_fluid_name(op["outputs"])
replace_compat_name(op_op_map, forward_op_dict, {})
add_compat_name(op_op_map, forward_op_dict, {})
if len(ops) == 0:
if os.path.isfile(output_op_path):
......@@ -116,7 +123,6 @@ def main(
ops=ops,
backward_ops=[],
op_dict=forward_op_dict,
composite_gen_flag=False,
)
f.write(msg)
......
......@@ -294,14 +294,13 @@ def parse_composite(
composite_config: str,
) -> Dict[str, Any]:
# composite_config: func(args1, args2,.....)
fname = r'(.*?)'
wspace = r'\s*'
fargs = r'(.*?)'
pattern = fr'{fname}{wspace}\({wspace}{fargs}{wspace}\)'
m = re.search(pattern, composite_config)
func_name = m.group(1)
func_args = m.group(2)
result = re.search(
r"(?P<func_name>[a-z][a-z0-9_]+)\s*\((?P<func_args>[^\)]+)\)",
composite_config,
)
func_name = result.group("func_name")
func_args = result.group("func_args")
composite_dict = {}
composite_dict["func_name"] = func_name
......
......@@ -39,11 +39,9 @@ using paddle::framework::GradVarName;
{% else %}
{{backward_op_reused_maker(op, op_dict[op["forward"]["name"]], op["invoke"])}}
{% endif %}
{% if composite_gen_flag == True %}
{% if op is composite_op %}
{% if op is composite_op %}
{{composite_grad_op_maker(op_dict[op["name"]])}}
{% endif %}
{% endif %}
{% endif %}
{% endfor %}
} // namespace operators
} // namespace paddle
......@@ -51,7 +49,7 @@ using paddle::framework::GradVarName;
namespace ops = paddle::operators;
{% for op in ops + backward_ops %}
{% if op is base_op %}
{{register_op_with_components(op, op_dict)}}
{{register_op_with_components(op)}}
{{register_op_version(op)}}
{% endif %}
{% endfor %}
......@@ -135,9 +135,9 @@ struct TestBaseProgram {
int idx_{0};
};
class TestGradCompositeGradMaker : public GradCompositeOpMakerBase {
class TestGradCompositeGradMaker : public CompositeGradOpMakerBase {
public:
using prim::GradCompositeOpMakerBase::GradCompositeOpMakerBase;
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
void Apply() override {}
};
......
......@@ -41,9 +41,9 @@ namespace prim {
argument DropEmptyIG in the derived classes.
*/
class GradCompositeOpMakerBase {
class CompositeGradOpMakerBase {
public:
explicit GradCompositeOpMakerBase(
explicit CompositeGradOpMakerBase(
const framework::OpDesc& fwd_op,
const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var,
......@@ -61,7 +61,7 @@ class GradCompositeOpMakerBase {
acting_program_.MutableBlock(0));
}
virtual ~GradCompositeOpMakerBase() = default;
virtual ~CompositeGradOpMakerBase() = default;
virtual std::vector<std::unique_ptr<framework::OpDesc>> operator()() {
this->Apply();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册