Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
191948c9
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
191948c9
编写于
12月 05, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enable jitcode
上级
45bfa70c
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
342 addition
and
50 deletion
+342
-50
paddle/fluid/operators/jitkernels/CMakeLists.txt
paddle/fluid/operators/jitkernels/CMakeLists.txt
+1
-1
paddle/fluid/operators/jitkernels/jitcode/CMakeLists.txt
paddle/fluid/operators/jitkernels/jitcode/CMakeLists.txt
+3
-1
paddle/fluid/operators/jitkernels/jitcode/blas.cc
paddle/fluid/operators/jitkernels/jitcode/blas.cc
+118
-0
paddle/fluid/operators/jitkernels/jitcode/blas.h
paddle/fluid/operators/jitkernels/jitcode/blas.h
+88
-0
paddle/fluid/operators/jitkernels/jitcode/jitcode.h
paddle/fluid/operators/jitkernels/jitcode/jitcode.h
+87
-8
paddle/fluid/operators/jitkernels/jitcode_base.cc
paddle/fluid/operators/jitkernels/jitcode_base.cc
+4
-1
paddle/fluid/operators/jitkernels/jitcode_base.h
paddle/fluid/operators/jitkernels/jitcode_base.h
+10
-9
paddle/fluid/operators/jitkernels/kernels.h
paddle/fluid/operators/jitkernels/kernels.h
+30
-29
paddle/fluid/platform/cpu_info.h
paddle/fluid/platform/cpu_info.h
+1
-1
未找到文件。
paddle/fluid/operators/jitkernels/CMakeLists.txt
浏览文件 @
191948c9
...
@@ -7,7 +7,7 @@
...
@@ -7,7 +7,7 @@
set
(
JIT_KERNEL_DEPS cpu_info cblas gflags enforce place
)
set
(
JIT_KERNEL_DEPS cpu_info cblas gflags enforce place
)
cc_library
(
jit_kernel_base SRCS kernels.cc DEPS
${
JIT_KERNEL_DEPS
}
)
cc_library
(
jit_kernel_base SRCS kernels.cc
jitcode_base.cc
DEPS
${
JIT_KERNEL_DEPS
}
)
add_subdirectory
(
refer
)
add_subdirectory
(
refer
)
add_subdirectory
(
more
)
add_subdirectory
(
more
)
...
...
paddle/fluid/operators/jitkernels/jitcode/CMakeLists.txt
浏览文件 @
191948c9
cc_library
(
jit_kernel_jitcode SRCS jitcode.cc DEPS jit_kernel_base xbyak
)
file
(
GLOB jitcode_cc_srcs RELATIVE
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
"*.cc"
)
cc_library
(
jit_kernel_jitcode SRCS
${
jitcode_cc_srcs
}
DEPS jit_kernel_base xbyak
)
set
(
JIT_KERNEL_DEPS
${
JIT_KERNEL_DEPS
}
xbyak jit_kernel_jitcode PARENT_SCOPE
)
set
(
JIT_KERNEL_DEPS
${
JIT_KERNEL_DEPS
}
xbyak jit_kernel_jitcode PARENT_SCOPE
)
paddle/fluid/operators/jitkernels/jitcode/blas.cc
0 → 100644
浏览文件 @
191948c9
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jitkernels/jitcode/blas.h"
#include "paddle/fluid/operators/jitkernels/registry.h"
namespace
paddle
{
namespace
operators
{
namespace
jitkernels
{
namespace
jitcode
{
void
VXXJitCode
::
genCode
()
{
// do not need push stack, and do not need save avx512reg if do not use avx512
int
offset
=
0
;
if
(
with_relu_
)
{
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
}
if
(
scalar_index_
==
1
)
{
vbroadcastss
(
ymm_src1
,
ptr
[
param1
]);
}
else
if
(
scalar_index_
==
2
)
{
vbroadcastss
(
ymm_src2
,
ptr
[
param2
]);
}
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
if
(
scalar_index_
!=
1
)
{
vmovups
(
ymm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovups
(
ymm_src2
,
ptr
[
param2
+
offset
]);
}
if
(
type_
==
operand_type
::
mul
)
{
vmulps
(
ymm_dst
,
ymm_src1
,
ymm_src2
);
}
else
if
(
type_
==
operand_type
::
add
)
{
vaddps
(
ymm_dst
,
ymm_src1
,
ymm_src2
);
}
if
(
with_relu_
)
{
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_dst
);
}
vmovups
(
ptr
[
param3
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
int
rest
=
num_
%
YMM_FLOAT_BLOCK
;
while
(
rest
>
0
)
{
int
block
=
XMM_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
block
=
4
;
if
(
scalar_index_
!=
1
)
{
vmovups
(
xmm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovups
(
xmm_src2
,
ptr
[
param2
+
offset
]);
}
}
else
if
(
rest
>=
2
)
{
block
=
2
;
if
(
scalar_index_
!=
1
)
{
vmovq
(
xmm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovq
(
xmm_src2
,
ptr
[
param2
+
offset
]);
}
}
else
{
block
=
1
;
if
(
scalar_index_
!=
1
)
{
vmovss
(
xmm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovss
(
xmm_src2
,
ptr
[
param2
+
offset
]);
}
}
switch
(
type_
)
{
case
operand_type
::
mul
:
vmulps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
break
;
case
operand_type
::
add
:
vaddps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
break
;
default:
break
;
}
if
(
with_relu_
)
{
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_dst
);
}
if
(
rest
>=
4
)
{
vmovups
(
ptr
[
param3
+
offset
],
xmm_dst
);
}
else
if
(
rest
>=
2
)
{
vmovq
(
ptr
[
param3
+
offset
],
xmm_dst
);
}
else
{
vmovss
(
ptr
[
param3
+
offset
],
xmm_dst
);
}
offset
+=
sizeof
(
float
)
*
block
;
rest
-=
block
;
}
ret
();
}
}
// namespace jitcode
template
<
>
std
::
unique_ptr
<
JitBase
>
CreateJitCode
<
KernelType
::
vmul
,
float
,
int
>
(
int
attr
)
{
if
(
UseJitCode
<
KernelType
::
vmul
,
float
,
int
>
(
attr
))
{
return
make_unique
<
jitcode
::
VMulJitCode
>
(
attr
,
CodeSize
<
KernelType
::
vmul
,
float
,
int
>
(
attr
));
}
return
nullptr
;
}
}
// namespace jitkernels
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jitkernels/jitcode/blas.h
0 → 100644
浏览文件 @
191948c9
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/operators/jitkernels/jitcode/jitcode.h"
namespace
paddle
{
namespace
operators
{
namespace
jitkernels
{
namespace
jitcode
{
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class
VXXJitCode
:
public
JitCode
{
public:
const
char
*
name
()
const
override
{
std
::
string
base
=
"VXXJitCode"
;
if
(
scalar_index_
==
1
)
{
base
+=
"_Scalar"
;
}
else
{
base
+=
"_Vec"
;
}
if
(
type_
==
operand_type
::
mul
)
{
base
+=
"_Mul"
;
}
else
if
(
type_
==
operand_type
::
add
)
{
base
+=
"_Add"
;
}
if
(
scalar_index_
==
2
)
{
base
+=
"_Scalar"
;
}
else
{
base
+=
"_Vec"
;
}
base
+=
(
with_relu_
?
"_Relu"
:
""
);
return
base
.
c_str
();
}
explicit
VXXJitCode
(
int
d
,
operand_type
type
,
int
scalar_index
,
bool
with_relu
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
),
type_
(
type
),
scalar_index_
(
scalar_index
),
with_relu_
(
with_relu
)
{}
// static bool init(int d, int scalar_index = 0);
void
genCode
()
override
;
private:
int
num_
;
operand_type
type_
;
int
scalar_index_
;
bool
with_relu_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
reg64_t
param3
{
abi_param3
};
xmm_t
xmm_src1
=
xmm_t
(
0
);
xmm_t
xmm_src2
=
xmm_t
(
1
);
xmm_t
xmm_dst
=
xmm_t
(
2
);
xmm_t
xmm_zero
=
xmm_t
(
3
);
ymm_t
ymm_src1
=
ymm_t
(
0
);
ymm_t
ymm_src2
=
ymm_t
(
1
);
ymm_t
ymm_dst
=
ymm_t
(
2
);
ymm_t
ymm_zero
=
ymm_t
(
3
);
};
class
VMulJitCode
:
public
VXXJitCode
{
public:
explicit
VMulJitCode
(
int
d
,
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
:
VXXJitCode
(
d
,
operand_type
::
mul
,
0
,
false
,
code_size
,
code_ptr
)
{}
};
}
// namespace jitcode
}
// namespace jitkernels
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jitkernels/jitcode/jitcode.h
浏览文件 @
191948c9
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
#include <type_traits>
#include <type_traits>
#include "paddle/fluid/operators/jitkernels/jitcode_base.h"
#include "paddle/fluid/operators/jitkernels/jitcode_base.h"
#include "paddle/fluid/
operators/jitkernels/kernels
.h"
#include "paddle/fluid/
platform/cpu_info
.h"
#define XBYAK_USE_MMAP_ALLOCATOR
#define XBYAK_USE_MMAP_ALLOCATOR
#include "xbyak/xbyak.h"
#include "xbyak/xbyak.h"
...
@@ -30,23 +30,102 @@ namespace jitcode {
...
@@ -30,23 +30,102 @@ namespace jitcode {
// Application Binary Interface
// Application Binary Interface
constexpr
Xbyak
::
Operand
::
Code
abi_param1
(
Xbyak
::
Operand
::
RDI
),
constexpr
Xbyak
::
Operand
::
Code
abi_param1
(
Xbyak
::
Operand
::
RDI
),
abi_param2
(
Xbyak
::
Operand
::
RSI
),
abi_param3
(
Xbyak
::
Operand
::
RDX
),
abi_param2
(
Xbyak
::
Operand
::
RSI
),
abi_param3
(
Xbyak
::
Operand
::
RDX
),
abi_param4
(
Xbyak
::
Operand
::
RCX
)
,
abi_not_param1
(
Xbyak
::
Operand
::
RCX
)
;
abi_param4
(
Xbyak
::
Operand
::
RCX
);
template
<
typename
Attr
>
constexpr
Xbyak
::
Operand
::
Code
g_abi_regs
[]
=
{
class
VMulJitCode
:
public
JitBase
,
public
Xbyak
::
CodeGenerator
{
Xbyak
::
Operand
::
RBX
,
Xbyak
::
Operand
::
RBP
,
Xbyak
::
Operand
::
R12
,
Xbyak
::
Operand
::
R13
,
Xbyak
::
Operand
::
R14
,
Xbyak
::
Operand
::
R15
};
constexpr
int
num_g_abi_regs
=
sizeof
(
g_abi_regs
)
/
sizeof
(
g_abi_regs
[
0
]);
using
reg64_t
=
const
Xbyak
::
Reg64
;
using
reg32_t
=
const
Xbyak
::
Reg32
;
using
xmm_t
=
const
Xbyak
::
Xmm
;
using
ymm_t
=
const
Xbyak
::
Ymm
;
using
zmm_t
=
const
Xbyak
::
Zmm
;
using
Label
=
Xbyak
::
Label
;
typedef
enum
{
mul
=
0
,
add
,
sub
,
relu
,
exp
,
sigmoid
,
tanh
,
identity
}
operand_type
;
#define XMM_FLOAT_BLOCK 4
#define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define DECLARE_JIT_CODE(codename) \
const char* name() const override { return #codename; }
class
JitCode
:
public
JitBase
,
public
Xbyak
::
CodeGenerator
{
public:
public:
VMulJitCode
(
Attr
attr
,
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
explicit
JitCode
(
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
:
Xbyak
::
CodeGenerator
(
code_size
,
code_ptr
)
{
:
Xbyak
::
CodeGenerator
(
code_size
,
code_ptr
)
{
this
->
genCode
();
this
->
genCode
();
}
}
virtual
const
char
*
name
()
const
=
0
;
size_t
getSize
()
const
override
{
return
CodeGenerator
::
getSize
();
}
virtual
void
genCode
()
=
0
;
const
unsigned
char
*
getCodeInternal
()
override
{
const
unsigned
char
*
getCodeInternal
()
override
{
const
Xbyak
::
uint8
*
code
=
CodeGenerator
::
getCode
();
const
Xbyak
::
uint8
*
code
=
CodeGenerator
::
getCode
();
return
code
;
return
code
;
}
}
virtual
const
char
*
name
()
const
=
0
;
virtual
void
genCode
()
=
0
;
protected:
Xbyak
::
Reg64
param1
{
abi_param1
};
const
int
EVEX_max_8b_offt
=
0x200
;
const
Xbyak
::
Reg64
reg_EVEX_max_8b_offt
=
rbp
;
virtual
void
preCode
()
{
for
(
int
i
=
0
;
i
<
num_g_abi_regs
;
++
i
)
{
push
(
Xbyak
::
Reg64
(
g_abi_regs
[
i
]));
}
if
(
platform
::
jit
::
MayIUse
(
platform
::
jit
::
avx512f
))
{
mov
(
reg_EVEX_max_8b_offt
,
2
*
EVEX_max_8b_offt
);
}
}
virtual
void
postCode
()
{
for
(
int
i
=
0
;
i
<
num_g_abi_regs
;
++
i
)
{
pop
(
Xbyak
::
Reg64
(
g_abi_regs
[
num_g_abi_regs
-
1
-
i
]));
}
ret
();
}
void
L
(
const
char
*
label
)
{
Xbyak
::
CodeGenerator
::
L
(
label
);
}
void
L
(
const
Xbyak
::
Label
&
label
)
{
Xbyak
::
CodeGenerator
::
L
(
label
);
}
// Enhanced vector extension
Xbyak
::
Address
EVEX_compress_addr
(
Xbyak
::
Reg64
base
,
int
offt
,
bool
bcast
=
false
)
{
int
scale
=
0
;
// Learn from https://github.com/intel/mkl-dnn
if
(
EVEX_max_8b_offt
<=
offt
&&
offt
<
3
*
EVEX_max_8b_offt
)
{
offt
=
offt
-
2
*
EVEX_max_8b_offt
;
scale
=
1
;
}
else
if
(
3
*
EVEX_max_8b_offt
<=
offt
&&
offt
<
5
*
EVEX_max_8b_offt
)
{
offt
=
offt
-
4
*
EVEX_max_8b_offt
;
scale
=
2
;
}
auto
re
=
Xbyak
::
RegExp
()
+
base
+
offt
;
if
(
scale
)
{
re
=
re
+
reg_EVEX_max_8b_offt
*
scale
;
}
if
(
bcast
)
{
return
zword_b
[
re
];
}
else
{
return
zword
[
re
];
}
}
};
};
}
// namespace jitcode
}
// namespace jitcode
...
...
paddle/fluid/operators/jitkernels/jitcode_base.cc
浏览文件 @
191948c9
...
@@ -13,6 +13,9 @@
...
@@ -13,6 +13,9 @@
* limitations under the License. */
* limitations under the License. */
#include "paddle/fluid/operators/jitkernels/jitcode_base.h"
#include "paddle/fluid/operators/jitkernels/jitcode_base.h"
#include <fstream>
#include <iostream>
#include <sstream>
DEFINE_bool
(
dump_jitcode
,
false
,
"Whether to dump the jitcode to file"
);
DEFINE_bool
(
dump_jitcode
,
false
,
"Whether to dump the jitcode to file"
);
...
@@ -29,7 +32,7 @@ void JitBase::dumpCode(const unsigned char* code) const {
...
@@ -29,7 +32,7 @@ void JitBase::dumpCode(const unsigned char* code) const {
counter
++
;
counter
++
;
std
::
ofstream
fout
(
filename
.
str
(),
std
::
ios
::
out
);
std
::
ofstream
fout
(
filename
.
str
(),
std
::
ios
::
out
);
if
(
fout
.
is_open
())
{
if
(
fout
.
is_open
())
{
fout
.
write
(
reinterpret_cast
<
const
char
*>
(
code
),
getSize
());
fout
.
write
(
reinterpret_cast
<
const
char
*>
(
code
),
this
->
getSize
());
fout
.
close
();
fout
.
close
();
}
}
}
}
...
...
paddle/fluid/operators/jitkernels/jitcode_base.h
浏览文件 @
191948c9
...
@@ -28,7 +28,7 @@ namespace jitkernels {
...
@@ -28,7 +28,7 @@ namespace jitkernels {
// TODO(TJ): make these functions as virtual of a class
// TODO(TJ): make these functions as virtual of a class
// Every JitCode should estimate the code size itself
// Every JitCode should estimate the code size itself
template
<
KernelType
KT
,
typename
Attr
>
template
<
KernelType
KT
,
typename
T
,
typename
Attr
>
size_t
CodeSize
(
Attr
attr
)
{
size_t
CodeSize
(
Attr
attr
)
{
return
4096
;
return
4096
;
}
}
...
@@ -43,13 +43,11 @@ bool UseJitCode(Attr attr) {
...
@@ -43,13 +43,11 @@ bool UseJitCode(Attr attr) {
template
<
typename
Attr
>
template
<
typename
Attr
>
size_t
GetKey
(
Attr
attr
);
size_t
GetKey
(
Attr
attr
);
class
JitBase
{
class
JitBase
:
public
Kernel
{
public:
public:
JitBase
()
=
default
;
virtual
~
JitBase
()
=
default
;
virtual
const
char
*
name
()
const
=
0
;
virtual
const
char
*
name
()
const
=
0
;
virtual
const
unsigned
char
*
getCodeInternal
()
=
0
;
virtual
const
unsigned
char
*
getCodeInternal
()
=
0
;
virtual
size_t
getSize
()
const
=
0
;
template
<
typename
FUNC
>
template
<
typename
FUNC
>
const
FUNC
getCode
()
{
const
FUNC
getCode
()
{
const
unsigned
char
*
code
=
this
->
getCodeInternal
();
const
unsigned
char
*
code
=
this
->
getCodeInternal
();
...
@@ -58,14 +56,17 @@ class JitBase {
...
@@ -58,14 +56,17 @@ class JitBase {
}
}
return
reinterpret_cast
<
const
FUNC
>
(
code
);
return
reinterpret_cast
<
const
FUNC
>
(
code
);
}
}
DISABLE_COPY_AND_ASSIGN
(
JitBase
);
protected:
protected:
void
dumpCode
(
const
unsigned
char
*
code
);
void
dumpCode
(
const
unsigned
char
*
code
)
const
;
};
};
template
<
KernelType
KT
,
typename
Attr
>
template
<
KernelType
KT
,
typename
T
,
typename
Attr
>
std
::
shared_ptr
<
const
JitBase
>
CreateJitCode
(
Attr
attr
);
std
::
unique_ptr
<
JitBase
>
CreateJitCode
(
Attr
attr
);
//{
// if (UseJitCode<KT,T,Attr>) {
// return make_unique<xxxxclass>(attr, CodeSize<KT,T,Attr>());
// }
// }
}
// namespace jitkernels
}
// namespace jitkernels
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/jitkernels/kernels.h
浏览文件 @
191948c9
...
@@ -31,6 +31,9 @@ namespace jitkernels {
...
@@ -31,6 +31,9 @@ namespace jitkernels {
template
<
KernelType
KT
>
template
<
KernelType
KT
>
class
JitCodePool
{
class
JitCodePool
{
typedef
std
::
unique_ptr
<
JitBase
>
JitBasePtr
;
typedef
std
::
unordered_map
<
size_t
,
JitBasePtr
>
JitBaseMap
;
public:
public:
JitCodePool
()
=
default
;
JitCodePool
()
=
default
;
static
JitCodePool
&
Instance
()
{
static
JitCodePool
&
Instance
()
{
...
@@ -38,29 +41,26 @@ class JitCodePool {
...
@@ -38,29 +41,26 @@ class JitCodePool {
return
g_jit_codes
;
return
g_jit_codes
;
}
}
std
::
shared_ptr
<
const
JitBase
>
Get
(
size_t
key
)
const
{
const
JitBaseMap
&
AllKernels
()
{
return
codes_
;
}
if
(
codes_
.
find
(
key
)
==
codes_
.
end
())
{
return
nullptr
;
bool
Has
(
size_t
key
)
const
{
return
codes_
.
find
(
key
)
!=
codes_
.
end
();
}
}
return
codes_
.
at
(
key
);
}
void
Insert
(
size_t
key
,
const
std
::
shared_ptr
<
const
JitBase
>&
value
)
{
void
Insert
(
size_t
key
,
JitBasePtr
value
)
{
codes_
.
insert
({
key
,
value
}
);
codes_
.
emplace
(
key
,
std
::
move
(
value
)
);
}
}
private:
private:
std
::
unordered_map
<
size_t
,
std
::
shared_ptr
<
const
JitBase
>>
codes_
;
JitBaseMap
codes_
;
DISABLE_COPY_AND_ASSIGN
(
JitCodePool
);
DISABLE_COPY_AND_ASSIGN
(
JitCodePool
);
};
};
// TODO(TJ): std::tuple<T, Func, Attr>
// TODO(TJ): std::tuple<T, Func, Attr>
template
<
typename
T
,
typename
Func
,
typename
Attr
>
//
template <typename T, typename Func, typename Attr>
struct
KernelAttr
{
//
struct KernelAttr {
typedef
T
data_type
;
//
typedef T data_type;
typedef
Func
return_type
;
//
typedef Func return_type;
typedef
Attr
attr_type
;
//
typedef Attr attr_type;
};
//
};
typedef
std
::
unique_ptr
<
const
Kernel
>
KernelPtr
;
typedef
std
::
unique_ptr
<
const
Kernel
>
KernelPtr
;
typedef
std
::
unordered_map
<
KernelKey
,
std
::
vector
<
KernelPtr
>
,
KernelKey
::
Hash
>
typedef
std
::
unordered_map
<
KernelKey
,
std
::
vector
<
KernelPtr
>
,
KernelKey
::
Hash
>
...
@@ -123,20 +123,21 @@ inline Func GetRefer() {
...
@@ -123,20 +123,21 @@ inline Func GetRefer() {
// TODO(TJ): make tuple? named KernelAttr
// TODO(TJ): make tuple? named KernelAttr
template
<
KernelType
KT
,
typename
T
,
typename
Func
,
typename
Attr
,
template
<
KernelType
KT
,
typename
T
,
typename
Func
,
typename
Attr
,
typename
PlaceType
=
platform
::
CPUPlace
>
typename
PlaceType
=
platform
::
CPUPlace
>
Func
Get
(
Attr
attr
)
{
const
Func
Get
(
Attr
attr
)
{
// size_t key = GetKey<Attr>(attr);
size_t
key
=
GetKey
<
Attr
>
(
attr
);
// auto jitcode = JitCodePool<KT>().Instance().Get(key);
auto
&
codes
=
JitCodePool
<
KT
>
().
Instance
();
// if (jitcode) {
if
(
codes
.
Has
(
key
))
{
// return jitcode->template getCode<Func>();
return
codes
.
AllKernels
().
at
(
key
)
->
template
getCode
<
Func
>();
// }
}
if
(
std
::
is_same
<
PlaceType
,
platform
::
CPUPlace
>::
value
&&
if
(
std
::
is_same
<
PlaceType
,
platform
::
CPUPlace
>::
value
)
{
// TODO(TJ): float
std
::
is_same
<
T
,
float
>::
value
)
{
// TODO(TJ): float move to create
// move to create
// auto p = CreateJitCode<KT, Attr>(attr);
auto
p
=
CreateJitCode
<
KT
,
T
,
Attr
>
(
attr
);
// if (p) {
if
(
p
)
{
// JitCodePool<KT>().Instance().Insert(key, p);
auto
f
=
p
->
template
getCode
<
Func
>();
// return p->template getCode<Func>();
codes
.
Insert
(
key
,
std
::
move
(
p
));
// }
return
f
;
}
}
}
// pool: (KernelKey(type, place), vector<Kernel>)
// pool: (KernelKey(type, place), vector<Kernel>)
...
...
paddle/fluid/platform/cpu_info.h
浏览文件 @
191948c9
...
@@ -39,7 +39,7 @@ size_t CUDAPinnedMinChunkSize();
...
@@ -39,7 +39,7 @@ size_t CUDAPinnedMinChunkSize();
//! Get the maximum chunk size for buddy allocator.
//! Get the maximum chunk size for buddy allocator.
size_t
CUDAPinnedMaxChunkSize
();
size_t
CUDAPinnedMaxChunkSize
();
namespace
jit
{
// remove this namespace
namespace
jit
{
typedef
enum
{
typedef
enum
{
isa_any
,
isa_any
,
sse42
,
sse42
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录