diff --git a/python/paddle/fluid/distributed/helper.py b/python/paddle/fluid/distributed/helper.py index ca6dd5dabfa1ea19da56187113335a81b090df86..999c8d77b83b6fd6629d9b91a75c339ef7f3cad5 100644 --- a/python/paddle/fluid/distributed/helper.py +++ b/python/paddle/fluid/distributed/helper.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mpi4py import MPI import ps_pb2 as pslib @@ -59,7 +58,7 @@ class FileSystem(object): class MPIHelper(object): """ - MPIHelper is a wrapper of mpi4py, supprot get_rank get_size etc. + MPIHelper is a wrapper of mpi4py, support get_rank get_size etc. Args: No params Examples: @@ -68,7 +67,9 @@ class MPIHelper(object): """ def __init__(self): + from mpi4py import MPI self.comm = MPI.COMM_WORLD + self.MPI = MPI def get_rank(self): return self.comm.Get_rank() @@ -86,4 +87,4 @@ class MPIHelper(object): return socket.gethostname() def finalize(self): - MPI.Finalize() + self.MPI.Finalize() diff --git a/python/requirements.txt b/python/requirements.txt index 36313333b2b42601b0dabf8ffe899342820cdad5..2f81d85df0626b294f4d861706b5c1b7ec9841d5 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -9,4 +9,3 @@ Pillow nltk>=3.2.2 graphviz six -mpi4py==3.0.0