提交 89b9d86d 编写于 作者: S sneaxiy

fix windows compile bug

test=develop
上级 83ac8515
...@@ -62,7 +62,8 @@ struct DDimPlusVisitor { ...@@ -62,7 +62,8 @@ struct DDimPlusVisitor {
DDim DDim::operator+(const DDim& d) const { DDim DDim::operator+(const DDim& d) const {
PADDLE_ENFORCE(rank_ == d.rank_); PADDLE_ENFORCE(rank_ == d.rank_);
DDim ret(rank_); DDim ret;
ret.rank_ = rank_;
ret.apply_visitor(DDimPlusVisitor(Get(), d.Get())); ret.apply_visitor(DDimPlusVisitor(Get(), d.Get()));
return ret; return ret;
} }
...@@ -82,7 +83,8 @@ struct DDimMulVisitor { ...@@ -82,7 +83,8 @@ struct DDimMulVisitor {
DDim DDim::operator*(const DDim& d) const { DDim DDim::operator*(const DDim& d) const {
PADDLE_ENFORCE(rank_ == d.rank_); PADDLE_ENFORCE(rank_ == d.rank_);
DDim ret(rank_); DDim ret;
ret.rank_ = rank_;
ret.apply_visitor(DDimMulVisitor(Get(), d.Get())); ret.apply_visitor(DDimMulVisitor(Get(), d.Get()));
return ret; return ret;
} }
...@@ -121,7 +123,9 @@ int64_t product(const DDim& ddim) { ...@@ -121,7 +123,9 @@ int64_t product(const DDim& ddim) {
DDim slice_ddim(const DDim& dim, int begin, int end) { DDim slice_ddim(const DDim& dim, int begin, int end) {
PADDLE_ENFORCE(begin >= 0, PADDLE_ENFORCE(begin >= 0,
"Begin index can't be less than zero in ddim slice."); "Begin index can't be less than zero in ddim slice.");
DDim ret(end - begin); int len = end - begin;
DDim ret;
ret.rank_ = len;
dynamic_dim_assign(dim.Get() + begin, ret.GetMutable(), ret.rank_); dynamic_dim_assign(dim.Get() + begin, ret.GetMutable(), ret.rank_);
return ret; return ret;
} }
...@@ -156,7 +160,8 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims) { ...@@ -156,7 +160,8 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims) {
DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); } DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); }
DDim stride(const DDim& ddim) { DDim stride(const DDim& ddim) {
DDim strides(ddim.size()); DDim strides;
strides.rank_ = ddim.size();
strides[ddim.size() - 1] = 1; strides[ddim.size() - 1] = 1;
for (int i = ddim.size() - 2; i >= 0; --i) { for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i + 1]; strides[i] = strides[i + 1] * ddim[i + 1];
...@@ -165,7 +170,8 @@ DDim stride(const DDim& ddim) { ...@@ -165,7 +170,8 @@ DDim stride(const DDim& ddim) {
} }
DDim stride_numel(const DDim& ddim) { DDim stride_numel(const DDim& ddim) {
DDim strides(ddim.size()); DDim strides;
strides.rank_ = ddim.size();
strides[ddim.size() - 1] = ddim[ddim.size() - 1]; strides[ddim.size() - 1] = ddim[ddim.size() - 1];
for (int i = ddim.size() - 2; i >= 0; --i) { for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i]; strides[i] = strides[i + 1] * ddim[i];
......
...@@ -136,12 +136,6 @@ class DDim { ...@@ -136,12 +136,6 @@ class DDim {
return *reinterpret_cast<const Dim<M>*>(p); return *reinterpret_cast<const Dim<M>*>(p);
} }
// Construct DDim with given rank
// Only used in friend functions
explicit DDim(int rank) : rank_(rank) {
PADDLE_ENFORCE(rank_ >= 0 && rank_ < kMaxRank, "Invalid rank %d", rank);
}
friend DDim slice_ddim(const DDim& dim, int begin, int end); friend DDim slice_ddim(const DDim& dim, int begin, int end);
friend DDim stride(const DDim& ddim); friend DDim stride(const DDim& ddim);
friend DDim stride_numel(const DDim& ddim); friend DDim stride_numel(const DDim& ddim);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册