Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
53709e7e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
53709e7e
编写于
12月 06, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine names
上级
ce674b68
变更
24
隐藏空白更改
内联
并排
Showing
24 changed file
with
96 addition
and
95 deletion
+96
-95
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+2
-2
paddle/fluid/operators/jit/CMakeLists.txt
paddle/fluid/operators/jit/CMakeLists.txt
+3
-3
paddle/fluid/operators/jit/README.md
paddle/fluid/operators/jit/README.md
+1
-1
paddle/fluid/operators/jit/gen/CMakeLists.txt
paddle/fluid/operators/jit/gen/CMakeLists.txt
+0
-0
paddle/fluid/operators/jit/gen/blas.cc
paddle/fluid/operators/jit/gen/blas.cc
+9
-8
paddle/fluid/operators/jit/gen/blas.h
paddle/fluid/operators/jit/gen/blas.h
+5
-5
paddle/fluid/operators/jit/gen/jitcode.cc
paddle/fluid/operators/jit/gen/jitcode.cc
+5
-5
paddle/fluid/operators/jit/gen/jitcode.h
paddle/fluid/operators/jit/gen/jitcode.h
+6
-6
paddle/fluid/operators/jit/gen_base.cc
paddle/fluid/operators/jit/gen_base.cc
+4
-4
paddle/fluid/operators/jit/gen_base.h
paddle/fluid/operators/jit/gen_base.h
+5
-5
paddle/fluid/operators/jit/kernel_base.h
paddle/fluid/operators/jit/kernel_base.h
+2
-2
paddle/fluid/operators/jit/kernel_key.h
paddle/fluid/operators/jit/kernel_key.h
+3
-3
paddle/fluid/operators/jit/kernel_pool.cc
paddle/fluid/operators/jit/kernel_pool.cc
+3
-3
paddle/fluid/operators/jit/kernel_pool.h
paddle/fluid/operators/jit/kernel_pool.h
+10
-10
paddle/fluid/operators/jit/more/CMakeLists.txt
paddle/fluid/operators/jit/more/CMakeLists.txt
+0
-0
paddle/fluid/operators/jit/more/mkl/CMakeLists.txt
paddle/fluid/operators/jit/more/mkl/CMakeLists.txt
+0
-0
paddle/fluid/operators/jit/more/mkl/mkl.cc
paddle/fluid/operators/jit/more/mkl/mkl.cc
+5
-5
paddle/fluid/operators/jit/more/mkl/mkl.h
paddle/fluid/operators/jit/more/mkl/mkl.h
+3
-3
paddle/fluid/operators/jit/more/more.h
paddle/fluid/operators/jit/more/more.h
+0
-0
paddle/fluid/operators/jit/refer/CMakeLists.txt
paddle/fluid/operators/jit/refer/CMakeLists.txt
+0
-0
paddle/fluid/operators/jit/refer/refer.cc
paddle/fluid/operators/jit/refer/refer.cc
+3
-3
paddle/fluid/operators/jit/refer/refer.h
paddle/fluid/operators/jit/refer/refer.h
+3
-3
paddle/fluid/operators/jit/registry.h
paddle/fluid/operators/jit/registry.h
+21
-21
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+3
-3
未找到文件。
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
53709e7e
...
...
@@ -16,7 +16,7 @@ add_subdirectory(metrics)
add_subdirectory
(
optimizers
)
add_subdirectory
(
reduce_ops
)
add_subdirectory
(
sequence_ops
)
add_subdirectory
(
jit
kernels
)
add_subdirectory
(
jit
)
if
(
WITH_DISTRIBUTE
)
add_subdirectory
(
distributed
)
...
...
@@ -68,7 +68,7 @@ set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_ten
if
(
NOT WIN32
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
dynload_warpctc
)
endif
()
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence_padding sequence_scale cos_sim_functor memory jit_kernel concat_and_split cross_entropy softmax vol2col im2col sampler
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence_padding sequence_scale cos_sim_functor memory jit_kernel
_helper
concat_and_split cross_entropy softmax vol2col im2col sampler
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions
)
if
(
WITH_GPU
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
depthwise_conv prelu
)
...
...
paddle/fluid/operators/jit
kernels
/CMakeLists.txt
→
paddle/fluid/operators/jit/CMakeLists.txt
浏览文件 @
53709e7e
...
...
@@ -14,8 +14,8 @@ cc_library(jit_kernel_base SRCS ${jit_kernel_cc_srcs} DEPS ${JIT_KERNEL_DEPS})
add_subdirectory
(
refer
)
add_subdirectory
(
more
)
if
(
WITH_XBYAK
)
add_subdirectory
(
jitcode
)
add_subdirectory
(
gen
)
endif
()
cc_library
(
jit_kernel SRCS
${
jit_kernel_cc_srcs
}
DEPS
${
JIT_KERNEL_DEPS
}
)
cc_test
(
jit_kernel_test SRCS test.cc DEPS jit_kernel
)
cc_library
(
jit_kernel
_helper
SRCS
${
jit_kernel_cc_srcs
}
DEPS
${
JIT_KERNEL_DEPS
}
)
cc_test
(
jit_kernel_test SRCS test.cc DEPS jit_kernel
_helper
)
paddle/fluid/operators/jit
kernels
/README.md
→
paddle/fluid/operators/jit/README.md
浏览文件 @
53709e7e
...
...
@@ -13,7 +13,7 @@ PaddlePaddle/Paddle/paddle/fluid/
│ ├── .../
└── jit/
├── ...
├──
jitcode
/
├──
gen
/
│ └── ...
|── more/
│ ├── ...
...
...
paddle/fluid/operators/jit
kernels/jitcode
/CMakeLists.txt
→
paddle/fluid/operators/jit
/gen
/CMakeLists.txt
浏览文件 @
53709e7e
文件已移动
paddle/fluid/operators/jit
kernels/jitcode
/blas.cc
→
paddle/fluid/operators/jit
/gen
/blas.cc
浏览文件 @
53709e7e
...
...
@@ -11,13 +11,14 @@
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jitkernels/jitcode/blas.h"
#include "paddle/fluid/operators/jitkernels/registry.h"
#include "paddle/fluid/operators/jit/gen/blas.h"
#include "paddle/fluid/operators/jit/registry.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jitcode
{
namespace
jit
{
namespace
gen
{
void
VXXJitCode
::
genCode
()
{
// do not need push stack, and do not need save avx512reg if do not use avx512
...
...
@@ -102,17 +103,17 @@ void VXXJitCode::genCode() {
ret
();
}
}
// namespace
jitcode
}
// namespace
gen
template
<
>
std
::
unique_ptr
<
Jit
Base
>
CreateJitCode
<
KernelType
::
vmul
,
float
,
int
>
(
int
attr
)
{
std
::
unique_ptr
<
Gen
Base
>
CreateJitCode
<
KernelType
::
vmul
,
float
,
int
>
(
int
attr
)
{
if
(
UseJitCode
<
KernelType
::
vmul
,
float
,
int
>
(
attr
))
{
return
make_unique
<
jitcode
::
VMulJitCode
>
(
return
make_unique
<
gen
::
VMulJitCode
>
(
attr
,
CodeSize
<
KernelType
::
vmul
,
float
,
int
>
(
attr
));
}
return
nullptr
;
}
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels/jitcode
/blas.h
→
paddle/fluid/operators/jit
/gen
/blas.h
浏览文件 @
53709e7e
...
...
@@ -15,12 +15,12 @@
#pragma once
#include <string>
#include "paddle/fluid/operators/jit
kernels/jitcode
/jitcode.h"
#include "paddle/fluid/operators/jit
/gen
/jitcode.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jitcode
{
namespace
jit
{
namespace
gen
{
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class
VXXJitCode
:
public
JitCode
{
...
...
@@ -82,7 +82,7 @@ class VMulJitCode : public VXXJitCode {
:
VXXJitCode
(
d
,
operand_type
::
mul
,
0
,
false
,
code_size
,
code_ptr
)
{}
};
}
// namespace
jitcode
}
// namespace jit
kernels
}
// namespace
gen
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels/jitcode
/jitcode.cc
→
paddle/fluid/operators/jit
/gen
/jitcode.cc
浏览文件 @
53709e7e
...
...
@@ -12,11 +12,11 @@
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit
kernels/jitcode
/jitcode.h"
#include "paddle/fluid/operators/jit
/gen
/jitcode.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jit
{
template
<
>
size_t
GetKey
<
int
>
(
int
d
)
{
...
...
@@ -24,15 +24,15 @@ size_t GetKey<int>(int d) {
}
// template <>
// std::shared_ptr<const
Jit
Base> CreateJitCode<KernelType::vmul, int>(int attr)
// std::shared_ptr<const
Gen
Base> CreateJitCode<KernelType::vmul, int>(int attr)
// {
// if (UseJitCode<KernelType::vmul, int>(attr)) {
// return std::make_shared<
jitcode
::VMulJitCode<int>>(attr,
// return std::make_shared<
gen
::VMulJitCode<int>>(attr,
// CodeSize<KernelType::vmul, int>(attr)));
// }
// return nullptr;
// }
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels/jitcode
/jitcode.h
→
paddle/fluid/operators/jit
/gen
/jitcode.h
浏览文件 @
53709e7e
...
...
@@ -15,7 +15,7 @@
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/jit
kernels/jitcode
_base.h"
#include "paddle/fluid/operators/jit
/gen
_base.h"
#include "paddle/fluid/platform/cpu_info.h"
#define XBYAK_USE_MMAP_ALLOCATOR
...
...
@@ -24,8 +24,8 @@
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jitcode
{
namespace
jit
{
namespace
gen
{
// Application Binary Interface
constexpr
Xbyak
::
Operand
::
Code
abi_param1
(
Xbyak
::
Operand
::
RDI
),
...
...
@@ -67,7 +67,7 @@ typedef enum {
#define DECLARE_JIT_CODE(codename) \
const char* name() const override { return #codename; }
class
JitCode
:
public
Jit
Base
,
public
Xbyak
::
CodeGenerator
{
class
JitCode
:
public
Gen
Base
,
public
Xbyak
::
CodeGenerator
{
public:
explicit
JitCode
(
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
:
Xbyak
::
CodeGenerator
(
code_size
,
code_ptr
)
{
...
...
@@ -128,7 +128,7 @@ class JitCode : public JitBase, public Xbyak::CodeGenerator {
}
};
}
// namespace
jitcode
}
// namespace jit
kernels
}
// namespace
gen
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels/jitcode
_base.cc
→
paddle/fluid/operators/jit
/gen
_base.cc
浏览文件 @
53709e7e
...
...
@@ -12,7 +12,7 @@
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit
kernels/jitcode
_base.h"
#include "paddle/fluid/operators/jit
/gen
_base.h"
#include <fstream>
#include <iostream>
#include <sstream>
...
...
@@ -21,10 +21,10 @@ DEFINE_bool(dump_jitcode, false, "Whether to dump the jitcode to file");
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jit
{
// refer do not need useme, it would be the last one.
void
Jit
Base
::
dumpCode
(
const
unsigned
char
*
code
)
const
{
void
Gen
Base
::
dumpCode
(
const
unsigned
char
*
code
)
const
{
if
(
code
)
{
static
int
counter
=
0
;
std
::
ostringstream
filename
;
...
...
@@ -38,6 +38,6 @@ void JitBase::dumpCode(const unsigned char* code) const {
}
}
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels/jitcode
_base.h
→
paddle/fluid/operators/jit
/gen
_base.h
浏览文件 @
53709e7e
...
...
@@ -16,14 +16,14 @@
#include <gflags/gflags.h>
#include <memory> // for shared_ptr
#include "paddle/fluid/operators/jit
kernels
/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/macros.h"
DECLARE_bool
(
dump_jitcode
);
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jit
{
// TODO(TJ): make these functions as virtual of a class
...
...
@@ -43,7 +43,7 @@ bool UseJitCode(Attr attr) {
template
<
typename
Attr
>
size_t
GetKey
(
Attr
attr
);
class
Jit
Base
:
public
Kernel
{
class
Gen
Base
:
public
Kernel
{
public:
virtual
const
char
*
name
()
const
=
0
;
virtual
const
unsigned
char
*
getCodeInternal
()
=
0
;
...
...
@@ -62,8 +62,8 @@ class JitBase : public Kernel {
};
template
<
KernelType
KT
,
typename
T
,
typename
Attr
>
std
::
unique_ptr
<
Jit
Base
>
CreateJitCode
(
Attr
attr
);
std
::
unique_ptr
<
Gen
Base
>
CreateJitCode
(
Attr
attr
);
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels
/kernel_base.h
→
paddle/fluid/operators/jit/kernel_base.h
浏览文件 @
53709e7e
...
...
@@ -17,7 +17,7 @@
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jit
{
typedef
enum
{
vmul
=
0
,
vadd
=
1
,
vsub
,
vexp
}
KernelType
;
...
...
@@ -54,6 +54,6 @@ class ReferKernel : public KernelImpl<T, Func, Attr> {
bool
UseMe
(
Attr
attr
)
const
override
{
return
true
;
}
};
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels
/kernel_key.h
→
paddle/fluid/operators/jit/kernel_key.h
浏览文件 @
53709e7e
...
...
@@ -13,12 +13,12 @@
* limitations under the License. */
#pragma once
#include "paddle/fluid/operators/jit
kernels
/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jit
{
struct
KernelKey
{
struct
Hash
{
...
...
@@ -44,6 +44,6 @@ struct KernelKey {
bool
operator
!=
(
const
KernelKey
&
o
)
const
{
return
!
(
*
this
==
o
);
}
};
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels
/kernel_pool.cc
→
paddle/fluid/operators/jit/kernel_pool.cc
浏览文件 @
53709e7e
...
...
@@ -12,14 +12,14 @@
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit
kernels
/kernel_pool.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
#include <memory> // for shared_ptr
#include <string>
#include <unordered_map>
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jit
{
KernelPool
&
KernelPool
::
Instance
()
{
static
KernelPool
g_kernel_pool
;
...
...
@@ -31,6 +31,6 @@ ReferKernelPool& ReferKernelPool::Instance() {
return
g_refer_kernel_pool
;
}
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels
/kernel_pool.h
→
paddle/fluid/operators/jit/kernel_pool.h
浏览文件 @
53709e7e
...
...
@@ -18,19 +18,19 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/jit
kernels/jitcode
_base.h"
#include "paddle/fluid/operators/jit
kernels
/kernel_base.h"
#include "paddle/fluid/operators/jit
kernels
/kernel_key.h"
#include "paddle/fluid/operators/jit
/gen
_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_key.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jit
{
template
<
KernelType
KT
>
class
JitCodePool
{
typedef
std
::
unique_ptr
<
JitBase
>
Jit
BasePtr
;
typedef
std
::
unordered_map
<
size_t
,
JitBasePtr
>
JitBas
eMap
;
typedef
std
::
unique_ptr
<
GenBase
>
Gen
BasePtr
;
typedef
std
::
unordered_map
<
size_t
,
GenBasePtr
>
JitCod
eMap
;
public:
JitCodePool
()
=
default
;
...
...
@@ -39,16 +39,16 @@ class JitCodePool {
return
g_jit_codes
;
}
const
Jit
Bas
eMap
&
AllKernels
()
{
return
codes_
;
}
const
Jit
Cod
eMap
&
AllKernels
()
{
return
codes_
;
}
bool
Has
(
size_t
key
)
const
{
return
codes_
.
find
(
key
)
!=
codes_
.
end
();
}
void
Insert
(
size_t
key
,
Jit
BasePtr
value
)
{
void
Insert
(
size_t
key
,
Gen
BasePtr
value
)
{
codes_
.
emplace
(
key
,
std
::
move
(
value
));
}
private:
Jit
Bas
eMap
codes_
;
Jit
Cod
eMap
codes_
;
DISABLE_COPY_AND_ASSIGN
(
JitCodePool
);
};
...
...
@@ -146,6 +146,6 @@ const Func Get(Attr attr) {
return
GetRefer
<
KT
,
T
,
Func
,
Attr
>
();
}
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels
/more/CMakeLists.txt
→
paddle/fluid/operators/jit/more/CMakeLists.txt
浏览文件 @
53709e7e
文件已移动
paddle/fluid/operators/jit
kernels
/more/mkl/CMakeLists.txt
→
paddle/fluid/operators/jit/more/mkl/CMakeLists.txt
浏览文件 @
53709e7e
文件已移动
paddle/fluid/operators/jit
kernels
/more/mkl/mkl.cc
→
paddle/fluid/operators/jit/more/mkl/mkl.cc
浏览文件 @
53709e7e
...
...
@@ -12,13 +12,13 @@
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit
kernels
/more/mkl/mkl.h"
#include "paddle/fluid/operators/jit
kernels
/registry.h"
#include "paddle/fluid/operators/jit/more/mkl/mkl.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/dynload/mklml.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jit
{
namespace
more
{
namespace
mkl
{
...
...
@@ -34,11 +34,11 @@ void VMul<double>(const double* x, const double* y, double* z, int n) {
}
// namespace mkl
}
// namespace more
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
namespace
mkl
=
paddle
::
operators
::
jit
kernels
::
more
::
mkl
;
namespace
mkl
=
paddle
::
operators
::
jit
::
more
::
mkl
;
REGISTER_JITKERNEL_MORE
(
vmul
,
mkl
,
mkl
::
VMulKernel
<
float
>
,
mkl
::
VMulKernel
<
double
>
);
paddle/fluid/operators/jit
kernels
/more/mkl/mkl.h
→
paddle/fluid/operators/jit/more/mkl/mkl.h
浏览文件 @
53709e7e
...
...
@@ -15,12 +15,12 @@
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/jit
kernels
/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jit
{
namespace
more
{
namespace
mkl
{
...
...
@@ -43,6 +43,6 @@ class VMulKernel : public KernelImpl<T, typename VMulTypes<T>::func_type,
}
// namespace mkl
}
// namespace more
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels
/more/more.h
→
paddle/fluid/operators/jit/more/more.h
浏览文件 @
53709e7e
文件已移动
paddle/fluid/operators/jit
kernels
/refer/CMakeLists.txt
→
paddle/fluid/operators/jit/refer/CMakeLists.txt
浏览文件 @
53709e7e
文件已移动
paddle/fluid/operators/jit
kernels
/refer/refer.cc
→
paddle/fluid/operators/jit/refer/refer.cc
浏览文件 @
53709e7e
...
...
@@ -12,10 +12,10 @@
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit
kernels
/refer/refer.h"
#include "paddle/fluid/operators/jit
kernels
/registry.h"
#include "paddle/fluid/operators/jit/refer/refer.h"
#include "paddle/fluid/operators/jit/registry.h"
namespace
refer
=
paddle
::
operators
::
jit
kernels
::
refer
;
namespace
refer
=
paddle
::
operators
::
jit
::
refer
;
REGISTER_JITKERNEL_REFER
(
vmul
,
refer
::
VMulKernel
<
float
>
,
refer
::
VMulKernel
<
double
>
);
paddle/fluid/operators/jit
kernels
/refer/refer.h
→
paddle/fluid/operators/jit/refer/refer.h
浏览文件 @
53709e7e
...
...
@@ -13,12 +13,12 @@
* limitations under the License. */
#pragma once
#include "paddle/fluid/operators/jit
kernels
/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jit
{
namespace
refer
{
template
<
typename
T
>
...
...
@@ -36,6 +36,6 @@ class VMulKernel : public ReferKernel<T, typename VMulTypes<T>::func_type,
};
}
// namespace refer
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels
/registry.h
→
paddle/fluid/operators/jit/registry.h
浏览文件 @
53709e7e
...
...
@@ -17,14 +17,14 @@
#include <memory>
#include <tuple>
#include <type_traits>
#include "paddle/fluid/operators/jit
kernels
/kernel_base.h"
#include "paddle/fluid/operators/jit
kernels
/kernel_pool.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h" // for UNUSED
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jit
{
// make_unique is supported since c++14
template
<
typename
T
,
typename
...
Args
>
...
...
@@ -76,21 +76,21 @@ class JitKernelRegistrar {
msg)
// Refer always on CPUPlace
#define REGISTER_JITKERNEL_REFER(kernel_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_refer_CPUPlace, \
"REGISTER_KERNEL_REFER must be called in global namespace"); \
static ::paddle::operators::jit
kernels::JitKernelRegistrar<
\
::paddle::operators::jit
kernels::ReferKernelPool,
\
::paddle::platform::CPUPlace, __VA_ARGS__>
\
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_( \
::paddle::operators::jit
kernels::KernelType::kernel_type);
\
int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() { \
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_.Touch(); \
return 0; \
#define REGISTER_JITKERNEL_REFER(kernel_type, ...)
\
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(
\
__reg_jitkernel_##kernel_type##_refer_CPUPlace,
\
"REGISTER_KERNEL_REFER must be called in global namespace");
\
static ::paddle::operators::jit
::JitKernelRegistrar<
\
::paddle::operators::jit
::ReferKernelPool, ::paddle::platform::CPUPlace,
\
__VA_ARGS__>
\
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_(
\
::paddle::operators::jit
::KernelType::kernel_type);
\
int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() {
\
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_.Touch();
\
return 0;
\
}
// kernel_type: should be in paddle::operators::jit
kernels
::KernelType
// kernel_type: should be in paddle::operators::jit::KernelType
// place_type: should be one of CPUPlace and GPUPlace in paddle::platform
#define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
...
...
@@ -99,11 +99,11 @@ class JitKernelRegistrar {
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int __assert_##kernel_type##_##impl_type##_##place_type##_has_refer_ \
UNUSED = TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static ::paddle::operators::jit
kernels::JitKernelRegistrar<
\
::paddle::operators::jit
kernels::KernelPool,
\
::paddle::platform::place_type, __VA_ARGS__>
\
static ::paddle::operators::jit
::JitKernelRegistrar<
\
::paddle::operators::jit
::KernelPool, ::paddle::platform::place_type,
\
__VA_ARGS__>
\
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_( \
::paddle::operators::jit
kernels::KernelType::kernel_type);
\
::paddle::operators::jit
::KernelType::kernel_type);
\
int TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() { \
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_ \
.Touch(); \
...
...
@@ -139,6 +139,6 @@ class JitKernelRegistrar {
#define USE_JITKERNEL_MORE(kernel_type, impl_type) \
USE_KERNEL_MORE(kernel_type, impl_type, CPUPlace)
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels
/test.cc
→
paddle/fluid/operators/jit/test.cc
浏览文件 @
53709e7e
...
...
@@ -19,9 +19,9 @@
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/operators/jit
kernels
/kernel_pool.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
// TODO(TJ): remove me
#include "paddle/fluid/operators/jit
kernels
/registry.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/port.h"
...
...
@@ -66,7 +66,7 @@ TEST(JitKernel, vmul) {
using
T
=
float
;
using
PlaceType
=
paddle
::
platform
::
CPUPlace
;
namespace
jit
=
paddle
::
operators
::
jit
kernels
;
namespace
jit
=
paddle
::
operators
::
jit
;
// TODO(TJ): test more vector size
for
(
int
d
=
1
;
d
<
30
;
++
d
)
{
auto
ref
=
jit
::
GetRefer
<
jit
::
vmul
,
T
,
jit
::
VMulTypes
<
T
>::
func_type
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录