op.c.j2 1.8 KB
Newer Older
J
Jiabin Yang 已提交
1
{% from "operator_utils.c.j2" import op_maker, backward_op_maker, backward_op_reused_maker, operator, register_op_with_components, register_op_version, composite_grad_op_maker %}
2
// this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit.
3
#include <string>
J
Jiabin Yang 已提交
4
#include "paddle/fluid/framework/convert_utils.h"
5 6
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
7
#include "paddle/fluid/framework/op_version_registry.h"
8
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
J
Jiabin Yang 已提交
9 10
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
11
#include "paddle/fluid/operators/generator/get_expected_kernel_func.h"
12
#include "paddle/phi/core/infermeta_utils.h"
J
Jiabin Yang 已提交
13
#include "paddle/phi/infermeta/backward.h"
14
#include "paddle/phi/infermeta/binary.h"
15
#include "paddle/phi/infermeta/fusion.h"
16
#include "paddle/phi/infermeta/multiary.h"
J
Jiabin Yang 已提交
17 18 19
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/infermeta/unary.h"
20 21 22 23 24 25

namespace paddle {
namespace operators {

using paddle::framework::GradVarName;

26 27
{% for op in ops %}
  {% if op is base_op %}
28

29
{{op_maker(op)}}
30

31
{{operator(op)}}
32 33 34
  {% endif %}
{% endfor %}

35 36
{% for op in backward_ops %}
  {% if op is base_op %}
37

38
{{backward_op_maker(op, op_dict[op["forward"]["name"]])}}
39

40
{{operator(op)}}
41
  {% else %}
42
{{backward_op_reused_maker(op, op_dict[op["forward"]["name"]], op["invoke"])}}
43
  {% endif %}
44
  {% if op is composite_op %}
J
Jiabin Yang 已提交
45
{{composite_grad_op_maker(op_dict[op["name"]])}}
46
  {% endif %}
47 48 49 50 51
{% endfor %}
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
52 53
{% for op in ops + backward_ops %}
{% if op is base_op %}
54
{{register_op_with_components(op)}}
55
{{register_op_version(op)}}
56 57
{% endif %}
{% endfor %}