// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/frame_op.h" #include "paddle/fluid/operators/spectral_op.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; template class StftKernel : public framework::OpKernel { public: /* Batch Signals (N, T) -> Frames (N, n_fft, num_frames) -> FFTR2C -> (N, n_fft/2 + 1, num_frames) or (N, n_fft, num_frames) */ void Compute(const framework::ExecutionContext& ctx) const override { using C = paddle::platform::complex; const Tensor* x = ctx.Input("X"); Tensor* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); const size_t x_rank = x->dims().size(); const size_t out_rank = out->dims().size(); const int n_fft = ctx.Attr("n_fft"); const int hop_length = ctx.Attr("hop_length"); const bool normalized = ctx.Attr("normalized"); const bool onesided = ctx.Attr("onesided"); const int n_frames = out->dims()[out_rank - 1]; const int seq_length = x->dims()[x_rank - 1]; auto& dev_ctx = ctx.device_context(); std::vector axes = {1}; // Frame Tensor frames; framework::DDim frames_dims(out->dims()); frames_dims.at(axes.back()) = n_fft; frames.mutable_data(frames_dims, ctx.GetPlace()); FrameFunctor()(dev_ctx, x, &frames, seq_length, n_fft, n_frames, hop_length, /*is_grad*/ false); // FFTR2C FFTNormMode normalization; if (normalized) { normalization = get_norm_from_string("ortho", true); } else { normalization = get_norm_from_string("backward", true); } FFTR2CFunctor fft_r2c_func; if (onesided) { fft_r2c_func(dev_ctx, &frames, out, axes, normalization, true); } else { framework::DDim onesided_dims(out->dims()); const int64_t onesided_axis_size = out->dims().at(axes.back()) / 2 + 1; onesided_dims.at(axes.back()) = onesided_axis_size; Tensor onesided_out; onesided_out.mutable_data(onesided_dims, ctx.GetPlace()); fft_r2c_func(dev_ctx, &frames, &onesided_out, axes, normalization, true); fill_conj(dev_ctx, &onesided_out, out, axes); } } }; template class StftGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { using C = paddle::platform::complex; auto& dev_ctx = ctx.device_context(); const auto* dy = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); dx->mutable_data(ctx.GetPlace()); const size_t dy_rank = dy->dims().size(); const size_t dx_rank = dx->dims().size(); const int n_fft = ctx.Attr("n_fft"); const int hop_length = ctx.Attr("hop_length"); const bool normalized = ctx.Attr("normalized"); const bool onesided = ctx.Attr("onesided"); const int n_frames = dy->dims()[dy_rank - 1]; const int seq_length = dx->dims()[dx_rank - 1]; std::vector axes = {1}; Tensor d_frames; framework::DDim d_frames_dims(dy->dims()); d_frames_dims.at(axes.back()) = n_fft; d_frames.mutable_data(d_frames_dims, ctx.GetPlace()); Tensor complex_d_frames; complex_d_frames.mutable_data(d_frames_dims, ctx.GetPlace()); // dy -> d_frames FFTNormMode normalization; if (normalized) { normalization = get_norm_from_string("ortho", true); } else { normalization = get_norm_from_string("backward", true); } FFTC2CFunctor fft_c2c_func; if (!onesided) { fft_c2c_func(dev_ctx, dy, &complex_d_frames, axes, normalization, false); } else { Tensor full_dy; full_dy.mutable_data(d_frames_dims, ctx.GetPlace()); auto zero_length = static_cast(full_dy.dims().at(axes.back()) - dy->dims().at(axes.back())); auto rank = dy->dims().size(); std::vector pads(rank * 2, 0); pads[axes.back() * 2 + 1] = zero_length; phi::funcs::PaddingFunctor( rank, ctx.template device_context(), pads, static_cast(0), *dy, &full_dy); fft_c2c_func(dev_ctx, &full_dy, &complex_d_frames, axes, normalization, false); } framework::TransComplexToReal( framework::TransToProtoVarType(d_frames.dtype()), framework::TransToProtoVarType(complex_d_frames.dtype()), complex_d_frames, &d_frames); // d_frames -> dx FrameFunctor()(dev_ctx, &d_frames, dx, seq_length, n_fft, n_frames, hop_length, /*is_grad*/ true); } }; } // namespace operators } // namespace paddle