diff --git a/paddle/cinn/ir/tensor.cc b/paddle/cinn/ir/tensor.cc index b324e1fd25961fbb5d4bd6133a8446c32f0b732c..dbd056df69541503c65af4a6c72027c3ccefa1b4 100644 --- a/paddle/cinn/ir/tensor.cc +++ b/paddle/cinn/ir/tensor.cc @@ -553,14 +553,17 @@ ir::Tensor _Tensor_::Reshape(const std::vector &shape, auto selft = Tensor(const_cast(this)); { - Expr this_num_elements = Expr(1); - for (auto &e : this->shape) this_num_elements = this_num_elements * e; + int32_t this_num_elements = 1; + for (auto &e : this->shape) { + this_num_elements = this_num_elements * e.as_int32(); + } - Expr num_elements = Expr(1); - for (auto &e : shape) num_elements = num_elements * e; + int32_t num_elements = 1; + for (auto &e : shape) { + num_elements = num_elements * e.as_int32(); + } - CHECK(MathIsZero(this_num_elements - num_elements)) - << "number of elements mismatch"; + CHECK_EQ(this_num_elements, num_elements) << "number of elements mismatch."; } n->name = Context::Global().NewName(name + "_reshape");