提交 1cabdb87 编写于 作者: G guosheng

Refine gru_unit_op according to comments to support multiple activation types

上级 0922fca4
...@@ -24,26 +24,26 @@ class GRUUnitOp : public framework::OperatorWithKernel { ...@@ -24,26 +24,26 @@ class GRUUnitOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(framework::InferShapeContextBase *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("input"), PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(%s) of GRUUnitOp should not be null.", "input"); "Input(%s) of GRUUnitOp should not be null.", "Input");
PADDLE_ENFORCE(ctx->HasInput("hidden_prev"), PADDLE_ENFORCE(ctx->HasInput("HiddenPrev"),
"Input(%s) of GRUUnitOp should not be null.", "hidden_prev"); "Input(%s) of GRUUnitOp should not be null.", "HiddenPrev");
PADDLE_ENFORCE(ctx->HasInput("weight"), PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(%s) of GRUUnitOp should not be null.", "weight"); "Input(%s) of GRUUnitOp should not be null.", "Weight");
PADDLE_ENFORCE(ctx->HasInput("bias"), PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(%s) of GRUUnitOp should not be null.", "bias"); "Input(%s) of GRUUnitOp should not be null.", "Bias");
PADDLE_ENFORCE(ctx->HasOutput("gate"), PADDLE_ENFORCE(ctx->HasOutput("Gate"),
"Output(%s) of GRUUnitOp should not be null.", "gate"); "Output(%s) of GRUUnitOp should not be null.", "Gate");
PADDLE_ENFORCE(ctx->HasOutput("reset_hidden_prev"), PADDLE_ENFORCE(ctx->HasOutput("ResetHiddenPrev"),
"Output(%s) of GRUUnitOp should not be null.", "Output(%s) of GRUUnitOp should not be null.",
"reset_hidden_prev"); "ResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasOutput("hidden"), PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(%s) of GRUUnitOp should not be null.", "hidden"); "Output(%s) of GRUUnitOp should not be null.", "Hidden");
auto input_dims = ctx->GetInputDim("input"); auto input_dims = ctx->GetInputDim("Input");
auto hidden_prev_dims = ctx->GetInputDim("hidden_prev"); auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev");
auto weight_dims = ctx->GetInputDim("weight"); auto weight_dims = ctx->GetInputDim("Weight");
auto bias_dims = ctx->GetInputDim("bias"); auto bias_dims = ctx->GetInputDim("Bias");
int batch_size = input_dims[0]; int batch_size = input_dims[0];
int input_size = input_dims[1]; int input_size = input_dims[1];
int frame_size = hidden_prev_dims[1]; int frame_size = hidden_prev_dims[1];
...@@ -53,54 +53,64 @@ class GRUUnitOp : public framework::OperatorWithKernel { ...@@ -53,54 +53,64 @@ class GRUUnitOp : public framework::OperatorWithKernel {
int bias_width = bias_dims[1]; int bias_width = bias_dims[1];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input_size, frame_size * 3, input_size, frame_size * 3,
"The innput_size must be 3 times of frame_size in GRUUnitOp."); "The input_size must be 3 times of frame_size in GRUUnitOp.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_height, frame_size, weight_height, frame_size,
"The shape of weight matrix must be [frame_size, frame_size * 3]."); "The shape of Weight matrix must be [frame_size, frame_size * 3].");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_width, frame_size * 3, weight_width, frame_size * 3,
"The shape of weight matrix must be [frame_size, frame_size * 3]."); "The shape of Weight matrix must be [frame_size, frame_size * 3].");
PADDLE_ENFORCE_EQ(bias_height, 1, PADDLE_ENFORCE_EQ(bias_height, 1,
"The shape of bias must be [1, frame_size * 3]."); "The shape of Bias must be [1, frame_size * 3].");
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3, PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
"The shape of bias must be [1, frame_size * 3]."); "The shape of Bias must be [1, frame_size * 3].");
ctx->SetOutputDim("gate", {batch_size, frame_size * 3}); ctx->SetOutputDim("Gate", {batch_size, frame_size * 3});
ctx->SetOutputDim("reset_hidden_prev", {batch_size, frame_size}); ctx->SetOutputDim("ResetHiddenPrev", {batch_size, frame_size});
ctx->SetOutputDim("hidden", {batch_size, frame_size}); ctx->SetOutputDim("Hidden", {batch_size, frame_size});
} }
}; };
class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker { class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
GRUUnitOpMaker(framework::OpProto *proto, GRUUnitOpMaker(framework::OpProto* proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", AddInput("Input",
"(Tensor) Matrix with shape [batch_size, frame_size * 3] for the " "(Tensor) Matrix with shape [batch_size, frame_size * 3] for the "
"input."); "input.");
AddInput("hidden_prev", AddInput("HiddenPrev",
"(Tensor) Matrix with shape [batch_size, frame_size] for the " "(Tensor) Matrix with shape [batch_size, frame_size] for the "
"states of previous time step."); "states of previous time step.");
AddInput("weight", AddInput("Weight",
"(Tensor) Weight matrix with shape [frame_size, frame_size * 3]. " "(Tensor) Weight matrix with shape [frame_size, frame_size * 3]. "
"The elements continuous in memory can be divided into two parts. " "The elements continuous in memory can be divided into two parts. "
"The first part are weights of the update gate and reset gate " "The first part are weights of the update gate and reset gate "
"with shape [frame_size, frame_size * 2], and the second part are " "with shape [frame_size, frame_size * 2], and the second part are "
"weights of output candidate with shape [frame_size, frame_size]"); "weights of output candidate with shape [frame_size, frame_size]");
AddInput("bias", AddInput("Bias",
"(Tensor) Bias vector with shape [1, frame_size * 3] concating " "(Tensor) Bias vector with shape [1, frame_size * 3] concating "
"bias of the update gate, reset gate and output candidate."); "bias of the update gate, reset gate and output candidate.");
AddOutput("gate", AddOutput("Gate",
"(Tensor) Matrix with shape [batch_size, frame_size * 3] for the " "(Tensor) Matrix with shape [batch_size, frame_size * 3] for the "
"output of update gate, reset gate and output candidate") "output of update gate, reset gate and output candidate")
.AsIntermediate(); .AsIntermediate();
AddOutput("reset_hidden_prev", AddOutput("ResetHiddenPrev",
"(Tensor) Matrix with shape [batch_size, frame_size] for the " "(Tensor) Matrix with shape [batch_size, frame_size] for the "
"reseted hidden state of previous time step.") "reseted hidden state of previous time step.")
.AsIntermediate(); .AsIntermediate();
AddOutput("hidden", AddOutput("Hidden",
"(Tensor) The GRU hidden state of the current time step " "(Tensor) The GRU hidden state of the current time step "
"with shape [batch_size, frame_size]."); "with shape [batch_size, frame_size].");
AddAttr<int>("activation",
"(enum int, default tanh) "
"The activation type used for output candidate {h}_t.")
.SetDefault(tanh)
.InEnum({identity, sigmoid, tanh, relu});
AddAttr<int>("gate_activation",
"(enum int, default sigmoid) "
"The activation type used in update gate and reset gate.")
.SetDefault(sigmoid)
.InEnum({identity, sigmoid, tanh, relu});
AddComment(R"DOC( AddComment(R"DOC(
GRUUnitOp implements part calculations of the GRU unit as following: GRUUnitOp implements part calculations of the GRU unit as following:
...@@ -121,36 +131,36 @@ class GRUUnitGradOp : public framework::OperatorWithKernel { ...@@ -121,36 +131,36 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(framework::InferShapeContextBase *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("input"), PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(%s) of GRUUnitGradOp should not be null.", "input"); "Input(%s) of GRUUnitGradOp should not be null.", "Input");
PADDLE_ENFORCE(ctx->HasInput("hidden_prev"), PADDLE_ENFORCE(ctx->HasInput("HiddenPrev"),
"Input(%s) of GRUUnitGradOp should not be null.", "Input(%s) of GRUUnitGradOp should not be null.",
"hidden_prev"); "HiddenPrev");
PADDLE_ENFORCE(ctx->HasInput("weight"), PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(%s) of GRUUnitGradOp should not be null.", "weight"); "Input(%s) of GRUUnitGradOp should not be null.", "Weight");
PADDLE_ENFORCE(ctx->HasInput("bias"), PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(%s) of GRUUnitGradOp should not be null.", "bias"); "Input(%s) of GRUUnitGradOp should not be null.", "Bias");
PADDLE_ENFORCE(ctx->HasInput("gate"), PADDLE_ENFORCE(ctx->HasInput("Gate"),
"Input(%s) of GRUUnitGradOp should not be null.", "gate"); "Input(%s) of GRUUnitGradOp should not be null.", "Gate");
PADDLE_ENFORCE(ctx->HasInput("reset_hidden_prev"), PADDLE_ENFORCE(ctx->HasInput("ResetHiddenPrev"),
"Input(%s) of GRUUnitGradOp should not be null.", "Input(%s) of GRUUnitGradOp should not be null.",
"reset_hidden_prev"); "ResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasInput("hidden"), PADDLE_ENFORCE(ctx->HasInput("Hidden"),
"Input(%s) of GRUUnitGradOp should not be null.", "hidden"); "Input(%s) of GRUUnitGradOp should not be null.", "Hidden");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("gate")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Gate")),
"Input(%s@GRAD) of GRUUnitGradOp should not be null.", "Input(%s@GRAD) of GRUUnitGradOp should not be null.",
"gate"); "Gate");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("reset_hidden_prev")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("ResetHiddenPrev")),
"Input(%s@GRAD) of GRUUnitGradOp should not be null.", "Input(%s@GRAD) of GRUUnitGradOp should not be null.",
"reset_hidden_prev"); "ResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("hidden")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
"Input(%s@GRAD) of GRUUnitGradOp should not be null.", "Input(%s@GRAD) of GRUUnitGradOp should not be null.",
"hidden"); "Hidden");
auto input_dims = ctx->GetInputDim("input"); auto input_dims = ctx->GetInputDim("Input");
auto hidden_prev_dims = ctx->GetInputDim("hidden_prev"); auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev");
auto weight_dims = ctx->GetInputDim("weight"); auto weight_dims = ctx->GetInputDim("Weight");
auto bias_dims = ctx->GetInputDim("bias"); auto bias_dims = ctx->GetInputDim("Bias");
// int batch_size = input_dims[0]; // int batch_size = input_dims[0];
int input_size = input_dims[1]; int input_size = input_dims[1];
int frame_size = hidden_prev_dims[1]; int frame_size = hidden_prev_dims[1];
...@@ -160,27 +170,27 @@ class GRUUnitGradOp : public framework::OperatorWithKernel { ...@@ -160,27 +170,27 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
int bias_width = bias_dims[1]; int bias_width = bias_dims[1];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input_size, frame_size * 3, input_size, frame_size * 3,
"The innput_size must be 3 times of frame_size in GRUUnitOp."); "The input_size must be 3 times of frame_size in GRUUnitOp.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_height, frame_size, weight_height, frame_size,
"The shape of weight matrix must be [frame_size, frame_size * 3]."); "The shape of Weight matrix must be [frame_size, frame_size * 3].");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_width, frame_size * 3, weight_width, frame_size * 3,
"The shape of weight matrix must be [frame_size, frame_size * 3]."); "The shape of Weight matrix must be [frame_size, frame_size * 3].");
PADDLE_ENFORCE_EQ(bias_height, 1, PADDLE_ENFORCE_EQ(bias_height, 1,
"The shape of bias must be [1, frame_size * 3]."); "The shape of Bias must be [1, frame_size * 3].");
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3, PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
"The shape of bias must be [1, frame_size * 3]."); "The shape of Bias must be [1, frame_size * 3].");
auto input_grad_name = framework::GradVarName("input"); auto input_grad_name = framework::GradVarName("Input");
if (ctx->HasOutput(input_grad_name)) if (ctx->HasOutput(input_grad_name))
ctx->SetOutputDim(input_grad_name, input_dims); ctx->SetOutputDim(input_grad_name, input_dims);
auto hidden_prev_grad_name = framework::GradVarName("hidden_prev"); auto hidden_prev_grad_name = framework::GradVarName("HiddenPrev");
if (ctx->HasOutput(hidden_prev_grad_name)) if (ctx->HasOutput(hidden_prev_grad_name))
ctx->SetOutputDim(hidden_prev_grad_name, hidden_prev_dims); ctx->SetOutputDim(hidden_prev_grad_name, hidden_prev_dims);
auto weight_grad_name = framework::GradVarName("weight"); auto weight_grad_name = framework::GradVarName("Weight");
if (ctx->HasOutput(weight_grad_name)) if (ctx->HasOutput(weight_grad_name))
ctx->SetOutputDim(weight_grad_name, weight_dims); ctx->SetOutputDim(weight_grad_name, weight_dims);
auto bias_grad_name = framework::GradVarName("bias"); auto bias_grad_name = framework::GradVarName("Bias");
if (ctx->HasOutput(bias_grad_name)) if (ctx->HasOutput(bias_grad_name))
ctx->SetOutputDim(bias_grad_name, bias_dims); ctx->SetOutputDim(bias_grad_name, bias_dims);
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include "paddle/operators/activation_op.h"
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
...@@ -27,19 +28,35 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -27,19 +28,35 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
enum GRUActivationType { identity = 0, sigmoid = 1, tanh = 2, relu = 3 };
template <typename Place, typename T> template <typename Place, typename T>
class GRUUnitKernel : public framework::OpKernel { class GRUUnitKernel : public framework::OpKernel<T> {
public: public:
template <typename Device, typename X, typename Y>
void ActCompute(const int act_type, const Device& d, X x, Y y) const {
if (act_type == identity)
y.device(d) = x;
else if (act_type == sigmoid)
SigmoidFunctor<T>()(d, x, y);
else if (act_type == tanh)
TanhFunctor<T>()(d, x, y);
else if (act_type == relu)
ReluFunctor<T>()(d, x, y);
else
PADDLE_THROW("unsupported activation type");
}
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("input"); auto* input = context.Input<Tensor>("Input");
auto* hidden_prev = context.Input<Tensor>("hidden_prev"); auto* hidden_prev = context.Input<Tensor>("HiddenPrev");
auto* weight = context.Input<Tensor>("weight"); auto* weight = context.Input<Tensor>("Weight");
auto* bias = context.Input<Tensor>("bias"); auto* bias = context.Input<Tensor>("Bias");
auto* gate = context.Output<Tensor>("gate"); auto* gate = context.Output<Tensor>("Gate");
gate->mutable_data<T>(context.GetPlace()); gate->mutable_data<T>(context.GetPlace());
auto* reset_hidden_prev = context.Output<Tensor>("reset_hidden_prev"); auto* reset_hidden_prev = context.Output<Tensor>("ResetHiddenPrev");
reset_hidden_prev->mutable_data<T>(context.GetPlace()); reset_hidden_prev->mutable_data<T>(context.GetPlace());
auto* hidden = context.Output<Tensor>("hidden"); auto* hidden = context.Output<Tensor>("Hidden");
hidden->mutable_data<T>(context.GetPlace()); hidden->mutable_data<T>(context.GetPlace());
int batch_size = input->dims()[0]; int batch_size = input->dims()[0];
...@@ -69,12 +86,12 @@ class GRUUnitKernel : public framework::OpKernel { ...@@ -69,12 +86,12 @@ class GRUUnitKernel : public framework::OpKernel {
// calculate activited gate // calculate activited gate
Eigen::array<int, 2> extents({{batch_size, frame_size}}); Eigen::array<int, 2> extents({{batch_size, frame_size}});
Eigen::array<int, 2> u_offsets({{0, 0}}); Eigen::array<int, 2> u_offsets({{0, 0}});
g.slice(u_offsets, extents).device(place) = ActCompute(context.Attr<int>("gate_activation"), place,
g.slice(u_offsets, extents).sigmoid(); g.slice(u_offsets, extents), g.slice(u_offsets, extents));
auto u = g.slice(u_offsets, extents); // update gate auto u = g.slice(u_offsets, extents); // update gate
Eigen::array<int, 2> r_offsets({{0, frame_size}}); Eigen::array<int, 2> r_offsets({{0, frame_size}});
g.slice(r_offsets, extents).device(place) = ActCompute(context.Attr<int>("gate_activation"), place,
g.slice(r_offsets, extents).sigmoid(); g.slice(r_offsets, extents), g.slice(r_offsets, extents));
auto r = g.slice(r_offsets, extents); // reset gate auto r = g.slice(r_offsets, extents); // reset gate
r_h_p.device(place) = r * h_p; // reset previous hidden state r_h_p.device(place) = r * h_p; // reset previous hidden state
math::gemm<Place, T>(context.device_context(), false, false, batch_size, math::gemm<Place, T>(context.device_context(), false, false, batch_size,
...@@ -84,8 +101,8 @@ class GRUUnitKernel : public framework::OpKernel { ...@@ -84,8 +101,8 @@ class GRUUnitKernel : public framework::OpKernel {
frame_size * 3); frame_size * 3);
Eigen::array<int, 2> c_offsets({{0, frame_size * 2}}); Eigen::array<int, 2> c_offsets({{0, frame_size * 2}});
g.slice(c_offsets, extents).device(place) = ActCompute(context.Attr<int>("activation"), place,
g.slice(c_offsets, extents).tanh(); g.slice(c_offsets, extents), g.slice(c_offsets, extents));
auto c = g.slice(c_offsets, extents); // output candidate auto c = g.slice(c_offsets, extents); // output candidate
// calculate final output // calculate final output
...@@ -94,21 +111,37 @@ class GRUUnitKernel : public framework::OpKernel { ...@@ -94,21 +111,37 @@ class GRUUnitKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class GRUUnitGradKernel : public framework::OpKernel { class GRUUnitGradKernel : public framework::OpKernel<T> {
public: public:
template <typename Device, typename X, typename Y, typename DX, typename DY>
void ActGradCompute(const int act_type, const Device& d, X x, Y y, DX dx,
DY dy) const {
// x is dummy and won't be used even in Relu(use y instead)
if (act_type == identity)
dx.device(d) = dy;
else if (act_type == sigmoid)
SigmoidGradFunctor<T>()(d, x, y, dy, dx);
else if (act_type == tanh)
TanhGradFunctor<T>()(d, x, y, dy, dx);
else if (act_type == relu)
ReluGradFunctor<T>()(d, x, y, dy, dx);
else
PADDLE_THROW("unsupported activation type");
}
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("input"); auto* input = context.Input<Tensor>("Input");
auto* hidden_prev = context.Input<Tensor>("hidden_prev"); auto* hidden_prev = context.Input<Tensor>("HiddenPrev");
auto* weight = context.Input<Tensor>("weight"); auto* weight = context.Input<Tensor>("Weight");
auto* gate = context.Input<Tensor>("gate"); auto* gate = context.Input<Tensor>("Gate");
auto* reset_hidden_prev = context.Input<Tensor>("reset_hidden_prev"); auto* reset_hidden_prev = context.Input<Tensor>("ResetHiddenPrev");
auto* hidden_grad = context.Input<Tensor>(framework::GradVarName("hidden")); auto* hidden_grad = context.Input<Tensor>(framework::GradVarName("Hidden"));
auto* input_grad = context.Output<Tensor>(framework::GradVarName("input")); auto* input_grad = context.Output<Tensor>(framework::GradVarName("Input"));
auto* hidden_prev_grad = auto* hidden_prev_grad =
context.Output<Tensor>(framework::GradVarName("hidden_prev")); context.Output<Tensor>(framework::GradVarName("HiddenPrev"));
auto* weight_grad = auto* weight_grad =
context.Output<Tensor>(framework::GradVarName("weight")); context.Output<Tensor>(framework::GradVarName("Weight"));
auto* bias_grad = context.Output<Tensor>(framework::GradVarName("bias")); auto* bias_grad = context.Output<Tensor>(framework::GradVarName("Bias"));
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
hidden_prev_grad->mutable_data<T>(context.GetPlace()); hidden_prev_grad->mutable_data<T>(context.GetPlace());
weight_grad->mutable_data<T>(context.GetPlace()); weight_grad->mutable_data<T>(context.GetPlace());
...@@ -149,11 +182,11 @@ class GRUUnitGradKernel : public framework::OpKernel { ...@@ -149,11 +182,11 @@ class GRUUnitGradKernel : public framework::OpKernel {
auto c = g.slice(c_offsets, extents); // output candidate auto c = g.slice(c_offsets, extents); // output candidate
// backward for unactivated update gate // backward for unactivated update gate
d_g.slice(u_offsets, extents).device(place) = ActGradCompute(context.Attr<int>("gate_activation"), place, u, u,
d_h * (h_p - c) * u * (u.constant(T(1)) - u); d_g.slice(u_offsets, extents), d_h * (h_p - c));
// backward for unactivated output candidate // backward for unactivated output candidate
d_g.slice(c_offsets, extents).device(place) = ActGradCompute(context.Attr<int>("activation"), place, c, c,
d_h * (u.constant(T(1)) - u) * (c.constant(T(1)) - c * c); d_g.slice(c_offsets, extents), d_h * (u.constant(T(1)) - u));
// backward for reset_hidden_prev // backward for reset_hidden_prev
math::gemm<Place, T>(context.device_context(), false, true, batch_size, math::gemm<Place, T>(context.device_context(), false, true, batch_size,
frame_size, frame_size, 1, frame_size, frame_size, 1,
...@@ -167,8 +200,8 @@ class GRUUnitGradKernel : public framework::OpKernel { ...@@ -167,8 +200,8 @@ class GRUUnitGradKernel : public framework::OpKernel {
gate_grad_data + frame_size * 2, frame_size * 3, 0, gate_grad_data + frame_size * 2, frame_size * 3, 0,
weight_grad_data + frame_size * frame_size * 2, frame_size); weight_grad_data + frame_size * frame_size * 2, frame_size);
// backward for unactivated reset gate // backward for unactivated reset gate
d_g.slice(r_offsets, extents).device(place) = ActGradCompute(context.Attr<int>("gate_activation"), place, r, r,
d_r_h_p * h_p * r * (r.constant(T(1)) - r); d_g.slice(r_offsets, extents), d_r_h_p * h_p);
// backward for update_gate_weight and reset_gate_weight // backward for update_gate_weight and reset_gate_weight
math::gemm<Place, T>(context.device_context(), true, false, frame_size, math::gemm<Place, T>(context.device_context(), true, false, frame_size,
frame_size * 2, batch_size, 1, hidden_prev_data, frame_size * 2, batch_size, 1, hidden_prev_data,
......
...@@ -4,54 +4,84 @@ import numpy as np ...@@ -4,54 +4,84 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
def sigmoid_np(x): class GRUActivationType(OpTest):
identity = 0
sigmoid = 1
tanh = 2
relu = 3
def identity(x):
return x
def sigmoid(x):
return 1. / (1. + np.exp(-x)) return 1. / (1. + np.exp(-x))
def tanh_np(x): def tanh(x):
return 2. * sigmoid_np(2. * x) - 1. return 2. * sigmoid(2. * x) - 1.
def relu(x):
return np.maximum(x, 0)
class TestGRUUnitOp(OpTest): class TestGRUUnitOp(OpTest):
activate = {
GRUActivationType.identity: identity,
GRUActivationType.sigmoid: sigmoid,
GRUActivationType.tanh: tanh,
GRUActivationType.relu: relu,
}
def setUp(self): def setUp(self):
batch_size = 3 batch_size = 3
frame_size = 5 frame_size = 5
self.op_type = "gru_unit" self.op_type = 'gru_unit'
self.inputs = { self.inputs = {
'input': np.random.uniform( 'Input': np.random.uniform(
-0.1, 0.1, (batch_size, frame_size * 3)).astype("float32"), -0.1, 0.1, (batch_size, frame_size * 3)).astype('float32'),
'hidden_prev': np.random.uniform( 'HiddenPrev': np.random.uniform(
-0.1, 0.1, (batch_size, frame_size)).astype("float32"), -0.1, 0.1, (batch_size, frame_size)).astype('float32'),
'weight': np.random.uniform( 'Weight': np.random.uniform(
-1. / math.sqrt(frame_size), 1. / math.sqrt(frame_size), -1. / math.sqrt(frame_size), 1. / math.sqrt(frame_size),
(frame_size, frame_size * 3)).astype("float32"), (frame_size, frame_size * 3)).astype('float32'),
'bias': np.random.uniform(-0.1, 0.1, 'Bias': np.random.uniform(-0.1, 0.1,
(1, frame_size * 3)).astype("float32") (1, frame_size * 3)).astype('float32')
} }
x = self.inputs['input'] self.attrs = {
h_p = self.inputs['hidden_prev'] 'activation': GRUActivationType.tanh,
w = self.inputs['weight'] 'gate_activation': GRUActivationType.sigmoid
b = self.inputs['bias'] }
# GRU calculations
x = self.inputs['Input']
h_p = self.inputs['HiddenPrev']
w = self.inputs['Weight']
b = self.inputs['Bias']
g = x + np.tile(b, (batch_size, 1)) g = x + np.tile(b, (batch_size, 1))
w_u_r = w.flatten()[:frame_size * frame_size * 2].reshape( w_u_r = w.flatten()[:frame_size * frame_size * 2].reshape(
(frame_size, frame_size * 2)) (frame_size, frame_size * 2))
u_r = sigmoid_np(np.dot(h_p, w_u_r) + g[:, :frame_size * 2]) u_r = self.activate[self.attrs['gate_activation']](np.dot(
h_p, w_u_r) + g[:, :frame_size * 2])
u = u_r[:, :frame_size] u = u_r[:, :frame_size]
r = u_r[:, frame_size:frame_size * 2] r = u_r[:, frame_size:frame_size * 2]
r_h_p = r * h_p r_h_p = r * h_p
w_c = w.flatten()[frame_size * frame_size * 2:].reshape( w_c = w.flatten()[frame_size * frame_size * 2:].reshape(
(frame_size, frame_size)) (frame_size, frame_size))
c = tanh_np(np.dot(r_h_p, w_c) + g[:, frame_size * 2:]) c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) +
g[:, frame_size * 2:])
g = np.hstack((u_r, c)) g = np.hstack((u_r, c))
h = u * h_p + (1 - u) * c h = u * h_p + (1 - u) * c
self.outputs = {'gate': g, 'reset_hidden_prev': r_h_p, 'hidden': h}
self.outputs = {'Gate': g, 'ResetHiddenPrev': r_h_p, 'Hidden': h}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad( self.check_grad(
['input', 'hidden_prev', 'weight', 'bias'], ['hidden'], ['Input', 'HiddenPrev', 'Weight', 'Bias'], ['Hidden'],
max_relative_error=0.007) max_relative_error=0.007)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册