From e6b26393e6bb51626b7210ab72bf0bcd59f202fe Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Mon, 14 Aug 2023 14:21:25 +0800 Subject: [PATCH] [NewIR]Codegen templete of Op vjp interface (#56196) * op_vjp generate template * op_vjp_gen templete * delete print Co-authored-by: Aurelius84 --------- Co-authored-by: Aurelius84 --- .../fluid/ir/dialect/op_generator/op_gen.py | 28 +++++-- .../dialect/op_generator/op_interface_gen.py | 80 +++++++++++++++++++ 2 files changed, 103 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index d990141add5..a204d64b00f 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -110,6 +110,9 @@ CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_g #include "paddle/phi/infermeta/ternary.h" #include "paddle/phi/infermeta/backward.h" #include "paddle/phi/api/lib/utils/allocator.h" +#include "paddle/fluid/primitive/rule/vjp/vjp.h" +#include "paddle/fluid/primitive/type/desc_tensor.h" +#include "paddle/ir/core/op_base.h" {input} @@ -679,17 +682,16 @@ def OpGenerator( with open(yaml_file, "r") as f: ops = yaml.safe_load(f) op_yaml_items = op_yaml_items + ops - op_info_items = [] + op_info_items = {} for op in op_yaml_items: - op_info_items.append( - OpInfoParser(op, op_compat_parser.get_compat(op['name'])) + op_info_items[op['name']] = OpInfoParser( + op, op_compat_parser.get_compat(op['name']) ) - # (3) CodeGen: Traverse op_info_items and generate ops_name_list = [] # all op class name store in this list ops_declare_list = [] # all op class declare store in this list ops_defined_list = [] # all op class defined store in this list - for op_info in op_info_items: + for key, op_info in op_info_items.items(): # get op inputs info op_input_name_list = op_info.input_name_list op_input_type_list = op_info.input_type_list @@ -1028,6 +1030,21 @@ def OpGenerator( op_infer_meta_str = gen_op_infer_meta_str(op_info, op_class_name) + # =================================== # + # gen Vjp func str # + # =================================== # + + # generate op vjp function str + op_vjp_str = '' + + # TODO(chenzhiyang) add vjp gen code + # if op_info.backward_name and op_info.op_phi_name[0] in vjp_interface_gen_op_list: + # op_vjp_str = gen_op_vjp_str(op_class_name, + # op_info.backward_name, + # op_name, + # op_info_items[op_info.op_phi_name[0]], + # op_info_items[op_info.backward_name]) + ops_name_list.append(op_class_name) ops_declare_list.append(op_declare_str) ops_defined_list.append(op_defined_str) @@ -1038,6 +1055,7 @@ def OpGenerator( ops_defined_list.append(build_func_with_muta_attr_is_input) ops_defined_list.append(op_verify_str) ops_defined_list.append(op_infer_meta_str) + ops_defined_list.append(op_vjp_str) # (4) Generate head file str op_namespaces_prev = "" diff --git a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py index 4833111c9d2..ef5f2e1b4cc 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py @@ -22,6 +22,86 @@ void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{ }} """ +OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE = """ + {input_type} {input_name}(std::make_shared(op_obj.{input_name}())); +""" + +OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE = """ + Tensor {output_grad_name}(std::make_shared((out_grads[{idx1}][{idx2}]); +""" + +OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE = """ + std::vector {output_grad_name}(std::make_shared((out_grads[{idx1}]); +""" + +OP_VJP_CALL_VJP_TEMPLATE = """ + Tensor std::vector> tensor_res = + primitive::experimental::{op_phi_name}_vjp({inputs_list}, stop_gradients); +""" + +OP_VJP_STOPGRADIENT_TEMPLATE = """ + if(!stop_gradients[{idx1}][{idx2}]){{ + res[{idx1}][{idx2}] = std::static_pointer_cast( + tensor_res[idx1][idx2].impl()) + ->getValue() + .dyn_cast(); + }} +""" + +OP_VJP_DEFINE_TEMPLATE = """ +std::vector> {op_class_name}::Vjp( + ir::Operation* op, + const std::vector>& out_grads, + const std::vector>& stop_gradients){{ + {op_class_name} op_obj = op->dyn_cast<{op_class_name}>(); + + VLOG(6) << "Prepare inputs of {op_grad_name}"; + + {forward_input_code} + {forward_output_code} + {forward_output_grad_code} + + VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}"; + {attribute_code} + + VLOG(4) << "Vjp prepare call {op_phi_name}'s vjp inteface"; + {call_vjp_code} + + std::vector> res(1, std::vector(1)); + {stop_gradient_input_grad_code} + + return res; +}} +""" + + +def gen_op_vjp_str( + op_class_name, + op_grad_name, + op_phi_name, + op_info, + op_grad_info, +): + forward_input_code = '' + forward_output_code = '' + forward_output_grad_code = '' + attribute_code = '' + call_vjp_code = '' + stop_gradient_input_grad_code = '' + + str = OP_VJP_DEFINE_TEMPLATE.format( + op_class_name=op_class_name, + op_grad_name=op_grad_name, + op_phi_name=op_phi_name, + forward_input_code=forward_input_code, + forward_output_code=forward_output_code, + forward_output_grad_code=forward_output_grad_code, + attribute_code=attribute_code, + call_vjp_code=call_vjp_code, + stop_gradient_input_grad_code=stop_gradient_input_grad_code, + ) + return str + def gen_op_infer_meta_str(op_info, op_class_name): op_infer_meta_str = "" -- GitLab