未验证 提交 3944be62 编写于 作者: Z zhouzj 提交者: GitHub

add bipointnet. (#1727)

上级 874bba39
# BiPointNet
## 1. 简介
本示例介绍了一种用于点云模型 (PointNet) 的二值化方法(BiPointNet)。BiPointNet 通过引入熵最大化聚合(Entropy-Maximizing Aggregation)来调整聚合前的分布,以获得最大信息熵,并引入分层尺度因子(Layer-wise Scale Recovery)有效恢复特征表达能力,是一种简单而高效的二值化点云模型的方法。
## 2. Benchmark
| 模型 | Accuracy | 权重下载 |
| ------------- | --------- | --------- |
| PointNet | 89.83 | [PointNet.pdparams](https://bj.bcebos.com/v1/paddle-slim-models/PointNet.pdparams) |
| BiPointNet | 85.86 | [BiPointNet.pdparams](https://bj.bcebos.com/v1/paddle-slim-models/BiPointNet.pdparams) |
## 3. BiPointNet 的训练及测试
BiPointNet 整体结构如图所示,详情见论文 [BIPOINTNET: BINARY NEURAL NETWORK FOR POINT CLOUDS](https://arxiv.org/abs/2010.05501)
![arch](arch.png)
### 3.1 准备环境
- PaddlePaddle >= 2.4 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装)
安装 paddlepaddle:
```shell
# CPU
pip install paddlepaddle==2.4.2
# GPU 以Ubuntu、CUDA 11.2为例
python -m pip install paddlepaddle-gpu==2.4.2.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
```
### 3.2 准备数据集
本示例在 [ModelNet40](https://modelnet.cs.princeton.edu) 数据集上进行了分类实验。
```shell
wget https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip --no-check-certificate
unzip modelnet40_normal_resampled.zip
```
### 3.3 启动训练
- 训练基准模型 PointNet
```
export CUDA_VISIBLE_DEVICES=0
python train.py --save_dir 'pointnet'
```
- 训练 BiPointNet
```
export CUDA_VISIBLE_DEVICES=0
python train.py --save_dir 'bipointnet' --binary
```
### 3.4 验证精度
- 测试基准模型 PointNet
```shell
export CUDA_VISIBLE_DEVICES=0
python test.py --model_path 'PointNet.pdparams'
```
- 测试 BiPointNet
```shell
export CUDA_VISIBLE_DEVICES=0
python train.py --model_path 'BiPointNet.pdparams' --binary
```
## 致谢
感谢 [Macaronlin](https://github.com/Macaronlin) 贡献 BiPointNet。
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
from paddle.autograd import PyLayer
from paddle.nn import functional as F
from paddle.nn.layer import Conv1D
from paddle.nn.layer.common import Linear
class BinaryQuantizer(PyLayer):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
out = paddle.sign(input)
return out
@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensor()[0]
grad_input = grad_output
grad_input[input >= 1] = 0
grad_input[input <= -1] = 0
return grad_input.clone()
class BiLinear(Linear):
def __init__(self,
in_features,
out_features,
weight_attr=None,
bias_attr=None,
name=None):
super(BiLinear, self).__init__(
in_features,
out_features,
weight_attr=weight_attr,
bias_attr=bias_attr,
name=name)
self.scale_weight_init = False
self.scale_weight = paddle.create_parameter(shape=[1], dtype='float32')
def forward(self, input):
ba = input
bw = self.weight
bw = bw - bw.mean()
if self.scale_weight_init == False:
scale_weight = F.linear(ba, bw).std() / F.linear(
paddle.sign(ba), paddle.sign(bw)).std()
if paddle.isnan(scale_weight):
scale_weight = bw.std() / paddle.sign(bw).std()
self.scale_weight.set_value(scale_weight)
self.scale_weight_init = True
ba = BinaryQuantizer.apply(ba)
bw = BinaryQuantizer.apply(bw)
bw = bw * self.scale_weight
out = F.linear(x=ba, weight=bw, bias=self.bias, name=self.name)
return out
class BiConv1D(Conv1D):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
padding_mode='zeros',
weight_attr=None,
bias_attr=None,
data_format="NCL"):
super(BiConv1D, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, padding_mode, weight_attr, bias_attr, data_format)
self.scale_weight_init = False
self.scale_weight = paddle.create_parameter(shape=[1], dtype='float32')
def forward(self, input):
ba = input
bw = self.weight
bw = bw - bw.mean()
padding = 0
if self._padding_mode != "zeros":
ba = F.pad(
ba,
self._reversed_padding_repeated_twice,
mode=self._padding_mode,
data_format=self._data_format)
else:
padding = self._padding
if self.scale_weight_init == False:
scale_weight = F.conv1d(ba, bw, bias=self.bias, padding=padding, stride=self._stride, dilation=self._dilation, groups=self._groups, data_format=self._data_format).std() / \
F.conv1d(paddle.sign(ba), paddle.sign(bw), bias=self.bias, padding=padding, stride=self._stride, dilation=self._dilation, groups=self._groups, data_format=self._data_format).std()
if paddle.isnan(scale_weight):
scale_weight = bw.std() / paddle.sign(bw).std()
self.scale_weight.set_value(scale_weight)
self.scale_weight_init = True
ba = BinaryQuantizer.apply(ba)
bw = BinaryQuantizer.apply(bw)
bw = bw * self.scale_weight
return F.conv1d(
ba,
bw,
bias=self.bias,
padding=padding,
stride=self._stride,
dilation=self._dilation,
groups=self._groups,
data_format=self._data_format)
def _to_bi_function(model, fp_layers=[]):
for name, layer in model.named_children():
if id(layer) in fp_layers:
continue
if isinstance(layer, Linear):
new_layer = BiLinear(layer.weight.shape[0], layer.weight.shape[1],
layer._weight_attr, layer._bias_attr,
layer.name)
new_layer.weight = layer.weight
new_layer.bias = layer.bias
model._sub_layers[name] = new_layer
elif isinstance(layer, Conv1D):
new_layer = BiConv1D(layer._in_channels, layer._out_channels,
layer._kernel_size, layer._stride,
layer._padding, layer._dilation, layer._groups,
layer._padding_mode, layer._param_attr,
layer._bias_attr, layer._data_format)
new_layer.weight = layer.weight
new_layer.bias = layer.bias
model._sub_layers[name] = new_layer
elif isinstance(layer, nn.ReLU):
model._sub_layers[name] = nn.Hardtanh()
else:
model._sub_layers[name] = _to_bi_function(layer, fp_layers)
return model
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle
import warnings
import numpy as np
import paddle
from paddle.io import Dataset
from tqdm import tqdm
warnings.filterwarnings("ignore")
def normalize_point_cloud(pc):
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
def random_point_dropout(pc, max_dropout_ratio=0.875):
dropout_ratio = np.random.random() * max_dropout_ratio # 0~0.875
drop_idx = np.where(np.random.random((pc.shape[0])) <= dropout_ratio)[0]
if len(drop_idx) > 0:
pc[drop_idx, :] = pc[0, :] # set to the first point
return pc
def random_scale_point_cloud(data, scale_low=0.8, scale_high=1.25):
scales = np.random.uniform(scale_low, scale_high)
data *= scales
return data
def shift_point_cloud(data, shift_range=0.1):
shifts = np.random.uniform(-shift_range, shift_range, (3))
data += shifts
return data
def jitter_point_cloud(data: np.ndarray, sigma: float=0.02, clip: float=0.05):
assert clip > 0
jittered_data = np.clip(sigma * np.random.randn(*data.shape), -clip, clip)
data = data + jittered_data
return data
def random_rotate_point_cloud(data: np.ndarray):
angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(angle)
sinval = np.sin(angle)
rotation_matrix = np.array(
[[cosval, sinval, 0], [-sinval, cosval, 0], [0, 0, 1]],
dtype=data.dtype)
data = data @ rotation_matrix
return data
def farthest_point_sample(point, npoint):
"""
Input:
xyz: pointcloud data, [N, D]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [npoint, D]
"""
N, D = point.shape
xyz = point[:, :3]
centroids = np.zeros((npoint, ))
distance = np.ones((N, )) * 1e10
farthest = np.random.randint(0, N)
for i in range(npoint):
centroids[i] = farthest
centroid = xyz[farthest, :]
dist = np.sum((xyz - centroid)**2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = np.argmax(distance, -1)
point = point[centroids.astype(np.int32)]
return point
class ModelNetDataset(Dataset):
def __init__(
self,
root,
num_point,
use_uniform_sample=False,
use_normals=False,
num_category=40,
split="train",
process_data=False, ):
self.root = root
self.npoints = num_point
self.split = split
self.process_data = process_data
self.uniform = use_uniform_sample
self.use_normals = use_normals
self.num_category = num_category
if self.num_category == 10:
self.catfile = os.path.join(self.root, "modelnet10_shape_names.txt")
else:
self.catfile = os.path.join(self.root, "modelnet40_shape_names.txt")
self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat))))
shape_ids = {}
if self.num_category == 10:
shape_ids["train"] = [
line.rstrip() for line in
open(os.path.join(self.root, "modelnet10_train.txt"))
]
shape_ids["test"] = [
line.rstrip() for line in
open(os.path.join(self.root, "modelnet10_test.txt"))
]
else:
shape_ids["train"] = [
line.rstrip() for line in
open(os.path.join(self.root, "modelnet40_train.txt"))
]
shape_ids["test"] = [
line.rstrip() for line in
open(os.path.join(self.root, "modelnet40_test.txt"))
]
assert split == "train" or split == "test"
shape_names = ["_".join(x.split("_")[0:-1]) for x in shape_ids[split]]
self.datapath = [(shape_names[i], os.path.join(
self.root, shape_names[i], shape_ids[split][i]) + ".txt", )
for i in range(len(shape_ids[split]))]
print("The size of %s data is %d" % (split, len(self.datapath)))
if self.uniform:
self.save_path = os.path.join(
root,
"modelnet%d_%s_%dpts_fps.dat" % (self.num_category, split,
self.npoints), )
else:
self.save_path = os.path.join(
root,
"modelnet%d_%s_%dpts.dat" % (self.num_category, split,
self.npoints), )
if self.process_data:
if not os.path.exists(self.save_path):
print("Processing data %s (only running in the first time)..." %
self.save_path)
self.list_of_points = [None] * len(self.datapath)
self.list_of_labels = [None] * len(self.datapath)
for index in tqdm(
range(len(self.datapath)), total=len(self.datapath)):
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(
fn[1], delimiter=",").astype(np.float32)
if self.uniform:
point_set = farthest_point_sample(
point_set, self.npoints)
else:
point_set = point_set[0:self.npoints, :]
self.list_of_points[index] = point_set
self.list_of_labels[index] = cls
with open(self.save_path, "wb") as f:
pickle.dump([self.list_of_points, self.list_of_labels], f)
else:
print("Load processed data from %s..." % self.save_path)
with open(self.save_path, "rb") as f:
self.list_of_points, self.list_of_labels = pickle.load(f)
def __len__(self):
return len(self.datapath)
def _get_item(self, index):
if self.process_data:
point_set, label = self.list_of_points[index], self.list_of_labels[
index]
else:
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
label = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=",").astype(np.float32)
if self.uniform:
point_set = farthest_point_sample(point_set, self.npoints)
else:
point_set = point_set[0:self.npoints, :]
point_set[:, 0:3] = normalize_point_cloud(point_set[:, 0:3])
if not self.use_normals:
point_set = point_set[:, 0:3]
if self.split == "train":
# point_set[:, 0:3] = random_rotate_point_cloud(point_set[:, 0:3])
point_set[:, 0:3] = jitter_point_cloud(point_set[:, 0:3])
# point_set[:, 0:3] = random_point_dropout(point_set[:, 0:3])
point_set[:, 0:3] = random_scale_point_cloud(point_set[:, 0:3])
point_set[:, 0:3] = shift_point_cloud(point_set[:, 0:3])
return paddle.to_tensor(point_set, dtype=paddle.float32), int(label[0])
def __getitem__(self, index):
return self._get_item(index)
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
offset_map = {1024: -3.2041, 2048: -3.4025, 4096: -3.5836}
class TNet(nn.Layer):
def __init__(self, k=64, binary=False):
super(TNet, self).__init__()
self.conv1 = nn.Conv1D(k, 64, 1)
self.conv2 = nn.Conv1D(64, 128, 1)
self.conv3 = nn.Conv1D(128, 1024, 1)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, k * k)
self.act_function = nn.ReLU()
self.bn1 = nn.BatchNorm1D(64)
self.bn2 = nn.BatchNorm1D(128)
self.bn3 = nn.BatchNorm1D(1024)
self.bn4 = nn.BatchNorm1D(512)
self.bn5 = nn.BatchNorm1D(256)
self.k = k
self.binary = binary
self.iden = paddle.eye(self.k, self.k, dtype=paddle.float32)
def forward(self, x):
B, D, N = x.shape
x = self.act_function(self.bn1(self.conv1(x)))
x = self.act_function(self.bn2(self.conv2(x)))
if self.binary:
x = self.bn3(self.conv3(x))
x = paddle.max(x, 2, keepdim=True) + offset_map[N]
else:
x = self.act_function(self.bn3(self.conv3(x)))
x = paddle.max(x, 2, keepdim=True)
x = x.reshape((-1, 1024))
x = self.act_function(self.bn4(self.fc1(x)))
x = self.act_function(self.bn5(self.fc2(x)))
x = self.fc3(x)
x = x.reshape((-1, self.k, self.k)) + self.iden
return x
class PointNetEncoder(nn.Layer):
def __init__(self,
global_feat=True,
input_transform=True,
feature_transform=False,
channel=3,
binary=False):
super(PointNetEncoder, self).__init__()
self.global_feat = global_feat
if input_transform:
self.input_transfrom = TNet(k=channel)
else:
self.input_transfrom = lambda x: paddle.eye(
channel, channel, dtype=paddle.float32
)
self.conv1 = nn.Conv1D(channel, 64, 1)
self.conv2 = nn.Conv1D(64, 128, 1)
self.conv3 = nn.Conv1D(128, 1024, 1)
self.bn1 = nn.BatchNorm1D(64)
self.bn2 = nn.BatchNorm1D(128)
self.bn3 = nn.BatchNorm1D(1024)
if feature_transform:
self.feature_transform = TNet(k=64)
else:
self.feature_transform = lambda x: paddle.eye(64, 64, dtype=paddle.float32)
self.act_function = nn.ReLU()
self.binary = binary
def forward(self, x):
x = paddle.transpose(x, (0, 2, 1))
B, D, N = x.shape
trans_input = self.input_transfrom(x)
x = paddle.transpose(x, (0, 2, 1))
if D > 3:
feature = x[:, :, 3:]
x = x[:, :, :3]
x = paddle.bmm(x, trans_input)
if D > 3:
x = paddle.concat([x, feature], axis=2)
x = paddle.transpose(x, (0, 2, 1))
x = self.act_function(self.bn1(self.conv1(x)))
trans_feat = self.feature_transform(x)
x = paddle.transpose(x, (0, 2, 1))
x = paddle.bmm(x, trans_feat)
x = paddle.transpose(x, (0, 2, 1))
pointfeat = x
x = self.act_function(self.bn2(self.conv2(x)))
x = self.bn3(self.conv3(x))
if self.binary:
x = paddle.max(x, 2, keepdim=True) + offset_map[N]
else:
x = paddle.max(x, 2, keepdim=True)
x = x.reshape((-1, 1024))
if self.global_feat:
return x, trans_input, trans_feat
else:
x = x.reshape((-1, 1024, 1)).repeat(1, 1, N)
return paddle.cat([x, pointfeat], 1), trans_input, trans_feat
class PointNetClassifier(nn.Layer):
def __init__(self, k=40, normal_channel=False, binary=False):
super(PointNetClassifier, self).__init__()
if normal_channel:
channel = 6
else:
channel = 3
self.feat = PointNetEncoder(
global_feat=True,
input_transform=True,
feature_transform=True,
channel=channel,
binary=binary, )
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, k)
self.dropout = nn.Dropout(p=0.4)
self.bn1 = nn.BatchNorm1D(512)
self.bn2 = nn.BatchNorm1D(256)
self.act_function = nn.ReLU()
self.binary = binary
def forward(self, x):
x, trans_input, trans_feat = self.feat(x)
x = self.act_function(self.bn1(self.fc1(x)))
x = self.act_function(self.bn2(self.fc2(x)))
if not self.binary:
x = self.dropout(x)
x = self.fc3(x)
return x, trans_input, trans_feat
class CrossEntropyMatrixRegularization(nn.Layer):
def __init__(self, mat_diff_loss_scale=1e-3):
super(CrossEntropyMatrixRegularization, self).__init__()
self.mat_diff_loss_scale = mat_diff_loss_scale
def forward(self, pred, target, trans_feat=None):
loss = F.cross_entropy(pred, target)
if trans_feat is None:
mat_diff_loss = 0
else:
mat_diff_loss = self.feature_transform_reguliarzer(trans_feat)
total_loss = loss + mat_diff_loss * self.mat_diff_loss_scale
return total_loss
def feature_transform_reguliarzer(self, trans):
d = trans.shape[1]
I = paddle.eye(d)
loss = paddle.mean(
paddle.norm(
paddle.bmm(trans, paddle.transpose(trans, (0, 2, 1))) - I,
axis=(1, 2)))
return loss
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import paddle
from paddle.io import DataLoader
from paddle.metric import Accuracy
from data import ModelNetDataset
from model import PointNetClassifier
def parse_args():
parser = argparse.ArgumentParser("Test")
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
parser.add_argument(
"--num_point", type=int, default=1024, help="point number")
parser.add_argument(
"--num_workers", type=int, default=32, help="num wrokers")
parser.add_argument("--log_freq", type=int, default=10)
parser.add_argument(
"--model_path", type=str, default="./BiPointNet.pdparams")
parser.add_argument(
"--data_dir",
type=str,
default="./modelnet40_normal_resampled", )
parser.add_argument(
"--binary",
action='store_true',
help="whehter to build binary pointnet")
return parser.parse_args()
def test(args):
test_data = ModelNetDataset(
args.data_dir, split="test", num_point=args.num_point)
test_loader = DataLoader(
test_data,
shuffle=False,
num_workers=args.num_workers,
batch_size=args.batch_size, )
model = PointNetClassifier(binary=args.binary)
if args.binary:
import basic
fp_layers = [
id(model.feat.input_transfrom.conv1),
id(model.feat.conv1),
id(model.fc3)
]
model = basic._to_bi_function(model, fp_layers=fp_layers)
def func(model):
if hasattr(model, "scale_weight_init"):
model.scale_weight_init = True
model.apply(func)
model_state_dict = paddle.load(args.model_path)
model.set_state_dict(model_state_dict)
metrics = Accuracy()
metrics.reset()
model.eval()
for iter, data in enumerate(test_loader):
x, y = data
pred, _, _ = model(x)
correct = metrics.compute(pred, y)
metrics.update(correct)
if iter % args.log_freq == 0:
print("Eval iter:", iter)
test_acc = metrics.accumulate()
print("Test Accuracy: {}".format(test_acc))
if __name__ == "__main__":
args = parse_args()
print(args)
test(args)
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import paddle
from paddle.io import DataLoader
from paddle.metric import Accuracy
from paddle.optimizer import Adam
from paddle.optimizer.lr import CosineAnnealingDecay
from data import ModelNetDataset
from model import CrossEntropyMatrixRegularization, PointNetClassifier
def parse_args():
parser = argparse.ArgumentParser("Train")
parser.add_argument(
"--batch_size", type=int, default=32, help="batch size in training")
parser.add_argument(
"--learning_rate",
type=float,
default=1e-3,
help="learning rate in training")
parser.add_argument(
"--num_point", type=int, default=1024, help="point number")
parser.add_argument(
"--max_epochs", type=int, default=200, help="max epochs")
parser.add_argument(
"--num_workers", type=int, default=32, help="num wrokers")
parser.add_argument(
"--weight_decay", type=float, default=1e-4, help="weight decay")
parser.add_argument("--log_freq", type=int, default=50)
parser.add_argument(
"--pretrained",
type=str,
default='pointnet.pdparams',
help='pretrained model path')
parser.add_argument(
"--save_dir", type=str, default='./save_model', help='save model path')
parser.add_argument(
"--data_dir",
type=str,
default="./modelnet40_normal_resampled",
help='dataset dir')
parser.add_argument(
"--binary",
action='store_true',
help="whehter to build binary pointnet")
return parser.parse_args()
def train(args):
train_data = ModelNetDataset(
args.data_dir, split="train", num_point=args.num_point)
test_data = ModelNetDataset(
args.data_dir, split="test", num_point=args.num_point)
train_loader = DataLoader(
train_data,
shuffle=True,
num_workers=args.num_workers,
batch_size=args.batch_size, )
test_loader = DataLoader(
test_data,
shuffle=False,
num_workers=args.num_workers,
batch_size=args.batch_size, )
model = PointNetClassifier(binary=args.binary)
if args.binary:
import basic
fp_layers = [
id(model.feat.input_transfrom.conv1),
id(model.feat.conv1),
id(model.fc3)
]
model = basic._to_bi_function(model, fp_layers=fp_layers)
print(model)
model_state_dict = paddle.load(path=args.pretrained)
model.set_state_dict(model_state_dict)
scheduler = CosineAnnealingDecay(
learning_rate=args.learning_rate,
T_max=args.max_epochs, )
optimizer = Adam(
learning_rate=scheduler,
parameters=model.parameters(),
weight_decay=args.weight_decay, )
loss_fn = CrossEntropyMatrixRegularization()
metrics = Accuracy()
best_test_acc = 0
for epoch in range(args.max_epochs):
metrics.reset()
model.train()
for batch_id, data in enumerate(train_loader):
x, y = data
pred, trans_input, trans_feat = model(x)
loss = loss_fn(pred, y, trans_feat)
correct = metrics.compute(pred, y)
metrics.update(correct)
loss.backward()
optimizer.step()
optimizer.clear_grad()
if (batch_id + 1) % args.log_freq == 0:
print("Epoch: {}, Batch ID: {}, Loss: {}, ACC: {}".format(
epoch, batch_id + 1, loss.item(), metrics.accumulate()))
scheduler.step()
metrics.reset()
model.eval()
for batch_id, data in enumerate(test_loader):
x, y = data
pred, trans_input, trans_feat = model(x)
correct = metrics.compute(pred, y)
metrics.update(correct)
test_acc = metrics.accumulate()
print("Test epoch: {}, acc is: {}".format(epoch, test_acc))
if test_acc > best_test_acc:
best_test_acc = test_acc
save_path = os.path.join(args.save_dir, 'best_model.pdparams')
paddle.save(model.state_dict(), save_path)
print("Best Test ACC: {}, Model saved in {}".format(
test_acc, save_path))
else:
print("Current Best Test ACC: {}".format(best_test_acc))
if __name__ == "__main__":
args = parse_args()
print(args)
train(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册