Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
b9acbcc8
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b9acbcc8
编写于
9月 18, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
init lstm kernel
上级
c260bf94
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
63 addition
and
11 deletion
+63
-11
paddle/fluid/operators/math/jit_kernel.cc
paddle/fluid/operators/math/jit_kernel.cc
+39
-1
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+23
-4
paddle/fluid/operators/math/jit_kernel_impl.h
paddle/fluid/operators/math/jit_kernel_impl.h
+1
-6
未找到文件。
paddle/fluid/operators/math/jit_kernel.cc
浏览文件 @
b9acbcc8
...
@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <functional>
#include <string>
#include <string>
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -25,13 +28,48 @@ KernelPool& KernelPool::Instance() {
...
@@ -25,13 +28,48 @@ KernelPool& KernelPool::Instance() {
return
g_jit_kernels
;
return
g_jit_kernels
;
}
}
template
<
>
LSTMKernel
<
float
>::
LSTMKernel
(
int
d
,
const
std
::
string
&
act_gate_str
,
const
std
::
string
&
act_cand_str
,
const
std
::
string
&
act_cell_str
)
:
Kernel
(),
d_
(
d
)
{
if
(
platform
::
jit
::
MayIUse
(
platform
::
jit
::
avx512_common
))
{
math
::
VecActivations
<
float
,
platform
::
jit
::
avx512_common
>
act_functor
;
act_gate_
=
act_functor
(
act_gate_str
);
act_cell_
=
act_functor
(
act_cell_str
);
act_cand_
=
act_functor
(
act_cand_str
);
}
else
if
(
platform
::
jit
::
MayIUse
(
platform
::
jit
::
avx2
))
{
math
::
VecActivations
<
float
,
platform
::
jit
::
avx2
>
act_functor
;
act_gate_
=
act_functor
(
act_gate_str
);
act_cell_
=
act_functor
(
act_cell_str
);
act_cand_
=
act_functor
(
act_cand_str
);
}
else
if
(
platform
::
jit
::
MayIUse
(
platform
::
jit
::
avx
))
{
math
::
VecActivations
<
float
,
platform
::
jit
::
avx
>
act_functor
;
act_gate_
=
act_functor
(
act_gate_str
);
act_cell_
=
act_functor
(
act_cell_str
);
act_cand_
=
act_functor
(
act_cand_str
);
}
else
{
math
::
VecActivations
<
float
,
platform
::
jit
::
isa_any
>
act_functor
;
act_gate_
=
act_functor
(
act_gate_str
);
act_cell_
=
act_functor
(
act_cell_str
);
act_cand_
=
act_functor
(
act_cand_str
);
}
}
template
<
>
template
<
>
const
std
::
shared_ptr
<
LSTMKernel
<
float
>>
const
std
::
shared_ptr
<
LSTMKernel
<
float
>>
KernelPool
::
Get
<
LSTMKernel
<
float
>
,
int
,
const
std
::
string
&
,
const
std
::
string
&
,
KernelPool
::
Get
<
LSTMKernel
<
float
>
,
int
,
const
std
::
string
&
,
const
std
::
string
&
,
const
std
::
string
&>
(
int
d
,
const
std
::
string
&
act_gate
,
const
std
::
string
&>
(
int
d
,
const
std
::
string
&
act_gate
,
const
std
::
string
&
act_cand
,
const
std
::
string
&
act_cand
,
const
std
::
string
&
act_cell
)
{
const
std
::
string
&
act_cell
)
{
return
nullptr
;
std
::
string
key
=
"f"
+
std
::
to_string
(
d
)
+
act_gate
+
act_cand
+
act_cell
;
if
(
kers_
.
find
(
key
)
==
kers_
.
end
())
{
auto
p
=
std
::
make_shared
<
LSTMKernel
<
float
>>
(
d
,
act_gate
,
act_cand
,
act_cell
);
kers_
.
insert
({
key
,
std
::
dynamic_pointer_cast
<
Kernel
>
(
p
)});
return
p
;
}
return
std
::
dynamic_pointer_cast
<
LSTMKernel
<
float
>>
(
kers_
.
at
(
key
));
}
}
}
// namespace jitkernel
}
// namespace jitkernel
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
b9acbcc8
...
@@ -14,10 +14,9 @@ limitations under the License. */
...
@@ -14,10 +14,9 @@ limitations under the License. */
#pragma once
#pragma once
#include <functional>
#include <functional>
#include <map>
#include <memory> // for shared_ptr
#include <memory> // for shared_ptr
#include <string>
#include <string>
#include <
vector
>
#include <
unordered_map
>
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/macros.h"
// Note: Only support on CPU yet.
// Note: Only support on CPU yet.
...
@@ -27,23 +26,43 @@ namespace math {
...
@@ -27,23 +26,43 @@ namespace math {
namespace
jitkernel
{
namespace
jitkernel
{
class
Kernel
{
class
Kernel
{
public:
Kernel
()
{}
virtual
~
Kernel
()
=
default
;
private:
DISABLE_COPY_AND_ASSIGN
(
Kernel
);
DISABLE_COPY_AND_ASSIGN
(
Kernel
);
};
};
class
KernelPool
{
class
KernelPool
{
public:
public:
static
KernelPool
&
Instance
();
static
KernelPool
&
Instance
();
template
<
typename
Ker
,
typename
...
ARGS
>
template
<
typename
Ker
,
typename
...
ARGS
>
const
std
::
shared_ptr
<
Ker
>
Get
(
ARGS
...
args
);
const
std
::
shared_ptr
<
Ker
>
Get
(
ARGS
...
args
);
private:
private:
KernelPool
()
=
default
;
KernelPool
()
=
default
;
// std::unordered_map<std::string, Kernel
> kers_;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Kernel
>
>
kers_
;
DISABLE_COPY_AND_ASSIGN
(
KernelPool
);
DISABLE_COPY_AND_ASSIGN
(
KernelPool
);
};
};
template
<
typename
T
>
class
LSTMKernel
:
public
Kernel
{
public:
explicit
LSTMKernel
(
int
d
,
const
std
::
string
&
act_gate
,
const
std
::
string
&
act_cand
,
const
std
::
string
&
act_cell
);
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
);
void
ComputeCtHt_NoC0H0
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
);
private:
int
d_
;
std
::
function
<
void
(
const
int
,
const
T
*
,
T
*
)
>
act_gate_
,
act_cell_
,
act_cand_
;
};
}
// namespace jitkernel
}
// namespace jitkernel
}
// namespace math
}
// namespace math
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/math/jit_kernel_impl.h
浏览文件 @
b9acbcc8
...
@@ -21,12 +21,7 @@ limitations under the License. */
...
@@ -21,12 +21,7 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
namespace
jitkernel
{
namespace
jitkernel
{}
// namespace jitkernel
template
<
typename
T
>
class
LSTMKernel
:
public
Kernel
{};
}
// namespace jitkernel
}
// namespace math
}
// namespace math
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录