提交 c0d38e40 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1606 LSTM network adapt to cpu target.

Merge pull request !1606 from caojian05/ms_master_dev
......@@ -17,7 +17,7 @@ import math
import numpy as np
from mindspore import Parameter, Tensor, nn
from mindspore import Parameter, Tensor, nn, context, ParameterTuple
from mindspore.common.initializer import initializer
from mindspore.ops import operations as P
......@@ -57,6 +57,24 @@ def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional):
if bidirectional:
num_directions = 2
if context.get_context("device_target") == "CPU":
h_list = []
c_list = []
for i in range(num_layers):
hi = Parameter(initializer(
Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)),
[num_directions, batch_size, hidden_size]
), name='h' + str(i))
h_list.append(hi)
ci = Parameter(initializer(
Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)),
[num_directions, batch_size, hidden_size]
), name='c' + str(i))
c_list.append(ci)
h = ParameterTuple(tuple(h_list))
c = ParameterTuple(tuple(c_list))
return h, c
h = Tensor(
np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
c = Tensor(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册