From 1cabdb870893da8e242bcae88b76498d9aee3be1 Mon Sep 17 00:00:00 2001 From: guosheng Date: Wed, 11 Oct 2017 16:19:49 +0800 Subject: [PATCH] Refine gru_unit_op according to comments to support multiple activation types --- paddle/operators/gru_unit_op.cc | 150 ++++++++++-------- paddle/operators/gru_unit_op.h | 95 +++++++---- .../v2/framework/tests/test_gru_unit_op.py | 70 +++++--- 3 files changed, 194 insertions(+), 121 deletions(-) diff --git a/paddle/operators/gru_unit_op.cc b/paddle/operators/gru_unit_op.cc index d6d766cef0d..9a34daf9849 100644 --- a/paddle/operators/gru_unit_op.cc +++ b/paddle/operators/gru_unit_op.cc @@ -24,26 +24,26 @@ class GRUUnitOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("input"), - "Input(%s) of GRUUnitOp should not be null.", "input"); - PADDLE_ENFORCE(ctx->HasInput("hidden_prev"), - "Input(%s) of GRUUnitOp should not be null.", "hidden_prev"); - PADDLE_ENFORCE(ctx->HasInput("weight"), - "Input(%s) of GRUUnitOp should not be null.", "weight"); - PADDLE_ENFORCE(ctx->HasInput("bias"), - "Input(%s) of GRUUnitOp should not be null.", "bias"); - PADDLE_ENFORCE(ctx->HasOutput("gate"), - "Output(%s) of GRUUnitOp should not be null.", "gate"); - PADDLE_ENFORCE(ctx->HasOutput("reset_hidden_prev"), + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(%s) of GRUUnitOp should not be null.", "Input"); + PADDLE_ENFORCE(ctx->HasInput("HiddenPrev"), + "Input(%s) of GRUUnitOp should not be null.", "HiddenPrev"); + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(%s) of GRUUnitOp should not be null.", "Weight"); + PADDLE_ENFORCE(ctx->HasInput("Bias"), + "Input(%s) of GRUUnitOp should not be null.", "Bias"); + PADDLE_ENFORCE(ctx->HasOutput("Gate"), + "Output(%s) of GRUUnitOp should not be null.", "Gate"); + PADDLE_ENFORCE(ctx->HasOutput("ResetHiddenPrev"), "Output(%s) of GRUUnitOp should not be null.", - "reset_hidden_prev"); - PADDLE_ENFORCE(ctx->HasOutput("hidden"), - "Output(%s) of GRUUnitOp should not be null.", "hidden"); - auto input_dims = ctx->GetInputDim("input"); - auto hidden_prev_dims = ctx->GetInputDim("hidden_prev"); - auto weight_dims = ctx->GetInputDim("weight"); - auto bias_dims = ctx->GetInputDim("bias"); + "ResetHiddenPrev"); + PADDLE_ENFORCE(ctx->HasOutput("Hidden"), + "Output(%s) of GRUUnitOp should not be null.", "Hidden"); + auto input_dims = ctx->GetInputDim("Input"); + auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev"); + auto weight_dims = ctx->GetInputDim("Weight"); + auto bias_dims = ctx->GetInputDim("Bias"); int batch_size = input_dims[0]; int input_size = input_dims[1]; int frame_size = hidden_prev_dims[1]; @@ -53,54 +53,64 @@ class GRUUnitOp : public framework::OperatorWithKernel { int bias_width = bias_dims[1]; PADDLE_ENFORCE_EQ( input_size, frame_size * 3, - "The innput_size must be 3 times of frame_size in GRUUnitOp."); + "The input_size must be 3 times of frame_size in GRUUnitOp."); PADDLE_ENFORCE_EQ( weight_height, frame_size, - "The shape of weight matrix must be [frame_size, frame_size * 3]."); + "The shape of Weight matrix must be [frame_size, frame_size * 3]."); PADDLE_ENFORCE_EQ( weight_width, frame_size * 3, - "The shape of weight matrix must be [frame_size, frame_size * 3]."); + "The shape of Weight matrix must be [frame_size, frame_size * 3]."); PADDLE_ENFORCE_EQ(bias_height, 1, - "The shape of bias must be [1, frame_size * 3]."); + "The shape of Bias must be [1, frame_size * 3]."); PADDLE_ENFORCE_EQ(bias_width, frame_size * 3, - "The shape of bias must be [1, frame_size * 3]."); - ctx->SetOutputDim("gate", {batch_size, frame_size * 3}); - ctx->SetOutputDim("reset_hidden_prev", {batch_size, frame_size}); - ctx->SetOutputDim("hidden", {batch_size, frame_size}); + "The shape of Bias must be [1, frame_size * 3]."); + ctx->SetOutputDim("Gate", {batch_size, frame_size * 3}); + ctx->SetOutputDim("ResetHiddenPrev", {batch_size, frame_size}); + ctx->SetOutputDim("Hidden", {batch_size, frame_size}); } }; class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker { public: - GRUUnitOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + GRUUnitOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", + AddInput("Input", "(Tensor) Matrix with shape [batch_size, frame_size * 3] for the " "input."); - AddInput("hidden_prev", + AddInput("HiddenPrev", "(Tensor) Matrix with shape [batch_size, frame_size] for the " "states of previous time step."); - AddInput("weight", + AddInput("Weight", "(Tensor) Weight matrix with shape [frame_size, frame_size * 3]. " "The elements continuous in memory can be divided into two parts. " "The first part are weights of the update gate and reset gate " "with shape [frame_size, frame_size * 2], and the second part are " "weights of output candidate with shape [frame_size, frame_size]"); - AddInput("bias", + AddInput("Bias", "(Tensor) Bias vector with shape [1, frame_size * 3] concating " "bias of the update gate, reset gate and output candidate."); - AddOutput("gate", + AddOutput("Gate", "(Tensor) Matrix with shape [batch_size, frame_size * 3] for the " "output of update gate, reset gate and output candidate") .AsIntermediate(); - AddOutput("reset_hidden_prev", + AddOutput("ResetHiddenPrev", "(Tensor) Matrix with shape [batch_size, frame_size] for the " "reseted hidden state of previous time step.") .AsIntermediate(); - AddOutput("hidden", + AddOutput("Hidden", "(Tensor) The GRU hidden state of the current time step " "with shape [batch_size, frame_size]."); + AddAttr("activation", + "(enum int, default tanh) " + "The activation type used for output candidate {h}_t.") + .SetDefault(tanh) + .InEnum({identity, sigmoid, tanh, relu}); + AddAttr("gate_activation", + "(enum int, default sigmoid) " + "The activation type used in update gate and reset gate.") + .SetDefault(sigmoid) + .InEnum({identity, sigmoid, tanh, relu}); AddComment(R"DOC( GRUUnitOp implements part calculations of the GRU unit as following: @@ -121,36 +131,36 @@ class GRUUnitGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("input"), - "Input(%s) of GRUUnitGradOp should not be null.", "input"); - PADDLE_ENFORCE(ctx->HasInput("hidden_prev"), + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(%s) of GRUUnitGradOp should not be null.", "Input"); + PADDLE_ENFORCE(ctx->HasInput("HiddenPrev"), "Input(%s) of GRUUnitGradOp should not be null.", - "hidden_prev"); - PADDLE_ENFORCE(ctx->HasInput("weight"), - "Input(%s) of GRUUnitGradOp should not be null.", "weight"); - PADDLE_ENFORCE(ctx->HasInput("bias"), - "Input(%s) of GRUUnitGradOp should not be null.", "bias"); - PADDLE_ENFORCE(ctx->HasInput("gate"), - "Input(%s) of GRUUnitGradOp should not be null.", "gate"); - PADDLE_ENFORCE(ctx->HasInput("reset_hidden_prev"), + "HiddenPrev"); + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(%s) of GRUUnitGradOp should not be null.", "Weight"); + PADDLE_ENFORCE(ctx->HasInput("Bias"), + "Input(%s) of GRUUnitGradOp should not be null.", "Bias"); + PADDLE_ENFORCE(ctx->HasInput("Gate"), + "Input(%s) of GRUUnitGradOp should not be null.", "Gate"); + PADDLE_ENFORCE(ctx->HasInput("ResetHiddenPrev"), "Input(%s) of GRUUnitGradOp should not be null.", - "reset_hidden_prev"); - PADDLE_ENFORCE(ctx->HasInput("hidden"), - "Input(%s) of GRUUnitGradOp should not be null.", "hidden"); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("gate")), + "ResetHiddenPrev"); + PADDLE_ENFORCE(ctx->HasInput("Hidden"), + "Input(%s) of GRUUnitGradOp should not be null.", "Hidden"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Gate")), "Input(%s@GRAD) of GRUUnitGradOp should not be null.", - "gate"); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("reset_hidden_prev")), + "Gate"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("ResetHiddenPrev")), "Input(%s@GRAD) of GRUUnitGradOp should not be null.", - "reset_hidden_prev"); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("hidden")), + "ResetHiddenPrev"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")), "Input(%s@GRAD) of GRUUnitGradOp should not be null.", - "hidden"); - auto input_dims = ctx->GetInputDim("input"); - auto hidden_prev_dims = ctx->GetInputDim("hidden_prev"); - auto weight_dims = ctx->GetInputDim("weight"); - auto bias_dims = ctx->GetInputDim("bias"); + "Hidden"); + auto input_dims = ctx->GetInputDim("Input"); + auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev"); + auto weight_dims = ctx->GetInputDim("Weight"); + auto bias_dims = ctx->GetInputDim("Bias"); // int batch_size = input_dims[0]; int input_size = input_dims[1]; int frame_size = hidden_prev_dims[1]; @@ -160,27 +170,27 @@ class GRUUnitGradOp : public framework::OperatorWithKernel { int bias_width = bias_dims[1]; PADDLE_ENFORCE_EQ( input_size, frame_size * 3, - "The innput_size must be 3 times of frame_size in GRUUnitOp."); + "The input_size must be 3 times of frame_size in GRUUnitOp."); PADDLE_ENFORCE_EQ( weight_height, frame_size, - "The shape of weight matrix must be [frame_size, frame_size * 3]."); + "The shape of Weight matrix must be [frame_size, frame_size * 3]."); PADDLE_ENFORCE_EQ( weight_width, frame_size * 3, - "The shape of weight matrix must be [frame_size, frame_size * 3]."); + "The shape of Weight matrix must be [frame_size, frame_size * 3]."); PADDLE_ENFORCE_EQ(bias_height, 1, - "The shape of bias must be [1, frame_size * 3]."); + "The shape of Bias must be [1, frame_size * 3]."); PADDLE_ENFORCE_EQ(bias_width, frame_size * 3, - "The shape of bias must be [1, frame_size * 3]."); - auto input_grad_name = framework::GradVarName("input"); + "The shape of Bias must be [1, frame_size * 3]."); + auto input_grad_name = framework::GradVarName("Input"); if (ctx->HasOutput(input_grad_name)) ctx->SetOutputDim(input_grad_name, input_dims); - auto hidden_prev_grad_name = framework::GradVarName("hidden_prev"); + auto hidden_prev_grad_name = framework::GradVarName("HiddenPrev"); if (ctx->HasOutput(hidden_prev_grad_name)) ctx->SetOutputDim(hidden_prev_grad_name, hidden_prev_dims); - auto weight_grad_name = framework::GradVarName("weight"); + auto weight_grad_name = framework::GradVarName("Weight"); if (ctx->HasOutput(weight_grad_name)) ctx->SetOutputDim(weight_grad_name, weight_dims); - auto bias_grad_name = framework::GradVarName("bias"); + auto bias_grad_name = framework::GradVarName("Bias"); if (ctx->HasOutput(bias_grad_name)) ctx->SetOutputDim(bias_grad_name, bias_dims); } diff --git a/paddle/operators/gru_unit_op.h b/paddle/operators/gru_unit_op.h index e48734b0660..e97aa38ac64 100644 --- a/paddle/operators/gru_unit_op.h +++ b/paddle/operators/gru_unit_op.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/operators/activation_op.h" #include "paddle/operators/math/math_function.h" #include "paddle/framework/eigen.h" @@ -27,19 +28,35 @@ template using EigenMatrix = framework::EigenMatrix; +enum GRUActivationType { identity = 0, sigmoid = 1, tanh = 2, relu = 3 }; + template -class GRUUnitKernel : public framework::OpKernel { +class GRUUnitKernel : public framework::OpKernel { public: + template + void ActCompute(const int act_type, const Device& d, X x, Y y) const { + if (act_type == identity) + y.device(d) = x; + else if (act_type == sigmoid) + SigmoidFunctor()(d, x, y); + else if (act_type == tanh) + TanhFunctor()(d, x, y); + else if (act_type == relu) + ReluFunctor()(d, x, y); + else + PADDLE_THROW("unsupported activation type"); + } + void Compute(const framework::ExecutionContext& context) const override { - auto* input = context.Input("input"); - auto* hidden_prev = context.Input("hidden_prev"); - auto* weight = context.Input("weight"); - auto* bias = context.Input("bias"); - auto* gate = context.Output("gate"); + auto* input = context.Input("Input"); + auto* hidden_prev = context.Input("HiddenPrev"); + auto* weight = context.Input("Weight"); + auto* bias = context.Input("Bias"); + auto* gate = context.Output("Gate"); gate->mutable_data(context.GetPlace()); - auto* reset_hidden_prev = context.Output("reset_hidden_prev"); + auto* reset_hidden_prev = context.Output("ResetHiddenPrev"); reset_hidden_prev->mutable_data(context.GetPlace()); - auto* hidden = context.Output("hidden"); + auto* hidden = context.Output("Hidden"); hidden->mutable_data(context.GetPlace()); int batch_size = input->dims()[0]; @@ -69,12 +86,12 @@ class GRUUnitKernel : public framework::OpKernel { // calculate activited gate Eigen::array extents({{batch_size, frame_size}}); Eigen::array u_offsets({{0, 0}}); - g.slice(u_offsets, extents).device(place) = - g.slice(u_offsets, extents).sigmoid(); + ActCompute(context.Attr("gate_activation"), place, + g.slice(u_offsets, extents), g.slice(u_offsets, extents)); auto u = g.slice(u_offsets, extents); // update gate Eigen::array r_offsets({{0, frame_size}}); - g.slice(r_offsets, extents).device(place) = - g.slice(r_offsets, extents).sigmoid(); + ActCompute(context.Attr("gate_activation"), place, + g.slice(r_offsets, extents), g.slice(r_offsets, extents)); auto r = g.slice(r_offsets, extents); // reset gate r_h_p.device(place) = r * h_p; // reset previous hidden state math::gemm(context.device_context(), false, false, batch_size, @@ -84,8 +101,8 @@ class GRUUnitKernel : public framework::OpKernel { frame_size * 3); Eigen::array c_offsets({{0, frame_size * 2}}); - g.slice(c_offsets, extents).device(place) = - g.slice(c_offsets, extents).tanh(); + ActCompute(context.Attr("activation"), place, + g.slice(c_offsets, extents), g.slice(c_offsets, extents)); auto c = g.slice(c_offsets, extents); // output candidate // calculate final output @@ -94,21 +111,37 @@ class GRUUnitKernel : public framework::OpKernel { }; template -class GRUUnitGradKernel : public framework::OpKernel { +class GRUUnitGradKernel : public framework::OpKernel { public: + template + void ActGradCompute(const int act_type, const Device& d, X x, Y y, DX dx, + DY dy) const { + // x is dummy and won't be used even in Relu(use y instead) + if (act_type == identity) + dx.device(d) = dy; + else if (act_type == sigmoid) + SigmoidGradFunctor()(d, x, y, dy, dx); + else if (act_type == tanh) + TanhGradFunctor()(d, x, y, dy, dx); + else if (act_type == relu) + ReluGradFunctor()(d, x, y, dy, dx); + else + PADDLE_THROW("unsupported activation type"); + } + void Compute(const framework::ExecutionContext& context) const override { - auto* input = context.Input("input"); - auto* hidden_prev = context.Input("hidden_prev"); - auto* weight = context.Input("weight"); - auto* gate = context.Input("gate"); - auto* reset_hidden_prev = context.Input("reset_hidden_prev"); - auto* hidden_grad = context.Input(framework::GradVarName("hidden")); - auto* input_grad = context.Output(framework::GradVarName("input")); + auto* input = context.Input("Input"); + auto* hidden_prev = context.Input("HiddenPrev"); + auto* weight = context.Input("Weight"); + auto* gate = context.Input("Gate"); + auto* reset_hidden_prev = context.Input("ResetHiddenPrev"); + auto* hidden_grad = context.Input(framework::GradVarName("Hidden")); + auto* input_grad = context.Output(framework::GradVarName("Input")); auto* hidden_prev_grad = - context.Output(framework::GradVarName("hidden_prev")); + context.Output(framework::GradVarName("HiddenPrev")); auto* weight_grad = - context.Output(framework::GradVarName("weight")); - auto* bias_grad = context.Output(framework::GradVarName("bias")); + context.Output(framework::GradVarName("Weight")); + auto* bias_grad = context.Output(framework::GradVarName("Bias")); input_grad->mutable_data(context.GetPlace()); hidden_prev_grad->mutable_data(context.GetPlace()); weight_grad->mutable_data(context.GetPlace()); @@ -149,11 +182,11 @@ class GRUUnitGradKernel : public framework::OpKernel { auto c = g.slice(c_offsets, extents); // output candidate // backward for unactivated update gate - d_g.slice(u_offsets, extents).device(place) = - d_h * (h_p - c) * u * (u.constant(T(1)) - u); + ActGradCompute(context.Attr("gate_activation"), place, u, u, + d_g.slice(u_offsets, extents), d_h * (h_p - c)); // backward for unactivated output candidate - d_g.slice(c_offsets, extents).device(place) = - d_h * (u.constant(T(1)) - u) * (c.constant(T(1)) - c * c); + ActGradCompute(context.Attr("activation"), place, c, c, + d_g.slice(c_offsets, extents), d_h * (u.constant(T(1)) - u)); // backward for reset_hidden_prev math::gemm(context.device_context(), false, true, batch_size, frame_size, frame_size, 1, @@ -167,8 +200,8 @@ class GRUUnitGradKernel : public framework::OpKernel { gate_grad_data + frame_size * 2, frame_size * 3, 0, weight_grad_data + frame_size * frame_size * 2, frame_size); // backward for unactivated reset gate - d_g.slice(r_offsets, extents).device(place) = - d_r_h_p * h_p * r * (r.constant(T(1)) - r); + ActGradCompute(context.Attr("gate_activation"), place, r, r, + d_g.slice(r_offsets, extents), d_r_h_p * h_p); // backward for update_gate_weight and reset_gate_weight math::gemm(context.device_context(), true, false, frame_size, frame_size * 2, batch_size, 1, hidden_prev_data, diff --git a/python/paddle/v2/framework/tests/test_gru_unit_op.py b/python/paddle/v2/framework/tests/test_gru_unit_op.py index f7b3fab817d..bc8b3406e65 100644 --- a/python/paddle/v2/framework/tests/test_gru_unit_op.py +++ b/python/paddle/v2/framework/tests/test_gru_unit_op.py @@ -4,54 +4,84 @@ import numpy as np from op_test import OpTest -def sigmoid_np(x): +class GRUActivationType(OpTest): + identity = 0 + sigmoid = 1 + tanh = 2 + relu = 3 + + +def identity(x): + return x + + +def sigmoid(x): return 1. / (1. + np.exp(-x)) -def tanh_np(x): - return 2. * sigmoid_np(2. * x) - 1. +def tanh(x): + return 2. * sigmoid(2. * x) - 1. + + +def relu(x): + return np.maximum(x, 0) class TestGRUUnitOp(OpTest): + activate = { + GRUActivationType.identity: identity, + GRUActivationType.sigmoid: sigmoid, + GRUActivationType.tanh: tanh, + GRUActivationType.relu: relu, + } + def setUp(self): batch_size = 3 frame_size = 5 - self.op_type = "gru_unit" + self.op_type = 'gru_unit' self.inputs = { - 'input': np.random.uniform( - -0.1, 0.1, (batch_size, frame_size * 3)).astype("float32"), - 'hidden_prev': np.random.uniform( - -0.1, 0.1, (batch_size, frame_size)).astype("float32"), - 'weight': np.random.uniform( + 'Input': np.random.uniform( + -0.1, 0.1, (batch_size, frame_size * 3)).astype('float32'), + 'HiddenPrev': np.random.uniform( + -0.1, 0.1, (batch_size, frame_size)).astype('float32'), + 'Weight': np.random.uniform( -1. / math.sqrt(frame_size), 1. / math.sqrt(frame_size), - (frame_size, frame_size * 3)).astype("float32"), - 'bias': np.random.uniform(-0.1, 0.1, - (1, frame_size * 3)).astype("float32") + (frame_size, frame_size * 3)).astype('float32'), + 'Bias': np.random.uniform(-0.1, 0.1, + (1, frame_size * 3)).astype('float32') } - x = self.inputs['input'] - h_p = self.inputs['hidden_prev'] - w = self.inputs['weight'] - b = self.inputs['bias'] + self.attrs = { + 'activation': GRUActivationType.tanh, + 'gate_activation': GRUActivationType.sigmoid + } + # GRU calculations + x = self.inputs['Input'] + h_p = self.inputs['HiddenPrev'] + w = self.inputs['Weight'] + b = self.inputs['Bias'] g = x + np.tile(b, (batch_size, 1)) w_u_r = w.flatten()[:frame_size * frame_size * 2].reshape( (frame_size, frame_size * 2)) - u_r = sigmoid_np(np.dot(h_p, w_u_r) + g[:, :frame_size * 2]) + u_r = self.activate[self.attrs['gate_activation']](np.dot( + h_p, w_u_r) + g[:, :frame_size * 2]) u = u_r[:, :frame_size] r = u_r[:, frame_size:frame_size * 2] r_h_p = r * h_p w_c = w.flatten()[frame_size * frame_size * 2:].reshape( (frame_size, frame_size)) - c = tanh_np(np.dot(r_h_p, w_c) + g[:, frame_size * 2:]) + c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) + + g[:, frame_size * 2:]) g = np.hstack((u_r, c)) h = u * h_p + (1 - u) * c - self.outputs = {'gate': g, 'reset_hidden_prev': r_h_p, 'hidden': h} + + self.outputs = {'Gate': g, 'ResetHiddenPrev': r_h_p, 'Hidden': h} def test_check_output(self): self.check_output() def test_check_grad(self): self.check_grad( - ['input', 'hidden_prev', 'weight', 'bias'], ['hidden'], + ['Input', 'HiddenPrev', 'Weight', 'Bias'], ['Hidden'], max_relative_error=0.007) -- GitLab