/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { namespace framework { DDim make_ddim(std::initializer_list dims) { return DDim(dims.begin(), dims.size()); } DDim make_ddim(const std::vector& dims) { return DDim(dims.data(), dims.size()); } DDim make_ddim(const std::vector& dims) { return DDim(dims.data(), dims.size()); } struct DDimEqualityVisitor { explicit DDimEqualityVisitor(const int64_t* d) : d_(d) {} template inline bool operator()(const Dim& self) const { return UnrollCompare::Run(self.Get(), d_); } const int64_t* d_; }; bool DDim::operator==(const DDim& d) const { return rank_ == d.rank_ && this->apply_visitor(DDimEqualityVisitor(d.Get())); } bool DDim::operator!=(const DDim& d) const { return !(*this == d); } struct DDimPlusVisitor { explicit DDimPlusVisitor(const int64_t* d1, const int64_t* d2) : d1_(d1), d2_(d2) {} template inline void operator()(Dim& self) const { UnrollAdd::Run(d1_, d2_, self.GetMutable()); } const int64_t* d1_; const int64_t* d2_; }; DDim DDim::operator+(const DDim& d) const { PADDLE_ENFORCE(rank_ == d.rank_); DDim ret; ret.rank_ = rank_; ret.apply_visitor(DDimPlusVisitor(Get(), d.Get())); return ret; } struct DDimMulVisitor { explicit DDimMulVisitor(const int64_t* d1, const int64_t* d2) : d1_(d1), d2_(d2) {} template inline void operator()(Dim& self) const { UnrollMul::Run(d1_, d2_, self.GetMutable()); } const int64_t* d1_; const int64_t* d2_; }; DDim DDim::operator*(const DDim& d) const { PADDLE_ENFORCE(rank_ == d.rank_); DDim ret; ret.rank_ = rank_; ret.apply_visitor(DDimMulVisitor(Get(), d.Get())); return ret; } int64_t get(const DDim& ddim, int idx) { return ddim[idx]; } void set(DDim& ddim, int idx, int value) { ddim[idx] = value; } // NOLINT std::vector vectorize(const DDim& ddim) { std::vector result(DDim::kMaxRank); dynamic_dim_assign(ddim.Get(), result.data(), ddim.size()); result.resize(ddim.size()); return result; } // NOTE: framework::vectorize converts to type int64_t // which does not fit cudnn inputs. std::vector vectorize2int(const DDim& ddim) { std::vector result(DDim::kMaxRank); dynamic_dim_assign(ddim.Get(), result.data(), ddim.size()); result.resize(ddim.size()); return result; } struct ProductVisitor { template inline int64_t operator()(const Dim& dim) { return product(dim); } }; int64_t product(const DDim& ddim) { return ddim.apply_visitor(ProductVisitor()); } DDim slice_ddim(const DDim& dim, int begin, int end) { PADDLE_ENFORCE(begin >= 0, "Begin index can't be less than zero in ddim slice."); int len = end - begin; DDim ret; ret.rank_ = len; dynamic_dim_assign(dim.Get() + begin, ret.GetMutable(), ret.rank_); return ret; } int arity(const DDim& d) { return d.size(); } /// \cond HIDDEN struct DDimPrinter { std::ostream& os; explicit DDimPrinter(std::ostream& os_) : os(os_) {} template void operator()(const T& t) { os << t; } }; /// \endcond std::ostream& operator<<(std::ostream& os, const DDim& ddim) { ddim.apply_visitor(DDimPrinter(os)); return os; } DDim flatten_to_2d(const DDim& src, int num_col_dims) { int rank = src.size(); return make_ddim({product(slice_ddim(src, 0, num_col_dims)), product(slice_ddim(src, num_col_dims, rank))}); } DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); } DDim stride(const DDim& ddim) { DDim strides; strides.rank_ = ddim.size(); strides[ddim.size() - 1] = 1; for (int i = ddim.size() - 2; i >= 0; --i) { strides[i] = strides[i + 1] * ddim[i + 1]; } return strides; } DDim stride_numel(const DDim& ddim) { DDim strides; strides.rank_ = ddim.size(); strides[ddim.size() - 1] = ddim[ddim.size() - 1]; for (int i = ddim.size() - 2; i >= 0; --i) { strides[i] = strides[i + 1] * ddim[i]; } return strides; } } // namespace framework } // namespace paddle