From d4cab7d94890c3bf43d20243ea9f21722b2738a3 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Tue, 11 Dec 2018 07:30:53 +0000 Subject: [PATCH] use jitkernel in one file --- paddle/fluid/operators/jit/CMakeLists.txt | 13 +-- paddle/fluid/operators/jit/gen/CMakeLists.txt | 7 ++ paddle/fluid/operators/jit/gen/jitcode.cc | 21 ---- paddle/fluid/operators/jit/helper.h | 96 +++++++++++++++++++ paddle/fluid/operators/jit/kernel_base.h | 2 +- paddle/fluid/operators/jit/kernel_pool.h | 63 ------------ .../fluid/operators/jit/more/CMakeLists.txt | 3 + .../operators/jit/more/mkl/CMakeLists.txt | 3 + paddle/fluid/operators/jit/more/mkl/mkl.h | 4 +- .../fluid/operators/jit/refer/CMakeLists.txt | 7 ++ paddle/fluid/operators/jit/refer/refer.h | 4 +- paddle/fluid/operators/jit/test.cc | 18 +--- 12 files changed, 133 insertions(+), 108 deletions(-) delete mode 100644 paddle/fluid/operators/jit/gen/jitcode.cc create mode 100644 paddle/fluid/operators/jit/helper.h diff --git a/paddle/fluid/operators/jit/CMakeLists.txt b/paddle/fluid/operators/jit/CMakeLists.txt index 77fd27666f..26903e0e44 100644 --- a/paddle/fluid/operators/jit/CMakeLists.txt +++ b/paddle/fluid/operators/jit/CMakeLists.txt @@ -1,16 +1,17 @@ -# set(use_jit_file ${PADDLE_BINARY_DIR}/paddle/fluid/operators/jit/kernels.h) -# file(WRITE ${pass_file} "// Generated by the paddle/fluid/framework/ir/CMakeLists.txt. DO NOT EDIT!\n\n") -# file(APPEND ${pass_file} "\#pragma once\n") -# file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n") - +set(jit_file ${PADDLE_BINARY_DIR}/paddle/fluid/operators/jit/kernels.h) +file(WRITE ${jit_file} "// Generated by the paddle/fluid/operators/jit/CMakeLists.txt. DO NOT EDIT!\n\n") +file(APPEND ${jit_file} "\#pragma once\n") +file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/helper.h\"\n") +file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/registry.h\"\n\n") set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place) file(GLOB jit_kernel_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc") -list(REMOVE_ITEM jit_kernel_cc_srcs jit_test.cc) +list(REMOVE_ITEM jit_kernel_cc_srcs test.cc) cc_library(jit_kernel_base SRCS ${jit_kernel_cc_srcs} DEPS ${JIT_KERNEL_DEPS}) +# refer must go first add_subdirectory(refer) add_subdirectory(more) if(WITH_XBYAK) diff --git a/paddle/fluid/operators/jit/gen/CMakeLists.txt b/paddle/fluid/operators/jit/gen/CMakeLists.txt index c678ea33b8..98d9231faa 100644 --- a/paddle/fluid/operators/jit/gen/CMakeLists.txt +++ b/paddle/fluid/operators/jit/gen/CMakeLists.txt @@ -3,3 +3,10 @@ file(GLOB jitcode_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc") cc_library(jit_kernel_jitcode SRCS ${jitcode_cc_srcs} DEPS jit_kernel_base xbyak) set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} xbyak jit_kernel_jitcode PARENT_SCOPE) + +function(USE_JITKERNEL_GEN TARGET) + file(APPEND ${jit_file} "USE_JITKERNEL_GEN(${TARGET});\n") +endfunction() + +# use gen jitcode kernel by name +USE_JITKERNEL_GEN(vmul) diff --git a/paddle/fluid/operators/jit/gen/jitcode.cc b/paddle/fluid/operators/jit/gen/jitcode.cc deleted file mode 100644 index 7aaf6a2ff6..0000000000 --- a/paddle/fluid/operators/jit/gen/jitcode.cc +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ - -#include "paddle/fluid/operators/jit/gen/jitcode.h" - -namespace paddle { -namespace operators { -namespace jit {} // namespace jit -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h new file mode 100644 index 0000000000..c8da960a1e --- /dev/null +++ b/paddle/fluid/operators/jit/helper.h @@ -0,0 +1,96 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include // for unique_ptr +#include +#include +#include +#include "paddle/fluid/operators/jit/gen_base.h" +#include "paddle/fluid/operators/jit/kernel_base.h" +#include "paddle/fluid/operators/jit/kernel_key.h" +#include "paddle/fluid/operators/jit/kernel_pool.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace operators { +namespace jit { + +// Refer code do not related with attr, and always on CPUPlace +template +inline Func GetRefer() { + auto& ref_pool = ReferKernelPool().Instance().AllKernels(); + KernelKey kkey(KT, platform::CPUPlace()); + auto ref_iter = ref_pool.find(kkey); + PADDLE_ENFORCE(ref_iter != ref_pool.end(), + "Every Kernel should have reference function."); + auto& ref_impls = ref_iter->second; + for (auto& impl : ref_impls) { + auto i = dynamic_cast*>(impl.get()); + if (i) { + return i->GetFunc(); + } + } + return nullptr; +} + +template +const Func Get(Attr attr) { + size_t key = JitCodeKey(attr); + auto& codes = JitCodePool().Instance(); + if (codes.Has(key)) { + return codes.AllKernels().at(key)->template getCode(); + } + + KernelKey kkey(KT, PlaceType()); + if (std::is_same::value) { + // pool: (KernelKey(type, place), vector) + auto& creator_map = JitCodeCreatorPool().Instance().AllCreators(); + auto iter = creator_map.find(kkey); + auto& creators = iter->second; + for (auto& cur : creators) { + auto i = dynamic_cast*>(cur.get()); + if (i && i->UseMe(attr)) { + auto p = i->CreateJitCode(attr); + if (p) { + auto f = p->template getCode(); + codes.Insert(key, std::move(p)); + return f; + } + } + } + } + + // pool: (KernelKey(type, place), vector) + auto& pool = KernelPool().Instance().AllKernels(); + auto iter = pool.find(kkey); + if (iter != pool.end()) { + auto& impls = iter->second; + for (auto& impl : impls) { + auto i = dynamic_cast*>(impl.get()); + if (i && i->UseMe(attr)) { + return i->GetFunc(); + } + } + } + + // The last implementation should be reference function on CPUPlace. + return GetRefer(); +} + +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index 6a789c52c3..df7be6ab8e 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -22,7 +22,7 @@ namespace jit { typedef enum { vmul = 0, vadd = 1, vsub, vexp } KernelType; template -struct VMulTypes { +struct VMulTuples { typedef T data_type; typedef int attr_type; typedef void (*func_type)(const T*, const T*, T*, int); diff --git a/paddle/fluid/operators/jit/kernel_pool.h b/paddle/fluid/operators/jit/kernel_pool.h index c9e7fc84e5..3e15242af2 100644 --- a/paddle/fluid/operators/jit/kernel_pool.h +++ b/paddle/fluid/operators/jit/kernel_pool.h @@ -114,69 +114,6 @@ class ReferKernelPool { DISABLE_COPY_AND_ASSIGN(ReferKernelPool); }; -// Refer code do not related with attr, and always on CPUPlace -template -inline Func GetRefer() { - auto& ref_pool = ReferKernelPool().Instance().AllKernels(); - KernelKey kkey(KT, platform::CPUPlace()); - auto ref_iter = ref_pool.find(kkey); - PADDLE_ENFORCE(ref_iter != ref_pool.end(), - "Every Kernel should have reference function."); - auto& ref_impls = ref_iter->second; - for (auto& impl : ref_impls) { - auto i = dynamic_cast*>(impl.get()); - if (i) { - return i->GetFunc(); - } - } - return nullptr; -} - -template -const Func Get(Attr attr) { - size_t key = JitCodeKey(attr); - auto& codes = JitCodePool().Instance(); - if (codes.Has(key)) { - return codes.AllKernels().at(key)->template getCode(); - } - - KernelKey kkey(KT, PlaceType()); - if (std::is_same::value) { - // pool: (KernelKey(type, place), vector) - auto& creator_map = JitCodeCreatorPool().Instance().AllCreators(); - auto iter = creator_map.find(kkey); - auto& creators = iter->second; - for (auto& cur : creators) { - auto i = dynamic_cast*>(cur.get()); - if (i && i->UseMe(attr)) { - auto p = i->CreateJitCode(attr); - if (p) { - auto f = p->template getCode(); - codes.Insert(key, std::move(p)); - return f; - } - } - } - } - - // pool: (KernelKey(type, place), vector) - auto& pool = KernelPool().Instance().AllKernels(); - auto iter = pool.find(kkey); - if (iter != pool.end()) { - auto& impls = iter->second; - for (auto& impl : impls) { - auto i = dynamic_cast*>(impl.get()); - if (i && i->UseMe(attr)) { - return i->GetFunc(); - } - } - } - - // The last implementation should be reference function on CPUPlace. - return GetRefer(); -} - } // namespace jit } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/jit/more/CMakeLists.txt b/paddle/fluid/operators/jit/more/CMakeLists.txt index 84f1811ced..5bb78b9304 100644 --- a/paddle/fluid/operators/jit/more/CMakeLists.txt +++ b/paddle/fluid/operators/jit/more/CMakeLists.txt @@ -1,4 +1,7 @@ +function(USE_JITKERNEL_MORE TARGET TYPE) + file(APPEND ${jit_file} "USE_JITKERNEL_MORE(${TARGET} ${TYPE});\n") +endfunction() if(WITH_MKLML) add_subdirectory(mkl) diff --git a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt index 94d2487866..0c15c7060d 100644 --- a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt +++ b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt @@ -1,3 +1,6 @@ cc_library(jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml) set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE) + +# use mkl kernels by name and type +USE_JITKERNEL_MORE(vmul, mkl) diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.h b/paddle/fluid/operators/jit/more/mkl/mkl.h index 45cfec1c47..c0f738cceb 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.h +++ b/paddle/fluid/operators/jit/more/mkl/mkl.h @@ -28,8 +28,8 @@ template void VMul(const T* x, const T* y, T* z, int n); template -class VMulKernel : public KernelImpl::func_type, - typename VMulTypes::attr_type> { +class VMulKernel : public KernelImpl::func_type, + typename VMulTuples::attr_type> { public: VMulKernel() { this->func = VMul; } bool UseMe(int d) const override { diff --git a/paddle/fluid/operators/jit/refer/CMakeLists.txt b/paddle/fluid/operators/jit/refer/CMakeLists.txt index 8c116e42dc..b6ff80d03d 100644 --- a/paddle/fluid/operators/jit/refer/CMakeLists.txt +++ b/paddle/fluid/operators/jit/refer/CMakeLists.txt @@ -1,3 +1,10 @@ cc_library(jit_kernel_refer SRCS refer.cc DEPS jit_kernel_base) set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_refer PARENT_SCOPE) + +function(USE_JITKERNEL_REFER TARGET) + file(APPEND ${jit_file} "USE_JITKERNEL_REFER(${TARGET});\n") +endfunction() + +# use refer kernel by name +USE_JITKERNEL_REFER(vmul) diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index 76a663633d..97aa5de8fc 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -29,8 +29,8 @@ void VMul(const T* x, const T* y, T* z, int n) { } template -class VMulKernel : public ReferKernel::func_type, - typename VMulTypes::attr_type> { +class VMulKernel : public ReferKernel::func_type, + typename VMulTuples::attr_type> { public: VMulKernel() { this->func = VMul; } }; diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index 5af9ed697d..e531ba1a2c 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -19,10 +19,7 @@ #include "gflags/gflags.h" #include "glog/logging.h" #include "gtest/gtest.h" -#include "paddle/fluid/operators/jit/kernel_pool.h" -// TODO(TJ): remove me -#include "paddle/fluid/operators/jit/registry.h" - +#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/port.h" @@ -58,11 +55,6 @@ void ExpectEQ(const T* target, const T* refer, int n) { } } -// TODO(TJ): remove me -USE_JITKERNEL_MORE(vmul, mkl); -USE_JITKERNEL_REFER(vmul); -USE_JITKERNEL_GEN(vmul); - TEST(JitKernel, vmul) { using T = float; using PlaceType = paddle::platform::CPUPlace; @@ -70,10 +62,10 @@ TEST(JitKernel, vmul) { namespace jit = paddle::operators::jit; // TODO(TJ): test more vector size for (int d = 1; d < 30; ++d) { - auto ref = jit::GetRefer::func_type, - jit::VMulTypes::attr_type>(); - auto tgt = jit::Get::func_type, - jit::VMulTypes::attr_type, PlaceType>(d); + auto ref = jit::GetRefer::func_type, + jit::VMulTuples::attr_type>(); + auto tgt = jit::Get::func_type, + jit::VMulTuples::attr_type, PlaceType>(d); EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(tgt != nullptr); -- GitLab