未验证 提交 7f696804 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Reorganize the fold structure (#54059)

* [Auto Parallel] Reorganize the fold structure

* [Auto Parallel] Fix some import errors
上级 88e43625
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from .strategy import Strategy from .strategy import Strategy
from .process_mesh import ProcessMesh from .process_mesh import ProcessMesh
from .engine import Engine from .static.engine import Engine
from .interface import shard_tensor from .interface import shard_tensor
from .interface import shard_op from .interface import shard_op
from .interface import recompute from .interface import recompute
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
import paddle import paddle
from .dist_context import get_default_distributed_context
from .dist_op import DistributedOperatorHelper
from .dist_tensor import DistributedTensor
from .process_mesh import ProcessMesh, get_current_process_mesh from .process_mesh import ProcessMesh, get_current_process_mesh
from .utils import ( from .static.dist_context import get_default_distributed_context
from .static.dist_op import DistributedOperatorHelper
from .static.dist_tensor import DistributedTensor
from .static.utils import (
__no_shape_var_type__, __no_shape_var_type__,
convert_to_dims_mapping, convert_to_dims_mapping,
verify_shard_spec, verify_shard_spec,
......
...@@ -140,12 +140,12 @@ class ProcessMesh(core.ProcessMesh): ...@@ -140,12 +140,12 @@ class ProcessMesh(core.ProcessMesh):
) )
# Store all process meshes # Store all process meshes
from .dist_context import get_default_distributed_context from .static.dist_context import get_default_distributed_context
default_dist_cxt = get_default_distributed_context() default_dist_cxt = get_default_distributed_context()
default_dist_cxt.add_process_mesh(self) default_dist_cxt.add_process_mesh(self)
# Add new processes to process group 0 # Add new processes to process group 0
from .process_group import get_process_group from .static.process_group import get_process_group
pg0 = get_process_group(0) pg0 = get_process_group(0)
pg0.add_ranks(self.process_ids) pg0.add_ranks(self.process_ids)
...@@ -204,14 +204,14 @@ class ProcessMesh(core.ProcessMesh): ...@@ -204,14 +204,14 @@ class ProcessMesh(core.ProcessMesh):
self._old_op_size = len(cur_block.ops) self._old_op_size = len(cur_block.ops)
def __exit__(self, exc_type, exc_value, exc_traceback): def __exit__(self, exc_type, exc_value, exc_traceback):
from .dist_op import DistributedOperator from .static.dist_op import DistributedOperator
from .dist_tensor import DistributedTensor from .static.dist_tensor import DistributedTensor
default_prog = paddle.static.default_main_program() default_prog = paddle.static.default_main_program()
cur_block = default_prog.current_block() cur_block = default_prog.current_block()
new_var_names = list(cur_block.vars.keys()) new_var_names = list(cur_block.vars.keys())
new_op_size = len(cur_block.ops) new_op_size = len(cur_block.ops)
from .dist_context import get_default_distributed_context from .static.dist_context import get_default_distributed_context
default_dist_ctx = get_default_distributed_context() default_dist_ctx = get_default_distributed_context()
for name in new_var_names: for name in new_var_names:
......
...@@ -17,7 +17,7 @@ import paddle ...@@ -17,7 +17,7 @@ import paddle
from ..utils.log_utils import get_logger from ..utils.log_utils import get_logger
from .process_mesh import retrive_unique_id_for_process_mesh from .process_mesh import retrive_unique_id_for_process_mesh
from .utils import _get_idx_in_axis from .static.utils import _get_idx_in_axis
_logger = get_logger(logging.INFO) _logger = get_logger(logging.INFO)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -21,11 +21,11 @@ import numpy as np ...@@ -21,11 +21,11 @@ import numpy as np
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from paddle.distributed.auto_parallel.converter import Converter from paddle.distributed.auto_parallel.static.converter import Converter
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.static.dist_context import (
get_default_distributed_context, get_default_distributed_context,
) )
from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.static.utils import (
is_backward_op, is_backward_op,
is_forward_op, is_forward_op,
is_loss_op, is_loss_op,
......
...@@ -24,7 +24,7 @@ from paddle.hapi.callbacks import ( ...@@ -24,7 +24,7 @@ from paddle.hapi.callbacks import (
ProgBarLogger, ProgBarLogger,
) )
from .interface import CollectionNames, get_collection from ..interface import CollectionNames, get_collection
def config_callbacks( def config_callbacks(
......
...@@ -20,7 +20,7 @@ from enum import IntEnum, unique ...@@ -20,7 +20,7 @@ from enum import IntEnum, unique
import paddle import paddle
from ..utils.log_utils import get_logger from ...utils.log_utils import get_logger
@unique @unique
......
...@@ -18,11 +18,11 @@ import logging ...@@ -18,11 +18,11 @@ import logging
from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.framework import core from paddle.framework import core
from ..process_mesh import ProcessMesh, compute_compatible_process_mesh
from .dist_attribute import OperatorDistAttr, TensorDistAttr from .dist_attribute import OperatorDistAttr, TensorDistAttr
from .dist_context import _node_id from .dist_context import _node_id
from .operators import find_compatible_distributed_operator_impls from .operators import find_compatible_distributed_operator_impls
from .process_group import get_world_process_group from .process_group import get_world_process_group
from .process_mesh import ProcessMesh, compute_compatible_process_mesh
from .utils import ( from .utils import (
__no_shape_var_type__, __no_shape_var_type__,
get_logger, get_logger,
...@@ -1641,7 +1641,7 @@ class Completer: ...@@ -1641,7 +1641,7 @@ class Completer:
"""Complete the annotation of vars and ops in the update phase for parallel program.""" """Complete the annotation of vars and ops in the update phase for parallel program."""
# Copy the dist tensors and dist ops annotated by users from the default context # Copy the dist tensors and dist ops annotated by users from the default context
# global mesh # global mesh
from paddle.distributed.auto_parallel.process_group import ( from paddle.distributed.auto_parallel.static.process_group import (
get_world_process_group, get_world_process_group,
) )
...@@ -1895,7 +1895,7 @@ class Completer: ...@@ -1895,7 +1895,7 @@ class Completer:
def _init_global_mesh_for_program(self): def _init_global_mesh_for_program(self):
# Copy the dist tensors and dist ops annotated by users from the default context # Copy the dist tensors and dist ops annotated by users from the default context
# global mesh # global mesh
from paddle.distributed.auto_parallel.process_group import ( from paddle.distributed.auto_parallel.static.process_group import (
get_world_process_group, get_world_process_group,
) )
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
import paddle import paddle
from ..utils.log_utils import get_logger from ...utils.log_utils import get_logger
class Converter: class Converter:
......
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
from functools import reduce from functools import reduce
import paddle import paddle
from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor from paddle.distributed.auto_parallel.static.dist_tensor import (
DistributedTensor,
)
from paddle.static import Variable from paddle.static import Variable
from .base_cost import Cost from .base_cost import Cost
......
...@@ -18,9 +18,9 @@ from collections import defaultdict ...@@ -18,9 +18,9 @@ from collections import defaultdict
from paddle.distributed.passes import PassContext from paddle.distributed.passes import PassContext
from paddle.framework import IrGraph, core, set_flags from paddle.framework import IrGraph, core, set_flags
from ..process_mesh import ProcessMesh
from .dist_op import DistributedOperator from .dist_op import DistributedOperator
from .dist_tensor import DistributedTensor from .dist_tensor import DistributedTensor
from .process_mesh import ProcessMesh
from .utils import ( from .utils import (
__no_shape_var_type__, __no_shape_var_type__,
_copy_dist_attr_to_cpp, _copy_dist_attr_to_cpp,
......
...@@ -23,7 +23,7 @@ import numpy as np ...@@ -23,7 +23,7 @@ import numpy as np
import paddle import paddle
from paddle.framework import core from paddle.framework import core
from ..utils.log_utils import get_logger from ...utils.log_utils import get_logger
from .process_group import _g_process_group_map from .process_group import _g_process_group_map
from .utils import get_dist_attr from .utils import get_dist_attr
......
...@@ -22,7 +22,7 @@ import random ...@@ -22,7 +22,7 @@ import random
import numpy as np import numpy as np
import paddle import paddle
import paddle.distributed.auto_parallel.utils as auto_utils import paddle.distributed.auto_parallel.static.utils as auto_utils
from paddle import static, utils from paddle import static, utils
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.fluid.executor import _to_name_str from paddle.fluid.executor import _to_name_str
...@@ -32,7 +32,9 @@ from paddle.framework import core, in_dynamic_mode ...@@ -32,7 +32,9 @@ from paddle.framework import core, in_dynamic_mode
from paddle.metric import Metric from paddle.metric import Metric
from paddle.static import InputSpec, Operator, Variable, global_scope from paddle.static import InputSpec, Operator, Variable, global_scope
from ..utils.log_utils import get_logger from ...utils.log_utils import get_logger
from ..interface import CollectionNames, fetch, get_collection
from ..strategy import Strategy
from .callbacks import config_callbacks from .callbacks import config_callbacks
from .cluster import Cluster, get_default_cluster from .cluster import Cluster, get_default_cluster
from .converter import Converter from .converter import Converter
...@@ -45,11 +47,9 @@ from .dist_loader import ( ...@@ -45,11 +47,9 @@ from .dist_loader import (
from .dist_op import DistributedOperator from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver from .dist_saver import DistributedSaver
from .helper import ProgramHelper from .helper import ProgramHelper
from .interface import CollectionNames, fetch, get_collection
from .parallelizer_v2 import Parallelizer from .parallelizer_v2 import Parallelizer
from .planner_v2 import Planner from .planner_v2 import Planner
from .process_group import get_all_process_groups, new_process_group from .process_group import get_all_process_groups, new_process_group
from .strategy import Strategy
class Engine: class Engine:
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from paddle.distributed.auto_parallel.process_group import ( from paddle.distributed.auto_parallel.static.process_group import (
get_world_process_group, get_world_process_group,
) )
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
......
...@@ -18,10 +18,10 @@ import paddle ...@@ -18,10 +18,10 @@ import paddle
from paddle.framework import core from paddle.framework import core
from paddle.utils import unique_name from paddle.utils import unique_name
from ...utils.log_utils import get_logger from ....utils.log_utils import get_logger
_logger = get_logger(logging.INFO) _logger = get_logger(logging.INFO)
from ..random import determinate_rng, is_enable_auto_rand_ctrl from ...random import determinate_rng, is_enable_auto_rand_ctrl
from ..utils import ( from ..utils import (
naive_set_dist_op_attr_for_program_by_mesh_and_mapping, naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr, set_var_dist_attr,
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License # limitations under the License
from paddle.common_ops_import import check_dtype, check_variable_and_dtype from paddle.common_ops_import import check_dtype, check_variable_and_dtype
from paddle.distributed.auto_parallel.cost.comm_op_cost import ( from paddle.distributed.auto_parallel.static.cost.comm_op_cost import (
AllreduceSumOpCost, AllreduceSumOpCost,
IdentityOpCost, IdentityOpCost,
) )
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
import logging import logging
from ...utils.log_utils import get_logger from ....utils.log_utils import get_logger
_logger = get_logger(logging.INFO) _logger = get_logger(logging.INFO)
from ..random import determinate_rng, is_enable_auto_rand_ctrl from ...random import determinate_rng, is_enable_auto_rand_ctrl
from .common import ( from .common import (
DistributedOperatorImplContainer, DistributedOperatorImplContainer,
register_distributed_operator_impl, register_distributed_operator_impl,
......
...@@ -18,10 +18,10 @@ import paddle ...@@ -18,10 +18,10 @@ import paddle
from paddle.framework import core from paddle.framework import core
from paddle.utils import unique_name from paddle.utils import unique_name
from ...utils.log_utils import get_logger from ....utils.log_utils import get_logger
_logger = get_logger(logging.INFO) _logger = get_logger(logging.INFO)
from ..random import determinate_rng, is_enable_auto_rand_ctrl from ...random import determinate_rng, is_enable_auto_rand_ctrl
from ..utils import ( from ..utils import (
naive_set_dist_op_attr_for_program_by_mesh_and_mapping, naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr, set_var_dist_attr,
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import copy import copy
from paddle.common_ops_import import check_dtype, check_variable_and_dtype from paddle.common_ops_import import check_dtype, check_variable_and_dtype
from paddle.distributed.auto_parallel.cost.comm_op_cost import ( from paddle.distributed.auto_parallel.static.cost.comm_op_cost import (
AllreduceSumOpCost, AllreduceSumOpCost,
IdentityOpCost, IdentityOpCost,
) )
......
...@@ -20,10 +20,10 @@ from paddle.distributed.passes import PassManager, new_pass ...@@ -20,10 +20,10 @@ from paddle.distributed.passes import PassManager, new_pass
from paddle.static import append_backward, program_guard from paddle.static import append_backward, program_guard
from paddle.utils import unique_name from paddle.utils import unique_name
from ..utils.log_utils import get_logger from ...utils.log_utils import get_logger
from ..random import init_auto_parallel_rng
from .partitioner import Partitioner from .partitioner import Partitioner
from .process_group import get_world_process_group from .process_group import get_world_process_group
from .random import init_auto_parallel_rng
from .reshard import Resharder from .reshard import Resharder
from .utils import set_grad_var_shape from .utils import set_grad_var_shape
......
...@@ -15,8 +15,10 @@ ...@@ -15,8 +15,10 @@
import copy import copy
import paddle import paddle
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.operators.common import ( DistributedContext,
)
from paddle.distributed.auto_parallel.static.operators.common import (
get_distributed_operator_impl_container, get_distributed_operator_impl_container,
) )
from paddle.framework import Program, core from paddle.framework import Program, core
......
...@@ -18,15 +18,17 @@ import pickle ...@@ -18,15 +18,17 @@ import pickle
import numpy as np import numpy as np
from paddle.distributed.auto_parallel.dist_attribute import ( from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.static.dist_attribute import (
OperatorDistAttr, OperatorDistAttr,
TensorDistAttr, TensorDistAttr,
) )
from paddle.distributed.auto_parallel.dist_op import DistributedOperator from paddle.distributed.auto_parallel.static.dist_op import DistributedOperator
from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor from paddle.distributed.auto_parallel.static.dist_tensor import (
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh DistributedTensor,
)
from ..utils.log_utils import get_logger from ...utils.log_utils import get_logger
from .completion import Completer from .completion import Completer
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
from .tuner.parallel_tuner import ParallelTuner from .tuner.parallel_tuner import ParallelTuner
......
...@@ -17,8 +17,8 @@ from collections import OrderedDict ...@@ -17,8 +17,8 @@ from collections import OrderedDict
import paddle import paddle
from paddle.framework import core from paddle.framework import core
from ..collective import _get_global_env, _new_ring_id from ...collective import _get_global_env, _new_ring_id
from ..utils.log_utils import get_logger from ...utils.log_utils import get_logger
from .utils import dygraph_guard from .utils import dygraph_guard
logger = get_logger("INFO", __name__) logger = get_logger("INFO", __name__)
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import copy import copy
import os import os
from ..strategy import Strategy from ...strategy import Strategy
_tuning_supported_passes = ["sharding", "recompute"] _tuning_supported_passes = ["sharding", "recompute"]
......
...@@ -27,16 +27,18 @@ import sys ...@@ -27,16 +27,18 @@ import sys
import time import time
import paddle import paddle
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.partitioner import Partitioner DistributedContext,
from paddle.distributed.auto_parallel.process_group import ( )
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.auto_parallel.static.process_group import (
clear_all_process_groups, clear_all_process_groups,
get_all_process_groups, get_all_process_groups,
new_process_group, new_process_group,
) )
from paddle.distributed.auto_parallel.reshard import Resharder from paddle.distributed.auto_parallel.static.reshard import Resharder
from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.static.utils import (
debug_program, debug_program,
set_grad_var_shape, set_grad_var_shape,
) )
...@@ -465,7 +467,7 @@ class OptimizationTuner: ...@@ -465,7 +467,7 @@ class OptimizationTuner:
] ]
) )
cmd_args = ( cmd_args = (
"-m paddle.distributed.auto_parallel.tuner.profiler" "-m paddle.distributed.auto_parallel.static.tuner.profiler"
+ " " + " "
+ profile_args + profile_args
) )
......
...@@ -21,13 +21,13 @@ from collections import defaultdict ...@@ -21,13 +21,13 @@ from collections import defaultdict
import numpy as np import numpy as np
from ...process_mesh import ProcessMesh
from ..completion import Completer from ..completion import Completer
from ..cost import CostEstimator from ..cost import CostEstimator
from ..dist_context import _node_id from ..dist_context import _node_id
from ..dist_op import DistributedOperator from ..dist_op import DistributedOperator
from ..operators.common import find_compatible_distributed_operator_impls from ..operators.common import find_compatible_distributed_operator_impls
from ..parallelizer_v2 import Parallelizer from ..parallelizer_v2 import Parallelizer
from ..process_mesh import ProcessMesh
from .trial import Trial, TrialStatus from .trial import Trial, TrialStatus
from .tunable_space import TunableSpace from .tunable_space import TunableSpace
from .tunable_variable import Boolean, IntRange from .tunable_variable import Boolean, IntRange
......
...@@ -21,10 +21,10 @@ import time ...@@ -21,10 +21,10 @@ import time
import traceback import traceback
import paddle import paddle
from paddle.distributed.auto_parallel.dist_loader import ( from paddle.distributed.auto_parallel.static.dist_loader import (
DistributedDataLoaderFromGenerator, DistributedDataLoaderFromGenerator,
) )
from paddle.distributed.auto_parallel.process_group import ( from paddle.distributed.auto_parallel.static.process_group import (
get_all_process_groups, get_all_process_groups,
new_process_group, new_process_group,
) )
......
...@@ -26,20 +26,24 @@ from functools import reduce ...@@ -26,20 +26,24 @@ from functools import reduce
import numpy as np import numpy as np
import paddle import paddle
from paddle.distributed.auto_parallel.cluster_v2 import DeviceMesh from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.cluster_v2 import DeviceMesh
from paddle.distributed.auto_parallel.cost import CostEstimator from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_attribute import ( from paddle.distributed.auto_parallel.static.cost import CostEstimator
from paddle.distributed.auto_parallel.static.dist_attribute import (
OperatorDistAttr, OperatorDistAttr,
TensorDistAttr, TensorDistAttr,
) )
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor DistributedContext,
from paddle.distributed.auto_parallel.process_group import ( )
from paddle.distributed.auto_parallel.static.dist_tensor import (
DistributedTensor,
)
from paddle.distributed.auto_parallel.static.process_group import (
get_world_process_group, get_world_process_group,
) )
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh from paddle.distributed.auto_parallel.static.utils import (
from paddle.distributed.auto_parallel.utils import (
is_gradient_clip_op, is_gradient_clip_op,
print_program_with_dist_attr, print_program_with_dist_attr,
) )
...@@ -48,7 +52,7 @@ from paddle.fluid import program_guard ...@@ -48,7 +52,7 @@ from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Parameter, unique_name from paddle.fluid.framework import Parameter, unique_name
from ...utils.log_utils import get_logger from ....utils.log_utils import get_logger
from ..graph import Graph from ..graph import Graph
_PATTERNS = {} _PATTERNS = {}
......
...@@ -27,8 +27,8 @@ from paddle.framework import core ...@@ -27,8 +27,8 @@ from paddle.framework import core
from paddle.framework.io_utils import is_belong_to_optimizer, is_parameter from paddle.framework.io_utils import is_belong_to_optimizer, is_parameter
from paddle.static import Variable from paddle.static import Variable
from ..process_mesh import ProcessMesh
from .dist_attribute import OperatorDistAttr, TensorDistAttr from .dist_attribute import OperatorDistAttr, TensorDistAttr
from .process_mesh import ProcessMesh
OpRole = core.op_proto_and_checker_maker.OpRole OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
...@@ -1868,7 +1868,7 @@ def get_lr(optimizer): ...@@ -1868,7 +1868,7 @@ def get_lr(optimizer):
def initialize_pg_in_full_mode(all_process_groups, cur_rank): def initialize_pg_in_full_mode(all_process_groups, cur_rank):
import socket import socket
from ..collective import _get_global_env from ...collective import _get_global_env
has_recv_by_socket = [] has_recv_by_socket = []
# This is a magic number # This is a magic number
...@@ -1946,7 +1946,7 @@ def is_recompute_op(op): ...@@ -1946,7 +1946,7 @@ def is_recompute_op(op):
def set_recompute_segments(model, losses, strategy, program): def set_recompute_segments(model, losses, strategy, program):
from ..passes.auto_parallel_recompute import RecomputeState from ...passes.auto_parallel_recompute import RecomputeState
if not losses: if not losses:
return return
...@@ -2054,7 +2054,7 @@ def validate_opt(optimizer): ...@@ -2054,7 +2054,7 @@ def validate_opt(optimizer):
def set_data_parallel(x): def set_data_parallel(x):
from .interface import ProcessMesh, shard_tensor from ..interface import ProcessMesh, shard_tensor
from .process_group import get_world_process_group from .process_group import get_world_process_group
world_ranks = get_world_process_group().ranks world_ranks = get_world_process_group().ranks
...@@ -2095,7 +2095,7 @@ def _copy_tensor_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr): ...@@ -2095,7 +2095,7 @@ def _copy_tensor_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr):
def _copy_tensor_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr): def _copy_tensor_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr):
from .process_mesh import ProcessMesh from ..process_mesh import ProcessMesh
cpp_process_mesh = cpp_dist_attr.process_mesh cpp_process_mesh = cpp_dist_attr.process_mesh
if cpp_process_mesh is not None: if cpp_process_mesh is not None:
...@@ -2128,7 +2128,7 @@ def _copy_op_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr): ...@@ -2128,7 +2128,7 @@ def _copy_op_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr):
def _copy_op_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr): def _copy_op_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr):
from .process_mesh import ProcessMesh from ..process_mesh import ProcessMesh
cpp_process_mesh = cpp_dist_attr.process_mesh cpp_process_mesh = cpp_dist_attr.process_mesh
if cpp_process_mesh is not None: if cpp_process_mesh is not None:
......
...@@ -1335,7 +1335,7 @@ class Fleet: ...@@ -1335,7 +1335,7 @@ class Fleet:
self._user_defined_strategy.semi_auto self._user_defined_strategy.semi_auto
or self._user_defined_strategy.auto_search or self._user_defined_strategy.auto_search
): ):
from ..auto_parallel.parallelizer import AutoParallelizer from ..auto_parallel.static.parallelizer import AutoParallelizer
auto_parallelizer = AutoParallelizer(self) auto_parallelizer = AutoParallelizer(self)
( (
......
...@@ -13,11 +13,13 @@ ...@@ -13,11 +13,13 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistAttr from paddle.distributed.auto_parallel.static.dist_attribute import (
from paddle.distributed.auto_parallel.process_group import ( OperatorDistAttr,
)
from paddle.distributed.auto_parallel.static.process_group import (
get_world_process_group, get_world_process_group,
) )
from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.static.utils import (
naive_set_dist_op_attr_for_program_by_mesh_and_mapping, naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr, set_var_dist_attr,
) )
...@@ -42,7 +44,7 @@ from paddle.static.amp.fp16_utils import ( ...@@ -42,7 +44,7 @@ from paddle.static.amp.fp16_utils import (
from paddle.utils import unique_name from paddle.utils import unique_name
from ..auto_parallel.process_mesh import ProcessMesh from ..auto_parallel.process_mesh import ProcessMesh
from ..auto_parallel.utils import ( from ..auto_parallel.static.utils import (
is_backward_op, is_backward_op,
is_forward_op, is_forward_op,
is_loss_grad_op, is_loss_grad_op,
......
...@@ -15,16 +15,16 @@ ...@@ -15,16 +15,16 @@
from collections import OrderedDict from collections import OrderedDict
import paddle import paddle
from paddle.distributed.auto_parallel.dist_attribute import ( from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.static.dist_attribute import (
OperatorDistAttr, OperatorDistAttr,
TensorDistAttr, TensorDistAttr,
) )
from paddle.distributed.auto_parallel.operators.common import ( from paddle.distributed.auto_parallel.static.operators.common import (
is_data_parallel_reduce_op, is_data_parallel_reduce_op,
is_data_parallel_scale_op, is_data_parallel_scale_op,
) )
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh from paddle.distributed.auto_parallel.static.utils import (
from paddle.distributed.auto_parallel.utils import (
find_higher_order_backward_op, find_higher_order_backward_op,
get_var_numel, get_var_numel,
insert_dependencies_for_vars, insert_dependencies_for_vars,
......
...@@ -16,11 +16,13 @@ from collections import defaultdict ...@@ -16,11 +16,13 @@ from collections import defaultdict
import paddle import paddle
from paddle.common_ops_import import check_type, check_variable_and_dtype from paddle.common_ops_import import check_type, check_variable_and_dtype
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistAttr from paddle.distributed.auto_parallel.static.dist_attribute import (
from paddle.distributed.auto_parallel.process_group import ( OperatorDistAttr,
)
from paddle.distributed.auto_parallel.static.process_group import (
get_world_process_group, get_world_process_group,
) )
from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.static.utils import (
is_backward_op, is_backward_op,
is_forward_op, is_forward_op,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping, naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
......
...@@ -19,18 +19,21 @@ import numpy as np ...@@ -19,18 +19,21 @@ import numpy as np
import paddle import paddle
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from ..auto_parallel.dist_attribute import OperatorDistAttr, TensorDistAttr from ..auto_parallel.process_mesh import ProcessMesh
from ..auto_parallel.operators.common import ( from ..auto_parallel.static.dist_attribute import (
OperatorDistAttr,
TensorDistAttr,
)
from ..auto_parallel.static.operators.common import (
SyncMode, SyncMode,
is_data_parallel_reduce_op, is_data_parallel_reduce_op,
) )
from ..auto_parallel.process_group import ( from ..auto_parallel.static.process_group import (
get_all_process_groups, get_all_process_groups,
get_world_process_group, get_world_process_group,
) )
from ..auto_parallel.process_mesh import ProcessMesh from ..auto_parallel.static.reshard import Resharder
from ..auto_parallel.reshard import Resharder from ..auto_parallel.static.utils import (
from ..auto_parallel.utils import (
_get_comm_group, _get_comm_group,
insert_dependencies_for_vars, insert_dependencies_for_vars,
is_gradient_clip_op, is_gradient_clip_op,
......
...@@ -15,11 +15,11 @@ ...@@ -15,11 +15,11 @@
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
import paddle import paddle
from paddle.distributed.auto_parallel.process_group import ( from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.static.process_group import (
get_world_process_group, get_world_process_group,
) )
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh from paddle.distributed.auto_parallel.static.utils import (
from paddle.distributed.auto_parallel.utils import (
is_optimize_op, is_optimize_op,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping, naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr, set_var_dist_attr,
......
...@@ -26,8 +26,11 @@ from paddle.static.quantization import ( ...@@ -26,8 +26,11 @@ from paddle.static.quantization import (
quant_config, quant_config,
) )
from ..auto_parallel.converter import Converter from ..auto_parallel.static.converter import Converter
from ..auto_parallel.dist_attribute import OperatorDistAttr, TensorDistAttr from ..auto_parallel.static.dist_attribute import (
OperatorDistAttr,
TensorDistAttr,
)
from .pass_base import PassBase, register_pass from .pass_base import PassBase, register_pass
TRANSFORM_PASS_OP_TYPES = list( TRANSFORM_PASS_OP_TYPES = list(
......
...@@ -26,8 +26,8 @@ from paddle.fluid.backward import ( ...@@ -26,8 +26,8 @@ from paddle.fluid.backward import (
from paddle.framework import core from paddle.framework import core
from paddle.utils import unique_name from paddle.utils import unique_name
from ..auto_parallel.dist_attribute import OperatorDistAttr from ..auto_parallel.static.dist_attribute import OperatorDistAttr
from ..auto_parallel.utils import ( from ..auto_parallel.static.utils import (
get_loss_op, get_loss_op,
insert_dependencies_for_two_ops, insert_dependencies_for_two_ops,
is_backward_op, is_backward_op,
......
...@@ -16,13 +16,15 @@ import logging ...@@ -16,13 +16,15 @@ import logging
from functools import reduce from functools import reduce
import paddle import paddle
from paddle.distributed.auto_parallel.operators.common import ( from paddle.distributed.auto_parallel.static.operators.common import (
ParallelMode, ParallelMode,
is_data_parallel_reduce_op, is_data_parallel_reduce_op,
is_parameter_related, is_parameter_related,
) )
from paddle.distributed.auto_parallel.process_group import new_process_group from paddle.distributed.auto_parallel.static.process_group import (
from paddle.distributed.auto_parallel.utils import ( new_process_group,
)
from paddle.distributed.auto_parallel.static.utils import (
_get_comm_group, _get_comm_group,
get_logger, get_logger,
get_var_numel, get_var_numel,
......
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.distributed.auto_parallel.operators.common import ( from paddle.distributed.auto_parallel.static.operators.common import (
is_amp_flag_sync_op, is_amp_flag_sync_op,
is_data_parallel_reduce_op, is_data_parallel_reduce_op,
is_global_norm_sync_op, is_global_norm_sync_op,
) )
from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.static.utils import (
OpRole, OpRole,
insert_dependencies_for_vars, insert_dependencies_for_vars,
) )
......
...@@ -1439,7 +1439,7 @@ def _append_backward_ops_( ...@@ -1439,7 +1439,7 @@ def _append_backward_ops_(
) )
else: else:
default_ctx = getattr( default_ctx = getattr(
paddle.distributed.auto_parallel.dist_context, paddle.distributed.auto_parallel.static.dist_context,
'_g_default_distributed_context', '_g_default_distributed_context',
None, None,
) )
......
...@@ -1681,7 +1681,7 @@ class Variable(metaclass=VariableMetaClass): ...@@ -1681,7 +1681,7 @@ class Variable(metaclass=VariableMetaClass):
if self.persistable: if self.persistable:
var_str = "persist " + var_str var_str = "persist " + var_str
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.static.dist_context import (
get_default_distributed_context, get_default_distributed_context,
) )
...@@ -3137,7 +3137,7 @@ class Operator: ...@@ -3137,7 +3137,7 @@ class Operator:
if i != len(attr_names) - 1: if i != len(attr_names) - 1:
attrs_str += ", " attrs_str += ", "
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.static.dist_context import (
get_default_distributed_context, get_default_distributed_context,
) )
......
...@@ -22,10 +22,10 @@ import paddle ...@@ -22,10 +22,10 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static, utils from paddle import nn, static, utils
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.static.dist_context import (
set_default_distributed_context, set_default_distributed_context,
) )
from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.static.utils import (
get_dist_attr, get_dist_attr,
load_checkpoint_into_program, load_checkpoint_into_program,
load_distributed_checkpoint, load_distributed_checkpoint,
......
...@@ -23,7 +23,7 @@ import paddle ...@@ -23,7 +23,7 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static, utils from paddle import nn, static, utils
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.static.utils import (
load_checkpoint_into_program, load_checkpoint_into_program,
save_distributed_checkpoint, save_distributed_checkpoint,
) )
......
...@@ -25,7 +25,7 @@ import numpy as np ...@@ -25,7 +25,7 @@ import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel import engine from paddle.distributed.auto_parallel.static import engine
from paddle.distributed.fleet.layers.mpu.mp_layers import ( from paddle.distributed.fleet.layers.mpu.mp_layers import (
ColumnParallelLinear, ColumnParallelLinear,
RowParallelLinear, RowParallelLinear,
......
...@@ -17,7 +17,7 @@ import os ...@@ -17,7 +17,7 @@ import os
import tempfile import tempfile
import unittest import unittest
from paddle.distributed.auto_parallel.cluster import ( from paddle.distributed.auto_parallel.static.cluster import (
Cluster, Cluster,
DeviceType, DeviceType,
LinkType, LinkType,
......
...@@ -18,8 +18,10 @@ import unittest.mock ...@@ -18,8 +18,10 @@ import unittest.mock
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static, tensor, utils from paddle import nn, static, tensor, utils
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext,
)
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
paddle.enable_static() paddle.enable_static()
...@@ -188,7 +190,7 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -188,7 +190,7 @@ class TestMLPAutoCompletion(unittest.TestCase):
# # dist_context) # # dist_context)
# dist_context.finalize_distributed_attr_for_program( # dist_context.finalize_distributed_attr_for_program(
# complete_train_program) # complete_train_program)
# from paddle.distributed.auto_parallel.interface import _g_process_mesh_map # from paddle.distributed.auto_parallel.static.interface import _g_process_mesh_map
# for block in complete_train_program.blocks: # for block in complete_train_program.blocks:
# for tensor in block.vars.values(): # for tensor in block.vars.values():
# desc = tensor.desc # desc = tensor.desc
......
...@@ -18,8 +18,10 @@ import unittest ...@@ -18,8 +18,10 @@ import unittest
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static, tensor, utils from paddle import nn, static, tensor, utils
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext,
)
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from paddle.fluid import layers from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list from paddle.nn.layer.transformer import _convert_param_attr_to_list
......
...@@ -18,12 +18,16 @@ import paddle ...@@ -18,12 +18,16 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static, utils from paddle import nn, static, utils
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.cost_model import estimate_cost from paddle.distributed.auto_parallel.static.cost_model import estimate_cost
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer DistributedContext,
from paddle.distributed.auto_parallel.partitioner import Partitioner )
from paddle.distributed.auto_parallel.reshard import Resharder from paddle.distributed.auto_parallel.static.parallelizer import (
AutoParallelizer,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.auto_parallel.static.reshard import Resharder
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from paddle.fluid import core from paddle.fluid import core
......
...@@ -20,12 +20,20 @@ from test_auto_parallel_reshard import mlp_forward ...@@ -20,12 +20,20 @@ from test_auto_parallel_reshard import mlp_forward
import paddle import paddle
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_attribute import TensorDistAttr from paddle.distributed.auto_parallel.static.dist_attribute import (
from paddle.distributed.auto_parallel.dist_context import DistributedContext TensorDistAttr,
from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor )
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.partitioner import Partitioner DistributedContext,
)
from paddle.distributed.auto_parallel.static.dist_tensor import (
DistributedTensor,
)
from paddle.distributed.auto_parallel.static.parallelizer import (
AutoParallelizer,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import unittest import unittest
from paddle.distributed.auto_parallel.graph import Graph from paddle.distributed.auto_parallel.static.graph import Graph
class TestAutoParallelGraph(unittest.TestCase): class TestAutoParallelGraph(unittest.TestCase):
......
...@@ -23,17 +23,21 @@ import paddle ...@@ -23,17 +23,21 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import fluid, nn, static, utils from paddle import fluid, nn, static, utils
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.cluster import Cluster from paddle.distributed.auto_parallel.static.cluster import Cluster
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.mapper import ( DistributedContext,
)
from paddle.distributed.auto_parallel.static.mapper import (
get_comm_volume, get_comm_volume,
get_dtype_bytes, get_dtype_bytes,
mapping, mapping,
) )
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.static.parallelizer import (
from paddle.distributed.auto_parallel.partitioner import Partitioner AutoParallelizer,
from paddle.distributed.auto_parallel.reshard import Resharder )
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.auto_parallel.static.reshard import Resharder
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from paddle.fluid import core from paddle.fluid import core
......
...@@ -19,11 +19,15 @@ import paddle ...@@ -19,11 +19,15 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static, tensor, utils from paddle import nn, static, tensor, utils
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.partitioner import Partitioner DistributedContext,
from paddle.distributed.auto_parallel.process_group import new_process_group )
from paddle.distributed.auto_parallel.utils import _get_comm_group from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.auto_parallel.static.process_group import (
new_process_group,
)
from paddle.distributed.auto_parallel.static.utils import _get_comm_group
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
paddle.enable_static() paddle.enable_static()
......
...@@ -18,11 +18,15 @@ import unittest ...@@ -18,11 +18,15 @@ import unittest
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static, tensor, utils from paddle import nn, static, tensor, utils
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.static.parallelizer import (
from paddle.distributed.auto_parallel.partitioner import Partitioner AutoParallelizer,
from paddle.distributed.auto_parallel.process_group import new_process_group )
from paddle.distributed.auto_parallel.utils import _get_comm_group from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.auto_parallel.static.process_group import (
new_process_group,
)
from paddle.distributed.auto_parallel.static.utils import _get_comm_group
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from paddle.fluid import layers from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list from paddle.nn.layer.transformer import _convert_param_attr_to_list
......
...@@ -18,15 +18,19 @@ import paddle ...@@ -18,15 +18,19 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static, utils from paddle import nn, static, utils
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer DistributedContext,
from paddle.distributed.auto_parallel.partitioner import Partitioner )
from paddle.distributed.auto_parallel.process_group import ( from paddle.distributed.auto_parallel.static.parallelizer import (
AutoParallelizer,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.auto_parallel.static.process_group import (
ProcessGroup, ProcessGroup,
_g_process_group_map, _g_process_group_map,
) )
from paddle.distributed.auto_parallel.reshard import Resharder from paddle.distributed.auto_parallel.static.reshard import Resharder
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
paddle.enable_static() paddle.enable_static()
......
...@@ -18,11 +18,15 @@ import paddle ...@@ -18,11 +18,15 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static, utils from paddle import nn, static, utils
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer DistributedContext,
from paddle.distributed.auto_parallel.partitioner import Partitioner )
from paddle.distributed.auto_parallel.reshard import Resharder from paddle.distributed.auto_parallel.static.parallelizer import (
AutoParallelizer,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.auto_parallel.static.reshard import Resharder
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
paddle.enable_static() paddle.enable_static()
......
...@@ -18,13 +18,17 @@ import paddle ...@@ -18,13 +18,17 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static, utils from paddle import nn, static, utils
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.cluster import Cluster from paddle.distributed.auto_parallel.static.cluster import Cluster
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.cost import CostEstimator from paddle.distributed.auto_parallel.static.cost import CostEstimator
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer DistributedContext,
from paddle.distributed.auto_parallel.partitioner import Partitioner )
from paddle.distributed.auto_parallel.reshard import Resharder from paddle.distributed.auto_parallel.static.parallelizer import (
AutoParallelizer,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.auto_parallel.static.reshard import Resharder
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
paddle.enable_static() paddle.enable_static()
......
...@@ -22,7 +22,7 @@ import paddle ...@@ -22,7 +22,7 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static, utils from paddle import nn, static, utils
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.static.dist_context import (
get_default_distributed_context, get_default_distributed_context,
) )
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
...@@ -80,6 +80,7 @@ class MLPLayer(nn.Layer): ...@@ -80,6 +80,7 @@ class MLPLayer(nn.Layer):
def mlp_forward(train_program, start_program): def mlp_forward(train_program, start_program):
print("mlp_forward outer", flush=True)
with static.program_guard( with static.program_guard(
train_program, start_program train_program, start_program
), utils.unique_name.guard(): ), utils.unique_name.guard():
...@@ -99,6 +100,7 @@ def mlp_forward(train_program, start_program): ...@@ -99,6 +100,7 @@ def mlp_forward(train_program, start_program):
elif _global_parallel_strategy == "dp": elif _global_parallel_strategy == "dp":
auto.shard_tensor(input, _global_process_mesh, ["x", None]) auto.shard_tensor(input, _global_process_mesh, ["x", None])
else: else:
print("mlp_forward inner", flush=True)
auto.shard_tensor(input, _global_process_mesh, [None, None]) auto.shard_tensor(input, _global_process_mesh, [None, None])
mlp = MLPLayer( mlp = MLPLayer(
...@@ -128,10 +130,14 @@ def get_dist_prog_with_parallelizer( ...@@ -128,10 +130,14 @@ def get_dist_prog_with_parallelizer(
dist_strategy.semi_auto = True dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy) fleet.init(is_collective=True, strategy=dist_strategy)
print("mlp_forward before", flush=True)
loss, train_program, startup_program = mlp_forward( loss, train_program, startup_program = mlp_forward(
train_program, startup_program train_program, startup_program
) )
print("mlp_forward after", flush=True)
optimizer = paddle.fluid.optimizer.AdamOptimizer( optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=0.00001, learning_rate=0.00001,
beta1=0.9, beta1=0.9,
...@@ -185,6 +191,7 @@ def check_send_recv_result(dist_main_prog, rank_id): ...@@ -185,6 +191,7 @@ def check_send_recv_result(dist_main_prog, rank_id):
) )
class TestMLPReshard(unittest.TestCase): class TestMLPReshard(unittest.TestCase):
def test_mlp_serial(self): def test_mlp_serial(self):
print("################-0")
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = None _global_parallel_strategy = None
global _global_process_mesh global _global_process_mesh
......
...@@ -17,13 +17,15 @@ import unittest ...@@ -17,13 +17,15 @@ import unittest
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static, utils from paddle import nn, static, utils
from paddle.distributed.auto_parallel.dist_attribute import ( from paddle.distributed.auto_parallel.static.dist_attribute import (
OperatorDistAttr, OperatorDistAttr,
TensorDistAttr, TensorDistAttr,
) )
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.planner import PlanSpace DistributedContext,
from paddle.distributed.auto_parallel.utils import ( )
from paddle.distributed.auto_parallel.static.planner import PlanSpace
from paddle.distributed.auto_parallel.static.utils import (
update_op_dims_mapping_by_default_dist_impl, update_op_dims_mapping_by_default_dist_impl,
update_op_dims_mapping_by_elementwise_like_dist_impl, update_op_dims_mapping_by_elementwise_like_dist_impl,
) )
...@@ -177,8 +179,10 @@ class TestMLPSearcher(unittest.TestCase): ...@@ -177,8 +179,10 @@ class TestMLPSearcher(unittest.TestCase):
set_default_dist_attr(train_program, dist_context, global_process_mesh) set_default_dist_attr(train_program, dist_context, global_process_mesh)
ops = train_program.global_block().ops ops = train_program.global_block().ops
vars = train_program.global_block().vars vars = train_program.global_block().vars
from paddle.distributed.auto_parallel.dist_op import DistributedOperator from paddle.distributed.auto_parallel.static.dist_op import (
from paddle.distributed.auto_parallel.operators.common import ( DistributedOperator,
)
from paddle.distributed.auto_parallel.static.operators.common import (
get_distributed_operator_impl_container, get_distributed_operator_impl_container,
is_elementwise_op, is_elementwise_op,
) )
......
...@@ -16,9 +16,11 @@ import unittest ...@@ -16,9 +16,11 @@ import unittest
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static, utils from paddle import nn, static, utils
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistAttr from paddle.distributed.auto_parallel.static.dist_attribute import (
from paddle.distributed.auto_parallel.dist_op import DistributedOperator OperatorDistAttr,
from paddle.distributed.auto_parallel.operators.common import ( )
from paddle.distributed.auto_parallel.static.dist_op import DistributedOperator
from paddle.distributed.auto_parallel.static.operators.common import (
get_distributed_operator_impl_container, get_distributed_operator_impl_container,
) )
from paddle.framework import core from paddle.framework import core
......
...@@ -16,9 +16,11 @@ import unittest ...@@ -16,9 +16,11 @@ import unittest
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static, utils from paddle import nn, static, utils
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistAttr from paddle.distributed.auto_parallel.static.dist_attribute import (
from paddle.distributed.auto_parallel.dist_op import DistributedOperator OperatorDistAttr,
from paddle.distributed.auto_parallel.operators.common import ( )
from paddle.distributed.auto_parallel.static.dist_op import DistributedOperator
from paddle.distributed.auto_parallel.static.operators.common import (
get_distributed_operator_impl_container, get_distributed_operator_impl_container,
) )
from paddle.fluid import core from paddle.fluid import core
......
...@@ -426,9 +426,11 @@ packages=['paddle', ...@@ -426,9 +426,11 @@ packages=['paddle',
'paddle.distributed.fleet.meta_parallel.sharding', 'paddle.distributed.fleet.meta_parallel.sharding',
'paddle.distributed.fleet.meta_parallel.parallel_layers', 'paddle.distributed.fleet.meta_parallel.parallel_layers',
'paddle.distributed.auto_parallel', 'paddle.distributed.auto_parallel',
'paddle.distributed.auto_parallel.operators', 'paddle.distributed.auto_parallel.dygraph',
'paddle.distributed.auto_parallel.tuner', 'paddle.distributed.auto_parallel.static',
'paddle.distributed.auto_parallel.cost', 'paddle.distributed.auto_parallel.static.operators',
'paddle.distributed.auto_parallel.static.tuner',
'paddle.distributed.auto_parallel.static.cost',
'paddle.distributed.passes', 'paddle.distributed.passes',
'paddle.distributed.models', 'paddle.distributed.models',
'paddle.distributed.models.moe', 'paddle.distributed.models.moe',
......
...@@ -1430,9 +1430,11 @@ def get_setup_parameters(): ...@@ -1430,9 +1430,11 @@ def get_setup_parameters():
'paddle.distributed.fleet.meta_parallel.sharding', 'paddle.distributed.fleet.meta_parallel.sharding',
'paddle.distributed.fleet.meta_parallel.parallel_layers', 'paddle.distributed.fleet.meta_parallel.parallel_layers',
'paddle.distributed.auto_parallel', 'paddle.distributed.auto_parallel',
'paddle.distributed.auto_parallel.operators', 'paddle.distributed.auto_parallel.dygraph',
'paddle.distributed.auto_parallel.tuner', 'paddle.distributed.auto_parallel.static',
'paddle.distributed.auto_parallel.cost', 'paddle.distributed.auto_parallel.static.operators',
'paddle.distributed.auto_parallel.static.tuner',
'paddle.distributed.auto_parallel.static.cost',
'paddle.distributed.passes', 'paddle.distributed.passes',
'paddle.distributed.models', 'paddle.distributed.models',
'paddle.distributed.models.moe', 'paddle.distributed.models.moe',
......
...@@ -120,7 +120,10 @@ class TestShardingStage2WithNewEXE(unittest.TestCase): ...@@ -120,7 +120,10 @@ class TestShardingStage2WithNewEXE(unittest.TestCase):
# bf16 # bf16
mp_bf16_engine = self.get_engine(use_amp=True) mp_bf16_engine = self.get_engine(use_amp=True)
if not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000: if not (
paddle.amp.is_bfloat16_supported()
and paddle.device.cuda.get_device_capability()[0] >= 8
):
return return
mp_bf16_history = mp_bf16_engine.fit( mp_bf16_history = mp_bf16_engine.fit(
......
...@@ -20,7 +20,7 @@ import paddle ...@@ -20,7 +20,7 @@ import paddle
from paddle import static from paddle import static
from paddle.distributed import fleet from paddle.distributed import fleet
sys.path.append("..") sys.path.append("../legacy_test")
import auto_parallel_gpt_model as modeling import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import ( from auto_parallel_gpt_model import (
GPTForPretraining, GPTForPretraining,
...@@ -151,7 +151,7 @@ def train(): ...@@ -151,7 +151,7 @@ def train():
}, },
fetch_list=[loss], fetch_list=[loss],
) )
print(f"step: {step}, loss: {loss_print[0]:f}") print(f"step: {step}, loss: {loss_print:f}")
else: else:
exe.run( exe.run(
distributed_main_program, distributed_main_program,
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
import paddle import paddle
from paddle import static from paddle import static
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.cluster import Cluster from paddle.distributed.auto_parallel.static.cluster import Cluster
from paddle.distributed.auto_parallel.cost import CostEstimator from paddle.distributed.auto_parallel.static.cost import CostEstimator
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.static.dist_context import (
get_default_distributed_context, get_default_distributed_context,
) )
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import numpy as np import numpy as np
import paddle import paddle
from paddle.distributed.auto_parallel.converter import Converter from paddle.distributed.auto_parallel.static.converter import Converter
def test_convert(): def test_convert():
......
...@@ -20,7 +20,9 @@ import numpy as np ...@@ -20,7 +20,9 @@ import numpy as np
import paddle import paddle
from paddle import fluid, nn, optimizer, static from paddle import fluid, nn, optimizer, static
from paddle.distributed.auto_parallel.auto_align_tool import AutoAlignTool from paddle.distributed.auto_parallel.static.auto_align_tool import (
AutoAlignTool,
)
from paddle.vision.datasets import MNIST from paddle.vision.datasets import MNIST
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
......
...@@ -23,21 +23,25 @@ import paddle ...@@ -23,21 +23,25 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static, utils from paddle import nn, static, utils
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.cluster import Cluster from paddle.distributed.auto_parallel.static.cluster import Cluster
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.cost import ( from paddle.distributed.auto_parallel.static.cost import (
AllreduceSumOpCost, AllreduceSumOpCost,
_g_op_cost_factory, _g_op_cost_factory,
) )
from paddle.distributed.auto_parallel.cost.base_cost import ( from paddle.distributed.auto_parallel.static.cost.base_cost import (
build_comm_costs_from_descs, build_comm_costs_from_descs,
build_comm_desc_from_dist_op, build_comm_desc_from_dist_op,
build_comp_costs_from_descs, build_comp_costs_from_descs,
build_comp_desc_from_dist_op, build_comp_desc_from_dist_op,
build_dp_costs, build_dp_costs,
) )
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer DistributedContext,
)
from paddle.distributed.auto_parallel.static.parallelizer import (
AutoParallelizer,
)
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
paddle.enable_static() paddle.enable_static()
......
...@@ -17,7 +17,7 @@ import os ...@@ -17,7 +17,7 @@ import os
import tempfile import tempfile
import unittest import unittest
from paddle.distributed.auto_parallel.cluster import ( from paddle.distributed.auto_parallel.static.cluster import (
Cluster, Cluster,
get_default_cluster, get_default_cluster,
) )
......
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
class TestClusterPartition(unittest.TestCase): class TestClusterPartition(unittest.TestCase):
def test_cluster_partition(self): def test_cluster_partition(self):
clusters = [(5, 8), (1, 8), (4, 8), (16, 8), (2, 8), (3, 8)] clusters = [(5, 8), (1, 8), (4, 8), (16, 8), (2, 8), (3, 8)]
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( from paddle.distributed.auto_parallel.static.tuner.rule_based_tuner import (
ClusterPartitionUtil, ClusterPartitionUtil,
) )
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import unittest import unittest
from paddle.distributed.auto_parallel.cluster_v2 import DeviceMesh from paddle.distributed.auto_parallel.static.cluster_v2 import DeviceMesh
from paddle.framework import core from paddle.framework import core
......
...@@ -20,8 +20,8 @@ import unittest ...@@ -20,8 +20,8 @@ import unittest
from test_cluster import cluster_json, multi_cluster_json from test_cluster import cluster_json, multi_cluster_json
import paddle import paddle
from paddle.distributed.auto_parallel.cluster import Cluster from paddle.distributed.auto_parallel.static.cluster import Cluster
from paddle.distributed.auto_parallel.cost import ( from paddle.distributed.auto_parallel.static.cost import (
AllgatherOpCost, AllgatherOpCost,
AllreduceSumOpCost, AllreduceSumOpCost,
BroadcastOpCost, BroadcastOpCost,
......
...@@ -18,8 +18,8 @@ import unittest ...@@ -18,8 +18,8 @@ import unittest
from test_cluster import cluster_json from test_cluster import cluster_json
from paddle.distributed.auto_parallel.cluster import Cluster from paddle.distributed.auto_parallel.static.cluster import Cluster
from paddle.distributed.auto_parallel.cost.comp_op_cost import ( from paddle.distributed.auto_parallel.static.cost.comp_op_cost import (
AssignOpCost, AssignOpCost,
AssignValueOpCost, AssignValueOpCost,
BeamSearchDecodeOpCost, BeamSearchDecodeOpCost,
......
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
class TestConvertToProcessMeshes(unittest.TestCase): class TestConvertToProcessMeshes(unittest.TestCase):
def test_convert_to_process_meshes(self): def test_convert_to_process_meshes(self):
device_meshes = [[1, 8], [4, 8], [15, 8]] device_meshes = [[1, 8], [4, 8], [15, 8]]
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( from paddle.distributed.auto_parallel.static.tuner.rule_based_tuner import (
convert_to_process_meshes, convert_to_process_meshes,
) )
......
...@@ -18,7 +18,7 @@ import sys ...@@ -18,7 +18,7 @@ import sys
import tempfile import tempfile
import unittest import unittest
from paddle.distributed.auto_parallel.converter import Converter from paddle.distributed.auto_parallel.static.converter import Converter
class TestConverter(unittest.TestCase): class TestConverter(unittest.TestCase):
......
...@@ -38,9 +38,11 @@ def make_program(): ...@@ -38,9 +38,11 @@ def make_program():
def parallelizer(program_func, rank): def parallelizer(program_func, rank):
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.partitioner import Partitioner DistributedContext,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
main_program, start_program = program_func() main_program, start_program = program_func()
......
...@@ -21,12 +21,12 @@ import paddle ...@@ -21,12 +21,12 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static from paddle import nn, static
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext, DistributedContext,
set_default_distributed_context, set_default_distributed_context,
) )
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh from paddle.distributed.auto_parallel.static.utils import (
from paddle.distributed.auto_parallel.utils import (
_copy_dist_attr_from_cpp, _copy_dist_attr_from_cpp,
_copy_dist_attr_from_cpp_for_graph, _copy_dist_attr_from_cpp_for_graph,
_copy_dist_attr_to_cpp, _copy_dist_attr_to_cpp,
......
...@@ -21,7 +21,9 @@ import paddle ...@@ -21,7 +21,9 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static from paddle import nn, static
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext,
)
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
paddle.enable_static() paddle.enable_static()
......
...@@ -103,9 +103,11 @@ def matmulv2_dp2mp2(init_x, init_y, trans_x, trans_y): ...@@ -103,9 +103,11 @@ def matmulv2_dp2mp2(init_x, init_y, trans_x, trans_y):
def parallelizer(program_func, *args, **kwargs): def parallelizer(program_func, *args, **kwargs):
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.partitioner import Partitioner DistributedContext,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
main_program, start_program, loss = program_func(*args, **kwargs) main_program, start_program, loss = program_func(*args, **kwargs)
......
...@@ -16,8 +16,8 @@ import copy ...@@ -16,8 +16,8 @@ import copy
import unittest import unittest
import paddle import paddle
from paddle.distributed.auto_parallel.cluster import Cluster from paddle.distributed.auto_parallel.static.cluster import Cluster
from paddle.distributed.auto_parallel.operators.common import ( from paddle.distributed.auto_parallel.static.operators.common import (
get_distributed_operator_impl_container, get_distributed_operator_impl_container,
is_elementwise_op, is_elementwise_op,
) )
...@@ -29,8 +29,10 @@ paddle.enable_static() ...@@ -29,8 +29,10 @@ paddle.enable_static()
def parallelizer(program_func, rank): def parallelizer(program_func, rank):
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext,
)
main_program, startup_program, loss = program_func() main_program, startup_program, loss = program_func()
......
...@@ -75,9 +75,11 @@ def make_program_serial(): ...@@ -75,9 +75,11 @@ def make_program_serial():
def parallelizer(program_func, rank): def parallelizer(program_func, rank):
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.partitioner import Partitioner DistributedContext,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
main_program, start_program, loss = program_func() main_program, start_program, loss = program_func()
......
...@@ -37,9 +37,11 @@ def make_program_dp2(): ...@@ -37,9 +37,11 @@ def make_program_dp2():
def parallelizer(program_func, rank): def parallelizer(program_func, rank):
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.partitioner import Partitioner DistributedContext,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
main_program, start_program = program_func() main_program, start_program = program_func()
......
...@@ -34,9 +34,11 @@ def make_program(): ...@@ -34,9 +34,11 @@ def make_program():
def parallelizer(program_func, rank): def parallelizer(program_func, rank):
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.partitioner import Partitioner DistributedContext,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
main_program, start_program = program_func() main_program, start_program = program_func()
......
...@@ -34,9 +34,11 @@ def make_program(): ...@@ -34,9 +34,11 @@ def make_program():
def parallelizer(program_func, rank): def parallelizer(program_func, rank):
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.partitioner import Partitioner DistributedContext,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
main_program, start_program = program_func() main_program, start_program = program_func()
......
...@@ -56,9 +56,11 @@ def make_program_serial(): ...@@ -56,9 +56,11 @@ def make_program_serial():
def parallelizer(program_func, rank): def parallelizer(program_func, rank):
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.partitioner import Partitioner DistributedContext,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
main_program, start_program = program_func() main_program, start_program = program_func()
......
...@@ -34,9 +34,11 @@ def make_program_dp2(): ...@@ -34,9 +34,11 @@ def make_program_dp2():
def parallelizer(program_func, rank): def parallelizer(program_func, rank):
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.partitioner import Partitioner DistributedContext,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
main_program, start_program = program_func() main_program, start_program = program_func()
......
...@@ -20,7 +20,7 @@ import unittest ...@@ -20,7 +20,7 @@ import unittest
import paddle import paddle
import paddle.vision.transforms as T import paddle.vision.transforms as T
from paddle.distributed.auto_parallel.callbacks import config_callbacks from paddle.distributed.auto_parallel.static.callbacks import config_callbacks
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from paddle.static import InputSpec from paddle.static import InputSpec
from paddle.vision.datasets import MNIST from paddle.vision.datasets import MNIST
......
...@@ -64,9 +64,11 @@ def make_program(): ...@@ -64,9 +64,11 @@ def make_program():
def parallelizer(program_func, rank): def parallelizer(program_func, rank):
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.partitioner import Partitioner DistributedContext,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
main_program, start_program = program_func() main_program, start_program = program_func()
......
...@@ -112,10 +112,10 @@ class TestGroupOperators(unittest.TestCase): ...@@ -112,10 +112,10 @@ class TestGroupOperators(unittest.TestCase):
sequence_len, sequence_len,
vocab_size, vocab_size,
) )
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext, DistributedContext,
) )
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( from paddle.distributed.auto_parallel.static.tuner.rule_based_tuner import (
RuleBasedTuner, RuleBasedTuner,
) )
......
...@@ -17,10 +17,10 @@ import unittest ...@@ -17,10 +17,10 @@ import unittest
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static from paddle import nn, static
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.static.dist_context import (
get_default_distributed_context, get_default_distributed_context,
) )
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
paddle.enable_static() paddle.enable_static()
......
...@@ -20,10 +20,10 @@ import unittest ...@@ -20,10 +20,10 @@ import unittest
from test_cluster import cluster_json from test_cluster import cluster_json
import paddle import paddle
import paddle.distributed.auto_parallel.cost as cost_model import paddle.distributed.auto_parallel.static.cost as cost_model
from paddle.distributed.auto_parallel.cluster import Cluster from paddle.distributed.auto_parallel.static.cluster import Cluster
from paddle.distributed.auto_parallel.cost import CommContext from paddle.distributed.auto_parallel.static.cost import CommContext
from paddle.distributed.auto_parallel.cost.base_cost import ( from paddle.distributed.auto_parallel.static.cost.base_cost import (
build_comp_desc_from_op, build_comp_desc_from_op,
build_comp_desc_str_for_predict, build_comp_desc_str_for_predict,
calc_time_by_modeling, calc_time_by_modeling,
......
...@@ -18,13 +18,15 @@ import unittest ...@@ -18,13 +18,15 @@ import unittest
import paddle import paddle
from paddle import static from paddle import static
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.cluster import Cluster from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.static.cluster import Cluster
from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext, DistributedContext,
set_default_distributed_context, set_default_distributed_context,
) )
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh from paddle.distributed.auto_parallel.static.tuner.parallel_tuner import (
from paddle.distributed.auto_parallel.tuner.parallel_tuner import ParallelTuner ParallelTuner,
)
sys.path.append("../legacy_test") sys.path.append("../legacy_test")
import auto_parallel_gpt_model as modeling import auto_parallel_gpt_model as modeling
......
...@@ -18,15 +18,17 @@ import unittest ...@@ -18,15 +18,17 @@ import unittest
import paddle import paddle
from paddle import static from paddle import static
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.cluster import Cluster from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.static.cluster import Cluster
from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext, DistributedContext,
set_default_distributed_context, set_default_distributed_context,
) )
from paddle.distributed.auto_parallel.planner_v2 import Planner from paddle.distributed.auto_parallel.static.planner_v2 import Planner
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh from paddle.distributed.auto_parallel.static.tuner.parallel_tuner import (
ParallelTuner,
)
from paddle.distributed.auto_parallel.strategy import Strategy from paddle.distributed.auto_parallel.strategy import Strategy
from paddle.distributed.auto_parallel.tuner.parallel_tuner import ParallelTuner
sys.path.append("../legacy_test") sys.path.append("../legacy_test")
import auto_parallel_gpt_model as modeling import auto_parallel_gpt_model as modeling
......
...@@ -18,13 +18,15 @@ import unittest ...@@ -18,13 +18,15 @@ import unittest
import paddle import paddle
from paddle import static from paddle import static
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.cluster import Cluster from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.static.cluster import Cluster
from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext, DistributedContext,
set_default_distributed_context, set_default_distributed_context,
) )
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh from paddle.distributed.auto_parallel.static.tuner.parallel_tuner import (
from paddle.distributed.auto_parallel.tuner.parallel_tuner import ParallelTuner ParallelTuner,
)
sys.path.append("../legacy_test") sys.path.append("../legacy_test")
import auto_parallel_gpt_model as modeling import auto_parallel_gpt_model as modeling
......
...@@ -112,7 +112,7 @@ class TestGroupOperatorsAndPatterns(unittest.TestCase): ...@@ -112,7 +112,7 @@ class TestGroupOperatorsAndPatterns(unittest.TestCase):
sequence_len, sequence_len,
vocab_size, vocab_size,
) )
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( from paddle.distributed.auto_parallel.static.tuner.rule_based_tuner import (
_PATTERNS, _PATTERNS,
GraphUtil, GraphUtil,
) )
......
...@@ -112,10 +112,10 @@ class TestPatternMatch(unittest.TestCase): ...@@ -112,10 +112,10 @@ class TestPatternMatch(unittest.TestCase):
sequence_len, sequence_len,
vocab_size, vocab_size,
) )
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext, DistributedContext,
) )
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( from paddle.distributed.auto_parallel.static.tuner.rule_based_tuner import (
GraphUtil, GraphUtil,
RuleBasedTuner, RuleBasedTuner,
) )
......
...@@ -15,13 +15,13 @@ ...@@ -15,13 +15,13 @@
import unittest import unittest
import paddle import paddle
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext, DistributedContext,
get_default_distributed_context, get_default_distributed_context,
) )
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import set_var_dist_attr from paddle.distributed.auto_parallel.static.utils import set_var_dist_attr
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.incubate.autograd import enable_prim from paddle.incubate.autograd import enable_prim
......
...@@ -19,14 +19,14 @@ import numpy as np ...@@ -19,14 +19,14 @@ import numpy as np
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static from paddle import nn, static
from paddle.distributed.auto_parallel.dist_context import (
get_default_distributed_context,
)
from paddle.distributed.auto_parallel.process_mesh import ( from paddle.distributed.auto_parallel.process_mesh import (
ProcessMesh, ProcessMesh,
compute_compatible_process_mesh, compute_compatible_process_mesh,
merge_process_meshes, merge_process_meshes,
) )
from paddle.distributed.auto_parallel.static.dist_context import (
get_default_distributed_context,
)
paddle.enable_static() paddle.enable_static()
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import unittest import unittest
from paddle.distributed.auto_parallel.process_mesh_v2 import ( from paddle.distributed.auto_parallel.static.process_mesh_v2 import (
ProcessMesh, ProcessMesh,
compute_compatible_process_mesh, compute_compatible_process_mesh,
merge_process_mesh, merge_process_mesh,
......
...@@ -16,7 +16,7 @@ import unittest ...@@ -16,7 +16,7 @@ import unittest
import numpy as np import numpy as np
from paddle.distributed.auto_parallel.tuner import recorder as rd from paddle.distributed.auto_parallel.static.tuner import recorder as rd
class TestRecorder(unittest.TestCase): class TestRecorder(unittest.TestCase):
......
...@@ -112,11 +112,11 @@ class TestRuleBasedTuner(unittest.TestCase): ...@@ -112,11 +112,11 @@ class TestRuleBasedTuner(unittest.TestCase):
sequence_len, sequence_len,
vocab_size, vocab_size,
) )
from paddle.distributed.auto_parallel.cluster import Cluster from paddle.distributed.auto_parallel.static.cluster import Cluster
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext, DistributedContext,
) )
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( from paddle.distributed.auto_parallel.static.tuner.rule_based_tuner import (
RuleBasedTuner, RuleBasedTuner,
) )
......
...@@ -112,11 +112,11 @@ class TestRuleBasedTuner(unittest.TestCase): ...@@ -112,11 +112,11 @@ class TestRuleBasedTuner(unittest.TestCase):
sequence_len, sequence_len,
vocab_size, vocab_size,
) )
from paddle.distributed.auto_parallel.cluster import Cluster from paddle.distributed.auto_parallel.static.cluster import Cluster
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext, DistributedContext,
) )
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( from paddle.distributed.auto_parallel.static.tuner.rule_based_tuner import (
RuleBasedTuner, RuleBasedTuner,
) )
......
...@@ -20,11 +20,11 @@ import paddle ...@@ -20,11 +20,11 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static from paddle import nn, static
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext, DistributedContext,
set_default_distributed_context, set_default_distributed_context,
) )
from paddle.distributed.auto_parallel.process_mesh_v2 import ProcessMesh from paddle.distributed.auto_parallel.static.process_mesh_v2 import ProcessMesh
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from paddle.fluid.core import TensorDistAttr from paddle.fluid.core import TensorDistAttr
from paddle.fluid.framework import Program from paddle.fluid.framework import Program
......
...@@ -19,7 +19,10 @@ import numpy as np ...@@ -19,7 +19,10 @@ import numpy as np
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import LazyGuard, nn from paddle import LazyGuard, nn
from paddle.distributed.auto_parallel.helper import ProgramHelper, ProxyLayer from paddle.distributed.auto_parallel.static.helper import (
ProgramHelper,
ProxyLayer,
)
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from paddle.framework import in_dynamic_mode from paddle.framework import in_dynamic_mode
from paddle.io import Dataset from paddle.io import Dataset
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import unittest import unittest
from paddle.distributed.auto_parallel.topo import SingleNodeTopology from paddle.distributed.auto_parallel.static.topo import SingleNodeTopology
def check_empty_json_object(json_object): def check_empty_json_object(json_object):
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
import unittest import unittest
from paddle.distributed.auto_parallel.tuner import trial as tr from paddle.distributed.auto_parallel.static.tuner import trial as tr
from paddle.distributed.auto_parallel.tuner import tunable_space as ts from paddle.distributed.auto_parallel.static.tuner import tunable_space as ts
class TestTiral(unittest.TestCase): class TestTiral(unittest.TestCase):
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import unittest import unittest
from paddle.distributed.auto_parallel.tuner import tunable_space as ts from paddle.distributed.auto_parallel.static.tuner import tunable_space as ts
class TestTunableSpace(unittest.TestCase): class TestTunableSpace(unittest.TestCase):
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import unittest import unittest
from paddle.distributed.auto_parallel.tuner import tunable_variable as tv from paddle.distributed.auto_parallel.static.tuner import tunable_variable as tv
class TestTunableVariable(unittest.TestCase): class TestTunableVariable(unittest.TestCase):
......
...@@ -20,8 +20,10 @@ import paddle ...@@ -20,8 +20,10 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static from paddle import nn, static
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext,
)
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
paddle.enable_static() paddle.enable_static()
......
...@@ -20,12 +20,12 @@ import paddle ...@@ -20,12 +20,12 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import fluid, nn, static from paddle import fluid, nn, static
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.static.dist_context import (
get_default_distributed_context, get_default_distributed_context,
) )
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import make_data_unshard from paddle.distributed.auto_parallel.static.utils import make_data_unshard
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
paddle.enable_static() paddle.enable_static()
......
...@@ -23,10 +23,10 @@ from auto_parallel_pass_test_base import AutoPallelPassTestBase ...@@ -23,10 +23,10 @@ from auto_parallel_pass_test_base import AutoPallelPassTestBase
import paddle import paddle
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.static.dist_context import (
get_default_distributed_context, get_default_distributed_context,
) )
from paddle.distributed.auto_parallel.operators.common import ( from paddle.distributed.auto_parallel.static.operators.common import (
is_data_parallel_reduce_op, is_data_parallel_reduce_op,
) )
from paddle.distributed.passes import PassContext, new_pass from paddle.distributed.passes import PassContext, new_pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册