op.c.j2 1.8 KB
Newer Older
J
Jiabin Yang 已提交
1
{% from "operator_utils.c.j2" import op_maker, backward_op_maker, backward_op_reused_maker, operator, register_op_with_components, register_op_version, composite_grad_op_maker %}
2
// this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit.
3
#include <string>
J
Jiabin Yang 已提交
4
#include "paddle/fluid/framework/convert_utils.h"
5 6
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
7
#include "paddle/fluid/framework/op_version_registry.h"
J
Jiabin Yang 已提交
8 9 10
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
11
#include "paddle/phi/core/infermeta_utils.h"
J
Jiabin Yang 已提交
12
#include "paddle/phi/infermeta/backward.h"
13 14
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/multiary.h"
J
Jiabin Yang 已提交
15 16 17
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/infermeta/unary.h"
18 19 20 21 22 23

namespace paddle {
namespace operators {

using paddle::framework::GradVarName;

24 25
{% for op in ops %}
  {% if op is base_op %}
26

27
{{op_maker(op)}}
28

29
{{operator(op)}}
30 31 32
  {% endif %}
{% endfor %}

33 34
{% for op in backward_ops %}
  {% if op is base_op %}
35

36
{{backward_op_maker(op, op_dict[op["forward"]["name"]])}}
37

38
{{operator(op)}}
39
  {% else %}
40
{{backward_op_reused_maker(op, op_dict[op["forward"]["name"]], op["invoke"])}}
41
  {% endif %}
J
Jiabin Yang 已提交
42 43 44 45 46
  {% if composite_gen_flag == True %}
    {% if op is composite_op %}
{{composite_grad_op_maker(op_dict[op["name"]])}}
    {% endif %}
  {% endif %}  
47 48 49 50 51
{% endfor %}
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
52 53
{% for op in ops + backward_ops %}
{% if op is base_op %}
J
Jiabin Yang 已提交
54
{{register_op_with_components(op, op_dict)}}
55
{{register_op_version(op)}}
56 57
{% endif %}
{% endfor %}