未验证 提交 a19b082e 编写于 作者: Z Zhen Wang 提交者: GitHub

Fix some bugs hidden in build_cinn_pass. (#46843)

* Fix some bugs hidden in build_cinn_pass.

* Update codes about OpTransInfo.

* Only support for the static reshape/reshape2 op.
上级 4297d05e
...@@ -100,7 +100,7 @@ else() ...@@ -100,7 +100,7 @@ else()
legacy_eager_codegen legacy_eager_codegen
COMMAND COMMAND
${CMAKE_COMMAND} -E env ${CMAKE_COMMAND} -E env
"LD_LIBRARY_PATH=$ENV{LD_LIBRARY_PATH}:${CMAKE_CURRENT_BINARY_DIR}/../../pybind" "LD_LIBRARY_PATH=$ENV{LD_LIBRARY_PATH}:${CMAKE_CURRENT_BINARY_DIR}/../../pybind:${PADDLE_BINARY_DIR}/third_party/install/mklml/lib"
"${CMAKE_CURRENT_BINARY_DIR}/eager_generator" "${CMAKE_CURRENT_BINARY_DIR}/eager_generator"
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated"
"${CODE_GEN_SPLIT_FILE_COUNT}" "${CODE_GEN_SPLIT_FILE_COUNT}"
......
...@@ -52,6 +52,37 @@ using framework::ir::Node; ...@@ -52,6 +52,37 @@ using framework::ir::Node;
using GraphNodeVec = std::vector<Node*>; using GraphNodeVec = std::vector<Node*>;
using GraphNodeMap = std::unordered_map<Node*, Node*>; using GraphNodeMap = std::unordered_map<Node*, Node*>;
OpTransInfo::OpTransInfo() {
// judgment condition for the dynamic slice
dynamic_op_cond_.emplace("slice", [](const ir::Node& node) -> bool {
if (!node.IsOp()) {
return false;
}
auto* op_desc = node.Op();
auto infer_flags =
op_desc->GetAttrIfExists<std::vector<int>>("infer_flags");
return std::find_if(infer_flags.begin(), infer_flags.end(), [](int v) {
return v < 0;
}) != infer_flags.end();
});
// judgment condition for the dynamic reshape
dynamic_op_cond_.emplace("reshape", [](const ir::Node& node) -> bool {
if (!node.IsOp()) {
return false;
}
auto* op_desc = node.Op();
bool has_shape_tensor = op_desc->Inputs().count("ShapeTensor") &&
op_desc->Inputs().at("ShapeTensor").size();
bool has_shape = op_desc->Inputs().count("Shape") &&
op_desc->Inputs().at("Shape").size();
return has_shape_tensor || has_shape;
});
// judgment condition for the dynamic reshape2
dynamic_op_cond_.emplace("reshape2", dynamic_op_cond_.at("reshape"));
}
std::unordered_set<std::string> OpTransInfo::GetDenyVarNames( std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
const GraphNodeSet& cluster) const { const GraphNodeSet& cluster) const {
std::unordered_set<std::string> deny_var_set; std::unordered_set<std::string> deny_var_set;
...@@ -67,16 +98,16 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames( ...@@ -67,16 +98,16 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
}; };
for (auto* op : cluster) { for (auto* op : cluster) {
if (deny_param_cond.count(op->Name())) { if (deny_param_cond_.count(op->Name())) {
const auto* desc = op->Op(); const auto* desc = op->Op();
PADDLE_ENFORCE_NE(desc, PADDLE_ENFORCE_NE(desc,
nullptr, nullptr,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The Op %s's OpDesc should not be NULL, which has " "The Op %s's OpDesc should not be NULL, which has "
"a parameter in deny_param_cond.", "a parameter in deny_param_cond_.",
op->Name().c_str())); op->Name().c_str()));
auto deny_param_names = deny_param_cond.at(op->Name()); auto deny_param_names = deny_param_cond_.at(op->Name());
VLOG(4) << "We found deny param " << get_debug_info(deny_param_names) VLOG(4) << "We found deny param " << get_debug_info(deny_param_names)
<< " in op [" << op->Name() << "]."; << " in op [" << op->Name() << "].";
...@@ -107,6 +138,15 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames( ...@@ -107,6 +138,15 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
return deny_var_set; return deny_var_set;
} }
bool OpTransInfo::IsInplaceOp(const OpDesc& op_desc) {
auto inputs = op_desc.InputArgumentNames();
std::unordered_set<std::string> input_set(inputs.begin(), inputs.end());
for (auto& name : op_desc.OutputArgumentNames()) {
if (input_set.count(name) > 0) return true;
}
return false;
}
namespace { namespace {
// The delim(`;`) that is used to split the FLAGS_allow_cinn_ops // The delim(`;`) that is used to split the FLAGS_allow_cinn_ops
// & FLAGS_deny_cinn_ops. // & FLAGS_deny_cinn_ops.
...@@ -539,15 +579,6 @@ void ReplaceSubGraphWithCinnOpNode( ...@@ -539,15 +579,6 @@ void ReplaceSubGraphWithCinnOpNode(
RemoveSubGraphFromGraph(cluster, cluster_internals, graph); RemoveSubGraphFromGraph(cluster, cluster_internals, graph);
} }
static bool IsInplaceOp(const OpDesc& op_desc) {
auto inputs = op_desc.InputArgumentNames();
std::unordered_set<std::string> input_set(inputs.begin(), inputs.end());
for (auto& name : op_desc.OutputArgumentNames()) {
if (input_set.count(name) > 0) return true;
}
return false;
}
// Search all subgraphs which all op node supported by CINN, // Search all subgraphs which all op node supported by CINN,
// Here we using SubgraphDetector to detecte the subgraph that // Here we using SubgraphDetector to detecte the subgraph that
// all of op node supported by CINN. We using OpMapperRegistry // all of op node supported by CINN. We using OpMapperRegistry
...@@ -562,24 +593,27 @@ void SearchAllSubgraphs(Graph* graph) { ...@@ -562,24 +593,27 @@ void SearchAllSubgraphs(Graph* graph) {
node_name) != nullptr; node_name) != nullptr;
// skip the dynamic ops // skip the dynamic ops
bool is_dynamic = false; bool is_dynamic = false;
if (trans_info.dynamic_op_cond.count(node_name)) { if (trans_info.dynamic_op_cond().count(node_name)) {
is_dynamic = trans_info.dynamic_op_cond.at(node_name)(node); is_dynamic = trans_info.dynamic_op_cond().at(node_name)(*node);
} }
bool is_support =
registered && !trans_info.default_deny_ops().count(node_name) &&
!is_dynamic && (node->IsOp() && !trans_info.IsInplaceOp(*node->Op()));
// if the op type is registered in CINN and allow_ops is not empty, return // if the op type is registered in CINN and allow_ops is not empty, return
// true only when it is in allow_ops // true only when it is in allow_ops
if (!allow_ops.empty()) { if (!allow_ops.empty()) {
return registered && !is_dynamic && allow_ops.count(node_name); return is_support && allow_ops.count(node_name);
} }
// if the op type is registered in CINN and deny_ops is not empty, return // if the op type is registered in CINN and deny_ops is not empty, return
// true only when it is not in deny_ops // true only when it is not in deny_ops
if (!deny_ops.empty()) { if (!deny_ops.empty()) {
return registered && !is_dynamic && !deny_ops.count(node_name); return is_support && !deny_ops.count(node_name);
} }
// if the user doesn't set FLAGS_allow_cinn_ops and FLAGS_deny_cinn_ops, // if the user doesn't set FLAGS_allow_cinn_ops and FLAGS_deny_cinn_ops,
// return true only when it is registered in CINN // return true only when it is registered in CINN
return registered && !trans_info.default_deny_ops.count(node_name) && return is_support;
!is_dynamic && (node->IsOp() && !IsInplaceOp(*node->Op()));
}; };
VLOG(4) << "The allowed Cinn Ops: " << FLAGS_allow_cinn_ops; VLOG(4) << "The allowed Cinn Ops: " << FLAGS_allow_cinn_ops;
VLOG(4) << "The denied Cinn Ops: " << FLAGS_deny_cinn_ops; VLOG(4) << "The denied Cinn Ops: " << FLAGS_deny_cinn_ops;
......
...@@ -21,8 +21,6 @@ limitations under the License. */ ...@@ -21,8 +21,6 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -46,33 +44,37 @@ using Name2VarInfoMap = ...@@ -46,33 +44,37 @@ using Name2VarInfoMap =
std::shared_ptr<framework::ir::MemOptVarInfo>>; std::shared_ptr<framework::ir::MemOptVarInfo>>;
using GraphNodeSet = std::unordered_set<ir::Node*>; using GraphNodeSet = std::unordered_set<ir::Node*>;
struct OpTransInfo { // OpTransInfo contains informations used to detect subgraphs
const std::unordered_set<std::string> default_deny_ops{"feed", "fetch"}; // supported by the CINN compiler.
class OpTransInfo {
using DyOpCondT =
std::unordered_map<std::string, std::function<bool(const ir::Node&)>>;
using DeParamCondT =
std::unordered_map<std::string, std::unordered_set<std::string>>;
const std::unordered_map<std::string, std::function<bool(const ir::Node*)>> public:
dynamic_op_cond{ OpTransInfo();
{"slice", [](const ir::Node* node) -> bool {
if (!node->IsOp()) {
return false;
}
auto* op_desc = node->Op();
auto infer_flags =
op_desc->GetAttrIfExists<std::vector<int>>("infer_flags");
if (std::find_if(
infer_flags.begin(), infer_flags.end(), [](int v) {
return v < 0;
}) != infer_flags.end()) {
return true;
}
return false;
}}};
const std::unordered_map<std::string, std::unordered_set<std::string>> const DyOpCondT& dynamic_op_cond() const { return dynamic_op_cond_; }
deny_param_cond{{"batch_norm", {"ReserveSpace"}},
{"batch_norm_grad", {"ReserveSpace"}}}; const DeParamCondT& deny_param_cond() const { return deny_param_cond_; }
const std::unordered_set<std::string>& default_deny_ops() const {
return default_deny_ops_;
}
std::unordered_set<std::string> GetDenyVarNames( std::unordered_set<std::string> GetDenyVarNames(
const GraphNodeSet& cluster) const; const GraphNodeSet& cluster) const;
static bool IsInplaceOp(const OpDesc& op_desc);
private:
DyOpCondT dynamic_op_cond_;
DeParamCondT deny_param_cond_{{"batch_norm", {"ReserveSpace"}},
{"batch_norm_grad", {"ReserveSpace"}}};
std::unordered_set<std::string> default_deny_ops_{"feed", "fetch"};
}; };
// A pass named BuildCinnPass, the function of this pass is: // A pass named BuildCinnPass, the function of this pass is:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册