/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include // for shared_ptr #include #include #include #include "paddle/fluid/operators/jitkernels/jitcode_base.h" #include "paddle/fluid/operators/jitkernels/kernel_base.h" #include "paddle/fluid/operators/jitkernels/kernel_key.h" #ifdef PADDLE_WITH_XBYAK #include "paddle/fluid/operators/jitkernels/jitcode/jitcode.h" #endif namespace paddle { namespace operators { namespace jitkernels { template class JitCodePool { public: static JitCodePool& Instance() { static thread_local JitCodePool g_jit_codes; return g_jit_codes; } std::shared_ptr Get(size_t key) const { if (codes_.find(key) == codes_.end()) { return nullptr; } return codes_.at(key); } void Insert(size_t key, const std::shared_ptr& value) { codes_.insert({key, value}); } private: JitCodePool() = default; std::unordered_map> codes_; DISABLE_COPY_AND_ASSIGN(JitCodePool); }; // std::tuple template struct KernelAttr { typedef T data_type; typedef Func return_type; typedef Attr attr_type; }; class KernelPool { public: static KernelPool& Instance(); typedef std::unique_ptr KernelPtr; typedef std::unordered_map, KernelKey::Hash> KernelMap; KernelMap& AllKernels() { return pool_; } void Insert(const KernelKey& key, KernelPtr value) { if (pool_.find(key) == pool_.end()) { pool_.emplace(key, std::vector()); } pool_.at(key).emplace_back(std::move(value)); } KernelPool() = default; private: KernelMap pool_; DISABLE_COPY_AND_ASSIGN(KernelPool); }; // TODO(TJ): create_jitcode; // TODO(TJ): make tuple? named KernelAttr template Func Get(Attr attr) { size_t key = GetKey(attr); auto jitcode = JitCodePool().Instance().Get(key); if (jitcode) { return jitcode->template getCode(); } #ifdef PADDLE_WITH_XBYAK // // jitcode::JitCode is under protection of PADDLE_WITH_XBYAK // if (std::is_same::value) { // if (UseJitCode(attr)) { // std::shared_ptr p(std::make_shared>( // attr, CodeSize(attr))); // JitCodePool().Instance().Insert(key, p); // return p->getCode(); // } // } #endif // (KernelKey(type, place), vector) auto& pool = KernelPool().Instance().AllKernels(); KernelKey kkey(KT, PlaceType()); auto iter = pool.find(kkey); if (iter != pool.end()) { auto impls = iter->second; for (auto impl : impls) { auto i = std::dynamic_pointer_cast>(impl.get()); if (i && i->UseMe(attr)) { return i->GetFunc(); } } } // The last implementation should be reference function on CPU // Every kernel should have refer code. // because of test refer should have it's own pool // PADDLE_ENFORCE_GT(list.size(), 1) << "Should have refer implemtation"; // const auto& refer = KernelRefer().AllKernels(); // return refer.Get(); return nullptr; } } // namespace jitkernels } // namespace operators } // namespace paddle