未验证 提交 2a8219c1 编写于 作者: L levi131 提交者: GitHub

migrate overlap_add and overlap_add_grad op (#44739)

* update code format

* add ymal and test

* update for comments
上级 1d79f1f7
......@@ -12,7 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/overlap_add_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -21,93 +25,6 @@ class OverlapAddOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "overlap_add");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "overlap_add");
const int hop_length = ctx->Attrs().Get<int>("hop_length");
const int axis = ctx->Attrs().Get<int>("axis");
const auto x_dims = ctx->GetInputDim("X");
const int x_rank = x_dims.size();
PADDLE_ENFORCE_GE(
x_rank,
2,
platform::errors::InvalidArgument(
"Input(X) of OverlapAddOp should be a tensor which contains "
"at least 2 dimensions, but got rank %s.",
x_rank));
PADDLE_ENFORCE_GT(
hop_length,
0,
platform::errors::InvalidArgument(
"Attribute(hop_length) of OverlapAddOp should be greater "
"than 0, but got %s.",
hop_length));
PADDLE_ENFORCE_EQ(
(axis == 0 || axis == -1),
true,
platform::errors::InvalidArgument(
"Attribute(axis) of OverlapAddOp should 0 or -1, but got %s.",
axis));
std::vector<int64_t> output_shape;
int n_frames;
int frame_length;
int seq_length;
int start_axis;
int end_axis;
if (axis == 0) {
n_frames = x_dims[0];
frame_length = x_dims[1];
start_axis = 2;
end_axis = x_rank - 1;
} else {
n_frames = x_dims[x_rank - 1];
frame_length = x_dims[x_rank - 2];
start_axis = 0;
end_axis = x_rank - 3;
}
bool contain_unknown_dim = phi::contain_unknown_dim(x_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) {
PADDLE_ENFORCE_LE(
hop_length,
frame_length,
platform::errors::InvalidArgument(
"Attribute(hop_length) of OverlapAddOp should be less or equal "
"than frame_length, but got hop_length(%s) > frame_length(%s).",
hop_length,
frame_length));
}
if (n_frames == -1) {
seq_length = -1;
} else {
seq_length = (n_frames - 1) * hop_length + frame_length;
}
// It won't go into for loop when x_rank == 2U.
for (int i = start_axis; i <= end_axis; i++) {
output_shape.push_back(x_dims[i]);
}
if (axis == 0) {
// (seq_length, ...)
output_shape.insert(output_shape.begin(), seq_length);
} else {
// (..., seq_length)
output_shape.push_back(seq_length);
}
ctx->SetOutputDim("Out", phi::make_ddim(output_shape));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -137,17 +54,6 @@ class OverlapAddOpMaker : public framework::OpProtoAndCheckerMaker {
class OverlapAddOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "overlap_add_grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out@GRAD",
"overlap_add_grad");
const auto x_dims = ctx->GetInputDim("X");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
......@@ -176,30 +82,21 @@ class OverlapAddOpGradMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(overlap_add,
OverlapAddInferShapeFunctor,
PD_INFER_META(phi::OverlapAddInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(overlap_add_grad,
OverlapAddGradInferShapeFunctor,
PD_INFER_META(phi::OverlapAddGradInferMeta));
REGISTER_OPERATOR(overlap_add,
ops::OverlapAddOp,
ops::OverlapAddOpMaker,
ops::OverlapAddOpGradMaker<paddle::framework::OpDesc>,
ops::OverlapAddOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(overlap_add_grad, ops::OverlapAddOpGrad);
REGISTER_OP_CPU_KERNEL(
overlap_add,
ops::OverlapAddKernel<phi::CPUContext, int>,
ops::OverlapAddKernel<phi::CPUContext, int64_t>,
ops::OverlapAddKernel<phi::CPUContext, float>,
ops::OverlapAddKernel<phi::CPUContext, double>,
ops::OverlapAddKernel<phi::CPUContext, paddle::platform::complex<float>>,
ops::OverlapAddKernel<phi::CPUContext, paddle::platform::complex<double>>);
ops::OverlapAddOpGradMaker<paddle::imperative::OpBase>,
OverlapAddInferShapeFunctor);
REGISTER_OP_CPU_KERNEL(
overlap_add_grad,
ops::OverlapAddGradKernel<phi::CPUContext, int>,
ops::OverlapAddGradKernel<phi::CPUContext, int64_t>,
ops::OverlapAddGradKernel<phi::CPUContext, float>,
ops::OverlapAddGradKernel<phi::CPUContext, double>,
ops::OverlapAddGradKernel<phi::CPUContext,
paddle::platform::complex<float>>,
ops::OverlapAddGradKernel<phi::CPUContext,
paddle::platform::complex<double>>);
REGISTER_OPERATOR(overlap_add_grad,
ops::OverlapAddOpGrad,
OverlapAddGradInferShapeFunctor);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/overlap_add_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
overlap_add,
ops::OverlapAddKernel<paddle::platform::CUDADeviceContext, int>,
ops::OverlapAddKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::OverlapAddKernel<paddle::platform::CUDADeviceContext, float>,
ops::OverlapAddKernel<paddle::platform::CUDADeviceContext, double>,
ops::OverlapAddKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::OverlapAddKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::OverlapAddKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
overlap_add_grad,
ops::OverlapAddGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::OverlapAddGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::OverlapAddGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::OverlapAddGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::OverlapAddGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::OverlapAddGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::OverlapAddGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/seq2col.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
struct OverlapAddFunctor {
void operator()(const DeviceContext& dev_ctx,
const Tensor* input,
Tensor* output,
size_t seq_length,
size_t frame_length,
size_t n_frames,
size_t hop_length,
bool is_grad = false) const {
auto numel = output->numel();
const auto* input_data = input->data<T>();
auto* output_data = output->data<T>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
if (!is_grad) {
phi::funcs::Col2SeqFunctor<T> functor(input_data,
output_data,
seq_length,
frame_length,
n_frames,
hop_length);
for_range(functor);
} else {
phi::funcs::Seq2ColFunctor<T> functor(input_data,
output_data,
seq_length,
frame_length,
n_frames,
hop_length);
for_range(functor);
}
}
};
template <typename DeviceContext, typename T>
class OverlapAddKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
const Tensor* x = ctx.Input<Tensor>("X");
Tensor* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
const size_t x_rank = x->dims().size();
const size_t out_rank = out->dims().size();
const int hop_length = ctx.Attr<int>("hop_length");
const int axis = ctx.Attr<int>("axis");
const int n_frames = (axis == 0) ? x->dims()[0] : x->dims()[x_rank - 1];
const int frame_length = (axis == 0) ? x->dims()[1] : x->dims()[x_rank - 2];
const int seq_length =
(axis == 0) ? out->dims()[0] : out->dims()[out_rank - 1];
auto& dev_ctx = ctx.device_context<DeviceContext>();
Tensor x_(x->type());
x_ = *x;
framework::DDim preserved_dims;
if (out_rank > 2) {
// Save dims used to flatten both input and output tensors and restore
// output tensor.
framework::DDim x_resized_dims;
framework::DDim out_resized_dims;
if (axis == 0) {
preserved_dims = phi::slice_ddim(out->dims(), 1, out_rank);
x_resized_dims = {n_frames, frame_length, phi::product(preserved_dims)};
out_resized_dims = {seq_length, phi::product(preserved_dims)};
} else {
preserved_dims = phi::slice_ddim(out->dims(), 0, out_rank - 1);
x_resized_dims = {phi::product(preserved_dims), frame_length, n_frames};
out_resized_dims = {phi::product(preserved_dims), seq_length};
}
x_.Resize(x_resized_dims);
out->Resize(out_resized_dims);
}
Tensor trans_x(x_.type());
Tensor trans_out(out->type());
// Transpose input and output in case that axis is 0.
if (axis == 0) {
if (out_rank == 1U) {
trans_out = *out;
std::vector<int> perm_x{1, 0};
auto x_dims_vec = phi::vectorize(x_.dims());
for (int i = 0; i < x_.dims().size(); ++i) {
x_dims_vec[i] = x_.dims()[perm_x[i]];
}
trans_x.Resize(phi::make_ddim(x_dims_vec));
trans_x.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(
perm_x.size(), dev_ctx, x_, &trans_x, perm_x);
} else {
std::vector<int> perm_out{1, 0};
auto out_dims_vec = phi::vectorize(out->dims());
for (int i = 0; i < out->dims().size(); ++i) {
out_dims_vec[i] = out->dims()[perm_out[i]];
}
trans_out.Resize(phi::make_ddim(out_dims_vec));
trans_out.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(
perm_out.size(), dev_ctx, *out, &trans_out, perm_out);
std::vector<int> perm_x{2, 1, 0};
auto x_dims_vec = phi::vectorize(x_.dims());
for (int i = 0; i < x_.dims().size(); ++i) {
x_dims_vec[i] = x_.dims()[perm_x[i]];
}
trans_x.Resize(phi::make_ddim(x_dims_vec));
trans_x.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(
perm_x.size(), dev_ctx, x_, &trans_x, perm_x);
}
} else {
trans_x = x_;
trans_out = *out;
}
OverlapAddFunctor<DeviceContext, T>()(dev_ctx,
&trans_x,
&trans_out,
seq_length,
frame_length,
n_frames,
hop_length,
/*is_grad*/ false);
// Transpose output in case axis is 0.
if (axis == 0 && out_rank > 1U) {
std::vector<int> perm_out{1, 0};
TransCompute<DeviceContext, T>(
perm_out.size(), dev_ctx, trans_out, out, perm_out);
}
// Restore output dims when the number of dims is larger than 2.
if (out_rank > 2) {
std::vector<int64_t> restored_out_shape;
for (int i = 0; i < preserved_dims.size(); i++) {
restored_out_shape.push_back(preserved_dims[i]);
}
if (axis == 0) {
// (seq_length, ...)
restored_out_shape.insert(restored_out_shape.begin(), seq_length);
} else {
// (..., seq_length)
restored_out_shape.push_back(seq_length);
}
out->Resize(phi::make_ddim(restored_out_shape));
}
}
};
template <typename DeviceContext, typename T>
class OverlapAddGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
Tensor* d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
d_x->mutable_data<T>(ctx.GetPlace());
const size_t d_out_rank = d_out->dims().size();
const size_t d_x_rank = d_x->dims().size();
const int hop_length = ctx.Attr<int>("hop_length");
const int axis = ctx.Attr<int>("axis");
const int n_frames =
(axis == 0) ? d_x->dims()[0] : d_x->dims()[d_x_rank - 1];
const int frame_length =
(axis == 0) ? d_x->dims()[1] : d_x->dims()[d_x_rank - 2];
const int seq_length =
(axis == 0) ? d_out->dims()[0] : d_out->dims()[d_out_rank - 1];
auto& dev_ctx = ctx.device_context<DeviceContext>();
// When the number of input dims is larger than 2, it needs to copy
// from x to resize input into 2d and output into 3d. Morevoer, output
// dims will be restored at the last step.
Tensor d_out_(d_out->type());
d_out_ = *d_out;
framework::DDim preserved_dims;
if (d_out_rank > 2) {
// Save dims used to flatten both input and output tensors and restore
// output tensor.
framework::DDim d_x_resized_dims;
framework::DDim d_out_resized_dims;
if (axis == 0) {
preserved_dims = phi::slice_ddim(d_out_.dims(), 1, d_out_rank);
d_x_resized_dims = {
n_frames, frame_length, phi::product(preserved_dims)};
d_out_resized_dims = {seq_length, phi::product(preserved_dims)};
} else {
preserved_dims = phi::slice_ddim(d_out_.dims(), 0, d_out_rank - 1);
d_x_resized_dims = {
phi::product(preserved_dims), frame_length, n_frames};
d_out_resized_dims = {phi::product(preserved_dims), seq_length};
}
d_x->Resize(d_x_resized_dims);
d_out_.Resize(d_out_resized_dims);
}
Tensor trans_d_x(d_x->type());
Tensor trans_d_out(d_out_.type());
// Transpose input and output in case that axis is 0.
if (axis == 0) {
if (d_out_rank == 1U) {
trans_d_out = d_out_;
std::vector<int> perm_d_x{1, 0};
auto d_x_dims_vec = phi::vectorize(d_x->dims());
for (int i = 0; i < d_x->dims().size(); ++i) {
d_x_dims_vec[i] = d_x->dims()[perm_d_x[i]];
}
trans_d_x.Resize(phi::make_ddim(d_x_dims_vec));
trans_d_x.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(
perm_d_x.size(), dev_ctx, *d_x, &trans_d_x, perm_d_x);
} else {
std::vector<int> perm_d_out{1, 0};
auto d_out_dims_vec = phi::vectorize(d_out_.dims());
for (int i = 0; i < d_out_.dims().size(); ++i) {
d_out_dims_vec[i] = d_out_.dims()[perm_d_out[i]];
}
trans_d_out.Resize(phi::make_ddim(d_out_dims_vec));
trans_d_out.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(
perm_d_out.size(), dev_ctx, d_out_, &trans_d_out, perm_d_out);
std::vector<int> perm_d_x{2, 1, 0};
auto d_x_dims_vec = phi::vectorize(d_x->dims());
for (int i = 0; i < d_x->dims().size(); ++i) {
d_x_dims_vec[i] = d_x->dims()[perm_d_x[i]];
}
trans_d_x.Resize(phi::make_ddim(d_x_dims_vec));
trans_d_x.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(
perm_d_x.size(), dev_ctx, *d_x, &trans_d_x, perm_d_x);
}
} else {
trans_d_x = *d_x;
trans_d_out = d_out_;
}
OverlapAddFunctor<DeviceContext, T>()(dev_ctx,
&trans_d_out,
&trans_d_x,
seq_length,
frame_length,
n_frames,
hop_length,
/*is_grad*/ true);
// Transpose output in case axis is 0.
if (axis == 0) {
if (d_out_rank == 1U) {
std::vector<int> perm_d_x{1, 0};
TransCompute<DeviceContext, T>(
perm_d_x.size(), dev_ctx, trans_d_x, d_x, perm_d_x);
} else {
std::vector<int> perm_d_x{2, 1, 0};
TransCompute<DeviceContext, T>(
perm_d_x.size(), dev_ctx, trans_d_x, d_x, perm_d_x);
}
}
// Restore output dims when the number of dims is larger than 2.
if (d_out_rank > 2) {
std::vector<int64_t> restored_d_x_shape;
for (int i = 0; i < preserved_dims.size(); i++) {
restored_d_x_shape.push_back(preserved_dims[i]);
}
if (axis == 0) {
// (n_frames, frame_length, ...)
restored_d_x_shape.insert(restored_d_x_shape.begin(), frame_length);
restored_d_x_shape.insert(restored_d_x_shape.begin(), n_frames);
} else {
// (..., frame_length, n_frames)
restored_d_x_shape.push_back(frame_length);
restored_d_x_shape.push_back(n_frames);
}
d_x->Resize(phi::make_ddim(restored_d_x_shape));
}
}
};
} // namespace operators
} // namespace paddle
......@@ -2648,3 +2648,13 @@
kernel:
func: eig
backward: eig_grad
# overlap_add
- api: overlap_add
args: (Tensor x, int hop_length, int axis)
output: Tensor
infer_meta:
func: OverlapAddInferMeta
kernel:
func: overlap_add
backward: overlap_add_grad
......@@ -914,10 +914,10 @@
forward : grid_sample (Tensor x, Tensor grid, str mode, str padding_mode, bool align_corners) -> Tensor(out)
args : (Tensor x, Tensor grid, Tensor out_grad, str mode, str padding_mode, bool align_corners)
output : Tensor(x_grad), Tensor(grid_grad)
infer_meta :
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, grid]
kernel :
kernel :
func : grid_sample_grad
data_type : x
......@@ -1552,6 +1552,16 @@
kernel :
func : norm_grad
- backward_api : overlap_add_grad
forward : overlap_add(Tensor x, int hop_length, int axis) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int hop_length, int axis)
output : Tensor(x_grad)
infer_meta :
func : OverlapAddGradInferMeta
kernel :
func : overlap_add_grad
data_type : x
- backward_api : p_norm_grad
forward : p_norm(Tensor x, float porder, int axis, float epsilon, bool keepdim, bool asvector=false) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, float porder, int axis, float epsilon, bool keepdim, bool asvector)
......
......@@ -609,6 +609,18 @@ void NllLossGradInferMeta(const MetaTensor& x,
}
}
void OverlapAddGradInferMeta(const MetaTensor& x,
const MetaTensor& out_grad,
int hop_length,
int axis,
MetaTensor* x_grad) {
const auto x_dims = x.dims();
if (x_grad != nullptr) {
x_grad->set_dims(x_dims);
x_grad->set_dtype(x.dtype());
}
}
void PixelUnshuffleGradInferMeta(const MetaTensor& out_grad,
int downscale_factor,
const std::string& data_format,
......
......@@ -262,6 +262,12 @@ void PixelUnshuffleGradInferMeta(const MetaTensor& out_grad,
const std::string& data_format,
MetaTensor* x_grad);
void OverlapAddGradInferMeta(const MetaTensor& x,
const MetaTensor& out_grad,
int hop_length,
int axis,
MetaTensor* x_grad);
void PsroiPoolGradInferMeta(const MetaTensor& x,
const MetaTensor& rois,
const MetaTensor& rois_num,
......
......@@ -1725,6 +1725,90 @@ void NormInferMeta(const MetaTensor& x,
}
}
void OverlapAddInferMeta(const MetaTensor& x,
int hop_length,
int axis,
MetaTensor* out,
MetaConfig config) {
const auto x_dims = x.dims();
const int x_rank = x_dims.size();
PADDLE_ENFORCE_GE(
x_rank,
2,
errors::InvalidArgument(
"Input(X) of OverlapAddOp should be a tensor which contains "
"at least 2 dimensions, but got rank %s.",
x_rank));
PADDLE_ENFORCE_GT(
hop_length,
0,
errors::InvalidArgument(
"Attribute(hop_length) of OverlapAddOp should be greater "
"than 0, but got %s.",
hop_length));
PADDLE_ENFORCE_EQ(
(axis == 0 || axis == -1),
true,
errors::InvalidArgument(
"Attribute(axis) of OverlapAddOp should 0 or -1, but got %s.", axis));
std::vector<int64_t> output_shape;
int n_frames;
int frame_length;
int seq_length;
int start_axis;
int end_axis;
if (axis == 0) {
n_frames = x_dims[0];
frame_length = x_dims[1];
start_axis = 2;
end_axis = x_rank - 1;
} else {
n_frames = x_dims[x_rank - 1];
frame_length = x_dims[x_rank - 2];
start_axis = 0;
end_axis = x_rank - 3;
}
bool contain_unknown_dim = phi::contain_unknown_dim(x_dims);
bool check = config.is_runtime || !contain_unknown_dim;
if (check) {
PADDLE_ENFORCE_LE(
hop_length,
frame_length,
errors::InvalidArgument(
"Attribute(hop_length) of OverlapAddOp should be less or equal "
"than frame_length, but got hop_length(%s) > frame_length(%s).",
hop_length,
frame_length));
}
if (n_frames == -1) {
seq_length = -1;
} else {
seq_length = (n_frames - 1) * hop_length + frame_length;
}
// It won't go into for loop when x_rank == 2U.
for (int i = start_axis; i <= end_axis; i++) {
output_shape.push_back(x_dims[i]);
}
if (axis == 0) {
// (seq_length, ...)
output_shape.insert(output_shape.begin(), seq_length);
} else {
// (..., seq_length)
output_shape.push_back(seq_length);
}
out->set_dims(phi::make_ddim(output_shape));
}
void PadInferMeta(const MetaTensor& input,
const std::vector<int>& paddings,
float pad_value,
......
......@@ -235,6 +235,12 @@ void NormInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaTensor* norm);
void OverlapAddInferMeta(const MetaTensor& x,
int hop_length,
int axis,
MetaTensor* out,
MetaConfig config = MetaConfig());
void PadInferMeta(const MetaTensor& input,
const std::vector<int>& paddings,
float pad_value,
......@@ -542,4 +548,5 @@ void ChannelShuffleInferMeta(const MetaTensor& x,
MetaTensor* out);
void IdentityLossInferMeta(const MetaTensor& x, int reduction, MetaTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/overlap_add_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/overlap_add_functor.h"
namespace phi {
template <typename T, typename Context>
void OverlapAddGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
int hop_length,
int axis,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
const size_t out_grad_rank = out_grad.dims().size();
const size_t x_grad_rank = x_grad->dims().size();
const int n_frames =
(axis == 0) ? x_grad->dims()[0] : x_grad->dims()[x_grad_rank - 1];
const int frame_length =
(axis == 0) ? x_grad->dims()[1] : x_grad->dims()[x_grad_rank - 2];
const int seq_length =
(axis == 0) ? out_grad.dims()[0] : out_grad.dims()[out_grad_rank - 1];
// When the number of input dims is larger than 2, it needs to copy
// from x to resize input into 2d and output into 3d. Morevoer, output
// dims will be restored at the last step.
DenseTensor out_grad_(out_grad.type());
out_grad_ = out_grad;
phi::DDim preserved_dims;
if (out_grad_rank > 2) {
// Save dims used to flatten both input and output tensors and restore
// output tensor.
phi::DDim x_grad_resized_dims;
phi::DDim out_grad_resized_dims;
if (axis == 0) {
preserved_dims = phi::slice_ddim(out_grad_.dims(), 1, out_grad_rank);
x_grad_resized_dims = {
n_frames, frame_length, phi::product(preserved_dims)};
out_grad_resized_dims = {seq_length, phi::product(preserved_dims)};
} else {
preserved_dims = phi::slice_ddim(out_grad_.dims(), 0, out_grad_rank - 1);
x_grad_resized_dims = {
phi::product(preserved_dims), frame_length, n_frames};
out_grad_resized_dims = {phi::product(preserved_dims), seq_length};
}
x_grad->Resize(x_grad_resized_dims);
out_grad_.Resize(out_grad_resized_dims);
}
DenseTensor trans_x_grad(x_grad->type());
DenseTensor trans_out_grad(out_grad_.type());
// Transpose input and output in case that axis is 0.
if (axis == 0) {
if (out_grad_rank == 1U) {
trans_out_grad = out_grad_;
std::vector<int> perm_x_grad{1, 0};
auto x_grad_dims_vec = phi::vectorize(x_grad->dims());
for (int i = 0; i < x_grad->dims().size(); ++i) {
x_grad_dims_vec[i] = x_grad->dims()[perm_x_grad[i]];
}
trans_x_grad.Resize(phi::make_ddim(x_grad_dims_vec));
dev_ctx.template Alloc<T>(&trans_x_grad);
phi::funcs::TransCompute<Context, T>(
perm_x_grad.size(), dev_ctx, *x_grad, &trans_x_grad, perm_x_grad);
} else {
std::vector<int> perm_d_out{1, 0};
auto out_grad_dims_vec = phi::vectorize(out_grad_.dims());
for (int i = 0; i < out_grad_.dims().size(); ++i) {
out_grad_dims_vec[i] = out_grad_.dims()[perm_d_out[i]];
}
trans_out_grad.Resize(phi::make_ddim(out_grad_dims_vec));
dev_ctx.template Alloc<T>(&trans_out_grad);
phi::funcs::TransCompute<Context, T>(
perm_d_out.size(), dev_ctx, out_grad_, &trans_out_grad, perm_d_out);
std::vector<int> perm_x_grad{2, 1, 0};
auto x_grad_dims_vec = phi::vectorize(x_grad->dims());
for (int i = 0; i < x_grad->dims().size(); ++i) {
x_grad_dims_vec[i] = x_grad->dims()[perm_x_grad[i]];
}
trans_x_grad.Resize(phi::make_ddim(x_grad_dims_vec));
dev_ctx.template Alloc<T>(&trans_x_grad);
phi::funcs::TransCompute<Context, T>(
perm_x_grad.size(), dev_ctx, *x_grad, &trans_x_grad, perm_x_grad);
}
} else {
trans_x_grad = *x_grad;
trans_out_grad = out_grad_;
}
OverlapAddFunctor<Context, T>()(dev_ctx,
&trans_out_grad,
&trans_x_grad,
seq_length,
frame_length,
n_frames,
hop_length,
/*is_grad*/ true);
// Transpose output in case axis is 0.
if (axis == 0) {
if (out_grad_rank == 1U) {
std::vector<int> perm_x_grad{1, 0};
phi::funcs::TransCompute<Context, T>(
perm_x_grad.size(), dev_ctx, trans_x_grad, x_grad, perm_x_grad);
} else {
std::vector<int> perm_x_grad{2, 1, 0};
phi::funcs::TransCompute<Context, T>(
perm_x_grad.size(), dev_ctx, trans_x_grad, x_grad, perm_x_grad);
}
}
// Restore output dims when the number of dims is larger than 2.
if (out_grad_rank > 2) {
std::vector<int64_t> restored_x_grad_shape;
for (int i = 0; i < preserved_dims.size(); i++) {
restored_x_grad_shape.push_back(preserved_dims[i]);
}
if (axis == 0) {
// (n_frames, frame_length, ...)
restored_x_grad_shape.insert(restored_x_grad_shape.begin(), frame_length);
restored_x_grad_shape.insert(restored_x_grad_shape.begin(), n_frames);
} else {
// (..., frame_length, n_frames)
restored_x_grad_shape.push_back(frame_length);
restored_x_grad_shape.push_back(n_frames);
}
x_grad->Resize(phi::make_ddim(restored_x_grad_shape));
}
}
} // namespace phi
PD_REGISTER_KERNEL(overlap_add_grad,
CPU,
ALL_LAYOUT,
phi::OverlapAddGradKernel,
int,
int64_t,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/overlap_add_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/overlap_add_functor.h"
namespace phi {
template <typename T, typename Context>
void OverlapAddKernel(const Context& dev_ctx,
const DenseTensor& x,
int hop_length,
int axis,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
const size_t x_rank = x.dims().size();
const size_t out_rank = out->dims().size();
const int n_frames = (axis == 0) ? x.dims()[0] : x.dims()[x_rank - 1];
const int frame_length = (axis == 0) ? x.dims()[1] : x.dims()[x_rank - 2];
const int seq_length =
(axis == 0) ? out->dims()[0] : out->dims()[out_rank - 1];
// auto& dev_ctx = ctx.device_context<Context>();
DenseTensor x_(x.type());
x_ = x;
phi::DDim preserved_dims;
if (out_rank > 2) {
// Save dims used to flatten both input and output tensors and restore
// output tensor.
phi::DDim x_resized_dims;
phi::DDim out_resized_dims;
if (axis == 0) {
preserved_dims = phi::slice_ddim(out->dims(), 1, out_rank);
x_resized_dims = {n_frames, frame_length, phi::product(preserved_dims)};
out_resized_dims = {seq_length, phi::product(preserved_dims)};
} else {
preserved_dims = phi::slice_ddim(out->dims(), 0, out_rank - 1);
x_resized_dims = {phi::product(preserved_dims), frame_length, n_frames};
out_resized_dims = {phi::product(preserved_dims), seq_length};
}
x_.Resize(x_resized_dims);
out->Resize(out_resized_dims);
}
DenseTensor trans_x(x_.type());
DenseTensor trans_out(out->type());
// Transpose input and output in case that axis is 0.
if (axis == 0) {
if (out_rank == 1U) {
trans_out = *out;
std::vector<int> perm_x{1, 0};
auto x_dims_vec = phi::vectorize(x_.dims());
for (int i = 0; i < x_.dims().size(); ++i) {
x_dims_vec[i] = x_.dims()[perm_x[i]];
}
trans_x.Resize(phi::make_ddim(x_dims_vec));
dev_ctx.template Alloc<T>(&trans_x);
phi::funcs::TransCompute<Context, T>(
perm_x.size(), dev_ctx, x_, &trans_x, perm_x);
} else {
std::vector<int> perm_out{1, 0};
auto out_dims_vec = phi::vectorize(out->dims());
for (int i = 0; i < out->dims().size(); ++i) {
out_dims_vec[i] = out->dims()[perm_out[i]];
}
trans_out.Resize(phi::make_ddim(out_dims_vec));
dev_ctx.template Alloc<T>(&trans_out);
phi::funcs::TransCompute<Context, T>(
perm_out.size(), dev_ctx, *out, &trans_out, perm_out);
std::vector<int> perm_x{2, 1, 0};
auto x_dims_vec = phi::vectorize(x_.dims());
for (int i = 0; i < x_.dims().size(); ++i) {
x_dims_vec[i] = x_.dims()[perm_x[i]];
}
trans_x.Resize(phi::make_ddim(x_dims_vec));
dev_ctx.template Alloc<T>(&trans_x);
phi::funcs::TransCompute<Context, T>(
perm_x.size(), dev_ctx, x_, &trans_x, perm_x);
}
} else {
trans_x = x_;
trans_out = *out;
}
OverlapAddFunctor<Context, T>()(dev_ctx,
&trans_x,
&trans_out,
seq_length,
frame_length,
n_frames,
hop_length,
/*is_grad*/ false);
// Transpose output in case axis is 0.
if (axis == 0 && out_rank > 1U) {
std::vector<int> perm_out{1, 0};
phi::funcs::TransCompute<Context, T>(
perm_out.size(), dev_ctx, trans_out, out, perm_out);
}
// Restore output dims when the number of dims is larger than 2.
if (out_rank > 2) {
std::vector<int64_t> restored_out_shape;
for (int i = 0; i < preserved_dims.size(); i++) {
restored_out_shape.push_back(preserved_dims[i]);
}
if (axis == 0) {
// (seq_length, ...)
restored_out_shape.insert(restored_out_shape.begin(), seq_length);
} else {
// (..., seq_length)
restored_out_shape.push_back(seq_length);
}
out->Resize(phi::make_ddim(restored_out_shape));
}
}
} // namespace phi
PD_REGISTER_KERNEL(overlap_add,
CPU,
ALL_LAYOUT,
phi::OverlapAddKernel,
int,
int64_t,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/seq2col.h"
namespace phi {
template <typename Context, typename T>
struct OverlapAddFunctor {
void operator()(const Context& dev_ctx,
const DenseTensor* input,
DenseTensor* output,
size_t seq_length,
size_t frame_length,
size_t n_frames,
size_t hop_length,
bool is_grad = false) const {
auto numel = output->numel();
const auto* input_data = input->data<T>();
auto* output_data = output->data<T>();
phi::funcs::ForRange<Context> for_range(dev_ctx, numel);
if (!is_grad) {
phi::funcs::Col2SeqFunctor<T> functor(input_data,
output_data,
seq_length,
frame_length,
n_frames,
hop_length);
for_range(functor);
} else {
phi::funcs::Seq2ColFunctor<T> functor(input_data,
output_data,
seq_length,
frame_length,
n_frames,
hop_length);
for_range(functor);
}
}
};
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/overlap_add_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/overlap_add_functor.h"
namespace phi {
template <typename T, typename Context>
void OverlapAddGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
int hop_length,
int axis,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
const size_t out_grad_rank = out_grad.dims().size();
const size_t x_grad_rank = x_grad->dims().size();
const int n_frames =
(axis == 0) ? x_grad->dims()[0] : x_grad->dims()[x_grad_rank - 1];
const int frame_length =
(axis == 0) ? x_grad->dims()[1] : x_grad->dims()[x_grad_rank - 2];
const int seq_length =
(axis == 0) ? out_grad.dims()[0] : out_grad.dims()[out_grad_rank - 1];
// When the number of input dims is larger than 2, it needs to copy
// from x to resize input into 2d and output into 3d. Morevoer, output
// dims will be restored at the last step.
DenseTensor out_grad_(out_grad.type());
out_grad_ = out_grad;
phi::DDim preserved_dims;
if (out_grad_rank > 2) {
// Save dims used to flatten both input and output tensors and restore
// output tensor.
phi::DDim x_grad_resized_dims;
phi::DDim out_grad_resized_dims;
if (axis == 0) {
preserved_dims = phi::slice_ddim(out_grad_.dims(), 1, out_grad_rank);
x_grad_resized_dims = {
n_frames, frame_length, phi::product(preserved_dims)};
out_grad_resized_dims = {seq_length, phi::product(preserved_dims)};
} else {
preserved_dims = phi::slice_ddim(out_grad_.dims(), 0, out_grad_rank - 1);
x_grad_resized_dims = {
phi::product(preserved_dims), frame_length, n_frames};
out_grad_resized_dims = {phi::product(preserved_dims), seq_length};
}
x_grad->Resize(x_grad_resized_dims);
out_grad_.Resize(out_grad_resized_dims);
}
DenseTensor trans_x_grad(x_grad->type());
DenseTensor trans_out_grad(out_grad_.type());
// Transpose input and output in case that axis is 0.
if (axis == 0) {
if (out_grad_rank == 1U) {
trans_out_grad = out_grad_;
std::vector<int> perm_x_grad{1, 0};
auto x_grad_dims_vec = phi::vectorize(x_grad->dims());
for (int i = 0; i < x_grad->dims().size(); ++i) {
x_grad_dims_vec[i] = x_grad->dims()[perm_x_grad[i]];
}
trans_x_grad.Resize(phi::make_ddim(x_grad_dims_vec));
dev_ctx.template Alloc<T>(&trans_x_grad);
phi::funcs::TransCompute<Context, T>(
perm_x_grad.size(), dev_ctx, *x_grad, &trans_x_grad, perm_x_grad);
} else {
std::vector<int> perm_d_out{1, 0};
auto out_grad_dims_vec = phi::vectorize(out_grad_.dims());
for (int i = 0; i < out_grad_.dims().size(); ++i) {
out_grad_dims_vec[i] = out_grad_.dims()[perm_d_out[i]];
}
trans_out_grad.Resize(phi::make_ddim(out_grad_dims_vec));
dev_ctx.template Alloc<T>(&trans_out_grad);
phi::funcs::TransCompute<Context, T>(
perm_d_out.size(), dev_ctx, out_grad_, &trans_out_grad, perm_d_out);
std::vector<int> perm_x_grad{2, 1, 0};
auto x_grad_dims_vec = phi::vectorize(x_grad->dims());
for (int i = 0; i < x_grad->dims().size(); ++i) {
x_grad_dims_vec[i] = x_grad->dims()[perm_x_grad[i]];
}
trans_x_grad.Resize(phi::make_ddim(x_grad_dims_vec));
dev_ctx.template Alloc<T>(&trans_x_grad);
phi::funcs::TransCompute<Context, T>(
perm_x_grad.size(), dev_ctx, *x_grad, &trans_x_grad, perm_x_grad);
}
} else {
trans_x_grad = *x_grad;
trans_out_grad = out_grad_;
}
OverlapAddFunctor<Context, T>()(dev_ctx,
&trans_out_grad,
&trans_x_grad,
seq_length,
frame_length,
n_frames,
hop_length,
/*is_grad*/ true);
// Transpose output in case axis is 0.
if (axis == 0) {
if (out_grad_rank == 1U) {
std::vector<int> perm_x_grad{1, 0};
phi::funcs::TransCompute<Context, T>(
perm_x_grad.size(), dev_ctx, trans_x_grad, x_grad, perm_x_grad);
} else {
std::vector<int> perm_x_grad{2, 1, 0};
phi::funcs::TransCompute<Context, T>(
perm_x_grad.size(), dev_ctx, trans_x_grad, x_grad, perm_x_grad);
}
}
// Restore output dims when the number of dims is larger than 2.
if (out_grad_rank > 2) {
std::vector<int64_t> restored_x_grad_shape;
for (int i = 0; i < preserved_dims.size(); i++) {
restored_x_grad_shape.push_back(preserved_dims[i]);
}
if (axis == 0) {
// (n_frames, frame_length, ...)
restored_x_grad_shape.insert(restored_x_grad_shape.begin(), frame_length);
restored_x_grad_shape.insert(restored_x_grad_shape.begin(), n_frames);
} else {
// (..., frame_length, n_frames)
restored_x_grad_shape.push_back(frame_length);
restored_x_grad_shape.push_back(n_frames);
}
x_grad->Resize(phi::make_ddim(restored_x_grad_shape));
}
}
} // namespace phi
PD_REGISTER_KERNEL(overlap_add_grad,
GPU,
ALL_LAYOUT,
phi::OverlapAddGradKernel,
int,
int64_t,
float,
double,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/overlap_add_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/overlap_add_functor.h"
namespace phi {
template <typename T, typename Context>
void OverlapAddKernel(const Context& dev_ctx,
const DenseTensor& x,
int hop_length,
int axis,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
const size_t x_rank = x.dims().size();
const size_t out_rank = out->dims().size();
const int n_frames = (axis == 0) ? x.dims()[0] : x.dims()[x_rank - 1];
const int frame_length = (axis == 0) ? x.dims()[1] : x.dims()[x_rank - 2];
const int seq_length =
(axis == 0) ? out->dims()[0] : out->dims()[out_rank - 1];
// auto& dev_ctx = ctx.device_context<Context>();
DenseTensor x_(x.type());
x_ = x;
phi::DDim preserved_dims;
if (out_rank > 2) {
// Save dims used to flatten both input and output tensors and restore
// output tensor.
phi::DDim x_resized_dims;
phi::DDim out_resized_dims;
if (axis == 0) {
preserved_dims = phi::slice_ddim(out->dims(), 1, out_rank);
x_resized_dims = {n_frames, frame_length, phi::product(preserved_dims)};
out_resized_dims = {seq_length, phi::product(preserved_dims)};
} else {
preserved_dims = phi::slice_ddim(out->dims(), 0, out_rank - 1);
x_resized_dims = {phi::product(preserved_dims), frame_length, n_frames};
out_resized_dims = {phi::product(preserved_dims), seq_length};
}
x_.Resize(x_resized_dims);
out->Resize(out_resized_dims);
}
DenseTensor trans_x(x_.type());
DenseTensor trans_out(out->type());
// Transpose input and output in case that axis is 0.
if (axis == 0) {
if (out_rank == 1U) {
trans_out = *out;
std::vector<int> perm_x{1, 0};
auto x_dims_vec = phi::vectorize(x_.dims());
for (int i = 0; i < x_.dims().size(); ++i) {
x_dims_vec[i] = x_.dims()[perm_x[i]];
}
trans_x.Resize(phi::make_ddim(x_dims_vec));
dev_ctx.template Alloc<T>(&trans_x);
phi::funcs::TransCompute<Context, T>(
perm_x.size(), dev_ctx, x_, &trans_x, perm_x);
} else {
std::vector<int> perm_out{1, 0};
auto out_dims_vec = phi::vectorize(out->dims());
for (int i = 0; i < out->dims().size(); ++i) {
out_dims_vec[i] = out->dims()[perm_out[i]];
}
trans_out.Resize(phi::make_ddim(out_dims_vec));
dev_ctx.template Alloc<T>(&trans_out);
phi::funcs::TransCompute<Context, T>(
perm_out.size(), dev_ctx, *out, &trans_out, perm_out);
std::vector<int> perm_x{2, 1, 0};
auto x_dims_vec = phi::vectorize(x_.dims());
for (int i = 0; i < x_.dims().size(); ++i) {
x_dims_vec[i] = x_.dims()[perm_x[i]];
}
trans_x.Resize(phi::make_ddim(x_dims_vec));
dev_ctx.template Alloc<T>(&trans_x);
phi::funcs::TransCompute<Context, T>(
perm_x.size(), dev_ctx, x_, &trans_x, perm_x);
}
} else {
trans_x = x_;
trans_out = *out;
}
OverlapAddFunctor<Context, T>()(dev_ctx,
&trans_x,
&trans_out,
seq_length,
frame_length,
n_frames,
hop_length,
/*is_grad*/ false);
// Transpose output in case axis is 0.
if (axis == 0 && out_rank > 1U) {
std::vector<int> perm_out{1, 0};
phi::funcs::TransCompute<Context, T>(
perm_out.size(), dev_ctx, trans_out, out, perm_out);
}
// Restore output dims when the number of dims is larger than 2.
if (out_rank > 2) {
std::vector<int64_t> restored_out_shape;
for (int i = 0; i < preserved_dims.size(); i++) {
restored_out_shape.push_back(preserved_dims[i]);
}
if (axis == 0) {
// (seq_length, ...)
restored_out_shape.insert(restored_out_shape.begin(), seq_length);
} else {
// (..., seq_length)
restored_out_shape.push_back(seq_length);
}
out->Resize(phi::make_ddim(restored_out_shape));
}
}
} // namespace phi
PD_REGISTER_KERNEL(overlap_add,
GPU,
ALL_LAYOUT,
phi::OverlapAddKernel,
int,
int64_t,
float,
double,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void OverlapAddGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
int hop_length,
int axis,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void OverlapAddKernel(const Context& dev_ctx,
const DenseTensor& x,
int hop_length,
int axis,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature OverlapAddGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("overlap_add_grad",
{"X", "Out@GRAD"},
{"hop_length", "axis"},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(overlap_add_grad,
phi::OverlapAddGradOpArgumentMapping);
......@@ -73,6 +73,7 @@ class TestOverlapAddOp(OpTest):
def setUp(self):
self.op_type = "overlap_add"
self.python_api = paddle.signal.overlap_add
self.shape, self.type, self.attrs = self.initTestCase()
self.inputs = {
'X': np.random.random(size=self.shape).astype(self.type),
......@@ -90,12 +91,12 @@ class TestOverlapAddOp(OpTest):
def test_check_output(self):
paddle.enable_static()
self.check_output()
self.check_output(check_eager=True)
paddle.disable_static()
def test_check_grad_normal(self):
paddle.enable_static()
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_eager=True)
paddle.disable_static()
......
......@@ -217,7 +217,9 @@ def overlap_add(x, hop_length, axis=-1, name=None):
op_type = 'overlap_add'
if _non_static_mode():
if in_dygraph_mode():
out = _C_ops.final_state_overlap_add(x, hop_length, axis)
elif paddle.in_dynamic_mode():
attrs = ('hop_length', hop_length, 'axis', axis)
op = getattr(_C_ops, op_type)
out = op(x, *attrs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册