You need to sign in or sign up before continuing.
提交 0bc5a122 编写于 作者: G guosheng

Refine gru_unit_op by optional bias

上级 1cabdb87
...@@ -31,8 +31,6 @@ class GRUUnitOp : public framework::OperatorWithKernel { ...@@ -31,8 +31,6 @@ class GRUUnitOp : public framework::OperatorWithKernel {
"Input(%s) of GRUUnitOp should not be null.", "HiddenPrev"); "Input(%s) of GRUUnitOp should not be null.", "HiddenPrev");
PADDLE_ENFORCE(ctx->HasInput("Weight"), PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(%s) of GRUUnitOp should not be null.", "Weight"); "Input(%s) of GRUUnitOp should not be null.", "Weight");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(%s) of GRUUnitOp should not be null.", "Bias");
PADDLE_ENFORCE(ctx->HasOutput("Gate"), PADDLE_ENFORCE(ctx->HasOutput("Gate"),
"Output(%s) of GRUUnitOp should not be null.", "Gate"); "Output(%s) of GRUUnitOp should not be null.", "Gate");
PADDLE_ENFORCE(ctx->HasOutput("ResetHiddenPrev"), PADDLE_ENFORCE(ctx->HasOutput("ResetHiddenPrev"),
...@@ -43,14 +41,11 @@ class GRUUnitOp : public framework::OperatorWithKernel { ...@@ -43,14 +41,11 @@ class GRUUnitOp : public framework::OperatorWithKernel {
auto input_dims = ctx->GetInputDim("Input"); auto input_dims = ctx->GetInputDim("Input");
auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev"); auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev");
auto weight_dims = ctx->GetInputDim("Weight"); auto weight_dims = ctx->GetInputDim("Weight");
auto bias_dims = ctx->GetInputDim("Bias");
int batch_size = input_dims[0]; int batch_size = input_dims[0];
int input_size = input_dims[1]; int input_size = input_dims[1];
int frame_size = hidden_prev_dims[1]; int frame_size = hidden_prev_dims[1];
int weight_height = weight_dims[0]; int weight_height = weight_dims[0];
int weight_width = weight_dims[1]; int weight_width = weight_dims[1];
int bias_height = bias_dims[0];
int bias_width = bias_dims[1];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input_size, frame_size * 3, input_size, frame_size * 3,
"The input_size must be 3 times of frame_size in GRUUnitOp."); "The input_size must be 3 times of frame_size in GRUUnitOp.");
...@@ -60,10 +55,16 @@ class GRUUnitOp : public framework::OperatorWithKernel { ...@@ -60,10 +55,16 @@ class GRUUnitOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_width, frame_size * 3, weight_width, frame_size * 3,
"The shape of Weight matrix must be [frame_size, frame_size * 3]."); "The shape of Weight matrix must be [frame_size, frame_size * 3].");
PADDLE_ENFORCE_EQ(bias_height, 1, auto bias = Input("Bias");
"The shape of Bias must be [1, frame_size * 3]."); if (bias != framework::kEmptyVarName) {
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3, auto bias_dims = ctx->GetInputDim("Bias");
"The shape of Bias must be [1, frame_size * 3]."); int bias_height = bias_dims[0];
int bias_width = bias_dims[1];
PADDLE_ENFORCE_EQ(bias_height, 1,
"The shape of Bias must be [1, frame_size * 3].");
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
"The shape of Bias must be [1, frame_size * 3].");
}
ctx->SetOutputDim("Gate", {batch_size, frame_size * 3}); ctx->SetOutputDim("Gate", {batch_size, frame_size * 3});
ctx->SetOutputDim("ResetHiddenPrev", {batch_size, frame_size}); ctx->SetOutputDim("ResetHiddenPrev", {batch_size, frame_size});
ctx->SetOutputDim("Hidden", {batch_size, frame_size}); ctx->SetOutputDim("Hidden", {batch_size, frame_size});
...@@ -139,8 +140,6 @@ class GRUUnitGradOp : public framework::OperatorWithKernel { ...@@ -139,8 +140,6 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
"HiddenPrev"); "HiddenPrev");
PADDLE_ENFORCE(ctx->HasInput("Weight"), PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(%s) of GRUUnitGradOp should not be null.", "Weight"); "Input(%s) of GRUUnitGradOp should not be null.", "Weight");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(%s) of GRUUnitGradOp should not be null.", "Bias");
PADDLE_ENFORCE(ctx->HasInput("Gate"), PADDLE_ENFORCE(ctx->HasInput("Gate"),
"Input(%s) of GRUUnitGradOp should not be null.", "Gate"); "Input(%s) of GRUUnitGradOp should not be null.", "Gate");
PADDLE_ENFORCE(ctx->HasInput("ResetHiddenPrev"), PADDLE_ENFORCE(ctx->HasInput("ResetHiddenPrev"),
...@@ -160,14 +159,11 @@ class GRUUnitGradOp : public framework::OperatorWithKernel { ...@@ -160,14 +159,11 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
auto input_dims = ctx->GetInputDim("Input"); auto input_dims = ctx->GetInputDim("Input");
auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev"); auto hidden_prev_dims = ctx->GetInputDim("HiddenPrev");
auto weight_dims = ctx->GetInputDim("Weight"); auto weight_dims = ctx->GetInputDim("Weight");
auto bias_dims = ctx->GetInputDim("Bias");
// int batch_size = input_dims[0]; // int batch_size = input_dims[0];
int input_size = input_dims[1]; int input_size = input_dims[1];
int frame_size = hidden_prev_dims[1]; int frame_size = hidden_prev_dims[1];
int weight_height = weight_dims[0]; int weight_height = weight_dims[0];
int weight_width = weight_dims[1]; int weight_width = weight_dims[1];
int bias_height = bias_dims[0];
int bias_width = bias_dims[1];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input_size, frame_size * 3, input_size, frame_size * 3,
"The input_size must be 3 times of frame_size in GRUUnitOp."); "The input_size must be 3 times of frame_size in GRUUnitOp.");
...@@ -177,10 +173,19 @@ class GRUUnitGradOp : public framework::OperatorWithKernel { ...@@ -177,10 +173,19 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_width, frame_size * 3, weight_width, frame_size * 3,
"The shape of Weight matrix must be [frame_size, frame_size * 3]."); "The shape of Weight matrix must be [frame_size, frame_size * 3].");
PADDLE_ENFORCE_EQ(bias_height, 1, auto bias = Input("Bias");
"The shape of Bias must be [1, frame_size * 3]."); if (bias != framework::kEmptyVarName) {
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3, auto bias_dims = ctx->GetInputDim("Bias");
"The shape of Bias must be [1, frame_size * 3]."); int bias_height = bias_dims[0];
int bias_width = bias_dims[1];
PADDLE_ENFORCE_EQ(bias_height, 1,
"The shape of Bias must be [1, frame_size * 3].");
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
"The shape of Bias must be [1, frame_size * 3].");
auto bias_grad_name = framework::GradVarName("Bias");
if (ctx->HasOutput(bias_grad_name))
ctx->SetOutputDim(bias_grad_name, bias_dims);
}
auto input_grad_name = framework::GradVarName("Input"); auto input_grad_name = framework::GradVarName("Input");
if (ctx->HasOutput(input_grad_name)) if (ctx->HasOutput(input_grad_name))
ctx->SetOutputDim(input_grad_name, input_dims); ctx->SetOutputDim(input_grad_name, input_dims);
...@@ -190,9 +195,6 @@ class GRUUnitGradOp : public framework::OperatorWithKernel { ...@@ -190,9 +195,6 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
auto weight_grad_name = framework::GradVarName("Weight"); auto weight_grad_name = framework::GradVarName("Weight");
if (ctx->HasOutput(weight_grad_name)) if (ctx->HasOutput(weight_grad_name))
ctx->SetOutputDim(weight_grad_name, weight_dims); ctx->SetOutputDim(weight_grad_name, weight_dims);
auto bias_grad_name = framework::GradVarName("Bias");
if (ctx->HasOutput(bias_grad_name))
ctx->SetOutputDim(bias_grad_name, bias_dims);
} }
}; };
......
...@@ -64,16 +64,20 @@ class GRUUnitKernel : public framework::OpKernel<T> { ...@@ -64,16 +64,20 @@ class GRUUnitKernel : public framework::OpKernel<T> {
auto x = EigenMatrix<T>::From(*input); auto x = EigenMatrix<T>::From(*input);
auto h_p = EigenMatrix<T>::From(*hidden_prev); auto h_p = EigenMatrix<T>::From(*hidden_prev);
auto b = EigenMatrix<T>::From(*bias);
auto g = EigenMatrix<T>::From(*gate); auto g = EigenMatrix<T>::From(*gate);
auto r_h_p = EigenMatrix<T>::From(*reset_hidden_prev); auto r_h_p = EigenMatrix<T>::From(*reset_hidden_prev);
auto h = EigenMatrix<T>::From(*hidden); auto h = EigenMatrix<T>::From(*hidden);
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
// calculate unactivated gate outputs // calculate unactivated gate outputs
g.device(place) = x + if (bias) {
b.reshape(Eigen::array<int, 2>({{1, frame_size * 3}})) auto b = EigenMatrix<T>::From(*bias);
.broadcast(Eigen::array<int, 2>({{batch_size, 1}})); g.device(place) = x +
b.reshape(Eigen::array<int, 2>({{1, frame_size * 3}}))
.broadcast(Eigen::array<int, 2>({{batch_size, 1}}));
} else {
g.device(place) = x;
}
const T* hidden_prev_data = hidden_prev->data<T>(); const T* hidden_prev_data = hidden_prev->data<T>();
const T* weight_data = weight->data<T>(); const T* weight_data = weight->data<T>();
T* gate_data = gate->data<T>(); T* gate_data = gate->data<T>();
...@@ -145,7 +149,6 @@ class GRUUnitGradKernel : public framework::OpKernel<T> { ...@@ -145,7 +149,6 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
hidden_prev_grad->mutable_data<T>(context.GetPlace()); hidden_prev_grad->mutable_data<T>(context.GetPlace());
weight_grad->mutable_data<T>(context.GetPlace()); weight_grad->mutable_data<T>(context.GetPlace());
bias_grad->mutable_data<T>(context.GetPlace());
Tensor gate_grad; Tensor gate_grad;
gate_grad.mutable_data<T>(input->dims(), context.GetPlace()); gate_grad.mutable_data<T>(input->dims(), context.GetPlace());
Tensor reset_hidden_prev_grad; Tensor reset_hidden_prev_grad;
...@@ -168,7 +171,6 @@ class GRUUnitGradKernel : public framework::OpKernel<T> { ...@@ -168,7 +171,6 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
auto d_h = EigenMatrix<T>::From(*hidden_grad); auto d_h = EigenMatrix<T>::From(*hidden_grad);
auto d_x = EigenMatrix<T>::From(*input_grad); auto d_x = EigenMatrix<T>::From(*input_grad);
auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad); auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad);
auto d_b = EigenMatrix<T>::From(*bias_grad);
auto d_g = EigenMatrix<T>::From(gate_grad); auto d_g = EigenMatrix<T>::From(gate_grad);
auto d_r_h_p = EigenMatrix<T>::From(reset_hidden_prev_grad); auto d_r_h_p = EigenMatrix<T>::From(reset_hidden_prev_grad);
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
...@@ -216,7 +218,11 @@ class GRUUnitGradKernel : public framework::OpKernel<T> { ...@@ -216,7 +218,11 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
// backward for input // backward for input
d_x.device(place) = d_g; d_x.device(place) = d_g;
// backward for bias // backward for bias
d_b.device(place) = d_g.sum(Eigen::array<int, 1>({{0}})); if (bias_grad) {
bias_grad->mutable_data<T>(context.GetPlace());
auto d_b = EigenMatrix<T>::From(*bias_grad);
d_b.device(place) = d_g.sum(Eigen::array<int, 1>({{0}}));
}
} }
}; };
......
...@@ -28,6 +28,8 @@ def relu(x): ...@@ -28,6 +28,8 @@ def relu(x):
class TestGRUUnitOp(OpTest): class TestGRUUnitOp(OpTest):
batch_size = 3
frame_size = 5
activate = { activate = {
GRUActivationType.identity: identity, GRUActivationType.identity: identity,
GRUActivationType.sigmoid: sigmoid, GRUActivationType.sigmoid: sigmoid,
...@@ -35,9 +37,9 @@ class TestGRUUnitOp(OpTest): ...@@ -35,9 +37,9 @@ class TestGRUUnitOp(OpTest):
GRUActivationType.relu: relu, GRUActivationType.relu: relu,
} }
def setUp(self): def set_inputs(self):
batch_size = 3 batch_size = self.batch_size
frame_size = 5 frame_size = self.frame_size
self.op_type = 'gru_unit' self.op_type = 'gru_unit'
self.inputs = { self.inputs = {
'Input': np.random.uniform( 'Input': np.random.uniform(
...@@ -47,18 +49,21 @@ class TestGRUUnitOp(OpTest): ...@@ -47,18 +49,21 @@ class TestGRUUnitOp(OpTest):
'Weight': np.random.uniform( 'Weight': np.random.uniform(
-1. / math.sqrt(frame_size), 1. / math.sqrt(frame_size), -1. / math.sqrt(frame_size), 1. / math.sqrt(frame_size),
(frame_size, frame_size * 3)).astype('float32'), (frame_size, frame_size * 3)).astype('float32'),
'Bias': np.random.uniform(-0.1, 0.1,
(1, frame_size * 3)).astype('float32')
} }
self.attrs = { self.attrs = {
'activation': GRUActivationType.tanh, 'activation': GRUActivationType.tanh,
'gate_activation': GRUActivationType.sigmoid 'gate_activation': GRUActivationType.sigmoid
} }
def set_outputs(self):
# GRU calculations # GRU calculations
batch_size = self.batch_size
frame_size = self.frame_size
x = self.inputs['Input'] x = self.inputs['Input']
h_p = self.inputs['HiddenPrev'] h_p = self.inputs['HiddenPrev']
w = self.inputs['Weight'] w = self.inputs['Weight']
b = self.inputs['Bias'] b = self.inputs['Bias'] if self.inputs.has_key('Bias') else np.zeros(
(1, frame_size * 3))
g = x + np.tile(b, (batch_size, 1)) g = x + np.tile(b, (batch_size, 1))
w_u_r = w.flatten()[:frame_size * frame_size * 2].reshape( w_u_r = w.flatten()[:frame_size * frame_size * 2].reshape(
(frame_size, frame_size * 2)) (frame_size, frame_size * 2))
...@@ -73,12 +78,33 @@ class TestGRUUnitOp(OpTest): ...@@ -73,12 +78,33 @@ class TestGRUUnitOp(OpTest):
g[:, frame_size * 2:]) g[:, frame_size * 2:])
g = np.hstack((u_r, c)) g = np.hstack((u_r, c))
h = u * h_p + (1 - u) * c h = u * h_p + (1 - u) * c
self.outputs = {'Gate': g, 'ResetHiddenPrev': r_h_p, 'Hidden': h} self.outputs = {'Gate': g, 'ResetHiddenPrev': r_h_p, 'Hidden': h}
def setUp(self):
self.set_inputs()
self.set_outputs()
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self):
self.check_grad(
['Input', 'HiddenPrev', 'Weight'], ['Hidden'],
max_relative_error=0.007)
class TestGRUUnitOpWithBias(TestGRUUnitOp):
def set_inputs(self):
batch_size = self.batch_size
frame_size = self.frame_size
super(TestGRUUnitOpWithBias, self).set_inputs()
self.inputs['Bias'] = np.random.uniform(
-0.1, 0.1, (1, frame_size * 3)).astype('float32')
self.attrs = {
'activation': GRUActivationType.identity,
'gate_activation': GRUActivationType.sigmoid
}
def test_check_grad(self): def test_check_grad(self):
self.check_grad( self.check_grad(
['Input', 'HiddenPrev', 'Weight', 'Bias'], ['Hidden'], ['Input', 'HiddenPrev', 'Weight', 'Bias'], ['Hidden'],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册