未验证 提交 a19b082e 编写于 作者: Z Zhen Wang 提交者: GitHub

Fix some bugs hidden in build_cinn_pass. (#46843)

* Fix some bugs hidden in build_cinn_pass.

* Update codes about OpTransInfo.

* Only support for the static reshape/reshape2 op.
上级 4297d05e
......@@ -100,7 +100,7 @@ else()
legacy_eager_codegen
COMMAND
${CMAKE_COMMAND} -E env
"LD_LIBRARY_PATH=$ENV{LD_LIBRARY_PATH}:${CMAKE_CURRENT_BINARY_DIR}/../../pybind"
"LD_LIBRARY_PATH=$ENV{LD_LIBRARY_PATH}:${CMAKE_CURRENT_BINARY_DIR}/../../pybind:${PADDLE_BINARY_DIR}/third_party/install/mklml/lib"
"${CMAKE_CURRENT_BINARY_DIR}/eager_generator"
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated"
"${CODE_GEN_SPLIT_FILE_COUNT}"
......
......@@ -52,6 +52,37 @@ using framework::ir::Node;
using GraphNodeVec = std::vector<Node*>;
using GraphNodeMap = std::unordered_map<Node*, Node*>;
OpTransInfo::OpTransInfo() {
// judgment condition for the dynamic slice
dynamic_op_cond_.emplace("slice", [](const ir::Node& node) -> bool {
if (!node.IsOp()) {
return false;
}
auto* op_desc = node.Op();
auto infer_flags =
op_desc->GetAttrIfExists<std::vector<int>>("infer_flags");
return std::find_if(infer_flags.begin(), infer_flags.end(), [](int v) {
return v < 0;
}) != infer_flags.end();
});
// judgment condition for the dynamic reshape
dynamic_op_cond_.emplace("reshape", [](const ir::Node& node) -> bool {
if (!node.IsOp()) {
return false;
}
auto* op_desc = node.Op();
bool has_shape_tensor = op_desc->Inputs().count("ShapeTensor") &&
op_desc->Inputs().at("ShapeTensor").size();
bool has_shape = op_desc->Inputs().count("Shape") &&
op_desc->Inputs().at("Shape").size();
return has_shape_tensor || has_shape;
});
// judgment condition for the dynamic reshape2
dynamic_op_cond_.emplace("reshape2", dynamic_op_cond_.at("reshape"));
}
std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
const GraphNodeSet& cluster) const {
std::unordered_set<std::string> deny_var_set;
......@@ -67,16 +98,16 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
};
for (auto* op : cluster) {
if (deny_param_cond.count(op->Name())) {
if (deny_param_cond_.count(op->Name())) {
const auto* desc = op->Op();
PADDLE_ENFORCE_NE(desc,
nullptr,
platform::errors::PreconditionNotMet(
"The Op %s's OpDesc should not be NULL, which has "
"a parameter in deny_param_cond.",
"a parameter in deny_param_cond_.",
op->Name().c_str()));
auto deny_param_names = deny_param_cond.at(op->Name());
auto deny_param_names = deny_param_cond_.at(op->Name());
VLOG(4) << "We found deny param " << get_debug_info(deny_param_names)
<< " in op [" << op->Name() << "].";
......@@ -107,6 +138,15 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
return deny_var_set;
}
bool OpTransInfo::IsInplaceOp(const OpDesc& op_desc) {
auto inputs = op_desc.InputArgumentNames();
std::unordered_set<std::string> input_set(inputs.begin(), inputs.end());
for (auto& name : op_desc.OutputArgumentNames()) {
if (input_set.count(name) > 0) return true;
}
return false;
}
namespace {
// The delim(`;`) that is used to split the FLAGS_allow_cinn_ops
// & FLAGS_deny_cinn_ops.
......@@ -539,15 +579,6 @@ void ReplaceSubGraphWithCinnOpNode(
RemoveSubGraphFromGraph(cluster, cluster_internals, graph);
}
static bool IsInplaceOp(const OpDesc& op_desc) {
auto inputs = op_desc.InputArgumentNames();
std::unordered_set<std::string> input_set(inputs.begin(), inputs.end());
for (auto& name : op_desc.OutputArgumentNames()) {
if (input_set.count(name) > 0) return true;
}
return false;
}
// Search all subgraphs which all op node supported by CINN,
// Here we using SubgraphDetector to detecte the subgraph that
// all of op node supported by CINN. We using OpMapperRegistry
......@@ -562,24 +593,27 @@ void SearchAllSubgraphs(Graph* graph) {
node_name) != nullptr;
// skip the dynamic ops
bool is_dynamic = false;
if (trans_info.dynamic_op_cond.count(node_name)) {
is_dynamic = trans_info.dynamic_op_cond.at(node_name)(node);
if (trans_info.dynamic_op_cond().count(node_name)) {
is_dynamic = trans_info.dynamic_op_cond().at(node_name)(*node);
}
bool is_support =
registered && !trans_info.default_deny_ops().count(node_name) &&
!is_dynamic && (node->IsOp() && !trans_info.IsInplaceOp(*node->Op()));
// if the op type is registered in CINN and allow_ops is not empty, return
// true only when it is in allow_ops
if (!allow_ops.empty()) {
return registered && !is_dynamic && allow_ops.count(node_name);
return is_support && allow_ops.count(node_name);
}
// if the op type is registered in CINN and deny_ops is not empty, return
// true only when it is not in deny_ops
if (!deny_ops.empty()) {
return registered && !is_dynamic && !deny_ops.count(node_name);
return is_support && !deny_ops.count(node_name);
}
// if the user doesn't set FLAGS_allow_cinn_ops and FLAGS_deny_cinn_ops,
// return true only when it is registered in CINN
return registered && !trans_info.default_deny_ops.count(node_name) &&
!is_dynamic && (node->IsOp() && !IsInplaceOp(*node->Op()));
return is_support;
};
VLOG(4) << "The allowed Cinn Ops: " << FLAGS_allow_cinn_ops;
VLOG(4) << "The denied Cinn Ops: " << FLAGS_deny_cinn_ops;
......
......@@ -21,8 +21,6 @@ limitations under the License. */
#include <unordered_set>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace framework {
......@@ -46,33 +44,37 @@ using Name2VarInfoMap =
std::shared_ptr<framework::ir::MemOptVarInfo>>;
using GraphNodeSet = std::unordered_set<ir::Node*>;
struct OpTransInfo {
const std::unordered_set<std::string> default_deny_ops{"feed", "fetch"};
// OpTransInfo contains informations used to detect subgraphs
// supported by the CINN compiler.
class OpTransInfo {
using DyOpCondT =
std::unordered_map<std::string, std::function<bool(const ir::Node&)>>;
using DeParamCondT =
std::unordered_map<std::string, std::unordered_set<std::string>>;
const std::unordered_map<std::string, std::function<bool(const ir::Node*)>>
dynamic_op_cond{
{"slice", [](const ir::Node* node) -> bool {
if (!node->IsOp()) {
return false;
}
auto* op_desc = node->Op();
auto infer_flags =
op_desc->GetAttrIfExists<std::vector<int>>("infer_flags");
if (std::find_if(
infer_flags.begin(), infer_flags.end(), [](int v) {
return v < 0;
}) != infer_flags.end()) {
return true;
}
return false;
}}};
public:
OpTransInfo();
const std::unordered_map<std::string, std::unordered_set<std::string>>
deny_param_cond{{"batch_norm", {"ReserveSpace"}},
{"batch_norm_grad", {"ReserveSpace"}}};
const DyOpCondT& dynamic_op_cond() const { return dynamic_op_cond_; }
const DeParamCondT& deny_param_cond() const { return deny_param_cond_; }
const std::unordered_set<std::string>& default_deny_ops() const {
return default_deny_ops_;
}
std::unordered_set<std::string> GetDenyVarNames(
const GraphNodeSet& cluster) const;
static bool IsInplaceOp(const OpDesc& op_desc);
private:
DyOpCondT dynamic_op_cond_;
DeParamCondT deny_param_cond_{{"batch_norm", {"ReserveSpace"}},
{"batch_norm_grad", {"ReserveSpace"}}};
std::unordered_set<std::string> default_deny_ops_{"feed", "fetch"};
};
// A pass named BuildCinnPass, the function of this pass is:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册