未验证 提交 c049a6b4 编写于 作者: K KP 提交者: GitHub

Add window computation in stft op. (#40987)

上级 b6661d3a
...@@ -30,6 +30,8 @@ class StftOp : public framework::OperatorWithKernel { ...@@ -30,6 +30,8 @@ class StftOp : public framework::OperatorWithKernel {
const auto x_dims = ctx->GetInputDim("X"); const auto x_dims = ctx->GetInputDim("X");
const int x_rank = x_dims.size(); const int x_rank = x_dims.size();
const auto window_dims = ctx->GetInputDim("Window");
const int window_size = window_dims[0];
const bool onesided = ctx->Attrs().Get<bool>("onesided"); const bool onesided = ctx->Attrs().Get<bool>("onesided");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -43,6 +45,12 @@ class StftOp : public framework::OperatorWithKernel { ...@@ -43,6 +45,12 @@ class StftOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Attribute(hop_length) should be greater than 0, but got %s.", "Attribute(hop_length) should be greater than 0, but got %s.",
hop_length)); hop_length));
PADDLE_ENFORCE_EQ(
window_size, n_fft,
platform::errors::InvalidArgument(
"Input(Window) of StftOp should be equal with n_fft %s, "
"but got %s.",
n_fft, window_size));
int seq_length = x_dims[x_rank - 1]; int seq_length = x_dims[x_rank - 1];
int n_frames = 1 + (seq_length - n_fft) / hop_length; int n_frames = 1 + (seq_length - n_fft) / hop_length;
...@@ -77,6 +85,7 @@ class StftOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -77,6 +85,7 @@ class StftOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "Input waveforms with shape (N, T)"); AddInput("X", "Input waveforms with shape (N, T)");
AddInput("Window", "Input window with shape (n_fft,)");
AddOutput("Out", AddOutput("Out",
"The complex STFT output tensor with shape (N, n_fft, " "The complex STFT output tensor with shape (N, n_fft, "
"num_frames) or (N, n_fft/2 + 1, num_frames)"); "num_frames) or (N, n_fft/2 + 1, num_frames)");
...@@ -101,6 +110,7 @@ class StftGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -101,6 +110,7 @@ class StftGradOpMaker : public framework::SingleGradOpMaker<T> {
void Apply(GradOpPtr<T> grad_op) const override { void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("stft_grad"); grad_op->SetType("stft_grad");
grad_op->SetInput("X", this->Input("X")); grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput("Window", this->Input("Window"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs()); grad_op->SetAttrMap(this->Attrs());
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/frame_op.h" #include "paddle/fluid/operators/frame_op.h"
#include "paddle/fluid/operators/spectral_op.h" #include "paddle/fluid/operators/spectral_op.h"
...@@ -36,6 +37,7 @@ class StftKernel : public framework::OpKernel<T> { ...@@ -36,6 +37,7 @@ class StftKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using C = paddle::platform::complex<T>; using C = paddle::platform::complex<T>;
const Tensor* x = ctx.Input<Tensor>("X"); const Tensor* x = ctx.Input<Tensor>("X");
const Tensor* window = ctx.Input<Tensor>("Window");
Tensor* out = ctx.Output<Tensor>("Out"); Tensor* out = ctx.Output<Tensor>("Out");
out->mutable_data<C>(ctx.GetPlace()); out->mutable_data<C>(ctx.GetPlace());
...@@ -62,6 +64,12 @@ class StftKernel : public framework::OpKernel<T> { ...@@ -62,6 +64,12 @@ class StftKernel : public framework::OpKernel<T> {
FrameFunctor<DeviceContext, T>()(dev_ctx, x, &frames, seq_length, n_fft, FrameFunctor<DeviceContext, T>()(dev_ctx, x, &frames, seq_length, n_fft,
n_frames, hop_length, /*is_grad*/ false); n_frames, hop_length, /*is_grad*/ false);
// Window
Tensor frames_w;
frames_w.mutable_data<T>(frames_dims, ctx.GetPlace());
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &frames, window, axes.back(), MulFunctor<T>(), &frames_w);
// FFTR2C // FFTR2C
FFTNormMode normalization; FFTNormMode normalization;
if (normalized) { if (normalized) {
...@@ -72,14 +80,15 @@ class StftKernel : public framework::OpKernel<T> { ...@@ -72,14 +80,15 @@ class StftKernel : public framework::OpKernel<T> {
FFTR2CFunctor<DeviceContext, T, C> fft_r2c_func; FFTR2CFunctor<DeviceContext, T, C> fft_r2c_func;
if (onesided) { if (onesided) {
fft_r2c_func(dev_ctx, &frames, out, axes, normalization, true); fft_r2c_func(dev_ctx, &frames_w, out, axes, normalization, true);
} else { } else {
framework::DDim onesided_dims(out->dims()); framework::DDim onesided_dims(out->dims());
const int64_t onesided_axis_size = out->dims().at(axes.back()) / 2 + 1; const int64_t onesided_axis_size = out->dims().at(axes.back()) / 2 + 1;
onesided_dims.at(axes.back()) = onesided_axis_size; onesided_dims.at(axes.back()) = onesided_axis_size;
Tensor onesided_out; Tensor onesided_out;
onesided_out.mutable_data<C>(onesided_dims, ctx.GetPlace()); onesided_out.mutable_data<C>(onesided_dims, ctx.GetPlace());
fft_r2c_func(dev_ctx, &frames, &onesided_out, axes, normalization, true); fft_r2c_func(dev_ctx, &frames_w, &onesided_out, axes, normalization,
true);
fill_conj<DeviceContext, C>(dev_ctx, &onesided_out, out, axes); fill_conj<DeviceContext, C>(dev_ctx, &onesided_out, out, axes);
} }
} }
...@@ -92,6 +101,7 @@ class StftGradKernel : public framework::OpKernel<T> { ...@@ -92,6 +101,7 @@ class StftGradKernel : public framework::OpKernel<T> {
using C = paddle::platform::complex<T>; using C = paddle::platform::complex<T>;
auto& dev_ctx = ctx.device_context<DeviceContext>(); auto& dev_ctx = ctx.device_context<DeviceContext>();
const Tensor* window = ctx.Input<Tensor>("Window");
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Out")); const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace()); dx->mutable_data<T>(ctx.GetPlace());
...@@ -107,15 +117,15 @@ class StftGradKernel : public framework::OpKernel<T> { ...@@ -107,15 +117,15 @@ class StftGradKernel : public framework::OpKernel<T> {
const int seq_length = dx->dims()[dx_rank - 1]; const int seq_length = dx->dims()[dx_rank - 1];
std::vector<int64_t> axes = {1}; std::vector<int64_t> axes = {1};
Tensor d_frames; Tensor d_frames_w;
framework::DDim d_frames_dims(dy->dims()); framework::DDim d_frames_dims(dy->dims());
d_frames_dims.at(axes.back()) = n_fft; d_frames_dims.at(axes.back()) = n_fft;
d_frames.mutable_data<T>(d_frames_dims, ctx.GetPlace()); d_frames_w.mutable_data<T>(d_frames_dims, ctx.GetPlace());
Tensor complex_d_frames; Tensor complex_d_frames_w;
complex_d_frames.mutable_data<C>(d_frames_dims, ctx.GetPlace()); complex_d_frames_w.mutable_data<C>(d_frames_dims, ctx.GetPlace());
// dy -> d_frames // dy -> d_frames_w
FFTNormMode normalization; FFTNormMode normalization;
if (normalized) { if (normalized) {
normalization = get_norm_from_string("ortho", true); normalization = get_norm_from_string("ortho", true);
...@@ -125,7 +135,8 @@ class StftGradKernel : public framework::OpKernel<T> { ...@@ -125,7 +135,8 @@ class StftGradKernel : public framework::OpKernel<T> {
FFTC2CFunctor<DeviceContext, C, C> fft_c2c_func; FFTC2CFunctor<DeviceContext, C, C> fft_c2c_func;
if (!onesided) { if (!onesided) {
fft_c2c_func(dev_ctx, dy, &complex_d_frames, axes, normalization, false); fft_c2c_func(dev_ctx, dy, &complex_d_frames_w, axes, normalization,
false);
} else { } else {
Tensor full_dy; Tensor full_dy;
full_dy.mutable_data<C>(d_frames_dims, ctx.GetPlace()); full_dy.mutable_data<C>(d_frames_dims, ctx.GetPlace());
...@@ -139,13 +150,19 @@ class StftGradKernel : public framework::OpKernel<T> { ...@@ -139,13 +150,19 @@ class StftGradKernel : public framework::OpKernel<T> {
phi::funcs::PaddingFunctor<DeviceContext, C>( phi::funcs::PaddingFunctor<DeviceContext, C>(
rank, ctx.template device_context<DeviceContext>(), pads, rank, ctx.template device_context<DeviceContext>(), pads,
static_cast<C>(0), *dy, &full_dy); static_cast<C>(0), *dy, &full_dy);
fft_c2c_func(dev_ctx, &full_dy, &complex_d_frames, axes, normalization, fft_c2c_func(dev_ctx, &full_dy, &complex_d_frames_w, axes, normalization,
false); false);
} }
framework::TransComplexToReal( framework::TransComplexToReal(
framework::TransToProtoVarType(d_frames.dtype()), framework::TransToProtoVarType(d_frames_w.dtype()),
framework::TransToProtoVarType(complex_d_frames.dtype()), framework::TransToProtoVarType(complex_d_frames_w.dtype()),
complex_d_frames, &d_frames); complex_d_frames_w, &d_frames_w);
// d_frames_w -> d_frames
Tensor d_frames;
d_frames.mutable_data<T>(d_frames_dims, ctx.GetPlace());
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &d_frames_w, window, axes.back(), MulFunctor<T>(), &d_frames);
// d_frames -> dx // d_frames -> dx
FrameFunctor<DeviceContext, T>()(dev_ctx, &d_frames, dx, seq_length, n_fft, FrameFunctor<DeviceContext, T>()(dev_ctx, &d_frames, dx, seq_length, n_fft,
......
...@@ -43,8 +43,10 @@ def frame_from_librosa(x, frame_length, hop_length, axis=-1): ...@@ -43,8 +43,10 @@ def frame_from_librosa(x, frame_length, hop_length, axis=-1):
return as_strided(x, shape=shape, strides=strides) return as_strided(x, shape=shape, strides=strides)
def stft_np(x, n_fft, hop_length, **kwargs): def stft_np(x, window, n_fft, hop_length, **kwargs):
frames = frame_from_librosa(x, n_fft, hop_length) frames = frame_from_librosa(x, n_fft, hop_length)
frames = np.multiply(frames.transpose([0, 2, 1]), window).transpose(
[0, 2, 1])
res = np.fft.rfft(frames, axis=1) res = np.fft.rfft(frames, axis=1)
return res return res
...@@ -55,8 +57,12 @@ class TestStftOp(OpTest): ...@@ -55,8 +57,12 @@ class TestStftOp(OpTest):
self.shape, self.type, self.attrs = self.initTestCase() self.shape, self.type, self.attrs = self.initTestCase()
self.inputs = { self.inputs = {
'X': np.random.random(size=self.shape).astype(self.type), 'X': np.random.random(size=self.shape).astype(self.type),
'Window': np.hamming(self.attrs['n_fft']).astype(self.type),
}
self.outputs = {
'Out': stft_np(
x=self.inputs['X'], window=self.inputs['Window'], **self.attrs)
} }
self.outputs = {'Out': stft_np(x=self.inputs['X'], **self.attrs)}
def initTestCase(self): def initTestCase(self):
input_shape = (2, 100) input_shape = (2, 100)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册