未验证 提交 4b3f9e5c 编写于 作者: T tangwei12 提交者: GitHub

fix params with only 1 dim (#15828) (#15832)

* fix params with only 1 dim
* test=develop
上级 2307baf5
...@@ -766,7 +766,10 @@ def _load_distributed_persistables(executor, dirname, main_program=None): ...@@ -766,7 +766,10 @@ def _load_distributed_persistables(executor, dirname, main_program=None):
dtype=slice_var.dtype, dtype=slice_var.dtype,
persistable=True) persistable=True)
dim1_flatten = 1
if len(slice.shape) >= 2:
dim1_flatten = reduce(lambda x, y: x * y, slice.shape[1:]) dim1_flatten = reduce(lambda x, y: x * y, slice.shape[1:])
start = int(offset / dim1_flatten) start = int(offset / dim1_flatten)
end = int(offset / dim1_flatten + slice.shape[0]) end = int(offset / dim1_flatten + slice.shape[0])
......
...@@ -1020,7 +1020,11 @@ class DistributeTranspiler(object): ...@@ -1020,7 +1020,11 @@ class DistributeTranspiler(object):
skip_dim0 = 0 skip_dim0 = 0
slice_vars = self.param_var_mapping[orig_var_name] slice_vars = self.param_var_mapping[orig_var_name]
orig_dim1_flatten = reduce(lambda x, y: x * y, slice_vars[0].shape[1:]) orig_dim1_flatten = 1
if len(slice_vars[0].shape) >= 2:
orig_dim1_flatten = reduce(lambda x, y: x * y,
slice_vars[0].shape[1:])
for slice_var in slice_vars[:block_idx]: for slice_var in slice_vars[:block_idx]:
skip_dim0 += slice_var.shape[0] skip_dim0 += slice_var.shape[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册