未验证 提交 971945ab 编写于 作者: C Chen Zhiyang 提交者: GitHub

【NewIR】Vjp autogen for multi-input op(concat) (#56657)

* gen-temp-save

* add concat vjp

* remove useless print

* code style

* remove manual concat vjp
上级 d8b2a1ed
...@@ -25,6 +25,15 @@ void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{ ...@@ -25,6 +25,15 @@ void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{
OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE = """ OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE = """
{input_type} {input_name}(std::make_shared<primitive::LazyTensor>(op_obj.{input_name}()));""" {input_type} {input_name}(std::make_shared<primitive::LazyTensor>(op_obj.{input_name}()));"""
OP_VJP_FORWARD_MULTI_INPUT_TEMPLATE = """
ir::CombineOp combine_op_obj =
op_obj.{input_name}().GetDefiningOp()->dyn_cast<ir::CombineOp>();
std::vector<Tensor> {input_name};
for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) {{
{input_name}.emplace_back(
std::make_shared<primitive::LazyTensor>(combine_op_obj.inputs()[idx]));
}}"""
OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE = """ OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE = """
Tensor {output_grad_name}(std::make_shared<primitive::LazyTensor>(out_grads[{idx1}][{idx2}]));""" Tensor {output_grad_name}(std::make_shared<primitive::LazyTensor>(out_grads[{idx1}][{idx2}]));"""
...@@ -44,21 +53,21 @@ OP_VJP_CALL_VJP_TEMPLATE = """ std::vector<std::vector<Tensor>> tensor_res = ...@@ -44,21 +53,21 @@ OP_VJP_CALL_VJP_TEMPLATE = """ std::vector<std::vector<Tensor>> tensor_res =
OP_VJP_STOPGRADIENT_TEMPLATE = """ OP_VJP_STOPGRADIENT_TEMPLATE = """
std::vector<std::vector<ir::OpResult>> res(tensor_res.size()); std::vector<std::vector<ir::OpResult>> res(tensor_res.size());
for (size_t i = 0; i < tensor_res.size(); ++i) {{ for (size_t i = 0; i < tensor_res.size(); ++i) {
res[i].resize(tensor_res[i].size()); res[i].resize(tensor_res[i].size());
for (size_t j = 0; j < tensor_res[i].size(); ++j) {{ for (size_t j = 0; j < tensor_res[i].size(); ++j) {
if(tensor_res[i][j].defined()){{ if(tensor_res[i][j].defined()){
res[i][j] = std::static_pointer_cast<primitive::LazyTensor>(tensor_res[i][j].impl())->getValue().dyn_cast<ir::OpResult>(); res[i][j] = std::static_pointer_cast<primitive::LazyTensor>(tensor_res[i][j].impl())->getValue().dyn_cast<ir::OpResult>();
}} }
}} }
}}""" }"""
OP_VJP_DEFINE_TEMPLATE = """ OP_VJP_DEFINE_TEMPLATE = """
std::vector<std::vector<ir::OpResult>> {op_class_name}::Vjp(ir::Operation* op, const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients){{ std::vector<std::vector<ir::OpResult>> {op_class_name}::Vjp(ir::Operation* op, const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients){{
{op_class_name} op_obj = op->dyn_cast<{op_class_name}>(); {op_class_name} op_obj = op->dyn_cast<{op_class_name}>();
VLOG(6) << "Prepare inputs of {op_grad_name}"; VLOG(6) << "Prepare inputs of {op_grad_name}";
{forward_input_code} {forward_input_output_code}
{forward_output_grad_code} {forward_output_grad_code}
VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}"; VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}";
...@@ -87,7 +96,7 @@ def gen_op_vjp_str( ...@@ -87,7 +96,7 @@ def gen_op_vjp_str(
op_grad_info, op_grad_info,
): ):
bw_input_list = op_grad_info.input_name_list bw_input_list = op_grad_info.input_name_list
forward_input_code = '' forward_input_output_code = ''
forward_output_grad_code = '' forward_output_grad_code = ''
build_args_str = '' build_args_str = ''
grad_idx = -1 grad_idx = -1
...@@ -97,14 +106,20 @@ def gen_op_vjp_str( ...@@ -97,14 +106,20 @@ def gen_op_vjp_str(
bw_input_list[idx] in op_info.input_name_list bw_input_list[idx] in op_info.input_name_list
or bw_input_list[idx] in op_info.output_name_list or bw_input_list[idx] in op_info.output_name_list
): ):
forward_input_code += ( input_type = input_types_map[op_grad_info.input_type_list[idx]]
OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format( if input_type == 'Tensor':
input_type=input_types_map[ forward_input_output_code += (
op_grad_info.input_type_list[idx] OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format(
], input_type=input_type,
input_name=bw_input_list[idx], input_name=bw_input_list[idx],
)
)
else:
forward_input_output_code += (
OP_VJP_FORWARD_MULTI_INPUT_TEMPLATE.format(
input_name=bw_input_list[idx],
)
) )
)
else: else:
grad_idx += 1 grad_idx += 1
forward_output_grad_code += ( forward_output_grad_code += (
...@@ -117,21 +132,31 @@ def gen_op_vjp_str( ...@@ -117,21 +132,31 @@ def gen_op_vjp_str(
for idx in range(len(op_attribute_list)): for idx in range(len(op_attribute_list)):
build_args_str += op_attribute_list[idx] + ", " build_args_str += op_attribute_list[idx] + ", "
if op_attribute_list[idx] in op_info.attribute_name_list: if op_attribute_list[idx] in op_info.attribute_name_list:
attribute_code += OP_VJP_ATTRIBUTE_TEMPLATE.format( if op_attribute_list[idx] in op_info.mutable_attribute_name_list:
attr_type=op_grad_info.attribute_gen_arg_type_list[idx], attribute_code += (
attr_name=op_attribute_list[idx], OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format(
attr_parse_type=op_grad_info.attribute_type_list[idx], input_type="Tensor",
) input_name=op_attribute_list[idx],
)
)
else:
attribute_code += OP_VJP_ATTRIBUTE_TEMPLATE.format(
attr_type=op_grad_info.attribute_gen_arg_type_list[idx],
attr_name=op_attribute_list[idx],
attr_parse_type=op_grad_info.attribute_type_list[idx],
)
else: else:
attribute_code += OP_VJP_ATTRIBUTE_DEFAULT_TEMPLATE.format( attribute_code += OP_VJP_ATTRIBUTE_DEFAULT_TEMPLATE.format(
attr_type=op_grad_info.attribute_gen_arg_type_list[idx], attr_type=op_grad_info.attribute_gen_arg_type_list[idx],
attr_name=op_attribute_list[idx], attr_name=op_attribute_list[idx],
default_value=op_grad_info.attribute_default_value_list[idx], default_value=op_grad_info.attribute_default_value_list[idx],
) )
op_phi_name_format = op_phi_name
if op_phi_name[-1] == '_': if op_phi_name[-1] == '_':
op_phi_name = op_phi_name[:-1] op_phi_name_format = op_phi_name[:-1]
call_vjp_code = OP_VJP_CALL_VJP_TEMPLATE.format( call_vjp_code = OP_VJP_CALL_VJP_TEMPLATE.format(
op_phi_name=op_phi_name, op_phi_name=op_phi_name_format,
inputs_list=build_args_str, inputs_list=build_args_str,
) )
stop_gradient_input_grad_code = OP_VJP_STOPGRADIENT_TEMPLATE stop_gradient_input_grad_code = OP_VJP_STOPGRADIENT_TEMPLATE
...@@ -141,7 +166,7 @@ def gen_op_vjp_str( ...@@ -141,7 +166,7 @@ def gen_op_vjp_str(
op_grad_name=op_grad_name, op_grad_name=op_grad_name,
op_phi_name=op_phi_name, op_phi_name=op_phi_name,
res_size=len(op_info.input_name_list), res_size=len(op_info.input_name_list),
forward_input_code=forward_input_code, forward_input_output_code=forward_input_output_code,
forward_output_grad_code=forward_output_grad_code, forward_output_grad_code=forward_output_grad_code,
attribute_code=attribute_code, attribute_code=attribute_code,
call_vjp_code=call_vjp_code, call_vjp_code=call_vjp_code,
......
...@@ -30,4 +30,10 @@ vjp_interface_declare_gen_op_list = [ ...@@ -30,4 +30,10 @@ vjp_interface_declare_gen_op_list = [
"add", "add",
"concat", "concat",
] ]
vjp_interface_implementation_gen_op_list = ["tanh", "mean", "divide", "add"] vjp_interface_implementation_gen_op_list = [
"tanh",
"mean",
"divide",
"add",
"concat",
]
...@@ -27,40 +27,6 @@ namespace paddle { ...@@ -27,40 +27,6 @@ namespace paddle {
namespace dialect { namespace dialect {
using IntArray = paddle::experimental::IntArray; using IntArray = paddle::experimental::IntArray;
std::vector<std::vector<ir::OpResult>> ConcatOp::Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
ConcatOp op_obj = op->dyn_cast<ConcatOp>();
ir::CombineOp combine_op_obj =
op_obj.x().GetDefiningOp()->dyn_cast<ir::CombineOp>();
std::vector<Tensor> x;
for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) {
x.emplace_back(
std::make_shared<primitive::LazyTensor>(combine_op_obj.inputs()[idx]));
}
Tensor out_grad(std::make_shared<primitive::LazyTensor>(out_grads[0][0]));
Tensor axis(std::make_shared<primitive::LazyTensor>(op_obj.axis()));
std::vector<std::vector<Tensor>> tensor_res =
primitive::concat_vjp(x, out_grad, axis, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(tensor_res.size(),
std::vector<ir::OpResult>());
for (uint64_t i = 0; i < tensor_res.size(); i++) {
res[i].resize(tensor_res[i].size());
for (uint64_t j = 0; j < tensor_res[i].size(); j++) {
if (tensor_res[i][j].defined()) {
res[i][j] = std::static_pointer_cast<primitive::LazyTensor>(
tensor_res[i][j].impl())
->getValue()
.dyn_cast<ir::OpResult>();
}
}
}
return res;
}
std::vector<std::vector<ir::OpResult>> SumOp::Vjp( std::vector<std::vector<ir::OpResult>> SumOp::Vjp(
ir::Operation* op, ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<ir::OpResult>>& out_grads,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册