未验证 提交 7f696804 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Reorganize the fold structure (#54059)

* [Auto Parallel] Reorganize the fold structure

* [Auto Parallel] Fix some import errors
上级 88e43625
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from .strategy import Strategy from .strategy import Strategy
from .process_mesh import ProcessMesh from .process_mesh import ProcessMesh
from .engine import Engine from .static.engine import Engine
from .interface import shard_tensor from .interface import shard_tensor
from .interface import shard_op from .interface import shard_op
from .interface import recompute from .interface import recompute
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
import paddle import paddle
from .dist_context import get_default_distributed_context
from .dist_op import DistributedOperatorHelper
from .dist_tensor import DistributedTensor
from .process_mesh import ProcessMesh, get_current_process_mesh from .process_mesh import ProcessMesh, get_current_process_mesh
from .utils import ( from .static.dist_context import get_default_distributed_context
from .static.dist_op import DistributedOperatorHelper
from .static.dist_tensor import DistributedTensor
from .static.utils import (
__no_shape_var_type__, __no_shape_var_type__,
convert_to_dims_mapping, convert_to_dims_mapping,
verify_shard_spec, verify_shard_spec,
......
...@@ -140,12 +140,12 @@ class ProcessMesh(core.ProcessMesh): ...@@ -140,12 +140,12 @@ class ProcessMesh(core.ProcessMesh):
) )
# Store all process meshes # Store all process meshes
from .dist_context import get_default_distributed_context from .static.dist_context import get_default_distributed_context
default_dist_cxt = get_default_distributed_context() default_dist_cxt = get_default_distributed_context()
default_dist_cxt.add_process_mesh(self) default_dist_cxt.add_process_mesh(self)
# Add new processes to process group 0 # Add new processes to process group 0
from .process_group import get_process_group from .static.process_group import get_process_group
pg0 = get_process_group(0) pg0 = get_process_group(0)
pg0.add_ranks(self.process_ids) pg0.add_ranks(self.process_ids)
...@@ -204,14 +204,14 @@ class ProcessMesh(core.ProcessMesh): ...@@ -204,14 +204,14 @@ class ProcessMesh(core.ProcessMesh):
self._old_op_size = len(cur_block.ops) self._old_op_size = len(cur_block.ops)
def __exit__(self, exc_type, exc_value, exc_traceback): def __exit__(self, exc_type, exc_value, exc_traceback):
from .dist_op import DistributedOperator from .static.dist_op import DistributedOperator
from .dist_tensor import DistributedTensor from .static.dist_tensor import DistributedTensor
default_prog = paddle.static.default_main_program() default_prog = paddle.static.default_main_program()
cur_block = default_prog.current_block() cur_block = default_prog.current_block()
new_var_names = list(cur_block.vars.keys()) new_var_names = list(cur_block.vars.keys())
new_op_size = len(cur_block.ops) new_op_size = len(cur_block.ops)
from .dist_context import get_default_distributed_context from .static.dist_context import get_default_distributed_context
default_dist_ctx = get_default_distributed_context() default_dist_ctx = get_default_distributed_context()
for name in new_var_names: for name in new_var_names:
......
...@@ -17,7 +17,7 @@ import paddle ...@@ -17,7 +17,7 @@ import paddle
from ..utils.log_utils import get_logger from ..utils.log_utils import get_logger
from .process_mesh import retrive_unique_id_for_process_mesh from .process_mesh import retrive_unique_id_for_process_mesh
from .utils import _get_idx_in_axis from .static.utils import _get_idx_in_axis
_logger = get_logger(logging.INFO) _logger = get_logger(logging.INFO)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -21,11 +21,11 @@ import numpy as np ...@@ -21,11 +21,11 @@ import numpy as np
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from paddle.distributed.auto_parallel.converter import Converter from paddle.distributed.auto_parallel.static.converter import Converter
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.static.dist_context import (
get_default_distributed_context, get_default_distributed_context,
) )
from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.static.utils import (
is_backward_op, is_backward_op,
is_forward_op, is_forward_op,
is_loss_op, is_loss_op,
......
...@@ -24,7 +24,7 @@ from paddle.hapi.callbacks import ( ...@@ -24,7 +24,7 @@ from paddle.hapi.callbacks import (
ProgBarLogger, ProgBarLogger,
) )
from .interface import CollectionNames, get_collection from ..interface import CollectionNames, get_collection
def config_callbacks( def config_callbacks(
......
...@@ -20,7 +20,7 @@ from enum import IntEnum, unique ...@@ -20,7 +20,7 @@ from enum import IntEnum, unique
import paddle import paddle
from ..utils.log_utils import get_logger from ...utils.log_utils import get_logger
@unique @unique
......
...@@ -18,11 +18,11 @@ import logging ...@@ -18,11 +18,11 @@ import logging
from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.framework import core from paddle.framework import core
from ..process_mesh import ProcessMesh, compute_compatible_process_mesh
from .dist_attribute import OperatorDistAttr, TensorDistAttr from .dist_attribute import OperatorDistAttr, TensorDistAttr
from .dist_context import _node_id from .dist_context import _node_id
from .operators import find_compatible_distributed_operator_impls from .operators import find_compatible_distributed_operator_impls
from .process_group import get_world_process_group from .process_group import get_world_process_group
from .process_mesh import ProcessMesh, compute_compatible_process_mesh
from .utils import ( from .utils import (
__no_shape_var_type__, __no_shape_var_type__,
get_logger, get_logger,
...@@ -1641,7 +1641,7 @@ class Completer: ...@@ -1641,7 +1641,7 @@ class Completer:
"""Complete the annotation of vars and ops in the update phase for parallel program.""" """Complete the annotation of vars and ops in the update phase for parallel program."""
# Copy the dist tensors and dist ops annotated by users from the default context # Copy the dist tensors and dist ops annotated by users from the default context
# global mesh # global mesh
from paddle.distributed.auto_parallel.process_group import ( from paddle.distributed.auto_parallel.static.process_group import (
get_world_process_group, get_world_process_group,
) )
...@@ -1895,7 +1895,7 @@ class Completer: ...@@ -1895,7 +1895,7 @@ class Completer:
def _init_global_mesh_for_program(self): def _init_global_mesh_for_program(self):
# Copy the dist tensors and dist ops annotated by users from the default context # Copy the dist tensors and dist ops annotated by users from the default context
# global mesh # global mesh
from paddle.distributed.auto_parallel.process_group import ( from paddle.distributed.auto_parallel.static.process_group import (
get_world_process_group, get_world_process_group,
) )
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
import paddle import paddle
from ..utils.log_utils import get_logger from ...utils.log_utils import get_logger
class Converter: class Converter:
......
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
from functools import reduce from functools import reduce
import paddle import paddle
from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor from paddle.distributed.auto_parallel.static.dist_tensor import (
DistributedTensor,
)
from paddle.static import Variable from paddle.static import Variable
from .base_cost import Cost from .base_cost import Cost
......
...@@ -18,9 +18,9 @@ from collections import defaultdict ...@@ -18,9 +18,9 @@ from collections import defaultdict
from paddle.distributed.passes import PassContext from paddle.distributed.passes import PassContext
from paddle.framework import IrGraph, core, set_flags from paddle.framework import IrGraph, core, set_flags
from ..process_mesh import ProcessMesh
from .dist_op import DistributedOperator from .dist_op import DistributedOperator
from .dist_tensor import DistributedTensor from .dist_tensor import DistributedTensor
from .process_mesh import ProcessMesh
from .utils import ( from .utils import (
__no_shape_var_type__, __no_shape_var_type__,
_copy_dist_attr_to_cpp, _copy_dist_attr_to_cpp,
......
...@@ -23,7 +23,7 @@ import numpy as np ...@@ -23,7 +23,7 @@ import numpy as np
import paddle import paddle
from paddle.framework import core from paddle.framework import core
from ..utils.log_utils import get_logger from ...utils.log_utils import get_logger
from .process_group import _g_process_group_map from .process_group import _g_process_group_map
from .utils import get_dist_attr from .utils import get_dist_attr
......
...@@ -22,7 +22,7 @@ import random ...@@ -22,7 +22,7 @@ import random
import numpy as np import numpy as np
import paddle import paddle
import paddle.distributed.auto_parallel.utils as auto_utils import paddle.distributed.auto_parallel.static.utils as auto_utils
from paddle import static, utils from paddle import static, utils
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.fluid.executor import _to_name_str from paddle.fluid.executor import _to_name_str
...@@ -32,7 +32,9 @@ from paddle.framework import core, in_dynamic_mode ...@@ -32,7 +32,9 @@ from paddle.framework import core, in_dynamic_mode
from paddle.metric import Metric from paddle.metric import Metric
from paddle.static import InputSpec, Operator, Variable, global_scope from paddle.static import InputSpec, Operator, Variable, global_scope
from ..utils.log_utils import get_logger from ...utils.log_utils import get_logger
from ..interface import CollectionNames, fetch, get_collection
from ..strategy import Strategy
from .callbacks import config_callbacks from .callbacks import config_callbacks
from .cluster import Cluster, get_default_cluster from .cluster import Cluster, get_default_cluster
from .converter import Converter from .converter import Converter
...@@ -45,11 +47,9 @@ from .dist_loader import ( ...@@ -45,11 +47,9 @@ from .dist_loader import (
from .dist_op import DistributedOperator from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver from .dist_saver import DistributedSaver
from .helper import ProgramHelper from .helper import ProgramHelper
from .interface import CollectionNames, fetch, get_collection
from .parallelizer_v2 import Parallelizer from .parallelizer_v2 import Parallelizer
from .planner_v2 import Planner from .planner_v2 import Planner
from .process_group import get_all_process_groups, new_process_group from .process_group import get_all_process_groups, new_process_group
from .strategy import Strategy
class Engine: class Engine:
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from paddle.distributed.auto_parallel.process_group import ( from paddle.distributed.auto_parallel.static.process_group import (
get_world_process_group, get_world_process_group,
) )
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
......
...@@ -18,10 +18,10 @@ import paddle ...@@ -18,10 +18,10 @@ import paddle
from paddle.framework import core from paddle.framework import core
from paddle.utils import unique_name from paddle.utils import unique_name
from ...utils.log_utils import get_logger from ....utils.log_utils import get_logger
_logger = get_logger(logging.INFO) _logger = get_logger(logging.INFO)
from ..random import determinate_rng, is_enable_auto_rand_ctrl from ...random import determinate_rng, is_enable_auto_rand_ctrl
from ..utils import ( from ..utils import (
naive_set_dist_op_attr_for_program_by_mesh_and_mapping, naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr, set_var_dist_attr,
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License # limitations under the License
from paddle.common_ops_import import check_dtype, check_variable_and_dtype from paddle.common_ops_import import check_dtype, check_variable_and_dtype
from paddle.distributed.auto_parallel.cost.comm_op_cost import ( from paddle.distributed.auto_parallel.static.cost.comm_op_cost import (
AllreduceSumOpCost, AllreduceSumOpCost,
IdentityOpCost, IdentityOpCost,
) )
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
import logging import logging
from ...utils.log_utils import get_logger from ....utils.log_utils import get_logger
_logger = get_logger(logging.INFO) _logger = get_logger(logging.INFO)
from ..random import determinate_rng, is_enable_auto_rand_ctrl from ...random import determinate_rng, is_enable_auto_rand_ctrl
from .common import ( from .common import (
DistributedOperatorImplContainer, DistributedOperatorImplContainer,
register_distributed_operator_impl, register_distributed_operator_impl,
......
...@@ -18,10 +18,10 @@ import paddle ...@@ -18,10 +18,10 @@ import paddle
from paddle.framework import core from paddle.framework import core
from paddle.utils import unique_name from paddle.utils import unique_name
from ...utils.log_utils import get_logger from ....utils.log_utils import get_logger
_logger = get_logger(logging.INFO) _logger = get_logger(logging.INFO)
from ..random import determinate_rng, is_enable_auto_rand_ctrl from ...random import determinate_rng, is_enable_auto_rand_ctrl
from ..utils import ( from ..utils import (
naive_set_dist_op_attr_for_program_by_mesh_and_mapping, naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr, set_var_dist_attr,
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import copy import copy
from paddle.common_ops_import import check_dtype, check_variable_and_dtype from paddle.common_ops_import import check_dtype, check_variable_and_dtype
from paddle.distributed.auto_parallel.cost.comm_op_cost import ( from paddle.distributed.auto_parallel.static.cost.comm_op_cost import (
AllreduceSumOpCost, AllreduceSumOpCost,
IdentityOpCost, IdentityOpCost,
) )
......
...@@ -20,10 +20,10 @@ from paddle.distributed.passes import PassManager, new_pass ...@@ -20,10 +20,10 @@ from paddle.distributed.passes import PassManager, new_pass
from paddle.static import append_backward, program_guard from paddle.static import append_backward, program_guard
from paddle.utils import unique_name from paddle.utils import unique_name
from ..utils.log_utils import get_logger from ...utils.log_utils import get_logger
from ..random import init_auto_parallel_rng
from .partitioner import Partitioner from .partitioner import Partitioner
from .process_group import get_world_process_group from .process_group import get_world_process_group
from .random import init_auto_parallel_rng
from .reshard import Resharder from .reshard import Resharder
from .utils import set_grad_var_shape from .utils import set_grad_var_shape
......
...@@ -15,8 +15,10 @@ ...@@ -15,8 +15,10 @@
import copy import copy
import paddle import paddle
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.operators.common import ( DistributedContext,
)
from paddle.distributed.auto_parallel.static.operators.common import (
get_distributed_operator_impl_container, get_distributed_operator_impl_container,
) )
from paddle.framework import Program, core from paddle.framework import Program, core
......
...@@ -18,15 +18,17 @@ import pickle ...@@ -18,15 +18,17 @@ import pickle
import numpy as np import numpy as np
from paddle.distributed.auto_parallel.dist_attribute import ( from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.static.dist_attribute import (
OperatorDistAttr, OperatorDistAttr,
TensorDistAttr, TensorDistAttr,
) )
from paddle.distributed.auto_parallel.dist_op import DistributedOperator from paddle.distributed.auto_parallel.static.dist_op import DistributedOperator
from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor from paddle.distributed.auto_parallel.static.dist_tensor import (
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh DistributedTensor,
)
from ..utils.log_utils import get_logger from ...utils.log_utils import get_logger
from .completion import Completer from .completion import Completer
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
from .tuner.parallel_tuner import ParallelTuner from .tuner.parallel_tuner import ParallelTuner
......
...@@ -17,8 +17,8 @@ from collections import OrderedDict ...@@ -17,8 +17,8 @@ from collections import OrderedDict
import paddle import paddle
from paddle.framework import core from paddle.framework import core
from ..collective import _get_global_env, _new_ring_id from ...collective import _get_global_env, _new_ring_id
from ..utils.log_utils import get_logger from ...utils.log_utils import get_logger
from .utils import dygraph_guard from .utils import dygraph_guard
logger = get_logger("INFO", __name__) logger = get_logger("INFO", __name__)
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import copy import copy
import os import os
from ..strategy import Strategy from ...strategy import Strategy
_tuning_supported_passes = ["sharding", "recompute"] _tuning_supported_passes = ["sharding", "recompute"]
......
...@@ -27,16 +27,18 @@ import sys ...@@ -27,16 +27,18 @@ import sys
import time import time
import paddle import paddle
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.partitioner import Partitioner DistributedContext,
from paddle.distributed.auto_parallel.process_group import ( )
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.auto_parallel.static.process_group import (
clear_all_process_groups, clear_all_process_groups,
get_all_process_groups, get_all_process_groups,
new_process_group, new_process_group,
) )
from paddle.distributed.auto_parallel.reshard import Resharder from paddle.distributed.auto_parallel.static.reshard import Resharder
from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.static.utils import (
debug_program, debug_program,
set_grad_var_shape, set_grad_var_shape,
) )
...@@ -465,7 +467,7 @@ class OptimizationTuner: ...@@ -465,7 +467,7 @@ class OptimizationTuner:
] ]
) )
cmd_args = ( cmd_args = (
"-m paddle.distributed.auto_parallel.tuner.profiler" "-m paddle.distributed.auto_parallel.static.tuner.profiler"
+ " " + " "
+ profile_args + profile_args
) )
......
...@@ -21,13 +21,13 @@ from collections import defaultdict ...@@ -21,13 +21,13 @@ from collections import defaultdict
import numpy as np import numpy as np
from ...process_mesh import ProcessMesh
from ..completion import Completer from ..completion import Completer
from ..cost import CostEstimator from ..cost import CostEstimator
from ..dist_context import _node_id from ..dist_context import _node_id
from ..dist_op import DistributedOperator from ..dist_op import DistributedOperator
from ..operators.common import find_compatible_distributed_operator_impls from ..operators.common import find_compatible_distributed_operator_impls
from ..parallelizer_v2 import Parallelizer from ..parallelizer_v2 import Parallelizer
from ..process_mesh import ProcessMesh
from .trial import Trial, TrialStatus from .trial import Trial, TrialStatus
from .tunable_space import TunableSpace from .tunable_space import TunableSpace
from .tunable_variable import Boolean, IntRange from .tunable_variable import Boolean, IntRange
......
...@@ -21,10 +21,10 @@ import time ...@@ -21,10 +21,10 @@ import time
import traceback import traceback
import paddle import paddle
from paddle.distributed.auto_parallel.dist_loader import ( from paddle.distributed.auto_parallel.static.dist_loader import (
DistributedDataLoaderFromGenerator, DistributedDataLoaderFromGenerator,
) )
from paddle.distributed.auto_parallel.process_group import ( from paddle.distributed.auto_parallel.static.process_group import (
get_all_process_groups, get_all_process_groups,
new_process_group, new_process_group,
) )
......
...@@ -26,20 +26,24 @@ from functools import reduce ...@@ -26,20 +26,24 @@ from functools import reduce
import numpy as np import numpy as np
import paddle import paddle
from paddle.distributed.auto_parallel.cluster_v2 import DeviceMesh from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.cluster_v2 import DeviceMesh
from paddle.distributed.auto_parallel.cost import CostEstimator from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_attribute import ( from paddle.distributed.auto_parallel.static.cost import CostEstimator
from paddle.distributed.auto_parallel.static.dist_attribute import (
OperatorDistAttr, OperatorDistAttr,
TensorDistAttr, TensorDistAttr,
) )
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor DistributedContext,
from paddle.distributed.auto_parallel.process_group import ( )
from paddle.distributed.auto_parallel.static.dist_tensor import (
DistributedTensor,
)
from paddle.distributed.auto_parallel.static.process_group import (
get_world_process_group, get_world_process_group,
) )
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh from paddle.distributed.auto_parallel.static.utils import (
from paddle.distributed.auto_parallel.utils import (
is_gradient_clip_op, is_gradient_clip_op,
print_program_with_dist_attr, print_program_with_dist_attr,
) )
...@@ -48,7 +52,7 @@ from paddle.fluid import program_guard ...@@ -48,7 +52,7 @@ from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Parameter, unique_name from paddle.fluid.framework import Parameter, unique_name
from ...utils.log_utils import get_logger from ....utils.log_utils import get_logger
from ..graph import Graph from ..graph import Graph
_PATTERNS = {} _PATTERNS = {}
......
...@@ -27,8 +27,8 @@ from paddle.framework import core ...@@ -27,8 +27,8 @@ from paddle.framework import core
from paddle.framework.io_utils import is_belong_to_optimizer, is_parameter from paddle.framework.io_utils import is_belong_to_optimizer, is_parameter
from paddle.static import Variable from paddle.static import Variable
from ..process_mesh import ProcessMesh
from .dist_attribute import OperatorDistAttr, TensorDistAttr from .dist_attribute import OperatorDistAttr, TensorDistAttr
from .process_mesh import ProcessMesh
OpRole = core.op_proto_and_checker_maker.OpRole OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
...@@ -1868,7 +1868,7 @@ def get_lr(optimizer): ...@@ -1868,7 +1868,7 @@ def get_lr(optimizer):
def initialize_pg_in_full_mode(all_process_groups, cur_rank): def initialize_pg_in_full_mode(all_process_groups, cur_rank):
import socket import socket
from ..collective import _get_global_env from ...collective import _get_global_env
has_recv_by_socket = [] has_recv_by_socket = []
# This is a magic number # This is a magic number
...@@ -1946,7 +1946,7 @@ def is_recompute_op(op): ...@@ -1946,7 +1946,7 @@ def is_recompute_op(op):
def set_recompute_segments(model, losses, strategy, program): def set_recompute_segments(model, losses, strategy, program):
from ..passes.auto_parallel_recompute import RecomputeState from ...passes.auto_parallel_recompute import RecomputeState
if not losses: if not losses:
return return
...@@ -2054,7 +2054,7 @@ def validate_opt(optimizer): ...@@ -2054,7 +2054,7 @@ def validate_opt(optimizer):
def set_data_parallel(x): def set_data_parallel(x):
from .interface import ProcessMesh, shard_tensor from ..interface import ProcessMesh, shard_tensor
from .process_group import get_world_process_group from .process_group import get_world_process_group
world_ranks = get_world_process_group().ranks world_ranks = get_world_process_group().ranks
...@@ -2095,7 +2095,7 @@ def _copy_tensor_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr): ...@@ -2095,7 +2095,7 @@ def _copy_tensor_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr):
def _copy_tensor_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr): def _copy_tensor_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr):
from .process_mesh import ProcessMesh from ..process_mesh import ProcessMesh
cpp_process_mesh = cpp_dist_attr.process_mesh cpp_process_mesh = cpp_dist_attr.process_mesh
if cpp_process_mesh is not None: if cpp_process_mesh is not None:
...@@ -2128,7 +2128,7 @@ def _copy_op_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr): ...@@ -2128,7 +2128,7 @@ def _copy_op_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr):
def _copy_op_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr): def _copy_op_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr):
from .process_mesh import ProcessMesh from ..process_mesh import ProcessMesh
cpp_process_mesh = cpp_dist_attr.process_mesh cpp_process_mesh = cpp_dist_attr.process_mesh
if cpp_process_mesh is not None: if cpp_process_mesh is not None:
......
...@@ -1335,7 +1335,7 @@ class Fleet: ...@@ -1335,7 +1335,7 @@ class Fleet:
self._user_defined_strategy.semi_auto self._user_defined_strategy.semi_auto
or self._user_defined_strategy.auto_search or self._user_defined_strategy.auto_search
): ):
from ..auto_parallel.parallelizer import AutoParallelizer from ..auto_parallel.static.parallelizer import AutoParallelizer
auto_parallelizer = AutoParallelizer(self) auto_parallelizer = AutoParallelizer(self)
( (
......
...@@ -13,11 +13,13 @@ ...@@ -13,11 +13,13 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistAttr from paddle.distributed.auto_parallel.static.dist_attribute import (
from paddle.distributed.auto_parallel.process_group import ( OperatorDistAttr,
)
from paddle.distributed.auto_parallel.static.process_group import (
get_world_process_group, get_world_process_group,
) )
from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.static.utils import (
naive_set_dist_op_attr_for_program_by_mesh_and_mapping, naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr, set_var_dist_attr,
) )
...@@ -42,7 +44,7 @@ from paddle.static.amp.fp16_utils import ( ...@@ -42,7 +44,7 @@ from paddle.static.amp.fp16_utils import (
from paddle.utils import unique_name from paddle.utils import unique_name
from ..auto_parallel.process_mesh import ProcessMesh from ..auto_parallel.process_mesh import ProcessMesh
from ..auto_parallel.utils import ( from ..auto_parallel.static.utils import (
is_backward_op, is_backward_op,
is_forward_op, is_forward_op,
is_loss_grad_op, is_loss_grad_op,
......
...@@ -15,16 +15,16 @@ ...@@ -15,16 +15,16 @@
from collections import OrderedDict from collections import OrderedDict
import paddle import paddle
from paddle.distributed.auto_parallel.dist_attribute import ( from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.static.dist_attribute import (
OperatorDistAttr, OperatorDistAttr,
TensorDistAttr, TensorDistAttr,
) )
from paddle.distributed.auto_parallel.operators.common import ( from paddle.distributed.auto_parallel.static.operators.common import (
is_data_parallel_reduce_op, is_data_parallel_reduce_op,
is_data_parallel_scale_op, is_data_parallel_scale_op,
) )
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh from paddle.distributed.auto_parallel.static.utils import (
from paddle.distributed.auto_parallel.utils import (
find_higher_order_backward_op, find_higher_order_backward_op,
get_var_numel, get_var_numel,
insert_dependencies_for_vars, insert_dependencies_for_vars,
......
...@@ -16,11 +16,13 @@ from collections import defaultdict ...@@ -16,11 +16,13 @@ from collections import defaultdict
import paddle import paddle
from paddle.common_ops_import import check_type, check_variable_and_dtype from paddle.common_ops_import import check_type, check_variable_and_dtype
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistAttr from paddle.distributed.auto_parallel.static.dist_attribute import (
from paddle.distributed.auto_parallel.process_group import ( OperatorDistAttr,
)
from paddle.distributed.auto_parallel.static.process_group import (
get_world_process_group, get_world_process_group,
) )
from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.static.utils import (
is_backward_op, is_backward_op,
is_forward_op, is_forward_op,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping, naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
......
...@@ -19,18 +19,21 @@ import numpy as np ...@@ -19,18 +19,21 @@ import numpy as np
import paddle import paddle
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from ..auto_parallel.dist_attribute import OperatorDistAttr, TensorDistAttr from ..auto_parallel.process_mesh import ProcessMesh
from ..auto_parallel.operators.common import ( from ..auto_parallel.static.dist_attribute import (
OperatorDistAttr,
TensorDistAttr,
)
from ..auto_parallel.static.operators.common import (
SyncMode, SyncMode,
is_data_parallel_reduce_op, is_data_parallel_reduce_op,
) )
from ..auto_parallel.process_group import ( from ..auto_parallel.static.process_group import (
get_all_process_groups, get_all_process_groups,
get_world_process_group, get_world_process_group,
) )
from ..auto_parallel.process_mesh import ProcessMesh from ..auto_parallel.static.reshard import Resharder
from ..auto_parallel.reshard import Resharder from ..auto_parallel.static.utils import (
from ..auto_parallel.utils import (
_get_comm_group, _get_comm_group,
insert_dependencies_for_vars, insert_dependencies_for_vars,
is_gradient_clip_op, is_gradient_clip_op,
......
...@@ -15,11 +15,11 @@ ...@@ -15,11 +15,11 @@
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
import paddle import paddle
from paddle.distributed.auto_parallel.process_group import ( from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.static.process_group import (
get_world_process_group, get_world_process_group,
) )
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh from paddle.distributed.auto_parallel.static.utils import (
from paddle.distributed.auto_parallel.utils import (
is_optimize_op, is_optimize_op,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping, naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr, set_var_dist_attr,
......
...@@ -26,8 +26,11 @@ from paddle.static.quantization import ( ...@@ -26,8 +26,11 @@ from paddle.static.quantization import (
quant_config, quant_config,
) )
from ..auto_parallel.converter import Converter from ..auto_parallel.static.converter import Converter
from ..auto_parallel.dist_attribute import OperatorDistAttr, TensorDistAttr from ..auto_parallel.static.dist_attribute import (
OperatorDistAttr,
TensorDistAttr,
)
from .pass_base import PassBase, register_pass from .pass_base import PassBase, register_pass
TRANSFORM_PASS_OP_TYPES = list( TRANSFORM_PASS_OP_TYPES = list(
......
...@@ -26,8 +26,8 @@ from paddle.fluid.backward import ( ...@@ -26,8 +26,8 @@ from paddle.fluid.backward import (
from paddle.framework import core from paddle.framework import core
from paddle.utils import unique_name from paddle.utils import unique_name
from ..auto_parallel.dist_attribute import OperatorDistAttr from ..auto_parallel.static.dist_attribute import OperatorDistAttr
from ..auto_parallel.utils import ( from ..auto_parallel.static.utils import (
get_loss_op, get_loss_op,
insert_dependencies_for_two_ops, insert_dependencies_for_two_ops,
is_backward_op, is_backward_op,
......
...@@ -16,13 +16,15 @@ import logging ...@@ -16,13 +16,15 @@ import logging
from functools import reduce from functools import reduce
import paddle import paddle
from paddle.distributed.auto_parallel.operators.common import ( from paddle.distributed.auto_parallel.static.operators.common import (
ParallelMode, ParallelMode,
is_data_parallel_reduce_op, is_data_parallel_reduce_op,
is_parameter_related, is_parameter_related,
) )
from paddle.distributed.auto_parallel.process_group import new_process_group from paddle.distributed.auto_parallel.static.process_group import (
from paddle.distributed.auto_parallel.utils import ( new_process_group,
)
from paddle.distributed.auto_parallel.static.utils import (
_get_comm_group, _get_comm_group,
get_logger, get_logger,
get_var_numel, get_var_numel,
......
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.distributed.auto_parallel.operators.common import ( from paddle.distributed.auto_parallel.static.operators.common import (
is_amp_flag_sync_op, is_amp_flag_sync_op,
is_data_parallel_reduce_op, is_data_parallel_reduce_op,
is_global_norm_sync_op, is_global_norm_sync_op,
) )
from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.static.utils import (
OpRole, OpRole,
insert_dependencies_for_vars, insert_dependencies_for_vars,
) )
......
...@@ -1439,7 +1439,7 @@ def _append_backward_ops_( ...@@ -1439,7 +1439,7 @@ def _append_backward_ops_(
) )
else: else:
default_ctx = getattr( default_ctx = getattr(
paddle.distributed.auto_parallel.dist_context, paddle.distributed.auto_parallel.static.dist_context,
'_g_default_distributed_context', '_g_default_distributed_context',
None, None,
) )
......
...@@ -1681,7 +1681,7 @@ class Variable(metaclass=VariableMetaClass): ...@@ -1681,7 +1681,7 @@ class Variable(metaclass=VariableMetaClass):
if self.persistable: if self.persistable:
var_str = "persist " + var_str var_str = "persist " + var_str
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.static.dist_context import (
get_default_distributed_context, get_default_distributed_context,
) )
...@@ -3137,7 +3137,7 @@ class Operator: ...@@ -3137,7 +3137,7 @@ class Operator:
if i != len(attr_names) - 1: if i != len(attr_names) - 1:
attrs_str += ", " attrs_str += ", "
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.static.dist_context import (
get_default_distributed_context, get_default_distributed_context,
) )
......
...@@ -22,10 +22,10 @@ import paddle ...@@ -22,10 +22,10 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static, utils from paddle import nn, static, utils
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.static.dist_context import (
set_default_distributed_context, set_default_distributed_context,
) )
from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.static.utils import (
get_dist_attr, get_dist_attr,
load_checkpoint_into_program, load_checkpoint_into_program,
load_distributed_checkpoint, load_distributed_checkpoint,
......
...@@ -23,7 +23,7 @@ import paddle ...@@ -23,7 +23,7 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static, utils from paddle import nn, static, utils
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.static.utils import (
load_checkpoint_into_program, load_checkpoint_into_program,
save_distributed_checkpoint, save_distributed_checkpoint,
) )
......
...@@ -25,7 +25,7 @@ import numpy as np ...@@ -25,7 +25,7 @@ import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel import engine from paddle.distributed.auto_parallel.static import engine
from paddle.distributed.fleet.layers.mpu.mp_layers import ( from paddle.distributed.fleet.layers.mpu.mp_layers import (
ColumnParallelLinear, ColumnParallelLinear,
RowParallelLinear, RowParallelLinear,
......
...@@ -17,7 +17,7 @@ import os ...@@ -17,7 +17,7 @@ import os
import tempfile import tempfile
import unittest import unittest
from paddle.distributed.auto_parallel.cluster import ( from paddle.distributed.auto_parallel.static.cluster import (
Cluster, Cluster,
DeviceType, DeviceType,
LinkType, LinkType,
......
...@@ -18,8 +18,10 @@ import unittest.mock ...@@ -18,8 +18,10 @@ import unittest.mock
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static, tensor, utils from paddle import nn, static, tensor, utils
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext,
)
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
paddle.enable_static() paddle.enable_static()
...@@ -188,7 +190,7 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -188,7 +190,7 @@ class TestMLPAutoCompletion(unittest.TestCase):
# # dist_context) # # dist_context)
# dist_context.finalize_distributed_attr_for_program( # dist_context.finalize_distributed_attr_for_program(
# complete_train_program) # complete_train_program)
# from paddle.distributed.auto_parallel.interface import _g_process_mesh_map # from paddle.distributed.auto_parallel.static.interface import _g_process_mesh_map
# for block in complete_train_program.blocks: # for block in complete_train_program.blocks:
# for tensor in block.vars.values(): # for tensor in block.vars.values():
# desc = tensor.desc # desc = tensor.desc
......
...@@ -18,8 +18,10 @@ import unittest ...@@ -18,8 +18,10 @@ import unittest
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static, tensor, utils from paddle import nn, static, tensor, utils
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext,
)
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from paddle.fluid import layers from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list from paddle.nn.layer.transformer import _convert_param_attr_to_list
......
...@@ -18,12 +18,16 @@ import paddle ...@@ -18,12 +18,16 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static, utils from paddle import nn, static, utils
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.cost_model import estimate_cost from paddle.distributed.auto_parallel.static.cost_model import estimate_cost
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer DistributedContext,
from paddle.distributed.auto_parallel.partitioner import Partitioner )
from paddle.distributed.auto_parallel.reshard import Resharder from paddle.distributed.auto_parallel.static.parallelizer import (
AutoParallelizer,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.auto_parallel.static.reshard import Resharder
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from paddle.fluid import core from paddle.fluid import core
......
...@@ -20,12 +20,20 @@ from test_auto_parallel_reshard import mlp_forward ...@@ -20,12 +20,20 @@ from test_auto_parallel_reshard import mlp_forward
import paddle import paddle
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_attribute import TensorDistAttr from paddle.distributed.auto_parallel.static.dist_attribute import (
from paddle.distributed.auto_parallel.dist_context import DistributedContext TensorDistAttr,
from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor )
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.partitioner import Partitioner DistributedContext,
)
from paddle.distributed.auto_parallel.static.dist_tensor import (
DistributedTensor,
)
from paddle.distributed.auto_parallel.static.parallelizer import (
AutoParallelizer,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import unittest import unittest
from paddle.distributed.auto_parallel.graph import Graph from paddle.distributed.auto_parallel.static.graph import Graph
class TestAutoParallelGraph(unittest.TestCase): class TestAutoParallelGraph(unittest.TestCase):
......
...@@ -23,17 +23,21 @@ import paddle ...@@ -23,17 +23,21 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import fluid, nn, static, utils from paddle import fluid, nn, static, utils
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.cluster import Cluster from paddle.distributed.auto_parallel.static.cluster import Cluster
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.mapper import ( DistributedContext,
)
from paddle.distributed.auto_parallel.static.mapper import (
get_comm_volume, get_comm_volume,
get_dtype_bytes, get_dtype_bytes,
mapping, mapping,
) )
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.static.parallelizer import (
from paddle.distributed.auto_parallel.partitioner import Partitioner AutoParallelizer,
from paddle.distributed.auto_parallel.reshard import Resharder )
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.auto_parallel.static.reshard import Resharder
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from paddle.fluid import core from paddle.fluid import core
......
...@@ -19,11 +19,15 @@ import paddle ...@@ -19,11 +19,15 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static, tensor, utils from paddle import nn, static, tensor, utils
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.partitioner import Partitioner DistributedContext,
from paddle.distributed.auto_parallel.process_group import new_process_group )
from paddle.distributed.auto_parallel.utils import _get_comm_group from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.auto_parallel.static.process_group import (
new_process_group,
)
from paddle.distributed.auto_parallel.static.utils import _get_comm_group
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
paddle.enable_static() paddle.enable_static()
......
...@@ -18,11 +18,15 @@ import unittest ...@@ -18,11 +18,15 @@ import unittest
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static, tensor, utils from paddle import nn, static, tensor, utils
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.static.parallelizer import (
from paddle.distributed.auto_parallel.partitioner import Partitioner AutoParallelizer,
from paddle.distributed.auto_parallel.process_group import new_process_group )
from paddle.distributed.auto_parallel.utils import _get_comm_group from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.auto_parallel.static.process_group import (
new_process_group,
)
from paddle.distributed.auto_parallel.static.utils import _get_comm_group
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from paddle.fluid import layers from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list from paddle.nn.layer.transformer import _convert_param_attr_to_list
......
...@@ -18,15 +18,19 @@ import paddle ...@@ -18,15 +18,19 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn, static, utils from paddle import nn, static, utils
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.static.dist_context import (
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer DistributedContext,
from paddle.distributed.auto_parallel.partitioner import Partitioner )
from paddle.distributed.auto_parallel.process_group import ( from paddle.distributed.auto_parallel.static.parallelizer import (
AutoParallelizer,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.auto_parallel.static.process_group import (
ProcessGroup, ProcessGroup,
_g_process_group_map, _g_process_group_map,
) )
from paddle.distributed.auto_parallel.reshard import Resharder from paddle.distributed.auto_parallel.static.reshard import Resharder
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
paddle.enable_static() paddle.enable_static()
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册