未验证 提交 bbdb65ed 编写于 作者: L Liu Yi 提交者: GitHub

Add SMOKE model (#5308)

* add SMOKE

* add deployment

* add pretrained link.

* Update README.md

* Update README.md

* Update README.md

* update config

* add figure

* fix a typo

* delete unused codes

* add reference links

* resolved problems

* change to 2.1
Co-authored-by: Nliuyi22 <liuyi22@baidu.com>
上级 6f331e14
# SMOKE: Single-Stage Monocular 3D Object Detection via Keypoint Estimation
## Requirements
All codes are tested under the following environment:
* CentOS 7.5
* Python 3.7
* PaddlePaddle 2.1.0
* CUDA 10.2
## Preparation
1. PaddlePaddle installation
```bash
conda create -n paddle_latest python=3.7
conda actviate paddle_latest
pip install paddlepaddle -i https://mirror.baidu.com/pypi/simple
pip install -r requirement.txt
```
2. Dataset preparations
Please first download the dataset and organize it as following structure:
```
kitti
│──training
│ ├──calib
│ ├──label_2
│ ├──image_2
│ └──ImageSets
└──testing
├──calib
├──image_2
└──ImageSets
```
The make a soft link of kitti dataset and put it under `datasets/` folder.
```bash
mkdir datasets
ln -s path_to_kitti datasets/kitti
```
Note: If you want to use Waymo dataset for training, you should also organize it following the above structure.
3. Compile KITTI evaluation codes
```bash
cd tools/kitti_eval_offline
g++ -O3 -DNDEBUG -o evaluate_object_3d_offline evaluate_object_offline_40p.cpp
```
Note: evaluate\_object\_3d\_40/11p.cpp stands for 40-point/11-point evaluation.
For further details please refer to [KITTI 3D Object Dataset](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d) and [Disentangling Monocular 3D Object Detection](https://arxiv.org/abs/1905.12365).
For 11-point evaluation, simply change `evaluate_object_offline_40p.cpp` to `evaluate_object_offline_11p.cpp`.
## Training
Please download the [pre-trained weights](https://bj.bcebos.com/paddleseg/3d/smoke/dla34.pdparams). Put it into ```./pretrained```.
#### Single GPU
```bash
python train.py --config configs/train_val_kitti.yaml --log_iters 100 --save_interval 5000 --num_workers 2
```
#### Multi-GPUs
Take two cards as an example.
```bash
export CUDA_VISIBLE_DEVICES="6, 7" && python -m paddle.distributed.launch train.py --config configs/train_val_kitti.yaml --log_iters 100 --save_interval 5000 --num_workers 2
```
#### VisualDL
Run the following command. If successful, view the training visualization via browser.
```bash
visualdl --logdir ./output
```
## Evaluation
```bash
python val.py --config configs/train_val_kitti.yaml --model_path path-to-model/model.pdparams --num_workers 2
```
The performance on KITTI 3D detection is as follows:
| | Easy | Moderate | Hard |
|-------------|:-----:|:-----------:|:------:|
| Car | 6.51 | 4.98 | 4.63 |
| Pedestrian | 4.44 | 3.73 | 2.99 |
| Cyclist | 1.40 | 0.57 | 0.60 |
The performance on WAYMO 3D detection is as follows:
| | Easy | Moderate | Hard |
|-------------|:-----:|:-----------:|:------:|
| Car | 6.17 | 5.74 | 5.74 |
| Pedestrian | 0.35 | 0.34 | 0.34 |
| Cyclist | 0.54 | 0.53 | 0.53 |
Download the well-trained models here, [smoke-release](https://bj.bcebos.com/paddleseg/3d/smoke/smoke-release.zip).
## Testing
Please download and uncompress above model weights first.
```bash
python test.py --config configs/test_export.yaml --model_path path-to-model/model_waymo.pdparams --input_path examples/0615037.png --output_path paddle.png
```
<img align="center" src="docs/paddle.png" width="750">
## Model Deployment
1. Convert to a static-graph model
```bash
export PYTHONPATH="$PWD"
```
```bash
python deploy/export.py --config configs/test_export.yaml --model_path path-to-model/model_waymo.pdparams
```
Running the above command will generate three files in ```./depoly```, i.e. 1) inference.pdmodel, which maintains model graph/structure, 2) inference.pdiparams, which is well-trained parameters of the model, 3) inference.pdiparams.info, which includes extra meta info of the model.
2. Visualize the model stucture.
```bash
visualdl --model deploy/inference.pdmodel
```
The above command could be a little bit slow. Instead, open the browser first via the following command, and then open the pdmodel locally.
```bash
visualdl
```
Note: If you are using remote server, please specify the ```--host```, e.g. 10.9.189.6
```bash
visualdl --model deploy/inference.pdmodel --host 10.9.189.6
```
3. Python Inference on the converted model.
Now you can run the inference anywhere without the repo. We provive an example for python inference.
```bash
python deploy/infer.py --model_file deploy/inference.pdmodel --params_file deploy/inference.pdiparams --input_path examples/0615037.png --output_path paddle.png
```
<img align="center" src="docs/paddle.png" width="750">
## Reference
> Liu, Zechen, Zizhang Wu, and Roland Tóth. "Smoke: single-stage monocular 3d object detection via keypoint estimation." In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops, pp. 996-997. 2020.
batch_size: 8
iters: 70000
train_dataset:
type: KITTI
dataset_root: datasets/kitti/training
transforms:
- type: Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
flip_prob: 0.5
aug_prob: 0.3
mode: train
val_dataset:
type: KITTI
dataset_root: datasets/kitti/training
transforms:
- type: Normalize
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
mode: val
optimizer:
type: Adam
lr_scheduler:
type: MultiStepDecay
milestones: [36000, 55000]
learning_rate: 1.25e-4
\ No newline at end of file
_base_: 'train_val_kitti.yaml'
model:
post_process:
type: PostProcessorHm
\ No newline at end of file
_base_: '_base_/kitti.yaml'
model:
type: SMOKE
#pretrained: null
backbone:
type: DLA34
pretrained: "pretrained/dla34.pdparams"
head:
type: SMOKEPredictor
num_classes: 3
reg_heads: 10
reg_channels: [1, 2, 3, 2, 2]
num_chanels: 256
norm_type: "gn"
in_channels: 64
post_process:
type: PostProcessor
depth_ref: [28.01, 16.32]
dim_ref: [[3.88, 1.63, 1.53], [1.78, 1.70, 0.58], [0.88, 1.73, 0.67]]
reg_head: 10
det_threshold: 0.25
max_detection: 50
pred_2d: True
loss:
type: SMOKELossComputation
depth_ref: [28.01, 16.32]
dim_ref: [[3.88, 1.63, 1.53], [1.78, 1.70, 0.58], [0.88, 1.73, 0.67]]
reg_loss: "DisL1"
loss_weight: [1., 10.]
max_objs: 50
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import paddle
from smoke.cvlibs import Config
from smoke.utils import load_pretrained_model
def parse_args():
parser = argparse.ArgumentParser(description='Model Export')
# params of evaluate
parser.add_argument(
"--config", dest="cfg", help="The config file.", required=True, type=str)
parser.add_argument(
'--model_path',
dest='model_path',
help='The path of model for evaluation',
type=str,
required=True)
parser.add_argument(
'--output_dir',
dest='output_dir',
help='The directory saving inference params.',
type=str,
default="./deploy")
return parser.parse_args()
def main(args):
cfg = Config(args.cfg)
model = cfg.model
model.eval()
load_pretrained_model(model, args.model_path)
model = paddle.jit.to_static(model,
input_spec=[
paddle.static.InputSpec(
shape=[1, 3, None, None], dtype="float32",
),
[
paddle.static.InputSpec(
shape=[1, 3, 3], dtype="float32"
),
paddle.static.InputSpec(
shape=[1, 2], dtype="float32"
)
]
]
)
paddle.jit.save(model, os.path.join(args.output_dir, "inference"))
if __name__ == '__main__':
args = parse_args()
main(args)
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import argparse
import cv2
import paddle
from paddle.inference import Config
from paddle.inference import create_predictor
from smoke.utils.vis_utils import encode_box3d, draw_box_3d
def get_ratio(ori_img_size, output_size, down_ratio=(4, 4)):
return np.array([[down_ratio[1] * ori_img_size[1] / output_size[1],
down_ratio[0] * ori_img_size[0] / output_size[0]]], np.float32)
def get_img(img_path):
img = cv2.imread(img_path)
ori_img_size = img.shape
img = cv2.resize(img, (960, 640))
output_size = img.shape
img = img/255.0
img = np.subtract(img, np.array([0.485, 0.456, 0.406]))
img = np.true_divide(img, np.array([0.229, 0.224, 0.225]))
img = np.array(img, np.float32)
img = img.transpose(2, 0, 1)
img = img[None,:,:,:]
return img, ori_img_size, output_size
def init_predictor(args):
if args.model_dir is not "":
config = Config(args.model_dir)
else:
config = Config(args.model_file, args.params_file)
config.enable_memory_optim()
if args.use_gpu:
config.enable_use_gpu(1000, 0)
else:
# If not specific mkldnn, you can set the blas thread.
# The thread num should not be greater than the number of cores in the CPU.
config.set_cpu_math_library_num_threads(4)
config.enable_mkldnn()
predictor = create_predictor(config)
return predictor
def run(predictor, img):
# copy img data to input tensor
input_names = predictor.get_input_names()
for i, name in enumerate(input_names):
input_tensor = predictor.get_input_handle(name)
input_tensor.reshape(img[i].shape)
input_tensor.copy_from_cpu(img[i].copy())
# do the inference
predictor.run()
results = []
# get out data from output tensor
output_names = predictor.get_output_names()
for i, name in enumerate(output_names):
output_tensor = predictor.get_output_handle(name)
output_data = output_tensor.copy_to_cpu()
results.append(output_data)
return results
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_file",
type=str,
default="./inference.pdmodel",
help="Model filename, Specify this when your model is a combined model."
)
parser.add_argument(
"--params_file",
type=str,
default="./inference.pdiparams",
help=
"Parameter filename, Specify this when your model is a combined model."
)
parser.add_argument(
"--model_dir",
type=str,
default="",
help=
"Model dir, If you load a non-combined model, specify the directory of the model."
)
parser.add_argument(
'--input_path',
dest='input_path',
help='The image path',
type=str,
required=True)
parser.add_argument(
'--output_path',
dest='output_path',
help='The result path of image',
type=str,
required=True)
parser.add_argument("--use_gpu",
type=int,
default=0,
help="Whether use gpu.")
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
pred = init_predictor(args)
K = np.array([[[2055.56, 0, 939.658], [0, 2055.56, 641.072], [0, 0, 1]]], np.float32)
K_inverse = np.linalg.inv(K)
img_path = args.input_path
img, ori_img_size, output_size = get_img(img_path)
ratio = get_ratio(ori_img_size, output_size)
results = run(pred, [img, K_inverse, ratio])
total_pred = paddle.to_tensor(results[0])
keep_idx = paddle.nonzero(total_pred[:, -1] > 0.25)
total_pred = paddle.gather(total_pred, keep_idx)
if total_pred.shape[0] > 0:
pred_dimensions = total_pred[:, 6:9]
pred_dimensions = pred_dimensions.roll(shifts=1, axis=1)
pred_rotys = total_pred[:, 12]
pred_locations = total_pred[:, 9:12]
bbox_3d = encode_box3d(pred_rotys, pred_dimensions, pred_locations, paddle.to_tensor(K), (1280, 1920))
else:
bbox_3d = total_pred
img_draw = cv2.imread(img_path)
for idx in range(bbox_3d.shape[0]):
bbox = bbox_3d[idx]
bbox = bbox.transpose([1,0]).numpy()
img_draw = draw_box_3d(img_draw, bbox)
cv2.imwrite(args.output_path, img_draw)
put the pretrained model here.
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
from io import BytesIO
import urllib.request
from zipfile import ZipFile
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
if __name__ == "__main__":
file_url = "https://bj.bcebos.com/paddleseg/3d/smoke/dla34.pdparams"
urllib.request.urlretrieve(file_url, "dla34.pdparams")
smoke_model_path = 'https://bj.bcebos.com/paddleseg/3d/smoke/smoke-release.zip'
with urllib.request.urlopen(smoke_model_path) as zipresp:
with ZipFile(BytesIO(zipresp.read())) as zfile:
zfile.extractall(LOCAL_PATH)
\ No newline at end of file
visualdl
opencv-python
scikit-image
filelock
tqdm
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import models, datasets, transforms
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .train import train
from .val import evaluate
from .kitti_eval import kitti_evaluation
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import csv
import logging
import subprocess
import shutil
from smoke.utils.miscellaneous import mkdir
def kitti_evaluation(dataset, predictions, output_dir):
"""Do evaluation by process kitti eval program
Args:
dataset (paddle.io.Dataset): [description]
predictions (Paddle.Tensor): [description]
output_dir (str): path of save prediction
"""
# Clear data dir before do evaluate
if os.path.exists(os.path.join(output_dir, 'data')):
shutil.rmtree(os.path.join(output_dir, 'data'))
predict_folder = os.path.join(output_dir, 'data') # only recognize data
mkdir(predict_folder)
type_id_conversion = getattr(dataset, 'TYPE_ID_CONVERSION')
id_type_conversion = {value:key for key, value in type_id_conversion.items()}
for image_id, prediction in predictions.items():
predict_txt = image_id + '.txt'
predict_txt = os.path.join(predict_folder, predict_txt)
generate_kitti_3d_detection(prediction, predict_txt, id_type_conversion)
output_dir = os.path.abspath(output_dir)
root_dir = os.getcwd()
os.chdir('./tools/kitti_eval_offline')
label_dir = getattr(dataset, 'label_dir')
label_dir = os.path.join(root_dir, label_dir)
if not os.path.isfile('evaluate_object_3d_offline'):
subprocess.Popen('g++ -O3 -DNDEBUG -o evaluate_object_3d_offline evaluate_object_3d_offline.cpp', shell=True)
command = "./evaluate_object_3d_offline {} {}".format(label_dir, output_dir)
os.system(command)
def generate_kitti_3d_detection(prediction, predict_txt, id_type_conversion):
"""write kitti 3d detection result to txt file
Args:
prediction (list[float]): final prediction result
predict_txt (str): path to save the result
"""
with open(predict_txt, 'w', newline='') as f:
w = csv.writer(f, delimiter=' ', lineterminator='\n')
if len(prediction) == 0:
w.writerow([])
else:
for p in prediction:
p = p.round(4)
type = id_type_conversion[int(p[0])]
row = [type, 0, 0] + p[1:].tolist()
w.writerow(row)
check_last_line_break(predict_txt)
def check_last_line_break(predict_txt):
"""check predict last lint
Args:
predict_txt (str): path of predict txt
"""
f = open(predict_txt, 'rb+')
try:
f.seek(-1, os.SEEK_END)
except:
pass
else:
if f.__next__() == b'\n':
f.seek(-1, os.SEEK_END)
f.truncate()
f.close()
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Copy-paste from PaddleSeg with minor modifications.
https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.1/paddleseg/core/train.py
"""
import os
import time
from collections import deque
import shutil
import paddle
import paddle.nn.functional as F
from visualdl import LogWriter
from smoke.utils import TimeAverager, calculate_eta, logger
def train(model,
train_dataset,
val_dataset=None,
optimizer=None,
loss_computation=None,
save_dir='output',
iters=10000,
batch_size=2,
resume_model=None,
save_interval=1000,
log_iters=10,
num_workers=0,
keep_checkpoint_max=5):
"""
Launch training.
Args:
model(nn.Layer): A sementic segmentation model.
train_dataset (paddle.io.Dataset): Used to read and process training datasets.
val_dataset (paddle.io.Dataset, optional): Used to read and process validation datasets.
optimizer (paddle.optimizer.Optimizer): The optimizer.
loss_computation (nn.Layer): A loss function.
save_dir (str, optional): The directory for saving the model snapshot. Default: 'output'.
iters (int, optional): How may iters to train the model. Defualt: 10000.
batch_size (int, optional): Mini batch size of one gpu or cpu. Default: 2.
resume_model (str, optional): The path of resume model.
save_interval (int, optional): How many iters to save a model snapshot once during training. Default: 1000.
log_iters (int, optional): Display logging information at every log_iters. Default: 10.
num_workers (int, optional): Num workers for data loader. Default: 0.
keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5.
"""
model.train()
nranks = paddle.distributed.ParallelEnv().nranks
local_rank = paddle.distributed.ParallelEnv().local_rank
start_iter = 0
if resume_model is not None:
start_iter = resume(model, optimizer, resume_model)
if not os.path.isdir(save_dir):
if os.path.exists(save_dir):
os.remove(save_dir)
os.makedirs(save_dir)
if nranks > 1:
# Initialize parallel environment if not done.
if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
):
paddle.distributed.init_parallel_env()
ddp_model = paddle.DataParallel(model)
else:
ddp_model = paddle.DataParallel(model)
batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
loader = paddle.io.DataLoader(
train_dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
return_list=True,
)
# VisualDL log
log_writer = LogWriter(save_dir)
avg_loss = 0.0
avg_loss_dict = {}
iters_per_epoch = len(batch_sampler)
reader_cost_averager = TimeAverager()
batch_cost_averager = TimeAverager()
save_models = deque()
batch_start = time.time()
iter = start_iter
while iter < iters:
for data in loader:
iter += 1
if iter > iters:
break
reader_cost_averager.record(time.time() - batch_start)
images = data[0]
targets = data[1]
if nranks > 1:
predictions = ddp_model(images)
else:
predictions = model(images)
loss_dict = loss_computation(predictions, targets)
loss = sum(loss for loss in loss_dict.values())
loss.backward()
optimizer.step()
lr = optimizer.get_lr()
if isinstance(optimizer._learning_rate,
paddle.optimizer.lr.LRScheduler):
optimizer._learning_rate.step()
model.clear_gradients()
avg_loss += loss.numpy()[0] # get the value
if len(avg_loss_dict) == 0:
avg_loss_dict = {k:v.numpy()[0] for k, v in loss_dict.items()}
else:
for key, value in loss_dict.items():
avg_loss_dict[key] += value.numpy()[0]
batch_cost_averager.record(
time.time() - batch_start, num_samples=batch_size)
if (iter) % log_iters == 0 and local_rank == 0:
avg_loss /= log_iters
for key, value in avg_loss_dict.items():
avg_loss_dict[key] /= log_iters
remain_iters = iters - iter
avg_train_batch_cost = batch_cost_averager.get_average()
avg_train_reader_cost = reader_cost_averager.get_average()
eta = calculate_eta(remain_iters, avg_train_batch_cost)
logger.info(
"[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.5f} | ETA {}"
.format((iter - 1) // iters_per_epoch + 1, iter, iters,
avg_loss, lr, avg_train_batch_cost,
avg_train_reader_cost, eta))
######################### VisualDL Log ##########################
log_writer.add_scalar('Train/loss', avg_loss, iter)
# Record all losses if there are more than 2 losses.
for key, value in avg_loss_dict.items():
log_tag = 'Train/' + key
log_writer.add_scalar(log_tag, value, iter)
log_writer.add_scalar('Train/lr', lr, iter)
log_writer.add_scalar('Train/batch_cost',
avg_train_batch_cost, iter)
log_writer.add_scalar('Train/reader_cost',
avg_train_reader_cost, iter)
#################################################################
avg_loss = 0.0
avg_loss_list = {}
reader_cost_averager.reset()
batch_cost_averager.reset()
if (iter % save_interval == 0 or iter == iters) and local_rank == 0:
current_save_dir = os.path.join(save_dir,
"iter_{}".format(iter))
if not os.path.isdir(current_save_dir):
os.makedirs(current_save_dir)
paddle.save(model.state_dict(),
os.path.join(current_save_dir, 'model.pdparams'))
paddle.save(optimizer.state_dict(),
os.path.join(current_save_dir, 'model.pdopt'))
save_models.append(current_save_dir)
if len(save_models) > keep_checkpoint_max > 0:
model_to_remove = save_models.popleft()
shutil.rmtree(model_to_remove)
batch_start = time.time()
# Sleep for half a second to let dataloader release resources.
time.sleep(0.5)
log_writer.close()
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Copy-paste from PaddleSeg with minor modifications.
https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.1/paddleseg/core/train.py
"""
import os
import time
import numpy as np
import paddle
import paddle.nn.functional as F
from smoke.utils import TimeAverager, calculate_eta, logger, progbar
from .kitti_eval import kitti_evaluation
def evaluate(model,
eval_dataset,
num_workers=0,
output_dir="./output",
print_detail=True):
"""
Launch evalution.
Args:
model(nn.Layer): A model.
eval_dataset (paddle.io.Dataset): Used to read and process validation datasets.
num_workers (int, optional): Num workers for data loader. Default: 0.
print_detail (bool, optional): Whether to print detailed information about the evaluation process. Default: True.
Returns:
float: The mIoU of validation datasets.
float: The accuracy of validation datasets.
"""
model.eval()
batch_sampler = paddle.io.BatchSampler(
eval_dataset, batch_size=1, shuffle=False, drop_last=False)
loader = paddle.io.DataLoader(
eval_dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
return_list=True,
)
total_iters = len(loader)
if print_detail:
logger.info(
"Start evaluating (total_samples={}, total_iters={})...".format(
len(eval_dataset), total_iters))
progbar_val = progbar.Progbar(target=total_iters, verbose=1)
reader_cost_averager = TimeAverager()
batch_cost_averager = TimeAverager()
batch_start = time.time()
predictions = {}
with paddle.no_grad():
for cur_iter, batch in enumerate(loader):
reader_cost_averager.record(time.time() - batch_start)
images, targets, image_ids = batch[0], batch[1], batch[2]
output = model(images, targets)
output = output.numpy()
predictions.update(
{img_id: output for img_id in image_ids})
batch_cost_averager.record(
time.time() - batch_start, num_samples=len(targets))
batch_cost = batch_cost_averager.get_average()
reader_cost = reader_cost_averager.get_average()
if print_detail:
progbar_val.update(cur_iter + 1, [('batch_cost', batch_cost),
('reader cost', reader_cost)])
reader_cost_averager.reset()
batch_cost_averager.reset()
batch_start = time.time()
kitti_evaluation(eval_dataset, predictions, output_dir=output_dir)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import manager
from . import param_init
from .config import Config
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Copy-paste from PaddleSeg with minor modifications.
https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.1/paddleseg/cvlibs/config.py
"""
import codecs
import os
from typing import Any, Dict, Generic
import paddle
import yaml
from smoke.cvlibs import manager
from smoke.utils import logger
class Config(object):
'''
Training configuration parsing. The only yaml/yml file is supported.
The following hyper-parameters are available in the config file:
batch_size: The number of samples per gpu.
iters: The total training steps.
train_dataset: A training data config including type/data_root/transforms/mode.
For data type, please refer to paddleseg.datasets.
For specific transforms, please refer to paddleseg.transforms.transforms.
val_dataset: A validation data config including type/data_root/transforms/mode.
optimizer: A optimizer config, but currently PaddleSeg only supports sgd with momentum in config file.
In addition, weight_decay could be set as a regularization.
learning_rate: A learning rate config. If decay is configured, learning _rate value is the starting learning rate,
where only poly decay is supported using the config file. In addition, decay power and end_lr are tuned experimentally.
loss: A loss config. Multi-loss config is available. The loss type order is consistent with the seg model outputs,
where the coef term indicates the weight of corresponding loss. Note that the number of coef must be the same as the number of
model outputs, and there could be only one loss type if using the same loss type among the outputs, otherwise the number of
loss type must be consistent with coef.
model: A model config including type/backbone and model-dependent arguments.
For model type, please refer to paddleseg.models.
For backbone, please refer to paddleseg.models.backbones.
Args:
path (str) : The path of config file, supports yaml format only.
Examples:
from paddleseg.cvlibs.config import Config
# Create a cfg object with yaml file path.
cfg = Config(yaml_cfg_path)
# Parsing the argument when its property is used.
train_dataset = cfg.train_dataset
# the argument of model should be parsed after dataset,
# since the model builder uses some properties in dataset.
model = cfg.model
...
'''
def __init__(self,
path: str,
learning_rate: float = None,
batch_size: int = None,
iters: int = None):
if not path:
raise ValueError('Please specify the configuration file path.')
if not os.path.exists(path):
raise FileNotFoundError('File {} does not exist'.format(path))
self._model = None
if path.endswith('yml') or path.endswith('yaml'):
self.dic = self._parse_from_yaml(path)
else:
raise RuntimeError('Config file should in yaml format!')
self.update(
learning_rate=learning_rate, batch_size=batch_size, iters=iters)
def _update_dic(self, dic, base_dic):
"""
Update config from dic based base_dic
"""
base_dic = base_dic.copy()
for key, val in dic.items():
if isinstance(val, dict) and key in base_dic:
base_dic[key] = self._update_dic(val, base_dic[key])
else:
base_dic[key] = val
dic = base_dic
return dic
def _parse_from_yaml(self, path: str):
'''Parse a yaml file and build config'''
with codecs.open(path, 'r', 'utf-8') as file:
dic = yaml.load(file, Loader=yaml.FullLoader)
if '_base_' in dic:
cfg_dir = os.path.dirname(path)
base_path = dic.pop('_base_')
base_path = os.path.join(cfg_dir, base_path)
base_dic = self._parse_from_yaml(base_path)
dic = self._update_dic(dic, base_dic)
return dic
def update(self,
learning_rate: float = None,
batch_size: int = None,
iters: int = None):
'''Update config'''
if learning_rate:
self.dic['lr_scheduler']['learning_rate'] = learning_rate
if batch_size:
self.dic['batch_size'] = batch_size
if iters:
self.dic['iters'] = iters
@property
def batch_size(self):
return self.dic.get('batch_size', 1)
@property
def iters(self):
iters = self.dic.get('iters')
if not iters:
raise RuntimeError('No iters specified in the configuration file.')
return iters
@property
def train_dataset(self):
train_dataset_cfg = self.dic.get('train_dataset', {})
if not train_dataset_cfg:
return None
return self._load_object(train_dataset_cfg)
@property
def val_dataset(self):
val_dataset_cfg = self.dic.get('val_dataset', {})
if not val_dataset_cfg:
return None
return self._load_object(val_dataset_cfg)
@property
def model(self):
model_cfg = self.dic.get('model').copy()
if not model_cfg:
raise RuntimeError('No model specified in the configuration file.')
if not self._model:
self._model = self._load_object(model_cfg)
return self._model
@property
def lr_scheduler(self) -> paddle.optimizer.lr.LRScheduler:
if 'lr_scheduler' not in self.dic:
raise RuntimeError(
'No `lr_scheduler` specified in the configuration file.')
params = self.dic.get('lr_scheduler').copy()
if 'type' not in params.keys():
if "learning_rate" in params.keys():
logger.warning(''' No decay config! The fixed learning rate will be used''')
return params["learning_rate"]
else:
raise RuntimeError(
'`lr_scheduler` is not set properlly in the configuration file.')
lr_type = params.pop('type')
return getattr(paddle.optimizer.lr, lr_type)(**params)
@property
def optimizer(self):
if 'lr_scheduler' in self.dic:
lr = self.lr_scheduler
else:
lr = self.learning_rate
args = self.dic.get('optimizer', {}).copy()
optimizer_type = args.pop('type')
return getattr(paddle.optimizer, optimizer_type)(lr, parameters=self.model.parameters(), **args)
@property
def loss(self):
loss_cfg = self.dic.get('loss', {}).copy()
if not loss_cfg:
return None
return self._load_object(loss_cfg)
def _load_component(self, com_name):
com_list = [
manager.MODELS, manager.BACKBONES, manager.DATASETS,
manager.TRANSFORMS, manager.LOSSES, manager.HEADS,
manager.POSTPROCESSORS
]
for com in com_list:
if com_name in com.components_dict:
return com[com_name]
else:
raise RuntimeError(
'The specified component was not found {}.'.format(com_name))
def _load_object(self, cfg):
cfg = cfg.copy()
if 'type' not in cfg:
raise RuntimeError('No object information in {}.'.format(cfg))
component = self._load_component(cfg.pop('type'))
params = {}
for key, val in cfg.items():
if self._is_meta_type(val):
params[key] = self._load_object(val)
elif isinstance(val, list):
params[key] = [
self._load_object(item)
if self._is_meta_type(item) else item for item in val
]
else:
params[key] = val
return component(**params)
def _is_meta_type(self, item):
return isinstance(item, dict) and 'type' in item
def __str__(self):
return yaml.dump(self.dic)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file is modified from PaddleSeg:
https://github.com/PaddlePaddle/PaddleSeg/blob/release/v2.0/paddleseg/cvlibs/manager.py
"""
import inspect
from collections.abc import Sequence
class ComponentManager:
"""
Implement a manager class to add the new component properly.
The component can be added as either class or function type.
Args:
name (str): The name of component.
Returns:
A callable object of ComponentManager.
Examples 1:
from paddleseg.cvlibs.manager import ComponentManager
model_manager = ComponentManager()
class AlexNet: ...
class ResNet: ...
model_manager.add_component(AlexNet)
model_manager.add_component(ResNet)
# Or pass a sequence alliteratively:
model_manager.add_component([AlexNet, ResNet])
print(model_manager.components_dict)
# {'AlexNet': <class '__main__.AlexNet'>, 'ResNet': <class '__main__.ResNet'>}
Examples 2:
# Or an easier way, using it as a Python decorator, while just add it above the class declaration.
from paddleseg.cvlibs.manager import ComponentManager
model_manager = ComponentManager()
@model_manager.add_component
class AlexNet: ...
@model_manager.add_component
class ResNet: ...
print(model_manager.components_dict)
# {'AlexNet': <class '__main__.AlexNet'>, 'ResNet': <class '__main__.ResNet'>}
"""
def __init__(self, name=None):
self._components_dict = dict()
self._name = name
def __len__(self):
return len(self._components_dict)
def __repr__(self):
name_str = self._name if self._name else self.__class__.__name__
return "{}:{}".format(name_str, list(self._components_dict.keys()))
def __getitem__(self, item):
if item not in self._components_dict.keys():
raise KeyError("{} does not exist in availabel {}".format(
item, self))
return self._components_dict[item]
@property
def components_dict(self):
return self._components_dict
@property
def name(self):
return self._name
def _add_single_component(self, component):
"""
Add a single component into the corresponding manager.
Args:
component (function|class): A new component.
Raises:
TypeError: When `component` is neither class nor function.
KeyError: When `component` was added already.
"""
# Currently only support class or function type
if not (inspect.isclass(component) or inspect.isfunction(component)):
raise TypeError(
"Expect class/function type, but received {}".format(
type(component)))
# Obtain the internal name of the component
component_name = component.__name__
# Check whether the component was added already
if component_name in self._components_dict.keys():
raise KeyError("{} exists already!".format(component_name))
else:
# Take the internal name of the component as its key
self._components_dict[component_name] = component
def add_component(self, components):
"""
Add component(s) into the corresponding manager.
Args:
components (function|class|list|tuple): Support four types of components.
Returns:
components (function|class|list|tuple): Same with input components.
"""
# Check whether the type is a sequence
if isinstance(components, Sequence):
for component in components:
self._add_single_component(component)
else:
component = components
self._add_single_component(component)
return components
MODELS = ComponentManager("models")
BACKBONES = ComponentManager("backbones")
HEADS = ComponentManager("heads")
POSTPROCESSORS = ComponentManager("post_processors")
DATASETS = ComponentManager("datasets")
TRANSFORMS = ComponentManager("transforms")
LOSSES = ComponentManager("losses")
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Copy-paste from PaddleSeg
https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.1/paddleseg/cvlibs/param_init.py
"""
import paddle.nn as nn
def constant_init(param, **kwargs):
"""
Initialize the `param` with constants.
Args:
param (Tensor): Tensor that needs to be initialized.
Examples:
from paddleseg.cvlibs import param_init
import paddle.nn as nn
linear = nn.Linear(2, 4)
param_init.constant_init(linear.weight, value=2.0)
print(linear.weight.numpy())
# result is [[2. 2. 2. 2.], [2. 2. 2. 2.]]
"""
initializer = nn.initializer.Constant(**kwargs)
initializer(param, param.block)
def normal_init(param, **kwargs):
"""
Initialize the `param` with a Normal distribution.
Args:
param (Tensor): Tensor that needs to be initialized.
Examples:
from paddleseg.cvlibs import param_init
import paddle.nn as nn
linear = nn.Linear(2, 4)
param_init.normal_init(linear.weight, loc=0.0, scale=1.0)
"""
initializer = nn.initializer.Normal(**kwargs)
initializer(param, param.block)
def kaiming_normal_init(param, **kwargs):
"""
Initialize the input tensor with Kaiming Normal initialization.
This function implements the `param` initialization from the paper
`Delving Deep into Rectifiers: Surpassing Human-Level Performance on
ImageNet Classification <https://arxiv.org/abs/1502.01852>`
by Kaiming He, Xiangyu Zhang, Shaoqing Ren and Jian Sun. This is a
robust initialization method that particularly considers the rectifier
nonlinearities. In case of Uniform distribution, the range is [-x, x], where
.. math::
x = \sqrt{\\frac{6.0}{fan\_in}}
In case of Normal distribution, the mean is 0 and the standard deviation
is
.. math::
\sqrt{\\frac{2.0}{fan\_in}}
Args:
param (Tensor): Tensor that needs to be initialized.
Examples:
from paddleseg.cvlibs import param_init
import paddle.nn as nn
linear = nn.Linear(2, 4)
# uniform is used to decide whether to use uniform or normal distribution
param_init.kaiming_normal_init(linear.weight)
"""
initializer = nn.initializer.KaimingNormal(**kwargs)
initializer(param, param.block)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .kitti import KITTI
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import csv
import logging
import random
import paddle
import numpy as np
from PIL import Image
from smoke.cvlibs import manager
from smoke.transforms import Compose
from smoke.utils.heatmap_coder import (
get_transfrom_matrix,
affine_transform,
gaussian_radius,
draw_umich_gaussian,
encode_label
)
@manager.DATASETS.add_component
class KITTI(paddle.io.Dataset):
"""Parsing KITTI format dataset
Args:
Dataset (class):
"""
def __init__(self, dataset_root, mode="train", transforms=None, flip_prob=0.5, aug_prob=0.3):
super().__init__()
self.TYPE_ID_CONVERSION = {
'Car': 0,
'Cyclist': 1,
'Pedestrian': 2,
}
mode = mode.lower()
self.image_dir = os.path.join(dataset_root, "image_2")
self.label_dir = os.path.join(dataset_root, "label_2")
self.calib_dir = os.path.join(dataset_root, "calib")
if mode.lower() not in ['train', 'val', 'trainval', 'test']:
raise ValueError(
"mode should be 'train', 'val', 'trainval' or 'test', but got {}.".format(
mode))
imageset_txt = os.path.join(dataset_root, "ImageSets", "{}.txt".format(mode))
self.is_train = True if mode in ["train", "trainval"] else False
self.transforms = Compose(transforms)
image_files = []
for line in open(imageset_txt, "r"):
base_name = line.replace("\n", "")
image_name = base_name + ".png"
image_files.append(image_name)
self.image_files = image_files
self.label_files = [i.replace(".png", ".txt") for i in self.image_files]
self.num_samples = len(self.image_files)
self.classes = ("Car", "Cyclist", "Pedestrian")
self.flip_prob = flip_prob if self.is_train else 0.0
self.aug_prob = aug_prob if self.is_train else 0.0
self.shift_scale = (0.2, 0.4)
self.num_classes = len(self.classes)
self.input_width = 1280
self.input_height = 384
self.output_width = self.input_width // 4
self.output_height = self.input_height // 4
self.max_objs = 50
self.logger = logging.getLogger(__name__)
self.logger.info("Initializing KITTI {} set with {} files loaded".format(mode, self.num_samples))
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# load default parameter here
original_idx = self.label_files[idx].replace(".txt", "")
img_path = os.path.join(self.image_dir, self.image_files[idx])
img = Image.open(img_path)
anns, K = self.load_annotations(idx)
center = np.array([i / 2 for i in img.size], dtype=np.float32)
size = np.array([i for i in img.size], dtype=np.float32)
"""
resize, horizontal flip, and affine augmentation are performed here.
since it is complicated to compute heatmap w.r.t transform.
"""
flipped = False
if (self.is_train) and (random.random() < self.flip_prob):
flipped = True
img = img.transpose(Image.FLIP_LEFT_RIGHT)
center[0] = size[0] - center[0] - 1
K[0, 2] = size[0] - K[0, 2] - 1
affine = False
if (self.is_train) and (random.random() < self.aug_prob):
affine = True
shift, scale = self.shift_scale[0], self.shift_scale[1]
shift_ranges = np.arange(-shift, shift + 0.1, 0.1)
center[0] += size[0] * random.choice(shift_ranges)
center[1] += size[1] * random.choice(shift_ranges)
scale_ranges = np.arange(1 - scale, 1 + scale + 0.1, 0.1)
size *= random.choice(scale_ranges)
center_size = [center, size]
trans_affine = get_transfrom_matrix(
center_size,
[self.input_width, self.input_height]
)
trans_affine_inv = np.linalg.inv(trans_affine)
img = img.transform(
(self.input_width, self.input_height),
method=Image.AFFINE,
data=trans_affine_inv.flatten()[:6],
resample=Image.BILINEAR,
)
trans_mat = get_transfrom_matrix(
center_size,
[self.output_width, self.output_height]
)
if not self.is_train:
# for inference we parametrize with original size
target = {}
target["image_size"] = size
target["is_train"] = self.is_train
target["trans_mat"] = trans_mat
target["K"] = K
if self.transforms is not None:
img, target = self.transforms(img, target)
return np.array(img), target, original_idx
heat_map = np.zeros([self.num_classes, self.output_height, self.output_width], dtype=np.float32)
regression = np.zeros([self.max_objs, 3, 8], dtype=np.float32)
cls_ids = np.zeros([self.max_objs], dtype=np.int32)
proj_points = np.zeros([self.max_objs, 2], dtype=np.int32)
p_offsets = np.zeros([self.max_objs, 2], dtype=np.float32)
c_offsets = np.zeros([self.max_objs, 2], dtype=np.float32)
dimensions = np.zeros([self.max_objs, 3], dtype=np.float32)
locations = np.zeros([self.max_objs, 3], dtype=np.float32)
rotys = np.zeros([self.max_objs], dtype=np.float32)
reg_mask = np.zeros([self.max_objs], dtype=np.uint8)
flip_mask = np.zeros([self.max_objs], dtype=np.uint8)
bbox2d_size = np.zeros([self.max_objs, 2], dtype=np.float32)
for i, a in enumerate(anns):
if i == self.max_objs:
break
a = a.copy()
cls = a["label"]
locs = np.array(a["locations"])
rot_y = np.array(a["rot_y"])
if flipped:
locs[0] *= -1
rot_y *= -1
point, box2d, box3d = encode_label(
K, rot_y, a["dimensions"], locs
)
if np.all(box2d == 0):
continue
point = affine_transform(point, trans_mat)
box2d[:2] = affine_transform(box2d[:2], trans_mat)
box2d[2:] = affine_transform(box2d[2:], trans_mat)
box2d[[0, 2]] = box2d[[0, 2]].clip(0, self.output_width - 1)
box2d[[1, 3]] = box2d[[1, 3]].clip(0, self.output_height - 1)
h, w = box2d[3] - box2d[1], box2d[2] - box2d[0]
center = np.array([(box2d[0] + box2d[2]) / 2, (box2d[1] + box2d[3]) /2], dtype=np.float32)
if (0 < center[0] < self.output_width) and (0 < center[1] < self.output_height):
point_int = center.astype(np.int32)
p_offset = point - point_int
c_offset = center - point_int
radius = gaussian_radius(h, w)
radius = max(0, int(radius))
heat_map[cls] = draw_umich_gaussian(heat_map[cls], point_int, radius)
cls_ids[i] = cls
regression[i] = box3d
proj_points[i] = point_int
p_offsets[i] = p_offset
c_offsets[i] = c_offset
dimensions[i] = np.array(a["dimensions"])
locations[i] = locs
rotys[i] = rot_y
reg_mask[i] = 1 if not affine else 0
flip_mask[i] = 1 if not affine and flipped else 0
# targets for 2d bbox
bbox2d_size[i, 0] = w
bbox2d_size[i, 1] = h
target = {}
target["image_size"] = np.array(img.size)
target["is_train"] = self.is_train
target["trans_mat"] = trans_mat
target["K"] = K
target["hm"] = heat_map
target["reg"] = regression
target["cls_ids"] = cls_ids
target["proj_p"] = proj_points
target["dimensions"] = dimensions
target["locations"] = locations
target["rotys"] = rotys
target["reg_mask"] = reg_mask
target["flip_mask"] = flip_mask
target["bbox_size"] = bbox2d_size
target["c_offsets"] = c_offsets
if self.transforms is not None:
img, target = self.transforms(img, target)
return np.array(img), target, original_idx
def load_annotations(self, idx):
"""load kitti label by given index
Args:
idx (int): which label to load
Returns:
(list[dict], np.ndarray(float32, 3x3)): labels and camera intrinsic matrix
"""
annotations = []
file_name = self.label_files[idx]
fieldnames = ['type', 'truncated', 'occluded', 'alpha', 'xmin', 'ymin', 'xmax', 'ymax', 'dh', 'dw',
'dl', 'lx', 'ly', 'lz', 'ry']
if self.is_train:
if os.path.exists(os.path.join(self.label_dir, file_name)):
with open(os.path.join(self.label_dir, file_name), 'r') as csv_file:
reader = csv.DictReader(csv_file, delimiter=' ', fieldnames=fieldnames)
for line, row in enumerate(reader):
if (float(row["xmax"]) == 0.) | (float(row["ymax"]) == 0.):
continue
if row["type"] in self.classes:
annotations.append({
"class": row["type"],
"label": self.TYPE_ID_CONVERSION[row["type"]],
"truncation": float(row["truncated"]),
"occlusion": float(row["occluded"]),
"alpha": float(row["alpha"]),
"dimensions": [float(row['dl']), float(row['dh']), float(row['dw'])],
"locations": [float(row['lx']), float(row['ly']), float(row['lz'])],
"rot_y": float(row["ry"])
})
# get camera intrinsic matrix K
with open(os.path.join(self.calib_dir, file_name), 'r') as csv_file:
reader = csv.reader(csv_file, delimiter=' ')
for line, row in enumerate(reader):
if row[0] == 'P2:':
K = row[1:]
K = [float(i) for i in K]
K = np.array(K, dtype=np.float32).reshape(3, 4)
K = K[:3, :3]
break
return annotations, K
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .backbones import *
from .losses import *
from .heads import *
from .postprocess import *
from .smoke import SMOKE
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .dla import *
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
from smoke.models.layers import group_norm
from smoke.cvlibs import manager
from smoke.utils import pretrained_utils
__all__ = [
"DLA", "DLA34"
]
@manager.BACKBONES.add_component
class DLA(nn.Layer):
def __init__(self,
levels,
channels,
block,
down_ratio=4,
last_level=5,
out_channel=0,
norm_type="gn",
pretrained=None):
super().__init__()
self.pretrained = pretrained
assert down_ratio in [2, 4, 8, 16]
self.first_level = int(np.log2(down_ratio))
self.last_level = last_level
norm_func = nn.BatchNorm2D if norm_type == "bn" else group_norm
self.base = DLABase(levels,
channels,
block=eval(block),
norm_func=norm_func)
scales = [2 ** i for i in range(len(channels[self.first_level:]))]
self.dla_up = DLAUp(startp=self.first_level,
channels=channels[self.first_level:],
scales=scales,
norm_func=norm_func)
if out_channel == 0:
out_channel = channels[self.first_level]
up_scales = [2 ** i for i in range(self.last_level - self.first_level)]
self.ida_up = IDAUp(in_channels=channels[self.first_level:self.last_level],
out_channel=out_channel,
up_f=up_scales,
norm_func=norm_func)
self.init_weight()
def forward(self, x):
x = self.base(x)
x = self.dla_up(x)
y = []
iter_levels = range(self.last_level - self.first_level)
for i in iter_levels:
y.append(x[i].clone())
self.ida_up(y, 0, len(y))
return y[-1]
def init_weight(self):
pretrained_utils.load_pretrained_model(self, self.pretrained)
class DLABase(nn.Layer):
"""DLA base module
"""
def __init__(self,
levels,
channels,
block=None,
residual_root=False,
norm_func=None,
):
super().__init__()
self.channels = channels
self.level_length = len(levels)
if block is None:
block = BasicBlock
if norm_func is None:
norm_func = nn.BatchNorm2d
self.base_layer = nn.Sequential(
nn.Conv2D(3,
channels[0],
kernel_size=7,
stride=1,
padding=3,
bias_attr=False),
norm_func(channels[0]),
nn.ReLU()
)
self.level0 = _make_conv_level(in_channels=channels[0],
out_channels=channels[0],
num_convs=levels[0],
norm_func=norm_func)
self.level1 = _make_conv_level(in_channels=channels[0],
out_channels=channels[1],
num_convs=levels[0],
norm_func=norm_func,
stride=2)
self.level2 = Tree(level=levels[2],
block=block,
in_channels=channels[1],
out_channels=channels[2],
norm_func=norm_func,
stride=2,
level_root=False,
root_residual=residual_root)
self.level3 = Tree(level=levels[3],
block=block,
in_channels=channels[2],
out_channels=channels[3],
norm_func=norm_func,
stride=2,
level_root=True,
root_residual=residual_root)
self.level4 = Tree(level=levels[4],
block=block,
in_channels=channels[3],
out_channels=channels[4],
norm_func=norm_func,
stride=2,
level_root=True,
root_residual=residual_root)
self.level5 = Tree(level=levels[5],
block=block,
in_channels=channels[4],
out_channels=channels[5],
norm_func=norm_func,
stride=2,
level_root=True,
root_residual=residual_root)
def forward(self, x):
"""forward
"""
y = []
x = self.base_layer(x)
for i in range(self.level_length):
x = getattr(self, 'level{}'.format(i))(x)
y.append(x)
return y
class DLAUp(nn.Layer):
"""DLA Up module
"""
def __init__(self,
startp,
channels,
scales,
in_channels=None,
norm_func=None):
"""DLA Up module
"""
super(DLAUp, self).__init__()
self.startp = startp
if norm_func is None:
norm_func = nn.BatchNorm2d
if in_channels is None:
in_channels = channels
self.channels = channels
channels = list(channels)
scales = np.array(scales, dtype=int)
for i in range(len(channels) - 1):
j = -i - 2
setattr(self,
'ida_{}'.format(i),
IDAUp(in_channels[j:],
channels[j],
scales[j:] // scales[j],
norm_func))
scales[j + 1:] = scales[j]
in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]]
def forward(self, layers):
"""forward
"""
out = [layers[-1]] # start with 32
for i in range(len(layers) - self.startp - 1):
ida = getattr(self, 'ida_{}'.format(i))
ida(layers, len(layers) - i - 2, len(layers))
out.insert(0, layers[-1])
return out
class BasicBlock(nn.Layer):
"""Basic Block
"""
def __init__(self,
in_channels,
out_channels,
norm_func,
stride=1,
dilation=1):
super().__init__()
self.conv1 = nn.Conv2D(in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=dilation,
bias_attr=False,
dilation=dilation)
self.norm1 = norm_func(out_channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2D(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=dilation,
bias_attr=False,
dilation=dilation
)
self.norm2 = norm_func(out_channels)
def forward(self, x, residual=None):
"""forward
"""
if residual is None:
residual = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
out += residual
out = self.relu(out)
return out
class Tree(nn.Layer):
def __init__(self,
level,
block,
in_channels,
out_channels,
norm_func,
stride=1,
level_root=False,
root_dim=0,
root_kernel_size=1,
dilation=1,
root_residual=False
):
super(Tree, self).__init__()
if root_dim == 0:
root_dim = 2 * out_channels
if level_root:
root_dim += in_channels
if level == 1:
self.tree1 = block(in_channels,
out_channels,
norm_func,
stride,
dilation=dilation)
self.tree2 = block(out_channels,
out_channels,
norm_func,
stride=1,
dilation=dilation)
else:
new_level = level - 1
self.tree1 = Tree(new_level,
block,
in_channels,
out_channels,
norm_func,
stride,
root_dim=0,
root_kernel_size=root_kernel_size,
dilation=dilation,
root_residual=root_residual)
self.tree2 = Tree(new_level,
block,
out_channels,
out_channels,
norm_func,
root_dim=root_dim + out_channels,
root_kernel_size=root_kernel_size,
dilation=dilation,
root_residual=root_residual)
if level == 1:
self.root = Root(root_dim,
out_channels,
norm_func,
root_kernel_size,
root_residual)
self.level_root = level_root
self.root_dim = root_dim
self.level = level
self.downsample = None
if stride > 1:
self.downsample = nn.MaxPool2D(stride, stride=stride)
self.project = None
if in_channels != out_channels:
self.project = nn.Sequential(
nn.Conv2D(in_channels,
out_channels,
kernel_size=1,
stride=1,
bias_attr=False),
norm_func(out_channels)
)
def forward(self, x, residual=None, children=None):
"""forward
"""
if children is None:
children = []
if self.downsample:
bottom = self.downsample(x)
else:
bottom = x
if self.project:
residual = self.project(bottom)
else:
residual = bottom
if self.level_root:
children.append(bottom)
x1 = self.tree1(x, residual)
if self.level == 1:
x2 = self.tree2(x1)
x = self.root(x2, x1, *children)
else:
children.append(x1)
x = self.tree2(x1, children=children)
return x
class Root(nn.Layer):
"""Root module
"""
def __init__(self,
in_channels,
out_channels,
norm_func,
kernel_size,
residual):
super(Root, self).__init__()
self.conv = nn.Conv2D(in_channels,
out_channels,
kernel_size=1,
stride=1,
bias_attr=False,
padding=(kernel_size - 1) // 2)
self.norm = norm_func(out_channels)
self.relu = nn.ReLU()
self.residual = residual
def forward(self, *x):
"""forward
"""
children = x
x = self.conv(paddle.concat(x, 1))
x = self.norm(x)
if self.residual:
x += children[0]
x = self.relu(x)
return x
class IDAUp(nn.Layer):
"""IDAUp module
"""
def __init__(self,
in_channels,
out_channel,
up_f, # todo: what is up_f here?
norm_func):
super().__init__()
for i in range(1, len(in_channels)):
in_channel = in_channels[i]
f = int(up_f[i])
#USE_DEFORMABLE_CONV = False
# so far only support normal convolution
proj = NormalConv(in_channel, out_channel, norm_func)
node = NormalConv(out_channel, out_channel, norm_func)
up = nn.Conv2DTranspose(out_channel,
out_channel,
kernel_size=f * 2,
stride=f,
padding=f // 2,
output_padding=0,
groups=out_channel,
bias_attr=False)
# todo: uncommoment later
# _fill_up_weights(up)
setattr(self, 'proj_' + str(i), proj)
setattr(self, 'up_' + str(i), up)
setattr(self, 'node_' + str(i), node)
def forward(self, layers, startp, endp):
"""forward
"""
for i in range(startp + 1, endp):
upsample = getattr(self, 'up_' + str(i - startp))
project = getattr(self, 'proj_' + str(i - startp))
layers[i] = upsample(project(layers[i]))
node = getattr(self, 'node_' + str(i - startp))
layers[i] = node(layers[i] + layers[i - 1])
class NormalConv(nn.Layer):
"""Normal Conv without deformable
"""
def __init__(self,
in_channels,
out_channels,
norm_func):
super(NormalConv, self).__init__()
self.norm = norm_func(out_channels)
self.relu = nn.ReLU()
self.conv = nn.Conv2D(in_channels,
out_channels,
kernel_size=(3, 3),
padding=1)
def forward(self, x):
"""forward
"""
x = self.conv(x)
x = self.norm(x)
x = self.relu(x)
return x
def _make_conv_level(in_channels, out_channels, num_convs, norm_func,
stride=1, dilation=1):
"""
make conv layers based on its number.
"""
layers = []
for i in range(num_convs):
layers.extend([
nn.Conv2D(in_channels, out_channels, kernel_size=3,
stride=stride if i == 0 else 1,
padding=dilation, bias_attr=False, dilation=dilation),
norm_func(out_channels),
nn.ReLU()])
in_channels = out_channels
return nn.Sequential(*layers)
@manager.BACKBONES.add_component
def DLA34(**kwargs):
model = DLA(
levels=[1, 1, 1, 2, 2, 1],
channels=[16, 32, 64, 128, 256, 512],
block="BasicBlock",
**kwargs
)
return model
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .smoke_predictor import SMOKEPredictor
from .smoke_coder import SMOKECoder
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from smoke.ops import gather_op
class SMOKECoder(paddle.nn.Layer):
"""SMOKE Coder class
"""
def __init__(self, depth_ref, dim_ref):
super().__init__()
# self.depth_ref = paddle.to_tensor(depth_ref)
# self.dim_ref = paddle.to_tensor(dim_ref)
self.depth_decoder = DepthDecoder(depth_ref)
self.dimension_decoder = DimensionDecoder(dim_ref)
@staticmethod
def rad_to_matrix(rotys, N):
"""decode rotys to R_matrix
Args:
rotys (Tensor): roty of objects
N (int): num of batch
Returns:
Tensor: R matrix with shape (N, 3, 3)
R = [[cos(r), 0, sin(r)], [0, 1, 0], [-cos(r), 0, sin(r)]]
"""
cos, sin = rotys.cos(), rotys.sin()
i_temp = paddle.to_tensor([[1, 0, 1], [0, 1, 0], [-1, 0, 1]], dtype="float32")
# ry = paddle.reshape(i_temp.tile([N, 1]), (N, -1, 3))
# ry[:, 0, 0] *= cos
# ry[:, 0, 2] *= sin
# ry[:, 2, 0] *= sin
# ry[:, 2, 2] *= cos
# slice bug, so use concat
pos1 = (paddle.ones([N], dtype="float32") * cos).unsqueeze(-1)
pos2 = (paddle.zeros([N], dtype="float32")).unsqueeze(-1)
pos3 = (paddle.ones([N], dtype="float32") * sin).unsqueeze(-1)
pos4 = (paddle.zeros([N], dtype="float32")).unsqueeze(-1)
pos5 = (paddle.ones([N], dtype="float32")).unsqueeze(-1)
pos6 = (paddle.zeros([N], dtype="float32")).unsqueeze(-1)
pos7 = (paddle.ones([N], dtype="float32") * (-sin)).unsqueeze(-1)
pos8 = (paddle.zeros([N], dtype="float32")).unsqueeze(-1)
pos9 = (paddle.ones([N], dtype="float32") * cos).unsqueeze(-1)
ry = paddle.concat([pos1, pos2, pos3, pos4, pos5, pos6, pos7, pos8, pos9], axis=1)
ry = paddle.reshape(ry, [N, 3, 3])
return ry
def encode_box3d(self, rotys, dims, locs):
"""
construct 3d bounding box for each object.
Args:
rotys: rotation in shape N
dims: dimensions of objects
locs: locations of objects
Returns:
"""
if len(rotys.shape) == 2:
rotys = rotys.flatten()
if len(dims.shape) == 3:
dims = paddle.reshape(dims, (-1, 3))
if len(locs.shape) == 3:
locs = paddle.reshape(locs, (-1, 3))
N = rotys.shape[0]
ry = self.rad_to_matrix(rotys, N)
# if test:
# dims.register_hook(lambda grad: print('dims grad', grad.sum()))
# dims = paddle.reshape(dims, (-1, 1)).tile([1, 8])
# dims[::3, :4] = 0.5 * dims[::3, :4]
# dims[1::3, :4] = 0.
# dims[2::3, :4] = 0.5 * dims[2::3, :4]
# dims[::3, 4:] = -0.5 * dims[::3, 4:]
# dims[1::3, 4:] = -dims[1::3, 4:]
# dims[2::3, 4:] = -0.5 * dims[2::3, 4:]
dim_left_1 = (0.5 * dims[:, 0]).unsqueeze(-1)
dim_left_2 = paddle.zeros([dims.shape[0], 1]).astype("float32") #(paddle.zeros_like(dims[:, 1])).unsqueeze(-1)
dim_left_3 = (0.5 * dims[:, 2]).unsqueeze(-1)
dim_left = paddle.concat([dim_left_1, dim_left_2, dim_left_3], axis=1)
dim_left = paddle.reshape(dim_left, (-1, 1)).tile([1, 4])
dim_right_1 = (-0.5 * dims[:, 0]).unsqueeze(-1)
dim_right_2 = (-dims[:, 1]).unsqueeze(-1)
dim_right_3 = (-0.5 * dims[:, 2]).unsqueeze(-1)
dim_right = paddle.concat([dim_right_1, dim_right_2, dim_right_3], axis=1)
dim_right = paddle.reshape(dim_right, (-1, 1)).tile([1, 4])
dims = paddle.concat([dim_left, dim_right], axis=1)
index = paddle.to_tensor([[4, 0, 1, 2, 3, 5, 6, 7],
[4, 5, 0, 1, 6, 7, 2, 3],
[4, 5, 6, 0, 1, 2, 3, 7]]).tile([N, 1])
box_3d_object = gather_op(dims, 1, index)
box_3d = paddle.matmul(ry, paddle.reshape(box_3d_object, (N, 3, -1)))
# box_3d += locs.unsqueeze(-1).repeat(1, 1, 8)
box_3d += locs.unsqueeze(-1).tile((1, 1, 8))
return box_3d
def decode_depth(self, depths_offset):
"""
Transform depth offset to depth
"""
#depth = depths_offset * self.depth_ref[1] + self.depth_ref[0]
#return depth
return self.depth_decoder(depths_offset)
def decode_location(self,
points,
points_offset,
depths,
Ks,
trans_mats):
"""
retrieve objects location in camera coordinate based on projected points
Args:
points: projected points on feature map in (x, y)
points_offset: project points offset in (delata_x, delta_y)
depths: object depth z
Ks: camera intrinsic matrix, shape = [N, 3, 3]
trans_mats: transformation matrix from image to feature map, shape = [N, 3, 3]
Returns:
locations: objects location, shape = [N, 3]
"""
# number of points
N = points_offset.shape[0]
# batch size
N_batch = Ks.shape[0]
batch_id = paddle.arange(N_batch).unsqueeze(1)
# obj_id = batch_id.repeat(1, N // N_batch).flatten()
obj_id = batch_id.tile([1, N // N_batch]).flatten()
# trans_mats_inv = trans_mats.inverse()[obj_id]
# Ks_inv = Ks.inverse()[obj_id]
inv = trans_mats.inverse()
trans_mats_inv = paddle.concat([inv[int(obj_id[i])].unsqueeze(0) for i in range(len(obj_id))])
inv = Ks.inverse()
Ks_inv = paddle.concat([inv[int(obj_id[i])].unsqueeze(0) for i in range(len(obj_id))])
points = paddle.reshape(points, (-1, 2))
assert points.shape[0] == N
# int + float -> int, but float + int -> float
# proj_points = points + points_offset
proj_points = points_offset + points
# transform project points in homogeneous form.
proj_points_extend = paddle.concat(
(proj_points.astype("float32"), paddle.ones((N, 1))), axis=1)
# expand project points as [N, 3, 1]
proj_points_extend = proj_points_extend.unsqueeze(-1)
# transform project points back on image
proj_points_img = paddle.matmul(trans_mats_inv, proj_points_extend)
# with depth
proj_points_img = proj_points_img * paddle.reshape(depths, (N, -1, 1))
# transform image coordinates back to object locations
locations = paddle.matmul(Ks_inv, proj_points_img)
return locations.squeeze(2)
def decode_location_without_transmat(self,
points, points_offset,
depths, Ks, down_ratios=None):
"""
retrieve objects location in camera coordinate based on projected points
Args:
points: projected points on feature map in (x, y)
points_offset: project points offset in (delata_x, delta_y)
depths: object depth z
Ks: camera intrinsic matrix, shape = [N, 3, 3]
trans_mats: transformation matrix from image to feature map, shape = [N, 3, 3]
Returns:
locations: objects location, shape = [N, 3]
"""
if down_ratios is None:
down_ratios = [(1, 1)]
# number of points
N = points_offset.shape[0]
# batch size
N_batch = Ks.shape[0]
#batch_id = paddle.arange(N_batch).unsqueeze(1)
batch_id = paddle.arange(N_batch).reshape((N_batch, 1))
# obj_id = batch_id.repeat(1, N // N_batch).flatten()
obj_id = batch_id.tile([1, N // N_batch]).flatten()
# Ks_inv = Ks[obj_id] pytorch
# Ks_inv = paddle.concat([Ks[int(obj_id[i])].unsqueeze(0) for i in range(len(obj_id))])
length = int(obj_id.shape[0])
ks_v = []
for i in range(length):
ks_v.append(Ks[int(obj_id[i])].unsqueeze(0))
Ks_inv = paddle.concat(ks_v)
down_ratio = down_ratios[0]
points = paddle.reshape(points, (numel_t(points)//2, 2))
proj_points = points + points_offset
# trans point from heatmap to ori image, down_sample * resize_scale
proj_points[:, 0] = down_ratio[0] * proj_points[:, 0]
proj_points[:, 1] = down_ratio[1] * proj_points[:, 1]
# transform project points in homogeneous form.
proj_points_extend = paddle.concat(
[proj_points, paddle.ones((N, 1))], axis=1)
# expand project points as [N, 3, 1]
proj_points_extend = proj_points_extend.unsqueeze(-1)
# with depth
proj_points_img = proj_points_extend * paddle.reshape(depths, (N, numel_t(depths)//N, 1))
# transform image coordinates back to object locations
locations = paddle.matmul(Ks_inv, proj_points_img)
return locations.squeeze(2)
def decode_bbox_2d_without_transmat(self, points, bbox_size, down_ratios=None):
"""get bbox 2d
Args:
points (paddle.Tensor, (50, 2)): 2d center
bbox_size (paddle.Tensor, (50, 2)): 2d bbox height and width
trans_mats (paddle.Tensor, (1, 3, 3)): transformation coord from img to feature map
"""
if down_ratios is None:
down_ratios = [(1, 1)]
# number of points
N = bbox_size.shape[0]
points = paddle.reshape(points, (-1, 2))
assert points.shape[0] == N
box2d = paddle.zeros((N, 4))
down_ratio = down_ratios[0]
box2d[:, 0] = (points[:, 0] - bbox_size[:, 0] / 2)
box2d[:, 1] = (points[:, 1] - bbox_size[:, 1] / 2)
box2d[:, 2] = (points[:, 0] + bbox_size[:, 0] / 2)
box2d[:, 3] = (points[:, 1] + bbox_size[:, 1] / 2)
box2d[:, 0] = down_ratio[0] * box2d[:, 0]
box2d[:, 1] = down_ratio[1] * box2d[:, 1]
box2d[:, 2] = down_ratio[0] * box2d[:, 2]
box2d[:, 3] = down_ratio[1] * box2d[:, 3]
return box2d
def decode_dimension(self, cls_id, dims_offset):
"""
retrieve object dimensions
Args:
cls_id: each object id
dims_offset: dimension offsets, shape = (N, 3)
Returns:
"""
# cls_id = cls_id.flatten().long()
# dims_select = self.dim_ref[cls_id, :]
# cls_id = cls_id.flatten()
# dims_select = paddle.concat([self.dim_ref[int(cls_id[i])].unsqueeze(0) for i in range(len(cls_id))])
# dimensions = dims_offset.exp() * dims_select
# return dimensions
return self.dimension_decoder(cls_id, dims_offset)
def decode_orientation(self, vector_ori, locations, flip_mask=None):
"""
retrieve object orientation
Args:
vector_ori: local orientation in [sin, cos] format
locations: object location
Returns: for training we only need roty
for testing we need both alpha and roty
"""
locations = paddle.reshape(locations, (-1, 3))
rays = paddle.atan(locations[:, 0] / (locations[:, 2] + 1e-7))
alphas = paddle.atan(vector_ori[:, 0] / (vector_ori[:, 1] + 1e-7))
# get cosine value positive and negtive index.
cos_pos_idx = (vector_ori[:, 1] >= 0).nonzero()
cos_neg_idx = (vector_ori[:, 1] < 0).nonzero()
PI = 3.14159
for i in range(cos_pos_idx.shape[0]):
ind = int(cos_pos_idx[i,0])
alphas[ind] = alphas[ind] - PI / 2
for i in range(cos_neg_idx.shape[0]):
ind = int(cos_neg_idx[i,0])
alphas[ind] = alphas[ind] + PI / 2
# alphas[cos_pos_idx] -= PI / 2
# alphas[cos_neg_idx] += PI / 2
# retrieve object rotation y angle.
rotys = alphas + rays
# in training time, it does not matter if angle lies in [-PI, PI]
# it matters at inference time? todo: does it really matter if it exceeds.
larger_idx = (rotys > PI).nonzero()
small_idx = (rotys < -PI).nonzero()
if len(larger_idx) != 0:
for i in range(larger_idx.shape[0]):
ind = int(larger_idx[i,0])
rotys[ind] -= 2 * PI
if len(small_idx) != 0:
for i in range(small_idx.shape[0]):
ind = int(small_idx[i,0])
rotys[ind] += 2 * PI
if flip_mask is not None:
fm = flip_mask.astype("float32").flatten()
rotys_flip = fm * rotys
# rotys_flip_pos_idx = rotys_flip > 0
# rotys_flip_neg_idx = rotys_flip < 0
# rotys_flip[rotys_flip_pos_idx] -= PI
# rotys_flip[rotys_flip_neg_idx] += PI
rotys_flip_pos_idx = (rotys_flip > 0).nonzero()
rotys_flip_neg_idx = (rotys_flip < 0).nonzero()
for i in range(rotys_flip_pos_idx.shape[0]):
ind = int(rotys_flip_pos_idx[i, 0])
rotys_flip[ind] -= PI
for i in range(rotys_flip_neg_idx.shape[0]):
ind = int(rotys_flip_neg_idx[i, 0])
rotys_flip[ind] += PI
rotys_all = fm * rotys_flip + (1 - fm) * rotys
return rotys_all
else:
return rotys, alphas
def decode_bbox_2d(self, points, bbox_size, trans_mats, img_size):
"""get bbox 2d
Args:
points (paddle.Tensor, (50, 2)): 2d center
bbox_size (paddle.Tensor, (50, 2)): 2d bbox height and width
trans_mats (paddle.Tensor, (1, 3, 3)): transformation coord from img to feature map
"""
img_size = img_size.flatten()
# number of points
N = bbox_size.shape[0]
# batch size
N_batch = trans_mats.shape[0]
batch_id = paddle.arange(N_batch).unsqueeze(1)
# obj_id = batch_id.repeat(1, N // N_batch).flatten()
obj_id = batch_id.tile([1, N // N_batch]).flatten()
inv = trans_mats.inverse()
trans_mats_inv = paddle.concat([inv[int(obj_id[i])].unsqueeze(0) for i in range(len(obj_id))])
#trans_mats_inv = trans_mats.inverse()[obj_id]
points = paddle.reshape(points, (-1, 2))
assert points.shape[0] == N
box2d = paddle.zeros([N, 4])
box2d[:, 0] = (points[:, 0] - bbox_size[:, 0] / 2)
box2d[:, 1] = (points[:, 1] - bbox_size[:, 1] / 2)
box2d[:, 2] = (points[:, 0] + bbox_size[:, 0] / 2)
box2d[:, 3] = (points[:, 1] + bbox_size[:, 1] / 2)
# transform project points in homogeneous form.
proj_points_extend_top = paddle.concat(
(box2d[:, :2], paddle.ones([N, 1])), axis=1)
proj_points_extend_bot = paddle.concat(
(box2d[:, 2:], paddle.ones([N, 1])), axis=1)
# expand project points as [N, 3, 1]
proj_points_extend_top = proj_points_extend_top.unsqueeze(-1)
proj_points_extend_bot = proj_points_extend_bot.unsqueeze(-1)
# transform project points back on image
proj_points_img_top = paddle.matmul(trans_mats_inv, proj_points_extend_top)
proj_points_img_bot = paddle.matmul(trans_mats_inv, proj_points_extend_bot)
box2d[:, :2] = proj_points_img_top.squeeze(2)[:, :2]
box2d[:, 2:] = proj_points_img_bot.squeeze(2)[:, :2]
box2d[:, ::2] = box2d[:, ::2].clip(0, img_size[0])
box2d[:, 1::2] = box2d[:, 1::2].clip(0, img_size[1])
return box2d
class DepthDecoder(paddle.nn.Layer):
def __init__(self, depth_ref):
super().__init__()
self.depth_ref = paddle.to_tensor(depth_ref)
def forward(self, depths_offset):
"""
Transform depth offset to depth
"""
depth = depths_offset * self.depth_ref[1] + self.depth_ref[0]
return depth
class DimensionDecoder(paddle.nn.Layer):
def __init__(self, dim_ref):
super().__init__()
self.dim_ref = paddle.to_tensor(dim_ref)
def forward(self, cls_id, dims_offset):
"""
retrieve object dimensions
Args:
cls_id: each object id
dims_offset: dimension offsets, shape = (N, 3)
Returns:
"""
# cls_id = cls_id.flatten().long()
# dims_select = self.dim_ref[cls_id, :]
cls_id = cls_id.flatten()
#dims_select = paddle.concat([self.dim_ref[int(cls_id[i])].unsqueeze(0) for i in range(len(cls_id))])
length = int(cls_id.shape[0])
list_v = []
for i in range(length):
list_v.append(self.dim_ref[int(cls_id[i])].unsqueeze(0))
dims_select = paddle.concat(list_v)
dimensions = dims_offset.exp() * dims_select
return dimensions
def numel_t(var):
from numpy import prod
assert -1 not in var.shape
return prod(var.shape)
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from smoke.models.layers import group_norm, sigmoid_hm
from smoke.cvlibs import manager, param_init
@manager.HEADS.add_component
class SMOKEPredictor(nn.Layer):
"""SMOKE Predictor
"""
def __init__(self,
num_classes=3,
reg_heads=10,
reg_channels=(1, 2, 3, 2, 2),
num_chanels=256,
norm_type="gn",
in_channels=64):
super().__init__()
regression = reg_heads
regression_channels = reg_channels
head_conv = num_chanels
norm_func = nn.BatchNorm2D if norm_type == "bn" else group_norm
assert sum(regression_channels) == regression, \
"the sum of {} must be equal to regression channel of {}".format(
reg_channels, reg_heads
)
self.dim_channel = get_channel_spec(regression_channels, name="dim")
self.ori_channel = get_channel_spec(regression_channels, name="ori")
self.class_head = nn.Sequential(
nn.Conv2D(in_channels,
head_conv,
kernel_size=3,
padding=1,
bias_attr=True),
norm_func(head_conv),
nn.ReLU(),
nn.Conv2D(head_conv,
num_classes,
kernel_size=1,
padding=1 // 2,
bias_attr=True)
)
# todo: what is datafill here
#self.class_head[-1].bias.data.fill_(-2.19)
param_init.constant_init(self.class_head[-1].bias, value=-2.19)
self.regression_head = nn.Sequential(
nn.Conv2D(in_channels,
head_conv,
kernel_size=3,
padding=1,
bias_attr=True),
norm_func(head_conv),
nn.ReLU(),
nn.Conv2D(head_conv,
regression,
kernel_size=1,
padding=1 // 2,
bias_attr=True)
)
#_fill_fc_weights(self.regression_head)
self.init_weight(self.regression_head)
def forward(self, features):
"""predictor forward
Args:
features (paddle.Tensor): smoke backbone output
Returns:
list: sigmoid class heatmap and regression map
"""
head_class = self.class_head(features)
head_regression = self.regression_head(features)
head_class = sigmoid_hm(head_class)
# (N, C, H, W)
# left slice bug
# offset_dims = head_regression[:, self.dim_channel, :, :].clone()
# head_regression[:, self.dim_channel, :, :] = F.sigmoid(offset_dims) - 0.5
# vector_ori = head_regression[:, self.ori_channel, :, :].clone()
# head_regression[:, self.ori_channel, :, :] = F.normalize(vector_ori)
offset_dims = head_regression[:, self.dim_channel, :, :].clone()
head_reg_dim = F.sigmoid(offset_dims) - 0.5
vector_ori = head_regression[:, self.ori_channel, :, :].clone()
head_reg_ori = F.normalize(vector_ori)
head_regression_left = head_regression[:, :self.dim_channel.start, :, :]
head_regression_right = head_regression[:, self.ori_channel.stop:, :, :]
head_regression = paddle.concat([head_regression_left, head_reg_dim, head_reg_ori, head_regression_right], axis=1)
return [head_class, head_regression]
def init_weight(self, block):
for sublayer in block.sublayers():
if isinstance(sublayer, nn.Conv2D):
param_init.constant_init(sublayer.bias, value=0.0)
def get_channel_spec(reg_channels, name):
"""get dim and ori dim
Args:
reg_channels (tuple): regress channels, default(1, 2, 3, 2) for
(depth_offset, keypoint_offset, dims, ori)
name (str): dim or ori
Returns:
slice: for start channel to stop channel
"""
if name == "dim":
s = sum(reg_channels[:2])
e = sum(reg_channels[:3])
elif name == "ori":
s = sum(reg_channels[:3])
e = sum(reg_channels[:4])
return slice(s, e, 1)
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .gn import group_norm
from .layer_libs import sigmoid_hm, nms_hm, select_topk, select_point_of_interest
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
def group_norm(out_channels):
"""group normal function
Args:
out_channels (int): out channel nums
Returns:
nn.Module: GroupNorm op
"""
num_groups = 32
if out_channels % 32 == 0:
return nn.GroupNorm(num_groups, out_channels)
else:
return nn.GroupNorm(num_groups // 2, out_channels)
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from smoke.ops import gather_op
def sigmoid_hm(hm_features):
"""sigmoid to headmap
Args:
hm_features (paddle.Tensor): heatmap
Returns:
paddle.Tensor: sigmoid heatmap
"""
x = F.sigmoid(hm_features)
x = x.clip(min=1e-4, max=1 - 1e-4)
return x
def nms_hm(heat_map, kernel=3):
"""Do max_pooling for nms
Args:
heat_map (paddle.Tensor): pred cls heatmap
kernel (int, optional): max_pool kernel size. Defaults to 3.
Returns:
heatmap after nms
"""
pad = (kernel - 1) // 2
hmax = F.max_pool2d(heat_map,
kernel_size=(kernel, kernel),
stride=1,
padding=pad)
eq_index = (hmax == heat_map).astype("float32")
return heat_map * eq_index
def select_topk(heat_map, K=100):
"""
Args:
heat_map: heat_map in [N, C, H, W]
K: top k samples to be selected
score: detection threshold
Returns:
"""
#batch, c, height, width = paddle.shape(heat_map)
batch, c = heat_map.shape[:2]
height = paddle.shape(heat_map)[2]
width = paddle.shape(heat_map)[3]
# First select topk scores in all classes and batchs
# [N, C, H, W] -----> [N, C, H*W]
heat_map = paddle.reshape(heat_map, (batch, c, -1))
# Both in [N, C, K]
topk_scores_all, topk_inds_all = paddle.topk(heat_map, K)
# topk_inds_all = topk_inds_all % (height * width) # todo: this seems redudant
topk_ys = (topk_inds_all // width).astype("float32")
topk_xs = (topk_inds_all % width).astype("float32")
# Select topK examples across channel
# [N, C, K] -----> [N, C*K]
topk_scores_all = paddle.reshape(topk_scores_all, (batch, -1))
# Both in [N, K]
topk_scores, topk_inds = paddle.topk(topk_scores_all, K)
topk_clses = (topk_inds // K).astype("float32")
# First expand it as 3 dimension
topk_inds_all = paddle.reshape(_gather_feat(paddle.reshape(topk_inds_all, (batch, -1, 1)), topk_inds), (batch, K))
topk_ys = paddle.reshape(_gather_feat(paddle.reshape(topk_ys, (batch, -1, 1)), topk_inds), (batch, K))
topk_xs = paddle.reshape(_gather_feat(paddle.reshape(topk_xs, (batch, -1, 1)), topk_inds), (batch, K))
return dict({"topk_score": topk_scores, "topk_inds_all": topk_inds_all,
"topk_clses": topk_clses, "topk_ys": topk_ys, "topk_xs": topk_xs})
def _gather_feat(feat, ind, mask=None):
"""
Select specific indexs on featuremap
Args:
feat: all results in 3 dimensions
ind: positive index
Returns:
"""
channel = feat.shape[-1]
ind = ind.unsqueeze(-1).expand((ind.shape[0], ind.shape[1], channel))
feat = gather_op(feat, 1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, channel)
return feat
def select_point_of_interest(batch, index, feature_maps):
"""
Select POI(point of interest) on feature map
Args:
batch: batch size
index: in point format or index format
feature_maps: regression feature map in [N, C, H, W]
Returns:
"""
w = feature_maps.shape[3]
index_length = len(index.shape)
if index_length == 3:
index = index[:, :, 1] * w + index[:, :, 0]
index = paddle.reshape(index, (batch, -1))
# [N, C, H, W] -----> [N, H, W, C]
feature_maps = paddle.transpose(feature_maps, (0, 2, 3, 1))
channel = feature_maps.shape[-1]
# [N, H, W, C] -----> [N, H*W, C]
feature_maps = paddle.reshape(feature_maps, (batch, -1, channel))
# expand index in channels
index = index.unsqueeze(-1).tile((1, 1, channel))
# select specific features bases on POIs
feature_maps = gather_op(feature_maps, 1, index)
return feature_maps
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .focal_loss import FocalLoss
from .loss import SMOKELossComputation
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
class FocalLoss(nn.Layer):
"""Focal loss class
"""
def __init__(self, alpha=2, beta=4):
super().__init__()
self.alpha = alpha
self.beta = beta
def forward(self, prediction, target):
"""forward
Args:
prediction (paddle.Tensor): model prediction
target (paddle.Tensor): ground truth
Returns:
paddle.Tensor: focal loss
"""
positive_index = (target == 1).astype("float32")
negative_index = (target < 1).astype("float32")
negative_weights = paddle.pow(1 - target, self.beta)
loss = 0.
positive_loss = paddle.log(prediction) \
* paddle.pow(1 - prediction, self.alpha) * positive_index
negative_loss = paddle.log(1 - prediction) \
* paddle.pow(prediction, self.alpha) * negative_weights * negative_index
num_positive = positive_index.sum()
positive_loss = positive_loss.sum()
negative_loss = negative_loss.sum()
if num_positive == 0:
loss -= negative_loss
else:
loss -= (positive_loss + negative_loss) / num_positive
return loss
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
import numpy as np
import cv2
import paddle
import paddle.nn as nn
from paddle.nn import functional as F
from smoke.models.losses import FocalLoss
from smoke.models.layers import select_point_of_interest
from smoke.cvlibs import manager
from smoke.models.heads import SMOKECoder
@manager.LOSSES.add_component
class SMOKELossComputation(object):
"""Convert targets and preds to heatmaps&regs, compute
loss with CE and L1
"""
def __init__(self,
depth_ref,
dim_ref,
reg_loss="DisL1",
loss_weight=(1., 10.),
max_objs=50):
self.smoke_coder = SMOKECoder(depth_ref, dim_ref)
self.cls_loss = FocalLoss(alpha=2, beta=4)
self.reg_loss = reg_loss
self.loss_weight = loss_weight
self.max_objs = max_objs
def prepare_targets(self, targets):
"""get heatmaps, regressions and 3D infos from targets
"""
heatmaps = targets["hm"]
regression = targets["reg"]
cls_ids = targets["cls_ids"]
proj_points = targets["proj_p"]
dimensions = targets["dimensions"]
locations = targets["locations"]
rotys = targets["rotys"]
trans_mat = targets["trans_mat"]
K = targets["K"]
reg_mask = targets["reg_mask"]
flip_mask = targets["flip_mask"]
bbox_size = targets["bbox_size"]
c_offsets = targets["c_offsets"]
return heatmaps, regression, dict(cls_ids=cls_ids,
proj_points=proj_points,
dimensions=dimensions,
locations=locations,
rotys=rotys,
trans_mat=trans_mat,
K=K,
reg_mask=reg_mask,
flip_mask=flip_mask,
bbox_size=bbox_size,
c_offsets=c_offsets)
def prepare_predictions(self, targets_variables, pred_regression):
"""decode model predictions
"""
batch, channel = pred_regression.shape[0], pred_regression.shape[1]
targets_proj_points = targets_variables["proj_points"]
# obtain prediction from points of interests
pred_regression_pois = select_point_of_interest(
batch, targets_proj_points, pred_regression
)
pred_regression_pois = paddle.reshape(pred_regression_pois, (-1, channel))
# FIXME: fix hard code here
pred_depths_offset = pred_regression_pois[:, 0]
pred_proj_offsets = pred_regression_pois[:, 1:3]
pred_dimensions_offsets = pred_regression_pois[:, 3:6]
pred_orientation = pred_regression_pois[:, 6:8]
# pred_bboxsize = paddle.zeros_like(pred_regression_pois[:, 6:8])
pred_bboxsize = pred_regression_pois[:, 8:10]
# pred_c_offsets = pred_regression_pois[:, 10:12]
pred_depths = self.smoke_coder.decode_depth(pred_depths_offset)
pred_locations = self.smoke_coder.decode_location(
targets_proj_points,
pred_proj_offsets,
pred_depths,
targets_variables["K"],
targets_variables["trans_mat"]
)
pred_dimensions = self.smoke_coder.decode_dimension(
targets_variables["cls_ids"],
pred_dimensions_offsets,
)
# we need to change center location to bottom location
# bug on left slice
# pred_locations[:, 1] += pred_dimensions[:, 1] / 2
pred_locations_x = (pred_locations[:, 0]).unsqueeze(-1)
pred_locations_y = (pred_locations[:, 1] + pred_dimensions[:, 1] / 2).unsqueeze(-1)
pred_locations_z = (pred_locations[:, 2]).unsqueeze(-1)
pred_locations = paddle.concat([pred_locations_x, pred_locations_y, pred_locations_z], axis=1)
pred_rotys = self.smoke_coder.decode_orientation(
pred_orientation,
targets_variables["locations"],
targets_variables["flip_mask"]
)
if self.reg_loss == "DisL1":
pred_box3d_rotys = self.smoke_coder.encode_box3d(
pred_rotys,
targets_variables["dimensions"],
targets_variables["locations"]
)
pred_box3d_dims = self.smoke_coder.encode_box3d(
targets_variables["rotys"],
pred_dimensions,
targets_variables["locations"]
)
pred_box3d_locs = self.smoke_coder.encode_box3d(
targets_variables["rotys"],
targets_variables["dimensions"],
pred_locations
)
return dict(ori=pred_box3d_rotys,
dim=pred_box3d_dims,
loc=pred_box3d_locs,
bbox=pred_bboxsize,)
# coff=pred_c_offsets)
elif self.reg_loss == "L1":
pred_box_3d = self.smoke_coder.encode_box3d(
pred_rotys,
pred_dimensions,
pred_locations
)
return pred_box_3d
def __call__(self, predictions, targets):
pred_heatmap, pred_regression = predictions[0], predictions[1]
targets_heatmap, targets_regression, targets_variables \
= self.prepare_targets(targets)
predict_boxes3d = self.prepare_predictions(targets_variables, pred_regression)
hm_loss = self.cls_loss(pred_heatmap, targets_heatmap) * self.loss_weight[0]
targets_regression = paddle.reshape(targets_regression, (
-1, targets_regression.shape[2], targets_regression.shape[3]
))
reg_mask = targets_variables["reg_mask"].astype("float32").flatten()
reg_mask = paddle.reshape(reg_mask, (-1, 1, 1))
reg_mask = reg_mask.expand_as(targets_regression)
if self.reg_loss == "DisL1":
reg_loss_ori = F.l1_loss(
predict_boxes3d["ori"] * reg_mask,
targets_regression * reg_mask,
reduction="sum") / (self.loss_weight[1] * self.max_objs)
reg_loss_dim = F.l1_loss(
predict_boxes3d["dim"] * reg_mask,
targets_regression * reg_mask,
reduction="sum") / (self.loss_weight[1] * self.max_objs)
reg_loss_loc = F.l1_loss(
predict_boxes3d["loc"] * reg_mask,
targets_regression * reg_mask,
reduction="sum") / (self.loss_weight[1] * self.max_objs)
reg_loss_size = F.l1_loss(
predict_boxes3d["bbox"],
paddle.reshape(targets_variables["bbox_size"],(-1, targets_variables["bbox_size"].shape[-1])),
reduction="sum") / (self.loss_weight[1] * self.max_objs)
losses = dict(hm_loss=hm_loss,
reg_loss=reg_loss_ori + reg_loss_dim + reg_loss_loc,
size_loss=reg_loss_size)
return losses
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .processor import PostProcessor
from .processorhm import PostProcessorHm
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
from smoke.models.layers import nms_hm, select_topk, select_point_of_interest
from smoke.cvlibs import manager
from smoke.models.heads import SMOKECoder
@manager.POSTPROCESSORS.add_component
class PostProcessor(nn.Layer):
def __init__(self,
depth_ref,
dim_ref,
reg_head=10,
det_threshold=0.25,
max_detection=50,
pred_2d=True):
super().__init__()
self.smoke_coder = SMOKECoder(depth_ref, dim_ref)
self.reg_head = reg_head
self.max_detection = max_detection
self.det_threshold = det_threshold
self.pred_2d = pred_2d
def forward(self, predictions, targets):
pred_heatmap, pred_regression = predictions[0], predictions[1]
batch = pred_heatmap.shape[0]
heatmap = nms_hm(pred_heatmap)
topk_dict = select_topk(
heatmap,
K=self.max_detection,
)
scores, indexs = topk_dict["topk_score"], topk_dict["topk_inds_all"]
clses, ys = topk_dict["topk_clses"], topk_dict["topk_ys"]
xs = topk_dict["topk_xs"]
pred_regression = select_point_of_interest(
batch, indexs, pred_regression
)
pred_regression_pois = paddle.reshape(pred_regression, (-1, self.reg_head))
pred_proj_points = paddle.concat([paddle.reshape(xs, (-1, 1)), paddle.reshape(ys, (-1, 1))], axis=1)
# FIXME: fix hard code here
pred_depths_offset = pred_regression_pois[:, 0]
pred_proj_offsets = pred_regression_pois[:, 1:3]
pred_dimensions_offsets = pred_regression_pois[:, 3:6]
pred_orientation = pred_regression_pois[:, 6:8]
pred_bbox_size = pred_regression_pois[:, 8:10]
pred_depths = self.smoke_coder.decode_depth(pred_depths_offset)
pred_locations = self.smoke_coder.decode_location(
pred_proj_points,
pred_proj_offsets,
pred_depths,
targets["K"],
targets["trans_mat"])
pred_dimensions = self.smoke_coder.decode_dimension(
clses,
pred_dimensions_offsets
)
# we need to change center location to bottom location
pred_locations[:, 1] += pred_dimensions[:, 1] / 2
pred_rotys, pred_alphas = self.smoke_coder.decode_orientation(
pred_orientation,
pred_locations
)
if self.pred_2d:
box2d = self.smoke_coder.decode_bbox_2d(pred_proj_points, pred_bbox_size,
targets["trans_mat"],
targets["image_size"])
else:
box2d = paddle.to_tensor([0, 0, 0, 0])
# change variables to the same dimension
clses = paddle.reshape(clses, (-1, 1))
pred_alphas = paddle.reshape(pred_alphas, (-1, 1))
pred_rotys = paddle.reshape(pred_rotys, (-1, 1))
scores = paddle.reshape(scores, (-1, 1))
l, h, w = pred_dimensions.chunk(3, 1)
pred_dimensions = paddle.concat([h, w, l], axis=1)
result = paddle.concat([
clses, pred_alphas, box2d, pred_dimensions, pred_locations, pred_rotys, scores
], axis=1)
keep_idx = result[:, -1] > self.det_threshold
if paddle.sum(keep_idx.astype("int32")) >= 1:
keep_idx = paddle.nonzero(result[:, -1] > self.det_threshold)
result = paddle.gather(result, keep_idx)
else:
result = paddle.to_tensor([])
return result
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
from smoke.models.heads import SMOKECoder
from smoke.models.layers import nms_hm, select_topk, select_point_of_interest
from smoke.cvlibs import manager
@manager.POSTPROCESSORS.add_component
class PostProcessorHm(nn.Layer):
def __init__(self,
depth_ref,
dim_ref,
reg_head=10,
det_threshold=0.25,
max_detection=50,
pred_2d=True):
super().__init__()
self.smoke_coder = SMOKECoder(depth_ref, dim_ref)
self.max_detection = max_detection
def forward(self, predictions, cam_info):
pred_heatmap, pred_regression = predictions[0], predictions[1]
batch = pred_heatmap.shape[0]
heatmap = nms_hm(pred_heatmap)
topk_dict = select_topk(
heatmap,
K=self.max_detection,
)
scores, indexs = topk_dict["topk_score"], topk_dict["topk_inds_all"]
clses, ys = topk_dict["topk_clses"], topk_dict["topk_ys"]
xs = topk_dict["topk_xs"]
pred_regression = select_point_of_interest(
batch, indexs, pred_regression
)
# pred_regression_pois = paddle.reshape(pred_regression, (pred_regression.numel()//10, 10))
# pred_proj_points = paddle.concat([paddle.reshape(xs, (xs.numel(), 1)), paddle.reshape(ys, (ys.numel(), 1))], axis=1)
pred_regression_pois = paddle.reshape(pred_regression, (numel_t(pred_regression)//10, 10))
pred_proj_points = paddle.concat([paddle.reshape(xs, (numel_t(xs), 1)), paddle.reshape(ys, (numel_t(ys), 1))], axis=1)
# FIXME: fix hard code here
pred_depths_offset = pred_regression_pois[:, 0]
pred_proj_offsets = pred_regression_pois[:, 1:3]
pred_dimensions_offsets = pred_regression_pois[:, 3:6]
pred_orientation = pred_regression_pois[:, 6:8]
pred_bbox_size = pred_regression_pois[:, 8:10]
pred_depths = self.smoke_coder.decode_depth(pred_depths_offset)
pred_locations = self.smoke_coder.decode_location_without_transmat(
pred_proj_points,
pred_proj_offsets,
pred_depths,
cam_info[0], cam_info[1])
pred_dimensions = self.smoke_coder.decode_dimension(
clses,
pred_dimensions_offsets
)
# we need to change center location to bottom location
pred_locations[:, 1] += pred_dimensions[:, 1] / 2
pred_rotys, pred_alphas = self.smoke_coder.decode_orientation(
pred_orientation,
pred_locations
)
box2d = self.smoke_coder.decode_bbox_2d_without_transmat(pred_proj_points,
pred_bbox_size, cam_info[1])
# change variables to the same dimension
clses = paddle.reshape(clses, (-1, 1))
pred_alphas = paddle.reshape(pred_alphas, (-1, 1))
pred_rotys = paddle.reshape(pred_rotys, (-1, 1))
scores = paddle.reshape(scores, (-1, 1))
l, h, w = pred_dimensions.chunk(3, 1)
pred_dimensions = paddle.concat([h, w, l], axis=1)
result = paddle.concat([
clses, pred_alphas, box2d, pred_dimensions, pred_locations, pred_rotys, scores
], axis=1)
return result
def numel_t(var):
from numpy import prod
assert -1 not in var.shape
return prod(var.shape)
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from smoke.cvlibs import manager
from smoke.utils import logger
@manager.MODELS.add_component
class SMOKE(nn.Layer):
def __init__(self, backbone, head, post_process=None):
super().__init__()
self.backbone = backbone
self.heads = head
self.post_process = post_process
self.init_weight()
def forward(self, images, targets=None):
features = self.backbone(images)
predictions = self.heads(features)
if not self.training:
return self.post_process(predictions, targets)
return predictions
def init_weight(self, bias_lr_factor=2):
for sublayer in self.sublayers():
if hasattr(sublayer, 'bias') and sublayer.bias is not None:
sublayer.bias.optimize_attr['learning_rate'] = bias_lr_factor
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .gather import gather_op
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The same function as torch.gather.
Note that: In PaddlePaddle2.0, paddle.gather is different with torch.gather
"""
import paddle
def gather_op(x, dim, index):
dtype_mapping = {"VarType.INT32": "int32", "VarType.INT64": "int64", "paddle.int32": "int32", "paddle.int64": "int64"}
if dim < 0:
dim += len(x.shape)
x_range = list(range(len(x.shape)))
x_range[0] = dim
x_range[dim] = 0
x_swaped = paddle.transpose(x, perm=x_range)
index_range = list(range(len(index.shape)))
index_range[0] = dim
index_range[dim] = 0
index_swaped = paddle.transpose(index, perm=index_range)
dtype = dtype_mapping[str(index.dtype)]
x_shape = paddle.shape(x_swaped)
index_shape = paddle.shape(index_swaped)
prod = paddle.prod(x_shape, dtype=dtype) / x_shape[0]
x_swaped_flattend = paddle.flatten(x_swaped)
index_swaped_flattend = paddle.flatten(index_swaped)
index_swaped_flattend *= prod
bias = paddle.arange(start=0, end=prod, dtype=dtype)
bias = paddle.reshape(bias, x_shape[1:])
bias = paddle.crop(bias, index_shape[1:])
bias = paddle.flatten(bias)
bias = paddle.tile(bias, [index_shape[0]])
index_swaped_flattend += bias
gathered = paddle.index_select(x_swaped_flattend, index_swaped_flattend)
gathered = paddle.reshape(gathered, index_swaped.shape)
out = paddle.transpose(gathered, perm=x_range)
return out
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .transforms import *
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import numpy as np
from PIL import Image, ImageEnhance
from scipy.ndimage.morphology import distance_transform_edt
def normalize(im, mean, std):
im = im.astype(np.float32, copy=False) / 255.0
im -= mean
im /= std
return im
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Copy-paste from PaddleSeg with minor modifications.
https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.1/paddleseg/transforms/transforms.py
"""
import random
import cv2
import numpy as np
from PIL import Image
from . import functional
from smoke.cvlibs import manager
@manager.TRANSFORMS.add_component
class Compose:
"""
Do transformation on input data with corresponding pre-processing and augmentation operations.
The shape of input data to all operations is [height, width, channels].
Args:
transforms (list): A list contains data pre-processing or augmentation. Empty list means only reading images, no transformation.
to_rgb (bool, optional): If converting image to RGB color space. Default: True.
Raises:
TypeError: When 'transforms' is not a list.
ValueError: when the length of 'transforms' is less than 1.
"""
def __init__(self, transforms, to_rgb=True):
if not isinstance(transforms, list):
raise TypeError('The transforms must be a list!')
self.transforms = transforms
self.to_rgb = to_rgb
def __call__(self, im, label=None):
"""
Args:
im (str|np.ndarray): It is either image path or image object.
label (str|np.ndarray): It is either label path or label ndarray.
Returns:
(tuple). A tuple including image, image info, and label after transformation.
"""
if isinstance(im, str):
im = cv2.imread(im).astype('float32')
if isinstance(label, str):
label = np.asarray(Image.open(label))
if im is None:
raise ValueError('Can\'t read The image file {}!'.format(im))
if self.to_rgb:
im = np.array(im)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
for op in self.transforms:
outputs = op(im, label)
im = outputs[0]
if len(outputs) == 2:
label = outputs[1]
im = np.transpose(im, (2, 0, 1))
return (im, label)
@manager.TRANSFORMS.add_component
class Normalize:
"""
Normalize an image.
Args:
mean (list, optional): The mean value of a data set. Default: [0.5, 0.5, 0.5].
std (list, optional): The standard deviation of a data set. Default: [0.5, 0.5, 0.5].
Raises:
ValueError: When mean/std is not list or any value in std is 0.
"""
def __init__(self, mean, std):
self.mean = mean
self.std = std
if not (isinstance(self.mean, (list, tuple))
and isinstance(self.std, (list, tuple))):
raise ValueError(
"{}: input type is invalid. It should be list or tuple".format(
self))
from functools import reduce
if reduce(lambda x, y: x * y, self.std) == 0:
raise ValueError('{}: std is invalid!'.format(self))
def __call__(self, im, label=None):
"""
Args:
im (np.ndarray): The Image data.
label (np.ndarray, optional): The label data. Default: None.
Returns:
(tuple). When label is None, it returns (im, ), otherwise it returns (im, label).
"""
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im = functional.normalize(im, mean, std)
if label is None:
return (im, )
else:
return (im, label)
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .timer import TimeAverager, calculate_eta
from .pretrained_utils import load_pretrained_model
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from skimage import transform as trans
def encode_label(K, ry, dims, locs):
"""get bbox 3d and 2d by model output
Args:
K (np.ndarray): camera intrisic matrix
ry (np.ndarray): rotation y
dims (np.ndarray): dimensions
locs (np.ndarray): locations
"""
l, h, w = dims[0], dims[1], dims[2]
x, y, z = locs[0], locs[1], locs[2]
x_corners = [0, l, l, l, l, 0, 0, 0]
y_corners = [0, 0, h, h, 0, 0, h, h]
z_corners = [0, 0, 0, w, w, w, w, 0]
x_corners += - np.float32(l) / 2
y_corners += - np.float32(h)
z_corners += - np.float32(w) / 2
corners_3d = np.array([x_corners, y_corners, z_corners])
rot_mat = np.array([[np.cos(ry), 0, np.sin(ry)],
[0, 1, 0],
[-np.sin(ry), 0, np.cos(ry)]])
corners_3d = np.matmul(rot_mat, corners_3d)
corners_3d += np.array([x, y, z]).reshape([3, 1])
loc_center = np.array([x, y - h / 2, z])
proj_point = np.matmul(K, loc_center)
proj_point = proj_point[:2] / proj_point[2]
corners_2d = np.matmul(K, corners_3d)
corners_2d = corners_2d[:2] / corners_2d[2]
box2d = np.array([min(corners_2d[0]), min(corners_2d[1]),
max(corners_2d[0]), max(corners_2d[1])])
return proj_point, box2d, corners_3d
def get_transfrom_matrix(center_scale, output_size):
"""get transform matrix
"""
center, scale = center_scale[0], center_scale[1]
# todo: further add rot and shift here.
src_w = scale[0]
dst_w = output_size[0]
dst_h = output_size[1]
src_dir = np.array([0, src_w * -0.5])
dst_dir = np.array([0, dst_w * -0.5])
src = np.zeros((3, 2), dtype=np.float32)
dst = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center
src[1, :] = center + src_dir
dst[0, :] = np.array([dst_w * 0.5, dst_h * 0.5])
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
src[2, :] = get_3rd_point(src[0, :], src[1, :])
dst[2, :] = get_3rd_point(dst[0, :], dst[1, :])
get_matrix = trans.estimate_transform("affine", src, dst)
matrix = get_matrix.params
return matrix.astype(np.float32)
def affine_transform(point, matrix):
"""do affine transform to label
"""
point_exd = np.array([point[0], point[1], 1.])
new_point = np.matmul(matrix, point_exd)
return new_point[:2]
def get_3rd_point(point_a, point_b):
"""get 3rd point
"""
d = point_a - point_b
point_c = point_b + np.array([-d[1], d[0]])
return point_c
def gaussian_radius(h, w, thresh_min=0.7):
"""gaussian radius
"""
a1 = 1
b1 = h + w
c1 = h * w * (1 - thresh_min) / (1 + thresh_min)
sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
r1 = (b1 - sq1) / (2 * a1)
a2 = 4
b2 = 2 * (h + w)
c2 = (1 - thresh_min) * w * h
sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
r2 = (b2 - sq2) / (2 * a2)
a3 = 4 * thresh_min
b3 = -2 * thresh_min * (h + w)
c3 = (thresh_min - 1) * w * h
sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
r3 = (b3 + sq3) / (2 * a3)
return min(r1, r2, r3)
def gaussian2D(shape, sigma=1):
"""get 2D gaussian map
"""
m, n = [(ss - 1.) / 2. for ss in shape]
y, x = np.ogrid[-m:m + 1, -n:n + 1]
h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
h[h < np.finfo(h.dtype).eps * h.max()] = 0
return h
def draw_umich_gaussian(heatmap, center, radius, k=1):
"""draw umich gaussian
"""
diameter = 2 * radius + 1
gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6)
x, y = int(center[0]), int(center[1])
height, width = heatmap.shape[0:2]
left, right = min(x, radius), min(width - x, radius + 1)
top, bottom = min(y, radius), min(height - y, radius + 1)
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right]
if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
return heatmap
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Copy-paste from PaddleSeg.
https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.1/paddleseg/utils/logger.py
"""
import sys
import time
import paddle
levels = {0: 'ERROR', 1: 'WARNING', 2: 'INFO', 3: 'DEBUG'}
log_level = 2
def log(level=2, message=""):
if paddle.distributed.ParallelEnv().local_rank == 0:
current_time = time.time()
time_array = time.localtime(current_time)
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array)
if log_level >= level:
print(
"{} [{}]\t{}".format(current_time, levels[level],
message).encode("utf-8").decode("latin1"))
sys.stdout.flush()
def debug(message=""):
log(level=3, message=message)
def info(message=""):
log(level=2, message=message)
def warning(message=""):
log(level=1, message=message)
def error(message=""):
log(level=0, message=message)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import errno
import os
def mkdir(path):
"""make new dir
Args:
path (str): path of new dir to make
"""
try:
os.makedirs(path)
except OSError as e:
if e.errno != errno.EEXIST:
raise
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Copy-paste from PaddleSeg with minor modifications.
https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.1/paddleseg/utils/utils.py
"""
import contextlib
import filelock
import math
import os
import tempfile
from urllib.parse import urlparse, unquote
import paddle
from smoke.utils import logger
@contextlib.contextmanager
def generate_tempdir(directory: str = None, **kwargs):
'''Generate a temporary directory'''
directory = seg_env.TMP_HOME if not directory else directory
with tempfile.TemporaryDirectory(dir=directory, **kwargs) as _dir:
yield _dir
def load_pretrained_model(model, pretrained_model):
if os.path.exists(pretrained_model):
para_state_dict = paddle.load(pretrained_model)
model_state_dict = model.state_dict()
keys = model_state_dict.keys()
num_params_loaded = 0
for k in keys:
if k not in para_state_dict:
logger.warning("{} is not in pretrained model".format(k))
elif list(para_state_dict[k].shape) != list(
model_state_dict[k].shape):
logger.warning(
"[SKIP] Shape of pretrained params {} doesn't match.(Pretrained: {}, Actual: {})"
.format(k, para_state_dict[k].shape,
model_state_dict[k].shape))
else:
model_state_dict[k] = para_state_dict[k]
num_params_loaded += 1
model.set_dict(model_state_dict)
logger.info("There are {}/{} variables loaded into {}.".format(
num_params_loaded, len(model_state_dict),
model.__class__.__name__))
else:
raise ValueError(
'The pretrained model directory is not Found: {}'.format(
pretrained_model))
def resume(model, optimizer, resume_model):
if resume_model is not None:
logger.info('Resume model from {}'.format(resume_model))
if os.path.exists(resume_model):
resume_model = os.path.normpath(resume_model)
ckpt_path = os.path.join(resume_model, 'model.pdparams')
para_state_dict = paddle.load(ckpt_path)
ckpt_path = os.path.join(resume_model, 'model.pdopt')
opti_state_dict = paddle.load(ckpt_path)
model.set_state_dict(para_state_dict)
optimizer.set_state_dict(opti_state_dict)
iter = resume_model.split('_')[-1]
iter = int(iter)
return iter
else:
raise ValueError(
'Directory of the model needed to resume is not Found: {}'.
format(resume_model))
else:
logger.info('No model needed to resume.')
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Copy-paste from PaddleSeg with minor modifications.
https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.1/paddleseg/utils/progbar.py
"""
import os
import sys
import time
import numpy as np
class Progbar(object):
"""
Displays a progress bar.
It refers to https://github.com/keras-team/keras/blob/keras-2/keras/utils/generic_utils.py
Args:
target (int): Total number of steps expected, None if unknown.
width (int): Progress bar width on screen.
verbose (int): Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
stateful_metrics (list|tuple): Iterable of string names of metrics that should *not* be
averaged over time. Metrics in this list will be displayed as-is. All
others will be averaged by the progbar before display.
interval (float): Minimum visual progress update interval (in seconds).
unit_name (str): Display name for step counts (usually "step" or "sample").
"""
def __init__(self,
target,
width=30,
verbose=1,
interval=0.05,
stateful_metrics=None,
unit_name='step'):
self.target = target
self.width = width
self.verbose = verbose
self.interval = interval
self.unit_name = unit_name
if stateful_metrics:
self.stateful_metrics = set(stateful_metrics)
else:
self.stateful_metrics = set()
self._dynamic_display = ((hasattr(sys.stderr, 'isatty')
and sys.stderr.isatty())
or 'ipykernel' in sys.modules
or 'posix' in sys.modules
or 'PYCHARM_HOSTED' in os.environ)
self._total_width = 0
self._seen_so_far = 0
# We use a dict + list to avoid garbage collection
# issues found in OrderedDict
self._values = {}
self._values_order = []
self._start = time.time()
self._last_update = 0
def update(self, current, values=None, finalize=None):
"""
Updates the progress bar.
Args:
current (int): Index of current step.
values (list): List of tuples: `(name, value_for_last_step)`. If `name` is in
`stateful_metrics`, `value_for_last_step` will be displayed as-is.
Else, an average of the metric over time will be displayed.
finalize (bool): Whether this is the last update for the progress bar. If
`None`, defaults to `current >= self.target`.
"""
if finalize is None:
if self.target is None:
finalize = False
else:
finalize = current >= self.target
values = values or []
for k, v in values:
if k not in self._values_order:
self._values_order.append(k)
if k not in self.stateful_metrics:
# In the case that progress bar doesn't have a target value in the first
# epoch, both on_batch_end and on_epoch_end will be called, which will
# cause 'current' and 'self._seen_so_far' to have the same value. Force
# the minimal value to 1 here, otherwise stateful_metric will be 0s.
value_base = max(current - self._seen_so_far, 1)
if k not in self._values:
self._values[k] = [v * value_base, value_base]
else:
self._values[k][0] += v * value_base
self._values[k][1] += value_base
else:
# Stateful metrics output a numeric value. This representation
# means "take an average from a single value" but keeps the
# numeric formatting.
self._values[k] = [v, 1]
self._seen_so_far = current
now = time.time()
info = ' - %.0fs' % (now - self._start)
if self.verbose == 1:
if now - self._last_update < self.interval and not finalize:
return
prev_total_width = self._total_width
if self._dynamic_display:
sys.stderr.write('\b' * prev_total_width)
sys.stderr.write('\r')
else:
sys.stderr.write('\n')
if self.target is not None:
numdigits = int(np.log10(self.target)) + 1
bar = ('%' + str(numdigits) + 'd/%d [') % (current, self.target)
prog = float(current) / self.target
prog_width = int(self.width * prog)
if prog_width > 0:
bar += ('=' * (prog_width - 1))
if current < self.target:
bar += '>'
else:
bar += '='
bar += ('.' * (self.width - prog_width))
bar += ']'
else:
bar = '%7d/Unknown' % current
self._total_width = len(bar)
sys.stderr.write(bar)
if current:
time_per_unit = (now - self._start) / current
else:
time_per_unit = 0
if self.target is None or finalize:
if time_per_unit >= 1 or time_per_unit == 0:
info += ' %.0fs/%s' % (time_per_unit, self.unit_name)
elif time_per_unit >= 1e-3:
info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name)
else:
info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name)
else:
eta = time_per_unit * (self.target - current)
if eta > 3600:
eta_format = '%d:%02d:%02d' % (eta // 3600,
(eta % 3600) // 60, eta % 60)
elif eta > 60:
eta_format = '%d:%02d' % (eta // 60, eta % 60)
else:
eta_format = '%ds' % eta
info = ' - ETA: %s' % eta_format
for k in self._values_order:
info += ' - %s:' % k
if isinstance(self._values[k], list):
avg = np.mean(
self._values[k][0] / max(1, self._values[k][1]))
if abs(avg) > 1e-3:
info += ' %.4f' % avg
else:
info += ' %.4e' % avg
else:
info += ' %s' % self._values[k]
self._total_width += len(info)
if prev_total_width > self._total_width:
info += (' ' * (prev_total_width - self._total_width))
if finalize:
info += '\n'
sys.stderr.write(info)
sys.stderr.flush()
elif self.verbose == 2:
if finalize:
numdigits = int(np.log10(self.target)) + 1
count = ('%' + str(numdigits) + 'd/%d') % (current, self.target)
info = count + info
for k in self._values_order:
info += ' - %s:' % k
avg = np.mean(
self._values[k][0] / max(1, self._values[k][1]))
if avg > 1e-3:
info += ' %.4f' % avg
else:
info += ' %.4e' % avg
info += '\n'
sys.stderr.write(info)
sys.stderr.flush()
self._last_update = now
def add(self, n, values=None):
self.update(self._seen_so_far + n, values)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Copy-paste from PaddleSeg.
https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.1/paddleseg/utils/timer.py
"""
import time
class TimeAverager(object):
def __init__(self):
self.reset()
def reset(self):
self._cnt = 0
self._total_time = 0
self._total_samples = 0
def record(self, usetime, num_samples=None):
self._cnt += 1
self._total_time += usetime
if num_samples:
self._total_samples += num_samples
def get_average(self):
if self._cnt == 0:
return 0
return self._total_time / float(self._cnt)
def get_ips_average(self):
if not self._total_samples or self._cnt == 0:
return 0
return float(self._total_samples) / self._total_time
def calculate_eta(remaining_step, speed):
if remaining_step < 0:
remaining_step = 0
remaining_time = int(remaining_step * speed)
result = "{:0>2}:{:0>2}:{:0>2}"
arr = []
for i in range(2, -1, -1):
arr.append(int(remaining_time / 60**i))
remaining_time %= 60**i
return result.format(*arr)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import cv2
import numpy as np
from smoke.ops import gather_op
def get_ratio(ori_img_size, output_size, down_ratio=(4, 4)):
return np.array([[down_ratio[1] * ori_img_size[1] / output_size[1],
down_ratio[0] * ori_img_size[0] / output_size[0]]], np.float32)
def get_img(img_path):
img = cv2.imread(img_path)
ori_img_size = img.shape
img = cv2.resize(img, (960, 640))
output_size = img.shape
img = img/255.0
img = np.subtract(img, np.array([0.485, 0.456, 0.406]))
img = np.true_divide(img, np.array([0.229, 0.224, 0.225]))
img = np.array(img, np.float32)
img = img.transpose(2, 0, 1)
img = img[None,:,:,:]
img = paddle.to_tensor(img)
return img, ori_img_size, output_size
def encode_box3d(rotys, dims, locs, K, image_size):
'''
construct 3d bounding box for each object.
Args:
rotys: rotation in shape N
dims: dimensions of objects
locs: locations of objects
Returns:
box_3d in camera frame, shape(b, 2, 8)
'''
if len(rotys.shape) == 2:
rotys = rotys.flatten()
if len(dims.shape) == 3:
dims = paddle.reshape(dims, (-1, 3))
if len(locs.shape) == 3:
locs = paddle.reshape(locs, (-1, 3))
N = rotys.shape[0]
ry = rad_to_matrix(rotys, N)
dims = paddle.reshape(dims, (-1, 1)).tile((1, 8))
dims[::3, :4], dims[2::3, :4] = 0.5 * dims[::3, :4], 0.5 * dims[2::3, :4]
dims[::3, 4:], dims[2::3, 4:] = -0.5 * dims[::3, 4:], -0.5 * dims[2::3, 4:]
dims[1::3, :4], dims[1::3, 4:] = 0., -dims[1::3, 4:]
index = paddle.to_tensor([[4, 0, 1, 2, 3, 5, 6, 7],
[4, 5, 0, 1, 6, 7, 2, 3],
[4, 5, 6, 0, 1, 2, 3, 7]]).tile((N, 1))
box_3d_object = gather_op(dims, 1, index)
box_3d = paddle.matmul(ry, paddle.reshape(box_3d_object, (N, 3, -1)))
box_3d += locs.unsqueeze(-1).tile((1, 1, 8))
box3d_image = paddle.matmul(K, box_3d)
box3d_image = box3d_image[:, :2, :] / paddle.reshape(box3d_image[:, 2, :], (box_3d.shape[0], 1, box_3d.shape[2]))
box3d_image = box3d_image.astype("int32")
box3d_image = box3d_image.astype("float32")
box3d_image[:, 0] = box3d_image[:, 0].clip(0, image_size[1])
box3d_image[:, 1] = box3d_image[:, 1].clip(0, image_size[0])
return box3d_image
def rad_to_matrix(rotys, N):
cos, sin = rotys.cos(), rotys.sin()
i_temp = paddle.to_tensor([[1, 0, 1],
[0, 1, 0],
[-1, 0, 1]]).astype("float32")
ry = paddle.reshape(i_temp.tile((N, 1)), (N, -1, 3))
ry[:, 0, 0] *= cos
ry[:, 0, 2] *= sin
ry[:, 2, 0] *= sin
ry[:, 2, 2] *= cos
return ry
def draw_box_3d(image, corners, color=None):
''' Draw 3d bounding box in image
corners: (8,2) array of vertices for the 3d box in following order:
'''
# face_idx = [[0, 1, 5, 4],
# [1, 2, 6, 5],
# [2, 3, 7, 6],
# [3, 0, 4, 7]]
if color is None:
color = (0, 0, 255)
face_idx = [[5, 4, 3, 6],
[1, 2, 3, 4],
[1, 0, 7, 2],
[0, 5, 6, 7]]
for ind_f in range(3, -1, -1):
f = face_idx[ind_f]
for j in range(4):
cv2.line(image, (corners[f[j], 0], corners[f[j], 1]),
(corners[f[(j + 1) % 4], 0], corners[f[(j + 1) % 4], 1]), color, 2, lineType=cv2.LINE_AA)
if ind_f == 0:
cv2.line(image, (corners[f[0], 0], corners[f[0], 1]),
(corners[f[2], 0], corners[f[2], 1]), color, 1, lineType=cv2.LINE_AA)
cv2.line(image, (corners[f[1], 0], corners[f[1], 1]),
(corners[f[3], 0], corners[f[3], 1]), color, 1, lineType=cv2.LINE_AA)
return image
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import cv2
import numpy as np
import paddle
from smoke.cvlibs import Config
from smoke.utils import logger, load_pretrained_model
from smoke.utils.vis_utils import get_img, get_ratio, encode_box3d, draw_box_3d
def parse_args():
parser = argparse.ArgumentParser(description='Model test')
# params of evaluate
parser.add_argument(
"--config", dest="cfg", help="The config file.", required=True, type=str)
parser.add_argument(
'--model_path',
dest='model_path',
help='The path of model for evaluation',
type=str,
required=True)
parser.add_argument(
'--input_path',
dest='input_path',
help='The image path',
type=str,
required=True)
parser.add_argument(
'--output_path',
dest='output_path',
help='The result path of image',
type=str,
required=True)
return parser.parse_args()
def main(args):
paddle.set_device("gpu")
cfg = Config(args.cfg)
model = cfg.model
model.eval()
if args.model_path:
load_pretrained_model(model, args.model_path)
logger.info('Loaded trained params of model successfully')
K = np.array([[[2055.56, 0, 939.658], [0, 2055.56, 641.072], [0, 0, 1]]], np.float32)
K_inverse = np.linalg.inv(K)
K_inverse = paddle.to_tensor(K_inverse)
img, ori_img_size, output_size = get_img(args.input_path)
ratio = get_ratio(ori_img_size, output_size)
ratio = paddle.to_tensor(ratio)
cam_info = [K_inverse, ratio]
total_pred = model(img, cam_info)
keep_idx = paddle.nonzero(total_pred[:, -1] > 0.25)
total_pred = paddle.gather(total_pred, keep_idx)
if total_pred.shape[0] > 0:
pred_dimensions = total_pred[:, 6:9]
pred_dimensions = pred_dimensions.roll(shifts=1, axis=1)
pred_rotys = total_pred[:, 12]
pred_locations = total_pred[:, 9:12]
bbox_3d = encode_box3d(pred_rotys, pred_dimensions, pred_locations, paddle.to_tensor(K), (1280, 1920))
else:
bbox_3d = total_pred
img_draw = cv2.imread(args.input_path)
for idx in range(bbox_3d.shape[0]):
bbox = bbox_3d[idx]
bbox = bbox.transpose([1,0]).numpy()
img_draw = draw_box_3d(img_draw, bbox)
cv2.imwrite(args.output_path, img_draw)
if __name__ == '__main__':
args = parse_args()
main(args)
#ifndef MAIL_H
#define MAIL_H
#include <stdio.h>
#include <stdarg.h>
#include <string.h>
class Mail {
public:
Mail (std::string email = "") {
if (email.compare("")) {
mail = popen("/usr/lib/sendmail -t -f noreply@cvlibs.net","w");
fprintf(mail,"To: %s\n", email.c_str());
fprintf(mail,"From: noreply@cvlibs.net\n");
fprintf(mail,"Subject: KITTI Evaluation Benchmark\n");
fprintf(mail,"\n\n");
} else {
mail = 0;
}
}
~Mail() {
if (mail) {
pclose(mail);
}
}
void msg (const char *format, ...) {
va_list args;
va_start(args,format);
if (mail) {
vfprintf(mail,format,args);
fprintf(mail,"\n");
}
vprintf(format,args);
printf("\n");
va_end(args);
}
private:
FILE *mail;
};
#endif
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Copy-paste from PaddleSeg with minor modifications.
https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.1/train.py
"""
import argparse
import paddle
from smoke.cvlibs import manager, Config
from smoke.utils import logger
from smoke.core import train
def parse_args():
parser = argparse.ArgumentParser(description='Model training')
# params of training
parser.add_argument(
"--config", dest="cfg", help="The config file.", required=True, type=str)
parser.add_argument(
'--iters',
dest='iters',
help='iters for training',
type=int,
default=None)
parser.add_argument(
'--batch_size',
dest='batch_size',
help='Mini batch size of one gpu or cpu',
type=int,
default=None)
parser.add_argument(
'--learning_rate',
dest='learning_rate',
help='Learning rate',
type=float,
default=None)
parser.add_argument(
'--save_interval',
dest='save_interval',
help='How many iters to save a model snapshot once during training.',
type=int,
default=1000)
parser.add_argument(
'--resume_model',
dest='resume_model',
help='The path of resume model',
type=str,
default=None)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='The directory for saving the model snapshot',
type=str,
default='./output')
parser.add_argument(
'--keep_checkpoint_max',
dest='keep_checkpoint_max',
help='Maximum number of checkpoints to save',
type=int,
default=5)
parser.add_argument(
'--num_workers',
dest='num_workers',
help='Num workers for data loader',
type=int,
default=0)
parser.add_argument(
'--log_iters',
dest='log_iters',
help='Display logging information at every log_iters',
default=10,
type=int)
return parser.parse_args()
def main(args):
paddle.set_device("gpu")
cfg = Config(
args.cfg,
learning_rate=args.learning_rate,
iters=args.iters,
batch_size=args.batch_size)
train_dataset = cfg.train_dataset
if train_dataset is None:
raise RuntimeError(
'The training dataset is not specified in the configuration file.')
elif len(train_dataset) == 0:
raise ValueError(
'The length of train_dataset is 0. Please check if your dataset is valid'
)
val_dataset = None #cfg.val_dataset if args.do_eval else None
losses = cfg.loss
msg = '\n---------------Config Information---------------\n'
msg += str(cfg)
msg += '------------------------------------------------'
logger.info(msg)
train(
cfg.model,
train_dataset,
val_dataset=val_dataset,
optimizer=cfg.optimizer,
loss_computation=cfg.loss,
save_dir=args.save_dir,
iters=cfg.iters,
batch_size=cfg.batch_size,
resume_model=args.resume_model,
save_interval=args.save_interval,
log_iters=args.log_iters,
num_workers=args.num_workers,
keep_checkpoint_max=args.keep_checkpoint_max)
if __name__ == '__main__':
args = parse_args()
main(args)
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册