From 28b4b2f76c511f11a03bf9bf65e157673fffcb88 Mon Sep 17 00:00:00 2001 From: Charles-hit <56987902+Charles-hit@users.noreply.github.com> Date: Thu, 28 Jul 2022 13:03:17 +0800 Subject: [PATCH] Move frame kernel to phi (#44615) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Move frame OP to phi态add frame OP yaml config and supplement single test * add Header file of in_dygraph_mode * Modify variable name and FrameGradInferMeta multiplex UnchangedInferMeta * move seq2col to phi --- paddle/fluid/operators/frame_op.cc | 134 +------ paddle/fluid/operators/frame_op.cu | 43 --- paddle/fluid/operators/frame_op.h | 362 ------------------ paddle/fluid/operators/overlap_add_op.h | 26 +- paddle/fluid/operators/stft_op.h | 34 +- paddle/phi/api/yaml/legacy_api.yaml | 9 + paddle/phi/api/yaml/legacy_backward.yaml | 10 + paddle/phi/infermeta/unary.cc | 84 ++++ paddle/phi/infermeta/unary.h | 7 + paddle/phi/kernels/cpu/frame_grad_kernel.cc | 31 ++ paddle/phi/kernels/cpu/frame_kernel.cc | 31 ++ paddle/phi/kernels/frame_grad_kernel.h | 30 ++ paddle/phi/kernels/frame_kernel.h | 27 ++ paddle/phi/kernels/funcs/frame_functor.h | 60 +++ .../math => phi/kernels/funcs}/seq2col.h | 10 +- paddle/phi/kernels/gpu/frame_grad_kernel.cu | 32 ++ paddle/phi/kernels/gpu/frame_kernel.cu | 32 ++ .../phi/kernels/impl/frame_grad_kernel_impl.h | 135 +++++++ paddle/phi/kernels/impl/frame_kernel_impl.h | 144 +++++++ paddle/phi/ops/compat/frame_sig.cc | 28 ++ .../fluid/tests/unittests/test_frame_op.py | 5 +- python/paddle/signal.py | 8 +- 22 files changed, 721 insertions(+), 561 deletions(-) delete mode 100644 paddle/fluid/operators/frame_op.cu delete mode 100644 paddle/fluid/operators/frame_op.h create mode 100644 paddle/phi/kernels/cpu/frame_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/frame_kernel.cc create mode 100644 paddle/phi/kernels/frame_grad_kernel.h create mode 100644 paddle/phi/kernels/frame_kernel.h create mode 100644 paddle/phi/kernels/funcs/frame_functor.h rename paddle/{fluid/operators/math => phi/kernels/funcs}/seq2col.h (97%) create mode 100644 paddle/phi/kernels/gpu/frame_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/frame_kernel.cu create mode 100644 paddle/phi/kernels/impl/frame_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/frame_kernel_impl.h create mode 100644 paddle/phi/ops/compat/frame_sig.cc diff --git a/paddle/fluid/operators/frame_op.cc b/paddle/fluid/operators/frame_op.cc index 45a6bc9994..8acb372a55 100644 --- a/paddle/fluid/operators/frame_op.cc +++ b/paddle/fluid/operators/frame_op.cc @@ -12,7 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/frame_op.h" +#include "paddle/phi/core/enforce.h" + +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -21,89 +27,6 @@ class FrameOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "frame"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "frame"); - - const int frame_length = ctx->Attrs().Get("frame_length"); - const int hop_length = ctx->Attrs().Get("hop_length"); - const int axis = ctx->Attrs().Get("axis"); - - const auto x_dims = ctx->GetInputDim("X"); - const int x_rank = x_dims.size(); - - PADDLE_ENFORCE_GE( - x_rank, - 1, - platform::errors::InvalidArgument( - "Input(X) of FrameOp should be a tensor which contains " - "at least 1 dimension, but got rank %s.", - x_rank)); - PADDLE_ENFORCE_GT(hop_length, - 0, - platform::errors::InvalidArgument( - "Attribute(hop_length) of FrameOp should be greater " - "than 0, but got %s.", - hop_length)); - PADDLE_ENFORCE_EQ( - (axis == 0 || axis == -1), - true, - platform::errors::InvalidArgument( - "Attribute(axis) of FrameOp should 0 or -1, but got %s.", axis)); - - std::vector output_shape; - int seq_length; - int n_frames; - - int start_axis; - int end_axis; - - if (axis == 0) { - seq_length = x_dims[0]; - start_axis = 1; - end_axis = x_rank - 1; - } else { - seq_length = x_dims[x_rank - 1]; - start_axis = 0; - end_axis = x_rank - 2; - } - - bool contain_unknown_dim = phi::contain_unknown_dim(x_dims); - bool check = ctx->IsRuntime() || !contain_unknown_dim; - if (check) { - PADDLE_ENFORCE_LE(frame_length, - seq_length, - platform::errors::InvalidArgument( - "Attribute(frame_length) of FrameOp should be less " - "equal than sequence length, but got (%s) > (%s).", - frame_length, - seq_length)); - } - - // It won't go into for loop when x_rank == 1U. - for (int i = start_axis; i <= end_axis; i++) { - output_shape.push_back(x_dims[i]); - } - - if (seq_length == -1) { - n_frames = -1; - } else { - n_frames = 1 + (seq_length - frame_length) / hop_length; - } - - if (axis == 0) { - // (n_frames, frame_length, ...) - output_shape.insert(output_shape.begin(), frame_length); - output_shape.insert(output_shape.begin(), n_frames); - } else { - // (..., frame_length, n_frames) - output_shape.push_back(frame_length); - output_shape.push_back(n_frames); - } - - ctx->SetOutputDim("Out", phi::make_ddim(output_shape)); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -136,17 +59,6 @@ class FrameOpMaker : public framework::OpProtoAndCheckerMaker { class FrameOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "frame_grad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - "Out@GRAD", - "frame_grad"); - const auto x_dims = ctx->GetInputDim("X"); - if (ctx->HasOutput(framework::GradVarName("X"))) { - ctx->SetOutputDim(framework::GradVarName("X"), x_dims); - } - } protected: framework::OpKernelType GetExpectedKernelType( @@ -160,7 +72,6 @@ template class FrameOpGradMaker : public framework::SingleGradOpMaker { public: using framework::SingleGradOpMaker::SingleGradOpMaker; - void Apply(GradOpPtr retv) const override { retv->SetType("frame_grad"); retv->SetInput("X", this->Input("X")); @@ -175,28 +86,19 @@ class FrameOpGradMaker : public framework::SingleGradOpMaker { namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(frame, + FrameInferShapeFunctor, + PD_INFER_META(phi::FrameInferMeta)); + +DECLARE_INFER_SHAPE_FUNCTOR(frame_grad, + FrameGradInferShapeFunctor, + PD_INFER_META(phi::UnchangedInferMeta)); + REGISTER_OPERATOR(frame, ops::FrameOp, ops::FrameOpMaker, ops::FrameOpGradMaker, - ops::FrameOpGradMaker); - -REGISTER_OPERATOR(frame_grad, ops::FrameOpGrad); - -REGISTER_OP_CPU_KERNEL( - frame, - ops::FrameKernel, - ops::FrameKernel, - ops::FrameKernel, - ops::FrameKernel, - ops::FrameKernel>, - ops::FrameKernel>); + ops::FrameOpGradMaker, + FrameInferShapeFunctor); -REGISTER_OP_CPU_KERNEL( - frame_grad, - ops::FrameGradKernel, - ops::FrameGradKernel, - ops::FrameGradKernel, - ops::FrameGradKernel, - ops::FrameGradKernel>, - ops::FrameGradKernel>); +REGISTER_OPERATOR(frame_grad, ops::FrameOpGrad, FrameGradInferShapeFunctor); diff --git a/paddle/fluid/operators/frame_op.cu b/paddle/fluid/operators/frame_op.cu deleted file mode 100644 index 33766b6621..0000000000 --- a/paddle/fluid/operators/frame_op.cu +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/frame_op.h" - -namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL( - frame, - ops::FrameKernel, - ops::FrameKernel, - ops::FrameKernel, - ops::FrameKernel, - ops::FrameKernel, - ops::FrameKernel>, - ops::FrameKernel>); - -REGISTER_OP_CUDA_KERNEL( - frame_grad, - ops::FrameGradKernel, - ops::FrameGradKernel, - ops::FrameGradKernel, - ops::FrameGradKernel, - ops::FrameGradKernel, - ops::FrameGradKernel>, - ops::FrameGradKernel>); diff --git a/paddle/fluid/operators/frame_op.h b/paddle/fluid/operators/frame_op.h deleted file mode 100644 index 36c73d0e62..0000000000 --- a/paddle/fluid/operators/frame_op.h +++ /dev/null @@ -1,362 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/math/seq2col.h" -#include "paddle/fluid/operators/transpose_op.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { -using Tensor = framework::Tensor; - -template -struct FrameFunctor { - void operator()(const DeviceContext& dev_ctx, - const Tensor* input, - Tensor* output, - size_t seq_length, - size_t frame_length, - size_t n_frames, - size_t hop_length, - bool is_grad = false) const { - auto numel = output->numel(); - const auto* input_data = input->data(); - auto* output_data = output->data(); - - platform::ForRange for_range(dev_ctx, numel); - if (!is_grad) { - math::Seq2ColFunctor functor(input_data, - output_data, - seq_length, - frame_length, - n_frames, - hop_length); - for_range(functor); - } else { - math::Col2SeqFunctor functor(input_data, - output_data, - seq_length, - frame_length, - n_frames, - hop_length); - for_range(functor); - } - } -}; - -template -class FrameKernel : public framework::OpKernel { - public: - /* - Frame kernel slices frames from input sequences. The main steps as follows: - - - Case 1 - input dims == 1: - - axis is -1: Call a FrameFunctor to compute directly. - - axis is 0: Transpose output firstly, and then it falls into - case axis is -1. Finally, it restores the dims of - output tensor. - - - Case 2 - input dims == 2: - - axis is -1: Call a FrameFunctor to compute directly. - - axis is 0: Transpose both input and output firstly, and then it falls - into case axis is -1. Finally, it restores the dims of - output tensor. - - - Case 3 - input dims > 2: - Flatten the input and output to 2D and 3D respectively so that it - falls into Case 2. Finally, it restores the dims of output tensor. - */ - void Compute(const framework::ExecutionContext& ctx) const override { - const Tensor* x = ctx.Input("X"); - Tensor* out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); - const size_t x_rank = x->dims().size(); - const size_t out_rank = out->dims().size(); - - const int frame_length = ctx.Attr("frame_length"); - const int hop_length = ctx.Attr("hop_length"); - const int axis = ctx.Attr("axis"); - const int n_frames = - (axis == 0) ? out->dims()[0] : out->dims()[out_rank - 1]; - const int seq_length = (axis == 0) ? x->dims()[0] : x->dims()[x_rank - 1]; - - auto& dev_ctx = ctx.device_context(); - - // When the number of input dims is larger than 2, it needs to copy - // from x to resize input into 2d and output into 3d. Morevoer, output - // dims will be restored at the last step. - Tensor x_(x->type()); - x_ = *x; - - framework::DDim preserved_dims; - if (x_rank > 2) { - // Save dims used to flatten both input and output tensors and restore - // output tensor. - framework::DDim x_resized_dims; - framework::DDim out_resized_dims; - if (axis == 0) { - preserved_dims = phi::slice_ddim(x_.dims(), 1, x_rank); - x_resized_dims = {seq_length, phi::product(preserved_dims)}; - out_resized_dims = { - n_frames, frame_length, phi::product(preserved_dims)}; - } else { - preserved_dims = phi::slice_ddim(x_.dims(), 0, x_rank - 1); - x_resized_dims = {phi::product(preserved_dims), seq_length}; - out_resized_dims = { - phi::product(preserved_dims), frame_length, n_frames}; - } - x_.Resize(x_resized_dims); - out->Resize(out_resized_dims); - } - - Tensor trans_x(x_.type()); - Tensor trans_out(out->type()); - - // Transpose input and output in case that axis is 0. - if (axis == 0) { - if (x_rank == 1U) { - trans_x = x_; - - std::vector perm_out{1, 0}; - auto out_dims_vec = phi::vectorize(out->dims()); - for (int i = 0; i < out->dims().size(); ++i) { - out_dims_vec[i] = out->dims()[perm_out[i]]; - } - trans_out.Resize(phi::make_ddim(out_dims_vec)); - trans_out.mutable_data(ctx.GetPlace()); - TransCompute( - perm_out.size(), dev_ctx, *out, &trans_out, perm_out); - } else { - std::vector perm_x{1, 0}; - auto x_dims_vec = phi::vectorize(x_.dims()); - for (int i = 0; i < x_.dims().size(); ++i) { - x_dims_vec[i] = x_.dims()[perm_x[i]]; - } - trans_x.Resize(phi::make_ddim(x_dims_vec)); - trans_x.mutable_data(ctx.GetPlace()); - TransCompute( - perm_x.size(), dev_ctx, x_, &trans_x, perm_x); - - std::vector perm_out{2, 1, 0}; - auto out_dims_vec = phi::vectorize(out->dims()); - for (int i = 0; i < out->dims().size(); ++i) { - out_dims_vec[i] = out->dims()[perm_out[i]]; - } - trans_out.Resize(phi::make_ddim(out_dims_vec)); - trans_out.mutable_data(ctx.GetPlace()); - TransCompute( - perm_out.size(), dev_ctx, *out, &trans_out, perm_out); - } - } else { - trans_x = x_; - trans_out = *out; - } - - FrameFunctor()(dev_ctx, - &trans_x, - &trans_out, - seq_length, - frame_length, - n_frames, - hop_length, - /*is_grad*/ false); - - // Transpose output in case axis is 0. - if (axis == 0) { - if (x_rank == 1U) { - std::vector perm_out{1, 0}; - TransCompute( - perm_out.size(), dev_ctx, trans_out, out, perm_out); - } else { - std::vector perm_out{2, 1, 0}; - TransCompute( - perm_out.size(), dev_ctx, trans_out, out, perm_out); - } - } - - // Restore output dims when the number of dims is larger than 2. - if (x_rank > 2) { - std::vector restored_out_shape; - for (int i = 0; i < preserved_dims.size(); i++) { - restored_out_shape.push_back(preserved_dims[i]); - } - - if (axis == 0) { - // (n_frames, frame_length, ...) - restored_out_shape.insert(restored_out_shape.begin(), frame_length); - restored_out_shape.insert(restored_out_shape.begin(), n_frames); - } else { - // (..., frame_length, n_frames) - restored_out_shape.push_back(frame_length); - restored_out_shape.push_back(n_frames); - } - - out->Resize(phi::make_ddim(restored_out_shape)); - } - } -}; - -template -class FrameGradKernel : public framework::OpKernel { - public: - /* - Frame gradient kernel accumulate gradient `d_x` from `d_out`. The - main steps as follows: - - - Case 1 - d_x dims == 1: - - axis is -1: Call a FrameFunctor to compute directly. Notes that - `is_grad` is set to true to select gradient data functor. - - axis is 0: Transpose `d_out` firstly, and then it falls into - case axis is -1. - - - Case 2 - d_x dims == 2: - - axis is -1: Call a FrameFunctor to compute directly. - - axis is 0: Transpose both `d_x` and `d_out` firstly, and then it - falls into case axis is -1. Finally, it restores the - dims of `d_x`. - - - Case 3 - d_x dims > 2: - Flatten the `d_x` and `d_out` to 2D and 3D respectively so that it - falls into Case 2. Finally, it restores the dims of `d_x` tensor. - */ - void Compute(const framework::ExecutionContext& ctx) const { - const Tensor* d_out = ctx.Input(framework::GradVarName("Out")); - Tensor* d_x = ctx.Output(framework::GradVarName("X")); - d_x->mutable_data(ctx.GetPlace()); - const size_t d_out_rank = d_out->dims().size(); - const size_t d_x_rank = d_x->dims().size(); - - const int frame_length = ctx.Attr("frame_length"); - const int hop_length = ctx.Attr("hop_length"); - const int axis = ctx.Attr("axis"); - const int n_frames = - (axis == 0) ? d_out->dims()[0] : d_out->dims()[d_out_rank - 1]; - const int seq_length = - (axis == 0) ? d_x->dims()[0] : d_x->dims()[d_x_rank - 1]; - - auto& dev_ctx = ctx.device_context(); - - Tensor d_out_(d_out->type()); - d_out_ = *d_out; - - framework::DDim preserved_dims; - if (d_x_rank > 2) { - // Save dims used to flatten both input and output tensors and restore - // output tensor. - framework::DDim d_x_resized_dims; - framework::DDim d_out_resized_dims; - if (axis == 0) { - preserved_dims = phi::slice_ddim(d_x->dims(), 1, d_x_rank); - d_x_resized_dims = {seq_length, phi::product(preserved_dims)}; - d_out_resized_dims = { - n_frames, frame_length, phi::product(preserved_dims)}; - } else { - preserved_dims = phi::slice_ddim(d_x->dims(), 0, d_x_rank - 1); - d_x_resized_dims = {phi::product(preserved_dims), seq_length}; - d_out_resized_dims = { - phi::product(preserved_dims), frame_length, n_frames}; - } - d_x->Resize(d_x_resized_dims); - d_out_.Resize(d_out_resized_dims); - } - - Tensor trans_d_x(d_x->type()); - Tensor trans_d_out(d_out_.type()); - - // Transpose input and output in case that axis is 0. - if (axis == 0) { - if (d_x_rank == 1U) { - trans_d_x = *d_x; - - std::vector perm_d_out{1, 0}; - auto d_out_dims_vec = phi::vectorize(d_out_.dims()); - for (int i = 0; i < d_out_.dims().size(); ++i) { - d_out_dims_vec[i] = d_out_.dims()[perm_d_out[i]]; - } - trans_d_out.Resize(phi::make_ddim(d_out_dims_vec)); - trans_d_out.mutable_data(ctx.GetPlace()); - TransCompute( - perm_d_out.size(), dev_ctx, d_out_, &trans_d_out, perm_d_out); - } else { - std::vector perm_d_x{1, 0}; - auto d_x_dims_vec = phi::vectorize(d_x->dims()); - for (int i = 0; i < d_x->dims().size(); ++i) { - d_x_dims_vec[i] = d_x->dims()[perm_d_x[i]]; - } - trans_d_x.Resize(phi::make_ddim(d_x_dims_vec)); - trans_d_x.mutable_data(ctx.GetPlace()); - TransCompute( - perm_d_x.size(), dev_ctx, *d_x, &trans_d_x, perm_d_x); - - std::vector perm_d_out{2, 1, 0}; - auto d_out_dims_vec = phi::vectorize(d_out_.dims()); - for (int i = 0; i < d_out_.dims().size(); ++i) { - d_out_dims_vec[i] = d_out_.dims()[perm_d_out[i]]; - } - trans_d_out.Resize(phi::make_ddim(d_out_dims_vec)); - trans_d_out.mutable_data(ctx.GetPlace()); - TransCompute( - perm_d_out.size(), dev_ctx, d_out_, &trans_d_out, perm_d_out); - } - } else { - trans_d_x = *d_x; - trans_d_out = d_out_; - } - - FrameFunctor()(dev_ctx, - &trans_d_out, - &trans_d_x, - seq_length, - frame_length, - n_frames, - hop_length, - /*is_grad*/ true); - - // Transpose output in case axis is 0. - if (axis == 0 && d_x_rank > 1U) { - std::vector perm_d_x{1, 0}; - TransCompute( - perm_d_x.size(), dev_ctx, trans_d_x, d_x, perm_d_x); - } - - // Restore output dims when the number of dims is larger than 2. - if (d_x_rank > 2) { - std::vector restored_d_x_shape; - for (int i = 0; i < preserved_dims.size(); i++) { - restored_d_x_shape.push_back(preserved_dims[i]); - } - - if (axis == 0) { - // (seq_length, ...) - restored_d_x_shape.insert(restored_d_x_shape.begin(), seq_length); - } else { - // (..., seq_length) - restored_d_x_shape.push_back(seq_length); - } - - d_x->Resize(phi::make_ddim(restored_d_x_shape)); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/overlap_add_op.h b/paddle/fluid/operators/overlap_add_op.h index 54d2278d8c..b8008871d2 100644 --- a/paddle/fluid/operators/overlap_add_op.h +++ b/paddle/fluid/operators/overlap_add_op.h @@ -18,11 +18,11 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/math/seq2col.h" #include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/seq2col.h" namespace paddle { namespace operators { @@ -44,20 +44,20 @@ struct OverlapAddFunctor { platform::ForRange for_range(dev_ctx, numel); if (!is_grad) { - math::Col2SeqFunctor functor(input_data, - output_data, - seq_length, - frame_length, - n_frames, - hop_length); + phi::funcs::Col2SeqFunctor functor(input_data, + output_data, + seq_length, + frame_length, + n_frames, + hop_length); for_range(functor); } else { - math::Seq2ColFunctor functor(input_data, - output_data, - seq_length, - frame_length, - n_frames, - hop_length); + phi::funcs::Seq2ColFunctor functor(input_data, + output_data, + seq_length, + frame_length, + n_frames, + hop_length); for_range(functor); } } diff --git a/paddle/fluid/operators/stft_op.h b/paddle/fluid/operators/stft_op.h index c65c24748c..bbd9b13769 100644 --- a/paddle/fluid/operators/stft_op.h +++ b/paddle/fluid/operators/stft_op.h @@ -18,8 +18,8 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" -#include "paddle/fluid/operators/frame_op.h" #include "paddle/fluid/operators/spectral_op.h" +#include "paddle/phi/kernels/funcs/frame_functor.h" namespace paddle { namespace operators { @@ -60,14 +60,14 @@ class StftKernel : public framework::OpKernel { framework::DDim frames_dims(out->dims()); frames_dims.at(axes.back()) = n_fft; frames.mutable_data(frames_dims, ctx.GetPlace()); - FrameFunctor()(dev_ctx, - x, - &frames, - seq_length, - n_fft, - n_frames, - hop_length, - /*is_grad*/ false); + phi::funcs::FrameFunctor()(dev_ctx, + x, + &frames, + seq_length, + n_fft, + n_frames, + hop_length, + /*is_grad*/ false); // Window Tensor frames_w; @@ -175,14 +175,14 @@ class StftGradKernel : public framework::OpKernel { ctx, &d_frames_w, window, axes.back(), MulFunctor(), &d_frames); // d_frames -> dx - FrameFunctor()(dev_ctx, - &d_frames, - dx, - seq_length, - n_fft, - n_frames, - hop_length, - /*is_grad*/ true); + phi::funcs::FrameFunctor()(dev_ctx, + &d_frames, + dx, + seq_length, + n_fft, + n_frames, + hop_length, + /*is_grad*/ true); } }; diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 6a5eb3b832..6d5e87bd79 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -850,6 +850,15 @@ func : fmin backward : fmin_grad +- api : frame + args : (Tensor x, int frame_length, int hop_length, int axis) + output : Tensor(out) + infer_meta : + func : FrameInferMeta + kernel : + func : frame + backward : frame_grad + - api : frobenius_norm args : (Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all) output : Tensor(out) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index a97fab73cd..804a653348 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -824,6 +824,16 @@ kernel : func : fmin_grad +- backward_api : frame_grad + forward : frame(Tensor x, int frame_length, int hop_length, int axis) -> Tensor(out) + args : (Tensor x, Tensor out_grad, int frame_length, int hop_length, int axis) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : frame_grad + - backward_api : frobenius_norm_grad forward : frobenius_norm(Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all) -> Tensor(out) args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] axis, bool keep_dim, bool reduce_all) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 3b31b165b4..dbeba144b5 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -875,6 +875,90 @@ void FlipInferMeta(const MetaTensor& x, out->share_lod(x); } +void FrameInferMeta(const MetaTensor& x, + int frame_length, + int hop_length, + int axis, + MetaTensor* out, + MetaConfig config) { + PADDLE_ENFORCE_NOT_NULL(out, + phi::errors::InvalidArgument( + "Output(Out) of FrameOp should not be null.")); + const auto x_dims = x.dims(); + const int x_rank = x_dims.size(); + + PADDLE_ENFORCE_GE(x_rank, + 1, + phi::errors::InvalidArgument( + "Input(X) of FrameOp should be a tensor which contains " + "at least 1 dimension, but got rank %s.", + x_rank)); + PADDLE_ENFORCE_GT(hop_length, + 0, + phi::errors::InvalidArgument( + "Attribute(hop_length) of FrameOp should be greater " + "than 0, but got %s.", + hop_length)); + PADDLE_ENFORCE_EQ( + (axis == 0 || axis == -1), + true, + phi::errors::InvalidArgument( + "Attribute(axis) of FrameOp should 0 or -1, but got %s.", axis)); + + std::vector output_shape; + int seq_length; + int n_frames; + + int start_axis; + int end_axis; + + if (axis == 0) { + seq_length = x_dims[0]; + start_axis = 1; + end_axis = x_rank - 1; + } else { + seq_length = x_dims[x_rank - 1]; + start_axis = 0; + end_axis = x_rank - 2; + } + + bool contain_unknown_dim = phi::contain_unknown_dim(x_dims); + bool check = config.is_runtime || !contain_unknown_dim; + if (check) { + PADDLE_ENFORCE_LE(frame_length, + seq_length, + phi::errors::InvalidArgument( + "Attribute(frame_length) of FrameOp should be less " + "equal than sequence length, but got (%s) > (%s).", + frame_length, + seq_length)); + } + + // It won't go into for loop when x_rank == 1U. + for (int i = start_axis; i <= end_axis; i++) { + output_shape.push_back(x_dims[i]); + } + + if (seq_length == -1) { + n_frames = -1; + } else { + n_frames = 1 + (seq_length - frame_length) / hop_length; + } + + if (axis == 0) { + // (n_frames, frame_length, ...) + output_shape.insert(output_shape.begin(), frame_length); + output_shape.insert(output_shape.begin(), n_frames); + } else { + // (..., frame_length, n_frames) + output_shape.push_back(frame_length); + output_shape.push_back(n_frames); + } + + out->set_dims(phi::make_ddim(output_shape)); + out->set_dtype(x.dtype()); +} + void FullBatchSizeLikeInferMeta(const MetaTensor& x, const std::vector& shape, const Scalar& val, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index c1db2561f0..e7d04cb998 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -130,6 +130,13 @@ void FlipInferMeta(const MetaTensor& x, const std::vector& axis, MetaTensor* out); +void FrameInferMeta(const MetaTensor& x, + int frame_length, + int hop_length, + int axis, + MetaTensor* out, + MetaConfig = MetaConfig()); + void FullBatchSizeLikeInferMeta(const MetaTensor& x, const std::vector& shape, const Scalar& val, diff --git a/paddle/phi/kernels/cpu/frame_grad_kernel.cc b/paddle/phi/kernels/cpu/frame_grad_kernel.cc new file mode 100644 index 0000000000..d4772b176a --- /dev/null +++ b/paddle/phi/kernels/cpu/frame_grad_kernel.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/frame_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/frame_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(frame_grad, + CPU, + ALL_LAYOUT, + phi::FrameGradKernel, + int, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/frame_kernel.cc b/paddle/phi/kernels/cpu/frame_kernel.cc new file mode 100644 index 0000000000..708ceddbc1 --- /dev/null +++ b/paddle/phi/kernels/cpu/frame_kernel.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/frame_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/frame_kernel_impl.h" + +PD_REGISTER_KERNEL(frame, + CPU, + ALL_LAYOUT, + phi::FrameKernel, + int, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/frame_grad_kernel.h b/paddle/phi/kernels/frame_grad_kernel.h new file mode 100644 index 0000000000..eaca698d76 --- /dev/null +++ b/paddle/phi/kernels/frame_grad_kernel.h @@ -0,0 +1,30 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void FrameGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + int frame_length, + int hop_length, + int axis, + DenseTensor* dx); + +} // namespace phi diff --git a/paddle/phi/kernels/frame_kernel.h b/paddle/phi/kernels/frame_kernel.h new file mode 100644 index 0000000000..66a6cff347 --- /dev/null +++ b/paddle/phi/kernels/frame_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +template +void FrameKernel(const Context& dev_ctx, + const DenseTensor& x, + int frame_length, + int hop_length, + int axis, + DenseTensor* out); +} // namespace phi diff --git a/paddle/phi/kernels/funcs/frame_functor.h b/paddle/phi/kernels/funcs/frame_functor.h new file mode 100644 index 0000000000..88efc86718 --- /dev/null +++ b/paddle/phi/kernels/funcs/frame_functor.h @@ -0,0 +1,60 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/funcs/seq2col.h" + +namespace phi { +namespace funcs { + +template +struct FrameFunctor { + void operator()(const Context& dev_ctx, + const DenseTensor* input, + DenseTensor* output, + size_t seq_length, + size_t frame_length, + size_t n_frames, + size_t hop_length, + bool is_grad = false) const { + auto numel = output->numel(); + const auto* input_data = input->data(); + auto* output_data = output->data(); + + phi::funcs::ForRange for_range(dev_ctx, numel); + if (!is_grad) { + phi::funcs::Seq2ColFunctor functor(input_data, + output_data, + seq_length, + frame_length, + n_frames, + hop_length); + for_range(functor); + } else { + phi::funcs::Col2SeqFunctor functor(input_data, + output_data, + seq_length, + frame_length, + n_frames, + hop_length); + for_range(functor); + } + } +}; + +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/seq2col.h b/paddle/phi/kernels/funcs/seq2col.h similarity index 97% rename from paddle/fluid/operators/math/seq2col.h rename to paddle/phi/kernels/funcs/seq2col.h index eca587e368..b757f8403d 100644 --- a/paddle/fluid/operators/math/seq2col.h +++ b/paddle/phi/kernels/funcs/seq2col.h @@ -14,9 +14,8 @@ #pragma once -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { template struct Seq2ColFunctor { @@ -189,6 +188,5 @@ struct Col2SeqFunctor { size_t hop_length_; }; -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/frame_grad_kernel.cu b/paddle/phi/kernels/gpu/frame_grad_kernel.cu new file mode 100644 index 0000000000..7deb9ff04c --- /dev/null +++ b/paddle/phi/kernels/gpu/frame_grad_kernel.cu @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/frame_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/frame_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(frame_grad, + GPU, + ALL_LAYOUT, + phi::FrameGradKernel, + int, + int64_t, + float, + double, + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/frame_kernel.cu b/paddle/phi/kernels/gpu/frame_kernel.cu new file mode 100644 index 0000000000..2506cd714b --- /dev/null +++ b/paddle/phi/kernels/gpu/frame_kernel.cu @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/frame_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/frame_kernel_impl.h" + +PD_REGISTER_KERNEL(frame, + GPU, + ALL_LAYOUT, + phi::FrameKernel, + int, + int64_t, + float, + double, + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/impl/frame_grad_kernel_impl.h b/paddle/phi/kernels/impl/frame_grad_kernel_impl.h new file mode 100644 index 0000000000..9d1bfe453d --- /dev/null +++ b/paddle/phi/kernels/impl/frame_grad_kernel_impl.h @@ -0,0 +1,135 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/funcs/frame_functor.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { +template +void FrameGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + int frame_length, + int hop_length, + int axis, + DenseTensor* dx) { + dev_ctx.template Alloc(dx); + const size_t dout_rank = dout.dims().size(); + const size_t dx_rank = dx->dims().size(); + const int n_frames = + (axis == 0) ? dout.dims()[0] : dout.dims()[dout_rank - 1]; + const int seq_length = (axis == 0) ? dx->dims()[0] : dx->dims()[dx_rank - 1]; + DenseTensor dout_tmp = dout; + + DDim preserved_dims; + if (dx_rank > 2) { + // Save dims used to flatten both input and output tensors and restore + // output tensor. + DDim dx_resized_dims; + DDim dout_resized_dims; + if (axis == 0) { + preserved_dims = phi::slice_ddim(dx->dims(), 1, dx_rank); + dx_resized_dims = {seq_length, phi::product(preserved_dims)}; + dout_resized_dims = { + n_frames, frame_length, phi::product(preserved_dims)}; + } else { + preserved_dims = phi::slice_ddim(dx->dims(), 0, dx_rank - 1); + dx_resized_dims = {phi::product(preserved_dims), seq_length}; + dout_resized_dims = { + phi::product(preserved_dims), frame_length, n_frames}; + } + dx->Resize(dx_resized_dims); + dout_tmp.Resize(dout_resized_dims); + } + + DenseTensor trans_dx; + DenseTensor trans_dout; + + // Transpose input and output in case that axis is 0. + if (axis == 0) { + if (dx_rank == 1U) { + trans_dx = *dx; + + std::vector perm_dout{1, 0}; + auto dout_dims_vec = phi::vectorize(dout_tmp.dims()); + for (int i = 0; i < dout_tmp.dims().size(); ++i) { + dout_dims_vec[i] = dout_tmp.dims()[perm_dout[i]]; + } + trans_dout.Resize(phi::make_ddim(dout_dims_vec)); + dev_ctx.template Alloc(&trans_dout); + phi::funcs::TransCompute( + perm_dout.size(), dev_ctx, dout_tmp, &trans_dout, perm_dout); + } else { + std::vector perm_dx{1, 0}; + auto dx_dims_vec = phi::vectorize(dx->dims()); + for (int i = 0; i < dx->dims().size(); ++i) { + dx_dims_vec[i] = dx->dims()[perm_dx[i]]; + } + trans_dx.Resize(phi::make_ddim(dx_dims_vec)); + dev_ctx.template Alloc(&trans_dx); + phi::funcs::TransCompute( + perm_dx.size(), dev_ctx, *dx, &trans_dx, perm_dx); + + std::vector perm_dout{2, 1, 0}; + auto dout_dims_vec = phi::vectorize(dout_tmp.dims()); + for (int i = 0; i < dout_tmp.dims().size(); ++i) { + dout_dims_vec[i] = dout_tmp.dims()[perm_dout[i]]; + } + trans_dout.Resize(phi::make_ddim(dout_dims_vec)); + dev_ctx.template Alloc(&trans_dout); + phi::funcs::TransCompute( + perm_dout.size(), dev_ctx, dout_tmp, &trans_dout, perm_dout); + } + } else { + trans_dx = *dx; + trans_dout = dout_tmp; + } + + phi::funcs::FrameFunctor()(dev_ctx, + &trans_dout, + &trans_dx, + seq_length, + frame_length, + n_frames, + hop_length, + /*is_grad*/ true); + + // Transpose output in case axis is 0. + if (axis == 0 && dx_rank > 1U) { + std::vector perm_dx{1, 0}; + phi::funcs::TransCompute( + perm_dx.size(), dev_ctx, trans_dx, dx, perm_dx); + } + + // Restore output dims when the number of dims is larger than 2. + if (dx_rank > 2) { + std::vector restored_dx_shape; + for (int i = 0; i < preserved_dims.size(); i++) { + restored_dx_shape.push_back(preserved_dims[i]); + } + + if (axis == 0) { + // (seq_length, ...) + restored_dx_shape.insert(restored_dx_shape.begin(), seq_length); + } else { + // (..., seq_length) + restored_dx_shape.push_back(seq_length); + } + + dx->Resize(phi::make_ddim(restored_dx_shape)); + } +} +} // namespace phi diff --git a/paddle/phi/kernels/impl/frame_kernel_impl.h b/paddle/phi/kernels/impl/frame_kernel_impl.h new file mode 100644 index 0000000000..b6a0b2ab6a --- /dev/null +++ b/paddle/phi/kernels/impl/frame_kernel_impl.h @@ -0,0 +1,144 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/funcs/frame_functor.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { +template +void FrameKernel(const Context& dev_ctx, + const DenseTensor& x, + int frame_length, + int hop_length, + int axis, + DenseTensor* out) { + dev_ctx.template Alloc(out); + const size_t x_rank = x.dims().size(); + const size_t out_rank = out->dims().size(); + const int n_frames = (axis == 0) ? out->dims()[0] : out->dims()[out_rank - 1]; + const int seq_length = (axis == 0) ? x.dims()[0] : x.dims()[x_rank - 1]; + // When the number of input dims is larger than 2, it needs to copy + // from x to resize input into 2d and output into 3d. Morevoer, output + // dims will be restored at the last step. + DenseTensor x_tmp = x; + + DDim preserved_dims; + if (x_rank > 2) { + // Save dims used to flatten both input and output tensors and restore + // output tensor. + DDim x_resized_dims; + DDim out_resized_dims; + if (axis == 0) { + preserved_dims = phi::slice_ddim(x_tmp.dims(), 1, x_rank); + x_resized_dims = {seq_length, phi::product(preserved_dims)}; + out_resized_dims = {n_frames, frame_length, phi::product(preserved_dims)}; + } else { + preserved_dims = phi::slice_ddim(x_tmp.dims(), 0, x_rank - 1); + x_resized_dims = {phi::product(preserved_dims), seq_length}; + out_resized_dims = {phi::product(preserved_dims), frame_length, n_frames}; + } + x_tmp.Resize(x_resized_dims); + out->Resize(out_resized_dims); + } + + DenseTensor trans_x; + DenseTensor trans_out; + + // Transpose input and output in case that axis is 0. + if (axis == 0) { + if (x_rank == 1U) { + trans_x = x_tmp; + + std::vector perm_out{1, 0}; + auto out_dims_vec = phi::vectorize(out->dims()); + for (int i = 0; i < out->dims().size(); ++i) { + out_dims_vec[i] = out->dims()[perm_out[i]]; + } + trans_out.Resize(phi::make_ddim(out_dims_vec)); + + dev_ctx.template Alloc(&trans_out); + phi::funcs::TransCompute( + perm_out.size(), dev_ctx, *out, &trans_out, perm_out); + } else { + std::vector perm_x{1, 0}; + auto x_dims_vec = phi::vectorize(x_tmp.dims()); + for (int i = 0; i < x_tmp.dims().size(); ++i) { + x_dims_vec[i] = x_tmp.dims()[perm_x[i]]; + } + trans_x.Resize(phi::make_ddim(x_dims_vec)); + dev_ctx.template Alloc(&trans_x); + phi::funcs::TransCompute( + perm_x.size(), dev_ctx, x_tmp, &trans_x, perm_x); + + std::vector perm_out{2, 1, 0}; + auto out_dims_vec = phi::vectorize(out->dims()); + for (int i = 0; i < out->dims().size(); ++i) { + out_dims_vec[i] = out->dims()[perm_out[i]]; + } + trans_out.Resize(phi::make_ddim(out_dims_vec)); + dev_ctx.template Alloc(&trans_out); + phi::funcs::TransCompute( + perm_out.size(), dev_ctx, *out, &trans_out, perm_out); + } + } else { + trans_x = x_tmp; + trans_out = *out; + } + + phi::funcs::FrameFunctor()(dev_ctx, + &trans_x, + &trans_out, + seq_length, + frame_length, + n_frames, + hop_length, + /*is_grad*/ false); + + // Transpose output in case axis is 0. + if (axis == 0) { + if (x_rank == 1U) { + std::vector perm_out{1, 0}; + funcs::TransCompute( + perm_out.size(), dev_ctx, trans_out, out, perm_out); + } else { + std::vector perm_out{2, 1, 0}; + funcs::TransCompute( + perm_out.size(), dev_ctx, trans_out, out, perm_out); + } + } + + // Restore output dims when the number of dims is larger than 2. + if (x_rank > 2) { + std::vector restored_out_shape; + for (int i = 0; i < preserved_dims.size(); i++) { + restored_out_shape.push_back(preserved_dims[i]); + } + + if (axis == 0) { + // (n_frames, frame_length, ...) + restored_out_shape.insert(restored_out_shape.begin(), frame_length); + restored_out_shape.insert(restored_out_shape.begin(), n_frames); + } else { + // (..., frame_length, n_frames) + restored_out_shape.push_back(frame_length); + restored_out_shape.push_back(n_frames); + } + + out->Resize(phi::make_ddim(restored_out_shape)); + } +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/frame_sig.cc b/paddle/phi/ops/compat/frame_sig.cc new file mode 100644 index 0000000000..cbe24095b0 --- /dev/null +++ b/paddle/phi/ops/compat/frame_sig.cc @@ -0,0 +1,28 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature FrameGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("frame_grad", + {"X", "Out@GRAD"}, + {"frame_length", "hop_length", "axis"}, + {"X@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(frame_grad, phi::FrameGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_frame_op.py b/python/paddle/fluid/tests/unittests/test_frame_op.py index 528446f3eb..7f38b95266 100644 --- a/python/paddle/fluid/tests/unittests/test_frame_op.py +++ b/python/paddle/fluid/tests/unittests/test_frame_op.py @@ -47,6 +47,7 @@ class TestFrameOp(OpTest): def setUp(self): self.op_type = "frame" + self.python_api = paddle.signal.frame self.shape, self.type, self.attrs = self.initTestCase() self.inputs = { 'X': np.random.random(size=self.shape).astype(self.type), @@ -67,12 +68,12 @@ class TestFrameOp(OpTest): def test_check_output(self): paddle.enable_static() - self.check_output() + self.check_output(check_eager=True) paddle.disable_static() def test_check_grad_normal(self): paddle.enable_static() - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) paddle.disable_static() diff --git a/python/paddle/signal.py b/python/paddle/signal.py index 6725373d05..fdca681fab 100644 --- a/python/paddle/signal.py +++ b/python/paddle/signal.py @@ -21,7 +21,8 @@ from .fft import fft_r2c, fft_c2r, fft_c2c from .fluid.data_feeder import check_variable_and_dtype from .fluid.framework import _non_static_mode from .fluid.layer_helper import LayerHelper -from . import _C_ops +from paddle import _C_ops +from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph __all__ = [ 'stft', @@ -127,7 +128,10 @@ def frame(x, frame_length, hop_length, axis=-1, name=None): op_type = 'frame' - if _non_static_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_frame(x, frame_length, hop_length, axis) + + if _in_legacy_dygraph(): attrs = ('frame_length', frame_length, 'hop_length', hop_length, 'axis', axis) op = getattr(_C_ops, op_type) -- GitLab