diff --git a/paddle/fluid/operators/overlap_add_op.cc b/paddle/fluid/operators/overlap_add_op.cc index 108c2df4cd2e1cf954499f60049e893cb7ce1cc5..4ead216135762629a94a32eed217a3e4e3f0e98f 100644 --- a/paddle/fluid/operators/overlap_add_op.cc +++ b/paddle/fluid/operators/overlap_add_op.cc @@ -12,7 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/overlap_add_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -21,93 +25,6 @@ class OverlapAddOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "overlap_add"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "overlap_add"); - - const int hop_length = ctx->Attrs().Get("hop_length"); - const int axis = ctx->Attrs().Get("axis"); - - const auto x_dims = ctx->GetInputDim("X"); - const int x_rank = x_dims.size(); - - PADDLE_ENFORCE_GE( - x_rank, - 2, - platform::errors::InvalidArgument( - "Input(X) of OverlapAddOp should be a tensor which contains " - "at least 2 dimensions, but got rank %s.", - x_rank)); - - PADDLE_ENFORCE_GT( - hop_length, - 0, - platform::errors::InvalidArgument( - "Attribute(hop_length) of OverlapAddOp should be greater " - "than 0, but got %s.", - hop_length)); - - PADDLE_ENFORCE_EQ( - (axis == 0 || axis == -1), - true, - platform::errors::InvalidArgument( - "Attribute(axis) of OverlapAddOp should 0 or -1, but got %s.", - axis)); - - std::vector output_shape; - int n_frames; - int frame_length; - int seq_length; - - int start_axis; - int end_axis; - if (axis == 0) { - n_frames = x_dims[0]; - frame_length = x_dims[1]; - start_axis = 2; - end_axis = x_rank - 1; - } else { - n_frames = x_dims[x_rank - 1]; - frame_length = x_dims[x_rank - 2]; - start_axis = 0; - end_axis = x_rank - 3; - } - - bool contain_unknown_dim = phi::contain_unknown_dim(x_dims); - bool check = ctx->IsRuntime() || !contain_unknown_dim; - if (check) { - PADDLE_ENFORCE_LE( - hop_length, - frame_length, - platform::errors::InvalidArgument( - "Attribute(hop_length) of OverlapAddOp should be less or equal " - "than frame_length, but got hop_length(%s) > frame_length(%s).", - hop_length, - frame_length)); - } - - if (n_frames == -1) { - seq_length = -1; - } else { - seq_length = (n_frames - 1) * hop_length + frame_length; - } - - // It won't go into for loop when x_rank == 2U. - for (int i = start_axis; i <= end_axis; i++) { - output_shape.push_back(x_dims[i]); - } - - if (axis == 0) { - // (seq_length, ...) - output_shape.insert(output_shape.begin(), seq_length); - } else { - // (..., seq_length) - output_shape.push_back(seq_length); - } - - ctx->SetOutputDim("Out", phi::make_ddim(output_shape)); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -137,17 +54,6 @@ class OverlapAddOpMaker : public framework::OpProtoAndCheckerMaker { class OverlapAddOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "overlap_add_grad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - "Out@GRAD", - "overlap_add_grad"); - const auto x_dims = ctx->GetInputDim("X"); - if (ctx->HasOutput(framework::GradVarName("X"))) { - ctx->SetOutputDim(framework::GradVarName("X"), x_dims); - } - } protected: framework::OpKernelType GetExpectedKernelType( @@ -176,30 +82,21 @@ class OverlapAddOpGradMaker : public framework::SingleGradOpMaker { namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(overlap_add, + OverlapAddInferShapeFunctor, + PD_INFER_META(phi::OverlapAddInferMeta)); + +DECLARE_INFER_SHAPE_FUNCTOR(overlap_add_grad, + OverlapAddGradInferShapeFunctor, + PD_INFER_META(phi::OverlapAddGradInferMeta)); + REGISTER_OPERATOR(overlap_add, ops::OverlapAddOp, ops::OverlapAddOpMaker, ops::OverlapAddOpGradMaker, - ops::OverlapAddOpGradMaker); - -REGISTER_OPERATOR(overlap_add_grad, ops::OverlapAddOpGrad); - -REGISTER_OP_CPU_KERNEL( - overlap_add, - ops::OverlapAddKernel, - ops::OverlapAddKernel, - ops::OverlapAddKernel, - ops::OverlapAddKernel, - ops::OverlapAddKernel>, - ops::OverlapAddKernel>); + ops::OverlapAddOpGradMaker, + OverlapAddInferShapeFunctor); -REGISTER_OP_CPU_KERNEL( - overlap_add_grad, - ops::OverlapAddGradKernel, - ops::OverlapAddGradKernel, - ops::OverlapAddGradKernel, - ops::OverlapAddGradKernel, - ops::OverlapAddGradKernel>, - ops::OverlapAddGradKernel>); +REGISTER_OPERATOR(overlap_add_grad, + ops::OverlapAddOpGrad, + OverlapAddGradInferShapeFunctor); diff --git a/paddle/fluid/operators/overlap_add_op.cu b/paddle/fluid/operators/overlap_add_op.cu deleted file mode 100644 index 2b7935e0191b7ae3cec8aff44236fd386e374261..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/overlap_add_op.cu +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/overlap_add_op.h" - -namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL( - overlap_add, - ops::OverlapAddKernel, - ops::OverlapAddKernel, - ops::OverlapAddKernel, - ops::OverlapAddKernel, - ops::OverlapAddKernel, - ops::OverlapAddKernel>, - ops::OverlapAddKernel>); - -REGISTER_OP_CUDA_KERNEL( - overlap_add_grad, - ops::OverlapAddGradKernel, - ops::OverlapAddGradKernel, - ops::OverlapAddGradKernel, - ops::OverlapAddGradKernel, - ops::OverlapAddGradKernel, - ops::OverlapAddGradKernel>, - ops::OverlapAddGradKernel>); diff --git a/paddle/fluid/operators/overlap_add_op.h b/paddle/fluid/operators/overlap_add_op.h deleted file mode 100644 index b8008871d208f518dddd22b586cad0ef0b5f5721..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/overlap_add_op.h +++ /dev/null @@ -1,322 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/transpose_op.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/funcs/seq2col.h" - -namespace paddle { -namespace operators { -using Tensor = framework::Tensor; - -template -struct OverlapAddFunctor { - void operator()(const DeviceContext& dev_ctx, - const Tensor* input, - Tensor* output, - size_t seq_length, - size_t frame_length, - size_t n_frames, - size_t hop_length, - bool is_grad = false) const { - auto numel = output->numel(); - const auto* input_data = input->data(); - auto* output_data = output->data(); - - platform::ForRange for_range(dev_ctx, numel); - if (!is_grad) { - phi::funcs::Col2SeqFunctor functor(input_data, - output_data, - seq_length, - frame_length, - n_frames, - hop_length); - for_range(functor); - } else { - phi::funcs::Seq2ColFunctor functor(input_data, - output_data, - seq_length, - frame_length, - n_frames, - hop_length); - for_range(functor); - } - } -}; - -template -class OverlapAddKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - const Tensor* x = ctx.Input("X"); - Tensor* out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); - const size_t x_rank = x->dims().size(); - const size_t out_rank = out->dims().size(); - - const int hop_length = ctx.Attr("hop_length"); - const int axis = ctx.Attr("axis"); - const int n_frames = (axis == 0) ? x->dims()[0] : x->dims()[x_rank - 1]; - const int frame_length = (axis == 0) ? x->dims()[1] : x->dims()[x_rank - 2]; - const int seq_length = - (axis == 0) ? out->dims()[0] : out->dims()[out_rank - 1]; - - auto& dev_ctx = ctx.device_context(); - - Tensor x_(x->type()); - x_ = *x; - - framework::DDim preserved_dims; - if (out_rank > 2) { - // Save dims used to flatten both input and output tensors and restore - // output tensor. - framework::DDim x_resized_dims; - framework::DDim out_resized_dims; - if (axis == 0) { - preserved_dims = phi::slice_ddim(out->dims(), 1, out_rank); - x_resized_dims = {n_frames, frame_length, phi::product(preserved_dims)}; - out_resized_dims = {seq_length, phi::product(preserved_dims)}; - } else { - preserved_dims = phi::slice_ddim(out->dims(), 0, out_rank - 1); - x_resized_dims = {phi::product(preserved_dims), frame_length, n_frames}; - out_resized_dims = {phi::product(preserved_dims), seq_length}; - } - x_.Resize(x_resized_dims); - out->Resize(out_resized_dims); - } - - Tensor trans_x(x_.type()); - Tensor trans_out(out->type()); - - // Transpose input and output in case that axis is 0. - if (axis == 0) { - if (out_rank == 1U) { - trans_out = *out; - - std::vector perm_x{1, 0}; - auto x_dims_vec = phi::vectorize(x_.dims()); - for (int i = 0; i < x_.dims().size(); ++i) { - x_dims_vec[i] = x_.dims()[perm_x[i]]; - } - trans_x.Resize(phi::make_ddim(x_dims_vec)); - trans_x.mutable_data(ctx.GetPlace()); - TransCompute( - perm_x.size(), dev_ctx, x_, &trans_x, perm_x); - } else { - std::vector perm_out{1, 0}; - auto out_dims_vec = phi::vectorize(out->dims()); - for (int i = 0; i < out->dims().size(); ++i) { - out_dims_vec[i] = out->dims()[perm_out[i]]; - } - trans_out.Resize(phi::make_ddim(out_dims_vec)); - trans_out.mutable_data(ctx.GetPlace()); - TransCompute( - perm_out.size(), dev_ctx, *out, &trans_out, perm_out); - - std::vector perm_x{2, 1, 0}; - auto x_dims_vec = phi::vectorize(x_.dims()); - for (int i = 0; i < x_.dims().size(); ++i) { - x_dims_vec[i] = x_.dims()[perm_x[i]]; - } - trans_x.Resize(phi::make_ddim(x_dims_vec)); - trans_x.mutable_data(ctx.GetPlace()); - TransCompute( - perm_x.size(), dev_ctx, x_, &trans_x, perm_x); - } - } else { - trans_x = x_; - trans_out = *out; - } - - OverlapAddFunctor()(dev_ctx, - &trans_x, - &trans_out, - seq_length, - frame_length, - n_frames, - hop_length, - /*is_grad*/ false); - - // Transpose output in case axis is 0. - if (axis == 0 && out_rank > 1U) { - std::vector perm_out{1, 0}; - TransCompute( - perm_out.size(), dev_ctx, trans_out, out, perm_out); - } - - // Restore output dims when the number of dims is larger than 2. - if (out_rank > 2) { - std::vector restored_out_shape; - for (int i = 0; i < preserved_dims.size(); i++) { - restored_out_shape.push_back(preserved_dims[i]); - } - - if (axis == 0) { - // (seq_length, ...) - restored_out_shape.insert(restored_out_shape.begin(), seq_length); - } else { - // (..., seq_length) - restored_out_shape.push_back(seq_length); - } - - out->Resize(phi::make_ddim(restored_out_shape)); - } - } -}; - -template -class OverlapAddGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const Tensor* d_out = ctx.Input(framework::GradVarName("Out")); - Tensor* d_x = ctx.Output(framework::GradVarName("X")); - d_x->mutable_data(ctx.GetPlace()); - const size_t d_out_rank = d_out->dims().size(); - const size_t d_x_rank = d_x->dims().size(); - - const int hop_length = ctx.Attr("hop_length"); - const int axis = ctx.Attr("axis"); - const int n_frames = - (axis == 0) ? d_x->dims()[0] : d_x->dims()[d_x_rank - 1]; - const int frame_length = - (axis == 0) ? d_x->dims()[1] : d_x->dims()[d_x_rank - 2]; - const int seq_length = - (axis == 0) ? d_out->dims()[0] : d_out->dims()[d_out_rank - 1]; - - auto& dev_ctx = ctx.device_context(); - - // When the number of input dims is larger than 2, it needs to copy - // from x to resize input into 2d and output into 3d. Morevoer, output - // dims will be restored at the last step. - Tensor d_out_(d_out->type()); - d_out_ = *d_out; - - framework::DDim preserved_dims; - if (d_out_rank > 2) { - // Save dims used to flatten both input and output tensors and restore - // output tensor. - framework::DDim d_x_resized_dims; - framework::DDim d_out_resized_dims; - if (axis == 0) { - preserved_dims = phi::slice_ddim(d_out_.dims(), 1, d_out_rank); - d_x_resized_dims = { - n_frames, frame_length, phi::product(preserved_dims)}; - d_out_resized_dims = {seq_length, phi::product(preserved_dims)}; - } else { - preserved_dims = phi::slice_ddim(d_out_.dims(), 0, d_out_rank - 1); - d_x_resized_dims = { - phi::product(preserved_dims), frame_length, n_frames}; - d_out_resized_dims = {phi::product(preserved_dims), seq_length}; - } - d_x->Resize(d_x_resized_dims); - d_out_.Resize(d_out_resized_dims); - } - - Tensor trans_d_x(d_x->type()); - Tensor trans_d_out(d_out_.type()); - - // Transpose input and output in case that axis is 0. - if (axis == 0) { - if (d_out_rank == 1U) { - trans_d_out = d_out_; - - std::vector perm_d_x{1, 0}; - auto d_x_dims_vec = phi::vectorize(d_x->dims()); - for (int i = 0; i < d_x->dims().size(); ++i) { - d_x_dims_vec[i] = d_x->dims()[perm_d_x[i]]; - } - trans_d_x.Resize(phi::make_ddim(d_x_dims_vec)); - trans_d_x.mutable_data(ctx.GetPlace()); - TransCompute( - perm_d_x.size(), dev_ctx, *d_x, &trans_d_x, perm_d_x); - } else { - std::vector perm_d_out{1, 0}; - auto d_out_dims_vec = phi::vectorize(d_out_.dims()); - for (int i = 0; i < d_out_.dims().size(); ++i) { - d_out_dims_vec[i] = d_out_.dims()[perm_d_out[i]]; - } - trans_d_out.Resize(phi::make_ddim(d_out_dims_vec)); - trans_d_out.mutable_data(ctx.GetPlace()); - TransCompute( - perm_d_out.size(), dev_ctx, d_out_, &trans_d_out, perm_d_out); - - std::vector perm_d_x{2, 1, 0}; - auto d_x_dims_vec = phi::vectorize(d_x->dims()); - for (int i = 0; i < d_x->dims().size(); ++i) { - d_x_dims_vec[i] = d_x->dims()[perm_d_x[i]]; - } - trans_d_x.Resize(phi::make_ddim(d_x_dims_vec)); - trans_d_x.mutable_data(ctx.GetPlace()); - TransCompute( - perm_d_x.size(), dev_ctx, *d_x, &trans_d_x, perm_d_x); - } - } else { - trans_d_x = *d_x; - trans_d_out = d_out_; - } - - OverlapAddFunctor()(dev_ctx, - &trans_d_out, - &trans_d_x, - seq_length, - frame_length, - n_frames, - hop_length, - /*is_grad*/ true); - - // Transpose output in case axis is 0. - if (axis == 0) { - if (d_out_rank == 1U) { - std::vector perm_d_x{1, 0}; - TransCompute( - perm_d_x.size(), dev_ctx, trans_d_x, d_x, perm_d_x); - } else { - std::vector perm_d_x{2, 1, 0}; - TransCompute( - perm_d_x.size(), dev_ctx, trans_d_x, d_x, perm_d_x); - } - } - - // Restore output dims when the number of dims is larger than 2. - if (d_out_rank > 2) { - std::vector restored_d_x_shape; - for (int i = 0; i < preserved_dims.size(); i++) { - restored_d_x_shape.push_back(preserved_dims[i]); - } - - if (axis == 0) { - // (n_frames, frame_length, ...) - restored_d_x_shape.insert(restored_d_x_shape.begin(), frame_length); - restored_d_x_shape.insert(restored_d_x_shape.begin(), n_frames); - } else { - // (..., frame_length, n_frames) - restored_d_x_shape.push_back(frame_length); - restored_d_x_shape.push_back(n_frames); - } - - d_x->Resize(phi::make_ddim(restored_d_x_shape)); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 25cdd37ddea9d90d3b5757eee2330c34cccd5490..bb232f6212c39617cff279b93c90d44c1a28ea3f 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -2648,3 +2648,13 @@ kernel: func: eig backward: eig_grad + +# overlap_add +- api: overlap_add + args: (Tensor x, int hop_length, int axis) + output: Tensor + infer_meta: + func: OverlapAddInferMeta + kernel: + func: overlap_add + backward: overlap_add_grad diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index c00d9fd9a627b1a4a650447c32f6c9b1885c1306..b44417050783e688442c08d1c722994d229a4dee 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -914,10 +914,10 @@ forward : grid_sample (Tensor x, Tensor grid, str mode, str padding_mode, bool align_corners) -> Tensor(out) args : (Tensor x, Tensor grid, Tensor out_grad, str mode, str padding_mode, bool align_corners) output : Tensor(x_grad), Tensor(grid_grad) - infer_meta : + infer_meta : func : GeneralBinaryGradInferMeta param : [x, grid] - kernel : + kernel : func : grid_sample_grad data_type : x @@ -1552,6 +1552,16 @@ kernel : func : norm_grad +- backward_api : overlap_add_grad + forward : overlap_add(Tensor x, int hop_length, int axis) -> Tensor(out) + args : (Tensor x, Tensor out_grad, int hop_length, int axis) + output : Tensor(x_grad) + infer_meta : + func : OverlapAddGradInferMeta + kernel : + func : overlap_add_grad + data_type : x + - backward_api : p_norm_grad forward : p_norm(Tensor x, float porder, int axis, float epsilon, bool keepdim, bool asvector=false) -> Tensor(out) args : (Tensor x, Tensor out, Tensor out_grad, float porder, int axis, float epsilon, bool keepdim, bool asvector) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 26e578107206e84da4750b21e83b21b79f7adef2..2c2484da35d0d3dc3427c94dbaafc873bc3975fc 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -609,6 +609,18 @@ void NllLossGradInferMeta(const MetaTensor& x, } } +void OverlapAddGradInferMeta(const MetaTensor& x, + const MetaTensor& out_grad, + int hop_length, + int axis, + MetaTensor* x_grad) { + const auto x_dims = x.dims(); + if (x_grad != nullptr) { + x_grad->set_dims(x_dims); + x_grad->set_dtype(x.dtype()); + } +} + void PixelUnshuffleGradInferMeta(const MetaTensor& out_grad, int downscale_factor, const std::string& data_format, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index bc89d84cf2203cdd7d19bff1b9698f38fff3925f..add2f8945dd9f6c1258fdf321a362151114e2e29 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -262,6 +262,12 @@ void PixelUnshuffleGradInferMeta(const MetaTensor& out_grad, const std::string& data_format, MetaTensor* x_grad); +void OverlapAddGradInferMeta(const MetaTensor& x, + const MetaTensor& out_grad, + int hop_length, + int axis, + MetaTensor* x_grad); + void PsroiPoolGradInferMeta(const MetaTensor& x, const MetaTensor& rois, const MetaTensor& rois_num, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 868f563afe1b06c680c116171a5643f41f4ad367..9bb8c156d9016465c01e538d1e9e91fc1f7939f9 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1725,6 +1725,90 @@ void NormInferMeta(const MetaTensor& x, } } +void OverlapAddInferMeta(const MetaTensor& x, + int hop_length, + int axis, + MetaTensor* out, + MetaConfig config) { + const auto x_dims = x.dims(); + const int x_rank = x_dims.size(); + + PADDLE_ENFORCE_GE( + x_rank, + 2, + errors::InvalidArgument( + "Input(X) of OverlapAddOp should be a tensor which contains " + "at least 2 dimensions, but got rank %s.", + x_rank)); + + PADDLE_ENFORCE_GT( + hop_length, + 0, + errors::InvalidArgument( + "Attribute(hop_length) of OverlapAddOp should be greater " + "than 0, but got %s.", + hop_length)); + + PADDLE_ENFORCE_EQ( + (axis == 0 || axis == -1), + true, + errors::InvalidArgument( + "Attribute(axis) of OverlapAddOp should 0 or -1, but got %s.", axis)); + + std::vector output_shape; + int n_frames; + int frame_length; + int seq_length; + + int start_axis; + int end_axis; + if (axis == 0) { + n_frames = x_dims[0]; + frame_length = x_dims[1]; + start_axis = 2; + end_axis = x_rank - 1; + } else { + n_frames = x_dims[x_rank - 1]; + frame_length = x_dims[x_rank - 2]; + start_axis = 0; + end_axis = x_rank - 3; + } + + bool contain_unknown_dim = phi::contain_unknown_dim(x_dims); + bool check = config.is_runtime || !contain_unknown_dim; + if (check) { + PADDLE_ENFORCE_LE( + hop_length, + frame_length, + errors::InvalidArgument( + "Attribute(hop_length) of OverlapAddOp should be less or equal " + "than frame_length, but got hop_length(%s) > frame_length(%s).", + hop_length, + frame_length)); + } + + if (n_frames == -1) { + seq_length = -1; + } else { + seq_length = (n_frames - 1) * hop_length + frame_length; + } + + // It won't go into for loop when x_rank == 2U. + for (int i = start_axis; i <= end_axis; i++) { + output_shape.push_back(x_dims[i]); + } + + if (axis == 0) { + // (seq_length, ...) + output_shape.insert(output_shape.begin(), seq_length); + } else { + // (..., seq_length) + output_shape.push_back(seq_length); + } + + out->set_dims(phi::make_ddim(output_shape)); +} + void PadInferMeta(const MetaTensor& input, const std::vector& paddings, float pad_value, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index c44705fbfa3958f610895c3b7beb769647498c24..f17ab48f0fae6dc6ae2d3953e8f852e65767dd58 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -235,6 +235,12 @@ void NormInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* norm); +void OverlapAddInferMeta(const MetaTensor& x, + int hop_length, + int axis, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void PadInferMeta(const MetaTensor& input, const std::vector& paddings, float pad_value, @@ -542,4 +548,5 @@ void ChannelShuffleInferMeta(const MetaTensor& x, MetaTensor* out); void IdentityLossInferMeta(const MetaTensor& x, int reduction, MetaTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/cpu/overlap_add_grad_kernel.cc b/paddle/phi/kernels/cpu/overlap_add_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..adb4cda0be00556e83a47fb50f64bf12b58fa1ac --- /dev/null +++ b/paddle/phi/kernels/cpu/overlap_add_grad_kernel.cc @@ -0,0 +1,164 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/overlap_add_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/overlap_add_functor.h" + +namespace phi { + +template +void OverlapAddGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + int hop_length, + int axis, + DenseTensor* x_grad) { + dev_ctx.template Alloc(x_grad); + const size_t out_grad_rank = out_grad.dims().size(); + const size_t x_grad_rank = x_grad->dims().size(); + + const int n_frames = + (axis == 0) ? x_grad->dims()[0] : x_grad->dims()[x_grad_rank - 1]; + const int frame_length = + (axis == 0) ? x_grad->dims()[1] : x_grad->dims()[x_grad_rank - 2]; + const int seq_length = + (axis == 0) ? out_grad.dims()[0] : out_grad.dims()[out_grad_rank - 1]; + + // When the number of input dims is larger than 2, it needs to copy + // from x to resize input into 2d and output into 3d. Morevoer, output + // dims will be restored at the last step. + DenseTensor out_grad_(out_grad.type()); + out_grad_ = out_grad; + + phi::DDim preserved_dims; + if (out_grad_rank > 2) { + // Save dims used to flatten both input and output tensors and restore + // output tensor. + phi::DDim x_grad_resized_dims; + phi::DDim out_grad_resized_dims; + if (axis == 0) { + preserved_dims = phi::slice_ddim(out_grad_.dims(), 1, out_grad_rank); + x_grad_resized_dims = { + n_frames, frame_length, phi::product(preserved_dims)}; + out_grad_resized_dims = {seq_length, phi::product(preserved_dims)}; + } else { + preserved_dims = phi::slice_ddim(out_grad_.dims(), 0, out_grad_rank - 1); + x_grad_resized_dims = { + phi::product(preserved_dims), frame_length, n_frames}; + out_grad_resized_dims = {phi::product(preserved_dims), seq_length}; + } + x_grad->Resize(x_grad_resized_dims); + out_grad_.Resize(out_grad_resized_dims); + } + + DenseTensor trans_x_grad(x_grad->type()); + DenseTensor trans_out_grad(out_grad_.type()); + + // Transpose input and output in case that axis is 0. + if (axis == 0) { + if (out_grad_rank == 1U) { + trans_out_grad = out_grad_; + + std::vector perm_x_grad{1, 0}; + auto x_grad_dims_vec = phi::vectorize(x_grad->dims()); + for (int i = 0; i < x_grad->dims().size(); ++i) { + x_grad_dims_vec[i] = x_grad->dims()[perm_x_grad[i]]; + } + trans_x_grad.Resize(phi::make_ddim(x_grad_dims_vec)); + dev_ctx.template Alloc(&trans_x_grad); + phi::funcs::TransCompute( + perm_x_grad.size(), dev_ctx, *x_grad, &trans_x_grad, perm_x_grad); + } else { + std::vector perm_d_out{1, 0}; + auto out_grad_dims_vec = phi::vectorize(out_grad_.dims()); + for (int i = 0; i < out_grad_.dims().size(); ++i) { + out_grad_dims_vec[i] = out_grad_.dims()[perm_d_out[i]]; + } + trans_out_grad.Resize(phi::make_ddim(out_grad_dims_vec)); + dev_ctx.template Alloc(&trans_out_grad); + phi::funcs::TransCompute( + perm_d_out.size(), dev_ctx, out_grad_, &trans_out_grad, perm_d_out); + + std::vector perm_x_grad{2, 1, 0}; + auto x_grad_dims_vec = phi::vectorize(x_grad->dims()); + for (int i = 0; i < x_grad->dims().size(); ++i) { + x_grad_dims_vec[i] = x_grad->dims()[perm_x_grad[i]]; + } + trans_x_grad.Resize(phi::make_ddim(x_grad_dims_vec)); + dev_ctx.template Alloc(&trans_x_grad); + phi::funcs::TransCompute( + perm_x_grad.size(), dev_ctx, *x_grad, &trans_x_grad, perm_x_grad); + } + } else { + trans_x_grad = *x_grad; + trans_out_grad = out_grad_; + } + + OverlapAddFunctor()(dev_ctx, + &trans_out_grad, + &trans_x_grad, + seq_length, + frame_length, + n_frames, + hop_length, + /*is_grad*/ true); + + // Transpose output in case axis is 0. + if (axis == 0) { + if (out_grad_rank == 1U) { + std::vector perm_x_grad{1, 0}; + phi::funcs::TransCompute( + perm_x_grad.size(), dev_ctx, trans_x_grad, x_grad, perm_x_grad); + } else { + std::vector perm_x_grad{2, 1, 0}; + phi::funcs::TransCompute( + perm_x_grad.size(), dev_ctx, trans_x_grad, x_grad, perm_x_grad); + } + } + + // Restore output dims when the number of dims is larger than 2. + if (out_grad_rank > 2) { + std::vector restored_x_grad_shape; + for (int i = 0; i < preserved_dims.size(); i++) { + restored_x_grad_shape.push_back(preserved_dims[i]); + } + + if (axis == 0) { + // (n_frames, frame_length, ...) + restored_x_grad_shape.insert(restored_x_grad_shape.begin(), frame_length); + restored_x_grad_shape.insert(restored_x_grad_shape.begin(), n_frames); + } else { + // (..., frame_length, n_frames) + restored_x_grad_shape.push_back(frame_length); + restored_x_grad_shape.push_back(n_frames); + } + + x_grad->Resize(phi::make_ddim(restored_x_grad_shape)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(overlap_add_grad, + CPU, + ALL_LAYOUT, + phi::OverlapAddGradKernel, + int, + int64_t, + float, + double, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/phi/kernels/cpu/overlap_add_kernel.cc b/paddle/phi/kernels/cpu/overlap_add_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..7609ed837ae074f3b26dbe51e7d3571103b6fe23 --- /dev/null +++ b/paddle/phi/kernels/cpu/overlap_add_kernel.cc @@ -0,0 +1,150 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/overlap_add_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/overlap_add_functor.h" + +namespace phi { + +template +void OverlapAddKernel(const Context& dev_ctx, + const DenseTensor& x, + int hop_length, + int axis, + DenseTensor* out) { + dev_ctx.template Alloc(out); + const size_t x_rank = x.dims().size(); + const size_t out_rank = out->dims().size(); + + const int n_frames = (axis == 0) ? x.dims()[0] : x.dims()[x_rank - 1]; + const int frame_length = (axis == 0) ? x.dims()[1] : x.dims()[x_rank - 2]; + const int seq_length = + (axis == 0) ? out->dims()[0] : out->dims()[out_rank - 1]; + + // auto& dev_ctx = ctx.device_context(); + + DenseTensor x_(x.type()); + x_ = x; + + phi::DDim preserved_dims; + if (out_rank > 2) { + // Save dims used to flatten both input and output tensors and restore + // output tensor. + phi::DDim x_resized_dims; + phi::DDim out_resized_dims; + if (axis == 0) { + preserved_dims = phi::slice_ddim(out->dims(), 1, out_rank); + x_resized_dims = {n_frames, frame_length, phi::product(preserved_dims)}; + out_resized_dims = {seq_length, phi::product(preserved_dims)}; + } else { + preserved_dims = phi::slice_ddim(out->dims(), 0, out_rank - 1); + x_resized_dims = {phi::product(preserved_dims), frame_length, n_frames}; + out_resized_dims = {phi::product(preserved_dims), seq_length}; + } + x_.Resize(x_resized_dims); + out->Resize(out_resized_dims); + } + + DenseTensor trans_x(x_.type()); + DenseTensor trans_out(out->type()); + + // Transpose input and output in case that axis is 0. + if (axis == 0) { + if (out_rank == 1U) { + trans_out = *out; + + std::vector perm_x{1, 0}; + auto x_dims_vec = phi::vectorize(x_.dims()); + for (int i = 0; i < x_.dims().size(); ++i) { + x_dims_vec[i] = x_.dims()[perm_x[i]]; + } + trans_x.Resize(phi::make_ddim(x_dims_vec)); + dev_ctx.template Alloc(&trans_x); + phi::funcs::TransCompute( + perm_x.size(), dev_ctx, x_, &trans_x, perm_x); + } else { + std::vector perm_out{1, 0}; + auto out_dims_vec = phi::vectorize(out->dims()); + for (int i = 0; i < out->dims().size(); ++i) { + out_dims_vec[i] = out->dims()[perm_out[i]]; + } + trans_out.Resize(phi::make_ddim(out_dims_vec)); + dev_ctx.template Alloc(&trans_out); + phi::funcs::TransCompute( + perm_out.size(), dev_ctx, *out, &trans_out, perm_out); + + std::vector perm_x{2, 1, 0}; + auto x_dims_vec = phi::vectorize(x_.dims()); + for (int i = 0; i < x_.dims().size(); ++i) { + x_dims_vec[i] = x_.dims()[perm_x[i]]; + } + trans_x.Resize(phi::make_ddim(x_dims_vec)); + dev_ctx.template Alloc(&trans_x); + phi::funcs::TransCompute( + perm_x.size(), dev_ctx, x_, &trans_x, perm_x); + } + } else { + trans_x = x_; + trans_out = *out; + } + + OverlapAddFunctor()(dev_ctx, + &trans_x, + &trans_out, + seq_length, + frame_length, + n_frames, + hop_length, + /*is_grad*/ false); + + // Transpose output in case axis is 0. + if (axis == 0 && out_rank > 1U) { + std::vector perm_out{1, 0}; + phi::funcs::TransCompute( + perm_out.size(), dev_ctx, trans_out, out, perm_out); + } + + // Restore output dims when the number of dims is larger than 2. + if (out_rank > 2) { + std::vector restored_out_shape; + for (int i = 0; i < preserved_dims.size(); i++) { + restored_out_shape.push_back(preserved_dims[i]); + } + + if (axis == 0) { + // (seq_length, ...) + restored_out_shape.insert(restored_out_shape.begin(), seq_length); + } else { + // (..., seq_length) + restored_out_shape.push_back(seq_length); + } + + out->Resize(phi::make_ddim(restored_out_shape)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(overlap_add, + CPU, + ALL_LAYOUT, + phi::OverlapAddKernel, + int, + int64_t, + float, + double, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/phi/kernels/funcs/overlap_add_functor.h b/paddle/phi/kernels/funcs/overlap_add_functor.h new file mode 100644 index 0000000000000000000000000000000000000000..df7505325145a31d5a2027b0a76b8db04bad658b --- /dev/null +++ b/paddle/phi/kernels/funcs/overlap_add_functor.h @@ -0,0 +1,58 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/seq2col.h" + +namespace phi { + +template +struct OverlapAddFunctor { + void operator()(const Context& dev_ctx, + const DenseTensor* input, + DenseTensor* output, + size_t seq_length, + size_t frame_length, + size_t n_frames, + size_t hop_length, + bool is_grad = false) const { + auto numel = output->numel(); + const auto* input_data = input->data(); + auto* output_data = output->data(); + + phi::funcs::ForRange for_range(dev_ctx, numel); + if (!is_grad) { + phi::funcs::Col2SeqFunctor functor(input_data, + output_data, + seq_length, + frame_length, + n_frames, + hop_length); + for_range(functor); + } else { + phi::funcs::Seq2ColFunctor functor(input_data, + output_data, + seq_length, + frame_length, + n_frames, + hop_length); + for_range(functor); + } + } +}; + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/overlap_add_grad_kernel.cu b/paddle/phi/kernels/gpu/overlap_add_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..4957ade2c37bc98fcbfb44820ac171dc4eb3493e --- /dev/null +++ b/paddle/phi/kernels/gpu/overlap_add_grad_kernel.cu @@ -0,0 +1,165 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/overlap_add_grad_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/overlap_add_functor.h" + +namespace phi { + +template +void OverlapAddGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + int hop_length, + int axis, + DenseTensor* x_grad) { + dev_ctx.template Alloc(x_grad); + const size_t out_grad_rank = out_grad.dims().size(); + const size_t x_grad_rank = x_grad->dims().size(); + + const int n_frames = + (axis == 0) ? x_grad->dims()[0] : x_grad->dims()[x_grad_rank - 1]; + const int frame_length = + (axis == 0) ? x_grad->dims()[1] : x_grad->dims()[x_grad_rank - 2]; + const int seq_length = + (axis == 0) ? out_grad.dims()[0] : out_grad.dims()[out_grad_rank - 1]; + + // When the number of input dims is larger than 2, it needs to copy + // from x to resize input into 2d and output into 3d. Morevoer, output + // dims will be restored at the last step. + DenseTensor out_grad_(out_grad.type()); + out_grad_ = out_grad; + + phi::DDim preserved_dims; + if (out_grad_rank > 2) { + // Save dims used to flatten both input and output tensors and restore + // output tensor. + phi::DDim x_grad_resized_dims; + phi::DDim out_grad_resized_dims; + if (axis == 0) { + preserved_dims = phi::slice_ddim(out_grad_.dims(), 1, out_grad_rank); + x_grad_resized_dims = { + n_frames, frame_length, phi::product(preserved_dims)}; + out_grad_resized_dims = {seq_length, phi::product(preserved_dims)}; + } else { + preserved_dims = phi::slice_ddim(out_grad_.dims(), 0, out_grad_rank - 1); + x_grad_resized_dims = { + phi::product(preserved_dims), frame_length, n_frames}; + out_grad_resized_dims = {phi::product(preserved_dims), seq_length}; + } + x_grad->Resize(x_grad_resized_dims); + out_grad_.Resize(out_grad_resized_dims); + } + + DenseTensor trans_x_grad(x_grad->type()); + DenseTensor trans_out_grad(out_grad_.type()); + + // Transpose input and output in case that axis is 0. + if (axis == 0) { + if (out_grad_rank == 1U) { + trans_out_grad = out_grad_; + + std::vector perm_x_grad{1, 0}; + auto x_grad_dims_vec = phi::vectorize(x_grad->dims()); + for (int i = 0; i < x_grad->dims().size(); ++i) { + x_grad_dims_vec[i] = x_grad->dims()[perm_x_grad[i]]; + } + trans_x_grad.Resize(phi::make_ddim(x_grad_dims_vec)); + dev_ctx.template Alloc(&trans_x_grad); + phi::funcs::TransCompute( + perm_x_grad.size(), dev_ctx, *x_grad, &trans_x_grad, perm_x_grad); + } else { + std::vector perm_d_out{1, 0}; + auto out_grad_dims_vec = phi::vectorize(out_grad_.dims()); + for (int i = 0; i < out_grad_.dims().size(); ++i) { + out_grad_dims_vec[i] = out_grad_.dims()[perm_d_out[i]]; + } + trans_out_grad.Resize(phi::make_ddim(out_grad_dims_vec)); + dev_ctx.template Alloc(&trans_out_grad); + phi::funcs::TransCompute( + perm_d_out.size(), dev_ctx, out_grad_, &trans_out_grad, perm_d_out); + + std::vector perm_x_grad{2, 1, 0}; + auto x_grad_dims_vec = phi::vectorize(x_grad->dims()); + for (int i = 0; i < x_grad->dims().size(); ++i) { + x_grad_dims_vec[i] = x_grad->dims()[perm_x_grad[i]]; + } + trans_x_grad.Resize(phi::make_ddim(x_grad_dims_vec)); + dev_ctx.template Alloc(&trans_x_grad); + phi::funcs::TransCompute( + perm_x_grad.size(), dev_ctx, *x_grad, &trans_x_grad, perm_x_grad); + } + } else { + trans_x_grad = *x_grad; + trans_out_grad = out_grad_; + } + + OverlapAddFunctor()(dev_ctx, + &trans_out_grad, + &trans_x_grad, + seq_length, + frame_length, + n_frames, + hop_length, + /*is_grad*/ true); + + // Transpose output in case axis is 0. + if (axis == 0) { + if (out_grad_rank == 1U) { + std::vector perm_x_grad{1, 0}; + phi::funcs::TransCompute( + perm_x_grad.size(), dev_ctx, trans_x_grad, x_grad, perm_x_grad); + } else { + std::vector perm_x_grad{2, 1, 0}; + phi::funcs::TransCompute( + perm_x_grad.size(), dev_ctx, trans_x_grad, x_grad, perm_x_grad); + } + } + + // Restore output dims when the number of dims is larger than 2. + if (out_grad_rank > 2) { + std::vector restored_x_grad_shape; + for (int i = 0; i < preserved_dims.size(); i++) { + restored_x_grad_shape.push_back(preserved_dims[i]); + } + + if (axis == 0) { + // (n_frames, frame_length, ...) + restored_x_grad_shape.insert(restored_x_grad_shape.begin(), frame_length); + restored_x_grad_shape.insert(restored_x_grad_shape.begin(), n_frames); + } else { + // (..., frame_length, n_frames) + restored_x_grad_shape.push_back(frame_length); + restored_x_grad_shape.push_back(n_frames); + } + + x_grad->Resize(phi::make_ddim(restored_x_grad_shape)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(overlap_add_grad, + GPU, + ALL_LAYOUT, + phi::OverlapAddGradKernel, + int, + int64_t, + float, + double, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/phi/kernels/gpu/overlap_add_kernel.cu b/paddle/phi/kernels/gpu/overlap_add_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..47989ebcda816514b3370dca074cd8a153564d0d --- /dev/null +++ b/paddle/phi/kernels/gpu/overlap_add_kernel.cu @@ -0,0 +1,151 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/overlap_add_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/overlap_add_functor.h" + +namespace phi { + +template +void OverlapAddKernel(const Context& dev_ctx, + const DenseTensor& x, + int hop_length, + int axis, + DenseTensor* out) { + dev_ctx.template Alloc(out); + const size_t x_rank = x.dims().size(); + const size_t out_rank = out->dims().size(); + + const int n_frames = (axis == 0) ? x.dims()[0] : x.dims()[x_rank - 1]; + const int frame_length = (axis == 0) ? x.dims()[1] : x.dims()[x_rank - 2]; + const int seq_length = + (axis == 0) ? out->dims()[0] : out->dims()[out_rank - 1]; + + // auto& dev_ctx = ctx.device_context(); + + DenseTensor x_(x.type()); + x_ = x; + + phi::DDim preserved_dims; + if (out_rank > 2) { + // Save dims used to flatten both input and output tensors and restore + // output tensor. + phi::DDim x_resized_dims; + phi::DDim out_resized_dims; + if (axis == 0) { + preserved_dims = phi::slice_ddim(out->dims(), 1, out_rank); + x_resized_dims = {n_frames, frame_length, phi::product(preserved_dims)}; + out_resized_dims = {seq_length, phi::product(preserved_dims)}; + } else { + preserved_dims = phi::slice_ddim(out->dims(), 0, out_rank - 1); + x_resized_dims = {phi::product(preserved_dims), frame_length, n_frames}; + out_resized_dims = {phi::product(preserved_dims), seq_length}; + } + x_.Resize(x_resized_dims); + out->Resize(out_resized_dims); + } + + DenseTensor trans_x(x_.type()); + DenseTensor trans_out(out->type()); + + // Transpose input and output in case that axis is 0. + if (axis == 0) { + if (out_rank == 1U) { + trans_out = *out; + + std::vector perm_x{1, 0}; + auto x_dims_vec = phi::vectorize(x_.dims()); + for (int i = 0; i < x_.dims().size(); ++i) { + x_dims_vec[i] = x_.dims()[perm_x[i]]; + } + trans_x.Resize(phi::make_ddim(x_dims_vec)); + dev_ctx.template Alloc(&trans_x); + phi::funcs::TransCompute( + perm_x.size(), dev_ctx, x_, &trans_x, perm_x); + } else { + std::vector perm_out{1, 0}; + auto out_dims_vec = phi::vectorize(out->dims()); + for (int i = 0; i < out->dims().size(); ++i) { + out_dims_vec[i] = out->dims()[perm_out[i]]; + } + trans_out.Resize(phi::make_ddim(out_dims_vec)); + dev_ctx.template Alloc(&trans_out); + phi::funcs::TransCompute( + perm_out.size(), dev_ctx, *out, &trans_out, perm_out); + + std::vector perm_x{2, 1, 0}; + auto x_dims_vec = phi::vectorize(x_.dims()); + for (int i = 0; i < x_.dims().size(); ++i) { + x_dims_vec[i] = x_.dims()[perm_x[i]]; + } + trans_x.Resize(phi::make_ddim(x_dims_vec)); + dev_ctx.template Alloc(&trans_x); + phi::funcs::TransCompute( + perm_x.size(), dev_ctx, x_, &trans_x, perm_x); + } + } else { + trans_x = x_; + trans_out = *out; + } + + OverlapAddFunctor()(dev_ctx, + &trans_x, + &trans_out, + seq_length, + frame_length, + n_frames, + hop_length, + /*is_grad*/ false); + + // Transpose output in case axis is 0. + if (axis == 0 && out_rank > 1U) { + std::vector perm_out{1, 0}; + phi::funcs::TransCompute( + perm_out.size(), dev_ctx, trans_out, out, perm_out); + } + + // Restore output dims when the number of dims is larger than 2. + if (out_rank > 2) { + std::vector restored_out_shape; + for (int i = 0; i < preserved_dims.size(); i++) { + restored_out_shape.push_back(preserved_dims[i]); + } + + if (axis == 0) { + // (seq_length, ...) + restored_out_shape.insert(restored_out_shape.begin(), seq_length); + } else { + // (..., seq_length) + restored_out_shape.push_back(seq_length); + } + + out->Resize(phi::make_ddim(restored_out_shape)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(overlap_add, + GPU, + ALL_LAYOUT, + phi::OverlapAddKernel, + int, + int64_t, + float, + double, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/phi/kernels/overlap_add_grad_kernel.h b/paddle/phi/kernels/overlap_add_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..62fcce2e44f15db2a0cc243b71f699e50239ee16 --- /dev/null +++ b/paddle/phi/kernels/overlap_add_grad_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +template +void OverlapAddGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + int hop_length, + int axis, + DenseTensor* x_grad); +} // namespace phi diff --git a/paddle/phi/kernels/overlap_add_kernel.h b/paddle/phi/kernels/overlap_add_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..46df3c8e097d67cb041e071e9220d6537fcc354f --- /dev/null +++ b/paddle/phi/kernels/overlap_add_kernel.h @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +template +void OverlapAddKernel(const Context& dev_ctx, + const DenseTensor& x, + int hop_length, + int axis, + DenseTensor* out); +} // namespace phi diff --git a/paddle/phi/ops/compat/overlap_add_sig.cc b/paddle/phi/ops/compat/overlap_add_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..c694b97f8bb0a5cc88178e360c4e2f6d0cb5c106 --- /dev/null +++ b/paddle/phi/ops/compat/overlap_add_sig.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature OverlapAddGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("overlap_add_grad", + {"X", "Out@GRAD"}, + {"hop_length", "axis"}, + {"X@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(overlap_add_grad, + phi::OverlapAddGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_overlap_add_op.py b/python/paddle/fluid/tests/unittests/test_overlap_add_op.py index e04db251de6d2cd819bb5786a5336765f3ed41f8..f3815ba496566b85df5b9b5833901dc499911a95 100644 --- a/python/paddle/fluid/tests/unittests/test_overlap_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_overlap_add_op.py @@ -73,6 +73,7 @@ class TestOverlapAddOp(OpTest): def setUp(self): self.op_type = "overlap_add" + self.python_api = paddle.signal.overlap_add self.shape, self.type, self.attrs = self.initTestCase() self.inputs = { 'X': np.random.random(size=self.shape).astype(self.type), @@ -90,12 +91,12 @@ class TestOverlapAddOp(OpTest): def test_check_output(self): paddle.enable_static() - self.check_output() + self.check_output(check_eager=True) paddle.disable_static() def test_check_grad_normal(self): paddle.enable_static() - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) paddle.disable_static() diff --git a/python/paddle/signal.py b/python/paddle/signal.py index fdca681fab7dfd0641bbffa86127535ba0292300..5c0d1d5edb8216857f5a5a8ad1a259dda1ef817b 100644 --- a/python/paddle/signal.py +++ b/python/paddle/signal.py @@ -217,7 +217,9 @@ def overlap_add(x, hop_length, axis=-1, name=None): op_type = 'overlap_add' - if _non_static_mode(): + if in_dygraph_mode(): + out = _C_ops.final_state_overlap_add(x, hop_length, axis) + elif paddle.in_dynamic_mode(): attrs = ('hop_length', hop_length, 'axis', axis) op = getattr(_C_ops, op_type) out = op(x, *attrs)