未验证 提交 4d5265b8 编写于 作者: C Charles-hit 提交者: GitHub

[static code gen]Add phi and fluid info in static code gen (#49763)

* polish static grad op maker gen

* fix some bugs

* fix static code gen

* solve conflict

* modify composite grad maker name
上级 70378584
...@@ -63,7 +63,7 @@ using OpRegistryClasses = std::tuple< // NOLINT ...@@ -63,7 +63,7 @@ using OpRegistryClasses = std::tuple< // NOLINT
TypePair<OpProtoAndCheckerMaker, kOpProtoAndCheckerMaker>, // NOLINT TypePair<OpProtoAndCheckerMaker, kOpProtoAndCheckerMaker>, // NOLINT
TypePair<GradOpDescMakerBase, kGradOpDescMaker>, // NOLINT TypePair<GradOpDescMakerBase, kGradOpDescMaker>, // NOLINT
TypePair<imperative::GradOpBaseMakerBase, kGradOpBaseMaker>, // NOLINT TypePair<imperative::GradOpBaseMakerBase, kGradOpBaseMaker>, // NOLINT
TypePair<prim::GradCompositeOpMakerBase, kGradCompOpDescMaker>, // NOLINT TypePair<prim::CompositeGradOpMakerBase, kGradCompOpDescMaker>, // NOLINT
TypePair<VarTypeInference, kVarTypeInference>, // NOLINT TypePair<VarTypeInference, kVarTypeInference>, // NOLINT
TypePair<InferShapeBase, kShapeInference>, // NOLINT TypePair<InferShapeBase, kShapeInference>, // NOLINT
TypePair<InplaceOpInference, kInplaceOpInference>, // NOLINT TypePair<InplaceOpInference, kInplaceOpInference>, // NOLINT
...@@ -262,7 +262,7 @@ struct OpInfoFiller<T, kGradCompOpDescMaker> { ...@@ -262,7 +262,7 @@ struct OpInfoFiller<T, kGradCompOpDescMaker> {
info->grad_comp_op_maker_, info->grad_comp_op_maker_,
nullptr, nullptr,
platform::errors::AlreadyExists( platform::errors::AlreadyExists(
"GradCompositeOpMakerBase of %s has been registered", op_type)); "CompositeGradOpMakerBase of %s has been registered", op_type));
info->grad_comp_op_maker_ = info->grad_comp_op_maker_ =
[](const OpDesc& fwd_op, [](const OpDesc& fwd_op,
......
...@@ -52,8 +52,8 @@ class ElementwiseAddOpMaker : public ElementwiseOpMaker { ...@@ -52,8 +52,8 @@ class ElementwiseAddOpMaker : public ElementwiseOpMaker {
}; };
class ElementwiseAddGradCompositeOpMaker class ElementwiseAddGradCompositeOpMaker
: public prim::GradCompositeOpMakerBase { : public prim::CompositeGradOpMakerBase {
using prim::GradCompositeOpMakerBase::GradCompositeOpMakerBase; using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public: public:
void Apply() override { void Apply() override {
......
...@@ -68,8 +68,8 @@ class ElementwiseDivGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -68,8 +68,8 @@ class ElementwiseDivGradOpMaker : public framework::SingleGradOpMaker<T> {
}; };
class ElementwiseDivGradCompositeOpMaker class ElementwiseDivGradCompositeOpMaker
: public prim::GradCompositeOpMakerBase { : public prim::CompositeGradOpMakerBase {
using prim::GradCompositeOpMakerBase::GradCompositeOpMakerBase; using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public: public:
void Apply() override { void Apply() override {
......
...@@ -55,8 +55,8 @@ class ElementwiseSubOpMaker : public ElementwiseOpMaker { ...@@ -55,8 +55,8 @@ class ElementwiseSubOpMaker : public ElementwiseOpMaker {
}; };
class ElementwiseSubGradCompositeOpMaker class ElementwiseSubGradCompositeOpMaker
: public prim::GradCompositeOpMakerBase { : public prim::CompositeGradOpMakerBase {
using prim::GradCompositeOpMakerBase::GradCompositeOpMakerBase; using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public: public:
void Apply() override { void Apply() override {
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import itertools import itertools
import re import re
from typing import Dict, List
from type_mapping import ( from type_mapping import (
attr_types_map, attr_types_map,
...@@ -137,17 +138,23 @@ def to_composite_grad_opmaker_name(backward_op_name): ...@@ -137,17 +138,23 @@ def to_composite_grad_opmaker_name(backward_op_name):
for i in range(len(words)): for i in range(len(words)):
words[i] = words[i].strip() words[i] = words[i].strip()
words[i] = words[i].capitalize() words[i] = words[i].capitalize()
composite_grad_opmaker_name = words[0] + "Composite" composite_grad_opmaker_name = "".join(word for word in words)
composite_grad_opmaker_name += "".join(word for word in words[1:]) composite_grad_opmaker_name += "CompositeGradOpMaker"
composite_grad_opmaker_name += "OpMaker"
return composite_grad_opmaker_name return composite_grad_opmaker_name
def to_variable_names(dict_list: List[Dict], key: str) -> List[str]:
names = []
for var in dict_list:
names.append(var[key])
return names
def cartesian_prod_attrs(attrs): def cartesian_prod_attrs(attrs):
items = [] items = []
for attr in attrs: for attr in attrs:
type_name = attr["typename"] type_name = attr["typename"]
name = attr["name"] name = attr["fluid_name"]
if type_name == "Scalar": if type_name == "Scalar":
items.append((name, to_scalar_tensor_name(attr))) items.append((name, to_scalar_tensor_name(attr)))
elif type_name == "IntArray": elif type_name == "IntArray":
...@@ -176,11 +183,15 @@ def cartesian_prod_attrs(attrs): ...@@ -176,11 +183,15 @@ def cartesian_prod_attrs(attrs):
def cartesian_prod_mapping(op): def cartesian_prod_mapping(op):
kernels = op["kernel"]["func"] kernels = op["kernel"]["func"]
inputs = [ inputs = [
x["name"] for x in op["inputs"] if x["name"] in op["kernel"]["param"] x["fluid_name"]
for x in op["inputs"]
if x["fluid_name"] in op["kernel"]["param"]
] ]
inputs = [to_opmaker_name_cstr(input) for input in inputs] inputs = [to_opmaker_name_cstr(input) for input in inputs]
attrs = cartesian_prod_attrs(op["attrs"]) attrs = cartesian_prod_attrs(op["attrs"])
outputs = [to_opmaker_name_cstr(output["name"]) for output in op["outputs"]] outputs = [
to_opmaker_name_cstr(output["fluid_name"]) for output in op["outputs"]
]
def vec(items): def vec(items):
return "{" + ', '.join(items) + "}" return "{" + ', '.join(items) + "}"
......
...@@ -28,6 +28,7 @@ from filters import ( ...@@ -28,6 +28,7 @@ from filters import (
to_opmaker_name_cstr, to_opmaker_name_cstr,
to_pascal_case, to_pascal_case,
to_scalar_tensor_name, to_scalar_tensor_name,
to_variable_names,
) )
from jinja2 import Environment, FileSystemLoader, StrictUndefined from jinja2 import Environment, FileSystemLoader, StrictUndefined
from parse_utils import to_named_dict from parse_utils import to_named_dict
...@@ -60,6 +61,7 @@ env.filters["to_input_name"] = to_input_name ...@@ -60,6 +61,7 @@ env.filters["to_input_name"] = to_input_name
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
env.filters["to_composite_grad_opmaker_name"] = to_composite_grad_opmaker_name env.filters["to_composite_grad_opmaker_name"] = to_composite_grad_opmaker_name
env.filters["to_variable_names"] = to_variable_names
env.tests["base_op"] = is_base_op env.tests["base_op"] = is_base_op
env.tests["composite_op"] = is_composite_op env.tests["composite_op"] = is_composite_op
env.tests["vec"] = is_vec env.tests["vec"] = is_vec
...@@ -157,29 +159,26 @@ def process_int_array(op_item, int_array_configs): ...@@ -157,29 +159,26 @@ def process_int_array(op_item, int_array_configs):
] ]
def parse_composite_info(ops, backward_ops, backward_op_dict): def add_composite_info(ops, backward_ops, backward_op_dict):
for op in ops: # add backward composite name in forward
if "backward" in op: for op in ops + backward_ops:
op["phi_backward"] = op["backward"] if (
for backward_op in backward_ops: op["backward"] in backward_op_dict
if "backward" in backward_op: and "composite" in backward_op_dict[op["backward"]]
backward_op["phi_backward"] = backward_op["backward"] ):
for backward_op_name, op_dict in backward_op_dict.items(): op["backward_composite"] = op["backward"]
if "composite" not in op_dict: else:
continue op["backward_composite"] = None
op_dict["composite"]["phi_inputs"] = []
op_dict["composite"]["phi_attrs"] = []
op_dict["composite"]["phi_outputs"] = [] # add fluid name in ops and backward ops info
for input in op_dict["inputs"]: def add_fluid_name(dict_list):
op_dict["composite"]["phi_inputs"].append(input['name']) for item in dict_list:
for attr in op_dict["attrs"]: item["fluid_name"] = item["name"]
op_dict["composite"]["phi_attrs"].append(attr['name'])
for output in op_dict["outputs"]:
op_dict["composite"]["phi_outputs"].append(output['name']) # add fluid name of op and params for OpMaker
def add_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
# replace name of op and params for OpMaker
def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
def get_phi_and_fluid_op_name(op_item): def get_phi_and_fluid_op_name(op_item):
names = op_item.split('(') names = op_item.split('(')
if len(names) == 1: if len(names) == 1:
...@@ -187,12 +186,14 @@ def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict): ...@@ -187,12 +186,14 @@ def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
else: else:
return names[0].strip(), names[1].split(')')[0].strip() return names[0].strip(), names[1].split(')')[0].strip()
def update_op_param_name(op_args, args_alias_map): def add_op_param_name(op_args, args_alias_map):
for item in op_args: for item in op_args:
if item['name'] in args_alias_map: if item['name'] in args_alias_map:
item['name'] = args_alias_map[item['name']] item['fluid_name'] = args_alias_map[item['name']]
else:
item['fluid_name'] = item['name']
def update_grad_args_name(op_args, args_alias_map): def add_grad_args_name(op_args, args_alias_map):
for item in op_args: for item in op_args:
if ( if (
item['name'].endswith('_grad') item['name'].endswith('_grad')
...@@ -201,38 +202,12 @@ def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict): ...@@ -201,38 +202,12 @@ def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
args_alias_map[item['name']] = ( args_alias_map[item['name']] = (
args_alias_map[item['name'][:-5]] + '_grad' args_alias_map[item['name'][:-5]] + '_grad'
) )
item['name'] = args_alias_map[item['name'][:-5]] + '_grad' item['fluid_name'] = args_alias_map[item['name'][:-5]] + '_grad'
elif (
def add_fluid_info_in_composite(composite_map, args_alias_map): item['name'].endswith('_grad')
fluid_input_list = [] and item['name'][:-5] not in args_alias_map
fluid_attr_list = [] ):
fluid_output_list = [] item['fluid_name'] = item['name']
# add fluid op inputs
for input in composite_map["phi_inputs"]:
if input in args_alias_map:
fluid_input_list.append(args_alias_map[input])
else:
fluid_input_list.append(input)
# add fluid op attrs
for attr in composite_map["phi_attrs"]:
if attr in args_alias_map:
fluid_attr_list.append(args_alias_map[attr])
else:
fluid_attr_list.append(attr)
# add fluid op outputs
for output in composite_map["phi_outputs"]:
if output in args_alias_map:
fluid_output_list.append(args_alias_map[output])
else:
fluid_output_list.append(output)
composite_map.update(
{
"fluid_inputs": fluid_input_list,
"fluid_attrs": fluid_attr_list,
"fluid_outputs": fluid_output_list,
}
)
def get_param_list_alias(param_list, args_map): def get_param_list_alias(param_list, args_map):
return [ return [
...@@ -287,15 +262,15 @@ def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict): ...@@ -287,15 +262,15 @@ def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
op_item['kernel']['layout']['candidates'], args_name_map op_item['kernel']['layout']['candidates'], args_name_map
) )
def update_grad_op_compat_name(grad_op_item, args_name_map): def add_grad_op_compat_name(grad_op_item, args_name_map):
update_op_param_name(grad_op_item['inputs'], args_name_map) add_op_param_name(grad_op_item['inputs'], args_name_map)
update_op_param_name(grad_op_item['outputs'], args_name_map) add_op_param_name(grad_op_item['outputs'], args_name_map)
update_op_param_name(grad_op_item['attrs'], args_name_map) add_op_param_name(grad_op_item['attrs'], args_name_map)
update_op_param_name(grad_op_item['forward']['inputs'], args_name_map) add_op_param_name(grad_op_item['forward']['inputs'], args_name_map)
update_op_param_name(grad_op_item['forward']['outputs'], args_name_map) add_op_param_name(grad_op_item['forward']['outputs'], args_name_map)
update_op_param_name(grad_op_item['forward']['attrs'], args_name_map) add_op_param_name(grad_op_item['forward']['attrs'], args_name_map)
update_grad_args_name(grad_op_item['inputs'], args_map) add_grad_args_name(grad_op_item['inputs'], args_map)
update_grad_args_name(grad_op_item['outputs'], args_map) add_grad_args_name(grad_op_item['outputs'], args_map)
for op_args in op_fluid_map_list: for op_args in op_fluid_map_list:
new_op_name, op_name = get_phi_and_fluid_op_name(op_args['op']) new_op_name, op_name = get_phi_and_fluid_op_name(op_args['op'])
...@@ -340,39 +315,32 @@ def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict): ...@@ -340,39 +315,32 @@ def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
int_array_configs[ int_array_configs[
op_args[key][args_item['name']] op_args[key][args_item['name']]
] = int_array_configs[args_item['name']] ] = int_array_configs[args_item['name']]
args_item['name'] = op_args[key][args_item['name']] args_item['fluid_name'] = op_args[key][
if has_backward: args_item['name']
for args_item in backward_op_item['forward'][key]: ]
if args_item['name'] in op_args[key]:
args_item['name'] = op_args[key][args_item['name']]
forward_op_item["attr_dict"] = to_named_dict(forward_op_item["attrs"])
update_common_params_name( update_common_params_name(
forward_op_item, args_map, scalar_configs, int_array_configs forward_op_item, args_map, scalar_configs, int_array_configs
) )
if has_backward: if has_backward:
update_grad_op_compat_name(backward_op_item, args_map) # update fluid info in backward
add_grad_op_compat_name(backward_op_item, args_map)
update_common_params_name( update_common_params_name(
backward_op_item, args_map, scalar_configs, int_array_configs backward_op_item, args_map, scalar_configs, int_array_configs
) )
backward_op_item["attr_dict"] = to_named_dict(
backward_op_item["attrs"]
)
if 'backward' not in op_args: if 'backward' not in op_args:
continue continue
backward_op_list = op_args['backward'].split(',') backward_op_list = op_args['backward'].split(',')
# add fluid args name in composite map phi_bw_op_name, bw_op_name = get_phi_and_fluid_op_name(
for backward_op in backward_op_list: backward_op_list[0]
if ( )
"composite" if (
in backward_op_dict[backward_op.split('(')[0].strip()] forward_op_item["backward_composite"] is not None
): and phi_bw_op_name != bw_op_name
add_fluid_info_in_composite( ):
backward_op_dict[backward_op]["composite"], args_map forward_op_item["backward_composite"] = bw_op_name
)
_, bw_op_name = get_phi_and_fluid_op_name(backward_op_list[0])
forward_op_item['backward'] = bw_op_name forward_op_item['backward'] = bw_op_name
backward_op_item['op_name'] = bw_op_name backward_op_item['op_name'] = bw_op_name
...@@ -383,18 +351,20 @@ def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict): ...@@ -383,18 +351,20 @@ def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
double_grad_op_name, double_grad_op_name,
) = get_phi_and_fluid_op_name(backward_op_list[1]) ) = get_phi_and_fluid_op_name(backward_op_list[1])
double_grad_item = backward_op_dict[phi_double_grad_op_name] double_grad_item = backward_op_dict[phi_double_grad_op_name]
if (
backward_op_item["backward_composite"] is not None
and phi_double_grad_op_name != double_grad_op_name
):
backward_op_item["backward_composite"] = double_grad_op_name
backward_op_item['backward'] = double_grad_op_name backward_op_item['backward'] = double_grad_op_name
double_grad_item['op_name'] = double_grad_op_name double_grad_item['op_name'] = double_grad_op_name
update_grad_op_compat_name(double_grad_item, args_map) add_grad_op_compat_name(double_grad_item, args_map)
update_common_params_name( update_common_params_name(
double_grad_item, double_grad_item,
args_map, args_map,
scalar_configs, scalar_configs,
int_array_configs, int_array_configs,
) )
double_grad_item["attr_dict"] = to_named_dict(
double_grad_item["attrs"]
)
# for triple grad # for triple grad
if len(backward_op_list) > 2: if len(backward_op_list) > 2:
...@@ -403,18 +373,22 @@ def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict): ...@@ -403,18 +373,22 @@ def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
triple_grad_op_name, triple_grad_op_name,
) = get_phi_and_fluid_op_name(backward_op_list[2]) ) = get_phi_and_fluid_op_name(backward_op_list[2])
triple_grad_item = backward_op_dict[phi_triple_grad_op_name] triple_grad_item = backward_op_dict[phi_triple_grad_op_name]
if (
double_grad_item["backward_composite"] is not None
and phi_triple_grad_op_name != triple_grad_op_name
):
double_grad_item[
"backward_composite"
] = triple_grad_op_name
double_grad_item['backward'] = triple_grad_op_name double_grad_item['backward'] = triple_grad_op_name
triple_grad_item['op_name'] = triple_grad_op_name triple_grad_item['op_name'] = triple_grad_op_name
update_grad_op_compat_name(triple_grad_item, args_map) add_grad_op_compat_name(triple_grad_item, args_map)
update_common_params_name( update_common_params_name(
triple_grad_item, triple_grad_item,
args_map, args_map,
scalar_configs, scalar_configs,
int_array_configs, int_array_configs,
) )
triple_grad_item["attr_dict"] = to_named_dict(
triple_grad_item["attrs"]
)
def process_invoke_op(forward_op_dict, backward_op_dict): def process_invoke_op(forward_op_dict, backward_op_dict):
...@@ -432,20 +406,28 @@ def process_invoke_op(forward_op_dict, backward_op_dict): ...@@ -432,20 +406,28 @@ def process_invoke_op(forward_op_dict, backward_op_dict):
for input_item in reuse_op['inputs']: for input_item in reuse_op['inputs']:
bw_op['invoke']['inputs'].append( bw_op['invoke']['inputs'].append(
{ {
'fluid_name': input_item['fluid_name'],
'name': input_item['name'], 'name': input_item['name'],
'value': args_list[args_index], 'value': args_list[args_index],
} }
) )
args_index = args_index + 1 args_index = args_index + 1
bw_fluid_attrs_set = [
item['fluid_name'] for item in bw_op['attrs']
]
for attr in reuse_op['attrs']: for attr in reuse_op['attrs']:
if args_index < len(args_list): if args_index < len(args_list):
attr_value = ( attr_value = (
f"this->GetAttr(\"{args_list[args_index]}\")" f"this->GetAttr(\"{args_list[args_index]}\")"
if args_list[args_index] in bw_op['attr_dict'] if args_list[args_index] in bw_fluid_attrs_set
else args_list[args_index] else args_list[args_index]
) )
bw_op['invoke']['attrs'].append( bw_op['invoke']['attrs'].append(
{'name': attr['name'], 'value': attr_value} {
'name': attr['name'],
'fluid_name': attr['fluid_name'],
'value': attr_value,
}
) )
args_index = args_index + 1 args_index = args_index + 1
else: else:
...@@ -454,7 +436,8 @@ def process_invoke_op(forward_op_dict, backward_op_dict): ...@@ -454,7 +436,8 @@ def process_invoke_op(forward_op_dict, backward_op_dict):
bw_op['invoke']['outputs'].append( bw_op['invoke']['outputs'].append(
{ {
'name': output_item['name'], 'name': output_item['name'],
'value': bw_op['outputs'][idx]['name'], 'fluid_name': output_item['fluid_name'],
'value': bw_op['outputs'][idx]['fluid_name'],
} }
) )
...@@ -507,17 +490,26 @@ def main( ...@@ -507,17 +490,26 @@ def main(
for op in ops: for op in ops:
op['op_name'] = op['name'] op['op_name'] = op['name']
add_fluid_name(op['inputs'])
add_fluid_name(op['attrs'])
add_fluid_name(op['outputs'])
for bw_op in backward_ops: for bw_op in backward_ops:
bw_op['op_name'] = bw_op['name'] bw_op['op_name'] = bw_op['name']
add_fluid_name(bw_op['inputs'])
add_fluid_name(bw_op['attrs'])
add_fluid_name(bw_op['outputs'])
add_fluid_name(bw_op['forward']['inputs'])
add_fluid_name(bw_op['forward']['attrs'])
add_fluid_name(bw_op['forward']['outputs'])
for bw_output in bw_op['outputs']: for bw_output in bw_op['outputs']:
bw_output['drop_empty_grad'] = True bw_output['drop_empty_grad'] = True
# deal the drop_empty_grad of bw_op by op_compat.yaml # deal the drop_empty_grad of bw_op by op_compat.yaml
parse_drop_empty_grad(op_fluid_map_list, backward_op_dict) parse_drop_empty_grad(op_fluid_map_list, backward_op_dict)
parse_composite_info(ops, backward_ops, backward_op_dict) add_composite_info(ops, backward_ops, backward_op_dict)
replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict) add_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict)
# prepare for invoke case # prepare for invoke case
process_invoke_op(forward_op_dict, backward_op_dict) process_invoke_op(forward_op_dict, backward_op_dict)
...@@ -545,7 +537,6 @@ def main( ...@@ -545,7 +537,6 @@ def main(
ops=ops, ops=ops,
backward_ops=backward_ops, backward_ops=backward_ops,
op_dict=op_dict, op_dict=op_dict,
composite_gen_flag=True,
) )
f.write(msg) f.write(msg)
ks_template = env.get_template('ks.c.j2') ks_template = env.get_template('ks.c.j2')
......
...@@ -28,12 +28,14 @@ from filters import ( ...@@ -28,12 +28,14 @@ from filters import (
to_opmaker_name_cstr, to_opmaker_name_cstr,
to_pascal_case, to_pascal_case,
to_scalar_tensor_name, to_scalar_tensor_name,
to_variable_names,
) )
from generate_op import process_invoke_op from generate_op import add_fluid_name, process_invoke_op
from jinja2 import Environment, FileSystemLoader, StrictUndefined from jinja2 import Environment, FileSystemLoader, StrictUndefined
from parse_utils import to_named_dict from parse_utils import to_named_dict
from tests import ( from tests import (
is_base_op, is_base_op,
is_composite_op,
is_initializer_list, is_initializer_list,
is_scalar, is_scalar,
is_vec, is_vec,
...@@ -60,7 +62,9 @@ env.filters["to_input_name"] = to_input_name ...@@ -60,7 +62,9 @@ env.filters["to_input_name"] = to_input_name
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
env.filters["to_composite_grad_opmaker_name"] = to_composite_grad_opmaker_name env.filters["to_composite_grad_opmaker_name"] = to_composite_grad_opmaker_name
env.filters["to_variable_names"] = to_variable_names
env.tests["base_op"] = is_base_op env.tests["base_op"] = is_base_op
env.tests["composite_op"] = is_composite_op
env.tests["vec"] = is_vec env.tests["vec"] = is_vec
env.tests["scalar"] = is_scalar env.tests["scalar"] = is_scalar
env.tests["initializer_list"] = is_initializer_list env.tests["initializer_list"] = is_initializer_list
...@@ -96,9 +100,18 @@ def main(op_yaml_path, backward_yaml_path, output_op_path, output_arg_map_path): ...@@ -96,9 +100,18 @@ def main(op_yaml_path, backward_yaml_path, output_op_path, output_arg_map_path):
op['name'] = op['op_name'] op['name'] = op['op_name']
if op["backward"] is not None: if op["backward"] is not None:
op["backward"] = SPARSE_OP_PREFIX + op["backward"] op["backward"] = SPARSE_OP_PREFIX + op["backward"]
add_fluid_name(op["inputs"])
add_fluid_name(op["attrs"])
add_fluid_name(op["outputs"])
for bw_op in backward_ops: for bw_op in backward_ops:
bw_op['op_name'] = SPARSE_OP_PREFIX + bw_op['name'] bw_op['op_name'] = SPARSE_OP_PREFIX + bw_op['name']
bw_op['name'] = bw_op['op_name'] bw_op['name'] = bw_op['op_name']
add_fluid_name(bw_op["inputs"])
add_fluid_name(bw_op["attrs"])
add_fluid_name(bw_op["outputs"])
add_fluid_name(bw_op["forward"]["inputs"])
add_fluid_name(bw_op["forward"]["attrs"])
add_fluid_name(bw_op["forward"]["outputs"])
if 'invoke' in bw_op: if 'invoke' in bw_op:
bw_op['invoke']['args'] = [ bw_op['invoke']['args'] = [
param.strip() for param in bw_op['invoke']['args'].split(',') param.strip() for param in bw_op['invoke']['args'].split(',')
...@@ -139,7 +152,6 @@ def main(op_yaml_path, backward_yaml_path, output_op_path, output_arg_map_path): ...@@ -139,7 +152,6 @@ def main(op_yaml_path, backward_yaml_path, output_op_path, output_arg_map_path):
ops=ops, ops=ops,
backward_ops=backward_ops, backward_ops=backward_ops,
op_dict=op_dict, op_dict=op_dict,
composite_gen_flag=False,
) )
f.write(msg) f.write(msg)
......
...@@ -28,12 +28,14 @@ from filters import ( ...@@ -28,12 +28,14 @@ from filters import (
to_opmaker_name_cstr, to_opmaker_name_cstr,
to_pascal_case, to_pascal_case,
to_scalar_tensor_name, to_scalar_tensor_name,
to_variable_names,
) )
from generate_op import replace_compat_name from generate_op import add_compat_name, add_fluid_name
from jinja2 import Environment, FileSystemLoader, StrictUndefined from jinja2 import Environment, FileSystemLoader, StrictUndefined
from parse_utils import to_named_dict from parse_utils import to_named_dict
from tests import ( from tests import (
is_base_op, is_base_op,
is_composite_op,
is_initializer_list, is_initializer_list,
is_scalar, is_scalar,
is_vec, is_vec,
...@@ -60,7 +62,9 @@ env.filters["to_input_name"] = to_input_name ...@@ -60,7 +62,9 @@ env.filters["to_input_name"] = to_input_name
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
env.filters["to_composite_grad_opmaker_name"] = to_composite_grad_opmaker_name env.filters["to_composite_grad_opmaker_name"] = to_composite_grad_opmaker_name
env.filters["to_variable_names"] = to_variable_names
env.tests["base_op"] = is_base_op env.tests["base_op"] = is_base_op
env.tests["composite_op"] = is_composite_op
env.tests["vec"] = is_vec env.tests["vec"] = is_vec
env.tests["scalar"] = is_scalar env.tests["scalar"] = is_scalar
env.tests["initializer_list"] = is_initializer_list env.tests["initializer_list"] = is_initializer_list
...@@ -100,8 +104,11 @@ def main( ...@@ -100,8 +104,11 @@ def main(
for op in ops: for op in ops:
op['op_name'] = op['name'] op['op_name'] = op['name']
add_fluid_name(op["inputs"])
add_fluid_name(op["attrs"])
add_fluid_name(op["outputs"])
replace_compat_name(op_op_map, forward_op_dict, {}) add_compat_name(op_op_map, forward_op_dict, {})
if len(ops) == 0: if len(ops) == 0:
if os.path.isfile(output_op_path): if os.path.isfile(output_op_path):
...@@ -116,7 +123,6 @@ def main( ...@@ -116,7 +123,6 @@ def main(
ops=ops, ops=ops,
backward_ops=[], backward_ops=[],
op_dict=forward_op_dict, op_dict=forward_op_dict,
composite_gen_flag=False,
) )
f.write(msg) f.write(msg)
......
...@@ -294,14 +294,13 @@ def parse_composite( ...@@ -294,14 +294,13 @@ def parse_composite(
composite_config: str, composite_config: str,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
# composite_config: func(args1, args2,.....) # composite_config: func(args1, args2,.....)
fname = r'(.*?)' result = re.search(
wspace = r'\s*' r"(?P<func_name>[a-z][a-z0-9_]+)\s*\((?P<func_args>[^\)]+)\)",
fargs = r'(.*?)' composite_config,
pattern = fr'{fname}{wspace}\({wspace}{fargs}{wspace}\)' )
m = re.search(pattern, composite_config) func_name = result.group("func_name")
func_name = m.group(1) func_args = result.group("func_args")
func_args = m.group(2)
composite_dict = {} composite_dict = {}
composite_dict["func_name"] = func_name composite_dict["func_name"] = func_name
......
...@@ -39,11 +39,9 @@ using paddle::framework::GradVarName; ...@@ -39,11 +39,9 @@ using paddle::framework::GradVarName;
{% else %} {% else %}
{{backward_op_reused_maker(op, op_dict[op["forward"]["name"]], op["invoke"])}} {{backward_op_reused_maker(op, op_dict[op["forward"]["name"]], op["invoke"])}}
{% endif %} {% endif %}
{% if composite_gen_flag == True %} {% if op is composite_op %}
{% if op is composite_op %}
{{composite_grad_op_maker(op_dict[op["name"]])}} {{composite_grad_op_maker(op_dict[op["name"]])}}
{% endif %} {% endif %}
{% endif %}
{% endfor %} {% endfor %}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -51,7 +49,7 @@ using paddle::framework::GradVarName; ...@@ -51,7 +49,7 @@ using paddle::framework::GradVarName;
namespace ops = paddle::operators; namespace ops = paddle::operators;
{% for op in ops + backward_ops %} {% for op in ops + backward_ops %}
{% if op is base_op %} {% if op is base_op %}
{{register_op_with_components(op, op_dict)}} {{register_op_with_components(op)}}
{{register_op_version(op)}} {{register_op_version(op)}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
...@@ -12,7 +12,7 @@ class {{op_name | to_pascal_case}}OpMaker : public framework::OpProtoAndCheckerM ...@@ -12,7 +12,7 @@ class {{op_name | to_pascal_case}}OpMaker : public framework::OpProtoAndCheckerM
{{add_output(loop.index0, output, op_name)}}; {{add_output(loop.index0, output, op_name)}};
{% endfor %} {% endfor %}
{% for attr in op["attrs"] %} {% for attr in op["attrs"] %}
{% if attr["name"] in op["kernel"]["param"] %} {% if attr["fluid_name"] in op["kernel"]["param"] %}
{{add_attr(loop.index0, attr, op_name)}}; {{add_attr(loop.index0, attr, op_name)}};
{% endif %} {% endif %}
{% endfor %} {% endfor %}
...@@ -27,7 +27,7 @@ TODO: Documentation of {{op_name}} op. ...@@ -27,7 +27,7 @@ TODO: Documentation of {{op_name}} op.
{# add input, it could be duplicable or dispensable #} {# add input, it could be duplicable or dispensable #}
{% macro add_input(i, input, op_name) %}{# inline #} {% macro add_input(i, input, op_name) %}{# inline #}
{% set name = input["name"] %} {% set name = input["fluid_name"] %}
{% set typename = input["typename"] %} {% set typename = input["typename"] %}
AddInput({{name| to_opmaker_name}}, "({{typename}}), input {{i}} of {{op_name}} op.") AddInput({{name| to_opmaker_name}}, "({{typename}}), input {{i}} of {{op_name}} op.")
{%- if typename is vec %} {%- if typename is vec %}
...@@ -42,7 +42,7 @@ AddInput({{name| to_opmaker_name}}, "({{typename}}), input {{i}} of {{op_name}} ...@@ -42,7 +42,7 @@ AddInput({{name| to_opmaker_name}}, "({{typename}}), input {{i}} of {{op_name}}
{# add output, it could be duplicable or intermediate, however, optional output is not supported #} {# add output, it could be duplicable or intermediate, however, optional output is not supported #}
{% macro add_output(i, output, op_name) %}{# inline #} {% macro add_output(i, output, op_name) %}{# inline #}
{% set name = output["name"] %} {% set name = output["fluid_name"] %}
{% set typename = output["typename"] %} {% set typename = output["typename"] %}
{% set is_intermediate = output["intermediate"] %} {% set is_intermediate = output["intermediate"] %}
AddOutput({{name | to_opmaker_name}}, "({{typename}}), output {{i}} of {{op_name}} op.") AddOutput({{name | to_opmaker_name}}, "({{typename}}), output {{i}} of {{op_name}} op.")
...@@ -66,7 +66,7 @@ AddOutput({{name | to_opmaker_name}}, "({{typename}}), output {{i}} of {{op_name ...@@ -66,7 +66,7 @@ AddOutput({{name | to_opmaker_name}}, "({{typename}}), output {{i}} of {{op_name
{# add attribute, and process default value if needed #} {# add attribute, and process default value if needed #}
{% macro add_attr(i, attr, op_name) %}{# inline #} {% macro add_attr(i, attr, op_name) %}{# inline #}
{% set name = attr["name"] %} {% set name = attr["fluid_name"] %}
{% set typename = attr["typename"] %} {% set typename = attr["typename"] %}
{% if typename is scalar %} {% if typename is scalar %}
AddInput("{{attr | to_scalar_tensor_name}}", "attribute {{i}} for {{op_name}} op from 0D Tensor.") AddInput("{{attr | to_scalar_tensor_name}}", "attribute {{i}} for {{op_name}} op from 0D Tensor.")
...@@ -153,15 +153,15 @@ All possible KernelSignatures returned by {{op["name"] | to_pascal_case }}OpArgu ...@@ -153,15 +153,15 @@ All possible KernelSignatures returned by {{op["name"] | to_pascal_case }}OpArgu
{% set kernel_in_type_list = kernel_config["dispatch"][kernel_func][0] %} {% set kernel_in_type_list = kernel_config["dispatch"][kernel_func][0] %}
if ( {%- for input in inputs %} if ( {%- for input in inputs %}
{%- if input["name"] in kernel_config["param"] %} {%- if input["fluid_name"] in kernel_config["param"] %}
{%- if kernel_in_type_list[input_idx.idx] == "dense" %} {%- if kernel_in_type_list[input_idx.idx] == "dense" %}
ctx.IsDenseTensorInput("{{input["name"]}}"){{" && " if not loop.last}} ctx.IsDenseTensorInput("{{input["fluid_name"]}}"){{" && " if not loop.last}}
{%- elif kernel_in_type_list[input_idx.idx] == "selected_rows" %} {%- elif kernel_in_type_list[input_idx.idx] == "selected_rows" %}
ctx.IsSelectedRowsInput("{{input["name"]}}"){{" && " if not loop.last}} ctx.IsSelectedRowsInput("{{input["fluid_name"]}}"){{" && " if not loop.last}}
{%- elif kernel_in_type_list[input_idx.idx] == "sparse_coo" %} {%- elif kernel_in_type_list[input_idx.idx] == "sparse_coo" %}
ctx.IsSparseCooTensorInput("{{input["name"]}}"){{" && " if not loop.last}} ctx.IsSparseCooTensorInput("{{input["fluid_name"]}}"){{" && " if not loop.last}}
{%- elif kernel_in_type_list[input_idx.idx] == "sparse_csr" %} {%- elif kernel_in_type_list[input_idx.idx] == "sparse_csr" %}
ctx.IsSparseCsrTensorInput("{{input["name"]}}"){{" && " if not loop.last}} ctx.IsSparseCsrTensorInput("{{input["fluid_name"]}}"){{" && " if not loop.last}}
{%- endif %} {%- endif %}
{% set input_idx.idx = input_idx.idx + 1 %} {% set input_idx.idx = input_idx.idx + 1 %}
{%- endif %} {%- endif %}
...@@ -210,8 +210,8 @@ PD_REGISTER_ARG_MAPPING_FN({{op["op_name"]}}, phi::{{op["op_name"] | to_pascal_c ...@@ -210,8 +210,8 @@ PD_REGISTER_ARG_MAPPING_FN({{op["op_name"]}}, phi::{{op["op_name"] | to_pascal_c
{% macro get_input_list(inputs, kernel_args) %}{# inline #} {% macro get_input_list(inputs, kernel_args) %}{# inline #}
paddle::small_vector<const char*> inputs { paddle::small_vector<const char*> inputs {
{%- for input in inputs %} {%- for input in inputs %}
{%- if input["name"] in kernel_args %} {%- if input["fluid_name"] in kernel_args %}
{{input["name"] | to_opmaker_name_cstr}}{{", " if not loop.last}} {{input["fluid_name"] | to_opmaker_name_cstr}}{{", " if not loop.last}}
{%- endif %} {%- endif %}
{%- endfor %} {%- endfor %}
} }
...@@ -219,8 +219,8 @@ paddle::small_vector<const char*> inputs { ...@@ -219,8 +219,8 @@ paddle::small_vector<const char*> inputs {
{% macro get_an_attr(attr, kernel_args) %}{# inline #} {% macro get_an_attr(attr, kernel_args) %}{# inline #}
{% set typename = attr["typename"] %} {% set typename = attr["typename"] %}
{%- if attr["name"] in kernel_args %} {%- if attr["fluid_name"] in kernel_args %}
{% set name = attr["name"] %} {% set name = attr["fluid_name"] %}
{% if typename is scalar %}{# scalar correspond to a dispensable input and an attr in opmaker #} {% if typename is scalar %}{# scalar correspond to a dispensable input and an attr in opmaker #}
attrs.emplace_back(ctx.HasInput("{{attr | to_scalar_tensor_name}}") ? "{{attr | to_scalar_tensor_name}}" : "{{name}}"); attrs.emplace_back(ctx.HasInput("{{attr | to_scalar_tensor_name}}") ? "{{attr | to_scalar_tensor_name}}" : "{{name}}");
{%- elif typename == "IntArray" %} {%- elif typename == "IntArray" %}
...@@ -251,7 +251,7 @@ attrs.emplace_back("{{name}}"); ...@@ -251,7 +251,7 @@ attrs.emplace_back("{{name}}");
{% macro get_output_list(outputs, kernel_args) %}{# inline #} {% macro get_output_list(outputs, kernel_args) %}{# inline #}
paddle::small_vector<const char*> outputs { paddle::small_vector<const char*> outputs {
{%- for output in outputs %} {%- for output in outputs %}
{{output["name"] | to_opmaker_name_cstr}}{{", " if not loop.last}} {{output["fluid_name"] | to_opmaker_name_cstr}}{{", " if not loop.last}}
{%- endfor %} {%- endfor %}
} }
{%- endmacro %} {%- endmacro %}
...@@ -263,7 +263,7 @@ phi::KernelKey GetExpectedKernelType( ...@@ -263,7 +263,7 @@ phi::KernelKey GetExpectedKernelType(
{%if kernel["data_type"] is not none %}{# data type ---------------------------------#} {%if kernel["data_type"] is not none %}{# data type ---------------------------------#}
{% if kernel["data_type"]["candidates"] | length == 1 %} {% if kernel["data_type"]["candidates"] | length == 1 %}
{% set data_type_arg = kernel["data_type"]["candidates"][0] %} {% set data_type_arg = kernel["data_type"]["candidates"][0] %}
{% set inputs = op["inputs"] | map(attribute="name") | list %} {% set inputs = op["inputs"] | map(attribute="fluid_name") | list %}
{% if data_type_arg in inputs %} {% if data_type_arg in inputs %}
auto data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, {{data_type_arg | to_opmaker_name}}); auto data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, {{data_type_arg | to_opmaker_name}});
{% if kernel["data_type"]["to_complex_flag"][0] %} {% if kernel["data_type"]["to_complex_flag"][0] %}
...@@ -319,9 +319,8 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER({{op["op_name"] | to_pascal_case}}NoNeedBuff ...@@ -319,9 +319,8 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER({{op["op_name"] | to_pascal_case}}NoNeedBuff
{% endif %} {% endif %}
{% endmacro%} {% endmacro%}
{% macro register_op_with_components(op, op_dict) %} {% macro register_op_with_components(op) %}
{% set name = op["op_name"] %} {% set name = op["op_name"] %}
{% set phi_name = op["name"] %}
REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op, REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op,
{% if not "forward" in op %}{# it is a forward op #} {% if not "forward" in op %}{# it is a forward op #}
ops::{{name | to_pascal_case}}OpMaker, ops::{{name | to_pascal_case}}OpMaker,
...@@ -337,8 +336,8 @@ REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op, ...@@ -337,8 +336,8 @@ REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op,
{% if op is supports_inplace %}{# inplace#} {% if op is supports_inplace %}{# inplace#}
ops::{{name | to_pascal_case}}InplaceInferer, ops::{{name | to_pascal_case}}InplaceInferer,
{% endif %} {% endif %}
{% if "phi_backward" in op and op["phi_backward"] is not none and "composite" in op_dict[op["phi_backward"]] %} {% if "backward_composite" in op and op["backward_composite"] is not none %}
ops::{{op["phi_backward"] | to_composite_grad_opmaker_name}}, ops::{{op["backward_composite"] | to_composite_grad_opmaker_name}},
{% endif %} {% endif %}
{% if op is supports_no_need_buffer %}{# no_need_buffer #} {% if op is supports_no_need_buffer %}{# no_need_buffer #}
ops::{{name | to_pascal_case}}NoNeedBufferVarInferer, ops::{{name | to_pascal_case}}NoNeedBufferVarInferer,
...@@ -391,12 +390,12 @@ REGISTER_OP_VERSION({{name}}) ...@@ -391,12 +390,12 @@ REGISTER_OP_VERSION({{name}})
{# --------------------------------------- backward op maker ---------------------------------------------- #} {# --------------------------------------- backward op maker ---------------------------------------------- #}
{% macro backward_op_maker(op, forward_op ) %} {% macro backward_op_maker(op, forward_op ) %}
{% set name = op["op_name"] %} {% set name = op["op_name"] %}
{% set forward_input_names = op["forward"]["inputs"] | map(attribute="name") | list %} {% set forward_input_names = op["forward"]["inputs"] | map(attribute="fluid_name") | list %}
{% set forward_output_names = op["forward"]["outputs"] | map(attribute="name") | list %} {% set forward_output_names = op["forward"]["outputs"] | map(attribute="fluid_name") | list %}
{% set forward_attr_names = op["forward"]["attrs"] | map(attribute="name") | list %} {% set forward_attr_names = op["forward"]["attrs"] | map(attribute="fluid_name") | list %}
{% set forward_input_orig_names = forward_op["inputs"] | map(attribute="name") | list %} {% set forward_input_orig_names = forward_op["inputs"] | map(attribute="fluid_name") | list %}
{% set forward_output_orig_names = forward_op["outputs"] | map(attribute="name") | list %} {% set forward_output_orig_names = forward_op["outputs"] | map(attribute="fluid_name") | list %}
{% set forward_attr_orig_names = forward_op["attrs"] | map(attribute="name") | list %} {% set forward_attr_orig_names = forward_op["attrs"] | map(attribute="fluid_name") | list %}
template <typename T> template <typename T>
class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> { class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> {
public: public:
...@@ -407,8 +406,8 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> ...@@ -407,8 +406,8 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
grad_op->SetType("{{name}}"); grad_op->SetType("{{name}}");
{% for input in op["inputs"] %} {% for input in op["inputs"] %}
grad_op->SetInput({{input["name"] | to_opmaker_name}}, this->{{extract_input_from_forward( grad_op->SetInput({{input["fluid_name"] | to_opmaker_name}}, this->{{extract_input_from_forward(
input["name"], input["fluid_name"],
forward_input_names, forward_input_names,
forward_output_names, forward_output_names,
forward_input_orig_names, forward_input_orig_names,
...@@ -416,8 +415,8 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> ...@@ -416,8 +415,8 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
{% endfor %} {% endfor %}
{% for output in op["outputs"] %} {% for output in op["outputs"] %}
grad_op->SetOutput({{output["name"] | to_opmaker_name}}, this->{{extract_output_from_forward( grad_op->SetOutput({{output["fluid_name"] | to_opmaker_name}}, this->{{extract_output_from_forward(
output["name"], output["fluid_name"],
forward_input_names, forward_input_names,
forward_output_names, forward_output_names,
forward_input_orig_names, forward_input_orig_names,
...@@ -427,7 +426,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> ...@@ -427,7 +426,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
grad_op->SetAttrMap(this->Attrs()); grad_op->SetAttrMap(this->Attrs());
{% for attr in op["attrs"] %} {% for attr in op["attrs"] %}
{% set attr_name = attr["name"] %} {% set attr_name = attr["fluid_name"] %}
{% if attr_name in forward_attr_names %} {% if attr_name in forward_attr_names %}
{% if attr["typename"] == "IntArray" %} {% if attr["typename"] == "IntArray" %}
{% if 'tensor_name' in attr or 'manual_flag' not in attr %} {% if 'tensor_name' in attr or 'manual_flag' not in attr %}
...@@ -455,12 +454,12 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> ...@@ -455,12 +454,12 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
{% macro backward_op_reused_maker(bw_op, forward_op, invoke_op) %} {% macro backward_op_reused_maker(bw_op, forward_op, invoke_op) %}
{% set name = bw_op["op_name"] %} {% set name = bw_op["op_name"] %}
{% set forward_input_names = bw_op["forward"]["inputs"] | map(attribute="name") | list %} {% set forward_input_names = bw_op["forward"]["inputs"] | map(attribute="fluid_name") | list %}
{% set forward_output_names = bw_op["forward"]["outputs"] | map(attribute="name") | list %} {% set forward_output_names = bw_op["forward"]["outputs"] | map(attribute="fluid_name") | list %}
{% set forward_attr_names = bw_op["forward"]["attrs"] | map(attribute="name") | list %} {% set forward_attr_names = bw_op["forward"]["attrs"] | map(attribute="fluid_name") | list %}
{% set forward_input_orig_names = forward_op["inputs"] | map(attribute="name") | list %} {% set forward_input_orig_names = forward_op["inputs"] | map(attribute="fluid_name") | list %}
{% set forward_output_orig_names = forward_op["outputs"] | map(attribute="name") | list %} {% set forward_output_orig_names = forward_op["outputs"] | map(attribute="fluid_name") | list %}
{% set forward_attr_orig_names = forward_op["attrs"] | map(attribute="name") | list %} {% set forward_attr_orig_names = forward_op["attrs"] | map(attribute="fluid_name") | list %}
template <typename T> template <typename T>
class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> { class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> {
public: public:
...@@ -471,7 +470,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> ...@@ -471,7 +470,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
grad_op->SetType("{{invoke_op["func"]}}"); grad_op->SetType("{{invoke_op["func"]}}");
{% for input in invoke_op["inputs"] %} {% for input in invoke_op["inputs"] %}
grad_op->SetInput({{input["name"] | to_opmaker_name}}, this->{{extract_input_from_forward( grad_op->SetInput({{input["fluid_name"] | to_opmaker_name}}, this->{{extract_input_from_forward(
input["value"], input["value"],
forward_input_names, forward_input_names,
forward_output_names, forward_output_names,
...@@ -480,7 +479,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> ...@@ -480,7 +479,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
{% endfor %} {% endfor %}
{% for output in invoke_op["outputs"] %} {% for output in invoke_op["outputs"] %}
grad_op->SetOutput({{output["name"] | to_opmaker_name}}, this->{{extract_output_from_forward( grad_op->SetOutput({{output["fluid_name"] | to_opmaker_name}}, this->{{extract_output_from_forward(
output["value"], output["value"],
forward_input_names, forward_input_names,
forward_output_names, forward_output_names,
...@@ -490,42 +489,49 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> ...@@ -490,42 +489,49 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
{% endfor %} {% endfor %}
{% for attr in invoke_op["attrs"] %} {% for attr in invoke_op["attrs"] %}
grad_op->SetAttr("{{attr["name"]}}", {{attr["value"]}}); grad_op->SetAttr("{{attr["fluid_name"]}}", {{attr["value"]}});
{% endfor %} {% endfor %}
} }
}; };
{% endmacro %} {% endmacro %}
{% macro composite_grad_op_maker(composite_op_dict) %} {% macro composite_grad_op_maker(backward_op) %}
{% set op_name = composite_op_dict["name"] %} {% set op_name = backward_op["op_name"] %}
class {{op_name | to_composite_grad_opmaker_name}} : public prim::GradCompositeOpMakerBase { {% set inputs = backward_op["inputs"] | to_variable_names("name")%}
{% set input_dict = backward_op["input_dict"] %}
{% set fluid_inputs = backward_op["inputs"] | to_variable_names("fluid_name")%}
{% set forward_fluid_inputs = backward_op["forward"]["inputs"] | to_variable_names("fluid_name")%}
{% set forward_fluid_outputs = backward_op["forward"]["outputs"] | to_variable_names("fluid_name")%}
{% set attrs = backward_op["attrs"] | to_variable_names("name") %}
{% set fluid_attrs = backward_op["attrs"] | to_variable_names("fluid_name") %}
{% set attr_dict = backward_op["attr_dict"] %}
{% set outputs = backward_op["outputs"] | to_variable_names("name")%}
{% set output_dict = backward_op["output_dict"] %}
{% set fluid_outputs = backward_op["outputs"] | to_variable_names("fluid_name")%}
{% set composite_func_info = backward_op["composite"] %}
class {{op_name | to_composite_grad_opmaker_name}} : public prim::CompositeGradOpMakerBase {
public: public:
using prim::GradCompositeOpMakerBase::GradCompositeOpMakerBase; using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
void Apply() override { void Apply() override {
//get inputs //get inputs
{{construct_composite_input(composite_op_dict)}} {{construct_composite_input(inputs, fluid_inputs, forward_fluid_inputs, forward_fluid_outputs, input_dict)}}
//get attr //get attr
{{construct_composite_attr(composite_op_dict)}} {{construct_composite_attr(attrs, fluid_attrs, attr_dict)}}
//get output //get output
{{construct_composite_output(composite_op_dict)}} {{construct_composite_output(outputs, fluid_outputs, output_dict)}}
//get output ptr //get output ptr
{{construct_composite_output_ptr(composite_op_dict)}} {{construct_composite_output_ptr(outputs, output_dict)}}
//get output orginal name //get output orginal name
{{get_composite_output_orginal_name(composite_op_dict)}} {{get_composite_output_orginal_name(outputs, output_dict)}}
//call composite backward func //call composite backward func
{{call_composite_backward_api(composite_op_dict)}} {{call_composite_backward_api(composite_func_info)}}
//recover output name //recover output name
{{recover_composite_output_name(composite_op_dict)}} {{recover_composite_output_name(outputs)}}
} }
}; };
{%- endmacro %} {%- endmacro %}
{% macro construct_composite_input(composite_op_dict) %} {% macro construct_composite_input(inputs, fluid_inputs, forward_fluid_inputs, forward_fluid_outputs, input_dict) %}
{% set inputs = composite_op_dict["composite"]["phi_inputs"] %}
{% set input_dict = composite_op_dict["input_dict"] %}
{% set fluid_inputs = composite_op_dict["composite"]["fluid_inputs"] %}
{% set forward_fluid_inputs = composite_op_dict["forward"]["inputs"] | map(attribute="name") | list %}
{% set forward_fluid_outputs = composite_op_dict["forward"]["outputs"] | map(attribute="name") | list %}
{% set inputs_length = inputs | length %} {% set inputs_length = inputs | length %}
{% for i in range(inputs_length) %} {% for i in range(inputs_length) %}
{% set input_typename = input_dict[inputs[i]]["typename"] %} {% set input_typename = input_dict[inputs[i]]["typename"] %}
...@@ -534,13 +540,13 @@ class {{op_name | to_composite_grad_opmaker_name}} : public prim::GradCompositeO ...@@ -534,13 +540,13 @@ class {{op_name | to_composite_grad_opmaker_name}} : public prim::GradCompositeO
{% if input_typename == "Tensor" %} {% if input_typename == "Tensor" %}
{% if input_optional_flag == True %} {% if input_optional_flag == True %}
paddle::optional<paddle::experimental::Tensor> {{inputs[i]}} = this->GetOptionalSingleForwardInput("{{fluid_inputs[i]}}"); paddle::optional<paddle::experimental::Tensor> {{inputs[i]}} = this->GetOptionalSingleForwardInput("{{fluid_inputs[i]}}");
{% elif input_optional_flag == False %} {% else %}
paddle::experimental::Tensor {{inputs[i]}} = this->GetSingleForwardInput("{{fluid_inputs[i]}}"); paddle::experimental::Tensor {{inputs[i]}} = this->GetSingleForwardInput("{{fluid_inputs[i]}}");
{% endif %} {% endif %}
{% elif input_typename == "Tensor[]" %} {% elif input_typename == "Tensor[]" %}
{% if input_optional_flag == True %} {% if input_optional_flag == True %}
std::vector<paddle::optional<paddle::experimental::Tensor>> {{inputs[i]}} = this->GetOptionalMultiForwardInput("{{fluid_inputs[i]}}"); std::vector<paddle::optional<paddle::experimental::Tensor>> {{inputs[i]}} = this->GetOptionalMultiForwardInput("{{fluid_inputs[i]}}");
{% elif input_optional_flag == False %} {% else %}
std::vector<paddle::experimental::Tensor> {{inputs[i]}} = this->GetMultiForwardInput("{{fluid_inputs[i]}}"); std::vector<paddle::experimental::Tensor> {{inputs[i]}} = this->GetMultiForwardInput("{{fluid_inputs[i]}}");
{% endif %} {% endif %}
{% endif %} {% endif %}
...@@ -548,13 +554,13 @@ class {{op_name | to_composite_grad_opmaker_name}} : public prim::GradCompositeO ...@@ -548,13 +554,13 @@ class {{op_name | to_composite_grad_opmaker_name}} : public prim::GradCompositeO
{% if input_typename == "Tensor" %} {% if input_typename == "Tensor" %}
{% if input_optional_flag == True %} {% if input_optional_flag == True %}
paddle::optional<paddle::experimental::Tensor> {{inputs[i]}} = this->GetOptionalSingleForwardOutput("{{fluid_inputs[i]}}"); paddle::optional<paddle::experimental::Tensor> {{inputs[i]}} = this->GetOptionalSingleForwardOutput("{{fluid_inputs[i]}}");
{% elif input_optional_flag == False %} {% else %}
paddle::experimental::Tensor {{inputs[i]}} = this->GetSingleForwardOutput("{{fluid_inputs[i]}}"); paddle::experimental::Tensor {{inputs[i]}} = this->GetSingleForwardOutput("{{fluid_inputs[i]}}");
{% endif %} {% endif %}
{% elif input_typename == "Tensor[]" %} {% elif input_typename == "Tensor[]" %}
{% if input_optional_flag == True %} {% if input_optional_flag == True %}
std::vector<paddle::optional<paddle::experimental::Tensor>> {{inputs[i]}} = this->GetOptionalMultiForwardOutput("{{fluid_inputs[i]}}"); std::vector<paddle::optional<paddle::experimental::Tensor>> {{inputs[i]}} = this->GetOptionalMultiForwardOutput("{{fluid_inputs[i]}}");
{% elif input_optional_flag == False %} {% else %}
std::vector<paddle::experimental::Tensor> {{inputs[i]}} = this->GetMultiForwardOutput("{{fluid_inputs[i]}}"); std::vector<paddle::experimental::Tensor> {{inputs[i]}} = this->GetMultiForwardOutput("{{fluid_inputs[i]}}");
{% endif %} {% endif %}
{% endif %} {% endif %}
...@@ -562,13 +568,13 @@ class {{op_name | to_composite_grad_opmaker_name}} : public prim::GradCompositeO ...@@ -562,13 +568,13 @@ class {{op_name | to_composite_grad_opmaker_name}} : public prim::GradCompositeO
{% if input_typename == "Tensor" %} {% if input_typename == "Tensor" %}
{% if input_optional_flag == True %} {% if input_optional_flag == True %}
paddle::optional<paddle::experimental::Tensor> {{inputs[i]}} = this->GetOptionalSingleOutputGrad("{{fluid_inputs[i][:-5]}}"); paddle::optional<paddle::experimental::Tensor> {{inputs[i]}} = this->GetOptionalSingleOutputGrad("{{fluid_inputs[i][:-5]}}");
{% elif input_optional_flag == False %} {% else %}
paddle::experimental::Tensor {{inputs[i]}} = this->GetSingleOutputGrad("{{fluid_inputs[i][:-5]}}"); paddle::experimental::Tensor {{inputs[i]}} = this->GetSingleOutputGrad("{{fluid_inputs[i][:-5]}}");
{% endif %} {% endif %}
{% elif input_typename == "Tensor[]" %} {% elif input_typename == "Tensor[]" %}
{% if input_optional_flag == True %} {% if input_optional_flag == True %}
std::vector<paddle::optional<paddle::experimental::Tensor>> {{inputs[i]}} = this->GetOptionalMultiOutputGrad("{{fluid_inputs[i][:-5]}}"); std::vector<paddle::optional<paddle::experimental::Tensor>> {{inputs[i]}} = this->GetOptionalMultiOutputGrad("{{fluid_inputs[i][:-5]}}");
{% elif input_optional_flag == False %} {% else %}
std::vector<paddle::experimental::Tensor> {{inputs[i]}} = this->GetMultiOutputGrad("{{fluid_inputs[i][:-5]}}"); std::vector<paddle::experimental::Tensor> {{inputs[i]}} = this->GetMultiOutputGrad("{{fluid_inputs[i][:-5]}}");
{%- endif %} {%- endif %}
{%- endif %} {%- endif %}
...@@ -576,24 +582,18 @@ class {{op_name | to_composite_grad_opmaker_name}} : public prim::GradCompositeO ...@@ -576,24 +582,18 @@ class {{op_name | to_composite_grad_opmaker_name}} : public prim::GradCompositeO
{%- endfor %} {%- endfor %}
{%- endmacro %} {%- endmacro %}
{% macro construct_composite_attr(composite_op_dict) %} {% macro construct_composite_attr(attrs, fluid_attrs, attr_dict) %}
{% set attrs = composite_op_dict["composite"]["phi_attrs"] %}
{% set fluid_attrs = composite_op_dict["composite"]["fluid_attrs"] %}
{% set fluid_attrs_dict = composite_op_dict["attr_dict"] %}
{% set attrs_length = attrs | length %} {% set attrs_length = attrs | length %}
{% for i in range(attrs_length) %} {% for i in range(attrs_length) %}
{% set attrs_data_type = fluid_attrs_dict[fluid_attrs[i]]["typename"] | to_op_attr_type %} {% set attrs_data_type = attr_dict[attrs[i]]["typename"] | to_op_attr_type %}
{{attrs_data_type}} {{attrs[i]}} = this->Attr<{{attrs_data_type}}>("{{fluid_attrs[i]}}"); const {{attrs_data_type}} {{attrs[i]}} = this->Attr<{{attrs_data_type}}>("{{fluid_attrs[i]}}");
{% endfor %} {% endfor %}
{%- endmacro %} {%- endmacro %}
{% macro construct_composite_output(composite_op_dict) %} {% macro construct_composite_output(outputs, fluid_outputs, output_dict) %}
{% set outputs = composite_op_dict["composite"]["phi_outputs"] %}
{% set fluid_outputs = composite_op_dict["composite"]["fluid_outputs"] %}
{% set outputs_dict = composite_op_dict["output_dict"] %}
{% set outputs_length = outputs | length %} {% set outputs_length = outputs | length %}
{% for i in range(outputs_length) %} {% for i in range(outputs_length) %}
{% set output_typename = outputs_dict[outputs[i]]["typename"] %} {% set output_typename = output_dict[outputs[i]]["typename"] %}
{% if output_typename == "Tensor" %} {% if output_typename == "Tensor" %}
paddle::experimental::Tensor {{outputs[i] + "_t"}} = this->GetSingleInputGrad("{{fluid_outputs[i][:-5]}}"); paddle::experimental::Tensor {{outputs[i] + "_t"}} = this->GetSingleInputGrad("{{fluid_outputs[i][:-5]}}");
{% elif output_typename == "Tensor[]" %} {% elif output_typename == "Tensor[]" %}
...@@ -602,12 +602,10 @@ class {{op_name | to_composite_grad_opmaker_name}} : public prim::GradCompositeO ...@@ -602,12 +602,10 @@ class {{op_name | to_composite_grad_opmaker_name}} : public prim::GradCompositeO
{%- endfor %} {%- endfor %}
{%- endmacro %} {%- endmacro %}
{% macro construct_composite_output_ptr(composite_op_dict) %} {% macro construct_composite_output_ptr(outputs, output_dict) %}
{% set outputs = composite_op_dict["composite"]["phi_outputs"] %}
{% set outputs_dict = composite_op_dict["output_dict"] %}
{% set outputs_length = outputs | length %} {% set outputs_length = outputs | length %}
{% for i in range(outputs_length) %} {% for i in range(outputs_length) %}
{% set output_typename = outputs_dict[outputs[i]]["typename"] %} {% set output_typename = output_dict[outputs[i]]["typename"] %}
{% if output_typename == "Tensor" %} {% if output_typename == "Tensor" %}
paddle::experimental::Tensor* {{outputs[i]}} = this->GetOutputPtr(&{{outputs[i]+ "_t"}}); paddle::experimental::Tensor* {{outputs[i]}} = this->GetOutputPtr(&{{outputs[i]+ "_t"}});
{% elif output_typename == "Tensor[]" %} {% elif output_typename == "Tensor[]" %}
...@@ -620,12 +618,10 @@ class {{op_name | to_composite_grad_opmaker_name}} : public prim::GradCompositeO ...@@ -620,12 +618,10 @@ class {{op_name | to_composite_grad_opmaker_name}} : public prim::GradCompositeO
{%- endfor %} {%- endfor %}
{%- endmacro %} {%- endmacro %}
{% macro get_composite_output_orginal_name(composite_op_dict) %} {% macro get_composite_output_orginal_name(outputs, output_dict) %}
{% set outputs = composite_op_dict["composite"]["phi_outputs"] %}
{% set outputs_dict = composite_op_dict["output_dict"] %}
{% set outputs_length = outputs | length %} {% set outputs_length = outputs | length %}
{% for i in range(outputs_length) %} {% for i in range(outputs_length) %}
{% set output_typename = outputs_dict[outputs[i]]["typename"] %} {% set output_typename = output_dict[outputs[i]]["typename"] %}
{% if output_typename == "Tensor" %} {% if output_typename == "Tensor" %}
std::string {{outputs[i] + "_name"}} = this->GetOutputName({{outputs[i] + "_t"}}); std::string {{outputs[i] + "_name"}} = this->GetOutputName({{outputs[i] + "_t"}});
{% elif output_typename == "Tensor[]" %} {% elif output_typename == "Tensor[]" %}
...@@ -634,13 +630,12 @@ class {{op_name | to_composite_grad_opmaker_name}} : public prim::GradCompositeO ...@@ -634,13 +630,12 @@ class {{op_name | to_composite_grad_opmaker_name}} : public prim::GradCompositeO
{%- endfor %} {%- endfor %}
{%- endmacro %} {%- endmacro %}
{% macro call_composite_backward_api(composite_op_dict) %} {% macro call_composite_backward_api(composite_func_info) %}
VLOG(3) << "Runing {{composite_op_dict["composite"]["func_name"]}} composite func"; VLOG(3) << "Runing {{composite_func_info["func_name"]}} composite func";
prim::{{composite_op_dict["composite"]["func_name"]}}<prim::DescTensor>({{composite_op_dict["composite"]["func_args"]}}); prim::{{composite_func_info["func_name"]}}<prim::DescTensor>({{composite_func_info["func_args"]}});
{%- endmacro %} {%- endmacro %}
{% macro recover_composite_output_name(composite_op_dict) %} {% macro recover_composite_output_name(outputs) %}
{% set outputs = composite_op_dict["composite"]["phi_outputs"] %}
{% set outputs_length = outputs | length %} {% set outputs_length = outputs | length %}
{% for i in range(outputs_length) %} {% for i in range(outputs_length) %}
this->RecoverOutputName({{outputs[i] + "_t"}}, {{outputs[i] + "_name"}}); this->RecoverOutputName({{outputs[i] + "_t"}}, {{outputs[i] + "_name"}});
......
...@@ -135,9 +135,9 @@ struct TestBaseProgram { ...@@ -135,9 +135,9 @@ struct TestBaseProgram {
int idx_{0}; int idx_{0};
}; };
class TestGradCompositeGradMaker : public GradCompositeOpMakerBase { class TestGradCompositeGradMaker : public CompositeGradOpMakerBase {
public: public:
using prim::GradCompositeOpMakerBase::GradCompositeOpMakerBase; using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
void Apply() override {} void Apply() override {}
}; };
......
...@@ -41,9 +41,9 @@ namespace prim { ...@@ -41,9 +41,9 @@ namespace prim {
argument DropEmptyIG in the derived classes. argument DropEmptyIG in the derived classes.
*/ */
class GradCompositeOpMakerBase { class CompositeGradOpMakerBase {
public: public:
explicit GradCompositeOpMakerBase( explicit CompositeGradOpMakerBase(
const framework::OpDesc& fwd_op, const framework::OpDesc& fwd_op,
const std::unordered_set<std::string>& no_grad_set, const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var, std::unordered_map<std::string, std::string>* grad_to_var,
...@@ -61,7 +61,7 @@ class GradCompositeOpMakerBase { ...@@ -61,7 +61,7 @@ class GradCompositeOpMakerBase {
acting_program_.MutableBlock(0)); acting_program_.MutableBlock(0));
} }
virtual ~GradCompositeOpMakerBase() = default; virtual ~CompositeGradOpMakerBase() = default;
virtual std::vector<std::unique_ptr<framework::OpDesc>> operator()() { virtual std::vector<std::unique_ptr<framework::OpDesc>> operator()() {
this->Apply(); this->Apply();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册