未验证 提交 f7d08b7d 编写于 作者: D Dong Daxiang 提交者: GitHub

【paddle.fleet】refine launch and distributed repr string for print (#27093)

* refine launch and distributed repr string for print
上级 43b0445b
......@@ -17,6 +17,7 @@ from paddle.distributed.fleet.proto import distributed_strategy_pb2
from paddle.fluid.framework import Variable, set_flags, core
from paddle.fluid.wrapped_decorator import wrap_decorator
import google.protobuf.text_format
import google.protobuf
__all__ = ["DistributedStrategy"]
......@@ -1133,7 +1134,91 @@ class DistributedStrategy(object):
return False
def __repr__(self):
spacing = 2
max_k = 38
max_v = 38
length = max_k + max_v + spacing
h1_format = " " + "|{{:^{}s}}|\n".format(length)
h2_format = " " + "|{{:>{}s}}{}{{:^{}s}}|\n".format(max_k, " " *
spacing, max_v)
border = " +" + "".join(["="] * length) + "+"
line = " +" + "".join(["-"] * length) + "+"
draws = border + "\n"
draws += h1_format.format("")
draws += h1_format.format("DistributedStrategy Overview")
draws += h1_format.format("")
fields = self.strategy.DESCRIPTOR.fields
str_res = ""
env_draws = line + "\n"
for f in fields:
if "build_strategy" in f.name or "execution_strategy" in f.name:
continue
if "_configs" in f.name:
continue
else:
if isinstance(getattr(self.strategy, f.name), bool):
if hasattr(self.strategy, f.name + "_configs"):
if getattr(self.strategy, f.name):
draws += border + "\n"
draws += h1_format.format(
"{} = True, please check {}_configs".format(
f.name, f.name))
draws += line + "\n"
my_configs = getattr(self.strategy,
f.name + "_configs")
config_fields = my_configs.DESCRIPTOR.fields
for ff in config_fields:
if isinstance(
getattr(my_configs, ff.name),
google.protobuf.pyext._message.
RepeatedScalarContainer):
values = getattr(my_configs, ff.name)
for i, v in enumerate(values):
if i == 0:
draws += h2_format.format(ff.name,
str(v))
else:
draws += h2_format.format("",
str(v))
else:
draws += h2_format.format(
ff.name,
str(getattr(my_configs, ff.name)))
else:
env_draws += h2_format.format(
f.name, str(getattr(self.strategy, f.name)))
else:
env_draws += h2_format.format(
f.name, str(getattr(self.strategy, f.name)))
result_res = draws + border + "\n" + h1_format.format(
"Environment Flags, Communication Flags")
result_res += env_draws
build_strategy_str = border + "\n"
build_strategy_str += h1_format.format("Build Strategy")
build_strategy_str += line + "\n"
fields = self.strategy.build_strategy.DESCRIPTOR.fields
for f in fields:
print("{}: {}".format(f.name, f.default_value))
return str(self.strategy)
build_strategy_str += h2_format.format(
f.name, str(getattr(self.strategy.build_strategy, f.name)))
build_strategy_str += border + "\n"
execution_strategy_str = h1_format.format("Execution Strategy")
execution_strategy_str += line + "\n"
fields = self.strategy.execution_strategy.DESCRIPTOR.fields
for f in fields:
execution_strategy_str += h2_format.format(
f.name, str(getattr(self.strategy.execution_strategy, f.name)))
execution_strategy_str += border + "\n"
result_res += build_strategy_str + execution_strategy_str
return result_res
......@@ -347,12 +347,13 @@ def pretty_print_envs(envs, header=None):
for k, v in envs.items():
max_k = max(max_k, len(k))
h_format = "{{:^{}s}}{}{{:<{}s}}\n".format(max_k, " " * spacing, max_v)
l_format = "{{:<{}s}}{{}}{{:<{}s}}\n".format(max_k, max_v)
h_format = " " + "|{{:>{}s}}{}{{:^{}s}}|\n".format(max_k, " " * spacing,
max_v)
l_format = " " + "|{{:>{}s}}{{}}{{:^{}s}}|\n".format(max_k, max_v)
length = max_k + max_v + spacing
border = "".join(["="] * length)
line = "".join(["-"] * length)
border = " +" + "".join(["="] * length) + "+"
line = " +" + "".join(["-"] * length) + "+"
draws = ""
draws += border + "\n"
......
......@@ -47,6 +47,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_dgc_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_private_function)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_graph_executor)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_meta_optimizer_base)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_distributed_strategy)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_auto)
foreach(TEST_OP ${MIXED_DIST_TEST_OPS})
list(REMOVE_ITEM TEST_OPS ${TEST_OP})
......@@ -461,6 +462,7 @@ if(WITH_DISTRIBUTE)
py_test_modules(test_fleet_pipeline_meta_optimizer MODULES test_fleet_pipeline_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_private_function MODULES test_fleet_private_function ENVS ${dist_ENVS})
py_test_modules(test_fleet_meta_optimizer_base MODULES test_fleet_meta_optimizer_base ENVS ${dist_ENVS})
py_test_modules(test_fleet_distributed_strategy MODULES test_fleet_distributed_strategy)
py_test_modules(test_fleet_auto MODULES test_fleet_auto ENVS ${dist_ENVS})
if(NOT WIN32)
py_test_modules(test_fleet_localsgd_meta_optimizer MODULES test_fleet_localsgd_meta_optimizer ENVS ${dist_ENVS})
......
......@@ -316,6 +316,14 @@ class TestStrategyConfig(unittest.TestCase):
self.assertEqual(strategy.conv_workspace_size_limit, 1000)
strategy._enable_env()
def test_distributed_strategy_repr(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.recompute = True
strategy.recompute_configs = {"checkpoints": ["a1", "a2", "a3"]}
strategy.amp = True
strategy.localsgd = True
print(str(strategy))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册