提交 d4cab7d9 编写于 作者: T tensor-tang

use jitkernel in one file

上级 adc7ba2e
# set(use_jit_file ${PADDLE_BINARY_DIR}/paddle/fluid/operators/jit/kernels.h)
# file(WRITE ${pass_file} "// Generated by the paddle/fluid/framework/ir/CMakeLists.txt. DO NOT EDIT!\n\n")
# file(APPEND ${pass_file} "\#pragma once\n")
# file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
set(jit_file ${PADDLE_BINARY_DIR}/paddle/fluid/operators/jit/kernels.h)
file(WRITE ${jit_file} "// Generated by the paddle/fluid/operators/jit/CMakeLists.txt. DO NOT EDIT!\n\n")
file(APPEND ${jit_file} "\#pragma once\n")
file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/helper.h\"\n")
file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/registry.h\"\n\n")
set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place)
file(GLOB jit_kernel_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
list(REMOVE_ITEM jit_kernel_cc_srcs jit_test.cc)
list(REMOVE_ITEM jit_kernel_cc_srcs test.cc)
cc_library(jit_kernel_base SRCS ${jit_kernel_cc_srcs} DEPS ${JIT_KERNEL_DEPS})
# refer must go first
add_subdirectory(refer)
add_subdirectory(more)
if(WITH_XBYAK)
......
......@@ -3,3 +3,10 @@ file(GLOB jitcode_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
cc_library(jit_kernel_jitcode SRCS ${jitcode_cc_srcs} DEPS jit_kernel_base xbyak)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} xbyak jit_kernel_jitcode PARENT_SCOPE)
function(USE_JITKERNEL_GEN TARGET)
file(APPEND ${jit_file} "USE_JITKERNEL_GEN(${TARGET});\n")
endfunction()
# use gen jitcode kernel by name
USE_JITKERNEL_GEN(vmul)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace paddle {
namespace operators {
namespace jit {} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <memory> // for unique_ptr
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_key.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
namespace jit {
// Refer code do not related with attr, and always on CPUPlace
template <KernelType KT, typename T, typename Func, typename Attr>
inline Func GetRefer() {
auto& ref_pool = ReferKernelPool().Instance().AllKernels();
KernelKey kkey(KT, platform::CPUPlace());
auto ref_iter = ref_pool.find(kkey);
PADDLE_ENFORCE(ref_iter != ref_pool.end(),
"Every Kernel should have reference function.");
auto& ref_impls = ref_iter->second;
for (auto& impl : ref_impls) {
auto i = dynamic_cast<const ReferKernel<T, Func, Attr>*>(impl.get());
if (i) {
return i->GetFunc();
}
}
return nullptr;
}
template <KernelType KT, typename T, typename Func, typename Attr,
typename PlaceType = platform::CPUPlace>
const Func Get(Attr attr) {
size_t key = JitCodeKey<Attr>(attr);
auto& codes = JitCodePool<KT>().Instance();
if (codes.Has(key)) {
return codes.AllKernels().at(key)->template getCode<Func>();
}
KernelKey kkey(KT, PlaceType());
if (std::is_same<PlaceType, platform::CPUPlace>::value) {
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
auto iter = creator_map.find(kkey);
auto& creators = iter->second;
for (auto& cur : creators) {
auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get());
if (i && i->UseMe(attr)) {
auto p = i->CreateJitCode(attr);
if (p) {
auto f = p->template getCode<Func>();
codes.Insert(key, std::move(p));
return f;
}
}
}
}
// pool: (KernelKey(type, place), vector<KernelPtr>)
auto& pool = KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const KernelImpl<T, Func, Attr>*>(impl.get());
if (i && i->UseMe(attr)) {
return i->GetFunc();
}
}
}
// The last implementation should be reference function on CPUPlace.
return GetRefer<KT, T, Func, Attr>();
}
} // namespace jit
} // namespace operators
} // namespace paddle
......@@ -22,7 +22,7 @@ namespace jit {
typedef enum { vmul = 0, vadd = 1, vsub, vexp } KernelType;
template <typename T>
struct VMulTypes {
struct VMulTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int);
......
......@@ -114,69 +114,6 @@ class ReferKernelPool {
DISABLE_COPY_AND_ASSIGN(ReferKernelPool);
};
// Refer code do not related with attr, and always on CPUPlace
template <KernelType KT, typename T, typename Func, typename Attr>
inline Func GetRefer() {
auto& ref_pool = ReferKernelPool().Instance().AllKernels();
KernelKey kkey(KT, platform::CPUPlace());
auto ref_iter = ref_pool.find(kkey);
PADDLE_ENFORCE(ref_iter != ref_pool.end(),
"Every Kernel should have reference function.");
auto& ref_impls = ref_iter->second;
for (auto& impl : ref_impls) {
auto i = dynamic_cast<const ReferKernel<T, Func, Attr>*>(impl.get());
if (i) {
return i->GetFunc();
}
}
return nullptr;
}
template <KernelType KT, typename T, typename Func, typename Attr,
typename PlaceType = platform::CPUPlace>
const Func Get(Attr attr) {
size_t key = JitCodeKey<Attr>(attr);
auto& codes = JitCodePool<KT>().Instance();
if (codes.Has(key)) {
return codes.AllKernels().at(key)->template getCode<Func>();
}
KernelKey kkey(KT, PlaceType());
if (std::is_same<PlaceType, platform::CPUPlace>::value) {
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
auto iter = creator_map.find(kkey);
auto& creators = iter->second;
for (auto& cur : creators) {
auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get());
if (i && i->UseMe(attr)) {
auto p = i->CreateJitCode(attr);
if (p) {
auto f = p->template getCode<Func>();
codes.Insert(key, std::move(p));
return f;
}
}
}
}
// pool: (KernelKey(type, place), vector<KernelPtr>)
auto& pool = KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const KernelImpl<T, Func, Attr>*>(impl.get());
if (i && i->UseMe(attr)) {
return i->GetFunc();
}
}
}
// The last implementation should be reference function on CPUPlace.
return GetRefer<KT, T, Func, Attr>();
}
} // namespace jit
} // namespace operators
} // namespace paddle
function(USE_JITKERNEL_MORE TARGET TYPE)
file(APPEND ${jit_file} "USE_JITKERNEL_MORE(${TARGET} ${TYPE});\n")
endfunction()
if(WITH_MKLML)
add_subdirectory(mkl)
......
cc_library(jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE)
# use mkl kernels by name and type
USE_JITKERNEL_MORE(vmul, mkl)
......@@ -28,8 +28,8 @@ template <typename T>
void VMul(const T* x, const T* y, T* z, int n);
template <typename T>
class VMulKernel : public KernelImpl<T, typename VMulTypes<T>::func_type,
typename VMulTypes<T>::attr_type> {
class VMulKernel : public KernelImpl<T, typename VMulTuples<T>::func_type,
typename VMulTuples<T>::attr_type> {
public:
VMulKernel() { this->func = VMul<T>; }
bool UseMe(int d) const override {
......
cc_library(jit_kernel_refer SRCS refer.cc DEPS jit_kernel_base)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_refer PARENT_SCOPE)
function(USE_JITKERNEL_REFER TARGET)
file(APPEND ${jit_file} "USE_JITKERNEL_REFER(${TARGET});\n")
endfunction()
# use refer kernel by name
USE_JITKERNEL_REFER(vmul)
......@@ -29,8 +29,8 @@ void VMul(const T* x, const T* y, T* z, int n) {
}
template <typename T>
class VMulKernel : public ReferKernel<T, typename VMulTypes<T>::func_type,
typename VMulTypes<T>::attr_type> {
class VMulKernel : public ReferKernel<T, typename VMulTuples<T>::func_type,
typename VMulTuples<T>::attr_type> {
public:
VMulKernel() { this->func = VMul<T>; }
};
......
......@@ -19,10 +19,7 @@
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
// TODO(TJ): remove me
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/port.h"
......@@ -58,11 +55,6 @@ void ExpectEQ(const T* target, const T* refer, int n) {
}
}
// TODO(TJ): remove me
USE_JITKERNEL_MORE(vmul, mkl);
USE_JITKERNEL_REFER(vmul);
USE_JITKERNEL_GEN(vmul);
TEST(JitKernel, vmul) {
using T = float;
using PlaceType = paddle::platform::CPUPlace;
......@@ -70,10 +62,10 @@ TEST(JitKernel, vmul) {
namespace jit = paddle::operators::jit;
// TODO(TJ): test more vector size
for (int d = 1; d < 30; ++d) {
auto ref = jit::GetRefer<jit::vmul, T, jit::VMulTypes<T>::func_type,
jit::VMulTypes<T>::attr_type>();
auto tgt = jit::Get<jit::vmul, T, jit::VMulTypes<T>::func_type,
jit::VMulTypes<T>::attr_type, PlaceType>(d);
auto ref = jit::GetRefer<jit::vmul, T, jit::VMulTuples<T>::func_type,
jit::VMulTuples<T>::attr_type>();
auto tgt = jit::Get<jit::vmul, T, jit::VMulTuples<T>::func_type,
jit::VMulTuples<T>::attr_type, PlaceType>(d);
EXPECT_TRUE(ref != nullptr);
EXPECT_TRUE(tgt != nullptr);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册