未验证 提交 7f696804 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Reorganize the fold structure (#54059)

* [Auto Parallel] Reorganize the fold structure

* [Auto Parallel] Fix some import errors
上级 88e43625
......@@ -14,7 +14,7 @@
from .strategy import Strategy
from .process_mesh import ProcessMesh
from .engine import Engine
from .static.engine import Engine
from .interface import shard_tensor
from .interface import shard_op
from .interface import recompute
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -14,11 +14,11 @@
import paddle
from .dist_context import get_default_distributed_context
from .dist_op import DistributedOperatorHelper
from .dist_tensor import DistributedTensor
from .process_mesh import ProcessMesh, get_current_process_mesh
from .utils import (
from .static.dist_context import get_default_distributed_context
from .static.dist_op import DistributedOperatorHelper
from .static.dist_tensor import DistributedTensor
from .static.utils import (
__no_shape_var_type__,
convert_to_dims_mapping,
verify_shard_spec,
......
......@@ -140,12 +140,12 @@ class ProcessMesh(core.ProcessMesh):
)
# Store all process meshes
from .dist_context import get_default_distributed_context
from .static.dist_context import get_default_distributed_context
default_dist_cxt = get_default_distributed_context()
default_dist_cxt.add_process_mesh(self)
# Add new processes to process group 0
from .process_group import get_process_group
from .static.process_group import get_process_group
pg0 = get_process_group(0)
pg0.add_ranks(self.process_ids)
......@@ -204,14 +204,14 @@ class ProcessMesh(core.ProcessMesh):
self._old_op_size = len(cur_block.ops)
def __exit__(self, exc_type, exc_value, exc_traceback):
from .dist_op import DistributedOperator
from .dist_tensor import DistributedTensor
from .static.dist_op import DistributedOperator
from .static.dist_tensor import DistributedTensor
default_prog = paddle.static.default_main_program()
cur_block = default_prog.current_block()
new_var_names = list(cur_block.vars.keys())
new_op_size = len(cur_block.ops)
from .dist_context import get_default_distributed_context
from .static.dist_context import get_default_distributed_context
default_dist_ctx = get_default_distributed_context()
for name in new_var_names:
......
......@@ -17,7 +17,7 @@ import paddle
from ..utils.log_utils import get_logger
from .process_mesh import retrive_unique_id_for_process_mesh
from .utils import _get_idx_in_axis
from .static.utils import _get_idx_in_axis
_logger = get_logger(logging.INFO)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -21,11 +21,11 @@ import numpy as np
import paddle
import paddle.distributed as dist
from paddle.distributed.auto_parallel.converter import Converter
from paddle.distributed.auto_parallel.dist_context import (
from paddle.distributed.auto_parallel.static.converter import Converter
from paddle.distributed.auto_parallel.static.dist_context import (
get_default_distributed_context,
)
from paddle.distributed.auto_parallel.utils import (
from paddle.distributed.auto_parallel.static.utils import (
is_backward_op,
is_forward_op,
is_loss_op,
......
......@@ -24,7 +24,7 @@ from paddle.hapi.callbacks import (
ProgBarLogger,
)
from .interface import CollectionNames, get_collection
from ..interface import CollectionNames, get_collection
def config_callbacks(
......
......@@ -20,7 +20,7 @@ from enum import IntEnum, unique
import paddle
from ..utils.log_utils import get_logger
from ...utils.log_utils import get_logger
@unique
......
......@@ -18,11 +18,11 @@ import logging
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.framework import core
from ..process_mesh import ProcessMesh, compute_compatible_process_mesh
from .dist_attribute import OperatorDistAttr, TensorDistAttr
from .dist_context import _node_id
from .operators import find_compatible_distributed_operator_impls
from .process_group import get_world_process_group
from .process_mesh import ProcessMesh, compute_compatible_process_mesh
from .utils import (
__no_shape_var_type__,
get_logger,
......@@ -1641,7 +1641,7 @@ class Completer:
"""Complete the annotation of vars and ops in the update phase for parallel program."""
# Copy the dist tensors and dist ops annotated by users from the default context
# global mesh
from paddle.distributed.auto_parallel.process_group import (
from paddle.distributed.auto_parallel.static.process_group import (
get_world_process_group,
)
......@@ -1895,7 +1895,7 @@ class Completer:
def _init_global_mesh_for_program(self):
# Copy the dist tensors and dist ops annotated by users from the default context
# global mesh
from paddle.distributed.auto_parallel.process_group import (
from paddle.distributed.auto_parallel.static.process_group import (
get_world_process_group,
)
......
......@@ -19,7 +19,7 @@ import numpy as np
import paddle
from ..utils.log_utils import get_logger
from ...utils.log_utils import get_logger
class Converter:
......
......@@ -15,7 +15,9 @@
from functools import reduce
import paddle
from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor
from paddle.distributed.auto_parallel.static.dist_tensor import (
DistributedTensor,
)
from paddle.static import Variable
from .base_cost import Cost
......
......@@ -18,9 +18,9 @@ from collections import defaultdict
from paddle.distributed.passes import PassContext
from paddle.framework import IrGraph, core, set_flags
from ..process_mesh import ProcessMesh
from .dist_op import DistributedOperator
from .dist_tensor import DistributedTensor
from .process_mesh import ProcessMesh
from .utils import (
__no_shape_var_type__,
_copy_dist_attr_to_cpp,
......
......@@ -23,7 +23,7 @@ import numpy as np
import paddle
from paddle.framework import core
from ..utils.log_utils import get_logger
from ...utils.log_utils import get_logger
from .process_group import _g_process_group_map
from .utils import get_dist_attr
......
......@@ -22,7 +22,7 @@ import random
import numpy as np
import paddle
import paddle.distributed.auto_parallel.utils as auto_utils
import paddle.distributed.auto_parallel.static.utils as auto_utils
from paddle import static, utils
from paddle.distributed import fleet
from paddle.fluid.executor import _to_name_str
......@@ -32,7 +32,9 @@ from paddle.framework import core, in_dynamic_mode
from paddle.metric import Metric
from paddle.static import InputSpec, Operator, Variable, global_scope
from ..utils.log_utils import get_logger
from ...utils.log_utils import get_logger
from ..interface import CollectionNames, fetch, get_collection
from ..strategy import Strategy
from .callbacks import config_callbacks
from .cluster import Cluster, get_default_cluster
from .converter import Converter
......@@ -45,11 +47,9 @@ from .dist_loader import (
from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver
from .helper import ProgramHelper
from .interface import CollectionNames, fetch, get_collection
from .parallelizer_v2 import Parallelizer
from .planner_v2 import Planner
from .process_group import get_all_process_groups, new_process_group
from .strategy import Strategy
class Engine:
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License
from paddle.distributed.auto_parallel.process_group import (
from paddle.distributed.auto_parallel.static.process_group import (
get_world_process_group,
)
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
......
......@@ -18,10 +18,10 @@ import paddle
from paddle.framework import core
from paddle.utils import unique_name
from ...utils.log_utils import get_logger
from ....utils.log_utils import get_logger
_logger = get_logger(logging.INFO)
from ..random import determinate_rng, is_enable_auto_rand_ctrl
from ...random import determinate_rng, is_enable_auto_rand_ctrl
from ..utils import (
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr,
......
......@@ -13,7 +13,7 @@
# limitations under the License
from paddle.common_ops_import import check_dtype, check_variable_and_dtype
from paddle.distributed.auto_parallel.cost.comm_op_cost import (
from paddle.distributed.auto_parallel.static.cost.comm_op_cost import (
AllreduceSumOpCost,
IdentityOpCost,
)
......
......@@ -14,10 +14,10 @@
import logging
from ...utils.log_utils import get_logger
from ....utils.log_utils import get_logger
_logger = get_logger(logging.INFO)
from ..random import determinate_rng, is_enable_auto_rand_ctrl
from ...random import determinate_rng, is_enable_auto_rand_ctrl
from .common import (
DistributedOperatorImplContainer,
register_distributed_operator_impl,
......
......@@ -18,10 +18,10 @@ import paddle
from paddle.framework import core
from paddle.utils import unique_name
from ...utils.log_utils import get_logger
from ....utils.log_utils import get_logger
_logger = get_logger(logging.INFO)
from ..random import determinate_rng, is_enable_auto_rand_ctrl
from ...random import determinate_rng, is_enable_auto_rand_ctrl
from ..utils import (
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr,
......
......@@ -15,7 +15,7 @@
import copy
from paddle.common_ops_import import check_dtype, check_variable_and_dtype
from paddle.distributed.auto_parallel.cost.comm_op_cost import (
from paddle.distributed.auto_parallel.static.cost.comm_op_cost import (
AllreduceSumOpCost,
IdentityOpCost,
)
......
......@@ -20,10 +20,10 @@ from paddle.distributed.passes import PassManager, new_pass
from paddle.static import append_backward, program_guard
from paddle.utils import unique_name
from ..utils.log_utils import get_logger
from ...utils.log_utils import get_logger
from ..random import init_auto_parallel_rng
from .partitioner import Partitioner
from .process_group import get_world_process_group
from .random import init_auto_parallel_rng
from .reshard import Resharder
from .utils import set_grad_var_shape
......
......@@ -15,8 +15,10 @@
import copy
import paddle
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.operators.common import (
from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext,
)
from paddle.distributed.auto_parallel.static.operators.common import (
get_distributed_operator_impl_container,
)
from paddle.framework import Program, core
......
......@@ -18,15 +18,17 @@ import pickle
import numpy as np
from paddle.distributed.auto_parallel.dist_attribute import (
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.static.dist_attribute import (
OperatorDistAttr,
TensorDistAttr,
)
from paddle.distributed.auto_parallel.dist_op import DistributedOperator
from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.static.dist_op import DistributedOperator
from paddle.distributed.auto_parallel.static.dist_tensor import (
DistributedTensor,
)
from ..utils.log_utils import get_logger
from ...utils.log_utils import get_logger
from .completion import Completer
from .dist_context import get_default_distributed_context
from .tuner.parallel_tuner import ParallelTuner
......
......@@ -17,8 +17,8 @@ from collections import OrderedDict
import paddle
from paddle.framework import core
from ..collective import _get_global_env, _new_ring_id
from ..utils.log_utils import get_logger
from ...collective import _get_global_env, _new_ring_id
from ...utils.log_utils import get_logger
from .utils import dygraph_guard
logger = get_logger("INFO", __name__)
......
......@@ -15,7 +15,7 @@
import copy
import os
from ..strategy import Strategy
from ...strategy import Strategy
_tuning_supported_passes = ["sharding", "recompute"]
......
......@@ -27,16 +27,18 @@ import sys
import time
import paddle
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.process_group import (
from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.auto_parallel.static.process_group import (
clear_all_process_groups,
get_all_process_groups,
new_process_group,
)
from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.utils import (
from paddle.distributed.auto_parallel.static.reshard import Resharder
from paddle.distributed.auto_parallel.static.utils import (
debug_program,
set_grad_var_shape,
)
......@@ -465,7 +467,7 @@ class OptimizationTuner:
]
)
cmd_args = (
"-m paddle.distributed.auto_parallel.tuner.profiler"
"-m paddle.distributed.auto_parallel.static.tuner.profiler"
+ " "
+ profile_args
)
......
......@@ -21,13 +21,13 @@ from collections import defaultdict
import numpy as np
from ...process_mesh import ProcessMesh
from ..completion import Completer
from ..cost import CostEstimator
from ..dist_context import _node_id
from ..dist_op import DistributedOperator
from ..operators.common import find_compatible_distributed_operator_impls
from ..parallelizer_v2 import Parallelizer
from ..process_mesh import ProcessMesh
from .trial import Trial, TrialStatus
from .tunable_space import TunableSpace
from .tunable_variable import Boolean, IntRange
......
......@@ -21,10 +21,10 @@ import time
import traceback
import paddle
from paddle.distributed.auto_parallel.dist_loader import (
from paddle.distributed.auto_parallel.static.dist_loader import (
DistributedDataLoaderFromGenerator,
)
from paddle.distributed.auto_parallel.process_group import (
from paddle.distributed.auto_parallel.static.process_group import (
get_all_process_groups,
new_process_group,
)
......
......@@ -26,20 +26,24 @@ from functools import reduce
import numpy as np
import paddle
from paddle.distributed.auto_parallel.cluster_v2 import DeviceMesh
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.cost import CostEstimator
from paddle.distributed.auto_parallel.dist_attribute import (
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.static.cluster_v2 import DeviceMesh
from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.static.cost import CostEstimator
from paddle.distributed.auto_parallel.static.dist_attribute import (
OperatorDistAttr,
TensorDistAttr,
)
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor
from paddle.distributed.auto_parallel.process_group import (
from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext,
)
from paddle.distributed.auto_parallel.static.dist_tensor import (
DistributedTensor,
)
from paddle.distributed.auto_parallel.static.process_group import (
get_world_process_group,
)
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.utils import (
from paddle.distributed.auto_parallel.static.utils import (
is_gradient_clip_op,
print_program_with_dist_attr,
)
......@@ -48,7 +52,7 @@ from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Parameter, unique_name
from ...utils.log_utils import get_logger
from ....utils.log_utils import get_logger
from ..graph import Graph
_PATTERNS = {}
......
......@@ -27,8 +27,8 @@ from paddle.framework import core
from paddle.framework.io_utils import is_belong_to_optimizer, is_parameter
from paddle.static import Variable
from ..process_mesh import ProcessMesh
from .dist_attribute import OperatorDistAttr, TensorDistAttr
from .process_mesh import ProcessMesh
OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
......@@ -1868,7 +1868,7 @@ def get_lr(optimizer):
def initialize_pg_in_full_mode(all_process_groups, cur_rank):
import socket
from ..collective import _get_global_env
from ...collective import _get_global_env
has_recv_by_socket = []
# This is a magic number
......@@ -1946,7 +1946,7 @@ def is_recompute_op(op):
def set_recompute_segments(model, losses, strategy, program):
from ..passes.auto_parallel_recompute import RecomputeState
from ...passes.auto_parallel_recompute import RecomputeState
if not losses:
return
......@@ -2054,7 +2054,7 @@ def validate_opt(optimizer):
def set_data_parallel(x):
from .interface import ProcessMesh, shard_tensor
from ..interface import ProcessMesh, shard_tensor
from .process_group import get_world_process_group
world_ranks = get_world_process_group().ranks
......@@ -2095,7 +2095,7 @@ def _copy_tensor_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr):
def _copy_tensor_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr):
from .process_mesh import ProcessMesh
from ..process_mesh import ProcessMesh
cpp_process_mesh = cpp_dist_attr.process_mesh
if cpp_process_mesh is not None:
......@@ -2128,7 +2128,7 @@ def _copy_op_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr):
def _copy_op_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr):
from .process_mesh import ProcessMesh
from ..process_mesh import ProcessMesh
cpp_process_mesh = cpp_dist_attr.process_mesh
if cpp_process_mesh is not None:
......
......@@ -1335,7 +1335,7 @@ class Fleet:
self._user_defined_strategy.semi_auto
or self._user_defined_strategy.auto_search
):
from ..auto_parallel.parallelizer import AutoParallelizer
from ..auto_parallel.static.parallelizer import AutoParallelizer
auto_parallelizer = AutoParallelizer(self)
(
......
......@@ -13,11 +13,13 @@
# limitations under the License.
import paddle
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistAttr
from paddle.distributed.auto_parallel.process_group import (
from paddle.distributed.auto_parallel.static.dist_attribute import (
OperatorDistAttr,
)
from paddle.distributed.auto_parallel.static.process_group import (
get_world_process_group,
)
from paddle.distributed.auto_parallel.utils import (
from paddle.distributed.auto_parallel.static.utils import (
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr,
)
......@@ -42,7 +44,7 @@ from paddle.static.amp.fp16_utils import (
from paddle.utils import unique_name
from ..auto_parallel.process_mesh import ProcessMesh
from ..auto_parallel.utils import (
from ..auto_parallel.static.utils import (
is_backward_op,
is_forward_op,
is_loss_grad_op,
......
......@@ -15,16 +15,16 @@
from collections import OrderedDict
import paddle
from paddle.distributed.auto_parallel.dist_attribute import (
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.static.dist_attribute import (
OperatorDistAttr,
TensorDistAttr,
)
from paddle.distributed.auto_parallel.operators.common import (
from paddle.distributed.auto_parallel.static.operators.common import (
is_data_parallel_reduce_op,
is_data_parallel_scale_op,
)
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.utils import (
from paddle.distributed.auto_parallel.static.utils import (
find_higher_order_backward_op,
get_var_numel,
insert_dependencies_for_vars,
......
......@@ -16,11 +16,13 @@ from collections import defaultdict
import paddle
from paddle.common_ops_import import check_type, check_variable_and_dtype
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistAttr
from paddle.distributed.auto_parallel.process_group import (
from paddle.distributed.auto_parallel.static.dist_attribute import (
OperatorDistAttr,
)
from paddle.distributed.auto_parallel.static.process_group import (
get_world_process_group,
)
from paddle.distributed.auto_parallel.utils import (
from paddle.distributed.auto_parallel.static.utils import (
is_backward_op,
is_forward_op,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
......
......@@ -19,18 +19,21 @@ import numpy as np
import paddle
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from ..auto_parallel.dist_attribute import OperatorDistAttr, TensorDistAttr
from ..auto_parallel.operators.common import (
from ..auto_parallel.process_mesh import ProcessMesh
from ..auto_parallel.static.dist_attribute import (
OperatorDistAttr,
TensorDistAttr,
)
from ..auto_parallel.static.operators.common import (
SyncMode,
is_data_parallel_reduce_op,
)
from ..auto_parallel.process_group import (
from ..auto_parallel.static.process_group import (
get_all_process_groups,
get_world_process_group,
)
from ..auto_parallel.process_mesh import ProcessMesh
from ..auto_parallel.reshard import Resharder
from ..auto_parallel.utils import (
from ..auto_parallel.static.reshard import Resharder
from ..auto_parallel.static.utils import (
_get_comm_group,
insert_dependencies_for_vars,
is_gradient_clip_op,
......
......@@ -15,11 +15,11 @@
from typing import Any, Dict, List, Tuple
import paddle
from paddle.distributed.auto_parallel.process_group import (
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.static.process_group import (
get_world_process_group,
)
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.utils import (
from paddle.distributed.auto_parallel.static.utils import (
is_optimize_op,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr,
......
......@@ -26,8 +26,11 @@ from paddle.static.quantization import (
quant_config,
)
from ..auto_parallel.converter import Converter
from ..auto_parallel.dist_attribute import OperatorDistAttr, TensorDistAttr
from ..auto_parallel.static.converter import Converter
from ..auto_parallel.static.dist_attribute import (
OperatorDistAttr,
TensorDistAttr,
)
from .pass_base import PassBase, register_pass
TRANSFORM_PASS_OP_TYPES = list(
......
......@@ -26,8 +26,8 @@ from paddle.fluid.backward import (
from paddle.framework import core
from paddle.utils import unique_name
from ..auto_parallel.dist_attribute import OperatorDistAttr
from ..auto_parallel.utils import (
from ..auto_parallel.static.dist_attribute import OperatorDistAttr
from ..auto_parallel.static.utils import (
get_loss_op,
insert_dependencies_for_two_ops,
is_backward_op,
......
......@@ -16,13 +16,15 @@ import logging
from functools import reduce
import paddle
from paddle.distributed.auto_parallel.operators.common import (
from paddle.distributed.auto_parallel.static.operators.common import (
ParallelMode,
is_data_parallel_reduce_op,
is_parameter_related,
)
from paddle.distributed.auto_parallel.process_group import new_process_group
from paddle.distributed.auto_parallel.utils import (
from paddle.distributed.auto_parallel.static.process_group import (
new_process_group,
)
from paddle.distributed.auto_parallel.static.utils import (
_get_comm_group,
get_logger,
get_var_numel,
......
......@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.distributed.auto_parallel.operators.common import (
from paddle.distributed.auto_parallel.static.operators.common import (
is_amp_flag_sync_op,
is_data_parallel_reduce_op,
is_global_norm_sync_op,
)
from paddle.distributed.auto_parallel.utils import (
from paddle.distributed.auto_parallel.static.utils import (
OpRole,
insert_dependencies_for_vars,
)
......
......@@ -1439,7 +1439,7 @@ def _append_backward_ops_(
)
else:
default_ctx = getattr(
paddle.distributed.auto_parallel.dist_context,
paddle.distributed.auto_parallel.static.dist_context,
'_g_default_distributed_context',
None,
)
......
......@@ -1681,7 +1681,7 @@ class Variable(metaclass=VariableMetaClass):
if self.persistable:
var_str = "persist " + var_str
from paddle.distributed.auto_parallel.dist_context import (
from paddle.distributed.auto_parallel.static.dist_context import (
get_default_distributed_context,
)
......@@ -3137,7 +3137,7 @@ class Operator:
if i != len(attr_names) - 1:
attrs_str += ", "
from paddle.distributed.auto_parallel.dist_context import (
from paddle.distributed.auto_parallel.static.dist_context import (
get_default_distributed_context,
)
......
......@@ -22,10 +22,10 @@ import paddle
import paddle.nn.functional as F
from paddle import nn, static, utils
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.dist_context import (
from paddle.distributed.auto_parallel.static.dist_context import (
set_default_distributed_context,
)
from paddle.distributed.auto_parallel.utils import (
from paddle.distributed.auto_parallel.static.utils import (
get_dist_attr,
load_checkpoint_into_program,
load_distributed_checkpoint,
......
......@@ -23,7 +23,7 @@ import paddle
import paddle.nn.functional as F
from paddle import nn, static, utils
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.utils import (
from paddle.distributed.auto_parallel.static.utils import (
load_checkpoint_into_program,
save_distributed_checkpoint,
)
......
......@@ -25,7 +25,7 @@ import numpy as np
import paddle
from paddle import distributed as dist
from paddle.distributed import fleet
from paddle.distributed.auto_parallel import engine
from paddle.distributed.auto_parallel.static import engine
from paddle.distributed.fleet.layers.mpu.mp_layers import (
ColumnParallelLinear,
RowParallelLinear,
......
......@@ -17,7 +17,7 @@ import os
import tempfile
import unittest
from paddle.distributed.auto_parallel.cluster import (
from paddle.distributed.auto_parallel.static.cluster import (
Cluster,
DeviceType,
LinkType,
......
......@@ -18,8 +18,10 @@ import unittest.mock
import paddle
import paddle.nn.functional as F
from paddle import nn, static, tensor, utils
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext,
)
from paddle.distributed.fleet import auto
paddle.enable_static()
......@@ -188,7 +190,7 @@ class TestMLPAutoCompletion(unittest.TestCase):
# # dist_context)
# dist_context.finalize_distributed_attr_for_program(
# complete_train_program)
# from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
# from paddle.distributed.auto_parallel.static.interface import _g_process_mesh_map
# for block in complete_train_program.blocks:
# for tensor in block.vars.values():
# desc = tensor.desc
......
......@@ -18,8 +18,10 @@ import unittest
import paddle
import paddle.nn.functional as F
from paddle import nn, static, tensor, utils
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext,
)
from paddle.distributed.fleet import auto
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list
......
......@@ -18,12 +18,16 @@ import paddle
import paddle.nn.functional as F
from paddle import nn, static, utils
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.cost_model import estimate_cost
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.static.cost_model import estimate_cost
from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext,
)
from paddle.distributed.auto_parallel.static.parallelizer import (
AutoParallelizer,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.auto_parallel.static.reshard import Resharder
from paddle.distributed.fleet import auto
from paddle.fluid import core
......
......@@ -20,12 +20,20 @@ from test_auto_parallel_reshard import mlp_forward
import paddle
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_attribute import TensorDistAttr
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.static.dist_attribute import (
TensorDistAttr,
)
from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext,
)
from paddle.distributed.auto_parallel.static.dist_tensor import (
DistributedTensor,
)
from paddle.distributed.auto_parallel.static.parallelizer import (
AutoParallelizer,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.fleet import auto
......
......@@ -14,7 +14,7 @@
import unittest
from paddle.distributed.auto_parallel.graph import Graph
from paddle.distributed.auto_parallel.static.graph import Graph
class TestAutoParallelGraph(unittest.TestCase):
......
......@@ -23,17 +23,21 @@ import paddle
import paddle.nn.functional as F
from paddle import fluid, nn, static, utils
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.cluster import Cluster
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.mapper import (
from paddle.distributed.auto_parallel.static.cluster import Cluster
from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext,
)
from paddle.distributed.auto_parallel.static.mapper import (
get_comm_volume,
get_dtype_bytes,
mapping,
)
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.static.parallelizer import (
AutoParallelizer,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.auto_parallel.static.reshard import Resharder
from paddle.distributed.fleet import auto
from paddle.fluid import core
......
......@@ -19,11 +19,15 @@ import paddle
import paddle.nn.functional as F
from paddle import nn, static, tensor, utils
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.process_group import new_process_group
from paddle.distributed.auto_parallel.utils import _get_comm_group
from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.auto_parallel.static.process_group import (
new_process_group,
)
from paddle.distributed.auto_parallel.static.utils import _get_comm_group
from paddle.distributed.fleet import auto
paddle.enable_static()
......
......@@ -18,11 +18,15 @@ import unittest
import paddle
import paddle.nn.functional as F
from paddle import nn, static, tensor, utils
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.process_group import new_process_group
from paddle.distributed.auto_parallel.utils import _get_comm_group
from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.static.parallelizer import (
AutoParallelizer,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.auto_parallel.static.process_group import (
new_process_group,
)
from paddle.distributed.auto_parallel.static.utils import _get_comm_group
from paddle.distributed.fleet import auto
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list
......
......@@ -18,15 +18,19 @@ import paddle
import paddle.nn.functional as F
from paddle import nn, static, utils
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.process_group import (
from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext,
)
from paddle.distributed.auto_parallel.static.parallelizer import (
AutoParallelizer,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.auto_parallel.static.process_group import (
ProcessGroup,
_g_process_group_map,
)
from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.static.reshard import Resharder
from paddle.distributed.fleet import auto
paddle.enable_static()
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册