未验证 提交 d7cf1207 编写于 作者: W Walter 提交者: GitHub

Merge pull request #2296 from HydrogenSulfate/fix_GeneralRecognitionV2

Fix GeneralRecognitionV2 benchmark TIPC
...@@ -53,14 +53,14 @@ class PKSampler(DistributedBatchSampler): ...@@ -53,14 +53,14 @@ class PKSampler(DistributedBatchSampler):
f"PKSampler configs error, sample_per_id({sample_per_id}) must be a divisor of batch_size({batch_size})." f"PKSampler configs error, sample_per_id({sample_per_id}) must be a divisor of batch_size({batch_size})."
assert hasattr(self.dataset, assert hasattr(self.dataset,
"labels"), "Dataset must have labels attribute." "labels"), "Dataset must have labels attribute."
self.sample_per_label = sample_per_id self.sample_per_id = sample_per_id
self.label_dict = defaultdict(list) self.label_dict = defaultdict(list)
self.sample_method = sample_method self.sample_method = sample_method
for idx, label in enumerate(self.dataset.labels): for idx, label in enumerate(self.dataset.labels):
self.label_dict[label].append(idx) self.label_dict[label].append(idx)
self.label_list = list(self.label_dict) self.label_list = list(self.label_dict)
assert len(self.label_list) * self.sample_per_label > self.batch_size, \ assert len(self.label_list) * self.sample_per_id >= self.batch_size, \
"batch size should be smaller than " f"batch size({self.batch_size}) should not be bigger than than #classes({len(self.label_list)})*sample_per_id({self.sample_per_id})"
if self.sample_method == "id_avg_prob": if self.sample_method == "id_avg_prob":
self.prob_list = np.array([1 / len(self.label_list)] * self.prob_list = np.array([1 / len(self.label_list)] *
len(self.label_list)) len(self.label_list))
...@@ -94,7 +94,7 @@ class PKSampler(DistributedBatchSampler): ...@@ -94,7 +94,7 @@ class PKSampler(DistributedBatchSampler):
format(diff)) format(diff))
def __iter__(self): def __iter__(self):
label_per_batch = self.batch_size // self.sample_per_label label_per_batch = self.batch_size // self.sample_per_id
for _ in range(len(self)): for _ in range(len(self)):
batch_index = [] batch_index = []
batch_label_list = np.random.choice( batch_label_list = np.random.choice(
...@@ -104,17 +104,17 @@ class PKSampler(DistributedBatchSampler): ...@@ -104,17 +104,17 @@ class PKSampler(DistributedBatchSampler):
p=self.prob_list) p=self.prob_list)
for label_i in batch_label_list: for label_i in batch_label_list:
label_i_indexes = self.label_dict[label_i] label_i_indexes = self.label_dict[label_i]
if self.sample_per_label <= len(label_i_indexes): if self.sample_per_id <= len(label_i_indexes):
batch_index.extend( batch_index.extend(
np.random.choice( np.random.choice(
label_i_indexes, label_i_indexes,
size=self.sample_per_label, size=self.sample_per_id,
replace=False)) replace=False))
else: else:
batch_index.extend( batch_index.extend(
np.random.choice( np.random.choice(
label_i_indexes, label_i_indexes,
size=self.sample_per_label, size=self.sample_per_id,
replace=True)) replace=True))
if not self.drop_last or len(batch_index) == self.batch_size: if not self.drop_last or len(batch_index) == self.batch_size:
yield batch_index yield batch_index
...@@ -115,3 +115,4 @@ bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/MobileNetV3/Mo ...@@ -115,3 +115,4 @@ bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/MobileNetV3/Mo
- [test_serving_infer_python 使用](docs/test_serving_infer_python.md):测试python serving功能。 - [test_serving_infer_python 使用](docs/test_serving_infer_python.md):测试python serving功能。
- [test_serving_infer_cpp 使用](docs/test_serving_infer_cpp.md):测试cpp serving功能。 - [test_serving_infer_cpp 使用](docs/test_serving_infer_cpp.md):测试cpp serving功能。
- [test_train_fleet_inference_python 使用](./docs/test_train_fleet_inference_python.md):测试基于Python的多机多卡训练与推理等基本功能。 - [test_train_fleet_inference_python 使用](./docs/test_train_fleet_inference_python.md):测试基于Python的多机多卡训练与推理等基本功能。
- [benchmark_train 使用](./docs/benchmark_train.md):测试基于Python的训练benchmark等基本功能。
...@@ -51,7 +51,7 @@ inference:python/predict_rec.py -c configs/inference_rec.yaml ...@@ -51,7 +51,7 @@ inference:python/predict_rec.py -c configs/inference_rec.yaml
null:null null:null
null:null null:null
===========================train_benchmark_params========================== ===========================train_benchmark_params==========================
batch_size:256 batch_size:128
fp_items:fp32|fp16 fp_items:fp32|fp16
epoch:1 epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册