提交 18acc25f 编写于 作者: C cuicheng01

add table_attribute_code

上级 6b218caf
Global:
infer_imgs: "images/PULC/table_attribute/val_3610.jpg"
inference_model_dir: "./models/table_attribute_infer"
batch_size: 1
use_gpu: True
enable_mkldnn: True
cpu_num_threads: 10
benchmark: False
use_fp16: False
ir_optim: True
use_tensorrt: False
gpu_mem: 8000
enable_profile: False
PreProcess:
transform_ops:
- ResizeImage:
size: [224, 224]
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
channel_num: 3
- ToCHWImage:
PostProcess:
main_indicator: TableAttribute
TableAttribute:
source_threshold: 0.5
number_threshold: 0.5
color_threshold: 0.5
clarity_threshold : 0.5
obstruction_threshold: 0.5
angle_threshold: 0.5
...@@ -320,3 +320,42 @@ class VehicleAttribute(object): ...@@ -320,3 +320,42 @@ class VehicleAttribute(object):
).astype(np.int8).tolist() ).astype(np.int8).tolist()
batch_res.append({"attributes": label_res, "output": pred_res}) batch_res.append({"attributes": label_res, "output": pred_res})
return batch_res return batch_res
class TableAttribute(object):
def __init__(self,
source_threshold=0.5,
number_threshold=0.5,
color_threshold=0.5,
clarity_threshold=0.5,
obstruction_threshold=0.5,
angle_threshold=0.5,
):
self.source_threshold = source_threshold
self.number_threshold = number_threshold
self.color_threshold = color_threshold
self.clarity_threshold = clarity_threshold
self.obstruction_threshold = obstruction_threshold
self.angle_threshold = angle_threshold
def __call__(self, batch_preds, file_names=None):
# postprocess output of predictor
batch_res = []
for res in batch_preds:
res = res.tolist()
label_res = []
source = 'Scanned' if res[0] > self.source_threshold else 'Photo'
number = 'Little' if res[1] > self.number_threshold else 'Numerous'
color = 'Black-and-White' if res[2] > self.color_threshold else 'Multicolor'
clarity = 'Clear' if res[3] > self.clarity_threshold else 'Blurry'
obstruction = 'Without-Obstacles' if res[4] > self.number_threshold else 'With-Obstacles'
angle = 'Horizontal' if res[5] > self.number_threshold else 'Tilted'
label_res = [source, number, color, clarity, obstruction, angle]
threshold_list = [self.source_threshold, self.number_threshold, self.color_threshold, self.clarity_threshold, self.obstruction_threshold, self.angle_threshold]
pred_res = (np.array(res) > np.array(threshold_list)
).astype(np.int8).tolist()
batch_res.append({"attributes": label_res, "output": pred_res})
return batch_res
...@@ -11,17 +11,21 @@ ...@@ -11,17 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
import cv2 import cv2
import numpy as np import numpy as np
from paddleclas.deploy.utils import logger, config from utils import logger
from paddleclas.deploy.utils.predictor import Predictor from utils import config
from paddleclas.deploy.utils.get_image_list import get_image_list from utils.predictor import Predictor
from paddleclas.deploy.python.preprocess import create_operators from utils.get_image_list import get_image_list
from paddleclas.deploy.python.postprocess import build_postprocess from python.preprocess import create_operators
from python.postprocess import build_postprocess
class ClsPredictor(Predictor): class ClsPredictor(Predictor):
...@@ -136,7 +140,7 @@ def main(config): ...@@ -136,7 +140,7 @@ def main(config):
for number, result_dict in enumerate(batch_results): for number, result_dict in enumerate(batch_results):
if "PersonAttribute" in config[ if "PersonAttribute" in config[
"PostProcess"] or "VehicleAttribute" in config[ "PostProcess"] or "VehicleAttribute" in config[
"PostProcess"]: "PostProcess"] or "TableAttribute" in config["PostProcess"]:
filename = batch_names[number] filename = batch_names[number]
print("{}:\t {}".format(filename, result_dict)) print("{}:\t {}".format(filename, result_dict))
else: else:
......
...@@ -191,7 +191,8 @@ PULC_MODEL_BASE_DOWNLOAD_URL = "https://paddleclas.bj.bcebos.com/models/PULC/inf ...@@ -191,7 +191,8 @@ PULC_MODEL_BASE_DOWNLOAD_URL = "https://paddleclas.bj.bcebos.com/models/PULC/inf
PULC_MODELS = [ PULC_MODELS = [
"car_exists", "language_classification", "person_attribute", "car_exists", "language_classification", "person_attribute",
"person_exists", "safety_helmet", "text_image_orientation", "person_exists", "safety_helmet", "text_image_orientation",
"textline_orientation", "traffic_sign", "vehicle_attribute" "textline_orientation", "traffic_sign", "vehicle_attribute",
"table_attribute"
] ]
...@@ -271,7 +272,25 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs): ...@@ -271,7 +272,25 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs):
if "type_threshold" in kwargs and kwargs["type_threshold"]: if "type_threshold" in kwargs and kwargs["type_threshold"]:
cfg.PostProcess.VehicleAttribute.type_threshold = kwargs[ cfg.PostProcess.VehicleAttribute.type_threshold = kwargs[
"type_threshold"] "type_threshold"]
if "TableAttribute" in cfg.PostProcess:
if "source_threshold" in kwargs and kwargs["source_threshold"]:
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
"source_threshold"]
if "number_threshold" in kwargs and kwargs["number_threshold"]:
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
"number_threshold"]
if "color_threshold" in kwargs and kwargs["color_threshold"]:
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
"color_threshold"]
if "clarity_threshold" in kwargs and kwargs["clarity_threshold"]:
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
"clarity_threshold"]
if "obstruction_threshold" in kwargs and kwargs["obstruction_threshold"]:
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
"obstruction_threshold"]
if "angle_threshold" in kwargs and kwargs["angle_threshold"]:
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
"angle_threshold"]
if "save_dir" in kwargs and kwargs["save_dir"]: if "save_dir" in kwargs and kwargs["save_dir"]:
cfg.PostProcess.SavePreLabel.save_dir = kwargs["save_dir"] cfg.PostProcess.SavePreLabel.save_dir = kwargs["save_dir"]
......
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: "./output/"
device: "gpu"
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 20
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: "./inference"
use_multilabel: True
# model architecture
Arch:
name: "PPLCNet_x1_0"
pretrained: True
use_ssld: True
class_num: 6
# loss function config for traing/eval process
Loss:
Train:
- MultiLabelLoss:
weight: 1.0
weight_ratio: True
size_sum: True
Eval:
- MultiLabelLoss:
weight: 1.0
weight_ratio: True
size_sum: True
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.01
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.0005
# data loader for train and eval
DataLoader:
Train:
dataset:
name: MultiLabelDataset
image_root: "dataset/table_attribute/"
cls_label_path: "dataset/table_attribute/train_list.txt"
label_ratio: True
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: [224, 224]
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: True
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: MultiLabelDataset
image_root: "dataset/table_attribute/"
cls_label_path: "dataset/table_attribute/val_list.txt"
label_ratio: True
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: [224, 224]
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: deploy/images/PULC/table_attribute/val_3610.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: [224, 224]
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: TableAttribute
source_threshold: 0.5
number_threshold: 0.5
color_threshold: 0.5
clarity_threshold : 0.5
obstruction_threshold: 0.5
angle_threshold: 0.5
Metric:
Eval:
- ATTRMetric:
...@@ -18,7 +18,7 @@ from . import topk, threshoutput ...@@ -18,7 +18,7 @@ from . import topk, threshoutput
from .topk import Topk, MultiLabelTopk from .topk import Topk, MultiLabelTopk
from .threshoutput import ThreshOutput from .threshoutput import ThreshOutput
from .attr_rec import VehicleAttribute, PersonAttribute from .attr_rec import VehicleAttribute, PersonAttribute, TableAttribute
def build_postprocess(config): def build_postprocess(config):
......
...@@ -171,3 +171,47 @@ class PersonAttribute(object): ...@@ -171,3 +171,47 @@ class PersonAttribute(object):
batch_res.append({"attributes": label_res, "output": pred_res}) batch_res.append({"attributes": label_res, "output": pred_res})
return batch_res return batch_res
class TableAttribute(object):
def __init__(self,
source_threshold=0.5,
number_threshold=0.5,
color_threshold=0.5,
clarity_threshold=0.5,
obstruction_threshold=0.5,
angle_threshold=0.5,
):
self.source_threshold = source_threshold
self.number_threshold = number_threshold
self.color_threshold = color_threshold
self.clarity_threshold = clarity_threshold
self.obstruction_threshold = obstruction_threshold
self.angle_threshold = angle_threshold
def __call__(self, x, file_names=None):
if isinstance(x, dict):
x = x['logits']
assert isinstance(x, paddle.Tensor)
if file_names is not None:
assert x.shape[0] == len(file_names)
x = F.sigmoid(x).numpy()
# postprocess output of predictor
batch_res = []
for idx, res in enumerate(x):
res = res.tolist()
label_res = []
source = 'Scanned' if res[0] > self.source_threshold else 'Photo'
number = 'Little' if res[1] > self.number_threshold else 'Numerous'
color = 'Black-and-White' if res[2] > self.color_threshold else 'Multicolor'
clarity = 'Clear' if res[3] > self.clarity_threshold else 'Blurry'
obstruction = 'Without-Obstacles' if res[4] > self.number_threshold else 'With-Obstacles'
angle = 'Horizontal' if res[5] > self.number_threshold else 'Tilted'
label_res = [source, number, color, clarity, obstruction, angle]
threshold_list = [self.source_threshold, self.number_threshold, self.color_threshold, self.clarity_threshold, self.obstruction_threshold, self.angle_threshold]
pred_res = (np.array(res) > np.array(threshold_list)
).astype(np.int8).tolist()
batch_res.append({"attributes": label_res, "output": pred_res, "file_name": file_names[idx]})
return batch_res
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册