From f5830c0526f2a5e279967c03b95736ac512e36d6 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Wed, 26 Jul 2023 12:54:28 +0800 Subject: [PATCH] [0D-Tensor] CINN supports `fill_constant`, fix infershape and pass (#55563) * [0D-Tensor] CINN supports fill_constant, fix infershape and pass * fix infershape of fill_constant * add back fill_constant to zero_tensor_trick_pass --- .../frontend/pass/expand_zero_dim_pass.cc | 17 ++++++++++++ paddle/cinn/hlir/op/elementwise.cc | 1 - test/cinn/ops/test_zero_dim_tensor.py | 27 +++++++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/paddle/cinn/frontend/pass/expand_zero_dim_pass.cc b/paddle/cinn/frontend/pass/expand_zero_dim_pass.cc index a4212555fc6..9732478c75b 100644 --- a/paddle/cinn/frontend/pass/expand_zero_dim_pass.cc +++ b/paddle/cinn/frontend/pass/expand_zero_dim_pass.cc @@ -39,6 +39,9 @@ class ExpandZeroDimPass : public ProgramPass { if (instr->op_type == "transpose") { builder.AppendInstruction(HandleTranspose(instr)); continue; + } else if (instr->op_type == "fill_constant") { + builder.AppendInstruction(HandleFillConstant(instr)); + continue; } for (auto& input : instr->inputs) { if (input->shape.empty()) { @@ -101,6 +104,20 @@ class ExpandZeroDimPass : public ProgramPass { } return new_instr; } + + // Before: out-0D = fill_constant([], 123.456, "out", "float32") + // After: out-1D = fill_constant([1], 123.456, "out", "float32") + Instruction HandleFillConstant(const Instruction& instr) { + Instruction new_instr = instr; + std::vector shape = + new_instr.GetAttrs>("shape"); + if (shape.empty()) { + shape.push_back(1); + VLOG(4) << "Change fill_constant's attribute shape from [] to [1]"; + } + new_instr.SetAttr>("shape", shape); + return new_instr; + } }; } // namespace pass diff --git a/paddle/cinn/hlir/op/elementwise.cc b/paddle/cinn/hlir/op/elementwise.cc index 27ac596de87..f1ac4fb2352 100644 --- a/paddle/cinn/hlir/op/elementwise.cc +++ b/paddle/cinn/hlir/op/elementwise.cc @@ -393,7 +393,6 @@ std::vector InferShapeForFillConstant( const framework::AttrMapType &attrs) { CHECK(attrs.count("shape")); auto shape = absl::get>(attrs.at("shape")); - CHECK(!shape.empty()) << "shape attr is empty!"; return {shape}; } diff --git a/test/cinn/ops/test_zero_dim_tensor.py b/test/cinn/ops/test_zero_dim_tensor.py index 86e57e7e2fe..0b90ac4b80d 100644 --- a/test/cinn/ops/test_zero_dim_tensor.py +++ b/test/cinn/ops/test_zero_dim_tensor.py @@ -1318,6 +1318,33 @@ class TestReshapeOp1DTo0D(TestReshapeOp): self.target_shape = [] +@OpTestTool.skip_if( + not is_compiled_with_cuda(), "x86 test will be skipped due to timeout." +) +class TestFillConstantOp(OpTest): + def setUp(self): + np.random.seed(2023) + self.target_shape = () + + def build_paddle_program(self, target): + out = paddle.full([], 123.456, "float32") + + self.paddle_outputs = [out] + + def build_cinn_program(self, target): + builder = NetBuilder("fill_constant_op") + out = builder.fill_constant([], 123.456, "out", "float32") + + prog = builder.build() + res = self.get_cinn_output(prog, target, [], [], [out]) + + self.cinn_outputs = res + self.assertEqual(res[0].shape, self.target_shape) + + def test_check_results(self): + self.check_outputs_and_grads() + + @OpTestTool.skip_if( not is_compiled_with_cuda(), "x86 test will be skipped due to timeout." ) -- GitLab