提交 b50c33fd 编写于 作者: D dangqingqing

Use fixed activation in the lstm kernel, since there is some bug in the...

Use fixed activation in the lstm kernel, since there is some bug in the activation function pointer. It will be fixed later.
上级 bd680f15
......@@ -82,6 +82,13 @@ class LSTMOp : public framework::OperatorWithKernel {
ctx->ShareLoD("Input", "Hidden");
ctx->ShareLoD("Input", "Cell");
}
protected:
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(
ctx.Input<framework::LoDTensor>("Input")->type());
}
};
class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -239,6 +246,13 @@ class LSTMGradOp : public framework::OperatorWithKernel {
if (ctx->HasOutput(b_g_name))
ctx->SetOutputDim(b_g_name, ctx->GetInputDim("Bias"));
}
protected:
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(
ctx.Input<framework::LoDTensor>("Input")->type());
}
};
} // namespace operators
......
......@@ -26,10 +26,7 @@ namespace detail {
template <class T, class Op>
void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frameSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
int frameSize) {
T rValueIn;
T rValueIg;
T rValueFg;
......@@ -60,10 +57,8 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
rPrevState = value.prevStateValue[i];
}
hppl::cpu::ForwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate),
act(active_state));
rOut, rCheckI, rCheckF, rCheckO);
valueIn[i] = rValueIn;
valueIg[i] = rValueIg;
......@@ -77,10 +72,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op>
void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frameSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
LstmMetaGrad<T> grad, int frameSize) {
T rValueIn;
T rValueIg;
T rValueFg;
......@@ -127,11 +119,10 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
rPrevState = value.prevStateValue[i];
}
hppl::cpu::BackwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg,
rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv,
rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad,
rCheckOGrad, act(active_node), act(active_gate), act(active_state));
rCheckOGrad);
gradIn[i] = rGradIn;
gradIg[i] = rGradIg;
......@@ -283,8 +274,7 @@ void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frameSize,
avx_lstm_forward_one_sequence<T>(op, value, frameSize, active_node,
active_gate, active_state);
} else {
naive_lstm_forward_one_sequence<T>(op, value, frameSize, active_node,
active_gate, active_state);
naive_lstm_forward_one_sequence<T>(op, value, frameSize);
}
}
......@@ -297,8 +287,7 @@ void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
avx_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node,
active_gate, active_state);
} else {
naive_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node,
active_gate, active_state);
naive_lstm_backward_one_sequence<T>(op, value, grad, frameSize);
}
}
......
......@@ -32,9 +32,7 @@ namespace detail {
*/
template <class T, class Op, bool isBatch>
__global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
int batchSize, activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
int batchSize) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
......@@ -70,10 +68,8 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
rPrevState = value.prevStateValue[frameIdx];
}
hppl::gpu::ForwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate),
act(active_state));
rOut, rCheckI, rCheckF, rCheckO);
value.gateValue[frameIdx] = rValueIn;
value.gateValue[frameIdx + frameSize] = rValueIg;
......@@ -92,9 +88,7 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
template <class T, class Op, bool isBatch>
__global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frameSize,
int batchSize, activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
int batchSize) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
......@@ -145,11 +139,9 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
rPrevState = value.prevStateValue[frameIdx];
}
hppl::gpu::BackwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg,
rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad,
rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad,
act(active_node), act(active_gate), act(active_state));
rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad);
grad.gateGrad[frameIdx] = rGradIn;
grad.gateGrad[frameIdx + frameSize] = rGradIg;
......@@ -205,13 +197,11 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
if (batchSize == 1) {
KeLstmForward<T, Op,
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
op, value, frameSize, batchSize, active_node, active_gate,
active_state);
op, value, frameSize, batchSize);
} else {
KeLstmForward<T, Op,
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
op, value, frameSize, batchSize, active_node, active_gate,
active_state);
op, value, frameSize, batchSize);
}
}
......@@ -240,13 +230,11 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
if (batchSize == 1) {
KeLstmBackward<T, Op,
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
op, value, grad, frameSize, batchSize, active_node, active_gate,
active_state);
op, value, grad, frameSize, batchSize);
} else {
KeLstmBackward<T, Op,
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
op, value, grad, frameSize, batchSize, active_node, active_gate,
active_state);
op, value, grad, frameSize, batchSize);
}
}
......
......@@ -24,15 +24,29 @@ namespace detail {
namespace forward {
template <typename T>
DEVICE inline T sigmoid(const T a) {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
T tmp = (a < min) ? min : ((a > max) ? max : a);
return static_cast<T>(1.0) / (static_cast<T>(1.0) + exp(-tmp));
}
template <typename T>
DEVICE inline T tanh(const T a) {
T tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return (2.0 / (1.0 + exp(tmp))) - 1.0;
}
template <class T>
class lstm {
public:
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
T &prevState, T &state, T &stateAtv, T &output,
T &checkI, T &checkF, T &checkO,
typename hppl::ForwardActType<T>::type actInput,
typename hppl::ForwardActType<T>::type actGate,
typename hppl::ForwardActType<T>::type actState) {
T &checkI, T &checkF, T &checkO) {
#if 0
// TODO(qingqing) support to activation speficed by users
valueIn = actInput(valueIn);
valueIg = actGate(valueIg + prevState * checkI);
valueFg = actGate(valueFg + prevState * checkF);
......@@ -40,6 +54,15 @@ class lstm {
valueOg = actGate(valueOg + state * checkO);
stateAtv = actState(state);
output = valueOg * stateAtv;
#else
valueIn = tanh<T>(valueIn);
valueIg = sigmoid<T>(valueIg + prevState * checkI);
valueFg = sigmoid<T>(valueFg + prevState * checkF);
state = valueIn * valueIg + prevState * valueFg;
valueOg = sigmoid<T>(valueOg + state * checkO);
stateAtv = tanh<T>(state);
output = valueOg * stateAtv;
#endif
}
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
......@@ -72,6 +95,16 @@ class lstm {
namespace backward {
template <typename T>
DEVICE inline T sigmoid(const T a, const T b) {
return a * b * (1.0 - b);
}
template <typename T>
DEVICE inline T tanh(const T a, const T b) {
return a * (1.0 - b * b);
}
template <class T>
class lstm {
public:
......@@ -80,10 +113,9 @@ class lstm {
T &prevState, T &prevStateGrad, T &state,
T &stateGrad, T &stateAtv, T &outputGrad,
T &checkI, T &checkF, T &checkO, T &checkIGrad,
T &checkFGrad, T &checkOGrad,
typename hppl::BackwardActType<T>::type actInput,
typename hppl::BackwardActType<T>::type actGate,
typename hppl::BackwardActType<T>::type actState) {
T &checkFGrad, T &checkOGrad) {
#if 0
// TODO(qingqing) support to activation speficed by users
gradOg = actGate(outputGrad * stateAtv, valueOg);
stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO;
gradIn = actInput(stateGrad * valueIg, valueIn);
......@@ -93,6 +125,17 @@ class lstm {
checkIGrad = gradIg * prevState;
checkFGrad = gradFg * prevState;
checkOGrad = gradOg * state;
#else
gradOg = sigmoid<T>(outputGrad * stateAtv, valueOg);
stateGrad += tanh<T>(outputGrad * valueOg, stateAtv) + gradOg * checkO;
gradIn = tanh<T>(stateGrad * valueIg, valueIn);
gradIg = sigmoid<T>(stateGrad * valueIn, valueIg);
gradFg = sigmoid<T>(stateGrad * prevState, valueFg);
prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg;
checkIGrad = gradIg * prevState;
checkFGrad = gradFg * prevState;
checkOGrad = gradOg * state;
#endif
}
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
......
......@@ -110,7 +110,7 @@ def lstm(
class TestLstmOp(OpTest):
def set_argument(self):
self.lod = [[0, 2, 6]]
self.lod = [[0, 2, 5, 7]]
self.D = 16
self.act_gate = 'sigmoid'
......@@ -164,12 +164,13 @@ class TestLstmOp(OpTest):
# TODO(qingqing) remove folowing two lines after the check_grad is refined.
self.outputs['BatchGate'] = None
self.outputs['BatchCellPreAct'] = None
self.check_grad(['Input', 'Weight', 'Bias'], ['Hidden'])
self.check_grad(
['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=0.02)
class TestLstmOpHasNoInitial(TestLstmOp):
def set_argument(self):
self.lod = [[0, 2, 6]]
self.lod = [[0, 2, 5, 7]]
self.D = 16
self.act_gate = 'sigmoid'
......@@ -182,7 +183,7 @@ class TestLstmOpHasNoInitial(TestLstmOp):
class TestLstmOpRerverse(TestLstmOp):
def set_argument(self):
self.lod = [[0, 2, 6]]
self.lod = [[0, 2, 5, 7]]
self.D = 16
self.act_gate = 'sigmoid'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册