Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
c049a6b4
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c049a6b4
编写于
3月 28, 2022
作者:
K
KP
提交者:
GitHub
3月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add window computation in stft op. (#40987)
上级
b6661d3a
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
47 addition
and
14 deletion
+47
-14
paddle/fluid/operators/stft_op.cc
paddle/fluid/operators/stft_op.cc
+10
-0
paddle/fluid/operators/stft_op.h
paddle/fluid/operators/stft_op.h
+29
-12
python/paddle/fluid/tests/unittests/test_stft_op.py
python/paddle/fluid/tests/unittests/test_stft_op.py
+8
-2
未找到文件。
paddle/fluid/operators/stft_op.cc
浏览文件 @
c049a6b4
...
@@ -30,6 +30,8 @@ class StftOp : public framework::OperatorWithKernel {
...
@@ -30,6 +30,8 @@ class StftOp : public framework::OperatorWithKernel {
const
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
const
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
const
int
x_rank
=
x_dims
.
size
();
const
int
x_rank
=
x_dims
.
size
();
const
auto
window_dims
=
ctx
->
GetInputDim
(
"Window"
);
const
int
window_size
=
window_dims
[
0
];
const
bool
onesided
=
ctx
->
Attrs
().
Get
<
bool
>
(
"onesided"
);
const
bool
onesided
=
ctx
->
Attrs
().
Get
<
bool
>
(
"onesided"
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
...
@@ -43,6 +45,12 @@ class StftOp : public framework::OperatorWithKernel {
...
@@ -43,6 +45,12 @@ class StftOp : public framework::OperatorWithKernel {
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Attribute(hop_length) should be greater than 0, but got %s."
,
"Attribute(hop_length) should be greater than 0, but got %s."
,
hop_length
));
hop_length
));
PADDLE_ENFORCE_EQ
(
window_size
,
n_fft
,
platform
::
errors
::
InvalidArgument
(
"Input(Window) of StftOp should be equal with n_fft %s, "
"but got %s."
,
n_fft
,
window_size
));
int
seq_length
=
x_dims
[
x_rank
-
1
];
int
seq_length
=
x_dims
[
x_rank
-
1
];
int
n_frames
=
1
+
(
seq_length
-
n_fft
)
/
hop_length
;
int
n_frames
=
1
+
(
seq_length
-
n_fft
)
/
hop_length
;
...
@@ -77,6 +85,7 @@ class StftOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -77,6 +85,7 @@ class StftOpMaker : public framework::OpProtoAndCheckerMaker {
public:
public:
void
Make
()
override
{
void
Make
()
override
{
AddInput
(
"X"
,
"Input waveforms with shape (N, T)"
);
AddInput
(
"X"
,
"Input waveforms with shape (N, T)"
);
AddInput
(
"Window"
,
"Input window with shape (n_fft,)"
);
AddOutput
(
"Out"
,
AddOutput
(
"Out"
,
"The complex STFT output tensor with shape (N, n_fft, "
"The complex STFT output tensor with shape (N, n_fft, "
"num_frames) or (N, n_fft/2 + 1, num_frames)"
);
"num_frames) or (N, n_fft/2 + 1, num_frames)"
);
...
@@ -101,6 +110,7 @@ class StftGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -101,6 +110,7 @@ class StftGradOpMaker : public framework::SingleGradOpMaker<T> {
void
Apply
(
GradOpPtr
<
T
>
grad_op
)
const
override
{
void
Apply
(
GradOpPtr
<
T
>
grad_op
)
const
override
{
grad_op
->
SetType
(
"stft_grad"
);
grad_op
->
SetType
(
"stft_grad"
);
grad_op
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
grad_op
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
grad_op
->
SetInput
(
"Window"
,
this
->
Input
(
"Window"
));
grad_op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
grad_op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
grad_op
->
SetAttrMap
(
this
->
Attrs
());
grad_op
->
SetAttrMap
(
this
->
Attrs
());
...
...
paddle/fluid/operators/stft_op.h
浏览文件 @
c049a6b4
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/frame_op.h"
#include "paddle/fluid/operators/frame_op.h"
#include "paddle/fluid/operators/spectral_op.h"
#include "paddle/fluid/operators/spectral_op.h"
...
@@ -36,6 +37,7 @@ class StftKernel : public framework::OpKernel<T> {
...
@@ -36,6 +37,7 @@ class StftKernel : public framework::OpKernel<T> {
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
using
C
=
paddle
::
platform
::
complex
<
T
>
;
using
C
=
paddle
::
platform
::
complex
<
T
>
;
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
window
=
ctx
.
Input
<
Tensor
>
(
"Window"
);
Tensor
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
Tensor
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
out
->
mutable_data
<
C
>
(
ctx
.
GetPlace
());
out
->
mutable_data
<
C
>
(
ctx
.
GetPlace
());
...
@@ -62,6 +64,12 @@ class StftKernel : public framework::OpKernel<T> {
...
@@ -62,6 +64,12 @@ class StftKernel : public framework::OpKernel<T> {
FrameFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
x
,
&
frames
,
seq_length
,
n_fft
,
FrameFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
x
,
&
frames
,
seq_length
,
n_fft
,
n_frames
,
hop_length
,
/*is_grad*/
false
);
n_frames
,
hop_length
,
/*is_grad*/
false
);
// Window
Tensor
frames_w
;
frames_w
.
mutable_data
<
T
>
(
frames_dims
,
ctx
.
GetPlace
());
ElementwiseComputeEx
<
MulFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
&
frames
,
window
,
axes
.
back
(),
MulFunctor
<
T
>
(),
&
frames_w
);
// FFTR2C
// FFTR2C
FFTNormMode
normalization
;
FFTNormMode
normalization
;
if
(
normalized
)
{
if
(
normalized
)
{
...
@@ -72,14 +80,15 @@ class StftKernel : public framework::OpKernel<T> {
...
@@ -72,14 +80,15 @@ class StftKernel : public framework::OpKernel<T> {
FFTR2CFunctor
<
DeviceContext
,
T
,
C
>
fft_r2c_func
;
FFTR2CFunctor
<
DeviceContext
,
T
,
C
>
fft_r2c_func
;
if
(
onesided
)
{
if
(
onesided
)
{
fft_r2c_func
(
dev_ctx
,
&
frames
,
out
,
axes
,
normalization
,
true
);
fft_r2c_func
(
dev_ctx
,
&
frames
_w
,
out
,
axes
,
normalization
,
true
);
}
else
{
}
else
{
framework
::
DDim
onesided_dims
(
out
->
dims
());
framework
::
DDim
onesided_dims
(
out
->
dims
());
const
int64_t
onesided_axis_size
=
out
->
dims
().
at
(
axes
.
back
())
/
2
+
1
;
const
int64_t
onesided_axis_size
=
out
->
dims
().
at
(
axes
.
back
())
/
2
+
1
;
onesided_dims
.
at
(
axes
.
back
())
=
onesided_axis_size
;
onesided_dims
.
at
(
axes
.
back
())
=
onesided_axis_size
;
Tensor
onesided_out
;
Tensor
onesided_out
;
onesided_out
.
mutable_data
<
C
>
(
onesided_dims
,
ctx
.
GetPlace
());
onesided_out
.
mutable_data
<
C
>
(
onesided_dims
,
ctx
.
GetPlace
());
fft_r2c_func
(
dev_ctx
,
&
frames
,
&
onesided_out
,
axes
,
normalization
,
true
);
fft_r2c_func
(
dev_ctx
,
&
frames_w
,
&
onesided_out
,
axes
,
normalization
,
true
);
fill_conj
<
DeviceContext
,
C
>
(
dev_ctx
,
&
onesided_out
,
out
,
axes
);
fill_conj
<
DeviceContext
,
C
>
(
dev_ctx
,
&
onesided_out
,
out
,
axes
);
}
}
}
}
...
@@ -92,6 +101,7 @@ class StftGradKernel : public framework::OpKernel<T> {
...
@@ -92,6 +101,7 @@ class StftGradKernel : public framework::OpKernel<T> {
using
C
=
paddle
::
platform
::
complex
<
T
>
;
using
C
=
paddle
::
platform
::
complex
<
T
>
;
auto
&
dev_ctx
=
ctx
.
device_context
<
DeviceContext
>
();
auto
&
dev_ctx
=
ctx
.
device_context
<
DeviceContext
>
();
const
Tensor
*
window
=
ctx
.
Input
<
Tensor
>
(
"Window"
);
const
auto
*
dy
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
const
auto
*
dy
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
@@ -107,15 +117,15 @@ class StftGradKernel : public framework::OpKernel<T> {
...
@@ -107,15 +117,15 @@ class StftGradKernel : public framework::OpKernel<T> {
const
int
seq_length
=
dx
->
dims
()[
dx_rank
-
1
];
const
int
seq_length
=
dx
->
dims
()[
dx_rank
-
1
];
std
::
vector
<
int64_t
>
axes
=
{
1
};
std
::
vector
<
int64_t
>
axes
=
{
1
};
Tensor
d_frames
;
Tensor
d_frames
_w
;
framework
::
DDim
d_frames_dims
(
dy
->
dims
());
framework
::
DDim
d_frames_dims
(
dy
->
dims
());
d_frames_dims
.
at
(
axes
.
back
())
=
n_fft
;
d_frames_dims
.
at
(
axes
.
back
())
=
n_fft
;
d_frames
.
mutable_data
<
T
>
(
d_frames_dims
,
ctx
.
GetPlace
());
d_frames
_w
.
mutable_data
<
T
>
(
d_frames_dims
,
ctx
.
GetPlace
());
Tensor
complex_d_frames
;
Tensor
complex_d_frames
_w
;
complex_d_frames
.
mutable_data
<
C
>
(
d_frames_dims
,
ctx
.
GetPlace
());
complex_d_frames
_w
.
mutable_data
<
C
>
(
d_frames_dims
,
ctx
.
GetPlace
());
// dy -> d_frames
// dy -> d_frames
_w
FFTNormMode
normalization
;
FFTNormMode
normalization
;
if
(
normalized
)
{
if
(
normalized
)
{
normalization
=
get_norm_from_string
(
"ortho"
,
true
);
normalization
=
get_norm_from_string
(
"ortho"
,
true
);
...
@@ -125,7 +135,8 @@ class StftGradKernel : public framework::OpKernel<T> {
...
@@ -125,7 +135,8 @@ class StftGradKernel : public framework::OpKernel<T> {
FFTC2CFunctor
<
DeviceContext
,
C
,
C
>
fft_c2c_func
;
FFTC2CFunctor
<
DeviceContext
,
C
,
C
>
fft_c2c_func
;
if
(
!
onesided
)
{
if
(
!
onesided
)
{
fft_c2c_func
(
dev_ctx
,
dy
,
&
complex_d_frames
,
axes
,
normalization
,
false
);
fft_c2c_func
(
dev_ctx
,
dy
,
&
complex_d_frames_w
,
axes
,
normalization
,
false
);
}
else
{
}
else
{
Tensor
full_dy
;
Tensor
full_dy
;
full_dy
.
mutable_data
<
C
>
(
d_frames_dims
,
ctx
.
GetPlace
());
full_dy
.
mutable_data
<
C
>
(
d_frames_dims
,
ctx
.
GetPlace
());
...
@@ -139,13 +150,19 @@ class StftGradKernel : public framework::OpKernel<T> {
...
@@ -139,13 +150,19 @@ class StftGradKernel : public framework::OpKernel<T> {
phi
::
funcs
::
PaddingFunctor
<
DeviceContext
,
C
>
(
phi
::
funcs
::
PaddingFunctor
<
DeviceContext
,
C
>
(
rank
,
ctx
.
template
device_context
<
DeviceContext
>(),
pads
,
rank
,
ctx
.
template
device_context
<
DeviceContext
>(),
pads
,
static_cast
<
C
>
(
0
),
*
dy
,
&
full_dy
);
static_cast
<
C
>
(
0
),
*
dy
,
&
full_dy
);
fft_c2c_func
(
dev_ctx
,
&
full_dy
,
&
complex_d_frames
,
axes
,
normalization
,
fft_c2c_func
(
dev_ctx
,
&
full_dy
,
&
complex_d_frames
_w
,
axes
,
normalization
,
false
);
false
);
}
}
framework
::
TransComplexToReal
(
framework
::
TransComplexToReal
(
framework
::
TransToProtoVarType
(
d_frames
.
dtype
()),
framework
::
TransToProtoVarType
(
d_frames_w
.
dtype
()),
framework
::
TransToProtoVarType
(
complex_d_frames
.
dtype
()),
framework
::
TransToProtoVarType
(
complex_d_frames_w
.
dtype
()),
complex_d_frames
,
&
d_frames
);
complex_d_frames_w
,
&
d_frames_w
);
// d_frames_w -> d_frames
Tensor
d_frames
;
d_frames
.
mutable_data
<
T
>
(
d_frames_dims
,
ctx
.
GetPlace
());
ElementwiseComputeEx
<
MulFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
&
d_frames_w
,
window
,
axes
.
back
(),
MulFunctor
<
T
>
(),
&
d_frames
);
// d_frames -> dx
// d_frames -> dx
FrameFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
&
d_frames
,
dx
,
seq_length
,
n_fft
,
FrameFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
&
d_frames
,
dx
,
seq_length
,
n_fft
,
...
...
python/paddle/fluid/tests/unittests/test_stft_op.py
浏览文件 @
c049a6b4
...
@@ -43,8 +43,10 @@ def frame_from_librosa(x, frame_length, hop_length, axis=-1):
...
@@ -43,8 +43,10 @@ def frame_from_librosa(x, frame_length, hop_length, axis=-1):
return
as_strided
(
x
,
shape
=
shape
,
strides
=
strides
)
return
as_strided
(
x
,
shape
=
shape
,
strides
=
strides
)
def
stft_np
(
x
,
n_fft
,
hop_length
,
**
kwargs
):
def
stft_np
(
x
,
window
,
n_fft
,
hop_length
,
**
kwargs
):
frames
=
frame_from_librosa
(
x
,
n_fft
,
hop_length
)
frames
=
frame_from_librosa
(
x
,
n_fft
,
hop_length
)
frames
=
np
.
multiply
(
frames
.
transpose
([
0
,
2
,
1
]),
window
).
transpose
(
[
0
,
2
,
1
])
res
=
np
.
fft
.
rfft
(
frames
,
axis
=
1
)
res
=
np
.
fft
.
rfft
(
frames
,
axis
=
1
)
return
res
return
res
...
@@ -55,8 +57,12 @@ class TestStftOp(OpTest):
...
@@ -55,8 +57,12 @@ class TestStftOp(OpTest):
self
.
shape
,
self
.
type
,
self
.
attrs
=
self
.
initTestCase
()
self
.
shape
,
self
.
type
,
self
.
attrs
=
self
.
initTestCase
()
self
.
inputs
=
{
self
.
inputs
=
{
'X'
:
np
.
random
.
random
(
size
=
self
.
shape
).
astype
(
self
.
type
),
'X'
:
np
.
random
.
random
(
size
=
self
.
shape
).
astype
(
self
.
type
),
'Window'
:
np
.
hamming
(
self
.
attrs
[
'n_fft'
]).
astype
(
self
.
type
),
}
self
.
outputs
=
{
'Out'
:
stft_np
(
x
=
self
.
inputs
[
'X'
],
window
=
self
.
inputs
[
'Window'
],
**
self
.
attrs
)
}
}
self
.
outputs
=
{
'Out'
:
stft_np
(
x
=
self
.
inputs
[
'X'
],
**
self
.
attrs
)}
def
initTestCase
(
self
):
def
initTestCase
(
self
):
input_shape
=
(
2
,
100
)
input_shape
=
(
2
,
100
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录