未验证 提交 09f8e31d 编写于 作者: H HongyuJia 提交者: GitHub

[CINN Support 0D-Tensor] CINN hack squeeze2 with trick temporarily (#53454)

上级 d611e48c
...@@ -32,6 +32,32 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const { ...@@ -32,6 +32,32 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const {
"assign_value", "assign_value",
"gaussian_random", "gaussian_random",
"set_value"}; "set_value"};
// NOTE: Hack squeeze2 0D-Tensor input
// If squeeze2 inputs 0D-Tensor and axes, The 0D-Tensor's shape will convert
// to 1D-Tensor, which could lead error. We hack squeeze2's axes attribute to
// resolve this. Change 0D-Tensor input to 1D-Tensor input and then make
// axes->axes[: -1]
for (const ir::Node* n : graph->Nodes()) {
if (n->IsOp() && n->Op()->Type() == "unsqueeze2") {
if (n->Op()->HasAttr("axes")) {
auto axes =
PADDLE_GET_CONST(std::vector<int32_t>, n->Op()->GetAttr("axes"));
for (const ir::Node* var : n->inputs) {
if (var->Var() &&
var->Var()->GetType() == proto::VarType::LOD_TENSOR) {
std::vector<int64_t> shape = var->Var()->GetShape();
if (shape.empty()) {
axes.pop_back();
n->Op()->SetAttr("axes", axes);
VLOG(4) << "unsqueeze2 axes dims is full, fix dim -> dim[:-1] to "
"avoid 0D-Tensor input error";
}
}
}
}
}
}
for (const ir::Node* n : graph->Nodes()) { for (const ir::Node* n : graph->Nodes()) {
if (n->IsOp() && op_cases_fix_attr.count(n->Op()->Type())) { if (n->IsOp() && op_cases_fix_attr.count(n->Op()->Type())) {
if (n->Op()->HasAttr("shape")) { if (n->Op()->HasAttr("shape")) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册