未验证 提交 f3eccb3f 编写于 作者: H huangjiyi 提交者: GitHub

Support static graph code generation for conv2d, conv3d, depthwise_conv2d (#54201)

* update

* update cmake

* update

* update

* update

* update

* Revert "update cmake"

This reverts commit 1e1dc1b2bc9967b725201272607f939260070fd4.

* update

* update

* update

* update
上级 81c13b86
......@@ -456,7 +456,6 @@ if(WITH_MKLDNN)
set(TEST_CONV_BN_PASS_DEPS
conv_bn_fuse_pass
graph_to_program_pass
conv_op
conv_transpose_op
batch_norm_op
generated_op
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/conv_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
std::vector<int64_t> ConvOp::ComputeOutputShape(
framework::InferShapeContext* ctx) const {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Conv");
OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "Conv");
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::string padding_algorithm =
ctx->Attrs().Get<std::string>("padding_algorithm");
int groups = ctx->Attrs().Get<int>("groups");
std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
int dilation_size = dilations.size();
for (int i = 0; i < dilation_size; ++i) {
PADDLE_ENFORCE_GT(
dilations[i],
0,
platform::errors::InvalidArgument(
"The dilation of Op(Conv) should be larget than 0, but received "
"dilation is %d.",
dilations[i]));
}
const std::string data_format = ctx->Attrs().Get<std::string>("data_format");
// MKL-DNN Kernels are using NCHW order of dims description
// so we ignore data_format consideration for MKL-DNN kernel
const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) &&
(data_format == "NHWC" || data_format == "NDHWC");
PADDLE_ENFORCE_EQ(
in_dims.size() == 4 || in_dims.size() == 5,
true,
platform::errors::InvalidArgument(
"The input of Op(Conv) should be a 4-D or 5-D Tensor. But "
"received: input's dimension is %u, input's shape is [%s].",
in_dims.size(),
in_dims));
PADDLE_ENFORCE_EQ(
in_dims.size(),
filter_dims.size(),
platform::errors::InvalidArgument(
"The input's dimension and filter's dimension of "
"Op(Conv) should be equal. But received: the input's shape is [%s], "
"the input's dimension is %d; the filter's shape is [%s], "
"the filter's dimension is %d.",
in_dims,
in_dims.size(),
filter_dims,
filter_dims.size()));
int stride_size = strides.size();
for (int i = 0; i < stride_size; ++i) {
PADDLE_ENFORCE_GT(
strides[i],
0,
platform::errors::InvalidArgument(
"The stride of Op(Conv) should be larget than 0, but received "
"stride is %d.",
strides[i]));
}
int in_sub_stride_size = in_dims.size() - stride_size;
PADDLE_ENFORCE_EQ(
in_dims.size(),
strides.size() + 2U,
platform::errors::InvalidArgument(
"The difference of input's dimension and Attr(strides)'s "
"length must be euqal to 2 for Op(Conv). "
"But received: input's dimension is %d, input's shape is [%s]; "
"Attr(stride)'s length is %d, Attr(stride) is [%s]; "
"difference of input's dimention and Attr(strides)'s length = %u.",
in_dims.size(),
in_dims,
strides.size(),
phi::make_ddim(strides),
in_sub_stride_size));
const auto input_channels =
channel_last ? in_dims[in_dims.size() - 1] : in_dims[1];
PADDLE_ENFORCE_EQ(
input_channels,
filter_dims[1] * groups,
platform::errors::InvalidArgument(
"The number of input's channels should be equal to filter's channels "
"* groups for Op(Conv). But received: the input's channels is %d, "
"the input's shape is [%s]; the filter's channels is %d, the "
"filter's shape is [%s]; the groups is %d, the data_format is %s. "
"The error may come from wrong data_format setting.",
input_channels,
in_dims,
filter_dims[1],
filter_dims,
groups,
data_format));
PADDLE_ENFORCE_EQ(
filter_dims[0] % groups,
0,
platform::errors::InvalidArgument(
"The number of output's channels (filter's first dimension) of "
"Op(Conv) should be divided by groups. But received: "
"the output channels is %d, the filter's shape is [%s], "
"the groups is %d.",
filter_dims[0],
filter_dims,
groups));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_GT(
filter_dims[0],
0,
platform::errors::InvalidArgument(
"the size of filter at axis 0 should be greater than 0"));
}
framework::DDim in_data_dims;
if (channel_last) {
in_data_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
} else {
in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
}
framework::DDim filter_data_dims =
phi::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(
&paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize);
std::vector<int64_t> output_shape({in_dims[0]});
if (!channel_last) {
output_shape.push_back(filter_dims[0]);
}
for (int i = 0; i < in_data_dims.size(); ++i) {
if ((!ctx->IsRuntime()) &&
(in_data_dims[i] <= 0 || filter_dims[i + 2] <= 0)) {
output_shape.push_back(-1);
} else {
output_shape.push_back(ConvOutputSize(in_data_dims[i],
filter_data_dims[i],
dilations[i],
paddings[2 * i],
paddings[2 * i + 1],
strides[i]));
}
}
if (channel_last) {
output_shape.push_back(filter_dims[0]);
}
return output_shape;
}
phi::KernelKey ConvOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
// todo enable data layout when it's ready
// (https://github.com/PaddlePaddle/Paddle/pull/20042)
if (input_data_type != framework::proto::VarType::INT8 &&
input_data_type != framework::proto::VarType::UINT8 &&
input_data_type != framework::proto::VarType::BF16) {
auto filter_data_type = framework::TransToProtoVarType(
ctx.Input<phi::DenseTensor>("Filter")->dtype());
PADDLE_ENFORCE_EQ(
input_data_type,
filter_data_type,
platform::errors::InvalidArgument(
"input and filter data type should be consistent, "
"but received input data type is %s and filter type "
"is %s",
paddle::framework::DataTypeToString(input_data_type),
paddle::framework::DataTypeToString(filter_data_type)));
}
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
phi::KernelKey ConvOp::GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if ((var_name == "Input") &&
(expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
auto dl = phi::StringToDataLayout(data_format);
// Some models may have intentionally set "AnyLayout" for conv
// op. Treat this as NCHW (default data_format value)
if (dl != phi::DataLayout::kAnyLayout) {
return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype());
}
}
#endif
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
void Conv2DOpMaker::Make() {
AddInput("Input",
"(Tensor) The input tensor of convolution operator. "
"The format of input tensor is NCHW or NHWC, where N is batch size, "
"C is the "
"number of channels, H is the height of the feature, "
"and W is the width of the feature.");
AddInput("Filter",
"(Tensor) The filter tensor of convolution operator. "
"The format of the filter tensor is MCHW, where M is the number of "
"output image channels, C is the number of input image channels, "
"H is the height of the filter, and W is the width of the filter. "
"If the groups attribute is greater than 1, C equals the number of "
"input image channels divided by the groups.");
AddOutput("Output",
"(Tensor) The output tensor of convolution operator. "
"It has same data fromat and data type as the Input.");
AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of "
"convolution operator.")
.SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings",
"(vector<int> default:{0, 0}), the "
"paddings(pad_height_top, pad_height_bottom, "
"pad_width_left, pad_wifth_right) of "
"convolution operator.")
.SetDefault({0, 0});
AddAttr<std::string>(
"padding_algorithm",
"(string, default \"EXPLICIT\") An optional string from: \"EXPLICIT\","
"\"SAME\",\"VALID\". Set to \"EXPLICIT\" for explicit padding. "
"Set to \"SAME\" or \"VALID\" for algorithm of padding. ")
.SetDefault("EXPLICIT");
AddAttr<int>(
"groups",
"(int default:1), the groups number of the convolution operator. "
"According to grouped convolution in Alex Krizhevsky's Deep CNN paper: "
"when group=2, the first half of the filters is only connected to the "
"first half of the input channels, while the second half of the filters "
"is only connected to the second half of the input channels.")
.SetDefault(1);
AddAttr<std::vector<int>>("dilations",
"(vector<int> default:{1, 1}), the "
"dilations(h_dilation, w_dilation) of "
"convolution operator.")
.SetDefault({1, 1});
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("NCHW");
// TODO(dzhwinter): need to registered layout transform function
AddComment(R"DOC(
Convolution Operator.
The convolution operation calculates the output based on the input, filter
and strides, paddings, dilations, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
Input(Input) and Output(Output) are in NCHW or NHWC format. Where N is batch
size, C is the number of channels, H is the height of the feature, and W is
the width of the feature.
Filters(Input) is MCHW format format. Where M is the number of output image channels, C is
the number of input image channels, H is the height of the filter, and W
is the width of the filter.
Parameters(strides, paddings, dilations) are two elements. These two elements represent
height and width, respectively.
The input(X) size and output(Out) size may be different.
Example:
Input:
Input shape: $(N, C_{in}, H_{in}, W_{in})$
Filter shape: $(C_{out}, C_{in}, H_f, W_f)$
Output:
Output shape: $(N, C_{out}, H_{out}, W_{out})$
Where
$$
H_{out}= \frac{(H_{in} + pad_height_top + pad_height_bottom - (dilations[0] * (H_f - 1) + 1))}{strides[0]}+ 1 \\
W_{out}= \frac{(W_{in} + pad_width_left + pad_width_right - (dilations[1] * (W_f - 1) + 1))}{strides[1]}+ 1
$$
)DOC");
Apply();
}
class DepthwiseConv2DOpMaker : public Conv2DOpMaker {
protected:
void Apply() override {
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(false)
.AsExtra();
}
};
void Conv3DOpMaker::Make() {
AddInput(
"Input",
"(Tensor) The input tensor of convolution operator. "
"The format of input tensor is NCDHW or NDHWC. Where N is batch size, C "
"is the "
"number of channels, D is the depth of the feature, H is the height of "
"the feature, "
"and W is the width of the feature.");
AddInput("Filter",
"(Tensor) The filter tensor of convolution operator. "
"The format of the filter tensor is MCDHW, where M is the number of "
"output image channels, C is the number of input image channels, "
"D is the depth of the filter, H is the height of the filter, and W "
"is the width of the filter."
"If the groups attribute is greater than 1, C equals the number of "
"input image channels divided by the groups.");
AddOutput("Output",
"(Tensor) The output tensor of convolution operator."
"It has same data fromat and data type as the Input.");
AddAttr<std::vector<int>>("strides",
"(vector<int>, default:{1, 1, 1}), the "
"strides(d_stride, h_stride, w_stride) of "
"convolution operator.")
.SetDefault({1, 1, 1});
AddAttr<std::vector<int>>(
"paddings",
"(vector<int>, default:{0, 0, 0}), the "
"paddings(pad_depth_front, pad_depth_back, pad_height_top, "
"pad_height_bottom, pad_width_left, pad_width_right) of convolution "
"operator.")
.SetDefault({0, 0, 0});
AddAttr<std::string>(
"padding_algorithm",
"(string, default \"EXPLICIT\") An optional string from: \"EXPLICIT\","
"\"SAME\",\"VALID\". Set to \"EXPLICIT\" for explicit padding. "
"Set to \"SAME\" or \"VALID\" for algorithm of padding. ")
.SetDefault("EXPLICIT");
AddAttr<int>(
"groups",
"(int default:1), the groups number of the convolution operator. "
"According to grouped convolution in Alex Krizhevsky's Deep CNN paper: "
"when group=2, the first half of the filters is only connected to the "
"first half of the input channels, while the second half of the filters "
"is only connected to the second half of the input channels.")
.SetDefault(1);
AddAttr<std::vector<int>>("dilations",
"(vector<int> default:{1, 1, 1}), the "
"dilations(d_dilation, h_dilation, w_dilation) of "
"convolution operator.")
.SetDefault({1, 1, 1});
AddAttr<std::string>(
"data_format",
"(string, default NCDHW) Only used in "
"An optional string from: \"NDHWC\", \"NCDHW\". "
"Defaults to \"NDHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("NCDHW");
AddComment(R"DOC(
Convolution3D Operator.
The convolution operation calculates the output based on the input, filter
and strides, paddings, dilations, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
Input(Input) and output(Output) are in NCDHW or NDHWC format, where N is batch
size, C is the number of channels,D is the depth of the feature, H is the height of
the feature, and W is the width of the feature.
Filters(Input) is MCDHW format, where M is the number of output image channels,
C is the number of input image channels, D is the depth of the filter,
H is the height of the filter, and W is the width of the filter.
Parameters(strides, paddings, dilations) are three elements. These three elements
represent depth, height and width, respectively.
The input(X) size and output(Out) size may be different.
Example:
Input:
Input shape: $(N, C_{in}, D_{in}, H_{in}, W_{in})$
Filter shape: $(C_{out}, C_{in}, D_f, H_f, W_f)$
Output:
Output shape: $(N, C_{out}, D_{out}, H_{out}, W_{out})$
Where
$$
D_{out}= \frac{(D_{in} + pad_depth_front + pad_depth_back - (dilations[0] * (D_f - 1) + 1))}{ strides[0]}+ 1 \\
H_{out}= \frac{(H_{in} + pad_height_top + pad_height_bottom - (dilations[1] * (H_f - 1) + 1))}{ strides[1]}+ 1 \\
W_{out}= \frac{(W_{in} + pad_width_left + pad_width_right - (dilations[2] * (W_f - 1) + 1))}{ strides[2]}+ 1
$$
)DOC");
Apply();
}
void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const {
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
if (ctx->HasOutput(framework::GradVarName("Input"))) {
ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
}
if (ctx->HasOutput(framework::GradVarName("Filter"))) {
ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
}
}
phi::KernelKey ConvOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
return phi::KernelKey(data_type, ctx.GetPlace());
}
phi::KernelKey ConvOpGrad::GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if (((var_name == "Input") ||
(var_name == framework::GradVarName("Output"))) &&
(expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
auto dl = phi::StringToDataLayout(data_format);
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if (dl != phi::DataLayout::kAnyLayout) {
return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype());
}
}
#endif
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
template <typename T>
class Conv2DGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", this->Input("Filter"));
op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("Filter"), this->InputGrad("Filter"));
if (this->HasInput("Bias")) {
op->SetInput("Bias", this->Input("Bias"));
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
}
op->SetAttrMap(this->Attrs());
}
};
template <typename T>
class Conv3DGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", this->Input("Filter"));
op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("Filter"), this->InputGrad("Filter"));
if (this->HasInput("ResidualData")) {
op->SetInput("ResidualData", this->Input("ResidualData"));
}
op->SetAttrMap(this->Attrs());
}
};
/*
* Inputs: I, W, dO, ddI, ddW
* Outputs: ddO, dW, dI
*/
template <typename T>
class Conv2DDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
// I, W, dO, ddI, ddW
op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", this->Input("Filter"));
op->SetInput("DOutput", this->Input(framework::GradVarName("Output")));
op->SetInput("DDInput", this->OutputGrad(framework::GradVarName("Input")));
op->SetInput("DDFilter",
this->OutputGrad(framework::GradVarName("Filter")));
// ddO, dI, dW
// Unlike grad op, double grad op does not use name@GRAD@GRAD
// as key of ops' inputs and outputs.
auto ddx = this->OutputGrad(framework::GradVarName("Input"));
auto ddw = this->OutputGrad(framework::GradVarName("Filter"));
op->SetOutput("DDOutput",
ddx.empty()
? this->EmptyInputGrad()
: this->InputGrad(framework::GradVarName("Output")));
op->SetOutput(
"DFilter",
ddx.empty() ? this->EmptyInputGrad() : this->InputGrad("Filter"));
op->SetOutput(
"DInput",
ddw.empty() ? this->EmptyInputGrad() : this->InputGrad("Input"));
op->SetAttrMap(this->Attrs());
}
};
/*
* Inputs: I, W, dO, ddI, ddW
* Outputs: ddO, dW, dI
*/
template <typename T>
class Conv3DDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
// I, W, dO, ddI, ddW
op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", this->Input("Filter"));
op->SetInput("DOutput", this->Input(framework::GradVarName("Output")));
op->SetInput("DDInput", this->OutputGrad(framework::GradVarName("Input")));
op->SetInput("DDFilter",
this->OutputGrad(framework::GradVarName("Filter")));
auto ddx = this->OutputGrad(framework::GradVarName("Input"));
auto ddw = this->OutputGrad(framework::GradVarName("Filter"));
op->SetOutput("DDOutput",
ddx.empty()
? this->EmptyInputGrad()
: this->InputGrad(framework::GradVarName("Output")));
op->SetOutput(
"DFilter",
ddx.empty() ? this->EmptyInputGrad() : this->InputGrad("Filter"));
op->SetOutput(
"DInput",
ddw.empty() ? this->EmptyInputGrad() : this->InputGrad("Input"));
op->SetAttrMap(this->Attrs());
}
};
void ConvOpDoubleGrad::InferShape(framework::InferShapeContext* ctx) const {
auto x_dims = ctx->GetInputDim("Input");
auto w_dims = ctx->GetInputDim("Filter");
auto do_dims = ctx->GetInputDim("DOutput");
if (ctx->HasOutput("DDOutput") &&
(ctx->HasInput("DDInput") || (ctx->HasInput("DDFilter")))) {
ctx->SetOutputDim("DDOutput", do_dims);
}
if (ctx->HasOutput("DFilter") && ctx->HasInput("DDInput")) {
ctx->SetOutputDim("DFilter", w_dims);
}
if (ctx->HasOutput("DInput") && ctx->HasInput("DDFilter")) {
ctx->SetOutputDim("DInput", x_dims);
}
}
phi::KernelKey ConvOpDoubleGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
return phi::KernelKey(data_type, ctx.GetPlace());
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(conv2d,
ops::ConvOp,
ops::Conv2DOpMaker,
ops::ConvOpInferVarType,
ops::Conv2DGradMaker<paddle::framework::OpDesc>,
ops::Conv2DGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(conv2d_grad,
ops::ConvOpGrad,
ops::Conv2DDoubleGradMaker<paddle::framework::OpDesc>,
ops::Conv2DDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(conv2d_grad_grad, ops::ConvOpDoubleGrad);
// depthwise convolution op
REGISTER_OPERATOR(depthwise_conv2d,
ops::ConvOp,
ops::DepthwiseConv2DOpMaker,
ops::ConvOpInferVarType,
ops::Conv2DGradMaker<paddle::framework::OpDesc>,
ops::Conv2DGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(depthwise_conv2d_grad,
ops::ConvOpGrad,
ops::Conv2DDoubleGradMaker<paddle::framework::OpDesc>,
ops::Conv2DDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(depthwise_conv2d_grad_grad, ops::ConvOpDoubleGrad);
REGISTER_OPERATOR(conv3d,
ops::ConvOp,
ops::Conv3DOpMaker,
ops::ConvOpInferVarType,
ops::Conv3DGradMaker<paddle::framework::OpDesc>,
ops::Conv3DGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(conv3d_grad,
ops::ConvOpGrad,
ops::Conv3DDoubleGradMaker<paddle::framework::OpDesc>,
ops::Conv3DDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(conv3d_grad_grad, ops::ConvOpDoubleGrad);
REGISTER_OP_VERSION(conv2d).AddCheckpoint(
R"ROC(
Upgrade conv2d, add a new attribute [use_addto].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"use_addto",
"In order to support new feature (inplace addto strategy) for "
"gradient accumulation.",
false));
REGISTER_OP_VERSION(depthwise_conv2d)
.AddCheckpoint(
R"ROC(
Upgrade depthwise_conv2d, add a new attribute [use_addto].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"use_addto",
"In order to support new feature (inplace addto strategy) for "
"gradient accumulation.",
false));
REGISTER_OP_VERSION(conv3d).AddCheckpoint(
R"ROC(
Upgrade conv3d, add a new attribute [use_addto].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"use_addto",
"In order to support new feature (inplace addto strategy) for "
"gradient accumulation.",
false));
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/layout_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/im2col.h"
#include "paddle/phi/kernels/funcs/vol2col.h"
namespace paddle {
namespace operators {
// Base convolution operator definitions for other conv
// like operators to reuse the implementation.
inline int ConvOutputSize(
int input_size, int filter_size, int dilation, int padding, int stride) {
const int dkernel = dilation * (filter_size - 1) + 1;
int output_size = (input_size + 2 * padding - dkernel) / stride + 1;
PADDLE_ENFORCE_GT(
output_size,
0,
platform::errors::InvalidArgument(
"The output's size is expected to be greater than 0. "
"But received: output's size is %d. The output's size is computed by "
"((input_size + 2 * padding - (dilation * (filter_size - 1) + 1)) / "
"stride + 1), where input_size is %d, padding is %d, "
"filter_size is %d, dilation is %d, stride is %d.",
output_size,
input_size,
padding,
filter_size,
dilation,
stride));
return output_size;
}
inline int ConvOutputSize(int input_size,
int filter_size,
int dilation,
int padding_1,
int padding_2,
int stride) {
const int dkernel = dilation * (filter_size - 1) + 1;
int output_size = (input_size + padding_1 + padding_2 - dkernel) / stride + 1;
PADDLE_ENFORCE_GT(
output_size,
0,
platform::errors::InvalidArgument(
"The output's size is expected to be greater than 0. "
"But received: output's size is %d. The output's size is computed by "
"((input_size + padding_1 + padding_2 - (dilation * (filter_size - "
"1) + 1)) / stride + 1), where input_size is %d, padding is "
"(%d, %d), filter_size is %d, dilation is %d, stride is %d.",
output_size,
input_size,
padding_1,
padding_2,
filter_size,
dilation,
stride));
return output_size;
}
template <typename T = int>
inline void UpdatePaddingAndDilation(std::vector<T>* paddings,
std::vector<T>* dilation,
const std::string padding_algorithm,
const framework::DDim data_dims,
const std::vector<T>& strides,
const std::vector<T>& ksize) {
// set padding size == data_dims.size() * 2
auto data_shape = phi::vectorize<T>(data_dims);
if (static_cast<int>(paddings->size()) == data_dims.size()) {
for (int i = 0; i < data_dims.size(); ++i) {
T copy_pad = *(paddings->begin() + 2 * i);
paddings->insert(paddings->begin() + 2 * i + 1, copy_pad);
}
} else {
PADDLE_ENFORCE_EQ(
data_dims.size() * 2,
paddings->size(),
platform::errors::InvalidArgument(
"Attribute padding's size should be the same or twice as the "
"input's dimension. "
"But received: padding's size is %d, padding is [%s]; input's "
"dimension is %d, input's shape is [%s].",
paddings->size(),
phi::make_ddim(*paddings),
data_dims.size(),
data_dims));
}
// when padding_algorithm is "VALID" or "SAME"
if (padding_algorithm == "SAME") {
for (int i = 0; i < data_dims.size(); ++i) {
T out_size = (data_dims[i] + strides[i] - 1) / strides[i];
T pad_sum =
std::max((out_size - 1) * strides[i] + ksize[i] - data_shape[i],
static_cast<T>(0));
T pad_0 = pad_sum / 2;
T pad_1 = pad_sum - pad_0;
*(paddings->begin() + i * 2) = pad_0;
*(paddings->begin() + i * 2 + 1) = pad_1;
// dilation
*(dilation->begin() + i) = 1;
}
} else if (padding_algorithm == "VALID") {
for (auto it = paddings->begin(); it != paddings->end(); it++) {
*it = 0;
}
}
}
inline bool IsExpand(const std::vector<int64_t>& filter_dim,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations) {
bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true;
for (size_t j = 0; j < strides.size(); ++j) {
filter_1 = filter_1 && (static_cast<int>(filter_dim[j + 2]) == 1);
strides_1 = strides_1 && (strides[j] == 1);
padding_0 = padding_0 && (paddings[j] == 0);
dilation_1 = dilation_1 && (dilations[j] == 1);
}
if (paddings.size() != strides.size()) {
for (size_t j = 0; j < paddings.size(); ++j) {
padding_0 = padding_0 && (paddings[j] == 0);
}
}
return !(filter_1 && strides_1 && padding_0 && dilation_1);
}
// Define Op classes in .h file so that other conv
// operator implementations can reuse the code.
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() final;
protected:
virtual void Apply() {}
};
class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() final;
protected:
virtual void Apply() {}
};
class ConvOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{
{"Input", /*->*/ "Output"}};
return m;
}
};
class ConvOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
std::vector<int64_t> output_shape = ComputeOutputShape(ctx);
OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "Conv");
ctx->SetOutputDim("Output", phi::make_ddim(output_shape));
ctx->ShareLoD("Input", "Output");
}
protected:
std::vector<int64_t> ComputeOutputShape(
framework::InferShapeContext* ctx) const;
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const override;
};
class ConvOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const override;
};
class ConvOpDoubleGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
} // namespace operators
} // namespace paddle
......@@ -15,13 +15,42 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/generator/get_expected_kernel_func.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
namespace paddle {
namespace operators {
inline int ConvOutputSize(int input_size,
int filter_size,
int dilation,
int padding_1,
int padding_2,
int stride) {
const int dkernel = dilation * (filter_size - 1) + 1;
int output_size = (input_size + padding_1 + padding_2 - dkernel) / stride + 1;
PADDLE_ENFORCE_GT(
output_size,
0,
platform::errors::InvalidArgument(
"The output's size is expected to be greater than 0. "
"But received: output's size is %d. The output's size is computed by "
"((input_size + padding_1 + padding_2 - (dilation * (filter_size - "
"1) + 1)) / stride + 1), where input_size is %d, padding is "
"(%d, %d), filter_size is %d, dilation is %d, stride is %d.",
output_size,
input_size,
padding_1,
padding_2,
filter_size,
dilation,
stride));
return output_size;
}
// This fused conv follows the equation:
// y = act ( alpha1 * conv(x) + alpha2 * z + bias ).
// here, y is Output,
......@@ -30,9 +59,36 @@ namespace operators {
// bias is Bias
// When `split_channels` is set, y will be split into multiple outputs,
// each output has split_channels[i] number of channels.
class Conv2DFusionOpMaker : public Conv2DOpMaker {
class Conv2DFusionOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input", "(Tensor), input 0 of conv2d op.");
AddInput("Filter", "(Tensor), input 1 of conv2d op.");
AddOutput("Output", "(Tensor), output 0 of conv2d op.");
AddAttr<std::vector<int>>("strides",
"(std::vector<int>), attribute 0 for conv2d op.")
.SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings",
"(std::vector<int>), attribute 1 for conv2d op.")
.SetDefault({0, 0});
AddAttr<std::string>("padding_algorithm",
"(std::string), attribute 2 for conv2d op.")
.SetDefault("EXPLICIT");
AddAttr<std::vector<int>>("dilations",
"(std::vector<int>), attribute 3 for conv2d op.")
.SetDefault({1, 1});
AddAttr<int>("groups", "(int), attribute 4 for conv2d op.").SetDefault(1);
AddAttr<std::string>("data_format",
"(std::string), attribute 5 for conv2d op.")
.SetDefault("NCHW");
AddComment(R"DOC(
TODO: Documentation of conv2d op.
)DOC");
Apply();
}
protected:
void Apply() override {
void Apply() {
AddInput("Bias",
"(Tensor) Bias to be added to each output of filter application."
"The format of output tensor is X (one-dimensional) of size equal"
......@@ -73,9 +129,9 @@ class Conv2DFusionOpMaker : public Conv2DOpMaker {
}
};
class Conv2DFusionOp : public operators::ConvOp {
class Conv2DFusionOp : public framework::OperatorWithKernel {
public:
using operators::ConvOp::ConvOp;
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
......@@ -275,7 +331,7 @@ class Conv2DFusionOp : public operators::ConvOp {
}
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(
phi::UpdatePaddingAndDilation(
&paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize);
std::vector<int64_t> output_shape({in_dims[0]});
......@@ -301,6 +357,11 @@ class Conv2DFusionOp : public operators::ConvOp {
return output_shape;
}
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetConvExpectedKernelType(ctx, this);
}
};
// TODO(qingqing): add gradient operator for conv2d_fusion
......@@ -313,7 +374,6 @@ REGISTER_OPERATOR(
conv2d_fusion,
ops::Conv2DFusionOp,
ops::Conv2DFusionOpMaker,
ops::ConvOpInferVarType,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
......@@ -323,6 +383,5 @@ REGISTER_OPERATOR(
conv2d_fusion_cutlass,
ops::Conv2DFusionOp,
ops::Conv2DFusionOpMaker,
ops::ConvOpInferVarType,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
......@@ -15,14 +15,44 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/generator/get_expected_kernel_func.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
class FusedConvOpMaker : public Conv2DOpMaker {
class FusedConvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input", "(Tensor), input 0 of conv2d op.");
AddInput("Filter", "(Tensor), input 1 of conv2d op.");
AddOutput("Output", "(Tensor), output 0 of conv2d op.");
AddAttr<std::vector<int>>("strides",
"(std::vector<int>), attribute 0 for conv2d op.")
.SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings",
"(std::vector<int>), attribute 1 for conv2d op.")
.SetDefault({0, 0});
AddAttr<std::string>("padding_algorithm",
"(std::string), attribute 2 for conv2d op.")
.SetDefault("EXPLICIT");
AddAttr<std::vector<int>>("dilations",
"(std::vector<int>), attribute 3 for conv2d op.")
.SetDefault({1, 1});
AddAttr<int>("groups", "(int), attribute 4 for conv2d op.").SetDefault(1);
AddAttr<std::string>("data_format",
"(std::string), attribute 5 for conv2d op.")
.SetDefault("NCHW");
AddComment(R"DOC(
TODO: Documentation of conv2d op.
)DOC");
Apply();
}
protected:
void Apply() override {
void Apply() {
AddInput("Bias",
"(Tensor) Bias to be added to each output of filter application."
"The format of output tensor is X (one-dimensional) of size equal"
......@@ -84,25 +114,43 @@ $$
}
};
class FusedConvOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetConvExpectedKernelType(ctx, this);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(fused_conv2d,
FusedConv2DInferShapeFunctor,
PD_INFER_META(phi::FusedConvInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(fused_conv3d,
FusedConv3DInferShapeFunctor,
PD_INFER_META(phi::FusedConvInferMeta));
// fused_conv2d is only used for onednn inference.
REGISTER_OPERATOR(
fused_conv2d,
ops::ConvOp,
ops::FusedConvOp,
ops::FusedConvOpMaker,
ops::ConvOpInferVarType,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
FusedConv2DInferShapeFunctor);
// fused_conv3d is only used for onednn inference.
REGISTER_OPERATOR(
fused_conv3d,
ops::ConvOp,
ops::FusedConvOp,
ops::FusedConvOpMaker,
ops::ConvOpInferVarType,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
FusedConv3DInferShapeFunctor);
......@@ -14,7 +14,6 @@
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
......
......@@ -326,5 +326,31 @@ phi::KernelKey GetLayerNormExpectedKernelType(
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
phi::KernelKey GetConvExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr) {
auto input_data_type = op_ptr->IndicateVarDataType(ctx, "Input");
// todo enable data layout when it's ready
// (https://github.com/PaddlePaddle/Paddle/pull/20042)
if (input_data_type != framework::proto::VarType::INT8 &&
input_data_type != framework::proto::VarType::UINT8 &&
input_data_type != framework::proto::VarType::BF16) {
auto filter_data_type = framework::TransToProtoVarType(
ctx.Input<phi::DenseTensor>("Filter")->dtype());
PADDLE_ENFORCE_EQ(
input_data_type,
filter_data_type,
platform::errors::InvalidArgument(
"input and filter data type should be consistent, "
"but received input data type is %s and filter type "
"is %s",
paddle::framework::DataTypeToString(input_data_type),
paddle::framework::DataTypeToString(filter_data_type)));
}
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
} // namespace operators
} // namespace paddle
......@@ -80,5 +80,9 @@ phi::KernelKey GetLayerNormExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);
phi::KernelKey GetConvExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);
} // namespace operators
} // namespace paddle
......@@ -324,6 +324,54 @@
output : Tensor(x_grad)
invoke : conj(out_grad)
- backward_op : conv2d_grad
forward : conv2d (Tensor input, Tensor filter, int[] strides={1, 1}, int[] paddings={0, 0}, str padding_algorithm="EXPLICIT", int[] dilations={1, 1}, int groups=1, str data_format="NCHW") -> Tensor(out)
args : (Tensor input, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, str padding_algorithm, int[] dilations, int groups, str data_format)
output : Tensor(input_grad), Tensor(filter_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [input, filter]
kernel :
func : conv2d_grad
data_type : input
backward : conv2d_grad_grad
- backward_op : conv2d_grad_grad
forward : conv2d_grad (Tensor input, Tensor filter, Tensor grad_out, int[] strides, int[] paddings, str padding_algorithm, int[] dilations, int groups, str data_format) -> Tensor(grad_input), Tensor(grad_filter)
args : (Tensor input, Tensor filter, Tensor grad_out, Tensor grad_input_grad, Tensor grad_filter_grad, int[] strides, int[] paddings, str padding_algorithm, int[] dilations, int groups, str data_format)
output : Tensor(input_grad), Tensor(filter_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param: [input, filter, grad_out]
kernel :
func : conv2d_double_grad
data_type : input
optional : grad_input_grad, grad_filter_grad
- backward_op : conv3d_double_grad
forward : conv3d_grad (Tensor input, Tensor filter, Tensor grad_out, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(grad_input), Tensor(grad_filter)
args : (Tensor input, Tensor filter, Tensor grad_out, Tensor grad_input_grad, Tensor grad_filter_grad, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format)
output : Tensor(input_grad), Tensor(filter_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param: [input, filter, grad_out]
kernel :
func : conv3d_double_grad
data_type : input
optional : grad_input_grad, grad_filter_grad
- backward_op : conv3d_grad
forward : conv3d (Tensor input, Tensor filter, int[] strides={1, 1, 1}, int[] paddings={0, 0, 0}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1, 1}, str data_format="NCDHW") -> Tensor(out)
args : (Tensor input, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format)
output : Tensor(input_grad), Tensor(filter_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [input, filter]
kernel :
func : conv3d_grad
data_type : input
backward : conv3d_double_grad
- backward_op : cos_double_grad
forward : cos_grad (Tensor x, Tensor grad_out) -> Tensor(grad_x)
args : (Tensor x, Tensor grad_out, Tensor grad_x_grad)
......@@ -415,6 +463,30 @@
kernel :
func : cumprod_grad
- backward_op : depthwise_conv2d_double_grad
forward : depthwise_conv2d_grad (Tensor input, Tensor filter, Tensor grad_out, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(grad_input), Tensor(grad_filter)
args : (Tensor input, Tensor filter, Tensor grad_out, Tensor grad_input_grad, Tensor grad_filter_grad, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format)
output : Tensor(input_grad), Tensor(filter_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param: [input, filter, grad_out]
kernel :
func : depthwise_conv2d_double_grad
data_type : input
optional : grad_input_grad, grad_filter_grad
- backward_op : depthwise_conv2d_grad
forward : depthwise_conv2d (Tensor input, Tensor filter, int[] strides={1, 1}, int[] paddings={0, 0}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1}, str data_format="NCHW") -> Tensor(out)
args : (Tensor input, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format)
output : Tensor(input_grad), Tensor(filter_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [input, filter]
kernel :
func : depthwise_conv2d_grad
data_type : input
backward : depthwise_conv2d_double_grad
- backward_op : det_grad
forward : det (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad)
......
......@@ -137,28 +137,6 @@
no_need_buffer : x
backward : concat_double_grad
- backward_op : conv2d_grad
forward : conv2d (Tensor input, Tensor filter, int[] strides, int[] paddings, str padding_algorithm, int[] dilations, int groups, str data_format) -> Tensor(out)
args : (Tensor input, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, str padding_algorithm, int[] dilations, int groups, str data_format)
output : Tensor(input_grad), Tensor(filter_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [input, filter]
kernel :
func : conv2d_grad
backward : conv2d_grad_grad
- backward_op : conv2d_grad_grad
forward : conv2d_grad (Tensor input, Tensor filter, Tensor grad_out, int[] strides, int[] paddings, str padding_algorithm, int[] dilations, int groups, str data_format) -> Tensor(grad_input), Tensor(grad_filter)
args : (Tensor input, Tensor filter, Tensor grad_out, Tensor grad_input_grad, Tensor grad_filter_grad, int[] strides, int[] paddings, str padding_algorithm, int[] dilations, int groups, str data_format)
output : Tensor(input_grad), Tensor(filter_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param: [input, filter, grad_out]
kernel :
func : conv2d_double_grad
optional : grad_input_grad, grad_filter_grad
- backward_op : conv2d_transpose_double_grad
forward : conv2d_transpose_grad(Tensor x, Tensor filter, Tensor grad_out, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(grad_x), Tensor(grad_filter)
args : (Tensor x, Tensor filter, Tensor grad_out, Tensor grad_x_grad, Tensor grad_filter_grad, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format)
......@@ -178,28 +156,6 @@
func : conv2d_transpose_grad
backward : conv2d_transpose_double_grad
- backward_op : conv3d_double_grad
forward : conv3d_grad (Tensor input, Tensor filter, Tensor grad_out, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(grad_input), Tensor(grad_filter)
args : (Tensor input, Tensor filter, Tensor grad_out, Tensor grad_input_grad, Tensor grad_filter_grad, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format)
output : Tensor(input_grad), Tensor(filter_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param: [input, filter, grad_out]
kernel :
func : conv3d_double_grad
optional : grad_input_grad, grad_filter_grad
- backward_op : conv3d_grad
forward : conv3d (Tensor input, Tensor filter, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(out)
args : (Tensor input, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format)
output : Tensor(input_grad), Tensor(filter_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [input, filter]
kernel :
func : conv3d_grad
backward : conv3d_double_grad
- backward_op : conv3d_transpose_grad
forward : conv3d_transpose(Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(out)
args : (Tensor x, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format)
......@@ -232,29 +188,6 @@
data_type : x
optional : mask
- backward_op : depthwise_conv2d_double_grad
forward : depthwise_conv2d_grad (Tensor input, Tensor filter, Tensor grad_out, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(grad_input), Tensor(grad_filter)
args : (Tensor input, Tensor filter, Tensor grad_out, Tensor grad_input_grad, Tensor grad_filter_grad, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format)
output : Tensor(input_grad), Tensor(filter_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param: [input, filter, grad_out]
kernel :
func : depthwise_conv2d_double_grad
optional : grad_input_grad, grad_filter_grad
- backward_op : depthwise_conv2d_grad
forward : depthwise_conv2d (Tensor input, Tensor filter, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(out)
args : (Tensor input, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format)
output : Tensor(input_grad), Tensor(filter_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [input, filter]
kernel :
func : depthwise_conv2d_grad
param : [input, filter, out_grad, strides, paddings, padding_algorithm, groups, dilations, data_format]
backward : depthwise_conv2d_double_grad
- backward_op : depthwise_conv2d_transpose_grad
forward : depthwise_conv2d_transpose(Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(out)
args : (Tensor x, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format)
......
......@@ -172,15 +172,6 @@
func : concat
backward : concat_grad
- op : conv2d
args : (Tensor input, Tensor filter, int[] strides, int[] paddings, str padding_algorithm, int[] dilations, int groups, str data_format)
output : Tensor
infer_meta :
func : ConvInferMeta
kernel :
func : conv2d
backward : conv2d_grad
- op : conv2d_transpose
args : (Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format)
output : Tensor(out)
......@@ -190,15 +181,6 @@
func : conv2d_transpose
backward : conv2d_transpose_grad
- op : conv3d
args : (Tensor input, Tensor filter, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format)
output : Tensor
infer_meta :
func : Conv3DInferMeta
kernel :
func : conv3d
backward : conv3d_grad
- op : conv3d_transpose
args : (Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format)
output : Tensor(out)
......@@ -244,17 +226,6 @@
optional : mask
backward : deformable_conv_grad
- op : depthwise_conv2d
args : (Tensor x, Tensor filter, int[] strides, int[] paddings, str padding_algorithm, int groups, int[] dilations, str data_format)
output : Tensor(out)
infer_meta :
func : DepthwiseConvInferMeta
param : [x, filter, strides, paddings, padding_algorithm, groups, dilations, data_format]
kernel :
func : depthwise_conv2d
param : [x, filter, strides, paddings, padding_algorithm, groups, dilations, data_format]
backward : depthwise_conv2d_grad
- op : depthwise_conv2d_transpose
args : (Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format)
output : Tensor(out)
......
......@@ -479,11 +479,17 @@
out : Out
- op : conv2d
backward : conv2d_grad
backward : conv2d_grad, conv2d_grad_grad
inputs :
{input : Input, filter : Filter}
outputs :
out : Output
extra :
attrs : [bool is_test = false, bool use_cudnn = true, bool use_mkldnn = false, bool use_addto = false,
bool force_fp32_output = false,
int workspace_size_MB = phi::backends::gpu::GetDefaultConvWorkspaceSizeLimitMB(), bool exhaustive_search = false]
get_expected_kernel_type :
conv2d : GetConvExpectedKernelType
- op : conv2d_fusion
extra :
......@@ -503,12 +509,18 @@
int workspace_size_MB = phi::backends::gpu::GetDefaultConvWorkspaceSizeLimitMB()]
- op : conv3d
backward : conv3d_grad
backward : conv3d_grad, conv3d_double_grad (conv3d_grad_grad)
inputs :
{input : Input, filter : Filter}
outputs :
out : Output
extra :
attrs : [bool is_test = false, bool use_cudnn = true, bool use_mkldnn = false, str mkldnn_data_type = "float32", bool fuse_relu = false,
str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f,
bool use_addto = false, bool fuse_residual_connection = false, bool force_fp32_output = false,
int workspace_size_MB = phi::backends::gpu::GetDefaultConvWorkspaceSizeLimitMB(), bool exhaustive_search = false]
get_expected_kernel_type :
conv3d : GetConvExpectedKernelType
- op : conv3d_transpose
backward : conv3d_transpose_grad
......@@ -603,14 +615,20 @@
out : Output
- op : depthwise_conv2d
backward : depthwise_conv2d_grad
backward : depthwise_conv2d_grad, depthwise_conv2d_double_grad (depthwise_conv2d_grad_grad)
inputs :
{input : Input, filter : Filter}
outputs :
out : Output
extra :
attrs : [bool is_test = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false,
attrs : [bool is_test = false, bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false,
bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false,
str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f, bool use_addto = false,
bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f,
float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false,
int workspace_size_MB = phi::backends::gpu::GetDefaultConvWorkspaceSizeLimitMB(), bool exhaustive_search = false]
get_expected_kernel_type :
depthwise_conv2d : GetConvExpectedKernelType
- op : depthwise_conv2d_transpose
backward : depthwise_conv2d_transpose_grad
......
......@@ -90,6 +90,33 @@
of each place to be compatible with before.
default : -1
- op : conv2d
version :
- checkpoint : Upgrade conv2d, add a new attribute [use_addto].
action :
- add_attr : use_addto
comment : In order to support new feature (inplace addto strategy) for
gradient accumulation.
default : "false"
- op : conv3d
version :
- checkpoint : Upgrade conv3d, add a new attribute [use_addto].
action :
- add_attr : use_addto
comment : In order to support new feature (inplace addto strategy) for
gradient accumulation.
default : "false"
- op : depthwise_conv2d
version :
- checkpoint : Upgrade depthwise_conv2d, add a new attribute [use_addto].
action :
- add_attr : use_addto
comment : In order to support new feature (inplace addto strategy) for
gradient accumulation.
default : "false"
- op : embedding
version :
- checkpoint : Upgrade flip, add new attr [axis] and delete attr [dims]
......
......@@ -454,6 +454,24 @@
func : conj
backward : conj_grad
- op : conv2d
args : (Tensor input, Tensor filter, int[] strides={1, 1}, int[] paddings={0, 0}, str padding_algorithm="EXPLICIT", int[] dilations={1, 1}, int groups=1, str data_format="NCHW")
output : Tensor
infer_meta :
func : ConvInferMeta
kernel :
func : conv2d
backward : conv2d_grad
- op : conv3d
args : (Tensor input, Tensor filter, int[] strides={1, 1, 1}, int[] paddings={0, 0, 0}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1, 1}, str data_format="NCDHW")
output : Tensor
infer_meta :
func : Conv3DInferMeta
kernel :
func : conv3d
backward : conv3d_grad
- op : cos
args : (Tensor x)
output : Tensor
......@@ -513,6 +531,15 @@
func : cumprod
backward : cumprod_grad
- op : depthwise_conv2d
args : (Tensor input, Tensor filter, int[] strides={1, 1}, int[] paddings={0, 0}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1}, str data_format="NCHW")
output : Tensor(out)
infer_meta :
func : DepthwiseConvInferMeta
kernel :
func : depthwise_conv2d
backward : depthwise_conv2d_grad
- op : det
args : (Tensor x)
output : Tensor
......
......@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/concat_funcs.h"
......@@ -3320,6 +3321,34 @@ void FusedAdamInferMeta(
}
}
void FusedConvInferMeta(const MetaTensor& input,
const MetaTensor& filter,
const MetaTensor& bias,
const MetaTensor& residual_param,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& padding_algorithm,
const std::vector<int>& dilations,
int groups,
const std::string& data_format,
const std::string& mkldnn_data_type,
const std::string& fuse_activation,
bool fuse_residual_conn,
bool force_fp32_output,
MetaTensor* out,
MetaConfig config) {
ConvInferMeta(input,
filter,
strides,
paddings,
padding_algorithm,
dilations,
groups,
data_format,
out,
config);
}
void MoeInferMeta(const MetaTensor& x,
const MetaTensor& gate,
const MetaTensor& bmm0,
......
......@@ -619,6 +619,23 @@ void FusedAdamInferMeta(
std::vector<MetaTensor*> beta2_pows_out,
std::vector<MetaTensor*> master_params_out);
void FusedConvInferMeta(const MetaTensor& input,
const MetaTensor& filter,
const MetaTensor& bias,
const MetaTensor& residual_param,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& padding_algorithm,
const std::vector<int>& dilations,
int groups,
const std::string& data_format,
const std::string& mkldnn_data_type,
const std::string& fuse_activation,
bool fuse_residual_conn,
bool force_fp32_output,
MetaTensor* out,
MetaConfig config);
void MoeInferMeta(const MetaTensor& x,
const MetaTensor& gate,
const MetaTensor& bmm0,
......
......@@ -133,6 +133,29 @@ void FusedConv3DKernel(const Context& dev_ctx,
out);
}
KernelKey ConvGetKernelTypeForVar(const GetKernelTypeForVarContext* ctx) {
const std::string& var_name = ctx->GetVarName();
const DenseTensor& tensor = ctx->GetTensor();
const KernelKey& expected_kernel_type = ctx->GetKernelKey();
const AttributeMap& attrs = ctx->GetAttrs();
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if ((var_name == "Input") &&
(expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto it = attrs.find("data_format");
const std::string data_format = PADDLE_GET_CONST(std::string, it->second);
auto dl = phi::StringToDataLayout(data_format);
// Some models may have intentionally set "AnyLayout" for conv
// op. Treat this as NCHW (default data_format value)
if (dl != phi::DataLayout::kAnyLayout) {
return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype());
}
}
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
} // namespace fusion
} // namespace phi
......@@ -143,7 +166,11 @@ PD_REGISTER_KERNEL(fused_conv2d,
float,
phi::dtype::bfloat16,
uint8_t,
int8_t) {}
int8_t) {
kernel->get_kerneltype_forvar_fn_ = phi::fusion::ConvGetKernelTypeForVar;
}
PD_REGISTER_KERNEL(
fused_conv3d, OneDNN, ONEDNN, phi::fusion::FusedConv3DKernel, float) {}
fused_conv3d, OneDNN, ONEDNN, phi::fusion::FusedConv3DKernel, float) {
kernel->get_kerneltype_forvar_fn_ = phi::fusion::ConvGetKernelTypeForVar;
}
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/phi/kernels/conv_grad_kernel.h"
#include "paddle/phi/core/compat/get_kerneltype_forvar_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/data_layout_transform.h"
......@@ -235,6 +236,29 @@ void Conv3DGradKernel(const Context& dev_ctx,
filter_grad);
}
KernelKey ConvGradGetKernelTypeForVar(const GetKernelTypeForVarContext* ctx) {
const std::string& var_name = ctx->GetVarName();
const DenseTensor& tensor = ctx->GetTensor();
const KernelKey& expected_kernel_type = ctx->GetKernelKey();
const AttributeMap& attrs = ctx->GetAttrs();
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if (((var_name == "Input") || (var_name == "Output@GRAD")) &&
(expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto it = attrs.find("data_format");
const std::string data_format = PADDLE_GET_CONST(std::string, it->second);
auto dl = phi::StringToDataLayout(data_format);
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if (dl != phi::DataLayout::kAnyLayout) {
return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype());
}
}
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
} // namespace phi
PD_REGISTER_KERNEL(conv2d_grad,
......@@ -242,13 +266,19 @@ PD_REGISTER_KERNEL(conv2d_grad,
ONEDNN,
phi::ConvGradKernel,
float,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16) {
kernel->get_kerneltype_forvar_fn_ = phi::ConvGradGetKernelTypeForVar;
}
PD_REGISTER_KERNEL(depthwise_conv2d_grad,
OneDNN,
ONEDNN,
phi::DepthwiseConvGradKernel,
float,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16) {
kernel->get_kerneltype_forvar_fn_ = phi::ConvGradGetKernelTypeForVar;
}
PD_REGISTER_KERNEL(conv3d_grad, OneDNN, ONEDNN, phi::Conv3DGradKernel, float) {}
PD_REGISTER_KERNEL(conv3d_grad, OneDNN, ONEDNN, phi::Conv3DGradKernel, float) {
kernel->get_kerneltype_forvar_fn_ = phi::ConvGradGetKernelTypeForVar;
}
......@@ -14,6 +14,7 @@
#include "paddle/phi/kernels/conv_kernel.h"
#include "paddle/phi/core/compat/get_kerneltype_forvar_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/data_layout_transform.h"
......@@ -111,6 +112,29 @@ void Conv3DKernel(const Context& dev_ctx,
out);
}
KernelKey ConvGetKernelTypeForVar(const GetKernelTypeForVarContext* ctx) {
const std::string& var_name = ctx->GetVarName();
const DenseTensor& tensor = ctx->GetTensor();
const KernelKey& expected_kernel_type = ctx->GetKernelKey();
const AttributeMap& attrs = ctx->GetAttrs();
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if ((var_name == "Input") &&
(expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto it = attrs.find("data_format");
const std::string data_format = PADDLE_GET_CONST(std::string, it->second);
auto dl = phi::StringToDataLayout(data_format);
// Some models may have intentionally set "AnyLayout" for conv
// op. Treat this as NCHW (default data_format value)
if (dl != phi::DataLayout::kAnyLayout) {
return phi::KernelKey(tensor.place(), dl, expected_kernel_type.dtype());
}
}
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
} // namespace phi
PD_REGISTER_KERNEL(conv2d,
......@@ -120,7 +144,9 @@ PD_REGISTER_KERNEL(conv2d,
float,
phi::dtype::bfloat16,
uint8_t,
int8_t) {}
int8_t) {
kernel->get_kerneltype_forvar_fn_ = phi::ConvGetKernelTypeForVar;
}
PD_REGISTER_KERNEL(depthwise_conv2d,
OneDNN,
......@@ -129,6 +155,10 @@ PD_REGISTER_KERNEL(depthwise_conv2d,
float,
phi::dtype::bfloat16,
uint8_t,
int8_t) {}
int8_t) {
kernel->get_kerneltype_forvar_fn_ = phi::ConvGetKernelTypeForVar;
}
PD_REGISTER_KERNEL(conv3d, OneDNN, ONEDNN, phi::Conv3DKernel, float) {}
PD_REGISTER_KERNEL(conv3d, OneDNN, ONEDNN, phi::Conv3DKernel, float) {
kernel->get_kerneltype_forvar_fn_ = phi::ConvGetKernelTypeForVar;
}
......@@ -16,45 +16,6 @@
namespace phi {
KernelSignature Conv2dOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("conv2d",
{"Input", "Filter"},
{"strides",
"paddings",
"padding_algorithm",
"dilations",
"groups",
"data_format"},
{"Output"});
}
KernelSignature Conv2dGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("conv2d_grad",
{"Input", "Filter", "Output@GRAD"},
{"strides",
"paddings",
"padding_algorithm",
"dilations",
"groups",
"data_format"},
{"Input@GRAD", "Filter@GRAD"});
}
KernelSignature Conv2dDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("conv2d_double_grad",
{"Input", "Filter", "DOutput", "DDInput", "DDFilter"},
{"strides",
"paddings",
"padding_algorithm",
"dilations",
"groups",
"data_format"},
{"DInput", "DFilter", "DDOutput"});
}
KernelSignature Conv2dFusionArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("conv2d_fusion_cutlass",
......@@ -71,9 +32,5 @@ KernelSignature Conv2dFusionArgumentMapping(
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(conv2d, phi::Conv2dOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(conv2d_fusion_cutlass,
phi::Conv2dFusionArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(conv2d_grad, phi::Conv2dGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(conv2d_grad_grad,
phi::Conv2dDoubleGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature Conv3dOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("conv3d",
{"Input", "Filter"},
{
"strides",
"paddings",
"padding_algorithm",
"groups",
"dilations",
"data_format",
},
{"Output"});
}
KernelSignature Conv3dGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("conv3d_grad",
{"Input", "Filter", "Output@GRAD"},
{"strides",
"paddings",
"padding_algorithm",
"groups",
"dilations",
"data_format"},
{"Input@GRAD", "Filter@GRAD"});
}
KernelSignature Conv3dDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("conv3d_double_grad",
{"Input", "Filter", "DOutput", "DDInput", "DDFilter"},
{"strides",
"paddings",
"padding_algorithm",
"groups",
"dilations",
"data_format"},
{"DInput", "DFilter", "DDOutput"});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(conv3d_grad_grad, conv3d_double_grad);
PD_REGISTER_ARG_MAPPING_FN(conv3d, phi::Conv3dOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(conv3d_grad, phi::Conv3dGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(conv3d_grad_grad,
phi::Conv3dDoubleGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature DepthwiseConv2dOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("depthwise_conv2d",
{"Input", "Filter"},
{"strides",
"paddings",
"padding_algorithm",
"groups",
"dilations",
"data_format"},
{"Output"});
}
KernelSignature DepthwiseConv2dGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("depthwise_conv2d_grad",
{"Input", "Filter", "Output@GRAD"},
{"strides",
"paddings",
"padding_algorithm",
"groups",
"dilations",
"data_format"},
{"Input@GRAD", "Filter@GRAD"});
}
KernelSignature DepthwiseConv2dDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("depthwise_conv2d_double_grad",
{"Input", "Filter", "DOutput", "DDInput", "DDFilter"},
{"strides",
"paddings",
"padding_algorithm",
"groups",
"dilations",
"data_format"},
{"DInput", "DFilter", "DDOutput"});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(depthwise_conv2d_grad_grad,
depthwise_conv2d_double_grad);
PD_REGISTER_ARG_MAPPING_FN(depthwise_conv2d,
phi::DepthwiseConv2dOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(depthwise_conv2d_grad,
phi::DepthwiseConv2dGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(depthwise_conv2d_grad_grad,
phi::DepthwiseConv2dDoubleGradOpArgumentMapping);
......@@ -43,7 +43,7 @@ if(WITH_GPU OR WITH_ROCM)
cc_test(
test_cudnn_norm_conv
SRCS cudnn_norm_conv_test.cc
DEPS conv_op
DEPS generated_op
depthwise_conv
tensor
op_registry
......
......@@ -16,7 +16,6 @@ set(TEST_MKLDNN_CACHING_DEPS
elementwise_mul_op
elementwise_add_op
activation_op
conv_op
phi
scope
device_context
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册