未验证 提交 291c55a2 编写于 作者: W wangxiaoning 提交者: GitHub

[fluid clean]clean fluid.distribute_lookup_table (#50350)

* fluid clean

* fix optimizer

* fix distributed_transpiler

* fix fluid.__init__

* remove from fluid.init
上级 bf80664c
...@@ -67,7 +67,6 @@ from . import metrics ...@@ -67,7 +67,6 @@ from . import metrics
from . import transpiler from . import transpiler
from . import incubate from . import incubate
from .input import embedding, one_hot from .input import embedding, one_hot
from . import distribute_lookup_table
from .param_attr import ParamAttr, WeightNormParamAttr from .param_attr import ParamAttr, WeightNormParamAttr
from .data_feeder import DataFeeder from .data_feeder import DataFeeder
......
...@@ -16,11 +16,13 @@ ...@@ -16,11 +16,13 @@
__all__ = ["DistributedAdam", "FLEET_GLOBAL_DICT"] __all__ = ["DistributedAdam", "FLEET_GLOBAL_DICT"]
import paddle import paddle
from paddle.framework import core from paddle.framework import core
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table from paddle.distributed.distribute_lookup_table import (
from paddle.fluid.distribute_lookup_table import ( find_distributed_lookup_table,
)
from paddle.distributed.distribute_lookup_table import (
find_distributed_lookup_table_inputs, find_distributed_lookup_table_inputs,
) )
from paddle.fluid.distribute_lookup_table import ( from paddle.distributed.distribute_lookup_table import (
find_distributed_lookup_table_outputs, find_distributed_lookup_table_outputs,
) )
from google.protobuf import text_format from google.protobuf import text_format
......
...@@ -18,7 +18,7 @@ import logging ...@@ -18,7 +18,7 @@ import logging
from collections import defaultdict from collections import defaultdict
import paddle import paddle
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
from paddle.fluid.framework import ( from paddle.fluid.framework import (
Program, Program,
Variable, Variable,
...@@ -944,6 +944,10 @@ class Optimizer: ...@@ -944,6 +944,10 @@ class Optimizer:
:param loss: the loss variable. :param loss: the loss variable.
:param startup_program: the startup program :param startup_program: the startup program
""" """
from paddle.distributed.distribute_lookup_table import (
find_distributed_lookup_table,
)
program = framework.default_main_program() program = framework.default_main_program()
global_block = framework.default_main_program().global_block() global_block = framework.default_main_program().global_block()
table_name = find_distributed_lookup_table(program) table_name = find_distributed_lookup_table(program)
......
...@@ -50,7 +50,6 @@ from ..framework import ( ...@@ -50,7 +50,6 @@ from ..framework import (
) )
from .details import wait_server_ready, UnionFind, VarStruct, VarsDistributed from .details import wait_server_ready, UnionFind, VarStruct, VarsDistributed
from .details import delete_ops, find_op_by_output_arg from .details import delete_ops, find_op_by_output_arg
from ..distribute_lookup_table import find_distributed_lookup_table
from . import collective from . import collective
LOOKUP_TABLE_TYPE = ["lookup_table", "lookup_table_v2"] LOOKUP_TABLE_TYPE = ["lookup_table", "lookup_table_v2"]
...@@ -612,6 +611,9 @@ class DistributeTranspiler: ...@@ -612,6 +611,9 @@ class DistributeTranspiler:
sync_mode=False, sync_mode=False,
current_endpoint="127.0.0.1:7000") current_endpoint="127.0.0.1:7000")
""" """
from paddle.distributed.distribute_lookup_table import (
find_distributed_lookup_table,
)
err_msg = """ err_msg = """
......
...@@ -39,7 +39,6 @@ from ..framework import ( ...@@ -39,7 +39,6 @@ from ..framework import (
) )
from .details import wait_server_ready, VarsDistributed from .details import wait_server_ready, VarsDistributed
from .details import delete_ops from .details import delete_ops
from ..distribute_lookup_table import find_distributed_lookup_table
from .distribute_transpiler import ( from .distribute_transpiler import (
DistributeTranspiler, DistributeTranspiler,
DistributeTranspilerConfig, DistributeTranspilerConfig,
...@@ -48,6 +47,9 @@ from .distribute_transpiler import ( ...@@ -48,6 +47,9 @@ from .distribute_transpiler import (
ServerRuntimeConfig, ServerRuntimeConfig,
) )
from paddle.fluid.incubate.fleet.parameter_server.mode import DistributedMode from paddle.fluid.incubate.fleet.parameter_server.mode import DistributedMode
from paddle.distributed.distribute_lookup_table import (
find_distributed_lookup_table,
)
RPC_OP_ROLE_ATTR_NAME = ( RPC_OP_ROLE_ATTR_NAME = (
op_role_attr_name op_role_attr_name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册