From f736f15160aaf57141b28a19d32dad255223a1ef Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Mon, 17 Jul 2023 10:42:37 +0800 Subject: [PATCH] [0D-Tensor] CINN supports unsqueeze, delete hack in Paddle's pass (#55336) --- paddle/cinn/hlir/op/elementwise.cc | 4 +- .../cinn_zero_tensor_trick_pass.cc | 26 --------- test/cinn/ops/test_zero_dim_tensor.py | 54 +++++++++++++++++++ 3 files changed, 56 insertions(+), 28 deletions(-) diff --git a/paddle/cinn/hlir/op/elementwise.cc b/paddle/cinn/hlir/op/elementwise.cc index c225ea48118..f1e2616ecd9 100644 --- a/paddle/cinn/hlir/op/elementwise.cc +++ b/paddle/cinn/hlir/op/elementwise.cc @@ -755,8 +755,8 @@ std::shared_ptr StrategyForExpandDims( std::vector> InferShapeForExpandDims( const std::vector> &inputs_shape, const framework::AttrMapType &attrs) { - CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) - << "The input's shape size is 0! Please check again."; + CHECK(!inputs_shape.empty()) + << "At least 1 input tensor for expand_dims operator."; CHECK_EQ(inputs_shape.size(), 1U); const std::vector &axes = diff --git a/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.cc b/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.cc index e1833609ef4..678c49e6106 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_zero_tensor_trick_pass.cc @@ -32,32 +32,6 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const { "assign_value", "gaussian_random", "set_value"}; - // NOTE: Hack squeeze2 0D-Tensor input - // If squeeze2 inputs 0D-Tensor and axes, The 0D-Tensor's shape will convert - // to 1D-Tensor, which could lead error. We hack squeeze2's axes attribute to - // resolve this. Change 0D-Tensor input to 1D-Tensor input and then make - // axes->axes[: -1] - for (const ir::Node* n : graph->Nodes()) { - if (n->IsOp() && n->Op()->Type() == "unsqueeze2") { - if (n->Op()->HasAttr("axes")) { - auto axes = - PADDLE_GET_CONST(std::vector, n->Op()->GetAttr("axes")); - for (const ir::Node* var : n->inputs) { - if (var->Var() && - var->Var()->GetType() == proto::VarType::LOD_TENSOR) { - std::vector shape = var->Var()->GetShape(); - if (shape.empty()) { - axes.pop_back(); - n->Op()->SetAttr("axes", axes); - VLOG(4) << "unsqueeze2 axes dims is full, fix dim -> dim[:-1] to " - "avoid 0D-Tensor input error"; - } - } - } - } - } - } - // CINN ops in this white list support 0D-Tensor, wait-list = {"remainder"} const std::unordered_set white_op_list{"elementwise_add", "elementwise_sub", diff --git a/test/cinn/ops/test_zero_dim_tensor.py b/test/cinn/ops/test_zero_dim_tensor.py index 26a7ce52375..2b5ae1f5cac 100644 --- a/test/cinn/ops/test_zero_dim_tensor.py +++ b/test/cinn/ops/test_zero_dim_tensor.py @@ -668,6 +668,60 @@ class TestTransposeOp(OpTest): self.check_outputs_and_grads() +@OpTestTool.skip_if( + not is_compiled_with_cuda(), "x86 test will be skipped due to timeout." +) +class TestExpandDimsOp(OpTest): + def setUp(self): + np.random.seed(2023) + self.dtype = "float32" + self.init_input() + + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, []).astype(self.dtype), + } + self.unsqueeze_dim = [0] + self.target_shape = (1,) + + def build_paddle_program(self, target): + x = paddle.to_tensor(self.inputs["x"], stop_gradient=False) + out = paddle.unsqueeze(x, self.unsqueeze_dim) + + self.paddle_outputs = [out] + + def build_cinn_program(self, target): + builder = NetBuilder("unsqueeze_op") + x = builder.create_input( + cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x" + ) + out = builder.expand_dims(x, self.unsqueeze_dim) + + prog = builder.build() + res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], [out]) + + self.cinn_outputs = res + self.assertEqual(res[0].shape, self.target_shape) + + def test_check_results(self): + self.check_outputs_and_grads() + + +@OpTestTool.skip_if( + not is_compiled_with_cuda(), "x86 test will be skipped due to timeout." +) +class TestExpandDimsOp2D(TestExpandDimsOp): + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, []).astype(self.dtype), + } + self.unsqueeze_dim = [0, 1] + self.target_shape = ( + 1, + 1, + ) + + @OpTestTool.skip_if( not is_compiled_with_cuda(), "x86 test will be skipped due to timeout." ) -- GitLab