Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
d7be46b3
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d7be46b3
编写于
7月 08, 2022
作者:
Z
zhangyikun02
提交者:
GitHub
7月 08, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add implement of resnet_basic_block op for XPU2, test=kunlun (#44143)
上级
337bb47b
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
997 addition
and
18 deletion
+997
-18
paddle/fluid/operators/fused/resnet_basic_block_op.cc
paddle/fluid/operators/fused/resnet_basic_block_op.cc
+19
-18
paddle/fluid/operators/fused/resnet_basic_block_op_xpu.cc
paddle/fluid/operators/fused/resnet_basic_block_op_xpu.cc
+970
-0
paddle/fluid/platform/device/xpu/xpu2_op_list.h
paddle/fluid/platform/device/xpu/xpu2_op_list.h
+8
-0
未找到文件。
paddle/fluid/operators/fused/resnet_basic_block_op.cc
浏览文件 @
d7be46b3
...
...
@@ -258,24 +258,25 @@ class ResNetBasicBlockOp : public framework::OperatorWithKernel {
class
ResNetBasicBlockOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
// has_shortcut = True: X else: X
// / /
// | | | |
// CONV1 | CONV1 |
// | | | |
// BN1 | BN1 |
// | | | |
// RELU1 | RELU1 |
// | | | |
// CONV2 CONV3 CONV2 |
// | | | |
// BN2 BN3 BN2 |
// \ / \ /
// ADD ADD
// | |
// RELU RELU
// | |
// Y Y
// has_shortcut = True: else:
// X X
// / /
// | | | |
// CONV1 | CONV1 |
// | | | |
// BN1 | BN1 |
// | | | |
// RELU1 | RELU1 |
// | | | |
// CONV2 CONV3 CONV2 |
// | | | |
// BN2 BN3 BN2 |
// \ / \ /
// ADD ADD
// | |
// RELU RELU
// | |
// Y Y
AddInput
(
"X"
,
"Input tensor of conv 1"
);
AddInput
(
"Filter1"
,
"Filter tensor of conv 1"
);
AddInput
(
"Scale1"
,
"Scale tensor of bn 1"
);
...
...
paddle/fluid/operators/fused/resnet_basic_block_op_xpu.cc
0 → 100644
浏览文件 @
d7be46b3
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
#include "paddle/phi/api/all.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
class
ResnetBasicBlockAttr
{
public:
explicit
ResnetBasicBlockAttr
(
const
framework
::
ExecutionContext
&
ctx
)
{
padding1
=
ctx
.
Attr
<
int
>
(
"padding1"
);
padding2
=
ctx
.
Attr
<
int
>
(
"padding2"
);
padding3
=
ctx
.
Attr
<
int
>
(
"padding3"
);
stride1
=
ctx
.
Attr
<
int
>
(
"stride1"
);
stride2
=
ctx
.
Attr
<
int
>
(
"stride2"
);
stride3
=
ctx
.
Attr
<
int
>
(
"stride3"
);
dilation1
=
ctx
.
Attr
<
int
>
(
"dilation1"
);
dilation2
=
ctx
.
Attr
<
int
>
(
"dilation2"
);
dilation3
=
ctx
.
Attr
<
int
>
(
"dilation3"
);
group
=
ctx
.
Attr
<
int
>
(
"group"
);
eps
=
static_cast
<
double
>
(
ctx
.
Attr
<
float
>
(
"epsilon"
));
momentum
=
static_cast
<
double
>
(
ctx
.
Attr
<
float
>
(
"momentum"
));
has_shortcut
=
ctx
.
Attr
<
bool
>
(
"has_shortcut"
);
find_max
=
ctx
.
Attr
<
bool
>
(
"find_conv_input_max"
);
const
auto
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
const
auto
use_global_stats
=
ctx
.
Attr
<
bool
>
(
"use_global_stats"
);
const
auto
trainable_stats
=
ctx
.
Attr
<
bool
>
(
"trainable_statistics"
);
bool
test_mode
=
is_test
&&
(
!
trainable_stats
);
global_stats
=
test_mode
||
use_global_stats
;
// init shape
auto
input1
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
filter1
=
ctx
.
Input
<
Tensor
>
(
"Filter1"
);
auto
conv1_out
=
ctx
.
Output
<
Tensor
>
(
"Conv1"
);
auto
filter2
=
ctx
.
Input
<
Tensor
>
(
"Filter2"
);
auto
conv2_out
=
ctx
.
Output
<
Tensor
>
(
"Conv2"
);
conv1_input_shape
=
phi
::
vectorize
<
int
>
(
input1
->
dims
());
conv1_output_shape
=
phi
::
vectorize
<
int
>
(
conv1_out
->
dims
());
conv1_filter_shape
=
phi
::
vectorize
<
int
>
(
filter1
->
dims
());
conv1_filter_numel
=
filter1
->
numel
();
conv1_input_numel
=
input1
->
numel
();
conv1_output_numel
=
conv1_out
->
numel
();
conv2_input_shape
=
phi
::
vectorize
<
int
>
(
conv1_out
->
dims
());
conv2_output_shape
=
phi
::
vectorize
<
int
>
(
conv2_out
->
dims
());
conv2_filter_shape
=
phi
::
vectorize
<
int
>
(
filter2
->
dims
());
conv2_filter_numel
=
filter2
->
numel
();
conv2_input_numel
=
conv1_out
->
numel
();
conv2_output_numel
=
conv2_out
->
numel
();
if
(
has_shortcut
)
{
auto
filter3
=
ctx
.
Input
<
Tensor
>
(
"Filter3"
);
auto
conv3_out
=
ctx
.
Output
<
Tensor
>
(
"Conv3"
);
conv3_input_shape
=
phi
::
vectorize
<
int
>
(
input1
->
dims
());
conv3_output_shape
=
phi
::
vectorize
<
int
>
(
conv3_out
->
dims
());
conv3_filter_shape
=
phi
::
vectorize
<
int
>
(
filter3
->
dims
());
conv3_filter_numel
=
filter3
->
numel
();
conv3_input_numel
=
input1
->
numel
();
conv3_output_numel
=
conv3_out
->
numel
();
}
}
int
padding1
;
int
padding2
;
int
padding3
;
int
stride1
;
int
stride2
;
int
stride3
;
int
dilation1
;
int
dilation2
;
int
dilation3
;
int
group
;
double
eps
;
double
momentum
;
bool
has_shortcut
;
bool
find_max
;
bool
global_stats
;
std
::
vector
<
int
>
conv1_input_shape
;
std
::
vector
<
int
>
conv1_output_shape
;
std
::
vector
<
int
>
conv1_filter_shape
;
std
::
vector
<
int
>
conv2_input_shape
;
std
::
vector
<
int
>
conv2_output_shape
;
std
::
vector
<
int
>
conv2_filter_shape
;
std
::
vector
<
int
>
conv3_input_shape
;
std
::
vector
<
int
>
conv3_output_shape
;
std
::
vector
<
int
>
conv3_filter_shape
;
int
conv1_filter_numel
;
int
conv2_filter_numel
;
int
conv3_filter_numel
;
int
conv1_input_numel
;
int
conv2_input_numel
;
int
conv3_input_numel
;
int
conv1_output_numel
;
int
conv2_output_numel
;
int
conv3_output_numel
;
};
class
ResnetBasicBlockGradAttr
{
public:
explicit
ResnetBasicBlockGradAttr
(
const
framework
::
ExecutionContext
&
ctx
)
{
padding1
=
ctx
.
Attr
<
int
>
(
"padding1"
);
padding2
=
ctx
.
Attr
<
int
>
(
"padding2"
);
padding3
=
ctx
.
Attr
<
int
>
(
"padding3"
);
stride1
=
ctx
.
Attr
<
int
>
(
"stride1"
);
stride2
=
ctx
.
Attr
<
int
>
(
"stride2"
);
stride3
=
ctx
.
Attr
<
int
>
(
"stride3"
);
dilation1
=
ctx
.
Attr
<
int
>
(
"dilation1"
);
dilation2
=
ctx
.
Attr
<
int
>
(
"dilation2"
);
dilation3
=
ctx
.
Attr
<
int
>
(
"dilation3"
);
group
=
ctx
.
Attr
<
int
>
(
"group"
);
has_shortcut
=
ctx
.
Attr
<
bool
>
(
"has_shortcut"
);
find_max
=
ctx
.
Attr
<
bool
>
(
"find_conv_input_max"
);
// init shape
auto
input1
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
filter1
=
ctx
.
Input
<
Tensor
>
(
"Filter1"
);
auto
conv1_out
=
ctx
.
Input
<
Tensor
>
(
"Conv1"
);
auto
filter2
=
ctx
.
Input
<
Tensor
>
(
"Filter2"
);
auto
conv2_out
=
ctx
.
Input
<
Tensor
>
(
"Conv2"
);
conv1_input_shape
=
phi
::
vectorize
<
int
>
(
input1
->
dims
());
conv1_output_shape
=
phi
::
vectorize
<
int
>
(
conv1_out
->
dims
());
conv1_filter_shape
=
phi
::
vectorize
<
int
>
(
filter1
->
dims
());
conv1_filter_numel
=
filter1
->
numel
();
conv1_input_numel
=
input1
->
numel
();
conv1_output_numel
=
conv1_out
->
numel
();
conv2_input_shape
=
phi
::
vectorize
<
int
>
(
conv1_out
->
dims
());
conv2_output_shape
=
phi
::
vectorize
<
int
>
(
conv2_out
->
dims
());
conv2_filter_shape
=
phi
::
vectorize
<
int
>
(
filter2
->
dims
());
conv2_filter_numel
=
filter2
->
numel
();
conv2_input_numel
=
conv1_out
->
numel
();
conv2_output_numel
=
conv2_out
->
numel
();
if
(
has_shortcut
)
{
auto
filter3
=
ctx
.
Input
<
Tensor
>
(
"Filter3"
);
auto
conv3_out
=
ctx
.
Input
<
Tensor
>
(
"Conv3"
);
conv3_input_shape
=
phi
::
vectorize
<
int
>
(
input1
->
dims
());
conv3_output_shape
=
phi
::
vectorize
<
int
>
(
conv3_out
->
dims
());
conv3_filter_shape
=
phi
::
vectorize
<
int
>
(
filter3
->
dims
());
conv3_filter_numel
=
filter3
->
numel
();
conv3_input_numel
=
input1
->
numel
();
conv3_output_numel
=
conv3_out
->
numel
();
}
}
int
padding1
;
int
padding2
;
int
padding3
;
int
stride1
;
int
stride2
;
int
stride3
;
int
dilation1
;
int
dilation2
;
int
dilation3
;
int
group
;
bool
has_shortcut
;
bool
find_max
;
std
::
vector
<
int
>
conv1_input_shape
;
std
::
vector
<
int
>
conv1_output_shape
;
std
::
vector
<
int
>
conv1_filter_shape
;
std
::
vector
<
int
>
conv2_input_shape
;
std
::
vector
<
int
>
conv2_output_shape
;
std
::
vector
<
int
>
conv2_filter_shape
;
std
::
vector
<
int
>
conv3_input_shape
;
std
::
vector
<
int
>
conv3_output_shape
;
std
::
vector
<
int
>
conv3_filter_shape
;
int
conv1_filter_numel
;
int
conv2_filter_numel
;
int
conv3_filter_numel
;
int
conv1_input_numel
;
int
conv2_input_numel
;
int
conv3_input_numel
;
int
conv1_output_numel
;
int
conv2_output_numel
;
int
conv3_output_numel
;
};
template
<
typename
T
>
static
inline
void
xpu_conv2d
(
xpu
::
Context
*
ctx
,
const
T
*
input_data
,
const
T
*
filter_data
,
T
*
output_data
,
float
*
input_max_data
,
float
*
filter_max_data
,
const
std
::
vector
<
int
>&
input_shape
,
const
std
::
vector
<
int
>&
filter_shape
,
int
padding
,
int
stride
,
int
dilation
,
int
group
)
{
std
::
vector
<
int
>
ksize
{
filter_shape
[
2
],
filter_shape
[
3
]};
std
::
vector
<
int
>
stride_vec
{
stride
,
stride
};
std
::
vector
<
int
>
dilation_vec
{
dilation
,
dilation
};
std
::
vector
<
int
>
padding_vec
{
padding
,
padding
};
int
N
=
input_shape
[
0
];
int
C
=
input_shape
[
1
];
int
H
=
input_shape
[
2
];
int
W
=
input_shape
[
3
];
int
r
=
xpu
::
conv2d
<
T
,
T
,
T
,
int16_t
>
(
ctx
,
input_data
,
filter_data
,
output_data
,
N
,
C
,
H
,
W
,
filter_shape
[
0
],
ksize
,
stride_vec
,
padding_vec
,
dilation_vec
,
group
,
input_max_data
,
filter_max_data
,
nullptr
,
true
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"conv2d"
);
}
template
<
typename
T
>
static
inline
void
xpu_conv2d_grad
(
xpu
::
Context
*
ctx
,
const
T
*
input_data
,
const
T
*
filter_data
,
const
T
*
output_grad_data
,
T
*
input_grad_data
,
T
*
filter_grad_data
,
const
float
*
input_max_data
,
const
float
*
filter_max_data
,
const
std
::
vector
<
int
>&
input_shape
,
const
std
::
vector
<
int
>&
filter_shape
,
int
padding
,
int
stride
,
int
dilation
,
int
group
)
{
std
::
vector
<
int
>
ksize
{
filter_shape
[
2
],
filter_shape
[
3
]};
std
::
vector
<
int
>
stride_vec
{
stride
,
stride
};
std
::
vector
<
int
>
dilation_vec
{
dilation
,
dilation
};
std
::
vector
<
int
>
padding_vec
{
padding
,
padding
};
int
N
=
input_shape
[
0
];
int
C
=
input_shape
[
1
];
int
H
=
input_shape
[
2
];
int
W
=
input_shape
[
3
];
int
r
=
xpu
::
conv2d_grad
<
T
,
T
,
T
,
int16_t
>
(
ctx
,
input_data
,
filter_data
,
output_grad_data
,
input_grad_data
,
filter_grad_data
,
N
,
C
,
H
,
W
,
filter_shape
[
0
],
ksize
,
stride_vec
,
padding_vec
,
dilation_vec
,
group
,
input_max_data
,
filter_max_data
,
nullptr
,
nullptr
,
nullptr
,
true
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"conv2d_grad"
);
}
template
<
typename
T
>
class
ResNetBasicBlockXPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
using
XPUT
=
typename
XPUTypeTrait
<
T
>::
Type
;
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_xpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"It must use XPUPlace."
));
// input
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
filter1
=
ctx
.
Input
<
Tensor
>
(
"Filter1"
);
const
Tensor
*
scale1
=
ctx
.
Input
<
Tensor
>
(
"Scale1"
);
const
Tensor
*
bias1
=
ctx
.
Input
<
Tensor
>
(
"Bias1"
);
const
Tensor
*
filter2
=
ctx
.
Input
<
Tensor
>
(
"Filter2"
);
const
Tensor
*
scale2
=
ctx
.
Input
<
Tensor
>
(
"Scale2"
);
const
Tensor
*
bias2
=
ctx
.
Input
<
Tensor
>
(
"Bias2"
);
// output
Tensor
*
conv1_output
=
ctx
.
Output
<
Tensor
>
(
"Conv1"
);
Tensor
*
conv2_output
=
ctx
.
Output
<
Tensor
>
(
"Conv2"
);
Tensor
*
conv2_input
=
ctx
.
Output
<
Tensor
>
(
"Conv2Input"
);
Tensor
*
output
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
auto
place
=
ctx
.
GetPlace
();
auto
x_data
=
reinterpret_cast
<
const
XPUT
*>
(
x
->
data
<
T
>
());
auto
conv1_filter_data
=
reinterpret_cast
<
const
XPUT
*>
(
filter1
->
data
<
T
>
());
auto
conv2_filter_data
=
reinterpret_cast
<
const
XPUT
*>
(
filter2
->
data
<
T
>
());
auto
conv1_output_data
=
reinterpret_cast
<
XPUT
*>
(
conv1_output
->
mutable_data
<
T
>
(
place
));
auto
conv2_input_data
=
reinterpret_cast
<
XPUT
*>
(
conv2_input
->
mutable_data
<
T
>
(
place
));
auto
conv2_output_data
=
reinterpret_cast
<
XPUT
*>
(
conv2_output
->
mutable_data
<
T
>
(
place
));
auto
scale1_data
=
scale1
->
data
<
float
>
();
auto
scale2_data
=
scale2
->
data
<
float
>
();
auto
bias1_data
=
bias1
->
data
<
float
>
();
auto
bias2_data
=
bias2
->
data
<
float
>
();
auto
output_data
=
reinterpret_cast
<
XPUT
*>
(
output
->
mutable_data
<
T
>
(
place
));
float
*
conv1_input_max_data
=
nullptr
;
float
*
conv1_filter_max_data
=
nullptr
;
float
*
conv2_input_max_data
=
nullptr
;
float
*
conv2_filter_max_data
=
nullptr
;
float
*
conv3_input_max_data
=
nullptr
;
float
*
conv3_filter_max_data
=
nullptr
;
ResnetBasicBlockAttr
attr
(
ctx
);
// init find max
if
(
attr
.
find_max
)
{
Tensor
*
max_input1
=
ctx
.
Output
<
Tensor
>
(
"MaxInput1"
);
Tensor
*
max_filter1
=
ctx
.
Output
<
Tensor
>
(
"MaxFilter1"
);
conv1_input_max_data
=
max_input1
->
mutable_data
<
float
>
(
place
);
conv1_filter_max_data
=
max_filter1
->
mutable_data
<
float
>
(
place
);
Tensor
*
max_input2
=
ctx
.
Output
<
Tensor
>
(
"MaxInput2"
);
Tensor
*
max_filter2
=
ctx
.
Output
<
Tensor
>
(
"MaxFilter2"
);
conv2_input_max_data
=
max_input2
->
mutable_data
<
float
>
(
place
);
conv2_filter_max_data
=
max_filter2
->
mutable_data
<
float
>
(
place
);
if
(
attr
.
has_shortcut
)
{
Tensor
*
max_input3
=
ctx
.
Output
<
Tensor
>
(
"MaxInput3"
);
Tensor
*
max_filter3
=
ctx
.
Output
<
Tensor
>
(
"MaxFilter3"
);
conv3_input_max_data
=
max_input3
->
mutable_data
<
float
>
(
place
);
conv3_filter_max_data
=
max_filter3
->
mutable_data
<
float
>
(
place
);
}
}
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
XPUDeviceContext
>();
xpu
::
ctx_guard
RAII_GUARD
(
dev_ctx
.
x_context
());
int
r
=
XPU_SUCCESS
;
// 1. short
const
XPUT
*
z_out_data
=
nullptr
;
if
(
attr
.
has_shortcut
)
{
Tensor
*
conv3_out
=
ctx
.
Output
<
Tensor
>
(
"Conv3"
);
const
Tensor
*
filter3
=
ctx
.
Input
<
Tensor
>
(
"Filter3"
);
auto
conv3_filter_data
=
reinterpret_cast
<
const
XPUT
*>
(
filter3
->
data
<
T
>
());
auto
conv3_output_data
=
reinterpret_cast
<
XPUT
*>
(
conv3_out
->
mutable_data
<
T
>
(
place
));
XPUT
*
conv3_input_l3_data
=
nullptr
;
XPUT
*
conv3_filter_l3_data
=
RAII_GUARD
.
alloc_l3
<
XPUT
>
(
attr
.
conv3_filter_numel
);
if
(
attr
.
find_max
)
{
r
=
xpu
::
findmax_copy_fusion
(
dev_ctx
.
x_context
(),
x_data
,
conv3_input_max_data
,
conv3_input_l3_data
,
attr
.
conv3_input_numel
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"findmax_copy_fusion"
);
r
=
xpu
::
findmax_copy_fusion
(
dev_ctx
.
x_context
(),
conv3_filter_data
,
conv3_filter_max_data
,
conv3_filter_l3_data
,
attr
.
conv3_filter_numel
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"findmax_copy_fusion"
);
}
xpu_conv2d
(
dev_ctx
.
x_context
(),
conv3_input_l3_data
!=
nullptr
?
conv3_input_l3_data
:
x_data
,
conv3_filter_l3_data
,
conv3_output_data
,
conv3_input_max_data
,
conv3_filter_max_data
,
attr
.
conv3_input_shape
,
attr
.
conv3_filter_shape
,
attr
.
padding3
,
attr
.
stride3
,
attr
.
dilation3
,
attr
.
group
);
// bn3
const
Tensor
*
scale3
=
ctx
.
Input
<
Tensor
>
(
"Scale3"
);
const
Tensor
*
bias3
=
ctx
.
Input
<
Tensor
>
(
"Bias3"
);
auto
bias3_data
=
bias3
->
data
<
float
>
();
auto
scale3_data
=
scale3
->
data
<
float
>
();
auto
bn3_output_data
=
RAII_GUARD
.
alloc
<
XPUT
>
(
attr
.
conv3_output_numel
);
PADDLE_ENFORCE_XDNN_NOT_NULL
(
bn3_output_data
);
if
(
!
attr
.
global_stats
)
{
Tensor
*
saved_mean3
=
ctx
.
Output
<
Tensor
>
(
"SavedMean3"
);
Tensor
*
saved_invstd3
=
ctx
.
Output
<
Tensor
>
(
"SavedInvstd3"
);
Tensor
*
running_mean3
=
ctx
.
Output
<
Tensor
>
(
"Mean3Out"
);
Tensor
*
running_var3
=
ctx
.
Output
<
Tensor
>
(
"Var3Out"
);
auto
saved_mean3_data
=
saved_mean3
->
mutable_data
<
float
>
(
place
);
auto
saved_invstd3_data
=
saved_invstd3
->
mutable_data
<
float
>
(
place
);
auto
running_mean3_data
=
running_mean3
->
mutable_data
<
float
>
(
place
);
auto
running_var3_data
=
running_var3
->
mutable_data
<
float
>
(
place
);
r
=
xpu
::
batch_norm_fusion
<
XPUT
>
(
dev_ctx
.
x_context
(),
conv3_output_data
,
bn3_output_data
,
attr
.
conv3_output_shape
[
0
],
attr
.
conv3_output_shape
[
1
],
attr
.
conv3_output_shape
[
3
],
attr
.
conv3_output_shape
[
3
],
attr
.
eps
,
attr
.
momentum
,
scale3_data
,
bias3_data
,
saved_mean3_data
,
saved_invstd3_data
,
running_mean3_data
,
running_var3_data
,
true
,
nullptr
,
xpu
::
Activation_t
::
LINEAR
,
nullptr
,
0
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"batch_norm_fusion"
);
}
else
{
const
auto
*
mean3
=
ctx
.
Input
<
Tensor
>
(
"Mean3"
);
const
auto
*
var3
=
ctx
.
Input
<
Tensor
>
(
"Var3"
);
const
auto
*
mean3_data
=
mean3
->
data
<
float
>
();
const
auto
*
variance3_data
=
var3
->
data
<
float
>
();
r
=
xpu
::
batch_norm_infer
<
XPUT
>
(
dev_ctx
.
x_context
(),
conv3_output_data
,
bn3_output_data
,
attr
.
conv3_output_shape
[
0
],
attr
.
conv3_output_shape
[
1
],
attr
.
conv3_output_shape
[
2
],
attr
.
conv3_output_shape
[
3
],
attr
.
eps
,
scale3_data
,
bias3_data
,
mean3_data
,
variance3_data
,
true
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"batch_norm_infer"
);
}
z_out_data
=
reinterpret_cast
<
const
XPUT
*>
(
bn3_output_data
);
}
else
{
z_out_data
=
x_data
;
}
// 2. conv1
XPUT
*
conv1_input_l3_data
=
nullptr
;
XPUT
*
conv1_filter_l3_data
=
RAII_GUARD
.
alloc_l3
<
XPUT
>
(
attr
.
conv1_filter_numel
);
if
(
attr
.
find_max
)
{
r
=
xpu
::
findmax_copy_fusion
(
dev_ctx
.
x_context
(),
x_data
,
conv1_input_max_data
,
conv1_input_l3_data
,
attr
.
conv1_input_numel
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"findmax_copy_fusion"
);
r
=
xpu
::
findmax_copy_fusion
(
dev_ctx
.
x_context
(),
conv1_filter_data
,
conv1_filter_max_data
,
conv1_filter_l3_data
,
attr
.
conv1_filter_numel
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"findmax_copy_fusion"
);
}
xpu_conv2d
(
dev_ctx
.
x_context
(),
conv1_input_l3_data
!=
nullptr
?
conv1_input_l3_data
:
x_data
,
conv1_filter_l3_data
,
conv1_output_data
,
conv1_input_max_data
,
conv1_filter_max_data
,
attr
.
conv1_input_shape
,
attr
.
conv1_filter_shape
,
attr
.
padding1
,
attr
.
stride1
,
attr
.
dilation1
,
attr
.
group
);
// 3. bn1 + relu
if
(
!
attr
.
global_stats
)
{
Tensor
*
saved_mean1
=
ctx
.
Output
<
Tensor
>
(
"SavedMean1"
);
Tensor
*
saved_invstd1
=
ctx
.
Output
<
Tensor
>
(
"SavedInvstd1"
);
Tensor
*
running_mean1
=
ctx
.
Output
<
Tensor
>
(
"Mean1Out"
);
Tensor
*
running_var1
=
ctx
.
Output
<
Tensor
>
(
"Var1Out"
);
auto
saved_mean1_data
=
saved_mean1
->
mutable_data
<
float
>
(
place
);
auto
saved_invstd1_data
=
saved_invstd1
->
mutable_data
<
float
>
(
place
);
auto
running_mean1_data
=
running_mean1
->
mutable_data
<
float
>
(
place
);
auto
running_var1_data
=
running_var1
->
mutable_data
<
float
>
(
place
);
r
=
xpu
::
batch_norm_fusion
<
XPUT
>
(
dev_ctx
.
x_context
(),
conv1_output_data
,
conv2_input_data
,
attr
.
conv1_output_shape
[
0
],
attr
.
conv1_output_shape
[
1
],
attr
.
conv1_output_shape
[
2
],
attr
.
conv1_output_shape
[
3
],
attr
.
eps
,
attr
.
momentum
,
scale1_data
,
bias1_data
,
saved_mean1_data
,
saved_invstd1_data
,
running_mean1_data
,
running_var1_data
,
true
,
nullptr
,
xpu
::
Activation_t
::
RELU
,
nullptr
,
0
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"batch_norm_fusion"
);
}
else
{
// bn --> relu
auto
bn1_output_data
=
RAII_GUARD
.
alloc
<
XPUT
>
(
attr
.
conv1_output_numel
);
PADDLE_ENFORCE_XDNN_NOT_NULL
(
bn1_output_data
);
const
auto
*
mean1
=
ctx
.
Input
<
Tensor
>
(
"Mean1"
);
const
auto
*
var1
=
ctx
.
Input
<
Tensor
>
(
"Var1"
);
const
auto
*
mean_data
=
mean1
->
data
<
float
>
();
const
auto
*
variance_data
=
var1
->
data
<
float
>
();
r
=
xpu
::
batch_norm_infer
<
XPUT
>
(
dev_ctx
.
x_context
(),
conv1_output_data
,
bn1_output_data
,
attr
.
conv1_output_shape
[
0
],
attr
.
conv1_output_shape
[
1
],
attr
.
conv1_output_shape
[
2
],
attr
.
conv1_output_shape
[
3
],
attr
.
eps
,
scale1_data
,
bias1_data
,
mean_data
,
variance_data
,
true
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"batch_norm_infer"
);
r
=
xpu
::
relu
(
dev_ctx
.
x_context
(),
bn1_output_data
,
conv2_input_data
,
attr
.
conv1_output_numel
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"relu"
);
}
// 4. conv2
XPUT
*
conv2_input_l3_data
=
nullptr
;
XPUT
*
conv2_filter_l3_data
=
RAII_GUARD
.
alloc_l3
<
XPUT
>
(
attr
.
conv2_filter_numel
);
if
(
attr
.
find_max
)
{
Tensor
*
max_input2
=
ctx
.
Output
<
Tensor
>
(
"MaxInput2"
);
Tensor
*
max_filter2
=
ctx
.
Output
<
Tensor
>
(
"MaxFilter2"
);
conv2_input_max_data
=
max_input2
->
mutable_data
<
float
>
(
place
);
conv2_filter_max_data
=
max_filter2
->
mutable_data
<
float
>
(
place
);
r
=
xpu
::
findmax_copy_fusion
(
dev_ctx
.
x_context
(),
conv2_input_data
,
conv2_input_max_data
,
conv2_input_l3_data
,
attr
.
conv2_input_numel
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"findmax_copy_fusion"
);
r
=
xpu
::
findmax_copy_fusion
(
dev_ctx
.
x_context
(),
conv2_filter_data
,
conv2_filter_max_data
,
conv2_filter_l3_data
,
attr
.
conv2_filter_numel
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"findmax_copy_fusion"
);
}
xpu_conv2d
(
dev_ctx
.
x_context
(),
conv2_input_l3_data
!=
nullptr
?
conv2_input_l3_data
:
conv2_input_data
,
conv2_filter_l3_data
,
conv2_output_data
,
conv2_input_max_data
,
conv2_filter_max_data
,
attr
.
conv2_input_shape
,
attr
.
conv2_filter_shape
,
attr
.
padding2
,
attr
.
stride2
,
attr
.
dilation2
,
attr
.
group
);
// 5. bn2
if
(
!
attr
.
global_stats
)
{
Tensor
*
saved_mean2
=
ctx
.
Output
<
Tensor
>
(
"SavedMean2"
);
Tensor
*
saved_var2
=
ctx
.
Output
<
Tensor
>
(
"SavedInvstd2"
);
Tensor
*
running_mean2
=
ctx
.
Output
<
Tensor
>
(
"Mean2Out"
);
Tensor
*
running_var2
=
ctx
.
Output
<
Tensor
>
(
"Var2Out"
);
auto
saved_mean2_data
=
saved_mean2
->
mutable_data
<
float
>
(
place
);
auto
saved_var2_data
=
saved_var2
->
mutable_data
<
float
>
(
place
);
auto
running_mean2_data
=
running_mean2
->
mutable_data
<
float
>
(
place
);
auto
running_var2_data
=
running_var2
->
mutable_data
<
float
>
(
place
);
r
=
xpu
::
batch_norm_fusion
<
XPUT
>
(
dev_ctx
.
x_context
(),
conv2_output_data
,
output_data
,
attr
.
conv2_output_shape
[
0
],
attr
.
conv2_output_shape
[
1
],
attr
.
conv2_output_shape
[
2
],
attr
.
conv2_output_shape
[
3
],
attr
.
eps
,
attr
.
momentum
,
scale2_data
,
bias2_data
,
saved_mean2_data
,
saved_var2_data
,
running_mean2_data
,
running_var2_data
,
true
,
z_out_data
,
xpu
::
Activation_t
::
RELU
,
nullptr
,
0
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"batch_norm_fusion"
);
}
else
{
auto
bn2_out_data
=
RAII_GUARD
.
alloc
<
XPUT
>
(
attr
.
conv2_output_numel
);
PADDLE_ENFORCE_XDNN_NOT_NULL
(
bn2_out_data
);
const
auto
*
mean2
=
ctx
.
Input
<
Tensor
>
(
"Mean2"
);
const
auto
*
var2
=
ctx
.
Input
<
Tensor
>
(
"Var2"
);
const
auto
*
mean_data
=
mean2
->
data
<
float
>
();
const
auto
*
variance_data
=
var2
->
data
<
float
>
();
r
=
xpu
::
batch_norm_infer
<
XPUT
>
(
dev_ctx
.
x_context
(),
conv2_output_data
,
bn2_out_data
,
attr
.
conv2_output_shape
[
0
],
attr
.
conv2_output_shape
[
1
],
attr
.
conv2_output_shape
[
2
],
attr
.
conv2_output_shape
[
3
],
attr
.
eps
,
scale2_data
,
bias2_data
,
mean_data
,
variance_data
,
true
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"batch_norm_infer"
);
r
=
xpu
::
add_activation_fusion
<
XPUT
>
(
dev_ctx
.
x_context
(),
bn2_out_data
,
z_out_data
,
output_data
,
output
->
numel
(),
nullptr
,
nullptr
,
nullptr
,
xpu
::
Activation_t
::
RELU
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"add_activation_fusion"
);
}
}
};
template
<
typename
T
>
class
ResNetBasicBlockGradXPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
using
XPUT
=
typename
XPUTypeTrait
<
T
>::
Type
;
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_xpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"It must use XPUPlace."
));
const
Tensor
*
y_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
const
Tensor
*
y
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
filter1
=
ctx
.
Input
<
Tensor
>
(
"Filter1"
);
const
Tensor
*
scale1
=
ctx
.
Input
<
Tensor
>
(
"Scale1"
);
const
Tensor
*
filter2
=
ctx
.
Input
<
Tensor
>
(
"Filter2"
);
const
Tensor
*
scale2
=
ctx
.
Input
<
Tensor
>
(
"Scale2"
);
const
Tensor
*
saved_mean1
=
ctx
.
Input
<
Tensor
>
(
"SavedMean1"
);
const
Tensor
*
saved_invstd1
=
ctx
.
Input
<
Tensor
>
(
"SavedInvstd1"
);
const
Tensor
*
saved_mean2
=
ctx
.
Input
<
Tensor
>
(
"SavedMean2"
);
const
Tensor
*
saved_invstd2
=
ctx
.
Input
<
Tensor
>
(
"SavedInvstd2"
);
const
Tensor
*
conv1_out
=
ctx
.
Input
<
Tensor
>
(
"Conv1"
);
const
Tensor
*
conv2_out
=
ctx
.
Input
<
Tensor
>
(
"Conv2"
);
const
Tensor
*
conv2_input
=
ctx
.
Input
<
Tensor
>
(
"Conv2Input"
);
const
Tensor
*
filter3
=
ctx
.
Input
<
Tensor
>
(
"Filter3"
);
const
Tensor
*
conv3_out
=
ctx
.
Input
<
Tensor
>
(
"Conv3"
);
const
Tensor
*
scale3
=
ctx
.
Input
<
Tensor
>
(
"Scale3"
);
const
Tensor
*
saved_mean3
=
ctx
.
Input
<
Tensor
>
(
"SavedMean3"
);
const
Tensor
*
saved_invstd3
=
ctx
.
Input
<
Tensor
>
(
"SavedInvstd3"
);
const
Tensor
*
conv1_input_max
=
ctx
.
Input
<
Tensor
>
(
"MaxInput1"
);
const
Tensor
*
conv1_filter_max
=
ctx
.
Input
<
Tensor
>
(
"MaxFilter1"
);
const
Tensor
*
conv2_input_max
=
ctx
.
Input
<
Tensor
>
(
"MaxInput2"
);
const
Tensor
*
conv2_filter_max
=
ctx
.
Input
<
Tensor
>
(
"MaxFilter2"
);
const
Tensor
*
conv3_input_max
=
ctx
.
Input
<
Tensor
>
(
"MaxInput3"
);
const
Tensor
*
conv3_filter_max
=
ctx
.
Input
<
Tensor
>
(
"MaxFilter3"
);
Tensor
*
x_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
Tensor
*
filter1_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Filter1"
));
Tensor
*
scale1_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Scale1"
));
Tensor
*
bias1_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Bias1"
));
Tensor
*
filter2_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Filter2"
));
Tensor
*
scale2_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Scale2"
));
Tensor
*
bias2_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Bias2"
));
Tensor
*
filter3_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Filter3"
));
Tensor
*
scale3_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Scale3"
));
Tensor
*
bias3_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Bias3"
));
// attrs
ResnetBasicBlockGradAttr
attr
(
ctx
);
auto
place
=
ctx
.
GetPlace
();
const
auto
*
y_grad_data
=
reinterpret_cast
<
const
XPUT
*>
(
y_grad
->
data
<
T
>
());
const
auto
*
y_data
=
reinterpret_cast
<
const
XPUT
*>
(
y
->
data
<
T
>
());
const
auto
*
x_data
=
reinterpret_cast
<
const
XPUT
*>
(
x
->
data
<
T
>
());
const
auto
*
conv1_output_data
=
reinterpret_cast
<
const
XPUT
*>
(
conv1_out
->
data
<
T
>
());
const
auto
*
conv1_filter_data
=
reinterpret_cast
<
const
XPUT
*>
(
filter1
->
data
<
T
>
());
const
auto
*
conv2_input_data
=
reinterpret_cast
<
const
XPUT
*>
(
conv2_input
->
data
<
T
>
());
const
auto
*
conv2_output_data
=
reinterpret_cast
<
const
XPUT
*>
(
conv2_out
->
data
<
T
>
());
const
auto
*
conv2_filter_data
=
reinterpret_cast
<
const
XPUT
*>
(
filter2
->
data
<
T
>
());
const
auto
*
scale2_data
=
scale2
->
data
<
float
>
();
const
auto
*
saved_mean2_data
=
saved_mean2
->
data
<
float
>
();
const
auto
*
saved_invstd2_data
=
saved_invstd2
->
data
<
float
>
();
const
auto
*
scale1_data
=
scale1
->
data
<
float
>
();
const
auto
*
saved_mean1_data
=
saved_mean1
->
data
<
float
>
();
const
auto
*
saved_invstd1_data
=
saved_invstd1
->
data
<
float
>
();
auto
*
scale2_grad_data
=
scale2_grad
->
mutable_data
<
float
>
(
place
);
auto
*
bias2_grad_data
=
bias2_grad
->
mutable_data
<
float
>
(
place
);
const
float
*
conv1_input_max_data
=
nullptr
;
const
float
*
conv1_filter_max_data
=
nullptr
;
const
float
*
conv2_input_max_data
=
nullptr
;
const
float
*
conv2_filter_max_data
=
nullptr
;
const
float
*
conv3_input_max_data
=
nullptr
;
const
float
*
conv3_filter_max_data
=
nullptr
;
if
(
attr
.
find_max
)
{
conv1_input_max_data
=
reinterpret_cast
<
const
float
*>
(
conv1_input_max
->
data
<
float
>
());
conv1_filter_max_data
=
reinterpret_cast
<
const
float
*>
(
conv1_filter_max
->
data
<
float
>
());
conv2_input_max_data
=
reinterpret_cast
<
const
float
*>
(
conv2_input_max
->
data
<
float
>
());
conv2_filter_max_data
=
reinterpret_cast
<
const
float
*>
(
conv2_filter_max
->
data
<
float
>
());
if
(
attr
.
has_shortcut
)
{
conv3_input_max_data
=
reinterpret_cast
<
const
float
*>
(
conv3_input_max
->
data
<
float
>
());
conv3_filter_max_data
=
reinterpret_cast
<
const
float
*>
(
conv3_filter_max
->
data
<
float
>
());
}
}
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
XPUDeviceContext
>();
xpu
::
ctx_guard
RAII_GUARD
(
dev_ctx
.
x_context
());
int
r
=
XPU_SUCCESS
;
// 0. bn2, bn2_fusion grad
auto
conv2_output_grad_data
=
RAII_GUARD
.
alloc
<
XPUT
>
(
attr
.
conv2_output_numel
);
PADDLE_ENFORCE_XDNN_NOT_NULL
(
conv2_output_grad_data
);
XPUT
*
z_output_grad_data
=
nullptr
;
XPUT
*
z_grad_data
=
nullptr
;
if
(
!
attr
.
has_shortcut
)
{
z_output_grad_data
=
RAII_GUARD
.
alloc
<
XPUT
>
(
attr
.
conv1_input_numel
);
PADDLE_ENFORCE_XDNN_NOT_NULL
(
z_output_grad_data
);
z_grad_data
=
z_output_grad_data
;
}
else
{
z_output_grad_data
=
RAII_GUARD
.
alloc
<
XPUT
>
(
attr
.
conv3_output_numel
);
PADDLE_ENFORCE_XDNN_NOT_NULL
(
z_output_grad_data
);
z_grad_data
=
RAII_GUARD
.
alloc
<
XPUT
>
(
attr
.
conv1_input_numel
);
PADDLE_ENFORCE_XDNN_NOT_NULL
(
z_grad_data
);
}
r
=
xpu
::
batch_norm_grad_fusion
<
XPUT
>
(
dev_ctx
.
x_context
(),
conv2_output_data
,
y_data
,
y_grad_data
,
conv2_output_grad_data
,
attr
.
conv2_output_shape
[
0
],
attr
.
conv2_output_shape
[
1
],
attr
.
conv2_output_shape
[
2
],
attr
.
conv2_output_shape
[
3
],
scale2_data
,
saved_mean2_data
,
saved_invstd2_data
,
scale2_grad_data
,
bias2_grad_data
,
true
,
z_output_grad_data
,
xpu
::
Activation_t
::
RELU
,
nullptr
,
0
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"batch_norm_grad_fusion"
);
if
(
attr
.
has_shortcut
)
{
// bn3 grad
const
auto
*
conv3_output_data
=
reinterpret_cast
<
const
XPUT
*>
(
conv3_out
->
data
<
T
>
());
const
auto
*
scale3_data
=
scale3
->
data
<
float
>
();
const
auto
*
saved_mean3_data
=
saved_mean3
->
data
<
float
>
();
const
auto
*
saved_invstd3_data
=
saved_invstd3
->
data
<
float
>
();
auto
*
scale3_grad_data
=
scale3_grad
->
mutable_data
<
float
>
(
place
);
auto
*
bias3_grad_data
=
bias3_grad
->
mutable_data
<
float
>
(
place
);
auto
*
conv3_output_grad_data
=
RAII_GUARD
.
alloc
<
XPUT
>
(
attr
.
conv3_output_numel
);
r
=
xpu
::
batch_norm_grad
<
XPUT
>
(
dev_ctx
.
x_context
(),
conv3_output_data
,
z_output_grad_data
,
conv3_output_grad_data
,
attr
.
conv3_output_shape
[
0
],
attr
.
conv3_output_shape
[
1
],
attr
.
conv3_output_shape
[
2
],
attr
.
conv3_output_shape
[
3
],
scale3_data
,
saved_mean3_data
,
saved_invstd3_data
,
scale3_grad_data
,
bias3_grad_data
,
true
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"batch_norm_grad"
);
// conv3 grad
auto
*
conv3_filter_grad_data
=
reinterpret_cast
<
XPUT
*>
(
filter3_grad
->
mutable_data
<
T
>
(
place
));
auto
*
conv3_filter_data
=
reinterpret_cast
<
const
XPUT
*>
(
filter3
->
data
<
T
>
());
xpu_conv2d_grad
(
dev_ctx
.
x_context
(),
x_data
,
conv3_filter_data
,
conv3_output_grad_data
,
z_grad_data
,
conv3_filter_grad_data
,
conv3_input_max_data
,
conv3_filter_max_data
,
attr
.
conv3_input_shape
,
attr
.
conv3_filter_shape
,
attr
.
padding3
,
attr
.
stride3
,
attr
.
dilation3
,
attr
.
group
);
}
// 2. conv2_grad
auto
*
conv2_filter_grad_data
=
reinterpret_cast
<
XPUT
*>
(
filter2_grad
->
mutable_data
<
T
>
(
place
));
auto
*
conv2_input_grad_data
=
RAII_GUARD
.
alloc
<
XPUT
>
(
attr
.
conv2_input_numel
);
xpu_conv2d_grad
(
dev_ctx
.
x_context
(),
conv2_input_data
,
conv2_filter_data
,
conv2_output_grad_data
,
conv2_input_grad_data
,
conv2_filter_grad_data
,
conv2_input_max_data
,
conv2_filter_max_data
,
attr
.
conv2_input_shape
,
attr
.
conv2_filter_shape
,
attr
.
padding2
,
attr
.
stride2
,
attr
.
dilation2
,
attr
.
group
);
// 3. b1 grad
auto
*
conv1_output_grad_data
=
RAII_GUARD
.
alloc
<
XPUT
>
(
attr
.
conv1_output_numel
);
PADDLE_ENFORCE_XDNN_NOT_NULL
(
conv1_output_grad_data
);
auto
*
scale1_grad_data
=
scale1_grad
->
mutable_data
<
float
>
(
ctx
.
GetPlace
());
auto
*
bias1_grad_data
=
bias1_grad
->
mutable_data
<
float
>
(
ctx
.
GetPlace
());
r
=
xpu
::
batch_norm_grad_fusion
<
XPUT
>
(
dev_ctx
.
x_context
(),
conv1_output_data
,
conv2_input_data
,
conv2_input_grad_data
,
conv1_output_grad_data
,
attr
.
conv1_output_shape
[
0
],
attr
.
conv1_output_shape
[
1
],
attr
.
conv1_output_shape
[
2
],
attr
.
conv1_output_shape
[
3
],
scale1_data
,
saved_mean1_data
,
saved_invstd1_data
,
scale1_grad_data
,
bias1_grad_data
,
true
,
nullptr
,
xpu
::
Activation_t
::
RELU
,
nullptr
,
0
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"batch_norm_grad_fusion"
);
// 4. conv1_grad
auto
*
x_grad_data
=
reinterpret_cast
<
XPUT
*>
(
x_grad
->
mutable_data
<
T
>
(
place
));
auto
*
conv1_filter_grad_data
=
reinterpret_cast
<
XPUT
*>
(
filter1_grad
->
mutable_data
<
T
>
(
place
));
xpu_conv2d_grad
(
dev_ctx
.
x_context
(),
x_data
,
conv1_filter_data
,
conv1_output_grad_data
,
x_grad_data
,
conv1_filter_grad_data
,
conv1_input_max_data
,
conv1_filter_max_data
,
attr
.
conv1_input_shape
,
attr
.
conv1_filter_shape
,
attr
.
padding1
,
attr
.
stride1
,
attr
.
dilation1
,
attr
.
group
);
// add z_grad to x_grad
r
=
xpu
::
add
<
XPUT
>
(
dev_ctx
.
x_context
(),
x_grad_data
,
z_grad_data
,
x_grad_data
,
x
->
numel
());
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"add"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_XPU_KERNEL
(
resnet_basic_block
,
ops
::
ResNetBasicBlockXPUKernel
<
float
>
,
ops
::
ResNetBasicBlockXPUKernel
<
paddle
::
platform
::
float16
>
);
REGISTER_OP_XPU_KERNEL
(
resnet_basic_block_grad
,
ops
::
ResNetBasicBlockGradXPUKernel
<
float
>
,
ops
::
ResNetBasicBlockGradXPUKernel
<
paddle
::
platform
::
float16
>
);
#endif
paddle/fluid/platform/device/xpu/xpu2_op_list.h
浏览文件 @
d7be46b3
...
...
@@ -505,6 +505,14 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"sequence_conv_grad"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
// Fused op
{
"resnet_basic_block_grad"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
()),
pOpKernelType
(
vartype
::
FP16
,
XPUPlace
())})},
{
"resnet_basic_block"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
()),
pOpKernelType
(
vartype
::
FP16
,
XPUPlace
())})},
};
return
s_xpu2_kernels
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录