未验证 提交 de874cdd 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Enabled generation for special operators, the GradNode/Inputs/Outputs of which are empty (#37837)

上级 27d1f811
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <gflags/gflags.h>
#include <algorithm> #include <algorithm>
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
...@@ -27,69 +26,21 @@ ...@@ -27,69 +26,21 @@
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
DEFINE_bool(generate_all, false, namespace paddle {
"Generate all operators currently registered in Paddle"); namespace framework {
static std::unordered_map<std::string, paddle::framework::AttributeMap> static std::unordered_map<std::string, paddle::framework::AttributeMap>
operators_with_attrs = {}; operators_with_attrs = {};
static std::unordered_set<std::string> operators_to_skip = { static std::unordered_set<std::string> operators_to_skip = {
"pull_sparse", "pull_box_extended_sparse", "pull_sparse_v2", "chunk_eval", // Stupid tensor name
"pull_box_sparse", "fused_attention", "diag_v2", "minus", "pull_sparse", "pull_box_extended_sparse",
"c_split"}; "pull_sparse_v2", "pull_box_sparse", "fused_attention",
"diag_v2", "c_split"};
static std::unordered_set<std::string> operators_to_codegen = {}; static std::unordered_set<std::string> operators_to_codegen = {};
static std::unordered_set<std::string> skipped_operators = {}; static std::unordered_set<std::string> skipped_operators = {};
static void PrepareAttrMapForOps() {
// Handle "fused_elemwise_add_activation"
std::vector<std::string> functor_list = {"a", "b"};
operators_with_attrs["fused_elemwise_add_activation"] = {};
operators_with_attrs["fused_elemwise_add_activation"]["functor_list"] =
functor_list;
// Handle "fused_elemwise_activation"
operators_with_attrs["fused_elemwise_activation"] = {};
operators_with_attrs["fused_elemwise_activation"]["functor_list"] =
functor_list;
// Handle "reverse"
std::vector<int> axis = {0};
operators_with_attrs["reverse"] = {};
operators_with_attrs["reverse"]["axis"] = axis;
// Handle "flip"
operators_with_attrs["flip"] = {};
operators_with_attrs["flip"]["axis"] = axis;
// Handle "cast"
operators_with_attrs["cast"] = {};
operators_with_attrs["cast"]["out_dtype"] = 5;
operators_with_attrs["cast"]["in_dtype"] = 5;
// Handle "transfer_dtype"
operators_with_attrs["transfer_dtype"] = {};
operators_with_attrs["transfer_dtype"]["out_dtype"] = 5;
operators_with_attrs["transfer_dtype"]["in_dtype"] = 5;
}
static void CollectOperatorsToCodeGen(const std::string& op_list_path) {
std::string line;
std::ifstream op_list_file(op_list_path);
if (op_list_file.is_open()) {
while (getline(op_list_file, line)) {
operators_to_codegen.insert(line);
}
op_list_file.close();
} else {
PADDLE_THROW(
paddle::platform::errors::Fatal("Unable to open op_list.txt file"));
}
}
namespace paddle {
namespace framework {
static std::string AttrTypeToString(const proto::AttrType& type) { static std::string AttrTypeToString(const proto::AttrType& type) {
std::string ret; std::string ret;
switch (type) { switch (type) {
...@@ -392,10 +343,7 @@ static bool CheckOpProto(proto::OpProto* op_proto) { ...@@ -392,10 +343,7 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
// Only handle matmul_v2 for now // Only handle matmul_v2 for now
VLOG(1) << "------ Analyzing Op ------: " << op_type; VLOG(1) << "------ Analyzing Op ------: " << op_type;
if (!FLAGS_generate_all) {
if (!operators_to_codegen.count(op_type)) return false; if (!operators_to_codegen.count(op_type)) return false;
}
if (operators_to_skip.count(op_type)) return false; if (operators_to_skip.count(op_type)) return false;
return true; return true;
...@@ -404,21 +352,12 @@ static bool CheckOpProto(proto::OpProto* op_proto) { ...@@ -404,21 +352,12 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
/* --------------------------------------- */ /* --------------------------------------- */
/* --------- Preprocess Ins/Outs --------- */ /* --------- Preprocess Ins/Outs --------- */
/* --------------------------------------- */ /* --------------------------------------- */
static void PurifyOpProto( static void PurifyForwardOpProto(
const proto::OpProto& op_proto, const proto::OpProto& op_proto,
std::unordered_map<std::string, size_t>* fwd_inputs_name_pos_map, std::unordered_map<std::string, size_t>* fwd_inputs_name_pos_map,
std::unordered_map<std::string, size_t>* fwd_outputs_name_pos_map, std::unordered_map<std::string, size_t>* fwd_outputs_name_pos_map,
std::map<std::string, std::string>* grad_outs_slotname_map,
std::map<std::string, std::string>* grad_ins_fwd_slotname_map,
std::map<std::string, std::string>* grad_ins_grad_slotname_map,
std::vector<proto::OpProto::Var>* in_vars, std::vector<proto::OpProto::Var>* in_vars,
std::vector<proto::OpProto::Var>* out_vars, std::vector<proto::OpProto::Var>* out_vars) {
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_ins,
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_outs) {
// Op Name // Op Name
const std::string op_name = op_proto.type(); const std::string op_name = op_proto.type();
...@@ -440,6 +379,72 @@ static void PurifyOpProto( ...@@ -440,6 +379,72 @@ static void PurifyOpProto(
} }
} }
in_vars->erase(iter); in_vars->erase(iter);
}
}
}
for (const proto::OpProto::Var& output : op_proto.outputs()) {
std::string output_name = output.name();
// Delete dispensable tensor unless specified in op_outs_map
if (output.dispensable()) {
if (!op_outs_map.count(op_name) ||
!op_outs_map[op_name].count(output_name)) {
VLOG(6) << "Removing Dispensable Output: " << output_name;
// out_vars
auto iter = out_vars->begin();
for (iter = out_vars->begin(); iter != out_vars->end(); iter++) {
if (iter->name() == output_name) {
break;
}
}
out_vars->erase(iter);
}
}
}
/* ------ Maping forward slot name to fwd position ------ */
size_t in_pos = 0;
for (const auto& var : *in_vars) {
VLOG(6) << "Mapping input tensor: " << var.name()
<< " To position: " << in_pos;
(*fwd_inputs_name_pos_map)[var.name()] = in_pos;
in_pos++;
}
size_t out_pos = 0;
for (const auto& var : *out_vars) {
VLOG(6) << "Mapping output tensor: " << var.name()
<< " To position: " << out_pos;
(*fwd_outputs_name_pos_map)[var.name()] = out_pos;
out_pos++;
}
}
static void PurifyGradOpProto(
const proto::OpProto& op_proto,
std::map<std::string, std::string>* grad_outs_slotname_map,
std::map<std::string, std::string>* grad_ins_fwd_slotname_map,
std::map<std::string, std::string>* grad_ins_grad_slotname_map,
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_ins,
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_outs) {
// Op Name
const std::string op_name = op_proto.type();
// Handle dispensable inputs
for (const proto::OpProto::Var& input : op_proto.inputs()) {
std::string input_name = input.name();
// Delete dispensable tensor unless specified in op_ins_map
if (input.dispensable()) {
if (!op_ins_map.count(op_name) ||
!op_ins_map[op_name].count(input_name)) {
VLOG(6) << "Removing Dispensable Input: " << input_name;
// grad_outs_slotname_map // grad_outs_slotname_map
auto grad_outs_slotname_map_purified = *grad_outs_slotname_map; auto grad_outs_slotname_map_purified = *grad_outs_slotname_map;
...@@ -478,15 +483,6 @@ static void PurifyOpProto( ...@@ -478,15 +483,6 @@ static void PurifyOpProto(
!op_outs_map[op_name].count(output_name)) { !op_outs_map[op_name].count(output_name)) {
VLOG(6) << "Removing Dispensable Output: " << output_name; VLOG(6) << "Removing Dispensable Output: " << output_name;
// out_vars
auto iter = out_vars->begin();
for (iter = out_vars->begin(); iter != out_vars->end(); iter++) {
if (iter->name() == output_name) {
break;
}
}
out_vars->erase(iter);
// grad_ins_grad_slotname_map // grad_ins_grad_slotname_map
auto grad_ins_grad_slotname_map_purified = *grad_ins_grad_slotname_map; auto grad_ins_grad_slotname_map_purified = *grad_ins_grad_slotname_map;
for (const auto& iter : *grad_ins_grad_slotname_map) { for (const auto& iter : *grad_ins_grad_slotname_map) {
...@@ -514,52 +510,40 @@ static void PurifyOpProto( ...@@ -514,52 +510,40 @@ static void PurifyOpProto(
} }
} }
} }
/* ------ Maping forward slot name to fwd position ------ */
size_t in_pos = 0;
for (const auto& var : *in_vars) {
VLOG(6) << "Mapping input tensor: " << var.name()
<< " To position: " << in_pos;
(*fwd_inputs_name_pos_map)[var.name()] = in_pos;
in_pos++;
}
size_t out_pos = 0;
for (const auto& var : *out_vars) {
VLOG(6) << "Mapping output tensor: " << var.name()
<< " To position: " << out_pos;
(*fwd_outputs_name_pos_map)[var.name()] = out_pos;
out_pos++;
}
} }
/* -------------------------------- */ /* -------------------------------- */
/* --------- Collect Info --------- */ /* --------- Collect Info --------- */
/* -------------------------------- */ /* -------------------------------- */
static bool CollectInformationFromOpInfo( static void CollectForwardInformationFromOpInfo(
const paddle::framework::OpInfo& op_info, const paddle::framework::OpInfo& op_info,
std::vector<std::string>* grad_op_types,
std::map<std::string, std::string>* grad_outs_slotname_map,
std::map<std::string, std::string>* grad_ins_fwd_slotname_map,
std::map<std::string, std::string>* grad_ins_grad_slotname_map,
std::vector<proto::OpProto::Var>* in_vars, std::vector<proto::OpProto::Var>* in_vars,
std::vector<proto::OpProto::Var>* out_vars, std::vector<proto::OpProto::Var>* out_vars) {
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_ins,
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_outs) {
const proto::OpProto& op_proto = *op_info.proto_; const proto::OpProto& op_proto = *op_info.proto_;
const std::string& op_type = op_proto.type();
std::vector<int64_t> dims = {1, 1, 1, 1};
for (const proto::OpProto::Var& input : op_proto.inputs()) { for (const proto::OpProto::Var& input : op_proto.inputs()) {
in_vars->push_back(input); in_vars->push_back(input);
} }
for (const proto::OpProto::Var& output : op_proto.outputs()) { for (const proto::OpProto::Var& output : op_proto.outputs()) {
out_vars->push_back(output); out_vars->push_back(output);
} }
}
static bool CollectGradInformationFromOpInfo(
const paddle::framework::OpInfo& op_info, bool* generate_forward_only,
std::vector<std::string>* grad_op_types, // grad
std::map<std::string, std::string>* grad_outs_slotname_map, // grad
std::map<std::string, std::string>* grad_ins_fwd_slotname_map, // grad
std::map<std::string, std::string>* grad_ins_grad_slotname_map, // grad
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_ins, // grad
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_outs // grad
) {
const proto::OpProto& op_proto = *op_info.proto_;
const std::string& op_type = op_proto.type();
std::vector<int64_t> dims = {1, 1, 1, 1};
/* ------ Prepare "ins" ------ */ /* ------ Prepare "ins" ------ */
std::map<std::string, std::map<std::string,
...@@ -621,8 +605,6 @@ static bool CollectInformationFromOpInfo( ...@@ -621,8 +605,6 @@ static bool CollectInformationFromOpInfo(
if (operators_with_attrs.count(op_type)) { if (operators_with_attrs.count(op_type)) {
VLOG(6) << "Found operator " << op_type << " using special AttributeMap"; VLOG(6) << "Found operator " << op_type << " using special AttributeMap";
attrs = operators_with_attrs[op_type]; attrs = operators_with_attrs[op_type];
// default_attrs.insert(operators_with_attrs[op_type].begin(),
// operators_with_attrs[op_type].end());
} }
VLOG(6) << "Prepared Default Attributes Map, size = " << default_attrs.size(); VLOG(6) << "Prepared Default Attributes Map, size = " << default_attrs.size();
...@@ -655,8 +637,8 @@ static bool CollectInformationFromOpInfo( ...@@ -655,8 +637,8 @@ static bool CollectInformationFromOpInfo(
/* ------ Run GradOpMaker ------ */ /* ------ Run GradOpMaker ------ */
if (!op_info.dygraph_grad_op_maker_) { if (!op_info.dygraph_grad_op_maker_) {
VLOG(6) << op_type << " has no GradOpMaker, skip it"; VLOG(6) << op_type << " has no GradOpMaker";
skipped_operators.insert(op_type); *generate_forward_only = true;
return false; return false;
} }
...@@ -666,17 +648,19 @@ static bool CollectInformationFromOpInfo( ...@@ -666,17 +648,19 @@ static bool CollectInformationFromOpInfo(
if (!grad_node) { if (!grad_node) {
VLOG(6) << "Got nullptr GradOpNode for " << op_type VLOG(6) << "Got nullptr GradOpNode for " << op_type
<< " likely registered EmptyGradOpMaker, skip it"; << " likely registered EmptyGradOpMaker";
skipped_operators.insert(op_type); *generate_forward_only = true;
return false; return false;
} }
/*
if (grad_node->size() > 1) { if (grad_node->size() > 1) {
// Backward attributes can be super complicated // Backward attributes can be super complicated
VLOG(6) << "Skip GradOpNode with multiple OpBases for now: " << op_type; VLOG(6) << "Skip GradOpNode with multiple OpBases for now: " << op_type;
skipped_operators.insert(op_type); skipped_operators.insert(op_type);
return false; return false;
} }
*/
VLOG(6) << "Prepared GradOpNode"; VLOG(6) << "Prepared GradOpNode";
...@@ -901,6 +885,7 @@ static std::string GenerateGradNodeCreationContent( ...@@ -901,6 +885,7 @@ static std::string GenerateGradNodeCreationContent(
/* --------- CodeGen: Forward ----- */ /* --------- CodeGen: Forward ----- */
/* -------------------------------- */ /* -------------------------------- */
static std::pair<std::string, std::string> GenerateForwardFunctionContents( static std::pair<std::string, std::string> GenerateForwardFunctionContents(
bool generate_forward_only,
const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map, const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map,
const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map, const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map,
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map, const std::map<std::string, std::string>& grad_ins_fwd_slotname_map,
...@@ -1044,7 +1029,6 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1044,7 +1029,6 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
// [Generation] Get Attrs // [Generation] Get Attrs
dygraph_function_args_str += dygraph_function_args_str +=
", const paddle::framework::AttributeMap& attr_map"; ", const paddle::framework::AttributeMap& attr_map";
generated_function_body += "\n";
// [Generation] Get TraceOp // [Generation] Get TraceOp
const char* FWD_TRACE_OP_TEMPLATE = const char* FWD_TRACE_OP_TEMPLATE =
...@@ -1092,16 +1076,18 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1092,16 +1076,18 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
VLOG(6) << "Converted Output VarBase to EagerTensor(s)"; VLOG(6) << "Converted Output VarBase to EagerTensor(s)";
// [Generation] ComputeRequireGrad -> GradNodeCreation // [Generation] ComputeRequireGrad -> GradNodeCreation
if (!generate_forward_only) {
std::string grad_node_creation_body_str = GenerateGradNodeCreationContent( std::string grad_node_creation_body_str = GenerateGradNodeCreationContent(
fwd_inputs_name_pos_map, fwd_outputs_name_pos_map, fwd_inputs_name_pos_map, fwd_outputs_name_pos_map,
grad_ins_fwd_slotname_map, op_type, in_vars, out_vars); grad_ins_fwd_slotname_map, op_type, in_vars, out_vars);
generated_function_body += grad_node_creation_body_str; generated_function_body += grad_node_creation_body_str;
generated_function_body += "\n"; generated_function_body += "\n";
VLOG(6) << "Generated GradNode Creation codes"; VLOG(6) << "Generated GradNode Creation codes";
}
// [Generation] Handle return: Tuple/Vector/Tensor // [Generation] Handle return: Tuple/Vector/Tensor
generated_function_body += "\n"; generated_function_body += "\n";
std::string return_str; std::string return_str = "";
std::string return_type_str = ""; std::string return_type_str = "";
std::string function_proto_return_type_str = ""; std::string function_proto_return_type_str = "";
if (return_contents.size() > 1) { if (return_contents.size() > 1) {
...@@ -1124,14 +1110,20 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1124,14 +1110,20 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
const char* FWD_FUNCTION_PROTO_RETURN_TEMPLATE = "std::tuple<%s>"; const char* FWD_FUNCTION_PROTO_RETURN_TEMPLATE = "std::tuple<%s>";
function_proto_return_type_str = paddle::string::Sprintf( function_proto_return_type_str = paddle::string::Sprintf(
FWD_FUNCTION_PROTO_RETURN_TEMPLATE, return_type_str); FWD_FUNCTION_PROTO_RETURN_TEMPLATE, return_type_str);
} else {
} else if (return_contents.size() == 1) {
// Return vector<Tensor> or Tensor // Return vector<Tensor> or Tensor
return_type_str = return_types[0]; return_type_str = return_types[0];
const char* FWD_TENSOR_RETURN_TEMPLATE = " return %s;"; const char* FWD_TENSOR_RETURN_TEMPLATE = " return %s;";
return_str = return_str =
paddle::string::Sprintf(FWD_TENSOR_RETURN_TEMPLATE, return_contents[0]); paddle::string::Sprintf(FWD_TENSOR_RETURN_TEMPLATE, return_contents[0]);
function_proto_return_type_str = return_type_str; function_proto_return_type_str = return_type_str;
} else {
return_str = "return nullptr;";
function_proto_return_type_str = "void*";
} }
generated_function_body += return_str; generated_function_body += return_str;
generated_function_body += "\n"; generated_function_body += "\n";
VLOG(6) << "Generated return codes"; VLOG(6) << "Generated return codes";
...@@ -1139,6 +1131,11 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1139,6 +1131,11 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
// [Generation] Get Full Function // [Generation] Get Full Function
std::string function_name = op_type + "_dygraph_function"; std::string function_name = op_type + "_dygraph_function";
if (dygraph_function_args_str.size() > 0) {
auto iter = dygraph_function_args_str.begin();
if ((*iter) == ',') dygraph_function_args_str.erase(iter);
}
const char* FWD_FUNCTION_TEMPLATE = "%s %s(%s) {\n\n%s\n}\n\n"; const char* FWD_FUNCTION_TEMPLATE = "%s %s(%s) {\n\n%s\n}\n\n";
std::string fwd_function_str = paddle::string::Sprintf( std::string fwd_function_str = paddle::string::Sprintf(
FWD_FUNCTION_TEMPLATE, function_proto_return_type_str, function_name, FWD_FUNCTION_TEMPLATE, function_proto_return_type_str, function_name,
...@@ -1601,11 +1598,11 @@ static void DygraphCodeGeneration(const std::string& output_dir) { ...@@ -1601,11 +1598,11 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
/* ---- Collect Information ---- */ /* ---- Collect Information ---- */
/* ----------------------------- */ /* ----------------------------- */
std::vector<std::string> grad_op_types; std::vector<std::string> grad_op_types;
std::vector<proto::OpProto::Var> in_vars;
std::vector<proto::OpProto::Var> out_vars;
std::map<std::string, std::string> grad_outs_slotname_map; std::map<std::string, std::string> grad_outs_slotname_map;
std::map<std::string, std::string> grad_ins_fwd_slotname_map; std::map<std::string, std::string> grad_ins_fwd_slotname_map;
std::map<std::string, std::string> grad_ins_grad_slotname_map; std::map<std::string, std::string> grad_ins_grad_slotname_map;
std::vector<proto::OpProto::Var> in_vars;
std::vector<proto::OpProto::Var> out_vars;
std::map<std::string, std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>> std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>
grad_ins; grad_ins;
...@@ -1614,20 +1611,31 @@ static void DygraphCodeGeneration(const std::string& output_dir) { ...@@ -1614,20 +1611,31 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
grad_outs; grad_outs;
VLOG(6) << "-------- CollectInformationFromOpInfo -------"; VLOG(6) << "-------- CollectInformationFromOpInfo -------";
bool is_available = CollectInformationFromOpInfo(
op_info, &grad_op_types, &grad_outs_slotname_map,
&grad_ins_fwd_slotname_map, &grad_ins_grad_slotname_map, &in_vars,
&out_vars, &grad_ins, &grad_outs);
if (!is_available) continue; CollectForwardInformationFromOpInfo(op_info, &in_vars, &out_vars);
bool generate_forward_only = false;
bool is_available = CollectGradInformationFromOpInfo(
op_info, &generate_forward_only, &grad_op_types,
&grad_outs_slotname_map, &grad_ins_fwd_slotname_map,
&grad_ins_grad_slotname_map, &grad_ins, &grad_outs);
if (!is_available && !generate_forward_only) {
VLOG(6) << "Skipped operator: " << op_type;
continue;
}
VLOG(6) << "-------- PurifyOpProto -------"; VLOG(6) << "-------- PurifyOpProto -------";
std::unordered_map<std::string, size_t> fwd_inputs_name_pos_map; std::unordered_map<std::string, size_t> fwd_inputs_name_pos_map;
std::unordered_map<std::string, size_t> fwd_outputs_name_pos_map; std::unordered_map<std::string, size_t> fwd_outputs_name_pos_map;
PurifyOpProto(*op_proto, &fwd_inputs_name_pos_map, PurifyForwardOpProto(*op_proto, &fwd_inputs_name_pos_map,
&fwd_outputs_name_pos_map, &grad_outs_slotname_map, &fwd_outputs_name_pos_map, &in_vars, &out_vars);
if (!generate_forward_only) {
PurifyGradOpProto(*op_proto, &grad_outs_slotname_map,
&grad_ins_fwd_slotname_map, &grad_ins_grad_slotname_map, &grad_ins_fwd_slotname_map, &grad_ins_grad_slotname_map,
&in_vars, &out_vars, &grad_ins, &grad_outs); &grad_ins, &grad_outs);
}
/* --------------------------- */ /* --------------------------- */
/* --------- CodeGen --------- */ /* --------- CodeGen --------- */
...@@ -1636,16 +1644,19 @@ static void DygraphCodeGeneration(const std::string& output_dir) { ...@@ -1636,16 +1644,19 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
VLOG(6) << "-------- GenerateForwardFunctionContents -------"; VLOG(6) << "-------- GenerateForwardFunctionContents -------";
std::pair<std::string, std::string> body_and_declaration = std::pair<std::string, std::string> body_and_declaration =
GenerateForwardFunctionContents( GenerateForwardFunctionContents(
fwd_inputs_name_pos_map, fwd_outputs_name_pos_map, generate_forward_only, fwd_inputs_name_pos_map,
grad_ins_fwd_slotname_map, grad_ins_grad_slotname_map, fwd_outputs_name_pos_map, grad_ins_fwd_slotname_map,
grad_outs_slotname_map, grad_ins, grad_outs, op_type, in_vars, grad_ins_grad_slotname_map, grad_outs_slotname_map, grad_ins,
out_vars); grad_outs, op_type, in_vars, out_vars);
fwd_function_str += body_and_declaration.first + "\n"; fwd_function_str += body_and_declaration.first + "\n";
/* ---- dygraph_forward_api.h ---- */ /* ---- dygraph_forward_api.h ---- */
std::string fwd_function_declare_str = body_and_declaration.second; std::string fwd_function_declare_str = body_and_declaration.second;
dygraph_forward_api_str += fwd_function_declare_str; dygraph_forward_api_str += fwd_function_declare_str;
if (generate_forward_only) continue;
/* ---- nodes.h ---- */ /* ---- nodes.h ---- */
VLOG(6) << "-------- GenerateGradNodeHeaderContents -------"; VLOG(6) << "-------- GenerateGradNodeHeaderContents -------";
grad_node_h_str += grad_node_h_str +=
...@@ -1681,6 +1692,52 @@ static void DygraphCodeGeneration(const std::string& output_dir) { ...@@ -1681,6 +1692,52 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
GenerateNodeCCFile(output_dir, grad_node_cc_str); GenerateNodeCCFile(output_dir, grad_node_cc_str);
} }
static void PrepareAttrMapForOps() {
// Handle "fused_elemwise_add_activation"
std::vector<std::string> functor_list = {"a", "b"};
operators_with_attrs["fused_elemwise_add_activation"] = {};
operators_with_attrs["fused_elemwise_add_activation"]["functor_list"] =
functor_list;
// Handle "fused_elemwise_activation"
operators_with_attrs["fused_elemwise_activation"] = {};
operators_with_attrs["fused_elemwise_activation"]["functor_list"] =
functor_list;
// Handle "reverse"
std::vector<int> axis = {0};
operators_with_attrs["reverse"] = {};
operators_with_attrs["reverse"]["axis"] = axis;
// Handle "flip"
operators_with_attrs["flip"] = {};
operators_with_attrs["flip"]["axis"] = axis;
// Handle "cast"
operators_with_attrs["cast"] = {};
operators_with_attrs["cast"]["out_dtype"] = 5;
operators_with_attrs["cast"]["in_dtype"] = 5;
// Handle "transfer_dtype"
operators_with_attrs["transfer_dtype"] = {};
operators_with_attrs["transfer_dtype"]["out_dtype"] = 5;
operators_with_attrs["transfer_dtype"]["in_dtype"] = 5;
}
static void CollectOperatorsToCodeGen(const std::string& op_list_path) {
std::string line;
std::ifstream op_list_file(op_list_path);
if (op_list_file.is_open()) {
while (getline(op_list_file, line)) {
operators_to_codegen.insert(line);
}
op_list_file.close();
} else {
PADDLE_THROW(
paddle::platform::errors::Fatal("Unable to open op_list.txt file"));
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -1693,8 +1750,8 @@ int main(int argc, char* argv[]) { ...@@ -1693,8 +1750,8 @@ int main(int argc, char* argv[]) {
std::string eager_root = argv[1]; std::string eager_root = argv[1];
std::string op_list_path = argv[2]; std::string op_list_path = argv[2];
CollectOperatorsToCodeGen(op_list_path); paddle::framework::CollectOperatorsToCodeGen(op_list_path);
PrepareAttrMapForOps(); paddle::framework::PrepareAttrMapForOps();
paddle::framework::DygraphCodeGeneration(eager_root); paddle::framework::DygraphCodeGeneration(eager_root);
......
...@@ -215,7 +215,6 @@ spp ...@@ -215,7 +215,6 @@ spp
floor floor
gelu gelu
retinanet_detection_output retinanet_detection_output
minus
push_dense push_dense
silu silu
sequence_erase sequence_erase
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册