diff --git a/paddle/operators/gru_unit_op.cc b/paddle/operators/gru_unit_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..24f84597cd7301af6521b8c1032e69569ba6f03a --- /dev/null +++ b/paddle/operators/gru_unit_op.cc @@ -0,0 +1,210 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/gru_unit_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class GRUUnitOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(%s) of GRUUnitOp should not be null.", "Input"); + PADDLE_ENFORCE(ctx->HasInput("HiddenPrev"), + "Input(%s) of GRUUnitOp should not be null.", "HiddenPrev"); + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(%s) of GRUUnitOp should not be null.", "Weight"); + PADDLE_ENFORCE(ctx->HasOutput("Gate"), + "Output(%s) of GRUUnitOp should not be null.", "Gate"); + PADDLE_ENFORCE(ctx->HasOutput("ResetHiddenPrev"), + "Output(%s) of GRUUnitOp should not be null.", + "ResetHiddenPrev"); + PADDLE_ENFORCE(ctx->HasOutput("Hidden"), + "Output(%s) of GRUUnitOp should not be null.", "Hidden"); + auto input_dims = ctx->GetInputDim("Input"); + auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev"); + auto weight_dims = ctx->GetInputDim("Weight"); + int batch_size = input_dims[0]; + int input_size = input_dims[1]; + int frame_size = hidden_prev_dims[1]; + int weight_height = weight_dims[0]; + int weight_width = weight_dims[1]; + PADDLE_ENFORCE_EQ( + input_size, frame_size * 3, + "The input_size must be 3 times of frame_size in GRUUnitOp."); + PADDLE_ENFORCE_EQ( + weight_height, frame_size, + "The shape of Weight matrix must be [frame_size, frame_size * 3]."); + PADDLE_ENFORCE_EQ( + weight_width, frame_size * 3, + "The shape of Weight matrix must be [frame_size, frame_size * 3]."); + auto bias = Input("Bias"); + if (bias != framework::kEmptyVarName) { + auto bias_dims = ctx->GetInputDim("Bias"); + int bias_height = bias_dims[0]; + int bias_width = bias_dims[1]; + PADDLE_ENFORCE_EQ(bias_height, 1, + "The shape of Bias must be [1, frame_size * 3]."); + PADDLE_ENFORCE_EQ(bias_width, frame_size * 3, + "The shape of Bias must be [1, frame_size * 3]."); + } + ctx->SetOutputDim("Gate", {batch_size, frame_size * 3}); + ctx->SetOutputDim("ResetHiddenPrev", {batch_size, frame_size}); + ctx->SetOutputDim("Hidden", {batch_size, frame_size}); + } +}; + +class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker { + public: + GRUUnitOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Input", + "(Tensor) Matrix with shape [batch_size, frame_size * 3] for the " + "input."); + AddInput("HiddenPrev", + "(Tensor) Matrix with shape [batch_size, frame_size] for the " + "states of previous time step."); + AddInput("Weight", + "(Tensor) Weight matrix with shape [frame_size, frame_size * 3]. " + "The elements continuous in memory can be divided into two parts. " + "The first part are weights of the update gate and reset gate " + "with shape [frame_size, frame_size * 2], and the second part are " + "weights of output candidate with shape [frame_size, frame_size]"); + AddInput("Bias", + "(Tensor) Bias vector with shape [1, frame_size * 3] concating " + "bias of the update gate, reset gate and output candidate."); + AddOutput("Gate", + "(Tensor) Matrix with shape [batch_size, frame_size * 3] for the " + "output of update gate, reset gate and output candidate") + .AsIntermediate(); + AddOutput("ResetHiddenPrev", + "(Tensor) Matrix with shape [batch_size, frame_size] for the " + "reseted hidden state of previous time step.") + .AsIntermediate(); + AddOutput("Hidden", + "(Tensor) The GRU hidden state of the current time step " + "with shape [batch_size, frame_size]."); + AddAttr("activation", + "(enum int, default tanh) " + "The activation type used for output candidate {h}_t.") + .SetDefault(tanh) + .InEnum({identity, sigmoid, tanh, relu}); + AddAttr("gate_activation", + "(enum int, default sigmoid) " + "The activation type used in update gate and reset gate.") + .SetDefault(sigmoid) + .InEnum({identity, sigmoid, tanh, relu}); + AddComment(R"DOC( +GRUUnitOp implements part calculations of the GRU unit as following: + +\f[ +update \ gate: u_t = actGate(xu_t + W_u * hidden_prev + bias_u) \\ +reset \ gate: r_t = actGate(xr_t + W_r * hidden_prev + bias_r) \\ +output \ candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, hidden_prev) + bias_c) \\ +output: h_t = dot((1-u_t), {h}_t) + dot(u_t, hidden_prev) +\f] + +The rest of GRU unit can be completed by using FCOp's output as the input of GRUUnitOp. +)DOC"); + } +}; + +class GRUUnitGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(%s) of GRUUnitGradOp should not be null.", "Input"); + PADDLE_ENFORCE(ctx->HasInput("HiddenPrev"), + "Input(%s) of GRUUnitGradOp should not be null.", + "HiddenPrev"); + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(%s) of GRUUnitGradOp should not be null.", "Weight"); + PADDLE_ENFORCE(ctx->HasInput("Gate"), + "Input(%s) of GRUUnitGradOp should not be null.", "Gate"); + PADDLE_ENFORCE(ctx->HasInput("ResetHiddenPrev"), + "Input(%s) of GRUUnitGradOp should not be null.", + "ResetHiddenPrev"); + PADDLE_ENFORCE(ctx->HasInput("Hidden"), + "Input(%s) of GRUUnitGradOp should not be null.", "Hidden"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Gate")), + "Input(%s@GRAD) of GRUUnitGradOp should not be null.", + "Gate"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("ResetHiddenPrev")), + "Input(%s@GRAD) of GRUUnitGradOp should not be null.", + "ResetHiddenPrev"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")), + "Input(%s@GRAD) of GRUUnitGradOp should not be null.", + "Hidden"); + auto input_dims = ctx->GetInputDim("Input"); + auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev"); + auto weight_dims = ctx->GetInputDim("Weight"); + // int batch_size = input_dims[0]; + int input_size = input_dims[1]; + int frame_size = hidden_prev_dims[1]; + int weight_height = weight_dims[0]; + int weight_width = weight_dims[1]; + PADDLE_ENFORCE_EQ( + input_size, frame_size * 3, + "The input_size must be 3 times of frame_size in GRUUnitOp."); + PADDLE_ENFORCE_EQ( + weight_height, frame_size, + "The shape of Weight matrix must be [frame_size, frame_size * 3]."); + PADDLE_ENFORCE_EQ( + weight_width, frame_size * 3, + "The shape of Weight matrix must be [frame_size, frame_size * 3]."); + auto bias = Input("Bias"); + if (bias != framework::kEmptyVarName) { + auto bias_dims = ctx->GetInputDim("Bias"); + int bias_height = bias_dims[0]; + int bias_width = bias_dims[1]; + PADDLE_ENFORCE_EQ(bias_height, 1, + "The shape of Bias must be [1, frame_size * 3]."); + PADDLE_ENFORCE_EQ(bias_width, frame_size * 3, + "The shape of Bias must be [1, frame_size * 3]."); + auto bias_grad_name = framework::GradVarName("Bias"); + if (ctx->HasOutput(bias_grad_name)) + ctx->SetOutputDim(bias_grad_name, bias_dims); + } + auto input_grad_name = framework::GradVarName("Input"); + if (ctx->HasOutput(input_grad_name)) + ctx->SetOutputDim(input_grad_name, input_dims); + auto hidden_prev_grad_name = framework::GradVarName("HiddenPrev"); + if (ctx->HasOutput(hidden_prev_grad_name)) + ctx->SetOutputDim(hidden_prev_grad_name, hidden_prev_dims); + auto weight_grad_name = framework::GradVarName("Weight"); + if (ctx->HasOutput(weight_grad_name)) + ctx->SetOutputDim(weight_grad_name, weight_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(gru_unit, ops::GRUUnitOp, ops::GRUUnitOpMaker, gru_unit_grad, + ops::GRUUnitGradOp); +REGISTER_OP_CPU_KERNEL(gru_unit, + ops::GRUUnitKernel); +REGISTER_OP_CPU_KERNEL( + gru_unit_grad, ops::GRUUnitGradKernel); diff --git a/paddle/operators/gru_unit_op.cu b/paddle/operators/gru_unit_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..365f656523ddfb7ec8e2a5b885de74674823325a --- /dev/null +++ b/paddle/operators/gru_unit_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/gru_unit_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(gru_unit, + ops::GRUUnitKernel); +REGISTER_OP_GPU_KERNEL( + gru_unit_grad, ops::GRUUnitGradKernel); diff --git a/paddle/operators/gru_unit_op.h b/paddle/operators/gru_unit_op.h new file mode 100644 index 0000000000000000000000000000000000000000..c53e7d9827e0395e6ce613302e732b2797f83cdd --- /dev/null +++ b/paddle/operators/gru_unit_op.h @@ -0,0 +1,230 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/operators/activation_op.h" +#include "paddle/operators/math/math_function.h" + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +enum GRUActivationType { identity = 0, sigmoid = 1, tanh = 2, relu = 3 }; + +template +class GRUUnitKernel : public framework::OpKernel { + public: + template + void ActCompute(const int act_type, const Device& d, X x, Y y) const { + if (act_type == identity) + y.device(d) = x; + else if (act_type == sigmoid) + SigmoidFunctor()(d, x, y); + else if (act_type == tanh) + TanhFunctor()(d, x, y); + else if (act_type == relu) + ReluFunctor()(d, x, y); + else + PADDLE_THROW("unsupported activation type"); + } + + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("Input"); + auto* hidden_prev = context.Input("HiddenPrev"); + auto* weight = context.Input("Weight"); + auto* bias = context.Input("Bias"); + auto* gate = context.Output("Gate"); + gate->mutable_data(context.GetPlace()); + auto* reset_hidden_prev = context.Output("ResetHiddenPrev"); + reset_hidden_prev->mutable_data(context.GetPlace()); + auto* hidden = context.Output("Hidden"); + hidden->mutable_data(context.GetPlace()); + + int batch_size = input->dims()[0]; + int frame_size = hidden_prev->dims()[1]; + + auto x = EigenMatrix::From(*input); + auto h_p = EigenMatrix::From(*hidden_prev); + auto g = EigenMatrix::From(*gate); + auto r_h_p = EigenMatrix::From(*reset_hidden_prev); + auto h = EigenMatrix::From(*hidden); + auto place = context.GetEigenDevice(); + + // calculate unactivated gate outputs + if (bias) { + auto b = EigenMatrix::From(*bias); + g.device(place) = x + + b.reshape(Eigen::array({{1, frame_size * 3}})) + .broadcast(Eigen::array({{batch_size, 1}})); + } else { + g.device(place) = x; + } + const T* hidden_prev_data = hidden_prev->data(); + const T* weight_data = weight->data(); + T* gate_data = gate->data(); + T* reset_hidden_prev_data = reset_hidden_prev->data(); + math::gemm(context.device_context(), false, false, batch_size, + 2 * frame_size, frame_size, 1, hidden_prev_data, + frame_size, weight_data, frame_size * 2, 1, gate_data, + frame_size * 3); + + // calculate activited gate + Eigen::array extents({{batch_size, frame_size}}); + Eigen::array u_offsets({{0, 0}}); + ActCompute(context.Attr("gate_activation"), place, + g.slice(u_offsets, extents), g.slice(u_offsets, extents)); + auto u = g.slice(u_offsets, extents); // update gate + Eigen::array r_offsets({{0, frame_size}}); + ActCompute(context.Attr("gate_activation"), place, + g.slice(r_offsets, extents), g.slice(r_offsets, extents)); + auto r = g.slice(r_offsets, extents); // reset gate + r_h_p.device(place) = r * h_p; // reset previous hidden state + math::gemm(context.device_context(), false, false, batch_size, + frame_size, frame_size, 1, reset_hidden_prev_data, + frame_size, weight_data + frame_size * frame_size * 2, + frame_size, 1, gate_data + frame_size * 2, + frame_size * 3); + + Eigen::array c_offsets({{0, frame_size * 2}}); + ActCompute(context.Attr("activation"), place, + g.slice(c_offsets, extents), g.slice(c_offsets, extents)); + auto c = g.slice(c_offsets, extents); // output candidate + + // calculate final output + h.device(place) = u * (h_p - c) + c; + } +}; + +template +class GRUUnitGradKernel : public framework::OpKernel { + public: + template + void ActGradCompute(const int act_type, const Device& d, X x, Y y, DX dx, + DY dy) const { + // x is dummy and won't be used even in Relu(use y instead) + if (act_type == identity) + dx.device(d) = dy; + else if (act_type == sigmoid) + SigmoidGradFunctor()(d, x, y, dy, dx); + else if (act_type == tanh) + TanhGradFunctor()(d, x, y, dy, dx); + else if (act_type == relu) + ReluGradFunctor()(d, x, y, dy, dx); + else + PADDLE_THROW("unsupported activation type"); + } + + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("Input"); + auto* hidden_prev = context.Input("HiddenPrev"); + auto* weight = context.Input("Weight"); + auto* gate = context.Input("Gate"); + auto* reset_hidden_prev = context.Input("ResetHiddenPrev"); + auto* hidden_grad = context.Input(framework::GradVarName("Hidden")); + auto* input_grad = context.Output(framework::GradVarName("Input")); + auto* hidden_prev_grad = + context.Output(framework::GradVarName("HiddenPrev")); + auto* weight_grad = + context.Output(framework::GradVarName("Weight")); + auto* bias_grad = context.Output(framework::GradVarName("Bias")); + input_grad->mutable_data(context.GetPlace()); + hidden_prev_grad->mutable_data(context.GetPlace()); + weight_grad->mutable_data(context.GetPlace()); + Tensor gate_grad; + gate_grad.mutable_data(input->dims(), context.GetPlace()); + Tensor reset_hidden_prev_grad; + reset_hidden_prev_grad.mutable_data(reset_hidden_prev->dims(), + context.GetPlace()); + + int batch_size = input->dims()[0]; + int frame_size = hidden_prev->dims()[1]; + + const T* hidden_prev_data = hidden_prev->data(); + T* hidden_prev_grad_data = hidden_prev_grad->data(); + const T* weight_data = weight->data(); + T* weight_grad_data = weight_grad->data(); + T* gate_grad_data = gate_grad.data(); + const T* reset_hidden_prev_data = reset_hidden_prev->data(); + T* reset_hidden_prev_grad_data = reset_hidden_prev_grad.data(); + + auto h_p = EigenMatrix::From(*hidden_prev); + auto g = EigenMatrix::From(*gate); + auto d_h = EigenMatrix::From(*hidden_grad); + auto d_x = EigenMatrix::From(*input_grad); + auto d_h_p = EigenMatrix::From(*hidden_prev_grad); + auto d_g = EigenMatrix::From(gate_grad); + auto d_r_h_p = EigenMatrix::From(reset_hidden_prev_grad); + auto place = context.GetEigenDevice(); + + Eigen::array extents({{batch_size, frame_size}}); + Eigen::array u_offsets({{0, 0}}); + auto u = g.slice(u_offsets, extents); // update gate + Eigen::array r_offsets({{0, frame_size}}); + auto r = g.slice(r_offsets, extents); // reset gate + Eigen::array c_offsets({{0, frame_size * 2}}); + auto c = g.slice(c_offsets, extents); // output candidate + + // backward for unactivated update gate + ActGradCompute(context.Attr("gate_activation"), place, u, u, + d_g.slice(u_offsets, extents), d_h * (h_p - c)); + // backward for unactivated output candidate + ActGradCompute(context.Attr("activation"), place, c, c, + d_g.slice(c_offsets, extents), d_h * (u.constant(T(1)) - u)); + // backward for reset_hidden_prev + math::gemm(context.device_context(), false, true, batch_size, + frame_size, frame_size, 1, + gate_grad_data + frame_size * 2, frame_size * 3, + weight_data + frame_size * frame_size * 2, frame_size, + 0, reset_hidden_prev_grad_data, frame_size); + // backward for state_weight + math::gemm( + context.device_context(), true, false, frame_size, frame_size, + batch_size, 1, reset_hidden_prev_data, frame_size, + gate_grad_data + frame_size * 2, frame_size * 3, 0, + weight_grad_data + frame_size * frame_size * 2, frame_size); + // backward for unactivated reset gate + ActGradCompute(context.Attr("gate_activation"), place, r, r, + d_g.slice(r_offsets, extents), d_r_h_p * h_p); + // backward for update_gate_weight and reset_gate_weight + math::gemm(context.device_context(), true, false, frame_size, + frame_size * 2, batch_size, 1, hidden_prev_data, + frame_size, gate_grad_data, frame_size * 3, 0, + weight_grad_data, frame_size * 2); + // backward for hidden_prev + d_h_p.device(place) = d_r_h_p * r + d_h * u; + math::gemm(context.device_context(), false, true, batch_size, + frame_size, frame_size * 2, 1, gate_grad_data, + frame_size * 3, weight_data, frame_size * 2, 1, + hidden_prev_grad_data, frame_size); + // backward for input + d_x.device(place) = d_g; + // backward for bias + if (bias_grad) { + bias_grad->mutable_data(context.GetPlace()); + auto d_b = EigenMatrix::From(*bias_grad); + d_b.device(place) = d_g.sum(Eigen::array({{0}})); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_gru_unit_op.py b/python/paddle/v2/framework/tests/test_gru_unit_op.py new file mode 100644 index 0000000000000000000000000000000000000000..57625362d21905d257f46ff5330841a20438773a --- /dev/null +++ b/python/paddle/v2/framework/tests/test_gru_unit_op.py @@ -0,0 +1,115 @@ +import math +import unittest +import numpy as np +from op_test import OpTest + + +class GRUActivationType(OpTest): + identity = 0 + sigmoid = 1 + tanh = 2 + relu = 3 + + +def identity(x): + return x + + +def sigmoid(x): + return 1. / (1. + np.exp(-x)) + + +def tanh(x): + return 2. * sigmoid(2. * x) - 1. + + +def relu(x): + return np.maximum(x, 0) + + +class TestGRUUnitOp(OpTest): + batch_size = 3 + frame_size = 5 + activate = { + GRUActivationType.identity: identity, + GRUActivationType.sigmoid: sigmoid, + GRUActivationType.tanh: tanh, + GRUActivationType.relu: relu, + } + + def set_inputs(self): + batch_size = self.batch_size + frame_size = self.frame_size + self.op_type = 'gru_unit' + self.inputs = { + 'Input': np.random.uniform( + -0.1, 0.1, (batch_size, frame_size * 3)).astype('float32'), + 'HiddenPrev': np.random.uniform( + -0.1, 0.1, (batch_size, frame_size)).astype('float32'), + 'Weight': np.random.uniform( + -1. / math.sqrt(frame_size), 1. / math.sqrt(frame_size), + (frame_size, frame_size * 3)).astype('float32'), + } + self.attrs = { + 'activation': GRUActivationType.tanh, + 'gate_activation': GRUActivationType.sigmoid + } + + def set_outputs(self): + # GRU calculations + batch_size = self.batch_size + frame_size = self.frame_size + x = self.inputs['Input'] + h_p = self.inputs['HiddenPrev'] + w = self.inputs['Weight'] + b = self.inputs['Bias'] if self.inputs.has_key('Bias') else np.zeros( + (1, frame_size * 3)) + g = x + np.tile(b, (batch_size, 1)) + w_u_r = w.flatten()[:frame_size * frame_size * 2].reshape( + (frame_size, frame_size * 2)) + u_r = self.activate[self.attrs['gate_activation']](np.dot( + h_p, w_u_r) + g[:, :frame_size * 2]) + u = u_r[:, :frame_size] + r = u_r[:, frame_size:frame_size * 2] + r_h_p = r * h_p + w_c = w.flatten()[frame_size * frame_size * 2:].reshape( + (frame_size, frame_size)) + c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) + + g[:, frame_size * 2:]) + g = np.hstack((u_r, c)) + h = u * h_p + (1 - u) * c + self.outputs = {'Gate': g, 'ResetHiddenPrev': r_h_p, 'Hidden': h} + + def setUp(self): + self.set_inputs() + self.set_outputs() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ['Input', 'HiddenPrev', 'Weight'], ['Hidden'], + max_relative_error=0.007) + + +class TestGRUUnitOpWithBias(TestGRUUnitOp): + def set_inputs(self): + batch_size = self.batch_size + frame_size = self.frame_size + super(TestGRUUnitOpWithBias, self).set_inputs() + self.inputs['Bias'] = np.random.uniform( + -0.1, 0.1, (1, frame_size * 3)).astype('float32') + self.attrs = { + 'activation': GRUActivationType.identity, + 'gate_activation': GRUActivationType.sigmoid + } + + def test_check_grad(self): + self.check_grad( + ['Input', 'HiddenPrev', 'Weight', 'Bias'], ['Hidden'], + max_relative_error=0.007) + + +if __name__ == '__main__': + unittest.main()