提交 a3cee90b 编写于 作者: C caojian05

LSTM network adapt to cpu target.

上级 65fe1608
...@@ -17,7 +17,7 @@ import math ...@@ -17,7 +17,7 @@ import math
import numpy as np import numpy as np
from mindspore import Parameter, Tensor, nn from mindspore import Parameter, Tensor, nn, context, ParameterTuple
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.ops import operations as P from mindspore.ops import operations as P
...@@ -57,6 +57,24 @@ def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional): ...@@ -57,6 +57,24 @@ def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional):
if bidirectional: if bidirectional:
num_directions = 2 num_directions = 2
if context.get_context("device_target") == "CPU":
h_list = []
c_list = []
for i in range(num_layers):
hi = Parameter(initializer(
Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)),
[num_directions, batch_size, hidden_size]
), name='h' + str(i))
h_list.append(hi)
ci = Parameter(initializer(
Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)),
[num_directions, batch_size, hidden_size]
), name='c' + str(i))
c_list.append(ci)
h = ParameterTuple(tuple(h_list))
c = ParameterTuple(tuple(c_list))
return h, c
h = Tensor( h = Tensor(
np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32)) np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
c = Tensor( c = Tensor(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册