Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
50945685
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
50945685
编写于
1月 28, 2019
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add hmax, hsum jitcode
test=develop
上级
81177258
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
200 addition
and
1 deletion
+200
-1
paddle/fluid/operators/jit/gen/CMakeLists.txt
paddle/fluid/operators/jit/gen/CMakeLists.txt
+2
-0
paddle/fluid/operators/jit/gen/hopv.cc
paddle/fluid/operators/jit/gen/hopv.cc
+103
-0
paddle/fluid/operators/jit/gen/hopv.h
paddle/fluid/operators/jit/gen/hopv.h
+90
-0
paddle/fluid/operators/jit/gen/jitcode.h
paddle/fluid/operators/jit/gen/jitcode.h
+1
-0
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+4
-1
未找到文件。
paddle/fluid/operators/jit/gen/CMakeLists.txt
浏览文件 @
50945685
...
@@ -28,3 +28,5 @@ USE_JITKERNEL_GEN(kGRUHtPart1)
...
@@ -28,3 +28,5 @@ USE_JITKERNEL_GEN(kGRUHtPart1)
USE_JITKERNEL_GEN
(
kGRUHtPart2
)
USE_JITKERNEL_GEN
(
kGRUHtPart2
)
USE_JITKERNEL_GEN
(
kNCHW16CMulNC
)
USE_JITKERNEL_GEN
(
kNCHW16CMulNC
)
USE_JITKERNEL_GEN
(
kSeqPool
)
USE_JITKERNEL_GEN
(
kSeqPool
)
USE_JITKERNEL_GEN
(
kHMax
)
USE_JITKERNEL_GEN
(
kHSum
)
paddle/fluid/operators/jit/gen/hopv.cc
0 → 100644
浏览文件 @
50945685
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/hopv.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
gen
{
void
HOPVJitCode
::
genCode
()
{
const
int
num_blocks
=
num_
/
YMM_FLOAT_BLOCK
;
int
offset
=
0
;
if
(
num_blocks
>
0
)
{
// load one firstly
vmovups
(
ymm_tmp
,
ptr
[
param_src
]);
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
for
(
int
i
=
1
;
i
<
num_blocks
;
++
i
)
{
vmovups
(
ymm_src
,
ptr
[
param_src
+
offset
]);
process
(
ymm_tmp
,
ymm_src
,
ymm_tmp
);
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
vextractf128
(
xmm_dst
,
ymm_tmp
,
1
);
process
(
xmm_dst
,
xmm_dst
,
xmm_tmp
);
}
else
{
if
(
type_
==
operand_type
::
MAX
)
{
vbroadcastss
(
ymm_dst
,
ptr
[
param_src
]);
}
else
if
(
type_
==
operand_type
::
ADD
)
{
vxorps
(
ymm_dst
,
ymm_dst
,
ymm_dst
);
}
}
int
rest
=
num_
%
YMM_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
vmovups
(
xmm_src
,
ptr
[
param_src
+
offset
]);
offset
+=
sizeof
(
float
)
*
4
;
rest
-=
4
;
process
(
xmm_dst
,
xmm_dst
,
xmm_src
);
}
vpermilps
(
xmm_tmp
,
xmm_dst
,
16
+
8
+
3
);
process
(
xmm_dst
,
xmm_dst
,
xmm_tmp
);
if
(
rest
>=
2
)
{
vmovq
(
xmm_src
,
ptr
[
param_src
+
offset
]);
offset
+=
sizeof
(
float
)
*
2
;
rest
-=
2
;
process
(
xmm_dst
,
xmm_dst
,
xmm_src
);
}
vpermilps
(
xmm_tmp
,
xmm_dst
,
1
);
process
(
xmm_dst
,
xmm_dst
,
xmm_tmp
);
if
(
rest
>=
1
)
{
vmovss
(
xmm_src
,
ptr
[
param_src
+
offset
]);
process
(
xmm_dst
,
xmm_dst
,
xmm_src
);
}
vmovss
(
ptr
[
param_dst
],
xmm_dst
);
ret
();
}
#define DECLARE_HOP_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override { \
return platform::MayIUse(platform::avx); \
} \
size_t CodeSize(const int& d) const override { \
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \
} \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
DECLARE_HOP_CREATOR
(
HMax
);
DECLARE_HOP_CREATOR
(
HSum
);
#undef DECLARE_HOP_CREATOR
}
// namespace gen
}
// namespace jit
}
// namespace operators
}
// namespace paddle
namespace
gen
=
paddle
::
operators
::
jit
::
gen
;
REGISTER_JITKERNEL_GEN
(
kHMax
,
gen
::
HMaxCreator
);
REGISTER_JITKERNEL_GEN
(
kHSum
,
gen
::
HSumCreator
);
paddle/fluid/operators/jit/gen/hopv.h
0 → 100644
浏览文件 @
50945685
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
gen
{
// horizontal operand vector
class
HOPVJitCode
:
public
JitCode
{
public:
explicit
HOPVJitCode
(
int
d
,
operand_type
type
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
),
type_
(
type
)
{
if
(
!
(
type_
==
operand_type
::
MAX
||
type_
==
operand_type
::
ADD
))
{
LOG
(
FATAL
)
<<
"Do not support this operand type: "
<<
type_
;
}
this
->
genCode
();
}
virtual
const
char
*
name
()
const
{
std
::
string
base
=
"VXXJitCode"
;
if
(
type_
==
operand_type
::
MAX
)
{
base
+=
"_MAX"
;
}
else
{
base
+=
"_SUM"
;
}
return
base
.
c_str
();
}
void
genCode
()
override
;
protected:
template
<
typename
JMM
>
void
process
(
JMM
&
dst
,
JMM
&
src1
,
JMM
&
src2
)
{
// NOLINT
if
(
type_
==
operand_type
::
MAX
)
{
vmaxps
(
dst
,
src1
,
src2
);
}
else
if
(
type_
==
operand_type
::
ADD
)
{
vaddps
(
dst
,
src1
,
src2
);
}
}
private:
int
num_
;
operand_type
type_
;
reg64_t
param_src
{
abi_param1
};
reg64_t
param_dst
{
abi_param2
};
reg64_t
param_attr
{
abi_param3
};
ymm_t
ymm_tmp
=
ymm_t
(
0
);
ymm_t
ymm_src
=
ymm_t
(
1
);
ymm_t
ymm_dst
=
ymm_t
(
2
);
xmm_t
xmm_tmp
=
xmm_t
(
0
);
xmm_t
xmm_src
=
xmm_t
(
1
);
xmm_t
xmm_dst
=
xmm_t
(
2
);
};
#define DECLARE_HOP_JITCODE(name, op_type) \
class name##JitCode : public HOPVJitCode { \
public: \
explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr) \
: HOPVJitCode(d, op_type, code_size, code_ptr) {} \
};
DECLARE_HOP_JITCODE
(
HMax
,
operand_type
::
MAX
);
DECLARE_HOP_JITCODE
(
HSum
,
operand_type
::
ADD
);
#undef DECLARE_HOP_JITCODE
}
// namespace gen
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/gen/jitcode.h
浏览文件 @
50945685
...
@@ -47,6 +47,7 @@ using Label = Xbyak::Label;
...
@@ -47,6 +47,7 @@ using Label = Xbyak::Label;
typedef
enum
{
typedef
enum
{
MUL
=
0
,
MUL
=
0
,
MAX
,
ADD
,
ADD
,
SUB
,
SUB
,
RELU
,
RELU
,
...
...
paddle/fluid/operators/jit/test.cc
浏览文件 @
50945685
...
@@ -383,16 +383,19 @@ void TestAXYNKernel() {
...
@@ -383,16 +383,19 @@ void TestAXYNKernel() {
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
TestXRNKernel
()
{
void
TestXRNKernel
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
auto
last_acc
=
acc
;
acc
=
1e-4
;
for
(
int
d
:
TestSizes
())
{
for
(
int
d
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
XRNTuples
<
T
>>
();
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
XRNTuples
<
T
>>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
x
(
d
);
std
::
vector
<
T
>
x
(
d
);
RandomVec
<
T
>
(
d
,
x
.
data
());
RandomVec
<
T
>
(
d
,
x
.
data
()
,
-
2.
f
,
2.
f
);
T
ref_res
;
T
ref_res
;
ref
(
x
.
data
(),
&
ref_res
,
d
);
ref
(
x
.
data
(),
&
ref_res
,
d
);
TestAllImpls
<
KT
,
jit
::
XRNTuples
<
T
>
,
PlaceType
,
std
::
vector
<
T
>
,
T
>
(
d
,
x
,
TestAllImpls
<
KT
,
jit
::
XRNTuples
<
T
>
,
PlaceType
,
std
::
vector
<
T
>
,
T
>
(
d
,
x
,
ref_res
);
ref_res
);
}
}
acc
=
last_acc
;
}
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录