未验证 提交 971945ab 编写于 作者: C Chen Zhiyang 提交者: GitHub

【NewIR】Vjp autogen for multi-input op(concat) (#56657)

* gen-temp-save

* add concat vjp

* remove useless print

* code style

* remove manual concat vjp
上级 d8b2a1ed
......@@ -25,6 +25,15 @@ void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{
OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE = """
{input_type} {input_name}(std::make_shared<primitive::LazyTensor>(op_obj.{input_name}()));"""
OP_VJP_FORWARD_MULTI_INPUT_TEMPLATE = """
ir::CombineOp combine_op_obj =
op_obj.{input_name}().GetDefiningOp()->dyn_cast<ir::CombineOp>();
std::vector<Tensor> {input_name};
for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) {{
{input_name}.emplace_back(
std::make_shared<primitive::LazyTensor>(combine_op_obj.inputs()[idx]));
}}"""
OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE = """
Tensor {output_grad_name}(std::make_shared<primitive::LazyTensor>(out_grads[{idx1}][{idx2}]));"""
......@@ -44,21 +53,21 @@ OP_VJP_CALL_VJP_TEMPLATE = """ std::vector<std::vector<Tensor>> tensor_res =
OP_VJP_STOPGRADIENT_TEMPLATE = """
std::vector<std::vector<ir::OpResult>> res(tensor_res.size());
for (size_t i = 0; i < tensor_res.size(); ++i) {{
for (size_t i = 0; i < tensor_res.size(); ++i) {
res[i].resize(tensor_res[i].size());
for (size_t j = 0; j < tensor_res[i].size(); ++j) {{
if(tensor_res[i][j].defined()){{
for (size_t j = 0; j < tensor_res[i].size(); ++j) {
if(tensor_res[i][j].defined()){
res[i][j] = std::static_pointer_cast<primitive::LazyTensor>(tensor_res[i][j].impl())->getValue().dyn_cast<ir::OpResult>();
}}
}}
}}"""
}
}
}"""
OP_VJP_DEFINE_TEMPLATE = """
std::vector<std::vector<ir::OpResult>> {op_class_name}::Vjp(ir::Operation* op, const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients){{
{op_class_name} op_obj = op->dyn_cast<{op_class_name}>();
VLOG(6) << "Prepare inputs of {op_grad_name}";
{forward_input_code}
{forward_input_output_code}
{forward_output_grad_code}
VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}";
......@@ -87,7 +96,7 @@ def gen_op_vjp_str(
op_grad_info,
):
bw_input_list = op_grad_info.input_name_list
forward_input_code = ''
forward_input_output_code = ''
forward_output_grad_code = ''
build_args_str = ''
grad_idx = -1
......@@ -97,14 +106,20 @@ def gen_op_vjp_str(
bw_input_list[idx] in op_info.input_name_list
or bw_input_list[idx] in op_info.output_name_list
):
forward_input_code += (
OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format(
input_type=input_types_map[
op_grad_info.input_type_list[idx]
],
input_name=bw_input_list[idx],
input_type = input_types_map[op_grad_info.input_type_list[idx]]
if input_type == 'Tensor':
forward_input_output_code += (
OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format(
input_type=input_type,
input_name=bw_input_list[idx],
)
)
else:
forward_input_output_code += (
OP_VJP_FORWARD_MULTI_INPUT_TEMPLATE.format(
input_name=bw_input_list[idx],
)
)
)
else:
grad_idx += 1
forward_output_grad_code += (
......@@ -117,21 +132,31 @@ def gen_op_vjp_str(
for idx in range(len(op_attribute_list)):
build_args_str += op_attribute_list[idx] + ", "
if op_attribute_list[idx] in op_info.attribute_name_list:
attribute_code += OP_VJP_ATTRIBUTE_TEMPLATE.format(
attr_type=op_grad_info.attribute_gen_arg_type_list[idx],
attr_name=op_attribute_list[idx],
attr_parse_type=op_grad_info.attribute_type_list[idx],
)
if op_attribute_list[idx] in op_info.mutable_attribute_name_list:
attribute_code += (
OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format(
input_type="Tensor",
input_name=op_attribute_list[idx],
)
)
else:
attribute_code += OP_VJP_ATTRIBUTE_TEMPLATE.format(
attr_type=op_grad_info.attribute_gen_arg_type_list[idx],
attr_name=op_attribute_list[idx],
attr_parse_type=op_grad_info.attribute_type_list[idx],
)
else:
attribute_code += OP_VJP_ATTRIBUTE_DEFAULT_TEMPLATE.format(
attr_type=op_grad_info.attribute_gen_arg_type_list[idx],
attr_name=op_attribute_list[idx],
default_value=op_grad_info.attribute_default_value_list[idx],
)
op_phi_name_format = op_phi_name
if op_phi_name[-1] == '_':
op_phi_name = op_phi_name[:-1]
op_phi_name_format = op_phi_name[:-1]
call_vjp_code = OP_VJP_CALL_VJP_TEMPLATE.format(
op_phi_name=op_phi_name,
op_phi_name=op_phi_name_format,
inputs_list=build_args_str,
)
stop_gradient_input_grad_code = OP_VJP_STOPGRADIENT_TEMPLATE
......@@ -141,7 +166,7 @@ def gen_op_vjp_str(
op_grad_name=op_grad_name,
op_phi_name=op_phi_name,
res_size=len(op_info.input_name_list),
forward_input_code=forward_input_code,
forward_input_output_code=forward_input_output_code,
forward_output_grad_code=forward_output_grad_code,
attribute_code=attribute_code,
call_vjp_code=call_vjp_code,
......
......@@ -30,4 +30,10 @@ vjp_interface_declare_gen_op_list = [
"add",
"concat",
]
vjp_interface_implementation_gen_op_list = ["tanh", "mean", "divide", "add"]
vjp_interface_implementation_gen_op_list = [
"tanh",
"mean",
"divide",
"add",
"concat",
]
......@@ -27,40 +27,6 @@ namespace paddle {
namespace dialect {
using IntArray = paddle::experimental::IntArray;
std::vector<std::vector<ir::OpResult>> ConcatOp::Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
ConcatOp op_obj = op->dyn_cast<ConcatOp>();
ir::CombineOp combine_op_obj =
op_obj.x().GetDefiningOp()->dyn_cast<ir::CombineOp>();
std::vector<Tensor> x;
for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) {
x.emplace_back(
std::make_shared<primitive::LazyTensor>(combine_op_obj.inputs()[idx]));
}
Tensor out_grad(std::make_shared<primitive::LazyTensor>(out_grads[0][0]));
Tensor axis(std::make_shared<primitive::LazyTensor>(op_obj.axis()));
std::vector<std::vector<Tensor>> tensor_res =
primitive::concat_vjp(x, out_grad, axis, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(tensor_res.size(),
std::vector<ir::OpResult>());
for (uint64_t i = 0; i < tensor_res.size(); i++) {
res[i].resize(tensor_res[i].size());
for (uint64_t j = 0; j < tensor_res[i].size(); j++) {
if (tensor_res[i][j].defined()) {
res[i][j] = std::static_pointer_cast<primitive::LazyTensor>(
tensor_res[i][j].impl())
->getValue()
.dyn_cast<ir::OpResult>();
}
}
}
return res;
}
std::vector<std::vector<ir::OpResult>> SumOp::Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册