/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include // for shared_ptr #include #include #include #include "paddle/fluid/operators/jit/gen_base.h" #include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/fluid/operators/jit/kernel_key.h" #include "paddle/fluid/platform/place.h" namespace paddle { namespace operators { namespace jit { template class JitCodePool { typedef std::unique_ptr GenBasePtr; typedef std::unordered_map JitCodeMap; public: JitCodePool() = default; static JitCodePool& Instance() { static thread_local JitCodePool g_jit_codes; return g_jit_codes; } const JitCodeMap& AllKernels() { return codes_; } bool Has(size_t key) const { return codes_.find(key) != codes_.end(); } void Insert(size_t key, GenBasePtr value) { codes_.emplace(key, std::move(value)); } private: JitCodeMap codes_; DISABLE_COPY_AND_ASSIGN(JitCodePool); }; typedef std::unique_ptr KernelPtr; typedef std::unordered_map, KernelKey::Hash> KernelMap; class KernelPool { public: static KernelPool& Instance(); KernelPool() = default; KernelMap& AllKernels() { return pool_; } void Insert(const KernelKey& key, KernelPtr value) { if (pool_.find(key) == pool_.end()) { pool_.emplace(key, std::vector()); } pool_.at(key).emplace_back(std::move(value)); } private: KernelMap pool_; DISABLE_COPY_AND_ASSIGN(KernelPool); }; // Every kernel should have refer code and it should be used in unit tests, // so refer kernels should have it's independent kernel pool class ReferKernelPool { public: static ReferKernelPool& Instance(); ReferKernelPool() = default; KernelMap& AllKernels() { return pool_; } void Insert(const KernelKey& key, KernelPtr value) { if (pool_.find(key) == pool_.end()) { pool_.emplace(key, std::vector()); } pool_.at(key).emplace_back(std::move(value)); } private: KernelMap pool_; DISABLE_COPY_AND_ASSIGN(ReferKernelPool); }; // Refer code do not related with attr, and always on CPUPlace template inline Func GetRefer() { auto& ref_pool = ReferKernelPool().Instance().AllKernels(); KernelKey kkey(KT, platform::CPUPlace()); auto ref_iter = ref_pool.find(kkey); PADDLE_ENFORCE(ref_iter != ref_pool.end(), "Every Kernel should have reference function."); auto& ref_impls = ref_iter->second; for (auto& impl : ref_impls) { auto i = dynamic_cast*>(impl.get()); if (i) { return i->GetFunc(); } } return nullptr; } template const Func Get(Attr attr) { size_t key = GetKey(attr); auto& codes = JitCodePool().Instance(); if (codes.Has(key)) { return codes.AllKernels().at(key)->template getCode(); } if (std::is_same::value) { auto p = CreateJitCode(attr); if (p) { auto f = p->template getCode(); codes.Insert(key, std::move(p)); return f; } } // pool: (KernelKey(type, place), vector) auto& pool = KernelPool().Instance().AllKernels(); KernelKey kkey(KT, PlaceType()); auto iter = pool.find(kkey); if (iter != pool.end()) { auto& impls = iter->second; for (auto& impl : impls) { auto i = dynamic_cast*>(impl.get()); if (i && i->UseMe(attr)) { return i->GetFunc(); } } } // The last implementation should be reference function on CPUPlace. return GetRefer(); } } // namespace jit } // namespace operators } // namespace paddle