op.c.j2 1.8 KB
Newer Older
J
Jiabin Yang 已提交
1
{% from "operator_utils.c.j2" import op_maker, backward_op_maker, backward_op_reused_maker, operator, register_op_with_components, register_op_version, composite_grad_op_maker %}
2
// this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit.
3
#include <string>
J
Jiabin Yang 已提交
4
#include "paddle/fluid/framework/convert_utils.h"
5 6
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
7
#include "paddle/fluid/framework/op_version_registry.h"
8
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
J
Jiabin Yang 已提交
9 10
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
11
#include "paddle/phi/core/infermeta_utils.h"
J
Jiabin Yang 已提交
12
#include "paddle/phi/infermeta/backward.h"
13
#include "paddle/phi/infermeta/binary.h"
14
#include "paddle/phi/infermeta/fusion.h"
15
#include "paddle/phi/infermeta/multiary.h"
J
Jiabin Yang 已提交
16 17 18
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/infermeta/unary.h"
19 20 21 22 23 24

namespace paddle {
namespace operators {

using paddle::framework::GradVarName;

25 26
{% for op in ops %}
  {% if op is base_op %}
27

28
{{op_maker(op)}}
29

30
{{operator(op)}}
31 32 33
  {% endif %}
{% endfor %}

34 35
{% for op in backward_ops %}
  {% if op is base_op %}
36

37
{{backward_op_maker(op, op_dict[op["forward"]["name"]])}}
38

39
{{operator(op)}}
40
  {% else %}
41
{{backward_op_reused_maker(op, op_dict[op["forward"]["name"]], op["invoke"])}}
42
  {% endif %}
43
  {% if op is composite_op %}
J
Jiabin Yang 已提交
44
{{composite_grad_op_maker(op_dict[op["name"]])}}
45
  {% endif %}
46 47 48 49 50
{% endfor %}
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
51 52
{% for op in ops + backward_ops %}
{% if op is base_op %}
53
{{register_op_with_components(op)}}
54
{{register_op_version(op)}}
55 56
{% endif %}
{% endfor %}