未验证 提交 9121115b 编写于 作者: X xiongkun 提交者: GitHub

[phi] transfer unsqueeze to phi (#40596)

* transfer unsqueeze to phi

* fix conflict

* add squeeze

* add infershape

* fix xpu and npu error
上级 13c99434
...@@ -19,7 +19,9 @@ limitations under the License. */ ...@@ -19,7 +19,9 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -113,13 +115,13 @@ class SqueezeOp : public framework::OperatorWithKernel { ...@@ -113,13 +115,13 @@ class SqueezeOp : public framework::OperatorWithKernel {
auto input_data_type = auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
//#ifdef PADDLE_WITH_MKLDNN // #ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(), // return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN, // framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN); // framework::LibraryType::kMKLDNN);
// } // }
//#endif // #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
...@@ -140,13 +142,13 @@ class SqueezeGradOp : public framework::OperatorWithKernel { ...@@ -140,13 +142,13 @@ class SqueezeGradOp : public framework::OperatorWithKernel {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")); ctx, framework::GradVarName("Out"));
//#ifdef PADDLE_WITH_MKLDNN // #ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(), // return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN, // framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN); // framework::LibraryType::kMKLDNN);
// } // }
//#endif // #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
...@@ -201,53 +203,18 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -201,53 +203,18 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
class Squeeze2Op : public framework::OperatorWithKernel { class Squeeze2Op : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Squeeze2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Squeeze2");
const auto &x_dims = ctx->GetInputDim("X");
// Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE_LE(x_dims.size(), 6,
platform::errors::InvalidArgument(
"The dimensions of Input(X) "
"should be in the range of [1, 6] (Eigen limit)."
"But received X's dimensions = %d, X's shape = [%s].",
x_dims.size(), x_dims));
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
auto out_dims = GetOutputShape(axes, x_dims, false);
ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx->ShareLoD("X", "Out");
}
if (!ctx->HasOutput("XShape")) return;
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
ctx->SetOutputDim("XShape", phi::make_ddim(xshape_dims));
ctx->ShareLoD("X", /*->*/ "XShape");
}
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
//#ifdef PADDLE_WITH_MKLDNN // #ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(), // return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN, // framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN); // framework::LibraryType::kMKLDNN);
// } // }
//#endif // #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
...@@ -287,13 +254,13 @@ class Squeeze2GradOp : public framework::OperatorWithKernel { ...@@ -287,13 +254,13 @@ class Squeeze2GradOp : public framework::OperatorWithKernel {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")); ctx, framework::GradVarName("Out"));
//#ifdef PADDLE_WITH_MKLDNN // #ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(), // return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN, // framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN); // framework::LibraryType::kMKLDNN);
// } // }
//#endif // #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
...@@ -365,6 +332,10 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(SqueezeGradNoNeedBufferVarsInferer, "X"); ...@@ -365,6 +332,10 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(SqueezeGradNoNeedBufferVarsInferer, "X");
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(squeeze2, SqueezeInferShapeFunctor,
PD_INFER_META(phi::SqueezeInferMeta));
REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker, REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker,
ops::SqueezeGradOpMaker<paddle::framework::OpDesc>, ops::SqueezeGradOpMaker<paddle::framework::OpDesc>,
ops::SqueezeGradOpMaker<paddle::imperative::OpBase>); ops::SqueezeGradOpMaker<paddle::imperative::OpBase>);
...@@ -376,7 +347,7 @@ REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp, ...@@ -376,7 +347,7 @@ REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp,
REGISTER_OPERATOR(squeeze2, ops::Squeeze2Op, ops::Squeeze2OpMaker, REGISTER_OPERATOR(squeeze2, ops::Squeeze2Op, ops::Squeeze2OpMaker,
ops::Squeeze2GradOpMaker<paddle::framework::OpDesc>, ops::Squeeze2GradOpMaker<paddle::framework::OpDesc>,
ops::Squeeze2GradOpMaker<paddle::imperative::OpBase>, ops::Squeeze2GradOpMaker<paddle::imperative::OpBase>,
ops::SqueezeInplaceInferer); ops::SqueezeInplaceInferer, SqueezeInferShapeFunctor);
REGISTER_OPERATOR(squeeze2_grad, ops::Squeeze2GradOp, REGISTER_OPERATOR(squeeze2_grad, ops::Squeeze2GradOp,
ops::Squeeze2DoubleGradOpMaker<paddle::framework::OpDesc>, ops::Squeeze2DoubleGradOpMaker<paddle::framework::OpDesc>,
ops::Squeeze2DoubleGradOpMaker<paddle::imperative::OpBase>, ops::Squeeze2DoubleGradOpMaker<paddle::imperative::OpBase>,
...@@ -411,34 +382,3 @@ REGISTER_OP_CPU_KERNEL( ...@@ -411,34 +382,3 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex<double>>, paddle::platform::complex<double>>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>); paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(
squeeze2, ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, double>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, bool>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(
squeeze2_grad,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
...@@ -46,33 +46,3 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -46,33 +46,3 @@ REGISTER_OP_CUDA_KERNEL(
paddle::platform::complex<float>>, paddle::platform::complex<float>>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>); paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
squeeze2, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, float>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, double>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, plat::bfloat16>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, bool>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
squeeze2_grad,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, float>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, double>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext,
plat::bfloat16>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
...@@ -18,7 +18,9 @@ limitations under the License. */ ...@@ -18,7 +18,9 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -251,19 +253,6 @@ class UnsqueezeDoubleGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -251,19 +253,6 @@ class UnsqueezeDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
class Unsqueeze2Op : public UnsqueezeOp { class Unsqueeze2Op : public UnsqueezeOp {
public: public:
using UnsqueezeOp::UnsqueezeOp; using UnsqueezeOp::UnsqueezeOp;
void InferShape(framework::InferShapeContext *ctx) const override {
UnsqueezeOp::InferShape(ctx);
const auto &x_dims = ctx->GetInputDim("X");
if (!ctx->HasOutput("XShape")) return;
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
ctx->SetOutputDim("XShape", phi::make_ddim(xshape_dims));
ctx->ShareLoD("X", /*->*/ "XShape");
}
}; };
class Unsqueeze2OpMaker : public UnsqueezeOpMaker { class Unsqueeze2OpMaker : public UnsqueezeOpMaker {
...@@ -339,10 +328,14 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnsqueezeGradOpNoNeedBufferVarInferer, "X"); ...@@ -339,10 +328,14 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnsqueezeGradOpNoNeedBufferVarInferer, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(unsqueeze2, Unsqueeze2InferShapeFunctor,
PD_INFER_META(phi::UnsqueezeInferMeta));
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker, REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker,
ops::UnsqueezeGradOpMaker<paddle::framework::OpDesc>, ops::UnsqueezeGradOpMaker<paddle::framework::OpDesc>,
ops::UnsqueezeGradOpMaker<paddle::imperative::OpBase>); ops::UnsqueezeGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp, REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp,
ops::UnsqueezeDoubleGradOpMaker<paddle::framework::OpDesc>, ops::UnsqueezeDoubleGradOpMaker<paddle::framework::OpDesc>,
ops::UnsqueezeDoubleGradOpMaker<paddle::imperative::OpBase>, ops::UnsqueezeDoubleGradOpMaker<paddle::imperative::OpBase>,
...@@ -351,7 +344,8 @@ REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp, ...@@ -351,7 +344,8 @@ REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp,
REGISTER_OPERATOR(unsqueeze2, ops::Unsqueeze2Op, ops::Unsqueeze2OpMaker, REGISTER_OPERATOR(unsqueeze2, ops::Unsqueeze2Op, ops::Unsqueeze2OpMaker,
ops::Unsqueeze2GradOpMaker<paddle::framework::OpDesc>, ops::Unsqueeze2GradOpMaker<paddle::framework::OpDesc>,
ops::Unsqueeze2GradOpMaker<paddle::imperative::OpBase>, ops::Unsqueeze2GradOpMaker<paddle::imperative::OpBase>,
ops::UnsqueezeInplaceInferer); Unsqueeze2InferShapeFunctor, ops::UnsqueezeInplaceInferer);
REGISTER_OPERATOR(unsqueeze2_grad, ops::Unsqueeze2GradOp, REGISTER_OPERATOR(unsqueeze2_grad, ops::Unsqueeze2GradOp,
ops::Unsqueeze2DoubleGradOpMaker<paddle::framework::OpDesc>, ops::Unsqueeze2DoubleGradOpMaker<paddle::framework::OpDesc>,
ops::Unsqueeze2DoubleGradOpMaker<paddle::imperative::OpBase>, ops::Unsqueeze2DoubleGradOpMaker<paddle::imperative::OpBase>,
...@@ -388,34 +382,3 @@ REGISTER_OP_CPU_KERNEL( ...@@ -388,34 +382,3 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex<double>>, paddle::platform::complex<double>>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>); paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(
unsqueeze2, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(
unsqueeze2_grad,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
...@@ -50,37 +50,3 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -50,37 +50,3 @@ REGISTER_OP_CUDA_KERNEL(
paddle::platform::complex<float>>, paddle::platform::complex<float>>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>); paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
unsqueeze2,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::bfloat16>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
unsqueeze2_grad,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, float>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, double>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext,
plat::bfloat16>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
...@@ -42,6 +42,10 @@ const std::unordered_set<std::string> deprecated_op_names({"diag", ...@@ -42,6 +42,10 @@ const std::unordered_set<std::string> deprecated_op_names({"diag",
"flatten_grad", "flatten_grad",
"isinf", "isinf",
"isnan", "isnan",
"unsqueeze",
"unsqueeze_grad",
"squeeze",
"squeeze_grad",
"isfinite", "isfinite",
"matmul", "matmul",
"matmul_grad", "matmul_grad",
......
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/unfold_functor.h" #include "paddle/phi/kernels/funcs/unfold_functor.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
namespace phi { namespace phi {
...@@ -1497,6 +1498,40 @@ void SplitInferMeta(const MetaTensor& x, ...@@ -1497,6 +1498,40 @@ void SplitInferMeta(const MetaTensor& x,
} }
} }
void SqueezeInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
MetaTensor* xshape,
MetaTensor* out) {
const auto& x_dims = x.dims();
// Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE_LE(x_dims.size(),
6,
phi::errors::InvalidArgument(
"The dimensions of Input(X) "
"should be in the range of [1, 6] (Eigen limit)."
"But received X's dimensions = %d, X's shape = [%s].",
x_dims.size(),
x_dims));
auto out_dims = funcs::GetOutputSqueezeShape(axes, x_dims, false);
out->set_dims(out_dims);
if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
out->share_lod(x);
}
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->share_lod(x);
xshape->set_dtype(x.dtype());
out->set_dtype(x.dtype());
}
/* Why not use SumRawInferMeta directly? /* Why not use SumRawInferMeta directly?
Because we need make InferMetaFunction's args follow the design of api.yaml Because we need make InferMetaFunction's args follow the design of api.yaml
*/ */
...@@ -1982,6 +2017,41 @@ void UnfoldInferMeta(const MetaTensor& x, ...@@ -1982,6 +2017,41 @@ void UnfoldInferMeta(const MetaTensor& x,
out->set_dims(phi::make_ddim(out_dims)); out->set_dims(phi::make_ddim(out_dims));
} }
void UnsqueezeInferMeta(const MetaTensor& x,
const ScalarArray& axes,
MetaTensor* xshape,
MetaTensor* out) {
const auto& x_dims = x.dims();
// Validity Check: input tensor dims (<6).
PADDLE_ENFORCE_LE(x_dims.size(),
6,
phi::errors::InvalidArgument(
"Invalid "
"dimensions, the rank of Input(X) "
"should be in the range of [1, 6] (Eigen limit)"));
if (!axes.GetData().empty()) {
std::vector<int32_t> tmp;
tmp.reserve(axes.GetData().size());
std::for_each(axes.GetData().begin(),
axes.GetData().end(),
[&tmp](const int64_t& t) { tmp.push_back(t); });
auto out_dims = funcs::GetUnsqueezeShape(tmp, x_dims);
out->set_dims(out_dims);
if (x_dims[0] == out_dims[0]) {
out->share_lod(x);
}
}
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < x_dims.size(); ++i) {
xshape_dims[i + 1] = x_dims[i];
}
xshape->set_dims(phi::make_ddim(xshape_dims));
xshape->share_lod(x);
out->set_dtype(x.dtype());
xshape->set_dtype(x.dtype());
}
void OneHotRawInferMeta(const MetaTensor& x, void OneHotRawInferMeta(const MetaTensor& x,
int32_t depth, int32_t depth,
DataType dtype, DataType dtype,
...@@ -1992,7 +2062,6 @@ void OneHotRawInferMeta(const MetaTensor& x, ...@@ -1992,7 +2062,6 @@ void OneHotRawInferMeta(const MetaTensor& x,
x_dims.size(), x_dims.size(),
1, 1,
phi::errors::InvalidArgument("Rank of Input(X) should be at least 1.")); phi::errors::InvalidArgument("Rank of Input(X) should be at least 1."));
auto out_dims_vec = phi::vectorize(x_dims); auto out_dims_vec = phi::vectorize(x_dims);
out_dims_vec.push_back(depth); out_dims_vec.push_back(depth);
auto out_dims = phi::make_ddim(out_dims_vec); auto out_dims = phi::make_ddim(out_dims_vec);
......
...@@ -229,6 +229,11 @@ void SplitInferMeta(const MetaTensor& x_meta, ...@@ -229,6 +229,11 @@ void SplitInferMeta(const MetaTensor& x_meta,
std::vector<MetaTensor*> out, std::vector<MetaTensor*> out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void SqueezeInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
MetaTensor* xshape,
MetaTensor* out);
void SumInferMeta(const MetaTensor& x, void SumInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
DataType dtype, DataType dtype,
...@@ -290,6 +295,11 @@ void UnfoldInferMeta(const MetaTensor& x, ...@@ -290,6 +295,11 @@ void UnfoldInferMeta(const MetaTensor& x,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void UnsqueezeInferMeta(const MetaTensor& x,
const ScalarArray& axes,
MetaTensor* xshape,
MetaTensor* out);
void OneHotRawInferMeta(const MetaTensor& x, void OneHotRawInferMeta(const MetaTensor& x,
int32_t depth, int32_t depth,
DataType dtype, DataType dtype,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/squeeze_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/squeeze_grad_kernel_impl.h"
PD_REGISTER_KERNEL(squeeze_grad,
CPU,
ALL_LAYOUT,
phi::SqueezeGradKernel,
float,
double,
phi::dtype::bfloat16,
bool,
int,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/squeeze_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/squeeze_kernel_impl.h"
PD_REGISTER_KERNEL(squeeze,
CPU,
ALL_LAYOUT,
phi::SqueezeKernel,
float,
double,
phi::dtype::bfloat16,
bool,
int,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/unsqueeze_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/unsqueeze_grad_kernel_impl.h"
PD_REGISTER_KERNEL(unsqueeze_grad,
CPU,
ALL_LAYOUT,
phi::UnsqueezeGradKernel,
phi::dtype::bfloat16,
bool,
int,
int16_t,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/unsqueeze_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/unsqueeze_kernel_impl.h"
PD_REGISTER_KERNEL(unsqueeze,
CPU,
ALL_LAYOUT,
phi::UnsqueezeKernel,
float,
double,
phi::dtype::bfloat16,
bool,
int,
int16_t,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
...@@ -21,6 +21,118 @@ ...@@ -21,6 +21,118 @@
namespace phi { namespace phi {
namespace funcs { namespace funcs {
inline DDim GetOutputSqueezeShape(const std::vector<int> squeeze_dims,
const DDim& in_dims,
bool is_runtime) {
size_t num_squeeze_dims = squeeze_dims.size();
std::vector<bool> should_squeeze(in_dims.size(), false);
// Mark dimensions need to be squeezed.
if (num_squeeze_dims == 0) {
for (int i = 0; i < in_dims.size(); ++i) {
if (in_dims[i] == 1) {
should_squeeze[i] = true;
}
}
} else {
for (size_t i = 0; i < num_squeeze_dims; ++i) {
int current = squeeze_dims[i] < 0 ? squeeze_dims[i] + in_dims.size()
: squeeze_dims[i];
PADDLE_ENFORCE_GE(
current,
0,
phi::errors::InvalidArgument(
"Each axis in Attr(axes) should be in the range of [%d, %d]"
"But current axis is:%d, input tensor's shape = [%s].",
-in_dims.size(),
in_dims.size() - 1,
current,
in_dims));
PADDLE_ENFORCE_LT(
current,
in_dims.size(),
phi::errors::InvalidArgument(
"Each axis in Attr(axes) should be in the range of [%d, %d]"
"But current axis is:%d, input tensor's shape = [%s].",
-in_dims.size(),
in_dims.size() - 1,
current,
in_dims));
if (!should_squeeze[current]) {
if (is_runtime) {
// At run time, dim of 1 is allowed to squeeze
if (in_dims[current] == 1) {
should_squeeze[current] = true;
}
} else {
// At compile time, dim of -1 or 1 is allowed to squeeze
if (in_dims[current] == 1 || in_dims[current] == -1) {
should_squeeze[current] = true;
}
}
}
}
}
// Make output dimensions
std::vector<int64_t> output_shape;
for (int i = 0; i < in_dims.size(); ++i) {
if (!should_squeeze[i]) {
output_shape.push_back(in_dims[i]);
}
}
return phi::make_ddim(output_shape);
}
inline DDim GetUnsqueezeShape(const std::vector<int> unsqz_dims,
const DDim& in_dims) {
int output_size = in_dims.size() + static_cast<int>(unsqz_dims.size());
int cur_output_size = in_dims.size();
std::vector<int64_t> output_shape(output_size, 0);
// Validity Check: rank range.
PADDLE_ENFORCE_LE(
output_size,
6,
phi::errors::InvalidArgument("The output "
"tensor's rank should be less than 6."));
for (int axis : unsqz_dims) {
int cur = axis < 0 ? axis + cur_output_size + 1 : axis;
// Vaildity Check: the axis bound
PADDLE_ENFORCE_GE(
cur,
0,
phi::errors::InvalidArgument("The insert dimension value should "
"not be less than 0"));
PADDLE_ENFORCE_LE(cur,
cur_output_size,
phi::errors::InvalidArgument(
"The insert dimension value shoule not be larger "
"than the dimension size of input tensor"));
// Move old axis, and insert new axis
for (int i = cur_output_size; i >= cur; --i) {
if (output_shape[i] == 1) {
// Move axis
output_shape[i + 1] = 1;
output_shape[i] = 0;
}
}
output_shape[cur] = 1;
// Add the output size.
cur_output_size++;
}
// Make output shape
for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) {
if (output_shape[out_idx] == 0) {
output_shape[out_idx] = in_dims[in_idx++];
}
}
return phi::make_ddim(output_shape);
}
inline const DenseTensor Unsqueeze(const DenseTensor& x, int axis = 0) { inline const DenseTensor Unsqueeze(const DenseTensor& x, int axis = 0) {
// don't copy data, only change the dims // don't copy data, only change the dims
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/squeeze_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/squeeze_grad_kernel_impl.h"
PD_REGISTER_KERNEL(squeeze_grad,
GPU,
ALL_LAYOUT,
phi::SqueezeGradKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16,
bool,
int,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/squeeze_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/squeeze_kernel_impl.h"
PD_REGISTER_KERNEL(squeeze,
GPU,
ALL_LAYOUT,
phi::SqueezeKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16,
bool,
int,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/unsqueeze_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/unsqueeze_grad_kernel_impl.h"
PD_REGISTER_KERNEL(unsqueeze_grad,
GPU,
ALL_LAYOUT,
phi::UnsqueezeGradKernel,
phi::dtype::bfloat16,
phi::dtype::float16,
bool,
int,
int16_t,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/unsqueeze_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/unsqueeze_kernel_impl.h"
PD_REGISTER_KERNEL(unsqueeze,
GPU,
ALL_LAYOUT,
phi::UnsqueezeKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
bool,
int,
int16_t,
uint8_t,
int8_t,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/copy_kernel.h"
namespace phi {
template <typename T, typename Context>
void SqueezeGradKernel(const Context& dev_ctx,
const DenseTensor& xshape,
const DenseTensor& dout,
const std::vector<int>& axes,
DenseTensor* dx) {
auto xshape_dims = xshape.dims();
auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
dev_ctx.template Alloc<T>(dx);
phi::Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
dx->Resize(x_dims);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
namespace phi {
template <typename T, typename Context>
void SqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axes,
DenseTensor* xshape,
DenseTensor* out) {
auto x_dims = x.dims();
auto out_dims = funcs::GetOutputSqueezeShape(axes, x_dims, true);
dev_ctx.template Alloc<T>(out);
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
out->Resize(out_dims);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/copy_kernel.h"
namespace phi {
template <typename T, typename Context>
void UnsqueezeGradKernel(const Context& dev_ctx,
const DenseTensor& x_shape,
const DenseTensor& dout,
DenseTensor* dx) {
auto xshape_dims = x_shape.dims();
auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
dev_ctx.template Alloc<T>(dx);
phi::Copy(dev_ctx, dout, dev_ctx.GetPlace(), true, dx);
dx->Resize(x_dims);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
namespace phi {
template <typename T, typename Context>
void UnsqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& axes,
DenseTensor* xshape,
DenseTensor* out) {
auto x_dims = x.dims();
auto out_dims = out->dims();
if (axes.FromTensor()) {
std::vector<int32_t> tmp;
tmp.reserve(axes.GetData().size());
std::for_each(axes.GetData().begin(),
axes.GetData().end(),
[&tmp](const int64_t& t) { tmp.push_back(t); });
out_dims = funcs::GetUnsqueezeShape(tmp, x_dims);
}
out->Resize(out_dims);
dev_ctx.template Alloc<T>(out);
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
out->Resize(out_dims); // copy will reset the dims.
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SqueezeGradKernel(const Context& dev_ctx,
const DenseTensor& xshape,
const DenseTensor& dout,
const std::vector<int>& axes,
DenseTensor* dx);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axes,
DenseTensor* xshape,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void UnsqueezeGradKernel(const Context& dev_ctx,
const DenseTensor& x_shape,
const DenseTensor& dout,
DenseTensor* dx);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void UnsqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& axes,
DenseTensor* xshape,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature SqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("squeeze", {"X"}, {"axes"}, {"XShape", "Out"});
}
KernelSignature SqueezeGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("squeeze_grad",
{"XShape", GradVarName("Out")},
{"axes"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(squeeze2, squeeze);
PD_REGISTER_BASE_KERNEL_NAME(squeeze2_grad, squeeze_grad);
PD_REGISTER_ARG_MAPPING_FN(squeeze2, phi::SqueezeOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(squeeze2_grad, phi::SqueezeGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature UnsqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.InputSize("AxesTensorList") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensorList";
return KernelSignature(
"unsqueeze", {"X"}, {"AxesTensorList"}, {"XShape", "Out"});
} else if (ctx.InputSize("AxesTensor") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensor";
return KernelSignature(
"unsqueeze", {"X"}, {"AxesTensor"}, {"XShape", "Out"});
} else {
VLOG(2) << "unsqueeze2 in axes";
return KernelSignature("unsqueeze", {"X"}, {"axes"}, {"XShape", "Out"});
}
}
KernelSignature UnsqueezeGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"unsqueeze_grad", {"XShape", GradVarName("Out")}, {}, {GradVarName("X")});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(unsqueeze2, unsqueeze);
PD_REGISTER_BASE_KERNEL_NAME(unsqueeze2_grad, unsqueeze_grad);
PD_REGISTER_ARG_MAPPING_FN(unsqueeze2, phi::UnsqueezeOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(unsqueeze2_grad,
phi::UnsqueezeGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册