未验证 提交 bcf67536 编写于 作者: W Wang Xin 提交者: GitHub

static graph autogen code support for pad3d op (#53733)

* static graph autogen code support for pad3d op

* bug fixed

* add ut for pad3d mkldnn op

* fix coverage

* fix bug

* fix bug

* Delete test_pad3d_mkldnn_op.py
上级 1ef0de81
......@@ -195,6 +195,25 @@ phi::KernelKey GetMatrixNmsExpectedKernelType(
platform::CPUPlace());
}
phi::KernelKey GetPad3dExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr) {
auto input_data_type = op_ptr->IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
// only constant mode and non-blocked layouts are supported for oneDNN
if (op_ptr->CanMKLDNNBeUsed(ctx, input_data_type) &&
ctx.Attr<std::string>("mode") == "constant" &&
ctx.Input<phi::DenseTensor>("X")
->mem_desc()
.data.format_desc.blocking.inner_nblks == 0) {
return phi::KernelKey(phi::Backend::ONEDNN,
phi::DataLayout::ONEDNN,
phi::TransToPhiDataType(input_data_type));
}
#endif
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
phi::KernelKey GetYoloLossExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr) {
......
......@@ -52,6 +52,10 @@ phi::KernelKey GetMatrixNmsExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);
phi::KernelKey GetPad3dExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);
phi::KernelKey GetUniqueExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
class Pad3dOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
// only constant mode and non-blocked layouts are supported for oneDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type) &&
ctx.Attr<std::string>("mode") == "constant" &&
ctx.Input<phi::DenseTensor>("X")
->mem_desc()
.data.format_desc.blocking.inner_nblks == 0) {
return phi::KernelKey(phi::Backend::ONEDNN,
phi::DataLayout::ONEDNN,
phi::TransToPhiDataType(input_data_type));
}
#endif
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const override {
#ifdef PADDLE_WITH_MKLDNN
if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
return phi::KernelKey(tensor.place(),
phi::StringToDataLayout(data_format),
expected_kernel_type.dtype());
}
#endif
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
};
class Pad3dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The input of pad3d op. "
"The input should be a 5-D tensor with formate NCDHW or NDHWC.");
AddOutput("Out",
"The output of pad3d op. "
"A tensor with the same shape as X.");
AddInput("Paddings",
"A 1-D tensor to describe the padding rules."
"paddings=[0, 1, 2, 3, 4, 5] means "
"padding 0 column to left, 1 column to right, "
"2 row to top, 3 row to bottom, 4 depth to front "
"and 5 depth to back. Size of paddings must be 6.")
.AsDispensable();
AddAttr<std::vector<int>>(
"paddings",
"(vector<int>) "
"A list<int> to describe the padding rules."
"paddings=[0, 1, 2, 3, 4, 5] means "
"padding 0 column to left, 1 column to right, "
"2 row to top, 3 row to bottom, 4 depth to front "
"and 5 depth to back. Size of paddings must be 6.");
AddAttr<float>("value",
"(float, default 0.0) "
"The value to fill the padded areas in constant mode.")
.SetDefault(0.0f);
AddAttr<std::string>(
"mode",
"(string, default constant) "
"Four modes: constant(default), reflect, replicate, circular.")
.SetDefault("constant");
AddAttr<std::string>(
"data_format",
"(string, default NCDHW) Only used in "
"An optional string from: \"NDHWC\", \"NCDHW\". "
"Defaults to \"NDHWC\". Specify the data format of the input data.")
.SetDefault("NCDHW");
AddComment(R"DOC(
Pad3d Operator.
Pad 3-d images according to 'paddings' and 'mode'.
If mode is 'reflect', paddings[0] and paddings[1] must be no greater
than width-1. The height and depth dimension have the same condition.
Given that X is a channel of image from input:
X = [[[[[1, 2, 3],
[4, 5, 6]]]]]
Case 0:
paddings = [2, 2, 1, 1, 0, 0],
mode = 'constant'
pad_value = 0
Out = [[[[[0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 2. 3. 0. 0.]
[0. 0. 4. 5. 6. 0. 0.]
[0. 0. 0. 0. 0. 0. 0.]]]]]
Case 1:
paddings = [2, 2, 1, 1, 0, 0],
mode = 'reflect'
Out = [[[[[6. 5. 4. 5. 6. 5. 4.]
[3. 2. 1. 2. 3. 2. 1.]
[6. 5. 4. 5. 6. 5. 4.]
[3. 2. 1. 2. 3. 2. 1.]]]]]
Case 2:
paddings = [2, 2, 1, 1, 0, 0],
mode = 'replicate'
Out = [[[[[1. 1. 1. 2. 3. 3. 3.]
[1. 1. 1. 2. 3. 3. 3.]
[4. 4. 4. 5. 6. 6. 6.]
[4. 4. 4. 5. 6. 6. 6.]]]]]
Case 3:
paddings = [2, 2, 1, 1, 0, 0],
mode = 'circular'
Out = [[[[[5. 6. 4. 5. 6. 4. 5.]
[2. 3. 1. 2. 3. 1. 2.]
[5. 6. 4. 5. 6. 4. 5.]
[2. 3. 1. 2. 3. 1. 2.]]]]]
)DOC");
}
};
class Pad3dOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Pad3d@Grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
framework::GradVarName("Out"),
"Pad3d@Grad");
auto x_dims = ctx->GetInputDim("X");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
template <typename T>
class Pad3dOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> bind) const override {
bind->SetInput("X", this->Input("X"));
if (this->HasInput("Paddings")) {
bind->SetInput("Paddings", this->Input("Paddings"));
}
bind->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
bind->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
bind->SetAttrMap(this->Attrs());
bind->SetType("pad3d_grad");
}
};
template <typename T>
class Pad3dOpDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> grad_op) const override {
if (this->HasInput("Paddings")) {
grad_op->SetInput("Paddings", this->Input("Paddings"));
}
grad_op->SetType("pad3d");
grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
grad_op->SetAttrMap(this->Attrs());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(Pad3dOpGradNoNeedBufferVarsInferer, "X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(pad3d,
Pad3dInferShapeFunctor,
PD_INFER_META(phi::Pad3dInferMeta));
REGISTER_OPERATOR(pad3d,
ops::Pad3dOp,
ops::Pad3dOpMaker,
ops::Pad3dOpGradMaker<paddle::framework::OpDesc>,
ops::Pad3dOpGradMaker<paddle::imperative::OpBase>,
Pad3dInferShapeFunctor);
REGISTER_OPERATOR(pad3d_grad,
ops::Pad3dOpGrad,
ops::Pad3dOpDoubleGradMaker<paddle::framework::OpDesc>,
ops::Pad3dOpDoubleGradMaker<paddle::imperative::OpBase>,
ops::Pad3dOpGradNoNeedBufferVarsInferer);
......@@ -1378,6 +1378,27 @@
kernel :
func : p_norm_grad
- backward_op : pad3d_double_grad
forward : pad3d_grad(Tensor x, Tensor grad_out, IntArray paddings, str mode="constant", float pad_value=0.0, str data_format="NCDHW") -> Tensor(grad_x)
args : (Tensor grad_x_grad, IntArray paddings, str mode, float pad_value, str data_format)
output : Tensor(grad_out_grad)
infer_meta :
func : Pad3dInferMeta
kernel :
func : pad3d
- backward_op : pad3d_grad
forward : pad3d(Tensor x, IntArray paddings, str mode="constant", float pad_value=0.0, str data_format="NCDHW") -> Tensor(out)
args : (Tensor x, Tensor out_grad, IntArray paddings, str mode, float pad_value, str data_format)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [x]
kernel :
func : pad3d_grad
no_need_buffer : x
backward : pad3d_double_grad
- backward_op : pixel_shuffle_grad
forward : pixel_shuffle (Tensor x, int upscale_factor=1, str data_format="NCHW") -> Tensor(out)
args : (Tensor out_grad, int upscale_factor, str data_format)
......
......@@ -592,27 +592,6 @@
kernel :
func : norm_grad
- backward_op : pad3d_double_grad
forward : pad3d_grad(Tensor x, Tensor grad_out, IntArray paddings, str mode, float pad_value, str data_format) -> Tensor(grad_x)
args : (Tensor grad_x_grad, IntArray paddings, str mode, float pad_value, str data_format)
output : Tensor(grad_out_grad)
infer_meta :
func : Pad3dInferMeta
kernel :
func : pad3d
- backward_op : pad3d_grad
forward : pad3d(Tensor x, IntArray paddings, str mode, float pad_value, str data_format) -> Tensor(out)
args : (Tensor x, Tensor out_grad, IntArray paddings, str mode, float pad_value, str data_format)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [x]
kernel :
func : pad3d_grad
no_need_buffer : x
backward : pad3d_double_grad
- backward_op : pad_double_grad
forward : pad_grad(Tensor x, Tensor grad_out, int[] paddings, Scalar pad_value) -> Tensor(grad_x)
args : (Tensor grad_x_grad, int[] paddings, Scalar pad_value)
......
......@@ -771,15 +771,6 @@
func : pad
backward : pad_grad
- op : pad3d
args : (Tensor x, IntArray paddings, str mode, float pad_value, str data_format)
output : Tensor(out)
infer_meta :
func : Pad3dInferMeta
kernel :
func : pad3d
backward : pad3d_grad
- op : pool2d
args : (Tensor x, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm)
output : Tensor(out)
......
......@@ -1741,7 +1741,17 @@
attrs : [bool use_mkldnn = false]
- op : pad3d
backward : pad3d_grad
backward : pad3d_grad, pad3d_double_grad
inputs :
x : X
outputs :
out : Out
int_array:
paddings :
data_type : int
tensor_name : Paddings
attrs :
pad_value : value
extra :
attrs : [bool use_mkldnn = false]
......
......@@ -1639,6 +1639,15 @@
func : p_norm
backward : p_norm_grad
- op : pad3d
args : (Tensor x, IntArray paddings, str mode = "constant", float pad_value = 0.0, str data_format = "NCDHW")
output : Tensor(out)
infer_meta :
func : Pad3dInferMeta
kernel :
func : pad3d
backward : pad3d_grad
- op : pixel_shuffle
args : (Tensor x, int upscale_factor=1, str data_format="NCHW")
output : Tensor
......
......@@ -20,6 +20,24 @@
namespace phi {
KernelKey Pad3dGetKernelTypeForVar(const GetKernelTypeForVarContext* ctx) {
const DenseTensor& tensor = ctx->GetTensor();
const KernelKey& expected_kernel_type = ctx->GetKernelKey();
const AttributeMap& attrs = ctx->GetAttrs();
#ifdef PADDLE_WITH_MKLDNN
if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto it = attrs.find("data_format");
const std::string data_format = PADDLE_GET_CONST(std::string, it->second);
return phi::KernelKey(tensor.place(),
phi::StringToDataLayout(data_format),
expected_kernel_type.dtype());
}
#endif
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
template <typename T, typename Context>
void Pad3dKernel(const Context& dev_ctx,
const DenseTensor& x,
......@@ -38,4 +56,6 @@ PD_REGISTER_KERNEL(pad3d,
phi::Pad3dKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float) {}
float) {
kernel->get_kerneltype_forvar_fn_ = phi::Pad3dGetKernelTypeForVar;
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature Pad3dOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Paddings")) {
return KernelSignature(
"pad3d", {"X"}, {"Paddings", "mode", "value", "data_format"}, {"Out"});
}
return KernelSignature(
"pad3d", {"X"}, {"paddings", "mode", "value", "data_format"}, {"Out"});
}
KernelSignature Pad3dGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Paddings")) {
return KernelSignature("pad3d_grad",
{"X", "Out@GRAD"},
{"Paddings", "mode", "value", "data_format"},
{"X@GRAD"});
}
return KernelSignature("pad3d_grad",
{"X", "Out@GRAD"},
{"paddings", "mode", "value", "data_format"},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(pad3d, phi::Pad3dOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(pad3d_grad, phi::Pad3dGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册