提交 3d62c6da 编写于 作者: F fengjiayi

Fix bug

上级 823bdd67
...@@ -288,14 +288,11 @@ DDim::DDim(std::initializer_list<int64_t> init_list) { ...@@ -288,14 +288,11 @@ DDim::DDim(std::initializer_list<int64_t> init_list) {
// will be the product of tensor's first `num_col_dims` dimensions // will be the product of tensor's first `num_col_dims` dimensions
DDim flatten_to_2d(const DDim& src, int num_col_dims) { DDim flatten_to_2d(const DDim& src, int num_col_dims) {
int rank = src.size(); int rank = src.size();
return make_ddim( return make_ddim({product(slice_ddim(src, 0, num_col_dims)),
{static_cast<int>(product(slice_ddim(src, 0, num_col_dims))), product(slice_ddim(src, num_col_dims, rank))});
static_cast<int>(product(slice_ddim(src, num_col_dims, rank)))});
} }
DDim flatten_to_1d(const DDim& src) { DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); }
return make_ddim({static_cast<int>(product(src))});
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册