diff --git a/python/paddle/distributed/ps/the_one_ps.py b/python/paddle/distributed/ps/the_one_ps.py index 00daaf986bfa086e5350109b16291d289179298b..007aaeb4fed67918de0cae3bade229ca314dd952 100755 --- a/python/paddle/distributed/ps/the_one_ps.py +++ b/python/paddle/distributed/ps/the_one_ps.py @@ -16,7 +16,7 @@ import warnings import os import paddle.fluid as fluid -import paddle.distributed.fleet as fleet +from paddle.distributed import fleet from paddle.fluid import core from paddle.distributed.ps.utils.public import * from paddle.fluid.framework import Program @@ -26,7 +26,7 @@ from paddle.fluid.parallel_executor import ParallelExecutor from paddle.fluid.framework import Variable, Parameter from paddle.distributed.fleet.runtime.runtime_base import RuntimeBase from paddle.distributed.fleet.base.private_helper_function import wait_server_ready -import paddle.distributed.fleet.proto.the_one_ps_pb2 as ps_pb2 +from paddle.distributed.fleet.proto import the_one_ps_pb2 from paddle.fluid.communicator import Communicator, HeterClient from google.protobuf import text_format @@ -518,7 +518,7 @@ class BarrierTable(Table): table_proto.table_id = self.idx table_proto.table_class = 'BarrierTable' table_proto.shard_num = 256 - table_proto.type = ps_pb2.PS_OTHER_TABLE + table_proto.type = the_one_ps_pb2.PS_OTHER_TABLE table_proto.accessor.accessor_class = "CommMergeAccessor" table_proto.accessor.fea_dim = 0 @@ -544,7 +544,7 @@ class TensorTable(Table): def _set(self, table_proto): table_proto.table_id = self.idx - table_proto.type = ps_pb2.PS_OTHER_TABLE + table_proto.type = the_one_ps_pb2.PS_OTHER_TABLE table_proto.table_class = self.tensor_dict.get("tensor_table_class", '') table_proto.accessor.accessor_class = "CommMergeAccessor" @@ -573,7 +573,7 @@ class SparseTable(Table): return table_proto.table_id = ctx.table_id() table_proto.table_class = self.table_class - table_proto.type = ps_pb2.PS_SPARSE_TABLE + table_proto.type = the_one_ps_pb2.PS_SPARSE_TABLE table_proto.shard_num = self.shard_num self.common.table_name = self.context['grad_name_to_param_name'][ @@ -632,7 +632,7 @@ class GeoSparseTable(SparseTable): return table_proto.table_id = ctx.table_id() table_proto.table_class = self.table_class - table_proto.type = ps_pb2.PS_SPARSE_TABLE + table_proto.type = the_one_ps_pb2.PS_SPARSE_TABLE table_proto.shard_num = self.shard_num table_proto.accessor.accessor_class = 'CommMergeAccessor' @@ -664,7 +664,7 @@ class DenseTable(Table): table_proto.table_id = ctx.table_id() - table_proto.type = ps_pb2.PS_DENSE_TABLE + table_proto.type = the_one_ps_pb2.PS_DENSE_TABLE table_proto.table_class = "CommonDenseTable" table_proto.shard_num = 256 @@ -748,7 +748,7 @@ class PsDescBuilder(object): self.service = self._get_service() self.fs_client = self._get_fs_client() - self.ps_desc = ps_pb2.PSParameter() + self.ps_desc = the_one_ps_pb2.PSParameter() def _get_tensor_tables(self): program_idx = 0 @@ -806,7 +806,7 @@ class PsDescBuilder(object): table_proto = self.ps_desc.server_param.downpour_server_param.downpour_table_param.add( ) table._set(table_proto) - if table_proto.type == ps_pb2.PS_SPARSE_TABLE and table_proto.common is not None: + if table_proto.type == the_one_ps_pb2.PS_SPARSE_TABLE and table_proto.common is not None: self.sparse_table_maps[ table_proto.common.table_name] = table_proto.table_id