未验证 提交 28b4b2f7 编写于 作者: C Charles-hit 提交者: GitHub

Move frame kernel to phi (#44615)

* Move frame OP to phi、add frame OP yaml config and supplement single test

* add Header file of in_dygraph_mode

* Modify variable name and FrameGradInferMeta multiplex UnchangedInferMeta

* move seq2col to phi
上级 511a2c1c
...@@ -12,7 +12,13 @@ ...@@ -12,7 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/frame_op.h" #include "paddle/phi/core/enforce.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,89 +27,6 @@ class FrameOp : public framework::OperatorWithKernel { ...@@ -21,89 +27,6 @@ class FrameOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "frame");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "frame");
const int frame_length = ctx->Attrs().Get<int>("frame_length");
const int hop_length = ctx->Attrs().Get<int>("hop_length");
const int axis = ctx->Attrs().Get<int>("axis");
const auto x_dims = ctx->GetInputDim("X");
const int x_rank = x_dims.size();
PADDLE_ENFORCE_GE(
x_rank,
1,
platform::errors::InvalidArgument(
"Input(X) of FrameOp should be a tensor which contains "
"at least 1 dimension, but got rank %s.",
x_rank));
PADDLE_ENFORCE_GT(hop_length,
0,
platform::errors::InvalidArgument(
"Attribute(hop_length) of FrameOp should be greater "
"than 0, but got %s.",
hop_length));
PADDLE_ENFORCE_EQ(
(axis == 0 || axis == -1),
true,
platform::errors::InvalidArgument(
"Attribute(axis) of FrameOp should 0 or -1, but got %s.", axis));
std::vector<int64_t> output_shape;
int seq_length;
int n_frames;
int start_axis;
int end_axis;
if (axis == 0) {
seq_length = x_dims[0];
start_axis = 1;
end_axis = x_rank - 1;
} else {
seq_length = x_dims[x_rank - 1];
start_axis = 0;
end_axis = x_rank - 2;
}
bool contain_unknown_dim = phi::contain_unknown_dim(x_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) {
PADDLE_ENFORCE_LE(frame_length,
seq_length,
platform::errors::InvalidArgument(
"Attribute(frame_length) of FrameOp should be less "
"equal than sequence length, but got (%s) > (%s).",
frame_length,
seq_length));
}
// It won't go into for loop when x_rank == 1U.
for (int i = start_axis; i <= end_axis; i++) {
output_shape.push_back(x_dims[i]);
}
if (seq_length == -1) {
n_frames = -1;
} else {
n_frames = 1 + (seq_length - frame_length) / hop_length;
}
if (axis == 0) {
// (n_frames, frame_length, ...)
output_shape.insert(output_shape.begin(), frame_length);
output_shape.insert(output_shape.begin(), n_frames);
} else {
// (..., frame_length, n_frames)
output_shape.push_back(frame_length);
output_shape.push_back(n_frames);
}
ctx->SetOutputDim("Out", phi::make_ddim(output_shape));
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -136,17 +59,6 @@ class FrameOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -136,17 +59,6 @@ class FrameOpMaker : public framework::OpProtoAndCheckerMaker {
class FrameOpGrad : public framework::OperatorWithKernel { class FrameOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "frame_grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out@GRAD",
"frame_grad");
const auto x_dims = ctx->GetInputDim("X");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
...@@ -160,7 +72,6 @@ template <typename T> ...@@ -160,7 +72,6 @@ template <typename T>
class FrameOpGradMaker : public framework::SingleGradOpMaker<T> { class FrameOpGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> retv) const override { void Apply(GradOpPtr<T> retv) const override {
retv->SetType("frame_grad"); retv->SetType("frame_grad");
retv->SetInput("X", this->Input("X")); retv->SetInput("X", this->Input("X"));
...@@ -175,28 +86,19 @@ class FrameOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -175,28 +86,19 @@ class FrameOpGradMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(frame,
FrameInferShapeFunctor,
PD_INFER_META(phi::FrameInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(frame_grad,
FrameGradInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(frame, REGISTER_OPERATOR(frame,
ops::FrameOp, ops::FrameOp,
ops::FrameOpMaker, ops::FrameOpMaker,
ops::FrameOpGradMaker<paddle::framework::OpDesc>, ops::FrameOpGradMaker<paddle::framework::OpDesc>,
ops::FrameOpGradMaker<paddle::imperative::OpBase>); ops::FrameOpGradMaker<paddle::imperative::OpBase>,
FrameInferShapeFunctor);
REGISTER_OPERATOR(frame_grad, ops::FrameOpGrad);
REGISTER_OP_CPU_KERNEL(
frame,
ops::FrameKernel<phi::CPUContext, int>,
ops::FrameKernel<phi::CPUContext, int64_t>,
ops::FrameKernel<phi::CPUContext, float>,
ops::FrameKernel<phi::CPUContext, double>,
ops::FrameKernel<phi::CPUContext, paddle::platform::complex<float>>,
ops::FrameKernel<phi::CPUContext, paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OPERATOR(frame_grad, ops::FrameOpGrad, FrameGradInferShapeFunctor);
frame_grad,
ops::FrameGradKernel<phi::CPUContext, int>,
ops::FrameGradKernel<phi::CPUContext, int64_t>,
ops::FrameGradKernel<phi::CPUContext, float>,
ops::FrameGradKernel<phi::CPUContext, double>,
ops::FrameGradKernel<phi::CPUContext, paddle::platform::complex<float>>,
ops::FrameGradKernel<phi::CPUContext, paddle::platform::complex<double>>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/frame_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
frame,
ops::FrameKernel<paddle::platform::CUDADeviceContext, int>,
ops::FrameKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FrameKernel<paddle::platform::CUDADeviceContext, float>,
ops::FrameKernel<paddle::platform::CUDADeviceContext, double>,
ops::FrameKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FrameKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::FrameKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
frame_grad,
ops::FrameGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::FrameGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FrameGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::FrameGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::FrameGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FrameGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::FrameGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/seq2col.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
struct FrameFunctor {
void operator()(const DeviceContext& dev_ctx,
const Tensor* input,
Tensor* output,
size_t seq_length,
size_t frame_length,
size_t n_frames,
size_t hop_length,
bool is_grad = false) const {
auto numel = output->numel();
const auto* input_data = input->data<T>();
auto* output_data = output->data<T>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
if (!is_grad) {
math::Seq2ColFunctor<T> functor(input_data,
output_data,
seq_length,
frame_length,
n_frames,
hop_length);
for_range(functor);
} else {
math::Col2SeqFunctor<T> functor(input_data,
output_data,
seq_length,
frame_length,
n_frames,
hop_length);
for_range(functor);
}
}
};
template <typename DeviceContext, typename T>
class FrameKernel : public framework::OpKernel<T> {
public:
/*
Frame kernel slices frames from input sequences. The main steps as follows:
- Case 1 - input dims == 1:
- axis is -1: Call a FrameFunctor to compute directly.
- axis is 0: Transpose output firstly, and then it falls into
case axis is -1. Finally, it restores the dims of
output tensor.
- Case 2 - input dims == 2:
- axis is -1: Call a FrameFunctor to compute directly.
- axis is 0: Transpose both input and output firstly, and then it falls
into case axis is -1. Finally, it restores the dims of
output tensor.
- Case 3 - input dims > 2:
Flatten the input and output to 2D and 3D respectively so that it
falls into Case 2. Finally, it restores the dims of output tensor.
*/
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* x = ctx.Input<Tensor>("X");
Tensor* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
const size_t x_rank = x->dims().size();
const size_t out_rank = out->dims().size();
const int frame_length = ctx.Attr<int>("frame_length");
const int hop_length = ctx.Attr<int>("hop_length");
const int axis = ctx.Attr<int>("axis");
const int n_frames =
(axis == 0) ? out->dims()[0] : out->dims()[out_rank - 1];
const int seq_length = (axis == 0) ? x->dims()[0] : x->dims()[x_rank - 1];
auto& dev_ctx = ctx.device_context<DeviceContext>();
// When the number of input dims is larger than 2, it needs to copy
// from x to resize input into 2d and output into 3d. Morevoer, output
// dims will be restored at the last step.
Tensor x_(x->type());
x_ = *x;
framework::DDim preserved_dims;
if (x_rank > 2) {
// Save dims used to flatten both input and output tensors and restore
// output tensor.
framework::DDim x_resized_dims;
framework::DDim out_resized_dims;
if (axis == 0) {
preserved_dims = phi::slice_ddim(x_.dims(), 1, x_rank);
x_resized_dims = {seq_length, phi::product(preserved_dims)};
out_resized_dims = {
n_frames, frame_length, phi::product(preserved_dims)};
} else {
preserved_dims = phi::slice_ddim(x_.dims(), 0, x_rank - 1);
x_resized_dims = {phi::product(preserved_dims), seq_length};
out_resized_dims = {
phi::product(preserved_dims), frame_length, n_frames};
}
x_.Resize(x_resized_dims);
out->Resize(out_resized_dims);
}
Tensor trans_x(x_.type());
Tensor trans_out(out->type());
// Transpose input and output in case that axis is 0.
if (axis == 0) {
if (x_rank == 1U) {
trans_x = x_;
std::vector<int> perm_out{1, 0};
auto out_dims_vec = phi::vectorize(out->dims());
for (int i = 0; i < out->dims().size(); ++i) {
out_dims_vec[i] = out->dims()[perm_out[i]];
}
trans_out.Resize(phi::make_ddim(out_dims_vec));
trans_out.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(
perm_out.size(), dev_ctx, *out, &trans_out, perm_out);
} else {
std::vector<int> perm_x{1, 0};
auto x_dims_vec = phi::vectorize(x_.dims());
for (int i = 0; i < x_.dims().size(); ++i) {
x_dims_vec[i] = x_.dims()[perm_x[i]];
}
trans_x.Resize(phi::make_ddim(x_dims_vec));
trans_x.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(
perm_x.size(), dev_ctx, x_, &trans_x, perm_x);
std::vector<int> perm_out{2, 1, 0};
auto out_dims_vec = phi::vectorize(out->dims());
for (int i = 0; i < out->dims().size(); ++i) {
out_dims_vec[i] = out->dims()[perm_out[i]];
}
trans_out.Resize(phi::make_ddim(out_dims_vec));
trans_out.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(
perm_out.size(), dev_ctx, *out, &trans_out, perm_out);
}
} else {
trans_x = x_;
trans_out = *out;
}
FrameFunctor<DeviceContext, T>()(dev_ctx,
&trans_x,
&trans_out,
seq_length,
frame_length,
n_frames,
hop_length,
/*is_grad*/ false);
// Transpose output in case axis is 0.
if (axis == 0) {
if (x_rank == 1U) {
std::vector<int> perm_out{1, 0};
TransCompute<DeviceContext, T>(
perm_out.size(), dev_ctx, trans_out, out, perm_out);
} else {
std::vector<int> perm_out{2, 1, 0};
TransCompute<DeviceContext, T>(
perm_out.size(), dev_ctx, trans_out, out, perm_out);
}
}
// Restore output dims when the number of dims is larger than 2.
if (x_rank > 2) {
std::vector<int64_t> restored_out_shape;
for (int i = 0; i < preserved_dims.size(); i++) {
restored_out_shape.push_back(preserved_dims[i]);
}
if (axis == 0) {
// (n_frames, frame_length, ...)
restored_out_shape.insert(restored_out_shape.begin(), frame_length);
restored_out_shape.insert(restored_out_shape.begin(), n_frames);
} else {
// (..., frame_length, n_frames)
restored_out_shape.push_back(frame_length);
restored_out_shape.push_back(n_frames);
}
out->Resize(phi::make_ddim(restored_out_shape));
}
}
};
template <typename DeviceContext, typename T>
class FrameGradKernel : public framework::OpKernel<T> {
public:
/*
Frame gradient kernel accumulate gradient `d_x` from `d_out`. The
main steps as follows:
- Case 1 - d_x dims == 1:
- axis is -1: Call a FrameFunctor to compute directly. Notes that
`is_grad` is set to true to select gradient data functor.
- axis is 0: Transpose `d_out` firstly, and then it falls into
case axis is -1.
- Case 2 - d_x dims == 2:
- axis is -1: Call a FrameFunctor to compute directly.
- axis is 0: Transpose both `d_x` and `d_out` firstly, and then it
falls into case axis is -1. Finally, it restores the
dims of `d_x`.
- Case 3 - d_x dims > 2:
Flatten the `d_x` and `d_out` to 2D and 3D respectively so that it
falls into Case 2. Finally, it restores the dims of `d_x` tensor.
*/
void Compute(const framework::ExecutionContext& ctx) const {
const Tensor* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
Tensor* d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
d_x->mutable_data<T>(ctx.GetPlace());
const size_t d_out_rank = d_out->dims().size();
const size_t d_x_rank = d_x->dims().size();
const int frame_length = ctx.Attr<int>("frame_length");
const int hop_length = ctx.Attr<int>("hop_length");
const int axis = ctx.Attr<int>("axis");
const int n_frames =
(axis == 0) ? d_out->dims()[0] : d_out->dims()[d_out_rank - 1];
const int seq_length =
(axis == 0) ? d_x->dims()[0] : d_x->dims()[d_x_rank - 1];
auto& dev_ctx = ctx.device_context<DeviceContext>();
Tensor d_out_(d_out->type());
d_out_ = *d_out;
framework::DDim preserved_dims;
if (d_x_rank > 2) {
// Save dims used to flatten both input and output tensors and restore
// output tensor.
framework::DDim d_x_resized_dims;
framework::DDim d_out_resized_dims;
if (axis == 0) {
preserved_dims = phi::slice_ddim(d_x->dims(), 1, d_x_rank);
d_x_resized_dims = {seq_length, phi::product(preserved_dims)};
d_out_resized_dims = {
n_frames, frame_length, phi::product(preserved_dims)};
} else {
preserved_dims = phi::slice_ddim(d_x->dims(), 0, d_x_rank - 1);
d_x_resized_dims = {phi::product(preserved_dims), seq_length};
d_out_resized_dims = {
phi::product(preserved_dims), frame_length, n_frames};
}
d_x->Resize(d_x_resized_dims);
d_out_.Resize(d_out_resized_dims);
}
Tensor trans_d_x(d_x->type());
Tensor trans_d_out(d_out_.type());
// Transpose input and output in case that axis is 0.
if (axis == 0) {
if (d_x_rank == 1U) {
trans_d_x = *d_x;
std::vector<int> perm_d_out{1, 0};
auto d_out_dims_vec = phi::vectorize(d_out_.dims());
for (int i = 0; i < d_out_.dims().size(); ++i) {
d_out_dims_vec[i] = d_out_.dims()[perm_d_out[i]];
}
trans_d_out.Resize(phi::make_ddim(d_out_dims_vec));
trans_d_out.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(
perm_d_out.size(), dev_ctx, d_out_, &trans_d_out, perm_d_out);
} else {
std::vector<int> perm_d_x{1, 0};
auto d_x_dims_vec = phi::vectorize(d_x->dims());
for (int i = 0; i < d_x->dims().size(); ++i) {
d_x_dims_vec[i] = d_x->dims()[perm_d_x[i]];
}
trans_d_x.Resize(phi::make_ddim(d_x_dims_vec));
trans_d_x.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(
perm_d_x.size(), dev_ctx, *d_x, &trans_d_x, perm_d_x);
std::vector<int> perm_d_out{2, 1, 0};
auto d_out_dims_vec = phi::vectorize(d_out_.dims());
for (int i = 0; i < d_out_.dims().size(); ++i) {
d_out_dims_vec[i] = d_out_.dims()[perm_d_out[i]];
}
trans_d_out.Resize(phi::make_ddim(d_out_dims_vec));
trans_d_out.mutable_data<T>(ctx.GetPlace());
TransCompute<DeviceContext, T>(
perm_d_out.size(), dev_ctx, d_out_, &trans_d_out, perm_d_out);
}
} else {
trans_d_x = *d_x;
trans_d_out = d_out_;
}
FrameFunctor<DeviceContext, T>()(dev_ctx,
&trans_d_out,
&trans_d_x,
seq_length,
frame_length,
n_frames,
hop_length,
/*is_grad*/ true);
// Transpose output in case axis is 0.
if (axis == 0 && d_x_rank > 1U) {
std::vector<int> perm_d_x{1, 0};
TransCompute<DeviceContext, T>(
perm_d_x.size(), dev_ctx, trans_d_x, d_x, perm_d_x);
}
// Restore output dims when the number of dims is larger than 2.
if (d_x_rank > 2) {
std::vector<int64_t> restored_d_x_shape;
for (int i = 0; i < preserved_dims.size(); i++) {
restored_d_x_shape.push_back(preserved_dims[i]);
}
if (axis == 0) {
// (seq_length, ...)
restored_d_x_shape.insert(restored_d_x_shape.begin(), seq_length);
} else {
// (..., seq_length)
restored_d_x_shape.push_back(seq_length);
}
d_x->Resize(phi::make_ddim(restored_d_x_shape));
}
}
};
} // namespace operators
} // namespace paddle
...@@ -18,11 +18,11 @@ ...@@ -18,11 +18,11 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/seq2col.h"
#include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/seq2col.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -44,7 +44,7 @@ struct OverlapAddFunctor { ...@@ -44,7 +44,7 @@ struct OverlapAddFunctor {
platform::ForRange<DeviceContext> for_range(dev_ctx, numel); platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
if (!is_grad) { if (!is_grad) {
math::Col2SeqFunctor<T> functor(input_data, phi::funcs::Col2SeqFunctor<T> functor(input_data,
output_data, output_data,
seq_length, seq_length,
frame_length, frame_length,
...@@ -52,7 +52,7 @@ struct OverlapAddFunctor { ...@@ -52,7 +52,7 @@ struct OverlapAddFunctor {
hop_length); hop_length);
for_range(functor); for_range(functor);
} else { } else {
math::Seq2ColFunctor<T> functor(input_data, phi::funcs::Seq2ColFunctor<T> functor(input_data,
output_data, output_data,
seq_length, seq_length,
frame_length, frame_length,
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/frame_op.h"
#include "paddle/fluid/operators/spectral_op.h" #include "paddle/fluid/operators/spectral_op.h"
#include "paddle/phi/kernels/funcs/frame_functor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -60,7 +60,7 @@ class StftKernel : public framework::OpKernel<T> { ...@@ -60,7 +60,7 @@ class StftKernel : public framework::OpKernel<T> {
framework::DDim frames_dims(out->dims()); framework::DDim frames_dims(out->dims());
frames_dims.at(axes.back()) = n_fft; frames_dims.at(axes.back()) = n_fft;
frames.mutable_data<T>(frames_dims, ctx.GetPlace()); frames.mutable_data<T>(frames_dims, ctx.GetPlace());
FrameFunctor<DeviceContext, T>()(dev_ctx, phi::funcs::FrameFunctor<DeviceContext, T>()(dev_ctx,
x, x,
&frames, &frames,
seq_length, seq_length,
...@@ -175,7 +175,7 @@ class StftGradKernel : public framework::OpKernel<T> { ...@@ -175,7 +175,7 @@ class StftGradKernel : public framework::OpKernel<T> {
ctx, &d_frames_w, window, axes.back(), MulFunctor<T>(), &d_frames); ctx, &d_frames_w, window, axes.back(), MulFunctor<T>(), &d_frames);
// d_frames -> dx // d_frames -> dx
FrameFunctor<DeviceContext, T>()(dev_ctx, phi::funcs::FrameFunctor<DeviceContext, T>()(dev_ctx,
&d_frames, &d_frames,
dx, dx,
seq_length, seq_length,
......
...@@ -850,6 +850,15 @@ ...@@ -850,6 +850,15 @@
func : fmin func : fmin
backward : fmin_grad backward : fmin_grad
- api : frame
args : (Tensor x, int frame_length, int hop_length, int axis)
output : Tensor(out)
infer_meta :
func : FrameInferMeta
kernel :
func : frame
backward : frame_grad
- api : frobenius_norm - api : frobenius_norm
args : (Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all) args : (Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all)
output : Tensor(out) output : Tensor(out)
......
...@@ -824,6 +824,16 @@ ...@@ -824,6 +824,16 @@
kernel : kernel :
func : fmin_grad func : fmin_grad
- backward_api : frame_grad
forward : frame(Tensor x, int frame_length, int hop_length, int axis) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int frame_length, int hop_length, int axis)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : frame_grad
- backward_api : frobenius_norm_grad - backward_api : frobenius_norm_grad
forward : frobenius_norm(Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all) -> Tensor(out) forward : frobenius_norm(Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] axis, bool keep_dim, bool reduce_all) args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] axis, bool keep_dim, bool reduce_all)
......
...@@ -875,6 +875,90 @@ void FlipInferMeta(const MetaTensor& x, ...@@ -875,6 +875,90 @@ void FlipInferMeta(const MetaTensor& x,
out->share_lod(x); out->share_lod(x);
} }
void FrameInferMeta(const MetaTensor& x,
int frame_length,
int hop_length,
int axis,
MetaTensor* out,
MetaConfig config) {
PADDLE_ENFORCE_NOT_NULL(out,
phi::errors::InvalidArgument(
"Output(Out) of FrameOp should not be null."));
const auto x_dims = x.dims();
const int x_rank = x_dims.size();
PADDLE_ENFORCE_GE(x_rank,
1,
phi::errors::InvalidArgument(
"Input(X) of FrameOp should be a tensor which contains "
"at least 1 dimension, but got rank %s.",
x_rank));
PADDLE_ENFORCE_GT(hop_length,
0,
phi::errors::InvalidArgument(
"Attribute(hop_length) of FrameOp should be greater "
"than 0, but got %s.",
hop_length));
PADDLE_ENFORCE_EQ(
(axis == 0 || axis == -1),
true,
phi::errors::InvalidArgument(
"Attribute(axis) of FrameOp should 0 or -1, but got %s.", axis));
std::vector<int64_t> output_shape;
int seq_length;
int n_frames;
int start_axis;
int end_axis;
if (axis == 0) {
seq_length = x_dims[0];
start_axis = 1;
end_axis = x_rank - 1;
} else {
seq_length = x_dims[x_rank - 1];
start_axis = 0;
end_axis = x_rank - 2;
}
bool contain_unknown_dim = phi::contain_unknown_dim(x_dims);
bool check = config.is_runtime || !contain_unknown_dim;
if (check) {
PADDLE_ENFORCE_LE(frame_length,
seq_length,
phi::errors::InvalidArgument(
"Attribute(frame_length) of FrameOp should be less "
"equal than sequence length, but got (%s) > (%s).",
frame_length,
seq_length));
}
// It won't go into for loop when x_rank == 1U.
for (int i = start_axis; i <= end_axis; i++) {
output_shape.push_back(x_dims[i]);
}
if (seq_length == -1) {
n_frames = -1;
} else {
n_frames = 1 + (seq_length - frame_length) / hop_length;
}
if (axis == 0) {
// (n_frames, frame_length, ...)
output_shape.insert(output_shape.begin(), frame_length);
output_shape.insert(output_shape.begin(), n_frames);
} else {
// (..., frame_length, n_frames)
output_shape.push_back(frame_length);
output_shape.push_back(n_frames);
}
out->set_dims(phi::make_ddim(output_shape));
out->set_dtype(x.dtype());
}
void FullBatchSizeLikeInferMeta(const MetaTensor& x, void FullBatchSizeLikeInferMeta(const MetaTensor& x,
const std::vector<int>& shape, const std::vector<int>& shape,
const Scalar& val, const Scalar& val,
......
...@@ -130,6 +130,13 @@ void FlipInferMeta(const MetaTensor& x, ...@@ -130,6 +130,13 @@ void FlipInferMeta(const MetaTensor& x,
const std::vector<int>& axis, const std::vector<int>& axis,
MetaTensor* out); MetaTensor* out);
void FrameInferMeta(const MetaTensor& x,
int frame_length,
int hop_length,
int axis,
MetaTensor* out,
MetaConfig = MetaConfig());
void FullBatchSizeLikeInferMeta(const MetaTensor& x, void FullBatchSizeLikeInferMeta(const MetaTensor& x,
const std::vector<int>& shape, const std::vector<int>& shape,
const Scalar& val, const Scalar& val,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/frame_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/frame_grad_kernel_impl.h"
PD_REGISTER_KERNEL(frame_grad,
CPU,
ALL_LAYOUT,
phi::FrameGradKernel,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/frame_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/frame_kernel_impl.h"
PD_REGISTER_KERNEL(frame,
CPU,
ALL_LAYOUT,
phi::FrameKernel,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void FrameGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
int frame_length,
int hop_length,
int axis,
DenseTensor* dx);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void FrameKernel(const Context& dev_ctx,
const DenseTensor& x,
int frame_length,
int hop_length,
int axis,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/seq2col.h"
namespace phi {
namespace funcs {
template <typename Context, typename T>
struct FrameFunctor {
void operator()(const Context& dev_ctx,
const DenseTensor* input,
DenseTensor* output,
size_t seq_length,
size_t frame_length,
size_t n_frames,
size_t hop_length,
bool is_grad = false) const {
auto numel = output->numel();
const auto* input_data = input->data<T>();
auto* output_data = output->data<T>();
phi::funcs::ForRange<Context> for_range(dev_ctx, numel);
if (!is_grad) {
phi::funcs::Seq2ColFunctor<T> functor(input_data,
output_data,
seq_length,
frame_length,
n_frames,
hop_length);
for_range(functor);
} else {
phi::funcs::Col2SeqFunctor<T> functor(input_data,
output_data,
seq_length,
frame_length,
n_frames,
hop_length);
for_range(functor);
}
}
};
} // namespace funcs
} // namespace phi
...@@ -14,9 +14,8 @@ ...@@ -14,9 +14,8 @@
#pragma once #pragma once
namespace paddle { namespace phi {
namespace operators { namespace funcs {
namespace math {
template <typename T> template <typename T>
struct Seq2ColFunctor { struct Seq2ColFunctor {
...@@ -189,6 +188,5 @@ struct Col2SeqFunctor { ...@@ -189,6 +188,5 @@ struct Col2SeqFunctor {
size_t hop_length_; size_t hop_length_;
}; };
} // namespace math } // namespace funcs
} // namespace operators } // namespace phi
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/frame_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/frame_grad_kernel_impl.h"
PD_REGISTER_KERNEL(frame_grad,
GPU,
ALL_LAYOUT,
phi::FrameGradKernel,
int,
int64_t,
float,
double,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/frame_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/frame_kernel_impl.h"
PD_REGISTER_KERNEL(frame,
GPU,
ALL_LAYOUT,
phi::FrameKernel,
int,
int64_t,
float,
double,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/funcs/frame_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void FrameGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
int frame_length,
int hop_length,
int axis,
DenseTensor* dx) {
dev_ctx.template Alloc<T>(dx);
const size_t dout_rank = dout.dims().size();
const size_t dx_rank = dx->dims().size();
const int n_frames =
(axis == 0) ? dout.dims()[0] : dout.dims()[dout_rank - 1];
const int seq_length = (axis == 0) ? dx->dims()[0] : dx->dims()[dx_rank - 1];
DenseTensor dout_tmp = dout;
DDim preserved_dims;
if (dx_rank > 2) {
// Save dims used to flatten both input and output tensors and restore
// output tensor.
DDim dx_resized_dims;
DDim dout_resized_dims;
if (axis == 0) {
preserved_dims = phi::slice_ddim(dx->dims(), 1, dx_rank);
dx_resized_dims = {seq_length, phi::product(preserved_dims)};
dout_resized_dims = {
n_frames, frame_length, phi::product(preserved_dims)};
} else {
preserved_dims = phi::slice_ddim(dx->dims(), 0, dx_rank - 1);
dx_resized_dims = {phi::product(preserved_dims), seq_length};
dout_resized_dims = {
phi::product(preserved_dims), frame_length, n_frames};
}
dx->Resize(dx_resized_dims);
dout_tmp.Resize(dout_resized_dims);
}
DenseTensor trans_dx;
DenseTensor trans_dout;
// Transpose input and output in case that axis is 0.
if (axis == 0) {
if (dx_rank == 1U) {
trans_dx = *dx;
std::vector<int> perm_dout{1, 0};
auto dout_dims_vec = phi::vectorize(dout_tmp.dims());
for (int i = 0; i < dout_tmp.dims().size(); ++i) {
dout_dims_vec[i] = dout_tmp.dims()[perm_dout[i]];
}
trans_dout.Resize(phi::make_ddim(dout_dims_vec));
dev_ctx.template Alloc<T>(&trans_dout);
phi::funcs::TransCompute<Context, T>(
perm_dout.size(), dev_ctx, dout_tmp, &trans_dout, perm_dout);
} else {
std::vector<int> perm_dx{1, 0};
auto dx_dims_vec = phi::vectorize(dx->dims());
for (int i = 0; i < dx->dims().size(); ++i) {
dx_dims_vec[i] = dx->dims()[perm_dx[i]];
}
trans_dx.Resize(phi::make_ddim(dx_dims_vec));
dev_ctx.template Alloc<T>(&trans_dx);
phi::funcs::TransCompute<Context, T>(
perm_dx.size(), dev_ctx, *dx, &trans_dx, perm_dx);
std::vector<int> perm_dout{2, 1, 0};
auto dout_dims_vec = phi::vectorize(dout_tmp.dims());
for (int i = 0; i < dout_tmp.dims().size(); ++i) {
dout_dims_vec[i] = dout_tmp.dims()[perm_dout[i]];
}
trans_dout.Resize(phi::make_ddim(dout_dims_vec));
dev_ctx.template Alloc<T>(&trans_dout);
phi::funcs::TransCompute<Context, T>(
perm_dout.size(), dev_ctx, dout_tmp, &trans_dout, perm_dout);
}
} else {
trans_dx = *dx;
trans_dout = dout_tmp;
}
phi::funcs::FrameFunctor<Context, T>()(dev_ctx,
&trans_dout,
&trans_dx,
seq_length,
frame_length,
n_frames,
hop_length,
/*is_grad*/ true);
// Transpose output in case axis is 0.
if (axis == 0 && dx_rank > 1U) {
std::vector<int> perm_dx{1, 0};
phi::funcs::TransCompute<Context, T>(
perm_dx.size(), dev_ctx, trans_dx, dx, perm_dx);
}
// Restore output dims when the number of dims is larger than 2.
if (dx_rank > 2) {
std::vector<int64_t> restored_dx_shape;
for (int i = 0; i < preserved_dims.size(); i++) {
restored_dx_shape.push_back(preserved_dims[i]);
}
if (axis == 0) {
// (seq_length, ...)
restored_dx_shape.insert(restored_dx_shape.begin(), seq_length);
} else {
// (..., seq_length)
restored_dx_shape.push_back(seq_length);
}
dx->Resize(phi::make_ddim(restored_dx_shape));
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/funcs/frame_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void FrameKernel(const Context& dev_ctx,
const DenseTensor& x,
int frame_length,
int hop_length,
int axis,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
const size_t x_rank = x.dims().size();
const size_t out_rank = out->dims().size();
const int n_frames = (axis == 0) ? out->dims()[0] : out->dims()[out_rank - 1];
const int seq_length = (axis == 0) ? x.dims()[0] : x.dims()[x_rank - 1];
// When the number of input dims is larger than 2, it needs to copy
// from x to resize input into 2d and output into 3d. Morevoer, output
// dims will be restored at the last step.
DenseTensor x_tmp = x;
DDim preserved_dims;
if (x_rank > 2) {
// Save dims used to flatten both input and output tensors and restore
// output tensor.
DDim x_resized_dims;
DDim out_resized_dims;
if (axis == 0) {
preserved_dims = phi::slice_ddim(x_tmp.dims(), 1, x_rank);
x_resized_dims = {seq_length, phi::product(preserved_dims)};
out_resized_dims = {n_frames, frame_length, phi::product(preserved_dims)};
} else {
preserved_dims = phi::slice_ddim(x_tmp.dims(), 0, x_rank - 1);
x_resized_dims = {phi::product(preserved_dims), seq_length};
out_resized_dims = {phi::product(preserved_dims), frame_length, n_frames};
}
x_tmp.Resize(x_resized_dims);
out->Resize(out_resized_dims);
}
DenseTensor trans_x;
DenseTensor trans_out;
// Transpose input and output in case that axis is 0.
if (axis == 0) {
if (x_rank == 1U) {
trans_x = x_tmp;
std::vector<int> perm_out{1, 0};
auto out_dims_vec = phi::vectorize(out->dims());
for (int i = 0; i < out->dims().size(); ++i) {
out_dims_vec[i] = out->dims()[perm_out[i]];
}
trans_out.Resize(phi::make_ddim(out_dims_vec));
dev_ctx.template Alloc<T>(&trans_out);
phi::funcs::TransCompute<Context, T>(
perm_out.size(), dev_ctx, *out, &trans_out, perm_out);
} else {
std::vector<int> perm_x{1, 0};
auto x_dims_vec = phi::vectorize(x_tmp.dims());
for (int i = 0; i < x_tmp.dims().size(); ++i) {
x_dims_vec[i] = x_tmp.dims()[perm_x[i]];
}
trans_x.Resize(phi::make_ddim(x_dims_vec));
dev_ctx.template Alloc<T>(&trans_x);
phi::funcs::TransCompute<Context, T>(
perm_x.size(), dev_ctx, x_tmp, &trans_x, perm_x);
std::vector<int> perm_out{2, 1, 0};
auto out_dims_vec = phi::vectorize(out->dims());
for (int i = 0; i < out->dims().size(); ++i) {
out_dims_vec[i] = out->dims()[perm_out[i]];
}
trans_out.Resize(phi::make_ddim(out_dims_vec));
dev_ctx.template Alloc<T>(&trans_out);
phi::funcs::TransCompute<Context, T>(
perm_out.size(), dev_ctx, *out, &trans_out, perm_out);
}
} else {
trans_x = x_tmp;
trans_out = *out;
}
phi::funcs::FrameFunctor<Context, T>()(dev_ctx,
&trans_x,
&trans_out,
seq_length,
frame_length,
n_frames,
hop_length,
/*is_grad*/ false);
// Transpose output in case axis is 0.
if (axis == 0) {
if (x_rank == 1U) {
std::vector<int> perm_out{1, 0};
funcs::TransCompute<Context, T>(
perm_out.size(), dev_ctx, trans_out, out, perm_out);
} else {
std::vector<int> perm_out{2, 1, 0};
funcs::TransCompute<Context, T>(
perm_out.size(), dev_ctx, trans_out, out, perm_out);
}
}
// Restore output dims when the number of dims is larger than 2.
if (x_rank > 2) {
std::vector<int64_t> restored_out_shape;
for (int i = 0; i < preserved_dims.size(); i++) {
restored_out_shape.push_back(preserved_dims[i]);
}
if (axis == 0) {
// (n_frames, frame_length, ...)
restored_out_shape.insert(restored_out_shape.begin(), frame_length);
restored_out_shape.insert(restored_out_shape.begin(), n_frames);
} else {
// (..., frame_length, n_frames)
restored_out_shape.push_back(frame_length);
restored_out_shape.push_back(n_frames);
}
out->Resize(phi::make_ddim(restored_out_shape));
}
}
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature FrameGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("frame_grad",
{"X", "Out@GRAD"},
{"frame_length", "hop_length", "axis"},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(frame_grad, phi::FrameGradOpArgumentMapping);
...@@ -47,6 +47,7 @@ class TestFrameOp(OpTest): ...@@ -47,6 +47,7 @@ class TestFrameOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "frame" self.op_type = "frame"
self.python_api = paddle.signal.frame
self.shape, self.type, self.attrs = self.initTestCase() self.shape, self.type, self.attrs = self.initTestCase()
self.inputs = { self.inputs = {
'X': np.random.random(size=self.shape).astype(self.type), 'X': np.random.random(size=self.shape).astype(self.type),
...@@ -67,12 +68,12 @@ class TestFrameOp(OpTest): ...@@ -67,12 +68,12 @@ class TestFrameOp(OpTest):
def test_check_output(self): def test_check_output(self):
paddle.enable_static() paddle.enable_static()
self.check_output() self.check_output(check_eager=True)
paddle.disable_static() paddle.disable_static()
def test_check_grad_normal(self): def test_check_grad_normal(self):
paddle.enable_static() paddle.enable_static()
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', check_eager=True)
paddle.disable_static() paddle.disable_static()
......
...@@ -21,7 +21,8 @@ from .fft import fft_r2c, fft_c2r, fft_c2c ...@@ -21,7 +21,8 @@ from .fft import fft_r2c, fft_c2r, fft_c2c
from .fluid.data_feeder import check_variable_and_dtype from .fluid.data_feeder import check_variable_and_dtype
from .fluid.framework import _non_static_mode from .fluid.framework import _non_static_mode
from .fluid.layer_helper import LayerHelper from .fluid.layer_helper import LayerHelper
from . import _C_ops from paddle import _C_ops
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
__all__ = [ __all__ = [
'stft', 'stft',
...@@ -127,7 +128,10 @@ def frame(x, frame_length, hop_length, axis=-1, name=None): ...@@ -127,7 +128,10 @@ def frame(x, frame_length, hop_length, axis=-1, name=None):
op_type = 'frame' op_type = 'frame'
if _non_static_mode(): if in_dygraph_mode():
return _C_ops.final_state_frame(x, frame_length, hop_length, axis)
if _in_legacy_dygraph():
attrs = ('frame_length', frame_length, 'hop_length', hop_length, 'axis', attrs = ('frame_length', frame_length, 'hop_length', hop_length, 'axis',
axis) axis)
op = getattr(_C_ops, op_type) op = getattr(_C_ops, op_type)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册