提交 a022d421 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1992 fix lstm weight initializer

Merge pull request !1992 from baihuawei/cpulstm
......@@ -27,6 +27,7 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
using dim = dnnl::memory::dims;
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
std::vector<size_t> src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional");
input_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "input_size");
hidden_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "hidden_size");
......@@ -41,6 +42,12 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) {
MS_LOG(EXCEPTION) << "error iteration shape!";
}
if (num_layers_ <= 0) {
MS_LOG(EXCEPTION) << "layers must be greater than zero!";
}
if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) {
MS_LOG(EXCEPTION) << "conv2d only support 3-D input!";
}
const int gate_size = 4 * hidden_size_;
for (int i = 0; i < num_layers_; ++i) {
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
......
......@@ -31,6 +31,7 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
auto eng = MKLKernelEngine::Get().engine();
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
std::vector<size_t> src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional");
input_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "input_size");
hidden_size_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "hidden_size");
......@@ -45,6 +46,12 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) {
MS_LOG(EXCEPTION) << "error iteration shape!";
}
if (num_layers_ <= 0) {
MS_LOG(EXCEPTION) << "layers must be greater than zero!";
}
if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) {
MS_LOG(EXCEPTION) << "conv2d only support 3-D input!";
}
const int gate_size = 4 * hidden_size_;
for (int i = 0; i < num_layers_; ++i) {
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
......
......@@ -13,8 +13,8 @@
# limitations under the License.
# ============================================================================
"""lstm"""
import math
import numpy as np
import mindspore.nn as nn
from mindspore import context
from mindspore._checkparam import Validator as validator
......@@ -148,7 +148,9 @@ class LSTM(Cell):
if self.has_bias:
increment_size += 2 * gate_size
weight_size += increment_size * num_directions
self.weight = Parameter(initializer(0.0, [weight_size, 1, 1]), name='weight')
stdv = 1 / math.sqrt(hidden_size)
w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32)
self.weight = Parameter(initializer(Tensor(w_np), [weight_size, 1, 1]), name='weight')
else:
input_size_list = []
input_size_list.append(self.input_size)
......@@ -157,14 +159,13 @@ class LSTM(Cell):
weights = []
layers = []
bias_size = 0 if not self.has_bias else num_directions * self.hidden_size * 4
stdv = 1 / math.sqrt(hidden_size)
for i in range(num_layers):
weight_size = (input_size_list[i] + self.hidden_size) * num_directions * self.hidden_size * 4
w_np = np.ones([weight_size, 1, 1]).astype(np.float32) * 0.01
if has_bias:
bias_np = np.zeros([bias_size, 1, 1]).astype(np.float32)
w_np = np.concatenate([w_np, bias_np], axis=0)
weight_size = weight_size + bias_size
w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32)
weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name='weight' + str(i)))
layers.append(nn.LSTMCell(input_size=input_size_list[i],
hidden_size=self.hidden_size,
has_bias=self.has_bias,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册