Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
fd0a954f
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fd0a954f
编写于
12月 14, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enable blas jitcode vmul, vadd, vaddrelu, vscal and vaddbias
上级
5e97be7b
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
52 addition
and
16 deletion
+52
-16
paddle/fluid/operators/jit/gen/CMakeLists.txt
paddle/fluid/operators/jit/gen/CMakeLists.txt
+5
-0
paddle/fluid/operators/jit/gen/blas.cc
paddle/fluid/operators/jit/gen/blas.cc
+27
-11
paddle/fluid/operators/jit/gen/blas.h
paddle/fluid/operators/jit/gen/blas.h
+20
-5
未找到文件。
paddle/fluid/operators/jit/gen/CMakeLists.txt
浏览文件 @
fd0a954f
...
@@ -10,3 +10,8 @@ endfunction()
...
@@ -10,3 +10,8 @@ endfunction()
# use gen jitcode kernel by name
# use gen jitcode kernel by name
USE_JITKERNEL_GEN
(
vmul
)
USE_JITKERNEL_GEN
(
vmul
)
USE_JITKERNEL_GEN
(
vadd
)
#USE_JITKERNEL_GEN(vsub) # TODO(TJ): enable me
USE_JITKERNEL_GEN
(
vaddrelu
)
USE_JITKERNEL_GEN
(
vscal
)
USE_JITKERNEL_GEN
(
vaddbias
)
paddle/fluid/operators/jit/gen/blas.cc
浏览文件 @
fd0a954f
...
@@ -104,18 +104,28 @@ void VXXJitCode::genCode() {
...
@@ -104,18 +104,28 @@ void VXXJitCode::genCode() {
ret
();
ret
();
}
}
class
VMulCreator
:
public
JitCodeCreator
<
int
>
{
#define DECLARE_BLAS_CREATOR(name) \
public:
class name##Creator : public JitCodeCreator<int> { \
bool
UseMe
(
const
int
&
attr
)
const
override
{
public: \
return
platform
::
MayIUse
(
platform
::
avx
);
bool UseMe(const int& attr) const override { \
return platform::MayIUse(platform::avx); \
} \
size_t CodeSize(const int& d) const override { \
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \
} \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
}
size_t
CodeSize
(
const
int
&
d
)
const
override
{
return
96
+
d
/
YMM_FLOAT_BLOCK
*
4
*
8
;
DECLARE_BLAS_CREATOR
(
VMul
);
}
DECLARE_BLAS_CREATOR
(
VAdd
);
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
int
&
attr
)
const
override
{
DECLARE_BLAS_CREATOR
(
VSub
);
return
make_unique
<
VMulJitCode
>
(
attr
,
CodeSize
(
attr
));
DECLARE_BLAS_CREATOR
(
VAddRelu
);
}
DECLARE_BLAS_CREATOR
(
VScal
);
};
DECLARE_BLAS_CREATOR
(
VAddBias
);
#undef DECLARE_BLAS_CREATOR
}
// namespace gen
}
// namespace gen
}
// namespace jit
}
// namespace jit
...
@@ -125,3 +135,9 @@ class VMulCreator : public JitCodeCreator<int> {
...
@@ -125,3 +135,9 @@ class VMulCreator : public JitCodeCreator<int> {
namespace
gen
=
paddle
::
operators
::
jit
::
gen
;
namespace
gen
=
paddle
::
operators
::
jit
::
gen
;
REGISTER_JITKERNEL_GEN
(
vmul
,
gen
::
VMulCreator
);
REGISTER_JITKERNEL_GEN
(
vmul
,
gen
::
VMulCreator
);
REGISTER_JITKERNEL_GEN
(
vadd
,
gen
::
VAddCreator
);
// TODO(TJ): enable sub
// REGISTER_JITKERNEL_GEN(vsub, gen::VSubCreator);
REGISTER_JITKERNEL_GEN
(
vaddrelu
,
gen
::
VAddReluCreator
);
REGISTER_JITKERNEL_GEN
(
vscal
,
gen
::
VScalCreator
);
REGISTER_JITKERNEL_GEN
(
vaddbias
,
gen
::
VAddBiasCreator
);
paddle/fluid/operators/jit/gen/blas.h
浏览文件 @
fd0a954f
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#pragma once
#pragma once
#include <string>
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -33,6 +34,9 @@ class VXXJitCode : public JitCode {
...
@@ -33,6 +34,9 @@ class VXXJitCode : public JitCode {
type_
(
type
),
type_
(
type
),
scalar_index_
(
scalar_index
),
scalar_index_
(
scalar_index
),
with_relu_
(
with_relu
)
{
with_relu_
(
with_relu
)
{
if
(
!
(
type_
==
operand_type
::
mul
||
type_
==
operand_type
::
add
))
{
LOG
(
FATAL
)
<<
"Do not support this operand type: "
<<
type_
;
}
this
->
genCode
();
this
->
genCode
();
}
}
...
@@ -78,11 +82,22 @@ class VXXJitCode : public JitCode {
...
@@ -78,11 +82,22 @@ class VXXJitCode : public JitCode {
ymm_t
ymm_zero
=
ymm_t
(
3
);
ymm_t
ymm_zero
=
ymm_t
(
3
);
};
};
class
VMulJitCode
:
public
VXXJitCode
{
#define DECLARE_BLAS_JITCODE(name, op_type, scalar_idx, with_relu) \
public:
class name##JitCode : public VXXJitCode { \
explicit
VMulJitCode
(
int
d
,
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
public: \
:
VXXJitCode
(
d
,
operand_type
::
mul
,
0
,
false
,
code_size
,
code_ptr
)
{}
explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr) \
};
: VXXJitCode(d, op_type, scalar_idx, with_relu, code_size, code_ptr) { \
} \
};
DECLARE_BLAS_JITCODE
(
VMul
,
operand_type
::
mul
,
0
,
false
);
DECLARE_BLAS_JITCODE
(
VAdd
,
operand_type
::
add
,
0
,
false
);
DECLARE_BLAS_JITCODE
(
VSub
,
operand_type
::
sub
,
0
,
false
);
DECLARE_BLAS_JITCODE
(
VAddRelu
,
operand_type
::
add
,
0
,
true
);
DECLARE_BLAS_JITCODE
(
VScal
,
operand_type
::
mul
,
1
,
false
);
DECLARE_BLAS_JITCODE
(
VAddBias
,
operand_type
::
add
,
1
,
false
);
#undef DECLARE_BLAS_JITCODE
}
// namespace gen
}
// namespace gen
}
// namespace jit
}
// namespace jit
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录