提交 27b5bc6d 编写于 作者: Y Yi Huaijie

don't change shape of Initializer when init slice of a Parameter

上级 634bfd35
......@@ -64,7 +64,7 @@ class Initializer:
def dtype(self, dtype):
self._dtype = dtype
def to_tensor(self, slice_index=None):
def to_tensor(self, slice_index=None, shape=None):
"""
Get the tensor format data of this Initializer.
......@@ -72,12 +72,16 @@ class Initializer:
slice_index (int): Slice index of a parameter's slices.
Used when initialize a slice of a parameter, it guarantee that
devices use the same slice can generate the same tensor.
shape (list[int]): Shape of the slice, used when initialize a slice of the parameter.
"""
arr = None
if shape is None:
shape = self.shape
try:
arr = np.ndarray(self.shape)
arr = np.ndarray(shape)
except ValueError:
msg = "Error shape={}".format(self.shape)
msg = "Error shape={}".format(shape)
logger.error(msg)
raise ValueError(msg)
......
......@@ -249,9 +249,8 @@ class Parameter:
if len(layout) != 3:
raise ValueError("The length of layout must be 3! layout is {}."
.format(layout))
self.init_mode.shape = layout[2]
slice_index = int(_get_slice_index(layout[0], layout[1]))
self.default_input = self.init_mode.to_tensor(slice_index)
self.default_input = self.init_mode.to_tensor(slice_index, layout[2])
else:
self.default_input = self.init_mode.to_tensor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册