diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index 74baa8a350625f566993009417e21c1d7765b7a2..99944b800c42224cd579209eb7d15cb419d8776e 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -89,6 +89,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault(false); // for parameter prefetch + AddAttr("remote_prefetch", "").SetDefault(false); AddAttr("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); AddAttr>("height_sections", "Height for each output SelectedRows.") diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index ccd9175b64d46d745c8be5f64d7ddc21a117c181..a2a47ce384f4b5c00e51d86e84f15534576a4cfc 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -285,6 +285,7 @@ def embedding(input, size, is_sparse=False, is_distributed=False, + remote_prefetch=False, padding_idx=None, param_attr=None, dtype='float32'): @@ -326,6 +327,8 @@ def embedding(input, """ helper = LayerHelper('embedding', **locals()) + if remote_prefetch: + assert is_sparse is True and is_distributed is False w = helper.create_parameter( attr=helper.param_attr, shape=size, dtype=dtype, is_bias=False) tmp = helper.create_variable_for_type_inference(dtype) @@ -339,6 +342,7 @@ def embedding(input, attrs={ 'is_sparse': is_sparse, 'is_distributed': is_distributed, + 'remote_prefetch': remote_prefetch, 'padding_idx': padding_idx }) return tmp diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index d132dd3c48f55c07725515e40faeb5076398adeb..dbc4583763d36dfb584ef87b42b36abded43a0c0 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -447,19 +447,23 @@ class TestEmptyPserverOptimizeBlocks(TranspilerTest): class TestDistLookupTableBase(TranspilerTest): - def network_with_table(self, is_sparse, is_distributed): + def network_with_table(self, + is_sparse, + is_distributed, + remote_prefetch=False): self.table_size = 1000 self.emb_size = 64 self.lookup_table_name = 'shared_w' - def emb_pool(ids, table_name, is_distributed): + def emb_pool(ids, table_name, is_distributed, remote_prefetch): emb = fluid.layers.embedding( input=ids, size=[self.table_size, self.emb_size], dtype='float32', param_attr=table_name, is_sparse=is_sparse, - is_distributed=is_distributed) + is_distributed=is_distributed, + remote_prefetch=remote_prefetch) pool = fluid.layers.sequence_pool(input=emb, pool_type='average') return pool @@ -469,9 +473,12 @@ class TestDistLookupTableBase(TranspilerTest): name='brand_ids', shape=[1], dtype='int64', lod_level=1) profile_ids = fluid.layers.data( name='brand_ids', shape=[1], dtype='int64', lod_level=1) - title_emb = emb_pool(title_ids, self.lookup_table_name, is_distributed) - brand_emb = emb_pool(brand_ids, self.lookup_table_name, is_distributed) - profile_emb = emb_pool(profile_ids, "profile_emb", False) + title_emb = emb_pool(title_ids, self.lookup_table_name, is_distributed, + False) + brand_emb = emb_pool(brand_ids, self.lookup_table_name, is_distributed, + False) + profile_emb = emb_pool(profile_ids, "profile_emb", False, + remote_prefetch) fc0 = fluid.layers.concat( input=[title_emb, brand_emb, profile_emb], axis=1) predict = fluid.layers.fc(input=fc0, @@ -575,6 +582,57 @@ class TestDistLookupTable(TestDistLookupTableBase): startup_ops) +class TestRemoteLookupTable(TestDistLookupTableBase): + def net_conf(self): + self.network_with_table( + is_sparse=True, is_distributed=False, remote_prefetch=True) + + def transpiler_test_impl(self): + pserver1, startup1 = self.get_pserver(self.pserver1_ep) + + self.assertEqual(len(pserver1.blocks), 6) + # 0 listen_and_serv + # 1 optimize for fc_w or fc_b adam + self.assertEqual([op.type for op in pserver1.blocks[1].ops], + ["sum", "scale", "adam", "scale", "scale"]) + # 4 prefetch -> lookup_sparse_table for data0 + self.assertEqual([op.type for op in pserver1.blocks[2].ops], + ["sum", "scale", "adam", "scale", "scale"]) + # 2 optimize for table sgd + self.assertEqual([op.type for op in pserver1.blocks[3].ops], + ["sum", "sgd"]) + # 3 prefetch -> lookup_sparse_table for data0 + self.assertEqual([op.type for op in pserver1.blocks[4].ops], + ["lookup_sparse_table"]) + # 5 save table + self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"]) + + trainer, trainer_startup = self.get_trainer() + self.assertEqual(len(trainer.blocks), 1) + ops = [ + 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', + 'sequence_pool', 'lookup_table', 'sequence_pool', 'concat', 'mul', + 'elementwise_add', 'cross_entropy', 'mean', 'fill_constant', + 'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send', + 'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad', + 'lookup_table_grad', 'split_selected_rows', 'send', + 'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad', + 'lookup_table_grad', 'sum', 'split_ids', 'send', 'send_barrier', + 'recv', 'recv', 'recv', 'fetch_barrier', 'concat' + ] + self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) + startup_ops = [ + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'uniform_random', + 'uniform_random', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat', + 'fake_init' + ] + self.assertEqual([op.type for op in trainer_startup.blocks[0].ops], + startup_ops) + + class TestAsyncLocalLookupTable(TestDistLookupTableBase): def net_conf(self): self.network_with_table(is_sparse=True, is_distributed=False) @@ -782,5 +840,45 @@ class TestNCCL2Transpile(TranspilerTest): pass +# test for remote prefetch +class TestRemoteLookupTable(TestDistLookupTableBase): + def net_conf(self): + self.network_with_table( + is_sparse=True, is_distributed=False, remote_prefetch=True) + + def transpiler_test_impl(self): + pserver1, startup1 = self.get_pserver(self.pserver1_ep) + + self.assertEqual(len(pserver1.blocks), 4) + # 0 listen_and_serv + # 1 optimize for fc_w or fc_b adam + self.assertEqual([op.type for op in pserver1.blocks[1].ops], + ["sum", "scale", "adam", "scale", "scale"]) + # 2 optimize for table adam + # NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num + self.assertEqual([op.type for op in pserver1.blocks[2].ops], + ["sum", "scale", "adam", "scale", "scale"]) + + # 3 optimize for table 2 adam + # NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num + self.assertEqual([op.type for op in pserver1.blocks[3].ops], + ["sum", "scale", "adam", "scale", "scale"]) + + trainer, _ = self.get_trainer() + self.assertEqual(len(trainer.blocks), 1) + ops = [ + 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', + 'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add', + 'cross_entropy', 'mean', 'fill_constant', 'mean_grad', + 'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad', + 'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad', + 'split_selected_rows', 'send', 'sequence_pool_grad', + 'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', + 'sum', 'split_selected_rows', 'send', 'send_barrier', 'recv', + 'recv', 'recv', 'fetch_barrier', 'concat' + ] + self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 7a3cf1230b29b40d762b978d8ca504f5c53c85d4..ddf7468cddb863fbf70efe2ea089917b415529d3 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -236,21 +236,29 @@ class DistributeTranspiler(object): else: raise ValueError("must set trainer_id > 0") - def _get_all_sparse_update_op(self, main_program): + def _get_all_remote_sparse_update_op(self, main_program): sparse_update_ops = [] sparse_update_op_types = ["lookup_table"] for op in main_program.global_block().ops: if op.type in sparse_update_op_types and op.attr( - 'is_sparse') is True and not op.attr('is_distributed'): + 'remote_prefetch') is True and not op.attr( + 'is_distributed'): sparse_update_ops.append(op) return sparse_update_ops - def _update_sparse_update_op(self, param_varname, height_sections, - endpint_map): + def _update_remote_sparse_update_op(self, param_varname, height_sections, + endpint_map): for op in self.sparse_update_ops: if param_varname in op.input_arg_names: op._set_attr('epmap', endpint_map) op._set_attr('height_sections', height_sections) + op._set_attr('trainer_id', self.trainer_id) + + def _is_input_of_remote_sparse_update_op(self, param_name): + for op in self.sparse_update_ops: + if param_name in op.input_arg_names: + return True + return False def transpile(self, trainer_id, @@ -316,7 +324,7 @@ class DistributeTranspiler(object): self.grad_name_to_param_name[grad_var.name] = param_var.name # get all sparse update ops - self.sparse_update_ops = self._get_all_sparse_update_op( + self.sparse_update_ops = self._get_all_remote_sparse_update_op( self.origin_program) self.sparse_param_to_height_sections = dict() @@ -449,8 +457,8 @@ class DistributeTranspiler(object): if param_varname in self.sparse_param_to_height_sections: height_sections = self.sparse_param_to_height_sections[ param_varname] - self._update_sparse_update_op(param_varname, height_sections, - eps) + self._update_remote_sparse_update_op(param_varname, + height_sections, eps) else: program.global_block().append_op( type="recv", @@ -481,8 +489,6 @@ class DistributeTranspiler(object): if len(splited_var) <= 1: continue orig_param = program.global_block().vars[param_varname] - print("sparse_param_to_height_sections: " + str( - self.sparse_param_to_height_sections)) if param_varname not in self.sparse_param_to_height_sections: program.global_block().append_op( type="concat", @@ -1448,7 +1454,7 @@ to transpile() call.") for v in splited_vars: height_sections.append(v.shape[0]) sparse_param_name = self.grad_name_to_param_name[orig_var.name] - if sparse_param_name != self.table_name: + if self._is_input_of_remote_sparse_update_op(sparse_param_name): self.sparse_param_to_height_sections[ sparse_param_name] = height_sections program.global_block()._insert_op(