未验证 提交 02537195 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #668 from PaddlePaddle/feat

audio feature
此差异已折叠。
...@@ -165,9 +165,13 @@ class STFT(torch.nn.Module): ...@@ -165,9 +165,13 @@ class STFT(torch.nn.Module):
# self.kernel_cos = torch.nn.Parameter(self.kernel_cos, requires_grad=self.trainable) # self.kernel_cos = torch.nn.Parameter(self.kernel_cos, requires_grad=self.trainable)
# Applying window functions to the Fourier kernels # Applying window functions to the Fourier kernels
window_mask = torch.tensor(window_mask) if window:
wsin = kernel_sin * window_mask window_mask = torch.tensor(window_mask)
wcos = kernel_cos * window_mask wsin = kernel_sin * window_mask
wcos = kernel_cos * window_mask
else:
wsin = kernel_sin
wcos = kernel_cos
if self.trainable==False: if self.trainable==False:
self.register_buffer('wsin', wsin) self.register_buffer('wsin', wsin)
...@@ -179,7 +183,6 @@ class STFT(torch.nn.Module): ...@@ -179,7 +183,6 @@ class STFT(torch.nn.Module):
self.register_parameter('wsin', wsin) self.register_parameter('wsin', wsin)
self.register_parameter('wcos', wcos) self.register_parameter('wcos', wcos)
# Prepare the shape of window mask so that it can be used later in inverse # Prepare the shape of window mask so that it can be used later in inverse
self.register_buffer('window_mask', window_mask.unsqueeze(0).unsqueeze(-1)) self.register_buffer('window_mask', window_mask.unsqueeze(0).unsqueeze(-1))
......
...@@ -2,29 +2,26 @@ import setuptools ...@@ -2,29 +2,26 @@ import setuptools
import codecs import codecs
import os.path import os.path
with open("README.md", "r") as fh:
long_description = fh.read()
def read(rel_path): def read(rel_path):
here = os.path.abspath(os.path.dirname(__file__)) here = os.path.abspath(os.path.dirname(__file__))
with codecs.open(os.path.join(here, rel_path), 'r') as fp: with codecs.open(os.path.join(here, rel_path), 'r') as fp:
return fp.read() return fp.read()
def get_version(rel_path): def get_version(rel_path):
for line in read(rel_path).splitlines(): for line in read(rel_path).splitlines():
if line.startswith('__version__'): if line.startswith('__version__'):
delim = '"' if '"' in line else "'" delim = '"' if '"' in line else "'"
return line.split(delim)[1] return line.split(delim)[1]
else: else:
raise RuntimeError("Unable to find version string.") raise RuntimeError("Unable to find version string.")
setuptools.setup( setuptools.setup(
name="nnAudio", # Replace with your own username name="nnAudio", # Replace with your own username
version=get_version("nnAudio/__init__.py"), version=get_version("nnAudio/__init__.py"),
author="KinWaiCheuk", author="KinWaiCheuk",
author_email="u3500684@connect.hku.hk", author_email="u3500684@connect.hku.hk",
description="A fast GPU audio processing toolbox with 1D convolutional neural network", description="A fast GPU audio processing toolbox with 1D convolutional neural network",
long_description=long_description, long_description='',
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
url="https://github.com/KinWaiCheuk/nnAudio", url="https://github.com/KinWaiCheuk/nnAudio",
packages=setuptools.find_packages(), packages=setuptools.find_packages(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册