diff --git a/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.cc b/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.cc index 9c4e6192be424230cca1df798e520e280c48d3d1..de84742146cabbaf10c6f194d930cf7008f4c134 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.cc @@ -32,6 +32,32 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const { "assign_value", "gaussian_random", "set_value"}; + // NOTE: Hack squeeze2 0D-Tensor input + // If squeeze2 inputs 0D-Tensor and axes, The 0D-Tensor's shape will convert + // to 1D-Tensor, which could lead error. We hack squeeze2's axes attribute to + // resolve this. Change 0D-Tensor input to 1D-Tensor input and then make + // axes->axes[: -1] + for (const ir::Node* n : graph->Nodes()) { + if (n->IsOp() && n->Op()->Type() == "unsqueeze2") { + if (n->Op()->HasAttr("axes")) { + auto axes = + PADDLE_GET_CONST(std::vector, n->Op()->GetAttr("axes")); + for (const ir::Node* var : n->inputs) { + if (var->Var() && + var->Var()->GetType() == proto::VarType::LOD_TENSOR) { + std::vector shape = var->Var()->GetShape(); + if (shape.empty()) { + axes.pop_back(); + n->Op()->SetAttr("axes", axes); + VLOG(4) << "unsqueeze2 axes dims is full, fix dim -> dim[:-1] to " + "avoid 0D-Tensor input error"; + } + } + } + } + } + } + for (const ir::Node* n : graph->Nodes()) { if (n->IsOp() && op_cases_fix_attr.count(n->Op()->Type())) { if (n->Op()->HasAttr("shape")) {