diff --git a/mindspore/model_zoo/lstm.py b/mindspore/model_zoo/lstm.py index 35fe67430304afca581857dbf1be8d73e3bbfc83..7368bbf8e5afbadfdb9e5de7a2e645ffe178a5cc 100644 --- a/mindspore/model_zoo/lstm.py +++ b/mindspore/model_zoo/lstm.py @@ -17,7 +17,7 @@ import math import numpy as np -from mindspore import Parameter, Tensor, nn +from mindspore import Parameter, Tensor, nn, context, ParameterTuple from mindspore.common.initializer import initializer from mindspore.ops import operations as P @@ -57,6 +57,24 @@ def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional): if bidirectional: num_directions = 2 + if context.get_context("device_target") == "CPU": + h_list = [] + c_list = [] + for i in range(num_layers): + hi = Parameter(initializer( + Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)), + [num_directions, batch_size, hidden_size] + ), name='h' + str(i)) + h_list.append(hi) + ci = Parameter(initializer( + Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)), + [num_directions, batch_size, hidden_size] + ), name='c' + str(i)) + c_list.append(ci) + h = ParameterTuple(tuple(h_list)) + c = ParameterTuple(tuple(c_list)) + return h, c + h = Tensor( np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32)) c = Tensor(