未验证 提交 89d6d69c 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #12781 from tensor-tang/feature/op/fusion_gru

add fusion gru 
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fusion_gru_op.h"
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/gru_compute.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
namespace paddle {
namespace operators {
void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasInput("WeightX"),
"Input(WeightX) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasInput("WeightH"),
"Input(WeightH) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedGate"),
"Output(BatchedGate) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchResetHiddenPrev"),
"Output(BatchResetHiddenPrev) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"),
"Output(BatchedHidden) of GRU should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of GRU should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
auto wx_dims = ctx->GetInputDim("WeightX");
PADDLE_ENFORCE_EQ(wx_dims.size(), 2,
"The rank of Input(WeightX) should be 2.");
PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1],
"The first dimension of Input(WeightX) "
"should be %d.",
x_dims[1]);
int frame_size = wx_dims[1] / 3;
auto wh_dims = ctx->GetInputDim("WeightH");
PADDLE_ENFORCE_EQ(wh_dims.size(), 2,
"The rank of Input(WeightH) should be 2.");
PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
"The first dimension of Input(WeightH) "
"should be %d.",
frame_size);
PADDLE_ENFORCE_EQ(wh_dims[1], 3 * frame_size,
"The second dimension of Input(WeightH) "
"should be 3 * %d.",
frame_size);
if (ctx->HasInput("H0")) {
auto h0_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
"The width of H0 must be equal to frame_size.");
}
if (ctx->HasInput("Bias")) {
auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1,
"The first dimension of Input(Bias) should be 1.");
PADDLE_ENFORCE_EQ(b_dims[1], frame_size * 3,
"The shape of Bias must be [1, frame_size * 3].");
}
framework::DDim out_dims({x_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("BatchedGate", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedHidden", out_dims);
ctx->SetOutputDim("BatchResetHiddenPrev", out_dims);
ctx->ShareLoD("X", "Hidden");
int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
ctx->SetOutputDim("XX", {x_dims[0], xx_width});
ctx->ShareLoD("X", "XX");
}
framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
void FusionGRUOpMaker::Make() {
AddInput("X",
"(LoDTensor) the input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T X M), where T is the "
"total time steps in this mini-batch, M is the dim size of x.");
AddInput("H0",
"(Tensor, optional) The initial hidden state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
"batch size, D is the hidden size.")
.AsDispensable();
AddInput("WeightX",
"(Tensor) The FC weight with shape (M x 3D),"
"where M is the dim size of x, D is the hidden size. ");
AddInput("WeightH",
"(Tensor) (D x 3D) Same as GRUOp, where D is the hidden size. ");
AddInput("Bias",
"(Tensor, optional) (1 x 3D)."
"Almost same as GRUOp."
"Note: if have FC bias it should be added on this bias.")
.AsDispensable();
AddOutput("XX",
"(LoDTensor) the result after X * WeightX (size is T x 4D)"
" or batched_X (size is T x M), this will be automatically chosen,"
" where T is the total time steps in this mini-batch,"
" D is the hidden size, M is the dim size of x input.")
.AsIntermediate();
AddOutput("BatchedGate", "(LoDTensor) Same as GRUOp").AsIntermediate();
AddOutput("BatchResetHiddenPrev", "(LoDTensor) (T x 3D) Same as GRUOp.")
.AsIntermediate();
AddOutput("BatchedHidden", "(LoDTensor) (T X D) Same as GRUOp.")
.AsIntermediate();
AddOutput("Hidden", "(LoDTensor) (T x D) Same as GRUOp");
AddAttr<std::string>("activation",
"(string, default tanh) "
"The activation type used for output candidate {h}_t.")
.SetDefault("tanh");
AddAttr<std::string>(
"gate_activation",
"(string, default sigmoid) "
"The activation type used in update gate and reset gate.")
.SetDefault("sigmoid");
AddAttr<bool>("is_reverse",
"(bool, defalut: False) "
"whether to compute reversed GRU.")
.SetDefault(false);
AddComment(R"DOC(
The Fusion complete GRU Operator.
This operator fuse the fully-connected operator into GRU,
more details can refer to GRU op.
)DOC");
}
template <typename DeviceContext, typename T>
inline void ReorderInitState(const DeviceContext& ctx,
const framework::Tensor& src,
framework::Vector<size_t> index_lod,
framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle;
dst->mutable_data<T>(src.dims(), ctx.GetPlace());
row_shuffle(ctx, src, index_lod, dst, indexed_src);
}
template <typename DeviceContext, typename T>
class FusionGRUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<LoDTensor>("X");
auto* wx = ctx.Input<Tensor>("WeightX");
auto* wh = ctx.Input<Tensor>("WeightH");
auto* bias = ctx.Input<Tensor>("Bias");
auto* h0 = ctx.Input<Tensor>("H0");
auto* xx = ctx.Output<LoDTensor>("XX");
auto* batched_gate = ctx.Output<LoDTensor>("BatchedGate");
auto* batch_reset_hidden_prev =
ctx.Output<LoDTensor>("BatchResetHiddenPrev");
auto* batch_hidden = ctx.Output<LoDTensor>("BatchedHidden");
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
bool is_reverse = ctx.Attr<bool>("is_reverse");
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
T* batched_gate_data = batched_gate->mutable_data<T>(ctx.GetPlace());
batch_reset_hidden_prev->mutable_data<T>(ctx.GetPlace());
batch_hidden->mutable_data<T>(ctx.GetPlace());
hidden_out->mutable_data<T>(ctx.GetPlace());
const T* x_data = x->data<T>();
const T* wx_data = wx->data<T>();
const T* wh_data = wh->data<T>();
auto x_dims = x->dims();
auto wx_dims = wx->dims();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
if (x_dims[1] > wx_dims[1]) {
math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1],
x_data, wx_data, xx_data,
bias ? bias->data<T>() : NULL);
to_batch(dev_ctx, *xx, batched_gate, true, is_reverse);
} else {
to_batch(dev_ctx, *x, xx, true, is_reverse);
batched_gate->set_lod(xx->lod());
math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1],
xx_data, wx_data, batched_gate_data,
bias ? bias->data<T>() : NULL);
}
int frame_size = static_cast<int>(wx_dims[1] / 3);
math::GRUMetaValue<T> gru_value;
gru_value.gate_weight = const_cast<T*>(wh_data);
gru_value.state_weight =
const_cast<T*>(wh_data + 2 * frame_size * frame_size);
Tensor ordered_h0;
framework::Vector<size_t> order(batched_gate->lod()[2]);
if (h0) {
ReorderInitState<DeviceContext, T>(
ctx.template device_context<DeviceContext>(), *h0, order, &ordered_h0,
true);
gru_value.prev_out_value = ordered_h0.data<T>();
} else {
gru_value.prev_out_value = nullptr;
}
auto batch_starts = batched_gate->lod()[0];
size_t seq_len = batch_starts.size() - 1;
auto active_node =
math::detail::GetActivationType(ctx.Attr<std::string>("activation"));
auto active_gate = math::detail::GetActivationType(
ctx.Attr<std::string>("gate_activation"));
#ifdef PADDLE_WITH_MKLML
// use MKL packed to speedup GEMM
if (FLAGS_paddle_num_threads >= 4) {
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
frame_size * 2 /*width of weight*/,
frame_size /*height of height*/);
PADDLE_ENFORCE(packed_gate);
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2,
frame_size, T(1.0), gru_value.gate_weight, frame_size * 2,
packed_gate);
T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
frame_size /*width of weight*/,
frame_size /*height of height*/);
PADDLE_ENFORCE(packed_state);
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size,
frame_size, T(1.0), gru_value.state_weight, frame_size,
packed_state);
for (size_t n = 0; n < seq_len; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
Tensor gate_t = batched_gate->Slice(bstart, bend);
Tensor reset_hidden_prev_t =
batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
gru_value.output_value = hidden_t.data<T>();
gru_value.gate_value = gate_t.data<T>();
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
if (gru_value.prev_out_value) {
blas.GEMM_COMPUTE(
CblasNoTrans, CblasPacked, cur_batch_size, frame_size * 2,
frame_size, gru_value.prev_out_value, frame_size, packed_gate,
frame_size * 2, T(1), gru_value.gate_value, frame_size * 3);
}
math::detail::forward_reset_output(
math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size,
cur_batch_size, active_gate);
if (gru_value.prev_out_value) {
blas.GEMM_COMPUTE(
CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size,
gru_value.reset_output_value, frame_size, packed_state,
frame_size, T(1), gru_value.gate_value + frame_size * 2,
frame_size * 3);
}
math::detail::forward_final_output(
math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
cur_batch_size, active_node);
gru_value.prev_out_value = gru_value.output_value;
}
blas.GEMM_FREE(packed_gate);
blas.GEMM_FREE(packed_state);
} else {
#endif
for (size_t n = 0; n < seq_len; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
Tensor gate_t = batched_gate->Slice(bstart, bend);
Tensor reset_hidden_prev_t =
batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
gru_value.output_value = hidden_t.data<T>();
gru_value.gate_value = gate_t.data<T>();
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
active_gate);
gru_value.prev_out_value = gru_value.output_value;
}
#ifdef PADDLE_WITH_MKLML
}
#endif
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batch_hidden->set_lod(batched_gate->lod());
to_seq(dev_ctx, *batch_hidden, hidden_out);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(
fusion_gru, ops::FusionGRUKernel<paddle::platform::CPUDeviceContext, float>,
ops::FusionGRUKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
class FusionGRUOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class FusionGRUOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import math
from op_test import OpTest
from test_gru_op import gru
from test_fusion_lstm_op import fc, ACTIVATION
def fusion_gru(
x, # T x M
lod, # 1 x N
h0, # N x D
wx, # M x 3D
wh, # D x 3D
bias, # 1 x 3D
is_reverse,
act_state,
act_gate):
return gru(fc(x, wx, bias),
lod,
h0,
wh,
np.zeros(
(1, wh.shape[1]), dtype='float64'),
is_reverse,
act_state,
act_gate)
class TestFusionGRUOp(OpTest):
def set_confs(self):
pass
def setUp(self):
self.op_type = "fusion_gru"
self.lod = [[2, 4, 3]]
self.M = 3
self.D = 5
self.is_reverse = False
self.with_h0 = True
self.with_bias = True
self.act_state = 'tanh'
self.act_gate = 'sigmoid'
self.set_confs()
T = sum(self.lod[0])
N = len(self.lod[0])
x = np.random.rand(T, self.M).astype('float64')
wx = np.random.rand(self.M, 3 * self.D).astype('float64')
wh = np.random.rand(self.D, 3 * self.D).astype('float64')
bias = np.random.rand(
1, 3 * self.D).astype('float64') if self.with_bias else np.zeros(
(1, 3 * self.D), dtype='float64')
h0 = np.random.rand(
N, self.D).astype('float64') if self.with_h0 else np.zeros(
(N, self.D), dtype='float64')
_, _, _, hidden = fusion_gru(
x, self.lod, h0, wx, wh, bias, self.is_reverse,
ACTIVATION[self.act_state], ACTIVATION[self.act_gate])
self.inputs = {'X': (x, self.lod), 'WeightX': wx, 'WeightH': wh}
if self.with_bias:
self.inputs['Bias'] = bias
if self.with_h0:
self.inputs['H0'] = h0
self.outputs = {'Hidden': (hidden, self.lod)}
self.attrs = {
'activation': self.act_state,
'gate_activation': self.act_gate,
'is_reverse': self.is_reverse
}
def test_check_output(self):
self.check_output(atol=1e-8)
class TestFusionGRUOpNoInitial(TestFusionGRUOp):
def set_confs(self):
self.with_h0 = False
class TestFusionGRUOpNoBias(TestFusionGRUOp):
def set_confs(self):
self.with_bias = False
class TestFusionGRUOpReverse(TestFusionGRUOp):
def set_confs(self):
self.is_reverse = True
class TestFusionGRUOpMD1(TestFusionGRUOp):
def set_confs(self):
self.M = 36
self.D = 8
class TestFusionGRUOpMD2(TestFusionGRUOp):
def set_confs(self):
self.M = 8
self.D = 8
class TestFusionGRUOpBS1(TestFusionGRUOp):
def set_confs(self):
self.lod = [[3]]
self.D = 16
if __name__ == "__main__":
unittest.main()
......@@ -19,22 +19,19 @@ import numpy as np
import math
import functools
from op_test import OpTest
from test_lstm_op import identity, sigmoid, tanh, relu
class TestGRUOp(OpTest):
lod = [[2, 4, 3]]
batch_size = sum(lod[0])
frame_size = 5
activate = {
'identity': identity,
'sigmoid': sigmoid,
'tanh': tanh,
'relu': relu
}
@staticmethod
def seq_to_batch(lod, is_reverse):
from test_lstm_op import ACTIVATION
def gru(
input, # T x 3D
lod, # 1 x N
h0, # N x D
weight, # D x 3D
bias, # 1 x 3D
is_reverse,
act_state,
act_gate):
def _seq_to_batch(lod, is_reverse):
idx_in_seq_list = []
seq_lens = lod[0]
seq_starts = [0]
......@@ -56,121 +53,125 @@ class TestGRUOp(OpTest):
idx_in_seq_list.append(idx_in_seq)
return idx_in_seq_list, sorted_seqs
def gru_step(self, x, h_p, w, b):
batch_size = x.shape[0]
frame_size = w.shape[0]
g = x + np.tile(b, (batch_size, 1))
w_u_r = w.flatten()[:frame_size * frame_size * 2].reshape(
(frame_size, frame_size * 2))
u_r = self.activate[self.attrs['gate_activation']](np.dot(
h_p, w_u_r) + g[:, :frame_size * 2])
u = u_r[:, :frame_size]
r = u_r[:, frame_size:frame_size * 2]
def _step(x, h_p, w, b, act_state, act_gate):
T = x.shape[0]
D = w.shape[0]
g = x + np.tile(b, (T, 1))
w_u_r = w.flatten()[:D * D * 2].reshape((D, D * 2))
u_r = act_gate(np.dot(h_p, w_u_r) + g[:, :D * 2])
u = u_r[:, :D]
r = u_r[:, D:D * 2]
r_h_p = r * h_p
w_c = w.flatten()[frame_size * frame_size * 2:].reshape(
(frame_size, frame_size))
c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) +
g[:, frame_size * 2:])
w_c = w.flatten()[D * D * 2:].reshape((D, D))
c = act_state(np.dot(r_h_p, w_c) + g[:, D * 2:])
g = np.hstack((u_r, c))
h = u * c + (1 - u) * h_p
return g, r_h_p, h
def gru(self):
input, lod = self.inputs['Input']
w = self.inputs['Weight']
b = self.inputs['Bias'] if 'Bias' in self.inputs else np.zeros(
(1, self.frame_size * 3))
batch_gate = self.outputs['BatchGate']
batch_reset_hidden_prev = self.outputs['BatchResetHiddenPrev']
batch_hidden = self.outputs['BatchHidden']
hidden = self.outputs['Hidden']
idx_in_seq_list = self.idx_in_seq_list
h_p = self.inputs['H0'][
self.sorted_seqs] if 'H0' in self.inputs else np.zeros(
(len(idx_in_seq_list[0]), self.frame_size))
num_batch = len(idx_in_seq_list)
end_idx = 0
for batch_idx in range(num_batch):
x = input[idx_in_seq_list[batch_idx]]
g, r_h_p, h = self.gru_step(x, h_p, w, b)
if batch_idx < (num_batch - 1):
h_p = h[:len(idx_in_seq_list[batch_idx + 1])]
start_idx = end_idx
end_idx = start_idx + len(idx_in_seq_list[batch_idx])
batch_gate[start_idx:end_idx] = g
batch_reset_hidden_prev[start_idx:end_idx] = r_h_p
batch_hidden[start_idx:end_idx] = h
hidden[idx_in_seq_list[batch_idx]] = h
return batch_gate, batch_reset_hidden_prev, hidden
def set_data(self):
lod = self.lod
self.idx_in_seq_list, self.sorted_seqs = self.seq_to_batch(
lod, self.is_reverse)
batch_size = self.batch_size
frame_size = self.frame_size
input = np.random.rand(batch_size, frame_size * 3).astype('float64')
h0 = np.random.rand(len(self.idx_in_seq_list[0]),
frame_size).astype('float64')
weight = np.random.rand(frame_size, frame_size * 3).astype('float64')
bias = np.random.rand(1, frame_size * 3).astype('float64')
self.inputs = {
'Input': (input, lod),
'H0': h0,
'Weight': weight,
'Bias': bias
}
T = sum(lod[0])
N = len(lod[0])
D = weight.shape[0]
batch_gate = np.zeros((T, 3 * D), dtype='float64')
batch_reset_hidden_prev = np.zeros((T, D), dtype='float64')
batch_hidden = np.zeros((T, D), dtype='float64')
hidden = np.zeros((T, D), dtype='float64')
idx_in_seq_list, sorted_seqs = _seq_to_batch(lod, is_reverse)
h_p = h0[sorted_seqs]
max_seq_len = len(idx_in_seq_list)
assert len(idx_in_seq_list[0]) == N
end_idx = 0
for batch_idx in range(max_seq_len):
x = input[idx_in_seq_list[batch_idx]]
g, r_h_p, h = _step(x, h_p, weight, bias, act_state, act_gate)
if batch_idx < (max_seq_len - 1):
h_p = h[:len(idx_in_seq_list[batch_idx + 1])]
start_idx = end_idx
end_idx = start_idx + len(idx_in_seq_list[batch_idx])
batch_gate[start_idx:end_idx] = g
batch_reset_hidden_prev[start_idx:end_idx] = r_h_p
batch_hidden[start_idx:end_idx] = h
hidden[idx_in_seq_list[batch_idx]] = h
return batch_gate, batch_reset_hidden_prev, batch_hidden, hidden
self.outputs = {
'BatchGate': np.zeros(
(batch_size, frame_size * 3), dtype='float64'),
'BatchResetHiddenPrev': np.zeros(
(batch_size, frame_size), dtype='float64'),
'BatchHidden': np.zeros(
(batch_size, frame_size), dtype='float64'),
'Hidden': np.zeros(
(batch_size, frame_size), dtype='float64')
}
class TestGRUOp(OpTest):
def set_confs(self):
self.is_reverse = False
self.attrs = {
'activation': 'tanh',
'gate_activation': 'sigmoid',
'is_reverse': self.is_reverse
}
pass
def setUp(self):
self.op_type = "gru"
self.lod = [[2, 4, 3]]
self.D = 5
self.is_reverse = False
self.with_h0 = True
self.with_bias = True
self.act_state = 'tanh'
self.act_gate = 'sigmoid'
self.set_confs()
self.set_data()
self.gru()
T = sum(self.lod[0])
N = len(self.lod[0])
input = np.random.rand(T, 3 * self.D).astype('float64')
weight = np.random.rand(self.D, 3 * self.D).astype('float64')
bias = np.random.rand(
1, 3 * self.D).astype('float64') if self.with_bias else np.zeros(
(1, 3 * self.D), dtype='float64')
h0 = np.random.rand(
N, self.D).astype('float64') if self.with_h0 else np.zeros(
(N, self.D), dtype='float64')
batch_gate, batch_reset_hidden_prev, batch_hidden, hidden = gru(
input, self.lod, h0, weight, bias, self.is_reverse,
ACTIVATION[self.act_state], ACTIVATION[self.act_gate])
self.inputs = {'Input': (input, self.lod), 'Weight': weight}
if self.with_bias:
self.inputs['Bias'] = bias
if self.with_h0:
self.inputs['H0'] = h0
self.outputs = {
'Hidden': (hidden, self.lod),
'BatchGate': batch_gate,
'BatchResetHiddenPrev': batch_reset_hidden_prev,
'BatchHidden': batch_hidden,
}
self.attrs = {
'activation': self.act_state,
'gate_activation': self.act_gate,
'is_reverse': self.is_reverse
}
def test_check_output(self):
self.check_output()
self.check_output(atol=1e-8)
def test_check_grad(self):
self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden'])
class TestGRUOpNoInitial(TestGRUOp):
def set_data(self):
super(TestGRUOpNoInitial, self).set_data()
self.inputs.pop('H0')
def set_confs(self):
self.with_h0 = False
def test_check_grad(self):
self.check_grad(['Input', 'Weight', 'Bias'], ['Hidden'])
class TestGRUOpNoBias(TestGRUOp):
def set_confs(self):
self.with_bias = False
def test_check_grad(self):
self.check_grad(['Input', 'H0', 'Weight'], ['Hidden'])
class TestGRUOpReverse(TestGRUOp):
def set_confs(self):
self.is_reverse = True
self.attrs = {
'activation': 'tanh',
'gate_activation': 'sigmoid',
'is_reverse': self.is_reverse
}
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册