未验证 提交 488071af 编写于 作者: C Chen Zhiyang 提交者: GitHub

Add vjp autogen v1.0 (#56369)

* add vjp autogen v1.0

* resolve type conflict
上级 3779412c
......@@ -20,6 +20,7 @@ from op_build_gen import gen_build_func_str
from op_interface_gen import (
gen_exclusive_interface_str,
gen_op_infer_meta_str,
gen_op_vjp_str,
vjp_interface_gen_op_list,
)
from op_member_func_gen import gen_op_get_inputs_outputs_str
......@@ -286,6 +287,9 @@ class OpInfoParser:
self.attribute_build_arg_type_list = (
self.parse_attribute_build_arg_type_list()
)
self.attribute_gen_arg_type_list = (
self.parse_attribute_gen_arg_type_list()
)
self.attribute_data_type_list = self.parse_attribute_data_type_list()
self.attribute_default_value_list = (
self.parse_attribute_default_value_list()
......@@ -584,6 +588,17 @@ class OpInfoParser:
type_list.append(self.get_phi_dtype_name(temp_type))
return type_list
def parse_attribute_gen_arg_type_list(self):
type_list = []
for attribute_info in self.op_yaml_item['attrs']:
assert (
attribute_info['typename'] in self.attr_types_map
), f"{self.op_phi_name} : Attr type error."
temp_type = self.attr_types_map[attribute_info['typename']][1]
type_list.append(self.get_phi_dtype_name(temp_type))
return type_list
def parse_attribute_type_list(self):
type_list = []
for attribute_info in self.op_yaml_item['attrs']:
......@@ -1038,12 +1053,17 @@ def OpGenerator(
op_vjp_str = ''
# TODO(chenzhiyang) add vjp gen code
# if op_info.backward_name and op_info.op_phi_name[0] in vjp_interface_gen_op_list:
# op_vjp_str = gen_op_vjp_str(op_class_name,
# op_info.backward_name,
# op_name,
# op_info_items[op_info.op_phi_name[0]],
# op_info_items[op_info.backward_name])
if (
op_info.backward_name
and op_info.op_phi_name[0] in vjp_interface_gen_op_list
):
op_vjp_str = gen_op_vjp_str(
op_class_name,
op_info.backward_name,
op_name,
op_info_items[op_info.op_phi_name[0]],
op_info_items[op_info.backward_name],
)
ops_name_list.append(op_class_name)
ops_declare_list.append(op_declare_str)
......
......@@ -23,57 +23,61 @@ void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{
"""
OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE = """
{input_type} {input_name}(std::make_shared<primitive::experimental::StaticTensor>(op_obj.{input_name}()));
"""
{input_type} {input_name}(std::make_shared<primitive::experimental::StaticTensor>(op_obj.{input_name}()));"""
OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE = """
Tensor {output_grad_name}(std::make_shared<primitive::experimental::StaticTensor>((out_grads[{idx1}][{idx2}]);
"""
Tensor {output_grad_name}(std::make_shared<primitive::experimental::StaticTensor>(out_grads[{idx1}][{idx2}]));"""
OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE = """
std::vector<Tensor> {output_grad_name}(std::make_shared<primitive::experimental::StaticTensor>((out_grads[{idx1}]);
"""
std::vector<Tensor> {output_grad_name}(std::make_shared<primitive::experimental::StaticTensor>(out_grads[{idx1}]));"""
OP_VJP_CALL_VJP_TEMPLATE = """
Tensor std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::{op_phi_name}_vjp({inputs_list}, stop_gradients);
"""
OP_VJP_ATTRIBUTE_TEMPLATE = """
{attr_type} {attr_name} = op->attribute("{attr_name}").dyn_cast<{attr_parse_type}>().data();"""
OP_VJP_STOPGRADIENT_TEMPLATE = """
if(!stop_gradients[{idx1}][{idx2}]){{
res[{idx1}][{idx2}] = std::static_pointer_cast<primitive::experimental::StaticTensor>(
tensor_res[idx1][idx2].impl())
->getValue()
.dyn_cast<ir::OpResult>();
}}
"""
OP_VJP_ATTRIBUTE_DEFAULT_TEMPLATE = """
{attr_type} {attr_name} = {default_value};"""
OP_VJP_DEFINE_TEMPLATE = """
std::vector<std::vector<ir::OpResult>> {op_class_name}::Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients){{
{op_class_name} op_obj = op->dyn_cast<{op_class_name}>();
VLOG(6) << "Prepare inputs of {op_grad_name}";
OP_VJP_CALL_VJP_TEMPLATE = """ std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::{op_phi_name}_vjp(
{inputs_list}stop_gradients);"""
{forward_input_code}
{forward_output_code}
{forward_output_grad_code}
OP_VJP_STOPGRADIENT_TEMPLATE = """
std::vector<std::vector<ir::OpResult>> res(tensor_res.size());
for (size_t i = 0; i < tensor_res.size(); ++i) {{
res[i].resize(tensor_res[i].size());
for (size_t j = 0; j < tensor_res[i].size(); ++j) {{
if(tensor_res[i][j].defined()){{
res[i][j] = std::static_pointer_cast<primitive::experimental::StaticTensor>(tensor_res[i][j].impl())->getValue().dyn_cast<ir::OpResult>();
}}
}}
}}"""
OP_VJP_DEFINE_TEMPLATE = """
std::vector<std::vector<ir::OpResult>> {op_class_name}::Vjp(ir::Operation* op, const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients){{
{op_class_name} op_obj = op->dyn_cast<{op_class_name}>();
VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}";
{attribute_code}
VLOG(6) << "Prepare inputs of {op_grad_name}";
{forward_input_code}
{forward_output_grad_code}
VLOG(4) << "Vjp prepare call {op_phi_name}'s vjp inteface";
{call_vjp_code}
VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}";
{attribute_code}
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1));
{stop_gradient_input_grad_code}
VLOG(4) << "Vjp prepare call {op_phi_name}'s vjp inteface";
{call_vjp_code}
return res;
VLOG(4) << "Vjp prepare stop gradient of {op_grad_name}";
{stop_gradient_input_grad_code}
return res;
}}
"""
input_types_map = {
'paddle::dialect::DenseTensorType': 'Tensor',
'ir::VectorType<paddle::dialect::DenseTensorType>': 'Tensor[]',
}
def gen_op_vjp_str(
op_class_name,
......@@ -82,19 +86,62 @@ def gen_op_vjp_str(
op_info,
op_grad_info,
):
bw_input_list = op_grad_info.input_name_list
forward_input_code = ''
forward_output_code = ''
forward_output_grad_code = ''
build_args_str = ''
grad_idx = -1
for idx in range(len(bw_input_list)):
build_args_str += bw_input_list[idx] + ", "
if (
bw_input_list[idx] in op_info.input_name_list
or bw_input_list[idx] in op_info.output_name_list
):
forward_input_code += (
OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format(
input_type=input_types_map[
op_grad_info.input_type_list[idx]
],
input_name=bw_input_list[idx],
)
)
else:
grad_idx += 1
forward_output_grad_code += (
OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE.format(
output_grad_name=bw_input_list[idx], idx1=grad_idx, idx2=0
)
)
op_attribute_list = op_grad_info.attribute_name_list
attribute_code = ''
call_vjp_code = ''
stop_gradient_input_grad_code = ''
for idx in range(len(op_attribute_list)):
build_args_str += op_attribute_list[idx] + ", "
if op_attribute_list[idx] in op_info.attribute_name_list:
attribute_code += OP_VJP_ATTRIBUTE_TEMPLATE.format(
attr_type=op_grad_info.attribute_gen_arg_type_list[idx],
attr_name=op_attribute_list[idx],
attr_parse_type=op_grad_info.attribute_type_list[idx],
)
else:
attribute_code += OP_VJP_ATTRIBUTE_DEFAULT_TEMPLATE.format(
attr_type=op_grad_info.attribute_gen_arg_type_list[idx],
attr_name=op_attribute_list[idx],
default_value=op_grad_info.attribute_default_value_list[idx],
)
if op_phi_name[-1] == '_':
op_phi_name = op_phi_name[:-1]
call_vjp_code = OP_VJP_CALL_VJP_TEMPLATE.format(
op_phi_name=op_phi_name,
inputs_list=build_args_str,
)
stop_gradient_input_grad_code = OP_VJP_STOPGRADIENT_TEMPLATE
str = OP_VJP_DEFINE_TEMPLATE.format(
op_class_name=op_class_name,
op_grad_name=op_grad_name,
op_phi_name=op_phi_name,
res_size=len(op_info.input_name_list),
forward_input_code=forward_input_code,
forward_output_code=forward_output_code,
forward_output_grad_code=forward_output_grad_code,
attribute_code=attribute_code,
call_vjp_code=call_vjp_code,
......
......@@ -23,132 +23,5 @@
// this file will be generated in pd_op.cc
namespace paddle {
namespace dialect {
using IntArray = paddle::experimental::IntArray;
std::vector<std::vector<ir::OpResult>> TanhOp::Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
TanhOp op_obj = op->dyn_cast<TanhOp>();
Tensor out(
std::make_shared<primitive::experimental::StaticTensor>(op_obj.out()));
Tensor grad_out(
std::make_shared<primitive::experimental::StaticTensor>(out_grads[0][0]));
std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::tanh_vjp(out, grad_out, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1));
if (tensor_res[0][0].defined()) {
res[0][0] = std::static_pointer_cast<primitive::experimental::StaticTensor>(
tensor_res[0][0].impl())
->getValue()
.dyn_cast<ir::OpResult>();
}
return res;
}
std::vector<std::vector<ir::OpResult>> Tanh_Op::Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
// TODO(wanghao107)
// we don't support inplace now,
// so use the non-inplace version instead currently.
// Support inplace in the future.
Tanh_Op op_obj = op->dyn_cast<Tanh_Op>();
Tensor out(
std::make_shared<primitive::experimental::StaticTensor>(op_obj.out()));
Tensor grad_out(
std::make_shared<primitive::experimental::StaticTensor>(out_grads[0][0]));
std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::tanh_vjp(out, grad_out, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1));
if (tensor_res[0][0].defined()) {
res[0][0] = std::static_pointer_cast<primitive::experimental::StaticTensor>(
tensor_res[0][0].impl())
->getValue()
.dyn_cast<ir::OpResult>();
}
return res;
}
std::vector<std::vector<ir::OpResult>> MeanOp::Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
MeanOp op_obj = op->dyn_cast<MeanOp>();
Tensor x(std::make_shared<primitive::experimental::StaticTensor>(op_obj.x()));
Tensor out_grad(
std::make_shared<primitive::experimental::StaticTensor>(out_grads[0][0]));
IntArray axis = op->attribute("axis")
.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data();
bool keepdim = op->attribute("keepdim").dyn_cast<ir::BoolAttribute>().data();
bool reduce_all = false;
std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::mean_vjp(
x, out_grad, axis, keepdim, reduce_all, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1));
if (tensor_res[0][0].defined()) {
res[0][0] = std::static_pointer_cast<primitive::experimental::StaticTensor>(
tensor_res[0][0].impl())
->getValue()
.dyn_cast<ir::OpResult>();
}
return res;
}
std::vector<std::vector<ir::OpResult>> AddOp::Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
AddOp op_obj = op->dyn_cast<AddOp>();
Tensor x(std::make_shared<primitive::experimental::StaticTensor>(op_obj.x()));
Tensor y(std::make_shared<primitive::experimental::StaticTensor>(op_obj.y()));
Tensor out_grad(
std::make_shared<primitive::experimental::StaticTensor>(out_grads[0][0]));
int axis = -1;
std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::add_vjp(x, y, out_grad, axis, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(2, std::vector<ir::OpResult>(1));
for (size_t i = 0; i < 2; ++i) {
if (tensor_res[i][0].defined()) {
res[i][0] =
std::static_pointer_cast<primitive::experimental::StaticTensor>(
tensor_res[i][0].impl())
->getValue()
.dyn_cast<ir::OpResult>();
}
}
return res;
}
std::vector<std::vector<ir::OpResult>> Add_Op::Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
Add_Op op_obj = op->dyn_cast<Add_Op>();
Tensor x(std::make_shared<primitive::experimental::StaticTensor>(op_obj.x()));
Tensor y(std::make_shared<primitive::experimental::StaticTensor>(op_obj.y()));
Tensor out_grad(
std::make_shared<primitive::experimental::StaticTensor>(out_grads[0][0]));
int axis = -1;
std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::add_vjp(x, y, out_grad, axis, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(2, std::vector<ir::OpResult>(1));
for (size_t i = 0; i < 2; ++i) {
if (tensor_res[i][0].defined()) {
res[i][0] =
std::static_pointer_cast<primitive::experimental::StaticTensor>(
tensor_res[i][0].impl())
->getValue()
.dyn_cast<ir::OpResult>();
}
}
return res;
}
} // namespace dialect
namespace dialect {} // namespace dialect
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册