未验证 提交 128f95a1 编写于 作者: C Chen Zhiyang 提交者: GitHub

Vjp autogen for grad list op(split) (#56720)

* add vjp code gen for SplitOp

* change vjp manual file name
上级 b0b827c7
......@@ -38,7 +38,11 @@ OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE = """
Tensor {output_grad_name}(std::make_shared<primitive::LazyTensor>(out_grads[{idx1}][{idx2}]));"""
OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE = """
std::vector<Tensor> {output_grad_name}(std::make_shared<primitive::LazyTensor>(out_grads[{idx1}]));"""
std::vector<Tensor> {output_grad_name};
for (size_t idx = 0; idx < out_grads[{index}].size(); idx++) {{
{output_grad_name}.emplace_back(
std::make_shared<primitive::LazyTensor>(out_grads[{index}][idx]));
}}"""
OP_VJP_ATTRIBUTE_TEMPLATE = """
{attr_type} {attr_name} = op->attribute("{attr_name}").dyn_cast<{attr_parse_type}>().data();"""
......@@ -47,9 +51,10 @@ OP_VJP_ATTRIBUTE_DEFAULT_TEMPLATE = """
{attr_type} {attr_name} = {default_value};"""
OP_VJP_CALL_VJP_TEMPLATE = """ std::vector<std::vector<Tensor>> tensor_res =
primitive::{op_phi_name}_vjp(
{inputs_list}stop_gradients);"""
OP_VJP_CALL_VJP_TEMPLATE = """
std::vector<std::vector<Tensor>> tensor_res =
primitive::{op_phi_name}_vjp(
{inputs_list}stop_gradients);"""
OP_VJP_STOPGRADIENT_TEMPLATE = """
std::vector<std::vector<ir::OpResult>> res(tensor_res.size());
......@@ -73,10 +78,10 @@ std::vector<std::vector<ir::OpResult>> {op_class_name}::Vjp(ir::Operation* op, c
VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}";
{attribute_code}
VLOG(4) << "Vjp prepare call {op_phi_name}'s vjp inteface";
VLOG(6) << "Vjp prepare call {op_phi_name}'s vjp inteface";
{call_vjp_code}
VLOG(4) << "Vjp prepare stop gradient of {op_grad_name}";
VLOG(6) << "Vjp prepare stop gradient of {op_grad_name}";
{stop_gradient_input_grad_code}
return res;
}}
......@@ -122,11 +127,21 @@ def gen_op_vjp_str(
)
else:
grad_idx += 1
forward_output_grad_code += (
OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE.format(
output_grad_name=bw_input_list[idx], idx1=grad_idx, idx2=0
input_type = input_types_map[op_grad_info.input_type_list[idx]]
if input_type == 'Tensor':
forward_output_grad_code += (
OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE.format(
output_grad_name=bw_input_list[idx],
idx1=grad_idx,
idx2=0,
)
)
else:
forward_input_output_code += (
OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE.format(
output_grad_name=bw_input_list[idx], index=grad_idx
)
)
)
op_attribute_list = op_grad_info.attribute_name_list
attribute_code = ''
for idx in range(len(op_attribute_list)):
......
......@@ -37,4 +37,5 @@ vjp_interface_implementation_gen_op_list = [
"divide",
"add",
"concat",
"split",
]
......@@ -134,6 +134,6 @@ target_include_directories(pd_dialect_api PRIVATE ${PD_DIALECT_BINARY_DIR})
cc_library(
pd_dialect
SRCS pd_dialect.cc pd_op_vjp_manual.cc ${op_vjp_source_file}
SRCS pd_dialect.cc pd_manual_op_vjp.cc ${op_vjp_source_file}
DEPS pd_dialect_api param_to_variable primitive_vjp_experimental
pd_dialect_utils op_yaml_info_parser)
......@@ -54,38 +54,5 @@ std::vector<std::vector<ir::OpResult>> SumOp::Vjp(
return res;
}
std::vector<std::vector<ir::OpResult>> SplitOp::Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
SplitOp op_obj = op->dyn_cast<SplitOp>();
Tensor axis(std::make_shared<primitive::LazyTensor>(op_obj.axis()));
std::vector<Tensor> out_grads_;
for (size_t idx = 0; idx < out_grads[0].size(); idx++) {
out_grads_.emplace_back(
std::make_shared<primitive::LazyTensor>(out_grads[0][idx]));
}
std::vector<std::vector<Tensor>> tensor_res =
primitive::split_vjp(out_grads_, axis, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(tensor_res.size(),
std::vector<ir::OpResult>());
for (uint64_t i = 0; i < tensor_res.size(); i++) {
res[i].resize(tensor_res[i].size());
for (uint64_t j = 0; j < tensor_res[i].size(); j++) {
if (tensor_res[i][j].defined()) {
res[i][j] = std::static_pointer_cast<primitive::LazyTensor>(
tensor_res[i][j].impl())
->getValue()
.dyn_cast<ir::OpResult>();
}
}
}
return res;
}
} // namespace dialect
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册