diff --git a/.inscode b/.inscode index 3a001eb88671424594301483eb8665b81f6f4826..b338a72a4543ceb48c51c18f29b126bbeb45c899 100644 --- a/.inscode +++ b/.inscode @@ -1,16 +1,4 @@ -run = "pip install -r requirements.txt;python main.py" -language = "python" - -[packager] -AUTO_PIP = true - -[env] -VIRTUAL_ENV = "/root/${PROJECT_DIR}/venv" -PATH = "${VIRTUAL_ENV}/bin:${PATH}" -PYTHONPATH = "$PYTHONHOME/lib/python3.10:${VIRTUAL_ENV}/lib/python3.10/site-packages" -REPLIT_POETRY_PYPI_REPOSITORY = "http://mirrors.csdn.net.cn/repository/csdn-pypi-mirrors/simple" -MPLBACKEND = "TkAgg" -POETRY_CACHE_DIR = "/root/${PROJECT_DIR}/.cache/pypoetry" - -[debugger] -program = "main.py" +run = "cd deeppolar-main && python main.py --test --N 256 --K 37 --kernel_size 16 --test_snr_start -5 --test_snr_end 5 --snr_points 5" +is_gui = false +is_resident = true +is_html = false diff --git a/deeppolar-main/.gitignore b/deeppolar-main/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..68bc17f9ff2104a9d7b6777058bb4c343ca72609 --- /dev/null +++ b/deeppolar-main/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/deeppolar-main/LICENSE b/deeppolar-main/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..025c405a0ba13968400e956a7cd21c020a60f596 --- /dev/null +++ b/deeppolar-main/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Ashwin Hebbar + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_1.pt b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_1.pt new file mode 100644 index 0000000000000000000000000000000000000000..c0ce41dfa608900fbac4f0d33be5ba16661ca554 Binary files /dev/null and b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_1.pt differ diff --git a/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_10.pt b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_10.pt new file mode 100644 index 0000000000000000000000000000000000000000..09523e07e1a7c9c35c6a52987bcdb7ee3d692ad5 Binary files /dev/null and b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_10.pt differ diff --git a/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_11.pt b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_11.pt new file mode 100644 index 0000000000000000000000000000000000000000..7edcc9b0a8c6506d4828d2b75275b3e132f27ab0 Binary files /dev/null and b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_11.pt differ diff --git a/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_12.pt b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_12.pt new file mode 100644 index 0000000000000000000000000000000000000000..60652dceb4a4b18f73190252d10cfef4f6c1fdc7 Binary files /dev/null and b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_12.pt differ diff --git a/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_13.pt b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_13.pt new file mode 100644 index 0000000000000000000000000000000000000000..fd415ec751dc324887188f81ee5766d5432ed3d8 Binary files /dev/null and b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_13.pt differ diff --git a/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_14.pt b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_14.pt new file mode 100644 index 0000000000000000000000000000000000000000..6a0db3800cd96fea807f6c0f964967088873836f Binary files /dev/null and b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_14.pt differ diff --git a/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_15.pt b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_15.pt new file mode 100644 index 0000000000000000000000000000000000000000..837359e0dec98bfe8e00b8c82bbb1c05547fa7e7 Binary files /dev/null and b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_15.pt differ diff --git a/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_16.pt b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_16.pt new file mode 100644 index 0000000000000000000000000000000000000000..9f5d1b07fbb2536afb34b061385a19b544bac38f Binary files /dev/null and b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_16.pt differ diff --git a/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_2.pt b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_2.pt new file mode 100644 index 0000000000000000000000000000000000000000..320cd8ab8347d701d9e80c07fa666d25439b4856 Binary files /dev/null and b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_2.pt differ diff --git a/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_3.pt b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_3.pt new file mode 100644 index 0000000000000000000000000000000000000000..34a8f7cf8395fbb0f36bcad2f3c8ffda206f3ad4 Binary files /dev/null and b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_3.pt differ diff --git a/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_4.pt b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_4.pt new file mode 100644 index 0000000000000000000000000000000000000000..9b9f6da03f2ec4678373c09d898a1d3ba1a7a2a2 Binary files /dev/null and b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_4.pt differ diff --git a/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_5.pt b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_5.pt new file mode 100644 index 0000000000000000000000000000000000000000..448901a02639ca73fec74a188774e324ba300c8a Binary files /dev/null and b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_5.pt differ diff --git a/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_6.pt b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_6.pt new file mode 100644 index 0000000000000000000000000000000000000000..71b611a2f20b58ebb158c9b82cb7f651683f29ce Binary files /dev/null and b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_6.pt differ diff --git a/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_7.pt b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_7.pt new file mode 100644 index 0000000000000000000000000000000000000000..1743a4efd9bb7863b7b429f6c7a744ecdeb9c574 Binary files /dev/null and b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_7.pt differ diff --git a/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_8.pt b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_8.pt new file mode 100644 index 0000000000000000000000000000000000000000..5d9431e59485caa7ad23ab164e1facb6cef2a945 Binary files /dev/null and b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_8.pt differ diff --git a/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_9.pt b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_9.pt new file mode 100644 index 0000000000000000000000000000000000000000..fcb595ed356bd38542599d82bb3c14ee3b006ab7 Binary files /dev/null and b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/16_9.pt differ diff --git a/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/2_1.pt b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/2_1.pt new file mode 100644 index 0000000000000000000000000000000000000000..74e7e9f4b9d1c9bb99c691194b937405e7b34049 Binary files /dev/null and b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/2_1.pt differ diff --git a/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/2_2.pt b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/2_2.pt new file mode 100644 index 0000000000000000000000000000000000000000..08d405e4e00e213d495314feed964ef584e997dc Binary files /dev/null and b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/2_2.pt differ diff --git a/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/4_1.pt b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/4_1.pt new file mode 100644 index 0000000000000000000000000000000000000000..3b41e250c47b25701ea78e57ed3fad66cdf4f807 Binary files /dev/null and b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/4_1.pt differ diff --git a/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/4_2.pt b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/4_2.pt new file mode 100644 index 0000000000000000000000000000000000000000..75bd221c7f0d8515d5b1fbd92aa72224862841b0 Binary files /dev/null and b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/4_2.pt differ diff --git a/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/4_3.pt b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/4_3.pt new file mode 100644 index 0000000000000000000000000000000000000000..7475acb5ef308a78627f9f83b4402320d0bc3be4 Binary files /dev/null and b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/4_3.pt differ diff --git a/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/4_4.pt b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/4_4.pt new file mode 100644 index 0000000000000000000000000000000000000000..581ca260dddb8dbfa4eb8a68f7a6d7f5a0fdea97 Binary files /dev/null and b/deeppolar-main/Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu/4_4.pt differ diff --git a/deeppolar-main/README.md b/deeppolar-main/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a38c93794ff87c1673f2df6acd613df2cce99a66 --- /dev/null +++ b/deeppolar-main/README.md @@ -0,0 +1,53 @@ +# DeepPolar codes +Code for "[DeepPolar: Inventing Nonlinear Large-Kernel Polar Codes via Deep Learning](https://arxiv.org/abs/2402.08864)", ICML 2024 + +## Installation + +First, clone the repository to your local machine: + +```bash +git clone https://github.com/hebbarashwin/deeppolar.git +cd deeppolar +``` + +Then, install the required Python packages: + +```bash +pip install -r requirements.txt +``` + +## Usage + +Best results are obtained by pretraining kernels using curriculum training and initializing the network using these pretrained kernels. (training from scratch may work too) +An exemplar kernel has been provided. Command to run: +(You can set --id for different runs.) + +```bash +python -u main.py --N 256 --K 37 -ell 16 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 2000 --enc_train_snr 0 --dec_train_snr -2 --enc_hidden_size 64 --dec_hidden_size 128 --enc_lr 0.0001 --dec_lr 0.0001 --weight_decay 0 --test_snr_start -5 --test_snr_end -1 --snr_points 5 --batch_size 20000 --id run1 --kernel_load_path Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu --gpu -2 +``` + +- `N`, `K`: Code parameters +- `-ell`, Kernel size; \sqrt{N} works best +- `kernel_load_path`: Path to load specific model kernels. (if training from scratch, don't set this flag) +- `enc_train_iters`, `dec_train_iters`: Number of training iterations for the encoder and decoder. +- `full_iters`: Total iterations for full training cycles. +- `id`: Identifier for the run. +- `model_save_per`: Frequency of saving the trained models. +- `gpu` : -2 : cuda, -1 : cpu, 0/1/2/3 : specific gpu + + +The kernels can be pretrained, for example by running +```bash +bash pretrain.sh +``` +(Typically we don't need to train each kernel for as many iterations as this script.) + +Testing + +```bash +python -u main.py --N 256 --K 37 -ell 16 --enc_hidden_size 64 --dec_hidden_size 128 --test_snr_start -5 --test_snr_end -1 --snr_points 5 --test_batch_size 10000 --id run1 --weight_decay 0. --num_errors 100 --test +``` + +(More details will be added soon.) +- Finetuning with increasingly large batch sizes improves high-SNR performance. +- BER gain can be traded off for BLER by finetuning with a BLER surrogate loss. \ No newline at end of file diff --git a/deeppolar-main/deeppolar.py b/deeppolar-main/deeppolar.py new file mode 100644 index 0000000000000000000000000000000000000000..3f94bee2b7a1d21386fc36c0ba077983a0098527 --- /dev/null +++ b/deeppolar-main/deeppolar.py @@ -0,0 +1,892 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from polar import PolarCode, get_frozen +from models import g_Full, f_Full, weights_init +from utils import min_sum_log_sum_exp, min_sum_log_sum_exp_4, countSetBits, log_sum_exp, STEQuantize +from collections import defaultdict +import os + + + +###��������ģ�ͼܹ��Ż��� +#������Ƚ���������ܹ�����Transformer��ResNetģ�飩 +#����ע�������������õز�׽����֮��Ĺ�ϵ +#ʵ�ָ�������ȿ����üܹ� +class TransformerLayer(nn.Module): + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.activation = nn.ReLU() + + def forward(self, src): + # src��״: (sequence_length, batch_size, d_model) + src2 = self.self_attn(src, src, src)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src +####ADD down + +class DeepPolar(PolarCode): + def __init__(self, args, device, N, K, ell = 2, infty = 1000., depth_map : defaultdict = None): + # ��ʼ��DeepPolar�� + # args: ����ģ�Ͳ��������ö��� + # device: �����豸����CPU��GPU�� + # N: Polar��ij��� + # K: ��Ϣλ������ + # ell: ÿ����ӿ��С + # infty: �����ֵ�����ڱ�ʾ���ɿ����ŵ� + # depth_map: ÿ����ӿ��С���ֵ� + args.use_transformer = True # �Ƿ�ʹ��Transformer + args.nhead = 4 # Transformerͷ�� + args.dim_feedforward = 2048 # FFNά�� + args.transformer_dropout = 0.1 # Dropout�� + + # ����Transformer���ʼ������ʼ��Transformer����޸� + self.transformer_layers = nn.ModuleDict() + if hasattr(args, 'use_transformer') and args.use_transformer: + for d in range(1, self.n_ell+1): + self.transformer_layers[str(d)] = TransformerLayer( + d_model=self.depth_map[d], + nhead=getattr(args, 'nhead', 4), # Ĭ��4��ͷ + dim_feedforward=getattr(args, 'dim_feedforward', 2048), + dropout=getattr(args, 'transformer_dropout', 0.1) + ).to(device) + # �������������ʼ���޸ģ���Ҫ��DeepPolar��__init__���������Ӷ�������ij�ʼ���� + if args.use_transformer: + self.transformer_layers = nn.ModuleDict() + for d in range(1, self.n_ell+1): + self.transformer_layers[str(d)] = TransformerLayer( + d_model=self.depth_map[d], + nhead=args.nhead, + dim_feedforward=args.dim_feedforward, + dropout=args.dropout_p + ) + + if args.use_attention: + self.attention_layers = nn.ModuleDict() + for d in range(1, self.n_ell+1): + self.attention_layers[str(d)] = nn.MultiheadAttention( + embed_dim=self.depth_map[d], + num_heads=args.nhead, + dropout=args.dropout_p + ) + # ��ȡ����λ����Ϣλ�� + Fr = get_frozen(N, K, self.args.rate_profile) + # ���ø���PolarCode�ij�ʼ������ + super().__init__(n = int(np.log2(N)), K = K, Fr=Fr, infty = infty) + self.N = N # Polar��ij��� + + # ����ṩ��depth_map����ʹ������ȷ��ÿһ����ӿ��С + if depth_map is not None: + # ���depth_map��ֵ�ij˻��Ƿ����N + assert np.prod(list(depth_map.values())) == N + # ���depth_map�ļ��Ƿ��1��ʼ������ij���� + assert min(list(depth_map.keys())) == 1 + assert max(list(depth_map.keys())) <= int(np.log2(N)) + self.ell = None # ��ʹ�ù̶���ellֵ + self.n_ell = len(depth_map.keys()) # ��Ȳ��� + assert max(list(depth_map.keys())) == self.n_ell # ��������Ƿ������Ȳ��� + self.depth_map = depth_map # �洢depth_map + else: + # ���û���ṩdepth_map����ʹ�ù̶���ellֵ + self.ell = ell + self.n_ell = int(np.log(N)/np.log(self.ell)) # ������Ȳ��� + self.depth_map = defaultdict(int) # ����һ��Ĭ���ֵ� + for d in range(1, self.n_ell+1): + self.depth_map[d] = self.ell # Ϊÿһ�������ӿ��С + assert np.prod(list(self.depth_map.values())) == N # ����ӿ��С�ij˻��Ƿ����N + + self.device = device # �洢�����豸 + self.fnet_dict = None # ���ڴ洢�������� + self.gnet_dict = None # ���ڴ洢�������� + self.infty = infty # �洢�����ֵ + + @staticmethod + def get_onehot(actions): + # ������ת��Ϊone-hot���� + # actions: �������� + # ����: one-hot��������� + + inds = (0.5 + 0.5*actions).long() # ������ת��Ϊ���� + if len(actions.shape) == 2: + # ������������Ƕ�ά�ģ�����one-hot���� + return torch.eye(2, device = inds.device)[inds].reshape(actions.shape[0], -1) + elif len(actions.shape) == 3: + # ���������������ά�ģ�����one-hot���� + return torch.eye(2, device = inds.device)[inds].reshape(actions.shape[0], actions.shape[1], -1) + + def define_kernel_nns(self, ell, unfrozen = None, fnet = 'KO', gnet = 'KO', shared = False): + # �������ͱ������� + # ell: ÿ����ӿ��С + # unfrozen: δ�����λλ�� + # fnet: ������������ + # gnet: ������������ + # shared: �Ƿ���������� + + if 'KO' in fnet: + self.fnet_dict = {} # ��ʼ�����������ֵ� + else: + self.fnet_dict = None # �����ʹ�ý������磬����ΪNone + self.shared = shared # �洢�Ƿ������� + + if 'KO' in gnet: + self.gnet_dict = {} # ��ʼ�����������ֵ� + else: + self.gnet_dict = None # �����ʹ�ñ������磬����ΪNone + + dec_hidden_size = self.args.dec_hidden_size # ������������ز��С + enc_hidden_size = self.args.enc_hidden_size # ������������ز��С + depth = 1 # ��ǰ��� + assert len(unfrozen) > 0, "No unfrozen bits!" # ����Ƿ���δ�����λ + + self.fnet_dict[depth] = {} # ��ʼ����ǰ��ȵĽ������� + if fnet == 'KO_parallel' or fnet == 'KO_last_parallel': + bit_position = 0 # ��ǰλλ�� + self.fnet_dict[depth][bit_position] = {} # ��ʼ����ǰλλ�õĽ������� + # �����С + input_size = ell + # �����С + output_size = ell + # ������������ + self.fnet_dict[depth][bit_position] = f_Full(input_size, dec_hidden_size, output_size, activation = self.args.dec_activation, dropout_p = self.args.dropout_p, depth = self.args.f_depth, use_norm = self.args.use_norm).to(self.device) + elif 'KO' in fnet: + if shared: + self.fnet_dict[depth] = {} # ��ʼ����ǰ��ȵĽ������� + for current_position in range(ell): + # �������������Ľ������� + self.fnet_dict[depth][current_position] = f_Full(ell + current_position, dec_hidden_size, 1, activation = self.args.dec_activation, dropout_p = self.args.dropout_p, depth = self.args.f_depth, use_norm = self.args.use_norm).to(self.device) + else: + bit_position = 0 # ��ǰλλ�� + for current_position in unfrozen: + if not self.fnet_dict[depth].get(bit_position): + self.fnet_dict[depth][bit_position] = {} # ��ʼ����ǰλλ�õĽ������� + input_size = ell + (int(self.args.onehot)+1)*current_position # �����С + # ������������ + self.fnet_dict[depth][bit_position][current_position] = f_Full(input_size, dec_hidden_size, 1, activation = self.args.dec_activation, dropout_p = self.args.dropout_p, depth = self.args.f_depth, use_norm = self.args.use_norm).to(self.device) + + if 'KO' in gnet: + self.gnet_dict[depth] = {} # ��ʼ����ǰ��ȵı������� + if shared: + if gnet == 'KO': + # �������������ı������� + self.gnet_dict[depth] = g_Full(ell, enc_hidden_size, ell-1, depth = self.args.g_depth, skip_depth = self.args.g_skip_depth, skip_layer = self.args.g_skip_layer, ell = ell, activation = self.args.enc_activation, use_skip = self.args.skip).to(self.device) + else: + bit_position = 0 # ��ǰλλ�� + if gnet == 'KO': + # ������������ + self.gnet_dict[depth][bit_position] = g_Full(ell, enc_hidden_size, ell-1, depth = self.args.g_depth, skip_depth = self.args.g_skip_depth, skip_layer = self.args.g_skip_layer, ell = ell, activation = self.args.enc_activation, use_skip = self.args.skip).to(self.device) + + def define_and_load_nns(self, ell, kernel_load_path = None, fnet = 'KO', gnet = 'KO', shared = True, dataparallel = False): + # ���岢���ؽ���ͱ������� + # ell: ÿ����ӿ��С + # kernel_load_path: �������������·�� + # fnet: ������������ + # gnet: ������������ + # shared: �Ƿ���������� + # dataparallel: �Ƿ�ʹ�����ݲ��� + + if 'KO' in fnet: + self.fnet_dict = {} # ��ʼ�����������ֵ� + else: + self.fnet_dict = None # �����ʹ�ý������磬����ΪNone + self.shared = shared # �洢�Ƿ������� + + if 'KO' in gnet: + self.gnet_dict = {} # ��ʼ�����������ֵ� + else: + self.gnet_dict = None # �����ʹ�ñ������磬����ΪNone + + dec_hidden_size = self.args.dec_hidden_size # ������������ز��С + enc_hidden_size = self.args.enc_hidden_size # ������������ز��С + + for depth in range(self.n_ell, 0, -1): + if depth in self.args.polar_depths: + continue # �����ǰ�����ָ��������б��У����� + ell = self.depth_map[depth] # ��ȡ��ǰ��ȵ��ӿ��С + proj_size = np.prod([self.depth_map[d] for d in range(1, depth+1)]) # ���㵱ǰ��ȵ�ͶӰ��С + + if fnet == 'KO_last_parallel' and depth == 1: + self.fnet_dict[depth] = {} # ��ʼ����ǰ��ȵĽ������� + for bit_position in range(self.N // proj_size): + proj = np.arange(bit_position*proj_size, (bit_position+1)*proj_size) # ���㵱ǰλλ�õ�ͶӰ + get_num_info_proj = lambda proj : sum([int(x in self.info_positions) for x in proj]) # ����ͶӰ�е���Ϣλ���� + num_info_in_proj = get_num_info_proj(proj) # ��ȡͶӰ�е���Ϣλ���� + subproj_len = len(proj) // ell # ������ͶӰ�ij��� + subproj = [proj[i:i+subproj_len] for i in range(0, len(proj), subproj_len)] # �ָ�ͶӰΪ��ͶӰ + num_info_in_subproj = [get_num_info_proj(x) for x in subproj] # ����ÿ����ͶӰ�е���Ϣλ���� + unfrozen = [i for i, x in enumerate(num_info_in_subproj) if x >= 1] # ��ȡδ�������ͶӰλ�� + + input_size = ell # �����С + output_size = ell # �����С + # ������������ + self.fnet_dict[depth][bit_position] = f_Full(input_size, dec_hidden_size, output_size, activation = self.args.dec_activation, dropout_p = self.args.dropout_p, depth = self.args.f_depth, use_norm = self.args.use_norm).to(self.device) + + if len(unfrozen) > 0: + if kernel_load_path is not None: + try: + # ���Լ���Ԥѵ���Ľ���������� + ckpt = torch.load(os.path.join(kernel_load_path + '_parallel', f'{ell}_{len(unfrozen)}.pt')) + ckpt_exists = True + except FileNotFoundError: + print(f"Parallel File not found for ell = {ell}, num_unfrozen = {len(unfrozen)}") + ckpt_exists = False + else: + ckpt_exists = False + if ckpt_exists: + # ���ؽ���������� + f_ckpt = ckpt[0][1][0].state_dict() + self.fnet_dict[depth][bit_position].load_state_dict(f_ckpt) + if dataparallel: + # ʹ�����ݲ��� + self.fnet_dict[depth][bit_position] = nn.DataParallel(self.fnet_dict[depth][bit_position]) + elif 'KO' in fnet: + self.fnet_dict[depth] = {} # ��ʼ����ǰ��ȵĽ������� + if shared: + self.fnet_dict[depth] = {} # ��ʼ����ǰ��ȵĽ������� + for current_position in range(ell): + # �������������Ľ������� + self.fnet_dict[depth][current_position] = f_Full(ell + current_position, dec_hidden_size, 1, activation = self.args.dec_activation, dropout_p = self.args.dropout_p, depth = self.args.f_depth, use_norm = self.args.use_norm).to(self.device) + if dataparallel: + # ʹ�����ݲ��� + self.fnet_dict[depth][current_position] = nn.DataParallel(self.fnet_dict[depth][current_position]) + else: + for bit_position in range(self.N // proj_size): + proj = np.arange(bit_position*proj_size, (bit_position+1)*proj_size) # ���㵱ǰλλ�õ�ͶӰ + get_num_info_proj = lambda proj : sum([int(x in self.info_positions) for x in proj]) # ����ͶӰ�е���Ϣλ���� + num_info_in_proj = get_num_info_proj(proj) # ��ȡͶӰ�е���Ϣλ���� + subproj_len = len(proj) // ell # ������ͶӰ�ij��� + subproj = [proj[i:i+subproj_len] for i in range(0, len(proj), subproj_len)] # �ָ�ͶӰΪ��ͶӰ + num_info_in_subproj = [get_num_info_proj(x) for x in subproj] # ����ÿ����ͶӰ�е���Ϣλ���� + unfrozen = [i for i, x in enumerate(num_info_in_subproj) if x >= 1] # ��ȡδ�������ͶӰλ�� + + if len(unfrozen) > 0: + if kernel_load_path is not None: + try: + # ���Լ���Ԥѵ���Ľ���������� + ckpt = torch.load(os.path.join(kernel_load_path, f'{ell}_{len(unfrozen)}.pt')) + ckpt_exists = True + except FileNotFoundError: + print(f"File not found for ell = {ell}, num_unfrozen = {len(unfrozen)}") + ckpt_exists = False + else: + ckpt_exists = False + for current_position in unfrozen: + if not self.fnet_dict[depth].get(bit_position): + self.fnet_dict[depth][bit_position] = {} # ��ʼ����ǰλλ�õĽ������� + input_size = ell + (int(self.args.onehot)+1)*current_position # �����С + output_size = 1 # �����С + # ������������ + self.fnet_dict[depth][bit_position][current_position] = f_Full(input_size, dec_hidden_size, output_size, activation = self.args.dec_activation, dropout_p = self.args.dropout_p, depth = self.args.f_depth, use_norm = self.args.use_norm).to(self.device) + if ckpt_exists: + try: + # ���ؽ���������� + f_ckpt = ckpt[0][1][0][current_position].state_dict() + except: + print(unfrozen) + self.fnet_dict[depth][bit_position][current_position].load_state_dict(f_ckpt) + if dataparallel: + # ʹ�����ݲ��� + self.fnet_dict[depth][bit_position][current_position] = nn.DataParallel(self.fnet_dict[depth][bit_position][current_position]) + + if 'KO' in gnet: + self.gnet_dict[depth] = {} # ��ʼ����ǰ��ȵı������� + if shared: + if gnet == 'KO': + if not dataparallel: + # �������������ı������� + self.gnet_dict[depth] = g_Full(ell, enc_hidden_size, ell-1, depth = self.args.g_depth, skip_depth = self.args.g_skip_depth, skip_layer = self.args.g_skip_layer, ell = ell, use_skip = self.args.skip).to(self.device) + else: + # ʹ�����ݲ��д������������ı������� + self.gnet_dict[depth] = nn.DataParallel(g_Full(ell, enc_hidden_size, ell-1, depth = self.args.g_depth, skip_depth = self.args.g_skip_depth, skip_layer = self.args.g_skip_layer, ell = ell, use_skip = self.args.skip)).to(self.device) + else: + for bit_position in range(self.N // proj_size): + proj = np.arange(bit_position*proj_size, (bit_position+1)*proj_size) # ���㵱ǰλλ�õ�ͶӰ + get_num_info_proj = lambda proj : sum([int(x in self.info_positions) for x in proj]) # ����ͶӰ�е���Ϣλ���� + num_info_in_proj = get_num_info_proj(proj) # ��ȡͶӰ�е���Ϣλ���� + subproj_len = len(proj) // ell # ������ͶӰ�ij��� + subproj = [proj[i:i+subproj_len] for i in range(0, len(proj), subproj_len)] # �ָ�ͶӰΪ��ͶӰ + num_info_in_subproj = [get_num_info_proj(x) for x in subproj] # ����ÿ����ͶӰ�е���Ϣλ���� + unfrozen = [i for i, x in enumerate(num_info_in_subproj) if x >= 1] + + if num_info_in_proj > 0: + if gnet == 'KO': + self.gnet_dict[depth][bit_position] = g_Full(ell, enc_hidden_size, ell-1, depth = self.args.g_depth, skip_depth = self.args.g_skip_depth, skip_layer = self.args.g_skip_layer, ell = ell, activation = self.args.enc_activation, use_skip = self.args.skip).to(self.device) + if kernel_load_path is not None: + try: + ckpt = torch.load(os.path.join(kernel_load_path, f'{ell}_{len(unfrozen)}.pt')) + self.gnet_dict[depth][bit_position].load_state_dict(ckpt[1][1][0].state_dict()) + except FileNotFoundError: + print(f"File not found for ell = {ell}, num_unfrozen = {len(unfrozen)}") + pass + if dataparallel: + self.gnet_dict[depth][bit_position] = nn.DataParallel(self.gnet_dict[depth][bit_position]) + + + # print(f"g : {depth}, {bit_position}, {len(unfrozen)}") + if kernel_load_path is not None: + print("Loaded kernel from ", kernel_load_path) + + def load_nns(self, fnet_dict, gnet_dict = None, shared = False): + self.fnet_dict = fnet_dict + self.gnet_dict = gnet_dict + + for depth in fnet_dict.keys(): + if self.fnet_dict is not None: + for bit_position in self.fnet_dict[depth].keys(): + if not isinstance(self.fnet_dict[depth][bit_position], dict):#shared or self.args.decoder_type == 'KO_parallel' or self.args.decoder_type == 'KO_RNN': + self.fnet_dict[depth][bit_position].to(self.device) + else: + for current_position in self.fnet_dict[depth][bit_position].keys(): + self.fnet_dict[depth][bit_position][current_position].to(self.device) + if gnet_dict is not None: + if shared: + self.gnet_dict[depth].to(self.device) + else: + for bit_position in self.gnet_dict[depth].keys(): + self.gnet_dict[depth][bit_position].to(self.device) + print("NN weights loaded!") + + def load_partial_nns(self, fnet_dict, gnet_dict = None): + + for depth in fnet_dict.keys(): + if fnet_dict is not None: + for bit_position in fnet_dict[depth].keys(): + if isinstance(fnet_dict[depth][bit_position], dict): + for current_position in fnet_dict[depth][bit_position].keys(): + self.fnet_dict[depth][bit_position][current_position] = fnet_dict[depth][bit_position][current_position].to(self.device) + else: + self.fnet_dict[depth][bit_position] = fnet_dict[depth][bit_position].to(self.device) + + if gnet_dict is not None: + for bit_position in gnet_dict[depth].keys(): + self.gnet_dict[depth][bit_position] = gnet_dict[depth][bit_position].to(self.device) + print("NN weights loaded!") + + def kernel_encode(self, ell, gnet, msg_bits, info_positions, binary = False): + input_shape = msg_bits.shape[-1] + assert input_shape <= ell + u = torch.ones(msg_bits.shape[0], self.N, dtype=torch.float).to(self.device) + u[:, info_positions] = msg_bits + output =torch.cat([gnet(u.unsqueeze(1)).squeeze(1), u[:, -1:]], 1) + + power_constrained_u = self.power_constraint(output) + if binary: + stequantize = STEQuantize.apply + power_constrained_u = stequantize(power_constrained_u) + return power_constrained_u + +####�������Ľ��ı������ܹ����滻���е�kernel_encode������ +# ���������ǿ���Ľ��������Ĺ���Լ������ +#��ǿ�ı������������������ӵ�DeepPolar���У���Ϊ���б���������������򲹳䡣 + def enhanced_encode(self, msg_bits, binary=False): + u = torch.ones(msg_bits.shape[0], self.N, dtype=torch.float).to(self.device) + u[:, self.info_positions] = msg_bits + + # ...ǰ��Ĵ���... + + # Ӧ��Transformer�� + chunks = [u[:, i:i+proj_size//ell].unsqueeze(1) for i in + range(bit_position*proj_size, (bit_position+1)*proj_size, proj_size//ell)] + concatenated = torch.cat(chunks, 1) # shape: (batch_size, num_chunks, chunk_size) + + # ת��ΪTransformer��Ҫ����״������ + transformer_input = concatenated.permute(1, 0, 2) # (num_chunks, batch_size, chunk_size) + transformer_out = self.transformer_layers[d](transformer_input) + output = transformer_out.permute(1, 0, 2) # ת����(batch_size, num_chunks, chunk_size) + + # ...����Ĵ���... + + # ����λ�ñ��� + position = torch.arange(self.N).unsqueeze(0).to(self.device) + pos_enc = torch.zeros_like(u) + pos_enc[:, :] = position.float() + u = u + pos_enc + + # �ֲ㴦�� + for d in range(1, self.n_ell+1): + proj_size = np.prod([self.depth_map[dd] for dd in range(1, d+1)]) + ell = self.depth_map[d] + + for bit_position in range(self.N // proj_size): + proj = np.arange(bit_position*proj_size, (bit_position+1)*proj_size) + num_info = sum([int(x in self.info_positions) for x in proj]) + + if num_info > 0: + # ʹ��Transformer������Ϣλ + chunks = [u[:, i:i+proj_size//ell].unsqueeze(1) for i in + range(bit_position*proj_size, (bit_position+1)*proj_size, proj_size//ell)] + concatenated = torch.cat(chunks, 1) + + # Ӧ��Transformer�� + transformer_out = self.transformer_layers[d](concatenated.permute(1, 0, 2)) + output = transformer_out.permute(1, 0, 2).reshape(u.shape[0], -1) + + u = torch.cat((u[:, :bit_position*proj_size], output, + u[:, (bit_position+1)*proj_size:]), dim=1) + + # �Ľ��Ĺ���Լ�� + power_constrained_u = self.enhanced_power_constraint(u) + + if binary: + stequantize = STEQuantize.apply + power_constrained_u = stequantize(power_constrained_u) + + return power_constrained_u + + def enhanced_power_constraint(self, codewords): + # ��������Ӧ���ʵ��� + norm = torch.norm(codewords, p=2, dim=1, keepdim=True) + scale = torch.sqrt(self.N) / (norm + 1e-6) + return codewords * scale +####ADD Down + + +#�Ľ��Ľ������ܹ����滻���е�deeppolar_decode������ +#��ǿ�������IJ��д����������Ż���Ϣλ�����߼� +###��ǿ�Ľ�����������ͬ�����ӵ�DeepPolar���У����������н������������棬ͨ����������ʹ�����ֽ��뷽ʽ�� + def enhanced_decode(self, noisy_code): + assert noisy_code.shape[1] == self.N + + # ��ʼ�� + decoded_llrs = self.infty*torch.ones(noisy_code.shape[0], self.N, device=noisy_code.device) + partial_sums = torch.ones(noisy_code.shape[0], self.n_ell+1, self.N, device=noisy_code.device) + + # ��������Ԥ���� + noisy_code = self.noise_preprocess(noisy_code) + + # �ݹ���� + decoded_llrs, partial_sums = self.enhanced_decode_depth( + noisy_code.unsqueeze(2), self.n_ell, 0, decoded_llrs, partial_sums) + + return decoded_llrs[:, self.info_positions], torch.sign(decoded_llrs[:, self.info_positions]) + + def noise_preprocess(self, noisy_code): + # Ӧ��С�����������Ԥ�������� + if self.args.use_wavelet: + noisy_code = self.wavelet_denoise(noisy_code) + return noisy_code + + def enhanced_decode_depth(self, llrs, depth, bit_position, decoded_llrs, partial_sums): + half_index = np.prod([self.depth_map[d] for d in range(1, depth)]) if depth > 1 else 1 + ell = self.depth_map[depth] + left_bit_position = self.depth_map[depth] * bit_position + + # ��Ϣλ��� + proj_size = np.prod([self.depth_map[d] for d in range(1, depth+1)]) + proj = np.arange(bit_position*proj_size, (bit_position+1)*proj_size) + num_info = sum([int(x in self.info_positions) for x in proj]) + + # ʹ��ע����������ǿ���� + dec_chunks = [llrs[:, (j)*half_index:(j+1)*half_index].clone() for j in range(ell)] + + if num_info > 0 and depth in self.fnet_dict: + # Ӧ��ע�������� + concatenated = torch.cat(dec_chunks, 2) + if self.args.use_attention: + concatenated = self.attention_layers[depth](concatenated) + + # ���н��� + if self.args.parallel_decode: + Lu = self.fnet_dict[depth][bit_position](concatenated) + decoded_llrs[:, left_bit_position + np.array(self.info_positions)] = Lu.squeeze(1) + else: + # ���н��� + for current_position in range(ell): + if current_position > 0: + prev_decoded = partial_sums[:, depth-1, (current_position-1)*half_index:current_position*half_index].unsqueeze(2) + dec_chunks.append(prev_decoded) + + if current_position in [p-left_bit_position for p in self.info_positions if left_bit_position <= p < left_bit_position+proj_size]: + concatenated = torch.cat(dec_chunks, 2) + Lu = self.fnet_dict[depth][bit_position][current_position](concatenated) + decoded_llrs[:, left_bit_position + current_position] = Lu.squeeze(2).squeeze(1) + + # �ݹ鴦�� + if depth > 1: + for current_position in range(ell): + bit_position_offset = left_bit_position + current_position + if bit_position_offset in self.info_positions: + decoded_llrs, partial_sums = self.enhanced_decode_depth( + llrs[:, current_position*half_index:(current_position+1)*half_index], + depth-1, bit_position_offset, decoded_llrs, partial_sums) + + return decoded_llrs, partial_sums +######ADD Down + + def deeppolar_encode(self, msg_bits, binary = False): + u = torch.ones(msg_bits.shape[0], self.N, dtype=torch.float).to(self.device) + u[:, self.info_positions] = msg_bits + for d in range(1, self.n_ell+1): + # num_bits = self.ell**(d-1) + num_bits = np.prod([self.depth_map[dd] for dd in range(1, d)]) if d > 1 else 1 + # proj_size = self.ell**(d) + proj_size = np.prod([self.depth_map[dd] for dd in range(1, d+1)]) + ell = self.depth_map[d] + for bit_position, i in enumerate(np.arange(0, self.N, ell*num_bits)): + + # [u v] encoded to [(u xor v),v)] + proj = np.arange(bit_position*proj_size, (bit_position+1)*proj_size) + get_num_info_proj = lambda proj : sum([int(x in self.info_positions) for x in proj]) + num_info_in_proj = get_num_info_proj(proj) + + subproj_len = len(proj) // ell + subproj = [proj[i:i+subproj_len] for i in range(0, len(proj), subproj_len)] + num_info_in_subproj = [get_num_info_proj(x) for x in subproj] + num_nonzero_subproj = sum([int(x != 0) for x in num_info_in_subproj]) + + if num_info_in_proj > 0: + info_bits_present = True + else: + info_bits_present = False + if d in self.args.polar_depths: + info_bits_present = False + + enc_chunks = [] + ell = self.depth_map[d] + for j in range(ell): + chunk = u[:, i + j*num_bits:i + (j+1)*num_bits].unsqueeze(2).clone() + enc_chunks.append(chunk) + if info_bits_present: + concatenated_chunks = torch.cat(enc_chunks, 2) + if self.shared: + output = torch.cat([self.gnet_dict[d](concatenated_chunks), u[:, i + (ell-1)*num_bits:i + (ell)*num_bits].unsqueeze(2)], dim=2) + else: + output = torch.cat([self.gnet_dict[d][bit_position](concatenated_chunks), u[:, i + (ell-1)*num_bits:i + (ell)*num_bits].unsqueeze(2)], dim=2) + output = output.permute(0,2,1).reshape(msg_bits.shape[0], -1, 1).squeeze(2) + + else: + output = self.encode_chunks_plotkin(enc_chunks, ell) + u = torch.cat((u[:, :i], output, u[:, i + ell*num_bits:]), dim=1) + + power_constrained_u = self.power_constraint(u) + if binary: + stequantize = STEQuantize.apply + power_constrained_u = stequantize(power_constrained_u) + return power_constrained_u + + def power_constraint(self, codewords): + return F.normalize(codewords, p=2, dim=1)*np.sqrt(self.N) + + def encode_chunks_plotkin(self, enc_chunks, ell = None): + + # message shape is (batch, k) + # BPSK convention : 0 -> +1, 1 -> -1 + # Therefore, xor(a, b) = a*b + + # to change for other kernels + + if ell is None: + ell = self.ell + assert len(enc_chunks) == ell + chunk_size = enc_chunks[0].shape[1] + batch_size = enc_chunks[0].shape[0] + + u = torch.cat(enc_chunks, 1).squeeze(2) + n = int(np.log2(ell)) + + for d in range(0, n): + num_bits = 2**d * chunk_size + for i in np.arange(0, chunk_size*ell, 2*num_bits): + # [u v] encoded to [(u,v) xor v] + u = torch.cat((u[:, :i], u[:, i:i+num_bits].clone() * u[:, i+num_bits: i+2*num_bits], u[:, i+num_bits:]), dim=1) + return u + + def deeppolar_parallel_decode(self, noisy_code): + # Successive cancellation decoder for polar codes + assert noisy_code.shape[1] == self.N + + depth = self.n_ell + + decoded_llrs = self.infty*torch.ones(noisy_code.shape[0], self.N, device = noisy_code.device) + # function is recursively called (DFS) + # arguments: Beliefs at the input of node (LLRs at top node), depth of children, bit_position (zero at top node) + decoded_llrs = self.KO_parallel_decode_depth(noisy_code.unsqueeze(2), depth, 0, decoded_llrs) + decoded_llrs = decoded_llrs[:, self.info_positions] + return decoded_llrs, torch.sign(decoded_llrs) + + def deeppolar_parallel_decode_depth(self, llrs, depth, bit_position, decoded_llrs): + # Function to call recursively, for SC decoder + + # half_index = self.ell ** (depth - 1) + half_index = np.prod([self.depth_map[d] for d in range(1, depth)]) if depth > 1 else 1 + ell = self.depth_map[depth] + left_bit_position = self.depth_map[depth] * bit_position + + # Check if >1 information bits are present in the current projection. If not, don't use NNs - use polar encoding and minsum SC decoding. + # proj_size = self.ell**(depth) + proj_size = np.prod([self.depth_map[d] for d in range(1, depth+1)]) + + proj = np.arange(bit_position*proj_size, (bit_position+1)*proj_size) + get_num_info_proj = lambda proj : sum([int(x in self.info_positions) for x in proj]) + get_info_proj = lambda proj : [x for x in proj if x in self.info_positions] + + num_info_in_proj = get_num_info_proj(proj) + info_in_proj = get_info_proj(proj) + + subproj_len = len(proj) // ell + subproj = [proj[i:i+subproj_len] for i in range(0, len(proj), subproj_len)] + num_info_in_subproj = [get_num_info_proj(x) for x in subproj] + num_nonzero_subproj = sum([int(x != 0) for x in num_info_in_subproj]) + unfrozen = np.array([i for i, x in enumerate(num_info_in_subproj) if x >= 1]) + + dec_chunks = torch.cat([llrs[:, (j)*half_index:(j+1)*half_index].clone() for j in range(ell)], 2) + Lu = self.fnet_dict[depth][bit_position](dec_chunks) + + if depth == 1: + u = torch.tanh(Lu/2) + decoded_llrs[:, left_bit_position + unfrozen] = Lu.squeeze(1) + else: + for index, current_position in enumerate(unfrozen): + bit_position_offset = left_bit_position + current_position + decoded_llrs = self.deeppolar_parallel_decode_depth(Lu[:, :, index:index+1], depth-1, bit_position_offset, decoded_llrs) + + return decoded_llrs + + def deeppolar_decode(self, noisy_code): + assert noisy_code.shape[1] == self.N + + depth = self.n_ell + + decoded_llrs = self.infty*torch.ones(noisy_code.shape[0], self.N, device = noisy_code.device) + + # don't want to go into useless frozen subtrees. + partial_sums = torch.ones(noisy_code.shape[0], self.n_ell+1, self.N, device=noisy_code.device) + + # function is recursively called (DFS) + # arguments: Beliefs at the input of node (LLRs at top node), depth of children, bit_position (zero at top node) + + decoded_llrs, partial_sums = self.deeppolar_decode_depth(noisy_code.unsqueeze(2), depth, 0, decoded_llrs, partial_sums) + decoded_llrs = decoded_llrs[:, self.info_positions] + + return decoded_llrs, torch.sign(decoded_llrs) + + def deeppolar_decode_depth(self, llrs, depth, bit_position, decoded_llrs, partial_sums): + # Function to call recursively, for SC decoder + + # half_index = self.ell ** (depth - 1) + half_index = np.prod([self.depth_map[d] for d in range(1, depth)]) if depth > 1 else 1 + ell = self.depth_map[depth] + left_bit_position = self.depth_map[depth] * bit_position + + # Check if >1 information bits are present in the current projection. If not, don't use NNs - use polar encoding and minsum SC decoding. + # proj_size = self.ell**(depth) + # size of the projection of tht subtree + proj_size = np.prod([self.depth_map[d] for d in range(1, depth+1)]) + + # This chunk - finds infrozen positions in this kernel. + proj = np.arange(bit_position*proj_size, (bit_position+1)*proj_size) + get_num_info_proj = lambda proj : sum([int(x in self.info_positions) for x in proj]) + get_info_proj = lambda proj : [x for x in proj if x in self.info_positions] + + num_info_in_proj = get_num_info_proj(proj) + info_in_proj = get_info_proj(proj) + + subproj_len = len(proj) // ell + subproj = [proj[i:i+subproj_len] for i in range(0, len(proj), subproj_len)] + num_info_in_subproj = [get_num_info_proj(x) for x in subproj] + num_nonzero_subproj = sum([int(x != 0) for x in num_info_in_subproj]) + unfrozen = np.array([i for i, x in enumerate(num_info_in_subproj) if x >= 1]) + + if num_nonzero_subproj > 0: + info_bits_present = True + else: + info_bits_present = False + + if depth in self.args.polar_depths: + info_bits_present = False + + # This will be input to decoder + dec_chunks = [llrs[:, (j)*half_index:(j+1)*half_index].clone() for j in range(ell)] + # n = 2 tree case + if depth == 1: + if self.args.decoder_type == 'KO_last_parallel': + concatenated_chunks = torch.cat(dec_chunks, 2) + Lu = self.fnet_dict[depth][bit_position](concatenated_chunks)[:, 0, unfrozen] + u_hat = torch.tanh(Lu/2) + decoded_llrs[:, left_bit_position + unfrozen] = Lu + partial_sums[:, depth-1, left_bit_position + unfrozen] = u_hat + + else: + for current_position in range(ell): + bit_position_offset = left_bit_position + current_position + if current_position > 0: + # I am adding previously decoded bits . (either onehot or normal) + if self.args.onehot: + prev_decoded = get_onehot(partial_sums[:, depth-1, (current_position -1)*half_index:(current_position)*half_index].unsqueeze(2).sign()).detach().clone() + else: + prev_decoded = partial_sums[:, depth-1, (current_position -1)*half_index:(current_position)*half_index].unsqueeze(2).clone() + dec_chunks.append(prev_decoded) + + if bit_position_offset in self.frozen_positions: # frozen + # don't update decoded llrs. It already has ones*prior. + # actually don't need this. can skip. + partial_sums[:, depth-1, bit_position_offset] = torch.ones_like(partial_sums[:, depth-1, bit_position_offset]) + else: # information bit + # This is the decoding. + concatenated_chunks = torch.cat(dec_chunks, 2) + if self.shared: + Lu = self.fnet_dict[depth][current_position](concatenated_chunks) + else: + Lu = self.fnet_dict[depth][bit_position][current_position](concatenated_chunks) + + u_hat = torch.tanh(Lu/2).squeeze(2) + decoded_llrs[:, bit_position_offset] = Lu.squeeze(2).squeeze(1) + partial_sums[:, depth-1, bit_position_offset] = u_hat.squeeze(1) + + # Encoding back the decoded bits - for higher layers. + # # Compute decoded codeword + i = left_bit_position * half_index + # num_bits = self.ell**(depth-1) + num_bits = 1 + + enc_chunks = [] + for j in range(ell): + chunk = torch.sign(partial_sums[:, depth-1, i + j*num_bits:i + (j+1)*num_bits]).unsqueeze(2).detach().clone() + enc_chunks.append(chunk) + if info_bits_present: + concatenated_chunks = torch.cat(enc_chunks, 2) + if 'KO' in self.args.encoder_type: + if self.shared: + output = torch.cat([self.gnet_dict[depth](concatenated_chunks), partial_sums[:, depth-1, i + (ell-1)*num_bits:i + (ell)*num_bits].unsqueeze(2)], dim=2) + else: + # bit position of the previous depth. + output = torch.cat([self.gnet_dict[depth][bit_position](concatenated_chunks), partial_sums[:, depth-1, i + (ell-1)*num_bits:i + (ell)*num_bits].unsqueeze(2)], dim=2) + output = output.permute(0,2,1).reshape(llrs.shape[0], -1, 1).squeeze(2) + else: + output = self.encode_chunks_plotkin(enc_chunks, ell) + else: + output = self.encode_chunks_plotkin(enc_chunks, ell) + partial_sums[:, depth, i : i + num_bits*ell] = output.clone() + + return decoded_llrs, partial_sums + + # General case + else: + for current_position in range(ell): + bit_position_offset = left_bit_position + current_position + + if current_position > 0: + if self.args.onehot: + prev_decoded = get_onehot(partial_sums[:, depth-1, (current_position -1)*half_index:(current_position)*half_index].unsqueeze(2).sign()).detach().clone() + else: + prev_decoded = partial_sums[:, depth-1, (current_position -1)*half_index:(current_position)*half_index].unsqueeze(2).clone() + dec_chunks.append(prev_decoded) + concatenated_chunks = torch.cat(dec_chunks, 2) + + if current_position in unfrozen: + # General decoding .... + # add the decoded bit here + if self.shared: + Lu = self.fnet_dict[depth][current_position](concatenated_chunks).squeeze(2) + else: + # if current_position == 0: + # Lu = self.fnet_dict[depth][bit_position][current_position](llrs) + # else: + Lu = self.fnet_dict[depth][bit_position][current_position](concatenated_chunks) + decoded_llrs, partial_sums = self.deeppolar_decode_depth(Lu, depth-1, bit_position_offset, decoded_llrs, partial_sums) + else: + Lu = self.infty*torch.ones_like(llrs) + + + # Compute decoded codeword + if depth < self.n_ell : + i = left_bit_position * half_index + # num_bits = self.ell**(depth-1) + num_bits = np.prod([self.depth_map[d] for d in range(1, depth)]) + enc_chunks = [] + for j in range(ell): + chunk = torch.sign(partial_sums[:, depth-1, i + j*num_bits:i + (j+1)*num_bits]).unsqueeze(2).detach().clone() + enc_chunks.append(chunk) + if info_bits_present: + concatenated_chunks = torch.cat(enc_chunks, 2) + if 'KO' in self.args.encoder_type: + if self.shared: + output = torch.cat([self.gnet_dict[depth](concatenated_chunks), partial_sums[:, depth-1, i + (ell-1)*num_bits:i + (ell)*num_bits].unsqueeze(2)], dim=2) + else: + # bit position of the previous depth. + output = torch.cat([self.gnet_dict[depth][bit_position](concatenated_chunks), partial_sums[:, depth-1, i + (ell-1)*num_bits:i + (ell)*num_bits].unsqueeze(2)], dim=2) + output = output.permute(0,2,1).reshape(llrs.shape[0], -1, 1).squeeze(2) + else: + output = self.encode_chunks_plotkin(enc_chunks, ell) + else: + output = self.encode_chunks_plotkin(enc_chunks, ell) + partial_sums[:, depth, i : i + num_bits*ell] = output.clone() + + return decoded_llrs, partial_sums + else: # encoding not required for last level - we have already decoded all bits. + return decoded_llrs, partial_sums + + + def kernel_decode(self, ell, fnet_dict, noisy_code, info_positions = None): + input_shape = noisy_code.shape[-1] + noisy_code = noisy_code.unsqueeze(2) + assert input_shape == ell + u = torch.ones(noisy_code.shape[0], self.N, dtype=torch.float).to(self.device) + decoded_llrs = self.infty*torch.ones(noisy_code.shape[0], self.N, device = noisy_code.device) + half_index = 1 + dec_chunks = [noisy_code[:, (j)*half_index:(j+1)*half_index].clone() for j in range(ell)] + + for current_position in range(ell): + if current_position > 0: + if self.args.onehot: + prev_decoded = get_onehot(u[:, (current_position -1)*half_index:(current_position)*half_index].unsqueeze(2).clone().sign()).detach().clone() + else: + prev_decoded = u[:, (current_position -1)*half_index:(current_position)*half_index].unsqueeze(2).clone() + dec_chunks.append(prev_decoded) + if current_position in info_positions: + if current_position in info_positions: + concatenated_chunks = torch.cat(dec_chunks, 2) + Lu = fnet_dict[current_position](concatenated_chunks) + decoded_llrs[:, current_position] = Lu.squeeze(2).squeeze(1) + u_hat = torch.tanh(Lu/2).squeeze(2) + u[:, current_position] = u_hat.squeeze(1) + return decoded_llrs[:, info_positions], u[:, info_positions] + + def kernel_parallel_decode(self, ell, fnet_dict, noisy_code, info_positions = None): + input_shape = noisy_code.shape[-1] + noisy_code = noisy_code.unsqueeze(2) + assert input_shape == ell + u = torch.ones(noisy_code.shape[0], self.N, dtype=torch.float).to(self.device) + decoded_llrs = self.infty*torch.ones(noisy_code.shape[0], self.N, device = noisy_code.device) + half_index = 1 + dec_chunks = torch.cat([noisy_code[:, (j)*half_index:(j+1)*half_index].clone() for j in range(ell)], 2) + + decoded_llrs = fnet_dict(dec_chunks).squeeze(1) + u = torch.tanh(decoded_llrs/2).squeeze(1) + return decoded_llrs[:, info_positions], u[:, info_positions] + + +# ѵ�������Ż����� +# �γ�ѧϰ�������ӱ��븴�ӶȺ�����ˮƽ +# ��Ͼ���ѵ����ʹ��torch.cuda.amp����ѵ�� +# �Ľ�����ʧ��������Ͻ����غ��Զ�����ʧ + def enhanced_loss_function(pred_llrs, true_bits, snr): + # ������������ʧ + ce_loss = F.binary_cross_entropy_with_logits(pred_llrs, (true_bits+1)/2) + + # �Զ���ɿ�����ʧ + reliability = torch.sigmoid(torch.abs(pred_llrs)) + rel_loss = -torch.mean(reliability) + + # SNR����Ӧ��Ȩ + snr_weight = 1.0 / (1.0 + torch.exp(-0.1*(snr-10))) + total_loss = ce_loss + snr_weight * rel_loss + + return total_loss +# ѵ�������Ż��� +# ʵ�ֿγ�ѧϰ���� +# ���Ӹ����ӵ����򻯼��� +# �Ľ���ʧ���� +##ADD down \ No newline at end of file diff --git a/deeppolar-main/figures/256_37_improved_bler.pdf b/deeppolar-main/figures/256_37_improved_bler.pdf new file mode 100644 index 0000000000000000000000000000000000000000..42438ddfe15c5835d6dbafd673544a982dcab9b4 --- /dev/null +++ b/deeppolar-main/figures/256_37_improved_bler.pdf @@ -0,0 +1 @@ +deeppolar-main/figures/256_37_improved_bler.pdf \ No newline at end of file diff --git a/deeppolar-main/figures/256_37_improved_highsnr.pdf b/deeppolar-main/figures/256_37_improved_highsnr.pdf new file mode 100644 index 0000000000000000000000000000000000000000..90e75ae0be2223269e48d7edcddd149be332458e Binary files /dev/null and b/deeppolar-main/figures/256_37_improved_highsnr.pdf differ diff --git a/deeppolar-main/figures/256_37_list_comparison.pdf b/deeppolar-main/figures/256_37_list_comparison.pdf new file mode 100644 index 0000000000000000000000000000000000000000..d1f65a4e5935d6cf3f120b321fc6eaafff30a4c5 Binary files /dev/null and b/deeppolar-main/figures/256_37_list_comparison.pdf differ diff --git a/deeppolar-main/figures/ablation/bers_256_37_decoder.pdf b/deeppolar-main/figures/ablation/bers_256_37_decoder.pdf new file mode 100644 index 0000000000000000000000000000000000000000..58153c19fc2565c4e59db40c3c9b40d94ba04411 Binary files /dev/null and b/deeppolar-main/figures/ablation/bers_256_37_decoder.pdf differ diff --git a/deeppolar-main/figures/ablation/bers_256_37_fcnn.pdf b/deeppolar-main/figures/ablation/bers_256_37_fcnn.pdf new file mode 100644 index 0000000000000000000000000000000000000000..8b2b7436bb5a7e75d67c076bfd1549ed60d3d64a Binary files /dev/null and b/deeppolar-main/figures/ablation/bers_256_37_fcnn.pdf differ diff --git a/deeppolar-main/figures/ablation/bers_256_37_hidden_size.pdf b/deeppolar-main/figures/ablation/bers_256_37_hidden_size.pdf new file mode 100644 index 0000000000000000000000000000000000000000..3625d75d6d0e55b62521c3615e64e4c7be0f2c17 Binary files /dev/null and b/deeppolar-main/figures/ablation/bers_256_37_hidden_size.pdf differ diff --git a/deeppolar-main/figures/ablation/blers_256_37_hidden_size.pdf b/deeppolar-main/figures/ablation/blers_256_37_hidden_size.pdf new file mode 100644 index 0000000000000000000000000000000000000000..b75630e1668d8dff6c4fe235563e1e2b009ba918 Binary files /dev/null and b/deeppolar-main/figures/ablation/blers_256_37_hidden_size.pdf differ diff --git a/deeppolar-main/figures/ablation/blers_256_37_training.pdf b/deeppolar-main/figures/ablation/blers_256_37_training.pdf new file mode 100644 index 0000000000000000000000000000000000000000..cc7f74a7306affbb6c5ffb7b275175bf21762eb5 Binary files /dev/null and b/deeppolar-main/figures/ablation/blers_256_37_training.pdf differ diff --git a/deeppolar-main/figures/binary/bers_256_37_list_comp_binary.pdf b/deeppolar-main/figures/binary/bers_256_37_list_comp_binary.pdf new file mode 100644 index 0000000000000000000000000000000000000000..f4cffd3746aeae7f6be6f676bab96b794e782667 Binary files /dev/null and b/deeppolar-main/figures/binary/bers_256_37_list_comp_binary.pdf differ diff --git a/deeppolar-main/figures/binary/blers_256_28_binary.pdf b/deeppolar-main/figures/binary/blers_256_28_binary.pdf new file mode 100644 index 0000000000000000000000000000000000000000..e4bba4a8e549acf16056b238bc92a00c82a7bb59 Binary files /dev/null and b/deeppolar-main/figures/binary/blers_256_28_binary.pdf differ diff --git a/deeppolar-main/figures/binary/blers_256_37_binary.pdf b/deeppolar-main/figures/binary/blers_256_37_binary.pdf new file mode 100644 index 0000000000000000000000000000000000000000..a6f51c2e65e998f562882fbb15ab165395126693 Binary files /dev/null and b/deeppolar-main/figures/binary/blers_256_37_binary.pdf differ diff --git a/deeppolar-main/figures/binary/blers_256_64_binary.pdf b/deeppolar-main/figures/binary/blers_256_64_binary.pdf new file mode 100644 index 0000000000000000000000000000000000000000..097fa18699d0674d209950c9c18f954e989e0dd5 Binary files /dev/null and b/deeppolar-main/figures/binary/blers_256_64_binary.pdf differ diff --git a/deeppolar-main/figures/non-awgn/256_37_bursty_adaptivity_sigma10.pdf b/deeppolar-main/figures/non-awgn/256_37_bursty_adaptivity_sigma10.pdf new file mode 100644 index 0000000000000000000000000000000000000000..b9e63c909d4a158ca2fb286279c1d05841b36029 Binary files /dev/null and b/deeppolar-main/figures/non-awgn/256_37_bursty_adaptivity_sigma10.pdf differ diff --git a/deeppolar-main/figures/non-awgn/256_37_fading_adaptivity.pdf b/deeppolar-main/figures/non-awgn/256_37_fading_adaptivity.pdf new file mode 100644 index 0000000000000000000000000000000000000000..49d13ddcdbb1723c46ed76c2f50651c22d1dd772 Binary files /dev/null and b/deeppolar-main/figures/non-awgn/256_37_fading_adaptivity.pdf differ diff --git a/deeppolar-main/main.py b/deeppolar-main/main.py new file mode 100644 index 0000000000000000000000000000000000000000..3fe23789a92dbf44e25fa5807835b5650c6a503f --- /dev/null +++ b/deeppolar-main/main.py @@ -0,0 +1,590 @@ +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torch.utils.data + +import os +import time +import matplotlib +matplotlib.use('AGG') +import matplotlib.pyplot as plt + +from deeppolar import DeepPolar +from polar import PolarCode, get_frozen +from trainer import train, deeppolar_full_test +from trainer_utils import save_model, plot_stuff +from collections import defaultdict +from itertools import combinations +from utils import snr_db2sigma, pairwise_distances + +import random +import numpy as np +from tqdm import tqdm +import sys +import csv + + +# set up matplotlib +is_ipython = 'inline' in matplotlib.get_backend() +if is_ipython: + from IPython import display + +plt.ion() + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + +def parse_int_list(input_string): + """Converts a comma-separated string into a list of integers.""" + try: + if len(input_string) == 0 : + return [] + return [int(item) for item in input_string.split(',')] + except ValueError: + raise argparse.ArgumentTypeError(f"List must contain integers, got '{input_string}'") + +# NN definition - define_kernel_nns for kernel, define_and_load_nns for general NN +# Encoding kernel_encode or deeppolar_encode +# Decoding kernel_decode, or deeppolar_decode + +def get_args(): + parser = argparse.ArgumentParser(description='DeepPolar codes') + + # General parameters + parser.add_argument('--id', type=str, default=None, help='ID: optional, to run multiple runs of same hyperparameters') #Will make a folder like init_932 , etc. + parser.add_argument('--test', dest = 'test', default=False, action='store_true', help='Testing?') + parser.add_argument('--pairwise', dest = 'pairwise', default=False, action='store_true', help='Plot codeword pairwise distances') + parser.add_argument('--epos', dest = 'epos', default=False, action='store_true', help='Plot error positions') + parser.add_argument('--only_args', dest = 'only_args', default=False, action='store_true', help='Helper to load functions on jupyter') + parser.add_argument('--gpu', type=int, default=-2, help='gpus used for training - e.g 0,1,3. -2 for cuda, -1 for cpu') + parser.add_argument('--seed', type=int, default=None, help='random seed') + parser.add_argument('--anomaly', dest = 'anomaly', default=False, action='store_true', help='enable anomaly detection') + parser.add_argument("--dataparallel", type=str2bool, nargs='?', + const=True, default=False, + help="Use Dataparallel") + + # Code parameters + parser.add_argument('--N', type=int, default=256, help = 'Block length')#, choices=[4, 8, 16, 32, 64, 128], help='Polar code parameter N') + parser.add_argument('--K', type=int, default=37, help = 'Message size')#, choices= [3, 4, 8, 16, 32, 64], help='Polar code parameter K') + parser.add_argument('--rate_profile', type=str, default='polar', choices=['RM', 'polar', 'sorted', 'last', 'rev_polar', 'custom'], help='Polar rate profiling') + # parser.add_argument('--target_K', type=int, default=None)#, choices= [3, 4, 8, 16, 32, 64], help='Polar code parameter K') + parser.add_argument('-ell', '--kernel_size', type=int, default=16, help = 'Kernel size') + parser.add_argument('--polar_depths', type=parse_int_list, default = '',help='A comma-separated list of integers') + parser.add_argument('--last_ell', type=int, default=None, help='use kernel last_ell last layer') + parser.add_argument('--infty', type=float, default=1000., help = 'Infinity value (used for frozen position LLR in polar dec)') + parser.add_argument('--lse', type=str, default='minsum', choices=['minsum', 'lse'], help='LSE function for polar SC decoder') + parser.add_argument('--hard_decision', dest = 'hard_decision', default=False, action='store_true', help='polar code sc decoding hard decision?') + + # DeepPolar parameters + parser.add_argument('--encoder_type', type=str, default='KO', choices=['KO', 'scaled', 'polar'], help='Type of encoding') + parser.add_argument('--decoder_type', type=str, default='KO', choices=['KO', 'SC', 'KO_parallel', 'KO_last_parallel'], help='Type of encoding') + parser.add_argument('--enc_activation', type=str, default='selu', choices=['selu', 'leaky_relu', 'gelu', 'silu', 'elu', 'mish', 'identity'], help='Activation function') + parser.add_argument('--dec_activation', type=str, default='selu', choices=['selu', 'leaky_relu', 'gelu', 'silu', 'elu', 'mish', 'identity'], help='Activation function') + parser.add_argument('--dropout_p', type=float, default=0.) + parser.add_argument('--dec_hidden_size', type=int, default=128, help='neural network size') + parser.add_argument('--enc_hidden_size', type=int, default=64, help='neural network size') + parser.add_argument('-fd', '--f_depth', type=int, default=3, help='decoder neural network depth') + parser.add_argument('-gd', '--g_depth', type=int, default=3, help='encoder neural network depth') + parser.add_argument('-gsd', '--g_skip_depth', type=int, default=1, help='encoder neural network depth') + parser.add_argument('-gsl', '--g_skip_layer', type=int, default=1, help='encoder neural network depth') + parser.add_argument("--onehot", type=str2bool, nargs='?', + const=True, default=False, + help="Use onehot representation of prev_decoded_bits?") + parser.add_argument("--shared", type=str2bool, nargs='?', + const=True, default=False, + help="Share weights across depth?") + parser.add_argument("--skip", type=str2bool, nargs='?', + const=True, default=True, + help="Use skip") + parser.add_argument("--use_norm", type=str2bool, nargs='?', + const=True, default=False, + help="Use norm") + parser.add_argument("--binary", type=str2bool, nargs='?', + const=True, default=False, + help="") + + # Training parameters + parser.add_argument('-fi', '--full_iters', type=int, default=20000, help='full iterations') + parser.add_argument('-ei', '--enc_train_iters', type=int, default=20, help='encoder iterations') #50 + parser.add_argument('-di', '--dec_train_iters', type=int, default=200, help='decoder iterations') #500 + + parser.add_argument('--enc_train_snr', type=float, default=0., help='snr at enc are trained') + parser.add_argument('--dec_train_snr', type=float, default=-2., help='snr at dec are trained') + + + parser.add_argument('--initialization', type=str, default='random', choices=['random', 'zeros'], help='initialization') + parser.add_argument('--optim', type=str, default='Adam', choices=['Adam', 'RMS', 'SGD', 'AdamW'], help='optimizer type') + parser.add_argument('--weight_decay', type=float, default=0.0) + parser.add_argument('--loss', type=str, default='BCE', choices=['MSE', 'BCE', 'BCE_reg', 'L1', 'huber', 'focal', 'BCE_bler'], help='loss function') + parser.add_argument('--dec_lr', type=float, default=0.0003, help='Decoder Learning rate') + parser.add_argument('--enc_lr', type=float, default=0.0003, help='Encoder Learning rate') + + parser.add_argument('--regularizer', type=str, default=None, choices=['std', 'max_deviation','polar'], help='regularize the kernel pretraining') + parser.add_argument('-rw', '--regularizer_weight', type=float, default=0.001) + + parser.add_argument('--scheduler', type=str, default=None, choices = ['reduce', '1cycle'],help='size of the batches') + parser.add_argument('--scheduler_patience', type=int, default=None, help='size of the batches') + + parser.add_argument('--small_batch_size', type=int, default=20000, help='size of the batches') + parser.add_argument('--batch_size', type=int, default=20000, help='size of the batches') + parser.add_argument("--batch_schedule", type=str2bool, nargs='?', + const=True, default=False, + help="Batch scheduler") + parser.add_argument('--batch_patience', type=int, default=50, help='patience') + parser.add_argument('--batch_factor', type=int, default=2, help='patience') + parser.add_argument('--min_batch_size', type=int, default=500, help='patience') + parser.add_argument('--max_batch_size', type=int, default=50000, help='patience') + + + parser.add_argument('--noise_type', type=str, default='awgn', choices=['fading', 'awgn', 'radar'], help='loss function') + parser.add_argument('--radar_power', type=float, default=None, help='snr at dec are trained') + parser.add_argument('--radar_prob', type=float, default=0.1, help='snr at dec are trained') + + + # TESTING parameters + parser.add_argument('--model_save_per', type=int, default=100, help='num of episodes after which model is saved') + parser.add_argument('--test_snr_start', type=float, default=-5., help='testing snr start') + parser.add_argument('--test_snr_end', type=float, default=-1., help='testing snr end') + parser.add_argument('--snr_points', type=int, default=5, help='testing snr num points') + parser.add_argument('--test_batch_size', type=int, default=10000, help='size of the batches') + parser.add_argument('--num_errors', type=int, default=100, help='Test until _ block errors') + parser.add_argument('--model_iters', type=int, default=None, help='by default load final model, option to load a model of x episodes') + parser.add_argument('--test_load_path', type=str, default=None, help='load test model given path') + parser.add_argument('--save_path', type=str, default=None, help='save name') + parser.add_argument('--load_path', type=str, default=None, help='load name') + parser.add_argument('--kernel_load_path', type=str, default=None, help='load name') + parser.add_argument("--no_fig", type=str2bool, nargs='?', + const=True, default=False, + help="Plot fig?") + + args = parser.parse_args() + + if args.small_batch_size > args.batch_size: + args.small_batch_size = args.batch_size + + return args + +if __name__ == '__main__': + + args = get_args() + if not args.test: + print(args) + + if args.anomaly: + torch.autograd.set_detect_anomaly(True) + + if torch.cuda.is_available() and args.gpu != -1: + if args.gpu == -2: + device = torch.device("cuda") + else: + device = torch.device("cuda:{0}".format(args.gpu)) + else: + if args.gpu != 1: + print(f"GPU device {args.gpu if args.gpu != -2 else ''} not found.") + device = torch.device("cpu") + + if args.seed is not None: + torch.manual_seed(args.seed) + + ID = str(np.random.randint(100000, 999999)) if args.id is None else args.id + + if args.save_path is not None: + results_save_path = args.save_path + else: + if args.encoder_type == 'polar': + results_save_path = './Polar_Results/Polar({0},{1})/Scheme_{2}/KO_Decoder/{3}'.format(args.K, args.N, args.rate_profile, ID) + elif 'KO' in args.encoder_type: + if args.decoder_type == 'KO_last_parallel': + dec = '_lp' + else: + dec = '' + results_save_path = f"./Polar_Results/Polar_{args.kernel_size}({args.N},{args.K})/Scheme_{args.rate_profile}/{args.encoder_type}_Encoder{dec}_Decoder/{ID}" + elif args.encoder_type == 'scaled': + if args.decoder_type == 'SC': + results_save_path = './Polar_Results/Polar({0},{1})/Scheme_{2}/Scaled_Decoder/{3}'.format(args.K, args.N, args.rate_profile, ID) + else: + results_save_path = './Polar_Results/Polar({0},{1})/Scheme_{2}/KO_Scaled_Decoder/{3}'.format(args.K, args.N, args.rate_profile, ID) + + + ############ + ## Polar Code parameters + ############ + K = args.K + N = args.N + + ############### + ### Polar code + ############## + + ### Encoder + + if args.last_ell is not None: + depth_map = defaultdict(int) + n = int(np.log2(args.N // args.last_ell) // np.log2(args.kernel_size)) + for d in range(1, n+1): + depth_map[d] = args.kernel_size + depth_map[n+1] = args.last_ell + assert np.prod(list(depth_map.values())) == args.N + polar = DeepPolar(args, device, args.N, args.K, infty = args.infty, depth_map = depth_map) + else: + polar = DeepPolar(args, device, args.N, args.K, args.kernel_size, args.infty) + + info_inds = polar.info_positions + frozen_inds = polar.frozen_positions + + print("Frozen positions : {}".format(frozen_inds)) + if args.only_args: + print("Loaded args. Exiting") + sys.exit() + ############## + ### Neural networks + ############## + ell = args.kernel_size + if args.N == ell: # Kernel pre-training + polar.define_kernel_nns(ell = args.kernel_size, unfrozen = polar.info_positions, fnet = args.decoder_type, gnet = args.encoder_type, shared = args.shared) + elif args.N > ell: # Initialize full network with pretrained kernels + polar.define_and_load_nns(ell = args.kernel_size, kernel_load_path=args.kernel_load_path, fnet = args.decoder_type, gnet = args.encoder_type, shared = args.shared, dataparallel=args.dataparallel) + + if args.binary: + args.load_path = os.path.join(results_save_path, 'Models/fnet_gnet_final.pt') + assert os.path.exists(args.load_path), "Model does not exist!!" + results_save_path = os.path.join(results_save_path, 'Binary') + os.makedirs(results_save_path, exist_ok=True) + os.makedirs(results_save_path +'/Models', exist_ok=True) + + if args.load_path is not None: + if args.test: + if args.test_load_path is None: + print("WARNING : have you used load_path instead of test_load_path?") + else: + checkpoint1 = torch.load(args.load_path , map_location=lambda storage, loc: storage) + fnet_dict = checkpoint1[0] + gnet_dict = checkpoint1[1] + + polar.load_partial_nns(fnet_dict, gnet_dict) + print("Loaded nets from {}".format(args.load_path)) + + if 'KO' in args.decoder_type: + dec_params = [] + for i in polar.fnet_dict.keys(): + for j in polar.fnet_dict[i].keys(): + if isinstance(polar.fnet_dict[i][j], dict): + for k in polar.fnet_dict[i][j].keys(): + dec_params += list(polar.fnet_dict[i][j][k].parameters()) + else: + dec_params += list(polar.fnet_dict[i][j].parameters()) + elif args.decoder_type == 'RNN': + dec_params = polar.fnet_dict.parameters() + else: + args.dec_train_iters = 0 + + if 'KO' in args.encoder_type: + enc_params = [] + if args.shared: + for i in polar.gnet_dict.keys(): + enc_params += list(polar.gnet_dict[i].parameters()) + else: + for i in polar.gnet_dict.keys(): + for j in polar.gnet_dict[i].keys(): + enc_params += list(polar.gnet_dict[i][j].parameters()) + elif args.encoder_type == 'scaled': + enc_params = [polar.a] + enc_optimizer = optim.Adam(enc_params, lr = args.enc_lr) + else: + args.enc_train_iters = 0 + + if args.dec_train_iters > 0: + if args.optim == 'Adam': + dec_optimizer = optim.Adam(dec_params, lr = args.dec_lr, weight_decay = args.weight_decay)#, momentum=0.9, nesterov=True) #, amsgrad=True) + elif args.optim == 'SGD': + dec_optimizer = optim.SGD(dec_params, lr = args.dec_lr, weight_decay = args.weight_decay)#, momentum=0.9, nesterov=True) #, amsgrad=True) + elif args.optim == 'RMS': + dec_optimizer = optim.RMSprop(dec_params, lr = args.dec_lr, weight_decay = args.weight_decay)#, momentum=0.9, nesterov=True) #, amsgrad=True) + if args.scheduler == 'reduce': + dec_scheduler = optim.lr_scheduler.ReduceLROnPlateau(dec_optimizer, 'min', patience = args.scheduler_patience) + elif args.scheduler == '1cycle': + dec_scheduler = optim.lr_scheduler.OneCycleLR(dec_optimizer, max_lr = args.dec_lr, total_steps=args.dec_train_iters*args.full_iters) + else: + dec_scheduler = None + + if args.enc_train_iters > 0: + enc_optimizer = optim.Adam(enc_params, lr = args.enc_lr)#, momentum=0.9, nesterov=True) #, amsgrad=True) + if args.scheduler == 'reduce': + enc_scheduler = optim.lr_scheduler.ReduceLROnPlateau(enc_optimizer, 'min', patience = args.scheduler_patience) + elif args.scheduler == '1cycle': + enc_scheduler = optim.lr_scheduler.OneCycleLR(enc_optimizer, max_lr = args.enc_lr, total_steps=args.enc_train_iters*args.full_iters) + else: + enc_scheduler = None + + + + + if 'BCE' in args.loss: + criterion = nn.BCEWithLogitsLoss() + elif args.loss == 'L1': + criterion = nn.L1Loss() + elif args.loss == 'huber': + criterion = nn.HuberLoss() + else: + criterion = nn.MSELoss() + + info_positions = polar.info_positions + if not args.test: + bers_enc = [] + losses_enc = [] + bers_dec = [] + losses_dec = [] + train_ber_dec = 0. + train_ber_enc = 0. + loss_dec = 0. + loss_enc = 0. + # val_bers = [] + os.makedirs(results_save_path, exist_ok=True) + os.makedirs(results_save_path +'/Models', exist_ok=True) + + # Create CSV at the beginning of training + save_path_id = random.randint(100000, 999999) + with open(os.path.join(results_save_path, f'training_results_{save_path_id}.csv'), 'w', newline='') as csvfile: + csvwriter = csv.writer(csvfile) + csvwriter.writerow(['Step', 'Loss', 'BER']) + + # save args in a json file + + + ############## + ### Optimizers + ############## + + print("Need to save for:", args.model_save_per) + if not args.batch_schedule: + batch_size = args.batch_size + else: + batch_size = args.min_batch_size + best_batch_ber = 10. + best_batch_iter = 0 + try: + best_ber = 10. + for iter in range(1, args.full_iters + 1): + start_time = time.time() + + if not args.batch_schedule: + batch_size = args.batch_size + elif batch_size != args.max_batch_size: + if iter - best_batch_iter > args.batch_patience: + batch_size = min(batch_size * 2, args.max_batch_size) + print(f"Increased batch size to {batch_size}") + best_batch_ber = train_ber_enc + best_batch_iter = iter + if 'KO' in args.decoder_type or args.decoder_type == 'RNN': + # Train decoder + loss_dec, train_ber_dec = train(args, polar, dec_optimizer, dec_scheduler if args.scheduler == '1cycle' else None, batch_size, args.dec_train_snr, args.dec_train_iters, criterion, device, info_positions, binary = args.binary, noise_type = args.noise_type) + if args.scheduler_patience is not None: + dec_scheduler.step(loss_dec) + bers_dec.append(train_ber_dec) + losses_dec.append(loss_dec) + if 'KO' in args.encoder_type: + # Train encoder + loss_enc, train_ber_enc = train(args, polar, enc_optimizer, enc_scheduler if args.scheduler == '1cycle' else None, batch_size, args.enc_train_snr, args.enc_train_iters, criterion, device, info_positions, binary = args.binary, noise_type = args.noise_type) + if args.scheduler_patience is not None: + enc_scheduler.step(loss_enc) + bers_enc.append(train_ber_enc) + losses_enc.append(loss_enc) + + + if args.batch_schedule and train_ber_enc < best_batch_ber: + best_batch_ber = train_ber_enc + best_batch_iter = iter + print(f'Best BER {best_batch_ber} at {best_batch_iter}') + + # Save to CSV + with open(os.path.join(results_save_path, f'training_results_{save_path_id}.csv'), 'a', newline='') as csvfile: + csvwriter = csv.writer(csvfile) + csvwriter.writerow([iter, loss_enc, train_ber_enc, loss_dec, train_ber_dec]) + if iter % 10 == 1: + print(f"[{iter}/{args.full_iters}] At {args.dec_train_snr} dB, Train Loss: {loss_dec} Train BER {train_ber_dec}, \ + \n [{iter}/{args.full_iters}] At {args.enc_train_snr} dB, Train Loss: {loss_enc} Train BER {train_ber_enc}") + print("Time for one full iteration is {0:.4f} minutes. save ID = {1}".format((time.time() - start_time)/60, ID)) + if iter % args.model_save_per == 0 or iter == 1: + if train_ber_enc < best_ber: + best_ber = train_ber_enc + best = True + else: + best = False + save_model(polar, iter, results_save_path, best = best) + plot_stuff(bers_enc, losses_enc, bers_dec, losses_dec, results_save_path) + save_model(polar, iter, results_save_path) + plot_stuff(bers_enc, losses_enc, bers_dec, losses_dec, results_save_path) + + except KeyboardInterrupt: + + save_model(polar, iter, results_save_path) + plot_stuff(bers_enc, losses_enc, bers_dec, losses_dec, results_save_path) + + print("Exited and saved") + + + print("TESTING") + times = [] + results_load_path = results_save_path + + + if args.model_iters is not None: + checkpoint1 = torch.load(results_save_path +'/Models/fnet_gnet_{}.pt'.format(args.model_iters), map_location=lambda storage, loc: storage) + elif args.test_load_path is not None: + checkpoint1 = torch.load(args.test_load_path , map_location=lambda storage, loc: storage) + else: + checkpoint1 = torch.load(results_load_path +'/Models/fnet_gnet_final.pt', map_location=lambda storage, loc: storage) + + fnet_dict = checkpoint1[0] + gnet_dict = checkpoint1[1] + + polar.load_nns(fnet_dict, gnet_dict, shared = args.shared) + + if args.snr_points == 1 and args.test_snr_start == args.test_snr_end: + snr_range = [args.test_snr_start] + else: + snrs_interval = (args.test_snr_end - args.test_snr_start)* 1.0 / (args.snr_points-1) + snr_range = [snrs_interval* item + args.test_snr_start for item in range(args.snr_points)] + + start_time = time.time() + + # For polar code testing. + args2 = argparse.Namespace(**vars(args)) + args2.ell = 2 + Frozen = get_frozen(N, K, args2.rate_profile) + Frozen.sort() + polar_l_2 = PolarCode(int(np.log2(N)), args.K, Fr=Frozen, infty = args.infty, hard_decision=args.hard_decision) + + + if args.pairwise: + codebook_size = 1000 + all_msg_bits = 2 * (torch.rand(codebook_size, args.K, device = device) < 0.5).float() - 1 + deeppolar_codebook = polar.deeppolar_encode(all_msg_bits) + polar_codebook = polar_l_2.encode_plotkin(all_msg_bits) + gaussian_codebook = F.normalize(torch.randn(codebook_size, args.N), p=2, dim=1)*np.sqrt(args.N) + + from scipy import stats + w_statistic_deeppolar, p_value_deeppolar = stats.shapiro(deeppolar_codebook.detach().cpu().numpy()) + w_statistic_gaussian, p_value_gaussian = stats.shapiro(gaussian_codebook.detach().cpu().numpy()) + w_statistic_polar, p_value_polar = stats.shapiro(polar_codebook.detach().cpu().numpy()) + + print(f"Deeppolar Shapiro test W = {w_statistic_deeppolar}, p-value = {p_value_deeppolar}") + print(f"Gaussian Shapiro test W = {w_statistic_gaussian}, p-value = {p_value_gaussian}") + print(f"Polar Shapiro test W = {w_statistic_polar}, p-value = {p_value_polar}") + + dists_deeppolar, md_deeppolar = pairwise_distances(deeppolar_codebook) + dists_polar, md_polar = pairwise_distances(polar_codebook) + dists_gaussian, md_gaussian = pairwise_distances(gaussian_codebook) + + # Function to calculate and plot PDF + def plot_pdf(data, label, bins=30, alpha=0.5): + counts, bin_edges = np.histogram(data, bins=bins, density=True) + bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 + plt.plot(bin_centers, counts, label=label, alpha=alpha) + + # Plotting PDF for each list + plt.figure() + plot_pdf(dists_deeppolar, 'Neural', 300) + # plot_pdf(dists_polar, 'Polar', 300) + plot_pdf(dists_gaussian, 'Gaussian', 300) + + # Adding labels and title + plt.xlabel('Value') + plt.ylabel('Probability Density') + plt.title(f'Pairwise Distances - N = {args.N}, K = {args.K}') + plt.legend() + + # Show the plot + plt.savefig(os.path.join(results_save_path, f"hists_N{args.N}_K{args.K}_{args.id}_2.pdf")) + + if args.epos: + from collections import OrderedDict, Counter + + def get_epos(k1, k2): + # return counter for bit ocations of first-errors + bb = torch.ne(k1.cpu().sign(), k2.cpu().sign()) + # inds = torch.nonzero(bb)[:, 1].numpy() + idx = [] + for ii in range(bb.shape[0]): + try: + iii = list(bb.cpu().float().numpy()[ii]).index(1) + idx.append(iii) + except: + pass + counter = Counter(idx) + ordered_counter = OrderedDict(sorted(counter.items())) + return ordered_counter + + with torch.no_grad(): + for (k, msg_bits) in enumerate(Test_Data_Generator): + msg_bits = msg_bits.to(device) + polar_code = polar_l_2.encode_plotkin(msg_bits) + noisy_code = polar.channel(polar_code, args.dec_train_snr) + noise = noisy_code - polar_code + deeppolar_code = polar.deeppolar_encode(msg_bits) + noisy_deeppolar_code = deeppolar_code + noise + SC_llrs, decoded_SC_msg_bits = polar_l_2.sc_decode_new(noisy_code, args.dec_train_snr) + deeppolar_llrs, decoded_deeppolar_msg_bits = polar.deeppolar_decode(noisy_deeppolar_code) + + if k == 0: + epos_deeppolar = get_epos(msg_bits, decoded_deeppolar_msg_bits.sign()) + epos_SC = get_epos(msg_bits, decoded_SC_msg_bits.sign()) + else: + epos_deeppolar1 = get_epos(msg_bits, decoded_deeppolar_msg_bits.sign()) + epos_SC1 = get_epos(msg_bits, decoded_SC_msg_bits.sign()) + epos_deeppolar = epos_deeppolar + epos_deeppolar1 + epos_SC = epos_SC + epos_SC1 + + print(f"epos_deeppolar: {epos_deeppolar}") + print(f"EPOS_SC: {epos_SC}") + + + + start = time.time() + bers_SC_test, blers_SC_test, bers_deeppolar_test, blers_deeppolar_test = deeppolar_full_test(args, polar_l_2, polar, snr_range, device, info_positions, binary = args.binary, noise_type = args.noise_type, num_errors = args.num_errors) + print("Test SNRs : {}\n".format(snr_range)) + print(f"Test Sigmas : {[snr_db2sigma(s) for s in snr_range]}\n") + print("BERs of DeepPolar: {0}".format(bers_deeppolar_test)) + print("BERs of SC decoding: {0}".format(bers_SC_test)) + print("BLERs of DeepPolar: {0}".format(blers_deeppolar_test)) + print("BLERs of SC decoding: {0}".format(blers_SC_test)) + print(f"time = {(time.time() - start)/60} minutes") + ## BER + plt.figure(figsize = (12,8)) + + ok = 0 + plt.semilogy(snr_range, bers_deeppolar_test, label="DeepPolar", marker='*', linewidth=1.5) + + plt.semilogy(snr_range, bers_SC_test, label="SC decoder", marker='^', linewidth=1.5) + + ## BLER + plt.semilogy(snr_range, blers_deeppolar_test, label="DeepPolar (BLER)", marker='*', linewidth=1.5, linestyle='dashed') + + plt.semilogy(snr_range, blers_SC_test, label="SC decoder (BLER)", marker='^', linewidth=1.5, linestyle='dashed') + + plt.grid() + plt.xlabel("SNR (dB)", fontsize=16) + plt.ylabel("Error Rate", fontsize=16) + if args.enc_train_iters > 0: + plt.title("PolarC({2}, {3}): DeepPolar trained at Dec_SNR = {0} dB, Enc_SNR = {1}dB".format(args.dec_train_snr, args.enc_train_snr, args.K,args.N)) + else: + plt.title("Polar({1}, {2}): DeepPolar trained at Dec_SNR = {0} dB".format(args.dec_train_snr, args.K,args.N)) + plt.legend(prop={'size': 15}) + if args.test_load_path is not None: + os.makedirs('Polar_Results/figures', exist_ok=True) + fig_save_path = 'Polar_Results/figures/new_plot_DeepPolar.pdf' + else: + fig_save_path = results_load_path + f"/Step_{args.model_iters if args.model_iters is not None else 'final'}{'_binary' if args.binary else ''}.pdf" + if not args.no_fig: + plt.savefig(fig_save_path) + + plt.close() diff --git a/deeppolar-main/models.py b/deeppolar-main/models.py new file mode 100644 index 0000000000000000000000000000000000000000..1f34bcbb8944258e4e44255531bfc0c602155645 --- /dev/null +++ b/deeppolar-main/models.py @@ -0,0 +1,326 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import time +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import time +from math import sqrt + +# Add TT Linear layer implementation +class TTLinear(nn.Module): + def __init__(self, in_features, out_features, ranks, activation=None, bias=True): + super(TTLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.ranks = ranks + self.activation = activation + + # Factorize input and output dimensions + self.in_shape, self.out_shape = self.factorize_dims(in_features, out_features) + + # Create TT cores + self.cores = nn.ParameterList() + for i in range(len(self.in_shape)): + if i == 0: + input_rank = 1 + else: + input_rank = ranks[i-1] + + if i == len(self.in_shape)-1: + output_rank = 1 + else: + output_rank = ranks[i] + + core = nn.Parameter(torch.randn( + input_rank, self.in_shape[i], self.out_shape[i], output_rank + ) * sqrt(2.0 / (input_rank * self.in_shape[i]))) + self.cores.append(core) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter('bias', None) + + def factorize_dims(self, in_dim, out_dim): + # Simple factorization - can be customized + factors = [] + d = in_dim + while d > 1: + for i in range(int(sqrt(d)), 0, -1): + if d % i == 0: + factors.append(i) + factors.append(d // i) + break + d = factors.pop() + in_shape = factors + + factors = [] + d = out_dim + while d > 1: + for i in range(int(sqrt(d)), 0, -1): + if d % i == 0: + factors.append(i) + factors.append(d // i) + break + d = factors.pop() + out_shape = factors + + return in_shape, out_shape + + def forward(self, x): + batch_size = x.size(0) + + # Reshape input to tensor form + x = x.view(batch_size, *self.in_shape) + + # Contract input with first core + res = torch.einsum('bi,rivo->bvo', x, self.cores[0]) + + # Contract with remaining cores + for i in range(1, len(self.cores)): + res = torch.einsum('bvi,rivo->bvo', res, self.cores[i]) + + # Reshape to output dimension + res = res.contiguous().view(batch_size, self.out_features) + + if self.bias is not None: + res = res + self.bias + + if self.activation is not None: + res = self.activation(res) + + return res + + +# Modify get_activation_fn to include TTLinear option +def get_activation_fn(activation): + if activation == 'tanh': + return F.tanh + elif activation == 'elu': + return F.elu + elif activation == 'relu': + return F.relu + elif activation == 'selu': + return F.selu + elif activation == 'sigmoid': + return F.sigmoid + elif activation == 'gelu': + return F.gelu + elif activation == 'silu': + return F.silu + elif activation == 'mish': + return F.mish + elif activation == 'linear': + return nn.Identity() + else: + raise NotImplementedError(f'Activation function {activation} not implemented') + +# Modify g_Full to use TTLinear +class g_Full(nn.Module): + def __init__(self, input_size, hidden_size, output_size, depth=3, skip_depth=1, + skip_layer=1, ell=2, activation='selu', use_skip=False, augment=False, + use_tt=False, tt_ranks=None): + super(g_Full, self).__init__() + + self.input_size = input_size + self.hidden_size = hidden_size + self.output_size = output_size + self.depth = depth + self.ell = ell + self.ell_input_size = input_size//self.ell + self.augment = augment + self.activation_fn = get_activation_fn(activation) + self.skip_depth = skip_depth + self.skip_layer = skip_layer + self.use_skip = use_skip + self.use_tt = use_tt + self.tt_ranks = tt_ranks if tt_ranks is not None else [8, 8] # Default ranks + + if self.use_skip: + if self.use_tt: + self.skip = nn.ModuleList([TTLinear(self.input_size + self.output_size, + self.hidden_size, + ranks=self.tt_ranks, + activation=self.activation_fn)]) + self.skip.extend([TTLinear(self.hidden_size, self.hidden_size, + ranks=self.tt_ranks, + activation=self.activation_fn) + for ii in range(1, self.skip_depth)]) + else: + self.skip = nn.ModuleList([nn.Linear(self.input_size + self.output_size, + self.hidden_size, bias=True)]) + self.skip.extend([nn.Linear(self.hidden_size, self.hidden_size, bias=True) + for ii in range(1, self.skip_depth)]) + + if self.use_tt: + self.linears = nn.ModuleList([TTLinear(self.input_size, self.hidden_size, + ranks=self.tt_ranks, + activation=self.activation_fn)]) + self.linears.extend([TTLinear(self.hidden_size, self.hidden_size, + ranks=self.tt_ranks, + activation=self.activation_fn) + for ii in range(1, self.depth)]) + self.linears.append(TTLinear(self.hidden_size, self.output_size, + ranks=self.tt_ranks)) + else: + self.linears = nn.ModuleList([nn.Linear(self.input_size, self.hidden_size, bias=True)]) + self.linears.extend([nn.Linear(self.hidden_size, self.hidden_size, bias=True) + for ii in range(1, self.depth)]) + self.linears.append(nn.Linear(self.hidden_size, self.output_size, bias=True)) + + def __init__(self, input_size, hidden_size, output_size, depth=3, skip_depth = 1, skip_layer = 1, ell = 2, activation = 'selu', use_skip = False, augment = False): + super(g_Full, self).__init__() + + self.input_size = input_size + + + self.hidden_size = hidden_size + self.output_size = output_size + + self.depth = depth + self.ell = ell + self.ell_input_size = input_size//self.ell + + self.augment = augment + self.activation_fn = get_activation_fn(activation) + self.skip_depth = skip_depth + self.skip_layer = skip_layer + + self.use_skip = use_skip + if self.use_skip: + self.skip = nn.ModuleList([nn.Linear(self.input_size + self.output_size, self.hidden_size, bias=True)]) + self.skip.extend([nn.Linear(self.hidden_size, self.hidden_size, bias=True) for ii in range(1, self.skip_depth)]) + + self.linears = nn.ModuleList([nn.Linear(self.input_size, self.hidden_size, bias=True)]) + self.linears.extend([nn.Linear(self.hidden_size, self.hidden_size, bias=True) for ii in range(1, self.depth)]) + self.linears.append(nn.Linear(self.hidden_size, self.output_size, bias=True)) + + + @staticmethod + def get_augment(msg, ell): + u = msg.clone() + n = int(np.log2(ell)) + for d in range(0, n): + num_bits = 2**d + for i in np.arange(0, ell, 2*num_bits): + # [u v] encoded to [u xor(u,v)] + if len(u.shape) == 2: + u = torch.cat((u[:, :i], u[:, i:i+num_bits].clone() * u[:, i+num_bits: i+2*num_bits], u[:, i+num_bits:]), dim=1) + elif len(u.shape) == 3: + u = torch.cat((u[:, :, :i], u[:, :, i:i+num_bits].clone() * u[:, :, i+num_bits: i+2*num_bits], u[:, :, i+num_bits:]), dim=2) + + # u[:, i:i+num_bits] = u[:, i:i+num_bits].clone() * u[:, i+num_bits: i+2*num_bits].clone + if len(u.shape) == 3: + return u[:, :, :-1] + elif len(u.shape) == 2: + return u[:, :-1] + + def forward(self, y): + + x = y.clone() + for ii, layer in enumerate(self.linears): + if ii != self.depth: + x = self.activation_fn(layer(x)) + if self.use_skip and ii == self.skip_layer: + if len(x.shape) == 3: + skip_input = torch.cat([y, g_Full.get_augment(y, self.ell)], dim = 2) + elif len(x.shape) == 2: + skip_input = torch.cat([y, g_Full.get_augment(y, self.ell)], dim = 1) + for jj, skip_layer in enumerate(self.skip): + skip_input = self.activation_fn(skip_layer(skip_input)) + x = x + skip_input + else: + x = layer(x) + if self.augment: + x = x + g_Full.get_augment(y, self.ell) + return x + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + m.weight.data.normal_(0.0, 0.01) + elif classname.find('BatchNorm') != -1: + m.weight.data.normal_(0.0, 0.01) + m.bias.data.fill_(0) + elif classname.find('Linear') != -1: + m.weight.data.normal_(0.0, 0.01) + try: + m.bias.data.fill_(0.) + except: + pass + +# Modify f_Full to use TTLinear +class f_Full(nn.Module): + def __init__(self, input_size, hidden_size, output_size, dropout_p=0., + activation='selu', depth=3, use_norm=False, use_tt=False, tt_ranks=None): + super(f_Full, self).__init__() + + self.input_size = input_size + self.hidden_size = hidden_size + self.output_size = output_size + self.depth = depth + self.use_norm = use_norm + self.activation_fn = get_activation_fn(activation) + self.use_tt = use_tt + self.tt_ranks = tt_ranks if tt_ranks is not None else [8, 8] # Default ranks + + if self.use_tt: + self.linears = nn.ModuleList([TTLinear(self.input_size, self.hidden_size, + ranks=self.tt_ranks, + activation=self.activation_fn)]) + for ii in range(1, self.depth): + self.linears.append(TTLinear(self.hidden_size, self.hidden_size, + ranks=self.tt_ranks, + activation=self.activation_fn)) + self.linears.append(TTLinear(self.hidden_size, self.output_size, + ranks=self.tt_ranks)) + else: + self.linears = nn.ModuleList([nn.Linear(self.input_size, self.hidden_size, bias=True)]) + for ii in range(1, self.depth): + self.linears.append(nn.Linear(self.hidden_size, self.hidden_size, bias=True)) + self.linears.append(nn.Linear(self.hidden_size, self.output_size, bias=True)) + + if self.use_norm: + self.norms = nn.ModuleList([nn.LayerNorm(self.hidden_size) + for _ in range(self.depth)]) + def __init__(self, input_size, hidden_size, output_size, dropout_p = 0., activation = 'selu', depth=3, use_norm = False): + super(f_Full, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.output_size = output_size + self.depth = depth + self.use_norm = use_norm + + self.activation_fn = get_activation_fn(activation) + + self.linears = nn.ModuleList([nn.Linear(self.input_size, self.hidden_size, bias=True)]) + if self.use_norm: + self.norms = nn.ModuleList([nn.LayerNorm(self.hidden_size)]) + for ii in range(1, self.depth): + self.linears.append(nn.Linear(self.hidden_size, self.hidden_size, bias=True)) + if self.use_norm: + self.norms.append(nn.LayerNorm(self.hidden_size)) + self.linears.append(nn.Linear(self.hidden_size, self.output_size, bias=True)) + + def forward(self, y, aug = None): + + x = y.clone() + for ii, layer in enumerate(self.linears): + if ii != self.depth: + x = layer(x) + if not hasattr(self, 'use_norm') or not self.use_norm: + pass + else: + x = self.norms[ii](x) + x = self.activation_fn(x) + else: + x = layer(x) + return x + +def get_onehot(actions): + inds = (0.5 + 0.5*actions).long() + return torch.eye(2, device = inds.device)[inds].reshape(actions.shape[0], -1) diff --git a/deeppolar-main/polar.py b/deeppolar-main/polar.py new file mode 100644 index 0000000000000000000000000000000000000000..0eaa32c2f4dcbcd18f1bff87a4cf76e7f58c3ab8 --- /dev/null +++ b/deeppolar-main/polar.py @@ -0,0 +1,547 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +import matplotlib +matplotlib.use('AGG') +import matplotlib.pyplot as plt +import pickle +import os +import argparse +import sys +import time +from collections import namedtuple + +from utils import snr_db2sigma, min_sum_log_sum_exp, log_sum_exp, errors_ber, errors_bler, corrupt_signal, countSetBits + + +def get_args(): + parser = argparse.ArgumentParser(description='(N,K) Polar code') + + parser.add_argument('--N', type=int, default=4, help='Polar code parameter N') + parser.add_argument('--K', type=int, default=3, help='Polar code parameter K') + parser.add_argument('--rate_profile', type=str, default='polar', choices=['RM', 'polar', 'sorted', 'sorted_last', 'rev_polar'], help='Polar rate profiling') + parser.add_argument('--hard_decision', dest = 'hard_decision', default=False, action='store_true') + parser.add_argument('--only_args', dest = 'only_args', default=False, action='store_true') + parser.add_argument('--list_size', type=int, default=1, help='SC List size') + parser.add_argument('--crc_len', type=int, default='0', choices=[0, 3, 8, 16], help='CRC length') + + parser.add_argument('--batch_size', type=int, default=10000, help='size of the batches') + parser.add_argument('--test_ratio', type = float, default = 1, help = 'Number of test samples x batch_size') + parser.add_argument('--test_snr_start', type=float, default=-2., help='testing snr start') + parser.add_argument('--test_snr_end', type=float, default=4., help='testing snr end') + parser.add_argument('--snr_points', type=int, default=7, help='testing snr num points') + args = parser.parse_args() + + return args + +class PolarCode: + + def __init__(self, n, K, Fr = None, rs = None, use_cuda = True, infty = 1000., hard_decision = False, lse = 'lse'): + + assert n>=1 + self.n = n + self.N = 2**n + self.K = K + self.G2 = np.array([[1,1],[0,1]]) + self.G = np.array([1]) + for i in range(n): + self.G = np.kron(self.G, self.G2) + self.G = torch.from_numpy(self.G).float() + self.device = torch.device("cuda" if use_cuda else "cpu") + self.infty = infty + self.hard_decision = hard_decision + self.lse = lse + + if Fr is not None: + assert len(Fr) == self.N - self.K + self.frozen_positions = Fr + self.unsorted_frozen_positions = self.frozen_positions + self.frozen_positions.sort() + + self.info_positions = np.array(list(set(self.frozen_positions) ^ set(np.arange(self.N)))) + self.unsorted_info_positions = self.info_positions + self.info_positions.sort() + + else: + if rs is None: + # in increasing order of reliability + self.reliability_seq = np.arange(1023, -1, -1) + self.rs = self.reliability_seq[self.reliability_seq +1, 1 -> -1 + # Therefore, xor(a, b) = a*b + if custom_info_positions is not None: + info_positions = custom_info_positions + else: + info_positions = self.info_positions + u = torch.ones(message.shape[0], self.N, dtype=torch.float).to(message.device) + u[:, info_positions] = message + + for d in range(0, self.n): + num_bits = 2**d + for i in np.arange(0, self.N, 2*num_bits): + # [u v] encoded to [u xor(u,v)] + u = torch.cat((u[:, :i], u[:, i:i+num_bits].clone() * u[:, i+num_bits: i+2*num_bits], u[:, i+num_bits:]), dim=1) + # u[:, i:i+num_bits] = u[:, i:i+num_bits].clone() * u[:, i+num_bits: i+2*num_bits].clone + if scaling is not None: + u = (scaling * np.sqrt(self.N)*u)/torch.norm(scaling) + return u + + def channel(self, code, snr, noise_type = 'awgn', vv =5.0, radar_power = 20.0, radar_prob = 5e-2): + if noise_type != "bsc": + sigma = snr_db2sigma(snr) + else: + sigma = snr + + r = corrupt_signal(code, sigma, noise_type, vv, radar_power, radar_prob) + + return r + + def define_partial_arrays(self, llrs): + # Initialize arrays to store llrs and partial_sums useful to compute the partial successive cancellation process. + llr_array = torch.zeros(llrs.shape[0], self.n+1, self.N, device=llrs.device) + llr_array[:, self.n] = llrs + partial_sums = torch.zeros(llrs.shape[0], self.n+1, self.N, device=llrs.device) + return llr_array, partial_sums + + + def updateLLR(self, leaf_position, llrs, partial_llrs = None, prior = None): + + #START + depth = self.n + decoded_bits = partial_llrs[:,0].clone() + if prior is None: + prior = torch.zeros(self.N) #priors + llrs, partial_llrs, decoded_bits = self.partial_decode(llrs, partial_llrs, depth, 0, leaf_position, prior, decoded_bits) + return llrs, decoded_bits + + + def partial_decode(self, llrs, partial_llrs, depth, bit_position, leaf_position, prior, decoded_bits=None): + # Function to call recursively, for partial SC decoder. + # We are assuming that u_0, u_1, .... , u_{leaf_position -1} bits are known. + # Partial sums computes the sums got through Plotkin encoding operations of known bits, to avoid recomputation. + # this function is implemented for rate 1 (not accounting for frozen bits in polar SC decoding) + + # print("DEPTH = {}, bit_position = {}".format(depth, bit_position)) + half_index = 2 ** (depth - 1) + leaf_position_at_depth = leaf_position // 2**(depth-1) # will tell us whether left_child or right_child + + # n = 2 tree case + if depth == 1: + # Left child + left_bit_position = 2*bit_position + if leaf_position_at_depth > left_bit_position: + u_hat = partial_llrs[:, depth-1, left_bit_position:left_bit_position+1] + elif leaf_position_at_depth == left_bit_position: + if self.lse == 'minsum': + Lu = min_sum_log_sum_exp(llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index], llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index]).sum(dim=1, keepdim=True) + elif self.lse == 'lse': + Lu = log_sum_exp(llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index], llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index]).sum(dim=1, keepdim=True) + # Lu = log_sum_avoid_zero_NaN(llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index], llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index]).sum(dim=1, keepdim=True) + #print(Lu.device, prior.device, torch.ones_like(Lu).device) + llrs[:, depth-1, left_bit_position*half_index:(left_bit_position+1)*half_index] = Lu + prior[left_bit_position]*torch.ones_like(Lu) + if self.hard_decision: + u_hat = torch.sign(Lu) + else: + u_hat = torch.tanh(Lu/2) + + decoded_bits[:, left_bit_position] = u_hat.squeeze(1) + + return llrs, partial_llrs, decoded_bits + + # Right child + right_bit_position = 2*bit_position + 1 + if leaf_position_at_depth > right_bit_position: + pass + elif leaf_position_at_depth == right_bit_position: + Lv = u_hat * llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index] + llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index] + llrs[:, depth-1, right_bit_position*half_index:(right_bit_position+1)*half_index] = Lv + prior[right_bit_position] * torch.ones_like(Lv) + if self.hard_decision: + v_hat = torch.sign(Lv) + else: + v_hat = torch.tanh(Lv/2) + decoded_bits[:, right_bit_position] = v_hat.squeeze(1) + return llrs, partial_llrs, decoded_bits + + # General case + else: + # LEFT CHILD + # Find likelihood of (u xor v) xor (v) = u + # Lu = log_sum_exp(torch.cat([llrs[:, :half_index].unsqueeze(2), llrs[:, half_index:].unsqueeze(2)], dim=2).permute(0, 2, 1)) + + left_bit_position = 2*bit_position + if leaf_position_at_depth > left_bit_position: + Lu = llrs[:, depth-1, left_bit_position*half_index:(left_bit_position+1)*half_index] + u_hat = partial_llrs[:, depth-1, left_bit_position*half_index:(left_bit_position+1)*half_index] + else: + if self.lse == 'minsum': + Lu = min_sum_log_sum_exp(llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index], llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index]) + elif self.lse == 'lse': + # Lu = log_sum_avoid_zero_NaN(llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index], llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index]) + Lu = log_sum_exp(llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index], llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index]) + + llrs[:, depth-1, left_bit_position*half_index:(left_bit_position+1)*half_index] = Lu + llrs, partial_llrs, decoded_bits = self.partial_decode(llrs, partial_llrs, depth-1, left_bit_position, leaf_position, prior, decoded_bits) + + return llrs, partial_llrs, decoded_bits + + # RIGHT CHILD + right_bit_position = 2*bit_position + 1 + + Lv = u_hat * llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index] + llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index] + llrs[:, depth-1, right_bit_position*half_index:(right_bit_position+1)*half_index] = Lv + llrs, partial_llrs, decoded_bits = self.partial_decode(llrs, partial_llrs, depth-1, right_bit_position, leaf_position, prior, decoded_bits) + + return llrs, partial_llrs, decoded_bits + + def updatePartialSums(self, leaf_position, decoded_bits, partial_llrs): + + u = decoded_bits.clone() + u[:, leaf_position+1:] = 0 + + for d in range(0, self.n): + partial_llrs[:, d] = u + num_bits = 2**d + for i in np.arange(0, self.N, 2*num_bits): + # [u v] encoded to [u xor(u,v)] + u = torch.cat((u[:, :i], u[:, i:i+num_bits].clone() * u[:, i+num_bits: i+2*num_bits], u[:, i+num_bits:]), dim=1) + partial_llrs[:, self.n] = u + return partial_llrs + + def sc_decode_new(self, corrupted_codewords, snr, use_gt = None, channel = 'awgn'): + + assert channel in ['awgn', 'bsc'] + + if channel == 'awgn': + noise_sigma = snr_db2sigma(snr) + llrs = (2/noise_sigma**2)*corrupted_codewords + elif channel == 'bsc': + # snr refers to transition prob + p = (torch.ones(1)*(snr + 1e-9)).to(corrupted_codewords.device) + llrs = (torch.clip(torch.log((1 - p) / p), -10000, 10000) * (corrupted_codewords + 1) - torch.clip(torch.log(p / (1-p)), -10000, 10000) * (corrupted_codewords - 1))/2 + + # step-wise implementation using updateLLR and updatePartialSums + + priors = torch.zeros(self.N) + priors[self.frozen_positions] = self.infty + + u_hat = torch.zeros(corrupted_codewords.shape[0], self.N, device=corrupted_codewords.device) + llr_array, partial_llrs = self.define_partial_arrays(llrs) + for ii in range(self.N): + #start = time.time() + llr_array , decoded_bits = self.updateLLR(ii, llr_array.clone(), partial_llrs, priors) + #print('SC update : {}'.format(time.time() - start), corrupted_codewords.shape[0]) + if use_gt is None: + u_hat[:, ii] = torch.sign(llr_array[:, 0, ii]) + else: + u_hat[:, ii] = use_gt[:, ii] + #start = time.time() + partial_llrs = self.updatePartialSums(ii, u_hat, partial_llrs) + #print('SC partial: {}s, {}', time.time() - start, 'frozen' if ii in self.frozen_positions else 'info') + decoded_bits = u_hat[:, self.info_positions] + return llr_array[:, 0, :].clone(), decoded_bits + + def get_CRC(self, message): + + # need to optimize. + # inout message should be int + + padded_bits = torch.cat([message, torch.zeros(self.CRC_len).int().to(message.device)]) + while len(padded_bits[0:self.K_minus_CRC].nonzero()): + cur_shift = (padded_bits != 0).int().argmax(0) + padded_bits[cur_shift: cur_shift + self.CRC_len + 1] = padded_bits[cur_shift: cur_shift + self.CRC_len + 1] ^ self.CRC_polynomials[self.CRC_len].to(message.device) + + return padded_bits[self.K_minus_CRC:] + + def CRC_check(self, message): + + # need to optimize. + # input message should be int + + padded_bits = message + while len(padded_bits[0:self.K_minus_CRC].nonzero()): + cur_shift = (padded_bits != 0).int().argmax(0) + padded_bits[cur_shift: cur_shift + polar.CRC_len + 1] ^= self.CRC_polynomials[self.CRC_len].to(message.device) + + if padded_bits[self.K_minus_CRC:].sum()>0: + return 0 + else: + return 1 + + def encode_with_crc(self, message, CRC_len): + self.CRC_len = CRC_len + self.K_minus_CRC = self.K - CRC_len + + if CRC_len == 0: + return self.encode_plotkin(message) + else: + crcs = 1-2*torch.vstack([self.get_CRC((0.5+0.5*message[jj]).int()) for jj in range(message.shape[0])]) + encoded = self.encode_plotkin(torch.cat([message, crcs], 1)) + + return encoded + +def get_frozen(N, K, rate_profile, target_K = None): + n = int(np.log2(N)) + if rate_profile == 'polar': + # computed for SNR = 0 + if n == 5: + rs = np.array([31, 30, 29, 27, 23, 15, 28, 26, 25, 22, 21, 14, 19, 13, 11, 24, 7, 20, 18, 12, 17, 10, 9, 6, 5, 3, 16, 8, 4, 2, 1, 0]) + + elif n == 4: + rs = np.array([15, 14, 13, 11, 7, 12, 10, 9, 6, 5, 3, 8, 4, 2, 1, 0]) + + # for RM :( + # rs = np.array([15, 14, 13, 11, 7, 12, 10, 9, 6, 3, 5, 8, 4, 2, 1, 0]) + + elif n == 3: + rs = np.array([7, 6, 5, 3, 4, 2, 1, 0]) + elif n == 2: + rs = np.array([3, 2, 1, 0]) + elif n<9: + rs = np.array([256 ,255 ,252 ,254 ,248 ,224 ,240 ,192 ,128 ,253 ,244 ,251 ,250 ,239 ,238 ,247 ,246 ,223 ,222 ,232 ,216 ,236 ,220 ,188 ,208 ,184 ,191 ,190 ,176 ,127 ,126 ,124 ,120 ,249 ,245 ,243 ,242 ,160 ,231 ,230 ,237 ,235 ,234 ,112 ,228 ,221 ,219 ,218 ,212 ,215 ,214 ,189 ,187 ,96 ,186 ,207 ,206 ,183 ,182 ,204 ,180 ,200 ,64 ,175 ,174 ,172 ,125 ,123 ,122 ,119 ,159 ,118 ,158 ,168 ,241 ,116 ,111 ,233 ,156 ,110 ,229 ,227 ,217 ,108 ,213 ,152 ,226 ,95 ,211 ,94 ,205 ,185 ,104 ,210 ,203 ,181 ,92 ,144 ,202 ,179 ,199 ,173 ,178 ,63 ,198 ,121 ,171 ,88 ,62 ,117 ,170 ,196 ,157 ,167 ,60 ,115 ,155 ,109 ,166 ,80 ,114 ,154 ,107 ,56 ,225 ,151 ,164 ,106 ,93 ,150 ,209 ,103 ,91 ,143 ,201 ,102 ,48 ,148 ,177 ,90 ,142 ,197 ,87 ,100 ,61 ,169 ,195 ,140 ,86 ,59 ,32 ,165 ,194 ,113 ,79 ,58 ,153 ,84 ,136 ,55 ,163 ,78 ,105 ,149 ,162 ,54 ,76 ,101 ,47 ,147 ,89 ,52 ,141 ,99 ,46 ,146 ,72 ,85 ,139 ,98 ,31 ,44 ,193 ,138 ,57 ,83 ,30 ,135 ,77 ,40 ,82 ,134 ,161 ,28 ,53 ,75 ,132 ,24 ,51 ,74 ,45 ,145 ,71 ,50 ,16 ,97 ,70 ,43 ,137 ,68 ,42 ,29 ,39 ,81 ,27 ,133 ,38 ,26 ,36 ,131 ,23 ,73 ,22 ,130 ,49 ,15 ,20 ,69 ,14 ,12 ,67 ,41 ,8 ,66 ,37 ,25 ,35 ,34 ,21 ,129 ,19 ,13 ,18 ,11 ,10 ,7 ,65 ,6 ,4 ,33 ,17 ,9 ,5 ,3 ,2 ,1 ]) - 1 + else: + rs = np.array([1023, 1022, 1021, 1019, 1015, 1007, 1020, 991, 1018, 1017, 1014, + 1006, 895, 1013, 1011, 959, 1005, 990, 1003, 989, 767, 1016, + 999, 1012, 987, 958, 983, 957, 1010, 1004, 955, 1009, 894, + 975, 893, 1002, 951, 1001, 988, 511, 766, 998, 891, 943, + 986, 997, 985, 887, 956, 765, 995, 927, 982, 981, 879, + 954, 974, 763, 953, 979, 510, 1008, 759, 863, 950, 892, + 1000, 973, 949, 509, 890, 971, 996, 942, 751, 984, 889, + 507, 947, 831, 886, 967, 941, 764, 926, 980, 994, 939, + 885, 993, 735, 878, 925, 503, 762, 883, 978, 935, 703, + 495, 952, 877, 761, 972, 923, 977, 948, 758, 862, 875, + 919, 970, 757, 861, 508, 969, 750, 946, 479, 888, 639, + 871, 911, 830, 940, 859, 755, 966, 945, 749, 506, 884, + 938, 965, 829, 734, 924, 855, 505, 747, 963, 937, 882, + 934, 827, 733, 447, 992, 847, 876, 501, 921, 702, 494, + 881, 760, 743, 933, 502, 918, 874, 922, 823, 731, 499, + 860, 756, 931, 701, 873, 493, 727, 917, 870, 976, 815, + 910, 383, 968, 478, 858, 754, 699, 491, 869, 944, 748, + 638, 915, 477, 719, 909, 964, 255, 799, 504, 857, 854, + 753, 828, 746, 695, 487, 907, 637, 867, 853, 475, 936, + 962, 446, 732, 826, 745, 846, 500, 825, 903, 687, 932, + 635, 471, 445, 742, 880, 498, 730, 851, 822, 382, 920, + 845, 741, 443, 700, 729, 631, 492, 872, 961, 726, 821, + 930, 497, 381, 843, 463, 916, 739, 671, 623, 490, 929, + 439, 814, 819, 868, 752, 914, 698, 725, 839, 856, 476, + 813, 718, 908, 486, 723, 866, 489, 607, 431, 697, 379, + 811, 798, 913, 575, 717, 254, 694, 636, 474, 807, 715, + 906, 797, 693, 865, 960, 852, 744, 634, 473, 795, 905, + 485, 415, 483, 470, 444, 375, 850, 740, 686, 902, 824, + 691, 253, 711, 633, 844, 685, 630, 901, 367, 791, 928, + 728, 820, 849, 783, 670, 899, 738, 842, 683, 247, 469, + 441, 442, 462, 251, 737, 438, 467, 351, 629, 841, 724, + 679, 669, 496, 461, 818, 380, 437, 627, 622, 459, 378, + 239, 488, 667, 838, 430, 484, 812, 621, 319, 817, 435, + 377, 696, 722, 912, 606, 810, 864, 716, 837, 721, 714, + 809, 796, 455, 472, 619, 835, 692, 663, 223, 414, 904, + 427, 806, 482, 632, 713, 690, 848, 605, 373, 252, 794, + 429, 710, 684, 615, 805, 900, 655, 468, 366, 603, 413, + 574, 481, 371, 250, 793, 466, 423, 374, 689, 628, 440, + 365, 709, 789, 803, 411, 573, 682, 249, 460, 790, 668, + 599, 350, 707, 246, 681, 465, 571, 626, 436, 407, 782, + 191, 127, 363, 620, 666, 458, 245, 349, 677, 434, 678, + 591, 787, 399, 457, 359, 238, 625, 840, 567, 736, 665, + 428, 376, 781, 898, 618, 675, 318, 454, 662, 243, 897, + 347, 836, 816, 720, 433, 604, 617, 779, 808, 661, 834, + 712, 804, 833, 559, 237, 453, 426, 222, 317, 775, 372, + 343, 412, 235, 543, 614, 451, 425, 422, 613, 370, 221, + 315, 480, 335, 659, 654, 364, 190, 369, 248, 653, 688, + 231, 410, 602, 611, 802, 792, 421, 651, 601, 598, 708, + 311, 219, 572, 597, 788, 570, 409, 590, 362, 801, 680, + 464, 406, 419, 348, 647, 786, 215, 589, 706, 361, 676, + 566, 189, 595, 244, 569, 303, 405, 358, 456, 346, 398, + 565, 242, 126, 705, 780, 587, 624, 664, 236, 187, 357, + 432, 785, 558, 674, 207, 403, 397, 452, 345, 563, 778, + 241, 316, 342, 616, 660, 557, 125, 234, 183, 287, 355, + 583, 673, 395, 424, 314, 220, 777, 341, 612, 658, 123, + 175, 774, 555, 233, 334, 542, 450, 313, 391, 230, 652, + 368, 218, 339, 600, 119, 333, 657, 610, 773, 541, 310, + 420, 159, 229, 650, 551, 596, 609, 408, 217, 449, 188, + 309, 214, 331, 111, 539, 360, 771, 649, 302, 418, 594, + 896, 227, 404, 646, 186, 588, 832, 568, 213, 417, 301, + 307, 356, 402, 800, 564, 327, 95, 206, 240, 535, 593, + 645, 586, 344, 396, 185, 401, 211, 354, 299, 585, 286, + 562, 643, 182, 205, 124, 232, 285, 295, 181, 556, 582, + 527, 394, 340, 63, 203, 561, 353, 448, 122, 283, 393, + 581, 554, 174, 390, 704, 312, 338, 228, 179, 784, 199, + 553, 121, 173, 389, 540, 579, 332, 118, 672, 550, 337, + 158, 279, 271, 416, 216, 308, 387, 538, 549, 226, 330, + 776, 171, 212, 117, 110, 329, 656, 157, 772, 306, 326, + 225, 167, 115, 537, 534, 184, 109, 300, 547, 305, 210, + 155, 533, 325, 352, 608, 400, 298, 204, 94, 648, 284, + 209, 151, 180, 107, 770, 297, 392, 323, 592, 202, 644, + 93, 294, 178, 103, 143, 282, 62, 336, 201, 120, 172, + 198, 769, 584, 91, 388, 293, 177, 526, 278, 281, 642, + 525, 531, 61, 170, 116, 197, 87, 156, 277, 114, 560, + 169, 59, 291, 580, 275, 523, 641, 270, 195, 552, 519, + 166, 224, 578, 108, 269, 79, 154, 113, 548, 577, 536, + 328, 55, 106, 165, 153, 150, 386, 208, 324, 546, 385, + 267, 47, 92, 163, 296, 304, 105, 102, 149, 263, 532, + 322, 292, 545, 90, 200, 31, 321, 530, 142, 176, 147, + 101, 141, 196, 524, 529, 290, 89, 280, 60, 86, 99, + 139, 168, 58, 522, 276, 85, 194, 289, 78, 135, 112, + 521, 57, 83, 54, 518, 274, 268, 768, 164, 77, 152, + 193, 53, 162, 104, 517, 273, 266, 75, 46, 148, 51, + 640, 100, 45, 576, 161, 265, 262, 71, 146, 30, 140, + 88, 515, 98, 43, 29, 261, 145, 138, 84, 259, 39, + 97, 27, 56, 82, 137, 76, 384, 134, 23, 52, 133, + 320, 15, 73, 50, 81, 131, 44, 70, 544, 192, 528, + 288, 520, 160, 272, 74, 49, 516, 42, 69, 28, 144, + 41, 67, 96, 514, 38, 264, 260, 136, 22, 25, 37, + 80, 513, 26, 258, 35, 132, 21, 257, 72, 14, 48, + 13, 19, 130, 68, 40, 11, 512, 66, 129, 7, 36, + 24, 34, 256, 20, 65, 33, 12, 128, 18, 10, 17, + 6, 9, 64, 5, 3, 32, 16, 8, 4, 2, 1, + 0]) + rs = rs[rs 0.5).float() + codes = polar.encode_plotkin(msg_bits) + + + for snr_ind, snr in enumerate(snr_range): + + # codes_G = polar.encode_G(msg_bits_bpsk) + noisy_code = polar.channel(codes, snr) + noise = noisy_code - codes + + SC_llrs, decoded_SC_msg_bits = polar.sc_decode_new(noisy_code, snr) + ber_SC = errors_ber(msg_bits, decoded_SC_msg_bits.sign()).item() + bler_SC = errors_bler(msg_bits, decoded_SC_msg_bits.sign()).item() + + bers_SC[snr_ind] += ber_SC/args.test_ratio + blers_SC[snr_ind] += bler_SC/args.test_ratio + + print("Test SNRs : ", snr_range) + print("BERs of SC: {0}".format(bers_SC)) + print("BLERs of SC: {0}".format(blers_SC)) diff --git a/deeppolar-main/pretrain.sh b/deeppolar-main/pretrain.sh new file mode 100644 index 0000000000000000000000000000000000000000..034faae7358eccabd0bb3f2c06d7e0f228756982 --- /dev/null +++ b/deeppolar-main/pretrain.sh @@ -0,0 +1,17 @@ +python -u main.py --N 16 --K 1 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 500 --batch_size 20000 --enc_train_snr -1 --dec_train_snr -3 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start -7 --test_snr_end 0 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_1_normal_polar_eh64_dh128_selu_new --gpu 0 --regularizer polar --regularizer_weight 0.05 +python -u main.py --N 16 --K 2 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 500 --batch_size 20000 --enc_train_snr 2 --dec_train_snr 0 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start -4 --test_snr_end 3 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_2_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_1_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05 +python -u main.py --N 16 --K 3 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 3 --dec_train_snr 1 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start -3 --test_snr_end 4 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_3_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_2_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05 +python -u main.py --N 16 --K 4 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 3 --dec_train_snr 2 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start -3 --test_snr_end 4 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_4_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_3_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05 +python -u main.py --N 16 --K 5 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 4 --dec_train_snr 2.5 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start -2 --test_snr_end 5 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_5_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_4_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05 +python -u main.py --N 16 --K 6 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 5 --dec_train_snr 4 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start -1 --test_snr_end 6 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_6_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_5_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05 +python -u main.py --N 16 --K 7 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 6 --dec_train_snr 4 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start 0 --test_snr_end 7 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_7_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_6_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05 +python -u main.py --N 16 --K 8 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 6 --dec_train_snr 5 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start 0 --test_snr_end 7 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_8_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_7_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05 +python -u main.py --N 16 --K 9 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 6.5 --dec_train_snr 5 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start 1 --test_snr_end 8 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_9_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_8_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05 +python -u main.py --N 16 --K 10 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 7 --dec_train_snr 6 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start 1 --test_snr_end 8 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_10_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_9_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05 +python -u main.py --N 16 --K 11 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 7 --dec_train_snr 6 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start 1 --test_snr_end 8 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_11_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_10_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05 +python -u main.py --N 16 --K 12 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 8.5 --dec_train_snr 7 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start 3 --test_snr_end 10 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_12_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_11_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05 +python -u main.py --N 16 --K 13 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 9 --dec_train_snr 8 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start 3 --test_snr_end 10 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_13_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_12_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05 +python -u main.py --N 16 --K 14 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 9.5 --dec_train_snr 8 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start 4 --test_snr_end 11 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_14_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_13_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05 +python -u main.py --N 16 --K 15 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 10 --dec_train_snr 9 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start 4 --test_snr_end 11 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_15_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_14_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05 +python -u main.py --N 16 --K 16 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 12 --dec_train_snr 11 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start 6 --test_snr_end 13 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_16_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_15_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05 +bash copy_files.sh 16 normal_polar_eh64_dh128_selu_new \ No newline at end of file diff --git a/deeppolar-main/requirements.txt b/deeppolar-main/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ca66570458f234322cf4ef2b19c06f55222d92e4 --- /dev/null +++ b/deeppolar-main/requirements.txt @@ -0,0 +1,88 @@ +# +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: +# +# pip-compile requirements.in +# +argparse==1.4.0 + # via -r requirements.in +contourpy==1.2.1 + # via matplotlib +cycler==0.12.1 + # via matplotlib +filelock==3.14.0 + # via + # torch + # triton +fonttools==4.51.0 + # via matplotlib +fsspec==2024.3.1 + # via torch +jinja2==3.1.4 + # via torch +kiwisolver==1.4.5 + # via matplotlib +markupsafe==2.1.5 + # via jinja2 +matplotlib==3.8.4 + # via -r requirements.in +mpmath==1.3.0 + # via sympy +networkx==3.3 + # via torch +numpy==1.26.4 + # via + # -r requirements.in + # contourpy + # matplotlib +nvidia-cublas-cu12==12.1.3.1 + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.1.105 + # via torch +nvidia-cuda-nvrtc-cu12==12.1.105 + # via torch +nvidia-cuda-runtime-cu12==12.1.105 + # via torch +nvidia-cudnn-cu12==8.9.2.26 + # via torch +nvidia-cufft-cu12==11.0.2.54 + # via torch +nvidia-curand-cu12==10.3.2.106 + # via torch +nvidia-cusolver-cu12==11.4.5.107 + # via torch +nvidia-cusparse-cu12==12.1.0.106 + # via + # nvidia-cusolver-cu12 + # torch +nvidia-nccl-cu12==2.20.5 + # via torch +nvidia-nvjitlink-cu12==12.4.127 + # via + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 +nvidia-nvtx-cu12==12.1.105 + # via torch +packaging==24.0 + # via matplotlib +pillow==10.3.0 + # via matplotlib +pyparsing==3.1.2 + # via matplotlib +python-dateutil==2.9.0.post0 + # via matplotlib +six==1.16.0 + # via python-dateutil +sympy==1.12 + # via torch +torch==2.3.0 + # via -r requirements.in +tqdm==4.66.4 + # via -r requirements.in +triton==2.3.0 + # via torch +typing-extensions==4.11.0 + # via torch diff --git a/deeppolar-main/trainer.py b/deeppolar-main/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..1119e2e0f1cb31c859f148b486ce1c3738fffb58 --- /dev/null +++ b/deeppolar-main/trainer.py @@ -0,0 +1,151 @@ +import torch +import torch.nn.functional as F +import numpy as np +from utils import errors_ber, errors_bler, dec2bitarray, snr_db2sigma +import time + + +def train(args, polar, optimizer, scheduler, batch_size, train_snr, train_iters, criterion, device, info_positions, binary = False, noise_type = 'awgn'): + + if args.N == polar.ell: + assert len(info_positions) == args.K + kernel = True + else: + kernel = False + + for iter in range(train_iters): + if batch_size > args.small_batch_size: + small_batch_size = args.small_batch_size + else: + small_batch_size = batch_size + + num_batches = batch_size // small_batch_size + for ii in range(num_batches): + msg_bits = 1 - 2*(torch.rand(small_batch_size, args.K) > 0.5).float().to(device) + if args.encoder_type == 'polar': + codes = polar.encode_plotkin(msg_bits) + elif 'KO' in args.encoder_type: + if kernel: + codes = polar.kernel_encode(args.kernel_size, polar.gnet_dict[1][0], msg_bits, info_positions, binary = binary) + else: + codes = polar.deeppolar_encode(msg_bits, binary = binary) + + noisy_codes = polar.channel(codes, train_snr, noise_type) + + if 'KO' in args.decoder_type: + if kernel: + if args.decoder_type == 'KO_parallel': + decoded_llrs, decoded_bits = polar.kernel_parallel_decode(args.kernel_size, polar.fnet_dict[1][0], noisy_codes, info_positions) + else: + decoded_llrs, decoded_bits = polar.kernel_decode(args.kernel_size, polar.fnet_dict[1][0], noisy_codes, info_positions) + else: + decoded_llrs, decoded_bits = polar.deeppolar_decode(noisy_codes) + elif args.decoder_type == 'SC': + decoded_llrs, decoded_bits = polar.sc_decode_new(noisy_codes, train_snr) + + if 'BCE' in args.loss or args.loss == 'focal': + loss = criterion(decoded_llrs, 0.5 * msg_bits.to(polar.device) + 0.5) + else: + loss = criterion(torch.tanh(0.5*decoded_llrs), msg_bits.to(polar.device)) + + if args.regularizer == 'std': + if args.K == 1: + loss += args.regularizer_weight * torch.std(codes, dim=1).mean() + elif args.K == 2: + loss += args.regularizer_weight * (0.5*torch.std(codes[:, ::2], dim=1).mean() + .5*torch.std(codes[:, 1::2], dim=1).mean()) + elif args.regularizer == 'max_deviation': + if args.K == 1: + loss += args.regularizer_weight * torch.amax(torch.abs(codes - codes.mean(dim=1, keepdim=True)), dim=1).mean() + elif args.K == 2: + loss += args.regularizer_weight * (0.5*torch.amax(torch.abs(codes[:, ::2] - codes[:, ::2].mean(dim=1, keepdim=True)), dim=1).mean() + .5*torch.amax(torch.abs(codes[:, 1::2] - codes[:, 1::2].mean(dim=1, keepdim=True)), dim=1).mean()) + elif args.regularizer == 'polar': + loss += args.regularizer_weight * F.mse_loss(codes, polar.encode_plotkin(msg_bits)) + + loss = loss/num_batches + loss.backward() + optimizer.step() + if scheduler is not None: + scheduler.step() + optimizer.zero_grad() + train_ber = errors_ber(decoded_bits.sign(), msg_bits.to(polar.device)).item() + + return loss.item(), train_ber + +def deeppolar_full_test(args, polar, KO, snr_range, device, info_positions, binary=False, num_errors=100, noise_type = 'awgn'): + bers_KO_test = [0. for _ in snr_range] + blers_KO_test = [0. for _ in snr_range] + + bers_SC_test = [0. for _ in snr_range] + blers_SC_test = [0. for _ in snr_range] + + kernel = args.N == KO.ell + + print(f"TESTING until {num_errors} block errors") + for snr_ind, snr in enumerate(snr_range): + total_block_errors_SC = 0 + total_block_errors_KO = 0 + batches_processed = 0 + + sigma = snr_db2sigma(snr) # Assuming SNR is given in dB and noise variance is derived from it + + try: + while min(total_block_errors_SC, total_block_errors_KO) <= num_errors: + msg_bits = 2 * (torch.rand(args.test_batch_size, args.K) < 0.5).float() - 1 + msg_bits = msg_bits.to(device) + polar_code = polar.encode_plotkin(msg_bits) + + if 'KO' in args.encoder_type: + if kernel: + KO_polar_code = KO.kernel_encode(args.kernel_size, KO.gnet_dict[1][0], msg_bits, info_positions, binary=binary) + else: + KO_polar_code = KO.deeppolar_encode(msg_bits, binary=binary) + + noisy_code = polar.channel(polar_code, snr, noise_type) + noise = noisy_code - polar_code + noisy_KO_code = KO_polar_code + noise if 'KO' in args.encoder_type else noisy_code + + SC_llrs, decoded_SC_msg_bits = polar.sc_decode_new(noisy_code, snr) + ber_SC = errors_ber(msg_bits, decoded_SC_msg_bits.sign()).item() + bler_SC = errors_bler(msg_bits, decoded_SC_msg_bits.sign()).item() + total_block_errors_SC += int(bler_SC*args.test_batch_size) + if 'KO' in args.decoder_type: + if kernel: + if args.decoder_type == 'KO_parallel': + KO_llrs, decoded_KO_msg_bits = KO.kernel_parallel_decode(args.kernel_size, KO.fnet_dict[1][0], noisy_KO_code, info_positions) + else: + KO_llrs, decoded_KO_msg_bits = KO.kernel_decode(args.kernel_size, KO.fnet_dict[1][0], noisy_KO_code, info_positions) + else: + KO_llrs, decoded_KO_msg_bits = KO.deeppolar_decode(noisy_KO_code) + else: # if SC is also used for KO + KO_llrs, decoded_KO_msg_bits = KO.sc_decode_new(noisy_KO_code, snr) + + ber_KO = errors_ber(msg_bits, decoded_KO_msg_bits.sign()).item() + bler_KO = errors_bler(msg_bits, decoded_KO_msg_bits.sign()).item() + total_block_errors_KO += int(bler_KO*args.test_batch_size) + + batches_processed += 1 + + # Update accumulative results for logging + bers_KO_test[snr_ind] += ber_KO + bers_SC_test[snr_ind] += ber_SC + blers_KO_test[snr_ind] += bler_KO + blers_SC_test[snr_ind] += bler_SC + + # Real-time logging for progress, updating in-place + print(f"SNR: {snr} dB, Sigma: {sigma:.5f}, SC_BER: {bers_SC_test[snr_ind]/batches_processed:.6f}, SC_BLER: {blers_SC_test[snr_ind]/batches_processed:.6f}, KO_BER: {bers_KO_test[snr_ind]/batches_processed:.6f}, KO_BLER: {blers_KO_test[snr_ind]/batches_processed:.6f}, Batches: {batches_processed}", end='\r') + + except KeyboardInterrupt: + # print("\nInterrupted by user. Finalizing current SNR...") + pass + + # Normalize cumulative metrics by the number of processed batches for accuracy + bers_KO_test[snr_ind] /= (batches_processed + 0.00000001) + bers_SC_test[snr_ind] /= (batches_processed + 0.00000001) + blers_KO_test[snr_ind] /= (batches_processed + 0.00000001) + blers_SC_test[snr_ind] /= (batches_processed + 0.00000001) + print(f"SNR: {snr} dB, Sigma: {sigma:.5f}, SC_BER: {bers_SC_test[snr_ind]:.6f}, SC_BLER: {blers_SC_test[snr_ind]:.6f}, KO_BER: {bers_KO_test[snr_ind]:.6f}, KO_BLER: {blers_KO_test[snr_ind]:.6f}") + + return bers_SC_test, blers_SC_test, bers_KO_test, blers_KO_test + + + diff --git a/deeppolar-main/trainer_utils.py b/deeppolar-main/trainer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..14e2663e9955a8cc3a80c38ac4b9b078b8a7c916 --- /dev/null +++ b/deeppolar-main/trainer_utils.py @@ -0,0 +1,48 @@ +import torch +from utils import moving_average +import matplotlib.pyplot as plt +import os + +def save_model(polar, iter, results_save_path, best = False): + torch.save([polar.fnet_dict, polar.gnet_dict, polar.depth_map], os.path.join(results_save_path, 'Models/fnet_gnet_{}.pt'.format(iter))) + if iter > 1: + torch.save([polar.fnet_dict, polar.gnet_dict, polar.depth_map], os.path.join(results_save_path, 'Models/fnet_gnet_{}.pt'.format('final'))) + if best: + torch.save([polar.fnet_dict, polar.gnet_dict, polar.depth_map], os.path.join(results_save_path, 'Models/fnet_gnet_{}.pt'.format('best'))) + +def plot_stuff(bers_enc, losses_enc, bers_dec, losses_dec, results_save_path): + plt.figure() + plt.plot(bers_enc, label = 'BER') + plt.plot(moving_average(bers_enc, n=10), label = 'BER moving avg') + plt.yscale('log') + plt.legend(loc='best') + plt.title('Training BER ENC') + plt.savefig(os.path.join(results_save_path,'training_ber_enc.png')) + plt.close() + + plt.figure() + plt.plot(losses_enc, label = 'Losses') + plt.plot(moving_average(losses_enc, n=10), label='Losses moving avg') + plt.yscale('log') + plt.legend(loc='best') + plt.title('Training loss ENC') + plt.savefig(os.path.join(results_save_path ,'training_losses_enc.png')) + plt.close() + + plt.figure() + plt.plot(bers_dec, label = 'BER') + plt.plot(moving_average(bers_dec, n=10), label = 'BER moving avg') + plt.yscale('log') + plt.legend(loc='best') + plt.title('Training BER DEC') + plt.savefig(os.path.join(results_save_path,'training_ber_dec.png')) + plt.close() + + plt.figure() + plt.plot(losses_dec, label = 'Losses') + plt.plot(moving_average(losses_dec, n=10), label='Losses moving avg') + plt.yscale('log') + plt.legend(loc='best') + plt.title('Training loss DEC') + plt.savefig(os.path.join(results_save_path ,'training_losses_dec.png')) + plt.close() \ No newline at end of file diff --git a/deeppolar-main/utils.py b/deeppolar-main/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0369f86c7343cde843095ac7eb2d7f6bcea38bf1 --- /dev/null +++ b/deeppolar-main/utils.py @@ -0,0 +1,170 @@ +import torch +import torch.nn.functional as F +from torch.distributions import Normal, StudentT +import numpy as np +from itertools import combinations + +def snr_db2sigma(train_snr): + return 10**(-train_snr*1.0/20) + +def moving_average(a, n=3) : + ret = np.cumsum(a, dtype=float) + ret[n:] = ret[n:] - ret[:-n] + return ret[n - 1:] / n + +def errors_ber(y_true, y_pred, mask=None): + if mask == None: + mask=torch.ones(y_true.size(),device=y_true.device) + y_true = y_true.view(y_true.shape[0], -1, 1) + y_pred = y_pred.view(y_pred.shape[0], -1, 1) + mask = mask.view(mask.shape[0], -1, 1) + myOtherTensor = (mask*torch.ne(torch.round(y_true), torch.round(y_pred))).float() + res = sum(sum(myOtherTensor))/(torch.sum(mask)) + return res + +def errors_bler(y_true, y_pred, get_pos = False): + y_true = y_true.view(y_true.shape[0], -1, 1) + y_pred = y_pred.view(y_pred.shape[0], -1, 1) + + decoded_bits = torch.round(y_pred).cpu() + X_test = torch.round(y_true).cpu() + tp0 = (abs(decoded_bits-X_test)).view([X_test.shape[0],X_test.shape[1]]) + tp0 = tp0.detach().cpu().numpy() + bler_err_rate = sum(np.sum(tp0,axis=1)>0)*1.0/(X_test.shape[0]) + + if not get_pos: + return bler_err_rate + else: + err_pos = list(np.nonzero((np.sum(tp0,axis=1)>0).astype(int))[0]) + return bler_err_rate, err_pos + +def corrupt_signal(input_signal, sigma = 1.0, noise_type = 'awgn', vv =5.0, radar_power = 20.0, radar_prob = 0.05): + + data_shape = input_signal.shape # input_signal has to be a numpy array. + assert noise_type in ['bsc', 'awgn', 'fading', 'radar', 't-dist', 'isi_perfect', 'isi_uncertain'], "Invalid noise type" + device = input_signal.device + if noise_type == 'awgn': + dist = Normal(torch.tensor([0.0], device = device), torch.tensor([sigma], device = device)) + noise = dist.sample(input_signal.shape).squeeze() + corrupted_signal = input_signal + noise + + elif noise_type == 'fading': + fading_h = torch.sqrt(torch.randn_like(input_signal)**2 + torch.randn_like(input_signal)**2)/np.sqrt(3.14/2.0) + noise = sigma * torch.randn_like(input_signal) # Define noise + corrupted_signal = fading_h *(input_signal) + noise + + elif noise_type == 'radar': + add_pos = np.random.choice([0.0, 1.0], data_shape, + p=[1 - radar_prob, radar_prob]) + + corrupted_signal = radar_power* np.random.standard_normal( size = data_shape ) * add_pos + noise = sigma * torch.randn_like(input_signal) +\ + torch.from_numpy(corrupted_signal).float().to(input_signal.device) + corrupted_signal = input_signal + noise + + elif noise_type == 't-dist': + dist = StudentT(torch.tensor([vv], device = device)) + noise = sigma* dist.sample(input_signal.shape).squeeze() + corrupted_signal = input_signal + noise + + return corrupted_signal + +def snr_db2sigma(train_snr): + return 10**(-train_snr*1.0/20) + +def min_sum_log_sum_exp(x, y): + + log_sum_ms = torch.min(torch.abs(x), torch.abs(y))*torch.sign(x)*torch.sign(y) + return log_sum_ms + +def min_sum_log_sum_exp_4(x_1, x_2, x_3, x_4): + return min_sum_log_sum_exp(min_sum_log_sum_exp(x_1, x_2), min_sum_log_sum_exp(x_3, x_4)) + +def log_sum_exp(x, y): + def log_sum_exp_(LLR_vector): + + sum_vector = LLR_vector.sum(dim=1, keepdim=True) + sum_concat = torch.cat([sum_vector, torch.zeros_like(sum_vector)], dim=1) + + return torch.logsumexp(sum_concat, dim=1)- torch.logsumexp(LLR_vector, dim=1) + + + Lv = log_sum_exp_(torch.cat([x.unsqueeze(2), y.unsqueeze(2)], dim=2).permute(0, 2, 1)) + return Lv + +def dec2bitarray(in_number, bit_width): + """ + Converts a positive integer to NumPy array of the specified size containing + bits (0 and 1). + Parameters + ---------- + in_number : int + Positive integer to be converted to a bit array. + bit_width : int + Size of the output bit array. + Returns + ------- + bitarray : 1D ndarray of ints + Array containing the binary representation of the input decimal. + """ + + binary_string = bin(in_number) + length = len(binary_string) + bitarray = np.zeros(bit_width, 'int') + for i in range(length-2): + bitarray[bit_width-i-1] = int(binary_string[length-i-1]) + + return bitarray + + +def countSetBits(n): + + count = 0 + while (n): + n &= (n-1) + count+= 1 + + return count + +class STEQuantize(torch.autograd.Function): + @staticmethod + def forward(ctx, inputs, enc_quantize_level = 2, enc_value_limit = 1.0, enc_grad_limit = 0.01, enc_clipping = 'both'): + + ctx.save_for_backward(inputs) + assert enc_clipping in ['both', 'inputs'] + ctx.enc_clipping = enc_clipping + ctx.enc_value_limit = enc_value_limit + ctx.enc_quantize_level = enc_quantize_level + ctx.enc_grad_limit = enc_grad_limit + + x_lim_abs = enc_value_limit + x_lim_range = 2.0 * x_lim_abs + x_input_norm = torch.clamp(inputs, -x_lim_abs, x_lim_abs) + + if enc_quantize_level == 2: + outputs_int = torch.sign(x_input_norm) + else: + outputs_int = torch.round((x_input_norm +x_lim_abs) * ((enc_quantize_level - 1.0)/x_lim_range)) * x_lim_range/(enc_quantize_level - 1.0) - x_lim_abs + + return outputs_int + + @staticmethod + def backward(ctx, grad_output): + if ctx.enc_clipping in ['inputs', 'both']: + input, = ctx.saved_tensors + grad_output[input>ctx.enc_value_limit]=0 + grad_output[input<-ctx.enc_value_limit]=0 + + if ctx.enc_clipping in ['gradient', 'both']: + grad_output = torch.clamp(grad_output, -ctx.enc_grad_limit, ctx.enc_grad_limit) + grad_input = grad_output.clone() + + return grad_input, None + + +def pairwise_distances(codebook): + dists = [] + for row1, row2 in combinations(codebook, 2): + distance = (row1-row2).pow(2).sum() + dists.append(np.sqrt(distance.item())) + return dists, np.min(dists) \ No newline at end of file