提交 1cabdb87 编写于 作者: G guosheng

Refine gru_unit_op according to comments to support multiple activation types

上级 0922fca4
......@@ -24,26 +24,26 @@ class GRUUnitOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("input"),
"Input(%s) of GRUUnitOp should not be null.", "input");
PADDLE_ENFORCE(ctx->HasInput("hidden_prev"),
"Input(%s) of GRUUnitOp should not be null.", "hidden_prev");
PADDLE_ENFORCE(ctx->HasInput("weight"),
"Input(%s) of GRUUnitOp should not be null.", "weight");
PADDLE_ENFORCE(ctx->HasInput("bias"),
"Input(%s) of GRUUnitOp should not be null.", "bias");
PADDLE_ENFORCE(ctx->HasOutput("gate"),
"Output(%s) of GRUUnitOp should not be null.", "gate");
PADDLE_ENFORCE(ctx->HasOutput("reset_hidden_prev"),
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(%s) of GRUUnitOp should not be null.", "Input");
PADDLE_ENFORCE(ctx->HasInput("HiddenPrev"),
"Input(%s) of GRUUnitOp should not be null.", "HiddenPrev");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(%s) of GRUUnitOp should not be null.", "Weight");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(%s) of GRUUnitOp should not be null.", "Bias");
PADDLE_ENFORCE(ctx->HasOutput("Gate"),
"Output(%s) of GRUUnitOp should not be null.", "Gate");
PADDLE_ENFORCE(ctx->HasOutput("ResetHiddenPrev"),
"Output(%s) of GRUUnitOp should not be null.",
"reset_hidden_prev");
PADDLE_ENFORCE(ctx->HasOutput("hidden"),
"Output(%s) of GRUUnitOp should not be null.", "hidden");
auto input_dims = ctx->GetInputDim("input");
auto hidden_prev_dims = ctx->GetInputDim("hidden_prev");
auto weight_dims = ctx->GetInputDim("weight");
auto bias_dims = ctx->GetInputDim("bias");
"ResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(%s) of GRUUnitOp should not be null.", "Hidden");
auto input_dims = ctx->GetInputDim("Input");
auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev");
auto weight_dims = ctx->GetInputDim("Weight");
auto bias_dims = ctx->GetInputDim("Bias");
int batch_size = input_dims[0];
int input_size = input_dims[1];
int frame_size = hidden_prev_dims[1];
......@@ -53,54 +53,64 @@ class GRUUnitOp : public framework::OperatorWithKernel {
int bias_width = bias_dims[1];
PADDLE_ENFORCE_EQ(
input_size, frame_size * 3,
"The innput_size must be 3 times of frame_size in GRUUnitOp.");
"The input_size must be 3 times of frame_size in GRUUnitOp.");
PADDLE_ENFORCE_EQ(
weight_height, frame_size,
"The shape of weight matrix must be [frame_size, frame_size * 3].");
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
PADDLE_ENFORCE_EQ(
weight_width, frame_size * 3,
"The shape of weight matrix must be [frame_size, frame_size * 3].");
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
PADDLE_ENFORCE_EQ(bias_height, 1,
"The shape of bias must be [1, frame_size * 3].");
"The shape of Bias must be [1, frame_size * 3].");
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
"The shape of bias must be [1, frame_size * 3].");
ctx->SetOutputDim("gate", {batch_size, frame_size * 3});
ctx->SetOutputDim("reset_hidden_prev", {batch_size, frame_size});
ctx->SetOutputDim("hidden", {batch_size, frame_size});
"The shape of Bias must be [1, frame_size * 3].");
ctx->SetOutputDim("Gate", {batch_size, frame_size * 3});
ctx->SetOutputDim("ResetHiddenPrev", {batch_size, frame_size});
ctx->SetOutputDim("Hidden", {batch_size, frame_size});
}
};
class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
GRUUnitOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
GRUUnitOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input",
AddInput("Input",
"(Tensor) Matrix with shape [batch_size, frame_size * 3] for the "
"input.");
AddInput("hidden_prev",
AddInput("HiddenPrev",
"(Tensor) Matrix with shape [batch_size, frame_size] for the "
"states of previous time step.");
AddInput("weight",
AddInput("Weight",
"(Tensor) Weight matrix with shape [frame_size, frame_size * 3]. "
"The elements continuous in memory can be divided into two parts. "
"The first part are weights of the update gate and reset gate "
"with shape [frame_size, frame_size * 2], and the second part are "
"weights of output candidate with shape [frame_size, frame_size]");
AddInput("bias",
AddInput("Bias",
"(Tensor) Bias vector with shape [1, frame_size * 3] concating "
"bias of the update gate, reset gate and output candidate.");
AddOutput("gate",
AddOutput("Gate",
"(Tensor) Matrix with shape [batch_size, frame_size * 3] for the "
"output of update gate, reset gate and output candidate")
.AsIntermediate();
AddOutput("reset_hidden_prev",
AddOutput("ResetHiddenPrev",
"(Tensor) Matrix with shape [batch_size, frame_size] for the "
"reseted hidden state of previous time step.")
.AsIntermediate();
AddOutput("hidden",
AddOutput("Hidden",
"(Tensor) The GRU hidden state of the current time step "
"with shape [batch_size, frame_size].");
AddAttr<int>("activation",
"(enum int, default tanh) "
"The activation type used for output candidate {h}_t.")
.SetDefault(tanh)
.InEnum({identity, sigmoid, tanh, relu});
AddAttr<int>("gate_activation",
"(enum int, default sigmoid) "
"The activation type used in update gate and reset gate.")
.SetDefault(sigmoid)
.InEnum({identity, sigmoid, tanh, relu});
AddComment(R"DOC(
GRUUnitOp implements part calculations of the GRU unit as following:
......@@ -121,36 +131,36 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("input"),
"Input(%s) of GRUUnitGradOp should not be null.", "input");
PADDLE_ENFORCE(ctx->HasInput("hidden_prev"),
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(%s) of GRUUnitGradOp should not be null.", "Input");
PADDLE_ENFORCE(ctx->HasInput("HiddenPrev"),
"Input(%s) of GRUUnitGradOp should not be null.",
"hidden_prev");
PADDLE_ENFORCE(ctx->HasInput("weight"),
"Input(%s) of GRUUnitGradOp should not be null.", "weight");
PADDLE_ENFORCE(ctx->HasInput("bias"),
"Input(%s) of GRUUnitGradOp should not be null.", "bias");
PADDLE_ENFORCE(ctx->HasInput("gate"),
"Input(%s) of GRUUnitGradOp should not be null.", "gate");
PADDLE_ENFORCE(ctx->HasInput("reset_hidden_prev"),
"HiddenPrev");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(%s) of GRUUnitGradOp should not be null.", "Weight");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(%s) of GRUUnitGradOp should not be null.", "Bias");
PADDLE_ENFORCE(ctx->HasInput("Gate"),
"Input(%s) of GRUUnitGradOp should not be null.", "Gate");
PADDLE_ENFORCE(ctx->HasInput("ResetHiddenPrev"),
"Input(%s) of GRUUnitGradOp should not be null.",
"reset_hidden_prev");
PADDLE_ENFORCE(ctx->HasInput("hidden"),
"Input(%s) of GRUUnitGradOp should not be null.", "hidden");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("gate")),
"ResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasInput("Hidden"),
"Input(%s) of GRUUnitGradOp should not be null.", "Hidden");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Gate")),
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
"gate");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("reset_hidden_prev")),
"Gate");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("ResetHiddenPrev")),
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
"reset_hidden_prev");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("hidden")),
"ResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
"hidden");
auto input_dims = ctx->GetInputDim("input");
auto hidden_prev_dims = ctx->GetInputDim("hidden_prev");
auto weight_dims = ctx->GetInputDim("weight");
auto bias_dims = ctx->GetInputDim("bias");
"Hidden");
auto input_dims = ctx->GetInputDim("Input");
auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev");
auto weight_dims = ctx->GetInputDim("Weight");
auto bias_dims = ctx->GetInputDim("Bias");
// int batch_size = input_dims[0];
int input_size = input_dims[1];
int frame_size = hidden_prev_dims[1];
......@@ -160,27 +170,27 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
int bias_width = bias_dims[1];
PADDLE_ENFORCE_EQ(
input_size, frame_size * 3,
"The innput_size must be 3 times of frame_size in GRUUnitOp.");
"The input_size must be 3 times of frame_size in GRUUnitOp.");
PADDLE_ENFORCE_EQ(
weight_height, frame_size,
"The shape of weight matrix must be [frame_size, frame_size * 3].");
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
PADDLE_ENFORCE_EQ(
weight_width, frame_size * 3,
"The shape of weight matrix must be [frame_size, frame_size * 3].");
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
PADDLE_ENFORCE_EQ(bias_height, 1,
"The shape of bias must be [1, frame_size * 3].");
"The shape of Bias must be [1, frame_size * 3].");
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
"The shape of bias must be [1, frame_size * 3].");
auto input_grad_name = framework::GradVarName("input");
"The shape of Bias must be [1, frame_size * 3].");
auto input_grad_name = framework::GradVarName("Input");
if (ctx->HasOutput(input_grad_name))
ctx->SetOutputDim(input_grad_name, input_dims);
auto hidden_prev_grad_name = framework::GradVarName("hidden_prev");
auto hidden_prev_grad_name = framework::GradVarName("HiddenPrev");
if (ctx->HasOutput(hidden_prev_grad_name))
ctx->SetOutputDim(hidden_prev_grad_name, hidden_prev_dims);
auto weight_grad_name = framework::GradVarName("weight");
auto weight_grad_name = framework::GradVarName("Weight");
if (ctx->HasOutput(weight_grad_name))
ctx->SetOutputDim(weight_grad_name, weight_dims);
auto bias_grad_name = framework::GradVarName("bias");
auto bias_grad_name = framework::GradVarName("Bias");
if (ctx->HasOutput(bias_grad_name))
ctx->SetOutputDim(bias_grad_name, bias_dims);
}
......
......@@ -14,6 +14,7 @@
#pragma once
#include "paddle/operators/activation_op.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/framework/eigen.h"
......@@ -27,19 +28,35 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
enum GRUActivationType { identity = 0, sigmoid = 1, tanh = 2, relu = 3 };
template <typename Place, typename T>
class GRUUnitKernel : public framework::OpKernel {
class GRUUnitKernel : public framework::OpKernel<T> {
public:
template <typename Device, typename X, typename Y>
void ActCompute(const int act_type, const Device& d, X x, Y y) const {
if (act_type == identity)
y.device(d) = x;
else if (act_type == sigmoid)
SigmoidFunctor<T>()(d, x, y);
else if (act_type == tanh)
TanhFunctor<T>()(d, x, y);
else if (act_type == relu)
ReluFunctor<T>()(d, x, y);
else
PADDLE_THROW("unsupported activation type");
}
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("input");
auto* hidden_prev = context.Input<Tensor>("hidden_prev");
auto* weight = context.Input<Tensor>("weight");
auto* bias = context.Input<Tensor>("bias");
auto* gate = context.Output<Tensor>("gate");
auto* input = context.Input<Tensor>("Input");
auto* hidden_prev = context.Input<Tensor>("HiddenPrev");
auto* weight = context.Input<Tensor>("Weight");
auto* bias = context.Input<Tensor>("Bias");
auto* gate = context.Output<Tensor>("Gate");
gate->mutable_data<T>(context.GetPlace());
auto* reset_hidden_prev = context.Output<Tensor>("reset_hidden_prev");
auto* reset_hidden_prev = context.Output<Tensor>("ResetHiddenPrev");
reset_hidden_prev->mutable_data<T>(context.GetPlace());
auto* hidden = context.Output<Tensor>("hidden");
auto* hidden = context.Output<Tensor>("Hidden");
hidden->mutable_data<T>(context.GetPlace());
int batch_size = input->dims()[0];
......@@ -69,12 +86,12 @@ class GRUUnitKernel : public framework::OpKernel {
// calculate activited gate
Eigen::array<int, 2> extents({{batch_size, frame_size}});
Eigen::array<int, 2> u_offsets({{0, 0}});
g.slice(u_offsets, extents).device(place) =
g.slice(u_offsets, extents).sigmoid();
ActCompute(context.Attr<int>("gate_activation"), place,
g.slice(u_offsets, extents), g.slice(u_offsets, extents));
auto u = g.slice(u_offsets, extents); // update gate
Eigen::array<int, 2> r_offsets({{0, frame_size}});
g.slice(r_offsets, extents).device(place) =
g.slice(r_offsets, extents).sigmoid();
ActCompute(context.Attr<int>("gate_activation"), place,
g.slice(r_offsets, extents), g.slice(r_offsets, extents));
auto r = g.slice(r_offsets, extents); // reset gate
r_h_p.device(place) = r * h_p; // reset previous hidden state
math::gemm<Place, T>(context.device_context(), false, false, batch_size,
......@@ -84,8 +101,8 @@ class GRUUnitKernel : public framework::OpKernel {
frame_size * 3);
Eigen::array<int, 2> c_offsets({{0, frame_size * 2}});
g.slice(c_offsets, extents).device(place) =
g.slice(c_offsets, extents).tanh();
ActCompute(context.Attr<int>("activation"), place,
g.slice(c_offsets, extents), g.slice(c_offsets, extents));
auto c = g.slice(c_offsets, extents); // output candidate
// calculate final output
......@@ -94,21 +111,37 @@ class GRUUnitKernel : public framework::OpKernel {
};
template <typename Place, typename T>
class GRUUnitGradKernel : public framework::OpKernel {
class GRUUnitGradKernel : public framework::OpKernel<T> {
public:
template <typename Device, typename X, typename Y, typename DX, typename DY>
void ActGradCompute(const int act_type, const Device& d, X x, Y y, DX dx,
DY dy) const {
// x is dummy and won't be used even in Relu(use y instead)
if (act_type == identity)
dx.device(d) = dy;
else if (act_type == sigmoid)
SigmoidGradFunctor<T>()(d, x, y, dy, dx);
else if (act_type == tanh)
TanhGradFunctor<T>()(d, x, y, dy, dx);
else if (act_type == relu)
ReluGradFunctor<T>()(d, x, y, dy, dx);
else
PADDLE_THROW("unsupported activation type");
}
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("input");
auto* hidden_prev = context.Input<Tensor>("hidden_prev");
auto* weight = context.Input<Tensor>("weight");
auto* gate = context.Input<Tensor>("gate");
auto* reset_hidden_prev = context.Input<Tensor>("reset_hidden_prev");
auto* hidden_grad = context.Input<Tensor>(framework::GradVarName("hidden"));
auto* input_grad = context.Output<Tensor>(framework::GradVarName("input"));
auto* input = context.Input<Tensor>("Input");
auto* hidden_prev = context.Input<Tensor>("HiddenPrev");
auto* weight = context.Input<Tensor>("Weight");
auto* gate = context.Input<Tensor>("Gate");
auto* reset_hidden_prev = context.Input<Tensor>("ResetHiddenPrev");
auto* hidden_grad = context.Input<Tensor>(framework::GradVarName("Hidden"));
auto* input_grad = context.Output<Tensor>(framework::GradVarName("Input"));
auto* hidden_prev_grad =
context.Output<Tensor>(framework::GradVarName("hidden_prev"));
context.Output<Tensor>(framework::GradVarName("HiddenPrev"));
auto* weight_grad =
context.Output<Tensor>(framework::GradVarName("weight"));
auto* bias_grad = context.Output<Tensor>(framework::GradVarName("bias"));
context.Output<Tensor>(framework::GradVarName("Weight"));
auto* bias_grad = context.Output<Tensor>(framework::GradVarName("Bias"));
input_grad->mutable_data<T>(context.GetPlace());
hidden_prev_grad->mutable_data<T>(context.GetPlace());
weight_grad->mutable_data<T>(context.GetPlace());
......@@ -149,11 +182,11 @@ class GRUUnitGradKernel : public framework::OpKernel {
auto c = g.slice(c_offsets, extents); // output candidate
// backward for unactivated update gate
d_g.slice(u_offsets, extents).device(place) =
d_h * (h_p - c) * u * (u.constant(T(1)) - u);
ActGradCompute(context.Attr<int>("gate_activation"), place, u, u,
d_g.slice(u_offsets, extents), d_h * (h_p - c));
// backward for unactivated output candidate
d_g.slice(c_offsets, extents).device(place) =
d_h * (u.constant(T(1)) - u) * (c.constant(T(1)) - c * c);
ActGradCompute(context.Attr<int>("activation"), place, c, c,
d_g.slice(c_offsets, extents), d_h * (u.constant(T(1)) - u));
// backward for reset_hidden_prev
math::gemm<Place, T>(context.device_context(), false, true, batch_size,
frame_size, frame_size, 1,
......@@ -167,8 +200,8 @@ class GRUUnitGradKernel : public framework::OpKernel {
gate_grad_data + frame_size * 2, frame_size * 3, 0,
weight_grad_data + frame_size * frame_size * 2, frame_size);
// backward for unactivated reset gate
d_g.slice(r_offsets, extents).device(place) =
d_r_h_p * h_p * r * (r.constant(T(1)) - r);
ActGradCompute(context.Attr<int>("gate_activation"), place, r, r,
d_g.slice(r_offsets, extents), d_r_h_p * h_p);
// backward for update_gate_weight and reset_gate_weight
math::gemm<Place, T>(context.device_context(), true, false, frame_size,
frame_size * 2, batch_size, 1, hidden_prev_data,
......
......@@ -4,54 +4,84 @@ import numpy as np
from op_test import OpTest
def sigmoid_np(x):
class GRUActivationType(OpTest):
identity = 0
sigmoid = 1
tanh = 2
relu = 3
def identity(x):
return x
def sigmoid(x):
return 1. / (1. + np.exp(-x))
def tanh_np(x):
return 2. * sigmoid_np(2. * x) - 1.
def tanh(x):
return 2. * sigmoid(2. * x) - 1.
def relu(x):
return np.maximum(x, 0)
class TestGRUUnitOp(OpTest):
activate = {
GRUActivationType.identity: identity,
GRUActivationType.sigmoid: sigmoid,
GRUActivationType.tanh: tanh,
GRUActivationType.relu: relu,
}
def setUp(self):
batch_size = 3
frame_size = 5
self.op_type = "gru_unit"
self.op_type = 'gru_unit'
self.inputs = {
'input': np.random.uniform(
-0.1, 0.1, (batch_size, frame_size * 3)).astype("float32"),
'hidden_prev': np.random.uniform(
-0.1, 0.1, (batch_size, frame_size)).astype("float32"),
'weight': np.random.uniform(
'Input': np.random.uniform(
-0.1, 0.1, (batch_size, frame_size * 3)).astype('float32'),
'HiddenPrev': np.random.uniform(
-0.1, 0.1, (batch_size, frame_size)).astype('float32'),
'Weight': np.random.uniform(
-1. / math.sqrt(frame_size), 1. / math.sqrt(frame_size),
(frame_size, frame_size * 3)).astype("float32"),
'bias': np.random.uniform(-0.1, 0.1,
(1, frame_size * 3)).astype("float32")
(frame_size, frame_size * 3)).astype('float32'),
'Bias': np.random.uniform(-0.1, 0.1,
(1, frame_size * 3)).astype('float32')
}
x = self.inputs['input']
h_p = self.inputs['hidden_prev']
w = self.inputs['weight']
b = self.inputs['bias']
self.attrs = {
'activation': GRUActivationType.tanh,
'gate_activation': GRUActivationType.sigmoid
}
# GRU calculations
x = self.inputs['Input']
h_p = self.inputs['HiddenPrev']
w = self.inputs['Weight']
b = self.inputs['Bias']
g = x + np.tile(b, (batch_size, 1))
w_u_r = w.flatten()[:frame_size * frame_size * 2].reshape(
(frame_size, frame_size * 2))
u_r = sigmoid_np(np.dot(h_p, w_u_r) + g[:, :frame_size * 2])
u_r = self.activate[self.attrs['gate_activation']](np.dot(
h_p, w_u_r) + g[:, :frame_size * 2])
u = u_r[:, :frame_size]
r = u_r[:, frame_size:frame_size * 2]
r_h_p = r * h_p
w_c = w.flatten()[frame_size * frame_size * 2:].reshape(
(frame_size, frame_size))
c = tanh_np(np.dot(r_h_p, w_c) + g[:, frame_size * 2:])
c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) +
g[:, frame_size * 2:])
g = np.hstack((u_r, c))
h = u * h_p + (1 - u) * c
self.outputs = {'gate': g, 'reset_hidden_prev': r_h_p, 'hidden': h}
self.outputs = {'Gate': g, 'ResetHiddenPrev': r_h_p, 'Hidden': h}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(
['input', 'hidden_prev', 'weight', 'bias'], ['hidden'],
['Input', 'HiddenPrev', 'Weight', 'Bias'], ['Hidden'],
max_relative_error=0.007)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册