diff --git a/paddle/fluid/operators/stft_op.cc b/paddle/fluid/operators/stft_op.cc index ecbd9edd87dc6c12eda76b76ca9239b79d3aec9c..7d4103ddf3859c6fda71c08395207da1d987a933 100644 --- a/paddle/fluid/operators/stft_op.cc +++ b/paddle/fluid/operators/stft_op.cc @@ -30,6 +30,8 @@ class StftOp : public framework::OperatorWithKernel { const auto x_dims = ctx->GetInputDim("X"); const int x_rank = x_dims.size(); + const auto window_dims = ctx->GetInputDim("Window"); + const int window_size = window_dims[0]; const bool onesided = ctx->Attrs().Get("onesided"); PADDLE_ENFORCE_EQ( @@ -43,6 +45,12 @@ class StftOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "Attribute(hop_length) should be greater than 0, but got %s.", hop_length)); + PADDLE_ENFORCE_EQ( + window_size, n_fft, + platform::errors::InvalidArgument( + "Input(Window) of StftOp should be equal with n_fft %s, " + "but got %s.", + n_fft, window_size)); int seq_length = x_dims[x_rank - 1]; int n_frames = 1 + (seq_length - n_fft) / hop_length; @@ -77,6 +85,7 @@ class StftOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", "Input waveforms with shape (N, T)"); + AddInput("Window", "Input window with shape (n_fft,)"); AddOutput("Out", "The complex STFT output tensor with shape (N, n_fft, " "num_frames) or (N, n_fft/2 + 1, num_frames)"); @@ -101,6 +110,7 @@ class StftGradOpMaker : public framework::SingleGradOpMaker { void Apply(GradOpPtr grad_op) const override { grad_op->SetType("stft_grad"); grad_op->SetInput("X", this->Input("X")); + grad_op->SetInput("Window", this->Input("Window")); grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad_op->SetAttrMap(this->Attrs()); diff --git a/paddle/fluid/operators/stft_op.h b/paddle/fluid/operators/stft_op.h index 4f0746ee143f9b53214b2b3ebb81a571bb008908..e75c59232bcaebe0b594b951c9ceb758f308a49a 100644 --- a/paddle/fluid/operators/stft_op.h +++ b/paddle/fluid/operators/stft_op.h @@ -18,6 +18,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/frame_op.h" #include "paddle/fluid/operators/spectral_op.h" @@ -36,6 +37,7 @@ class StftKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { using C = paddle::platform::complex; const Tensor* x = ctx.Input("X"); + const Tensor* window = ctx.Input("Window"); Tensor* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); @@ -62,6 +64,12 @@ class StftKernel : public framework::OpKernel { FrameFunctor()(dev_ctx, x, &frames, seq_length, n_fft, n_frames, hop_length, /*is_grad*/ false); + // Window + Tensor frames_w; + frames_w.mutable_data(frames_dims, ctx.GetPlace()); + ElementwiseComputeEx, DeviceContext, T>( + ctx, &frames, window, axes.back(), MulFunctor(), &frames_w); + // FFTR2C FFTNormMode normalization; if (normalized) { @@ -72,14 +80,15 @@ class StftKernel : public framework::OpKernel { FFTR2CFunctor fft_r2c_func; if (onesided) { - fft_r2c_func(dev_ctx, &frames, out, axes, normalization, true); + fft_r2c_func(dev_ctx, &frames_w, out, axes, normalization, true); } else { framework::DDim onesided_dims(out->dims()); const int64_t onesided_axis_size = out->dims().at(axes.back()) / 2 + 1; onesided_dims.at(axes.back()) = onesided_axis_size; Tensor onesided_out; onesided_out.mutable_data(onesided_dims, ctx.GetPlace()); - fft_r2c_func(dev_ctx, &frames, &onesided_out, axes, normalization, true); + fft_r2c_func(dev_ctx, &frames_w, &onesided_out, axes, normalization, + true); fill_conj(dev_ctx, &onesided_out, out, axes); } } @@ -92,6 +101,7 @@ class StftGradKernel : public framework::OpKernel { using C = paddle::platform::complex; auto& dev_ctx = ctx.device_context(); + const Tensor* window = ctx.Input("Window"); const auto* dy = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); dx->mutable_data(ctx.GetPlace()); @@ -107,15 +117,15 @@ class StftGradKernel : public framework::OpKernel { const int seq_length = dx->dims()[dx_rank - 1]; std::vector axes = {1}; - Tensor d_frames; + Tensor d_frames_w; framework::DDim d_frames_dims(dy->dims()); d_frames_dims.at(axes.back()) = n_fft; - d_frames.mutable_data(d_frames_dims, ctx.GetPlace()); + d_frames_w.mutable_data(d_frames_dims, ctx.GetPlace()); - Tensor complex_d_frames; - complex_d_frames.mutable_data(d_frames_dims, ctx.GetPlace()); + Tensor complex_d_frames_w; + complex_d_frames_w.mutable_data(d_frames_dims, ctx.GetPlace()); - // dy -> d_frames + // dy -> d_frames_w FFTNormMode normalization; if (normalized) { normalization = get_norm_from_string("ortho", true); @@ -125,7 +135,8 @@ class StftGradKernel : public framework::OpKernel { FFTC2CFunctor fft_c2c_func; if (!onesided) { - fft_c2c_func(dev_ctx, dy, &complex_d_frames, axes, normalization, false); + fft_c2c_func(dev_ctx, dy, &complex_d_frames_w, axes, normalization, + false); } else { Tensor full_dy; full_dy.mutable_data(d_frames_dims, ctx.GetPlace()); @@ -139,13 +150,19 @@ class StftGradKernel : public framework::OpKernel { phi::funcs::PaddingFunctor( rank, ctx.template device_context(), pads, static_cast(0), *dy, &full_dy); - fft_c2c_func(dev_ctx, &full_dy, &complex_d_frames, axes, normalization, + fft_c2c_func(dev_ctx, &full_dy, &complex_d_frames_w, axes, normalization, false); } framework::TransComplexToReal( - framework::TransToProtoVarType(d_frames.dtype()), - framework::TransToProtoVarType(complex_d_frames.dtype()), - complex_d_frames, &d_frames); + framework::TransToProtoVarType(d_frames_w.dtype()), + framework::TransToProtoVarType(complex_d_frames_w.dtype()), + complex_d_frames_w, &d_frames_w); + + // d_frames_w -> d_frames + Tensor d_frames; + d_frames.mutable_data(d_frames_dims, ctx.GetPlace()); + ElementwiseComputeEx, DeviceContext, T>( + ctx, &d_frames_w, window, axes.back(), MulFunctor(), &d_frames); // d_frames -> dx FrameFunctor()(dev_ctx, &d_frames, dx, seq_length, n_fft, diff --git a/python/paddle/fluid/tests/unittests/test_stft_op.py b/python/paddle/fluid/tests/unittests/test_stft_op.py index 64b8084a1651f156dfdd12606df81e69dfa256ec..f228c148d6e177085f9b684b046f6a20c48ee7b8 100644 --- a/python/paddle/fluid/tests/unittests/test_stft_op.py +++ b/python/paddle/fluid/tests/unittests/test_stft_op.py @@ -43,8 +43,10 @@ def frame_from_librosa(x, frame_length, hop_length, axis=-1): return as_strided(x, shape=shape, strides=strides) -def stft_np(x, n_fft, hop_length, **kwargs): +def stft_np(x, window, n_fft, hop_length, **kwargs): frames = frame_from_librosa(x, n_fft, hop_length) + frames = np.multiply(frames.transpose([0, 2, 1]), window).transpose( + [0, 2, 1]) res = np.fft.rfft(frames, axis=1) return res @@ -55,8 +57,12 @@ class TestStftOp(OpTest): self.shape, self.type, self.attrs = self.initTestCase() self.inputs = { 'X': np.random.random(size=self.shape).astype(self.type), + 'Window': np.hamming(self.attrs['n_fft']).astype(self.type), + } + self.outputs = { + 'Out': stft_np( + x=self.inputs['X'], window=self.inputs['Window'], **self.attrs) } - self.outputs = {'Out': stft_np(x=self.inputs['X'], **self.attrs)} def initTestCase(self): input_shape = (2, 100)