Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1e6137b5
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1e6137b5
编写于
7月 07, 2022
作者:
Z
zhangyikun02
提交者:
GitHub
7月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add resnet_basic_block for kunlun, test=kunlun (#43949)
上级
48abaec6
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
607 addition
and
1 deletion
+607
-1
paddle/fluid/operators/fused/CMakeLists.txt
paddle/fluid/operators/fused/CMakeLists.txt
+6
-1
paddle/fluid/operators/fused/resnet_basic_block_op.cc
paddle/fluid/operators/fused/resnet_basic_block_op.cc
+576
-0
paddle/fluid/pybind/op_function_generator.h
paddle/fluid/pybind/op_function_generator.h
+25
-0
未找到文件。
paddle/fluid/operators/fused/CMakeLists.txt
浏览文件 @
1e6137b5
...
...
@@ -26,12 +26,17 @@ register_operators(
fused_bias_dropout_residual_layer_norm_op
resnet_unit_op
fused_gemm_epilogue_op
fused_gate_attention_op
)
fused_gate_attention_op
resnet_basic_block_op
)
# fusion_gru_op does not have CUDA kernel
op_library
(
fusion_gru_op
)
op_library
(
fusion_lstm_op
)
if
(
WITH_XPU
)
op_library
(
resnet_basic_block_op
)
endif
()
if
(
WITH_GPU OR WITH_ROCM
)
# fused_bn_activation_op needs cudnn 7.4.1 above
# HIP not support bn act fuse in MIOPEN
...
...
paddle/fluid/operators/fused/resnet_basic_block_op.cc
0 → 100644
浏览文件 @
1e6137b5
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/api/all.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
class
ResNetBasicBlockOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
// Check input
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Filter1"
),
"Input"
,
"Filter1"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Scale1"
),
"Input"
,
"Scale1"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Bias1"
),
"Input"
,
"Bias1"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Mean1"
),
"Input"
,
"Mean1"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Var1"
),
"Input"
,
"Var1"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Filter2"
),
"Input"
,
"Filter2"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Scale2"
),
"Input"
,
"Scale2"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Bias2"
),
"Input"
,
"Bias2"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Mean2"
),
"Input"
,
"Mean2"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Var2"
),
"Input"
,
"Var2"
,
"ResNetBasicBlockOp"
);
bool
has_shortcut
=
ctx
->
Attrs
().
Get
<
bool
>
(
"has_shortcut"
);
if
(
has_shortcut
)
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Filter3"
),
"Input"
,
"Filter3"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Scale3"
),
"Input"
,
"Scale3"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Bias3"
),
"Input"
,
"Bias3"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Mean3"
),
"Input"
,
"Mean3"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Var3"
),
"Input"
,
"Var3"
,
"ResNetBasicBlockOp"
);
}
// Check output
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Y"
),
"Output"
,
"Y"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Conv1"
),
"Output"
,
"Conv1"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SavedMean1"
),
"Output"
,
"SavedMean1"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SavedInvstd1"
),
"Output"
,
"SavedInvstd1"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Mean1Out"
),
"Output"
,
"Mean1Out"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Var1Out"
),
"Output"
,
"Var1Out"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Conv2"
),
"Output"
,
"Conv2"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SavedMean2"
),
"Output"
,
"SavedMean2"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SavedInvstd2"
),
"Output"
,
"SavedInvstd2"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Mean2Out"
),
"Output"
,
"Mean2Out"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Var2Out"
),
"Output"
,
"Var2Out"
,
"ResNetBasicBlockOp"
);
if
(
has_shortcut
)
{
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Conv3"
),
"Output"
,
"Conv3"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SavedMean3"
),
"Output"
,
"SavedMean3"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SavedInvstd3"
),
"Output"
,
"SavedInvstd3"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Mean3Out"
),
"Output"
,
"Mean3Out"
,
"ResNetBasicBlockOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Var3Out"
),
"Output"
,
"Var3Out"
,
"ResNetBasicBlockOp"
);
}
// make sure Mean/RunningMean and Var/RunningVar share memory
PADDLE_ENFORCE_EQ
(
ctx
->
Inputs
(
"Mean1"
)[
0
],
ctx
->
Outputs
(
"Mean1Out"
)[
0
],
platform
::
errors
::
InvalidArgument
(
"Mean1 and Mean1Out should share the same memory"
));
PADDLE_ENFORCE_EQ
(
ctx
->
Inputs
(
"Var1"
)[
0
],
ctx
->
Outputs
(
"Var1Out"
)[
0
],
platform
::
errors
::
InvalidArgument
(
"Var1 and Var1Out should share the same memory"
));
PADDLE_ENFORCE_EQ
(
ctx
->
Inputs
(
"Mean2"
)[
0
],
ctx
->
Outputs
(
"Mean2Out"
)[
0
],
platform
::
errors
::
InvalidArgument
(
"Mean2 and Mean2Out should share the same memory"
));
PADDLE_ENFORCE_EQ
(
ctx
->
Inputs
(
"Var2"
)[
0
],
ctx
->
Outputs
(
"Var2Out"
)[
0
],
platform
::
errors
::
InvalidArgument
(
"Var2 and Var2Out should share the same memory"
));
if
(
has_shortcut
)
{
PADDLE_ENFORCE_EQ
(
ctx
->
Inputs
(
"Mean3"
)[
0
],
ctx
->
Outputs
(
"Mean3Out"
)[
0
],
platform
::
errors
::
InvalidArgument
(
"Mean3 and Mean3Out should share the same memory"
));
PADDLE_ENFORCE_EQ
(
ctx
->
Inputs
(
"Var3"
)[
0
],
ctx
->
Outputs
(
"Var3Out"
)[
0
],
platform
::
errors
::
InvalidArgument
(
"Var3 and Var3Out should share the same memory"
));
}
// Check dims of inputs
auto
data_format
=
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"data_format"
);
PADDLE_ENFORCE_EQ
(
data_format
,
"NCHW"
,
platform
::
errors
::
InvalidArgument
(
"The data format must equal to NCHW. "
"But received: the data format "
"= [%s]"
,
data_format
));
int
stride1
=
ctx
->
Attrs
().
Get
<
int
>
(
"stride1"
);
int
stride2
=
ctx
->
Attrs
().
Get
<
int
>
(
"stride2"
);
int
padding1
=
ctx
->
Attrs
().
Get
<
int
>
(
"padding1"
);
int
padding2
=
ctx
->
Attrs
().
Get
<
int
>
(
"padding2"
);
const
auto
x1_dims
=
ctx
->
GetInputDim
(
"X"
);
const
auto
w1_dims
=
ctx
->
GetInputDim
(
"Filter1"
);
const
auto
bn1_param_dims
=
ctx
->
GetInputDim
(
"Scale1"
);
PADDLE_ENFORCE_EQ
(
x1_dims
.
size
(),
4
,
platform
::
errors
::
InvalidArgument
(
"The dimensions of input "
"must equal to 4."
"But received: the shape of input "
"= [%s], the dimension of input = "
"[%d]"
,
x1_dims
,
x1_dims
.
size
()));
// Calculate the dims of output1
int
batch
=
x1_dims
[
0
];
int
output1_channel
=
w1_dims
[
0
];
int
filter1_size
=
w1_dims
[
2
];
int
out1_h
=
(
x1_dims
[
2
]
+
padding1
*
2
-
filter1_size
)
/
stride1
+
1
;
int
out1_w
=
(
x1_dims
[
3
]
+
padding1
*
2
-
filter1_size
)
/
stride1
+
1
;
std
::
vector
<
int
>
out1_shape
=
{
batch
,
output1_channel
,
out1_h
,
out1_w
};
const
auto
w2_dims
=
ctx
->
GetInputDim
(
"Filter2"
);
const
auto
bn2_param_dims
=
ctx
->
GetInputDim
(
"Scale2"
);
int
output2_channel
=
w2_dims
[
0
];
int
filter2_size
=
w2_dims
[
2
];
int
out2_h
=
(
out1_h
+
padding2
*
2
-
filter2_size
)
/
stride2
+
1
;
int
out2_w
=
(
out1_w
+
padding2
*
2
-
filter2_size
)
/
stride2
+
1
;
std
::
vector
<
int
>
out2_shape
=
{
batch
,
output2_channel
,
out2_h
,
out2_w
};
auto
y_dims
=
phi
::
make_ddim
(
out2_shape
);
auto
conv1_dims
=
phi
::
make_ddim
(
out1_shape
);
ctx
->
SetOutputDim
(
"Y"
,
y_dims
);
ctx
->
SetOutputDim
(
"Conv1"
,
conv1_dims
);
ctx
->
SetOutputDim
(
"SavedMean1"
,
bn1_param_dims
);
ctx
->
SetOutputDim
(
"SavedInvstd1"
,
bn1_param_dims
);
ctx
->
SetOutputDim
(
"Mean1Out"
,
bn1_param_dims
);
ctx
->
SetOutputDim
(
"Var1Out"
,
bn1_param_dims
);
ctx
->
SetOutputDim
(
"Conv2"
,
y_dims
);
ctx
->
SetOutputDim
(
"Conv2Input"
,
conv1_dims
);
ctx
->
SetOutputDim
(
"SavedMean2"
,
bn2_param_dims
);
ctx
->
SetOutputDim
(
"SavedInvstd2"
,
bn2_param_dims
);
ctx
->
SetOutputDim
(
"Mean2Out"
,
bn2_param_dims
);
ctx
->
SetOutputDim
(
"Var2Out"
,
bn2_param_dims
);
if
(
has_shortcut
)
{
ctx
->
SetOutputDim
(
"Conv3"
,
y_dims
);
ctx
->
SetOutputDim
(
"SavedMean3"
,
bn2_param_dims
);
ctx
->
SetOutputDim
(
"SavedInvstd3"
,
bn2_param_dims
);
ctx
->
SetOutputDim
(
"Mean3Out"
,
bn2_param_dims
);
ctx
->
SetOutputDim
(
"Var3Out"
,
bn2_param_dims
);
}
bool
find_max
=
ctx
->
Attrs
().
Get
<
bool
>
(
"find_conv_input_max"
);
if
(
find_max
)
{
auto
max_dims
=
phi
::
make_ddim
({
6
});
ctx
->
SetOutputDim
(
"MaxInput1"
,
max_dims
);
ctx
->
SetOutputDim
(
"MaxFilter1"
,
max_dims
);
ctx
->
SetOutputDim
(
"MaxInput2"
,
max_dims
);
ctx
->
SetOutputDim
(
"MaxFilter2"
,
max_dims
);
if
(
has_shortcut
)
{
ctx
->
SetOutputDim
(
"MaxInput3"
,
max_dims
);
ctx
->
SetOutputDim
(
"MaxFilter3"
,
max_dims
);
}
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
input_data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
// By default, the type of the scale, bias, mean,
// and var tensors should be float when input tensor's dtype is float16.
auto
bn_param_type
=
framework
::
proto
::
VarType
::
FP32
;
PADDLE_ENFORCE_EQ
(
bn_param_type
,
framework
::
TransToProtoVarType
(
ctx
.
Input
<
Tensor
>
(
"Scale1"
)
->
dtype
()),
platform
::
errors
::
InvalidArgument
(
"Scale input should be of float type"
));
PADDLE_ENFORCE_EQ
(
bn_param_type
,
framework
::
TransToProtoVarType
(
ctx
.
Input
<
Tensor
>
(
"Bias1"
)
->
dtype
()),
platform
::
errors
::
InvalidArgument
(
"Bias input should be of float type"
));
PADDLE_ENFORCE_EQ
(
bn_param_type
,
framework
::
TransToProtoVarType
(
ctx
.
Input
<
Tensor
>
(
"Scale2"
)
->
dtype
()),
platform
::
errors
::
InvalidArgument
(
"Scale input should be of float type"
));
PADDLE_ENFORCE_EQ
(
bn_param_type
,
framework
::
TransToProtoVarType
(
ctx
.
Input
<
Tensor
>
(
"Bias2"
)
->
dtype
()),
platform
::
errors
::
InvalidArgument
(
"Bias input should be of float type"
));
framework
::
LibraryType
library
=
framework
::
LibraryType
::
kPlain
;
framework
::
DataLayout
layout
=
framework
::
DataLayout
::
kAnyLayout
;
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
(),
layout
,
library
);
}
};
class
ResNetBasicBlockOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
// has_shortcut = True: X else: X
// / /
// | | | |
// CONV1 | CONV1 |
// | | | |
// BN1 | BN1 |
// | | | |
// RELU1 | RELU1 |
// | | | |
// CONV2 CONV3 CONV2 |
// | | | |
// BN2 BN3 BN2 |
// \ / \ /
// ADD ADD
// | |
// RELU RELU
// | |
// Y Y
AddInput
(
"X"
,
"Input tensor of conv 1"
);
AddInput
(
"Filter1"
,
"Filter tensor of conv 1"
);
AddInput
(
"Scale1"
,
"Scale tensor of bn 1"
);
AddInput
(
"Bias1"
,
"Bias tensor of bn 1"
);
AddInput
(
"Mean1"
,
"Mean tensor of bn 1"
);
AddInput
(
"Var1"
,
"Variance tensor of bn 1"
);
AddInput
(
"Filter2"
,
"Filter tensor of conv 2"
);
AddInput
(
"Scale2"
,
"Scale tensor of bn 2"
);
AddInput
(
"Bias2"
,
"Bias tensor of bn 2"
);
AddInput
(
"Mean2"
,
"Mean tensor of bn 2"
);
AddInput
(
"Var2"
,
"Variance tensor of bn 2"
);
AddInput
(
"Filter3"
,
"Filter tensor of conv 3"
).
AsDispensable
();
AddInput
(
"Scale3"
,
"Scale tensor of bn 3"
).
AsDispensable
();
AddInput
(
"Bias3"
,
"Bias tensor of bn 3"
).
AsDispensable
();
AddInput
(
"Mean3"
,
"Mean tensor of bn 3"
).
AsDispensable
();
AddInput
(
"Var3"
,
"Variance tensor of bn 3"
).
AsDispensable
();
AddOutput
(
"Y"
,
"The result of ssd resnet unit"
);
AddOutput
(
"Conv1"
,
"The result of conv 1"
);
AddOutput
(
"SavedMean1"
,
"Mean of input 1 after conv 1"
);
AddOutput
(
"SavedInvstd1"
,
"Invstd of input 1 after conv 1"
);
AddOutput
(
"Mean1Out"
,
"Shared memory with Mean1"
);
AddOutput
(
"Var1Out"
,
"Shared memory with Var1"
);
AddOutput
(
"Conv2"
,
"The result of conv 2"
);
AddOutput
(
"Conv2Input"
,
"Conv2 input data"
);
AddOutput
(
"SavedMean2"
,
"Mean of input 2 after conv 2"
);
AddOutput
(
"SavedInvstd2"
,
"Invstd of input 2 after conv 2"
);
AddOutput
(
"Mean2Out"
,
"Shared memory with Mean2"
);
AddOutput
(
"Var2Out"
,
"Shared memory with Var2"
);
AddOutput
(
"Conv3"
,
"The result of conv 3"
).
AsDispensable
();
AddOutput
(
"SavedMean3"
,
"Mean of input 3 after conv 3"
).
AsDispensable
();
AddOutput
(
"SavedInvstd3"
,
"Invstd of input 3 after conv 3"
).
AsDispensable
();
AddOutput
(
"Mean3Out"
,
"Shared memory with Mean3"
).
AsDispensable
();
AddOutput
(
"Var3Out"
,
"Shared memory with Var3"
).
AsDispensable
();
AddOutput
(
"MaxInput1"
,
"The max value of conv1 input tensor"
)
.
AsDispensable
();
AddOutput
(
"MaxFilter1"
,
"The max value of conv1 filter tensor"
)
.
AsDispensable
();
AddOutput
(
"MaxInput2"
,
"The max value of conv2 input tensor"
)
.
AsDispensable
();
AddOutput
(
"MaxFilter2"
,
"The max value of conv2 filter tensor"
)
.
AsDispensable
();
AddOutput
(
"MaxInput3"
,
"The max value of conv3 input tensor"
)
.
AsDispensable
();
AddOutput
(
"MaxFilter3"
,
"The max value of conv3 filter tensor"
)
.
AsDispensable
();
AddAttr
<
int
>
(
"stride1"
,
"Stride of conv1"
).
SetDefault
(
1
);
AddAttr
<
int
>
(
"stride2"
,
"Stride of conv2"
).
SetDefault
(
1
);
AddAttr
<
int
>
(
"stride3"
,
"Stride of conv3"
).
SetDefault
(
1
);
AddAttr
<
int
>
(
"padding1"
,
"Padding of conv1"
).
SetDefault
(
0
);
AddAttr
<
int
>
(
"padding2"
,
"Padding of conv2"
).
SetDefault
(
0
);
AddAttr
<
int
>
(
"padding3"
,
"Padding of conv3"
).
SetDefault
(
0
);
AddAttr
<
int
>
(
"dilation1"
,
"Dilation of conv1"
).
SetDefault
(
1
);
AddAttr
<
int
>
(
"dilation2"
,
"Dilation of conv2"
).
SetDefault
(
1
);
AddAttr
<
int
>
(
"dilation3"
,
"Dilation of conv3"
).
SetDefault
(
1
);
AddAttr
<
int
>
(
"group"
,
"Group of all the 3 conv"
).
SetDefault
(
1
);
AddAttr
<
float
>
(
"momentum"
,
"Momentum of all the 3 bn"
).
SetDefault
(
0.9
);
AddAttr
<
float
>
(
"epsilon"
,
"Epsilon of all the 3 bn"
).
SetDefault
(
1e-5
);
AddAttr
<
std
::
string
>
(
"data_format"
,
""
).
SetDefault
(
"NCHW"
);
AddAttr
<
bool
>
(
"has_shortcut"
,
""
).
SetDefault
(
false
);
AddAttr
<
bool
>
(
"use_global_stats"
,
""
).
SetDefault
(
false
);
AddAttr
<
bool
>
(
"is_test"
,
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true."
)
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"trainable_statistics"
,
"(bool, default false) Whether to calculate mean and variance "
"in test mode. If setting true in test mode, mean and variace "
"will be calculated by current batch statistics."
)
.
SetDefault
(
false
);
AddAttr
<
std
::
string
>
(
"act_type"
,
"The activation type to be fused."
)
.
SetDefault
(
"relu"
);
AddAttr
<
bool
>
(
"find_conv_input_max"
,
"(bool, default true) Whether to calculate max value of conv "
"input tensor."
)
.
SetDefault
(
true
);
AddComment
(
R"DOC(
Fusion op of the basic unit of ssd resnet block.
** This is only use for XPU, if has problems, concat zhangyikun02@baidu.com **
)DOC"
);
}
};
template
<
typename
T
>
class
ResNetBasicBlockGradOpMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
protected:
void
Apply
(
GradOpPtr
<
T
>
op
)
const
override
{
op
->
SetType
(
"resnet_basic_block_grad"
);
op
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
op
->
SetInput
(
"Filter1"
,
this
->
Input
(
"Filter1"
));
op
->
SetInput
(
"Conv1"
,
this
->
Output
(
"Conv1"
));
op
->
SetInput
(
"Scale1"
,
this
->
Input
(
"Scale1"
));
op
->
SetInput
(
"Bias1"
,
this
->
Input
(
"Bias1"
));
op
->
SetInput
(
"SavedMean1"
,
this
->
Output
(
"SavedMean1"
));
op
->
SetInput
(
"SavedInvstd1"
,
this
->
Output
(
"SavedInvstd1"
));
op
->
SetInput
(
"Filter2"
,
this
->
Input
(
"Filter2"
));
op
->
SetInput
(
"Conv2"
,
this
->
Output
(
"Conv2"
));
op
->
SetInput
(
"Conv2Input"
,
this
->
Output
(
"Conv2Input"
));
op
->
SetInput
(
"Scale2"
,
this
->
Input
(
"Scale2"
));
op
->
SetInput
(
"Bias2"
,
this
->
Input
(
"Bias2"
));
op
->
SetInput
(
"SavedMean2"
,
this
->
Output
(
"SavedMean2"
));
op
->
SetInput
(
"SavedInvstd2"
,
this
->
Output
(
"SavedInvstd2"
));
op
->
SetInput
(
"Filter3"
,
this
->
Input
(
"Filter3"
));
op
->
SetInput
(
"Conv3"
,
this
->
Output
(
"Conv3"
));
op
->
SetInput
(
"Scale3"
,
this
->
Input
(
"Scale3"
));
op
->
SetInput
(
"Bias3"
,
this
->
Input
(
"Bias3"
));
op
->
SetInput
(
"SavedMean3"
,
this
->
Output
(
"SavedMean3"
));
op
->
SetInput
(
"SavedInvstd3"
,
this
->
Output
(
"SavedInvstd3"
));
op
->
SetInput
(
"MaxInput1"
,
this
->
Output
(
"MaxInput1"
));
op
->
SetInput
(
"MaxFilter1"
,
this
->
Output
(
"MaxFilter1"
));
op
->
SetInput
(
"MaxInput2"
,
this
->
Output
(
"MaxInput2"
));
op
->
SetInput
(
"MaxFilter2"
,
this
->
Output
(
"MaxFilter2"
));
op
->
SetInput
(
"MaxInput3"
,
this
->
Output
(
"MaxInput3"
));
op
->
SetInput
(
"MaxFilter3"
,
this
->
Output
(
"MaxFilter3"
));
op
->
SetInput
(
"Y"
,
this
->
Output
(
"Y"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Y"
),
this
->
OutputGrad
(
"Y"
));
op
->
SetAttrMap
(
this
->
Attrs
());
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Filter1"
),
this
->
InputGrad
(
"Filter1"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Scale1"
),
this
->
InputGrad
(
"Scale1"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Bias1"
),
this
->
InputGrad
(
"Bias1"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Filter2"
),
this
->
InputGrad
(
"Filter2"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Scale2"
),
this
->
InputGrad
(
"Scale2"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Bias2"
),
this
->
InputGrad
(
"Bias2"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Filter3"
),
this
->
InputGrad
(
"Filter3"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Scale3"
),
this
->
InputGrad
(
"Scale3"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Bias3"
),
this
->
InputGrad
(
"Bias3"
));
}
};
class
ResNetBasicBlockOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
GetInputOutputWithSameType
()
const
override
{
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Y"
}};
return
m
;
}
};
class
ResNetBasicBlockGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
// check input
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"ResNetBasicBlockGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Filter1"
),
"Input"
,
"Filter1"
,
"ResNetBasicBlockGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Conv1"
),
"Input"
,
"Conv1"
,
"ResNetBasicBlockGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Scale1"
),
"Input"
,
"Scale1"
,
"ResNetBasicBlockGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Bias1"
),
"Input"
,
"Bias1"
,
"ResNetBasicBlockGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"SavedMean1"
),
"Input"
,
"SavedMean1"
,
"ResNetBasicBlockGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"SavedInvstd1"
),
"Input"
,
"SavedInvstd1"
,
"ResNetBasicBlockGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Filter2"
),
"Input"
,
"Filter2"
,
"ResNetBasicBlockGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Conv2"
),
"Input"
,
"Conv2"
,
"ResNetBasicBlockGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Scale2"
),
"Input"
,
"Scale2"
,
"ResNetBasicBlockGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Bias2"
),
"Input"
,
"Bias2"
,
"ResNetBasicBlockGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"SavedMean2"
),
"Input"
,
"SavedMean2"
,
"ResNetBasicBlockGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"SavedInvstd2"
),
"Input"
,
"SavedInvstd2"
,
"ResNetBasicBlockGradOp"
);
bool
has_shortcut
=
ctx
->
Attrs
().
Get
<
bool
>
(
"has_shortcut"
);
if
(
has_shortcut
)
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Filter3"
),
"Input"
,
"Filter3"
,
"ResNetBasicBlockGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Scale3"
),
"Input"
,
"Scale3"
,
"ResNetBasicBlockGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Bias3"
),
"Input"
,
"Bias3"
,
"ResNetBasicBlockGradOp"
);
}
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Y"
),
"Input"
,
"Y"
,
"ResNetBasicBlockGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
"Input"
,
framework
::
GradVarName
(
"Y"
),
"ResNetBasicBlockGradOp"
);
// check output
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Filter1"
)),
"Output"
,
framework
::
GradVarName
(
"Filter1"
),
"ResNetBasicBlockGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Scale1"
)),
"Output"
,
framework
::
GradVarName
(
"Scale1"
),
"ResNetBasicBlockGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Bias1"
)),
"Output"
,
framework
::
GradVarName
(
"Bias1"
),
"ResNetBasicBlockGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Filter2"
)),
"Output"
,
framework
::
GradVarName
(
"Filter2"
),
"ResNetBasicBlockGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Scale2"
)),
"Output"
,
framework
::
GradVarName
(
"Scale2"
),
"ResNetBasicBlockGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Bias2"
)),
"Output"
,
framework
::
GradVarName
(
"Bias2"
),
"ResNetBasicBlockGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
"Output"
,
framework
::
GradVarName
(
"X"
),
"ResNetBasicBlockGradOp"
);
if
(
has_shortcut
)
{
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Filter3"
)),
"Output"
,
framework
::
GradVarName
(
"Filter3"
),
"ResNetBasicBlockGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Scale3"
)),
"Output"
,
framework
::
GradVarName
(
"Scale3"
),
"ResNetBasicBlockGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Bias3"
)),
"Output"
,
framework
::
GradVarName
(
"Bias3"
),
"ResNetBasicBlockGradOp"
);
}
const
auto
x1_dims
=
ctx
->
GetInputDim
(
"X"
);
const
auto
filter1_x_dims
=
ctx
->
GetInputDim
(
"Filter1"
);
const
auto
param1_dims
=
ctx
->
GetInputDim
(
"Scale1"
);
const
auto
filter2_x_dims
=
ctx
->
GetInputDim
(
"Filter2"
);
const
auto
param2_dims
=
ctx
->
GetInputDim
(
"Scale2"
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x1_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Filter1"
),
filter1_x_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Scale1"
),
param1_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Bias1"
),
param1_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Filter2"
),
filter2_x_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Scale2"
),
param2_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Bias2"
),
param2_dims
);
if
(
has_shortcut
)
{
const
auto
filter_z_dims
=
ctx
->
GetInputDim
(
"Filter3"
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Filter3"
),
filter_z_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Scale3"
),
param2_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Bias3"
),
param2_dims
);
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Y"
)),
platform
::
errors
::
NotFound
(
"Can not find Y@GRAD in the execution context."
));
framework
::
LibraryType
library
=
framework
::
LibraryType
::
kPlain
;
framework
::
DataLayout
layout
=
framework
::
DataLayout
::
kAnyLayout
;
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
(),
layout
,
library
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
resnet_basic_block
,
ops
::
ResNetBasicBlockOp
,
ops
::
ResNetBasicBlockOpMaker
,
ops
::
ResNetBasicBlockOpInferVarType
,
ops
::
ResNetBasicBlockGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
ResNetBasicBlockGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
resnet_basic_block_grad
,
ops
::
ResNetBasicBlockGradOp
);
paddle/fluid/pybind/op_function_generator.h
浏览文件 @
1e6137b5
...
...
@@ -208,6 +208,23 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{
"trilinear_interp"
,
{
"X"
,
"OutSize"
}},
{
"nearest_interp"
,
{
"X"
,
"OutSize"
}},
{
"bicubic_interp"
,
{
"X"
,
"OutSize"
}},
{
"resnet_basic_block"
,
{
"X"
,
"Filter1"
,
"Scale1"
,
"Bias1"
,
"Mean1"
,
"Var1"
,
"Filter2"
,
"Scale2"
,
"Bias2"
,
"Mean2"
,
"Var2"
,
"Filter3"
,
"Scale3"
,
"Bias3"
,
"Mean3"
,
"Var3"
}},
};
// NOTE(zhiqiu): Like op_ins_map.
...
...
@@ -309,6 +326,12 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
"Beta2PowOut"
,
"MasterParamOut"
}},
{
"fused_multi_transformer"
,
{
"CacheKVOut"
,
"Out"
}},
{
"resnet_basic_block"
,
{
"Y"
,
"Conv1"
,
"SavedMean1"
,
"SavedInvstd1"
,
"Mean1Out"
,
"Var1Out"
,
"Conv2"
,
"SavedMean2"
,
"SavedInvstd2"
,
"Mean2Out"
,
"Var2Out"
,
"Conv3"
,
"SavedMean3"
,
"SavedInvstd3"
,
"Mean3Out"
,
"Var3Out"
,
"MaxInput1"
,
"MaxFilter1"
,
"MaxInput2"
,
"MaxFilter2"
,
"MaxInput3"
,
"MaxFilter3"
}},
};
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
...
...
@@ -408,6 +431,8 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{
"concat"
,
{
"Out"
}},
{
"fused_multi_transformer"
,
{
"CacheKVOut"
}},
{
"group_norm"
,
{
"Mean"
,
"Variance"
}},
{
"resnet_basic_block"
,
{
"Mean1Out"
,
"Var1Out"
,
"Mean2Out"
,
"Var2Out"
,
"Mean3Out"
,
"Var3Out"
}},
};
// NOTE(pangyoki): Tensor View Strategy.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录