Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1d85b2bd
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
1d85b2bd
编写于
11月 04, 2017
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine GRU Operator according to activation_functions
上级
4b8bcf32
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
36 addition
and
81 deletion
+36
-81
paddle/operators/math/detail/gru_cpu_kernel.h
paddle/operators/math/detail/gru_cpu_kernel.h
+9
-13
paddle/operators/math/detail/gru_gpu_kernel.h
paddle/operators/math/detail/gru_gpu_kernel.h
+4
-8
paddle/operators/math/detail/gru_kernel.h
paddle/operators/math/detail/gru_kernel.h
+23
-60
未找到文件。
paddle/operators/math/detail/gru_cpu_kernel.h
浏览文件 @
1d85b2bd
...
@@ -14,7 +14,7 @@ limitations under the License. */
...
@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once
#pragma once
#include <type_traits>
#include <type_traits>
#include "paddle/operators/math/detail/
hl_
activation_functions.h"
#include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/operators/math/gru_compute.h"
#include "paddle/operators/math/gru_compute.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -43,9 +43,8 @@ void hl_naive_gru_forward_reset_output(OpResetOutput opResetOutput,
...
@@ -43,9 +43,8 @@ void hl_naive_gru_forward_reset_output(OpResetOutput opResetOutput,
rPrevOut
=
prevOutputValue
[
i
];
rPrevOut
=
prevOutputValue
[
i
];
}
}
hppl
::
cpu
::
ForwardAct
<
T
>
act
;
opResetOutput
(
rValueUpdateGate
,
rValueResetGate
,
rPrevOut
,
opResetOutput
(
rValueUpdateGate
,
rValueResetGate
,
rPrevOut
,
rValueResetOutput
,
act
(
active_gate
)
);
rValueResetOutput
,
act
ive_gate
);
updateGate
[
i
]
=
rValueUpdateGate
;
updateGate
[
i
]
=
rValueUpdateGate
;
resetGate
[
i
]
=
rValueResetGate
;
resetGate
[
i
]
=
rValueResetGate
;
...
@@ -72,9 +71,8 @@ void hl_naive_gru_forward_final_output(OpFinalOutput opFinalOutput,
...
@@ -72,9 +71,8 @@ void hl_naive_gru_forward_final_output(OpFinalOutput opFinalOutput,
rPrevOut
=
prevOutputValue
[
i
];
rPrevOut
=
prevOutputValue
[
i
];
}
}
hppl
::
cpu
::
ForwardAct
<
T
>
act
;
opFinalOutput
(
rValueUpdateGate
,
rValueFrameState
,
rPrevOut
,
rOutput
,
opFinalOutput
(
rValueUpdateGate
,
rValueFrameState
,
rPrevOut
,
rOutput
,
act
(
active_node
)
);
act
ive_node
);
frameState
[
i
]
=
rValueFrameState
;
frameState
[
i
]
=
rValueFrameState
;
outputValue
[
i
]
=
rOutput
;
outputValue
[
i
]
=
rOutput
;
...
@@ -102,7 +100,7 @@ void hl_avx_gru_forward_reset_output(OpResetOutput opResetOutput, T *gateValue,
...
@@ -102,7 +100,7 @@ void hl_avx_gru_forward_reset_output(OpResetOutput opResetOutput, T *gateValue,
}
}
opResetOutput
(
rValueUpdateGate
,
rValueResetGate
,
rPrevOut
,
opResetOutput
(
rValueUpdateGate
,
rValueResetGate
,
rPrevOut
,
rValueResetOutput
,
hppl
::
avx
::
forward
[
active_gate
]
);
rValueResetOutput
,
active_gate
);
updateGate
[
i
]
=
rValueUpdateGate
;
updateGate
[
i
]
=
rValueUpdateGate
;
resetGate
[
i
]
=
rValueResetGate
;
resetGate
[
i
]
=
rValueResetGate
;
...
@@ -132,7 +130,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput opFinalOutput, T *gateValue,
...
@@ -132,7 +130,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput opFinalOutput, T *gateValue,
}
}
opFinalOutput
(
rValueUpdateGate
,
rValueFrameState
,
rPrevOut
,
rOutput
,
opFinalOutput
(
rValueUpdateGate
,
rValueFrameState
,
rPrevOut
,
rOutput
,
hppl
::
avx
::
forward
[
active_node
]
);
active_node
);
frameState
[
i
]
=
rValueFrameState
;
frameState
[
i
]
=
rValueFrameState
;
((
__m256
*
)
outputValue
)[
i
]
=
rOutput
;
((
__m256
*
)
outputValue
)[
i
]
=
rOutput
;
...
@@ -215,10 +213,9 @@ void hl_naive_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue,
...
@@ -215,10 +213,9 @@ void hl_naive_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue,
rPrevOutGrad
=
prevOutGrad
[
i
];
rPrevOutGrad
=
prevOutGrad
[
i
];
}
}
hppl
::
cpu
::
BackwardAct
<
T
>
act
;
opStateGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rFrameStateValue
,
opStateGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rFrameStateValue
,
rFrameStateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rOutGrad
,
rFrameStateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rOutGrad
,
act
(
active_node
)
);
act
ive_node
);
updateGateGrad
[
i
]
=
rUpdateGateGrad
;
updateGateGrad
[
i
]
=
rUpdateGateGrad
;
frameStateGrad
[
i
]
=
rFrameStateGrad
;
frameStateGrad
[
i
]
=
rFrameStateGrad
;
...
@@ -261,10 +258,9 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue,
...
@@ -261,10 +258,9 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue,
rPrevOutGrad
=
prevOutGrad
[
i
];
rPrevOutGrad
=
prevOutGrad
[
i
];
}
}
hppl
::
cpu
::
BackwardAct
<
T
>
act
;
opResetGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rResetGateValue
,
opResetGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rResetGateValue
,
rResetGateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rResetOutputGrad
,
rResetGateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rResetOutputGrad
,
act
(
active_gate
)
);
act
ive_gate
);
updateGateGrad
[
i
]
=
rUpdateGateGrad
;
updateGateGrad
[
i
]
=
rUpdateGateGrad
;
resetGateGrad
[
i
]
=
rResetGateGrad
;
resetGateGrad
[
i
]
=
rResetGateGrad
;
...
@@ -306,7 +302,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue,
...
@@ -306,7 +302,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue,
opStateGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rFrameStateValue
,
opStateGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rFrameStateValue
,
rFrameStateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rOutGrad
,
rFrameStateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rOutGrad
,
hppl
::
avx
::
backward
[
active_node
]
);
active_node
);
updateGateGrad
[
i
]
=
rUpdateGateGrad
;
updateGateGrad
[
i
]
=
rUpdateGateGrad
;
frameStateGrad
[
i
]
=
rFrameStateGrad
;
frameStateGrad
[
i
]
=
rFrameStateGrad
;
...
@@ -353,7 +349,7 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue,
...
@@ -353,7 +349,7 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue,
opResetGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rResetGateValue
,
opResetGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rResetGateValue
,
rResetGateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rResetOutputGrad
,
rResetGateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rResetOutputGrad
,
hppl
::
avx
::
backward
[
active_gate
]
);
active_gate
);
updateGateGrad
[
i
]
=
rUpdateGateGrad
;
updateGateGrad
[
i
]
=
rUpdateGateGrad
;
resetGateGrad
[
i
]
=
rResetGateGrad
;
resetGateGrad
[
i
]
=
rResetGateGrad
;
...
...
paddle/operators/math/detail/gru_gpu_kernel.h
浏览文件 @
1d85b2bd
...
@@ -57,9 +57,8 @@ __global__ void KeGruForwardResetOutput(OpResetOutput opResetOutput,
...
@@ -57,9 +57,8 @@ __global__ void KeGruForwardResetOutput(OpResetOutput opResetOutput,
rPrevOut
=
prevOutputValue
[
frameIdx
];
rPrevOut
=
prevOutputValue
[
frameIdx
];
}
}
hppl
::
gpu
::
ForwardAct
<
T
>
act
;
opResetOutput
(
rValueUpdateGate
,
rValueResetGate
,
rPrevOut
,
rValueResetOutput
,
opResetOutput
(
rValueUpdateGate
,
rValueResetGate
,
rPrevOut
,
rValueResetOutput
,
act
(
active_gate
)
);
act
ive_gate
);
gateValue
[
frameIdx
+
frameSize
*
0
]
=
rValueUpdateGate
;
gateValue
[
frameIdx
+
frameSize
*
0
]
=
rValueUpdateGate
;
gateValue
[
frameIdx
+
frameSize
*
1
]
=
rValueResetGate
;
gateValue
[
frameIdx
+
frameSize
*
1
]
=
rValueResetGate
;
...
@@ -96,9 +95,8 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput opFinalOutput,
...
@@ -96,9 +95,8 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput opFinalOutput,
rPrevOut
=
prevOutputValue
[
frameIdx
];
rPrevOut
=
prevOutputValue
[
frameIdx
];
}
}
hppl
::
gpu
::
ForwardAct
<
T
>
act
;
opFinalOutput
(
rValueUpdateGate
,
rValueFrameState
,
rPrevOut
,
rOutput
,
opFinalOutput
(
rValueUpdateGate
,
rValueFrameState
,
rPrevOut
,
rOutput
,
act
(
active_node
)
);
act
ive_node
);
gateValue
[
frameIdx
+
frameSize
*
2
]
=
rValueFrameState
;
gateValue
[
frameIdx
+
frameSize
*
2
]
=
rValueFrameState
;
outputValue
[
frameIdx
]
=
rOutput
;
outputValue
[
frameIdx
]
=
rOutput
;
...
@@ -141,10 +139,9 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad opStateGrad, T *gateValue,
...
@@ -141,10 +139,9 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad opStateGrad, T *gateValue,
rPrevOutGrad
=
prevOutGrad
[
frameIdx
];
rPrevOutGrad
=
prevOutGrad
[
frameIdx
];
}
}
hppl
::
gpu
::
BackwardAct
<
T
>
act
;
opStateGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rFrameStateValue
,
opStateGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rFrameStateValue
,
rFrameStateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rOutGrad
,
rFrameStateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rOutGrad
,
act
(
active_node
)
);
act
ive_node
);
gateGrad
[
frameIdx
+
frameSize
*
0
]
=
rUpdateGateGrad
;
gateGrad
[
frameIdx
+
frameSize
*
0
]
=
rUpdateGateGrad
;
gateGrad
[
frameIdx
+
frameSize
*
2
]
=
rFrameStateGrad
;
gateGrad
[
frameIdx
+
frameSize
*
2
]
=
rFrameStateGrad
;
...
@@ -190,10 +187,9 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad opResetGrad, T *gateValue,
...
@@ -190,10 +187,9 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad opResetGrad, T *gateValue,
rResetOutputGrad
=
resetOutputGrad
[
frameIdx
];
rResetOutputGrad
=
resetOutputGrad
[
frameIdx
];
}
}
hppl
::
gpu
::
BackwardAct
<
T
>
act
;
opResetGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rResetGateValue
,
opResetGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rResetGateValue
,
rResetGateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rResetOutputGrad
,
rResetGateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rResetOutputGrad
,
act
(
active_gate
)
);
act
ive_gate
);
gateGrad
[
frameIdx
+
frameSize
*
0
]
=
rUpdateGateGrad
;
gateGrad
[
frameIdx
+
frameSize
*
0
]
=
rUpdateGateGrad
;
gateGrad
[
frameIdx
+
frameSize
*
1
]
=
rResetGateGrad
;
gateGrad
[
frameIdx
+
frameSize
*
1
]
=
rResetGateGrad
;
...
...
paddle/operators/math/detail/gru_kernel.h
浏览文件 @
1d85b2bd
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/operators/math/detail/
hl_
activation_functions.h"
#include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/platform/hostdevice.h"
#include "paddle/platform/hostdevice.h"
#include <type_traits>
#include <type_traits>
...
@@ -27,18 +27,10 @@ namespace forward {
...
@@ -27,18 +27,10 @@ namespace forward {
template
<
typename
T
>
template
<
typename
T
>
class
gru_resetOutput
{
class
gru_resetOutput
{
public:
public:
/**
* @param[in,out] valueUpdateGate update gate
* @param[in,out] valueResetGate reset gate
* @param[in] prevOut previous output
* @param[out] valueResetOutput intermediate value for frame state
* @param[in] actGate forward function of gate
*/
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
valueResetGate
,
T
&
prevOut
,
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
valueResetGate
,
T
&
prevOut
,
T
&
valueResetOutput
,
T
&
valueResetOutput
,
activation_mode_t
actGate
)
{
typename
hppl
::
Active
<
T
>::
forward
actGate
)
{
valueUpdateGate
=
activation
(
valueUpdateGate
,
actGate
);
valueUpdateGate
=
actGate
(
valueUpdateGate
);
valueResetGate
=
activation
(
valueResetGate
,
actGate
);
valueResetGate
=
actGate
(
valueResetGate
);
valueResetOutput
=
prevOut
*
valueResetGate
;
valueResetOutput
=
prevOut
*
valueResetGate
;
}
}
#ifndef __NVCC__
#ifndef __NVCC__
...
@@ -48,9 +40,9 @@ class gru_resetOutput {
...
@@ -48,9 +40,9 @@ class gru_resetOutput {
static
const
bool
avx
=
true
;
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
valueResetGate
,
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
valueResetGate
,
__m256
&
prevOut
,
__m256
&
valueResetOutput
,
__m256
&
prevOut
,
__m256
&
valueResetOutput
,
typename
hppl
::
Active
<
__m256
>::
forward
actGate
)
{
activation_mode_t
actGate
)
{
valueUpdateGate
=
act
Gate
(
valueUpdate
Gate
);
valueUpdateGate
=
act
ivation
(
valueUpdateGate
,
act
Gate
);
valueResetGate
=
act
Gate
(
valueRese
tGate
);
valueResetGate
=
act
ivation
(
valueResetGate
,
ac
tGate
);
valueResetOutput
=
_mm256_mul_ps
(
prevOut
,
valueResetGate
);
valueResetOutput
=
_mm256_mul_ps
(
prevOut
,
valueResetGate
);
}
}
#endif
#endif
...
@@ -60,17 +52,9 @@ class gru_resetOutput {
...
@@ -60,17 +52,9 @@ class gru_resetOutput {
template
<
typename
T
>
template
<
typename
T
>
class
gru_finalOutput
{
class
gru_finalOutput
{
public:
public:
/**
* @param[in] valueUpdateGate update gate
* @param[in,out] valueFrameState frame state ({\tilde{h}_t})
* @param[in] prevOut previous output
* @param[out] valueOutput output
* @param[in] actInput forward function of node
*/
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
valueFrameState
,
T
&
prevOut
,
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
valueFrameState
,
T
&
prevOut
,
T
&
valueOutput
,
T
&
valueOutput
,
activation_mode_t
actInput
)
{
typename
hppl
::
Active
<
T
>::
forward
actInput
)
{
valueFrameState
=
activation
(
valueFrameState
,
actInput
);
valueFrameState
=
actInput
(
valueFrameState
);
valueOutput
=
prevOut
-
(
valueUpdateGate
*
prevOut
)
+
valueOutput
=
prevOut
-
(
valueUpdateGate
*
prevOut
)
+
(
valueUpdateGate
*
valueFrameState
);
(
valueUpdateGate
*
valueFrameState
);
}
}
...
@@ -81,8 +65,8 @@ class gru_finalOutput {
...
@@ -81,8 +65,8 @@ class gru_finalOutput {
static
const
bool
avx
=
true
;
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
valueFrameState
,
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
valueFrameState
,
__m256
&
prevOut
,
__m256
&
valueOutput
,
__m256
&
prevOut
,
__m256
&
valueOutput
,
typename
hppl
::
Active
<
__m256
>::
forward
actInput
)
{
activation_mode_t
actInput
)
{
valueFrameState
=
act
Input
(
valueFrameState
);
valueFrameState
=
act
ivation
(
valueFrameState
,
actInput
);
valueOutput
=
_mm256_add_ps
(
valueOutput
=
_mm256_add_ps
(
_mm256_sub_ps
(
prevOut
,
_mm256_mul_ps
(
valueUpdateGate
,
prevOut
)),
_mm256_sub_ps
(
prevOut
,
_mm256_mul_ps
(
valueUpdateGate
,
prevOut
)),
_mm256_mul_ps
(
valueUpdateGate
,
valueFrameState
));
_mm256_mul_ps
(
valueUpdateGate
,
valueFrameState
));
...
@@ -97,25 +81,16 @@ namespace backward {
...
@@ -97,25 +81,16 @@ namespace backward {
template
<
typename
T
>
template
<
typename
T
>
class
gru_stateGrad
{
class
gru_stateGrad
{
public:
public:
/**
* @param[in] valueUpdateGate update gate value
* @param[out] gradUpdateGate update gate grad
* @param[in] valueFrameState frame state value
* @param[out] gradFrameState frame state grad
* @param[in] valuePrevOut previous output value
* @param[in,out] gradPrevOut previous output grad
* @param[in] gradOutput output grad
* @param[in] actInput backward function of frame state
*/
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
gradUpdateGate
,
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
gradUpdateGate
,
T
&
valueFrameState
,
T
&
gradFrameState
,
T
&
valueFrameState
,
T
&
gradFrameState
,
T
&
valuePrevOut
,
T
&
gradPrevOut
,
T
&
gradOutput
,
T
&
valuePrevOut
,
T
&
gradPrevOut
,
T
&
gradOutput
,
typename
hppl
::
Active
<
T
>::
backward
actInput
)
{
activation_mode_t
actInput
)
{
gradUpdateGate
=
(
gradOutput
*
valueFrameState
);
gradUpdateGate
=
(
gradOutput
*
valueFrameState
);
gradUpdateGate
-=
(
gradOutput
*
valuePrevOut
);
gradUpdateGate
-=
(
gradOutput
*
valuePrevOut
);
gradPrevOut
-=
(
gradOutput
*
valueUpdateGate
);
gradPrevOut
-=
(
gradOutput
*
valueUpdateGate
);
gradPrevOut
+=
gradOutput
;
gradPrevOut
+=
gradOutput
;
gradFrameState
=
actInput
(
gradOutput
*
valueUpdateGate
,
valueFrameState
);
gradFrameState
=
activation
(
gradOutput
*
valueUpdateGate
,
valueFrameState
,
actInput
);
}
}
#ifndef __NVCC__
#ifndef __NVCC__
#ifndef __AVX__
#ifndef __AVX__
...
@@ -125,16 +100,15 @@ class gru_stateGrad {
...
@@ -125,16 +100,15 @@ class gru_stateGrad {
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
gradUpdateGate
,
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
gradUpdateGate
,
__m256
&
valueFrameState
,
__m256
&
gradFrameState
,
__m256
&
valueFrameState
,
__m256
&
gradFrameState
,
__m256
&
valuePrevOut
,
__m256
&
gradPrevOut
,
__m256
&
valuePrevOut
,
__m256
&
gradPrevOut
,
__m256
&
gradOutput
,
__m256
&
gradOutput
,
activation_mode_t
actInput
)
{
typename
hppl
::
Active
<
__m256
>::
backward
actInput
)
{
gradUpdateGate
=
_mm256_mul_ps
(
gradOutput
,
valueFrameState
);
gradUpdateGate
=
_mm256_mul_ps
(
gradOutput
,
valueFrameState
);
gradUpdateGate
=
gradUpdateGate
=
_mm256_sub_ps
(
gradUpdateGate
,
_mm256_mul_ps
(
gradOutput
,
valuePrevOut
));
_mm256_sub_ps
(
gradUpdateGate
,
_mm256_mul_ps
(
gradOutput
,
valuePrevOut
));
gradPrevOut
=
_mm256_add_ps
(
gradPrevOut
=
_mm256_add_ps
(
_mm256_sub_ps
(
gradPrevOut
,
_mm256_mul_ps
(
gradOutput
,
valueUpdateGate
)),
_mm256_sub_ps
(
gradPrevOut
,
_mm256_mul_ps
(
gradOutput
,
valueUpdateGate
)),
gradOutput
);
gradOutput
);
gradFrameState
=
gradFrameState
=
activation
(
_mm256_mul_ps
(
gradOutput
,
valueUpdateGate
),
actInput
(
_mm256_mul_ps
(
gradOutput
,
valueUpdateGate
),
valueFrameState
);
valueFrameState
,
actInput
);
}
}
#endif
#endif
#endif
#endif
...
@@ -143,25 +117,14 @@ class gru_stateGrad {
...
@@ -143,25 +117,14 @@ class gru_stateGrad {
template
<
typename
T
>
template
<
typename
T
>
class
gru_resetGrad
{
class
gru_resetGrad
{
public:
public:
/**
* @param[in] valueUpdateGate update gate value
* @param[in,out] gradUpdateGate update gate grad
* @param[in] valueResetGate reset gate value
* @param[out] gradResetGate reset gate grad
* @param[in] valuePrevOut previous output value
* @param[in,out] gradPrevOut previous output grad
* @param[in] gradResetOutput reset output grad (temp val)
* @param[in] actGate backward function of gate
*/
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
gradUpdateGate
,
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
gradUpdateGate
,
T
&
valueResetGate
,
T
&
gradResetGate
,
T
&
valueResetGate
,
T
&
gradResetGate
,
T
&
valuePrevOut
,
T
&
gradPrevOut
,
T
&
valuePrevOut
,
T
&
gradPrevOut
,
T
&
gradResetOutput
,
T
&
gradResetOutput
,
activation_mode_t
actGate
)
{
typename
hppl
::
Active
<
T
>::
backward
actGate
)
{
gradResetGate
=
(
gradResetOutput
*
valuePrevOut
);
gradResetGate
=
(
gradResetOutput
*
valuePrevOut
);
gradPrevOut
+=
(
gradResetOutput
*
valueResetGate
);
gradPrevOut
+=
(
gradResetOutput
*
valueResetGate
);
gradUpdateGate
=
act
Gate
(
gradUpdateGate
,
valueUpdate
Gate
);
gradUpdateGate
=
act
ivation
(
gradUpdateGate
,
valueUpdateGate
,
act
Gate
);
gradResetGate
=
act
Gate
(
gradResetGate
,
valueRese
tGate
);
gradResetGate
=
act
ivation
(
gradResetGate
,
valueResetGate
,
ac
tGate
);
}
}
#ifndef __NVCC__
#ifndef __NVCC__
#ifndef __AVX__
#ifndef __AVX__
...
@@ -172,12 +135,12 @@ class gru_resetGrad {
...
@@ -172,12 +135,12 @@ class gru_resetGrad {
__m256
&
valueResetGate
,
__m256
&
gradResetGate
,
__m256
&
valueResetGate
,
__m256
&
gradResetGate
,
__m256
&
valuePrevOut
,
__m256
&
gradPrevOut
,
__m256
&
valuePrevOut
,
__m256
&
gradPrevOut
,
__m256
&
gradResetOutput
,
__m256
&
gradResetOutput
,
typename
hppl
::
Active
<
__m256
>::
backward
actGate
)
{
activation_mode_t
actGate
)
{
gradResetGate
=
_mm256_mul_ps
(
gradResetOutput
,
valuePrevOut
);
gradResetGate
=
_mm256_mul_ps
(
gradResetOutput
,
valuePrevOut
);
gradPrevOut
=
_mm256_add_ps
(
gradPrevOut
,
gradPrevOut
=
_mm256_add_ps
(
gradPrevOut
,
_mm256_mul_ps
(
gradResetOutput
,
valueResetGate
));
_mm256_mul_ps
(
gradResetOutput
,
valueResetGate
));
gradUpdateGate
=
act
Gate
(
gradUpdateGate
,
valueUpdate
Gate
);
gradUpdateGate
=
act
ivation
(
gradUpdateGate
,
valueUpdateGate
,
act
Gate
);
gradResetGate
=
act
Gate
(
gradResetGate
,
valueRese
tGate
);
gradResetGate
=
act
ivation
(
gradResetGate
,
valueResetGate
,
ac
tGate
);
}
}
#endif
#endif
#endif
#endif
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录