未验证 提交 2ed68d5d 编写于 作者: T Tingquan Gao 提交者: GitHub

Merge pull request #2264 from cuicheng01/add_table_attribute

add table_attribute_code
Global:
infer_imgs: "images/PULC/table_attribute/val_3610.jpg"
inference_model_dir: "./models/table_attribute_infer"
batch_size: 1
use_gpu: True
enable_mkldnn: True
cpu_num_threads: 10
benchmark: False
use_fp16: False
ir_optim: True
use_tensorrt: False
gpu_mem: 8000
enable_profile: False
PreProcess:
transform_ops:
- ResizeImage:
size: [224, 224]
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
channel_num: 3
- ToCHWImage:
PostProcess:
main_indicator: TableAttribute
TableAttribute:
source_threshold: 0.5
number_threshold: 0.5
color_threshold: 0.5
clarity_threshold : 0.5
obstruction_threshold: 0.5
angle_threshold: 0.5
......@@ -320,3 +320,49 @@ class VehicleAttribute(object):
).astype(np.int8).tolist()
batch_res.append({"attributes": label_res, "output": pred_res})
return batch_res
class TableAttribute(object):
def __init__(
self,
source_threshold=0.5,
number_threshold=0.5,
color_threshold=0.5,
clarity_threshold=0.5,
obstruction_threshold=0.5,
angle_threshold=0.5, ):
self.source_threshold = source_threshold
self.number_threshold = number_threshold
self.color_threshold = color_threshold
self.clarity_threshold = clarity_threshold
self.obstruction_threshold = obstruction_threshold
self.angle_threshold = angle_threshold
def __call__(self, batch_preds, file_names=None):
# postprocess output of predictor
batch_res = []
for res in batch_preds:
res = res.tolist()
label_res = []
source = 'Scanned' if res[0] > self.source_threshold else 'Photo'
number = 'Little' if res[1] > self.number_threshold else 'Numerous'
color = 'Black-and-White' if res[
2] > self.color_threshold else 'Multicolor'
clarity = 'Clear' if res[3] > self.clarity_threshold else 'Blurry'
obstruction = 'Without-Obstacles' if res[
4] > self.number_threshold else 'With-Obstacles'
angle = 'Horizontal' if res[
5] > self.number_threshold else 'Tilted'
label_res = [source, number, color, clarity, obstruction, angle]
threshold_list = [
self.source_threshold, self.number_threshold,
self.color_threshold, self.clarity_threshold,
self.obstruction_threshold, self.angle_threshold
]
pred_res = (np.array(res) > np.array(threshold_list)
).astype(np.int8).tolist()
batch_res.append({"attributes": label_res, "output": pred_res})
return batch_res
......@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
......@@ -136,7 +135,8 @@ def main(config):
for number, result_dict in enumerate(batch_results):
if "PersonAttribute" in config[
"PostProcess"] or "VehicleAttribute" in config[
"PostProcess"]:
"PostProcess"] or "TableAttribute" in config[
"PostProcess"]:
filename = batch_names[number]
print("{}:\t {}".format(filename, result_dict))
else:
......
......@@ -192,7 +192,8 @@ PULC_MODEL_BASE_DOWNLOAD_URL = "https://paddleclas.bj.bcebos.com/models/PULC/inf
PULC_MODELS = [
"car_exists", "language_classification", "person_attribute",
"person_exists", "safety_helmet", "text_image_orientation",
"textline_orientation", "traffic_sign", "vehicle_attribute"
"textline_orientation", "traffic_sign", "vehicle_attribute",
"table_attribute"
]
SHITU_MODEL_BASE_DOWNLOAD_URL = "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/{}_infer.tar"
......@@ -278,6 +279,7 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs):
if "thresh" in kwargs and kwargs[
"thresh"] and "ThreshOutput" in cfg.PostProcess:
cfg.PostProcess.ThreshOutput.thresh = kwargs["thresh"]
if cfg.get("PostProcess"):
if "Topk" in cfg.PostProcess:
if "topk" in kwargs and kwargs["topk"]:
......@@ -297,7 +299,25 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs):
if "type_threshold" in kwargs and kwargs["type_threshold"]:
cfg.PostProcess.VehicleAttribute.type_threshold = kwargs[
"type_threshold"]
if "TableAttribute" in cfg.PostProcess:
if "source_threshold" in kwargs and kwargs["source_threshold"]:
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
"source_threshold"]
if "number_threshold" in kwargs and kwargs["number_threshold"]:
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
"number_threshold"]
if "color_threshold" in kwargs and kwargs["color_threshold"]:
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
"color_threshold"]
if "clarity_threshold" in kwargs and kwargs["clarity_threshold"]:
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
"clarity_threshold"]
if "obstruction_threshold" in kwargs and kwargs["obstruction_threshold"]:
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
"obstruction_threshold"]
if "angle_threshold" in kwargs and kwargs["angle_threshold"]:
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
"angle_threshold"]
if "save_dir" in kwargs and kwargs["save_dir"]:
cfg.PostProcess.SavePreLabel.save_dir = kwargs["save_dir"]
......
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: "./output/"
device: "gpu"
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 20
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: "./inference"
use_multilabel: True
# model architecture
Arch:
name: "PPLCNet_x1_0"
pretrained: True
use_ssld: True
class_num: 6
# loss function config for traing/eval process
Loss:
Train:
- MultiLabelLoss:
weight: 1.0
weight_ratio: True
size_sum: True
Eval:
- MultiLabelLoss:
weight: 1.0
weight_ratio: True
size_sum: True
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.01
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.0005
# data loader for train and eval
DataLoader:
Train:
dataset:
name: MultiLabelDataset
image_root: "dataset/table_attribute/"
cls_label_path: "dataset/table_attribute/train_list.txt"
label_ratio: True
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: [224, 224]
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: True
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: MultiLabelDataset
image_root: "dataset/table_attribute/"
cls_label_path: "dataset/table_attribute/val_list.txt"
label_ratio: True
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: [224, 224]
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: deploy/images/PULC/table_attribute/val_3610.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: [224, 224]
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: TableAttribute
source_threshold: 0.5
number_threshold: 0.5
color_threshold: 0.5
clarity_threshold : 0.5
obstruction_threshold: 0.5
angle_threshold: 0.5
Metric:
Eval:
- ATTRMetric:
......@@ -18,7 +18,7 @@ from . import topk, threshoutput
from .topk import Topk, MultiLabelTopk
from .threshoutput import ThreshOutput
from .attr_rec import VehicleAttribute, PersonAttribute
from .attr_rec import VehicleAttribute, PersonAttribute, TableAttribute
def build_postprocess(config):
......
......@@ -71,7 +71,6 @@ class VehicleAttribute(object):
return batch_res
class PersonAttribute(object):
def __init__(self,
threshold=0.5,
......@@ -171,3 +170,58 @@ class PersonAttribute(object):
batch_res.append({"attributes": label_res, "output": pred_res})
return batch_res
class TableAttribute(object):
def __init__(
self,
source_threshold=0.5,
number_threshold=0.5,
color_threshold=0.5,
clarity_threshold=0.5,
obstruction_threshold=0.5,
angle_threshold=0.5, ):
self.source_threshold = source_threshold
self.number_threshold = number_threshold
self.color_threshold = color_threshold
self.clarity_threshold = clarity_threshold
self.obstruction_threshold = obstruction_threshold
self.angle_threshold = angle_threshold
def __call__(self, x, file_names=None):
if isinstance(x, dict):
x = x['logits']
assert isinstance(x, paddle.Tensor)
if file_names is not None:
assert x.shape[0] == len(file_names)
x = F.sigmoid(x).numpy()
# postprocess output of predictor
batch_res = []
for idx, res in enumerate(x):
res = res.tolist()
label_res = []
source = 'Scanned' if res[0] > self.source_threshold else 'Photo'
number = 'Little' if res[1] > self.number_threshold else 'Numerous'
color = 'Black-and-White' if res[
2] > self.color_threshold else 'Multicolor'
clarity = 'Clear' if res[3] > self.clarity_threshold else 'Blurry'
obstruction = 'Without-Obstacles' if res[
4] > self.number_threshold else 'With-Obstacles'
angle = 'Horizontal' if res[
5] > self.number_threshold else 'Tilted'
label_res = [source, number, color, clarity, obstruction, angle]
threshold_list = [
self.source_threshold, self.number_threshold,
self.color_threshold, self.clarity_threshold,
self.obstruction_threshold, self.angle_threshold
]
pred_res = (np.array(res) > np.array(threshold_list)
).astype(np.int8).tolist()
batch_res.append({
"attributes": label_res,
"output": pred_res,
"file_name": file_names[idx]
})
return batch_res
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册