From 0922fca41ef2d1a0d71e85c0261ab58d954df369 Mon Sep 17 00:00:00 2001 From: guosheng Date: Fri, 15 Sep 2017 20:47:03 +0800 Subject: [PATCH] Add gru_unit_op --- paddle/operators/gru_unit_op.cc | 198 ++++++++++++++++++ paddle/operators/gru_unit_op.cu | 22 ++ paddle/operators/gru_unit_op.h | 191 +++++++++++++++++ .../v2/framework/tests/test_gru_unit_op.py | 59 ++++++ 4 files changed, 470 insertions(+) create mode 100644 paddle/operators/gru_unit_op.cc create mode 100644 paddle/operators/gru_unit_op.cu create mode 100644 paddle/operators/gru_unit_op.h create mode 100644 python/paddle/v2/framework/tests/test_gru_unit_op.py diff --git a/paddle/operators/gru_unit_op.cc b/paddle/operators/gru_unit_op.cc new file mode 100644 index 00000000000..d6d766cef0d --- /dev/null +++ b/paddle/operators/gru_unit_op.cc @@ -0,0 +1,198 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/gru_unit_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class GRUUnitOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("input"), + "Input(%s) of GRUUnitOp should not be null.", "input"); + PADDLE_ENFORCE(ctx->HasInput("hidden_prev"), + "Input(%s) of GRUUnitOp should not be null.", "hidden_prev"); + PADDLE_ENFORCE(ctx->HasInput("weight"), + "Input(%s) of GRUUnitOp should not be null.", "weight"); + PADDLE_ENFORCE(ctx->HasInput("bias"), + "Input(%s) of GRUUnitOp should not be null.", "bias"); + PADDLE_ENFORCE(ctx->HasOutput("gate"), + "Output(%s) of GRUUnitOp should not be null.", "gate"); + PADDLE_ENFORCE(ctx->HasOutput("reset_hidden_prev"), + "Output(%s) of GRUUnitOp should not be null.", + "reset_hidden_prev"); + PADDLE_ENFORCE(ctx->HasOutput("hidden"), + "Output(%s) of GRUUnitOp should not be null.", "hidden"); + auto input_dims = ctx->GetInputDim("input"); + auto hidden_prev_dims = ctx->GetInputDim("hidden_prev"); + auto weight_dims = ctx->GetInputDim("weight"); + auto bias_dims = ctx->GetInputDim("bias"); + int batch_size = input_dims[0]; + int input_size = input_dims[1]; + int frame_size = hidden_prev_dims[1]; + int weight_height = weight_dims[0]; + int weight_width = weight_dims[1]; + int bias_height = bias_dims[0]; + int bias_width = bias_dims[1]; + PADDLE_ENFORCE_EQ( + input_size, frame_size * 3, + "The innput_size must be 3 times of frame_size in GRUUnitOp."); + PADDLE_ENFORCE_EQ( + weight_height, frame_size, + "The shape of weight matrix must be [frame_size, frame_size * 3]."); + PADDLE_ENFORCE_EQ( + weight_width, frame_size * 3, + "The shape of weight matrix must be [frame_size, frame_size * 3]."); + PADDLE_ENFORCE_EQ(bias_height, 1, + "The shape of bias must be [1, frame_size * 3]."); + PADDLE_ENFORCE_EQ(bias_width, frame_size * 3, + "The shape of bias must be [1, frame_size * 3]."); + ctx->SetOutputDim("gate", {batch_size, frame_size * 3}); + ctx->SetOutputDim("reset_hidden_prev", {batch_size, frame_size}); + ctx->SetOutputDim("hidden", {batch_size, frame_size}); + } +}; + +class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker { + public: + GRUUnitOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("input", + "(Tensor) Matrix with shape [batch_size, frame_size * 3] for the " + "input."); + AddInput("hidden_prev", + "(Tensor) Matrix with shape [batch_size, frame_size] for the " + "states of previous time step."); + AddInput("weight", + "(Tensor) Weight matrix with shape [frame_size, frame_size * 3]. " + "The elements continuous in memory can be divided into two parts. " + "The first part are weights of the update gate and reset gate " + "with shape [frame_size, frame_size * 2], and the second part are " + "weights of output candidate with shape [frame_size, frame_size]"); + AddInput("bias", + "(Tensor) Bias vector with shape [1, frame_size * 3] concating " + "bias of the update gate, reset gate and output candidate."); + AddOutput("gate", + "(Tensor) Matrix with shape [batch_size, frame_size * 3] for the " + "output of update gate, reset gate and output candidate") + .AsIntermediate(); + AddOutput("reset_hidden_prev", + "(Tensor) Matrix with shape [batch_size, frame_size] for the " + "reseted hidden state of previous time step.") + .AsIntermediate(); + AddOutput("hidden", + "(Tensor) The GRU hidden state of the current time step " + "with shape [batch_size, frame_size]."); + AddComment(R"DOC( +GRUUnitOp implements part calculations of the GRU unit as following: + +\f[ +update \ gate: u_t = actGate(xu_t + W_u * hidden_prev + bias_u) \\ +reset \ gate: r_t = actGate(xr_t + W_r * hidden_prev + bias_r) \\ +output \ candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, hidden_prev) + bias_c) \\ +output: h_t = dot((1-u_t), {h}_t) + dot(u_t, hidden_prev) +\f] + +The rest of GRU unit can be completed by using FCOp's output as the input of GRUUnitOp. +)DOC"); + } +}; + +class GRUUnitGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("input"), + "Input(%s) of GRUUnitGradOp should not be null.", "input"); + PADDLE_ENFORCE(ctx->HasInput("hidden_prev"), + "Input(%s) of GRUUnitGradOp should not be null.", + "hidden_prev"); + PADDLE_ENFORCE(ctx->HasInput("weight"), + "Input(%s) of GRUUnitGradOp should not be null.", "weight"); + PADDLE_ENFORCE(ctx->HasInput("bias"), + "Input(%s) of GRUUnitGradOp should not be null.", "bias"); + PADDLE_ENFORCE(ctx->HasInput("gate"), + "Input(%s) of GRUUnitGradOp should not be null.", "gate"); + PADDLE_ENFORCE(ctx->HasInput("reset_hidden_prev"), + "Input(%s) of GRUUnitGradOp should not be null.", + "reset_hidden_prev"); + PADDLE_ENFORCE(ctx->HasInput("hidden"), + "Input(%s) of GRUUnitGradOp should not be null.", "hidden"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("gate")), + "Input(%s@GRAD) of GRUUnitGradOp should not be null.", + "gate"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("reset_hidden_prev")), + "Input(%s@GRAD) of GRUUnitGradOp should not be null.", + "reset_hidden_prev"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("hidden")), + "Input(%s@GRAD) of GRUUnitGradOp should not be null.", + "hidden"); + auto input_dims = ctx->GetInputDim("input"); + auto hidden_prev_dims = ctx->GetInputDim("hidden_prev"); + auto weight_dims = ctx->GetInputDim("weight"); + auto bias_dims = ctx->GetInputDim("bias"); + // int batch_size = input_dims[0]; + int input_size = input_dims[1]; + int frame_size = hidden_prev_dims[1]; + int weight_height = weight_dims[0]; + int weight_width = weight_dims[1]; + int bias_height = bias_dims[0]; + int bias_width = bias_dims[1]; + PADDLE_ENFORCE_EQ( + input_size, frame_size * 3, + "The innput_size must be 3 times of frame_size in GRUUnitOp."); + PADDLE_ENFORCE_EQ( + weight_height, frame_size, + "The shape of weight matrix must be [frame_size, frame_size * 3]."); + PADDLE_ENFORCE_EQ( + weight_width, frame_size * 3, + "The shape of weight matrix must be [frame_size, frame_size * 3]."); + PADDLE_ENFORCE_EQ(bias_height, 1, + "The shape of bias must be [1, frame_size * 3]."); + PADDLE_ENFORCE_EQ(bias_width, frame_size * 3, + "The shape of bias must be [1, frame_size * 3]."); + auto input_grad_name = framework::GradVarName("input"); + if (ctx->HasOutput(input_grad_name)) + ctx->SetOutputDim(input_grad_name, input_dims); + auto hidden_prev_grad_name = framework::GradVarName("hidden_prev"); + if (ctx->HasOutput(hidden_prev_grad_name)) + ctx->SetOutputDim(hidden_prev_grad_name, hidden_prev_dims); + auto weight_grad_name = framework::GradVarName("weight"); + if (ctx->HasOutput(weight_grad_name)) + ctx->SetOutputDim(weight_grad_name, weight_dims); + auto bias_grad_name = framework::GradVarName("bias"); + if (ctx->HasOutput(bias_grad_name)) + ctx->SetOutputDim(bias_grad_name, bias_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(gru_unit, ops::GRUUnitOp, ops::GRUUnitOpMaker, gru_unit_grad, + ops::GRUUnitGradOp); +REGISTER_OP_CPU_KERNEL(gru_unit, + ops::GRUUnitKernel); +REGISTER_OP_CPU_KERNEL( + gru_unit_grad, ops::GRUUnitGradKernel); diff --git a/paddle/operators/gru_unit_op.cu b/paddle/operators/gru_unit_op.cu new file mode 100644 index 00000000000..365f656523d --- /dev/null +++ b/paddle/operators/gru_unit_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/gru_unit_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(gru_unit, + ops::GRUUnitKernel); +REGISTER_OP_GPU_KERNEL( + gru_unit_grad, ops::GRUUnitGradKernel); diff --git a/paddle/operators/gru_unit_op.h b/paddle/operators/gru_unit_op.h new file mode 100644 index 00000000000..e48734b0660 --- /dev/null +++ b/paddle/operators/gru_unit_op.h @@ -0,0 +1,191 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/operators/math/math_function.h" + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +template +class GRUUnitKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("input"); + auto* hidden_prev = context.Input("hidden_prev"); + auto* weight = context.Input("weight"); + auto* bias = context.Input("bias"); + auto* gate = context.Output("gate"); + gate->mutable_data(context.GetPlace()); + auto* reset_hidden_prev = context.Output("reset_hidden_prev"); + reset_hidden_prev->mutable_data(context.GetPlace()); + auto* hidden = context.Output("hidden"); + hidden->mutable_data(context.GetPlace()); + + int batch_size = input->dims()[0]; + int frame_size = hidden_prev->dims()[1]; + + auto x = EigenMatrix::From(*input); + auto h_p = EigenMatrix::From(*hidden_prev); + auto b = EigenMatrix::From(*bias); + auto g = EigenMatrix::From(*gate); + auto r_h_p = EigenMatrix::From(*reset_hidden_prev); + auto h = EigenMatrix::From(*hidden); + auto place = context.GetEigenDevice(); + + // calculate unactivated gate outputs + g.device(place) = x + + b.reshape(Eigen::array({{1, frame_size * 3}})) + .broadcast(Eigen::array({{batch_size, 1}})); + const T* hidden_prev_data = hidden_prev->data(); + const T* weight_data = weight->data(); + T* gate_data = gate->data(); + T* reset_hidden_prev_data = reset_hidden_prev->data(); + math::gemm(context.device_context(), false, false, batch_size, + 2 * frame_size, frame_size, 1, hidden_prev_data, + frame_size, weight_data, frame_size * 2, 1, gate_data, + frame_size * 3); + + // calculate activited gate + Eigen::array extents({{batch_size, frame_size}}); + Eigen::array u_offsets({{0, 0}}); + g.slice(u_offsets, extents).device(place) = + g.slice(u_offsets, extents).sigmoid(); + auto u = g.slice(u_offsets, extents); // update gate + Eigen::array r_offsets({{0, frame_size}}); + g.slice(r_offsets, extents).device(place) = + g.slice(r_offsets, extents).sigmoid(); + auto r = g.slice(r_offsets, extents); // reset gate + r_h_p.device(place) = r * h_p; // reset previous hidden state + math::gemm(context.device_context(), false, false, batch_size, + frame_size, frame_size, 1, reset_hidden_prev_data, + frame_size, weight_data + frame_size * frame_size * 2, + frame_size, 1, gate_data + frame_size * 2, + frame_size * 3); + + Eigen::array c_offsets({{0, frame_size * 2}}); + g.slice(c_offsets, extents).device(place) = + g.slice(c_offsets, extents).tanh(); + auto c = g.slice(c_offsets, extents); // output candidate + + // calculate final output + h.device(place) = u * (h_p - c) + c; + } +}; + +template +class GRUUnitGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("input"); + auto* hidden_prev = context.Input("hidden_prev"); + auto* weight = context.Input("weight"); + auto* gate = context.Input("gate"); + auto* reset_hidden_prev = context.Input("reset_hidden_prev"); + auto* hidden_grad = context.Input(framework::GradVarName("hidden")); + auto* input_grad = context.Output(framework::GradVarName("input")); + auto* hidden_prev_grad = + context.Output(framework::GradVarName("hidden_prev")); + auto* weight_grad = + context.Output(framework::GradVarName("weight")); + auto* bias_grad = context.Output(framework::GradVarName("bias")); + input_grad->mutable_data(context.GetPlace()); + hidden_prev_grad->mutable_data(context.GetPlace()); + weight_grad->mutable_data(context.GetPlace()); + bias_grad->mutable_data(context.GetPlace()); + Tensor gate_grad; + gate_grad.mutable_data(input->dims(), context.GetPlace()); + Tensor reset_hidden_prev_grad; + reset_hidden_prev_grad.mutable_data(reset_hidden_prev->dims(), + context.GetPlace()); + + int batch_size = input->dims()[0]; + int frame_size = hidden_prev->dims()[1]; + + const T* hidden_prev_data = hidden_prev->data(); + T* hidden_prev_grad_data = hidden_prev_grad->data(); + const T* weight_data = weight->data(); + T* weight_grad_data = weight_grad->data(); + T* gate_grad_data = gate_grad.data(); + const T* reset_hidden_prev_data = reset_hidden_prev->data(); + T* reset_hidden_prev_grad_data = reset_hidden_prev_grad.data(); + + auto h_p = EigenMatrix::From(*hidden_prev); + auto g = EigenMatrix::From(*gate); + auto d_h = EigenMatrix::From(*hidden_grad); + auto d_x = EigenMatrix::From(*input_grad); + auto d_h_p = EigenMatrix::From(*hidden_prev_grad); + auto d_b = EigenMatrix::From(*bias_grad); + auto d_g = EigenMatrix::From(gate_grad); + auto d_r_h_p = EigenMatrix::From(reset_hidden_prev_grad); + auto place = context.GetEigenDevice(); + + Eigen::array extents({{batch_size, frame_size}}); + Eigen::array u_offsets({{0, 0}}); + auto u = g.slice(u_offsets, extents); // update gate + Eigen::array r_offsets({{0, frame_size}}); + auto r = g.slice(r_offsets, extents); // reset gate + Eigen::array c_offsets({{0, frame_size * 2}}); + auto c = g.slice(c_offsets, extents); // output candidate + + // backward for unactivated update gate + d_g.slice(u_offsets, extents).device(place) = + d_h * (h_p - c) * u * (u.constant(T(1)) - u); + // backward for unactivated output candidate + d_g.slice(c_offsets, extents).device(place) = + d_h * (u.constant(T(1)) - u) * (c.constant(T(1)) - c * c); + // backward for reset_hidden_prev + math::gemm(context.device_context(), false, true, batch_size, + frame_size, frame_size, 1, + gate_grad_data + frame_size * 2, frame_size * 3, + weight_data + frame_size * frame_size * 2, frame_size, + 0, reset_hidden_prev_grad_data, frame_size); + // backward for state_weight + math::gemm( + context.device_context(), true, false, frame_size, frame_size, + batch_size, 1, reset_hidden_prev_data, frame_size, + gate_grad_data + frame_size * 2, frame_size * 3, 0, + weight_grad_data + frame_size * frame_size * 2, frame_size); + // backward for unactivated reset gate + d_g.slice(r_offsets, extents).device(place) = + d_r_h_p * h_p * r * (r.constant(T(1)) - r); + // backward for update_gate_weight and reset_gate_weight + math::gemm(context.device_context(), true, false, frame_size, + frame_size * 2, batch_size, 1, hidden_prev_data, + frame_size, gate_grad_data, frame_size * 3, 0, + weight_grad_data, frame_size * 2); + // backward for hidden_prev + d_h_p.device(place) = d_r_h_p * r + d_h * u; + math::gemm(context.device_context(), false, true, batch_size, + frame_size, frame_size * 2, 1, gate_grad_data, + frame_size * 3, weight_data, frame_size * 2, 1, + hidden_prev_grad_data, frame_size); + // backward for input + d_x.device(place) = d_g; + // backward for bias + d_b.device(place) = d_g.sum(Eigen::array({{0}})); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_gru_unit_op.py b/python/paddle/v2/framework/tests/test_gru_unit_op.py new file mode 100644 index 00000000000..f7b3fab817d --- /dev/null +++ b/python/paddle/v2/framework/tests/test_gru_unit_op.py @@ -0,0 +1,59 @@ +import math +import unittest +import numpy as np +from op_test import OpTest + + +def sigmoid_np(x): + return 1. / (1. + np.exp(-x)) + + +def tanh_np(x): + return 2. * sigmoid_np(2. * x) - 1. + + +class TestGRUUnitOp(OpTest): + def setUp(self): + batch_size = 3 + frame_size = 5 + self.op_type = "gru_unit" + self.inputs = { + 'input': np.random.uniform( + -0.1, 0.1, (batch_size, frame_size * 3)).astype("float32"), + 'hidden_prev': np.random.uniform( + -0.1, 0.1, (batch_size, frame_size)).astype("float32"), + 'weight': np.random.uniform( + -1. / math.sqrt(frame_size), 1. / math.sqrt(frame_size), + (frame_size, frame_size * 3)).astype("float32"), + 'bias': np.random.uniform(-0.1, 0.1, + (1, frame_size * 3)).astype("float32") + } + x = self.inputs['input'] + h_p = self.inputs['hidden_prev'] + w = self.inputs['weight'] + b = self.inputs['bias'] + g = x + np.tile(b, (batch_size, 1)) + w_u_r = w.flatten()[:frame_size * frame_size * 2].reshape( + (frame_size, frame_size * 2)) + u_r = sigmoid_np(np.dot(h_p, w_u_r) + g[:, :frame_size * 2]) + u = u_r[:, :frame_size] + r = u_r[:, frame_size:frame_size * 2] + r_h_p = r * h_p + w_c = w.flatten()[frame_size * frame_size * 2:].reshape( + (frame_size, frame_size)) + c = tanh_np(np.dot(r_h_p, w_c) + g[:, frame_size * 2:]) + g = np.hstack((u_r, c)) + h = u * h_p + (1 - u) * c + self.outputs = {'gate': g, 'reset_hidden_prev': r_h_p, 'hidden': h} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ['input', 'hidden_prev', 'weight', 'bias'], ['hidden'], + max_relative_error=0.007) + + +if __name__ == '__main__': + unittest.main() -- GitLab