diff --git a/src/operators/flatten_op.cpp b/src/operators/flatten_op.cpp index 844053b21c734849fc65bbe9d2d91b400f4d931b..4e52485345b4f891738e1de147e53746e03928c3 100644 --- a/src/operators/flatten_op.cpp +++ b/src/operators/flatten_op.cpp @@ -36,7 +36,7 @@ void FlattenOp::InferShape() const { "The axis should be less than or equal to input tensor's rank."); const auto &out_dims = GetOutputShape(axis, in_dims); - this->param_.Out()->Resize(in_dims); + this->param_.Out()->Resize(framework::make_ddim(out_dims)); } } // namespace operators