未验证 提交 23c1ac2c 编写于 作者: Z zyfncg 提交者: GitHub

Support static graph code-gen for squeeze and unsqueeze op (#49430)

* support static graph code-gen for squeeze op

* generate static graph code of unsqueeze

* refine op name

* add extra output in op_compat

* remove debug log
上级 18f0ab86
...@@ -739,6 +739,14 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -739,6 +739,14 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
self.backward_returns_list, self.backward_returns_list,
) = ParseYamlBackward(backward_args_str, backward_returns_str) ) = ParseYamlBackward(backward_args_str, backward_returns_str)
# Remove the output which is intermediate
if 'intermediate' in grad_api_contents:
backward_returns_list_new = []
for return_item in self.backward_returns_list:
if return_item[0] not in grad_api_contents['intermediate']:
backward_returns_list_new.append(return_item)
self.backward_returns_list = backward_returns_list_new
def CollectForwardInfoFromBackwardContents(self): def CollectForwardInfoFromBackwardContents(self):
backward_forward_str = self.backward_forward_str backward_forward_str = self.backward_forward_str
...@@ -1979,7 +1987,6 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1979,7 +1987,6 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
fill_zero_str += f"{indent}egr::EagerUtils::FillZeroForEmptyGradInput(&grads[{fwd_position}], input_metas[{fwd_position}]);\n" fill_zero_str += f"{indent}egr::EagerUtils::FillZeroForEmptyGradInput(&grads[{fwd_position}], input_metas[{fwd_position}]);\n"
inplace_grad_input_str = "" inplace_grad_input_str = ""
inplaced_tensor_wrapper = False
inplace_check_str = "" inplace_check_str = ""
optional_inplace_var_name = [] optional_inplace_var_name = []
# Grad Ins from TensorWrappers # Grad Ins from TensorWrappers
......
...@@ -131,9 +131,10 @@ def process_int_array(op_item, int_array_configs): ...@@ -131,9 +131,10 @@ def process_int_array(op_item, int_array_configs):
) )
if attr_item['is_support_tensor']: if attr_item['is_support_tensor']:
attr_item['typename'] = ( attr_item['typename'] = (
data_type_map[int_array_config['data_type']] 'int[]'
if 'data_type' in int_array_config if 'data_type' in int_array_config
else 'std::vector<int64_t>' and int_array_config['data_type'] == 'int'
else 'int64_t[]'
) )
else: else:
attr_item['data_type'] = ( attr_item['data_type'] = (
...@@ -153,21 +154,95 @@ def process_int_array(op_item, int_array_configs): ...@@ -153,21 +154,95 @@ def process_int_array(op_item, int_array_configs):
# replace name of op and params for OpMaker # replace name of op and params for OpMaker
def replace_compat_name(op_op_map, forward_op_dict, backward_op_dict): def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
def get_op_and_op_name(op_item): def get_phi_and_fluid_op_name(op_item):
names = op_item.split('(') names = op_item.split('(')
if len(names) == 1: if len(names) == 1:
return names[0].strip(), names[0].strip() return names[0].strip(), names[0].strip()
else: else:
return names[0].strip(), names[1].split(')')[0].strip() return names[0].strip(), names[1].split(')')[0].strip()
def update_op_attr_name(attrs, attrs_alias_map): def update_op_param_name(op_args, args_alias_map):
for attr_item in attrs: for item in op_args:
if attr_item['name'] in attrs_alias_map: if item['name'] in args_alias_map:
attr_item['name'] = attrs_alias_map[attr_item['name']] item['name'] = args_alias_map[item['name']]
def update_grad_args_name(op_args, args_alias_map):
for item in op_args:
if (
item['name'].endswith('_grad')
and item['name'][:-5] in args_alias_map
):
args_alias_map[item['name']] = (
args_alias_map[item['name'][:-5]] + '_grad'
)
item['name'] = args_alias_map[item['name'][:-5]] + '_grad'
def get_param_list_alias(param_list, args_map):
return [
args_map[param] if param in args_map else param
for param in param_list
]
for op_args in op_op_map: def update_common_params_name(
new_op_name, op_name = get_op_and_op_name(op_args['op']) op_item, args_name_map, scalar_configs, int_array_configs
):
if 'inplace' in op_item and op_item['inplace']:
inplace_map = {}
for key, val in op_item['inplace'].items():
if key in args_map:
key = args_map[key]
if val in args_map:
val = args_map[val]
inplace_map[key] = val
op_item['inplace'] = inplace_map
if 'no_need_buffer' in op_item and op_item['no_need_buffer']:
op_item['no_need_buffer'] = get_param_list_alias(
op_item['no_need_buffer'], args_map
)
process_scalar(op_item, scalar_configs)
process_int_array(op_item, int_array_configs)
if 'invoke' in op_item:
op_item['invoke']['args'] = [
args_map[param.strip()]
if param.strip() in args_map
else param.strip()
for param in op_item['invoke']['args'].split(',')
]
return
op_item['infer_meta']['param'] = get_param_list_alias(
op_item['infer_meta']['param'], args_name_map
)
op_item['kernel']['param'] = get_param_list_alias(
op_item['kernel']['param'], args_name_map
)
if op_item['kernel']['data_type']:
op_item['kernel']['data_type']['candidates'] = get_param_list_alias(
op_item['kernel']['data_type']['candidates'], args_name_map
)
if op_item['kernel']['backend']:
op_item['kernel']['backend']['candidates'] = get_param_list_alias(
op_item['kernel']['backend']['candidates'], args_name_map
)
if op_item['kernel']['layout']:
op_item['kernel']['layout']['candidates'] = get_param_list_alias(
op_item['kernel']['layout']['candidates'], args_name_map
)
def update_grad_op_compat_name(grad_op_item, args_name_map):
update_op_param_name(grad_op_item['inputs'], args_name_map)
update_op_param_name(grad_op_item['outputs'], args_name_map)
update_op_param_name(grad_op_item['attrs'], args_name_map)
update_op_param_name(grad_op_item['forward']['inputs'], args_name_map)
update_op_param_name(grad_op_item['forward']['outputs'], args_name_map)
update_op_param_name(grad_op_item['forward']['attrs'], args_name_map)
update_grad_args_name(grad_op_item['inputs'], args_map)
update_grad_args_name(grad_op_item['outputs'], args_map)
for op_args in op_fluid_map_list:
new_op_name, op_name = get_phi_and_fluid_op_name(op_args['op'])
if new_op_name not in forward_op_dict: if new_op_name not in forward_op_dict:
continue continue
forward_op_item = forward_op_dict[new_op_name] forward_op_item = forward_op_dict[new_op_name]
...@@ -179,190 +254,103 @@ def replace_compat_name(op_op_map, forward_op_dict, backward_op_dict): ...@@ -179,190 +254,103 @@ def replace_compat_name(op_op_map, forward_op_dict, backward_op_dict):
scalar_configs = None scalar_configs = None
int_array_configs = None int_array_configs = None
if 'scalar' in op_args: if 'scalar' in op_args:
scalar_configs = op_args['scalar'] scalar_configs = op_args['scalar']
if 'int_array' in op_args: if 'int_array' in op_args:
int_array_configs = op_args['int_array'] int_array_configs = op_args['int_array']
if 'extra' in op_args and 'outputs' in op_args['extra']:
for out_item in forward_op_item['outputs']:
if out_item['name'] in op_args['extra']['outputs']:
out_item['is_extra'] = True
process_scalar(forward_op_item, scalar_configs) key_set = ['inputs', 'attrs', 'outputs']
process_int_array(forward_op_item, int_array_configs) args_map = {}
for key in key_set:
if key in op_args:
args_map.update(op_args[key])
for args_item in forward_op_item[key]:
if args_item['name'] in op_args[key]:
if (
scalar_configs
and args_item['name'] in scalar_configs
):
scalar_configs[
op_args[key][args_item['name']]
] = scalar_configs[args_item['name']]
if (
int_array_configs
and args_item['name'] in int_array_configs
):
int_array_configs[
op_args[key][args_item['name']]
] = int_array_configs[args_item['name']]
args_item['name'] = op_args[key][args_item['name']]
if has_backward:
for args_item in backward_op_item['forward'][key]:
if args_item['name'] in op_args[key]:
args_item['name'] = op_args[key][args_item['name']]
forward_op_item["attr_dict"] = to_named_dict(forward_op_item["attrs"])
update_common_params_name(
forward_op_item, args_map, scalar_configs, int_array_configs
)
if has_backward:
update_grad_op_compat_name(backward_op_item, args_map)
update_common_params_name(
backward_op_item, args_map, scalar_configs, int_array_configs
)
backward_op_item["attr_dict"] = to_named_dict(
backward_op_item["attrs"]
)
if 'backward' not in op_args:
continue
if 'backward' in op_args and has_backward:
backward_op_list = op_args['backward'].split(',') backward_op_list = op_args['backward'].split(',')
_, bw_op_name = get_op_and_op_name(backward_op_list[0]) _, bw_op_name = get_phi_and_fluid_op_name(backward_op_list[0])
forward_op_item['backward'] = bw_op_name forward_op_item['backward'] = bw_op_name
backward_op_item['op_name'] = bw_op_name backward_op_item['op_name'] = bw_op_name
process_scalar(backward_op_item, scalar_configs)
process_int_array(backward_op_item, int_array_configs)
# for double grad # for double grad
if len(backward_op_list) > 1: if len(backward_op_list) > 1:
( (
new_double_grad_op_name, phi_double_grad_op_name,
double_grad_op_name, double_grad_op_name,
) = get_op_and_op_name(backward_op_list[1]) ) = get_phi_and_fluid_op_name(backward_op_list[1])
double_grad_item = backward_op_dict[new_double_grad_op_name] double_grad_item = backward_op_dict[phi_double_grad_op_name]
backward_op_item['backward'] = double_grad_op_name backward_op_item['backward'] = double_grad_op_name
double_grad_item['op_name'] = double_grad_op_name double_grad_item['op_name'] = double_grad_op_name
if 'attrs' in op_args: update_grad_op_compat_name(double_grad_item, args_map)
update_op_attr_name( update_common_params_name(
double_grad_item['attrs'], op_args['attrs'] double_grad_item,
args_map,
scalar_configs,
int_array_configs,
) )
update_op_attr_name( double_grad_item["attr_dict"] = to_named_dict(
double_grad_item['forward']['attrs'], op_args['attrs'] double_grad_item["attrs"]
) )
process_scalar(double_grad_item, scalar_configs)
process_int_array(double_grad_item, int_array_configs)
# for triple grad # for triple grad
if len(backward_op_list) > 2: if len(backward_op_list) > 2:
( (
new_triple_grad_op_name, phi_triple_grad_op_name,
triple_grad_op_name, triple_grad_op_name,
) = get_op_and_op_name(backward_op_list[2]) ) = get_phi_and_fluid_op_name(backward_op_list[2])
triple_grad_item = backward_op_dict[new_triple_grad_op_name] triple_grad_item = backward_op_dict[phi_triple_grad_op_name]
double_grad_item['backward'] = triple_grad_op_name double_grad_item['backward'] = triple_grad_op_name
triple_grad_item['op_name'] = triple_grad_op_name triple_grad_item['op_name'] = triple_grad_op_name
if 'attrs' in op_args: update_grad_op_compat_name(triple_grad_item, args_map)
update_op_attr_name( update_common_params_name(
triple_grad_item['attrs'], op_args['attrs'] triple_grad_item,
args_map,
scalar_configs,
int_array_configs,
) )
update_op_attr_name( triple_grad_item["attr_dict"] = to_named_dict(
triple_grad_item['forward']['attrs'], triple_grad_item["attrs"]
op_args['attrs'],
) )
process_scalar(triple_grad_item, scalar_configs)
process_int_array(triple_grad_item, int_array_configs)
key_set = ['inputs', 'attrs', 'outputs']
args_map = {}
for key in key_set:
if key in op_args:
args_map.update(op_args[key])
for args_item in forward_op_item[key]:
if args_item['name'] in op_args[key]:
args_item['name'] = op_args[key][args_item['name']]
if has_backward:
for args_item in backward_op_item['forward'][key]:
if args_item['name'] in op_args[key]:
args_item['name'] = op_args[key][args_item['name']]
forward_op_item['infer_meta']['param'] = [
args_map[param] if param in args_map else param
for param in forward_op_item['infer_meta']['param']
]
forward_op_item['kernel']['param'] = [
args_map[param] if param in args_map else param
for param in forward_op_item['kernel']['param']
]
if forward_op_item['kernel']['data_type']:
forward_op_item['kernel']['data_type']['candidates'] = [
args_map[param] if param in args_map else param
for param in forward_op_item['kernel']['data_type'][
'candidates'
]
]
if forward_op_item['kernel']['backend']:
forward_op_item['kernel']['backend']['candidates'] = [
args_map[param] if param in args_map else param
for param in forward_op_item['kernel']['backend']['candidates']
]
if forward_op_item['kernel']['layout']:
forward_op_item['kernel']['layout']['candidates'] = [
args_map[param] if param in args_map else param
for param in forward_op_item['kernel']['layout']['candidates']
]
if forward_op_item['inplace']:
inplace_map = {}
for key, val in forward_op_item['inplace'].items():
if key in args_map:
key = args_map[key]
if val in args_map:
val = args_map[val]
inplace_map[key] = val
forward_op_item['inplace'] = inplace_map
if has_backward:
for args_item in backward_op_item['inputs']:
if args_item['name'] in args_map:
args_item['name'] = args_map[args_item['name']]
elif (
args_item['name'].endswith('_grad')
and args_item['name'][:-5] in args_map
):
args_map[args_item['name']] = (
args_map[args_item['name'][:-5]] + '_grad'
)
args_item['name'] = args_map[args_item['name']]
for args_item in backward_op_item['attrs']:
if args_item['name'] in args_map:
args_item['name'] = args_map[args_item['name']]
for args_item in backward_op_item['outputs']:
if (
args_item['name'].endswith('_grad')
and args_item['name'][:-5] in args_map
):
args_map[args_item['name']] = (
args_map[args_item['name'][:-5]] + '_grad'
)
args_item['name'] = args_map[args_item['name']]
if 'invoke' in backward_op_item:
backward_op_item['invoke']['args'] = [
args_map[param.strip()]
if param.strip() in args_map
else param.strip()
for param in backward_op_item['invoke']['args'].split(',')
]
continue
backward_op_item['infer_meta']['param'] = [
args_map[param] if param in args_map else param
for param in backward_op_item['infer_meta']['param']
]
backward_op_item['kernel']['param'] = [
args_map[param] if param in args_map else param
for param in backward_op_item['kernel']['param']
]
if backward_op_item['kernel']['data_type']:
backward_op_item['kernel']['data_type']['candidates'] = [
args_map[param] if param in args_map else param
for param in backward_op_item['kernel']['data_type'][
'candidates'
]
]
if backward_op_item['kernel']['backend']:
backward_op_item['kernel']['backend']['candidates'] = [
args_map[param] if param in args_map else param
for param in backward_op_item['kernel']['backend'][
'candidates'
]
]
if backward_op_item['kernel']['layout']:
backward_op_item['kernel']['layout']['candidates'] = [
args_map[param] if param in args_map else param
for param in backward_op_item['kernel']['layout'][
'candidates'
]
]
if backward_op_item['no_need_buffer']:
backward_op_item['no_need_buffer'] = [
args_map[param] if param in args_map else param
for param in backward_op_item['no_need_buffer']
]
if backward_op_item['inplace']:
inplace_map = {}
for key, val in backward_op_item['inplace'].items():
if key in args_map:
key = args_map[key]
if val in args_map:
val = args_map[val]
inplace_map[key] = val
backward_op_item['inplace'] = inplace_map
def process_invoke_op(forward_op_dict, backward_op_dict): def process_invoke_op(forward_op_dict, backward_op_dict):
for bw_op in backward_op_dict.values(): for bw_op in backward_op_dict.values():
...@@ -372,6 +360,7 @@ def process_invoke_op(forward_op_dict, backward_op_dict): ...@@ -372,6 +360,7 @@ def process_invoke_op(forward_op_dict, backward_op_dict):
args_index = 0 args_index = 0
if invoke_op in forward_op_dict: if invoke_op in forward_op_dict:
reuse_op = forward_op_dict[invoke_op] reuse_op = forward_op_dict[invoke_op]
bw_op['invoke']['func'] = reuse_op['op_name']
bw_op['invoke']['inputs'] = [] bw_op['invoke']['inputs'] = []
bw_op['invoke']['attrs'] = [] bw_op['invoke']['attrs'] = []
bw_op['invoke']['outputs'] = [] bw_op['invoke']['outputs'] = []
...@@ -430,14 +419,14 @@ def main( ...@@ -430,14 +419,14 @@ def main(
forward_op_dict[op_version['op']]['version'] = op_version['version'] forward_op_dict[op_version['op']]['version'] = op_version['version']
with open(op_compat_yaml_path, "rt") as f: with open(op_compat_yaml_path, "rt") as f:
op_op_map = yaml.safe_load(f) op_fluid_map_list = yaml.safe_load(f)
for op in ops: for op in ops:
op['op_name'] = op['name'] op['op_name'] = op['name']
for bw_op in backward_ops: for bw_op in backward_ops:
bw_op['op_name'] = bw_op['name'] bw_op['op_name'] = bw_op['name']
replace_compat_name(op_op_map, forward_op_dict, backward_op_dict) replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict)
# prepare for invoke case # prepare for invoke case
process_invoke_op(forward_op_dict, backward_op_dict) process_invoke_op(forward_op_dict, backward_op_dict)
......
...@@ -54,6 +54,10 @@ AddOutput({{name | to_opmaker_name}}, "({{typename}}), output {{i}} of {{op_name ...@@ -54,6 +54,10 @@ AddOutput({{name | to_opmaker_name}}, "({{typename}}), output {{i}} of {{op_name
.AsIntermediate() .AsIntermediate()
{%- endif %} {%- endif %}
{%- if "is_extra" in output and output["is_extra"] %}
.AsExtra()
{%- endif %}
{%- endmacro %} {%- endmacro %}
{# add attribute, and process default value if needed #} {# add attribute, and process default value if needed #}
...@@ -115,7 +119,7 @@ KernelSignature {{op["op_name"] | to_pascal_case }}OpArgumentMapping(const Argum ...@@ -115,7 +119,7 @@ KernelSignature {{op["op_name"] | to_pascal_case }}OpArgumentMapping(const Argum
paddle::small_vector<const char*> attrs; paddle::small_vector<const char*> attrs;
{% for attr in op["attrs"]%} {% for attr in op["attrs"]%}
{% filter indent(2)%} {% filter indent(2)%}
{{get_an_attr(attr)}} {{get_an_attr(attr, kernel_args)}}
{% endfilter %} {% endfilter %}
{% endfor %} {% endfor %}
{{get_output_list(op["outputs"], kernel_args)}}; {{get_output_list(op["outputs"], kernel_args)}};
...@@ -170,7 +174,7 @@ KernelSignature {{op["op_name"] | to_pascal_case }}OpArgumentMapping(const Argum ...@@ -170,7 +174,7 @@ KernelSignature {{op["op_name"] | to_pascal_case }}OpArgumentMapping(const Argum
paddle::small_vector<const char*> attrs; paddle::small_vector<const char*> attrs;
{% for attr in op["attrs"]%} {% for attr in op["attrs"]%}
{% filter indent(2)%} {% filter indent(2)%}
{{get_an_attr(attr)}} {{get_an_attr(attr, kernel_args)}}
{% endfilter %} {% endfilter %}
{% endfor %} {% endfor %}
{{get_output_list(op["outputs"], kernel_args)}}; {{get_output_list(op["outputs"], kernel_args)}};
...@@ -209,8 +213,9 @@ paddle::small_vector<const char*> inputs { ...@@ -209,8 +213,9 @@ paddle::small_vector<const char*> inputs {
} }
{%- endmacro %} {%- endmacro %}
{% macro get_an_attr(attr) %}{# inline #} {% macro get_an_attr(attr, kernel_args) %}{# inline #}
{% set typename = attr["typename"] %} {% set typename = attr["typename"] %}
{%- if attr["name"] in kernel_args %}
{% set name = attr["name"] %} {% set name = attr["name"] %}
{% if typename is scalar %}{# scalar correspond to a dispensable input and an attr in opmaker #} {% if typename is scalar %}{# scalar correspond to a dispensable input and an attr in opmaker #}
attrs.emplace_back(ctx.HasInput("{{attr | to_scalar_tensor_name}}") ? "{{attr | to_scalar_tensor_name}}" : "{{name}}"); attrs.emplace_back(ctx.HasInput("{{attr | to_scalar_tensor_name}}") ? "{{attr | to_scalar_tensor_name}}" : "{{name}}");
...@@ -236,6 +241,7 @@ attrs.emplace_back( ...@@ -236,6 +241,7 @@ attrs.emplace_back(
{%- else %} {%- else %}
attrs.emplace_back("{{name}}"); attrs.emplace_back("{{name}}");
{%- endif %} {%- endif %}
{%- endif %}
{%- endmacro %} {%- endmacro %}
{% macro get_output_list(outputs, kernel_args) %}{# inline #} {% macro get_output_list(outputs, kernel_args) %}{# inline #}
...@@ -502,10 +508,9 @@ OutputGrad({{name_in_forward_orig | to_opmaker_name}}) ...@@ -502,10 +508,9 @@ OutputGrad({{name_in_forward_orig | to_opmaker_name}})
{% set name_in_forward = name[:-5] %} {% set name_in_forward = name[:-5] %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%} {% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%}
InputGrad({{name_in_forward_orig | to_opmaker_name}}) InputGrad({{name_in_forward_orig | to_opmaker_name}})
{%- elif (name | to_input_name) in input_names %} {%- elif (name) in input_names %}
{% set name_in_forward = name | to_input_name %} {% set name_in_forward_orig = input_orig_names[input_names.index(name)]%}
{% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%} Input({{name | to_opmaker_name}})
InputGrad({{name | to_input_name | to_opmaker_name}})
{%- endif %} {%- endif %}
{%- endmacro %} {%- endmacro %}
......
...@@ -195,17 +195,6 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -195,17 +195,6 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
class Squeeze2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
template <typename T> template <typename T>
class SqueezeGradOpMaker : public framework::SingleGradOpMaker<T> { class SqueezeGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
...@@ -220,32 +209,6 @@ class SqueezeGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -220,32 +209,6 @@ class SqueezeGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
class Squeeze2GradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(
context->HasInput("XShape"), "Input", "XShape", "Squeeze2Grad");
OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")),
"Input",
framework::GradVarName("Out"),
"Squeeze2Grad");
auto xshape_dims = context->GetInputDim("XShape");
auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
context->SetOutputDim(framework::GradVarName("X"), x_dims);
context->ShareLoD("XShape", framework::GradVarName("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
template <typename T> template <typename T>
class SqueezeDoubleGradOpMaker : public framework::SingleGradOpMaker<T> { class SqueezeDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
...@@ -259,82 +222,6 @@ class SqueezeDoubleGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -259,82 +222,6 @@ class SqueezeDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
// FIXME(zcd): squeeze2 adds an intermediate output(XShape) based on squeeze,
// the XShape is used to carry the shape and lod of X which will be used in
// squeeze_grad, in this way, the framework can reuse the memory of X
// immediately the squeeze2_op is finished.
// Considering compatibility issues, we could not fix squeeze2_op
class Squeeze2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor). The input tensor of squeeze operator.");
AddOutput("Out", "(Tensor). The output tensor of squeeze operator.");
AddOutput("XShape",
"XShape is just used to store the shape and lod of X, which will "
"be used in SqueezeGradOp.")
.AsIntermediate()
.AsExtra();
AddAttr<std::vector<int>>("axes",
"(std::vector<int>). List of integers,"
" indicating the dimensions to squeeze.")
.SetDefault({})
.SupportTensor();
AddComment(R"DOC(
Squeeze2 Operator.
Remove single-dimensional entries from the shape of a tensor.
Takes a parameter axes with a list of axes to squeeze.
If axes is not provided, all the single dimensions will be removed from the shape.
If an axis is selected with shape entry not equal to one, an error is raised.
Examples:
Case 1:
Given
X.shape = (1, 3, 1, 5)
and
axes = [0]
we get:
Out.shape = (3, 1, 5)
Case 2:
Given
X.shape = (1, 3, 1, 5)
and
axes = []
we get:
Out.shape = (3, 5)
)DOC");
}
};
template <typename T>
class Squeeze2GradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("squeeze2_grad");
grad_op->SetInput("XShape", this->Output("XShape"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};
template <typename T>
class Squeeze2DoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("squeeze2");
grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
grad_op->SetOutput("XShape", this->Input("XShape"));
grad_op->SetAttrMap(this->Attrs());
}
};
DECLARE_INPLACE_OP_INFERER(SqueezeInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(SqueezeInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(SqueezeGradInplaceInferer, DECLARE_INPLACE_OP_INFERER(SqueezeGradInplaceInferer,
{framework::GradVarName("Out"), {framework::GradVarName("Out"),
...@@ -345,10 +232,6 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(SqueezeGradNoNeedBufferVarsInferer, "X"); ...@@ -345,10 +232,6 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(SqueezeGradNoNeedBufferVarsInferer, "X");
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(squeeze2,
SqueezeInferShapeFunctor,
PD_INFER_META(phi::SqueezeWithXShapeInferMeta));
REGISTER_OPERATOR(squeeze, REGISTER_OPERATOR(squeeze,
ops::SqueezeOp, ops::SqueezeOp,
ops::SqueezeOpMaker, ops::SqueezeOpMaker,
...@@ -360,19 +243,6 @@ REGISTER_OPERATOR(squeeze_grad, ...@@ -360,19 +243,6 @@ REGISTER_OPERATOR(squeeze_grad,
ops::SqueezeDoubleGradOpMaker<paddle::imperative::OpBase>, ops::SqueezeDoubleGradOpMaker<paddle::imperative::OpBase>,
ops::SqueezeGradNoNeedBufferVarsInferer); ops::SqueezeGradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(squeeze2,
ops::Squeeze2Op,
ops::Squeeze2OpMaker,
ops::Squeeze2GradOpMaker<paddle::framework::OpDesc>,
ops::Squeeze2GradOpMaker<paddle::imperative::OpBase>,
ops::SqueezeInplaceInferer,
SqueezeInferShapeFunctor);
REGISTER_OPERATOR(squeeze2_grad,
ops::Squeeze2GradOp,
ops::Squeeze2DoubleGradOpMaker<paddle::framework::OpDesc>,
ops::Squeeze2DoubleGradOpMaker<paddle::imperative::OpBase>,
ops::SqueezeGradInplaceInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
squeeze, squeeze,
ops::SqueezeKernel<phi::CPUContext, float>, ops::SqueezeKernel<phi::CPUContext, float>,
......
...@@ -260,83 +260,6 @@ class UnsqueezeDoubleGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -260,83 +260,6 @@ class UnsqueezeDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
// FIXME(zcd): unsqueeze2 adds an intermediate output(XShape) based on
// unsqueeze, the XShape is used to carry the shape and lod of X which
// will be used in unsqueeze_grad, in this way, the framework can reuse
// the memory of X immediately the unsqueeze2_op is finished.
// Considering compatibility issues, we could not fix unsqueeze2_op
class Unsqueeze2Op : public UnsqueezeOp {
public:
using UnsqueezeOp::UnsqueezeOp;
};
class Unsqueeze2OpMaker : public UnsqueezeOpMaker {
public:
void Make() override {
UnsqueezeOpMaker::Make();
AddOutput("XShape",
"XShape is just used to store the shape and lod of X, which will "
"be used in UnsqueezeGradOp.")
.AsIntermediate()
.AsExtra();
}
};
template <typename T>
class Unsqueeze2GradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("unsqueeze2_grad");
grad_op->SetInput("XShape", this->Output("XShape"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};
class Unsqueeze2GradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE_EQ(
context->HasInput("XShape"),
true,
platform::errors::InvalidArgument("Input(XShape) shouldn't be null."));
PADDLE_ENFORCE_EQ(context->HasInput(framework::GradVarName("Out")),
true,
platform::errors::InvalidArgument(
"Input(Out@GRAD) shouldn't be null."));
auto xshape_dims = context->GetInputDim("XShape");
auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
context->SetOutputDim(framework::GradVarName("X"), x_dims);
context->ShareLoD("XShape", framework::GradVarName("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
template <typename T>
class Unsqueeze2DoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("unsqueeze2");
grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
grad_op->SetOutput("XShape", this->Input("XShape"));
grad_op->SetAttrMap(this->Attrs());
}
};
DECLARE_INPLACE_OP_INFERER(UnsqueezeInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(UnsqueezeInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(UnsqueezeGradInplaceInferer, DECLARE_INPLACE_OP_INFERER(UnsqueezeGradInplaceInferer,
{framework::GradVarName("Out"), {framework::GradVarName("Out"),
...@@ -345,10 +268,6 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnsqueezeGradOpNoNeedBufferVarInferer, "X"); ...@@ -345,10 +268,6 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnsqueezeGradOpNoNeedBufferVarInferer, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(unsqueeze2,
Unsqueeze2InferShapeFunctor,
PD_INFER_META(phi::UnsqueezeWithXShapeInferMeta));
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(unsqueeze, REGISTER_OPERATOR(unsqueeze,
ops::UnsqueezeOp, ops::UnsqueezeOp,
...@@ -362,20 +281,6 @@ REGISTER_OPERATOR(unsqueeze_grad, ...@@ -362,20 +281,6 @@ REGISTER_OPERATOR(unsqueeze_grad,
ops::UnsqueezeDoubleGradOpMaker<paddle::imperative::OpBase>, ops::UnsqueezeDoubleGradOpMaker<paddle::imperative::OpBase>,
ops::UnsqueezeGradOpNoNeedBufferVarInferer); ops::UnsqueezeGradOpNoNeedBufferVarInferer);
REGISTER_OPERATOR(unsqueeze2,
ops::Unsqueeze2Op,
ops::Unsqueeze2OpMaker,
ops::Unsqueeze2GradOpMaker<paddle::framework::OpDesc>,
ops::Unsqueeze2GradOpMaker<paddle::imperative::OpBase>,
Unsqueeze2InferShapeFunctor,
ops::UnsqueezeInplaceInferer);
REGISTER_OPERATOR(unsqueeze2_grad,
ops::Unsqueeze2GradOp,
ops::Unsqueeze2DoubleGradOpMaker<paddle::framework::OpDesc>,
ops::Unsqueeze2DoubleGradOpMaker<paddle::imperative::OpBase>,
ops::UnsqueezeGradInplaceInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
unsqueeze, unsqueeze,
ops::UnsqueezeKernel<phi::CPUContext, float>, ops::UnsqueezeKernel<phi::CPUContext, float>,
......
...@@ -1186,6 +1186,26 @@ ...@@ -1186,6 +1186,26 @@
backward : square_double_grad backward : square_double_grad
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : squeeze_double_grad
forward : squeeze_grad(Tensor xshape, Tensor grad_out, IntArray axis) -> Tensor(grad_x)
args : (Tensor grad_x_grad, IntArray axis)
output : Tensor(grad_out_grad), Tensor(xshape)
invoke: squeeze(grad_x_grad, axis)
intermediate : xshape
- backward_op : squeeze_grad
forward : squeeze(Tensor x, IntArray axis) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad, IntArray axis)
output : Tensor(x_grad)
infer_meta :
func : KernelWithXShapeInferMeta
param: [xshape]
kernel :
func : squeeze_grad
data_type : out_grad
inplace : (out_grad -> x_grad)
backward: squeeze_double_grad
- backward_op : svd_grad - backward_op : svd_grad
forward : svd (Tensor x, bool full_matrices = false) -> Tensor(u), Tensor(s), Tensor(vh) forward : svd (Tensor x, bool full_matrices = false) -> Tensor(u), Tensor(s), Tensor(vh)
args : (Tensor x, Tensor u, Tensor vh, Tensor s, Tensor u_grad, Tensor vh_grad, Tensor s_grad, bool full_matrices) args : (Tensor x, Tensor u, Tensor vh, Tensor s, Tensor u_grad, Tensor vh_grad, Tensor s_grad, bool full_matrices)
...@@ -1321,6 +1341,27 @@ ...@@ -1321,6 +1341,27 @@
data_type : out_grad data_type : out_grad
no_need_buffer : x no_need_buffer : x
- backward_op : unsqueeze_double_grad
forward : unsqueeze_grad(Tensor xshape, Tensor grad_out, IntArray axes) -> Tensor(grad_x)
args : (Tensor grad_x_grad, IntArray axes)
output : Tensor(grad_out_grad), Tensor(xshape)
invoke : unsqueeze(grad_x_grad, axes)
intermediate : xshape
- backward_op : unsqueeze_grad
forward : unsqueeze(Tensor x, IntArray axes) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad, IntArray axes)
output : Tensor(x_grad)
infer_meta :
func : KernelWithXShapeInferMeta
param: [xshape]
kernel :
func : unsqueeze_grad
param : [xshape, out_grad]
data_type : out_grad
inplace : (out_grad -> x_grad)
backward : unsqueeze_double_grad
- backward_op : unstack_grad - backward_op : unstack_grad
forward : unstack (Tensor x, int axis=0, int num=0) -> Tensor[](out) forward : unstack (Tensor x, int axis=0, int num=0) -> Tensor[](out)
args : (Tensor[] out_grad, int axis) args : (Tensor[] out_grad, int axis)
......
...@@ -1363,24 +1363,6 @@ ...@@ -1363,24 +1363,6 @@
kernel : kernel :
func : squared_l2_norm_grad func : squared_l2_norm_grad
- backward_op : squeeze_double_grad
forward : squeeze_grad(Tensor xshape, Tensor grad_out, IntArray axis) -> Tensor(grad_x)
args : (Tensor grad_x_grad, IntArray axis)
output : Tensor(grad_out_grad)
invoke: squeeze(grad_x_grad, axis)
- backward_op : squeeze_grad
forward : squeeze(Tensor x, IntArray axis) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad, IntArray axis)
output : Tensor(x_grad)
infer_meta :
func : KernelWithXShapeInferMeta
param: [xshape]
kernel :
func : squeeze_grad
inplace : (out_grad -> x_grad)
backward: squeeze_double_grad
- backward_op : stack_grad - backward_op : stack_grad
forward : stack (Tensor[] x, int axis) -> Tensor(out) forward : stack (Tensor[] x, int axis) -> Tensor(out)
args : (Tensor[] x, Tensor out_grad, int axis) args : (Tensor[] x, Tensor out_grad, int axis)
...@@ -1574,25 +1556,6 @@ ...@@ -1574,25 +1556,6 @@
func : uniform_inplace_grad func : uniform_inplace_grad
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : unsqueeze_double_grad
forward : unsqueeze_grad(Tensor xshape, Tensor grad_out, IntArray axes) -> Tensor(grad_x)
args : (Tensor grad_x_grad, IntArray axes)
output : Tensor(grad_out_grad)
invoke : unsqueeze(grad_x_grad, axes)
- backward_op : unsqueeze_grad
forward : unsqueeze(Tensor x, IntArray axes) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad, IntArray axes)
output : Tensor(x_grad)
infer_meta :
func : KernelWithXShapeInferMeta
param: [xshape]
kernel :
func : unsqueeze_grad
param: [xshape, out_grad]
inplace : (out_grad -> x_grad)
backward : unsqueeze_double_grad
- backward_op : warpctc_grad - backward_op : warpctc_grad
forward : warpctc (Tensor logits, Tensor label, Tensor logits_length, Tensor labels_length, int blank, bool norm_by_times) -> Tensor(loss), Tensor(warpctcgrad) forward : warpctc (Tensor logits, Tensor label, Tensor logits_length, Tensor labels_length, int blank, bool norm_by_times) -> Tensor(loss), Tensor(warpctcgrad)
args : (Tensor logits, Tensor logits_length, Tensor warpctcgrad, Tensor loss_grad, int blank, bool norm_by_times) args : (Tensor logits, Tensor logits_length, Tensor warpctcgrad, Tensor loss_grad, int blank, bool norm_by_times)
......
...@@ -1777,18 +1777,6 @@ ...@@ -1777,18 +1777,6 @@
func : squared_l2_norm func : squared_l2_norm
backward : squared_l2_norm_grad backward : squared_l2_norm_grad
- op : squeeze
args : (Tensor x, IntArray axis)
output : Tensor(out), Tensor(xshape)
infer_meta :
func : SqueezeWithXShapeInferMeta
kernel :
func : squeeze_with_xshape
inplace : (x -> out)
view: (x -> out)
intermediate : xshape
backward : squeeze_grad
- op : stack - op : stack
args : (Tensor[] x, int axis) args : (Tensor[] x, int axis)
output : Tensor output : Tensor
...@@ -2022,18 +2010,6 @@ ...@@ -2022,18 +2010,6 @@
data_type: x data_type: x
backward: unpool3d_grad backward: unpool3d_grad
- op : unsqueeze
args : (Tensor x, IntArray axis)
output : Tensor(out), Tensor(xshape)
infer_meta :
func : UnsqueezeWithXShapeInferMeta
kernel :
func : unsqueeze_with_xshape
inplace : (x -> out)
view: (x -> out)
intermediate : xshape
backward : unsqueeze_grad
- op : update_loss_scaling_ - op : update_loss_scaling_
args : (Tensor[] x, Tensor found_infinite, Tensor prev_loss_scaling, Tensor in_good_steps, Tensor in_bad_steps, int incr_every_n_steps, int decr_every_n_nan_or_inf, float incr_ratio, float decr_ratio, Scalar stop_update) args : (Tensor[] x, Tensor found_infinite, Tensor prev_loss_scaling, Tensor in_good_steps, Tensor in_bad_steps, int incr_every_n_steps, int decr_every_n_nan_or_inf, float incr_ratio, float decr_ratio, Scalar stop_update)
output : Tensor[](out){x.size()}, Tensor(loss_scaling), Tensor(out_good_steps), Tensor(out_bad_steps) output : Tensor[](out){x.size()}, Tensor(loss_scaling), Tensor(out_good_steps), Tensor(out_bad_steps)
......
...@@ -1270,9 +1270,20 @@ ...@@ -1270,9 +1270,20 @@
attrs : [bool use_mkldnn = false, bool use_cudnn = false] attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- op : squeeze (squeeze2) - op : squeeze (squeeze2)
backward : squeeze_grad (squeeze2_grad) backward : squeeze_grad (squeeze2_grad), squeeze_double_grad(squeeze2_double_grad)
inputs :
x : X
attrs :
axis : axes
outputs :
{out : Out, xshape : XShape}
int_array:
axis :
data_type : int
support_tensor : true
extra : extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"] attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"]
outputs : [xshape]
- op : stack - op : stack
backward : stack_grad backward : stack_grad
...@@ -1389,6 +1400,22 @@ ...@@ -1389,6 +1400,22 @@
outputs : outputs :
out : Y out : Y
- op : unsqueeze (unsqueeze2)
backward : unsqueeze_grad (unsqueeze2_grad), unsqueeze_double_grad(unsqueeze2_double_grad)
inputs :
x : X
attrs :
axis : axes
outputs :
{out : Out, xshape : XShape}
int_array:
axis :
data_type : int
tensor_name : AxesTensor
tensors_name : AxesTensorList
extra :
outputs : [xshape]
- op : unstack - op : unstack
backward : unstack_grad backward : unstack_grad
inputs : inputs :
......
...@@ -1054,6 +1054,19 @@ ...@@ -1054,6 +1054,19 @@
square_sr {selected_rows -> selected_rows} square_sr {selected_rows -> selected_rows}
backward : square_grad backward : square_grad
- op : squeeze
args : (Tensor x, IntArray axis={})
output : Tensor(out), Tensor(xshape)
infer_meta :
func : SqueezeWithXShapeInferMeta
kernel :
func : squeeze_with_xshape
data_type : x
inplace : (x -> out)
view: (x -> out)
intermediate : xshape
backward : squeeze_grad
- op : svd - op : svd
args : (Tensor x, bool full_matrices = false) args : (Tensor x, bool full_matrices = false)
output : Tensor(u), Tensor(s), Tensor(vh) output : Tensor(u), Tensor(s), Tensor(vh)
...@@ -1149,6 +1162,19 @@ ...@@ -1149,6 +1162,19 @@
func : unfold func : unfold
backward : unfold_grad backward : unfold_grad
- op : unsqueeze
args : (Tensor x, IntArray axis = {})
output : Tensor(out), Tensor(xshape)
infer_meta :
func : UnsqueezeWithXShapeInferMeta
kernel :
func : unsqueeze_with_xshape
data_type : x
inplace : (x -> out)
view: (x -> out)
intermediate : xshape
backward : unsqueeze_grad
- op : unstack - op : unstack
args : (Tensor x, int axis=0, int num=0) args : (Tensor x, int axis=0, int num=0)
output : Tensor[](out){num} output : Tensor[](out){num}
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature SqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"squeeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"});
}
KernelSignature SqueezeGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"squeeze_grad", {"XShape", "Out@GRAD"}, {"axes"}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(squeeze2, squeeze);
PD_REGISTER_BASE_KERNEL_NAME(squeeze2_grad, squeeze_grad);
PD_REGISTER_ARG_MAPPING_FN(squeeze2, phi::SqueezeOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(squeeze2_grad, phi::SqueezeGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature UnsqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.InputSize("AxesTensorList") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensorList";
return KernelSignature(
"unsqueeze_with_xshape", {"X"}, {"AxesTensorList"}, {"Out", "XShape"});
} else if (ctx.InputSize("AxesTensor") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensor";
return KernelSignature(
"unsqueeze_with_xshape", {"X"}, {"AxesTensor"}, {"Out", "XShape"});
} else {
VLOG(2) << "unsqueeze2 in axes";
return KernelSignature(
"unsqueeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"});
}
}
KernelSignature UnsqueezeGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"unsqueeze_grad", {"XShape", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(unsqueeze2, unsqueeze);
PD_REGISTER_BASE_KERNEL_NAME(unsqueeze2_grad, unsqueeze_grad);
PD_REGISTER_ARG_MAPPING_FN(unsqueeze2, phi::UnsqueezeOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(unsqueeze2_grad,
phi::UnsqueezeGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册