diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 0543e6238109bf5bab68a9f08fc34678936e756f..55f8b3eff874bce6ff55bdc07dac7cb0ab4ef4c9 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math + import unittest import paddle.fluid as fluid from paddle.fluid.transpiler.distribute_transpiler import delete_ops @@ -362,12 +364,13 @@ class TestL2DecayWithPiecewise(TranspilerTest): class TestDistLookupTableBase(TranspilerTest): def network_with_table(self, is_sparse, is_distributed): + self.table_size = 1000 + self.emb_size = 64 + def emb_pool(ids): - table_size = 1000 - emb_size = 64 emb = fluid.layers.embedding( input=ids, - size=[table_size, emb_size], + size=[self.table_size, self.emb_size], dtype='float32', param_attr='shared_w', # share parameter is_sparse=is_sparse, @@ -536,6 +539,22 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) +class TestDistLookupTableSliceSize(TestDistLookupTableBase): + def net_conf(self): + self.network_with_table(is_sparse=True, is_distributed=True) + + def transpiler_test_impl(self): + config = fluid.DistributeTranspilerConfig() + pserver1, startup1 = self.get_pserver(self.pserver1_ep, config) + + self.assertTrue(self.transpiler.has_distributed_lookup_table) + lookup_table_var = pserver1.global_block().vars[ + self.transpiler.table_name] + row_size = lookup_table_var.shape[0] + calc_row_size = int(math.ceil(self.table_size / self.pservers)) + self.assertEqual(row_size, calc_row_size) + + class TestRMSPropOptimizer(TranspilerTest): def net_conf(self): x = fluid.layers.data(name='x', shape=[1000], dtype='float32') diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index cd6cf558d5c4cd4a3fbced95e92c3f4dd94661af..c97beea1b3aef2c9b2718ff5b005e49a61e58109 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -885,9 +885,15 @@ class DistributeTranspiler(object): # create table param and grad var in pserver program origin_param_var = self.origin_program.global_block().vars[ self.table_name] + + zero_dim = int( + math.ceil(origin_param_var.shape[0] / len(self.pserver_endpoints))) + table_shape = list(origin_param_var.shape) + table_shape[0] = zero_dim + param_var = pserver_program.global_block().create_var( name=origin_param_var.name, - shape=origin_param_var.shape, + shape=table_shape, dtype=origin_param_var.dtype, type=core.VarDesc.VarType.SELECTED_ROWS, persistable=True)