未验证 提交 de874cdd 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Enabled generation for special operators, the GradNode/Inputs/Outputs of which are empty (#37837)

上级 27d1f811
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <algorithm>
#include <fstream>
#include <iostream>
......@@ -27,69 +26,21 @@
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h"
DEFINE_bool(generate_all, false,
"Generate all operators currently registered in Paddle");
namespace paddle {
namespace framework {
static std::unordered_map<std::string, paddle::framework::AttributeMap>
operators_with_attrs = {};
static std::unordered_set<std::string> operators_to_skip = {
"pull_sparse", "pull_box_extended_sparse", "pull_sparse_v2",
"pull_box_sparse", "fused_attention", "diag_v2",
"c_split"};
"chunk_eval", // Stupid tensor name
"minus", "pull_sparse", "pull_box_extended_sparse",
"pull_sparse_v2", "pull_box_sparse", "fused_attention",
"diag_v2", "c_split"};
static std::unordered_set<std::string> operators_to_codegen = {};
static std::unordered_set<std::string> skipped_operators = {};
static void PrepareAttrMapForOps() {
// Handle "fused_elemwise_add_activation"
std::vector<std::string> functor_list = {"a", "b"};
operators_with_attrs["fused_elemwise_add_activation"] = {};
operators_with_attrs["fused_elemwise_add_activation"]["functor_list"] =
functor_list;
// Handle "fused_elemwise_activation"
operators_with_attrs["fused_elemwise_activation"] = {};
operators_with_attrs["fused_elemwise_activation"]["functor_list"] =
functor_list;
// Handle "reverse"
std::vector<int> axis = {0};
operators_with_attrs["reverse"] = {};
operators_with_attrs["reverse"]["axis"] = axis;
// Handle "flip"
operators_with_attrs["flip"] = {};
operators_with_attrs["flip"]["axis"] = axis;
// Handle "cast"
operators_with_attrs["cast"] = {};
operators_with_attrs["cast"]["out_dtype"] = 5;
operators_with_attrs["cast"]["in_dtype"] = 5;
// Handle "transfer_dtype"
operators_with_attrs["transfer_dtype"] = {};
operators_with_attrs["transfer_dtype"]["out_dtype"] = 5;
operators_with_attrs["transfer_dtype"]["in_dtype"] = 5;
}
static void CollectOperatorsToCodeGen(const std::string& op_list_path) {
std::string line;
std::ifstream op_list_file(op_list_path);
if (op_list_file.is_open()) {
while (getline(op_list_file, line)) {
operators_to_codegen.insert(line);
}
op_list_file.close();
} else {
PADDLE_THROW(
paddle::platform::errors::Fatal("Unable to open op_list.txt file"));
}
}
namespace paddle {
namespace framework {
static std::string AttrTypeToString(const proto::AttrType& type) {
std::string ret;
switch (type) {
......@@ -392,10 +343,7 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
// Only handle matmul_v2 for now
VLOG(1) << "------ Analyzing Op ------: " << op_type;
if (!FLAGS_generate_all) {
if (!operators_to_codegen.count(op_type)) return false;
}
if (!operators_to_codegen.count(op_type)) return false;
if (operators_to_skip.count(op_type)) return false;
return true;
......@@ -404,21 +352,12 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
/* --------------------------------------- */
/* --------- Preprocess Ins/Outs --------- */
/* --------------------------------------- */
static void PurifyOpProto(
static void PurifyForwardOpProto(
const proto::OpProto& op_proto,
std::unordered_map<std::string, size_t>* fwd_inputs_name_pos_map,
std::unordered_map<std::string, size_t>* fwd_outputs_name_pos_map,
std::map<std::string, std::string>* grad_outs_slotname_map,
std::map<std::string, std::string>* grad_ins_fwd_slotname_map,
std::map<std::string, std::string>* grad_ins_grad_slotname_map,
std::vector<proto::OpProto::Var>* in_vars,
std::vector<proto::OpProto::Var>* out_vars,
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_ins,
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_outs) {
std::vector<proto::OpProto::Var>* out_vars) {
// Op Name
const std::string op_name = op_proto.type();
......@@ -440,6 +379,72 @@ static void PurifyOpProto(
}
}
in_vars->erase(iter);
}
}
}
for (const proto::OpProto::Var& output : op_proto.outputs()) {
std::string output_name = output.name();
// Delete dispensable tensor unless specified in op_outs_map
if (output.dispensable()) {
if (!op_outs_map.count(op_name) ||
!op_outs_map[op_name].count(output_name)) {
VLOG(6) << "Removing Dispensable Output: " << output_name;
// out_vars
auto iter = out_vars->begin();
for (iter = out_vars->begin(); iter != out_vars->end(); iter++) {
if (iter->name() == output_name) {
break;
}
}
out_vars->erase(iter);
}
}
}
/* ------ Maping forward slot name to fwd position ------ */
size_t in_pos = 0;
for (const auto& var : *in_vars) {
VLOG(6) << "Mapping input tensor: " << var.name()
<< " To position: " << in_pos;
(*fwd_inputs_name_pos_map)[var.name()] = in_pos;
in_pos++;
}
size_t out_pos = 0;
for (const auto& var : *out_vars) {
VLOG(6) << "Mapping output tensor: " << var.name()
<< " To position: " << out_pos;
(*fwd_outputs_name_pos_map)[var.name()] = out_pos;
out_pos++;
}
}
static void PurifyGradOpProto(
const proto::OpProto& op_proto,
std::map<std::string, std::string>* grad_outs_slotname_map,
std::map<std::string, std::string>* grad_ins_fwd_slotname_map,
std::map<std::string, std::string>* grad_ins_grad_slotname_map,
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_ins,
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_outs) {
// Op Name
const std::string op_name = op_proto.type();
// Handle dispensable inputs
for (const proto::OpProto::Var& input : op_proto.inputs()) {
std::string input_name = input.name();
// Delete dispensable tensor unless specified in op_ins_map
if (input.dispensable()) {
if (!op_ins_map.count(op_name) ||
!op_ins_map[op_name].count(input_name)) {
VLOG(6) << "Removing Dispensable Input: " << input_name;
// grad_outs_slotname_map
auto grad_outs_slotname_map_purified = *grad_outs_slotname_map;
......@@ -478,15 +483,6 @@ static void PurifyOpProto(
!op_outs_map[op_name].count(output_name)) {
VLOG(6) << "Removing Dispensable Output: " << output_name;
// out_vars
auto iter = out_vars->begin();
for (iter = out_vars->begin(); iter != out_vars->end(); iter++) {
if (iter->name() == output_name) {
break;
}
}
out_vars->erase(iter);
// grad_ins_grad_slotname_map
auto grad_ins_grad_slotname_map_purified = *grad_ins_grad_slotname_map;
for (const auto& iter : *grad_ins_grad_slotname_map) {
......@@ -514,52 +510,40 @@ static void PurifyOpProto(
}
}
}
/* ------ Maping forward slot name to fwd position ------ */
size_t in_pos = 0;
for (const auto& var : *in_vars) {
VLOG(6) << "Mapping input tensor: " << var.name()
<< " To position: " << in_pos;
(*fwd_inputs_name_pos_map)[var.name()] = in_pos;
in_pos++;
}
size_t out_pos = 0;
for (const auto& var : *out_vars) {
VLOG(6) << "Mapping output tensor: " << var.name()
<< " To position: " << out_pos;
(*fwd_outputs_name_pos_map)[var.name()] = out_pos;
out_pos++;
}
}
/* -------------------------------- */
/* --------- Collect Info --------- */
/* -------------------------------- */
static bool CollectInformationFromOpInfo(
static void CollectForwardInformationFromOpInfo(
const paddle::framework::OpInfo& op_info,
std::vector<std::string>* grad_op_types,
std::map<std::string, std::string>* grad_outs_slotname_map,
std::map<std::string, std::string>* grad_ins_fwd_slotname_map,
std::map<std::string, std::string>* grad_ins_grad_slotname_map,
std::vector<proto::OpProto::Var>* in_vars,
std::vector<proto::OpProto::Var>* out_vars,
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_ins,
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_outs) {
std::vector<proto::OpProto::Var>* out_vars) {
const proto::OpProto& op_proto = *op_info.proto_;
const std::string& op_type = op_proto.type();
std::vector<int64_t> dims = {1, 1, 1, 1};
for (const proto::OpProto::Var& input : op_proto.inputs()) {
in_vars->push_back(input);
}
for (const proto::OpProto::Var& output : op_proto.outputs()) {
out_vars->push_back(output);
}
}
static bool CollectGradInformationFromOpInfo(
const paddle::framework::OpInfo& op_info, bool* generate_forward_only,
std::vector<std::string>* grad_op_types, // grad
std::map<std::string, std::string>* grad_outs_slotname_map, // grad
std::map<std::string, std::string>* grad_ins_fwd_slotname_map, // grad
std::map<std::string, std::string>* grad_ins_grad_slotname_map, // grad
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_ins, // grad
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_outs // grad
) {
const proto::OpProto& op_proto = *op_info.proto_;
const std::string& op_type = op_proto.type();
std::vector<int64_t> dims = {1, 1, 1, 1};
/* ------ Prepare "ins" ------ */
std::map<std::string,
......@@ -621,8 +605,6 @@ static bool CollectInformationFromOpInfo(
if (operators_with_attrs.count(op_type)) {
VLOG(6) << "Found operator " << op_type << " using special AttributeMap";
attrs = operators_with_attrs[op_type];
// default_attrs.insert(operators_with_attrs[op_type].begin(),
// operators_with_attrs[op_type].end());
}
VLOG(6) << "Prepared Default Attributes Map, size = " << default_attrs.size();
......@@ -655,8 +637,8 @@ static bool CollectInformationFromOpInfo(
/* ------ Run GradOpMaker ------ */
if (!op_info.dygraph_grad_op_maker_) {
VLOG(6) << op_type << " has no GradOpMaker, skip it";
skipped_operators.insert(op_type);
VLOG(6) << op_type << " has no GradOpMaker";
*generate_forward_only = true;
return false;
}
......@@ -666,17 +648,19 @@ static bool CollectInformationFromOpInfo(
if (!grad_node) {
VLOG(6) << "Got nullptr GradOpNode for " << op_type
<< " likely registered EmptyGradOpMaker, skip it";
skipped_operators.insert(op_type);
<< " likely registered EmptyGradOpMaker";
*generate_forward_only = true;
return false;
}
/*
if (grad_node->size() > 1) {
// Backward attributes can be super complicated
VLOG(6) << "Skip GradOpNode with multiple OpBases for now: " << op_type;
skipped_operators.insert(op_type);
return false;
}
*/
VLOG(6) << "Prepared GradOpNode";
......@@ -901,6 +885,7 @@ static std::string GenerateGradNodeCreationContent(
/* --------- CodeGen: Forward ----- */
/* -------------------------------- */
static std::pair<std::string, std::string> GenerateForwardFunctionContents(
bool generate_forward_only,
const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map,
const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map,
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map,
......@@ -1044,7 +1029,6 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
// [Generation] Get Attrs
dygraph_function_args_str +=
", const paddle::framework::AttributeMap& attr_map";
generated_function_body += "\n";
// [Generation] Get TraceOp
const char* FWD_TRACE_OP_TEMPLATE =
......@@ -1092,16 +1076,18 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
VLOG(6) << "Converted Output VarBase to EagerTensor(s)";
// [Generation] ComputeRequireGrad -> GradNodeCreation
std::string grad_node_creation_body_str = GenerateGradNodeCreationContent(
fwd_inputs_name_pos_map, fwd_outputs_name_pos_map,
grad_ins_fwd_slotname_map, op_type, in_vars, out_vars);
generated_function_body += grad_node_creation_body_str;
generated_function_body += "\n";
VLOG(6) << "Generated GradNode Creation codes";
if (!generate_forward_only) {
std::string grad_node_creation_body_str = GenerateGradNodeCreationContent(
fwd_inputs_name_pos_map, fwd_outputs_name_pos_map,
grad_ins_fwd_slotname_map, op_type, in_vars, out_vars);
generated_function_body += grad_node_creation_body_str;
generated_function_body += "\n";
VLOG(6) << "Generated GradNode Creation codes";
}
// [Generation] Handle return: Tuple/Vector/Tensor
generated_function_body += "\n";
std::string return_str;
std::string return_str = "";
std::string return_type_str = "";
std::string function_proto_return_type_str = "";
if (return_contents.size() > 1) {
......@@ -1124,14 +1110,20 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
const char* FWD_FUNCTION_PROTO_RETURN_TEMPLATE = "std::tuple<%s>";
function_proto_return_type_str = paddle::string::Sprintf(
FWD_FUNCTION_PROTO_RETURN_TEMPLATE, return_type_str);
} else {
} else if (return_contents.size() == 1) {
// Return vector<Tensor> or Tensor
return_type_str = return_types[0];
const char* FWD_TENSOR_RETURN_TEMPLATE = " return %s;";
return_str =
paddle::string::Sprintf(FWD_TENSOR_RETURN_TEMPLATE, return_contents[0]);
function_proto_return_type_str = return_type_str;
} else {
return_str = "return nullptr;";
function_proto_return_type_str = "void*";
}
generated_function_body += return_str;
generated_function_body += "\n";
VLOG(6) << "Generated return codes";
......@@ -1139,6 +1131,11 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
// [Generation] Get Full Function
std::string function_name = op_type + "_dygraph_function";
if (dygraph_function_args_str.size() > 0) {
auto iter = dygraph_function_args_str.begin();
if ((*iter) == ',') dygraph_function_args_str.erase(iter);
}
const char* FWD_FUNCTION_TEMPLATE = "%s %s(%s) {\n\n%s\n}\n\n";
std::string fwd_function_str = paddle::string::Sprintf(
FWD_FUNCTION_TEMPLATE, function_proto_return_type_str, function_name,
......@@ -1601,11 +1598,11 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
/* ---- Collect Information ---- */
/* ----------------------------- */
std::vector<std::string> grad_op_types;
std::vector<proto::OpProto::Var> in_vars;
std::vector<proto::OpProto::Var> out_vars;
std::map<std::string, std::string> grad_outs_slotname_map;
std::map<std::string, std::string> grad_ins_fwd_slotname_map;
std::map<std::string, std::string> grad_ins_grad_slotname_map;
std::vector<proto::OpProto::Var> in_vars;
std::vector<proto::OpProto::Var> out_vars;
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>
grad_ins;
......@@ -1614,20 +1611,31 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
grad_outs;
VLOG(6) << "-------- CollectInformationFromOpInfo -------";
bool is_available = CollectInformationFromOpInfo(
op_info, &grad_op_types, &grad_outs_slotname_map,
&grad_ins_fwd_slotname_map, &grad_ins_grad_slotname_map, &in_vars,
&out_vars, &grad_ins, &grad_outs);
if (!is_available) continue;
CollectForwardInformationFromOpInfo(op_info, &in_vars, &out_vars);
bool generate_forward_only = false;
bool is_available = CollectGradInformationFromOpInfo(
op_info, &generate_forward_only, &grad_op_types,
&grad_outs_slotname_map, &grad_ins_fwd_slotname_map,
&grad_ins_grad_slotname_map, &grad_ins, &grad_outs);
if (!is_available && !generate_forward_only) {
VLOG(6) << "Skipped operator: " << op_type;
continue;
}
VLOG(6) << "-------- PurifyOpProto -------";
std::unordered_map<std::string, size_t> fwd_inputs_name_pos_map;
std::unordered_map<std::string, size_t> fwd_outputs_name_pos_map;
PurifyOpProto(*op_proto, &fwd_inputs_name_pos_map,
&fwd_outputs_name_pos_map, &grad_outs_slotname_map,
&grad_ins_fwd_slotname_map, &grad_ins_grad_slotname_map,
&in_vars, &out_vars, &grad_ins, &grad_outs);
PurifyForwardOpProto(*op_proto, &fwd_inputs_name_pos_map,
&fwd_outputs_name_pos_map, &in_vars, &out_vars);
if (!generate_forward_only) {
PurifyGradOpProto(*op_proto, &grad_outs_slotname_map,
&grad_ins_fwd_slotname_map, &grad_ins_grad_slotname_map,
&grad_ins, &grad_outs);
}
/* --------------------------- */
/* --------- CodeGen --------- */
......@@ -1636,16 +1644,19 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
VLOG(6) << "-------- GenerateForwardFunctionContents -------";
std::pair<std::string, std::string> body_and_declaration =
GenerateForwardFunctionContents(
fwd_inputs_name_pos_map, fwd_outputs_name_pos_map,
grad_ins_fwd_slotname_map, grad_ins_grad_slotname_map,
grad_outs_slotname_map, grad_ins, grad_outs, op_type, in_vars,
out_vars);
generate_forward_only, fwd_inputs_name_pos_map,
fwd_outputs_name_pos_map, grad_ins_fwd_slotname_map,
grad_ins_grad_slotname_map, grad_outs_slotname_map, grad_ins,
grad_outs, op_type, in_vars, out_vars);
fwd_function_str += body_and_declaration.first + "\n";
/* ---- dygraph_forward_api.h ---- */
std::string fwd_function_declare_str = body_and_declaration.second;
dygraph_forward_api_str += fwd_function_declare_str;
if (generate_forward_only) continue;
/* ---- nodes.h ---- */
VLOG(6) << "-------- GenerateGradNodeHeaderContents -------";
grad_node_h_str +=
......@@ -1681,6 +1692,52 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
GenerateNodeCCFile(output_dir, grad_node_cc_str);
}
static void PrepareAttrMapForOps() {
// Handle "fused_elemwise_add_activation"
std::vector<std::string> functor_list = {"a", "b"};
operators_with_attrs["fused_elemwise_add_activation"] = {};
operators_with_attrs["fused_elemwise_add_activation"]["functor_list"] =
functor_list;
// Handle "fused_elemwise_activation"
operators_with_attrs["fused_elemwise_activation"] = {};
operators_with_attrs["fused_elemwise_activation"]["functor_list"] =
functor_list;
// Handle "reverse"
std::vector<int> axis = {0};
operators_with_attrs["reverse"] = {};
operators_with_attrs["reverse"]["axis"] = axis;
// Handle "flip"
operators_with_attrs["flip"] = {};
operators_with_attrs["flip"]["axis"] = axis;
// Handle "cast"
operators_with_attrs["cast"] = {};
operators_with_attrs["cast"]["out_dtype"] = 5;
operators_with_attrs["cast"]["in_dtype"] = 5;
// Handle "transfer_dtype"
operators_with_attrs["transfer_dtype"] = {};
operators_with_attrs["transfer_dtype"]["out_dtype"] = 5;
operators_with_attrs["transfer_dtype"]["in_dtype"] = 5;
}
static void CollectOperatorsToCodeGen(const std::string& op_list_path) {
std::string line;
std::ifstream op_list_file(op_list_path);
if (op_list_file.is_open()) {
while (getline(op_list_file, line)) {
operators_to_codegen.insert(line);
}
op_list_file.close();
} else {
PADDLE_THROW(
paddle::platform::errors::Fatal("Unable to open op_list.txt file"));
}
}
} // namespace framework
} // namespace paddle
......@@ -1693,8 +1750,8 @@ int main(int argc, char* argv[]) {
std::string eager_root = argv[1];
std::string op_list_path = argv[2];
CollectOperatorsToCodeGen(op_list_path);
PrepareAttrMapForOps();
paddle::framework::CollectOperatorsToCodeGen(op_list_path);
paddle::framework::PrepareAttrMapForOps();
paddle::framework::DygraphCodeGeneration(eager_root);
......
......@@ -215,7 +215,6 @@ spp
floor
gelu
retinanet_detection_output
minus
push_dense
silu
sequence_erase
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册