未验证 提交 fc90903f 编写于 作者: G Guanghua Yu 提交者: GitHub

update tinypose act demo (#1227)

上级 b13ff649
Global: Global:
arch: 'keypoint'
reader_config: configs/tinypose_reader.yml reader_config: configs/tinypose_reader.yml
input_list: ['image'] input_list: ['image']
Evaluation: False Evaluation: True
model_dir: ./tinypose_128x96 model_dir: ./tinypose_128x96
model_filename: model.pdmodel model_filename: model.pdmodel
params_filename: model.pdiparams params_filename: model.pdiparams
...@@ -13,19 +14,21 @@ Distillation: ...@@ -13,19 +14,21 @@ Distillation:
- conv2d_441.tmp_0 - conv2d_441.tmp_0
Quantization: Quantization:
activation_quantize_type: 'range_abs_max' use_pact: true
weight_quantize_type: 'abs_max' activation_quantize_type: 'moving_average_abs_max'
weight_quantize_type: 'channel_wise_abs_max' # 'abs_max' is layer wise quant
quantize_op_types: quantize_op_types:
- conv2d - conv2d
- depthwise_conv2d - depthwise_conv2d
TrainConfig: TrainConfig:
epochs: 1 train_iter: 30000
eval_iter: 1000 eval_iter: 1000
learning_rate: 0.0001 learning_rate:
type: CosineAnnealingDecay
learning_rate: 0.015
T_max: 30000
optimizer_builder: optimizer_builder:
optimizer: optimizer:
type: SGD type: Momentum
weight_decay: 4.0e-05 weight_decay: 0.00002
#origin_metric: 0.291
...@@ -45,4 +45,4 @@ EvalReader: ...@@ -45,4 +45,4 @@ EvalReader:
std: *global_std std: *global_std
is_scale: true is_scale: true
- Permute: {} - Permute: {}
batch_size: 4 batch_size: 16
...@@ -19,8 +19,9 @@ import argparse ...@@ -19,8 +19,9 @@ import argparse
import paddle import paddle
from ppdet.core.workspace import load_config, merge_config from ppdet.core.workspace import load_config, merge_config
from ppdet.core.workspace import create from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from keypoint_utils import keypoint_post_process
def argsparser(): def argsparser():
...@@ -99,12 +100,16 @@ def eval(): ...@@ -99,12 +100,16 @@ def eval():
fetch_list=fetch_targets, fetch_list=fetch_targets,
return_numpy=False) return_numpy=False)
res = {} res = {}
for out in outs: if 'arch' in global_config and global_config['arch'] == 'keypoint':
v = np.array(out) res = keypoint_post_process(data, data_input, exe, val_program,
if len(v.shape) > 1: fetch_targets, outs)
res['bbox'] = v else:
else: for out in outs:
res['bbox_num'] = v v = np.array(out)
if len(v.shape) > 1:
res['bbox'] = v
else:
res['bbox_num'] = v
metric.update(data_all, res) metric.update(data_all, res)
if batch_id % 100 == 0: if batch_id % 100 == 0:
print('Eval iter:', batch_id) print('Eval iter:', batch_id)
...@@ -135,6 +140,10 @@ def main(): ...@@ -135,6 +140,10 @@ def main():
label_list=dataset.get_label_list(), label_list=dataset.get_label_list(),
class_num=reader_cfg['num_classes'], class_num=reader_cfg['num_classes'],
map_type=reader_cfg['map_type']) map_type=reader_cfg['map_type'])
elif reader_cfg['metric'] == 'KeyPointTopDownCOCOEval':
anno_file = dataset.get_anno()
metric = KeyPointTopDownCOCOEval(anno_file,
len(dataset), 17, 'output_eval')
else: else:
raise ValueError("metric currently only supports COCO and VOC.") raise ValueError("metric currently only supports COCO and VOC.")
global_config['metric'] = metric global_config['metric'] = metric
......
...@@ -11,39 +11,30 @@ ...@@ -11,39 +11,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import os
import json
import numpy as np import numpy as np
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from scipy.io import loadmat, savemat
import cv2 import cv2
import copy
from paddleslim.common import get_logger from paddleslim.common import get_logger
logger = get_logger(__name__, level=logging.INFO) logger = get_logger(__name__, level=logging.INFO)
__all__ = ['keypoint_post_process']
def get_affine_mat_kernel(h, w, s, inv=False):
if w < h:
w_ = s
h_ = int(np.ceil((s / w * h) / 64.) * 64)
scale_w = w
scale_h = h_ / w_ * w
else: def flip_back(output_flipped, matched_parts):
h_ = s assert output_flipped.ndim == 4,\
w_ = int(np.ceil((s / h * w) / 64.) * 64) 'output_flipped should be [batch_size, num_joints, height, width]'
scale_h = h
scale_w = w_ / h_ * h
center = np.array([np.round(w / 2.), np.round(h / 2.)]) output_flipped = output_flipped[:, :, :, ::-1]
size_resized = (w_, h_) for pair in matched_parts:
trans = get_affine_transform( tmp = output_flipped[:, pair[0], :, :].copy()
center, np.array([scale_w, scale_h]), 0, size_resized, inv=inv) output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
output_flipped[:, pair[1], :, :] = tmp
return trans, size_resized return output_flipped
def get_affine_transform(center, def get_affine_transform(center,
...@@ -101,37 +92,6 @@ def get_affine_transform(center, ...@@ -101,37 +92,6 @@ def get_affine_transform(center,
return trans return trans
def get_warp_matrix(theta, size_input, size_dst, size_target):
"""This code is based on
https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py
Calculate the transformation matrix under the constraint of unbiased.
Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased
Data Processing for Human Pose Estimation (CVPR 2020).
Args:
theta (float): Rotation angle in degrees.
size_input (np.ndarray): Size of input image [w, h].
size_dst (np.ndarray): Size of output image [w, h].
size_target (np.ndarray): Size of ROI in input plane [w, h].
Returns:
matrix (np.ndarray): A matrix for transformation.
"""
theta = np.deg2rad(theta)
matrix = np.zeros((2, 3), dtype=np.float32)
scale_x = size_dst[0] / size_target[0]
scale_y = size_dst[1] / size_target[1]
matrix[0, 0] = np.cos(theta) * scale_x
matrix[0, 1] = -np.sin(theta) * scale_x
matrix[0, 2] = scale_x * (
-0.5 * size_input[0] * np.cos(theta) + 0.5 * size_input[1] *
np.sin(theta) + 0.5 * size_target[0])
matrix[1, 0] = np.sin(theta) * scale_y
matrix[1, 1] = np.cos(theta) * scale_y
matrix[1, 2] = scale_y * (
-0.5 * size_input[0] * np.sin(theta) - 0.5 * size_input[1] *
np.cos(theta) + 0.5 * size_target[1])
return matrix
def _get_3rd_point(a, b): def _get_3rd_point(a, b):
"""To calculate the affine matrix, three pairs of points are required. This """To calculate the affine matrix, three pairs of points are required. This
function is used to get the 3rd point, given 2D points a & b. function is used to get the 3rd point, given 2D points a & b.
...@@ -170,29 +130,6 @@ def rotate_point(pt, angle_rad): ...@@ -170,29 +130,6 @@ def rotate_point(pt, angle_rad):
return rotated_pt return rotated_pt
def transpred(kpts, h, w, s):
trans, _ = get_affine_mat_kernel(h, w, s, inv=True)
return warp_affine_joints(kpts[..., :2].copy(), trans)
def warp_affine_joints(joints, mat):
"""Apply affine transformation defined by the transform matrix on the
joints.
Args:
joints (np.ndarray[..., 2]): Origin coordinate of joints.
mat (np.ndarray[3, 2]): The affine matrix.
Returns:
matrix (np.ndarray[..., 2]): Result coordinate of joints.
"""
joints = np.array(joints)
shape = joints.shape
joints = joints.reshape(-1, 2)
return np.dot(np.concatenate(
(joints, joints[:, 0:1] * 0 + 1), axis=1),
mat.T).reshape(shape)
def affine_transform(pt, t): def affine_transform(pt, t):
new_pt = np.array([pt[0], pt[1], 1.]).T new_pt = np.array([pt[0], pt[1], 1.]).T
new_pt = np.dot(t, new_pt) new_pt = np.dot(t, new_pt)
...@@ -207,130 +144,6 @@ def transform_preds(coords, center, scale, output_size): ...@@ -207,130 +144,6 @@ def transform_preds(coords, center, scale, output_size):
return target_coords return target_coords
def oks_iou(g, d, a_g, a_d, sigmas=None, in_vis_thre=None):
if not isinstance(sigmas, np.ndarray):
sigmas = np.array([
.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07,
.87, .87, .89, .89
]) / 10.0
vars = (sigmas * 2)**2
xg = g[0::3]
yg = g[1::3]
vg = g[2::3]
ious = np.zeros((d.shape[0]))
for n_d in range(0, d.shape[0]):
xd = d[n_d, 0::3]
yd = d[n_d, 1::3]
vd = d[n_d, 2::3]
dx = xd - xg
dy = yd - yg
e = (dx**2 + dy**2) / vars / ((a_g + a_d[n_d]) / 2 + np.spacing(1)) / 2
if in_vis_thre is not None:
ind = list(vg > in_vis_thre) and list(vd > in_vis_thre)
e = e[ind]
ious[n_d] = np.sum(np.exp(-e)) / e.shape[0] if e.shape[0] != 0 else 0.0
return ious
def oks_nms(kpts_db, thresh, sigmas=None, in_vis_thre=None):
"""greedily select boxes with high confidence and overlap with current maximum <= thresh
rule out overlap >= thresh
Args:
kpts_db (list): The predicted keypoints within the image
thresh (float): The threshold to select the boxes
sigmas (np.array): The variance to calculate the oks iou
Default: None
in_vis_thre (float): The threshold to select the high confidence boxes
Default: None
Return:
keep (list): indexes to keep
"""
if len(kpts_db) == 0:
return []
scores = np.array([kpts_db[i]['score'] for i in range(len(kpts_db))])
kpts = np.array(
[kpts_db[i]['keypoints'].flatten() for i in range(len(kpts_db))])
areas = np.array([kpts_db[i]['area'] for i in range(len(kpts_db))])
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
oks_ovr = oks_iou(kpts[i], kpts[order[1:]], areas[i], areas[order[1:]],
sigmas, in_vis_thre)
inds = np.where(oks_ovr <= thresh)[0]
order = order[inds + 1]
return keep
def rescore(overlap, scores, thresh, type='gaussian'):
assert overlap.shape[0] == scores.shape[0]
if type == 'linear':
inds = np.where(overlap >= thresh)[0]
scores[inds] = scores[inds] * (1 - overlap[inds])
else:
scores = scores * np.exp(-overlap**2 / thresh)
return scores
def soft_oks_nms(kpts_db, thresh, sigmas=None, in_vis_thre=None):
"""greedily select boxes with high confidence and overlap with current maximum <= thresh
rule out overlap >= thresh
Args:
kpts_db (list): The predicted keypoints within the image
thresh (float): The threshold to select the boxes
sigmas (np.array): The variance to calculate the oks iou
Default: None
in_vis_thre (float): The threshold to select the high confidence boxes
Default: None
Return:
keep (list): indexes to keep
"""
if len(kpts_db) == 0:
return []
scores = np.array([kpts_db[i]['score'] for i in range(len(kpts_db))])
kpts = np.array(
[kpts_db[i]['keypoints'].flatten() for i in range(len(kpts_db))])
areas = np.array([kpts_db[i]['area'] for i in range(len(kpts_db))])
order = scores.argsort()[::-1]
scores = scores[order]
# max_dets = order.size
max_dets = 20
keep = np.zeros(max_dets, dtype=np.intp)
keep_cnt = 0
while order.size > 0 and keep_cnt < max_dets:
i = order[0]
oks_ovr = oks_iou(kpts[i], kpts[order[1:]], areas[i], areas[order[1:]],
sigmas, in_vis_thre)
order = order[1:]
scores = rescore(oks_ovr, scores[1:], thresh)
tmp = scores.argsort()[::-1]
order = order[tmp]
scores = scores[tmp]
keep[keep_cnt] = i
keep_cnt += 1
keep = keep[:keep_cnt]
return keep
class HRNetPostProcess(object): class HRNetPostProcess(object):
def __init__(self, use_dark=True): def __init__(self, use_dark=True):
self.use_dark = use_dark self.use_dark = use_dark
...@@ -468,3 +281,27 @@ class HRNetPostProcess(object): ...@@ -468,3 +281,27 @@ class HRNetPostProcess(object):
maxvals, axis=1) maxvals, axis=1)
]] ]]
return outputs return outputs
def keypoint_post_process(data, data_input, exe, val_program, fetch_targets,
outs):
data_input['image'] = np.flip(data_input['image'], [3])
output_flipped = exe.run(val_program,
feed=data_input,
fetch_list=fetch_targets,
return_numpy=False)
output_flipped = np.array(output_flipped[0])
flip_perm = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14],
[15, 16]]
output_flipped = flip_back(output_flipped, flip_perm)
output_flipped[:, :, :, 1:] = copy.copy(output_flipped)[:, :, :, 0:-1]
hrnet_outputs = (np.array(outs[0]) + output_flipped) * 0.5
imshape = (
np.array(data['im_shape']))[:, ::-1] if 'im_shape' in data else None
center = np.array(data['center']) if 'center' in data else np.round(
imshape / 2.)
scale = np.array(data['scale']) if 'scale' in data else imshape / 200.
post_process = HRNetPostProcess()
outputs = post_process(hrnet_outputs, center, scale)
return {'keypoint': outputs}
...@@ -19,9 +19,10 @@ import argparse ...@@ -19,9 +19,10 @@ import argparse
import paddle import paddle
from ppdet.core.workspace import load_config, merge_config from ppdet.core.workspace import load_config, merge_config
from ppdet.core.workspace import create from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from paddleslim.auto_compression import AutoCompression from paddleslim.auto_compression import AutoCompression
from keypoint_utils import keypoint_post_process
def argsparser(): def argsparser():
...@@ -95,12 +96,17 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): ...@@ -95,12 +96,17 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
fetch_list=test_fetch_list, fetch_list=test_fetch_list,
return_numpy=False) return_numpy=False)
res = {} res = {}
for out in outs: if 'arch' in global_config and global_config['arch'] == 'keypoint':
v = np.array(out) res = keypoint_post_process(data, data_input, exe,
if len(v.shape) > 1: compiled_test_program, test_fetch_list,
res['bbox'] = v outs)
else: else:
res['bbox_num'] = v for out in outs:
v = np.array(out)
if len(v.shape) > 1:
res['bbox'] = v
else:
res['bbox_num'] = v
metric.update(data_all, res) metric.update(data_all, res)
if batch_id % 100 == 0: if batch_id % 100 == 0:
...@@ -109,7 +115,9 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): ...@@ -109,7 +115,9 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
metric.log() metric.log()
map_res = metric.get_results() map_res = metric.get_results()
metric.reset() metric.reset()
return map_res['bbox'][0] map_key = 'keypoint' if 'arch' in global_config and global_config[
'arch'] == 'keypoint' else 'bbox'
return map_res[map_key][0]
def main(): def main():
...@@ -146,6 +154,10 @@ def main(): ...@@ -146,6 +154,10 @@ def main():
label_list=dataset.get_label_list(), label_list=dataset.get_label_list(),
class_num=reader_cfg['num_classes'], class_num=reader_cfg['num_classes'],
map_type=reader_cfg['map_type']) map_type=reader_cfg['map_type'])
elif reader_cfg['metric'] == 'KeyPointTopDownCOCOEval':
anno_file = dataset.get_anno()
metric = KeyPointTopDownCOCOEval(anno_file,
len(dataset), 17, 'output_eval')
else: else:
raise ValueError("metric currently only supports COCO and VOC.") raise ValueError("metric currently only supports COCO and VOC.")
global_config['metric'] = metric global_config['metric'] = metric
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import argparse
import paddle
import copy
import cv2
from ppdet.core.workspace import load_config, merge_config
from ppdet.core.workspace import create
from ppdet.metrics import KeyPointTopDownCOCOEval
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from paddleslim.auto_compression import AutoCompression
from paddleslim.quant import quant_post_static
from keypoint_utils import HRNetPostProcess, transform_preds
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--config_path',
type=str,
default=None,
help="path of compression strategy config.",
required=True)
parser.add_argument(
'--save_dir',
type=str,
default='output',
help="directory to save compressed model.")
parser.add_argument(
'--devices',
type=str,
default='gpu',
help="which device used to compress.")
parser.add_argument(
'--eval', type=bool, default=False, help="whether to run evaluation.")
parser.add_argument(
'--quant', type=bool, default=False, help="whether to run evaluation.")
return parser
def print_arguments(args):
print('----------- Running Arguments -----------')
for arg, value in sorted(vars(args).items()):
print('%s: %s' % (arg, value))
print('------------------------------------------')
def reader_wrapper(reader, input_list):
def gen():
for data in reader:
in_dict = {}
for input_name in input_list:
in_dict[input_name] = data[input_name]
yield in_dict
return gen
def flip_back(output_flipped, matched_parts):
assert output_flipped.ndim == 4,\
'output_flipped should be [batch_size, num_joints, height, width]'
output_flipped = output_flipped[:, :, :, ::-1]
for pair in matched_parts:
tmp = output_flipped[:, pair[0], :, :].copy()
output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
output_flipped[:, pair[1], :, :] = tmp
return output_flipped
def eval(config):
place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace()
exe = paddle.static.Executor(place)
val_program, feed_target_names, fetch_targets = paddle.fluid.io.load_inference_model(
config["model_dir"],
exe,
model_filename=config["model_filename"],
params_filename=config["params_filename"], )
dataset.check_or_download_dataset()
anno_file = dataset.get_anno()
metric = KeyPointTopDownCOCOEval(anno_file, len(dataset), 17, 'output_eval')
post_process = HRNetPostProcess()
for batch_id, data in enumerate(val_loader):
data_all = {k: np.array(v) for k, v in data.items()}
data_input = {}
for k, v in data.items():
if k in config['input_list']:
data_input[k] = np.array(v)
outs = exe.run(val_program,
feed=data_input,
fetch_list=fetch_targets,
return_numpy=False)
data_input['image'] = np.flip(data_input['image'], [3])
output_flipped = exe.run(val_program,
feed=data_input,
fetch_list=fetch_targets,
return_numpy=False)
output_flipped = np.array(output_flipped[0])
flip_perm = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12],
[13, 14], [15, 16]]
output_flipped = flip_back(output_flipped, flip_perm)
output_flipped[:, :, :, 1:] = copy.copy(output_flipped)[:, :, :, 0:-1]
hrnet_outputs = (np.array(outs[0]) + output_flipped) * 0.5
imshape = (
np.array(data['im_shape']))[:, ::-1] if 'im_shape' in data else None
center = np.array(data['center']) if 'center' in data else np.round(
imshape / 2.)
scale = np.array(data['scale']) if 'scale' in data else imshape / 200.
outputs = post_process(hrnet_outputs, center, scale)
outputs = {'keypoint': outputs}
metric.update(data_all, outputs)
if batch_id % 100 == 0:
print('Eval iter:', batch_id)
metric.accumulate()
metric.log()
metric.reset()
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
dataset.check_or_download_dataset()
anno_file = dataset.get_anno()
metric = KeyPointTopDownCOCOEval(anno_file, len(dataset), 17, 'output_eval')
post_process = HRNetPostProcess()
for batch_id, data in enumerate(val_loader):
data_all = {k: np.array(v) for k, v in data.items()}
data_input = {}
for k, v in data.items():
if k in test_feed_names:
data_input[k] = np.array(v)
outs = exe.run(compiled_test_program,
feed=data_input,
fetch_list=test_fetch_list,
return_numpy=False)
data_input['image'] = np.flip(data_input['image'], [3])
output_flipped = exe.run(compiled_test_program,
feed=data_input,
fetch_list=test_fetch_list,
return_numpy=False)
output_flipped = np.array(output_flipped[0])
flip_perm = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12],
[13, 14], [15, 16]]
output_flipped = flip_back(output_flipped, flip_perm)
output_flipped[:, :, :, 1:] = copy.copy(output_flipped)[:, :, :, 0:-1]
hrnet_outputs = (np.array(outs[0]) + output_flipped) * 0.5
imshape = (
np.array(data['im_shape']))[:, ::-1] if 'im_shape' in data else None
center = np.array(data['center']) if 'center' in data else np.round(
imshape / 2.)
scale = np.array(data['scale']) if 'scale' in data else imshape / 200.
outputs = post_process(hrnet_outputs, center, scale)
outputs = {'keypoint': outputs}
metric.update(data_all, outputs)
if batch_id % 100 == 0:
print('Eval iter:', batch_id)
metric.accumulate()
metric.log()
map_res = metric.get_results()
metric.reset()
return map_res['keypoint'][0]
def main():
all_config = load_slim_config(FLAGS.config_path)
global global_config
assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}"
global_config = all_config["Global"]
reader_cfg = load_config(global_config['reader_config'])
train_loader = create('EvalReader')(reader_cfg['TrainDataset'],
reader_cfg['worker_num'],
return_list=True)
train_loader = reader_wrapper(train_loader, global_config['input_list'])
global dataset
dataset = reader_cfg['EvalDataset']
global val_loader
val_loader = create('EvalReader')(reader_cfg['EvalDataset'],
reader_cfg['worker_num'],
return_list=True)
if FLAGS.eval:
eval(global_config)
sys.exit(0)
if 'Evaluation' in global_config.keys() and global_config['Evaluation']:
eval_func = eval_function
else:
eval_func = None
ac = AutoCompression(
model_dir=global_config["model_dir"],
model_filename=global_config["model_filename"],
params_filename=global_config["params_filename"],
save_dir=FLAGS.save_dir,
config=all_config,
train_dataloader=train_loader,
eval_callback=eval_func)
ac.compress()
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
print_arguments(FLAGS)
assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu']
paddle.set_device(FLAGS.devices)
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册