未验证 提交 c049a6b4 编写于 作者: K KP 提交者: GitHub

Add window computation in stft op. (#40987)

上级 b6661d3a
......@@ -30,6 +30,8 @@ class StftOp : public framework::OperatorWithKernel {
const auto x_dims = ctx->GetInputDim("X");
const int x_rank = x_dims.size();
const auto window_dims = ctx->GetInputDim("Window");
const int window_size = window_dims[0];
const bool onesided = ctx->Attrs().Get<bool>("onesided");
PADDLE_ENFORCE_EQ(
......@@ -43,6 +45,12 @@ class StftOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument(
"Attribute(hop_length) should be greater than 0, but got %s.",
hop_length));
PADDLE_ENFORCE_EQ(
window_size, n_fft,
platform::errors::InvalidArgument(
"Input(Window) of StftOp should be equal with n_fft %s, "
"but got %s.",
n_fft, window_size));
int seq_length = x_dims[x_rank - 1];
int n_frames = 1 + (seq_length - n_fft) / hop_length;
......@@ -77,6 +85,7 @@ class StftOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input waveforms with shape (N, T)");
AddInput("Window", "Input window with shape (n_fft,)");
AddOutput("Out",
"The complex STFT output tensor with shape (N, n_fft, "
"num_frames) or (N, n_fft/2 + 1, num_frames)");
......@@ -101,6 +110,7 @@ class StftGradOpMaker : public framework::SingleGradOpMaker<T> {
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("stft_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput("Window", this->Input("Window"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
......
......@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/frame_op.h"
#include "paddle/fluid/operators/spectral_op.h"
......@@ -36,6 +37,7 @@ class StftKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>;
const Tensor* x = ctx.Input<Tensor>("X");
const Tensor* window = ctx.Input<Tensor>("Window");
Tensor* out = ctx.Output<Tensor>("Out");
out->mutable_data<C>(ctx.GetPlace());
......@@ -62,6 +64,12 @@ class StftKernel : public framework::OpKernel<T> {
FrameFunctor<DeviceContext, T>()(dev_ctx, x, &frames, seq_length, n_fft,
n_frames, hop_length, /*is_grad*/ false);
// Window
Tensor frames_w;
frames_w.mutable_data<T>(frames_dims, ctx.GetPlace());
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &frames, window, axes.back(), MulFunctor<T>(), &frames_w);
// FFTR2C
FFTNormMode normalization;
if (normalized) {
......@@ -72,14 +80,15 @@ class StftKernel : public framework::OpKernel<T> {
FFTR2CFunctor<DeviceContext, T, C> fft_r2c_func;
if (onesided) {
fft_r2c_func(dev_ctx, &frames, out, axes, normalization, true);
fft_r2c_func(dev_ctx, &frames_w, out, axes, normalization, true);
} else {
framework::DDim onesided_dims(out->dims());
const int64_t onesided_axis_size = out->dims().at(axes.back()) / 2 + 1;
onesided_dims.at(axes.back()) = onesided_axis_size;
Tensor onesided_out;
onesided_out.mutable_data<C>(onesided_dims, ctx.GetPlace());
fft_r2c_func(dev_ctx, &frames, &onesided_out, axes, normalization, true);
fft_r2c_func(dev_ctx, &frames_w, &onesided_out, axes, normalization,
true);
fill_conj<DeviceContext, C>(dev_ctx, &onesided_out, out, axes);
}
}
......@@ -92,6 +101,7 @@ class StftGradKernel : public framework::OpKernel<T> {
using C = paddle::platform::complex<T>;
auto& dev_ctx = ctx.device_context<DeviceContext>();
const Tensor* window = ctx.Input<Tensor>("Window");
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
......@@ -107,15 +117,15 @@ class StftGradKernel : public framework::OpKernel<T> {
const int seq_length = dx->dims()[dx_rank - 1];
std::vector<int64_t> axes = {1};
Tensor d_frames;
Tensor d_frames_w;
framework::DDim d_frames_dims(dy->dims());
d_frames_dims.at(axes.back()) = n_fft;
d_frames.mutable_data<T>(d_frames_dims, ctx.GetPlace());
d_frames_w.mutable_data<T>(d_frames_dims, ctx.GetPlace());
Tensor complex_d_frames;
complex_d_frames.mutable_data<C>(d_frames_dims, ctx.GetPlace());
Tensor complex_d_frames_w;
complex_d_frames_w.mutable_data<C>(d_frames_dims, ctx.GetPlace());
// dy -> d_frames
// dy -> d_frames_w
FFTNormMode normalization;
if (normalized) {
normalization = get_norm_from_string("ortho", true);
......@@ -125,7 +135,8 @@ class StftGradKernel : public framework::OpKernel<T> {
FFTC2CFunctor<DeviceContext, C, C> fft_c2c_func;
if (!onesided) {
fft_c2c_func(dev_ctx, dy, &complex_d_frames, axes, normalization, false);
fft_c2c_func(dev_ctx, dy, &complex_d_frames_w, axes, normalization,
false);
} else {
Tensor full_dy;
full_dy.mutable_data<C>(d_frames_dims, ctx.GetPlace());
......@@ -139,13 +150,19 @@ class StftGradKernel : public framework::OpKernel<T> {
phi::funcs::PaddingFunctor<DeviceContext, C>(
rank, ctx.template device_context<DeviceContext>(), pads,
static_cast<C>(0), *dy, &full_dy);
fft_c2c_func(dev_ctx, &full_dy, &complex_d_frames, axes, normalization,
fft_c2c_func(dev_ctx, &full_dy, &complex_d_frames_w, axes, normalization,
false);
}
framework::TransComplexToReal(
framework::TransToProtoVarType(d_frames.dtype()),
framework::TransToProtoVarType(complex_d_frames.dtype()),
complex_d_frames, &d_frames);
framework::TransToProtoVarType(d_frames_w.dtype()),
framework::TransToProtoVarType(complex_d_frames_w.dtype()),
complex_d_frames_w, &d_frames_w);
// d_frames_w -> d_frames
Tensor d_frames;
d_frames.mutable_data<T>(d_frames_dims, ctx.GetPlace());
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &d_frames_w, window, axes.back(), MulFunctor<T>(), &d_frames);
// d_frames -> dx
FrameFunctor<DeviceContext, T>()(dev_ctx, &d_frames, dx, seq_length, n_fft,
......
......@@ -43,8 +43,10 @@ def frame_from_librosa(x, frame_length, hop_length, axis=-1):
return as_strided(x, shape=shape, strides=strides)
def stft_np(x, n_fft, hop_length, **kwargs):
def stft_np(x, window, n_fft, hop_length, **kwargs):
frames = frame_from_librosa(x, n_fft, hop_length)
frames = np.multiply(frames.transpose([0, 2, 1]), window).transpose(
[0, 2, 1])
res = np.fft.rfft(frames, axis=1)
return res
......@@ -55,8 +57,12 @@ class TestStftOp(OpTest):
self.shape, self.type, self.attrs = self.initTestCase()
self.inputs = {
'X': np.random.random(size=self.shape).astype(self.type),
'Window': np.hamming(self.attrs['n_fft']).astype(self.type),
}
self.outputs = {
'Out': stft_np(
x=self.inputs['X'], window=self.inputs['Window'], **self.attrs)
}
self.outputs = {'Out': stft_np(x=self.inputs['X'], **self.attrs)}
def initTestCase(self):
input_shape = (2, 100)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册