未验证 提交 227ab74d 编写于 作者: Z zyfncg 提交者: GitHub

support generating code of opmaker for backward op invoke forward op (#46912)

上级 e896567e
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
using framework::OpKernelType;
class FlipOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
layout,
library,
customized_type_value);
}
};
class FlipOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of flip op.");
AddOutput("Out", "(Tensor), The output tensor of flip op.");
AddAttr<std::vector<int>>("axis", "The axes to flip on.");
AddComment(R"DOC(
Flip Operator.
Reverse the order of a n-D tensor along given axis in axes.
)DOC");
}
};
class FlipOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
template <typename T>
class FlipOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("flip");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(flip,
FlipInferShapeFunctor,
PD_INFER_META(phi::FlipInferMeta));
REGISTER_OPERATOR(flip,
ops::FlipOp,
ops::FlipOpMaker,
ops::FlipOpInferVarType,
ops::FlipOpGradMaker<paddle::framework::OpDesc>,
ops::FlipOpGradMaker<paddle::imperative::OpBase>,
FlipInferShapeFunctor);
/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(flip).AddCheckpoint(
R"ROC(Upgrade flip, add new attr [axis] and delete attr [dims].)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewAttr("axis",
"The added attr 'axis' doesn't set default value.",
paddle::none)
.DeleteAttr("dims", "The attr 'dims' is deleted."));
...@@ -108,7 +108,6 @@ register_unity_group( ...@@ -108,7 +108,6 @@ register_unity_group(
register_unity_group( register_unity_group(
cc cc
flatten_op.cc flatten_op.cc
flip_op.cc
fsp_op.cc fsp_op.cc
gather_nd_op.cc gather_nd_op.cc
gather_op.cc gather_op.cc
...@@ -423,7 +422,6 @@ register_unity_group(cu expand_v2_op.cu fake_dequantize_op.cu ...@@ -423,7 +422,6 @@ register_unity_group(cu expand_v2_op.cu fake_dequantize_op.cu
fill_any_like_op.cu) fill_any_like_op.cu)
register_unity_group( register_unity_group(
cu cu
flip_op.cu
fsp_op.cu fsp_op.cu
gather_nd_op.cu gather_nd_op.cu
gather_op.cu gather_op.cu
......
...@@ -171,9 +171,9 @@ create or remove auto-geneated argument mappings: ${generated_argument_mapping_p ...@@ -171,9 +171,9 @@ create or remove auto-geneated argument mappings: ${generated_argument_mapping_p
execute_process( execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml
COMMAND COMMAND
${PYTHON_EXECUTABLE} generator/generate_op.py --api_yaml_path ${PYTHON_EXECUTABLE} generator/generate_op.py --ops_yaml_path
./parsed_apis/api.parsed.yaml --backward_api_yaml_path ./parsed_apis/api.parsed.yaml --backward_yaml_path
./parsed_apis/backward_api.parsed.yaml --api_version_yaml_path ./parsed_apis/backward_api.parsed.yaml --op_version_yaml_path
op_version.yaml --op_compat_yaml_path op_compat.yaml --output_op_path op_version.yaml --op_compat_yaml_path op_compat.yaml --output_op_path
"${generated_op_path}.tmp" --output_arg_map_path "${generated_op_path}.tmp" --output_arg_map_path
"${generated_argument_mapping_path}.tmp" "${generated_argument_mapping_path}.tmp"
......
...@@ -147,6 +147,12 @@ ...@@ -147,6 +147,12 @@
data_type: out_grad data_type: out_grad
no_need_buffer: x no_need_buffer: x
- backward_op : flip_grad
forward : flip (Tensor x, int[] axis) -> Tensor(out)
args : (Tensor out_grad, int[] axis)
output : Tensor(x_grad)
invoke : flip(out_grad, axis)
- backward_op : graph_send_uv_grad - backward_op : graph_send_uv_grad
forward : graph_send_uv (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD") -> Tensor(out) forward : graph_send_uv (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD") -> Tensor(out)
args: (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, Tensor out_grad, str message_op = "ADD") args: (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, Tensor out_grad, str message_op = "ADD")
......
...@@ -143,6 +143,14 @@ def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict): ...@@ -143,6 +143,14 @@ def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict):
[:-5]] + '_grad' [:-5]] + '_grad'
args_item['name'] = args_map[args_item['name']] args_item['name'] = args_map[args_item['name']]
if 'invoke' in backward_api_item:
backward_api_item['invoke']['args'] = [
args_map[param.strip()]
if param.strip() in args_map else param.strip()
for param in backward_api_item['invoke']['args'].split(',')
]
continue
backward_api_item['infer_meta']['param'] = [ backward_api_item['infer_meta']['param'] = [
args_map[param] if param in args_map else param args_map[param] if param in args_map else param
for param in backward_api_item['infer_meta']['param'] for param in backward_api_item['infer_meta']['param']
...@@ -173,9 +181,9 @@ def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict): ...@@ -173,9 +181,9 @@ def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict):
] ]
def main(api_yaml_path, backward_yaml_path, op_compat_yaml_path, def main(ops_yaml_path, backward_yaml_path, op_compat_yaml_path,
api_version_yaml_path, output_op_path, output_arg_map_path): op_version_yaml_path, output_op_path, output_arg_map_path):
with open(api_yaml_path, "rt") as f: with open(ops_yaml_path, "rt") as f:
apis = yaml.safe_load(f) apis = yaml.safe_load(f)
apis = [restruct_io(api) for api in apis] apis = [restruct_io(api) for api in apis]
forward_api_dict = to_named_dict(apis) forward_api_dict = to_named_dict(apis)
...@@ -185,7 +193,7 @@ def main(api_yaml_path, backward_yaml_path, op_compat_yaml_path, ...@@ -185,7 +193,7 @@ def main(api_yaml_path, backward_yaml_path, op_compat_yaml_path,
backward_apis = [restruct_io(api) for api in backward_apis] backward_apis = [restruct_io(api) for api in backward_apis]
backward_api_dict = to_named_dict(backward_apis) backward_api_dict = to_named_dict(backward_apis)
with open(api_version_yaml_path, "rt") as f: with open(op_version_yaml_path, "rt") as f:
api_versions = yaml.safe_load(f) api_versions = yaml.safe_load(f)
# add api version info into api # add api version info into api
for api_version in api_versions: for api_version in api_versions:
...@@ -201,6 +209,45 @@ def main(api_yaml_path, backward_yaml_path, op_compat_yaml_path, ...@@ -201,6 +209,45 @@ def main(api_yaml_path, backward_yaml_path, op_compat_yaml_path,
replace_compat_name(api_op_map, forward_api_dict, backward_api_dict) replace_compat_name(api_op_map, forward_api_dict, backward_api_dict)
# prepare for invoke case
for bw_name, bw_api in backward_api_dict.items():
if 'invoke' in bw_api:
invoke_op = bw_api['invoke']['func']
args_list = bw_api['invoke']['args']
args_index = 0
if invoke_op in forward_api_dict.keys():
reuse_op = forward_api_dict[invoke_op]
bw_api['invoke']['inputs'] = []
bw_api['invoke']['attrs'] = []
bw_api['invoke']['outputs'] = []
for input_item in reuse_op['inputs']:
bw_api['invoke']['inputs'].append({
'name':
input_item['name'],
'value':
args_list[args_index]
})
args_index = args_index + 1
for attr in reuse_op['attrs']:
if args_index < len(args_list):
attr_value = f"this->GetAttr(\"{args_list[args_index]}\")" if args_list[
args_index] in bw_api['attr_dict'] else args_list[
args_index]
bw_api['invoke']['attrs'].append({
'name': attr['name'],
'value': attr_value
})
args_index = args_index + 1
else:
break
for idx, output_item in enumerate(reuse_op['outputs']):
bw_api['invoke']['outputs'].append({
'name':
output_item['name'],
'value':
bw_api['outputs'][idx]['name']
})
# fill backward field for an api if another api claims it as forward # fill backward field for an api if another api claims it as forward
for name, backward_api in backward_api_dict.items(): for name, backward_api in backward_api_dict.items():
forward_name = backward_api["forward"]["name"] forward_name = backward_api["forward"]["name"]
...@@ -236,18 +283,18 @@ def main(api_yaml_path, backward_yaml_path, op_compat_yaml_path, ...@@ -236,18 +283,18 @@ def main(api_yaml_path, backward_yaml_path, op_compat_yaml_path,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Generate operator file from api yaml.") description="Generate operator file from api yaml.")
parser.add_argument('--api_yaml_path', parser.add_argument('--ops_yaml_path',
type=str, type=str,
help="parsed api yaml file.") help="parsed ops yaml file.")
parser.add_argument('--backward_api_yaml_path', parser.add_argument('--backward_yaml_path',
type=str, type=str,
help="parsed backward api yaml file.") help="parsed backward ops yaml file.")
parser.add_argument('--op_compat_yaml_path', parser.add_argument('--op_compat_yaml_path',
type=str, type=str,
help="api args compat yaml file.") help="ops args compat yaml file.")
parser.add_argument('--api_version_yaml_path', parser.add_argument('--op_version_yaml_path',
type=str, type=str,
help="api version yaml file.") help="ops version yaml file.")
parser.add_argument("--output_op_path", parser.add_argument("--output_op_path",
type=str, type=str,
help="path to save generated operators.") help="path to save generated operators.")
...@@ -257,6 +304,6 @@ if __name__ == "__main__": ...@@ -257,6 +304,6 @@ if __name__ == "__main__":
help="path to save generated argument mapping functions.") help="path to save generated argument mapping functions.")
args = parser.parse_args() args = parser.parse_args()
main(args.api_yaml_path, args.backward_api_yaml_path, main(args.ops_yaml_path, args.backward_yaml_path, args.op_compat_yaml_path,
args.op_compat_yaml_path, args.api_version_yaml_path, args.op_version_yaml_path, args.output_op_path,
args.output_op_path, args.output_arg_map_path) args.output_arg_map_path)
{% from "operator_utils.c.j2" import op_maker, backward_op_maker, operator, register_op_with_components, register_op_version %} {% from "operator_utils.c.j2" import op_maker, backward_op_maker, backward_op_reused_maker, operator, register_op_with_components, register_op_version %}
// this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit. // this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit.
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
...@@ -33,6 +33,8 @@ using paddle::framework::GradVarName; ...@@ -33,6 +33,8 @@ using paddle::framework::GradVarName;
{{backward_op_maker(api, api_dict[api["forward"]["name"]])}} {{backward_op_maker(api, api_dict[api["forward"]["name"]])}}
{{operator(api)}} {{operator(api)}}
{% else %}
{{backward_op_reused_maker(api, api_dict[api["forward"]["name"]], api["invoke"])}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
} // namespace operators } // namespace operators
......
...@@ -352,6 +352,48 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> ...@@ -352,6 +352,48 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
}; };
{% endmacro %} {% endmacro %}
{% macro backward_op_reused_maker(bw_op, forward_op, invoke_op) %}
{% set name = bw_op["op_name"] %}
{% set forward_input_names = bw_op["forward"]["inputs"] | map(attribute="name") | list %}
{% set forward_output_names = bw_op["forward"]["outputs"] | map(attribute="name") | list %}
{% set forward_attr_names = bw_op["forward"]["attrs"] | map(attribute="name") | list %}
{% set forward_input_orig_names = forward_op["inputs"] | map(attribute="name") | list %}
{% set forward_output_orig_names = forward_op["outputs"] | map(attribute="name") | list %}
{% set forward_attr_orig_names = forward_op["attrs"] | map(attribute="name") | list %}
template <typename T>
class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("{{invoke_op["func"]}}");
{% for input in invoke_op["inputs"] %}
grad_op->SetInput({{input["name"] | to_opmaker_name}}, this->{{extract_input_from_forward(
input["value"],
forward_input_names,
forward_output_names,
forward_input_orig_names,
forward_output_orig_names)}});
{% endfor %}
{% for output in invoke_op["outputs"] %}
grad_op->SetOutput({{output["name"] | to_opmaker_name}}, this->{{extract_output_from_forward(
output["value"],
forward_input_names,
forward_output_names,
forward_input_orig_names,
forward_output_orig_names)}});
{% endfor %}
{% for attr in invoke_op["attrs"] %}
grad_op->SetAttr("{{attr["name"]}}", {{attr["value"]}});
{% endfor %}
}
};
{% endmacro %}
{% macro extract_input_from_forward(name, {% macro extract_input_from_forward(name,
input_names, output_names, input_names, output_names,
......
...@@ -839,12 +839,6 @@ ...@@ -839,12 +839,6 @@
layout: out_grad layout: out_grad
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : flip_grad
forward : flip (Tensor x, int[] axis) -> Tensor(out)
args : (Tensor out_grad, int[] axis)
output : Tensor(x_grad)
invoke : flip(out_grad, axis)
- backward_op : floor_grad - backward_op : floor_grad
forward : floor(Tensor x) -> Tensor(out) forward : floor(Tensor x) -> Tensor(out)
args : (Tensor out_grad) args : (Tensor out_grad)
......
...@@ -962,15 +962,6 @@ ...@@ -962,15 +962,6 @@
intermediate : xshape intermediate : xshape
backward : flatten_grad backward : flatten_grad
- op : flip
args : (Tensor x, int[] axis)
output : Tensor
infer_meta :
func : FlipInferMeta
kernel :
func : flip
backward : flip_grad
- op : floor - op : floor
args : (Tensor x) args : (Tensor x)
output : Tensor(out) output : Tensor(out)
......
...@@ -324,6 +324,12 @@ ...@@ -324,6 +324,12 @@
inputs: {x: X} inputs: {x: X}
outputs: {out: Out} outputs: {out: Out}
- op : flip
inputs :
x : X
outputs :
out : Out
- op : floor - op : floor
backward : floor_grad backward : floor_grad
extra : extra :
......
- op : flip
version :
- checkpoint : Upgrade flip, add new attr [axis] and delete attr [dims]
action :
- add_attr : axis
comment : The added attr 'axis' doesn't set default value
default : paddle::none
- delete_attr : dims
comment : The attr 'dims' is deleted.
- op : trace - op : trace
version : version :
- checkpoint : Upgrade trace add a new attribute [axis2] - checkpoint : Upgrade trace add a new attribute [axis2]
......
...@@ -199,3 +199,12 @@ ...@@ -199,3 +199,12 @@
kernel : kernel :
func : trunc func : trunc
backward : trunc_grad backward : trunc_grad
- op : flip
args : (Tensor x, int[] axis)
output : Tensor (out)
infer_meta :
func : FlipInferMeta
kernel :
func : flip
backward : flip_grad
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册