未验证 提交 e6b26393 编写于 作者: X xiaoguoguo626807 提交者: GitHub

[NewIR]Codegen templete of Op vjp interface (#56196)

* op_vjp generate template

* op_vjp_gen templete

* delete print
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>

---------
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
上级 9d40da31
......@@ -110,6 +110,9 @@ CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_g
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/fluid/primitive/type/desc_tensor.h"
#include "paddle/ir/core/op_base.h"
{input}
......@@ -679,17 +682,16 @@ def OpGenerator(
with open(yaml_file, "r") as f:
ops = yaml.safe_load(f)
op_yaml_items = op_yaml_items + ops
op_info_items = []
op_info_items = {}
for op in op_yaml_items:
op_info_items.append(
OpInfoParser(op, op_compat_parser.get_compat(op['name']))
op_info_items[op['name']] = OpInfoParser(
op, op_compat_parser.get_compat(op['name'])
)
# (3) CodeGen: Traverse op_info_items and generate
ops_name_list = [] # all op class name store in this list
ops_declare_list = [] # all op class declare store in this list
ops_defined_list = [] # all op class defined store in this list
for op_info in op_info_items:
for key, op_info in op_info_items.items():
# get op inputs info
op_input_name_list = op_info.input_name_list
op_input_type_list = op_info.input_type_list
......@@ -1028,6 +1030,21 @@ def OpGenerator(
op_infer_meta_str = gen_op_infer_meta_str(op_info, op_class_name)
# =================================== #
# gen Vjp func str #
# =================================== #
# generate op vjp function str
op_vjp_str = ''
# TODO(chenzhiyang) add vjp gen code
# if op_info.backward_name and op_info.op_phi_name[0] in vjp_interface_gen_op_list:
# op_vjp_str = gen_op_vjp_str(op_class_name,
# op_info.backward_name,
# op_name,
# op_info_items[op_info.op_phi_name[0]],
# op_info_items[op_info.backward_name])
ops_name_list.append(op_class_name)
ops_declare_list.append(op_declare_str)
ops_defined_list.append(op_defined_str)
......@@ -1038,6 +1055,7 @@ def OpGenerator(
ops_defined_list.append(build_func_with_muta_attr_is_input)
ops_defined_list.append(op_verify_str)
ops_defined_list.append(op_infer_meta_str)
ops_defined_list.append(op_vjp_str)
# (4) Generate head file str
op_namespaces_prev = ""
......
......@@ -22,6 +22,86 @@ void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{
}}
"""
OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE = """
{input_type} {input_name}(std::make_shared<primitive::experimental::DescTensor>(op_obj.{input_name}()));
"""
OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE = """
Tensor {output_grad_name}(std::make_shared<primitive::experimental::DescTensor>((out_grads[{idx1}][{idx2}]);
"""
OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE = """
std::vector<Tensor> {output_grad_name}(std::make_shared<primitive::experimental::DescTensor>((out_grads[{idx1}]);
"""
OP_VJP_CALL_VJP_TEMPLATE = """
Tensor std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::{op_phi_name}_vjp({inputs_list}, stop_gradients);
"""
OP_VJP_STOPGRADIENT_TEMPLATE = """
if(!stop_gradients[{idx1}][{idx2}]){{
res[{idx1}][{idx2}] = std::static_pointer_cast<primitive::experimental::DescTensor>(
tensor_res[idx1][idx2].impl())
->getValue()
.dyn_cast<ir::OpResult>();
}}
"""
OP_VJP_DEFINE_TEMPLATE = """
std::vector<std::vector<ir::OpResult>> {op_class_name}::Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients){{
{op_class_name} op_obj = op->dyn_cast<{op_class_name}>();
VLOG(6) << "Prepare inputs of {op_grad_name}";
{forward_input_code}
{forward_output_code}
{forward_output_grad_code}
VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}";
{attribute_code}
VLOG(4) << "Vjp prepare call {op_phi_name}'s vjp inteface";
{call_vjp_code}
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1));
{stop_gradient_input_grad_code}
return res;
}}
"""
def gen_op_vjp_str(
op_class_name,
op_grad_name,
op_phi_name,
op_info,
op_grad_info,
):
forward_input_code = ''
forward_output_code = ''
forward_output_grad_code = ''
attribute_code = ''
call_vjp_code = ''
stop_gradient_input_grad_code = ''
str = OP_VJP_DEFINE_TEMPLATE.format(
op_class_name=op_class_name,
op_grad_name=op_grad_name,
op_phi_name=op_phi_name,
forward_input_code=forward_input_code,
forward_output_code=forward_output_code,
forward_output_grad_code=forward_output_grad_code,
attribute_code=attribute_code,
call_vjp_code=call_vjp_code,
stop_gradient_input_grad_code=stop_gradient_input_grad_code,
)
return str
def gen_op_infer_meta_str(op_info, op_class_name):
op_infer_meta_str = ""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册