diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index d990141add5a01912b385012033b557aff0b234b..a204d64b00f48cc5133ba2732e45ce5612c40c31 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -110,6 +110,9 @@ CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_g #include "paddle/phi/infermeta/ternary.h" #include "paddle/phi/infermeta/backward.h" #include "paddle/phi/api/lib/utils/allocator.h" +#include "paddle/fluid/primitive/rule/vjp/vjp.h" +#include "paddle/fluid/primitive/type/desc_tensor.h" +#include "paddle/ir/core/op_base.h" {input} @@ -679,17 +682,16 @@ def OpGenerator( with open(yaml_file, "r") as f: ops = yaml.safe_load(f) op_yaml_items = op_yaml_items + ops - op_info_items = [] + op_info_items = {} for op in op_yaml_items: - op_info_items.append( - OpInfoParser(op, op_compat_parser.get_compat(op['name'])) + op_info_items[op['name']] = OpInfoParser( + op, op_compat_parser.get_compat(op['name']) ) - # (3) CodeGen: Traverse op_info_items and generate ops_name_list = [] # all op class name store in this list ops_declare_list = [] # all op class declare store in this list ops_defined_list = [] # all op class defined store in this list - for op_info in op_info_items: + for key, op_info in op_info_items.items(): # get op inputs info op_input_name_list = op_info.input_name_list op_input_type_list = op_info.input_type_list @@ -1028,6 +1030,21 @@ def OpGenerator( op_infer_meta_str = gen_op_infer_meta_str(op_info, op_class_name) + # =================================== # + # gen Vjp func str # + # =================================== # + + # generate op vjp function str + op_vjp_str = '' + + # TODO(chenzhiyang) add vjp gen code + # if op_info.backward_name and op_info.op_phi_name[0] in vjp_interface_gen_op_list: + # op_vjp_str = gen_op_vjp_str(op_class_name, + # op_info.backward_name, + # op_name, + # op_info_items[op_info.op_phi_name[0]], + # op_info_items[op_info.backward_name]) + ops_name_list.append(op_class_name) ops_declare_list.append(op_declare_str) ops_defined_list.append(op_defined_str) @@ -1038,6 +1055,7 @@ def OpGenerator( ops_defined_list.append(build_func_with_muta_attr_is_input) ops_defined_list.append(op_verify_str) ops_defined_list.append(op_infer_meta_str) + ops_defined_list.append(op_vjp_str) # (4) Generate head file str op_namespaces_prev = "" diff --git a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py index 4833111c9d2ab3cb36254ca802fce4c501abfcb0..ef5f2e1b4ccab00276dbcdd5c7a52a52e1dbde00 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py @@ -22,6 +22,86 @@ void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{ }} """ +OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE = """ + {input_type} {input_name}(std::make_shared(op_obj.{input_name}())); +""" + +OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE = """ + Tensor {output_grad_name}(std::make_shared((out_grads[{idx1}][{idx2}]); +""" + +OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE = """ + std::vector {output_grad_name}(std::make_shared((out_grads[{idx1}]); +""" + +OP_VJP_CALL_VJP_TEMPLATE = """ + Tensor std::vector> tensor_res = + primitive::experimental::{op_phi_name}_vjp({inputs_list}, stop_gradients); +""" + +OP_VJP_STOPGRADIENT_TEMPLATE = """ + if(!stop_gradients[{idx1}][{idx2}]){{ + res[{idx1}][{idx2}] = std::static_pointer_cast( + tensor_res[idx1][idx2].impl()) + ->getValue() + .dyn_cast(); + }} +""" + +OP_VJP_DEFINE_TEMPLATE = """ +std::vector> {op_class_name}::Vjp( + ir::Operation* op, + const std::vector>& out_grads, + const std::vector>& stop_gradients){{ + {op_class_name} op_obj = op->dyn_cast<{op_class_name}>(); + + VLOG(6) << "Prepare inputs of {op_grad_name}"; + + {forward_input_code} + {forward_output_code} + {forward_output_grad_code} + + VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}"; + {attribute_code} + + VLOG(4) << "Vjp prepare call {op_phi_name}'s vjp inteface"; + {call_vjp_code} + + std::vector> res(1, std::vector(1)); + {stop_gradient_input_grad_code} + + return res; +}} +""" + + +def gen_op_vjp_str( + op_class_name, + op_grad_name, + op_phi_name, + op_info, + op_grad_info, +): + forward_input_code = '' + forward_output_code = '' + forward_output_grad_code = '' + attribute_code = '' + call_vjp_code = '' + stop_gradient_input_grad_code = '' + + str = OP_VJP_DEFINE_TEMPLATE.format( + op_class_name=op_class_name, + op_grad_name=op_grad_name, + op_phi_name=op_phi_name, + forward_input_code=forward_input_code, + forward_output_code=forward_output_code, + forward_output_grad_code=forward_output_grad_code, + attribute_code=attribute_code, + call_vjp_code=call_vjp_code, + stop_gradient_input_grad_code=stop_gradient_input_grad_code, + ) + return str + def gen_op_infer_meta_str(op_info, op_class_name): op_infer_meta_str = ""