提交 b50c33fd 编写于 作者: D dangqingqing

Use fixed activation in the lstm kernel, since there is some bug in the...

Use fixed activation in the lstm kernel, since there is some bug in the activation function pointer. It will be fixed later.
上级 bd680f15
...@@ -82,6 +82,13 @@ class LSTMOp : public framework::OperatorWithKernel { ...@@ -82,6 +82,13 @@ class LSTMOp : public framework::OperatorWithKernel {
ctx->ShareLoD("Input", "Hidden"); ctx->ShareLoD("Input", "Hidden");
ctx->ShareLoD("Input", "Cell"); ctx->ShareLoD("Input", "Cell");
} }
protected:
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(
ctx.Input<framework::LoDTensor>("Input")->type());
}
}; };
class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -239,6 +246,13 @@ class LSTMGradOp : public framework::OperatorWithKernel { ...@@ -239,6 +246,13 @@ class LSTMGradOp : public framework::OperatorWithKernel {
if (ctx->HasOutput(b_g_name)) if (ctx->HasOutput(b_g_name))
ctx->SetOutputDim(b_g_name, ctx->GetInputDim("Bias")); ctx->SetOutputDim(b_g_name, ctx->GetInputDim("Bias"));
} }
protected:
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(
ctx.Input<framework::LoDTensor>("Input")->type());
}
}; };
} // namespace operators } // namespace operators
......
...@@ -26,10 +26,7 @@ namespace detail { ...@@ -26,10 +26,7 @@ namespace detail {
template <class T, class Op> template <class T, class Op>
void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frameSize, int frameSize) {
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
T rValueIn; T rValueIn;
T rValueIg; T rValueIg;
T rValueFg; T rValueFg;
...@@ -60,10 +57,8 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -60,10 +57,8 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
rPrevState = value.prevStateValue[i]; rPrevState = value.prevStateValue[i];
} }
hppl::cpu::ForwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate), rOut, rCheckI, rCheckF, rCheckO);
act(active_state));
valueIn[i] = rValueIn; valueIn[i] = rValueIn;
valueIg[i] = rValueIg; valueIg[i] = rValueIg;
...@@ -77,10 +72,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -77,10 +72,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op> template <class T, class Op>
void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frameSize, LstmMetaGrad<T> grad, int frameSize) {
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
T rValueIn; T rValueIn;
T rValueIg; T rValueIg;
T rValueFg; T rValueFg;
...@@ -127,11 +119,10 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -127,11 +119,10 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
rPrevState = value.prevStateValue[i]; rPrevState = value.prevStateValue[i];
} }
hppl::cpu::BackwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg,
rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv,
rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad,
rCheckOGrad, act(active_node), act(active_gate), act(active_state)); rCheckOGrad);
gradIn[i] = rGradIn; gradIn[i] = rGradIn;
gradIg[i] = rGradIg; gradIg[i] = rGradIg;
...@@ -283,8 +274,7 @@ void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frameSize, ...@@ -283,8 +274,7 @@ void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frameSize,
avx_lstm_forward_one_sequence<T>(op, value, frameSize, active_node, avx_lstm_forward_one_sequence<T>(op, value, frameSize, active_node,
active_gate, active_state); active_gate, active_state);
} else { } else {
naive_lstm_forward_one_sequence<T>(op, value, frameSize, active_node, naive_lstm_forward_one_sequence<T>(op, value, frameSize);
active_gate, active_state);
} }
} }
...@@ -297,8 +287,7 @@ void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad, ...@@ -297,8 +287,7 @@ void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
avx_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node, avx_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node,
active_gate, active_state); active_gate, active_state);
} else { } else {
naive_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node, naive_lstm_backward_one_sequence<T>(op, value, grad, frameSize);
active_gate, active_state);
} }
} }
......
...@@ -32,9 +32,7 @@ namespace detail { ...@@ -32,9 +32,7 @@ namespace detail {
*/ */
template <class T, class Op, bool isBatch> template <class T, class Op, bool isBatch>
__global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize, __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
int batchSize, activation_mode_t active_node, int batchSize) {
activation_mode_t active_gate,
activation_mode_t active_state) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return; if (frameIdx >= frameSize) return;
...@@ -70,10 +68,8 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize, ...@@ -70,10 +68,8 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
rPrevState = value.prevStateValue[frameIdx]; rPrevState = value.prevStateValue[frameIdx];
} }
hppl::gpu::ForwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate), rOut, rCheckI, rCheckF, rCheckO);
act(active_state));
value.gateValue[frameIdx] = rValueIn; value.gateValue[frameIdx] = rValueIn;
value.gateValue[frameIdx + frameSize] = rValueIg; value.gateValue[frameIdx + frameSize] = rValueIg;
...@@ -92,9 +88,7 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize, ...@@ -92,9 +88,7 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
template <class T, class Op, bool isBatch> template <class T, class Op, bool isBatch>
__global__ void KeLstmBackward(Op op, LstmMetaValue<T> value, __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frameSize, LstmMetaGrad<T> grad, int frameSize,
int batchSize, activation_mode_t active_node, int batchSize) {
activation_mode_t active_gate,
activation_mode_t active_state) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return; if (frameIdx >= frameSize) return;
...@@ -145,11 +139,9 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value, ...@@ -145,11 +139,9 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
rPrevState = value.prevStateValue[frameIdx]; rPrevState = value.prevStateValue[frameIdx];
} }
hppl::gpu::BackwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg, op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg,
rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad,
rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad);
act(active_node), act(active_gate), act(active_state));
grad.gateGrad[frameIdx] = rGradIn; grad.gateGrad[frameIdx] = rGradIn;
grad.gateGrad[frameIdx + frameSize] = rGradIg; grad.gateGrad[frameIdx + frameSize] = rGradIg;
...@@ -205,13 +197,11 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op, ...@@ -205,13 +197,11 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
if (batchSize == 1) { if (batchSize == 1) {
KeLstmForward<T, Op, KeLstmForward<T, Op,
/* isBatch= */ false><<<grid, threads, 0, stream>>>( /* isBatch= */ false><<<grid, threads, 0, stream>>>(
op, value, frameSize, batchSize, active_node, active_gate, op, value, frameSize, batchSize);
active_state);
} else { } else {
KeLstmForward<T, Op, KeLstmForward<T, Op,
/* isBatch= */ true><<<grid, threads, 0, stream>>>( /* isBatch= */ true><<<grid, threads, 0, stream>>>(
op, value, frameSize, batchSize, active_node, active_gate, op, value, frameSize, batchSize);
active_state);
} }
} }
...@@ -240,13 +230,11 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op, ...@@ -240,13 +230,11 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
if (batchSize == 1) { if (batchSize == 1) {
KeLstmBackward<T, Op, KeLstmBackward<T, Op,
/* isBatch= */ false><<<grid, threads, 0, stream>>>( /* isBatch= */ false><<<grid, threads, 0, stream>>>(
op, value, grad, frameSize, batchSize, active_node, active_gate, op, value, grad, frameSize, batchSize);
active_state);
} else { } else {
KeLstmBackward<T, Op, KeLstmBackward<T, Op,
/* isBatch= */ true><<<grid, threads, 0, stream>>>( /* isBatch= */ true><<<grid, threads, 0, stream>>>(
op, value, grad, frameSize, batchSize, active_node, active_gate, op, value, grad, frameSize, batchSize);
active_state);
} }
} }
......
...@@ -24,15 +24,29 @@ namespace detail { ...@@ -24,15 +24,29 @@ namespace detail {
namespace forward { namespace forward {
template <typename T>
DEVICE inline T sigmoid(const T a) {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
T tmp = (a < min) ? min : ((a > max) ? max : a);
return static_cast<T>(1.0) / (static_cast<T>(1.0) + exp(-tmp));
}
template <typename T>
DEVICE inline T tanh(const T a) {
T tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return (2.0 / (1.0 + exp(tmp))) - 1.0;
}
template <class T> template <class T>
class lstm { class lstm {
public: public:
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
T &prevState, T &state, T &stateAtv, T &output, T &prevState, T &state, T &stateAtv, T &output,
T &checkI, T &checkF, T &checkO, T &checkI, T &checkF, T &checkO) {
typename hppl::ForwardActType<T>::type actInput, #if 0
typename hppl::ForwardActType<T>::type actGate, // TODO(qingqing) support to activation speficed by users
typename hppl::ForwardActType<T>::type actState) {
valueIn = actInput(valueIn); valueIn = actInput(valueIn);
valueIg = actGate(valueIg + prevState * checkI); valueIg = actGate(valueIg + prevState * checkI);
valueFg = actGate(valueFg + prevState * checkF); valueFg = actGate(valueFg + prevState * checkF);
...@@ -40,6 +54,15 @@ class lstm { ...@@ -40,6 +54,15 @@ class lstm {
valueOg = actGate(valueOg + state * checkO); valueOg = actGate(valueOg + state * checkO);
stateAtv = actState(state); stateAtv = actState(state);
output = valueOg * stateAtv; output = valueOg * stateAtv;
#else
valueIn = tanh<T>(valueIn);
valueIg = sigmoid<T>(valueIg + prevState * checkI);
valueFg = sigmoid<T>(valueFg + prevState * checkF);
state = valueIn * valueIg + prevState * valueFg;
valueOg = sigmoid<T>(valueOg + state * checkO);
stateAtv = tanh<T>(state);
output = valueOg * stateAtv;
#endif
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
...@@ -72,6 +95,16 @@ class lstm { ...@@ -72,6 +95,16 @@ class lstm {
namespace backward { namespace backward {
template <typename T>
DEVICE inline T sigmoid(const T a, const T b) {
return a * b * (1.0 - b);
}
template <typename T>
DEVICE inline T tanh(const T a, const T b) {
return a * (1.0 - b * b);
}
template <class T> template <class T>
class lstm { class lstm {
public: public:
...@@ -80,10 +113,9 @@ class lstm { ...@@ -80,10 +113,9 @@ class lstm {
T &prevState, T &prevStateGrad, T &state, T &prevState, T &prevStateGrad, T &state,
T &stateGrad, T &stateAtv, T &outputGrad, T &stateGrad, T &stateAtv, T &outputGrad,
T &checkI, T &checkF, T &checkO, T &checkIGrad, T &checkI, T &checkF, T &checkO, T &checkIGrad,
T &checkFGrad, T &checkOGrad, T &checkFGrad, T &checkOGrad) {
typename hppl::BackwardActType<T>::type actInput, #if 0
typename hppl::BackwardActType<T>::type actGate, // TODO(qingqing) support to activation speficed by users
typename hppl::BackwardActType<T>::type actState) {
gradOg = actGate(outputGrad * stateAtv, valueOg); gradOg = actGate(outputGrad * stateAtv, valueOg);
stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO; stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO;
gradIn = actInput(stateGrad * valueIg, valueIn); gradIn = actInput(stateGrad * valueIg, valueIn);
...@@ -93,6 +125,17 @@ class lstm { ...@@ -93,6 +125,17 @@ class lstm {
checkIGrad = gradIg * prevState; checkIGrad = gradIg * prevState;
checkFGrad = gradFg * prevState; checkFGrad = gradFg * prevState;
checkOGrad = gradOg * state; checkOGrad = gradOg * state;
#else
gradOg = sigmoid<T>(outputGrad * stateAtv, valueOg);
stateGrad += tanh<T>(outputGrad * valueOg, stateAtv) + gradOg * checkO;
gradIn = tanh<T>(stateGrad * valueIg, valueIn);
gradIg = sigmoid<T>(stateGrad * valueIn, valueIg);
gradFg = sigmoid<T>(stateGrad * prevState, valueFg);
prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg;
checkIGrad = gradIg * prevState;
checkFGrad = gradFg * prevState;
checkOGrad = gradOg * state;
#endif
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
......
...@@ -110,7 +110,7 @@ def lstm( ...@@ -110,7 +110,7 @@ def lstm(
class TestLstmOp(OpTest): class TestLstmOp(OpTest):
def set_argument(self): def set_argument(self):
self.lod = [[0, 2, 6]] self.lod = [[0, 2, 5, 7]]
self.D = 16 self.D = 16
self.act_gate = 'sigmoid' self.act_gate = 'sigmoid'
...@@ -164,12 +164,13 @@ class TestLstmOp(OpTest): ...@@ -164,12 +164,13 @@ class TestLstmOp(OpTest):
# TODO(qingqing) remove folowing two lines after the check_grad is refined. # TODO(qingqing) remove folowing two lines after the check_grad is refined.
self.outputs['BatchGate'] = None self.outputs['BatchGate'] = None
self.outputs['BatchCellPreAct'] = None self.outputs['BatchCellPreAct'] = None
self.check_grad(['Input', 'Weight', 'Bias'], ['Hidden']) self.check_grad(
['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=0.02)
class TestLstmOpHasNoInitial(TestLstmOp): class TestLstmOpHasNoInitial(TestLstmOp):
def set_argument(self): def set_argument(self):
self.lod = [[0, 2, 6]] self.lod = [[0, 2, 5, 7]]
self.D = 16 self.D = 16
self.act_gate = 'sigmoid' self.act_gate = 'sigmoid'
...@@ -182,7 +183,7 @@ class TestLstmOpHasNoInitial(TestLstmOp): ...@@ -182,7 +183,7 @@ class TestLstmOpHasNoInitial(TestLstmOp):
class TestLstmOpRerverse(TestLstmOp): class TestLstmOpRerverse(TestLstmOp):
def set_argument(self): def set_argument(self):
self.lod = [[0, 2, 6]] self.lod = [[0, 2, 5, 7]]
self.D = 16 self.D = 16
self.act_gate = 'sigmoid' self.act_gate = 'sigmoid'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册