提交 b08384cd 编写于 作者: H Hui Zhang

using conv1d to do fft

上级 4ad885f8
此差异已折叠。
......@@ -165,9 +165,13 @@ class STFT(torch.nn.Module):
# self.kernel_cos = torch.nn.Parameter(self.kernel_cos, requires_grad=self.trainable)
# Applying window functions to the Fourier kernels
window_mask = torch.tensor(window_mask)
wsin = kernel_sin * window_mask
wcos = kernel_cos * window_mask
if window:
window_mask = torch.tensor(window_mask)
wsin = kernel_sin * window_mask
wcos = kernel_cos * window_mask
else:
wsin = kernel_sin
wcos = kernel_cos
if self.trainable==False:
self.register_buffer('wsin', wsin)
......@@ -179,7 +183,6 @@ class STFT(torch.nn.Module):
self.register_parameter('wsin', wsin)
self.register_parameter('wcos', wcos)
# Prepare the shape of window mask so that it can be used later in inverse
self.register_buffer('window_mask', window_mask.unsqueeze(0).unsqueeze(-1))
......
......@@ -2,29 +2,26 @@ import setuptools
import codecs
import os.path
with open("README.md", "r") as fh:
long_description = fh.read()
def read(rel_path):
here = os.path.abspath(os.path.dirname(__file__))
with codecs.open(os.path.join(here, rel_path), 'r') as fp:
return fp.read()
return fp.read()
def get_version(rel_path):
for line in read(rel_path).splitlines():
if line.startswith('__version__'):
delim = '"' if '"' in line else "'"
return line.split(delim)[1]
else:
raise RuntimeError("Unable to find version string.")
raise RuntimeError("Unable to find version string.")
setuptools.setup(
name="nnAudio", # Replace with your own username
version=get_version("nnAudio/__init__.py"),
author="KinWaiCheuk",
author_email="u3500684@connect.hku.hk",
description="A fast GPU audio processing toolbox with 1D convolutional neural network",
long_description=long_description,
long_description='',
long_description_content_type="text/markdown",
url="https://github.com/KinWaiCheuk/nnAudio",
packages=setuptools.find_packages(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册