/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include // for std::move #include #include "paddle/fluid/operators/jit/gen_base.h" #include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/fluid/operators/jit/kernel_key.h" #include "paddle/fluid/operators/jit/kernel_pool.h" #include "paddle/fluid/platform/place.h" namespace paddle { namespace operators { namespace jit { template inline typename std::enable_if< std::is_same::value && std::is_same::value, const Kernel*>::type GetJitCode(const typename KernelTuple::attr_type& attr) { using Attr = typename KernelTuple::attr_type; int64_t key = JitCodeKey(attr); auto& codes = JitCodePool::Instance(); if (codes.Has(key)) { return codes.AllKernels().at(key).get(); } // creator is not related with attr, so can use KernelKey as key KernelKey kkey(KernelTuple::kernel_type, PlaceType()); // pool: (KernelKey(type, place), vector) auto& creator_map = JitCodeCreatorPool::Instance().AllCreators(); auto iter = creator_map.find(kkey); if (iter != creator_map.end()) { auto& creators = iter->second; for (auto& cur : creators) { auto i = dynamic_cast*>(cur.get()); if (i && i->CanBeUsed(attr)) { auto p = i->CreateJitCode(attr); if (p) { auto res = p.get(); codes.Insert(key, std::move(p)); return res; } } } } return nullptr; } template inline typename std::enable_if< !std::is_same::value || !std::is_same::value, const Kernel*>::type GetJitCode(const typename KernelTuple::attr_type& attr) { return nullptr; } // Refer code do not related with attr, which is just for cast // Refer is always on CPUPlace template inline const Kernel* GetReferKernel() { auto& ref_pool = ReferKernelPool::Instance().AllKernels(); KernelKey kkey(KernelTuple::kernel_type, platform::CPUPlace()); auto ref_iter = ref_pool.find(kkey); PADDLE_ENFORCE(ref_iter != ref_pool.end(), "Every Kernel should have reference function."); auto& ref_impls = ref_iter->second; for (auto& impl : ref_impls) { auto i = dynamic_cast*>(impl.get()); if (i) { return i; } } return nullptr; } template inline typename KernelTuple::func_type GetReferFunc() { auto ker = GetReferKernel(); auto p = dynamic_cast*>(ker); PADDLE_ENFORCE(p, "The Refer kernel should exsit"); return p->GetFunc(); } // Return all Kernels that can be used template std::vector GetAllCandidateKernels( const typename KernelTuple::attr_type& attr) { // the search order shoudl be jitcode > more > refer std::vector res; auto jitker = GetJitCode(attr); if (jitker) { res.emplace_back(jitker); } // more kernelpool: (KernelKey(type, place), vector) KernelKey kkey(KernelTuple::kernel_type, PlaceType()); auto& pool = KernelPool::Instance().AllKernels(); auto iter = pool.find(kkey); if (iter != pool.end()) { auto& impls = iter->second; for (auto& impl : impls) { auto i = dynamic_cast*>(impl.get()); if (i && i->CanBeUsed(attr)) { res.emplace_back(i); } } } // The last implementation should be reference function on CPUPlace. auto ref = GetReferKernel(); PADDLE_ENFORCE(ref != nullptr, "Refer Kernel can not be empty."); res.emplace_back(ref); return res; } template std::vector> GetAllCandidateFuncsWithTypes(const typename KernelTuple::attr_type& attr) { using Func = typename KernelTuple::func_type; auto kers = GetAllCandidateKernels(attr); std::vector> res; for (auto k : kers) { std::string name = k->ImplType(); if (name == "JitCode") { auto i = dynamic_cast(k); PADDLE_ENFORCE(i, "jitcode kernel cast can not fail."); res.emplace_back(std::make_pair(name, i->template getCode())); } else { auto i = dynamic_cast*>(k); PADDLE_ENFORCE(i, "kernel cast can not fail."); res.emplace_back(std::make_pair(name, i->GetFunc())); } } return res; } template std::vector GetAllCandidateFuncs( const typename KernelTuple::attr_type& attr) { auto funcs = GetAllCandidateFuncsWithTypes(attr); std::vector res; for (auto& i : funcs) { res.emplace_back(i.second); } return res; } template typename KernelTuple::func_type GetDefaultBestFunc( const typename KernelTuple::attr_type& attr) { auto funcs = GetAllCandidateFuncs(attr); PADDLE_ENFORCE_GE(funcs.size(), 1UL); // Here could do some runtime benchmark of this attr and return the best one. // But yet just get the first one as the default best one, // which is searched in order and tuned by offline. return funcs[0]; } template class KernelFuncs { public: KernelFuncs() = default; static KernelFuncs& Cache() { static thread_local KernelFuncs g_func_cache; return g_func_cache; } // the exposed interface to use typename KernelTuple::func_type At( const typename KernelTuple::attr_type& attr) { // Maybe here is not good enough, not all kernels should have jitcode int64_t key = JitCodeKey(attr); if (Has(key)) { return funcs_.at(key); } // If do not have this attr in cache then get the default best auto func = GetDefaultBestFunc(attr); Insert(key, func); return func; } typename KernelTuple::func_type operator[]( const typename KernelTuple::attr_type& attr) { return At(attr); } protected: bool Has(int64_t key) const { return funcs_.find(key) != funcs_.end(); } void Insert(int64_t key, typename KernelTuple::func_type func) { funcs_.emplace(key, func); } private: std::unordered_map funcs_; DISABLE_COPY_AND_ASSIGN(KernelFuncs); }; const char* to_string(KernelType kt); const char* to_string(SeqPoolType kt); KernelType to_kerneltype(const std::string& act); inline std::ostream& operator<<(std::ostream& os, const lstm_attr_t& attr) { os << "dim_size[" << attr.d << "],act_gate[" << to_string(attr.act_gate) << "],act_cand[" << to_string(attr.act_cand) << "],act_cell[" << to_string(attr.act_cell) << "],use_peephole[" << (attr.use_peephole ? "True" : "False") << "]"; return os; } inline std::ostream& operator<<(std::ostream& os, const gru_attr_t& attr) { os << "dim_size[" << attr.d << "],act_gate[" << to_string(attr.act_gate) << "],act_cand[" << to_string(attr.act_cand) << "]"; return os; } inline std::ostream& operator<<(std::ostream& os, const seq_pool_attr_t& attr) { os << "height_size[" << attr.h << "],width_size[" << attr.w << "],pool_type[" << to_string(attr.type) << "]"; return os; } inline std::ostream& operator<<(std::ostream& os, const emb_seq_pool_attr_t& attr) { os << "table_height[" << attr.table_height << "],table_width[" << attr.table_width << "],index_height[" << attr.index_height << "],index_width[" << attr.index_width << "],output_width[" << attr.out_width << "],pool_type[" << to_string(attr.pool_type) << "]"; return os; } inline std::ostream& operator<<(std::ostream& os, const sgd_attr_t& attr) { os << "param_height[" << attr.param_height << "],param_width[" << attr.param_width << "],grad_height[" << attr.grad_height << "],grad_width[" << attr.grad_width << "],selected_rows_size[" << attr.selected_rows_size << "]"; return os; } inline std::ostream& operator<<(std::ostream& os, const matmul_attr_t& attr) { os << "M[" << attr.m << "],N[" << attr.n << "],K[" << attr.k << "]"; return os; } // expose the method to pack matmul weight template void pack_weights(const T* src, T* dst, int n, int k); } // namespace jit } // namespace operators } // namespace paddle