未验证 提交 1075d35d 编写于 作者: H huangjiyi 提交者: GitHub

Support code generation for op conv2d_transpose, conv3d_transpose,...

Support code generation for op conv2d_transpose, conv3d_transpose, depthwise_conv2d_transpose (#54242)
上级 587f66ee
......@@ -459,7 +459,6 @@ if(WITH_MKLDNN)
set(TEST_CONV_BN_PASS_DEPS
conv_bn_fuse_pass
graph_to_program_pass
conv_transpose_op
batch_norm_op
generated_op
generated_static_op
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/conv_transpose_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
using DataLayout = phi::DataLayout;
phi::KernelKey ConvTransposeOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
return phi::KernelKey(data_type, ctx.GetPlace());
}
phi::KernelKey ConvTransposeOp::GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if ((var_name == "Input") &&
(expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
auto dl = phi::StringToDataLayout(data_format);
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if (dl != phi::DataLayout::kAnyLayout) {
return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype());
}
}
#endif
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
void Conv2DTransposeOpMaker::Make() {
AddInput("Input",
"(Tensor) The input tensor of convolution transpose operator. "
"The format of input tensor is NCHW or NHWC. Where N is batch size, "
"C is the number of input channels, H is the height of the feature, "
"and W is the width of the feature.");
AddInput(
"Filter",
"(Tensor) The filter tensor of convolution transpose operator. "
"The format of the filter tensor is MCHW, where M is the number of "
"input feature channels, C is the number of "
"output feature channels,"
"H is the height of the filter, and W is the width of the filter. "
"We enforce groups number == 1 in the convolution transpose scenario.");
AddInput("Bias",
"(Tensor) Bias to be added to each output of filter application."
"The format of output tensor is X (one-dimensional) of size equal"
"to the number of output channels. Only used with MKL-DNN.")
.AsDispensable()
.AsExtra();
AddOutput("Output",
"(Tensor) The output tensor of convolution transpose operator. "
"The format of output tensor is the same as input tensor.");
AddAttr<std::vector<int>>("output_padding",
"(vector<int> default: []), Additional size added "
"to one side of each dimension in the output "
"shape")
.SetDefault({});
AddAttr<std::vector<int>>("output_size",
"(vector<int> default: []), the "
"size of the output tensor")
.SetDefault({})
.SupportTensor();
AddAttr<int>("groups",
"(int default:1), the groups number of the convolution "
"transpose operator. ")
.SetDefault(1);
AddAttr<std::vector<int>>("dilations",
"(vector<int> default:{1, 1}), the "
"dilations(h_dilation, w_dilation) of convolution "
"transpose operator.")
.SetDefault({1, 1});
AddAttr<std::vector<int>>(
"strides",
"(vector<int> default:{1, 1}), the strides(h_stride, w_stride) of "
"convolution transpose operator.")
.SetDefault({1, 1});
AddAttr<std::vector<int>>(
"paddings",
"(vector<int> default:{0, 0}), the paddings(h_pad, w_pad) of convolution "
"transpose operator.")
.SetDefault({0, 0});
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Specify that the data format of the input and output data is "
"channel_first or channel_last.")
.SetDefault("NCHW");
AddAttr<std::string>(
"padding_algorithm",
"(string, default \"EXPLICIT\") An optional string from: \"EXPLICIT\","
"\"SAME\",\"VALID\". Set to \"EXPLICIT\" for explicit padding. "
"Set to \"SAME\" or \"VALID\" for algorithm of padding. ")
.SetDefault("EXPLICIT");
AddComment(R"DOC(
Convolution2D Transpose Operator.
The convolution transpose operation calculates the output based on the input, filter
and dilations, strides, paddings, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
Input(Input) and output(Output) are in NCHW or NHWC format. Where N is batchsize, C is the
number of channels, H is the height of the feature, and W is the width of the feature.
Filter(Input) is in MCHW format. Where M is the number of input feature channels,
C is the number of output feature channels, H is the height of the filter,
and W is the width of the filter.
Parameters(strides, paddings) are two elements. These two elements represent height
and width, respectively.
The input(X) size and output(Out) size may be different.
For an example:
Input:
Input shape: $(N, C_{in}, H_{in}, W_{in})$
Filter shape: $(C_{in}, C_{out}, H_f, W_f)$
Output:
Output shape: $(N, C_{out}, H_{out}, W_{out})$
Where
$$
H_{out} = (H_{in} - 1) * strides[0] - pad_height_top - pad_height_bottom + dilations[0] * (H_f - 1) + 1 \\
W_{out} = (W_{in} - 1) * strides[1] - pad_width_left - pad_width_right + dilations[1] * (W_f - 1) + 1
$$
)DOC");
}
void Conv3DTransposeOpMaker::Make() {
AddInput(
"Input",
"(Tensor) The input tensor of convolution transpose operator."
"The format of input tensor is NCDHW or NDHWC. Where N is batch "
"size, C is the number of channels, D is the depth of the feature, "
"H is the height of the feature, and W is the width of the feature.");
AddInput("Filter",
"(Tensor) The filter tensor of convolution transpose operator."
"The format of the filter tensor is MCDHW, where M is the number of "
"input feature channels, C is the number of "
"output feature channels, D "
"is the depth of the filter, H is the height of the filter, and "
"W is the width of the filter."
"We enforce groups number == 1 and padding == 0 in "
"the convolution3d transpose scenario.");
AddOutput("Output",
"(Tensor) The output tensor of convolution transpose operator."
"The format of output tensor is the same as input tensor."
"Where N is batch size, C is "
"the number of channels, D is the depth of the feature, H is the "
"height of the feature, and W is the width of the feature.");
AddAttr<std::vector<int>>("output_padding",
"(vector<int> default: []), Additional size added "
"to one side of each dimension in the output "
"shape")
.SetDefault({});
AddAttr<std::vector<int>>("output_size",
"(vector<int> default: []), the "
"size of the output tensor")
.SetDefault({});
AddAttr<std::vector<int>>(
"dilations",
"(vector<int> default:{1, 1, 1}), the "
"dilations(d_dilation,h_dilation, w_dilation) of convolution "
"transpose operator.")
.SetDefault({1, 1, 1});
AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1, 1}), the "
"strides{d_stride, h_stride, w_stride} of "
"convolution transpose operator.")
.SetDefault({1, 1, 1});
AddAttr<std::vector<int>>("paddings",
"(vector<int> default:{0, 0, 0}), paddings(d_pad, "
"h_pad, w_pad) of convolution transpose operator.")
.SetDefault({0, 0, 0});
AddAttr<int>("groups",
"(int default:1), the groups number of the convolution3d "
"transpose operator. ")
.SetDefault(1);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Specify that the data format of the input and output data is "
"channel_first or channel_last.")
.SetDefault("NCHW");
AddAttr<std::string>(
"padding_algorithm",
"(string, default \"EXPLICIT\") An optional string from: \"EXPLICIT\","
"\"SAME\",\"VALID\". Set to \"EXPLICIT\" for explicit padding. "
"Set to \"SAME\" or \"VALID\" for algorithm of padding. ")
.SetDefault("EXPLICIT");
AddComment(R"DOC(
Convolution3D Transpose Operator.
The convolution transpose operation calculates the output based on the input, filter
and dilations, strides, paddings, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
Input(Input) and output(Output) are in NCDHW or NDHWC format. Where N is batch size, C is the
number of channels, D is the depth of the feature, H is the height of the feature,
and W is the width of the feature.
Filter(Input) is in MCDHW format. Where M is the number of input feature channels,
C is the number of output feature channels, D is the depth of the filter,H is the
height of the filter, and W is the width of the filter.
Parameters(strides, paddings) are three elements. These three elements represent
depth, height and width, respectively.
The input(X) size and output(Out) size may be different.
Example:
Input:
Input shape: $(N, C_{in}, D_{in}, H_{in}, W_{in})$
Filter shape: $(C_{in}, C_{out}, D_f, H_f, W_f)$
Output:
Output shape: $(N, C_{out}, D_{out}, H_{out}, W_{out})$
Where
$$
D_{out} = (D_{in} - 1) * strides[0] - pad_depth_front - pad_depth_back + dilations[0] * (D_f - 1) + 1 \\
H_{out} = (H_{in} - 1) * strides[1] - pad_height_top - pad_height_bottom + dilations[1] * (H_f - 1) + 1 \\
W_{out} = (W_{in} - 1) * strides[2] - pad_width_left - pad_width_right + dilations[2] * (W_f - 1) + 1
$$
)DOC");
}
phi::KernelKey ConvTransposeOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
return phi::KernelKey(data_type, ctx.GetPlace());
}
template <typename T>
class ConvTransposeGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", this->Input("Filter"));
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("Filter"), this->InputGrad("Filter"));
if (this->HasInput("Bias")) {
op->SetInput("Bias", this->Input("Bias"));
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
}
op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
op->SetAttrMap(this->Attrs());
}
};
/*
* Inputs: I, W, dO, ddI, ddW
* Outputs: ddO, dW, dI
*/
template <typename T>
class ConvTransposeDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
// I, W, dO, ddI, ddW
op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", this->Input("Filter"));
op->SetInput("DOutput", this->Input(framework::GradVarName("Output")));
op->SetInput("DDInput", this->OutputGrad(framework::GradVarName("Input")));
op->SetInput("DDFilter",
this->OutputGrad(framework::GradVarName("Filter")));
// ddO, dI, dW
// Unlike grad op, double grad op does not use name@GRAD@GRAD
// as key of ops' inputs and outputs.
auto ddx = this->OutputGrad(framework::GradVarName("Input"));
auto ddw = this->OutputGrad(framework::GradVarName("Filter"));
op->SetOutput("DDOutput",
ddx.empty()
? this->EmptyInputGrad()
: this->InputGrad(framework::GradVarName("Output")));
op->SetOutput(
"DFilter",
ddx.empty() ? this->EmptyInputGrad() : this->InputGrad("Filter"));
op->SetOutput(
"DInput",
ddw.empty() ? this->EmptyInputGrad() : this->InputGrad("Input"));
op->SetAttrMap(this->Attrs());
}
};
phi::KernelKey ConvTransposeOpDoubleGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
return phi::KernelKey(data_type, ctx.GetPlace());
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
// conv2d_transpose
DECLARE_INFER_SHAPE_FUNCTOR(conv2d_transpose,
Conv2dTranposeInferShapeFunctor,
PD_INFER_META(phi::Conv2dTransposeInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(conv2d_transpose_grad,
Conv2dTranposeGradInferShapeFunctor,
PD_INFER_META(phi::Conv2dTransposeGradInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(
conv2d_transpose_grad_grad,
Conv2dTranposeDoubleGradInferShapeFunctor,
PD_INFER_META(phi::Conv2dTransposeDoubleGradInferMeta));
REGISTER_OPERATOR(conv2d_transpose,
ops::ConvTransposeOp,
ops::Conv2DTransposeOpMaker,
ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>,
ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>,
Conv2dTranposeInferShapeFunctor);
REGISTER_OPERATOR(conv2d_transpose_grad,
ops::ConvTransposeOpGrad,
ops::ConvTransposeDoubleGradMaker<paddle::framework::OpDesc>,
ops::ConvTransposeDoubleGradMaker<paddle::imperative::OpBase>,
Conv2dTranposeGradInferShapeFunctor);
REGISTER_OPERATOR(conv2d_transpose_grad_grad,
ops::ConvTransposeOpDoubleGrad,
Conv2dTranposeDoubleGradInferShapeFunctor);
// conv3d_transpose
DECLARE_INFER_SHAPE_FUNCTOR(conv3d_transpose,
Conv3dTranposeInferShapeFunctor,
PD_INFER_META(phi::ConvTransposeInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(conv3d_transpose_grad,
Conv3dTranposeGradInferShapeFunctor,
PD_INFER_META(phi::ConvTransposeGradInferMeta));
REGISTER_OPERATOR(conv3d_transpose,
ops::ConvTransposeOp,
ops::Conv3DTransposeOpMaker,
ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>,
ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>,
Conv3dTranposeInferShapeFunctor);
REGISTER_OPERATOR(conv3d_transpose_grad,
ops::ConvTransposeOpGrad,
Conv3dTranposeGradInferShapeFunctor);
// depthwise conv2d_transpose
DECLARE_INFER_SHAPE_FUNCTOR(depthwise_conv2d_transpose,
DepthWiseConv2dTranposeInferShapeFunctor,
PD_INFER_META(phi::Conv2dTransposeInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(depthwise_conv2d_transpose_grad,
DepthWiseConv2dTranposeGradInferShapeFunctor,
PD_INFER_META(phi::Conv2dTransposeGradInferMeta));
REGISTER_OPERATOR(depthwise_conv2d_transpose,
ops::ConvTransposeOp,
ops::Conv2DTransposeOpMaker,
ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>,
ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>,
DepthWiseConv2dTranposeInferShapeFunctor);
REGISTER_OPERATOR(depthwise_conv2d_transpose_grad,
ops::ConvTransposeOpGrad,
DepthWiseConv2dTranposeGradInferShapeFunctor);
REGISTER_OP_VERSION(conv_transpose)
.AddCheckpoint(
R"ROC(
Upgrade convtranspose add a new attribute [output_padding].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"output_padding",
"In order to add additional size to one side of each dimension "
"in the output",
std::vector<int>{}));
REGISTER_OP_VERSION(conv2d_transpose)
.AddCheckpoint(
R"ROC(
Upgrade conv2d transpose to add a new attribute [output_padding].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"output_padding",
"In order to add additional size to one side of each dimension "
"in the output",
std::vector<int>{}))
.AddCheckpoint(
R"ROC(
Upgrade conv2d transpose to add a new attributes [force_fp32_output, mkldnn_data_type].
)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewAttr("force_fp32_output",
"Force BF16 kernel output FP32, only used in MKL-DNN BF16",
false)
.NewAttr("mkldnn_data_type",
"Data type of mkldnn kernel",
"float32"));
REGISTER_OP_VERSION(conv3d_transpose)
.AddCheckpoint(
R"ROC(
Upgrade conv3d transpose to add a new attribute [output_padding].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"output_padding",
"In order to add additional size to one side of each dimension "
"in the output",
std::vector<int>{}));
REGISTER_OP_VERSION(depthwise_conv2d_transpose)
.AddCheckpoint(
R"ROC(
Upgrade depthwise conv2d transpose to add a new attribute [output_padding].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"output_padding",
"In order to add additional size to one side of each dimension "
"in the output",
std::vector<int>{}));
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
// Define Op classes in .h file so that other conv transpose
// operator implementations can reuse the code.
class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
class ConvTransposeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const override;
};
class ConvTransposeOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class ConvTransposeOpDoubleGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
} // namespace operators
} // namespace paddle
......@@ -318,6 +318,10 @@ def add_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
for out_item in forward_op_item['outputs']:
if out_item['name'] in op_args['extra']['outputs']:
out_item['is_extra'] = True
if 'extra' in op_args and 'inputs' in op_args['extra']:
for input_item in forward_op_item['inputs']:
if input_item['name'] in op_args['extra']['inputs']:
input_item['is_extra'] = True
key_set = ['inputs', 'attrs', 'outputs']
args_map = {}
......
......@@ -38,6 +38,10 @@ AddInput({{name| to_opmaker_name}}, "({{typename}}), input {{i}} of {{op_name}}
.AsDispensable()
{%- endif %}
{%- if "is_extra" in input and input["is_extra"] %}
.AsExtra()
{%- endif %}
{%- endmacro %}
{# add output, it could be duplicable or intermediate, however, optional output is not supported #}
......
......@@ -86,8 +86,6 @@ register_unity_group(
real_op.cc
sync_batch_norm_op.cc
top_k_op.cc
conv_op.cc
conv_transpose_op.cc
gru_unit_op.cc)
register_unity_group(
cc
......@@ -343,14 +341,8 @@ register_unity_group(
run_program_op.cc
softmax_with_cross_entropy_op.cc
warpctc_op.cc)
register_unity_group(
cc
conv_op.cu.cc
lstm_op.cu.cc
rnn_op.cu.cc
split_op.cu.cc
assign_value_op.cu.cc
warpctc_op.cu.cc)
register_unity_group(cc lstm_op.cu.cc rnn_op.cu.cc split_op.cu.cc
assign_value_op.cu.cc warpctc_op.cu.cc)
register_unity_group(
cu
addmm_op.cu
......@@ -374,9 +366,7 @@ register_unity_group(
register_unity_group(
cu
center_loss_op.cu
conv_op.cu
conv_transpose_cudnn_op.cu
conv_transpose_op.cu
cos_sim_op.cu
crop_op.cu
conj_op.cu
......
......@@ -372,6 +372,16 @@
data_type : input
backward : conv3d_double_grad
- backward_op : conv3d_transpose_grad
forward : conv3d_transpose(Tensor x, Tensor filter, int[] strides={1, 1, 1}, int[] paddings={0, 0, 0}, int[] output_padding={}, int[] output_size={}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1, 1}, str data_format="NCHW") -> Tensor(out)
args : (Tensor x, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format)
output : Tensor(x_grad), Tensor(filter_grad)
infer_meta :
func : ConvTransposeGradInferMeta
kernel :
func : conv3d_transpose_grad
data_type : x
- backward_op : cos_double_grad
forward : cos_grad (Tensor x, Tensor grad_out) -> Tensor(grad_x)
args : (Tensor x, Tensor grad_out, Tensor grad_x_grad)
......
......@@ -146,26 +146,19 @@
func : Conv2dTransposeDoubleGradInferMeta
kernel :
func : conv2d_transpose_double_grad
data_type : x
- backward_op : conv2d_transpose_grad
forward : conv2d_transpose(Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(out)
forward : conv2d_transpose(Tensor x, Tensor filter, int[] strides={1, 1}, int[] paddings={0, 0}, int[] output_padding={}, IntArray output_size={}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1}, str data_format="NCHW") -> Tensor(out)
args : (Tensor x, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format)
output : Tensor(x_grad), Tensor(filter_grad)
infer_meta :
func : Conv2dTransposeGradInferMeta
kernel :
func : conv2d_transpose_grad
data_type : x
backward : conv2d_transpose_double_grad
- backward_op : conv3d_transpose_grad
forward : conv3d_transpose(Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(out)
args : (Tensor x, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format)
output : Tensor(x_grad), Tensor(filter_grad)
infer_meta :
func : ConvTransposeGradInferMeta
kernel :
func : conv3d_transpose_grad
- backward_op : cumsum_grad
forward : cumsum(Tensor x, Scalar axis, bool flatten, bool exclusive, bool reverse) -> Tensor(out)
args : (Tensor x, Tensor out_grad, Scalar axis, bool flatten, bool exclusive, bool reverse)
......@@ -190,13 +183,14 @@
optional : mask
- backward_op : depthwise_conv2d_transpose_grad
forward : depthwise_conv2d_transpose(Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(out)
forward : depthwise_conv2d_transpose(Tensor x, Tensor filter, int[] strides={1, 1}, int[] paddings={0, 0}, int[] output_padding={}, IntArray output_size={}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1}, str data_format="NCHW") -> Tensor(out)
args : (Tensor x, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format)
output : Tensor(x_grad), Tensor(filter_grad)
infer_meta :
func : Conv2dTransposeGradInferMeta
kernel :
func : depthwise_conv2d_transpose_grad
data_type : x
- backward_op : divide_double_grad
forward : divide_grad (Tensor x, Tensor y, Tensor out, Tensor grad_out, int axis = -1) -> Tensor(grad_x), Tensor(grad_y)
......
......@@ -161,23 +161,15 @@
backward : concat_grad
- op : conv2d_transpose
args : (Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format)
args : (Tensor x, Tensor filter, int[] strides={1, 1}, int[] paddings={0, 0}, int[] output_padding={}, IntArray output_size={}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1}, str data_format="NCHW")
output : Tensor(out)
infer_meta :
func : Conv2dTransposeInferMeta
kernel :
func : conv2d_transpose
data_type : x
backward : conv2d_transpose_grad
- op : conv3d_transpose
args : (Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format)
output : Tensor(out)
infer_meta :
func : ConvTransposeInferMeta
kernel :
func : conv3d_transpose
backward : conv3d_transpose_grad
- op : copy_to
args : (Tensor x, Place place, bool blocking)
output : Tensor(out)
......@@ -215,12 +207,13 @@
backward : deformable_conv_grad
- op : depthwise_conv2d_transpose
args : (Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format)
args : (Tensor x, Tensor filter, int[] strides={1, 1}, int[] paddings={0, 0}, int[] output_padding={}, IntArray output_size={}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1}, str data_format="NCHW")
output : Tensor(out)
infer_meta :
func : Conv2dTransposeInferMeta
kernel :
func : depthwise_conv2d_transpose
data_type : x
backward : depthwise_conv2d_transpose_grad
- op : distribute_fpn_proposals
......
......@@ -544,8 +544,17 @@
int workspace_size_MB = phi::backends::gpu::GetDefaultConvWorkspaceSizeLimitMB(), bool exhaustive_search = false]
- op : conv2d_transpose
backward : conv2d_transpose_grad
backward : conv2d_transpose_grad, conv2d_transpose_double_grad (conv2d_transpose_grad_grad)
inputs :
{x : Input, filter : Filter, bias : Bias}
outputs :
out : Output
int_array :
output_size :
data_type : int
support_tensor : true
extra :
inputs : [bias]
attrs : [bool is_test = false, bool use_cudnn = true, bool use_mkldnn = false, bool force_fp32_output = false,
str mkldnn_data_type = "float32", bool fuse_relu = false,
str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f,
......@@ -567,6 +576,10 @@
- op : conv3d_transpose
backward : conv3d_transpose_grad
inputs :
{x : Input, filter : Filter}
outputs :
out : Output
extra :
attrs : [bool use_cudnn = true, bool use_mkldnn = false, int workspace_size_MB = phi::backends::gpu::GetDefaultConvWorkspaceSizeLimitMB()]
......@@ -675,7 +688,16 @@
- op : depthwise_conv2d_transpose
backward : depthwise_conv2d_transpose_grad
inputs :
{x : Input, filter : Filter, bias: Bias}
outputs :
out : Output
int_array :
output_size :
data_type : int
support_tensor : true
extra :
inputs : [bias]
attrs : [bool is_test = false, bool use_cudnn = false, bool use_mkldnn = false, bool force_fp32_output = false,
str mkldnn_data_type = "float32", bool fuse_relu = false,
str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f,
......
......@@ -99,6 +99,22 @@
gradient accumulation.
default : "false"
- op : conv2d_transpose
version :
- checkpoint : Upgrade convtranspose add a new attribute [output_padding].
action :
- add_attr : output_padding
comment : In order to add additional size to one side of each dimension in the output.
default : "std::vector<int>{}"
- checkpoint : Upgrade conv2d transpose to add a new attributes [force_fp32_output, mkldnn_data_type].
action :
- add_attr : force_fp32_output
comment : Force BF16 kernel output FP32, only used in MKL-DNN BF16.
default : "false"
- add_attr : mkldnn_data_type
comment : Data type of mkldnn kernel.
default : "\"float32\""
- op : conv3d
version :
- checkpoint : Upgrade conv3d, add a new attribute [use_addto].
......@@ -108,6 +124,22 @@
gradient accumulation.
default : "false"
- op : conv3d_transpose
version :
- checkpoint : Upgrade convtranspose add a new attribute [output_padding].
action :
- add_attr : output_padding
comment : In order to add additional size to one side of each dimension in the output.
default : "std::vector<int>{}"
- op : conv_transpose
version :
- checkpoint : Upgrade convtranspose add a new attribute [output_padding].
action :
- add_attr : output_padding
comment : In order to add additional size to one side of each dimension in the output.
default : "std::vector<int>{}"
- op : depthwise_conv2d
version :
- checkpoint : Upgrade depthwise_conv2d, add a new attribute [use_addto].
......@@ -117,6 +149,14 @@
gradient accumulation.
default : "false"
- op : depthwise_conv2d_transpose
version :
- checkpoint : Upgrade convtranspose add a new attribute [output_padding].
action :
- add_attr : output_padding
comment : In order to add additional size to one side of each dimension in the output.
default : "std::vector<int>{}"
- op : embedding
version :
- checkpoint : Upgrade flip, add new attr [axis] and delete attr [dims]
......
......@@ -484,6 +484,16 @@
func : conv3d
backward : conv3d_grad
- op : conv3d_transpose
args : (Tensor x, Tensor filter, int[] strides={1, 1, 1}, int[] paddings={0, 0, 0}, int[] output_padding={}, int[] output_size={}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1, 1}, str data_format="NCHW")
output : Tensor(out)
infer_meta :
func : ConvTransposeInferMeta
kernel :
func : conv3d_transpose
data_type : x
backward : conv3d_transpose_grad
- op : cos
args : (Tensor x)
output : Tensor
......
......@@ -27,6 +27,30 @@
composite: assign_grad(out_grad, x_grad)
invoke : assign(out_grad)
- backward_op : conv2d_transpose_double_grad
forward : conv2d_transpose_grad(Tensor x, Tensor filter, Tensor bias, Tensor grad_out, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(grad_x), Tensor(grad_filter)
args : (Tensor x, Tensor filter, Tensor grad_out, Tensor grad_x_grad, Tensor grad_filter_grad, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format)
output : Tensor(x_grad), Tensor(filter_grad), Tensor(grad_out_grad)
infer_meta :
func : Conv2dTransposeDoubleGradInferMeta
kernel :
func : conv2d_transpose_double_grad
data_type : x
- backward_op : conv2d_transpose_grad
forward : conv2d_transpose(Tensor x, Tensor filter, Tensor bias, int[] strides={1, 1}, int[] paddings={0, 0}, int[] output_padding={}, IntArray output_size={}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1}, str data_format="NCHW") -> Tensor(out)
args : (Tensor x, Tensor filter, Tensor bias, Tensor out_grad, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format)
output : Tensor(x_grad), Tensor(filter_grad)
infer_meta :
func : Conv2dTransposeGradInferMeta
param : [x, filter, out_grad, strides, paddings, output_padding, output_size, padding_algorithm, groups, dilations, data_format]
kernel :
func : conv2d_transpose_grad
param : [x, filter, out_grad, strides, paddings, output_padding, output_size, padding_algorithm, groups, dilations, data_format]
data_type : x
optional : bias
backward : conv2d_transpose_double_grad
- backward_op : deformable_conv_grad
forward : deformable_conv (Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides={1, 1}, int[] paddings={0, 0}, int[] dilations={1, 1}, int deformable_groups=1, int groups=1, int im2col_step=64) -> Tensor(out)
args : (Tensor x, Tensor offset, Tensor filter, Tensor mask, Tensor out_grad, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step)
......@@ -37,6 +61,19 @@
func : deformable_conv_grad
data_type : x
- backward_op : depthwise_conv2d_transpose_grad
forward : depthwise_conv2d_transpose(Tensor x, Tensor filter, Tensor bias, int[] strides={1, 1}, int[] paddings={0, 0}, int[] output_padding={}, IntArray output_size={}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1}, str data_format="NCHW") -> Tensor(out)
args : (Tensor x, Tensor filter, Tensor bias, Tensor out_grad, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format)
output : Tensor(x_grad), Tensor(filter_grad)
infer_meta :
func : Conv2dTransposeGradInferMeta
param : [x, filter, out_grad, strides, paddings, output_padding, output_size, padding_algorithm, groups, dilations, data_format]
kernel :
func : depthwise_conv2d_transpose_grad
param : [x, filter, out_grad, strides, paddings, output_padding, output_size, padding_algorithm, groups, dilations, data_format]
data_type : x
optional : bias
- backward_op : embedding_grad
forward : embedding (Tensor x, Tensor weight, int64_t padding_idx=-1) -> Tensor(out)
args : (Tensor x, Tensor weight, Tensor out_grad, int64_t padding_idx=-1)
......
......@@ -100,6 +100,19 @@
func : broadcast
param: [x, root]
- op : conv2d_transpose
args : (Tensor x, Tensor filter, Tensor bias, int[] strides={1, 1}, int[] paddings={0, 0}, int[] output_padding={}, IntArray output_size={}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1}, str data_format="NCHW")
output : Tensor(out)
infer_meta :
func : Conv2dTransposeInferMeta
param : [x, filter, strides, paddings, output_padding, output_size, padding_algorithm, groups, dilations, data_format]
kernel :
func : conv2d_transpose
param : [x, filter, strides, paddings, output_padding, output_size, padding_algorithm, groups, dilations, data_format]
data_type : x
optional : bias
backward : conv2d_transpose_grad
- op : decode_jpeg
args : (Tensor x, str mode = "unchanged")
output : Tensor(out)
......@@ -120,6 +133,19 @@
data_type : x
backward : deformable_conv_grad
- op : depthwise_conv2d_transpose
args : (Tensor x, Tensor filter, Tensor bias, int[] strides={1, 1}, int[] paddings={0, 0}, int[] output_padding={}, IntArray output_size={}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1}, str data_format="NCHW")
output : Tensor(out)
infer_meta :
func : Conv2dTransposeInferMeta
param : [x, filter, strides, paddings, output_padding, output_size, padding_algorithm, groups, dilations, data_format]
kernel :
func : depthwise_conv2d_transpose
param : [x, filter, strides, paddings, output_padding, output_size, padding_algorithm, groups, dilations, data_format]
data_type : x
optional : bias
backward : depthwise_conv2d_transpose_grad
- op : embedding
args : (Tensor x, Tensor weight, int64_t padding_idx=-1)
output : Tensor
......
......@@ -16,6 +16,7 @@
#include "paddle/phi/backends/onednn/onednn_helper.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/compat/get_kerneltype_forvar_utils.h"
#include "paddle/phi/core/expect.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
......@@ -419,6 +420,30 @@ void Conv2dTransposeKernel(const Context& dev_ctx,
}
}
KernelKey ConvTransposeGetKernelTypeForVar(
const GetKernelTypeForVarContext* ctx) {
const std::string& var_name = ctx->GetVarName();
const DenseTensor& tensor = ctx->GetTensor();
const KernelKey& expected_kernel_type = ctx->GetKernelKey();
const AttributeMap& attrs = ctx->GetAttrs();
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if ((var_name == "Input") &&
(expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto it = attrs.find("data_format");
const std::string data_format = PADDLE_GET_CONST(std::string, it->second);
auto dl = phi::StringToDataLayout(data_format);
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if (dl != phi::DataLayout::kAnyLayout) {
return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype());
}
}
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
} // namespace phi
PD_REGISTER_KERNEL(conv2d_transpose,
......@@ -426,4 +451,6 @@ PD_REGISTER_KERNEL(conv2d_transpose,
ONEDNN,
phi::Conv2dTransposeKernel,
float,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16) {
kernel->get_kerneltype_forvar_fn_ = phi::ConvTransposeGetKernelTypeForVar;
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature Conv2dTransposeOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("conv2d_transpose",
{"Input", "Filter"},
{"strides",
"paddings",
"output_padding",
"output_size",
"padding_algorithm",
"groups",
"dilations",
"data_format"},
{"Output"});
}
KernelSignature Conv2dTransposeGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("conv2d_transpose_grad",
{"Input", "Filter", "Output@GRAD"},
{"strides",
"paddings",
"output_padding",
"output_size",
"padding_algorithm",
"groups",
"dilations",
"data_format"},
{"Input@GRAD", "Filter@GRAD"});
}
KernelSignature Conv2dTransposeDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("conv2d_transpose_double_grad",
{"Input", "Filter", "DOutput", "DDInput", "DDFilter"},
{"strides",
"paddings",
"output_padding",
"output_size",
"padding_algorithm",
"groups",
"dilations",
"data_format"},
{"DInput", "DFilter", "DDOutput"});
}
KernelSignature Conv3dTransposeOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("conv3d_transpose",
{"Input", "Filter"},
{"strides",
"paddings",
"output_padding",
"output_size",
"padding_algorithm",
"groups",
"dilations",
"data_format"},
{"Output"});
}
KernelSignature Conv3dTransposeGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("conv3d_transpose_grad",
{"Input", "Filter", "Output@GRAD"},
{"strides",
"paddings",
"output_padding",
"output_size",
"padding_algorithm",
"groups",
"dilations",
"data_format"},
{"Input@GRAD", "Filter@GRAD"});
}
KernelSignature DepthwiseConv2dTransposeOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("depthwise_conv2d_transpose",
{"Input", "Filter"},
{"strides",
"paddings",
"output_padding",
"output_size",
"padding_algorithm",
"groups",
"dilations",
"data_format"},
{"Output"});
}
KernelSignature DepthwiseConv2dTransposeGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("depthwise_conv2d_transpose_grad",
{"Input", "Filter", "Output@GRAD"},
{"strides",
"paddings",
"output_padding",
"output_size",
"padding_algorithm",
"groups",
"dilations",
"data_format"},
{"Input@GRAD", "Filter@GRAD"});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(conv2d_transpose_grad_grad,
conv2d_transpose_double_grad);
PD_REGISTER_ARG_MAPPING_FN(conv2d_transpose,
phi::Conv2dTransposeOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(conv2d_transpose_grad,
phi::Conv2dTransposeGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(conv2d_transpose_grad_grad,
phi::Conv2dTransposeDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(conv3d_transpose,
phi::Conv3dTransposeOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(conv3d_transpose_grad,
phi::Conv3dTransposeGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(depthwise_conv2d_transpose,
phi::DepthwiseConv2dTransposeOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(depthwise_conv2d_transpose_grad,
phi::DepthwiseConv2dTransposeGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册