未验证 提交 4d3e4184 编写于 作者: Z zhouzj 提交者: GitHub

Solve the bug of distributed training. (#1122)

上级 f9b8dce4
......@@ -26,6 +26,11 @@ def predict_compressed_model(model_dir,
Returns:
latency_dict(dict): The latency latency of the model under various compression strategies.
"""
local_rank = paddle.distributed.get_rank()
quant_model_path = f'quant_model/rank_{local_rank}'
prune_model_path = f'prune_model/rank_{local_rank}'
sparse_model_path = f'sparse_model/rank_{local_rank}'
latency_dict = {}
model_file = os.path.join(model_dir, model_filename)
......@@ -43,13 +48,13 @@ def predict_compressed_model(model_dir,
model_dir=model_dir,
model_filename=model_filename,
params_filename=params_filename,
save_model_path='quant_model',
save_model_path=quant_model_path,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
activation_bits=8,
weight_bits=8)
quant_model_file = os.path.join('quant_model', model_filename)
quant_param_file = os.path.join('quant_model', params_filename)
quant_model_file = os.path.join(quant_model_path, model_filename)
quant_param_file = os.path.join(quant_model_path, params_filename)
latency = predictor.predict(
model_file=quant_model_file,
......@@ -62,9 +67,9 @@ def predict_compressed_model(model_dir,
model_file=model_file,
param_file=param_file,
ratio=prune_ratio,
save_path='prune_model')
prune_model_file = os.path.join('prune_model', model_filename)
prune_param_file = os.path.join('prune_model', params_filename)
save_path=prune_model_path)
prune_model_file = os.path.join(prune_model_path, model_filename)
prune_param_file = os.path.join(prune_model_path, params_filename)
latency = predictor.predict(
model_file=prune_model_file,
......@@ -74,16 +79,16 @@ def predict_compressed_model(model_dir,
post_quant_fake(
exe,
model_dir='prune_model',
model_dir=prune_model_path,
model_filename=model_filename,
params_filename=params_filename,
save_model_path='quant_model',
save_model_path=quant_model_path,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
activation_bits=8,
weight_bits=8)
quant_model_file = os.path.join('quant_model', model_filename)
quant_param_file = os.path.join('quant_model', params_filename)
quant_model_file = os.path.join(quant_model_path, model_filename)
quant_param_file = os.path.join(quant_model_path, params_filename)
latency = predictor.predict(
model_file=quant_model_file,
......@@ -96,9 +101,9 @@ def predict_compressed_model(model_dir,
model_file=model_file,
param_file=param_file,
ratio=sparse_ratio,
save_path='sparse_model')
sparse_model_file = os.path.join('sparse_model', model_filename)
sparse_param_file = os.path.join('sparse_model', params_filename)
save_path=sparse_model_path)
sparse_model_file = os.path.join(sparse_model_path, model_filename)
sparse_param_file = os.path.join(sparse_model_path, params_filename)
latency = predictor.predict(
model_file=sparse_model_file,
......@@ -108,7 +113,7 @@ def predict_compressed_model(model_dir,
post_quant_fake(
exe,
model_dir='sparse_model',
model_dir=sparse_model_path,
model_filename=model_filename,
params_filename=params_filename,
save_model_path='quant_model',
......@@ -116,8 +121,8 @@ def predict_compressed_model(model_dir,
is_full_quantize=False,
activation_bits=8,
weight_bits=8)
quant_model_file = os.path.join('quant_model', model_filename)
quant_param_file = os.path.join('quant_model', params_filename)
quant_model_file = os.path.join(quant_model_path, model_filename)
quant_param_file = os.path.join(quant_model_path, params_filename)
latency = predictor.predict(
model_file=quant_model_file,
......@@ -125,8 +130,11 @@ def predict_compressed_model(model_dir,
data_type='int8')
latency_dict.update({f'sparse_{sparse_ratio}_int8': latency})
# Delete temporary model files
shutil.rmtree('./quant_model')
shutil.rmtree('./prune_model')
shutil.rmtree('./sparse_model')
# NOTE: Delete temporary model files
if os.path.exists('quant_model'):
shutil.rmtree('quant_model', ignore_errors=True)
if os.path.exists('prune_model'):
shutil.rmtree('prune_model', ignore_errors=True)
if os.path.exists('sparse_model'):
shutil.rmtree('sparse_model', ignore_errors=True)
return latency_dict
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册