未验证 提交 766f4dcb 编写于 作者: Z zyfncg 提交者: GitHub

Support setting version for api in yaml (#43771)

* move trace into api.yaml

* add trace unittest

* fix trace test

* fix generate op
上级 1fca8f33
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class TraceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
class TraceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input",
"(Tensor) The input tensor, from which the diagonals are taken.");
AddOutput("Out", "(Tensor) the sum along diagonals of the input tensor");
AddAttr<int>(
"offset",
R"DOC((int, default 0), offset of the diagonal from the main diagonal. Can be both positive and negative. Defaults to 0.
)DOC")
.SetDefault(0);
AddAttr<int>(
"axis1",
R"DOC((int, default 0), the first axis of the 2-D planes from which the diagonals should be taken.
Can be either positive or negative. Default: 0.
)DOC")
.SetDefault(0);
AddAttr<int>(
"axis2",
R"DOC((int, default 1), the second axis of the 2-D planes from which the diagonals should be taken.
Can be either positive or negative. Default: 1.
)DOC")
.SetDefault(1);
AddComment(R"DOC(
Trace Operator.
Return the sum along diagonals of the input tensor.
The behavior of this operator is similar to how `numpy.trace` works.
If Input is 2-D, returns the sum of diagonal.
If Input has larger dimensions, then returns an tensor of diagonals sum, diagonals be taken from
the 2-D planes specified by dim1 and dim2.
)DOC");
}
};
class TraceGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Input"), true,
platform::errors::NotFound("Input(Input) of TraceOp is not found."));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Input")), true,
platform::errors::NotFound(
"Output(Input@GRAD) of TraceGradOp is not found."));
ctx->SetOutputDim(framework::GradVarName("Input"),
ctx->GetInputDim("Input"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
template <typename T>
class TraceGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("trace_grad");
grad_op->SetInput("Input", this->Input("Input"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("Input"),
this->InputGrad("Input"));
grad_op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(TraceGradNoNeedBufferVarsInferer, "Input");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(trace, TraceInferShapeFunctor,
PD_INFER_META(phi::TraceInferMeta));
REGISTER_OPERATOR(trace, ops::TraceOp, ops::TraceOpMaker,
ops::TraceGradOpMaker<paddle::framework::OpDesc>,
ops::TraceGradOpMaker<paddle::imperative::OpBase>,
TraceInferShapeFunctor);
REGISTER_OPERATOR(trace_grad, ops::TraceGradOp,
ops::TraceGradNoNeedBufferVarsInferer);
/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(trace).AddCheckpoint(
R"ROC(Upgrade trace add a new attribute [axis2])ROC",
paddle::framework::compatible::OpVersionDesc()
.NewAttr("axis1", "The added attribute 'axis1' is not yet registered.",
std::vector<float>{0.0f})
.NewAttr("axis2", "The added attribute 'axis2' is not yet registered.",
std::vector<float>{1.0f})
.DeleteAttr("dim1",
"The attribute 'dim1' is not recommend according to "
"the specification 2.0.")
.DeleteAttr("dim2",
"The attribute 'dim2' is not recommend according to "
"the specification 2.0."));
...@@ -159,8 +159,9 @@ execute_process( ...@@ -159,8 +159,9 @@ execute_process(
COMMAND COMMAND
${PYTHON_EXECUTABLE} generate_op.py --api_yaml_path ${PYTHON_EXECUTABLE} generate_op.py --api_yaml_path
./parsed_apis/api.parsed.yaml --backward_api_yaml_path ./parsed_apis/api.parsed.yaml --backward_api_yaml_path
./parsed_apis/backward_api.parsed.yaml --output_op_path ./parsed_apis/backward_api.parsed.yaml --api_version_yaml_path
"${generated_op_path}.tmp" --output_arg_map_path api_version.yaml --api_args_compat_yaml_path args_compat.yaml
--output_op_path "${generated_op_path}.tmp" --output_arg_map_path
"${generated_argument_mapping_path}.tmp" "${generated_argument_mapping_path}.tmp"
RESULT_VARIABLE _result) RESULT_VARIABLE _result)
if(${_result}) if(${_result})
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature TraceOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"trace", {"Input"}, {"offset", "axis1", "axis2"}, {"Out"});
}
KernelSignature TraceGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("trace_grad",
{"Input", "Out@GRAD"},
{"offset", "axis1", "axis2"},
{"Input@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(trace, phi::TraceOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(trace_grad, phi::TraceGradOpArgumentMapping);
...@@ -28,14 +28,15 @@ class TestTraceOp(OpTest): ...@@ -28,14 +28,15 @@ class TestTraceOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "trace" self.op_type = "trace"
self.python_api = paddle.trace
self.init_config() self.init_config()
self.outputs = {'Out': self.target} self.outputs = {'Out': self.target}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_eager=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['Input'], 'Out') self.check_grad(['Input'], 'Out', check_eager=True)
def init_config(self): def init_config(self):
self.case = np.random.randn(20, 6).astype('float64') self.case = np.random.randn(20, 6).astype('float64')
......
# erf
# bernoulli
- api : bernoulli - api : bernoulli
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor
...@@ -27,7 +24,6 @@ ...@@ -27,7 +24,6 @@
func : mv func : mv
backward : mv_grad backward : mv_grad
# poisson
- api : poisson - api : poisson
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor
...@@ -37,6 +33,15 @@ ...@@ -37,6 +33,15 @@
func : poisson func : poisson
backward : poisson_grad backward : poisson_grad
- api : trace
args : (Tensor x, int offset = 0, int axis1 = 0, int axis2 = 1)
output : Tensor
infer_meta :
func : TraceInferMeta
kernel :
func : trace
backward : trace_grad
- api : trunc - api : trunc
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor
......
- api : trace
version :
- checkpoint : Upgrade trace add a new attribute [axis2]
action :
- add_attr : axis1
comment : The added attribute 'axis1' is not yet registered.
default : std::vector<float>{0.0f}
- add_attr :
name : axis2
comment : The added attribute 'axis2' is not yet registered.
default : std::vector<float>{1.0f}
- delete_attr : dim1
comment : The attribute 'dim1' is not recommend according to the specification 2.0.
- delete_attr : dim2
comment : The attribute 'dim2' is not recommend according to the specification 2.0.
- api : trace
inputs :
x : Input
outputs :
out : Out
...@@ -29,6 +29,18 @@ ...@@ -29,6 +29,18 @@
kernel : kernel :
func : poisson_grad func : poisson_grad
- backward_api : trace_grad
forward : trace (Tensor x, int offset, int axis1, int axis2) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int offset, int axis1, int axis2)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : trace_grad
data_type : out_grad
no_need_buffer : x
- backward_api : trunc_grad - backward_api : trunc_grad
forward : trunc (Tensor x) -> Tensor(out) forward : trunc (Tensor x) -> Tensor(out)
args : (Tensor out_grad) args : (Tensor out_grad)
......
...@@ -101,9 +101,9 @@ def to_input_name(s): ...@@ -101,9 +101,9 @@ def to_input_name(s):
x -> dx x -> dx
x -> d2x x -> d2x
x -> d3x x -> d3x
NOTE: for first order backward api NOTE: for first order backward api
x -> x_grad x -> x_grad
is more common. is more common.
""" """
match = re.match(r"(d\d*)(\w+)", s) match = re.match(r"(d\d*)(\w+)", s)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import argparse import argparse
import os import os
import re
from itertools import chain from itertools import chain
from pathlib import Path from pathlib import Path
...@@ -53,8 +54,8 @@ def restruct_io(api): ...@@ -53,8 +54,8 @@ def restruct_io(api):
return api return api
def main(api_yaml_path, backward_yaml_path, output_op_path, def main(api_yaml_path, backward_yaml_path, api_args_compat_yaml_path,
output_arg_map_path): api_version_yaml_path, output_op_path, output_arg_map_path):
with open(api_yaml_path, "rt") as f: with open(api_yaml_path, "rt") as f:
apis = yaml.safe_load(f) apis = yaml.safe_load(f)
apis = [restruct_io(api) for api in apis] apis = [restruct_io(api) for api in apis]
...@@ -65,6 +66,113 @@ def main(api_yaml_path, backward_yaml_path, output_op_path, ...@@ -65,6 +66,113 @@ def main(api_yaml_path, backward_yaml_path, output_op_path,
backward_apis = [restruct_io(api) for api in backward_apis] backward_apis = [restruct_io(api) for api in backward_apis]
backward_api_dict = to_named_dict(backward_apis) backward_api_dict = to_named_dict(backward_apis)
with open(api_version_yaml_path, "rt") as f:
api_versions = yaml.safe_load(f)
# add api version info into api
for api_version in api_versions:
forward_api_dict[api_version['api']]['version'] = api_version['version']
with open(api_args_compat_yaml_path, "rt") as f:
api_args_map = yaml.safe_load(f)
# replace args name for OpMaker
for api_args in api_args_map:
forward_api_item = forward_api_dict[api_args['api']]
has_backward = True if forward_api_item['backward'] else False
if has_backward:
backward_api_item = backward_api_dict[forward_api_item['backward']]
key_set = ['inputs', 'attrs', 'outputs']
args_map = {}
for key in key_set:
if key in api_args:
args_map.update(api_args[key])
for args_item in forward_api_item[key]:
if args_item['name'] in api_args[key]:
args_item['name'] = api_args[key][args_item['name']]
if has_backward:
for args_item in backward_api_item['forward'][key]:
if args_item['name'] in api_args[key]:
args_item['name'] = api_args[key][args_item['name']]
forward_api_item['infer_meta']['param'] = [
args_map[param] if param in args_map else param
for param in forward_api_item['infer_meta']['param']
]
forward_api_item['kernel']['param'] = [
args_map[param] if param in args_map else param
for param in forward_api_item['kernel']['param']
]
if forward_api_item['kernel']['data_type']:
forward_api_item['kernel']['data_type']['candidates'] = [
args_map[param] if param in args_map else param for param in
forward_api_item['kernel']['data_type']['candidates']
]
if forward_api_item['kernel']['backend']:
forward_api_item['kernel']['backend']['candidates'] = [
args_map[param] if param in args_map else param
for param in forward_api_item['kernel']['backend']['candidates']
]
if forward_api_item['kernel']['layout']:
forward_api_item['kernel']['layout']['candidates'] = [
args_map[param] if param in args_map else param
for param in forward_api_item['kernel']['layout']['candidates']
]
if forward_api_item['inplace']:
inplace_map = {}
for key, val in forward_api_item['inplace'].items():
if key in args_map:
key = args_map[key]
if val in args_map:
val = args_map[val]
inplace_map[key] = val
forward_api_item['inplace'] = inplace_map
if has_backward:
for args_item in backward_api_item['inputs']:
if args_item['name'] in args_map:
args_item['name'] = args_map[args_item['name']]
elif args_item['name'].endswith(
'_grad') and args_item['name'][:-5] in args_map:
args_map[args_item['name']] = args_map[args_item['name']
[:-5]] + '_grad'
args_item['name'] = args_map[args_item['name']]
for args_item in backward_api_item['attrs']:
if args_item['name'] in args_map:
args_item['name'] = args_map[args_item['name']]
for args_item in backward_api_item['outputs']:
if args_item['name'].endswith(
'_grad') and args_item['name'][:-5] in args_map:
args_map[args_item['name']] = args_map[args_item['name']
[:-5]] + '_grad'
args_item['name'] = args_map[args_item['name']]
backward_api_item['infer_meta']['param'] = [
args_map[param] if param in args_map else param
for param in backward_api_item['infer_meta']['param']
]
backward_api_item['kernel']['param'] = [
args_map[param] if param in args_map else param
for param in backward_api_item['kernel']['param']
]
if backward_api_item['kernel']['data_type']:
backward_api_item['kernel']['data_type']['candidates'] = [
args_map[param] if param in args_map else param for param in
backward_api_item['kernel']['data_type']['candidates']
]
if backward_api_item['kernel']['backend']:
backward_api_item['kernel']['backend']['candidates'] = [
args_map[param] if param in args_map else param for param in
backward_api_item['kernel']['backend']['candidates']
]
if backward_api_item['kernel']['layout']:
backward_api_item['kernel']['layout']['candidates'] = [
args_map[param] if param in args_map else param for param in
backward_api_item['kernel']['layout']['candidates']
]
if backward_api_item['no_need_buffer']:
backward_api_item['no_need_buffer'] = [
args_map[param] if param in args_map else param
for param in backward_api_item['no_need_buffer']
]
# fill backward field for an api if another api claims it as forward # fill backward field for an api if another api claims it as forward
for name, backward_api in backward_api_dict.items(): for name, backward_api in backward_api_dict.items():
forward_name = backward_api["forward"]["name"] forward_name = backward_api["forward"]["name"]
...@@ -111,6 +219,12 @@ if __name__ == "__main__": ...@@ -111,6 +219,12 @@ if __name__ == "__main__":
parser.add_argument('--backward_api_yaml_path', parser.add_argument('--backward_api_yaml_path',
type=str, type=str,
help="parsed backward api yaml file.") help="parsed backward api yaml file.")
parser.add_argument('--api_args_compat_yaml_path',
type=str,
help="api args compat yaml file.")
parser.add_argument('--api_version_yaml_path',
type=str,
help="api version yaml file.")
parser.add_argument("--output_op_path", parser.add_argument("--output_op_path",
type=str, type=str,
help="path to save generated operators.") help="path to save generated operators.")
...@@ -120,5 +234,6 @@ if __name__ == "__main__": ...@@ -120,5 +234,6 @@ if __name__ == "__main__":
help="path to save generated argument mapping functions.") help="path to save generated argument mapping functions.")
args = parser.parse_args() args = parser.parse_args()
main(args.api_yaml_path, args.backward_api_yaml_path, args.output_op_path, main(args.api_yaml_path, args.backward_api_yaml_path,
args.output_arg_map_path) args.api_args_compat_yaml_path, args.api_version_yaml_path,
args.output_op_path, args.output_arg_map_path)
...@@ -115,7 +115,6 @@ def generate_intermediate_api(api_yaml_path, sparse_api_yaml_path, ...@@ -115,7 +115,6 @@ def generate_intermediate_api(api_yaml_path, sparse_api_yaml_path,
for api in sparse_apis: for api in sparse_apis:
sparse_api = SparseAPI(api) sparse_api = SparseAPI(api)
if sparse_api.is_dygraph_api: if sparse_api.is_dygraph_api:
print(sparse_api.api)
dygraph_header_file.write(sparse_api.gene_api_declaration()) dygraph_header_file.write(sparse_api.gene_api_declaration())
dygraph_source_file.write(sparse_api.gene_api_code()) dygraph_source_file.write(sparse_api.gene_api_code())
......
# The apis in this file are unstandardized that may caused by a variety of reasons,
# we are trying to fix these apis and will move standardized apis into api.yaml.
- api : abs - api : abs
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor
...@@ -1741,7 +1744,7 @@ ...@@ -1741,7 +1744,7 @@
- api : relu - api : relu
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
...@@ -2153,15 +2156,6 @@ ...@@ -2153,15 +2156,6 @@
func : top_k func : top_k
backward : top_k_grad backward : top_k_grad
- api : trace
args : (Tensor x, int offset, int axis1, int axis2)
output : Tensor
infer_meta :
func : TraceInferMeta
kernel :
func : trace
backward : trace_grad
- api : transpose - api : transpose
args : (Tensor x, int[] axis) args : (Tensor x, int[] axis)
output : Tensor output : Tensor
......
...@@ -2180,17 +2180,6 @@ ...@@ -2180,17 +2180,6 @@
kernel : kernel :
func : top_k_grad func : top_k_grad
- backward_api : trace_grad
forward : trace (Tensor x, int offset, int axis1, int axis2) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int offset, int axis1, int axis2)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : trace_grad
no_need_buffer : x
- backward_api : transpose_double_grad - backward_api : transpose_double_grad
forward : transpose_grad (Tensor grad_out, int[] axis) -> Tensor(grad_x) forward : transpose_grad (Tensor grad_out, int[] axis) -> Tensor(grad_x)
args : (Tensor grad_x_grad, int[] axis) args : (Tensor grad_x_grad, int[] axis)
......
{% from "operator_utils.c.j2" import op_maker, backward_op_maker, operator, register_op_with_components %} {% from "operator_utils.c.j2" import op_maker, backward_op_maker, operator, register_op_with_components, register_op_version %}
// this file is generated by python/paddle/utils/code_gen/generate_op.py, do not edit. // this file is generated by python/paddle/utils/code_gen/generate_op.py, do not edit.
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/nullary.h" #include "paddle/phi/infermeta/nullary.h"
...@@ -41,5 +42,6 @@ namespace ops = paddle::operators; ...@@ -41,5 +42,6 @@ namespace ops = paddle::operators;
{% for api in apis + backward_apis %} {% for api in apis + backward_apis %}
{% if api is base_api %} {% if api is base_api %}
{{register_op_with_components(api)}} {{register_op_with_components(api)}}
{{register_op_version(api)}}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
...@@ -6,9 +6,7 @@ class {{api_name | to_pascal_case}}OpMaker : public framework::OpProtoAndChecker ...@@ -6,9 +6,7 @@ class {{api_name | to_pascal_case}}OpMaker : public framework::OpProtoAndChecker
void Make() override { void Make() override {
{% filter indent(4, True) %} {% filter indent(4, True) %}
{% for input in api["inputs"] %} {% for input in api["inputs"] %}
{% if input["name"] in api["kernel"]["param"] %}
{{add_input(loop.index0, input, api_name)}}; {{add_input(loop.index0, input, api_name)}};
{% endif %}
{% endfor %} {% endfor %}
{% for output in api["outputs"] %} {% for output in api["outputs"] %}
{{add_output(loop.index0, output, api_name)}}; {{add_output(loop.index0, output, api_name)}};
...@@ -130,9 +128,7 @@ PD_REGISTER_ARG_MAPPING_FN({{api["name"]}}, phi::{{api["name"] | to_pascal_case} ...@@ -130,9 +128,7 @@ PD_REGISTER_ARG_MAPPING_FN({{api["name"]}}, phi::{{api["name"] | to_pascal_case}
{% macro get_input_list(inputs, kernel_args) %}{# inline #} {% macro get_input_list(inputs, kernel_args) %}{# inline #}
paddle::small_vector<const char*> inputs { paddle::small_vector<const char*> inputs {
{%- for input in inputs %} {%- for input in inputs %}
{%- if input["name"] in kernel_args %}
{{input["name"] | to_opmaker_name_cstr}}{{", " if not loop.last}} {{input["name"] | to_opmaker_name_cstr}}{{", " if not loop.last}}
{%- endif %}
{%- endfor %} {%- endfor %}
} }
{%- endmacro %} {%- endmacro %}
...@@ -188,8 +184,7 @@ framework::OpKernelType GetExpectedKernelType( ...@@ -188,8 +184,7 @@ framework::OpKernelType GetExpectedKernelType(
} }
{% endif %} {% endif %}
{% endif %} {% endif %}
platform::Place place = ctx.GetPlace(); return framework::OpKernelType(data_type, ctx.GetPlace());
return framework::OpKernelType(data_type, place);
} }
{% endmacro %} {% endmacro %}
...@@ -251,6 +246,47 @@ REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op, ...@@ -251,6 +246,47 @@ REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op,
ops::{{name | to_pascal_case}}InferShapeFunctor); ops::{{name | to_pascal_case}}InferShapeFunctor);
{% endmacro %} {% endmacro %}
{% macro register_op_version(api) %}
{% if "version" in api %}
{% set name = api["name"] %}
REGISTER_OP_VERSION({{name}})
{% for checkpoint in api["version"]%}
.AddCheckpoint(
R"ROC({{checkpoint["checkpoint"]}})ROC",
paddle::framework::compatible::OpVersionDesc()
{% for action in checkpoint["action"]%}
{% if "add_input" in action %}
.NewInput("{{action["add_input"]}}", "{{action["comment"]}}"){{")" if loop.last}}
{% endif %}
{% if "delete_input" in action %}
.DeleteInput("{{action["delete_input"]}}", "{{action["comment"]}}"){{")" if loop.last}}
{% endif %}
{% if "modify_input" in action %}
.ModifyInput("{{action["modify_input"]}}", "{{action["comment"]}}"){{")" if loop.last}}
{% endif %}
{% if "add_output" in action %}
.NewOutput("{{action["add_output"]}}", "{{action["comment"]}}"){{")" if loop.last}}
{% endif %}
{% if "delete_output" in action %}
.DeleteOutput("{{action["delete_output"]}}", "{{action["comment"]}}"){{")" if loop.last}}
{% endif %}
{% if "modify_output" in action %}
.ModifyOutput("{{action["modify_output"]}}", "{{action["comment"]}}"){{")" if loop.last}}
{% endif %}
{% if "add_attr" in action %}
.NewAttr("{{action["add_attr"]}}", "{{action["comment"]}}", {{action["default"]}}){{")" if loop.last}}
{% endif %}
{% if "delete_attr" in action %}
.DeleteAttr("{{action["delete_attr"]}}", "{{action["comment"]}}"){{")" if loop.last}}
{% endif %}
{% if "fix_bug" in action %}
.BugfixWithBehaviorChanged("{{action["comment"]}}"){{")" if loop.last}}
{% endif %}
{% endfor %}
{% endfor %};
{% endif %}
{% endmacro %}
{# --------------------------------------- backward op maker ---------------------------------------------- #} {# --------------------------------------- backward op maker ---------------------------------------------- #}
{% macro backward_op_maker(api, forward_api) %} {% macro backward_op_maker(api, forward_api) %}
...@@ -272,8 +308,8 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> ...@@ -272,8 +308,8 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
{% for input in api["inputs"] %} {% for input in api["inputs"] %}
grad_op->SetInput({{input["name"] | to_opmaker_name}}, this->{{extract_input_from_forward( grad_op->SetInput({{input["name"] | to_opmaker_name}}, this->{{extract_input_from_forward(
input["name"], input["name"],
forward_input_names, forward_input_names,
forward_output_names, forward_output_names,
forward_input_orig_names, forward_input_orig_names,
forward_output_orig_names)}}); forward_output_orig_names)}});
...@@ -281,8 +317,8 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> ...@@ -281,8 +317,8 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
{% for output in api["outputs"] %} {% for output in api["outputs"] %}
grad_op->SetOutput({{output["name"] | to_opmaker_name}}, this->{{extract_output_from_forward( grad_op->SetOutput({{output["name"] | to_opmaker_name}}, this->{{extract_output_from_forward(
output["name"], output["name"],
forward_input_names, forward_input_names,
forward_output_names, forward_output_names,
forward_input_orig_names, forward_input_orig_names,
forward_output_orig_names)}}); forward_output_orig_names)}});
...@@ -308,8 +344,8 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> ...@@ -308,8 +344,8 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
{% endmacro %} {% endmacro %}
{% macro extract_input_from_forward(name, {% macro extract_input_from_forward(name,
input_names, output_names, input_names, output_names,
input_orig_names, output_orig_names) %}{# inline #} input_orig_names, output_orig_names) %}{# inline #}
{% if name in input_names %} {% if name in input_names %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name)]%} {% set name_in_forward_orig = input_orig_names[input_names.index(name)]%}
......
...@@ -50,7 +50,6 @@ PD_REGISTER_INFER_META_FN({api.kernel['func'][0]}, phi::{api.infer_meta['func']} ...@@ -50,7 +50,6 @@ PD_REGISTER_INFER_META_FN({api.kernel['func'][0]}, phi::{api.infer_meta['func']}
args = [] args = []
for input_name in api.inputs['names']: for input_name in api.inputs['names']:
if input_name in kernel_params: if input_name in kernel_params:
print("type", api.inputs['input_info'])
args.append( args.append(
tensor_type_map[api.inputs['input_info'][input_name]] + tensor_type_map[api.inputs['input_info'][input_name]] +
' ' + input_name) ' ' + input_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册