Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
53709e7e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
53709e7e
编写于
12月 06, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine names
上级
ce674b68
变更
24
显示空白变更内容
内联
并排
Showing
24 changed file
with
96 addition
and
95 deletion
+96
-95
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+2
-2
paddle/fluid/operators/jit/CMakeLists.txt
paddle/fluid/operators/jit/CMakeLists.txt
+3
-3
paddle/fluid/operators/jit/README.md
paddle/fluid/operators/jit/README.md
+1
-1
paddle/fluid/operators/jit/gen/CMakeLists.txt
paddle/fluid/operators/jit/gen/CMakeLists.txt
+0
-0
paddle/fluid/operators/jit/gen/blas.cc
paddle/fluid/operators/jit/gen/blas.cc
+9
-8
paddle/fluid/operators/jit/gen/blas.h
paddle/fluid/operators/jit/gen/blas.h
+5
-5
paddle/fluid/operators/jit/gen/jitcode.cc
paddle/fluid/operators/jit/gen/jitcode.cc
+5
-5
paddle/fluid/operators/jit/gen/jitcode.h
paddle/fluid/operators/jit/gen/jitcode.h
+6
-6
paddle/fluid/operators/jit/gen_base.cc
paddle/fluid/operators/jit/gen_base.cc
+4
-4
paddle/fluid/operators/jit/gen_base.h
paddle/fluid/operators/jit/gen_base.h
+5
-5
paddle/fluid/operators/jit/kernel_base.h
paddle/fluid/operators/jit/kernel_base.h
+2
-2
paddle/fluid/operators/jit/kernel_key.h
paddle/fluid/operators/jit/kernel_key.h
+3
-3
paddle/fluid/operators/jit/kernel_pool.cc
paddle/fluid/operators/jit/kernel_pool.cc
+3
-3
paddle/fluid/operators/jit/kernel_pool.h
paddle/fluid/operators/jit/kernel_pool.h
+10
-10
paddle/fluid/operators/jit/more/CMakeLists.txt
paddle/fluid/operators/jit/more/CMakeLists.txt
+0
-0
paddle/fluid/operators/jit/more/mkl/CMakeLists.txt
paddle/fluid/operators/jit/more/mkl/CMakeLists.txt
+0
-0
paddle/fluid/operators/jit/more/mkl/mkl.cc
paddle/fluid/operators/jit/more/mkl/mkl.cc
+5
-5
paddle/fluid/operators/jit/more/mkl/mkl.h
paddle/fluid/operators/jit/more/mkl/mkl.h
+3
-3
paddle/fluid/operators/jit/more/more.h
paddle/fluid/operators/jit/more/more.h
+0
-0
paddle/fluid/operators/jit/refer/CMakeLists.txt
paddle/fluid/operators/jit/refer/CMakeLists.txt
+0
-0
paddle/fluid/operators/jit/refer/refer.cc
paddle/fluid/operators/jit/refer/refer.cc
+3
-3
paddle/fluid/operators/jit/refer/refer.h
paddle/fluid/operators/jit/refer/refer.h
+3
-3
paddle/fluid/operators/jit/registry.h
paddle/fluid/operators/jit/registry.h
+21
-21
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+3
-3
未找到文件。
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
53709e7e
...
...
@@ -16,7 +16,7 @@ add_subdirectory(metrics)
add_subdirectory
(
optimizers
)
add_subdirectory
(
reduce_ops
)
add_subdirectory
(
sequence_ops
)
add_subdirectory
(
jit
kernels
)
add_subdirectory
(
jit
)
if
(
WITH_DISTRIBUTE
)
add_subdirectory
(
distributed
)
...
...
@@ -68,7 +68,7 @@ set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_ten
if
(
NOT WIN32
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
dynload_warpctc
)
endif
()
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence_padding sequence_scale cos_sim_functor memory jit_kernel concat_and_split cross_entropy softmax vol2col im2col sampler
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence_padding sequence_scale cos_sim_functor memory jit_kernel
_helper
concat_and_split cross_entropy softmax vol2col im2col sampler
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions
)
if
(
WITH_GPU
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
depthwise_conv prelu
)
...
...
paddle/fluid/operators/jit
kernels
/CMakeLists.txt
→
paddle/fluid/operators/jit/CMakeLists.txt
浏览文件 @
53709e7e
...
...
@@ -14,8 +14,8 @@ cc_library(jit_kernel_base SRCS ${jit_kernel_cc_srcs} DEPS ${JIT_KERNEL_DEPS})
add_subdirectory
(
refer
)
add_subdirectory
(
more
)
if
(
WITH_XBYAK
)
add_subdirectory
(
jitcode
)
add_subdirectory
(
gen
)
endif
()
cc_library
(
jit_kernel SRCS
${
jit_kernel_cc_srcs
}
DEPS
${
JIT_KERNEL_DEPS
}
)
cc_test
(
jit_kernel_test SRCS test.cc DEPS jit_kernel
)
cc_library
(
jit_kernel
_helper
SRCS
${
jit_kernel_cc_srcs
}
DEPS
${
JIT_KERNEL_DEPS
}
)
cc_test
(
jit_kernel_test SRCS test.cc DEPS jit_kernel
_helper
)
paddle/fluid/operators/jit
kernels
/README.md
→
paddle/fluid/operators/jit/README.md
浏览文件 @
53709e7e
...
...
@@ -13,7 +13,7 @@ PaddlePaddle/Paddle/paddle/fluid/
│ ├── .../
└── jit/
├── ...
├──
jitcode
/
├──
gen
/
│ └── ...
|── more/
│ ├── ...
...
...
paddle/fluid/operators/jit
kernels/jitcode
/CMakeLists.txt
→
paddle/fluid/operators/jit
/gen
/CMakeLists.txt
浏览文件 @
53709e7e
文件已移动
paddle/fluid/operators/jit
kernels/jitcode
/blas.cc
→
paddle/fluid/operators/jit
/gen
/blas.cc
浏览文件 @
53709e7e
...
...
@@ -11,13 +11,14 @@
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jitkernels/jitcode/blas.h"
#include "paddle/fluid/operators/jitkernels/registry.h"
#include "paddle/fluid/operators/jit/gen/blas.h"
#include "paddle/fluid/operators/jit/registry.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jitcode
{
namespace
jit
{
namespace
gen
{
void
VXXJitCode
::
genCode
()
{
// do not need push stack, and do not need save avx512reg if do not use avx512
...
...
@@ -102,17 +103,17 @@ void VXXJitCode::genCode() {
ret
();
}
}
// namespace
jitcode
}
// namespace
gen
template
<
>
std
::
unique_ptr
<
Jit
Base
>
CreateJitCode
<
KernelType
::
vmul
,
float
,
int
>
(
int
attr
)
{
std
::
unique_ptr
<
Gen
Base
>
CreateJitCode
<
KernelType
::
vmul
,
float
,
int
>
(
int
attr
)
{
if
(
UseJitCode
<
KernelType
::
vmul
,
float
,
int
>
(
attr
))
{
return
make_unique
<
jitcode
::
VMulJitCode
>
(
return
make_unique
<
gen
::
VMulJitCode
>
(
attr
,
CodeSize
<
KernelType
::
vmul
,
float
,
int
>
(
attr
));
}
return
nullptr
;
}
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels/jitcode
/blas.h
→
paddle/fluid/operators/jit
/gen
/blas.h
浏览文件 @
53709e7e
...
...
@@ -15,12 +15,12 @@
#pragma once
#include <string>
#include "paddle/fluid/operators/jit
kernels/jitcode
/jitcode.h"
#include "paddle/fluid/operators/jit
/gen
/jitcode.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jitcode
{
namespace
jit
{
namespace
gen
{
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class
VXXJitCode
:
public
JitCode
{
...
...
@@ -82,7 +82,7 @@ class VMulJitCode : public VXXJitCode {
:
VXXJitCode
(
d
,
operand_type
::
mul
,
0
,
false
,
code_size
,
code_ptr
)
{}
};
}
// namespace
jitcode
}
// namespace jit
kernels
}
// namespace
gen
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels/jitcode
/jitcode.cc
→
paddle/fluid/operators/jit
/gen
/jitcode.cc
浏览文件 @
53709e7e
...
...
@@ -12,11 +12,11 @@
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit
kernels/jitcode
/jitcode.h"
#include "paddle/fluid/operators/jit
/gen
/jitcode.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jit
{
template
<
>
size_t
GetKey
<
int
>
(
int
d
)
{
...
...
@@ -24,15 +24,15 @@ size_t GetKey<int>(int d) {
}
// template <>
// std::shared_ptr<const
Jit
Base> CreateJitCode<KernelType::vmul, int>(int attr)
// std::shared_ptr<const
Gen
Base> CreateJitCode<KernelType::vmul, int>(int attr)
// {
// if (UseJitCode<KernelType::vmul, int>(attr)) {
// return std::make_shared<
jitcode
::VMulJitCode<int>>(attr,
// return std::make_shared<
gen
::VMulJitCode<int>>(attr,
// CodeSize<KernelType::vmul, int>(attr)));
// }
// return nullptr;
// }
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels/jitcode
/jitcode.h
→
paddle/fluid/operators/jit
/gen
/jitcode.h
浏览文件 @
53709e7e
...
...
@@ -15,7 +15,7 @@
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/jit
kernels/jitcode
_base.h"
#include "paddle/fluid/operators/jit
/gen
_base.h"
#include "paddle/fluid/platform/cpu_info.h"
#define XBYAK_USE_MMAP_ALLOCATOR
...
...
@@ -24,8 +24,8 @@
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jitcode
{
namespace
jit
{
namespace
gen
{
// Application Binary Interface
constexpr
Xbyak
::
Operand
::
Code
abi_param1
(
Xbyak
::
Operand
::
RDI
),
...
...
@@ -67,7 +67,7 @@ typedef enum {
#define DECLARE_JIT_CODE(codename) \
const char* name() const override { return #codename; }
class
JitCode
:
public
Jit
Base
,
public
Xbyak
::
CodeGenerator
{
class
JitCode
:
public
Gen
Base
,
public
Xbyak
::
CodeGenerator
{
public:
explicit
JitCode
(
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
:
Xbyak
::
CodeGenerator
(
code_size
,
code_ptr
)
{
...
...
@@ -128,7 +128,7 @@ class JitCode : public JitBase, public Xbyak::CodeGenerator {
}
};
}
// namespace
jitcode
}
// namespace jit
kernels
}
// namespace
gen
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels/jitcode
_base.cc
→
paddle/fluid/operators/jit
/gen
_base.cc
浏览文件 @
53709e7e
...
...
@@ -12,7 +12,7 @@
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit
kernels/jitcode
_base.h"
#include "paddle/fluid/operators/jit
/gen
_base.h"
#include <fstream>
#include <iostream>
#include <sstream>
...
...
@@ -21,10 +21,10 @@ DEFINE_bool(dump_jitcode, false, "Whether to dump the jitcode to file");
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jit
{
// refer do not need useme, it would be the last one.
void
Jit
Base
::
dumpCode
(
const
unsigned
char
*
code
)
const
{
void
Gen
Base
::
dumpCode
(
const
unsigned
char
*
code
)
const
{
if
(
code
)
{
static
int
counter
=
0
;
std
::
ostringstream
filename
;
...
...
@@ -38,6 +38,6 @@ void JitBase::dumpCode(const unsigned char* code) const {
}
}
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels/jitcode
_base.h
→
paddle/fluid/operators/jit
/gen
_base.h
浏览文件 @
53709e7e
...
...
@@ -16,14 +16,14 @@
#include <gflags/gflags.h>
#include <memory> // for shared_ptr
#include "paddle/fluid/operators/jit
kernels
/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/macros.h"
DECLARE_bool
(
dump_jitcode
);
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jit
{
// TODO(TJ): make these functions as virtual of a class
...
...
@@ -43,7 +43,7 @@ bool UseJitCode(Attr attr) {
template
<
typename
Attr
>
size_t
GetKey
(
Attr
attr
);
class
Jit
Base
:
public
Kernel
{
class
Gen
Base
:
public
Kernel
{
public:
virtual
const
char
*
name
()
const
=
0
;
virtual
const
unsigned
char
*
getCodeInternal
()
=
0
;
...
...
@@ -62,8 +62,8 @@ class JitBase : public Kernel {
};
template
<
KernelType
KT
,
typename
T
,
typename
Attr
>
std
::
unique_ptr
<
Jit
Base
>
CreateJitCode
(
Attr
attr
);
std
::
unique_ptr
<
Gen
Base
>
CreateJitCode
(
Attr
attr
);
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels
/kernel_base.h
→
paddle/fluid/operators/jit/kernel_base.h
浏览文件 @
53709e7e
...
...
@@ -17,7 +17,7 @@
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jit
{
typedef
enum
{
vmul
=
0
,
vadd
=
1
,
vsub
,
vexp
}
KernelType
;
...
...
@@ -54,6 +54,6 @@ class ReferKernel : public KernelImpl<T, Func, Attr> {
bool
UseMe
(
Attr
attr
)
const
override
{
return
true
;
}
};
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels
/kernel_key.h
→
paddle/fluid/operators/jit/kernel_key.h
浏览文件 @
53709e7e
...
...
@@ -13,12 +13,12 @@
* limitations under the License. */
#pragma once
#include "paddle/fluid/operators/jit
kernels
/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jit
{
struct
KernelKey
{
struct
Hash
{
...
...
@@ -44,6 +44,6 @@ struct KernelKey {
bool
operator
!=
(
const
KernelKey
&
o
)
const
{
return
!
(
*
this
==
o
);
}
};
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels
/kernel_pool.cc
→
paddle/fluid/operators/jit/kernel_pool.cc
浏览文件 @
53709e7e
...
...
@@ -12,14 +12,14 @@
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit
kernels
/kernel_pool.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
#include <memory> // for shared_ptr
#include <string>
#include <unordered_map>
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jit
{
KernelPool
&
KernelPool
::
Instance
()
{
static
KernelPool
g_kernel_pool
;
...
...
@@ -31,6 +31,6 @@ ReferKernelPool& ReferKernelPool::Instance() {
return
g_refer_kernel_pool
;
}
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels
/kernel_pool.h
→
paddle/fluid/operators/jit/kernel_pool.h
浏览文件 @
53709e7e
...
...
@@ -18,19 +18,19 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/jit
kernels/jitcode
_base.h"
#include "paddle/fluid/operators/jit
kernels
/kernel_base.h"
#include "paddle/fluid/operators/jit
kernels
/kernel_key.h"
#include "paddle/fluid/operators/jit
/gen
_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_key.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jit
{
template
<
KernelType
KT
>
class
JitCodePool
{
typedef
std
::
unique_ptr
<
JitBase
>
Jit
BasePtr
;
typedef
std
::
unordered_map
<
size_t
,
JitBasePtr
>
JitBas
eMap
;
typedef
std
::
unique_ptr
<
GenBase
>
Gen
BasePtr
;
typedef
std
::
unordered_map
<
size_t
,
GenBasePtr
>
JitCod
eMap
;
public:
JitCodePool
()
=
default
;
...
...
@@ -39,16 +39,16 @@ class JitCodePool {
return
g_jit_codes
;
}
const
Jit
Bas
eMap
&
AllKernels
()
{
return
codes_
;
}
const
Jit
Cod
eMap
&
AllKernels
()
{
return
codes_
;
}
bool
Has
(
size_t
key
)
const
{
return
codes_
.
find
(
key
)
!=
codes_
.
end
();
}
void
Insert
(
size_t
key
,
Jit
BasePtr
value
)
{
void
Insert
(
size_t
key
,
Gen
BasePtr
value
)
{
codes_
.
emplace
(
key
,
std
::
move
(
value
));
}
private:
Jit
Bas
eMap
codes_
;
Jit
Cod
eMap
codes_
;
DISABLE_COPY_AND_ASSIGN
(
JitCodePool
);
};
...
...
@@ -146,6 +146,6 @@ const Func Get(Attr attr) {
return
GetRefer
<
KT
,
T
,
Func
,
Attr
>
();
}
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels
/more/CMakeLists.txt
→
paddle/fluid/operators/jit/more/CMakeLists.txt
浏览文件 @
53709e7e
文件已移动
paddle/fluid/operators/jit
kernels
/more/mkl/CMakeLists.txt
→
paddle/fluid/operators/jit/more/mkl/CMakeLists.txt
浏览文件 @
53709e7e
文件已移动
paddle/fluid/operators/jit
kernels
/more/mkl/mkl.cc
→
paddle/fluid/operators/jit/more/mkl/mkl.cc
浏览文件 @
53709e7e
...
...
@@ -12,13 +12,13 @@
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit
kernels
/more/mkl/mkl.h"
#include "paddle/fluid/operators/jit
kernels
/registry.h"
#include "paddle/fluid/operators/jit/more/mkl/mkl.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/dynload/mklml.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jit
{
namespace
more
{
namespace
mkl
{
...
...
@@ -34,11 +34,11 @@ void VMul<double>(const double* x, const double* y, double* z, int n) {
}
// namespace mkl
}
// namespace more
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
namespace
mkl
=
paddle
::
operators
::
jit
kernels
::
more
::
mkl
;
namespace
mkl
=
paddle
::
operators
::
jit
::
more
::
mkl
;
REGISTER_JITKERNEL_MORE
(
vmul
,
mkl
,
mkl
::
VMulKernel
<
float
>
,
mkl
::
VMulKernel
<
double
>
);
paddle/fluid/operators/jit
kernels
/more/mkl/mkl.h
→
paddle/fluid/operators/jit/more/mkl/mkl.h
浏览文件 @
53709e7e
...
...
@@ -15,12 +15,12 @@
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/jit
kernels
/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jit
{
namespace
more
{
namespace
mkl
{
...
...
@@ -43,6 +43,6 @@ class VMulKernel : public KernelImpl<T, typename VMulTypes<T>::func_type,
}
// namespace mkl
}
// namespace more
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels
/more/more.h
→
paddle/fluid/operators/jit/more/more.h
浏览文件 @
53709e7e
文件已移动
paddle/fluid/operators/jit
kernels
/refer/CMakeLists.txt
→
paddle/fluid/operators/jit/refer/CMakeLists.txt
浏览文件 @
53709e7e
文件已移动
paddle/fluid/operators/jit
kernels
/refer/refer.cc
→
paddle/fluid/operators/jit/refer/refer.cc
浏览文件 @
53709e7e
...
...
@@ -12,10 +12,10 @@
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit
kernels
/refer/refer.h"
#include "paddle/fluid/operators/jit
kernels
/registry.h"
#include "paddle/fluid/operators/jit/refer/refer.h"
#include "paddle/fluid/operators/jit/registry.h"
namespace
refer
=
paddle
::
operators
::
jit
kernels
::
refer
;
namespace
refer
=
paddle
::
operators
::
jit
::
refer
;
REGISTER_JITKERNEL_REFER
(
vmul
,
refer
::
VMulKernel
<
float
>
,
refer
::
VMulKernel
<
double
>
);
paddle/fluid/operators/jit
kernels
/refer/refer.h
→
paddle/fluid/operators/jit/refer/refer.h
浏览文件 @
53709e7e
...
...
@@ -13,12 +13,12 @@
* limitations under the License. */
#pragma once
#include "paddle/fluid/operators/jit
kernels
/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jit
{
namespace
refer
{
template
<
typename
T
>
...
...
@@ -36,6 +36,6 @@ class VMulKernel : public ReferKernel<T, typename VMulTypes<T>::func_type,
};
}
// namespace refer
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels
/registry.h
→
paddle/fluid/operators/jit/registry.h
浏览文件 @
53709e7e
...
...
@@ -17,14 +17,14 @@
#include <memory>
#include <tuple>
#include <type_traits>
#include "paddle/fluid/operators/jit
kernels
/kernel_base.h"
#include "paddle/fluid/operators/jit
kernels
/kernel_pool.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h" // for UNUSED
namespace
paddle
{
namespace
operators
{
namespace
jit
kernels
{
namespace
jit
{
// make_unique is supported since c++14
template
<
typename
T
,
typename
...
Args
>
...
...
@@ -80,17 +80,17 @@ class JitKernelRegistrar {
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_refer_CPUPlace, \
"REGISTER_KERNEL_REFER must be called in global namespace"); \
static ::paddle::operators::jit
kernels::JitKernelRegistrar<
\
::paddle::operators::jit
kernels::ReferKernelPool,
\
::paddle::platform::CPUPlace, __VA_ARGS__>
\
static ::paddle::operators::jit
::JitKernelRegistrar<
\
::paddle::operators::jit
::ReferKernelPool, ::paddle::platform::CPUPlace,
\
__VA_ARGS__>
\
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_( \
::paddle::operators::jit
kernels::KernelType::kernel_type);
\
::paddle::operators::jit
::KernelType::kernel_type);
\
int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() { \
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_.Touch(); \
return 0; \
}
// kernel_type: should be in paddle::operators::jit
kernels
::KernelType
// kernel_type: should be in paddle::operators::jit::KernelType
// place_type: should be one of CPUPlace and GPUPlace in paddle::platform
#define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
...
...
@@ -99,11 +99,11 @@ class JitKernelRegistrar {
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int __assert_##kernel_type##_##impl_type##_##place_type##_has_refer_ \
UNUSED = TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static ::paddle::operators::jit
kernels::JitKernelRegistrar<
\
::paddle::operators::jit
kernels::KernelPool,
\
::paddle::platform::place_type, __VA_ARGS__>
\
static ::paddle::operators::jit
::JitKernelRegistrar<
\
::paddle::operators::jit
::KernelPool, ::paddle::platform::place_type,
\
__VA_ARGS__>
\
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_( \
::paddle::operators::jit
kernels::KernelType::kernel_type);
\
::paddle::operators::jit
::KernelType::kernel_type);
\
int TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() { \
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_ \
.Touch(); \
...
...
@@ -139,6 +139,6 @@ class JitKernelRegistrar {
#define USE_JITKERNEL_MORE(kernel_type, impl_type) \
USE_KERNEL_MORE(kernel_type, impl_type, CPUPlace)
}
// namespace jit
kernels
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit
kernels
/test.cc
→
paddle/fluid/operators/jit/test.cc
浏览文件 @
53709e7e
...
...
@@ -19,9 +19,9 @@
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/operators/jit
kernels
/kernel_pool.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
// TODO(TJ): remove me
#include "paddle/fluid/operators/jit
kernels
/registry.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/port.h"
...
...
@@ -66,7 +66,7 @@ TEST(JitKernel, vmul) {
using
T
=
float
;
using
PlaceType
=
paddle
::
platform
::
CPUPlace
;
namespace
jit
=
paddle
::
operators
::
jit
kernels
;
namespace
jit
=
paddle
::
operators
::
jit
;
// TODO(TJ): test more vector size
for
(
int
d
=
1
;
d
<
30
;
++
d
)
{
auto
ref
=
jit
::
GetRefer
<
jit
::
vmul
,
T
,
jit
::
VMulTypes
<
T
>::
func_type
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录