未验证 提交 fa9d3fa5 编写于 作者: G Guo Sheng 提交者: GitHub

Incorporate cudnn_lstm into LSTM api (#27217)

* Incorporate cudnn_lstm into LSTM api.
test=develop

* Make coalesce_tensor support alignment optionally.
test=develop

* Reorganize RNN apis. test=develop

* Fix cudnn rnn layout conversion.
test=develop

* Add sequence_length support for RNN cudnn implement.
Add optional init_h and init_c gradient for cudnn_lstm_op.
test=develop

* Use create_parameter for rnn cudnn impl.
test=develop

* Move `self._flat_weight = self.create_parameter()` in RNNBase to main_program.
test=develop

* Update RNN api unittest to use set_device.
test=develop

* Fix set_place for unit tests of RNN apis.
test=develop

* Fix use_align in coalesce_tensor_op.
test=develop

* Adjust RNN apis arguments according to comments.
test=develop

* Polish documents for SimpleRNN apis.
test=develop

* Refine random seed in cudnn_lstm_op.
Expose rnn params from sublayers to RNN.
test=develop

* Fix RNN saving for jit.save.
Refine cudnn_lstm dropout behavior.
test=develop

* Fix doc of GRU. test=develop

* Use ShareDataWith to avoid copying for cudnn_lstm_op test.
test=develop

* Remove updates on cudnn_lstm temporarily.
test=develop

* Use ShareDataWith to avoid copying for cudnn_lstm_op test.
test=develop

* Refine random seed in cudnn_lstm_op.
test=develop

* Fix test_lstm by adjust ConcreteProgram buffer getter.
test=develop

* Use create_parameter instead of create_var for rnn._flat_weight for static graph usage.
test=develop

* Remove W input for cudnn_lstm to pass unused_var_check.
test=develop

* Add test_predict for RNN unit tests coverage.
test=develop

* Fix code style of rnn.
test=develop

* Fix F.rnn usage in rnn.py.
test=develop
上级 78b1026f
...@@ -67,6 +67,7 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> { ...@@ -67,6 +67,7 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
} }
auto in_tensors = context.MultiInput<framework::LoDTensor>("Input"); auto in_tensors = context.MultiInput<framework::LoDTensor>("Input");
bool use_align = context.Attr<bool>("use_align");
if (context.Attr<bool>("check_name")) { if (context.Attr<bool>("check_name")) {
for (size_t i = 0; i < in_var_names.size(); ++i) { for (size_t i = 0; i < in_var_names.size(); ++i) {
...@@ -93,7 +94,7 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> { ...@@ -93,7 +94,7 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
context.Attr<int>("dtype")); context.Attr<int>("dtype"));
size_t size_of_dtype = framework::SizeOfType(dtype); size_t size_of_dtype = framework::SizeOfType(dtype);
GetMemSizeAndDtype(in_tensors, in_var_names, &numel, size_of_dtype, GetMemSizeAndDtype(in_tensors, in_var_names, &numel, size_of_dtype,
context.GetPlace()); context.GetPlace(), use_align);
// Alloc the continuous space // Alloc the continuous space
auto fused_tensor = context.Output<framework::LoDTensor>("FusedOutput"); auto fused_tensor = context.Output<framework::LoDTensor>("FusedOutput");
...@@ -111,8 +112,11 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> { ...@@ -111,8 +112,11 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
framework::TensorCopy(*in_tensors[i], context.GetPlace(), dev_ctx, framework::TensorCopy(*in_tensors[i], context.GetPlace(), dev_ctx,
&sub_tensor); &sub_tensor);
offset += platform::Alignment(len * size_of_dtype, context.GetPlace()) / offset +=
size_of_dtype; use_align
? platform::Alignment(len * size_of_dtype, context.GetPlace()) /
size_of_dtype
: len;
} }
} else if (context.Attr<bool>("set_constant")) { } else if (context.Attr<bool>("set_constant")) {
math::SetConstant<DeviceContext, T> set_constant; math::SetConstant<DeviceContext, T> set_constant;
...@@ -131,8 +135,10 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> { ...@@ -131,8 +135,10 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
->ShareDataWith(fused_tensor->Slice( ->ShareDataWith(fused_tensor->Slice(
static_cast<int64_t>(offset), static_cast<int64_t>(offset + len))) static_cast<int64_t>(offset), static_cast<int64_t>(offset + len)))
.Resize(dim); .Resize(dim);
len = platform::Alignment(len * size_of_dtype, context.GetPlace()) / len = use_align
size_of_dtype; ? platform::Alignment(len * size_of_dtype, context.GetPlace()) /
size_of_dtype
: len;
offset += len; offset += len;
ss << "output(" << out_var_names[i] << ") dim:(" << dim << ")" ss << "output(" << out_var_names[i] << ") dim:(" << dim << ")"
<< " address: " << out_tensors[i]->data<void>() << ", "; << " address: " << out_tensors[i]->data<void>() << ", ";
...@@ -144,7 +150,8 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> { ...@@ -144,7 +150,8 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
void GetMemSizeAndDtype( void GetMemSizeAndDtype(
const std::vector<const framework::LoDTensor *> &lod_tensors, const std::vector<const framework::LoDTensor *> &lod_tensors,
const std::vector<std::string> var_names, size_t *numel, const std::vector<std::string> var_names, size_t *numel,
const size_t &size_of_dtype, const platform::Place &place) const { const size_t &size_of_dtype, const platform::Place &place,
const bool use_align = true) const {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
lod_tensors.size(), var_names.size(), lod_tensors.size(), var_names.size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -167,9 +174,11 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> { ...@@ -167,9 +174,11 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
ss << "input(" << var_names[i] << ") dim:(" << lod_tensors[i]->dims() ss << "input(" << var_names[i] << ") dim:(" << lod_tensors[i]->dims()
<< ") " << ") "
<< " addres:" << lod_tensors[i]->data<void>() << ", "; << " addres:" << lod_tensors[i]->data<void>() << ", ";
*numel += platform::Alignment(static_cast<size_t>(size) * size_of_dtype, *numel += use_align
place) / ? platform::Alignment(
size_of_dtype; static_cast<size_t>(size) * size_of_dtype, place) /
size_of_dtype
: static_cast<size_t>(size);
} }
VLOG(10) << ss.str(); VLOG(10) << ss.str();
...@@ -223,6 +232,10 @@ class CoalesceTensorOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -223,6 +232,10 @@ class CoalesceTensorOpMaker : public framework::OpProtoAndCheckerMaker {
"Whether to check the name of Input and Output to ensure " "Whether to check the name of Input and Output to ensure "
"they are the same separately.") "they are the same separately.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("use_align",
"Whether to consider memory chunk and take alignment into "
"account for inputs and outputs.")
.SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
CoalesceTensor Operator. CoalesceTensor Operator.
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/cudnn_lstm_cache.h" #include "paddle/fluid/operators/cudnn_lstm_cache.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -156,6 +157,21 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> { ...@@ -156,6 +157,21 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
bool is_test = ctx.Attr<bool>("is_test"); bool is_test = ctx.Attr<bool>("is_test");
int seed = ctx.Attr<int>("seed"); int seed = ctx.Attr<int>("seed");
if (!is_test) {
int device_id =
BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed == 0) {
// If perform `manual_seed` in python and inner seed is not specified
// (equals 0), use global generator generated seed.
seed = static_cast<int>(gen_cuda->Random64());
} else if (seed == 0) {
// use random generated seed
std::random_device rd;
seed = rd();
} // else use `ctx.Attr<int>("seed")` specified seed
}
bool has_seq_length = ctx.HasInput("SequenceLength"); bool has_seq_length = ctx.HasInput("SequenceLength");
std::vector<int> SequenceLength; std::vector<int> SequenceLength;
if (has_seq_length) { if (has_seq_length) {
...@@ -194,13 +210,25 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> { ...@@ -194,13 +210,25 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
if (!continuous) { if (!continuous) {
LOG_FIRST_N(WARNING, 2) LOG_FIRST_N(WARNING, 2)
<< "If the memory space of the Input WeightList is not " << "If the memory space of the Input WeightList is not continuous, "
"continuous, less efficient calculation will be " "less efficient calculation will be called. Please call "
"called. Please call coalesce_tensor op to make the " "flatten_parameters() to make the input memory continuous.";
"input memory continuous.";
weight_whole.mutable_data<T>({weight_numel}, place); weight_whole.mutable_data<T>({weight_numel}, place);
weight_to_tensor<T>(place, stream, weight_list, &weight_whole); weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
w_data = weight_whole.data<T>(); w_data = weight_whole.data<T>();
if (is_test) { // maybe also reset small weights' ptr for training
int offset = 0;
for (size_t i = 0; i < weight_list.size(); ++i) {
size_t len = weight_list[i]->numel();
auto dim = weight_list[i]->dims();
const_cast<Tensor *>(weight_list[i])
->ShareDataWith(
weight_whole.Slice(static_cast<int64_t>(offset),
static_cast<int64_t>(offset + len)))
.Resize(dim);
offset += len;
}
}
} else { } else {
w_data = const_cast<T *>(weight_list[0]->data<T>()); w_data = const_cast<T *>(weight_list[0]->data<T>());
} }
...@@ -226,12 +254,6 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> { ...@@ -226,12 +254,6 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
LSTMInferece<T>(has_seq_length, handle, seq_length, &rnn, x_data, LSTMInferece<T>(has_seq_length, handle, seq_length, &rnn, x_data,
init_h_data, init_c_data, w_data, out_data, last_h_data, init_h_data, init_c_data, w_data, out_data, last_h_data,
last_c_data, &workspace_data_, workspace_size); last_c_data, &workspace_data_, workspace_size);
if (!w_initialized && ctx.HasInput("W") && ctx.HasInput("WeightList")) {
auto *W = const_cast<Tensor *>(ctx.Input<Tensor>("W"));
auto weight_list = ctx.MultiInput<framework::Tensor>("WeightList");
W->mutable_data<T>({weight_numel}, place);
weight_to_tensor<T>(place, stream, weight_list, W);
}
} else { } else {
if (!has_seq_length) { if (!has_seq_length) {
// for train // for train
......
...@@ -89,6 +89,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -89,6 +89,7 @@ REGISTER_OP_CPU_KERNEL(
save, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, float>, save, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, double>, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int>, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int16_t>, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -21,6 +21,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -21,6 +21,7 @@ REGISTER_OP_CUDA_KERNEL(
save, ops::SaveOpKernel<paddle::platform::CUDADeviceContext, float>, save, ops::SaveOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, double>, ops::SaveOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int>, ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, ops::SaveOpKernel<paddle::platform::CUDADeviceContext,
......
...@@ -174,7 +174,7 @@ class CacheKey(object): ...@@ -174,7 +174,7 @@ class CacheKey(object):
# 1. filter `self` in args # 1. filter `self` in args
if args and isinstance(args[0], layers.Layer): if args and isinstance(args[0], layers.Layer):
args = args[1:] args = args[1:]
# 2. convert tensor and numpy array into InputSpec # 2. convert tensor and numpy array into InputSpec
_args, _kwargs = function_spec.unified_args_and_kwargs(args, kwargs) _args, _kwargs = function_spec.unified_args_and_kwargs(args, kwargs)
input_with_spec = function_spec.args_to_input_spec(_args, _kwargs) input_with_spec = function_spec.args_to_input_spec(_args, _kwargs)
...@@ -592,9 +592,8 @@ class ConcreteProgram(object): ...@@ -592,9 +592,8 @@ class ConcreteProgram(object):
inputs = tuple([class_instance] + list(inputs)) inputs = tuple([class_instance] + list(inputs))
# 2. Gets all ParamBases and buffered VarBases in the function # 2. Gets all ParamBases and buffered VarBases in the function
all_parameters_and_buffers = list( all_parameters_and_buffers = _extract_indeed_params_buffers(
get_parameters(class_instance).values()) + list( class_instance)
get_buffers(class_instance).values())
# 3. Builds program only once and returns the output Variables. # 3. Builds program only once and returns the output Variables.
with param_guard(get_parameters( with param_guard(get_parameters(
...@@ -622,6 +621,17 @@ class ConcreteProgram(object): ...@@ -622,6 +621,17 @@ class ConcreteProgram(object):
startup_program=startup_program) startup_program=startup_program)
def _extract_indeed_params_buffers(class_instance):
"""
To filter not initialzed buffers.
"""
params = list(get_parameters(class_instance).values())
buffers = list(get_buffers(class_instance).values())
buffers = [buffer for buffer in buffers if buffer.shape != []]
return params + buffers
class ProgramCache(object): class ProgramCache(object):
""" """
Wrapper class for the program functions defined by dygraph function. Wrapper class for the program functions defined by dygraph function.
......
...@@ -29,11 +29,13 @@ class TestSimpleRNN(unittest.TestCase): ...@@ -29,11 +29,13 @@ class TestSimpleRNN(unittest.TestCase):
self.time_major = time_major self.time_major = time_major
self.direction = direction self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1 self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \ self.place = place
else paddle.CUDAPlace(0)
def setUp(self): def setUp(self):
paddle.disable_static(self.place) # Since `set_device` is global, set `set_device` in `setUp` rather than
# `__init__` to avoid using an error device set by another test case.
place = paddle.set_device(self.place)
paddle.disable_static(place)
rnn1 = SimpleRNN( rnn1 = SimpleRNN(
16, 32, 2, time_major=self.time_major, direction=self.direction) 16, 32, 2, time_major=self.time_major, direction=self.direction)
rnn2 = paddle.nn.SimpleRNN( rnn2 = paddle.nn.SimpleRNN(
...@@ -103,11 +105,13 @@ class TestGRU(unittest.TestCase): ...@@ -103,11 +105,13 @@ class TestGRU(unittest.TestCase):
self.time_major = time_major self.time_major = time_major
self.direction = direction self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1 self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \ self.place = place
else paddle.CUDAPlace(0)
def setUp(self): def setUp(self):
paddle.disable_static(self.place) # Since `set_device` is global, set `set_device` in `setUp` rather than
# `__init__` to avoid using an error device set by another test case.
place = paddle.set_device(self.place)
paddle.disable_static(place)
rnn1 = GRU(16, rnn1 = GRU(16,
32, 32,
2, 2,
...@@ -183,11 +187,13 @@ class TestLSTM(unittest.TestCase): ...@@ -183,11 +187,13 @@ class TestLSTM(unittest.TestCase):
self.time_major = time_major self.time_major = time_major
self.direction = direction self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1 self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \ self.place = place
else paddle.CUDAPlace(0)
def setUp(self): def setUp(self):
paddle.disable_static(self.place) # Since `set_device` is global, set `set_device` in `setUp` rather than
# `__init__` to avoid using an error device set by another test case.
place = paddle.set_device(self.place)
paddle.disable_static(place)
rnn1 = LSTM( rnn1 = LSTM(
16, 32, 2, time_major=self.time_major, direction=self.direction) 16, 32, 2, time_major=self.time_major, direction=self.direction)
rnn2 = paddle.nn.LSTM( rnn2 = paddle.nn.LSTM(
...@@ -251,10 +257,68 @@ class TestLSTM(unittest.TestCase): ...@@ -251,10 +257,68 @@ class TestLSTM(unittest.TestCase):
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5) np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5)
def test_predict(self):
place = paddle.set_device(self.place)
paddle.manual_seed(123)
np.random.seed(123)
class Net(paddle.nn.Layer):
def __init__(self):
super(Net, self).__init__()
self.rnn1 = paddle.nn.LSTM(
16, 32, 2, direction="bidirectional", dropout=0.1)
def forward(self, input):
return self.rnn1(input)
x = paddle.randn((4, 10, 16))
x.stop_gradient = False
seq_len = paddle.to_tensor(np.array([10, 6, 8, 5]))
mask = sequence_mask(seq_len, maxlen=10, dtype=x.dtype)
mask = paddle.unsqueeze(mask, [2])
rnn = Net()
y, (h, c) = rnn(x)
y = y * mask
loss = paddle.mean(y)
loss.backward()
optimizer = paddle.optimizer.Adam(
learning_rate=0.1, parameters=rnn.parameters())
optimizer.step()
rnn.eval()
y, (h, c) = rnn(x)
# `jit.to_static` would include a train_program, eval mode might cause
# some errors currently, such as dropout grad op gets `is_test == True`.
rnn.train()
rnn = paddle.jit.to_static(
rnn,
[paddle.static.InputSpec(
shape=[None, None, 16], dtype=x.dtype)])
paddle.jit.save(rnn, "./inference/lstm_infer")
paddle.enable_static()
new_scope = paddle.static.Scope()
with paddle.static.scope_guard(new_scope):
exe = paddle.static.Executor(place)
[inference_program, feed_target_names,
fetch_targets] = paddle.static.load_inference_model(
dirname="./inference",
executor=exe,
model_filename="lstm_infer.pdmodel",
params_filename="lstm_infer.pdiparams")
results = exe.run(inference_program,
feed={feed_target_names[0]: x.numpy()},
fetch_list=fetch_targets)
np.testing.assert_equal(
y.numpy(), results[0]) # eval results equal predict results
paddle.disable_static()
def runTest(self): def runTest(self):
self.test_with_initial_state() self.test_with_initial_state()
self.test_with_zero_state() self.test_with_zero_state()
self.test_with_input_lengths() self.test_with_input_lengths()
self.test_predict()
def load_tests(loader, tests, pattern): def load_tests(loader, tests, pattern):
......
...@@ -30,10 +30,12 @@ class TestSimpleRNN(unittest.TestCase): ...@@ -30,10 +30,12 @@ class TestSimpleRNN(unittest.TestCase):
self.time_major = time_major self.time_major = time_major
self.direction = direction self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1 self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \ self.place = place
else paddle.CUDAPlace(0)
def setUp(self): def setUp(self):
# Since `set_device` is global, set `set_device` in `setUp` rather than
# `__init__` to avoid using an error device set by another test case.
place = paddle.set_device(self.place)
rnn1 = SimpleRNN( rnn1 = SimpleRNN(
16, 32, 2, time_major=self.time_major, direction=self.direction) 16, 32, 2, time_major=self.time_major, direction=self.direction)
...@@ -48,7 +50,6 @@ class TestSimpleRNN(unittest.TestCase): ...@@ -48,7 +50,6 @@ class TestSimpleRNN(unittest.TestCase):
time_major=self.time_major, time_major=self.time_major,
direction=self.direction) direction=self.direction)
place = self.place
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
scope = paddle.fluid.Scope() scope = paddle.fluid.Scope()
with paddle.static.scope_guard(scope): with paddle.static.scope_guard(scope):
...@@ -172,10 +173,12 @@ class TestGRU(unittest.TestCase): ...@@ -172,10 +173,12 @@ class TestGRU(unittest.TestCase):
self.time_major = time_major self.time_major = time_major
self.direction = direction self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1 self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \ self.place = place
else paddle.CUDAPlace(0)
def setUp(self): def setUp(self):
# Since `set_device` is global, set `set_device` in `setUp` rather than
# `__init__` to avoid using an error device set by another test case.
place = paddle.set_device(self.place)
rnn1 = GRU(16, rnn1 = GRU(16,
32, 32,
2, 2,
...@@ -192,7 +195,6 @@ class TestGRU(unittest.TestCase): ...@@ -192,7 +195,6 @@ class TestGRU(unittest.TestCase):
time_major=self.time_major, time_major=self.time_major,
direction=self.direction) direction=self.direction)
place = self.place
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
scope = paddle.fluid.Scope() scope = paddle.fluid.Scope()
with paddle.static.scope_guard(scope): with paddle.static.scope_guard(scope):
...@@ -316,10 +318,12 @@ class TestLSTM(unittest.TestCase): ...@@ -316,10 +318,12 @@ class TestLSTM(unittest.TestCase):
self.time_major = time_major self.time_major = time_major
self.direction = direction self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1 self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \ self.place = place
else paddle.CUDAPlace(0)
def setUp(self): def setUp(self):
# Since `set_device` is global, set `set_device` in `setUp` rather than
# `__init__` to avoid using an error device set by another test case.
place = paddle.set_device(self.place)
rnn1 = LSTM( rnn1 = LSTM(
16, 32, 2, time_major=self.time_major, direction=self.direction) 16, 32, 2, time_major=self.time_major, direction=self.direction)
...@@ -334,7 +338,6 @@ class TestLSTM(unittest.TestCase): ...@@ -334,7 +338,6 @@ class TestLSTM(unittest.TestCase):
time_major=self.time_major, time_major=self.time_major,
direction=self.direction) direction=self.direction)
place = self.place
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
scope = paddle.fluid.Scope() scope = paddle.fluid.Scope()
with paddle.static.scope_guard(scope): with paddle.static.scope_guard(scope):
......
...@@ -21,8 +21,11 @@ import sys ...@@ -21,8 +21,11 @@ import sys
import warnings import warnings
from functools import partial, reduce from functools import partial, reduce
import numpy as np
import paddle import paddle
import paddle.fluid as fluid
from paddle import framework from paddle import framework
from paddle.device import get_device, get_cudnn_version
from paddle.nn import functional as F from paddle.nn import functional as F
from paddle.nn import initializer as I from paddle.nn import initializer as I
from paddle.fluid.dygraph import Layer, LayerList from paddle.fluid.dygraph import Layer, LayerList
...@@ -135,7 +138,7 @@ def concat_states(states, bidirectional=False, state_components=1): ...@@ -135,7 +138,7 @@ def concat_states(states, bidirectional=False, state_components=1):
componnets = [] componnets = []
for i in range(state_components): for i in range(state_components):
componnets.append(states[i::state_components]) componnets.append(states[i::state_components])
return [paddle.stack(item) for item in componnets] return tuple([paddle.stack(item) for item in componnets])
class RNNCellBase(Layer): class RNNCellBase(Layer):
...@@ -270,9 +273,12 @@ class SimpleRNNCell(RNNCellBase): ...@@ -270,9 +273,12 @@ class SimpleRNNCell(RNNCellBase):
The formula used is as follows: The formula used is as follows:
.. math:: .. math::
h_{t} & = \mathrm{tanh}(W_{ih}x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh}) h_{t} & = act(W_{ih}x_{t} + b_{ih} + W_{hh}h{t-1} + b_{hh})
y_{t} & = h_{t} y_{t} & = h_{t}
where :math:`act` is for :attr:`activation` , and * is the elemetwise
multiplication operator.
Please refer to `Finding Structure in Time Please refer to `Finding Structure in Time
<https://crl.ucsd.edu/~elman/Papers/fsit.pdf>`_ for more details. <https://crl.ucsd.edu/~elman/Papers/fsit.pdf>`_ for more details.
...@@ -807,13 +813,14 @@ class RNN(Layer): ...@@ -807,13 +813,14 @@ class RNN(Layer):
initial_states=None, initial_states=None,
sequence_length=None, sequence_length=None,
**kwargs): **kwargs):
final_outputs, final_states = paddle.fluid.layers.rnn(self.cell, final_outputs, final_states = paddle.fluid.layers.rnn(
inputs, self.cell,
initial_states=initial_states, inputs,
sequence_length=sequence_length, initial_states=initial_states,
time_major=self.time_major, sequence_length=sequence_length,
is_reverse=self.is_reverse, time_major=self.time_major,
**kwargs) is_reverse=self.is_reverse,
**kwargs)
return final_outputs, final_states return final_outputs, final_states
...@@ -909,18 +916,194 @@ class BiRNN(Layer): ...@@ -909,18 +916,194 @@ class BiRNN(Layer):
assert len(initial_states) == 2, \ assert len(initial_states) == 2, \
"length of initial_states should be 2 when it is a list/tuple" "length of initial_states should be 2 when it is a list/tuple"
outputs, final_states = paddle.fluid.layers.birnn(self.cell_fw, self.cell_bw, inputs, outputs, final_states = paddle.fluid.layers.birnn(
initial_states, sequence_length, self.cell_fw, self.cell_bw, inputs, initial_states, sequence_length,
self.time_major, **kwargs) self.time_major, **kwargs)
return outputs, final_states return outputs, final_states
class RNNMixin(LayerList): class RNNBase(LayerList):
r""" r"""
A Mixin class for RNN networks. It provides `forward` method for SimpleRNN, RNNBase class for RNN networks. It provides `forward`, `flatten_parameters`
LSTM and GRU. and other common methods for SimpleRNN, LSTM and GRU.
""" """
def __init__(self,
mode,
input_size,
hidden_size,
num_layers=1,
direction="forward",
time_major=False,
dropout=0.,
weight_ih_attr=None,
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None):
super(RNNBase, self).__init__()
self.mode = mode
self.input_size = input_size
self.hidden_size = hidden_size
self.dropout = dropout
self.num_directions = 2 if direction == "bidirectional" else 1
self.time_major = time_major
self.num_layers = num_layers
self.state_components = 2 if mode == "LSTM" else 1
kwargs = {
"weight_ih_attr": weight_ih_attr,
"weight_hh_attr": weight_hh_attr,
"bias_ih_attr": bias_ih_attr,
"bias_hh_attr": bias_hh_attr
}
if mode == "LSTM":
rnn_cls = LSTMCell
elif mode == "GRU":
rnn_cls = GRUCell
else:
rnn_cls = SimpleRNNCell
kwargs["activation"] = self.activation
if direction in ["forward", "backward"]:
is_reverse = direction == "backward"
cell = rnn_cls(input_size, hidden_size, **kwargs)
self.append(RNN(cell, is_reverse, time_major))
for i in range(1, num_layers):
cell = rnn_cls(hidden_size, hidden_size, **kwargs)
self.append(RNN(cell, is_reverse, time_major))
elif direction == "bidirectional":
cell_fw = rnn_cls(input_size, hidden_size, **kwargs)
cell_bw = rnn_cls(input_size, hidden_size, **kwargs)
self.append(BiRNN(cell_fw, cell_bw, time_major))
for i in range(1, num_layers):
cell_fw = rnn_cls(2 * hidden_size, hidden_size, **kwargs)
cell_bw = rnn_cls(2 * hidden_size, hidden_size, **kwargs)
self.append(BiRNN(cell_fw, cell_bw, time_major))
else:
raise ValueError(
"direction should be forward, backward or bidirectional, "
"received direction = {}".format(direction))
self.could_use_cudnn = get_device().startswith(
"gpu:") and get_cudnn_version()
self.could_use_cudnn &= direction != "backward"
self.could_use_cudnn &= len(self.parameters()) == num_layers * 4 * (
2 if direction == "bidirectional" else 1)
self.could_use_cudnn &= mode == "LSTM" # currently only support LSTM
# Expose params as RNN's attribute, which can make it compatible when
# replacing small ops composed rnn with cpp rnn kernel.
# Moreover, `jit.to_static` assumes params are added by current layer
# and wouldn't include sublayer's params in current layer, which also
# requires these params are added to current layer for `jit.save`.
param_names = []
for layer in range(self.num_layers):
for direction in range(self.num_directions):
suffix = '_reverse' if direction == 1 else ''
param_names.extend(['weight_ih_l{}{}', 'weight_hh_l{}{}'])
if bias_ih_attr != False: param_names.append('bias_ih_l{}{}')
if bias_hh_attr != False: param_names.append('bias_hh_l{}{}')
param_names = [x.format(layer, suffix) for x in param_names]
for name, param in zip(param_names, self.parameters()):
setattr(self, name, param)
self.flatten_parameters()
def flatten_parameters(self):
"""
Resets parameter data pointer to address in continuous memory block for
cudnn usage.
"""
if self.could_use_cudnn:
# layer.parameters() is depth first and ordered
# for i in layer: for j in direct: w_ih, w_hh, b_ih, b_hh
# need to reorganize to cudnn param layout:
# all bias following all weights
params = self.parameters(include_sublayers=False)
shape = [np.prod(param.shape) for param in params]
self._all_weights = [None] * len(params)
for i, param in enumerate(params):
offset = 0 if i % 4 < 2 else (2 * self.num_layers *
self.num_directions)
layer_idx = i // 4
self._all_weights[offset + layer_idx * 2 + i % 2] = param
# Wrap using a list to avoid registed into params and saving, maybe
# need a better way to handle this later. Use `create_parameter` to
# add both to main_program and startup_program for static-graph.
# Use Constant initializer to avoid make effect on random generator.
self._flat_weight = [
self.create_parameter(
shape=[np.sum(shape)],
dtype=params[0].dtype,
default_initializer=I.Constant(0.0))
]
# dropout state may also can be hided and avoid saving
# should dropout state be persistable for static-graph
self._dropout_state = self.create_variable(
dtype=fluid.core.VarDesc.VarType.UINT8)
# for static-graph, append coalesce_tensor into startup program
with fluid.program_guard(fluid.default_startup_program(),
fluid.default_startup_program()):
with framework.no_grad():
self._helper.append_op(
type="coalesce_tensor",
inputs={"Input": self._all_weights},
outputs={
"Output": self._all_weights,
"FusedOutput": self._flat_weight
},
attrs={
"copy_data": True,
"use_align": False,
"dtype": params[0].dtype
})
def _cudnn_impl(self, inputs, initial_states, sequence_length):
if not self.time_major:
inputs = paddle.tensor.transpose(inputs, [1, 0, 2])
# unify LSTM/GRU/SimpleRNN later, currently only support LSTM
# TODO(guosheng): use `core.ops.cudnn_lstm` in dygraph mode if support
# specify output, since `dropout_state` should be a persistable tensor
# rather than a temporary on.
out = self._helper.create_variable_for_type_inference(inputs.dtype)
last_h = self._helper.create_variable_for_type_inference(inputs.dtype)
last_c = self._helper.create_variable_for_type_inference(inputs.dtype)
reserve = self._helper.create_variable_for_type_inference(
dtype=fluid.core.VarDesc.VarType.UINT8, stop_gradient=True)
inputs = {
'Input': inputs,
# 'W': self._flat_weight, # would be unused_var
'WeightList': self._all_weights,
'InitH': initial_states[0],
'InitC': initial_states[1],
'SequenceLength': sequence_length
}
attrs = {
'dropout_prob': self.dropout,
'is_bidirec': self.num_directions == 2,
'input_size': self.input_size,
'hidden_size': self.hidden_size,
'num_layers': self.num_layers,
'is_test': not self.training
}
outputs = {
'Out': out,
'LastH': last_h,
'LastC': last_c,
'Reserve': reserve,
'StateOut': self._dropout_state,
}
self._helper.append_op(
type="cudnn_lstm", inputs=inputs, outputs=outputs, attrs=attrs)
out = paddle.tensor.transpose(out,
[1, 0, 2]) if not self.time_major else out
states = (last_h, last_c)
return out, states
def forward(self, inputs, initial_states=None, sequence_length=None): def forward(self, inputs, initial_states=None, sequence_length=None):
batch_index = 1 if self.time_major else 0 batch_index = 1 if self.time_major else 0
dtype = inputs.dtype dtype = inputs.dtype
...@@ -937,6 +1120,10 @@ class RNNMixin(LayerList): ...@@ -937,6 +1120,10 @@ class RNNMixin(LayerList):
for _ in range(self.state_components) for _ in range(self.state_components)
]) ])
if self.could_use_cudnn:
# Add CPU kernel and dispatch in backend later
return self._cudnn_impl(inputs, initial_states, sequence_length)
states = split_states(initial_states, self.num_directions == 2, states = split_states(initial_states, self.num_directions == 2,
self.state_components) self.state_components)
final_states = [] final_states = []
...@@ -957,7 +1144,7 @@ class RNNMixin(LayerList): ...@@ -957,7 +1144,7 @@ class RNNMixin(LayerList):
return outputs, final_states return outputs, final_states
class SimpleRNN(RNNMixin): class SimpleRNN(RNNBase):
r""" r"""
Multilayer Elman network(SimpleRNN). It takes input sequences and initial Multilayer Elman network(SimpleRNN). It takes input sequences and initial
states as inputs, and returns the output sequences and the final states. states as inputs, and returns the output sequences and the final states.
...@@ -970,22 +1157,28 @@ class SimpleRNN(RNNMixin): ...@@ -970,22 +1157,28 @@ class SimpleRNN(RNNMixin):
.. math:: .. math::
h_{t} & = \mathrm{tanh}(W_{ih}x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh}) h_{t} & = act(W_{ih}x_{t} + b_{ih} + W_{hh}h{t-1} + b_{hh})
y_{t} & = h_{t} y_{t} & = h_{t}
where :math:`act` is for :attr:`activation` , and * is the elemetwise
multiplication operator.
Using key word arguments to construct is recommended.
Parameters: Parameters:
input_size (int): The input size for the first layer's cell. input_size (int): The input size for the first layer's cell.
hidden_size (int): The hidden size for each layer's cell. hidden_size (int): The hidden size for each layer's cell.
num_layers (int, optional): Number of layers. Defaults to 1. num_layers (int, optional): Number of layers. Defaults to 1.
activation (str, optional): The activation in each SimpleRNN cell. It can be
`tanh` or `relu`. Defaults to `tanh`.
direction (str, optional): The direction of the network. It can be "forward", direction (str, optional): The direction of the network. It can be "forward",
"backward" and "bidirectional". Defaults to "forward". "backward" and "bidirectional". When "bidirectional", the way to merge
dropout (float, optional): The droput probability. Dropout is applied to the outputs of forward and backward is concatenating. Defaults to "forward".
input of each layer except for the first layer. Defaults to 0.
time_major (bool, optional): Whether the first dimension of the input means the time_major (bool, optional): Whether the first dimension of the input means the
time steps. Defaults to False. time steps. Defaults to False.
dropout (float, optional): The droput probability. Dropout is applied to the
input of each layer except for the first layer. Defaults to 0.
activation (str, optional): The activation in each SimpleRNN cell. It can be
`tanh` or `relu`. Defaults to `tanh`.
weight_ih_attr (ParamAttr, optional): The parameter attribute for weight_ih_attr (ParamAttr, optional): The parameter attribute for
`weight_ih` of each cell. Defaults to None. `weight_ih` of each cell. Defaults to None.
weight_hh_attr (ParamAttr, optional): The parameter attribute for weight_hh_attr (ParamAttr, optional): The parameter attribute for
...@@ -1002,7 +1195,7 @@ class SimpleRNN(RNNMixin): ...@@ -1002,7 +1195,7 @@ class SimpleRNN(RNNMixin):
If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`,
else, the shape is `[batch_size, time_steps, hidden_size]`. else, the shape is `[batch_size, time_steps, hidden_size]`.
initial_states (Tensor, optional): the initial state. The shape is initial_states (Tensor, optional): the initial state. The shape is
`[num_lauers * num_directions, batch_size, hidden_size]`. `[num_layers * num_directions, batch_size, hidden_size]`.
If initial_state is not given, zero initial states are used. If initial_state is not given, zero initial states are used.
sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64
or int32. The valid lengths of input sequences. Defaults to None. or int32. The valid lengths of input sequences. Defaults to None.
...@@ -1020,10 +1213,21 @@ class SimpleRNN(RNNMixin): ...@@ -1020,10 +1213,21 @@ class SimpleRNN(RNNMixin):
Note that `num_directions` is 2 if direction is "bidirectional" Note that `num_directions` is 2 if direction is "bidirectional"
else 1. else 1.
final_states (Tensor): final states. The shape is final_states (Tensor): final states. The shape is
`[num_lauers * num_directions, batch_size, hidden_size]`. `[num_layers * num_directions, batch_size, hidden_size]`.
Note that `num_directions` is 2 if direction is "bidirectional" Note that `num_directions` is 2 if direction is "bidirectional"
else 1. else 1.
Attributes:
weight_ih_l[k]: the learnable input-hidden weights of the k-th layer,
If `k = 0`, the shape is `[hidden_size, input_size]`. Otherwise,
the shape is `[hidden_size, num_directions * hidden_size]`.
weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer,
with shape `[hidden_size, hidden_size]`.
bias_ih_l[k]: the learnable input-hidden bias of the k-th layer,
with shape `[hidden_size]`.
bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer,
with shape `[hidden_size]`.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -1048,59 +1252,28 @@ class SimpleRNN(RNNMixin): ...@@ -1048,59 +1252,28 @@ class SimpleRNN(RNNMixin):
input_size, input_size,
hidden_size, hidden_size,
num_layers=1, num_layers=1,
activation="tanh",
direction="forward", direction="forward",
dropout=0.,
time_major=False, time_major=False,
dropout=0.,
activation="tanh",
weight_ih_attr=None, weight_ih_attr=None,
weight_hh_attr=None, weight_hh_attr=None,
bias_ih_attr=None, bias_ih_attr=None,
bias_hh_attr=None, bias_hh_attr=None,
name=None): name=None):
super(SimpleRNN, self).__init__() if activation == "tanh":
mode = "RNN_TANH"
if direction in ["forward", "backward"]: elif activation == "relu":
is_reverse = direction == "backward" mode = "RNN_RELU"
cell = SimpleRNNCell(input_size, hidden_size, activation,
weight_ih_attr, weight_hh_attr, bias_ih_attr,
bias_hh_attr)
self.append(RNN(cell, is_reverse, time_major))
for i in range(1, num_layers):
cell = SimpleRNNCell(hidden_size, hidden_size, activation,
weight_ih_attr, weight_hh_attr,
bias_ih_attr, bias_hh_attr)
self.append(RNN(cell, is_reverse, time_major))
elif direction == "bidirectional":
cell_fw = SimpleRNNCell(input_size, hidden_size, activation,
weight_ih_attr, weight_hh_attr,
bias_ih_attr, bias_hh_attr)
cell_bw = SimpleRNNCell(input_size, hidden_size, activation,
weight_ih_attr, weight_hh_attr,
bias_ih_attr, bias_hh_attr)
self.append(BiRNN(cell_fw, cell_bw, time_major))
for i in range(1, num_layers):
cell_fw = SimpleRNNCell(
2 * hidden_size, hidden_size, activation, weight_ih_attr,
weight_hh_attr, bias_ih_attr, bias_hh_attr)
cell_bw = SimpleRNNCell(
2 * hidden_size, hidden_size, activation, weight_ih_attr,
weight_hh_attr, bias_ih_attr, bias_hh_attr)
self.append(BiRNN(cell_fw, cell_bw, time_major))
else: else:
raise ValueError( raise ValueError("Unknown activation '{}'".format(activation))
"direction should be forward, backward or bidirectional, " self.activation = activation
"received direction = {}".format(direction)) super(SimpleRNN, self).__init__(
mode, input_size, hidden_size, num_layers, direction, time_major,
self.input_size = input_size dropout, weight_ih_attr, weight_hh_attr, bias_ih_attr, bias_hh_attr)
self.hidden_size = hidden_size
self.dropout = dropout
self.num_directions = 2 if direction == "bidirectional" else 1
self.time_major = time_major
self.num_layers = num_layers
self.state_components = 1
class LSTM(RNNMixin): class LSTM(RNNBase):
r""" r"""
Multilayer LSTM. It takes a sequence and an initial state as inputs, and Multilayer LSTM. It takes a sequence and an initial state as inputs, and
returns the output sequences and the final states. returns the output sequences and the final states.
...@@ -1130,16 +1303,19 @@ class LSTM(RNNMixin): ...@@ -1130,16 +1303,19 @@ class LSTM(RNNMixin):
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
multiplication operator. multiplication operator.
Using key word arguments to construct is recommended.
Parameters: Parameters:
input_size (int): The input size for the first layer's cell. input_size (int): The input size for the first layer's cell.
hidden_size (int): The hidden size for each layer's cell. hidden_size (int): The hidden size for each layer's cell.
num_layers (int, optional): Number of layers. Defaults to 1. num_layers (int, optional): Number of layers. Defaults to 1.
direction (str, optional): The direction of the network. It can be direction (str, optional): The direction of the network. It can be "forward",
"forward", "backward" and "bidirectional". Defaults to "forward". "backward" and "bidirectional". When "bidirectional", the way to merge
dropout (float, optional): The droput probability. Dropout is applied outputs of forward and backward is concatenating. Defaults to "forward".
to the input of each layer except for the first layer. Defaults to 0.
time_major (bool, optional): Whether the first dimension of the input time_major (bool, optional): Whether the first dimension of the input
means the time steps. Defaults to False. means the time steps. Defaults to False.
dropout (float, optional): The droput probability. Dropout is applied
to the input of each layer except for the first layer. Defaults to 0.
weight_ih_attr (ParamAttr, optional): The parameter attribute for weight_ih_attr (ParamAttr, optional): The parameter attribute for
`weight_ih` of each cell. Default: None. `weight_ih` of each cell. Default: None.
weight_hh_attr (ParamAttr, optional): The parameter attribute for weight_hh_attr (ParamAttr, optional): The parameter attribute for
...@@ -1156,7 +1332,7 @@ class LSTM(RNNMixin): ...@@ -1156,7 +1332,7 @@ class LSTM(RNNMixin):
If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`,
else, the shape is `[batch_size, time_steps, hidden_size]`. else, the shape is `[batch_size, time_steps, hidden_size]`.
initial_states (tuple, optional): the initial state, a tuple of (h, c), initial_states (tuple, optional): the initial state, a tuple of (h, c),
the shape of each is `[num_lauers * num_directions, batch_size, hidden_size]`. the shape of each is `[num_layers * num_directions, batch_size, hidden_size]`.
If initial_state is not given, zero initial states are used. If initial_state is not given, zero initial states are used.
sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64
or int32. The valid lengths of input sequences. Defaults to None. or int32. The valid lengths of input sequences. Defaults to None.
...@@ -1175,10 +1351,21 @@ class LSTM(RNNMixin): ...@@ -1175,10 +1351,21 @@ class LSTM(RNNMixin):
else 1. else 1.
final_states (tuple): the final state, a tuple of two tensors, h and c. final_states (tuple): the final state, a tuple of two tensors, h and c.
The shape of each is The shape of each is
`[num_lauers * num_directions, batch_size, hidden_size]`. `[num_layers * num_directions, batch_size, hidden_size]`.
Note that `num_directions` is 2 if direction is "bidirectional" Note that `num_directions` is 2 if direction is "bidirectional"
else 1. else 1.
Attributes:
weight_ih_l[k]: the learnable input-hidden weights of the k-th layer,
If `k = 0`, the shape is `[hidden_size, input_size]`. Otherwise,
the shape is `[hidden_size, num_directions * hidden_size]`.
weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer,
with shape `[hidden_size, hidden_size]`.
bias_ih_l[k]: the learnable input-hidden bias of the k-th layer,
with shape `[hidden_size]`.
bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer,
with shape `[hidden_size]`.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -1207,51 +1394,19 @@ class LSTM(RNNMixin): ...@@ -1207,51 +1394,19 @@ class LSTM(RNNMixin):
hidden_size, hidden_size,
num_layers=1, num_layers=1,
direction="forward", direction="forward",
dropout=0.,
time_major=False, time_major=False,
dropout=0.,
weight_ih_attr=None, weight_ih_attr=None,
weight_hh_attr=None, weight_hh_attr=None,
bias_ih_attr=None, bias_ih_attr=None,
bias_hh_attr=None, bias_hh_attr=None,
name=None): name=None):
super(LSTM, self).__init__() super(LSTM, self).__init__(
"LSTM", input_size, hidden_size, num_layers, direction, time_major,
if direction in ["forward", "backward"]: dropout, weight_ih_attr, weight_hh_attr, bias_ih_attr, bias_hh_attr)
is_reverse = direction == "backward"
cell = LSTMCell(input_size, hidden_size, weight_ih_attr,
weight_hh_attr, bias_ih_attr, bias_hh_attr)
self.append(RNN(cell, is_reverse, time_major))
for i in range(1, num_layers):
cell = LSTMCell(hidden_size, hidden_size, weight_ih_attr,
weight_hh_attr, bias_ih_attr, bias_hh_attr)
self.append(RNN(cell, is_reverse, time_major))
elif direction == "bidirectional":
cell_fw = LSTMCell(input_size, hidden_size, weight_ih_attr,
weight_hh_attr, bias_ih_attr, bias_hh_attr)
cell_bw = LSTMCell(input_size, hidden_size, weight_ih_attr,
weight_hh_attr, bias_ih_attr, bias_hh_attr)
self.append(BiRNN(cell_fw, cell_bw, time_major))
for i in range(1, num_layers):
cell_fw = LSTMCell(2 * hidden_size, hidden_size, weight_ih_attr,
weight_hh_attr, bias_ih_attr, bias_hh_attr)
cell_bw = LSTMCell(2 * hidden_size, hidden_size, weight_ih_attr,
weight_hh_attr, bias_ih_attr, bias_hh_attr)
self.append(BiRNN(cell_fw, cell_bw, time_major))
else:
raise ValueError(
"direction should be forward, backward or bidirectional, "
"received direction = {}".format(direction))
self.input_size = input_size
self.hidden_size = hidden_size
self.dropout = dropout
self.num_directions = 2 if direction == "bidirectional" else 1
self.time_major = time_major
self.num_layers = num_layers
self.state_components = 2
class GRU(RNNMixin): class GRU(RNNBase):
r""" r"""
Multilayer GRU. It takes input sequencse and initial states as inputs, and Multilayer GRU. It takes input sequencse and initial states as inputs, and
returns the output sequences and the final states. returns the output sequences and the final states.
...@@ -1277,16 +1432,19 @@ class GRU(RNNMixin): ...@@ -1277,16 +1432,19 @@ class GRU(RNNMixin):
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
multiplication operator. multiplication operator.
Using key word arguments to construct is recommended.
Parameters: Parameters:
input_size (int): The input size for the first layer's cell. input_size (int): The input size for the first layer's cell.
hidden_size (int): The hidden size for each layer's cell. hidden_size (int): The hidden size for each layer's cell.
num_layers (int, optional): Number of layers. Defaults to 1. num_layers (int, optional): Number of layers. Defaults to 1.
direction (str, optional): The direction of the network. It can be direction (str, optional): The direction of the network. It can be "forward",
"forward", "backward" and "bidirectional". Defaults to "forward". "backward" and "bidirectional". When "bidirectional", the way to merge
dropout (float, optional): The droput probability. Dropout is applied outputs of forward and backward is concatenating. Defaults to "forward".
to the input of each layer except for the first layer. Defaults to 0.
time_major (bool, optional): Whether the first dimension of the input time_major (bool, optional): Whether the first dimension of the input
means the time steps. Defaults to False. means the time steps. Defaults to False.
dropout (float, optional): The droput probability. Dropout is applied
to the input of each layer except for the first layer. Defaults to 0.
weight_ih_attr (ParamAttr, optional): The parameter attribute for weight_ih_attr (ParamAttr, optional): The parameter attribute for
`weight_ih` of each cell. Default: None. `weight_ih` of each cell. Default: None.
weight_hh_attr (ParamAttr, optional): The parameter attribute for weight_hh_attr (ParamAttr, optional): The parameter attribute for
...@@ -1303,7 +1461,7 @@ class GRU(RNNMixin): ...@@ -1303,7 +1461,7 @@ class GRU(RNNMixin):
If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`, If `time_major` is True, the shape is `[time_steps, batch_size, input_size]`,
else, the shape is `[batch_size, time_steps, hidden_size]`. else, the shape is `[batch_size, time_steps, hidden_size]`.
initial_states (Tensor, optional): the initial state. The shape is initial_states (Tensor, optional): the initial state. The shape is
`[num_lauers * num_directions, batch_size, hidden_size]`. `[num_layers * num_directions, batch_size, hidden_size]`.
If initial_state is not given, zero initial states are used. If initial_state is not given, zero initial states are used.
Defaults to None. Defaults to None.
sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64 sequence_length (Tensor, optional): shape `[batch_size]`, dtype: int64
...@@ -1322,10 +1480,21 @@ class GRU(RNNMixin): ...@@ -1322,10 +1480,21 @@ class GRU(RNNMixin):
Note that `num_directions` is 2 if direction is "bidirectional" Note that `num_directions` is 2 if direction is "bidirectional"
else 1. else 1.
final_states (Tensor): final states. The shape is final_states (Tensor): final states. The shape is
`[num_lauers * num_directions, batch_size, hidden_size]`. `[num_layers * num_directions, batch_size, hidden_size]`.
Note that `num_directions` is 2 if direction is "bidirectional" Note that `num_directions` is 2 if direction is "bidirectional"
else 1. else 1.
Attributes:
weight_ih_l[k]: the learnable input-hidden weights of the k-th layer,
If `k = 0`, the shape is `[hidden_size, input_size]`. Otherwise,
the shape is `[hidden_size, num_directions * hidden_size]`.
weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer,
with shape `[hidden_size, hidden_size]`.
bias_ih_l[k]: the learnable input-hidden bias of the k-th layer,
with shape `[hidden_size]`.
bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer,
with shape `[hidden_size]`.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -1351,45 +1520,13 @@ class GRU(RNNMixin): ...@@ -1351,45 +1520,13 @@ class GRU(RNNMixin):
hidden_size, hidden_size,
num_layers=1, num_layers=1,
direction="forward", direction="forward",
dropout=0.,
time_major=False, time_major=False,
dropout=0.,
weight_ih_attr=None, weight_ih_attr=None,
weight_hh_attr=None, weight_hh_attr=None,
bias_ih_attr=None, bias_ih_attr=None,
bias_hh_attr=None, bias_hh_attr=None,
name=None): name=None):
super(GRU, self).__init__() super(GRU, self).__init__(
"GRU", input_size, hidden_size, num_layers, direction, time_major,
if direction in ["forward", "backward"]: dropout, weight_ih_attr, weight_hh_attr, bias_ih_attr, bias_hh_attr)
is_reverse = direction == "backward"
cell = GRUCell(input_size, hidden_size, weight_ih_attr,
weight_hh_attr, bias_ih_attr, bias_hh_attr)
self.append(RNN(cell, is_reverse, time_major))
for i in range(1, num_layers):
cell = GRUCell(hidden_size, hidden_size, weight_ih_attr,
weight_hh_attr, bias_ih_attr, bias_hh_attr)
self.append(RNN(cell, is_reverse, time_major))
elif direction == "bidirectional":
cell_fw = GRUCell(input_size, hidden_size, weight_ih_attr,
weight_hh_attr, bias_ih_attr, bias_hh_attr)
cell_bw = GRUCell(input_size, hidden_size, weight_ih_attr,
weight_hh_attr, bias_ih_attr, bias_hh_attr)
self.append(BiRNN(cell_fw, cell_bw, time_major))
for i in range(1, num_layers):
cell_fw = GRUCell(2 * hidden_size, hidden_size, weight_ih_attr,
weight_hh_attr, bias_ih_attr, bias_hh_attr)
cell_bw = GRUCell(2 * hidden_size, hidden_size, weight_ih_attr,
weight_hh_attr, bias_ih_attr, bias_hh_attr)
self.append(BiRNN(cell_fw, cell_bw, time_major))
else:
raise ValueError(
"direction should be forward, backward or bidirectional, "
"received direction = {}".format(direction))
self.input_size = input_size
self.hidden_size = hidden_size
self.dropout = dropout
self.num_directions = 2 if direction == "bidirectional" else 1
self.time_major = time_major
self.num_layers = num_layers
self.state_components = 1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册