未验证 提交 ecefd4e3 编写于 作者: R Roc 提交者: GitHub

Replace built-in print with logger in distributed_strategy.py (#47761)

上级 267b218f
......@@ -14,6 +14,7 @@
# limitations under the License.
import paddle
from paddle.distributed.fleet.utils.log_util import logger
from paddle.distributed.fleet.proto import distributed_strategy_pb2
from paddle.fluid.framework import _global_flags
from paddle.fluid.wrapped_decorator import wrap_decorator
......@@ -141,6 +142,7 @@ class DistributedStrategy:
self.strategy.sync_nccl_allreduce = bool(_global_flags()[key])
self.__lock_attr = True
logger.info("distributed strategy initialized")
def __setattr__(self, key, value):
if self.__lock_attr and not hasattr(self, key):
......@@ -503,12 +505,12 @@ class DistributedStrategy:
for field in msg.DESCRIPTOR.fields:
name = config_name + "." + field.name
if field.type == FieldDescriptor.TYPE_MESSAGE:
# print("message:", name)
logger.debug(f"message: {name}")
if field.label == FieldDescriptor.LABEL_REPEATED:
if name + ".num" not in configs:
continue
num = configs[name + ".num"]
# print("message num:", name, num)
logger.debug(f"message num: {name} {num}")
for i in range(num):
data = getattr(msg, field.name).add()
set_table_config(data, name, configs, i)
......@@ -517,7 +519,7 @@ class DistributedStrategy:
getattr(msg, field.name), name, configs
)
else:
# print("not message:", name)
logger.debug("not message:", name)
if name not in configs:
continue
if field.label == FieldDescriptor.LABEL_REPEATED:
......@@ -529,7 +531,7 @@ class DistributedStrategy:
setattr(msg, field.name, configs[name])
if not configs:
print("table configs is empty")
logger.info("table configs is empty")
else:
for table_name in configs:
table_data = table_param.add()
......@@ -814,7 +816,7 @@ class DistributedStrategy:
add_graph_config(table_data.accessor.graph_sgd_param, config)
if not configs:
print("fleet desc config is empty")
logger.info("fleet desc config is empty")
else:
for table_name in configs:
if (
......@@ -851,7 +853,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.amp = flag
else:
print("WARNING: amp should have value of bool type")
logger.warning("amp should have value of bool type")
@property
def amp_configs(self):
......@@ -938,7 +940,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.asp = flag
else:
print("WARNING: asp should have value of bool type")
logger.warning("asp should have value of bool type")
@property
def recompute(self):
......@@ -980,7 +982,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.sync_nccl_allreduce = flag
else:
print("WARNING: sync_nccl_allreduce should have value of bool type")
logger.warning("sync_nccl_allreduce should have value of bool type")
@property
def use_hierarchical_allreduce(self):
......@@ -1005,8 +1007,8 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.use_hierarchical_allreduce = flag
else:
print(
"WARNING: use_hierarchical_allreduce should have value of bool type"
logger.warning(
"use_hierarchical_allreduce should have value of bool type"
)
@property
......@@ -1031,8 +1033,8 @@ class DistributedStrategy:
if isinstance(value, int):
self.strategy.hierarchical_allreduce_inter_nranks = value
else:
print(
"WARNING: hierarchical_allreduce_inter_nranks should have value of int type"
logger.warning(
"hierarchical_allreduce_inter_nranks should have value of int type"
)
@property
......@@ -1059,7 +1061,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.sync_batch_norm = flag
else:
print("WARNING: sync_batch_norm should have value of bool type")
logger.warning("sync_batch_norm should have value of bool type")
@property
def fuse_all_reduce_ops(self):
......@@ -1083,7 +1085,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.fuse_all_reduce_ops = flag
else:
print("WARNING: fuse_all_reduce_ops should have value of bool type")
logger.warning("fuse_all_reduce_ops should have value of bool type")
@property
def fuse_grad_size_in_MB(self):
......@@ -1108,7 +1110,7 @@ class DistributedStrategy:
if isinstance(value, int):
self.strategy.fuse_grad_size_in_MB = value
else:
print("WARNING: fuse_grad_size_in_MB should have value of int type")
logger.warning("fuse_grad_size_in_MB should have value of int type")
@property
def last_comm_group_size_MB(self):
......@@ -1161,8 +1163,8 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.find_unused_parameters = flag
else:
print(
"WARNING: find_unused_parameters should have value of bool type"
logger.warning(
"find_unused_parameters should have value of bool type"
)
@property
......@@ -1175,8 +1177,8 @@ class DistributedStrategy:
if isinstance(value, float):
self.strategy.fuse_grad_size_in_TFLOPS = value
else:
print(
"WARNING: fuse_grad_size_in_TFLOPS should have value of float type"
logger.warning(
"fuse_grad_size_in_TFLOPS should have value of float type"
)
@property
......@@ -1203,7 +1205,7 @@ class DistributedStrategy:
if isinstance(value, int):
self.strategy.nccl_comm_num = value
else:
print("WARNING: nccl_comm_num should have value of int type")
logger.warning("nccl_comm_num should have value of int type")
@recompute.setter
@is_strict_auto
......@@ -1211,7 +1213,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.recompute = flag
else:
print("WARNING: recompute should have value of bool type")
logger.warning("recompute should have value of bool type")
@property
def recompute_configs(self):
......@@ -1282,7 +1284,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.sharding = flag
else:
print("WARNING: sharding should have value of bool type")
logger.warning("sharding should have value of bool type")
@property
def sharding_configs(self):
......@@ -1371,8 +1373,8 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.without_graph_optimization = flag
else:
print(
"WARNING: without_graph_optimization should have value of bool type"
logger.warning(
"without_graph_optimization should have value of bool type"
)
@property
......@@ -1395,8 +1397,8 @@ class DistributedStrategy:
if isinstance(same, bool):
self.strategy.calc_comm_same_stream = same
else:
print(
"WARNING: calc_comm_same_stream should have value of boolean type"
logger.warning(
"calc_comm_same_stream should have value of boolean type"
)
@property
......@@ -1419,7 +1421,7 @@ class DistributedStrategy:
if isinstance(fuse_grad_merge, bool):
self.strategy.fuse_grad_merge = fuse_grad_merge
else:
print("WARNING: fuse_grad_merge should have value of boolean type")
logger.warning("fuse_grad_merge should have value of boolean type")
@property
def fuse_grad_size_in_num(self):
......@@ -1439,8 +1441,8 @@ class DistributedStrategy:
if isinstance(num, int):
self.strategy.fuse_grad_size_in_num = num
else:
print(
"WARNING: fuse_grad_size_in_num should have value of int32 type"
logger.warning(
"fuse_grad_size_in_num should have value of int32 type"
)
@property
......@@ -1472,7 +1474,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.is_fl_ps_mode = flag
else:
print("WARNING: is_fl_ps_mode should have value of bool type")
logger.warning("is_fl_ps_mode should have value of bool type")
@property
def is_with_coordinator(self):
......@@ -1484,7 +1486,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.with_coordinator = flag
else:
print("WARNING: with_coordinator should have value of bool type")
logger.warning("with_coordinator should have value of bool type")
@pipeline.setter
@is_strict_auto
......@@ -1492,7 +1494,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.pipeline = flag
else:
print("WARNING: pipeline should have value of bool type")
logger.warning("pipeline should have value of bool type")
@property
def pipeline_configs(self):
......@@ -1554,7 +1556,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.tensor_parallel = flag
else:
print("WARNING: tensor_parallel should have value of bool type")
logger.warning("tensor_parallel should have value of bool type")
@property
def tensor_parallel_configs(self):
......@@ -1650,7 +1652,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.localsgd = flag
else:
print("WARNING: localsgd should have value of bool type")
logger.warning("localsgd should have value of bool type")
@property
def localsgd_configs(self):
......@@ -1708,7 +1710,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.adaptive_localsgd = flag
else:
print("WARNING: adaptive_localsgd should have value of bool type")
logger.warning("adaptive_localsgd should have value of bool type")
@property
def adaptive_localsgd_configs(self):
......@@ -1770,7 +1772,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.dgc = flag
else:
print("WARNING: dgc should have value of bool type")
logger.warning("dgc should have value of bool type")
@property
def dgc_configs(self):
......@@ -1860,7 +1862,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.gradient_merge = flag
else:
print("WARNING: gradient_merge should have value of bool type")
logger.warning("gradient_merge should have value of bool type")
@property
def gradient_merge_configs(self):
......@@ -1916,7 +1918,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.lars = flag
else:
print("WARNING: lars should have value of bool type")
logger.warning("lars should have value of bool type")
@property
def lars_configs(self):
......@@ -1980,7 +1982,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.lamb = flag
else:
print("WARNING: lamb should have value of bool type")
logger.warning("lamb should have value of bool type")
@property
def lamb_configs(self):
......@@ -2026,7 +2028,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.elastic = flag
else:
print("WARNING: elastic should have value of bool type")
logger.warning("elastic should have value of bool type")
@property
def auto(self):
......@@ -2061,7 +2063,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.auto = flag
else:
print("WARNING: auto should have value of bool type")
logger.warning("auto should have value of bool type")
@property
def semi_auto(self):
......@@ -2096,7 +2098,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.semi_auto = flag
else:
print("WARNING: semi-auto should have value of bool type")
logger.warning("semi-auto should have value of bool type")
@property
def auto_search(self):
......@@ -2119,7 +2121,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.auto_search = flag
else:
print("WARNING: auto-search should have value of bool type")
logger.warning("auto-search should have value of bool type")
@property
def split_data(self):
......@@ -2141,7 +2143,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.split_data = flag
else:
print("WARNING: split_data should have value of bool type")
logger.warning("split_data should have value of bool type")
@property
def qat(self):
......@@ -2156,7 +2158,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.qat = flag
else:
print("WARNING: qat should have value of bool type")
logger.warning("qat should have value of bool type")
@property
def qat_configs(self):
......@@ -2226,7 +2228,7 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.heter_ccl_mode = flag
else:
print("WARNING: heter_ccl_mode should have value of bool type")
logger.warning("heter_ccl_mode should have value of bool type")
@property
def cudnn_exhaustive_search(self):
......@@ -2258,8 +2260,8 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.cudnn_exhaustive_search = flag
else:
print(
"WARNING: cudnn_exhaustive_search should have value of bool type"
logger.warning(
"cudnn_exhaustive_search should have value of bool type"
)
@property
......@@ -2293,8 +2295,8 @@ class DistributedStrategy:
if isinstance(value, int):
self.strategy.conv_workspace_size_limit = value
else:
print(
"WARNING: conv_workspace_size_limit should have value of int type"
logger.warning(
"conv_workspace_size_limit should have value of int type"
)
@property
......@@ -2326,8 +2328,8 @@ class DistributedStrategy:
if isinstance(flag, bool):
self.strategy.cudnn_batchnorm_spatial_persistent = flag
else:
print(
"WARNING: cudnn_batchnorm_spatial_persistent should have value of bool type"
logger.warning(
"cudnn_batchnorm_spatial_persistent should have value of bool type"
)
def _enable_env(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册