From 09f8e31dbcc43c853201a755aea653a3003b2029 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Fri, 28 Apr 2023 20:55:23 +0800 Subject: [PATCH] [CINN Support 0D-Tensor] CINN hack squeeze2 with trick temporarily (#53454) --- .../cinn_zero_tensor_trick_pass.cc | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.cc b/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.cc index 9c4e6192be4..de84742146c 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.cc @@ -32,6 +32,32 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const { "assign_value", "gaussian_random", "set_value"}; + // NOTE: Hack squeeze2 0D-Tensor input + // If squeeze2 inputs 0D-Tensor and axes, The 0D-Tensor's shape will convert + // to 1D-Tensor, which could lead error. We hack squeeze2's axes attribute to + // resolve this. Change 0D-Tensor input to 1D-Tensor input and then make + // axes->axes[: -1] + for (const ir::Node* n : graph->Nodes()) { + if (n->IsOp() && n->Op()->Type() == "unsqueeze2") { + if (n->Op()->HasAttr("axes")) { + auto axes = + PADDLE_GET_CONST(std::vector, n->Op()->GetAttr("axes")); + for (const ir::Node* var : n->inputs) { + if (var->Var() && + var->Var()->GetType() == proto::VarType::LOD_TENSOR) { + std::vector shape = var->Var()->GetShape(); + if (shape.empty()) { + axes.pop_back(); + n->Op()->SetAttr("axes", axes); + VLOG(4) << "unsqueeze2 axes dims is full, fix dim -> dim[:-1] to " + "avoid 0D-Tensor input error"; + } + } + } + } + } + } + for (const ir::Node* n : graph->Nodes()) { if (n->IsOp() && op_cases_fix_attr.count(n->Op()->Type())) { if (n->Op()->HasAttr("shape")) { -- GitLab