提交 8e7948b2 编写于 作者: R root

Tue Jul 1 18:12:01 CST 2025 inscode

上级 0a575cc3
run = "pip install -r requirements.txt;python main.py"
language = "python"
[packager]
AUTO_PIP = true
[env]
VIRTUAL_ENV = "/root/${PROJECT_DIR}/venv"
PATH = "${VIRTUAL_ENV}/bin:${PATH}"
PYTHONPATH = "$PYTHONHOME/lib/python3.10:${VIRTUAL_ENV}/lib/python3.10/site-packages"
REPLIT_POETRY_PYPI_REPOSITORY = "http://mirrors.csdn.net.cn/repository/csdn-pypi-mirrors/simple"
MPLBACKEND = "TkAgg"
POETRY_CACHE_DIR = "/root/${PROJECT_DIR}/.cache/pypoetry"
[debugger]
program = "main.py"
run = "cd deeppolar-main && python main.py --test --N 256 --K 37 --kernel_size 16 --test_snr_start -5 --test_snr_end 5 --snr_points 5"
is_gui = false
is_resident = true
is_html = false
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
MIT License
Copyright (c) 2024 Ashwin Hebbar
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# DeepPolar codes
Code for "[DeepPolar: Inventing Nonlinear Large-Kernel Polar Codes via Deep Learning](https://arxiv.org/abs/2402.08864)", ICML 2024
## Installation
First, clone the repository to your local machine:
```bash
git clone https://github.com/hebbarashwin/deeppolar.git
cd deeppolar
```
Then, install the required Python packages:
```bash
pip install -r requirements.txt
```
## Usage
Best results are obtained by pretraining kernels using curriculum training and initializing the network using these pretrained kernels. (training from scratch may work too)
An exemplar kernel has been provided. Command to run:
(You can set --id for different runs.)
```bash
python -u main.py --N 256 --K 37 -ell 16 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 2000 --enc_train_snr 0 --dec_train_snr -2 --enc_hidden_size 64 --dec_hidden_size 128 --enc_lr 0.0001 --dec_lr 0.0001 --weight_decay 0 --test_snr_start -5 --test_snr_end -1 --snr_points 5 --batch_size 20000 --id run1 --kernel_load_path Polar_Results/curriculum/final_kernels/16_normal_polar_eh64_dh128_selu --gpu -2
```
- `N`, `K`: Code parameters
- `-ell`, Kernel size; \sqrt{N} works best
- `kernel_load_path`: Path to load specific model kernels. (if training from scratch, don't set this flag)
- `enc_train_iters`, `dec_train_iters`: Number of training iterations for the encoder and decoder.
- `full_iters`: Total iterations for full training cycles.
- `id`: Identifier for the run.
- `model_save_per`: Frequency of saving the trained models.
- `gpu` : -2 : cuda, -1 : cpu, 0/1/2/3 : specific gpu
The kernels can be pretrained, for example by running
```bash
bash pretrain.sh
```
(Typically we don't need to train each kernel for as many iterations as this script.)
Testing
```bash
python -u main.py --N 256 --K 37 -ell 16 --enc_hidden_size 64 --dec_hidden_size 128 --test_snr_start -5 --test_snr_end -1 --snr_points 5 --test_batch_size 10000 --id run1 --weight_decay 0. --num_errors 100 --test
```
(More details will be added soon.)
- Finetuning with increasingly large batch sizes improves high-SNR performance.
- BER gain can be traded off for BLER by finetuning with a BLER surrogate loss.
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from polar import PolarCode, get_frozen
from models import g_Full, f_Full, weights_init
from utils import min_sum_log_sum_exp, min_sum_log_sum_exp_4, countSetBits, log_sum_exp, STEQuantize
from collections import defaultdict
import os
###��������ģ�ͼܹ��Ż���
#������Ƚ���������ܹ�����Transformer��ResNetģ�飩
#����ע�������������õز�׽����֮��Ĺ�ϵ
#ʵ�ָ�������ȿ����üܹ�
class TransformerLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = nn.ReLU()
def forward(self, src):
# src��״: (sequence_length, batch_size, d_model)
src2 = self.self_attn(src, src, src)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
####ADD down
class DeepPolar(PolarCode):
def __init__(self, args, device, N, K, ell = 2, infty = 1000., depth_map : defaultdict = None):
# ��ʼ��DeepPolar��
# args: ����ģ�Ͳ��������ö���
# device: �����豸����CPU��GPU��
# N: Polar��ij���
# K: ��Ϣλ������
# ell: ÿ����ӿ��С
# infty: �����ֵ�����ڱ�ʾ���ɿ����ŵ�
# depth_map: ÿ����ӿ��С���ֵ�
args.use_transformer = True # �Ƿ�ʹ��Transformer
args.nhead = 4 # Transformerͷ��
args.dim_feedforward = 2048 # FFN��
args.transformer_dropout = 0.1 # Dropout��
# ����Transformer���ʼ������ʼ��Transformer����޸�
self.transformer_layers = nn.ModuleDict()
if hasattr(args, 'use_transformer') and args.use_transformer:
for d in range(1, self.n_ell+1):
self.transformer_layers[str(d)] = TransformerLayer(
d_model=self.depth_map[d],
nhead=getattr(args, 'nhead', 4), # Ĭ��4��ͷ
dim_feedforward=getattr(args, 'dim_feedforward', 2048),
dropout=getattr(args, 'transformer_dropout', 0.1)
).to(device)
# �������������ʼ���޸ģ���Ҫ��DeepPolar��__init__���������Ӷ�������ij�ʼ����
if args.use_transformer:
self.transformer_layers = nn.ModuleDict()
for d in range(1, self.n_ell+1):
self.transformer_layers[str(d)] = TransformerLayer(
d_model=self.depth_map[d],
nhead=args.nhead,
dim_feedforward=args.dim_feedforward,
dropout=args.dropout_p
)
if args.use_attention:
self.attention_layers = nn.ModuleDict()
for d in range(1, self.n_ell+1):
self.attention_layers[str(d)] = nn.MultiheadAttention(
embed_dim=self.depth_map[d],
num_heads=args.nhead,
dropout=args.dropout_p
)
# ��ȡ����λ����Ϣλ��
Fr = get_frozen(N, K, self.args.rate_profile)
# ���ø���PolarCode�ij�ʼ������
super().__init__(n = int(np.log2(N)), K = K, Fr=Fr, infty = infty)
self.N = N # Polar��ij���
# ����ṩ��depth_map����ʹ������ȷ��ÿһ����ӿ��С
if depth_map is not None:
# ���depth_map��ֵ�ij˻��Ƿ����N
assert np.prod(list(depth_map.values())) == N
# ���depth_map�ļ��Ƿ��1��ʼ������ij����
assert min(list(depth_map.keys())) == 1
assert max(list(depth_map.keys())) <= int(np.log2(N))
self.ell = None # ��ʹ�ù̶���ellֵ
self.n_ell = len(depth_map.keys()) # ��Ȳ���
assert max(list(depth_map.keys())) == self.n_ell # ��������Ƿ������Ȳ���
self.depth_map = depth_map # �洢depth_map
else:
# ���û���ṩdepth_map����ʹ�ù̶���ellֵ
self.ell = ell
self.n_ell = int(np.log(N)/np.log(self.ell)) # ������Ȳ���
self.depth_map = defaultdict(int) # ����һ��Ĭ���ֵ�
for d in range(1, self.n_ell+1):
self.depth_map[d] = self.ell # Ϊÿһ�������ӿ��С
assert np.prod(list(self.depth_map.values())) == N # ����ӿ��С�ij˻��Ƿ����N
self.device = device # �洢�����豸
self.fnet_dict = None # ���ڴ洢��������
self.gnet_dict = None # ���ڴ洢��������
self.infty = infty # �洢�����ֵ
@staticmethod
def get_onehot(actions):
# ������ת��Ϊone-hot����
# actions: ��������
# ����: one-hot���������
inds = (0.5 + 0.5*actions).long() # ������ת��Ϊ����
if len(actions.shape) == 2:
# ������������Ƕ�ά�ģ�����one-hot����
return torch.eye(2, device = inds.device)[inds].reshape(actions.shape[0], -1)
elif len(actions.shape) == 3:
# ���������������ά�ģ�����one-hot����
return torch.eye(2, device = inds.device)[inds].reshape(actions.shape[0], actions.shape[1], -1)
def define_kernel_nns(self, ell, unfrozen = None, fnet = 'KO', gnet = 'KO', shared = False):
# �������ͱ�������
# ell: ÿ����ӿ��С
# unfrozen: δ�����λλ��
# fnet: ������������
# gnet: ������������
# shared: �Ƿ����������
if 'KO' in fnet:
self.fnet_dict = {} # ��ʼ�����������ֵ�
else:
self.fnet_dict = None # �����ʹ�ý������磬����ΪNone
self.shared = shared # �洢�Ƿ�������
if 'KO' in gnet:
self.gnet_dict = {} # ��ʼ�����������ֵ�
else:
self.gnet_dict = None # �����ʹ�ñ������磬����ΪNone
dec_hidden_size = self.args.dec_hidden_size # ������������ز��С
enc_hidden_size = self.args.enc_hidden_size # ������������ز��С
depth = 1 # ��ǰ���
assert len(unfrozen) > 0, "No unfrozen bits!" # ����Ƿ���δ�����λ
self.fnet_dict[depth] = {} # ��ʼ����ǰ��ȵĽ�������
if fnet == 'KO_parallel' or fnet == 'KO_last_parallel':
bit_position = 0 # ��ǰλλ��
self.fnet_dict[depth][bit_position] = {} # ��ʼ����ǰλλ�õĽ�������
# �����С
input_size = ell
# �����С
output_size = ell
# ������������
self.fnet_dict[depth][bit_position] = f_Full(input_size, dec_hidden_size, output_size, activation = self.args.dec_activation, dropout_p = self.args.dropout_p, depth = self.args.f_depth, use_norm = self.args.use_norm).to(self.device)
elif 'KO' in fnet:
if shared:
self.fnet_dict[depth] = {} # ��ʼ����ǰ��ȵĽ�������
for current_position in range(ell):
# �������������Ľ�������
self.fnet_dict[depth][current_position] = f_Full(ell + current_position, dec_hidden_size, 1, activation = self.args.dec_activation, dropout_p = self.args.dropout_p, depth = self.args.f_depth, use_norm = self.args.use_norm).to(self.device)
else:
bit_position = 0 # ��ǰλλ��
for current_position in unfrozen:
if not self.fnet_dict[depth].get(bit_position):
self.fnet_dict[depth][bit_position] = {} # ��ʼ����ǰλλ�õĽ�������
input_size = ell + (int(self.args.onehot)+1)*current_position # �����С
# ������������
self.fnet_dict[depth][bit_position][current_position] = f_Full(input_size, dec_hidden_size, 1, activation = self.args.dec_activation, dropout_p = self.args.dropout_p, depth = self.args.f_depth, use_norm = self.args.use_norm).to(self.device)
if 'KO' in gnet:
self.gnet_dict[depth] = {} # ��ʼ����ǰ��ȵı�������
if shared:
if gnet == 'KO':
# �������������ı�������
self.gnet_dict[depth] = g_Full(ell, enc_hidden_size, ell-1, depth = self.args.g_depth, skip_depth = self.args.g_skip_depth, skip_layer = self.args.g_skip_layer, ell = ell, activation = self.args.enc_activation, use_skip = self.args.skip).to(self.device)
else:
bit_position = 0 # ��ǰλλ��
if gnet == 'KO':
# ������������
self.gnet_dict[depth][bit_position] = g_Full(ell, enc_hidden_size, ell-1, depth = self.args.g_depth, skip_depth = self.args.g_skip_depth, skip_layer = self.args.g_skip_layer, ell = ell, activation = self.args.enc_activation, use_skip = self.args.skip).to(self.device)
def define_and_load_nns(self, ell, kernel_load_path = None, fnet = 'KO', gnet = 'KO', shared = True, dataparallel = False):
# ���岢���ؽ���ͱ�������
# ell: ÿ����ӿ��С
# kernel_load_path: �������������·��
# fnet: ������������
# gnet: ������������
# shared: �Ƿ����������
# dataparallel: �Ƿ�ʹ�����ݲ���
if 'KO' in fnet:
self.fnet_dict = {} # ��ʼ�����������ֵ�
else:
self.fnet_dict = None # �����ʹ�ý������磬����ΪNone
self.shared = shared # �洢�Ƿ�������
if 'KO' in gnet:
self.gnet_dict = {} # ��ʼ�����������ֵ�
else:
self.gnet_dict = None # �����ʹ�ñ������磬����ΪNone
dec_hidden_size = self.args.dec_hidden_size # ������������ز��С
enc_hidden_size = self.args.enc_hidden_size # ������������ز��С
for depth in range(self.n_ell, 0, -1):
if depth in self.args.polar_depths:
continue # �����ǰ�����ָ��������б��У�����
ell = self.depth_map[depth] # ��ȡ��ǰ��ȵ��ӿ��С
proj_size = np.prod([self.depth_map[d] for d in range(1, depth+1)]) # ���㵱ǰ��ȵ�ͶӰ��С
if fnet == 'KO_last_parallel' and depth == 1:
self.fnet_dict[depth] = {} # ��ʼ����ǰ��ȵĽ�������
for bit_position in range(self.N // proj_size):
proj = np.arange(bit_position*proj_size, (bit_position+1)*proj_size) # ���㵱ǰλλ�õ�ͶӰ
get_num_info_proj = lambda proj : sum([int(x in self.info_positions) for x in proj]) # ����ͶӰ�е���Ϣλ����
num_info_in_proj = get_num_info_proj(proj) # ��ȡͶӰ�е���Ϣλ����
subproj_len = len(proj) // ell # ������ͶӰ�ij���
subproj = [proj[i:i+subproj_len] for i in range(0, len(proj), subproj_len)] # �ָ�ͶӰΪ��ͶӰ
num_info_in_subproj = [get_num_info_proj(x) for x in subproj] # ����ÿ����ͶӰ�е���Ϣλ����
unfrozen = [i for i, x in enumerate(num_info_in_subproj) if x >= 1] # ��ȡδ�������ͶӰλ��
input_size = ell # �����С
output_size = ell # �����С
# ������������
self.fnet_dict[depth][bit_position] = f_Full(input_size, dec_hidden_size, output_size, activation = self.args.dec_activation, dropout_p = self.args.dropout_p, depth = self.args.f_depth, use_norm = self.args.use_norm).to(self.device)
if len(unfrozen) > 0:
if kernel_load_path is not None:
try:
# ���Լ���Ԥѵ���Ľ����������
ckpt = torch.load(os.path.join(kernel_load_path + '_parallel', f'{ell}_{len(unfrozen)}.pt'))
ckpt_exists = True
except FileNotFoundError:
print(f"Parallel File not found for ell = {ell}, num_unfrozen = {len(unfrozen)}")
ckpt_exists = False
else:
ckpt_exists = False
if ckpt_exists:
# ���ؽ����������
f_ckpt = ckpt[0][1][0].state_dict()
self.fnet_dict[depth][bit_position].load_state_dict(f_ckpt)
if dataparallel:
# ʹ�����ݲ���
self.fnet_dict[depth][bit_position] = nn.DataParallel(self.fnet_dict[depth][bit_position])
elif 'KO' in fnet:
self.fnet_dict[depth] = {} # ��ʼ����ǰ��ȵĽ�������
if shared:
self.fnet_dict[depth] = {} # ��ʼ����ǰ��ȵĽ�������
for current_position in range(ell):
# �������������Ľ�������
self.fnet_dict[depth][current_position] = f_Full(ell + current_position, dec_hidden_size, 1, activation = self.args.dec_activation, dropout_p = self.args.dropout_p, depth = self.args.f_depth, use_norm = self.args.use_norm).to(self.device)
if dataparallel:
# ʹ�����ݲ���
self.fnet_dict[depth][current_position] = nn.DataParallel(self.fnet_dict[depth][current_position])
else:
for bit_position in range(self.N // proj_size):
proj = np.arange(bit_position*proj_size, (bit_position+1)*proj_size) # ���㵱ǰλλ�õ�ͶӰ
get_num_info_proj = lambda proj : sum([int(x in self.info_positions) for x in proj]) # ����ͶӰ�е���Ϣλ����
num_info_in_proj = get_num_info_proj(proj) # ��ȡͶӰ�е���Ϣλ����
subproj_len = len(proj) // ell # ������ͶӰ�ij���
subproj = [proj[i:i+subproj_len] for i in range(0, len(proj), subproj_len)] # �ָ�ͶӰΪ��ͶӰ
num_info_in_subproj = [get_num_info_proj(x) for x in subproj] # ����ÿ����ͶӰ�е���Ϣλ����
unfrozen = [i for i, x in enumerate(num_info_in_subproj) if x >= 1] # ��ȡδ�������ͶӰλ��
if len(unfrozen) > 0:
if kernel_load_path is not None:
try:
# ���Լ���Ԥѵ���Ľ����������
ckpt = torch.load(os.path.join(kernel_load_path, f'{ell}_{len(unfrozen)}.pt'))
ckpt_exists = True
except FileNotFoundError:
print(f"File not found for ell = {ell}, num_unfrozen = {len(unfrozen)}")
ckpt_exists = False
else:
ckpt_exists = False
for current_position in unfrozen:
if not self.fnet_dict[depth].get(bit_position):
self.fnet_dict[depth][bit_position] = {} # ��ʼ����ǰλλ�õĽ�������
input_size = ell + (int(self.args.onehot)+1)*current_position # �����С
output_size = 1 # �����С
# ������������
self.fnet_dict[depth][bit_position][current_position] = f_Full(input_size, dec_hidden_size, output_size, activation = self.args.dec_activation, dropout_p = self.args.dropout_p, depth = self.args.f_depth, use_norm = self.args.use_norm).to(self.device)
if ckpt_exists:
try:
# ���ؽ����������
f_ckpt = ckpt[0][1][0][current_position].state_dict()
except:
print(unfrozen)
self.fnet_dict[depth][bit_position][current_position].load_state_dict(f_ckpt)
if dataparallel:
# ʹ�����ݲ���
self.fnet_dict[depth][bit_position][current_position] = nn.DataParallel(self.fnet_dict[depth][bit_position][current_position])
if 'KO' in gnet:
self.gnet_dict[depth] = {} # ��ʼ����ǰ��ȵı�������
if shared:
if gnet == 'KO':
if not dataparallel:
# �������������ı�������
self.gnet_dict[depth] = g_Full(ell, enc_hidden_size, ell-1, depth = self.args.g_depth, skip_depth = self.args.g_skip_depth, skip_layer = self.args.g_skip_layer, ell = ell, use_skip = self.args.skip).to(self.device)
else:
# ʹ�����ݲ��д������������ı�������
self.gnet_dict[depth] = nn.DataParallel(g_Full(ell, enc_hidden_size, ell-1, depth = self.args.g_depth, skip_depth = self.args.g_skip_depth, skip_layer = self.args.g_skip_layer, ell = ell, use_skip = self.args.skip)).to(self.device)
else:
for bit_position in range(self.N // proj_size):
proj = np.arange(bit_position*proj_size, (bit_position+1)*proj_size) # ���㵱ǰλλ�õ�ͶӰ
get_num_info_proj = lambda proj : sum([int(x in self.info_positions) for x in proj]) # ����ͶӰ�е���Ϣλ����
num_info_in_proj = get_num_info_proj(proj) # ��ȡͶӰ�е���Ϣλ����
subproj_len = len(proj) // ell # ������ͶӰ�ij���
subproj = [proj[i:i+subproj_len] for i in range(0, len(proj), subproj_len)] # �ָ�ͶӰΪ��ͶӰ
num_info_in_subproj = [get_num_info_proj(x) for x in subproj] # ����ÿ����ͶӰ�е���Ϣλ����
unfrozen = [i for i, x in enumerate(num_info_in_subproj) if x >= 1]
if num_info_in_proj > 0:
if gnet == 'KO':
self.gnet_dict[depth][bit_position] = g_Full(ell, enc_hidden_size, ell-1, depth = self.args.g_depth, skip_depth = self.args.g_skip_depth, skip_layer = self.args.g_skip_layer, ell = ell, activation = self.args.enc_activation, use_skip = self.args.skip).to(self.device)
if kernel_load_path is not None:
try:
ckpt = torch.load(os.path.join(kernel_load_path, f'{ell}_{len(unfrozen)}.pt'))
self.gnet_dict[depth][bit_position].load_state_dict(ckpt[1][1][0].state_dict())
except FileNotFoundError:
print(f"File not found for ell = {ell}, num_unfrozen = {len(unfrozen)}")
pass
if dataparallel:
self.gnet_dict[depth][bit_position] = nn.DataParallel(self.gnet_dict[depth][bit_position])
# print(f"g : {depth}, {bit_position}, {len(unfrozen)}")
if kernel_load_path is not None:
print("Loaded kernel from ", kernel_load_path)
def load_nns(self, fnet_dict, gnet_dict = None, shared = False):
self.fnet_dict = fnet_dict
self.gnet_dict = gnet_dict
for depth in fnet_dict.keys():
if self.fnet_dict is not None:
for bit_position in self.fnet_dict[depth].keys():
if not isinstance(self.fnet_dict[depth][bit_position], dict):#shared or self.args.decoder_type == 'KO_parallel' or self.args.decoder_type == 'KO_RNN':
self.fnet_dict[depth][bit_position].to(self.device)
else:
for current_position in self.fnet_dict[depth][bit_position].keys():
self.fnet_dict[depth][bit_position][current_position].to(self.device)
if gnet_dict is not None:
if shared:
self.gnet_dict[depth].to(self.device)
else:
for bit_position in self.gnet_dict[depth].keys():
self.gnet_dict[depth][bit_position].to(self.device)
print("NN weights loaded!")
def load_partial_nns(self, fnet_dict, gnet_dict = None):
for depth in fnet_dict.keys():
if fnet_dict is not None:
for bit_position in fnet_dict[depth].keys():
if isinstance(fnet_dict[depth][bit_position], dict):
for current_position in fnet_dict[depth][bit_position].keys():
self.fnet_dict[depth][bit_position][current_position] = fnet_dict[depth][bit_position][current_position].to(self.device)
else:
self.fnet_dict[depth][bit_position] = fnet_dict[depth][bit_position].to(self.device)
if gnet_dict is not None:
for bit_position in gnet_dict[depth].keys():
self.gnet_dict[depth][bit_position] = gnet_dict[depth][bit_position].to(self.device)
print("NN weights loaded!")
def kernel_encode(self, ell, gnet, msg_bits, info_positions, binary = False):
input_shape = msg_bits.shape[-1]
assert input_shape <= ell
u = torch.ones(msg_bits.shape[0], self.N, dtype=torch.float).to(self.device)
u[:, info_positions] = msg_bits
output =torch.cat([gnet(u.unsqueeze(1)).squeeze(1), u[:, -1:]], 1)
power_constrained_u = self.power_constraint(output)
if binary:
stequantize = STEQuantize.apply
power_constrained_u = stequantize(power_constrained_u)
return power_constrained_u
####�������Ľ��ı������ܹ����滻���е�kernel_encode������
# ���������ǿ���Ľ��������Ĺ���Լ������
#��ǿ�ı������������������ӵ�DeepPolar���У���Ϊ���б���������������򲹳䡣
def enhanced_encode(self, msg_bits, binary=False):
u = torch.ones(msg_bits.shape[0], self.N, dtype=torch.float).to(self.device)
u[:, self.info_positions] = msg_bits
# ...ǰ��Ĵ���...
# Ӧ��Transformer��
chunks = [u[:, i:i+proj_size//ell].unsqueeze(1) for i in
range(bit_position*proj_size, (bit_position+1)*proj_size, proj_size//ell)]
concatenated = torch.cat(chunks, 1) # shape: (batch_size, num_chunks, chunk_size)
# ת��ΪTransformer��Ҫ����״������
transformer_input = concatenated.permute(1, 0, 2) # (num_chunks, batch_size, chunk_size)
transformer_out = self.transformer_layers[d](transformer_input)
output = transformer_out.permute(1, 0, 2) # ת����(batch_size, num_chunks, chunk_size)
# ...����Ĵ���...
# ����λ�ñ���
position = torch.arange(self.N).unsqueeze(0).to(self.device)
pos_enc = torch.zeros_like(u)
pos_enc[:, :] = position.float()
u = u + pos_enc
# �ֲ㴦��
for d in range(1, self.n_ell+1):
proj_size = np.prod([self.depth_map[dd] for dd in range(1, d+1)])
ell = self.depth_map[d]
for bit_position in range(self.N // proj_size):
proj = np.arange(bit_position*proj_size, (bit_position+1)*proj_size)
num_info = sum([int(x in self.info_positions) for x in proj])
if num_info > 0:
# ʹ��Transformer������Ϣλ
chunks = [u[:, i:i+proj_size//ell].unsqueeze(1) for i in
range(bit_position*proj_size, (bit_position+1)*proj_size, proj_size//ell)]
concatenated = torch.cat(chunks, 1)
# Ӧ��Transformer��
transformer_out = self.transformer_layers[d](concatenated.permute(1, 0, 2))
output = transformer_out.permute(1, 0, 2).reshape(u.shape[0], -1)
u = torch.cat((u[:, :bit_position*proj_size], output,
u[:, (bit_position+1)*proj_size:]), dim=1)
# �Ľ��Ĺ���Լ��
power_constrained_u = self.enhanced_power_constraint(u)
if binary:
stequantize = STEQuantize.apply
power_constrained_u = stequantize(power_constrained_u)
return power_constrained_u
def enhanced_power_constraint(self, codewords):
# ��������Ӧ���ʵ���
norm = torch.norm(codewords, p=2, dim=1, keepdim=True)
scale = torch.sqrt(self.N) / (norm + 1e-6)
return codewords * scale
####ADD Down
#�Ľ��Ľ������ܹ����滻���е�deeppolar_decode������
#��ǿ�������IJ��д����������Ż���Ϣλ�����߼�
###��ǿ�Ľ�����������ͬ�����ӵ�DeepPolar���У����������н������������棬ͨ����������ʹ�����ֽ��뷽ʽ��
def enhanced_decode(self, noisy_code):
assert noisy_code.shape[1] == self.N
# ��ʼ��
decoded_llrs = self.infty*torch.ones(noisy_code.shape[0], self.N, device=noisy_code.device)
partial_sums = torch.ones(noisy_code.shape[0], self.n_ell+1, self.N, device=noisy_code.device)
# ��������Ԥ����
noisy_code = self.noise_preprocess(noisy_code)
# �ݹ����
decoded_llrs, partial_sums = self.enhanced_decode_depth(
noisy_code.unsqueeze(2), self.n_ell, 0, decoded_llrs, partial_sums)
return decoded_llrs[:, self.info_positions], torch.sign(decoded_llrs[:, self.info_positions])
def noise_preprocess(self, noisy_code):
# Ӧ��С�����������Ԥ��������
if self.args.use_wavelet:
noisy_code = self.wavelet_denoise(noisy_code)
return noisy_code
def enhanced_decode_depth(self, llrs, depth, bit_position, decoded_llrs, partial_sums):
half_index = np.prod([self.depth_map[d] for d in range(1, depth)]) if depth > 1 else 1
ell = self.depth_map[depth]
left_bit_position = self.depth_map[depth] * bit_position
# ��Ϣλ���
proj_size = np.prod([self.depth_map[d] for d in range(1, depth+1)])
proj = np.arange(bit_position*proj_size, (bit_position+1)*proj_size)
num_info = sum([int(x in self.info_positions) for x in proj])
# ʹ��ע����������ǿ����
dec_chunks = [llrs[:, (j)*half_index:(j+1)*half_index].clone() for j in range(ell)]
if num_info > 0 and depth in self.fnet_dict:
# Ӧ��ע��������
concatenated = torch.cat(dec_chunks, 2)
if self.args.use_attention:
concatenated = self.attention_layers[depth](concatenated)
# �����
if self.args.parallel_decode:
Lu = self.fnet_dict[depth][bit_position](concatenated)
decoded_llrs[:, left_bit_position + np.array(self.info_positions)] = Lu.squeeze(1)
else:
# �����
for current_position in range(ell):
if current_position > 0:
prev_decoded = partial_sums[:, depth-1, (current_position-1)*half_index:current_position*half_index].unsqueeze(2)
dec_chunks.append(prev_decoded)
if current_position in [p-left_bit_position for p in self.info_positions if left_bit_position <= p < left_bit_position+proj_size]:
concatenated = torch.cat(dec_chunks, 2)
Lu = self.fnet_dict[depth][bit_position][current_position](concatenated)
decoded_llrs[:, left_bit_position + current_position] = Lu.squeeze(2).squeeze(1)
# �ݹ鴦��
if depth > 1:
for current_position in range(ell):
bit_position_offset = left_bit_position + current_position
if bit_position_offset in self.info_positions:
decoded_llrs, partial_sums = self.enhanced_decode_depth(
llrs[:, current_position*half_index:(current_position+1)*half_index],
depth-1, bit_position_offset, decoded_llrs, partial_sums)
return decoded_llrs, partial_sums
######ADD Down
def deeppolar_encode(self, msg_bits, binary = False):
u = torch.ones(msg_bits.shape[0], self.N, dtype=torch.float).to(self.device)
u[:, self.info_positions] = msg_bits
for d in range(1, self.n_ell+1):
# num_bits = self.ell**(d-1)
num_bits = np.prod([self.depth_map[dd] for dd in range(1, d)]) if d > 1 else 1
# proj_size = self.ell**(d)
proj_size = np.prod([self.depth_map[dd] for dd in range(1, d+1)])
ell = self.depth_map[d]
for bit_position, i in enumerate(np.arange(0, self.N, ell*num_bits)):
# [u v] encoded to [(u xor v),v)]
proj = np.arange(bit_position*proj_size, (bit_position+1)*proj_size)
get_num_info_proj = lambda proj : sum([int(x in self.info_positions) for x in proj])
num_info_in_proj = get_num_info_proj(proj)
subproj_len = len(proj) // ell
subproj = [proj[i:i+subproj_len] for i in range(0, len(proj), subproj_len)]
num_info_in_subproj = [get_num_info_proj(x) for x in subproj]
num_nonzero_subproj = sum([int(x != 0) for x in num_info_in_subproj])
if num_info_in_proj > 0:
info_bits_present = True
else:
info_bits_present = False
if d in self.args.polar_depths:
info_bits_present = False
enc_chunks = []
ell = self.depth_map[d]
for j in range(ell):
chunk = u[:, i + j*num_bits:i + (j+1)*num_bits].unsqueeze(2).clone()
enc_chunks.append(chunk)
if info_bits_present:
concatenated_chunks = torch.cat(enc_chunks, 2)
if self.shared:
output = torch.cat([self.gnet_dict[d](concatenated_chunks), u[:, i + (ell-1)*num_bits:i + (ell)*num_bits].unsqueeze(2)], dim=2)
else:
output = torch.cat([self.gnet_dict[d][bit_position](concatenated_chunks), u[:, i + (ell-1)*num_bits:i + (ell)*num_bits].unsqueeze(2)], dim=2)
output = output.permute(0,2,1).reshape(msg_bits.shape[0], -1, 1).squeeze(2)
else:
output = self.encode_chunks_plotkin(enc_chunks, ell)
u = torch.cat((u[:, :i], output, u[:, i + ell*num_bits:]), dim=1)
power_constrained_u = self.power_constraint(u)
if binary:
stequantize = STEQuantize.apply
power_constrained_u = stequantize(power_constrained_u)
return power_constrained_u
def power_constraint(self, codewords):
return F.normalize(codewords, p=2, dim=1)*np.sqrt(self.N)
def encode_chunks_plotkin(self, enc_chunks, ell = None):
# message shape is (batch, k)
# BPSK convention : 0 -> +1, 1 -> -1
# Therefore, xor(a, b) = a*b
# to change for other kernels
if ell is None:
ell = self.ell
assert len(enc_chunks) == ell
chunk_size = enc_chunks[0].shape[1]
batch_size = enc_chunks[0].shape[0]
u = torch.cat(enc_chunks, 1).squeeze(2)
n = int(np.log2(ell))
for d in range(0, n):
num_bits = 2**d * chunk_size
for i in np.arange(0, chunk_size*ell, 2*num_bits):
# [u v] encoded to [(u,v) xor v]
u = torch.cat((u[:, :i], u[:, i:i+num_bits].clone() * u[:, i+num_bits: i+2*num_bits], u[:, i+num_bits:]), dim=1)
return u
def deeppolar_parallel_decode(self, noisy_code):
# Successive cancellation decoder for polar codes
assert noisy_code.shape[1] == self.N
depth = self.n_ell
decoded_llrs = self.infty*torch.ones(noisy_code.shape[0], self.N, device = noisy_code.device)
# function is recursively called (DFS)
# arguments: Beliefs at the input of node (LLRs at top node), depth of children, bit_position (zero at top node)
decoded_llrs = self.KO_parallel_decode_depth(noisy_code.unsqueeze(2), depth, 0, decoded_llrs)
decoded_llrs = decoded_llrs[:, self.info_positions]
return decoded_llrs, torch.sign(decoded_llrs)
def deeppolar_parallel_decode_depth(self, llrs, depth, bit_position, decoded_llrs):
# Function to call recursively, for SC decoder
# half_index = self.ell ** (depth - 1)
half_index = np.prod([self.depth_map[d] for d in range(1, depth)]) if depth > 1 else 1
ell = self.depth_map[depth]
left_bit_position = self.depth_map[depth] * bit_position
# Check if >1 information bits are present in the current projection. If not, don't use NNs - use polar encoding and minsum SC decoding.
# proj_size = self.ell**(depth)
proj_size = np.prod([self.depth_map[d] for d in range(1, depth+1)])
proj = np.arange(bit_position*proj_size, (bit_position+1)*proj_size)
get_num_info_proj = lambda proj : sum([int(x in self.info_positions) for x in proj])
get_info_proj = lambda proj : [x for x in proj if x in self.info_positions]
num_info_in_proj = get_num_info_proj(proj)
info_in_proj = get_info_proj(proj)
subproj_len = len(proj) // ell
subproj = [proj[i:i+subproj_len] for i in range(0, len(proj), subproj_len)]
num_info_in_subproj = [get_num_info_proj(x) for x in subproj]
num_nonzero_subproj = sum([int(x != 0) for x in num_info_in_subproj])
unfrozen = np.array([i for i, x in enumerate(num_info_in_subproj) if x >= 1])
dec_chunks = torch.cat([llrs[:, (j)*half_index:(j+1)*half_index].clone() for j in range(ell)], 2)
Lu = self.fnet_dict[depth][bit_position](dec_chunks)
if depth == 1:
u = torch.tanh(Lu/2)
decoded_llrs[:, left_bit_position + unfrozen] = Lu.squeeze(1)
else:
for index, current_position in enumerate(unfrozen):
bit_position_offset = left_bit_position + current_position
decoded_llrs = self.deeppolar_parallel_decode_depth(Lu[:, :, index:index+1], depth-1, bit_position_offset, decoded_llrs)
return decoded_llrs
def deeppolar_decode(self, noisy_code):
assert noisy_code.shape[1] == self.N
depth = self.n_ell
decoded_llrs = self.infty*torch.ones(noisy_code.shape[0], self.N, device = noisy_code.device)
# don't want to go into useless frozen subtrees.
partial_sums = torch.ones(noisy_code.shape[0], self.n_ell+1, self.N, device=noisy_code.device)
# function is recursively called (DFS)
# arguments: Beliefs at the input of node (LLRs at top node), depth of children, bit_position (zero at top node)
decoded_llrs, partial_sums = self.deeppolar_decode_depth(noisy_code.unsqueeze(2), depth, 0, decoded_llrs, partial_sums)
decoded_llrs = decoded_llrs[:, self.info_positions]
return decoded_llrs, torch.sign(decoded_llrs)
def deeppolar_decode_depth(self, llrs, depth, bit_position, decoded_llrs, partial_sums):
# Function to call recursively, for SC decoder
# half_index = self.ell ** (depth - 1)
half_index = np.prod([self.depth_map[d] for d in range(1, depth)]) if depth > 1 else 1
ell = self.depth_map[depth]
left_bit_position = self.depth_map[depth] * bit_position
# Check if >1 information bits are present in the current projection. If not, don't use NNs - use polar encoding and minsum SC decoding.
# proj_size = self.ell**(depth)
# size of the projection of tht subtree
proj_size = np.prod([self.depth_map[d] for d in range(1, depth+1)])
# This chunk - finds infrozen positions in this kernel.
proj = np.arange(bit_position*proj_size, (bit_position+1)*proj_size)
get_num_info_proj = lambda proj : sum([int(x in self.info_positions) for x in proj])
get_info_proj = lambda proj : [x for x in proj if x in self.info_positions]
num_info_in_proj = get_num_info_proj(proj)
info_in_proj = get_info_proj(proj)
subproj_len = len(proj) // ell
subproj = [proj[i:i+subproj_len] for i in range(0, len(proj), subproj_len)]
num_info_in_subproj = [get_num_info_proj(x) for x in subproj]
num_nonzero_subproj = sum([int(x != 0) for x in num_info_in_subproj])
unfrozen = np.array([i for i, x in enumerate(num_info_in_subproj) if x >= 1])
if num_nonzero_subproj > 0:
info_bits_present = True
else:
info_bits_present = False
if depth in self.args.polar_depths:
info_bits_present = False
# This will be input to decoder
dec_chunks = [llrs[:, (j)*half_index:(j+1)*half_index].clone() for j in range(ell)]
# n = 2 tree case
if depth == 1:
if self.args.decoder_type == 'KO_last_parallel':
concatenated_chunks = torch.cat(dec_chunks, 2)
Lu = self.fnet_dict[depth][bit_position](concatenated_chunks)[:, 0, unfrozen]
u_hat = torch.tanh(Lu/2)
decoded_llrs[:, left_bit_position + unfrozen] = Lu
partial_sums[:, depth-1, left_bit_position + unfrozen] = u_hat
else:
for current_position in range(ell):
bit_position_offset = left_bit_position + current_position
if current_position > 0:
# I am adding previously decoded bits . (either onehot or normal)
if self.args.onehot:
prev_decoded = get_onehot(partial_sums[:, depth-1, (current_position -1)*half_index:(current_position)*half_index].unsqueeze(2).sign()).detach().clone()
else:
prev_decoded = partial_sums[:, depth-1, (current_position -1)*half_index:(current_position)*half_index].unsqueeze(2).clone()
dec_chunks.append(prev_decoded)
if bit_position_offset in self.frozen_positions: # frozen
# don't update decoded llrs. It already has ones*prior.
# actually don't need this. can skip.
partial_sums[:, depth-1, bit_position_offset] = torch.ones_like(partial_sums[:, depth-1, bit_position_offset])
else: # information bit
# This is the decoding.
concatenated_chunks = torch.cat(dec_chunks, 2)
if self.shared:
Lu = self.fnet_dict[depth][current_position](concatenated_chunks)
else:
Lu = self.fnet_dict[depth][bit_position][current_position](concatenated_chunks)
u_hat = torch.tanh(Lu/2).squeeze(2)
decoded_llrs[:, bit_position_offset] = Lu.squeeze(2).squeeze(1)
partial_sums[:, depth-1, bit_position_offset] = u_hat.squeeze(1)
# Encoding back the decoded bits - for higher layers.
# # Compute decoded codeword
i = left_bit_position * half_index
# num_bits = self.ell**(depth-1)
num_bits = 1
enc_chunks = []
for j in range(ell):
chunk = torch.sign(partial_sums[:, depth-1, i + j*num_bits:i + (j+1)*num_bits]).unsqueeze(2).detach().clone()
enc_chunks.append(chunk)
if info_bits_present:
concatenated_chunks = torch.cat(enc_chunks, 2)
if 'KO' in self.args.encoder_type:
if self.shared:
output = torch.cat([self.gnet_dict[depth](concatenated_chunks), partial_sums[:, depth-1, i + (ell-1)*num_bits:i + (ell)*num_bits].unsqueeze(2)], dim=2)
else:
# bit position of the previous depth.
output = torch.cat([self.gnet_dict[depth][bit_position](concatenated_chunks), partial_sums[:, depth-1, i + (ell-1)*num_bits:i + (ell)*num_bits].unsqueeze(2)], dim=2)
output = output.permute(0,2,1).reshape(llrs.shape[0], -1, 1).squeeze(2)
else:
output = self.encode_chunks_plotkin(enc_chunks, ell)
else:
output = self.encode_chunks_plotkin(enc_chunks, ell)
partial_sums[:, depth, i : i + num_bits*ell] = output.clone()
return decoded_llrs, partial_sums
# General case
else:
for current_position in range(ell):
bit_position_offset = left_bit_position + current_position
if current_position > 0:
if self.args.onehot:
prev_decoded = get_onehot(partial_sums[:, depth-1, (current_position -1)*half_index:(current_position)*half_index].unsqueeze(2).sign()).detach().clone()
else:
prev_decoded = partial_sums[:, depth-1, (current_position -1)*half_index:(current_position)*half_index].unsqueeze(2).clone()
dec_chunks.append(prev_decoded)
concatenated_chunks = torch.cat(dec_chunks, 2)
if current_position in unfrozen:
# General decoding ....
# add the decoded bit here
if self.shared:
Lu = self.fnet_dict[depth][current_position](concatenated_chunks).squeeze(2)
else:
# if current_position == 0:
# Lu = self.fnet_dict[depth][bit_position][current_position](llrs)
# else:
Lu = self.fnet_dict[depth][bit_position][current_position](concatenated_chunks)
decoded_llrs, partial_sums = self.deeppolar_decode_depth(Lu, depth-1, bit_position_offset, decoded_llrs, partial_sums)
else:
Lu = self.infty*torch.ones_like(llrs)
# Compute decoded codeword
if depth < self.n_ell :
i = left_bit_position * half_index
# num_bits = self.ell**(depth-1)
num_bits = np.prod([self.depth_map[d] for d in range(1, depth)])
enc_chunks = []
for j in range(ell):
chunk = torch.sign(partial_sums[:, depth-1, i + j*num_bits:i + (j+1)*num_bits]).unsqueeze(2).detach().clone()
enc_chunks.append(chunk)
if info_bits_present:
concatenated_chunks = torch.cat(enc_chunks, 2)
if 'KO' in self.args.encoder_type:
if self.shared:
output = torch.cat([self.gnet_dict[depth](concatenated_chunks), partial_sums[:, depth-1, i + (ell-1)*num_bits:i + (ell)*num_bits].unsqueeze(2)], dim=2)
else:
# bit position of the previous depth.
output = torch.cat([self.gnet_dict[depth][bit_position](concatenated_chunks), partial_sums[:, depth-1, i + (ell-1)*num_bits:i + (ell)*num_bits].unsqueeze(2)], dim=2)
output = output.permute(0,2,1).reshape(llrs.shape[0], -1, 1).squeeze(2)
else:
output = self.encode_chunks_plotkin(enc_chunks, ell)
else:
output = self.encode_chunks_plotkin(enc_chunks, ell)
partial_sums[:, depth, i : i + num_bits*ell] = output.clone()
return decoded_llrs, partial_sums
else: # encoding not required for last level - we have already decoded all bits.
return decoded_llrs, partial_sums
def kernel_decode(self, ell, fnet_dict, noisy_code, info_positions = None):
input_shape = noisy_code.shape[-1]
noisy_code = noisy_code.unsqueeze(2)
assert input_shape == ell
u = torch.ones(noisy_code.shape[0], self.N, dtype=torch.float).to(self.device)
decoded_llrs = self.infty*torch.ones(noisy_code.shape[0], self.N, device = noisy_code.device)
half_index = 1
dec_chunks = [noisy_code[:, (j)*half_index:(j+1)*half_index].clone() for j in range(ell)]
for current_position in range(ell):
if current_position > 0:
if self.args.onehot:
prev_decoded = get_onehot(u[:, (current_position -1)*half_index:(current_position)*half_index].unsqueeze(2).clone().sign()).detach().clone()
else:
prev_decoded = u[:, (current_position -1)*half_index:(current_position)*half_index].unsqueeze(2).clone()
dec_chunks.append(prev_decoded)
if current_position in info_positions:
if current_position in info_positions:
concatenated_chunks = torch.cat(dec_chunks, 2)
Lu = fnet_dict[current_position](concatenated_chunks)
decoded_llrs[:, current_position] = Lu.squeeze(2).squeeze(1)
u_hat = torch.tanh(Lu/2).squeeze(2)
u[:, current_position] = u_hat.squeeze(1)
return decoded_llrs[:, info_positions], u[:, info_positions]
def kernel_parallel_decode(self, ell, fnet_dict, noisy_code, info_positions = None):
input_shape = noisy_code.shape[-1]
noisy_code = noisy_code.unsqueeze(2)
assert input_shape == ell
u = torch.ones(noisy_code.shape[0], self.N, dtype=torch.float).to(self.device)
decoded_llrs = self.infty*torch.ones(noisy_code.shape[0], self.N, device = noisy_code.device)
half_index = 1
dec_chunks = torch.cat([noisy_code[:, (j)*half_index:(j+1)*half_index].clone() for j in range(ell)], 2)
decoded_llrs = fnet_dict(dec_chunks).squeeze(1)
u = torch.tanh(decoded_llrs/2).squeeze(1)
return decoded_llrs[:, info_positions], u[:, info_positions]
# ѵ�������Ż�����
# �γ�ѧϰ�������ӱ��븴�ӶȺ�����ˮƽ
# ��Ͼ���ѵ����ʹ��torch.cuda.amp����ѵ��
# �Ľ�����ʧ��������Ͻ����غ��Զ�����ʧ
def enhanced_loss_function(pred_llrs, true_bits, snr):
# ������������ʧ
ce_loss = F.binary_cross_entropy_with_logits(pred_llrs, (true_bits+1)/2)
# �Զ���ɿ�����ʧ
reliability = torch.sigmoid(torch.abs(pred_llrs))
rel_loss = -torch.mean(reliability)
# SNR����Ӧ��Ȩ
snr_weight = 1.0 / (1.0 + torch.exp(-0.1*(snr-10)))
total_loss = ce_loss + snr_weight * rel_loss
return total_loss
# ѵ�������Ż���
# ʵ�ֿγ�ѧϰ����
# ���Ӹ����ӵ����򻯼���
# �Ľ���ʧ����
##ADD down
\ No newline at end of file
deeppolar-main/figures/256_37_improved_bler.pdf
\ No newline at end of file
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
import os
import time
import matplotlib
matplotlib.use('AGG')
import matplotlib.pyplot as plt
from deeppolar import DeepPolar
from polar import PolarCode, get_frozen
from trainer import train, deeppolar_full_test
from trainer_utils import save_model, plot_stuff
from collections import defaultdict
from itertools import combinations
from utils import snr_db2sigma, pairwise_distances
import random
import numpy as np
from tqdm import tqdm
import sys
import csv
# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
from IPython import display
plt.ion()
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def parse_int_list(input_string):
"""Converts a comma-separated string into a list of integers."""
try:
if len(input_string) == 0 :
return []
return [int(item) for item in input_string.split(',')]
except ValueError:
raise argparse.ArgumentTypeError(f"List must contain integers, got '{input_string}'")
# NN definition - define_kernel_nns for kernel, define_and_load_nns for general NN
# Encoding kernel_encode or deeppolar_encode
# Decoding kernel_decode, or deeppolar_decode
def get_args():
parser = argparse.ArgumentParser(description='DeepPolar codes')
# General parameters
parser.add_argument('--id', type=str, default=None, help='ID: optional, to run multiple runs of same hyperparameters') #Will make a folder like init_932 , etc.
parser.add_argument('--test', dest = 'test', default=False, action='store_true', help='Testing?')
parser.add_argument('--pairwise', dest = 'pairwise', default=False, action='store_true', help='Plot codeword pairwise distances')
parser.add_argument('--epos', dest = 'epos', default=False, action='store_true', help='Plot error positions')
parser.add_argument('--only_args', dest = 'only_args', default=False, action='store_true', help='Helper to load functions on jupyter')
parser.add_argument('--gpu', type=int, default=-2, help='gpus used for training - e.g 0,1,3. -2 for cuda, -1 for cpu')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument('--anomaly', dest = 'anomaly', default=False, action='store_true', help='enable anomaly detection')
parser.add_argument("--dataparallel", type=str2bool, nargs='?',
const=True, default=False,
help="Use Dataparallel")
# Code parameters
parser.add_argument('--N', type=int, default=256, help = 'Block length')#, choices=[4, 8, 16, 32, 64, 128], help='Polar code parameter N')
parser.add_argument('--K', type=int, default=37, help = 'Message size')#, choices= [3, 4, 8, 16, 32, 64], help='Polar code parameter K')
parser.add_argument('--rate_profile', type=str, default='polar', choices=['RM', 'polar', 'sorted', 'last', 'rev_polar', 'custom'], help='Polar rate profiling')
# parser.add_argument('--target_K', type=int, default=None)#, choices= [3, 4, 8, 16, 32, 64], help='Polar code parameter K')
parser.add_argument('-ell', '--kernel_size', type=int, default=16, help = 'Kernel size')
parser.add_argument('--polar_depths', type=parse_int_list, default = '',help='A comma-separated list of integers')
parser.add_argument('--last_ell', type=int, default=None, help='use kernel last_ell last layer')
parser.add_argument('--infty', type=float, default=1000., help = 'Infinity value (used for frozen position LLR in polar dec)')
parser.add_argument('--lse', type=str, default='minsum', choices=['minsum', 'lse'], help='LSE function for polar SC decoder')
parser.add_argument('--hard_decision', dest = 'hard_decision', default=False, action='store_true', help='polar code sc decoding hard decision?')
# DeepPolar parameters
parser.add_argument('--encoder_type', type=str, default='KO', choices=['KO', 'scaled', 'polar'], help='Type of encoding')
parser.add_argument('--decoder_type', type=str, default='KO', choices=['KO', 'SC', 'KO_parallel', 'KO_last_parallel'], help='Type of encoding')
parser.add_argument('--enc_activation', type=str, default='selu', choices=['selu', 'leaky_relu', 'gelu', 'silu', 'elu', 'mish', 'identity'], help='Activation function')
parser.add_argument('--dec_activation', type=str, default='selu', choices=['selu', 'leaky_relu', 'gelu', 'silu', 'elu', 'mish', 'identity'], help='Activation function')
parser.add_argument('--dropout_p', type=float, default=0.)
parser.add_argument('--dec_hidden_size', type=int, default=128, help='neural network size')
parser.add_argument('--enc_hidden_size', type=int, default=64, help='neural network size')
parser.add_argument('-fd', '--f_depth', type=int, default=3, help='decoder neural network depth')
parser.add_argument('-gd', '--g_depth', type=int, default=3, help='encoder neural network depth')
parser.add_argument('-gsd', '--g_skip_depth', type=int, default=1, help='encoder neural network depth')
parser.add_argument('-gsl', '--g_skip_layer', type=int, default=1, help='encoder neural network depth')
parser.add_argument("--onehot", type=str2bool, nargs='?',
const=True, default=False,
help="Use onehot representation of prev_decoded_bits?")
parser.add_argument("--shared", type=str2bool, nargs='?',
const=True, default=False,
help="Share weights across depth?")
parser.add_argument("--skip", type=str2bool, nargs='?',
const=True, default=True,
help="Use skip")
parser.add_argument("--use_norm", type=str2bool, nargs='?',
const=True, default=False,
help="Use norm")
parser.add_argument("--binary", type=str2bool, nargs='?',
const=True, default=False,
help="")
# Training parameters
parser.add_argument('-fi', '--full_iters', type=int, default=20000, help='full iterations')
parser.add_argument('-ei', '--enc_train_iters', type=int, default=20, help='encoder iterations') #50
parser.add_argument('-di', '--dec_train_iters', type=int, default=200, help='decoder iterations') #500
parser.add_argument('--enc_train_snr', type=float, default=0., help='snr at enc are trained')
parser.add_argument('--dec_train_snr', type=float, default=-2., help='snr at dec are trained')
parser.add_argument('--initialization', type=str, default='random', choices=['random', 'zeros'], help='initialization')
parser.add_argument('--optim', type=str, default='Adam', choices=['Adam', 'RMS', 'SGD', 'AdamW'], help='optimizer type')
parser.add_argument('--weight_decay', type=float, default=0.0)
parser.add_argument('--loss', type=str, default='BCE', choices=['MSE', 'BCE', 'BCE_reg', 'L1', 'huber', 'focal', 'BCE_bler'], help='loss function')
parser.add_argument('--dec_lr', type=float, default=0.0003, help='Decoder Learning rate')
parser.add_argument('--enc_lr', type=float, default=0.0003, help='Encoder Learning rate')
parser.add_argument('--regularizer', type=str, default=None, choices=['std', 'max_deviation','polar'], help='regularize the kernel pretraining')
parser.add_argument('-rw', '--regularizer_weight', type=float, default=0.001)
parser.add_argument('--scheduler', type=str, default=None, choices = ['reduce', '1cycle'],help='size of the batches')
parser.add_argument('--scheduler_patience', type=int, default=None, help='size of the batches')
parser.add_argument('--small_batch_size', type=int, default=20000, help='size of the batches')
parser.add_argument('--batch_size', type=int, default=20000, help='size of the batches')
parser.add_argument("--batch_schedule", type=str2bool, nargs='?',
const=True, default=False,
help="Batch scheduler")
parser.add_argument('--batch_patience', type=int, default=50, help='patience')
parser.add_argument('--batch_factor', type=int, default=2, help='patience')
parser.add_argument('--min_batch_size', type=int, default=500, help='patience')
parser.add_argument('--max_batch_size', type=int, default=50000, help='patience')
parser.add_argument('--noise_type', type=str, default='awgn', choices=['fading', 'awgn', 'radar'], help='loss function')
parser.add_argument('--radar_power', type=float, default=None, help='snr at dec are trained')
parser.add_argument('--radar_prob', type=float, default=0.1, help='snr at dec are trained')
# TESTING parameters
parser.add_argument('--model_save_per', type=int, default=100, help='num of episodes after which model is saved')
parser.add_argument('--test_snr_start', type=float, default=-5., help='testing snr start')
parser.add_argument('--test_snr_end', type=float, default=-1., help='testing snr end')
parser.add_argument('--snr_points', type=int, default=5, help='testing snr num points')
parser.add_argument('--test_batch_size', type=int, default=10000, help='size of the batches')
parser.add_argument('--num_errors', type=int, default=100, help='Test until _ block errors')
parser.add_argument('--model_iters', type=int, default=None, help='by default load final model, option to load a model of x episodes')
parser.add_argument('--test_load_path', type=str, default=None, help='load test model given path')
parser.add_argument('--save_path', type=str, default=None, help='save name')
parser.add_argument('--load_path', type=str, default=None, help='load name')
parser.add_argument('--kernel_load_path', type=str, default=None, help='load name')
parser.add_argument("--no_fig", type=str2bool, nargs='?',
const=True, default=False,
help="Plot fig?")
args = parser.parse_args()
if args.small_batch_size > args.batch_size:
args.small_batch_size = args.batch_size
return args
if __name__ == '__main__':
args = get_args()
if not args.test:
print(args)
if args.anomaly:
torch.autograd.set_detect_anomaly(True)
if torch.cuda.is_available() and args.gpu != -1:
if args.gpu == -2:
device = torch.device("cuda")
else:
device = torch.device("cuda:{0}".format(args.gpu))
else:
if args.gpu != 1:
print(f"GPU device {args.gpu if args.gpu != -2 else ''} not found.")
device = torch.device("cpu")
if args.seed is not None:
torch.manual_seed(args.seed)
ID = str(np.random.randint(100000, 999999)) if args.id is None else args.id
if args.save_path is not None:
results_save_path = args.save_path
else:
if args.encoder_type == 'polar':
results_save_path = './Polar_Results/Polar({0},{1})/Scheme_{2}/KO_Decoder/{3}'.format(args.K, args.N, args.rate_profile, ID)
elif 'KO' in args.encoder_type:
if args.decoder_type == 'KO_last_parallel':
dec = '_lp'
else:
dec = ''
results_save_path = f"./Polar_Results/Polar_{args.kernel_size}({args.N},{args.K})/Scheme_{args.rate_profile}/{args.encoder_type}_Encoder{dec}_Decoder/{ID}"
elif args.encoder_type == 'scaled':
if args.decoder_type == 'SC':
results_save_path = './Polar_Results/Polar({0},{1})/Scheme_{2}/Scaled_Decoder/{3}'.format(args.K, args.N, args.rate_profile, ID)
else:
results_save_path = './Polar_Results/Polar({0},{1})/Scheme_{2}/KO_Scaled_Decoder/{3}'.format(args.K, args.N, args.rate_profile, ID)
############
## Polar Code parameters
############
K = args.K
N = args.N
###############
### Polar code
##############
### Encoder
if args.last_ell is not None:
depth_map = defaultdict(int)
n = int(np.log2(args.N // args.last_ell) // np.log2(args.kernel_size))
for d in range(1, n+1):
depth_map[d] = args.kernel_size
depth_map[n+1] = args.last_ell
assert np.prod(list(depth_map.values())) == args.N
polar = DeepPolar(args, device, args.N, args.K, infty = args.infty, depth_map = depth_map)
else:
polar = DeepPolar(args, device, args.N, args.K, args.kernel_size, args.infty)
info_inds = polar.info_positions
frozen_inds = polar.frozen_positions
print("Frozen positions : {}".format(frozen_inds))
if args.only_args:
print("Loaded args. Exiting")
sys.exit()
##############
### Neural networks
##############
ell = args.kernel_size
if args.N == ell: # Kernel pre-training
polar.define_kernel_nns(ell = args.kernel_size, unfrozen = polar.info_positions, fnet = args.decoder_type, gnet = args.encoder_type, shared = args.shared)
elif args.N > ell: # Initialize full network with pretrained kernels
polar.define_and_load_nns(ell = args.kernel_size, kernel_load_path=args.kernel_load_path, fnet = args.decoder_type, gnet = args.encoder_type, shared = args.shared, dataparallel=args.dataparallel)
if args.binary:
args.load_path = os.path.join(results_save_path, 'Models/fnet_gnet_final.pt')
assert os.path.exists(args.load_path), "Model does not exist!!"
results_save_path = os.path.join(results_save_path, 'Binary')
os.makedirs(results_save_path, exist_ok=True)
os.makedirs(results_save_path +'/Models', exist_ok=True)
if args.load_path is not None:
if args.test:
if args.test_load_path is None:
print("WARNING : have you used load_path instead of test_load_path?")
else:
checkpoint1 = torch.load(args.load_path , map_location=lambda storage, loc: storage)
fnet_dict = checkpoint1[0]
gnet_dict = checkpoint1[1]
polar.load_partial_nns(fnet_dict, gnet_dict)
print("Loaded nets from {}".format(args.load_path))
if 'KO' in args.decoder_type:
dec_params = []
for i in polar.fnet_dict.keys():
for j in polar.fnet_dict[i].keys():
if isinstance(polar.fnet_dict[i][j], dict):
for k in polar.fnet_dict[i][j].keys():
dec_params += list(polar.fnet_dict[i][j][k].parameters())
else:
dec_params += list(polar.fnet_dict[i][j].parameters())
elif args.decoder_type == 'RNN':
dec_params = polar.fnet_dict.parameters()
else:
args.dec_train_iters = 0
if 'KO' in args.encoder_type:
enc_params = []
if args.shared:
for i in polar.gnet_dict.keys():
enc_params += list(polar.gnet_dict[i].parameters())
else:
for i in polar.gnet_dict.keys():
for j in polar.gnet_dict[i].keys():
enc_params += list(polar.gnet_dict[i][j].parameters())
elif args.encoder_type == 'scaled':
enc_params = [polar.a]
enc_optimizer = optim.Adam(enc_params, lr = args.enc_lr)
else:
args.enc_train_iters = 0
if args.dec_train_iters > 0:
if args.optim == 'Adam':
dec_optimizer = optim.Adam(dec_params, lr = args.dec_lr, weight_decay = args.weight_decay)#, momentum=0.9, nesterov=True) #, amsgrad=True)
elif args.optim == 'SGD':
dec_optimizer = optim.SGD(dec_params, lr = args.dec_lr, weight_decay = args.weight_decay)#, momentum=0.9, nesterov=True) #, amsgrad=True)
elif args.optim == 'RMS':
dec_optimizer = optim.RMSprop(dec_params, lr = args.dec_lr, weight_decay = args.weight_decay)#, momentum=0.9, nesterov=True) #, amsgrad=True)
if args.scheduler == 'reduce':
dec_scheduler = optim.lr_scheduler.ReduceLROnPlateau(dec_optimizer, 'min', patience = args.scheduler_patience)
elif args.scheduler == '1cycle':
dec_scheduler = optim.lr_scheduler.OneCycleLR(dec_optimizer, max_lr = args.dec_lr, total_steps=args.dec_train_iters*args.full_iters)
else:
dec_scheduler = None
if args.enc_train_iters > 0:
enc_optimizer = optim.Adam(enc_params, lr = args.enc_lr)#, momentum=0.9, nesterov=True) #, amsgrad=True)
if args.scheduler == 'reduce':
enc_scheduler = optim.lr_scheduler.ReduceLROnPlateau(enc_optimizer, 'min', patience = args.scheduler_patience)
elif args.scheduler == '1cycle':
enc_scheduler = optim.lr_scheduler.OneCycleLR(enc_optimizer, max_lr = args.enc_lr, total_steps=args.enc_train_iters*args.full_iters)
else:
enc_scheduler = None
if 'BCE' in args.loss:
criterion = nn.BCEWithLogitsLoss()
elif args.loss == 'L1':
criterion = nn.L1Loss()
elif args.loss == 'huber':
criterion = nn.HuberLoss()
else:
criterion = nn.MSELoss()
info_positions = polar.info_positions
if not args.test:
bers_enc = []
losses_enc = []
bers_dec = []
losses_dec = []
train_ber_dec = 0.
train_ber_enc = 0.
loss_dec = 0.
loss_enc = 0.
# val_bers = []
os.makedirs(results_save_path, exist_ok=True)
os.makedirs(results_save_path +'/Models', exist_ok=True)
# Create CSV at the beginning of training
save_path_id = random.randint(100000, 999999)
with open(os.path.join(results_save_path, f'training_results_{save_path_id}.csv'), 'w', newline='') as csvfile:
csvwriter = csv.writer(csvfile)
csvwriter.writerow(['Step', 'Loss', 'BER'])
# save args in a json file
##############
### Optimizers
##############
print("Need to save for:", args.model_save_per)
if not args.batch_schedule:
batch_size = args.batch_size
else:
batch_size = args.min_batch_size
best_batch_ber = 10.
best_batch_iter = 0
try:
best_ber = 10.
for iter in range(1, args.full_iters + 1):
start_time = time.time()
if not args.batch_schedule:
batch_size = args.batch_size
elif batch_size != args.max_batch_size:
if iter - best_batch_iter > args.batch_patience:
batch_size = min(batch_size * 2, args.max_batch_size)
print(f"Increased batch size to {batch_size}")
best_batch_ber = train_ber_enc
best_batch_iter = iter
if 'KO' in args.decoder_type or args.decoder_type == 'RNN':
# Train decoder
loss_dec, train_ber_dec = train(args, polar, dec_optimizer, dec_scheduler if args.scheduler == '1cycle' else None, batch_size, args.dec_train_snr, args.dec_train_iters, criterion, device, info_positions, binary = args.binary, noise_type = args.noise_type)
if args.scheduler_patience is not None:
dec_scheduler.step(loss_dec)
bers_dec.append(train_ber_dec)
losses_dec.append(loss_dec)
if 'KO' in args.encoder_type:
# Train encoder
loss_enc, train_ber_enc = train(args, polar, enc_optimizer, enc_scheduler if args.scheduler == '1cycle' else None, batch_size, args.enc_train_snr, args.enc_train_iters, criterion, device, info_positions, binary = args.binary, noise_type = args.noise_type)
if args.scheduler_patience is not None:
enc_scheduler.step(loss_enc)
bers_enc.append(train_ber_enc)
losses_enc.append(loss_enc)
if args.batch_schedule and train_ber_enc < best_batch_ber:
best_batch_ber = train_ber_enc
best_batch_iter = iter
print(f'Best BER {best_batch_ber} at {best_batch_iter}')
# Save to CSV
with open(os.path.join(results_save_path, f'training_results_{save_path_id}.csv'), 'a', newline='') as csvfile:
csvwriter = csv.writer(csvfile)
csvwriter.writerow([iter, loss_enc, train_ber_enc, loss_dec, train_ber_dec])
if iter % 10 == 1:
print(f"[{iter}/{args.full_iters}] At {args.dec_train_snr} dB, Train Loss: {loss_dec} Train BER {train_ber_dec}, \
\n [{iter}/{args.full_iters}] At {args.enc_train_snr} dB, Train Loss: {loss_enc} Train BER {train_ber_enc}")
print("Time for one full iteration is {0:.4f} minutes. save ID = {1}".format((time.time() - start_time)/60, ID))
if iter % args.model_save_per == 0 or iter == 1:
if train_ber_enc < best_ber:
best_ber = train_ber_enc
best = True
else:
best = False
save_model(polar, iter, results_save_path, best = best)
plot_stuff(bers_enc, losses_enc, bers_dec, losses_dec, results_save_path)
save_model(polar, iter, results_save_path)
plot_stuff(bers_enc, losses_enc, bers_dec, losses_dec, results_save_path)
except KeyboardInterrupt:
save_model(polar, iter, results_save_path)
plot_stuff(bers_enc, losses_enc, bers_dec, losses_dec, results_save_path)
print("Exited and saved")
print("TESTING")
times = []
results_load_path = results_save_path
if args.model_iters is not None:
checkpoint1 = torch.load(results_save_path +'/Models/fnet_gnet_{}.pt'.format(args.model_iters), map_location=lambda storage, loc: storage)
elif args.test_load_path is not None:
checkpoint1 = torch.load(args.test_load_path , map_location=lambda storage, loc: storage)
else:
checkpoint1 = torch.load(results_load_path +'/Models/fnet_gnet_final.pt', map_location=lambda storage, loc: storage)
fnet_dict = checkpoint1[0]
gnet_dict = checkpoint1[1]
polar.load_nns(fnet_dict, gnet_dict, shared = args.shared)
if args.snr_points == 1 and args.test_snr_start == args.test_snr_end:
snr_range = [args.test_snr_start]
else:
snrs_interval = (args.test_snr_end - args.test_snr_start)* 1.0 / (args.snr_points-1)
snr_range = [snrs_interval* item + args.test_snr_start for item in range(args.snr_points)]
start_time = time.time()
# For polar code testing.
args2 = argparse.Namespace(**vars(args))
args2.ell = 2
Frozen = get_frozen(N, K, args2.rate_profile)
Frozen.sort()
polar_l_2 = PolarCode(int(np.log2(N)), args.K, Fr=Frozen, infty = args.infty, hard_decision=args.hard_decision)
if args.pairwise:
codebook_size = 1000
all_msg_bits = 2 * (torch.rand(codebook_size, args.K, device = device) < 0.5).float() - 1
deeppolar_codebook = polar.deeppolar_encode(all_msg_bits)
polar_codebook = polar_l_2.encode_plotkin(all_msg_bits)
gaussian_codebook = F.normalize(torch.randn(codebook_size, args.N), p=2, dim=1)*np.sqrt(args.N)
from scipy import stats
w_statistic_deeppolar, p_value_deeppolar = stats.shapiro(deeppolar_codebook.detach().cpu().numpy())
w_statistic_gaussian, p_value_gaussian = stats.shapiro(gaussian_codebook.detach().cpu().numpy())
w_statistic_polar, p_value_polar = stats.shapiro(polar_codebook.detach().cpu().numpy())
print(f"Deeppolar Shapiro test W = {w_statistic_deeppolar}, p-value = {p_value_deeppolar}")
print(f"Gaussian Shapiro test W = {w_statistic_gaussian}, p-value = {p_value_gaussian}")
print(f"Polar Shapiro test W = {w_statistic_polar}, p-value = {p_value_polar}")
dists_deeppolar, md_deeppolar = pairwise_distances(deeppolar_codebook)
dists_polar, md_polar = pairwise_distances(polar_codebook)
dists_gaussian, md_gaussian = pairwise_distances(gaussian_codebook)
# Function to calculate and plot PDF
def plot_pdf(data, label, bins=30, alpha=0.5):
counts, bin_edges = np.histogram(data, bins=bins, density=True)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
plt.plot(bin_centers, counts, label=label, alpha=alpha)
# Plotting PDF for each list
plt.figure()
plot_pdf(dists_deeppolar, 'Neural', 300)
# plot_pdf(dists_polar, 'Polar', 300)
plot_pdf(dists_gaussian, 'Gaussian', 300)
# Adding labels and title
plt.xlabel('Value')
plt.ylabel('Probability Density')
plt.title(f'Pairwise Distances - N = {args.N}, K = {args.K}')
plt.legend()
# Show the plot
plt.savefig(os.path.join(results_save_path, f"hists_N{args.N}_K{args.K}_{args.id}_2.pdf"))
if args.epos:
from collections import OrderedDict, Counter
def get_epos(k1, k2):
# return counter for bit ocations of first-errors
bb = torch.ne(k1.cpu().sign(), k2.cpu().sign())
# inds = torch.nonzero(bb)[:, 1].numpy()
idx = []
for ii in range(bb.shape[0]):
try:
iii = list(bb.cpu().float().numpy()[ii]).index(1)
idx.append(iii)
except:
pass
counter = Counter(idx)
ordered_counter = OrderedDict(sorted(counter.items()))
return ordered_counter
with torch.no_grad():
for (k, msg_bits) in enumerate(Test_Data_Generator):
msg_bits = msg_bits.to(device)
polar_code = polar_l_2.encode_plotkin(msg_bits)
noisy_code = polar.channel(polar_code, args.dec_train_snr)
noise = noisy_code - polar_code
deeppolar_code = polar.deeppolar_encode(msg_bits)
noisy_deeppolar_code = deeppolar_code + noise
SC_llrs, decoded_SC_msg_bits = polar_l_2.sc_decode_new(noisy_code, args.dec_train_snr)
deeppolar_llrs, decoded_deeppolar_msg_bits = polar.deeppolar_decode(noisy_deeppolar_code)
if k == 0:
epos_deeppolar = get_epos(msg_bits, decoded_deeppolar_msg_bits.sign())
epos_SC = get_epos(msg_bits, decoded_SC_msg_bits.sign())
else:
epos_deeppolar1 = get_epos(msg_bits, decoded_deeppolar_msg_bits.sign())
epos_SC1 = get_epos(msg_bits, decoded_SC_msg_bits.sign())
epos_deeppolar = epos_deeppolar + epos_deeppolar1
epos_SC = epos_SC + epos_SC1
print(f"epos_deeppolar: {epos_deeppolar}")
print(f"EPOS_SC: {epos_SC}")
start = time.time()
bers_SC_test, blers_SC_test, bers_deeppolar_test, blers_deeppolar_test = deeppolar_full_test(args, polar_l_2, polar, snr_range, device, info_positions, binary = args.binary, noise_type = args.noise_type, num_errors = args.num_errors)
print("Test SNRs : {}\n".format(snr_range))
print(f"Test Sigmas : {[snr_db2sigma(s) for s in snr_range]}\n")
print("BERs of DeepPolar: {0}".format(bers_deeppolar_test))
print("BERs of SC decoding: {0}".format(bers_SC_test))
print("BLERs of DeepPolar: {0}".format(blers_deeppolar_test))
print("BLERs of SC decoding: {0}".format(blers_SC_test))
print(f"time = {(time.time() - start)/60} minutes")
## BER
plt.figure(figsize = (12,8))
ok = 0
plt.semilogy(snr_range, bers_deeppolar_test, label="DeepPolar", marker='*', linewidth=1.5)
plt.semilogy(snr_range, bers_SC_test, label="SC decoder", marker='^', linewidth=1.5)
## BLER
plt.semilogy(snr_range, blers_deeppolar_test, label="DeepPolar (BLER)", marker='*', linewidth=1.5, linestyle='dashed')
plt.semilogy(snr_range, blers_SC_test, label="SC decoder (BLER)", marker='^', linewidth=1.5, linestyle='dashed')
plt.grid()
plt.xlabel("SNR (dB)", fontsize=16)
plt.ylabel("Error Rate", fontsize=16)
if args.enc_train_iters > 0:
plt.title("PolarC({2}, {3}): DeepPolar trained at Dec_SNR = {0} dB, Enc_SNR = {1}dB".format(args.dec_train_snr, args.enc_train_snr, args.K,args.N))
else:
plt.title("Polar({1}, {2}): DeepPolar trained at Dec_SNR = {0} dB".format(args.dec_train_snr, args.K,args.N))
plt.legend(prop={'size': 15})
if args.test_load_path is not None:
os.makedirs('Polar_Results/figures', exist_ok=True)
fig_save_path = 'Polar_Results/figures/new_plot_DeepPolar.pdf'
else:
fig_save_path = results_load_path + f"/Step_{args.model_iters if args.model_iters is not None else 'final'}{'_binary' if args.binary else ''}.pdf"
if not args.no_fig:
plt.savefig(fig_save_path)
plt.close()
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
from math import sqrt
# Add TT Linear layer implementation
class TTLinear(nn.Module):
def __init__(self, in_features, out_features, ranks, activation=None, bias=True):
super(TTLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.ranks = ranks
self.activation = activation
# Factorize input and output dimensions
self.in_shape, self.out_shape = self.factorize_dims(in_features, out_features)
# Create TT cores
self.cores = nn.ParameterList()
for i in range(len(self.in_shape)):
if i == 0:
input_rank = 1
else:
input_rank = ranks[i-1]
if i == len(self.in_shape)-1:
output_rank = 1
else:
output_rank = ranks[i]
core = nn.Parameter(torch.randn(
input_rank, self.in_shape[i], self.out_shape[i], output_rank
) * sqrt(2.0 / (input_rank * self.in_shape[i])))
self.cores.append(core)
if bias:
self.bias = nn.Parameter(torch.zeros(out_features))
else:
self.register_parameter('bias', None)
def factorize_dims(self, in_dim, out_dim):
# Simple factorization - can be customized
factors = []
d = in_dim
while d > 1:
for i in range(int(sqrt(d)), 0, -1):
if d % i == 0:
factors.append(i)
factors.append(d // i)
break
d = factors.pop()
in_shape = factors
factors = []
d = out_dim
while d > 1:
for i in range(int(sqrt(d)), 0, -1):
if d % i == 0:
factors.append(i)
factors.append(d // i)
break
d = factors.pop()
out_shape = factors
return in_shape, out_shape
def forward(self, x):
batch_size = x.size(0)
# Reshape input to tensor form
x = x.view(batch_size, *self.in_shape)
# Contract input with first core
res = torch.einsum('bi,rivo->bvo', x, self.cores[0])
# Contract with remaining cores
for i in range(1, len(self.cores)):
res = torch.einsum('bvi,rivo->bvo', res, self.cores[i])
# Reshape to output dimension
res = res.contiguous().view(batch_size, self.out_features)
if self.bias is not None:
res = res + self.bias
if self.activation is not None:
res = self.activation(res)
return res
# Modify get_activation_fn to include TTLinear option
def get_activation_fn(activation):
if activation == 'tanh':
return F.tanh
elif activation == 'elu':
return F.elu
elif activation == 'relu':
return F.relu
elif activation == 'selu':
return F.selu
elif activation == 'sigmoid':
return F.sigmoid
elif activation == 'gelu':
return F.gelu
elif activation == 'silu':
return F.silu
elif activation == 'mish':
return F.mish
elif activation == 'linear':
return nn.Identity()
else:
raise NotImplementedError(f'Activation function {activation} not implemented')
# Modify g_Full to use TTLinear
class g_Full(nn.Module):
def __init__(self, input_size, hidden_size, output_size, depth=3, skip_depth=1,
skip_layer=1, ell=2, activation='selu', use_skip=False, augment=False,
use_tt=False, tt_ranks=None):
super(g_Full, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.depth = depth
self.ell = ell
self.ell_input_size = input_size//self.ell
self.augment = augment
self.activation_fn = get_activation_fn(activation)
self.skip_depth = skip_depth
self.skip_layer = skip_layer
self.use_skip = use_skip
self.use_tt = use_tt
self.tt_ranks = tt_ranks if tt_ranks is not None else [8, 8] # Default ranks
if self.use_skip:
if self.use_tt:
self.skip = nn.ModuleList([TTLinear(self.input_size + self.output_size,
self.hidden_size,
ranks=self.tt_ranks,
activation=self.activation_fn)])
self.skip.extend([TTLinear(self.hidden_size, self.hidden_size,
ranks=self.tt_ranks,
activation=self.activation_fn)
for ii in range(1, self.skip_depth)])
else:
self.skip = nn.ModuleList([nn.Linear(self.input_size + self.output_size,
self.hidden_size, bias=True)])
self.skip.extend([nn.Linear(self.hidden_size, self.hidden_size, bias=True)
for ii in range(1, self.skip_depth)])
if self.use_tt:
self.linears = nn.ModuleList([TTLinear(self.input_size, self.hidden_size,
ranks=self.tt_ranks,
activation=self.activation_fn)])
self.linears.extend([TTLinear(self.hidden_size, self.hidden_size,
ranks=self.tt_ranks,
activation=self.activation_fn)
for ii in range(1, self.depth)])
self.linears.append(TTLinear(self.hidden_size, self.output_size,
ranks=self.tt_ranks))
else:
self.linears = nn.ModuleList([nn.Linear(self.input_size, self.hidden_size, bias=True)])
self.linears.extend([nn.Linear(self.hidden_size, self.hidden_size, bias=True)
for ii in range(1, self.depth)])
self.linears.append(nn.Linear(self.hidden_size, self.output_size, bias=True))
def __init__(self, input_size, hidden_size, output_size, depth=3, skip_depth = 1, skip_layer = 1, ell = 2, activation = 'selu', use_skip = False, augment = False):
super(g_Full, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.depth = depth
self.ell = ell
self.ell_input_size = input_size//self.ell
self.augment = augment
self.activation_fn = get_activation_fn(activation)
self.skip_depth = skip_depth
self.skip_layer = skip_layer
self.use_skip = use_skip
if self.use_skip:
self.skip = nn.ModuleList([nn.Linear(self.input_size + self.output_size, self.hidden_size, bias=True)])
self.skip.extend([nn.Linear(self.hidden_size, self.hidden_size, bias=True) for ii in range(1, self.skip_depth)])
self.linears = nn.ModuleList([nn.Linear(self.input_size, self.hidden_size, bias=True)])
self.linears.extend([nn.Linear(self.hidden_size, self.hidden_size, bias=True) for ii in range(1, self.depth)])
self.linears.append(nn.Linear(self.hidden_size, self.output_size, bias=True))
@staticmethod
def get_augment(msg, ell):
u = msg.clone()
n = int(np.log2(ell))
for d in range(0, n):
num_bits = 2**d
for i in np.arange(0, ell, 2*num_bits):
# [u v] encoded to [u xor(u,v)]
if len(u.shape) == 2:
u = torch.cat((u[:, :i], u[:, i:i+num_bits].clone() * u[:, i+num_bits: i+2*num_bits], u[:, i+num_bits:]), dim=1)
elif len(u.shape) == 3:
u = torch.cat((u[:, :, :i], u[:, :, i:i+num_bits].clone() * u[:, :, i+num_bits: i+2*num_bits], u[:, :, i+num_bits:]), dim=2)
# u[:, i:i+num_bits] = u[:, i:i+num_bits].clone() * u[:, i+num_bits: i+2*num_bits].clone
if len(u.shape) == 3:
return u[:, :, :-1]
elif len(u.shape) == 2:
return u[:, :-1]
def forward(self, y):
x = y.clone()
for ii, layer in enumerate(self.linears):
if ii != self.depth:
x = self.activation_fn(layer(x))
if self.use_skip and ii == self.skip_layer:
if len(x.shape) == 3:
skip_input = torch.cat([y, g_Full.get_augment(y, self.ell)], dim = 2)
elif len(x.shape) == 2:
skip_input = torch.cat([y, g_Full.get_augment(y, self.ell)], dim = 1)
for jj, skip_layer in enumerate(self.skip):
skip_input = self.activation_fn(skip_layer(skip_input))
x = x + skip_input
else:
x = layer(x)
if self.augment:
x = x + g_Full.get_augment(y, self.ell)
return x
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.01)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(0.0, 0.01)
m.bias.data.fill_(0)
elif classname.find('Linear') != -1:
m.weight.data.normal_(0.0, 0.01)
try:
m.bias.data.fill_(0.)
except:
pass
# Modify f_Full to use TTLinear
class f_Full(nn.Module):
def __init__(self, input_size, hidden_size, output_size, dropout_p=0.,
activation='selu', depth=3, use_norm=False, use_tt=False, tt_ranks=None):
super(f_Full, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.depth = depth
self.use_norm = use_norm
self.activation_fn = get_activation_fn(activation)
self.use_tt = use_tt
self.tt_ranks = tt_ranks if tt_ranks is not None else [8, 8] # Default ranks
if self.use_tt:
self.linears = nn.ModuleList([TTLinear(self.input_size, self.hidden_size,
ranks=self.tt_ranks,
activation=self.activation_fn)])
for ii in range(1, self.depth):
self.linears.append(TTLinear(self.hidden_size, self.hidden_size,
ranks=self.tt_ranks,
activation=self.activation_fn))
self.linears.append(TTLinear(self.hidden_size, self.output_size,
ranks=self.tt_ranks))
else:
self.linears = nn.ModuleList([nn.Linear(self.input_size, self.hidden_size, bias=True)])
for ii in range(1, self.depth):
self.linears.append(nn.Linear(self.hidden_size, self.hidden_size, bias=True))
self.linears.append(nn.Linear(self.hidden_size, self.output_size, bias=True))
if self.use_norm:
self.norms = nn.ModuleList([nn.LayerNorm(self.hidden_size)
for _ in range(self.depth)])
def __init__(self, input_size, hidden_size, output_size, dropout_p = 0., activation = 'selu', depth=3, use_norm = False):
super(f_Full, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.depth = depth
self.use_norm = use_norm
self.activation_fn = get_activation_fn(activation)
self.linears = nn.ModuleList([nn.Linear(self.input_size, self.hidden_size, bias=True)])
if self.use_norm:
self.norms = nn.ModuleList([nn.LayerNorm(self.hidden_size)])
for ii in range(1, self.depth):
self.linears.append(nn.Linear(self.hidden_size, self.hidden_size, bias=True))
if self.use_norm:
self.norms.append(nn.LayerNorm(self.hidden_size))
self.linears.append(nn.Linear(self.hidden_size, self.output_size, bias=True))
def forward(self, y, aug = None):
x = y.clone()
for ii, layer in enumerate(self.linears):
if ii != self.depth:
x = layer(x)
if not hasattr(self, 'use_norm') or not self.use_norm:
pass
else:
x = self.norms[ii](x)
x = self.activation_fn(x)
else:
x = layer(x)
return x
def get_onehot(actions):
inds = (0.5 + 0.5*actions).long()
return torch.eye(2, device = inds.device)[inds].reshape(actions.shape[0], -1)
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib
matplotlib.use('AGG')
import matplotlib.pyplot as plt
import pickle
import os
import argparse
import sys
import time
from collections import namedtuple
from utils import snr_db2sigma, min_sum_log_sum_exp, log_sum_exp, errors_ber, errors_bler, corrupt_signal, countSetBits
def get_args():
parser = argparse.ArgumentParser(description='(N,K) Polar code')
parser.add_argument('--N', type=int, default=4, help='Polar code parameter N')
parser.add_argument('--K', type=int, default=3, help='Polar code parameter K')
parser.add_argument('--rate_profile', type=str, default='polar', choices=['RM', 'polar', 'sorted', 'sorted_last', 'rev_polar'], help='Polar rate profiling')
parser.add_argument('--hard_decision', dest = 'hard_decision', default=False, action='store_true')
parser.add_argument('--only_args', dest = 'only_args', default=False, action='store_true')
parser.add_argument('--list_size', type=int, default=1, help='SC List size')
parser.add_argument('--crc_len', type=int, default='0', choices=[0, 3, 8, 16], help='CRC length')
parser.add_argument('--batch_size', type=int, default=10000, help='size of the batches')
parser.add_argument('--test_ratio', type = float, default = 1, help = 'Number of test samples x batch_size')
parser.add_argument('--test_snr_start', type=float, default=-2., help='testing snr start')
parser.add_argument('--test_snr_end', type=float, default=4., help='testing snr end')
parser.add_argument('--snr_points', type=int, default=7, help='testing snr num points')
args = parser.parse_args()
return args
class PolarCode:
def __init__(self, n, K, Fr = None, rs = None, use_cuda = True, infty = 1000., hard_decision = False, lse = 'lse'):
assert n>=1
self.n = n
self.N = 2**n
self.K = K
self.G2 = np.array([[1,1],[0,1]])
self.G = np.array([1])
for i in range(n):
self.G = np.kron(self.G, self.G2)
self.G = torch.from_numpy(self.G).float()
self.device = torch.device("cuda" if use_cuda else "cpu")
self.infty = infty
self.hard_decision = hard_decision
self.lse = lse
if Fr is not None:
assert len(Fr) == self.N - self.K
self.frozen_positions = Fr
self.unsorted_frozen_positions = self.frozen_positions
self.frozen_positions.sort()
self.info_positions = np.array(list(set(self.frozen_positions) ^ set(np.arange(self.N))))
self.unsorted_info_positions = self.info_positions
self.info_positions.sort()
else:
if rs is None:
# in increasing order of reliability
self.reliability_seq = np.arange(1023, -1, -1)
self.rs = self.reliability_seq[self.reliability_seq<self.N]
else:
self.reliability_seq = rs
self.rs = self.reliability_seq[self.reliability_seq<self.N]
assert len(self.rs) == self.N
# best K bits
self.info_positions = self.rs[:self.K]
self.unsorted_info_positions = self.reliability_seq[self.reliability_seq<self.N][:self.K]
self.info_positions.sort()
self.unsorted_info_positions=np.flip(self.unsorted_info_positions)
# worst N-K bits
self.frozen_positions = self.rs[self.K:]
self.unsorted_frozen_positions = self.rs[self.K:]
self.frozen_positions.sort()
self.CRC_polynomials = {
3: torch.Tensor([1, 0, 1, 1]).int(),
8: torch.Tensor([1, 1, 1, 0, 1, 0, 1, 0, 1]).int(),
16: torch.Tensor([1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1]).int(),
}
def get_G(self, ell):
n = int(np.log2(ell))
G = np.array([1])
for i in range(n):
G = np.kron(G, self.G2)
return G
def encode_plotkin(self, message, scaling = None, custom_info_positions = None):
# message shape is (batch, k)
# BPSK convention : 0 -> +1, 1 -> -1
# Therefore, xor(a, b) = a*b
if custom_info_positions is not None:
info_positions = custom_info_positions
else:
info_positions = self.info_positions
u = torch.ones(message.shape[0], self.N, dtype=torch.float).to(message.device)
u[:, info_positions] = message
for d in range(0, self.n):
num_bits = 2**d
for i in np.arange(0, self.N, 2*num_bits):
# [u v] encoded to [u xor(u,v)]
u = torch.cat((u[:, :i], u[:, i:i+num_bits].clone() * u[:, i+num_bits: i+2*num_bits], u[:, i+num_bits:]), dim=1)
# u[:, i:i+num_bits] = u[:, i:i+num_bits].clone() * u[:, i+num_bits: i+2*num_bits].clone
if scaling is not None:
u = (scaling * np.sqrt(self.N)*u)/torch.norm(scaling)
return u
def channel(self, code, snr, noise_type = 'awgn', vv =5.0, radar_power = 20.0, radar_prob = 5e-2):
if noise_type != "bsc":
sigma = snr_db2sigma(snr)
else:
sigma = snr
r = corrupt_signal(code, sigma, noise_type, vv, radar_power, radar_prob)
return r
def define_partial_arrays(self, llrs):
# Initialize arrays to store llrs and partial_sums useful to compute the partial successive cancellation process.
llr_array = torch.zeros(llrs.shape[0], self.n+1, self.N, device=llrs.device)
llr_array[:, self.n] = llrs
partial_sums = torch.zeros(llrs.shape[0], self.n+1, self.N, device=llrs.device)
return llr_array, partial_sums
def updateLLR(self, leaf_position, llrs, partial_llrs = None, prior = None):
#START
depth = self.n
decoded_bits = partial_llrs[:,0].clone()
if prior is None:
prior = torch.zeros(self.N) #priors
llrs, partial_llrs, decoded_bits = self.partial_decode(llrs, partial_llrs, depth, 0, leaf_position, prior, decoded_bits)
return llrs, decoded_bits
def partial_decode(self, llrs, partial_llrs, depth, bit_position, leaf_position, prior, decoded_bits=None):
# Function to call recursively, for partial SC decoder.
# We are assuming that u_0, u_1, .... , u_{leaf_position -1} bits are known.
# Partial sums computes the sums got through Plotkin encoding operations of known bits, to avoid recomputation.
# this function is implemented for rate 1 (not accounting for frozen bits in polar SC decoding)
# print("DEPTH = {}, bit_position = {}".format(depth, bit_position))
half_index = 2 ** (depth - 1)
leaf_position_at_depth = leaf_position // 2**(depth-1) # will tell us whether left_child or right_child
# n = 2 tree case
if depth == 1:
# Left child
left_bit_position = 2*bit_position
if leaf_position_at_depth > left_bit_position:
u_hat = partial_llrs[:, depth-1, left_bit_position:left_bit_position+1]
elif leaf_position_at_depth == left_bit_position:
if self.lse == 'minsum':
Lu = min_sum_log_sum_exp(llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index], llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index]).sum(dim=1, keepdim=True)
elif self.lse == 'lse':
Lu = log_sum_exp(llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index], llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index]).sum(dim=1, keepdim=True)
# Lu = log_sum_avoid_zero_NaN(llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index], llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index]).sum(dim=1, keepdim=True)
#print(Lu.device, prior.device, torch.ones_like(Lu).device)
llrs[:, depth-1, left_bit_position*half_index:(left_bit_position+1)*half_index] = Lu + prior[left_bit_position]*torch.ones_like(Lu)
if self.hard_decision:
u_hat = torch.sign(Lu)
else:
u_hat = torch.tanh(Lu/2)
decoded_bits[:, left_bit_position] = u_hat.squeeze(1)
return llrs, partial_llrs, decoded_bits
# Right child
right_bit_position = 2*bit_position + 1
if leaf_position_at_depth > right_bit_position:
pass
elif leaf_position_at_depth == right_bit_position:
Lv = u_hat * llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index] + llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index]
llrs[:, depth-1, right_bit_position*half_index:(right_bit_position+1)*half_index] = Lv + prior[right_bit_position] * torch.ones_like(Lv)
if self.hard_decision:
v_hat = torch.sign(Lv)
else:
v_hat = torch.tanh(Lv/2)
decoded_bits[:, right_bit_position] = v_hat.squeeze(1)
return llrs, partial_llrs, decoded_bits
# General case
else:
# LEFT CHILD
# Find likelihood of (u xor v) xor (v) = u
# Lu = log_sum_exp(torch.cat([llrs[:, :half_index].unsqueeze(2), llrs[:, half_index:].unsqueeze(2)], dim=2).permute(0, 2, 1))
left_bit_position = 2*bit_position
if leaf_position_at_depth > left_bit_position:
Lu = llrs[:, depth-1, left_bit_position*half_index:(left_bit_position+1)*half_index]
u_hat = partial_llrs[:, depth-1, left_bit_position*half_index:(left_bit_position+1)*half_index]
else:
if self.lse == 'minsum':
Lu = min_sum_log_sum_exp(llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index], llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index])
elif self.lse == 'lse':
# Lu = log_sum_avoid_zero_NaN(llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index], llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index])
Lu = log_sum_exp(llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index], llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index])
llrs[:, depth-1, left_bit_position*half_index:(left_bit_position+1)*half_index] = Lu
llrs, partial_llrs, decoded_bits = self.partial_decode(llrs, partial_llrs, depth-1, left_bit_position, leaf_position, prior, decoded_bits)
return llrs, partial_llrs, decoded_bits
# RIGHT CHILD
right_bit_position = 2*bit_position + 1
Lv = u_hat * llrs[:, depth, left_bit_position*half_index:(left_bit_position+1)*half_index] + llrs[:,depth, (left_bit_position+1)*half_index:(left_bit_position+2)*half_index]
llrs[:, depth-1, right_bit_position*half_index:(right_bit_position+1)*half_index] = Lv
llrs, partial_llrs, decoded_bits = self.partial_decode(llrs, partial_llrs, depth-1, right_bit_position, leaf_position, prior, decoded_bits)
return llrs, partial_llrs, decoded_bits
def updatePartialSums(self, leaf_position, decoded_bits, partial_llrs):
u = decoded_bits.clone()
u[:, leaf_position+1:] = 0
for d in range(0, self.n):
partial_llrs[:, d] = u
num_bits = 2**d
for i in np.arange(0, self.N, 2*num_bits):
# [u v] encoded to [u xor(u,v)]
u = torch.cat((u[:, :i], u[:, i:i+num_bits].clone() * u[:, i+num_bits: i+2*num_bits], u[:, i+num_bits:]), dim=1)
partial_llrs[:, self.n] = u
return partial_llrs
def sc_decode_new(self, corrupted_codewords, snr, use_gt = None, channel = 'awgn'):
assert channel in ['awgn', 'bsc']
if channel == 'awgn':
noise_sigma = snr_db2sigma(snr)
llrs = (2/noise_sigma**2)*corrupted_codewords
elif channel == 'bsc':
# snr refers to transition prob
p = (torch.ones(1)*(snr + 1e-9)).to(corrupted_codewords.device)
llrs = (torch.clip(torch.log((1 - p) / p), -10000, 10000) * (corrupted_codewords + 1) - torch.clip(torch.log(p / (1-p)), -10000, 10000) * (corrupted_codewords - 1))/2
# step-wise implementation using updateLLR and updatePartialSums
priors = torch.zeros(self.N)
priors[self.frozen_positions] = self.infty
u_hat = torch.zeros(corrupted_codewords.shape[0], self.N, device=corrupted_codewords.device)
llr_array, partial_llrs = self.define_partial_arrays(llrs)
for ii in range(self.N):
#start = time.time()
llr_array , decoded_bits = self.updateLLR(ii, llr_array.clone(), partial_llrs, priors)
#print('SC update : {}'.format(time.time() - start), corrupted_codewords.shape[0])
if use_gt is None:
u_hat[:, ii] = torch.sign(llr_array[:, 0, ii])
else:
u_hat[:, ii] = use_gt[:, ii]
#start = time.time()
partial_llrs = self.updatePartialSums(ii, u_hat, partial_llrs)
#print('SC partial: {}s, {}', time.time() - start, 'frozen' if ii in self.frozen_positions else 'info')
decoded_bits = u_hat[:, self.info_positions]
return llr_array[:, 0, :].clone(), decoded_bits
def get_CRC(self, message):
# need to optimize.
# inout message should be int
padded_bits = torch.cat([message, torch.zeros(self.CRC_len).int().to(message.device)])
while len(padded_bits[0:self.K_minus_CRC].nonzero()):
cur_shift = (padded_bits != 0).int().argmax(0)
padded_bits[cur_shift: cur_shift + self.CRC_len + 1] = padded_bits[cur_shift: cur_shift + self.CRC_len + 1] ^ self.CRC_polynomials[self.CRC_len].to(message.device)
return padded_bits[self.K_minus_CRC:]
def CRC_check(self, message):
# need to optimize.
# input message should be int
padded_bits = message
while len(padded_bits[0:self.K_minus_CRC].nonzero()):
cur_shift = (padded_bits != 0).int().argmax(0)
padded_bits[cur_shift: cur_shift + polar.CRC_len + 1] ^= self.CRC_polynomials[self.CRC_len].to(message.device)
if padded_bits[self.K_minus_CRC:].sum()>0:
return 0
else:
return 1
def encode_with_crc(self, message, CRC_len):
self.CRC_len = CRC_len
self.K_minus_CRC = self.K - CRC_len
if CRC_len == 0:
return self.encode_plotkin(message)
else:
crcs = 1-2*torch.vstack([self.get_CRC((0.5+0.5*message[jj]).int()) for jj in range(message.shape[0])])
encoded = self.encode_plotkin(torch.cat([message, crcs], 1))
return encoded
def get_frozen(N, K, rate_profile, target_K = None):
n = int(np.log2(N))
if rate_profile == 'polar':
# computed for SNR = 0
if n == 5:
rs = np.array([31, 30, 29, 27, 23, 15, 28, 26, 25, 22, 21, 14, 19, 13, 11, 24, 7, 20, 18, 12, 17, 10, 9, 6, 5, 3, 16, 8, 4, 2, 1, 0])
elif n == 4:
rs = np.array([15, 14, 13, 11, 7, 12, 10, 9, 6, 5, 3, 8, 4, 2, 1, 0])
# for RM :(
# rs = np.array([15, 14, 13, 11, 7, 12, 10, 9, 6, 3, 5, 8, 4, 2, 1, 0])
elif n == 3:
rs = np.array([7, 6, 5, 3, 4, 2, 1, 0])
elif n == 2:
rs = np.array([3, 2, 1, 0])
elif n<9:
rs = np.array([256 ,255 ,252 ,254 ,248 ,224 ,240 ,192 ,128 ,253 ,244 ,251 ,250 ,239 ,238 ,247 ,246 ,223 ,222 ,232 ,216 ,236 ,220 ,188 ,208 ,184 ,191 ,190 ,176 ,127 ,126 ,124 ,120 ,249 ,245 ,243 ,242 ,160 ,231 ,230 ,237 ,235 ,234 ,112 ,228 ,221 ,219 ,218 ,212 ,215 ,214 ,189 ,187 ,96 ,186 ,207 ,206 ,183 ,182 ,204 ,180 ,200 ,64 ,175 ,174 ,172 ,125 ,123 ,122 ,119 ,159 ,118 ,158 ,168 ,241 ,116 ,111 ,233 ,156 ,110 ,229 ,227 ,217 ,108 ,213 ,152 ,226 ,95 ,211 ,94 ,205 ,185 ,104 ,210 ,203 ,181 ,92 ,144 ,202 ,179 ,199 ,173 ,178 ,63 ,198 ,121 ,171 ,88 ,62 ,117 ,170 ,196 ,157 ,167 ,60 ,115 ,155 ,109 ,166 ,80 ,114 ,154 ,107 ,56 ,225 ,151 ,164 ,106 ,93 ,150 ,209 ,103 ,91 ,143 ,201 ,102 ,48 ,148 ,177 ,90 ,142 ,197 ,87 ,100 ,61 ,169 ,195 ,140 ,86 ,59 ,32 ,165 ,194 ,113 ,79 ,58 ,153 ,84 ,136 ,55 ,163 ,78 ,105 ,149 ,162 ,54 ,76 ,101 ,47 ,147 ,89 ,52 ,141 ,99 ,46 ,146 ,72 ,85 ,139 ,98 ,31 ,44 ,193 ,138 ,57 ,83 ,30 ,135 ,77 ,40 ,82 ,134 ,161 ,28 ,53 ,75 ,132 ,24 ,51 ,74 ,45 ,145 ,71 ,50 ,16 ,97 ,70 ,43 ,137 ,68 ,42 ,29 ,39 ,81 ,27 ,133 ,38 ,26 ,36 ,131 ,23 ,73 ,22 ,130 ,49 ,15 ,20 ,69 ,14 ,12 ,67 ,41 ,8 ,66 ,37 ,25 ,35 ,34 ,21 ,129 ,19 ,13 ,18 ,11 ,10 ,7 ,65 ,6 ,4 ,33 ,17 ,9 ,5 ,3 ,2 ,1 ]) - 1
else:
rs = np.array([1023, 1022, 1021, 1019, 1015, 1007, 1020, 991, 1018, 1017, 1014,
1006, 895, 1013, 1011, 959, 1005, 990, 1003, 989, 767, 1016,
999, 1012, 987, 958, 983, 957, 1010, 1004, 955, 1009, 894,
975, 893, 1002, 951, 1001, 988, 511, 766, 998, 891, 943,
986, 997, 985, 887, 956, 765, 995, 927, 982, 981, 879,
954, 974, 763, 953, 979, 510, 1008, 759, 863, 950, 892,
1000, 973, 949, 509, 890, 971, 996, 942, 751, 984, 889,
507, 947, 831, 886, 967, 941, 764, 926, 980, 994, 939,
885, 993, 735, 878, 925, 503, 762, 883, 978, 935, 703,
495, 952, 877, 761, 972, 923, 977, 948, 758, 862, 875,
919, 970, 757, 861, 508, 969, 750, 946, 479, 888, 639,
871, 911, 830, 940, 859, 755, 966, 945, 749, 506, 884,
938, 965, 829, 734, 924, 855, 505, 747, 963, 937, 882,
934, 827, 733, 447, 992, 847, 876, 501, 921, 702, 494,
881, 760, 743, 933, 502, 918, 874, 922, 823, 731, 499,
860, 756, 931, 701, 873, 493, 727, 917, 870, 976, 815,
910, 383, 968, 478, 858, 754, 699, 491, 869, 944, 748,
638, 915, 477, 719, 909, 964, 255, 799, 504, 857, 854,
753, 828, 746, 695, 487, 907, 637, 867, 853, 475, 936,
962, 446, 732, 826, 745, 846, 500, 825, 903, 687, 932,
635, 471, 445, 742, 880, 498, 730, 851, 822, 382, 920,
845, 741, 443, 700, 729, 631, 492, 872, 961, 726, 821,
930, 497, 381, 843, 463, 916, 739, 671, 623, 490, 929,
439, 814, 819, 868, 752, 914, 698, 725, 839, 856, 476,
813, 718, 908, 486, 723, 866, 489, 607, 431, 697, 379,
811, 798, 913, 575, 717, 254, 694, 636, 474, 807, 715,
906, 797, 693, 865, 960, 852, 744, 634, 473, 795, 905,
485, 415, 483, 470, 444, 375, 850, 740, 686, 902, 824,
691, 253, 711, 633, 844, 685, 630, 901, 367, 791, 928,
728, 820, 849, 783, 670, 899, 738, 842, 683, 247, 469,
441, 442, 462, 251, 737, 438, 467, 351, 629, 841, 724,
679, 669, 496, 461, 818, 380, 437, 627, 622, 459, 378,
239, 488, 667, 838, 430, 484, 812, 621, 319, 817, 435,
377, 696, 722, 912, 606, 810, 864, 716, 837, 721, 714,
809, 796, 455, 472, 619, 835, 692, 663, 223, 414, 904,
427, 806, 482, 632, 713, 690, 848, 605, 373, 252, 794,
429, 710, 684, 615, 805, 900, 655, 468, 366, 603, 413,
574, 481, 371, 250, 793, 466, 423, 374, 689, 628, 440,
365, 709, 789, 803, 411, 573, 682, 249, 460, 790, 668,
599, 350, 707, 246, 681, 465, 571, 626, 436, 407, 782,
191, 127, 363, 620, 666, 458, 245, 349, 677, 434, 678,
591, 787, 399, 457, 359, 238, 625, 840, 567, 736, 665,
428, 376, 781, 898, 618, 675, 318, 454, 662, 243, 897,
347, 836, 816, 720, 433, 604, 617, 779, 808, 661, 834,
712, 804, 833, 559, 237, 453, 426, 222, 317, 775, 372,
343, 412, 235, 543, 614, 451, 425, 422, 613, 370, 221,
315, 480, 335, 659, 654, 364, 190, 369, 248, 653, 688,
231, 410, 602, 611, 802, 792, 421, 651, 601, 598, 708,
311, 219, 572, 597, 788, 570, 409, 590, 362, 801, 680,
464, 406, 419, 348, 647, 786, 215, 589, 706, 361, 676,
566, 189, 595, 244, 569, 303, 405, 358, 456, 346, 398,
565, 242, 126, 705, 780, 587, 624, 664, 236, 187, 357,
432, 785, 558, 674, 207, 403, 397, 452, 345, 563, 778,
241, 316, 342, 616, 660, 557, 125, 234, 183, 287, 355,
583, 673, 395, 424, 314, 220, 777, 341, 612, 658, 123,
175, 774, 555, 233, 334, 542, 450, 313, 391, 230, 652,
368, 218, 339, 600, 119, 333, 657, 610, 773, 541, 310,
420, 159, 229, 650, 551, 596, 609, 408, 217, 449, 188,
309, 214, 331, 111, 539, 360, 771, 649, 302, 418, 594,
896, 227, 404, 646, 186, 588, 832, 568, 213, 417, 301,
307, 356, 402, 800, 564, 327, 95, 206, 240, 535, 593,
645, 586, 344, 396, 185, 401, 211, 354, 299, 585, 286,
562, 643, 182, 205, 124, 232, 285, 295, 181, 556, 582,
527, 394, 340, 63, 203, 561, 353, 448, 122, 283, 393,
581, 554, 174, 390, 704, 312, 338, 228, 179, 784, 199,
553, 121, 173, 389, 540, 579, 332, 118, 672, 550, 337,
158, 279, 271, 416, 216, 308, 387, 538, 549, 226, 330,
776, 171, 212, 117, 110, 329, 656, 157, 772, 306, 326,
225, 167, 115, 537, 534, 184, 109, 300, 547, 305, 210,
155, 533, 325, 352, 608, 400, 298, 204, 94, 648, 284,
209, 151, 180, 107, 770, 297, 392, 323, 592, 202, 644,
93, 294, 178, 103, 143, 282, 62, 336, 201, 120, 172,
198, 769, 584, 91, 388, 293, 177, 526, 278, 281, 642,
525, 531, 61, 170, 116, 197, 87, 156, 277, 114, 560,
169, 59, 291, 580, 275, 523, 641, 270, 195, 552, 519,
166, 224, 578, 108, 269, 79, 154, 113, 548, 577, 536,
328, 55, 106, 165, 153, 150, 386, 208, 324, 546, 385,
267, 47, 92, 163, 296, 304, 105, 102, 149, 263, 532,
322, 292, 545, 90, 200, 31, 321, 530, 142, 176, 147,
101, 141, 196, 524, 529, 290, 89, 280, 60, 86, 99,
139, 168, 58, 522, 276, 85, 194, 289, 78, 135, 112,
521, 57, 83, 54, 518, 274, 268, 768, 164, 77, 152,
193, 53, 162, 104, 517, 273, 266, 75, 46, 148, 51,
640, 100, 45, 576, 161, 265, 262, 71, 146, 30, 140,
88, 515, 98, 43, 29, 261, 145, 138, 84, 259, 39,
97, 27, 56, 82, 137, 76, 384, 134, 23, 52, 133,
320, 15, 73, 50, 81, 131, 44, 70, 544, 192, 528,
288, 520, 160, 272, 74, 49, 516, 42, 69, 28, 144,
41, 67, 96, 514, 38, 264, 260, 136, 22, 25, 37,
80, 513, 26, 258, 35, 132, 21, 257, 72, 14, 48,
13, 19, 130, 68, 40, 11, 512, 66, 129, 7, 36,
24, 34, 256, 20, 65, 33, 12, 128, 18, 10, 17,
6, 9, 64, 5, 3, 32, 16, 8, 4, 2, 1,
0])
rs = rs[rs<N]
Fr = rs[K:].copy()
Fr.sort()
elif rate_profile == 'RM':
rmweight = np.array([countSetBits(i) for i in range(N)])
Fr = np.argsort(rmweight)[:-K]
Fr.sort()
elif rate_profile == 'sorted':
if n == 5:
rs = np.array([31, 30, 29, 27, 23, 15, 28, 26, 25, 22, 21, 14, 19, 13, 11, 24, 7, 20, 18, 12, 17, 10, 9, 6, 5, 3, 16, 8, 4, 2, 1, 0])
elif n == 4:
rs = np.array([15, 14, 13, 11, 7, 12, 10, 9, 6, 5, 3, 8, 4, 2, 1, 0])
elif n == 3:
rs = np.array([7, 6, 5, 3, 4, 2, 1, 0])
elif n == 2:
rs = np.array([3, 2, 1, 0])
rs = np.array([256 ,255 ,252 ,254 ,248 ,224 ,240 ,192 ,128 ,253 ,244 ,251 ,250 ,239 ,238 ,247 ,246 ,223 ,222 ,232 ,216 ,236 ,220 ,188 ,208 ,184 ,191 ,190 ,176 ,127 ,126 ,124 ,120 ,249 ,245 ,243 ,242 ,160 ,231 ,230 ,237 ,235 ,234 ,112 ,228 ,221 ,219 ,218 ,212 ,215 ,214 ,189 ,187 ,96 ,186 ,207 ,206 ,183 ,182 ,204 ,180 ,200 ,64 ,175 ,174 ,172 ,125 ,123 ,122 ,119 ,159 ,118 ,158 ,168 ,241 ,116 ,111 ,233 ,156 ,110 ,229 ,227 ,217 ,108 ,213 ,152 ,226 ,95 ,211 ,94 ,205 ,185 ,104 ,210 ,203 ,181 ,92 ,144 ,202 ,179 ,199 ,173 ,178 ,63 ,198 ,121 ,171 ,88 ,62 ,117 ,170 ,196 ,157 ,167 ,60 ,115 ,155 ,109 ,166 ,80 ,114 ,154 ,107 ,56 ,225 ,151 ,164 ,106 ,93 ,150 ,209 ,103 ,91 ,143 ,201 ,102 ,48 ,148 ,177 ,90 ,142 ,197 ,87 ,100 ,61 ,169 ,195 ,140 ,86 ,59 ,32 ,165 ,194 ,113 ,79 ,58 ,153 ,84 ,136 ,55 ,163 ,78 ,105 ,149 ,162 ,54 ,76 ,101 ,47 ,147 ,89 ,52 ,141 ,99 ,46 ,146 ,72 ,85 ,139 ,98 ,31 ,44 ,193 ,138 ,57 ,83 ,30 ,135 ,77 ,40 ,82 ,134 ,161 ,28 ,53 ,75 ,132 ,24 ,51 ,74 ,45 ,145 ,71 ,50 ,16 ,97 ,70 ,43 ,137 ,68 ,42 ,29 ,39 ,81 ,27 ,133 ,38 ,26 ,36 ,131 ,23 ,73 ,22 ,130 ,49 ,15 ,20 ,69 ,14 ,12 ,67 ,41 ,8 ,66 ,37 ,25 ,35 ,34 ,21 ,129 ,19 ,13 ,18 ,11 ,10 ,7 ,65 ,6 ,4 ,33 ,17 ,9 ,5 ,3 ,2 ,1 ]) - 1
rs = rs[rs<N]
first_inds = rs[:K].copy()
first_inds.sort()
rs[:K] = first_inds
Fr = rs[K:].copy()
Fr.sort()
elif rate_profile == 'sorted_last':
if n == 5:
rs = np.array([31, 30, 29, 27, 23, 15, 28, 26, 25, 22, 21, 14, 19, 13, 11, 24, 7, 20, 18, 12, 17, 10, 9, 6, 5, 3, 16, 8, 4, 2, 1, 0])
elif n == 4:
rs = np.array([15, 14, 13, 11, 7, 12, 10, 9, 6, 5, 3, 8, 4, 2, 1, 0])
elif n == 3:
rs = np.array([7, 6, 5, 3, 4, 2, 1, 0])
elif n == 2:
rs = np.array([3, 2, 1, 0])
rs = np.array([256 ,255 ,252 ,254 ,248 ,224 ,240 ,192 ,128 ,253 ,244 ,251 ,250 ,239 ,238 ,247 ,246 ,223 ,222 ,232 ,216 ,236 ,220 ,188 ,208 ,184 ,191 ,190 ,176 ,127 ,126 ,124 ,120 ,249 ,245 ,243 ,242 ,160 ,231 ,230 ,237 ,235 ,234 ,112 ,228 ,221 ,219 ,218 ,212 ,215 ,214 ,189 ,187 ,96 ,186 ,207 ,206 ,183 ,182 ,204 ,180 ,200 ,64 ,175 ,174 ,172 ,125 ,123 ,122 ,119 ,159 ,118 ,158 ,168 ,241 ,116 ,111 ,233 ,156 ,110 ,229 ,227 ,217 ,108 ,213 ,152 ,226 ,95 ,211 ,94 ,205 ,185 ,104 ,210 ,203 ,181 ,92 ,144 ,202 ,179 ,199 ,173 ,178 ,63 ,198 ,121 ,171 ,88 ,62 ,117 ,170 ,196 ,157 ,167 ,60 ,115 ,155 ,109 ,166 ,80 ,114 ,154 ,107 ,56 ,225 ,151 ,164 ,106 ,93 ,150 ,209 ,103 ,91 ,143 ,201 ,102 ,48 ,148 ,177 ,90 ,142 ,197 ,87 ,100 ,61 ,169 ,195 ,140 ,86 ,59 ,32 ,165 ,194 ,113 ,79 ,58 ,153 ,84 ,136 ,55 ,163 ,78 ,105 ,149 ,162 ,54 ,76 ,101 ,47 ,147 ,89 ,52 ,141 ,99 ,46 ,146 ,72 ,85 ,139 ,98 ,31 ,44 ,193 ,138 ,57 ,83 ,30 ,135 ,77 ,40 ,82 ,134 ,161 ,28 ,53 ,75 ,132 ,24 ,51 ,74 ,45 ,145 ,71 ,50 ,16 ,97 ,70 ,43 ,137 ,68 ,42 ,29 ,39 ,81 ,27 ,133 ,38 ,26 ,36 ,131 ,23 ,73 ,22 ,130 ,49 ,15 ,20 ,69 ,14 ,12 ,67 ,41 ,8 ,66 ,37 ,25 ,35 ,34 ,21 ,129 ,19 ,13 ,18 ,11 ,10 ,7 ,65 ,6 ,4 ,33 ,17 ,9 ,5 ,3 ,2 ,1 ]) - 1
rs = rs[rs<N]
first_inds = rs[:K].copy()
first_inds.sort()
rs[:K] = first_inds[::-1]
Fr = rs[K:].copy()
Fr.sort()
elif rate_profile == 'rev_polar':
if n == 5:
rs = np.array([31, 30, 29, 27, 23, 15, 28, 26, 25, 22, 21, 14, 19, 13, 11, 24, 7, 20, 18, 12, 17, 10, 9, 6, 5, 3, 16, 8, 4, 2, 1, 0])
elif n == 4:
rs = np.array([15, 14, 13, 11, 7, 12, 10, 9, 6, 5, 3, 8, 4, 2, 1, 0])
elif n == 3:
rs = np.array([7, 6, 5, 3, 4, 2, 1, 0])
elif n == 2:
rs = np.array([3, 2, 1, 0])
rs = np.array([256 ,255 ,252 ,254 ,248 ,224 ,240 ,192 ,128 ,253 ,244 ,251 ,250 ,239 ,238 ,247 ,246 ,223 ,222 ,232 ,216 ,236 ,220 ,188 ,208 ,184 ,191 ,190 ,176 ,127 ,126 ,124 ,120 ,249 ,245 ,243 ,242 ,160 ,231 ,230 ,237 ,235 ,234 ,112 ,228 ,221 ,219 ,218 ,212 ,215 ,214 ,189 ,187 ,96 ,186 ,207 ,206 ,183 ,182 ,204 ,180 ,200 ,64 ,175 ,174 ,172 ,125 ,123 ,122 ,119 ,159 ,118 ,158 ,168 ,241 ,116 ,111 ,233 ,156 ,110 ,229 ,227 ,217 ,108 ,213 ,152 ,226 ,95 ,211 ,94 ,205 ,185 ,104 ,210 ,203 ,181 ,92 ,144 ,202 ,179 ,199 ,173 ,178 ,63 ,198 ,121 ,171 ,88 ,62 ,117 ,170 ,196 ,157 ,167 ,60 ,115 ,155 ,109 ,166 ,80 ,114 ,154 ,107 ,56 ,225 ,151 ,164 ,106 ,93 ,150 ,209 ,103 ,91 ,143 ,201 ,102 ,48 ,148 ,177 ,90 ,142 ,197 ,87 ,100 ,61 ,169 ,195 ,140 ,86 ,59 ,32 ,165 ,194 ,113 ,79 ,58 ,153 ,84 ,136 ,55 ,163 ,78 ,105 ,149 ,162 ,54 ,76 ,101 ,47 ,147 ,89 ,52 ,141 ,99 ,46 ,146 ,72 ,85 ,139 ,98 ,31 ,44 ,193 ,138 ,57 ,83 ,30 ,135 ,77 ,40 ,82 ,134 ,161 ,28 ,53 ,75 ,132 ,24 ,51 ,74 ,45 ,145 ,71 ,50 ,16 ,97 ,70 ,43 ,137 ,68 ,42 ,29 ,39 ,81 ,27 ,133 ,38 ,26 ,36 ,131 ,23 ,73 ,22 ,130 ,49 ,15 ,20 ,69 ,14 ,12 ,67 ,41 ,8 ,66 ,37 ,25 ,35 ,34 ,21 ,129 ,19 ,13 ,18 ,11 ,10 ,7 ,65 ,6 ,4 ,33 ,17 ,9 ,5 ,3 ,2 ,1 ]) - 1
rs = rs[rs<N]
first_inds = rs[:target_K].copy()
rs[:target_K] = first_inds[::-1]
Fr = rs[K:].copy()
Fr.sort()
return Fr
if __name__ == '__main__':
args = get_args()
n = int(np.log2(args.N))
Fr = get_frozen(args.N, args.K, args.rate_profile)
polar = PolarCode(n, args.K, Fr = Fr, hard_decision=True)
# Multiple SNRs:
if args.snr_points == 1 and args.test_snr_start == args.test_snr_end:
snr_range = [args.test_snr_start]
else:
snrs_interval = (args.test_snr_end - args.test_snr_start)* 1.0 / (args.snr_points-1)
snr_range = [snrs_interval* item + args.test_snr_start for item in range(args.snr_points)]
if args.only_args:
print("Loaded args. Exiting")
sys.exit()
bers_SC = [0. for ii in snr_range]
blers_SC = [0. for ii in snr_range]
for r in range(int(args.test_ratio)):
msg_bits = 1 - 2*(torch.rand(args.batch_size, args.K) > 0.5).float()
codes = polar.encode_plotkin(msg_bits)
for snr_ind, snr in enumerate(snr_range):
# codes_G = polar.encode_G(msg_bits_bpsk)
noisy_code = polar.channel(codes, snr)
noise = noisy_code - codes
SC_llrs, decoded_SC_msg_bits = polar.sc_decode_new(noisy_code, snr)
ber_SC = errors_ber(msg_bits, decoded_SC_msg_bits.sign()).item()
bler_SC = errors_bler(msg_bits, decoded_SC_msg_bits.sign()).item()
bers_SC[snr_ind] += ber_SC/args.test_ratio
blers_SC[snr_ind] += bler_SC/args.test_ratio
print("Test SNRs : ", snr_range)
print("BERs of SC: {0}".format(bers_SC))
print("BLERs of SC: {0}".format(blers_SC))
python -u main.py --N 16 --K 1 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 500 --batch_size 20000 --enc_train_snr -1 --dec_train_snr -3 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start -7 --test_snr_end 0 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_1_normal_polar_eh64_dh128_selu_new --gpu 0 --regularizer polar --regularizer_weight 0.05
python -u main.py --N 16 --K 2 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 500 --batch_size 20000 --enc_train_snr 2 --dec_train_snr 0 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start -4 --test_snr_end 3 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_2_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_1_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05
python -u main.py --N 16 --K 3 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 3 --dec_train_snr 1 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start -3 --test_snr_end 4 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_3_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_2_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05
python -u main.py --N 16 --K 4 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 3 --dec_train_snr 2 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start -3 --test_snr_end 4 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_4_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_3_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05
python -u main.py --N 16 --K 5 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 4 --dec_train_snr 2.5 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start -2 --test_snr_end 5 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_5_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_4_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05
python -u main.py --N 16 --K 6 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 5 --dec_train_snr 4 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start -1 --test_snr_end 6 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_6_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_5_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05
python -u main.py --N 16 --K 7 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 6 --dec_train_snr 4 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start 0 --test_snr_end 7 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_7_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_6_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05
python -u main.py --N 16 --K 8 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 6 --dec_train_snr 5 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start 0 --test_snr_end 7 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_8_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_7_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05
python -u main.py --N 16 --K 9 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 6.5 --dec_train_snr 5 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start 1 --test_snr_end 8 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_9_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_8_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05
python -u main.py --N 16 --K 10 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 7 --dec_train_snr 6 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start 1 --test_snr_end 8 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_10_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_9_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05
python -u main.py --N 16 --K 11 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 7 --dec_train_snr 6 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start 1 --test_snr_end 8 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_11_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_10_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05
python -u main.py --N 16 --K 12 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 8.5 --dec_train_snr 7 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start 3 --test_snr_end 10 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_12_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_11_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05
python -u main.py --N 16 --K 13 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 9 --dec_train_snr 8 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start 3 --test_snr_end 10 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_13_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_12_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05
python -u main.py --N 16 --K 14 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 9.5 --dec_train_snr 8 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start 4 --test_snr_end 11 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_14_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_13_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05
python -u main.py --N 16 --K 15 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 10 --dec_train_snr 9 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start 4 --test_snr_end 11 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_15_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_14_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05
python -u main.py --N 16 --K 16 --model_save_per 100 --enc_train_iters 20 --dec_train_iters 200 --full_iters 1000 --batch_size 20000 --enc_train_snr 12 --dec_train_snr 11 --enc_lr 0.0003 --dec_lr 0.0003 --num_errors 10 --test_snr_start 6 --test_snr_end 13 --snr_points 8 -ell 16 --encoder_type KO --enc_activation selu --dec_activation selu --dec_hidden_size 128 --enc_hidden_size 64 --save_path Polar_Results/curriculum/16_16_normal_polar_eh64_dh128_selu_new --load_path Polar_Results/curriculum/16_15_normal_polar_eh64_dh128_selu_new/Models/fnet_gnet_final.pt --gpu 0 --regularizer polar --regularizer_weight 0.05
bash copy_files.sh 16 normal_polar_eh64_dh128_selu_new
\ No newline at end of file
#
# This file is autogenerated by pip-compile with Python 3.10
# by the following command:
#
# pip-compile requirements.in
#
argparse==1.4.0
# via -r requirements.in
contourpy==1.2.1
# via matplotlib
cycler==0.12.1
# via matplotlib
filelock==3.14.0
# via
# torch
# triton
fonttools==4.51.0
# via matplotlib
fsspec==2024.3.1
# via torch
jinja2==3.1.4
# via torch
kiwisolver==1.4.5
# via matplotlib
markupsafe==2.1.5
# via jinja2
matplotlib==3.8.4
# via -r requirements.in
mpmath==1.3.0
# via sympy
networkx==3.3
# via torch
numpy==1.26.4
# via
# -r requirements.in
# contourpy
# matplotlib
nvidia-cublas-cu12==12.1.3.1
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.1.105
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105
# via torch
nvidia-cuda-runtime-cu12==12.1.105
# via torch
nvidia-cudnn-cu12==8.9.2.26
# via torch
nvidia-cufft-cu12==11.0.2.54
# via torch
nvidia-curand-cu12==10.3.2.106
# via torch
nvidia-cusolver-cu12==11.4.5.107
# via torch
nvidia-cusparse-cu12==12.1.0.106
# via
# nvidia-cusolver-cu12
# torch
nvidia-nccl-cu12==2.20.5
# via torch
nvidia-nvjitlink-cu12==12.4.127
# via
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
packaging==24.0
# via matplotlib
pillow==10.3.0
# via matplotlib
pyparsing==3.1.2
# via matplotlib
python-dateutil==2.9.0.post0
# via matplotlib
six==1.16.0
# via python-dateutil
sympy==1.12
# via torch
torch==2.3.0
# via -r requirements.in
tqdm==4.66.4
# via -r requirements.in
triton==2.3.0
# via torch
typing-extensions==4.11.0
# via torch
import torch
import torch.nn.functional as F
import numpy as np
from utils import errors_ber, errors_bler, dec2bitarray, snr_db2sigma
import time
def train(args, polar, optimizer, scheduler, batch_size, train_snr, train_iters, criterion, device, info_positions, binary = False, noise_type = 'awgn'):
if args.N == polar.ell:
assert len(info_positions) == args.K
kernel = True
else:
kernel = False
for iter in range(train_iters):
if batch_size > args.small_batch_size:
small_batch_size = args.small_batch_size
else:
small_batch_size = batch_size
num_batches = batch_size // small_batch_size
for ii in range(num_batches):
msg_bits = 1 - 2*(torch.rand(small_batch_size, args.K) > 0.5).float().to(device)
if args.encoder_type == 'polar':
codes = polar.encode_plotkin(msg_bits)
elif 'KO' in args.encoder_type:
if kernel:
codes = polar.kernel_encode(args.kernel_size, polar.gnet_dict[1][0], msg_bits, info_positions, binary = binary)
else:
codes = polar.deeppolar_encode(msg_bits, binary = binary)
noisy_codes = polar.channel(codes, train_snr, noise_type)
if 'KO' in args.decoder_type:
if kernel:
if args.decoder_type == 'KO_parallel':
decoded_llrs, decoded_bits = polar.kernel_parallel_decode(args.kernel_size, polar.fnet_dict[1][0], noisy_codes, info_positions)
else:
decoded_llrs, decoded_bits = polar.kernel_decode(args.kernel_size, polar.fnet_dict[1][0], noisy_codes, info_positions)
else:
decoded_llrs, decoded_bits = polar.deeppolar_decode(noisy_codes)
elif args.decoder_type == 'SC':
decoded_llrs, decoded_bits = polar.sc_decode_new(noisy_codes, train_snr)
if 'BCE' in args.loss or args.loss == 'focal':
loss = criterion(decoded_llrs, 0.5 * msg_bits.to(polar.device) + 0.5)
else:
loss = criterion(torch.tanh(0.5*decoded_llrs), msg_bits.to(polar.device))
if args.regularizer == 'std':
if args.K == 1:
loss += args.regularizer_weight * torch.std(codes, dim=1).mean()
elif args.K == 2:
loss += args.regularizer_weight * (0.5*torch.std(codes[:, ::2], dim=1).mean() + .5*torch.std(codes[:, 1::2], dim=1).mean())
elif args.regularizer == 'max_deviation':
if args.K == 1:
loss += args.regularizer_weight * torch.amax(torch.abs(codes - codes.mean(dim=1, keepdim=True)), dim=1).mean()
elif args.K == 2:
loss += args.regularizer_weight * (0.5*torch.amax(torch.abs(codes[:, ::2] - codes[:, ::2].mean(dim=1, keepdim=True)), dim=1).mean() + .5*torch.amax(torch.abs(codes[:, 1::2] - codes[:, 1::2].mean(dim=1, keepdim=True)), dim=1).mean())
elif args.regularizer == 'polar':
loss += args.regularizer_weight * F.mse_loss(codes, polar.encode_plotkin(msg_bits))
loss = loss/num_batches
loss.backward()
optimizer.step()
if scheduler is not None:
scheduler.step()
optimizer.zero_grad()
train_ber = errors_ber(decoded_bits.sign(), msg_bits.to(polar.device)).item()
return loss.item(), train_ber
def deeppolar_full_test(args, polar, KO, snr_range, device, info_positions, binary=False, num_errors=100, noise_type = 'awgn'):
bers_KO_test = [0. for _ in snr_range]
blers_KO_test = [0. for _ in snr_range]
bers_SC_test = [0. for _ in snr_range]
blers_SC_test = [0. for _ in snr_range]
kernel = args.N == KO.ell
print(f"TESTING until {num_errors} block errors")
for snr_ind, snr in enumerate(snr_range):
total_block_errors_SC = 0
total_block_errors_KO = 0
batches_processed = 0
sigma = snr_db2sigma(snr) # Assuming SNR is given in dB and noise variance is derived from it
try:
while min(total_block_errors_SC, total_block_errors_KO) <= num_errors:
msg_bits = 2 * (torch.rand(args.test_batch_size, args.K) < 0.5).float() - 1
msg_bits = msg_bits.to(device)
polar_code = polar.encode_plotkin(msg_bits)
if 'KO' in args.encoder_type:
if kernel:
KO_polar_code = KO.kernel_encode(args.kernel_size, KO.gnet_dict[1][0], msg_bits, info_positions, binary=binary)
else:
KO_polar_code = KO.deeppolar_encode(msg_bits, binary=binary)
noisy_code = polar.channel(polar_code, snr, noise_type)
noise = noisy_code - polar_code
noisy_KO_code = KO_polar_code + noise if 'KO' in args.encoder_type else noisy_code
SC_llrs, decoded_SC_msg_bits = polar.sc_decode_new(noisy_code, snr)
ber_SC = errors_ber(msg_bits, decoded_SC_msg_bits.sign()).item()
bler_SC = errors_bler(msg_bits, decoded_SC_msg_bits.sign()).item()
total_block_errors_SC += int(bler_SC*args.test_batch_size)
if 'KO' in args.decoder_type:
if kernel:
if args.decoder_type == 'KO_parallel':
KO_llrs, decoded_KO_msg_bits = KO.kernel_parallel_decode(args.kernel_size, KO.fnet_dict[1][0], noisy_KO_code, info_positions)
else:
KO_llrs, decoded_KO_msg_bits = KO.kernel_decode(args.kernel_size, KO.fnet_dict[1][0], noisy_KO_code, info_positions)
else:
KO_llrs, decoded_KO_msg_bits = KO.deeppolar_decode(noisy_KO_code)
else: # if SC is also used for KO
KO_llrs, decoded_KO_msg_bits = KO.sc_decode_new(noisy_KO_code, snr)
ber_KO = errors_ber(msg_bits, decoded_KO_msg_bits.sign()).item()
bler_KO = errors_bler(msg_bits, decoded_KO_msg_bits.sign()).item()
total_block_errors_KO += int(bler_KO*args.test_batch_size)
batches_processed += 1
# Update accumulative results for logging
bers_KO_test[snr_ind] += ber_KO
bers_SC_test[snr_ind] += ber_SC
blers_KO_test[snr_ind] += bler_KO
blers_SC_test[snr_ind] += bler_SC
# Real-time logging for progress, updating in-place
print(f"SNR: {snr} dB, Sigma: {sigma:.5f}, SC_BER: {bers_SC_test[snr_ind]/batches_processed:.6f}, SC_BLER: {blers_SC_test[snr_ind]/batches_processed:.6f}, KO_BER: {bers_KO_test[snr_ind]/batches_processed:.6f}, KO_BLER: {blers_KO_test[snr_ind]/batches_processed:.6f}, Batches: {batches_processed}", end='\r')
except KeyboardInterrupt:
# print("\nInterrupted by user. Finalizing current SNR...")
pass
# Normalize cumulative metrics by the number of processed batches for accuracy
bers_KO_test[snr_ind] /= (batches_processed + 0.00000001)
bers_SC_test[snr_ind] /= (batches_processed + 0.00000001)
blers_KO_test[snr_ind] /= (batches_processed + 0.00000001)
blers_SC_test[snr_ind] /= (batches_processed + 0.00000001)
print(f"SNR: {snr} dB, Sigma: {sigma:.5f}, SC_BER: {bers_SC_test[snr_ind]:.6f}, SC_BLER: {blers_SC_test[snr_ind]:.6f}, KO_BER: {bers_KO_test[snr_ind]:.6f}, KO_BLER: {blers_KO_test[snr_ind]:.6f}")
return bers_SC_test, blers_SC_test, bers_KO_test, blers_KO_test
import torch
from utils import moving_average
import matplotlib.pyplot as plt
import os
def save_model(polar, iter, results_save_path, best = False):
torch.save([polar.fnet_dict, polar.gnet_dict, polar.depth_map], os.path.join(results_save_path, 'Models/fnet_gnet_{}.pt'.format(iter)))
if iter > 1:
torch.save([polar.fnet_dict, polar.gnet_dict, polar.depth_map], os.path.join(results_save_path, 'Models/fnet_gnet_{}.pt'.format('final')))
if best:
torch.save([polar.fnet_dict, polar.gnet_dict, polar.depth_map], os.path.join(results_save_path, 'Models/fnet_gnet_{}.pt'.format('best')))
def plot_stuff(bers_enc, losses_enc, bers_dec, losses_dec, results_save_path):
plt.figure()
plt.plot(bers_enc, label = 'BER')
plt.plot(moving_average(bers_enc, n=10), label = 'BER moving avg')
plt.yscale('log')
plt.legend(loc='best')
plt.title('Training BER ENC')
plt.savefig(os.path.join(results_save_path,'training_ber_enc.png'))
plt.close()
plt.figure()
plt.plot(losses_enc, label = 'Losses')
plt.plot(moving_average(losses_enc, n=10), label='Losses moving avg')
plt.yscale('log')
plt.legend(loc='best')
plt.title('Training loss ENC')
plt.savefig(os.path.join(results_save_path ,'training_losses_enc.png'))
plt.close()
plt.figure()
plt.plot(bers_dec, label = 'BER')
plt.plot(moving_average(bers_dec, n=10), label = 'BER moving avg')
plt.yscale('log')
plt.legend(loc='best')
plt.title('Training BER DEC')
plt.savefig(os.path.join(results_save_path,'training_ber_dec.png'))
plt.close()
plt.figure()
plt.plot(losses_dec, label = 'Losses')
plt.plot(moving_average(losses_dec, n=10), label='Losses moving avg')
plt.yscale('log')
plt.legend(loc='best')
plt.title('Training loss DEC')
plt.savefig(os.path.join(results_save_path ,'training_losses_dec.png'))
plt.close()
\ No newline at end of file
import torch
import torch.nn.functional as F
from torch.distributions import Normal, StudentT
import numpy as np
from itertools import combinations
def snr_db2sigma(train_snr):
return 10**(-train_snr*1.0/20)
def moving_average(a, n=3) :
ret = np.cumsum(a, dtype=float)
ret[n:] = ret[n:] - ret[:-n]
return ret[n - 1:] / n
def errors_ber(y_true, y_pred, mask=None):
if mask == None:
mask=torch.ones(y_true.size(),device=y_true.device)
y_true = y_true.view(y_true.shape[0], -1, 1)
y_pred = y_pred.view(y_pred.shape[0], -1, 1)
mask = mask.view(mask.shape[0], -1, 1)
myOtherTensor = (mask*torch.ne(torch.round(y_true), torch.round(y_pred))).float()
res = sum(sum(myOtherTensor))/(torch.sum(mask))
return res
def errors_bler(y_true, y_pred, get_pos = False):
y_true = y_true.view(y_true.shape[0], -1, 1)
y_pred = y_pred.view(y_pred.shape[0], -1, 1)
decoded_bits = torch.round(y_pred).cpu()
X_test = torch.round(y_true).cpu()
tp0 = (abs(decoded_bits-X_test)).view([X_test.shape[0],X_test.shape[1]])
tp0 = tp0.detach().cpu().numpy()
bler_err_rate = sum(np.sum(tp0,axis=1)>0)*1.0/(X_test.shape[0])
if not get_pos:
return bler_err_rate
else:
err_pos = list(np.nonzero((np.sum(tp0,axis=1)>0).astype(int))[0])
return bler_err_rate, err_pos
def corrupt_signal(input_signal, sigma = 1.0, noise_type = 'awgn', vv =5.0, radar_power = 20.0, radar_prob = 0.05):
data_shape = input_signal.shape # input_signal has to be a numpy array.
assert noise_type in ['bsc', 'awgn', 'fading', 'radar', 't-dist', 'isi_perfect', 'isi_uncertain'], "Invalid noise type"
device = input_signal.device
if noise_type == 'awgn':
dist = Normal(torch.tensor([0.0], device = device), torch.tensor([sigma], device = device))
noise = dist.sample(input_signal.shape).squeeze()
corrupted_signal = input_signal + noise
elif noise_type == 'fading':
fading_h = torch.sqrt(torch.randn_like(input_signal)**2 + torch.randn_like(input_signal)**2)/np.sqrt(3.14/2.0)
noise = sigma * torch.randn_like(input_signal) # Define noise
corrupted_signal = fading_h *(input_signal) + noise
elif noise_type == 'radar':
add_pos = np.random.choice([0.0, 1.0], data_shape,
p=[1 - radar_prob, radar_prob])
corrupted_signal = radar_power* np.random.standard_normal( size = data_shape ) * add_pos
noise = sigma * torch.randn_like(input_signal) +\
torch.from_numpy(corrupted_signal).float().to(input_signal.device)
corrupted_signal = input_signal + noise
elif noise_type == 't-dist':
dist = StudentT(torch.tensor([vv], device = device))
noise = sigma* dist.sample(input_signal.shape).squeeze()
corrupted_signal = input_signal + noise
return corrupted_signal
def snr_db2sigma(train_snr):
return 10**(-train_snr*1.0/20)
def min_sum_log_sum_exp(x, y):
log_sum_ms = torch.min(torch.abs(x), torch.abs(y))*torch.sign(x)*torch.sign(y)
return log_sum_ms
def min_sum_log_sum_exp_4(x_1, x_2, x_3, x_4):
return min_sum_log_sum_exp(min_sum_log_sum_exp(x_1, x_2), min_sum_log_sum_exp(x_3, x_4))
def log_sum_exp(x, y):
def log_sum_exp_(LLR_vector):
sum_vector = LLR_vector.sum(dim=1, keepdim=True)
sum_concat = torch.cat([sum_vector, torch.zeros_like(sum_vector)], dim=1)
return torch.logsumexp(sum_concat, dim=1)- torch.logsumexp(LLR_vector, dim=1)
Lv = log_sum_exp_(torch.cat([x.unsqueeze(2), y.unsqueeze(2)], dim=2).permute(0, 2, 1))
return Lv
def dec2bitarray(in_number, bit_width):
"""
Converts a positive integer to NumPy array of the specified size containing
bits (0 and 1).
Parameters
----------
in_number : int
Positive integer to be converted to a bit array.
bit_width : int
Size of the output bit array.
Returns
-------
bitarray : 1D ndarray of ints
Array containing the binary representation of the input decimal.
"""
binary_string = bin(in_number)
length = len(binary_string)
bitarray = np.zeros(bit_width, 'int')
for i in range(length-2):
bitarray[bit_width-i-1] = int(binary_string[length-i-1])
return bitarray
def countSetBits(n):
count = 0
while (n):
n &= (n-1)
count+= 1
return count
class STEQuantize(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, enc_quantize_level = 2, enc_value_limit = 1.0, enc_grad_limit = 0.01, enc_clipping = 'both'):
ctx.save_for_backward(inputs)
assert enc_clipping in ['both', 'inputs']
ctx.enc_clipping = enc_clipping
ctx.enc_value_limit = enc_value_limit
ctx.enc_quantize_level = enc_quantize_level
ctx.enc_grad_limit = enc_grad_limit
x_lim_abs = enc_value_limit
x_lim_range = 2.0 * x_lim_abs
x_input_norm = torch.clamp(inputs, -x_lim_abs, x_lim_abs)
if enc_quantize_level == 2:
outputs_int = torch.sign(x_input_norm)
else:
outputs_int = torch.round((x_input_norm +x_lim_abs) * ((enc_quantize_level - 1.0)/x_lim_range)) * x_lim_range/(enc_quantize_level - 1.0) - x_lim_abs
return outputs_int
@staticmethod
def backward(ctx, grad_output):
if ctx.enc_clipping in ['inputs', 'both']:
input, = ctx.saved_tensors
grad_output[input>ctx.enc_value_limit]=0
grad_output[input<-ctx.enc_value_limit]=0
if ctx.enc_clipping in ['gradient', 'both']:
grad_output = torch.clamp(grad_output, -ctx.enc_grad_limit, ctx.enc_grad_limit)
grad_input = grad_output.clone()
return grad_input, None
def pairwise_distances(codebook):
dists = []
for row1, row2 in combinations(codebook, 2):
distance = (row1-row2).pow(2).sum()
dists.append(np.sqrt(distance.item()))
return dists, np.min(dists)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册