提交 f60f0eae 编写于 作者: Y Yu Yang

Using double precision to stablize lstm gradient check

上级 9fbf94b6
......@@ -47,7 +47,6 @@ class LstmUnitOp : public framework::OperatorWithKernel {
}
};
template <typename AttrType>
class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LstmUnitOpMaker(framework::OpProto* proto,
......@@ -68,7 +67,7 @@ Equation:
H = C * sigm(o)
)DOC");
AddAttr<AttrType>("forget_bias", "The forget bias of Lstm Unit.")
AddAttr<float>("forget_bias", "The forget bias of Lstm Unit.")
.SetDefault(0.0);
}
};
......@@ -93,9 +92,11 @@ class LstmUnitGradOp : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(lstm_unit, ops::LstmUnitOp, ops::LstmUnitOpMaker<float>,
lstm_unit_grad, ops::LstmUnitGradOp);
REGISTER_OP(lstm_unit, ops::LstmUnitOp, ops::LstmUnitOpMaker, lstm_unit_grad,
ops::LstmUnitGradOp);
REGISTER_OP_CPU_KERNEL(lstm_unit,
ops::LstmUnitKernel<paddle::platform::CPUPlace, float>);
ops::LstmUnitKernel<paddle::platform::CPUPlace, float>,
ops::LstmUnitKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
lstm_unit_grad, ops::LstmUnitGradKernel<paddle::platform::CPUPlace, float>);
lstm_unit_grad, ops::LstmUnitGradKernel<paddle::platform::CPUPlace, float>,
ops::LstmUnitGradKernel<paddle::platform::CPUPlace, double>);
......@@ -89,7 +89,7 @@ __global__ void LSTMUnitGradientKernel(const int nthreads, const int dim,
}
}
template <typename T, typename AttrType = T>
template <typename T>
class LstmUnitOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -101,7 +101,7 @@ class LstmUnitOpCUDAKernel : public framework::OpKernel<T> {
auto* c_tensor = ctx.Output<framework::Tensor>("C");
auto* h_tensor = ctx.Output<framework::Tensor>("H");
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
auto forget_bias = static_cast<T>(ctx.Attr<float>("forget_bias"));
int b_size = c_tensor->dims()[0];
int D = c_tensor->dims()[1];
......@@ -120,7 +120,7 @@ class LstmUnitOpCUDAKernel : public framework::OpKernel<T> {
}
};
template <typename T, typename AttrType = T>
template <typename T>
class LstmUnitGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -153,7 +153,7 @@ class LstmUnitGradOpCUDAKernel : public framework::OpKernel<T> {
int N = c_tensor->dims()[0];
int D = c_tensor->dims()[1];
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
auto forget_bias = static_cast<T>(ctx.Attr<float>("forget_bias"));
int block = 512;
int n = N * D;
......@@ -169,5 +169,7 @@ class LstmUnitGradOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(lstm_unit, ops::LstmUnitOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(lstm_unit_grad, ops::LstmUnitGradOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(lstm_unit, ops::LstmUnitOpCUDAKernel<float>,
ops::LstmUnitOpCUDAKernel<double>);
REGISTER_OP_GPU_KERNEL(lstm_unit_grad, ops::LstmUnitGradOpCUDAKernel<float>,
ops::LstmUnitGradOpCUDAKernel<double>);
......@@ -32,7 +32,7 @@ inline T tanh(T x) {
return 2. * sigmoid(2. * x) - 1.;
}
template <typename Place, typename T, typename AttrType = T>
template <typename Place, typename T>
class LstmUnitKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -44,7 +44,7 @@ class LstmUnitKernel : public framework::OpKernel<T> {
auto* c_tensor = ctx.Output<framework::Tensor>("C");
auto* h_tensor = ctx.Output<framework::Tensor>("H");
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
auto forget_bias = static_cast<T>(ctx.Attr<float>("forget_bias"));
int b_size = c_tensor->dims()[0];
int D = c_tensor->dims()[1];
......@@ -75,7 +75,7 @@ class LstmUnitKernel : public framework::OpKernel<T> {
}
};
template <typename Place, typename T, typename AttrType = T>
template <typename Place, typename T>
class LstmUnitGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -108,7 +108,7 @@ class LstmUnitGradKernel : public framework::OpKernel<T> {
int N = c_tensor->dims()[0];
int D = c_tensor->dims()[1];
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
auto forget_bias = static_cast<T>(ctx.Attr<float>("forget_bias"));
for (int n = 0; n < N; ++n) {
for (int d = 0; d < D; ++d) {
......
......@@ -14,8 +14,8 @@ def tanh_np(x):
class LstmUnitTest(OpTest):
def setUp(self):
self.op_type = "lstm_unit"
x_np = np.random.normal(size=(5, 16)).astype("float32")
c_np = np.random.normal(size=(5, 4)).astype("float32")
x_np = np.random.normal(size=(5, 16)).astype("float64")
c_np = np.random.normal(size=(5, 4)).astype("float64")
i_np, f_np, o_np, j_np = np.split(x_np, 4, axis=1)
forget_bias_np = 0.
self.attrs = {'forget_bias': 0.}
......@@ -31,7 +31,7 @@ class LstmUnitTest(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X', 'C_prev'], ['C', 'H'], max_relative_error=0.01)
self.check_grad(['X', 'C_prev'], ['C', 'H'])
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册