diff --git a/example/BiPointNet/README.md b/example/BiPointNet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cb9867fd94ed3f72f8684df061febb80b4eac915 --- /dev/null +++ b/example/BiPointNet/README.md @@ -0,0 +1,68 @@ +# BiPointNet + +## 1. 简介 +本示例介绍了一种用于点云模型 (PointNet) 的二值化方法(BiPointNet)。BiPointNet 通过引入熵最大化聚合(Entropy-Maximizing Aggregation)来调整聚合前的分布,以获得最大信息熵,并引入分层尺度因子(Layer-wise Scale Recovery)有效恢复特征表达能力,是一种简单而高效的二值化点云模型的方法。 + +## 2. Benchmark + +| 模型 | Accuracy | 权重下载 | +| ------------- | --------- | --------- | +| PointNet | 89.83 | [PointNet.pdparams](https://bj.bcebos.com/v1/paddle-slim-models/PointNet.pdparams) | +| BiPointNet | 85.86 | [BiPointNet.pdparams](https://bj.bcebos.com/v1/paddle-slim-models/BiPointNet.pdparams) | + + +## 3. BiPointNet 的训练及测试 +BiPointNet 整体结构如图所示,详情见论文 [BIPOINTNET: BINARY NEURAL NETWORK FOR POINT CLOUDS](https://arxiv.org/abs/2010.05501) + +![arch](arch.png) + +### 3.1 准备环境 +- PaddlePaddle >= 2.4 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装) + +安装 paddlepaddle: +```shell +# CPU +pip install paddlepaddle==2.4.2 +# GPU 以Ubuntu、CUDA 11.2为例 +python -m pip install paddlepaddle-gpu==2.4.2.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html +``` + +### 3.2 准备数据集 + +本示例在 [ModelNet40](https://modelnet.cs.princeton.edu) 数据集上进行了分类实验。 + +```shell +wget https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip --no-check-certificate +unzip modelnet40_normal_resampled.zip +``` + +### 3.3 启动训练 + +- 训练基准模型 PointNet +``` +export CUDA_VISIBLE_DEVICES=0 +python train.py --save_dir 'pointnet' +``` + +- 训练 BiPointNet +``` +export CUDA_VISIBLE_DEVICES=0 +python train.py --save_dir 'bipointnet' --binary +``` + +### 3.4 验证精度 + +- 测试基准模型 PointNet +```shell +export CUDA_VISIBLE_DEVICES=0 +python test.py --model_path 'PointNet.pdparams' +``` + +- 测试 BiPointNet +```shell +export CUDA_VISIBLE_DEVICES=0 +python train.py --model_path 'BiPointNet.pdparams' --binary +``` + +## 致谢 +感谢 [Macaronlin](https://github.com/Macaronlin) 贡献 BiPointNet。 diff --git a/example/BiPointNet/arch.png b/example/BiPointNet/arch.png new file mode 100644 index 0000000000000000000000000000000000000000..b3296501706dd5c6fb22f469f46a1e9012c29276 Binary files /dev/null and b/example/BiPointNet/arch.png differ diff --git a/example/BiPointNet/basic.py b/example/BiPointNet/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..5c688988e77e58eb99e74a80cd221bad2f8eaf5e --- /dev/null +++ b/example/BiPointNet/basic.py @@ -0,0 +1,161 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import nn +from paddle.autograd import PyLayer +from paddle.nn import functional as F +from paddle.nn.layer import Conv1D +from paddle.nn.layer.common import Linear + + +class BinaryQuantizer(PyLayer): + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + out = paddle.sign(input) + return out + + @staticmethod + def backward(ctx, grad_output): + input = ctx.saved_tensor()[0] + grad_input = grad_output + grad_input[input >= 1] = 0 + grad_input[input <= -1] = 0 + return grad_input.clone() + + +class BiLinear(Linear): + def __init__(self, + in_features, + out_features, + weight_attr=None, + bias_attr=None, + name=None): + super(BiLinear, self).__init__( + in_features, + out_features, + weight_attr=weight_attr, + bias_attr=bias_attr, + name=name) + + self.scale_weight_init = False + self.scale_weight = paddle.create_parameter(shape=[1], dtype='float32') + + def forward(self, input): + ba = input + + bw = self.weight + bw = bw - bw.mean() + + if self.scale_weight_init == False: + scale_weight = F.linear(ba, bw).std() / F.linear( + paddle.sign(ba), paddle.sign(bw)).std() + if paddle.isnan(scale_weight): + scale_weight = bw.std() / paddle.sign(bw).std() + self.scale_weight.set_value(scale_weight) + self.scale_weight_init = True + + ba = BinaryQuantizer.apply(ba) + bw = BinaryQuantizer.apply(bw) + bw = bw * self.scale_weight + + out = F.linear(x=ba, weight=bw, bias=self.bias, name=self.name) + return out + + +class BiConv1D(Conv1D): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + padding_mode='zeros', + weight_attr=None, + bias_attr=None, + data_format="NCL"): + super(BiConv1D, self).__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, padding_mode, weight_attr, bias_attr, data_format) + self.scale_weight_init = False + self.scale_weight = paddle.create_parameter(shape=[1], dtype='float32') + + def forward(self, input): + ba = input + + bw = self.weight + bw = bw - bw.mean() + + padding = 0 + if self._padding_mode != "zeros": + ba = F.pad( + ba, + self._reversed_padding_repeated_twice, + mode=self._padding_mode, + data_format=self._data_format) + else: + padding = self._padding + + if self.scale_weight_init == False: + scale_weight = F.conv1d(ba, bw, bias=self.bias, padding=padding, stride=self._stride, dilation=self._dilation, groups=self._groups, data_format=self._data_format).std() / \ + F.conv1d(paddle.sign(ba), paddle.sign(bw), bias=self.bias, padding=padding, stride=self._stride, dilation=self._dilation, groups=self._groups, data_format=self._data_format).std() + if paddle.isnan(scale_weight): + scale_weight = bw.std() / paddle.sign(bw).std() + + self.scale_weight.set_value(scale_weight) + self.scale_weight_init = True + + ba = BinaryQuantizer.apply(ba) + bw = BinaryQuantizer.apply(bw) + bw = bw * self.scale_weight + + return F.conv1d( + ba, + bw, + bias=self.bias, + padding=padding, + stride=self._stride, + dilation=self._dilation, + groups=self._groups, + data_format=self._data_format) + + +def _to_bi_function(model, fp_layers=[]): + for name, layer in model.named_children(): + if id(layer) in fp_layers: + continue + if isinstance(layer, Linear): + new_layer = BiLinear(layer.weight.shape[0], layer.weight.shape[1], + layer._weight_attr, layer._bias_attr, + layer.name) + new_layer.weight = layer.weight + new_layer.bias = layer.bias + model._sub_layers[name] = new_layer + elif isinstance(layer, Conv1D): + new_layer = BiConv1D(layer._in_channels, layer._out_channels, + layer._kernel_size, layer._stride, + layer._padding, layer._dilation, layer._groups, + layer._padding_mode, layer._param_attr, + layer._bias_attr, layer._data_format) + new_layer.weight = layer.weight + new_layer.bias = layer.bias + model._sub_layers[name] = new_layer + elif isinstance(layer, nn.ReLU): + model._sub_layers[name] = nn.Hardtanh() + else: + model._sub_layers[name] = _to_bi_function(layer, fp_layers) + return model diff --git a/example/BiPointNet/data.py b/example/BiPointNet/data.py new file mode 100644 index 0000000000000000000000000000000000000000..45fb6d9766a786695eebbd8c7ac96ef3966d827a --- /dev/null +++ b/example/BiPointNet/data.py @@ -0,0 +1,223 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pickle +import warnings +import numpy as np +import paddle +from paddle.io import Dataset +from tqdm import tqdm + +warnings.filterwarnings("ignore") + + +def normalize_point_cloud(pc): + centroid = np.mean(pc, axis=0) + pc = pc - centroid + m = np.max(np.sqrt(np.sum(pc**2, axis=1))) + pc = pc / m + return pc + + +def random_point_dropout(pc, max_dropout_ratio=0.875): + dropout_ratio = np.random.random() * max_dropout_ratio # 0~0.875 + drop_idx = np.where(np.random.random((pc.shape[0])) <= dropout_ratio)[0] + if len(drop_idx) > 0: + pc[drop_idx, :] = pc[0, :] # set to the first point + return pc + + +def random_scale_point_cloud(data, scale_low=0.8, scale_high=1.25): + scales = np.random.uniform(scale_low, scale_high) + data *= scales + return data + + +def shift_point_cloud(data, shift_range=0.1): + shifts = np.random.uniform(-shift_range, shift_range, (3)) + data += shifts + return data + + +def jitter_point_cloud(data: np.ndarray, sigma: float=0.02, clip: float=0.05): + assert clip > 0 + jittered_data = np.clip(sigma * np.random.randn(*data.shape), -clip, clip) + data = data + jittered_data + return data + + +def random_rotate_point_cloud(data: np.ndarray): + angle = np.random.uniform() * 2 * np.pi + cosval = np.cos(angle) + sinval = np.sin(angle) + rotation_matrix = np.array( + [[cosval, sinval, 0], [-sinval, cosval, 0], [0, 0, 1]], + dtype=data.dtype) + data = data @ rotation_matrix + return data + + +def farthest_point_sample(point, npoint): + """ + Input: + xyz: pointcloud data, [N, D] + npoint: number of samples + Return: + centroids: sampled pointcloud index, [npoint, D] + """ + N, D = point.shape + xyz = point[:, :3] + centroids = np.zeros((npoint, )) + distance = np.ones((N, )) * 1e10 + farthest = np.random.randint(0, N) + for i in range(npoint): + centroids[i] = farthest + centroid = xyz[farthest, :] + dist = np.sum((xyz - centroid)**2, -1) + mask = dist < distance + distance[mask] = dist[mask] + farthest = np.argmax(distance, -1) + point = point[centroids.astype(np.int32)] + return point + + +class ModelNetDataset(Dataset): + def __init__( + self, + root, + num_point, + use_uniform_sample=False, + use_normals=False, + num_category=40, + split="train", + process_data=False, ): + self.root = root + self.npoints = num_point + self.split = split + self.process_data = process_data + self.uniform = use_uniform_sample + self.use_normals = use_normals + self.num_category = num_category + + if self.num_category == 10: + self.catfile = os.path.join(self.root, "modelnet10_shape_names.txt") + else: + self.catfile = os.path.join(self.root, "modelnet40_shape_names.txt") + + self.cat = [line.rstrip() for line in open(self.catfile)] + self.classes = dict(zip(self.cat, range(len(self.cat)))) + + shape_ids = {} + if self.num_category == 10: + shape_ids["train"] = [ + line.rstrip() for line in + open(os.path.join(self.root, "modelnet10_train.txt")) + ] + shape_ids["test"] = [ + line.rstrip() for line in + open(os.path.join(self.root, "modelnet10_test.txt")) + ] + else: + shape_ids["train"] = [ + line.rstrip() for line in + open(os.path.join(self.root, "modelnet40_train.txt")) + ] + shape_ids["test"] = [ + line.rstrip() for line in + open(os.path.join(self.root, "modelnet40_test.txt")) + ] + + assert split == "train" or split == "test" + shape_names = ["_".join(x.split("_")[0:-1]) for x in shape_ids[split]] + self.datapath = [(shape_names[i], os.path.join( + self.root, shape_names[i], shape_ids[split][i]) + ".txt", ) + for i in range(len(shape_ids[split]))] + print("The size of %s data is %d" % (split, len(self.datapath))) + + if self.uniform: + self.save_path = os.path.join( + root, + "modelnet%d_%s_%dpts_fps.dat" % (self.num_category, split, + self.npoints), ) + else: + self.save_path = os.path.join( + root, + "modelnet%d_%s_%dpts.dat" % (self.num_category, split, + self.npoints), ) + + if self.process_data: + if not os.path.exists(self.save_path): + print("Processing data %s (only running in the first time)..." % + self.save_path) + self.list_of_points = [None] * len(self.datapath) + self.list_of_labels = [None] * len(self.datapath) + + for index in tqdm( + range(len(self.datapath)), total=len(self.datapath)): + fn = self.datapath[index] + cls = self.classes[self.datapath[index][0]] + cls = np.array([cls]).astype(np.int32) + point_set = np.loadtxt( + fn[1], delimiter=",").astype(np.float32) + + if self.uniform: + point_set = farthest_point_sample( + point_set, self.npoints) + else: + point_set = point_set[0:self.npoints, :] + + self.list_of_points[index] = point_set + self.list_of_labels[index] = cls + + with open(self.save_path, "wb") as f: + pickle.dump([self.list_of_points, self.list_of_labels], f) + else: + print("Load processed data from %s..." % self.save_path) + with open(self.save_path, "rb") as f: + self.list_of_points, self.list_of_labels = pickle.load(f) + + def __len__(self): + return len(self.datapath) + + def _get_item(self, index): + if self.process_data: + point_set, label = self.list_of_points[index], self.list_of_labels[ + index] + else: + fn = self.datapath[index] + cls = self.classes[self.datapath[index][0]] + label = np.array([cls]).astype(np.int32) + point_set = np.loadtxt(fn[1], delimiter=",").astype(np.float32) + + if self.uniform: + point_set = farthest_point_sample(point_set, self.npoints) + else: + point_set = point_set[0:self.npoints, :] + + point_set[:, 0:3] = normalize_point_cloud(point_set[:, 0:3]) + if not self.use_normals: + point_set = point_set[:, 0:3] + + if self.split == "train": + # point_set[:, 0:3] = random_rotate_point_cloud(point_set[:, 0:3]) + point_set[:, 0:3] = jitter_point_cloud(point_set[:, 0:3]) + # point_set[:, 0:3] = random_point_dropout(point_set[:, 0:3]) + point_set[:, 0:3] = random_scale_point_cloud(point_set[:, 0:3]) + point_set[:, 0:3] = shift_point_cloud(point_set[:, 0:3]) + + return paddle.to_tensor(point_set, dtype=paddle.float32), int(label[0]) + + def __getitem__(self, index): + return self._get_item(index) diff --git a/example/BiPointNet/model.py b/example/BiPointNet/model.py new file mode 100644 index 0000000000000000000000000000000000000000..bb2dcbdf1aa29a99b06d56b0eedf4706e34ec230 --- /dev/null +++ b/example/BiPointNet/model.py @@ -0,0 +1,190 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +offset_map = {1024: -3.2041, 2048: -3.4025, 4096: -3.5836} + + +class TNet(nn.Layer): + def __init__(self, k=64, binary=False): + super(TNet, self).__init__() + self.conv1 = nn.Conv1D(k, 64, 1) + self.conv2 = nn.Conv1D(64, 128, 1) + self.conv3 = nn.Conv1D(128, 1024, 1) + self.fc1 = nn.Linear(1024, 512) + self.fc2 = nn.Linear(512, 256) + self.fc3 = nn.Linear(256, k * k) + self.act_function = nn.ReLU() + + self.bn1 = nn.BatchNorm1D(64) + self.bn2 = nn.BatchNorm1D(128) + self.bn3 = nn.BatchNorm1D(1024) + self.bn4 = nn.BatchNorm1D(512) + self.bn5 = nn.BatchNorm1D(256) + + self.k = k + self.binary = binary + self.iden = paddle.eye(self.k, self.k, dtype=paddle.float32) + + def forward(self, x): + B, D, N = x.shape + + x = self.act_function(self.bn1(self.conv1(x))) + x = self.act_function(self.bn2(self.conv2(x))) + + if self.binary: + x = self.bn3(self.conv3(x)) + x = paddle.max(x, 2, keepdim=True) + offset_map[N] + else: + x = self.act_function(self.bn3(self.conv3(x))) + x = paddle.max(x, 2, keepdim=True) + x = x.reshape((-1, 1024)) + + x = self.act_function(self.bn4(self.fc1(x))) + x = self.act_function(self.bn5(self.fc2(x))) + x = self.fc3(x) + + x = x.reshape((-1, self.k, self.k)) + self.iden + return x + + +class PointNetEncoder(nn.Layer): + def __init__(self, + global_feat=True, + input_transform=True, + feature_transform=False, + channel=3, + binary=False): + super(PointNetEncoder, self).__init__() + + self.global_feat = global_feat + if input_transform: + self.input_transfrom = TNet(k=channel) + else: + self.input_transfrom = lambda x: paddle.eye( + channel, channel, dtype=paddle.float32 + ) + + self.conv1 = nn.Conv1D(channel, 64, 1) + self.conv2 = nn.Conv1D(64, 128, 1) + self.conv3 = nn.Conv1D(128, 1024, 1) + + self.bn1 = nn.BatchNorm1D(64) + self.bn2 = nn.BatchNorm1D(128) + self.bn3 = nn.BatchNorm1D(1024) + + if feature_transform: + self.feature_transform = TNet(k=64) + else: + self.feature_transform = lambda x: paddle.eye(64, 64, dtype=paddle.float32) + + self.act_function = nn.ReLU() + self.binary = binary + + def forward(self, x): + x = paddle.transpose(x, (0, 2, 1)) + B, D, N = x.shape + trans_input = self.input_transfrom(x) + x = paddle.transpose(x, (0, 2, 1)) + if D > 3: + feature = x[:, :, 3:] + x = x[:, :, :3] + x = paddle.bmm(x, trans_input) + if D > 3: + x = paddle.concat([x, feature], axis=2) + x = paddle.transpose(x, (0, 2, 1)) + x = self.act_function(self.bn1(self.conv1(x))) + + trans_feat = self.feature_transform(x) + x = paddle.transpose(x, (0, 2, 1)) + x = paddle.bmm(x, trans_feat) + x = paddle.transpose(x, (0, 2, 1)) + + pointfeat = x + x = self.act_function(self.bn2(self.conv2(x))) + x = self.bn3(self.conv3(x)) + + if self.binary: + x = paddle.max(x, 2, keepdim=True) + offset_map[N] + else: + x = paddle.max(x, 2, keepdim=True) + + x = x.reshape((-1, 1024)) + + if self.global_feat: + return x, trans_input, trans_feat + else: + x = x.reshape((-1, 1024, 1)).repeat(1, 1, N) + return paddle.cat([x, pointfeat], 1), trans_input, trans_feat + + +class PointNetClassifier(nn.Layer): + def __init__(self, k=40, normal_channel=False, binary=False): + super(PointNetClassifier, self).__init__() + if normal_channel: + channel = 6 + else: + channel = 3 + self.feat = PointNetEncoder( + global_feat=True, + input_transform=True, + feature_transform=True, + channel=channel, + binary=binary, ) + self.fc1 = nn.Linear(1024, 512) + self.fc2 = nn.Linear(512, 256) + self.fc3 = nn.Linear(256, k) + self.dropout = nn.Dropout(p=0.4) + self.bn1 = nn.BatchNorm1D(512) + self.bn2 = nn.BatchNorm1D(256) + self.act_function = nn.ReLU() + self.binary = binary + + def forward(self, x): + x, trans_input, trans_feat = self.feat(x) + x = self.act_function(self.bn1(self.fc1(x))) + x = self.act_function(self.bn2(self.fc2(x))) + if not self.binary: + x = self.dropout(x) + x = self.fc3(x) + return x, trans_input, trans_feat + + +class CrossEntropyMatrixRegularization(nn.Layer): + def __init__(self, mat_diff_loss_scale=1e-3): + super(CrossEntropyMatrixRegularization, self).__init__() + self.mat_diff_loss_scale = mat_diff_loss_scale + + def forward(self, pred, target, trans_feat=None): + loss = F.cross_entropy(pred, target) + + if trans_feat is None: + mat_diff_loss = 0 + else: + mat_diff_loss = self.feature_transform_reguliarzer(trans_feat) + + total_loss = loss + mat_diff_loss * self.mat_diff_loss_scale + return total_loss + + def feature_transform_reguliarzer(self, trans): + d = trans.shape[1] + I = paddle.eye(d) + loss = paddle.mean( + paddle.norm( + paddle.bmm(trans, paddle.transpose(trans, (0, 2, 1))) - I, + axis=(1, 2))) + return loss diff --git a/example/BiPointNet/test.py b/example/BiPointNet/test.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9fb6bccde4c21468cbb4deb505647fe97ba0ec --- /dev/null +++ b/example/BiPointNet/test.py @@ -0,0 +1,91 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import paddle +from paddle.io import DataLoader +from paddle.metric import Accuracy +from data import ModelNetDataset +from model import PointNetClassifier + + +def parse_args(): + parser = argparse.ArgumentParser("Test") + parser.add_argument("--batch_size", type=int, default=32, help="batch size") + parser.add_argument( + "--num_point", type=int, default=1024, help="point number") + parser.add_argument( + "--num_workers", type=int, default=32, help="num wrokers") + parser.add_argument("--log_freq", type=int, default=10) + parser.add_argument( + "--model_path", type=str, default="./BiPointNet.pdparams") + parser.add_argument( + "--data_dir", + type=str, + default="./modelnet40_normal_resampled", ) + parser.add_argument( + "--binary", + action='store_true', + help="whehter to build binary pointnet") + return parser.parse_args() + + +def test(args): + + test_data = ModelNetDataset( + args.data_dir, split="test", num_point=args.num_point) + test_loader = DataLoader( + test_data, + shuffle=False, + num_workers=args.num_workers, + batch_size=args.batch_size, ) + + model = PointNetClassifier(binary=args.binary) + if args.binary: + import basic + fp_layers = [ + id(model.feat.input_transfrom.conv1), + id(model.feat.conv1), + id(model.fc3) + ] + model = basic._to_bi_function(model, fp_layers=fp_layers) + + def func(model): + if hasattr(model, "scale_weight_init"): + model.scale_weight_init = True + + model.apply(func) + + model_state_dict = paddle.load(args.model_path) + model.set_state_dict(model_state_dict) + + metrics = Accuracy() + metrics.reset() + model.eval() + for iter, data in enumerate(test_loader): + x, y = data + pred, _, _ = model(x) + + correct = metrics.compute(pred, y) + metrics.update(correct) + if iter % args.log_freq == 0: + print("Eval iter:", iter) + test_acc = metrics.accumulate() + print("Test Accuracy: {}".format(test_acc)) + + +if __name__ == "__main__": + args = parse_args() + print(args) + test(args) diff --git a/example/BiPointNet/train.py b/example/BiPointNet/train.py new file mode 100644 index 0000000000000000000000000000000000000000..60b0729a9c312c3a51ee34716a0f1952f80c6425 --- /dev/null +++ b/example/BiPointNet/train.py @@ -0,0 +1,151 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +import paddle +from paddle.io import DataLoader +from paddle.metric import Accuracy +from paddle.optimizer import Adam +from paddle.optimizer.lr import CosineAnnealingDecay +from data import ModelNetDataset +from model import CrossEntropyMatrixRegularization, PointNetClassifier + + +def parse_args(): + parser = argparse.ArgumentParser("Train") + parser.add_argument( + "--batch_size", type=int, default=32, help="batch size in training") + parser.add_argument( + "--learning_rate", + type=float, + default=1e-3, + help="learning rate in training") + parser.add_argument( + "--num_point", type=int, default=1024, help="point number") + parser.add_argument( + "--max_epochs", type=int, default=200, help="max epochs") + parser.add_argument( + "--num_workers", type=int, default=32, help="num wrokers") + parser.add_argument( + "--weight_decay", type=float, default=1e-4, help="weight decay") + parser.add_argument("--log_freq", type=int, default=50) + parser.add_argument( + "--pretrained", + type=str, + default='pointnet.pdparams', + help='pretrained model path') + parser.add_argument( + "--save_dir", type=str, default='./save_model', help='save model path') + parser.add_argument( + "--data_dir", + type=str, + default="./modelnet40_normal_resampled", + help='dataset dir') + parser.add_argument( + "--binary", + action='store_true', + help="whehter to build binary pointnet") + return parser.parse_args() + + +def train(args): + train_data = ModelNetDataset( + args.data_dir, split="train", num_point=args.num_point) + test_data = ModelNetDataset( + args.data_dir, split="test", num_point=args.num_point) + train_loader = DataLoader( + train_data, + shuffle=True, + num_workers=args.num_workers, + batch_size=args.batch_size, ) + test_loader = DataLoader( + test_data, + shuffle=False, + num_workers=args.num_workers, + batch_size=args.batch_size, ) + + model = PointNetClassifier(binary=args.binary) + if args.binary: + import basic + fp_layers = [ + id(model.feat.input_transfrom.conv1), + id(model.feat.conv1), + id(model.fc3) + ] + model = basic._to_bi_function(model, fp_layers=fp_layers) + print(model) + + model_state_dict = paddle.load(path=args.pretrained) + model.set_state_dict(model_state_dict) + + scheduler = CosineAnnealingDecay( + learning_rate=args.learning_rate, + T_max=args.max_epochs, ) + + optimizer = Adam( + learning_rate=scheduler, + parameters=model.parameters(), + weight_decay=args.weight_decay, ) + loss_fn = CrossEntropyMatrixRegularization() + metrics = Accuracy() + + best_test_acc = 0 + for epoch in range(args.max_epochs): + metrics.reset() + model.train() + for batch_id, data in enumerate(train_loader): + + x, y = data + pred, trans_input, trans_feat = model(x) + + loss = loss_fn(pred, y, trans_feat) + + correct = metrics.compute(pred, y) + metrics.update(correct) + loss.backward() + optimizer.step() + optimizer.clear_grad() + + if (batch_id + 1) % args.log_freq == 0: + print("Epoch: {}, Batch ID: {}, Loss: {}, ACC: {}".format( + epoch, batch_id + 1, loss.item(), metrics.accumulate())) + + scheduler.step() + + metrics.reset() + model.eval() + for batch_id, data in enumerate(test_loader): + x, y = data + pred, trans_input, trans_feat = model(x) + + correct = metrics.compute(pred, y) + metrics.update(correct) + test_acc = metrics.accumulate() + print("Test epoch: {}, acc is: {}".format(epoch, test_acc)) + + if test_acc > best_test_acc: + best_test_acc = test_acc + save_path = os.path.join(args.save_dir, 'best_model.pdparams') + paddle.save(model.state_dict(), save_path) + print("Best Test ACC: {}, Model saved in {}".format( + test_acc, save_path)) + else: + print("Current Best Test ACC: {}".format(best_test_acc)) + + +if __name__ == "__main__": + args = parse_args() + print(args) + train(args)