提交 6efb07fe 编写于 作者: M mir-of

fix flow.function in of_cnn_infer_benchmarks.py

上级 bfc921b0
......@@ -114,24 +114,28 @@ model_dict = {
"alexnet": alexnet_model.alexnet,
}
func_config = flow.FunctionConfig()
func_config.default_data_type(flow.float)
flow.config.gpu_device_num(args.gpu_num_per_node)
if args.use_tensorrt:
func_config.use_tensorrt()
if args.use_xla_jit:
func_config.use_xla_jit()
if args.precision == "float16":
if not args.use_tensorrt:
func_config.enable_auto_mixed_precision()
else:
func_config.tensorrt.use_fp16()
@flow.function
@flow.function(func_config)
def InferenceNet():
total_device_num = args.node_num * args.gpu_num_per_node
batch_size = total_device_num * args.batch_size_per_device
if args.use_tensorrt:
flow.config.use_tensorrt()
if args.use_xla_jit:
flow.config.use_xla_jit()
if args.precision == "float16":
if not args.use_tensorrt:
flow.config.enable_auto_mixed_precision()
else:
flow.config.tensorrt.use_fp16()
if args.data_dir:
assert os.path.exists(args.data_dir)
print("Loading data from {}".format(args.data_dir))
......@@ -159,12 +163,9 @@ def main():
print("{} = {}".format(arg, getattr(args, arg)))
print("-".ljust(66, "-"))
print("Time stamp: {}".format(str(datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))))
flow.config.default_data_type(flow.float)
flow.config.gpu_device_num(args.gpu_num_per_node)
flow.env.grpc_use_no_signal()
flow.env.log_dir(args.log_dir)
# flow.config.enable_inplace(False)
# flow.config.ctrl_port(12140)
if args.node_num > 1:
nodes = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册