未验证 提交 4d3e4184 编写于 作者: Z zhouzj 提交者: GitHub

Solve the bug of distributed training. (#1122)

上级 f9b8dce4
...@@ -26,6 +26,11 @@ def predict_compressed_model(model_dir, ...@@ -26,6 +26,11 @@ def predict_compressed_model(model_dir,
Returns: Returns:
latency_dict(dict): The latency latency of the model under various compression strategies. latency_dict(dict): The latency latency of the model under various compression strategies.
""" """
local_rank = paddle.distributed.get_rank()
quant_model_path = f'quant_model/rank_{local_rank}'
prune_model_path = f'prune_model/rank_{local_rank}'
sparse_model_path = f'sparse_model/rank_{local_rank}'
latency_dict = {} latency_dict = {}
model_file = os.path.join(model_dir, model_filename) model_file = os.path.join(model_dir, model_filename)
...@@ -43,13 +48,13 @@ def predict_compressed_model(model_dir, ...@@ -43,13 +48,13 @@ def predict_compressed_model(model_dir,
model_dir=model_dir, model_dir=model_dir,
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename, params_filename=params_filename,
save_model_path='quant_model', save_model_path=quant_model_path,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False, is_full_quantize=False,
activation_bits=8, activation_bits=8,
weight_bits=8) weight_bits=8)
quant_model_file = os.path.join('quant_model', model_filename) quant_model_file = os.path.join(quant_model_path, model_filename)
quant_param_file = os.path.join('quant_model', params_filename) quant_param_file = os.path.join(quant_model_path, params_filename)
latency = predictor.predict( latency = predictor.predict(
model_file=quant_model_file, model_file=quant_model_file,
...@@ -62,9 +67,9 @@ def predict_compressed_model(model_dir, ...@@ -62,9 +67,9 @@ def predict_compressed_model(model_dir,
model_file=model_file, model_file=model_file,
param_file=param_file, param_file=param_file,
ratio=prune_ratio, ratio=prune_ratio,
save_path='prune_model') save_path=prune_model_path)
prune_model_file = os.path.join('prune_model', model_filename) prune_model_file = os.path.join(prune_model_path, model_filename)
prune_param_file = os.path.join('prune_model', params_filename) prune_param_file = os.path.join(prune_model_path, params_filename)
latency = predictor.predict( latency = predictor.predict(
model_file=prune_model_file, model_file=prune_model_file,
...@@ -74,16 +79,16 @@ def predict_compressed_model(model_dir, ...@@ -74,16 +79,16 @@ def predict_compressed_model(model_dir,
post_quant_fake( post_quant_fake(
exe, exe,
model_dir='prune_model', model_dir=prune_model_path,
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename, params_filename=params_filename,
save_model_path='quant_model', save_model_path=quant_model_path,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False, is_full_quantize=False,
activation_bits=8, activation_bits=8,
weight_bits=8) weight_bits=8)
quant_model_file = os.path.join('quant_model', model_filename) quant_model_file = os.path.join(quant_model_path, model_filename)
quant_param_file = os.path.join('quant_model', params_filename) quant_param_file = os.path.join(quant_model_path, params_filename)
latency = predictor.predict( latency = predictor.predict(
model_file=quant_model_file, model_file=quant_model_file,
...@@ -96,9 +101,9 @@ def predict_compressed_model(model_dir, ...@@ -96,9 +101,9 @@ def predict_compressed_model(model_dir,
model_file=model_file, model_file=model_file,
param_file=param_file, param_file=param_file,
ratio=sparse_ratio, ratio=sparse_ratio,
save_path='sparse_model') save_path=sparse_model_path)
sparse_model_file = os.path.join('sparse_model', model_filename) sparse_model_file = os.path.join(sparse_model_path, model_filename)
sparse_param_file = os.path.join('sparse_model', params_filename) sparse_param_file = os.path.join(sparse_model_path, params_filename)
latency = predictor.predict( latency = predictor.predict(
model_file=sparse_model_file, model_file=sparse_model_file,
...@@ -108,7 +113,7 @@ def predict_compressed_model(model_dir, ...@@ -108,7 +113,7 @@ def predict_compressed_model(model_dir,
post_quant_fake( post_quant_fake(
exe, exe,
model_dir='sparse_model', model_dir=sparse_model_path,
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename, params_filename=params_filename,
save_model_path='quant_model', save_model_path='quant_model',
...@@ -116,8 +121,8 @@ def predict_compressed_model(model_dir, ...@@ -116,8 +121,8 @@ def predict_compressed_model(model_dir,
is_full_quantize=False, is_full_quantize=False,
activation_bits=8, activation_bits=8,
weight_bits=8) weight_bits=8)
quant_model_file = os.path.join('quant_model', model_filename) quant_model_file = os.path.join(quant_model_path, model_filename)
quant_param_file = os.path.join('quant_model', params_filename) quant_param_file = os.path.join(quant_model_path, params_filename)
latency = predictor.predict( latency = predictor.predict(
model_file=quant_model_file, model_file=quant_model_file,
...@@ -125,8 +130,11 @@ def predict_compressed_model(model_dir, ...@@ -125,8 +130,11 @@ def predict_compressed_model(model_dir,
data_type='int8') data_type='int8')
latency_dict.update({f'sparse_{sparse_ratio}_int8': latency}) latency_dict.update({f'sparse_{sparse_ratio}_int8': latency})
# Delete temporary model files # NOTE: Delete temporary model files
shutil.rmtree('./quant_model') if os.path.exists('quant_model'):
shutil.rmtree('./prune_model') shutil.rmtree('quant_model', ignore_errors=True)
shutil.rmtree('./sparse_model') if os.path.exists('prune_model'):
shutil.rmtree('prune_model', ignore_errors=True)
if os.path.exists('sparse_model'):
shutil.rmtree('sparse_model', ignore_errors=True)
return latency_dict return latency_dict
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册