提交 42758f54 编写于 作者: H HydrogenSulfate

fix benchmark train's config for GeneralRecognitionV2, and polish TIPC docs

上级 9b12161b
......@@ -53,14 +53,14 @@ class PKSampler(DistributedBatchSampler):
f"PKSampler configs error, sample_per_id({sample_per_id}) must be a divisor of batch_size({batch_size})."
assert hasattr(self.dataset,
"labels"), "Dataset must have labels attribute."
self.sample_per_label = sample_per_id
self.sample_per_id = sample_per_id
self.label_dict = defaultdict(list)
self.sample_method = sample_method
for idx, label in enumerate(self.dataset.labels):
self.label_dict[label].append(idx)
self.label_list = list(self.label_dict)
assert len(self.label_list) * self.sample_per_label > self.batch_size, \
"batch size should be smaller than "
assert len(self.label_list) * self.sample_per_id >= self.batch_size, \
f"batch size({self.batch_size}) should not be bigger than than #classes({len(self.label_list)})*sample_per_id({self.sample_per_id})"
if self.sample_method == "id_avg_prob":
self.prob_list = np.array([1 / len(self.label_list)] *
len(self.label_list))
......@@ -94,7 +94,7 @@ class PKSampler(DistributedBatchSampler):
format(diff))
def __iter__(self):
label_per_batch = self.batch_size // self.sample_per_label
label_per_batch = self.batch_size // self.sample_per_id
for _ in range(len(self)):
batch_index = []
batch_label_list = np.random.choice(
......@@ -104,17 +104,17 @@ class PKSampler(DistributedBatchSampler):
p=self.prob_list)
for label_i in batch_label_list:
label_i_indexes = self.label_dict[label_i]
if self.sample_per_label <= len(label_i_indexes):
if self.sample_per_id <= len(label_i_indexes):
batch_index.extend(
np.random.choice(
label_i_indexes,
size=self.sample_per_label,
size=self.sample_per_id,
replace=False))
else:
batch_index.extend(
np.random.choice(
label_i_indexes,
size=self.sample_per_label,
size=self.sample_per_id,
replace=True))
if not self.drop_last or len(batch_index) == self.batch_size:
yield batch_index
......@@ -115,3 +115,4 @@ bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/MobileNetV3/Mo
- [test_serving_infer_python 使用](docs/test_serving_infer_python.md):测试python serving功能。
- [test_serving_infer_cpp 使用](docs/test_serving_infer_cpp.md):测试cpp serving功能。
- [test_train_fleet_inference_python 使用](./docs/test_train_fleet_inference_python.md):测试基于Python的多机多卡训练与推理等基本功能。
- [benchmark_train 使用](./docs/benchmark_train.md):测试基于Python的训练benchmark等基本功能。
......@@ -51,7 +51,7 @@ inference:python/predict_rec.py -c configs/inference_rec.yaml
null:null
null:null
===========================train_benchmark_params==========================
batch_size:256
batch_size:128
fp_items:fp32|fp16
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册