From 4b3f9e5c61d687ea90e6599bf9494df92ed088fb Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 21 Feb 2019 13:45:05 +0800 Subject: [PATCH] fix params with only 1 dim (#15828) (#15832) * fix params with only 1 dim * test=develop --- python/paddle/fluid/io.py | 5 ++++- python/paddle/fluid/transpiler/distribute_transpiler.py | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index a2abbf36c0..24e102b6c2 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -766,7 +766,10 @@ def _load_distributed_persistables(executor, dirname, main_program=None): dtype=slice_var.dtype, persistable=True) - dim1_flatten = reduce(lambda x, y: x * y, slice.shape[1:]) + dim1_flatten = 1 + if len(slice.shape) >= 2: + dim1_flatten = reduce(lambda x, y: x * y, slice.shape[1:]) + start = int(offset / dim1_flatten) end = int(offset / dim1_flatten + slice.shape[0]) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index a3293afbbd..eb54068650 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -1020,7 +1020,11 @@ class DistributeTranspiler(object): skip_dim0 = 0 slice_vars = self.param_var_mapping[orig_var_name] - orig_dim1_flatten = reduce(lambda x, y: x * y, slice_vars[0].shape[1:]) + orig_dim1_flatten = 1 + + if len(slice_vars[0].shape) >= 2: + orig_dim1_flatten = reduce(lambda x, y: x * y, + slice_vars[0].shape[1:]) for slice_var in slice_vars[:block_idx]: skip_dim0 += slice_var.shape[0] -- GitLab