未验证 提交 971f3bc9 编写于 作者: T tangwei12 提交者: GitHub

fix params with only 1 dim (#15828)

* fix params with only 1 dim
* test=develop
上级 fbb54046
......@@ -766,7 +766,10 @@ def _load_distributed_persistables(executor, dirname, main_program=None):
dtype=slice_var.dtype,
persistable=True)
dim1_flatten = reduce(lambda x, y: x * y, slice.shape[1:])
dim1_flatten = 1
if len(slice.shape) >= 2:
dim1_flatten = reduce(lambda x, y: x * y, slice.shape[1:])
start = int(offset / dim1_flatten)
end = int(offset / dim1_flatten + slice.shape[0])
......
......@@ -1020,7 +1020,11 @@ class DistributeTranspiler(object):
skip_dim0 = 0
slice_vars = self.param_var_mapping[orig_var_name]
orig_dim1_flatten = reduce(lambda x, y: x * y, slice_vars[0].shape[1:])
orig_dim1_flatten = 1
if len(slice_vars[0].shape) >= 2:
orig_dim1_flatten = reduce(lambda x, y: x * y,
slice_vars[0].shape[1:])
for slice_var in slice_vars[:block_idx]:
skip_dim0 += slice_var.shape[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册