diff --git a/paddle/operators/lstm_unit_op.cc b/paddle/operators/lstm_unit_op.cc index e3cac2605a0892506815c595cf257b0268a07f02..3600f199770c4b8c9a6561b4c270a91bc8b20c0b 100644 --- a/paddle/operators/lstm_unit_op.cc +++ b/paddle/operators/lstm_unit_op.cc @@ -64,7 +64,7 @@ class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC(Lstm-Unit Operator Equation: - i, j, f, o = split(X) + i, f, o, j = split(X) C = C_prev * sigm(f + forget_bias) + sigm(i) * tanh(j) H = C * sigm(o) @@ -99,3 +99,5 @@ REGISTER_OP(lstm_unit, ops::LstmUnitOp, ops::LstmUnitOpMaker, lstm_unit_grad, ops::LstmUnitGradOp); REGISTER_OP_CPU_KERNEL(lstm_unit, ops::LstmUnitKernel); +REGISTER_OP_CPU_KERNEL( + lstm_unit_grad, ops::LstmUnitGradKernel); diff --git a/paddle/operators/lstm_unit_op.h b/paddle/operators/lstm_unit_op.h index 6e870f65e22c5a85ce34ec5c187c14c29570d89b..683034fe15df8cabfdff5e856adb5c0467055064 100644 --- a/paddle/operators/lstm_unit_op.h +++ b/paddle/operators/lstm_unit_op.h @@ -13,6 +13,7 @@ limitations under the License. */ #pragma once +#include "glog/logging.h" #include "paddle/framework/op_registry.h" namespace paddle { diff --git a/python/paddle/v2/framework/tests/test_lstm_unit_op.py b/python/paddle/v2/framework/tests/test_lstm_unit_op.py new file mode 100644 index 0000000000000000000000000000000000000000..f5888e908dc7806c11610f64734a117f21a3117b --- /dev/null +++ b/python/paddle/v2/framework/tests/test_lstm_unit_op.py @@ -0,0 +1,39 @@ +import unittest +import numpy as np +from op_test import OpTest +import glog as log + + +def sigmoid_np(x): + return 1. / (1. + np.exp(-x)) + + +def tanh_np(x): + return 2 * sigmoid_np(2. * x) - 1. + + +class LstmUnitTest(OpTest): + def setUp(self): + self.op_type = "lstm_unit" + x_np = np.random.normal(size=(5, 16)).astype("float32") + c_np = np.random.normal(size=(5, 4)).astype("float32") + i_np, f_np, o_np, j_np = np.split(x_np, 4, axis=1) + forget_bias_np = 0. + self.attrs = {'forget_bias': 0.} + + new_c = c_np * sigmoid_np(f_np + forget_bias_np) + sigmoid_np( + i_np) * tanh_np(j_np) + new_h = tanh_np(new_c) * sigmoid_np(o_np) + + self.inputs = {'X': x_np, 'C_prev': c_np} + self.outputs = {'C': new_c, 'H': new_h} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X', 'C_prev'], ['C', 'H'], max_relative_error=0.01) + + +if __name__ == "__main__": + unittest.main()