未验证 提交 e7bda1dd 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Added Eager Dygraph AutoCodeGen dependencies #2 (#37575)

上级 d2934a70
......@@ -40,11 +40,10 @@ static std::unordered_set<std::string> operators_to_skip = {
"fused_attention",
"diag_v2",
};
/*
static std::unordered_set<std::string> operators_to_codegen = {
"sigmoid", "matmul_v2", "reduce_sum", "elementwise_add",
"share_buffer", "var_conv_2d", "split"};
*/
static std::unordered_set<std::string> skipped_operators = {};
......@@ -107,8 +106,10 @@ static std::string AttrTypeToString(const proto::AttrType& type) {
break;
}
default: {
PADDLE_THROW(
platform::errors::Fatal("Unable to recognize AttrType: %d", type));
PADDLE_THROW(platform::errors::Fatal(
"AttrType of type boost::variant only supports specific data types."
"However, detected unrecognized AttrType: %d",
type));
}
}
return ret;
......@@ -214,8 +215,10 @@ static std::pair<std::string, std::string> GetAttrType(
break;
}
default: {
PADDLE_THROW(platform::errors::Fatal("Unable to recognize AttrType: %d",
variant_pos));
PADDLE_THROW(platform::errors::Fatal(
"AttrType of type boost::variant only supports specific data types."
"However, detected unrecognized AttrType: %d",
variant_pos));
}
}
return {ret, val};
......@@ -259,6 +262,7 @@ static void SlotNameMatching(
if (grad_fwd_slotname_map.count(grad_slot_name) &&
grad_fwd_slotname_map[grad_slot_name] != fwd_slot_name) {
PADDLE_THROW(platform::errors::Fatal(
"Detected mismatched slot names."
"grad_slot_name %s matches both %s and %s fwd_slot_name",
grad_slot_name, grad_fwd_slotname_map[grad_slot_name],
fwd_slot_name));
......@@ -271,6 +275,7 @@ static void SlotNameMatching(
if (grad_grad_slotname_map.count(grad_slot_name) &&
grad_grad_slotname_map[grad_slot_name] != fwd_slot_name) {
PADDLE_THROW(platform::errors::Fatal(
"Detected mismatched slot names."
"grad_slot_name %s matches both %s and %s fwd_slot_name",
grad_slot_name, grad_grad_slotname_map[grad_slot_name],
fwd_slot_name));
......@@ -290,6 +295,7 @@ static void SlotNameMatching(
if (grad_fwd_slotname_map.count(grad_slot_name) &&
grad_fwd_slotname_map[grad_slot_name] != fwd_slot_name) {
PADDLE_THROW(platform::errors::Fatal(
"Detected mismatched slot names"
"grad_slot_name %s matches both %s and %s fwd_slot_name",
grad_slot_name, grad_fwd_slotname_map[grad_slot_name],
fwd_slot_name));
......@@ -302,6 +308,7 @@ static void SlotNameMatching(
if (grad_grad_slotname_map.count(grad_slot_name) &&
grad_grad_slotname_map[grad_slot_name] != fwd_slot_name) {
PADDLE_THROW(platform::errors::Fatal(
"Detected mismatched slot names."
"grad_slot_name %s matches both %s and %s fwd_slot_name",
grad_slot_name, grad_grad_slotname_map[grad_slot_name],
fwd_slot_name));
......@@ -315,6 +322,7 @@ static void SlotNameMatching(
if (!found_matching) {
PADDLE_THROW(platform::errors::Fatal(
"Detected mismatched slot names."
"Found no matching fwd_slot_name for grad_slot_name: %s",
grad_slot_name));
......@@ -344,7 +352,7 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
// Only handle matmul_v2 for now
VLOG(1) << "------ Analyzing Op ------: " << op_type;
// if (!operators_to_codegen.count(op_type)) return false;
if (!operators_to_codegen.count(op_type)) return false;
if (operators_to_skip.count(op_type)) return false;
return true;
......@@ -741,5 +749,794 @@ static std::string AppendUseOp(const std::string& op_type) {
return return_str;
}
/* -------------------------------- */
/* --------- CodeGen: Forward ----- */
/* -------------------------------- */
static std::pair<std::string, std::string> GenerateForwardFunctionContents(
const std::vector<paddle::framework::AttributeMap>&
grad_node_default_attr_maps,
const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map,
const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map,
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map,
const std::map<std::string, std::string>& grad_ins_grad_slotname_map,
const std::map<std::string, std::string>& grad_outs_slotname_map,
const std::map<
std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>&
grad_ins,
const std::map<
std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>&
grad_outs,
const proto::OpProto& op_proto) {
/*
// Forward Function Example:
std::tuple<vector<Tensor>, Tensor, vector<Tensor>>
kernel_function(vector<Tensor>& X, Tensor& Y, const paddle::AttributeMap&
attr_map, size_t
Out0Num, size_t Out1Num) {
// Forward Function Body
// According to fwd_inputs_name_pos_map
std::map<std::string, std::vector<std::shared_ptr<egr::EagerTensor>>>
ins =
{ {"X" , SyncToVars(X)}, { "Y" , SyncToVars(Y)} };
std::map<std::string, std::vector<std::shared_ptr<egr::EagerTensor>>>
outs =
{
{"Out0" , ConstructDuplicableOutput(Out0Num)}, {"Out1"
,ConstructDuplicableOutput(Out1Num)} };
// According to op_proto->attrs()
egr::RunOp("op_type", ins, outs, attr_map,
Controller.Instance().GetExpectedPlace(), {});
// According to fwd_outputs_names
std::vector<egr::EagerTensor> Out0 = GetOutputs(outs["Out0"]);
egr::EagerTensor Out1 = GetOutputs(outs["Out1"][0]);
std::vector<egr::EagerTensor> Out2 = GetOutputs(outs["Out2"]);
// Grad Node Generation Codes
...
return std::make_tuple(Out0, Out1, Out2);
}
*/
VLOG(6) << "Generating Dygraph Forward Function";
const std::string& op_type = op_proto.type();
std::string generated_function_body = "";
std::string dygraph_function_args_str = "";
/* ------ Dygraph forward function generation ------ */
generated_function_body += " // Dygraph Forward Pass\n";
generated_function_body += "\n";
// [Generation] Get Ins Map
std::string ins_contents_str = "";
std::vector<std::string> input_args_str_list(op_proto.inputs().size());
for (const proto::OpProto::Var& input : op_proto.inputs()) {
const std::string& input_name = input.name();
size_t input_position = fwd_inputs_name_pos_map.at(input_name);
if (input.duplicable()) {
const char* FWD_INS_ARG_TEMPLATE =
"const std::vector<egr::EagerTensor>& %s";
input_args_str_list[input_position] =
paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name);
} else {
const char* FWD_INS_ARG_TEMPLATE = "const egr::EagerTensor& %s";
input_args_str_list[input_position] =
paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name);
}
const char* FWD_INS_CONTENT_TEMPLATE = "{ \"%s\", egr::SyncToVars(%s) },";
ins_contents_str += paddle::string::Sprintf(FWD_INS_CONTENT_TEMPLATE,
input_name, input_name);
}
if (ins_contents_str.size() > 0)
ins_contents_str.pop_back(); // // Remove trailing ","
for (const std::string& arg : input_args_str_list) {
dygraph_function_args_str += arg;
dygraph_function_args_str += ",";
}
if (dygraph_function_args_str.size() > 0)
dygraph_function_args_str.pop_back();
const char* FWD_INS_MAP_TEMPLATE =
" std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> ins = { "
"%s };\n";
std::string ins_map_str =
paddle::string::Sprintf(FWD_INS_MAP_TEMPLATE, ins_contents_str);
generated_function_body += ins_map_str;
generated_function_body += "\n";
VLOG(6) << "Generated Ins Map";
// [Generation] Get Outs Map
std::string outs_contents_str = "";
for (const proto::OpProto::Var& output : op_proto.outputs()) {
const std::string& output_name = output.name();
std::string outnum = "1";
if (output.duplicable()) {
outnum = output_name + "Num";
const char* FWD_NUM_ARG_TEMPLATE = ", size_t %s";
std::string arg_str =
paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, outnum);
dygraph_function_args_str += arg_str;
const char* FWD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::ConstructDuplicableOutput(%s) },";
outs_contents_str += paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE,
output_name, outnum);
} else {
const char* FWD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", "
"{std::make_shared<egr::EagerTensor>(egr::Controller::Instance()."
"GenerateUniqueName())}},";
outs_contents_str +=
paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE, output_name);
}
}
if (outs_contents_str.size() > 0)
outs_contents_str.pop_back(); // Remove trailing ","
const char* FWD_OUTS_MAP_TEMPLATE =
" std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> outs = { "
"%s };\n";
std::string outs_map_str =
paddle::string::Sprintf(FWD_OUTS_MAP_TEMPLATE, outs_contents_str);
generated_function_body += outs_map_str;
generated_function_body += "\n";
VLOG(6) << "Generated Outs Map";
// [Generation] Get Attrs
dygraph_function_args_str +=
", const paddle::framework::AttributeMap& attr_map";
generated_function_body += "\n";
// [Generation] Get TraceOp
const char* FWD_TRACE_OP_TEMPLATE =
" paddle::framework::AttributeMap attrs = attr_map;\n"
" paddle::framework::AttributeMap default_attrs;\n"
" egr::RunOp(\"%s\", ins, outs, attrs, \n"
" egr::Controller::Instance().GetExpectedPlace(),\n"
" &default_attrs, true, {});\n";
std::string trace_op_str =
paddle::string::Sprintf(FWD_TRACE_OP_TEMPLATE, op_proto.type());
generated_function_body += trace_op_str;
generated_function_body += "\n";
VLOG(6) << "Generated AttrMap & TraceOp";
// [Generation] Convert output VarBase to Vector/Tensor
size_t output_size = op_proto.outputs().size();
std::vector<std::string> return_contents(output_size);
std::vector<std::string> return_types(output_size);
for (const proto::OpProto::Var& output : op_proto.outputs()) {
const std::string& output_name = output.name();
std::string out_tensor_str;
size_t return_position = fwd_outputs_name_pos_map.at(output_name);
if (output.duplicable()) {
const char* FWD_OUT_TENSORS_TEMPLATE =
" std::vector<egr::EagerTensor> %s = "
"egr::GetOutputs(outs[\"%s\"]);\n";
out_tensor_str = paddle::string::Sprintf(FWD_OUT_TENSORS_TEMPLATE,
output_name, output_name);
return_types[return_position] = "std::vector<egr::EagerTensor>";
} else {
const char* FWD_OUT_TENSOR_TEMPLATE =
" egr::EagerTensor %s = "
"egr::GetOutput(outs[\"%s\"][0]);\n";
out_tensor_str = paddle::string::Sprintf(FWD_OUT_TENSOR_TEMPLATE,
output_name, output_name);
return_types[return_position] = "egr::EagerTensor";
}
return_contents[return_position] = output_name;
generated_function_body += out_tensor_str;
}
generated_function_body += "\n";
VLOG(6) << "Converted Output VarBase to EagerTensor(s)";
// [Generation] ComputeRequireGrad -> GradNodeCreation
std::string grad_node_creation_body_str = GenerateGradNodeCreationContent(
grad_node_default_attr_maps, fwd_inputs_name_pos_map,
fwd_outputs_name_pos_map, grad_ins_fwd_slotname_map, op_proto);
generated_function_body += grad_node_creation_body_str;
generated_function_body += "\n";
VLOG(6) << "Generated GradNode Creation codes";
// [Generation] Handle return: Tuple/Vector/Tensor
generated_function_body += "\n";
std::string return_str;
std::string return_type_str = "";
std::string function_proto_return_type_str = "";
if (return_contents.size() > 1) {
// Return tuple
std::string return_content_str = "";
for (const std::string& s : return_contents) {
return_content_str += s + ",";
}
return_content_str.pop_back(); // Remove trailing ","
for (const std::string& s : return_types) {
return_type_str += s + ",";
}
return_type_str.pop_back(); // Remove trailing ","
const char* FWD_TUPLE_RETURN_TEMPLATE = " return std::make_tuple(%s);";
return_str =
paddle::string::Sprintf(FWD_TUPLE_RETURN_TEMPLATE, return_content_str);
const char* FWD_FUNCTION_PROTO_RETURN_TEMPLATE = "std::tuple<%s>";
function_proto_return_type_str = paddle::string::Sprintf(
FWD_FUNCTION_PROTO_RETURN_TEMPLATE, return_type_str);
} else {
// Return vector<Tensor> or Tensor
return_type_str = return_types[0];
const char* FWD_TENSOR_RETURN_TEMPLATE = " return %s;";
return_str =
paddle::string::Sprintf(FWD_TENSOR_RETURN_TEMPLATE, return_contents[0]);
function_proto_return_type_str = return_type_str;
}
generated_function_body += return_str;
generated_function_body += "\n";
VLOG(6) << "Generated return codes";
// [Generation] Get Full Function
std::string function_name = op_type + "_dygraph_function";
const char* FWD_FUNCTION_TEMPLATE = "%s %s(%s) {\n\n%s\n}\n\n";
std::string fwd_function_str = paddle::string::Sprintf(
FWD_FUNCTION_TEMPLATE, function_proto_return_type_str, function_name,
dygraph_function_args_str, generated_function_body);
// [Generation] Append USE_OP
fwd_function_str += AppendUseOp(op_type);
// [Generation] Generate forward functions header
const char* FWD_HEADER_TEMPLATE = "%s %s(%s);\n";
std::string dygraph_function_declaration_str = paddle::string::Sprintf(
FWD_HEADER_TEMPLATE, function_proto_return_type_str, function_name,
dygraph_function_args_str);
return {fwd_function_str, dygraph_function_declaration_str};
}
/* ---------------------------------------------- */
/* --------- CodeGen: GradNode::operator() ------ */
/* ---------------------------------------------- */
static std::string GenerateGradNodeCCContents(
const std::vector<paddle::framework::AttributeMap>&
grad_node_default_attr_maps,
const std::vector<std::string>& grad_op_types,
const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map,
const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map,
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map,
const std::map<std::string, std::string>& grad_ins_grad_slotname_map,
const std::map<std::string, std::string>& grad_outs_slotname_map,
const std::map<
std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>&
grad_ins,
const std::map<
std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>&
grad_outs,
const proto::OpProto& op_proto) {
VLOG(6) << "Generating Grad Node CC";
/* [Outline]
vector<vector<Tensor>> GradNodeXXX::operator()(vector<vector<Tensor>>& grads)
{
const std::shared_ptr<Tracer>& tracer = imperative::GetCurrentTracer();
// Comes from "grad_ins"
std::map<std::string, std::vector<std::shared_ptr<VarBase>>> ins =
{
"X" : this->"X", "Y" : this->"Y",
"Out0@Grad":
SyncToVars(grads["fwd_outputs_name_pos_map[grad_ins_grad_slotname_map["Out0@Grad"]]"]),
"Out1@Grad":
TensorsToVarBases(grads["fwd_outputs_name_pos_map[grad_ins_grad_slotname_map["Out1@Grad"]]"])
};
// Comes from "grad_outs"
std::map<std::string, std::vector<std::shared_ptr<VarBase>>> outs =
{
"X@Grad" :
ConstructDuplicableOutput(this->OutputMeta()["fwd_inputs_name_pos_map[grad_outs_slotname_map["X@Grad"]]"].Size()),
"Y@Grad" :
ConstructDuplicableOutput(this->OutputMeta()["fwd_inputs_name_pos_map[grad_outs_slotname_map["Y@Grad"]]"].Size())
};
// Visit each OpBase
for(auto iter = "grad_node->begin()"; iter < "grad_node->end()"; iter++) {
// Simply pass entire attribute map to kernels
egr::RunOp("iter->Type()", ins, outs, this->attr_map_,
egr::Controller::Instance().ExpectedPlace(), false, {});
}
vector<vector<egr::EagerTensor>> outputs(outs.size());
for(auto& kv : outs) {
outputs["fwd_inputs_name_pos_map[grad_outs_slotname_map[kv.first]]"] =
GetOutputs(outs["kv.first"]);
}
return outputs;
}
*/
const std::string& op_type = op_proto.type();
std::string generated_grad_function_body = "";
// [Generation] Get Tracer
generated_grad_function_body += "\n";
generated_grad_function_body += "\n";
// [Generation] Get Ins Map
std::string ins_contents_str = "";
for (auto iter : grad_ins) {
const std::string& grad_input_name = iter.first;
if (grad_ins_fwd_slotname_map.count(grad_input_name)) {
// Fwd Tensor
std::string struct_fwd_input_name =
grad_ins_fwd_slotname_map.at(grad_input_name) + "_";
const char* GRAD_INS_FWD_CONTENT_TEMPLATE =
"{ \"%s\", "
"egr::SyncToVars(egr::EagerUtils::RecoverTensorWrapper(&this->%s, "
"nullptr)) },";
ins_contents_str +=
paddle::string::Sprintf(GRAD_INS_FWD_CONTENT_TEMPLATE,
grad_input_name, struct_fwd_input_name);
} else if (grad_ins_grad_slotname_map.count(grad_input_name)) {
// Fwd Tensor's Grad
size_t fwd_output_position = fwd_outputs_name_pos_map.at(
grad_ins_grad_slotname_map.at(grad_input_name));
const char* GRAD_INS_GRAD_CONTENT_TEMPLATE =
"{ \"%s\", egr::SyncToVars(grads[%d]) },";
ins_contents_str += paddle::string::Sprintf(
GRAD_INS_GRAD_CONTENT_TEMPLATE, grad_input_name, fwd_output_position);
} else {
PADDLE_THROW(platform::errors::Fatal(
"Detected mismatched slot names."
"Unable to find forward slot name that matches %s",
grad_input_name));
}
}
if (ins_contents_str.size() > 0)
ins_contents_str.pop_back(); // // Remove trailing ","
const char* BWD_INS_MAP_TEMPLATE =
" std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> ins = { "
"%s };\n";
std::string ins_map_str =
paddle::string::Sprintf(BWD_INS_MAP_TEMPLATE, ins_contents_str);
generated_grad_function_body += ins_map_str;
VLOG(6) << "Generated Ins Map";
// [Generation] Get Outs Map
std::unordered_set<std::string> duplicable_input_name_set;
for (const auto& in : op_proto.inputs()) {
if (in.duplicable()) duplicable_input_name_set.insert(in.name());
}
std::string outs_contents_str = "";
for (auto iter : grad_outs) {
const std::string& grad_output_name = iter.first;
if (grad_outs_slotname_map.count(grad_output_name)) {
// Fwd Tensor
const std::string& fwd_input_name =
grad_outs_slotname_map.at(grad_output_name);
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_input_name);
if (duplicable_input_name_set.count(fwd_input_name)) {
const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::ConstructDuplicableOutput( "
"this->OutputMeta()[%d].Size() ) },";
outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, fwd_input_position);
} else {
const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", "
"{std::make_shared<egr::EagerTensor>(egr::Controller::Instance()."
"GenerateUniqueName())}},";
outs_contents_str += paddle::string::Sprintf(GRAD_OUTS_CONTENT_TEMPLATE,
grad_output_name);
}
} else {
PADDLE_THROW(platform::errors::Fatal(
"Detected mismatched slot names."
"Unable to find forward slot name that matches %s",
grad_output_name));
}
}
if (outs_contents_str.size() > 0)
outs_contents_str.pop_back(); // // Remove trailing ","
const char* BWD_OUTS_MAP_TEMPLATE =
" std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> outs = { "
"%s };\n";
std::string outs_map_str =
paddle::string::Sprintf(BWD_OUTS_MAP_TEMPLATE, outs_contents_str);
generated_grad_function_body += outs_map_str;
generated_grad_function_body += "\n";
VLOG(6) << "Generated Outs Map";
// [Generation] Get Attrs Map
std::string trace_opbase_str = "";
for (size_t i = 0; i < grad_node_default_attr_maps.size(); i++) {
const std::string& op_base_type = grad_op_types[i];
const char* TRACE_OP_TEMPLATE =
" // Pass the entire attribute map to TraceOp\n"
" // The underlying kernel will pickup whatever attribute they need "
"at runtime\n"
" egr::RunOp(\"%s\", ins, outs, this->attr_map_,\n"
" egr::Controller::Instance().GetExpectedPlace(),\n"
" &this->default_attr_map_, false, {});\n";
trace_opbase_str = paddle::string::Sprintf(TRACE_OP_TEMPLATE, op_base_type);
}
generated_grad_function_body += trace_opbase_str;
VLOG(6) << "Generated Attrs Map";
// [Generation] Get Return
std::string outputs_str = "";
for (auto iter : grad_outs) {
const std::string& grad_out_name = iter.first;
size_t fwd_input_position =
fwd_inputs_name_pos_map.at(grad_outs_slotname_map.at(grad_out_name));
const char* BWD_OUTPUT_TEMPLATE =
" outputs[%d] = GetOutputs(outs[\"%s\"]);\n";
outputs_str += paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE,
fwd_input_position, grad_out_name);
}
const char* BWD_RETURN_TEMPLATE =
" std::vector<std::vector<egr::EagerTensor>> "
"outputs(outs.size());\n%s\n "
"return outputs;";
std::string return_str =
paddle::string::Sprintf(BWD_RETURN_TEMPLATE, outputs_str);
generated_grad_function_body += "\n";
generated_grad_function_body += return_str;
// [Generation] Get Full Grad Function
const char* GRAD_FUNCTION_TEMPLATE =
"std::vector<std::vector<egr::EagerTensor>> "
"GradNode%s::operator()(const "
"std::vector<std::vector<egr::EagerTensor>>& grads) {\n%s\n}";
std::string grad_function_str = paddle::string::Sprintf(
GRAD_FUNCTION_TEMPLATE, op_type, generated_grad_function_body);
VLOG(6) << "Generated returns";
return grad_function_str;
}
/* ----------------------------------------- */
/* --------- CodeGen: GradNode Header ------ */
/* ----------------------------------------- */
static std::string GenerateGradNodeHeaderContents(
const std::vector<paddle::framework::AttributeMap>&
grad_node_default_attr_maps,
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map,
const proto::OpProto& op_proto) {
VLOG(6) << "Generating Grad Node Header";
const char* GRAD_NODE_TEMPLATE =
"class GradNode%s : public egr::GradNodeBase {\n"
" public:\n"
" GradNode%s() : egr::GradNodeBase() {}\n"
" GradNode%s(size_t bwd_in_slot_num, size_t bwd_out_slot_num) : "
"egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {}\n"
" ~GradNode%s() override = default;\n"
"\n"
" virtual std::vector<std::vector<egr::EagerTensor>> "
"operator()(const "
"std::vector<std::vector<egr::EagerTensor>>& grads) "
"override;\n"
"\n"
" // SetX, SetY, ...\n"
"%s\n"
" // SetAttrMap\n"
"%s\n"
"\n"
" private:\n"
" // TensorWrappers\n"
"%s\n"
" // Attribute Map\n"
"%s\n"
"};";
const std::string& op_type = op_proto.type();
// [Generation] Handle Attributes
std::string set_attr_map_str =
" void SetAttrMap(paddle::framework::AttributeMap&& attr_map) {\n "
"attr_map_ = std::move(attr_map);\n }\n";
set_attr_map_str +=
" void SetDefaultAttrMap(paddle::framework::AttributeMap&& "
"default_attr_map) {\n default_attr_map_ = "
"std::move(default_attr_map);\n }\n";
std::string attr_members_str =
" paddle::framework::AttributeMap attr_map_;\n";
attr_members_str += " paddle::framework::AttributeMap default_attr_map_;";
VLOG(6) << "Generated SetAttr";
// [Generation] Handle TensorWrappers
std::unordered_set<std::string> duplicable_tensors;
for (const proto::OpProto::Var& input : op_proto.inputs()) {
if (input.duplicable()) {
duplicable_tensors.insert(input.name());
}
}
for (const proto::OpProto::Var& output : op_proto.outputs()) {
if (output.duplicable()) {
duplicable_tensors.insert(output.name());
}
}
std::string set_tensor_wrappers_str = "";
std::string tensor_wrapper_members_str = "";
for (const auto& kv : grad_ins_fwd_slotname_map) {
const std::string& tensor_wrapper_name = kv.second;
const std::string& struct_tensor_wrapper_name = kv.second + "_";
std::string tensor_wrapper_arg_str;
std::string tensor_wrapper_body_str;
if (duplicable_tensors.count(tensor_wrapper_name)) {
const char* ATTR_TENSOR_WRAPPER_ARG_TEMPLATE =
"const std::vector<egr::EagerTensor>& %s";
tensor_wrapper_arg_str = paddle::string::Sprintf(
ATTR_TENSOR_WRAPPER_ARG_TEMPLATE, tensor_wrapper_name);
const char* TENSOR_WRAPPER_MEMBER_TEMPLATE =
" std::vector<egr::TensorWrapper> %s;\n";
tensor_wrapper_members_str += paddle::string::Sprintf(
TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name);
const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE =
"for(const auto& eager_tensor : %s) {\n"
" %s.emplace_back( egr::TensorWrapper(eager_tensor, true "
"/*full_reserved*/) );\n"
" }\n";
tensor_wrapper_body_str = paddle::string::Sprintf(
SET_TENSOR_WRAPPER_BODY_TEMPLATE, tensor_wrapper_name,
struct_tensor_wrapper_name);
} else {
const char* ATTR_TENSOR_WRAPPER_ARG_TEMPLATE =
"const egr::EagerTensor& %s";
tensor_wrapper_arg_str = paddle::string::Sprintf(
ATTR_TENSOR_WRAPPER_ARG_TEMPLATE, tensor_wrapper_name);
const char* TENSOR_WRAPPER_MEMBER_TEMPLATE =
" egr::TensorWrapper %s;\n";
tensor_wrapper_members_str += paddle::string::Sprintf(
TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name);
const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE =
"%s = egr::TensorWrapper(%s, true /*full_reserved*/);";
tensor_wrapper_body_str = paddle::string::Sprintf(
SET_TENSOR_WRAPPER_BODY_TEMPLATE, struct_tensor_wrapper_name,
tensor_wrapper_name);
}
const char* SET_TENSOR_WRAPPER_TEMPLATE =
" void SetTensorWrapper%s(%s) {\n %s\n }\n";
set_tensor_wrappers_str += paddle::string::Sprintf(
SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name,
tensor_wrapper_arg_str, tensor_wrapper_body_str);
}
VLOG(6) << "Generated TensorWrapper";
std::string grad_node_str = paddle::string::Sprintf(
GRAD_NODE_TEMPLATE, op_type, op_type, op_type, op_type,
set_tensor_wrappers_str, set_attr_map_str, tensor_wrapper_members_str,
attr_members_str);
return grad_node_str;
}
/* --------------------------------- */
/* --------- FileGeneration --------- */
/* ---------------------------------- */
static void GenerateForwardHFile(const std::string& output_dir,
const std::string& dygraph_forward_api_str) {
std::string dygraph_forward_api_path = output_dir + "/dygraph_forward_api.h";
std::ofstream forward_header_stream(dygraph_forward_api_path, std::ios::out);
forward_header_stream << dygraph_forward_api_str;
forward_header_stream.close();
}
static void GenerateForwardDygraphFile(const std::string& op_type,
const std::string& output_dir,
const std::string& fwd_function_str) {
std::string forwards_dir = output_dir + "/forwards/";
std::string node_h_filename = op_type + "_node.h";
std::string forward_cc_filename = op_type + "_dygraph.cc";
std::string forward_cc_path = forwards_dir + forward_cc_filename;
const char* FORWARD_INCLUDE_TEMPLATE =
"#include "
"\"paddle/fluid/eager/api/generated/fluid_generated/"
"dygraph_forward_api.h\"\n"
"#include "
"\"paddle/fluid/eager/api/generated/fluid_generated/nodes/%s\"\n\n"
"#include \"paddle/fluid/eager/api/utils/global_utils.h\"\n"
"#include \"paddle/fluid/eager/legacy/op_runner.h\"\n";
std::string forward_cc_include_str =
paddle::string::Sprintf(FORWARD_INCLUDE_TEMPLATE, node_h_filename);
std::ofstream forward_cc_stream(forward_cc_path, std::ios::out);
forward_cc_stream << forward_cc_include_str;
forward_cc_stream << fwd_function_str;
forward_cc_stream.close();
}
static void GenerateNodeHFile(const std::string& op_type,
const std::string& output_dir,
const std::string& grad_node_str) {
std::string nodes_dir = output_dir + "/nodes/";
std::string node_h_filename = op_type + "_node.h";
std::string node_h_path = nodes_dir + node_h_filename;
std::string node_h_include_str =
"#pragma once\n"
"#include \"paddle/fluid/eager/tensor_wrapper.h\"\n"
"#include \"paddle/fluid/eager/legacy/op_runner.h\"\n"
"#include \"paddle/fluid/eager/grad_node_info.h\"\n\n";
std::ofstream node_h_stream(node_h_path, std::ios::out);
node_h_stream << node_h_include_str;
node_h_stream << grad_node_str;
node_h_stream.close();
}
static void GenerateNodeCCFile(const std::string& op_type,
const std::string& output_dir,
const std::string& grad_function_str) {
std::string nodes_dir = output_dir + "/nodes/";
std::string node_h_filename = op_type + "_node.h";
std::string node_cc_filename = op_type + "_node.cc";
std::string node_cc_path = nodes_dir + node_cc_filename;
const char* NODE_CC_INCLUDE_TEMPLATE =
"#include \"glog/logging.h\"\n"
"#include \"paddle/pten/api/all.h\"\n"
"#include \"paddle/fluid/imperative/tracer.h\"\n"
"#include \"paddle/fluid/framework/op_registry.h\"\n"
"#include \"paddle/fluid/eager/utils.h\"\n"
"#include \"paddle/fluid/eager/api/utils/global_utils.h\"\n"
"#include "
"\"paddle/fluid/eager/api/generated/fluid_generated/nodes/%s\"\n\n";
std::string node_cc_include_str =
paddle::string::Sprintf(NODE_CC_INCLUDE_TEMPLATE, node_h_filename);
std::ofstream node_cc_stream(node_cc_path, std::ios::out);
node_cc_stream << node_cc_include_str;
node_cc_stream << grad_function_str;
node_cc_stream.close();
}
static std::string GenerateDygraphHFileIncludes() {
std::string dygraph_forward_api_includes_str =
"#pragma once\n"
"#include \"glog/logging.h\"\n"
"#include \"paddle/fluid/eager/autograd_meta.h\"\n"
"#include \"paddle/pten/api/all.h\"\n"
"#include \"paddle/fluid/eager/utils.h\"\n"
"#include \"paddle/fluid/framework/op_registry.h\"\n\n";
return dygraph_forward_api_includes_str;
}
static void DygraphCodeGeneration(const std::string& output_dir) {
std::string dygraph_forward_api_str = GenerateDygraphHFileIncludes();
auto& op_info_map = paddle::framework::OpInfoMap::Instance().map();
for (auto& pair : op_info_map) {
const OpInfo& op_info = pair.second;
proto::OpProto* op_proto = op_info.proto_;
if (!CheckOpProto(op_proto)) continue;
const std::string& op_type = op_proto->type();
/* ----------------------------- */
/* ---- Collect Information ---- */
/* ----------------------------- */
std::vector<paddle::framework::AttributeMap> grad_node_default_attr_maps;
std::vector<std::string> grad_op_types;
std::unordered_map<std::string, size_t> fwd_inputs_name_pos_map;
std::unordered_map<std::string, size_t> fwd_outputs_name_pos_map;
std::map<std::string, std::string> grad_outs_slotname_map;
std::map<std::string, std::string> grad_ins_fwd_slotname_map;
std::map<std::string, std::string> grad_ins_grad_slotname_map;
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>
grad_ins;
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>
grad_outs;
VLOG(6) << "-------- CollectInformationFromOpInfo -------";
bool is_available = CollectInformationFromOpInfo(
op_info, &grad_node_default_attr_maps, &grad_op_types,
&fwd_inputs_name_pos_map, &fwd_outputs_name_pos_map,
&grad_outs_slotname_map, &grad_ins_fwd_slotname_map,
&grad_ins_grad_slotname_map, &grad_ins, &grad_outs);
if (!is_available) continue;
/* --------------------------- */
/* --------- CodeGen --------- */
/* --------------------------- */
/* ---- xxx_dygraph.cc ---- */
VLOG(6) << "-------- GenerateForwardFunctionContents -------";
std::pair<std::string, std::string> body_and_declaration =
GenerateForwardFunctionContents(
grad_node_default_attr_maps, fwd_inputs_name_pos_map,
fwd_outputs_name_pos_map, grad_ins_fwd_slotname_map,
grad_ins_grad_slotname_map, grad_outs_slotname_map, grad_ins,
grad_outs, *op_proto);
std::string fwd_function_str = body_and_declaration.first;
GenerateForwardDygraphFile(op_type, output_dir, fwd_function_str);
/* ---- dygraph_forward_api.h ---- */
std::string fwd_function_declare_str = body_and_declaration.second;
dygraph_forward_api_str += fwd_function_declare_str;
/* ---- xxx_node.h ---- */
VLOG(6) << "-------- GenerateGradNodeHeaderContents -------";
std::string grad_node_h_str = GenerateGradNodeHeaderContents(
grad_node_default_attr_maps, grad_ins_fwd_slotname_map, *op_proto);
GenerateNodeHFile(op_type, output_dir, grad_node_h_str);
/* ---- xxx_node.cc ---- */
VLOG(6) << "-------- GenerateGradNodeCCContents -------";
std::string grad_node_cc_str = GenerateGradNodeCCContents(
grad_node_default_attr_maps, grad_op_types, fwd_inputs_name_pos_map,
fwd_outputs_name_pos_map, grad_ins_fwd_slotname_map,
grad_ins_grad_slotname_map, grad_outs_slotname_map, grad_ins, grad_outs,
*op_proto);
GenerateNodeCCFile(op_type, output_dir, grad_node_cc_str);
VLOG(6) << op_type << ": Finished Generation";
}
/* ---- dygraph_forward_api.h ---- */
VLOG(6) << "-------- GenerateForwardHFile -------";
GenerateForwardHFile(output_dir, dygraph_forward_api_str);
}
int main(int argc, char* argv[]) {
if (argc != 2) {
std::cerr << "argc must be 2" << std::endl;
return -1;
}
std::string eager_root = argv[1];
paddle::framework::DygraphCodeGeneration(eager_root);
return 0;
}
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册