Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b2b78cd4
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2297
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
b2b78cd4
编写于
5月 26, 2022
作者:
Y
YuanRisheng
提交者:
GitHub
5月 26, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move instance_norm_double_grad (#43021)
上级
6af32a7f
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
832 addition
and
834 deletion
+832
-834
paddle/fluid/operators/instance_norm_op.cc
paddle/fluid/operators/instance_norm_op.cc
+18
-349
paddle/fluid/operators/instance_norm_op.cu
paddle/fluid/operators/instance_norm_op.cu
+0
-434
paddle/fluid/operators/instance_norm_op.h
paddle/fluid/operators/instance_norm_op.h
+0
-35
paddle/phi/infermeta/backward.cc
paddle/phi/infermeta/backward.cc
+57
-0
paddle/phi/infermeta/backward.h
paddle/phi/infermeta/backward.h
+24
-0
paddle/phi/infermeta/ternary.cc
paddle/phi/infermeta/ternary.cc
+105
-0
paddle/phi/infermeta/ternary.h
paddle/phi/infermeta/ternary.h
+9
-0
paddle/phi/kernels/cpu/instance_norm_grad_kernel.cc
paddle/phi/kernels/cpu/instance_norm_grad_kernel.cc
+202
-0
paddle/phi/kernels/funcs/norm_utils.h
paddle/phi/kernels/funcs/norm_utils.h
+46
-0
paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu
paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu
+2
-2
paddle/phi/kernels/gpu/batch_norm_kernel.cu
paddle/phi/kernels/gpu/batch_norm_kernel.cu
+2
-2
paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu
paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu
+332
-9
paddle/phi/kernels/gpu/instance_norm_kernel.cu
paddle/phi/kernels/gpu/instance_norm_kernel.cu
+2
-3
paddle/phi/kernels/instance_norm_grad_kernel.h
paddle/phi/kernels/instance_norm_grad_kernel.h
+15
-0
paddle/phi/ops/compat/instance_norm_sig.cc
paddle/phi/ops/compat/instance_norm_sig.cc
+18
-0
未找到文件。
paddle/fluid/operators/instance_norm_op.cc
浏览文件 @
b2b78cd4
...
@@ -17,93 +17,16 @@ limitations under the License. */
...
@@ -17,93 +17,16 @@ limitations under the License. */
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
void
InstanceNormOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"InstanceNorm"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Y"
),
"Output"
,
"Y"
,
"InstanceNorm"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SavedMean"
),
"Output"
,
"SavedMean"
,
"InstanceNorm"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SavedVariance"
),
"Output"
,
"SavedVariance"
,
"InstanceNorm"
);
const
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE_NE
(
phi
::
product
(
x_dims
),
0
,
platform
::
errors
::
PreconditionNotMet
(
"The Input variable X(%s) has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."
,
ctx
->
Inputs
(
"X"
).
front
()));
PADDLE_ENFORCE_GE
(
x_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"ShapeError: the dimension of input X must "
"greater than or equal to 2. But received: the shape of input "
"X = [%s], the dimension of input X =[%d]"
,
x_dims
,
x_dims
.
size
()));
PADDLE_ENFORCE_LE
(
x_dims
.
size
(),
5
,
platform
::
errors
::
InvalidArgument
(
"ShapeError: the dimension of input X must "
"smaller than or equal to 5, But received: the shape of input "
"X = [%s], the dimension of input X = [%d]"
,
x_dims
,
x_dims
.
size
()));
auto
N
=
x_dims
[
0
];
auto
C
=
x_dims
[
1
];
auto
NxC
=
N
*
C
;
if
(
ctx
->
HasInput
(
"Scale"
))
{
auto
scale_dim
=
ctx
->
GetInputDim
(
"Scale"
);
PADDLE_ENFORCE_EQ
(
scale_dim
.
size
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"ShapeError: the dimension of scale must equal to 1."
"But received: the shape of scale is [%s], the dimension "
"of scale is [%d]"
,
scale_dim
,
scale_dim
.
size
()));
bool
check
=
!
((
!
ctx
->
IsRuntime
())
&&
(
phi
::
product
(
scale_dim
)
<=
0
));
if
(
check
)
{
PADDLE_ENFORCE_EQ
(
scale_dim
[
0
],
C
,
platform
::
errors
::
InvalidArgument
(
"ShapeError: the shape of scale must equal to [%d]"
"But received: the shape of scale is [%d]"
,
C
,
scale_dim
[
0
]));
}
}
if
(
ctx
->
HasInput
(
"Bias"
))
{
auto
bias_dim
=
ctx
->
GetInputDim
(
"Bias"
);
PADDLE_ENFORCE_EQ
(
bias_dim
.
size
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"ShapeError: the dimension of bias must equal to 1."
"But received: the shape of bias is [%s],the dimension "
"of bias is [%d]"
,
bias_dim
,
bias_dim
.
size
()));
bool
check
=
!
((
!
ctx
->
IsRuntime
())
&&
(
phi
::
product
(
bias_dim
)
<=
0
));
if
(
check
)
{
PADDLE_ENFORCE_EQ
(
bias_dim
[
0
],
C
,
platform
::
errors
::
InvalidArgument
(
"ShapeError: the shape of bias must equal to [%d]"
"But received: the shape of bias is [%d]"
,
C
,
bias_dim
[
0
]));
}
}
ctx
->
SetOutputDim
(
"Y"
,
x_dims
);
ctx
->
SetOutputDim
(
"SavedMean"
,
{
NxC
});
ctx
->
SetOutputDim
(
"SavedVariance"
,
{
NxC
});
ctx
->
ShareLoD
(
"X"
,
"Y"
);
}
framework
::
OpKernelType
InstanceNormOp
::
GetExpectedKernelType
(
framework
::
OpKernelType
InstanceNormOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
input_data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
auto
input_data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
...
@@ -170,29 +93,6 @@ NCHW `[batch, in_channels, in_height, in_width]`
...
@@ -170,29 +93,6 @@ NCHW `[batch, in_channels, in_height, in_width]`
)DOC"
);
)DOC"
);
}
}
void
InstanceNormGradOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"InstanceNormGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
"Input"
,
framework
::
GradVarName
(
"Y"
),
"InstanceNormGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"SavedMean"
),
"Input"
,
"SavedMean"
,
"InstanceNormGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"SavedVariance"
),
"Input"
,
"SavedVariance"
,
"InstanceNormGrad"
);
// check output
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
"Output"
,
framework
::
GradVarName
(
"X"
),
"InstanceNormGrad"
);
const
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
const
int
C
=
x_dims
[
1
];
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Scale"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Scale"
),
{
C
});
}
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Bias"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Bias"
),
{
C
});
}
}
framework
::
OpKernelType
InstanceNormGradOp
::
GetExpectedKernelType
(
framework
::
OpKernelType
InstanceNormGradOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
const
framework
::
ExecutionContext
&
ctx
)
const
{
const
auto
*
var
=
ctx
.
InputVar
(
framework
::
GradVarName
(
"Y"
));
const
auto
*
var
=
ctx
.
InputVar
(
framework
::
GradVarName
(
"Y"
));
...
@@ -214,34 +114,6 @@ framework::OpKernelType InstanceNormGradOp::GetExpectedKernelType(
...
@@ -214,34 +114,6 @@ framework::OpKernelType InstanceNormGradOp::GetExpectedKernelType(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
}
}
void
InstanceNormDoubleGradOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"InstanceNormDoubleGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"SavedMean"
),
"Input"
,
"SavedMean"
,
"InstanceNormDoubleGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"SavedVariance"
),
"Input"
,
"SavedVariance"
,
"InstanceNormDoubleGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"DDX"
),
"Input"
,
"DDX"
,
"InstanceNormDoubleGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"DY"
),
"Input"
,
"DY"
,
"InstanceNormDoubleGrad"
);
// check output
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"DX"
),
"Output"
,
"DX"
,
"InstanceNormDoubleGrad"
);
const
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
const
int
C
=
x_dims
[
1
];
if
(
ctx
->
HasOutput
(
"DX"
))
{
ctx
->
SetOutputDim
(
"DX"
,
x_dims
);
}
if
(
ctx
->
HasOutput
(
"DScale"
))
{
ctx
->
SetOutputDim
(
"DScale"
,
{
C
});
}
if
(
ctx
->
HasOutput
(
"DDY"
))
{
ctx
->
ShareDim
(
"X"
,
"DDY"
);
}
}
framework
::
OpKernelType
InstanceNormDoubleGradOp
::
GetExpectedKernelType
(
framework
::
OpKernelType
InstanceNormDoubleGradOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
const
framework
::
ExecutionContext
&
ctx
)
const
{
const
auto
*
var
=
ctx
.
InputVar
(
"DY"
);
const
auto
*
var
=
ctx
.
InputVar
(
"DY"
);
...
@@ -263,213 +135,6 @@ framework::OpKernelType InstanceNormDoubleGradOp::GetExpectedKernelType(
...
@@ -263,213 +135,6 @@ framework::OpKernelType InstanceNormDoubleGradOp::GetExpectedKernelType(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
}
}
template
<
typename
T
>
class
InstanceNormDoubleGradKernel
<
platform
::
CPUDeviceContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
auto
*
Scale
=
ctx
.
Input
<
Tensor
>
(
"Scale"
);
const
auto
*
dY
=
ctx
.
Input
<
Tensor
>
(
"DY"
);
const
auto
*
Saved_mean
=
ctx
.
Input
<
Tensor
>
(
"SavedMean"
);
const
auto
*
Saved_variance
=
ctx
.
Input
<
Tensor
>
(
"SavedVariance"
);
const
auto
*
ddX
=
ctx
.
Input
<
Tensor
>
(
"DDX"
);
const
auto
*
ddScale
=
ctx
.
Input
<
Tensor
>
(
"DDScale"
);
const
auto
*
ddBias
=
ctx
.
Input
<
Tensor
>
(
"DDBias"
);
auto
*
dX
=
ctx
.
Output
<
Tensor
>
(
"DX"
);
auto
*
dScale
=
ctx
.
Output
<
Tensor
>
(
"DScale"
);
auto
*
ddY
=
ctx
.
Output
<
Tensor
>
(
"DDY"
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>();
phi
::
funcs
::
SetConstant
<
platform
::
CPUDeviceContext
,
T
>
set_constant
;
const
auto
&
x_dims
=
X
->
dims
();
int
N
,
C
,
H
,
W
,
D
;
ExtractNCWHD
(
x_dims
,
DataLayout
::
kNCHW
,
&
N
,
&
C
,
&
H
,
&
W
,
&
D
);
const
int
sample_size
=
X
->
numel
()
/
N
/
C
;
const
int
NxC
=
N
*
C
;
const
T
*
mean_data
=
Saved_mean
->
data
<
T
>
();
const
T
*
inv_var_data
=
Saved_variance
->
data
<
T
>
();
Tensor
mean_tensor
;
Tensor
inv_var_tensor
;
ConstEigenArrayMap
<
T
>
x_arr
(
X
->
data
<
T
>
(),
sample_size
,
NxC
);
ConstEigenVectorArrayMap
<
T
>
mean_arr
(
mean_data
,
NxC
);
ConstEigenVectorArrayMap
<
T
>
inv_var_arr
(
inv_var_data
,
NxC
);
Tensor
mean_tile
;
mean_tile
.
Resize
({
sample_size
,
NxC
});
mean_tile
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
EigenArrayMap
<
T
>
mean_tile_data
(
mean_tile
.
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
sample_size
,
NxC
);
Tensor
inv_var_tile
;
inv_var_tile
.
Resize
({
sample_size
,
NxC
});
inv_var_tile
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
EigenArrayMap
<
T
>
inv_var_tile_data
(
inv_var_tile
.
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
sample_size
,
NxC
);
mean_tile_data
=
mean_arr
.
transpose
().
replicate
(
sample_size
,
1
);
inv_var_tile_data
=
inv_var_arr
.
transpose
().
replicate
(
sample_size
,
1
);
Tensor
Scale_data
;
if
(
!
Scale
)
{
Scale_data
.
mutable_data
<
T
>
({
C
},
ctx
.
GetPlace
());
set_constant
(
dev_ctx
,
&
Scale_data
,
static_cast
<
T
>
(
1
));
}
ConstEigenVectorArrayMap
<
T
>
scale_arr
(
Scale
?
Scale
->
data
<
T
>
()
:
Scale_data
.
data
<
T
>
(),
C
);
Tensor
scale_tile
;
scale_tile
.
Resize
({
sample_size
,
NxC
});
scale_tile
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
EigenArrayMap
<
T
>
scale_tile_data
(
scale_tile
.
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
sample_size
,
NxC
);
scale_tile_data
=
scale_arr
.
transpose
().
replicate
(
sample_size
,
N
);
ConstEigenArrayMap
<
T
>
dy_arr
(
dY
->
data
<
T
>
(),
sample_size
,
NxC
);
ConstEigenArrayMap
<
T
>
ddx_arr
(
ddX
->
data
<
T
>
(),
sample_size
,
NxC
);
// math: dx = scale * ((x - mean) * inv_var / HxW * (np.mean(ddx,
// axis=(h,w)) *
// np.sum(dy, axis=(h,w)) -
// np.sum(dy * ddx, axis=(h,w)) + 3 * np.mean(dy * (x - mean),
// axis=(h,w)) * inv_var.pow(2) *
// np.sum(ddx * (x - mean), axis=(h,w))) + inv_var.pow(3) / HxW *
// np.sum(ddx * (x - mean)) *
// (np.mean(dy, axis=(h,w)) - dy) + inv_var.pow(3) / HxW *
// np.sum(dy,
// axis=(h,w)) * (x - mean) *
// (np.mean(ddx, axis=(h,w)) - ddx)) + ddr * (dy * inv_var -
// inv_var *
// np.mean(dy, axis=(h,w)) -
// inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean),
// axis=(h,w)))
Tensor
x_sub_mean_mul_invstd
;
x_sub_mean_mul_invstd
.
Resize
({
sample_size
,
NxC
});
x_sub_mean_mul_invstd
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
EigenArrayMap
<
T
>
x_sub_mean_mul_invstd_arr
(
x_sub_mean_mul_invstd
.
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
sample_size
,
NxC
);
x_sub_mean_mul_invstd_arr
=
(
x_arr
-
mean_tile_data
)
*
inv_var_tile_data
;
if
(
dX
)
{
dX
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
set_constant
(
dev_ctx
,
dX
,
static_cast
<
T
>
(
0
));
EigenArrayMap
<
T
>
dx_arr
(
dX
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
sample_size
,
NxC
);
if
(
ddX
)
{
dx_arr
+=
x_sub_mean_mul_invstd_arr
*
inv_var_tile_data
*
inv_var_tile_data
/
sample_size
*
(
ddx_arr
.
colwise
().
sum
()
*
dy_arr
.
colwise
().
sum
()
/
sample_size
-
(
dy_arr
*
ddx_arr
).
colwise
().
sum
()
+
3.
*
(
dy_arr
*
x_sub_mean_mul_invstd_arr
).
colwise
().
sum
()
*
(
ddx_arr
*
x_sub_mean_mul_invstd_arr
).
colwise
().
sum
()
/
sample_size
);
dx_arr
+=
(
ddx_arr
*
x_sub_mean_mul_invstd_arr
).
colwise
().
sum
()
/
sample_size
*
inv_var_tile_data
*
inv_var_tile_data
*
(
dy_arr
.
colwise
().
sum
()
/
sample_size
-
dy_arr
);
dx_arr
+=
(
dy_arr
*
x_sub_mean_mul_invstd_arr
).
colwise
().
sum
()
/
sample_size
*
inv_var_tile_data
*
inv_var_tile_data
*
(
ddx_arr
.
colwise
().
sum
()
/
sample_size
-
ddx_arr
);
dx_arr
=
scale_tile_data
*
dx_arr
;
}
if
(
ddScale
)
{
ConstEigenVectorArrayMap
<
T
>
ddscale_arr
(
ddScale
->
data
<
T
>
(),
C
);
Tensor
ddscale_tile
;
ddscale_tile
.
Resize
({
sample_size
,
NxC
});
ddscale_tile
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
EigenArrayMap
<
T
>
ddscale_tile_data
(
ddscale_tile
.
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
sample_size
,
NxC
);
ddscale_tile_data
=
ddscale_arr
.
transpose
().
replicate
(
sample_size
,
N
);
dx_arr
+=
(
dy_arr
*
inv_var_tile_data
-
dy_arr
.
colwise
().
sum
()
/
sample_size
*
inv_var_tile_data
-
x_sub_mean_mul_invstd_arr
*
inv_var_tile_data
*
(
dy_arr
*
x_sub_mean_mul_invstd_arr
).
colwise
().
sum
()
/
sample_size
)
*
ddscale_tile_data
;
}
}
if
(
dScale
)
{
// math: dscale = inv_var * (dy - np.mean(dy, axis=(h,w) - (x-mean) *
// inv_var.pow(2) * np.mean(dy * (x-mean), axis=(h,w)))) * ddx
dScale
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
set_constant
(
dev_ctx
,
dScale
,
static_cast
<
T
>
(
0
));
EigenVectorArrayMap
<
T
>
dscale_arr
(
dScale
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
C
);
if
(
ddX
)
{
Tensor
first_grad
;
first_grad
.
Resize
({
sample_size
,
NxC
});
first_grad
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
set_constant
(
dev_ctx
,
&
first_grad
,
static_cast
<
T
>
(
0
));
EigenArrayMap
<
T
>
first_grad_arr
(
first_grad
.
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
sample_size
,
NxC
);
first_grad_arr
+=
inv_var_tile_data
*
(
dy_arr
-
dy_arr
.
colwise
().
sum
().
replicate
(
sample_size
,
1
)
/
sample_size
-
x_sub_mean_mul_invstd_arr
*
(
dy_arr
*
x_sub_mean_mul_invstd_arr
)
.
colwise
()
.
sum
()
.
replicate
(
sample_size
,
1
)
/
sample_size
);
first_grad_arr
=
first_grad_arr
*
ddx_arr
;
for
(
int
nc
=
0
;
nc
<
NxC
;
++
nc
)
{
int
c
=
nc
%
C
;
dscale_arr
(
c
)
+=
first_grad_arr
.
colwise
().
sum
()(
nc
);
}
}
}
if
(
ddY
)
{
// math: ddy = (x - mean) * inv_var * ddscale + ddbias +
// scale * inv_var * (ddx - (x - mean) * inv_var.pow(2) *
// np.mean(ddx * (x - mean), axis=(h,w)))
ddY
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
set_constant
(
dev_ctx
,
ddY
,
static_cast
<
T
>
(
0
));
EigenArrayMap
<
T
>
ddy_arr
(
ddY
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
sample_size
,
NxC
);
if
(
ddX
)
{
ddy_arr
+=
scale_tile_data
*
inv_var_tile_data
*
(
ddx_arr
-
ddx_arr
.
colwise
().
sum
()
/
sample_size
-
x_sub_mean_mul_invstd_arr
*
(
ddx_arr
*
x_sub_mean_mul_invstd_arr
).
colwise
().
sum
()
/
sample_size
);
}
if
(
ddScale
&&
ddBias
)
{
ConstEigenVectorArrayMap
<
T
>
ddscale_arr
(
ddScale
->
data
<
T
>
(),
C
);
Tensor
ddscale_tile
;
ddscale_tile
.
Resize
({
sample_size
,
NxC
});
ddscale_tile
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
EigenArrayMap
<
T
>
ddscale_tile_data
(
ddscale_tile
.
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
sample_size
,
NxC
);
ddscale_tile_data
=
ddscale_arr
.
transpose
().
replicate
(
sample_size
,
N
);
ConstEigenVectorArrayMap
<
T
>
ddbias_arr
(
ddBias
->
data
<
T
>
(),
C
);
Tensor
ddbias_tile
;
ddbias_tile
.
Resize
({
sample_size
,
NxC
});
ddbias_tile
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
EigenArrayMap
<
T
>
ddbias_tile_data
(
ddbias_tile
.
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
sample_size
,
NxC
);
ddbias_tile_data
=
ddbias_arr
.
transpose
().
replicate
(
sample_size
,
N
);
ddy_arr
+=
x_sub_mean_mul_invstd_arr
*
ddscale_tile_data
;
ddy_arr
+=
ddbias_tile_data
;
}
}
}
};
DECLARE_INPLACE_OP_INFERER
(
InstanceNormDoubleGradOpInplaceInferer
,
DECLARE_INPLACE_OP_INFERER
(
InstanceNormDoubleGradOpInplaceInferer
,
{
"DY"
,
"DDY"
});
{
"DY"
,
"DDY"
});
...
@@ -477,22 +142,26 @@ DECLARE_INPLACE_OP_INFERER(InstanceNormDoubleGradOpInplaceInferer,
...
@@ -477,22 +142,26 @@ DECLARE_INPLACE_OP_INFERER(InstanceNormDoubleGradOpInplaceInferer,
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
DECLARE_INFER_SHAPE_FUNCTOR
(
instance_norm
,
InstanceNormInferShapeFunctor
,
PD_INFER_META
(
phi
::
InstanceNormInferMeta
));
DECLARE_INFER_SHAPE_FUNCTOR
(
instance_norm_grad
,
InstanceNormGradInferShapeFunctor
,
PD_INFER_META
(
phi
::
InstanceNormGradInferMeta
));
DECLARE_INFER_SHAPE_FUNCTOR
(
instance_norm_grad_grad
,
InstanceNormDoubleGradInferShapeFunctor
,
PD_INFER_META
(
phi
::
InstanceNormDoubleGradInferMeta
));
REGISTER_OPERATOR
(
instance_norm
,
ops
::
InstanceNormOp
,
ops
::
InstanceNormOpMaker
,
REGISTER_OPERATOR
(
instance_norm
,
ops
::
InstanceNormOp
,
ops
::
InstanceNormOpMaker
,
ops
::
InstanceNormOpInferVarType
,
ops
::
InstanceNormOpInferVarType
,
ops
::
InstanceNormGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
InstanceNormGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
InstanceNormGradMaker
<
paddle
::
imperative
::
OpBase
>
);
ops
::
InstanceNormGradMaker
<
paddle
::
imperative
::
OpBase
>
,
InstanceNormInferShapeFunctor
);
REGISTER_OPERATOR
(
instance_norm_grad
,
ops
::
InstanceNormGradOp
,
REGISTER_OPERATOR
(
instance_norm_grad
,
ops
::
InstanceNormGradOp
,
ops
::
InstanceNormDoubleGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
InstanceNormDoubleGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
InstanceNormDoubleGradMaker
<
paddle
::
imperative
::
OpBase
>
);
ops
::
InstanceNormDoubleGradMaker
<
paddle
::
imperative
::
OpBase
>
,
InstanceNormGradInferShapeFunctor
);
REGISTER_OPERATOR
(
instance_norm_grad_grad
,
ops
::
InstanceNormDoubleGradOp
,
REGISTER_OPERATOR
(
instance_norm_grad_grad
,
ops
::
InstanceNormDoubleGradOp
,
ops
::
InstanceNormDoubleGradOpInplaceInferer
);
ops
::
InstanceNormDoubleGradOpInplaceInferer
,
InstanceNormDoubleGradInferShapeFunctor
);
REGISTER_OP_CPU_KERNEL
(
instance_norm_grad_grad
,
ops
::
InstanceNormDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
InstanceNormDoubleGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_VERSION
(
instance_norm
)
REGISTER_OP_VERSION
(
instance_norm
)
.
AddCheckpoint
(
.
AddCheckpoint
(
...
...
paddle/fluid/operators/instance_norm_op.cu
已删除
100644 → 0
浏览文件 @
6af32a7f
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <cfloat>
#include <string>
#include <vector>
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/instance_norm_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
DataLayout
=
framework
::
DataLayout
;
template
<
typename
T
>
using
CudnnDataType
=
platform
::
CudnnDataType
<
T
>
;
template
<
typename
T
>
using
BatchNormParamType
=
typename
CudnnDataType
<
T
>::
BatchNormParamType
;
template
<
typename
T
>
static
__global__
void
repeat_param
(
const
T
*
input
,
T
*
output
,
const
int
repeat_num
,
const
int
C
)
{
CUDA_KERNEL_LOOP
(
i
,
repeat_num
*
C
)
{
int
index
=
i
%
C
;
output
[
i
]
=
input
[
index
];
}
}
template
<
typename
T
,
int
BlockDim
,
bool
AVG
>
static
__global__
void
add_param
(
const
T
*
input
,
T
*
output
,
const
int
repeat_num
,
const
int
C
)
{
typedef
cub
::
BlockReduce
<
T
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
ou_storage
;
for
(
int
i
=
blockIdx
.
x
;
i
<
C
;
i
+=
gridDim
.
x
)
{
T
ou
=
static_cast
<
T
>
(
0
);
for
(
int
j
=
threadIdx
.
x
;
j
<
repeat_num
;
j
+=
blockDim
.
x
)
{
const
int
index
=
j
*
C
+
i
;
ou
+=
static_cast
<
T
>
(
input
[
index
]);
}
ou
=
BlockReduce
(
ou_storage
).
Reduce
(
ou
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
{
output
[
i
]
=
ou
;
}
__syncthreads
();
if
(
AVG
)
{
output
[
i
]
/=
repeat_num
;
}
}
}
template
<
typename
T
,
int
BlockDim
>
static
__global__
void
GradComputeDX
(
const
T
*
dy
,
const
BatchNormParamType
<
T
>
*
scale
,
const
BatchNormParamType
<
T
>
*
mean
,
const
T
*
x
,
const
BatchNormParamType
<
T
>
*
variance
,
const
int
C
,
const
int
sample_size
,
T
*
dx
)
{
int
beg_idx
=
blockIdx
.
x
*
sample_size
+
threadIdx
.
x
;
int
end_idx
=
(
blockIdx
.
x
+
1
)
*
sample_size
;
int
ncid
=
blockIdx
.
x
;
int
c
=
ncid
%
C
;
BatchNormParamType
<
T
>
mean_val
=
mean
[
ncid
];
BatchNormParamType
<
T
>
inv_var_val
=
variance
[
ncid
];
typedef
cub
::
BlockReduce
<
BatchNormParamType
<
T
>
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
dy_storage
;
__shared__
typename
BlockReduce
::
TempStorage
dy_x_sub_mean_storage
;
__shared__
BatchNormParamType
<
T
>
dy_sum_val
;
__shared__
BatchNormParamType
<
T
>
dy_x_sub_mean_sum_val
;
BatchNormParamType
<
T
>
dy_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
BatchNormParamType
<
T
>
dy_x_sub_mean_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
BatchNormParamType
<
T
>
dy_i
=
static_cast
<
BatchNormParamType
<
T
>>
(
dy
[
i
]);
dy_sum
+=
dy_i
;
dy_x_sub_mean_sum
+=
dy_i
*
(
static_cast
<
BatchNormParamType
<
T
>>
(
x
[
i
])
-
mean_val
);
}
dy_sum
=
BlockReduce
(
dy_storage
).
Reduce
(
dy_sum
,
cub
::
Sum
());
dy_x_sub_mean_sum
=
BlockReduce
(
dy_x_sub_mean_storage
).
Reduce
(
dy_x_sub_mean_sum
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
{
dy_sum_val
=
dy_sum
;
dy_x_sub_mean_sum_val
=
dy_x_sub_mean_sum
;
}
__syncthreads
();
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
dx
[
i
]
=
(
static_cast
<
BatchNormParamType
<
T
>>
(
dy
[
i
])
-
dy_sum_val
/
static_cast
<
BatchNormParamType
<
T
>>
(
sample_size
)
-
(
static_cast
<
BatchNormParamType
<
T
>>
(
x
[
i
])
-
mean_val
)
*
dy_x_sub_mean_sum_val
*
inv_var_val
*
inv_var_val
/
sample_size
)
*
scale
[
c
]
*
inv_var_val
;
}
}
static
__device__
__forceinline__
float
real_sqrt
(
float
x
)
{
return
1.
/
sqrtf
(
x
);
}
static
__device__
__forceinline__
double
real_sqrt
(
double
x
)
{
return
1.
/
sqrt
(
x
);
}
template
<
typename
T
,
int
BlockDim
>
__global__
void
DoubleGradComputeDX
(
const
T
*
x
,
const
T
*
mean
,
const
T
*
variance
,
const
T
*
ddx
,
const
T
*
dy
,
const
T
*
scale
,
const
T
*
ddscale
,
int
C
,
int
sample_size
,
const
double
epsilon
,
T
*
dx
)
{
int
beg_idx
=
blockIdx
.
x
*
sample_size
+
threadIdx
.
x
;
int
end_idx
=
(
blockIdx
.
x
+
1
)
*
sample_size
;
int
ncid
=
blockIdx
.
x
;
int
c
=
ncid
%
C
;
T
mean_val
=
mean
[
ncid
];
T
var_val
=
variance
[
ncid
];
typedef
cub
::
BlockReduce
<
T
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
dy_storage
;
__shared__
typename
BlockReduce
::
TempStorage
ddx_storage
;
__shared__
typename
BlockReduce
::
TempStorage
dy_mul_ddx_storage
;
__shared__
typename
BlockReduce
::
TempStorage
dy_mul_x_sub_mean_storage
;
__shared__
typename
BlockReduce
::
TempStorage
ddx_mul_x_sub_mean_storage
;
__shared__
T
dy_sum_val
;
__shared__
T
ddx_sum_val
;
__shared__
T
dy_mul_ddx_sum_val
;
__shared__
T
dy_mul_x_sub_mean_sum_val
;
__shared__
T
ddx_mul_x_sub_mean_sum_val
;
T
dy_sum
=
0
;
T
ddx_sum
=
0
;
T
dy_mul_ddx_sum
=
0
;
T
dy_mul_x_sub_mean_sum
=
0
;
T
ddx_mul_x_sub_mean_sum
=
0
;
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
T
ddx_i
=
ddx
[
i
];
T
dy_i
=
dy
[
i
];
T
tmp
=
x
[
i
]
-
mean_val
;
dy_sum
+=
dy_i
;
ddx_sum
+=
ddx_i
;
dy_mul_ddx_sum
+=
(
ddx_i
*
dy_i
);
dy_mul_x_sub_mean_sum
+=
(
dy_i
*
tmp
);
ddx_mul_x_sub_mean_sum
+=
(
ddx_i
*
tmp
);
}
dy_sum
=
BlockReduce
(
dy_storage
).
Reduce
(
dy_sum
,
cub
::
Sum
());
ddx_sum
=
BlockReduce
(
ddx_storage
).
Reduce
(
ddx_sum
,
cub
::
Sum
());
dy_mul_ddx_sum
=
BlockReduce
(
dy_mul_ddx_storage
).
Reduce
(
dy_mul_ddx_sum
,
cub
::
Sum
());
dy_mul_x_sub_mean_sum
=
BlockReduce
(
dy_mul_x_sub_mean_storage
)
.
Reduce
(
dy_mul_x_sub_mean_sum
,
cub
::
Sum
());
ddx_mul_x_sub_mean_sum
=
BlockReduce
(
ddx_mul_x_sub_mean_storage
)
.
Reduce
(
ddx_mul_x_sub_mean_sum
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
{
dy_sum_val
=
dy_sum
;
ddx_sum_val
=
ddx_sum
;
dy_mul_ddx_sum_val
=
dy_mul_ddx_sum
;
dy_mul_x_sub_mean_sum_val
=
dy_mul_x_sub_mean_sum
;
ddx_mul_x_sub_mean_sum_val
=
ddx_mul_x_sub_mean_sum
;
}
__syncthreads
();
if
(
ddx
!=
nullptr
)
{
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
dx
[
i
]
+=
((
x
[
i
]
-
mean_val
)
*
var_val
*
var_val
*
var_val
/
sample_size
*
(
ddx_sum_val
*
dy_sum_val
/
sample_size
-
dy_mul_ddx_sum_val
+
3.
*
dy_mul_x_sub_mean_sum_val
*
var_val
*
ddx_mul_x_sub_mean_sum_val
*
var_val
/
sample_size
)
+
ddx_mul_x_sub_mean_sum_val
*
var_val
/
sample_size
*
var_val
*
var_val
*
(
dy_sum_val
/
sample_size
-
dy
[
i
])
+
dy_mul_x_sub_mean_sum_val
*
var_val
/
sample_size
*
var_val
*
var_val
*
(
ddx_sum_val
/
sample_size
-
ddx
[
i
]))
*
scale
[
c
];
}
}
__syncthreads
();
if
(
ddscale
!=
nullptr
)
{
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
dx
[
i
]
+=
(
dy
[
i
]
*
var_val
-
dy_sum_val
/
sample_size
*
var_val
-
(
x
[
i
]
-
mean_val
)
*
var_val
*
dy_mul_x_sub_mean_sum_val
*
var_val
/
sample_size
)
*
ddscale
[
c
];
}
}
}
template
<
typename
T
,
int
BlockDim
>
__global__
void
DoubleGradComputeDDY
(
const
T
*
x
,
const
T
*
mean
,
const
T
*
variance
,
const
T
*
ddscale
,
const
T
*
ddbias
,
const
T
*
ddx
,
const
T
*
scale
,
int
C
,
int
sample_size
,
const
double
epsilon
,
T
*
ddy
)
{
int
beg_idx
=
blockIdx
.
x
*
sample_size
+
threadIdx
.
x
;
int
end_idx
=
(
blockIdx
.
x
+
1
)
*
sample_size
;
int
ncid
=
blockIdx
.
x
;
int
c
=
ncid
%
C
;
T
mean_val
=
mean
[
ncid
];
T
var_val
=
variance
[
ncid
];
typedef
cub
::
BlockReduce
<
T
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
ddx_storage
;
__shared__
typename
BlockReduce
::
TempStorage
ddx_mul_x_sub_mean_storage
;
__shared__
T
ddx_sum_val
;
__shared__
T
ddx_mul_x_sub_mean_sum_val
;
T
ddx_sum
=
0
;
T
ddx_mul_x_sub_mean_sum
=
0
;
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
T
ddx_i
=
ddx
[
i
];
ddx_sum
+=
ddx_i
;
ddx_mul_x_sub_mean_sum
+=
(
ddx_i
*
(
x
[
i
]
-
mean_val
));
}
ddx_sum
=
BlockReduce
(
ddx_storage
).
Reduce
(
ddx_sum
,
cub
::
Sum
());
ddx_mul_x_sub_mean_sum
=
BlockReduce
(
ddx_mul_x_sub_mean_storage
)
.
Reduce
(
ddx_mul_x_sub_mean_sum
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
{
ddx_sum_val
=
ddx_sum
;
ddx_mul_x_sub_mean_sum_val
=
ddx_mul_x_sub_mean_sum
;
}
__syncthreads
();
if
(
ddx
!=
nullptr
)
{
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
ddy
[
i
]
+=
scale
[
c
]
*
var_val
*
(
ddx
[
i
]
-
ddx_sum_val
/
sample_size
-
(
x
[
i
]
-
mean_val
)
*
var_val
*
ddx_mul_x_sub_mean_sum_val
*
var_val
/
sample_size
);
}
}
__syncthreads
();
if
(
ddscale
!=
nullptr
)
{
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
ddy
[
i
]
+=
(
x
[
i
]
-
mean_val
)
*
var_val
*
ddscale
[
c
];
}
}
__syncthreads
();
if
(
ddbias
!=
nullptr
)
{
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
ddy
[
i
]
+=
ddbias
[
c
];
}
}
}
template
<
typename
T
,
int
BlockDim
>
__global__
void
DoubleGradComputeDScale
(
const
T
*
x
,
const
T
*
mean
,
const
T
*
variance
,
const
T
*
ddx
,
const
T
*
dy
,
int
C
,
int
sample_size
,
const
double
epsilon
,
T
*
dscale
)
{
int
beg_idx
=
blockIdx
.
x
*
sample_size
+
threadIdx
.
x
;
int
end_idx
=
(
blockIdx
.
x
+
1
)
*
sample_size
;
int
ncid
=
blockIdx
.
x
;
int
c
=
ncid
%
C
;
T
mean_val
=
mean
[
ncid
];
T
var_val
=
variance
[
ncid
];
typedef
cub
::
BlockReduce
<
T
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
dy_storage
;
__shared__
typename
BlockReduce
::
TempStorage
dy_mul_x_sub_mean_storage
;
__shared__
typename
BlockReduce
::
TempStorage
dscale_tmp_storage
;
__shared__
T
dy_sum_val
;
__shared__
T
dy_mul_x_sub_mean_sum_val
;
T
dy_sum
=
0
;
T
dy_mul_x_sub_mean_sum
=
0
;
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
T
dy_i
=
dy
[
i
];
dy_sum
+=
dy_i
;
dy_mul_x_sub_mean_sum
+=
(
dy_i
*
(
x
[
i
]
-
mean_val
));
}
dy_sum
=
BlockReduce
(
dy_storage
).
Reduce
(
dy_sum
,
cub
::
Sum
());
dy_mul_x_sub_mean_sum
=
BlockReduce
(
dy_mul_x_sub_mean_storage
)
.
Reduce
(
dy_mul_x_sub_mean_sum
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
{
dy_sum_val
=
dy_sum
;
dy_mul_x_sub_mean_sum_val
=
dy_mul_x_sub_mean_sum
;
}
__syncthreads
();
if
(
ddx
!=
nullptr
)
{
T
dscale_tmp
=
0
;
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
dscale_tmp
+=
ddx
[
i
]
*
var_val
*
(
dy
[
i
]
-
dy_sum_val
/
sample_size
-
dy_mul_x_sub_mean_sum_val
*
(
x
[
i
]
-
mean_val
)
*
var_val
*
var_val
/
sample_size
);
}
dscale_tmp
=
BlockReduce
(
dscale_tmp_storage
).
Reduce
(
dscale_tmp
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
{
dscale
[
ncid
]
+=
dscale_tmp
;
}
__syncthreads
();
}
}
template
<
typename
T
>
class
InstanceNormDoubleGradKernel
<
platform
::
CUDADeviceContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
auto
*
Scale
=
ctx
.
Input
<
Tensor
>
(
"Scale"
);
const
auto
*
dY
=
ctx
.
Input
<
Tensor
>
(
"DY"
);
const
auto
*
Saved_mean
=
ctx
.
Input
<
Tensor
>
(
"SavedMean"
);
const
auto
*
Saved_variance
=
ctx
.
Input
<
Tensor
>
(
"SavedVariance"
);
const
auto
*
running_mean
=
ctx
.
Input
<
Tensor
>
(
"Mean"
);
const
auto
*
running_var
=
ctx
.
Input
<
Tensor
>
(
"Variance"
);
const
auto
*
ddX
=
ctx
.
Input
<
Tensor
>
(
"DDX"
);
const
auto
*
ddScale
=
ctx
.
Input
<
Tensor
>
(
"DDScale"
);
const
auto
*
ddBias
=
ctx
.
Input
<
Tensor
>
(
"DDBias"
);
const
double
epsilon
=
static_cast
<
double
>
(
ctx
.
Attr
<
float
>
(
"epsilon"
));
auto
*
dX
=
ctx
.
Output
<
Tensor
>
(
"DX"
);
auto
*
dScale
=
ctx
.
Output
<
Tensor
>
(
"DScale"
);
auto
*
ddY
=
ctx
.
Output
<
Tensor
>
(
"DDY"
);
const
T
*
x_data
=
X
->
data
<
T
>
();
const
T
*
dy_data
=
dY
->
data
<
T
>
();
const
T
*
ddx_data
=
(
ddX
==
nullptr
?
nullptr
:
ddX
->
data
<
T
>
());
const
T
*
ddscale_data
=
(
ddScale
==
nullptr
?
nullptr
:
ddScale
->
data
<
T
>
());
const
T
*
ddbias_data
=
(
ddScale
==
nullptr
?
nullptr
:
ddBias
->
data
<
T
>
());
const
T
*
mean_data
=
Saved_mean
->
data
<
T
>
();
const
T
*
variance_data
=
Saved_variance
->
data
<
T
>
();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
phi
::
funcs
::
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
set_zero
;
auto
&
x_dims
=
X
->
dims
();
int
N
,
C
,
H
,
W
,
D
;
ExtractNCWHD
(
x_dims
,
DataLayout
::
kNCHW
,
&
N
,
&
C
,
&
H
,
&
W
,
&
D
);
int
NxC
=
N
*
C
;
const
int
n
=
X
->
numel
();
int
sample_size
=
n
/
N
/
C
;
Tensor
scale_tmp
;
if
(
!
Scale
)
{
scale_tmp
.
mutable_data
<
T
>
({
C
},
ctx
.
GetPlace
());
set_zero
(
dev_ctx
,
&
scale_tmp
,
static_cast
<
T
>
(
1
));
}
const
T
*
scale_data
=
Scale
?
Scale
->
data
<
T
>
()
:
scale_tmp
.
data
<
T
>
();
const
int
block
=
512
;
int
max_threads
=
dev_ctx
.
GetMaxPhysicalThreadCount
();
const
int
max_blocks
=
std
::
max
(
max_threads
/
block
,
1
);
const
int
grid
=
NxC
;
const
int
grid1
=
(
C
+
block
-
1
)
/
block
;
if
(
dX
)
{
T
*
dx_data
=
dX
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
set_zero
(
dev_ctx
,
dX
,
static_cast
<
T
>
(
0
));
DoubleGradComputeDX
<
T
,
block
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
x_data
,
mean_data
,
variance_data
,
ddx_data
,
dy_data
,
scale_data
,
ddscale_data
,
C
,
sample_size
,
epsilon
,
dx_data
);
}
if
(
dScale
)
{
Tensor
dscale_tmp
=
ctx
.
AllocateTmpTensor
<
T
,
platform
::
CUDADeviceContext
>
({
NxC
},
dev_ctx
);
set_zero
(
dev_ctx
,
&
dscale_tmp
,
static_cast
<
T
>
(
0
));
T
*
dscale_tmp_data
=
dscale_tmp
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
dscale_data
=
dScale
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
set_zero
(
dev_ctx
,
dScale
,
static_cast
<
T
>
(
0
));
DoubleGradComputeDScale
<
T
,
block
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
x_data
,
mean_data
,
variance_data
,
ddx_data
,
dy_data
,
C
,
sample_size
,
epsilon
,
dscale_tmp_data
);
add_param
<
T
,
block
,
false
><<<
grid1
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
dscale_tmp
.
data
<
T
>
(),
dScale
->
data
<
T
>
(),
N
,
C
);
}
if
(
ddY
)
{
T
*
ddy_data
=
ddY
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
set_zero
(
dev_ctx
,
ddY
,
static_cast
<
T
>
(
0
));
DoubleGradComputeDDY
<
T
,
block
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
x_data
,
mean_data
,
variance_data
,
ddscale_data
,
ddbias_data
,
ddx_data
,
scale_data
,
C
,
sample_size
,
epsilon
,
ddy_data
);
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL
(
instance_norm_grad_grad
,
ops
::
InstanceNormDoubleGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
#else
REGISTER_OP_CUDA_KERNEL
(
instance_norm_grad_grad
,
ops
::
InstanceNormDoubleGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
InstanceNormDoubleGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
#endif
paddle/fluid/operators/instance_norm_op.h
浏览文件 @
b2b78cd4
...
@@ -16,9 +16,7 @@ limitations under the License. */
...
@@ -16,9 +16,7 @@ limitations under the License. */
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/norm_utils.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -27,22 +25,9 @@ using Tensor = framework::Tensor;
...
@@ -27,22 +25,9 @@ using Tensor = framework::Tensor;
using
LoDTensor
=
framework
::
LoDTensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
using
DataLayout
=
framework
::
DataLayout
;
using
DataLayout
=
framework
::
DataLayout
;
template
<
typename
T
>
using
EigenArrayMap
=
Eigen
::
Map
<
Eigen
::
Array
<
T
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
>>
;
template
<
typename
T
>
using
ConstEigenArrayMap
=
Eigen
::
Map
<
const
Eigen
::
Array
<
T
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
>>
;
template
<
typename
T
>
using
EigenVectorArrayMap
=
Eigen
::
Map
<
Eigen
::
Array
<
T
,
Eigen
::
Dynamic
,
1
>>
;
template
<
typename
T
>
using
ConstEigenVectorArrayMap
=
Eigen
::
Map
<
const
Eigen
::
Array
<
T
,
Eigen
::
Dynamic
,
1
>>
;
class
InstanceNormOp
:
public
framework
::
OperatorWithKernel
{
class
InstanceNormOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
...
@@ -52,7 +37,6 @@ class InstanceNormOp : public framework::OperatorWithKernel {
...
@@ -52,7 +37,6 @@ class InstanceNormOp : public framework::OperatorWithKernel {
class
InstanceNormGradOp
:
public
framework
::
OperatorWithKernel
{
class
InstanceNormGradOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
...
@@ -62,7 +46,6 @@ class InstanceNormGradOp : public framework::OperatorWithKernel {
...
@@ -62,7 +46,6 @@ class InstanceNormGradOp : public framework::OperatorWithKernel {
class
InstanceNormDoubleGradOp
:
public
framework
::
OperatorWithKernel
{
class
InstanceNormDoubleGradOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
...
@@ -130,23 +113,5 @@ class InstanceNormOpInferVarType
...
@@ -130,23 +113,5 @@ class InstanceNormOpInferVarType
}
}
};
};
template
<
typename
DeviceContext
,
typename
T
>
class
InstanceNormKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
};
template
<
typename
DeviceContext
,
typename
T
>
class
InstanceNormGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
};
template
<
typename
DeviceContext
,
typename
T
>
class
InstanceNormDoubleGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/phi/infermeta/backward.cc
浏览文件 @
b2b78cd4
...
@@ -312,6 +312,63 @@ void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
...
@@ -312,6 +312,63 @@ void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
dx
->
share_meta
(
dout
);
dx
->
share_meta
(
dout
);
}
}
void
InstanceNormGradInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
y_grad
,
paddle
::
optional
<
const
MetaTensor
&>
scale
,
const
MetaTensor
&
saved_mean
,
const
MetaTensor
&
saved_variance
,
float
epsilon
,
MetaTensor
*
x_grad
,
MetaTensor
*
scale_grad
,
MetaTensor
*
bias_grad
)
{
PADDLE_ENFORCE_NE
(
x_grad
,
nullptr
,
phi
::
errors
::
InvalidArgument
(
"The X@GRAD in InstanceNormGradInferMeta can't be nullptr."
));
const
auto
x_dims
=
x
.
dims
();
const
int
C
=
x_dims
[
1
];
x_grad
->
set_dims
(
x_dims
);
x_grad
->
set_dtype
(
x
.
dtype
());
x_grad
->
set_layout
(
x
.
layout
());
if
(
scale_grad
)
{
scale_grad
->
set_dims
({
C
});
}
if
(
bias_grad
)
{
bias_grad
->
set_dims
({
C
});
}
}
void
InstanceNormDoubleGradInferMeta
(
const
MetaTensor
&
x
,
paddle
::
optional
<
const
MetaTensor
&>
scale
,
const
MetaTensor
&
saved_mean
,
const
MetaTensor
&
saved_variance
,
const
MetaTensor
&
dy
,
paddle
::
optional
<
const
MetaTensor
&>
ddx
,
paddle
::
optional
<
const
MetaTensor
&>
ddscale
,
paddle
::
optional
<
const
MetaTensor
&>
ddbias
,
float
epsilon
,
MetaTensor
*
dx
,
MetaTensor
*
dscale
,
MetaTensor
*
ddy
)
{
PADDLE_ENFORCE_NE
(
dx
,
nullptr
,
phi
::
errors
::
InvalidArgument
(
"The DX in InstanceNormDoubleGradInferMeta can't be nullptr."
));
const
auto
x_dims
=
x
.
dims
();
const
int
C
=
x_dims
[
1
];
dx
->
set_dims
(
x_dims
);
dx
->
set_dtype
(
x
.
dtype
());
dx
->
set_layout
(
x
.
layout
());
if
(
dscale
)
{
dscale
->
set_dims
({
C
});
}
if
(
ddy
)
{
ddy
->
share_dims
(
x
);
}
}
void
KernelWithXShapeInferMeta
(
const
MetaTensor
&
xshape
,
MetaTensor
*
dx
)
{
void
KernelWithXShapeInferMeta
(
const
MetaTensor
&
xshape
,
MetaTensor
*
dx
)
{
auto
xshape_dims
=
xshape
.
dims
();
auto
xshape_dims
=
xshape
.
dims
();
auto
x_dims
=
phi
::
slice_ddim
(
xshape_dims
,
1
,
xshape_dims
.
size
());
auto
x_dims
=
phi
::
slice_ddim
(
xshape_dims
,
1
,
xshape_dims
.
size
());
...
...
paddle/phi/infermeta/backward.h
浏览文件 @
b2b78cd4
...
@@ -144,6 +144,30 @@ void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
...
@@ -144,6 +144,30 @@ void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
int
axis
,
int
axis
,
MetaTensor
*
dx
);
MetaTensor
*
dx
);
void
InstanceNormGradInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
y_grad
,
paddle
::
optional
<
const
MetaTensor
&>
scale
,
const
MetaTensor
&
saved_mean
,
const
MetaTensor
&
saved_variance
,
float
epsilon
,
MetaTensor
*
x_grad
,
MetaTensor
*
scale_grad
,
MetaTensor
*
bias_grad
);
void
InstanceNormDoubleGradInferMeta
(
const
MetaTensor
&
x
,
paddle
::
optional
<
const
MetaTensor
&>
scale
,
const
MetaTensor
&
saved_mean
,
const
MetaTensor
&
saved_variance
,
const
MetaTensor
&
dy
,
paddle
::
optional
<
const
MetaTensor
&>
ddx
,
paddle
::
optional
<
const
MetaTensor
&>
ddscale
,
paddle
::
optional
<
const
MetaTensor
&>
ddbias
,
float
epsilon
,
MetaTensor
*
dx
,
MetaTensor
*
dscale
,
MetaTensor
*
ddy
);
void
KernelWithXShapeInferMeta
(
const
MetaTensor
&
xshape
,
MetaTensor
*
dx
);
void
KernelWithXShapeInferMeta
(
const
MetaTensor
&
xshape
,
MetaTensor
*
dx
);
void
MaxPoolWithIndexGradInferMeta
(
const
MetaTensor
&
x
,
void
MaxPoolWithIndexGradInferMeta
(
const
MetaTensor
&
x
,
...
...
paddle/phi/infermeta/ternary.cc
浏览文件 @
b2b78cd4
...
@@ -191,6 +191,111 @@ void ArangeInferMeta(const MetaTensor& start,
...
@@ -191,6 +191,111 @@ void ArangeInferMeta(const MetaTensor& start,
out
->
set_dtype
(
start
.
dtype
());
out
->
set_dtype
(
start
.
dtype
());
}
}
void
InstanceNormInferMeta
(
const
MetaTensor
&
x
,
paddle
::
optional
<
const
MetaTensor
&>
scale
,
paddle
::
optional
<
const
MetaTensor
&>
bias
,
float
epsilon
,
MetaTensor
*
y
,
MetaTensor
*
saved_mean
,
MetaTensor
*
saved_variance
,
MetaConfig
config
)
{
PADDLE_ENFORCE_NE
(
y
,
nullptr
,
phi
::
errors
::
InvalidArgument
(
"The y in InstanceNormInferMeta can't be nullptr."
));
PADDLE_ENFORCE_NE
(
saved_mean
,
nullptr
,
phi
::
errors
::
InvalidArgument
(
"The saved_mean in InstanceNormInferMeta can't be nullptr."
));
PADDLE_ENFORCE_NE
(
saved_variance
,
nullptr
,
phi
::
errors
::
InvalidArgument
(
"The saved_variance in InstanceNormInferMeta can't be nullptr."
));
const
auto
x_dims
=
x
.
dims
();
PADDLE_ENFORCE_NE
(
phi
::
product
(
x_dims
),
0
,
phi
::
errors
::
PreconditionNotMet
(
"The Input variable X has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."
));
PADDLE_ENFORCE_GE
(
x_dims
.
size
(),
2
,
phi
::
errors
::
InvalidArgument
(
"ShapeError: the dimension of input X must "
"greater than or equal to 2. But received: the shape of input "
"X = [%s], the dimension of input X =[%d]"
,
x_dims
,
x_dims
.
size
()));
PADDLE_ENFORCE_LE
(
x_dims
.
size
(),
5
,
phi
::
errors
::
InvalidArgument
(
"ShapeError: the dimension of input X must "
"smaller than or equal to 5, But received: the shape of input "
"X = [%s], the dimension of input X = [%d]"
,
x_dims
,
x_dims
.
size
()));
auto
N
=
x_dims
[
0
];
auto
C
=
x_dims
[
1
];
auto
NxC
=
N
*
C
;
const
auto
scale_ptr
=
scale
.
get_ptr
();
if
(
scale_ptr
)
{
auto
scale_dim
=
scale_ptr
->
dims
();
PADDLE_ENFORCE_EQ
(
scale_dim
.
size
(),
1UL
,
phi
::
errors
::
InvalidArgument
(
"ShapeError: the dimension of scale must equal to 1."
"But received: the shape of scale is [%s], the dimension "
"of scale is [%d]"
,
scale_dim
,
scale_dim
.
size
()));
bool
check
=
!
((
!
config
.
is_runtime
)
&&
(
phi
::
product
(
scale_dim
)
<=
0
));
if
(
check
)
{
PADDLE_ENFORCE_EQ
(
scale_dim
[
0
],
C
,
phi
::
errors
::
InvalidArgument
(
"ShapeError: the shape of scale must equal to [%d]"
"But received: the shape of scale is [%d]"
,
C
,
scale_dim
[
0
]));
}
}
const
auto
bias_ptr
=
bias
.
get_ptr
();
if
(
bias_ptr
)
{
auto
bias_dim
=
bias_ptr
->
dims
();
PADDLE_ENFORCE_EQ
(
bias_dim
.
size
(),
1UL
,
phi
::
errors
::
InvalidArgument
(
"ShapeError: the dimension of bias must equal to 1."
"But received: the shape of bias is [%s],the dimension "
"of bias is [%d]"
,
bias_dim
,
bias_dim
.
size
()));
bool
check
=
!
((
!
config
.
is_runtime
)
&&
(
phi
::
product
(
bias_dim
)
<=
0
));
if
(
check
)
{
PADDLE_ENFORCE_EQ
(
bias_dim
[
0
],
C
,
phi
::
errors
::
InvalidArgument
(
"ShapeError: the shape of bias must equal to [%d]"
"But received: the shape of bias is [%d]"
,
C
,
bias_dim
[
0
]));
}
}
y
->
set_dims
(
x_dims
);
saved_mean
->
set_dims
({
NxC
});
saved_variance
->
set_dims
({
NxC
});
y
->
share_lod
(
x
);
y
->
set_dtype
(
x
.
dtype
());
y
->
set_layout
(
x
.
layout
());
}
void
GraphSendRecvInferMeta
(
const
MetaTensor
&
x
,
void
GraphSendRecvInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
src_index
,
const
MetaTensor
&
src_index
,
const
MetaTensor
&
dst_index
,
const
MetaTensor
&
dst_index
,
...
...
paddle/phi/infermeta/ternary.h
浏览文件 @
b2b78cd4
...
@@ -52,6 +52,15 @@ void ArangeInferMeta(const MetaTensor& start,
...
@@ -52,6 +52,15 @@ void ArangeInferMeta(const MetaTensor& start,
const
MetaTensor
&
step
,
const
MetaTensor
&
step
,
MetaTensor
*
out
);
MetaTensor
*
out
);
void
InstanceNormInferMeta
(
const
MetaTensor
&
x
,
paddle
::
optional
<
const
MetaTensor
&>
scale
,
paddle
::
optional
<
const
MetaTensor
&>
bias
,
float
epsilon
,
MetaTensor
*
y
,
MetaTensor
*
saved_mean
,
MetaTensor
*
saved_variance
,
MetaConfig
config
=
MetaConfig
());
void
GraphSendRecvInferMeta
(
const
MetaTensor
&
x
,
void
GraphSendRecvInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
src_index
,
const
MetaTensor
&
src_index
,
const
MetaTensor
&
dst_index
,
const
MetaTensor
&
dst_index
,
...
...
paddle/phi/kernels/cpu/instance_norm_grad_kernel.cc
浏览文件 @
b2b78cd4
...
@@ -23,8 +23,22 @@
...
@@ -23,8 +23,22 @@
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/extensions.h"
#include "paddle/phi/kernels/funcs/eigen/extensions.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/norm_utils.h"
namespace
phi
{
namespace
phi
{
template
<
typename
T
>
using
ConstEigenArrayMap
=
Eigen
::
Map
<
const
Eigen
::
Array
<
T
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
>>
;
template
<
typename
T
>
using
ConstEigenVectorArrayMap
=
Eigen
::
Map
<
const
Eigen
::
Array
<
T
,
Eigen
::
Dynamic
,
1
>>
;
template
<
typename
T
>
using
EigenArrayMap
=
Eigen
::
Map
<
Eigen
::
Array
<
T
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
>>
;
template
<
typename
T
>
using
EigenVectorArrayMap
=
Eigen
::
Map
<
Eigen
::
Array
<
T
,
Eigen
::
Dynamic
,
1
>>
;
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
InstanceNormGradKernel
(
const
Context
&
dev_ctx
,
void
InstanceNormGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
...
@@ -136,6 +150,188 @@ void InstanceNormGradKernel(const Context& dev_ctx,
...
@@ -136,6 +150,188 @@ void InstanceNormGradKernel(const Context& dev_ctx,
.
broadcast
(
bcast
));
.
broadcast
(
bcast
));
}
}
template
<
typename
T
,
typename
Context
>
void
InstanceNormDoubleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
paddle
::
optional
<
const
DenseTensor
&>
scale
,
const
DenseTensor
&
saved_mean
,
const
DenseTensor
&
saved_variance
,
const
DenseTensor
&
dy
,
paddle
::
optional
<
const
DenseTensor
&>
ddx
,
paddle
::
optional
<
const
DenseTensor
&>
ddscale
,
paddle
::
optional
<
const
DenseTensor
&>
ddbias
,
float
epsilon
,
DenseTensor
*
dx
,
DenseTensor
*
dscale
,
DenseTensor
*
ddy
)
{
const
auto
*
Scale
=
scale
.
get_ptr
();
const
auto
*
ddScale
=
ddscale
.
get_ptr
();
const
auto
*
ddX
=
ddx
.
get_ptr
();
const
auto
*
ddBias
=
ddbias
.
get_ptr
();
phi
::
funcs
::
SetConstant
<
CPUContext
,
T
>
set_constant
;
const
auto
&
x_dims
=
x
.
dims
();
int
N
,
C
,
H
,
W
,
D
;
funcs
::
ExtractNCWHD
(
x_dims
,
DataLayout
::
kNCHW
,
&
N
,
&
C
,
&
H
,
&
W
,
&
D
);
const
int
sample_size
=
x
.
numel
()
/
N
/
C
;
const
int
NxC
=
N
*
C
;
const
T
*
mean_data
=
saved_mean
.
data
<
T
>
();
const
T
*
inv_var_data
=
saved_variance
.
data
<
T
>
();
DenseTensor
mean_tensor
;
DenseTensor
inv_var_tensor
;
ConstEigenArrayMap
<
T
>
x_arr
(
x
.
data
<
T
>
(),
sample_size
,
NxC
);
ConstEigenVectorArrayMap
<
T
>
mean_arr
(
mean_data
,
NxC
);
ConstEigenVectorArrayMap
<
T
>
inv_var_arr
(
inv_var_data
,
NxC
);
DenseTensor
mean_tile
;
mean_tile
.
Resize
({
sample_size
,
NxC
});
dev_ctx
.
template
Alloc
<
T
>(
&
mean_tile
);
EigenArrayMap
<
T
>
mean_tile_data
(
mean_tile
.
data
<
T
>
(),
sample_size
,
NxC
);
DenseTensor
inv_var_tile
;
inv_var_tile
.
Resize
({
sample_size
,
NxC
});
dev_ctx
.
template
Alloc
<
T
>(
&
inv_var_tile
);
EigenArrayMap
<
T
>
inv_var_tile_data
(
inv_var_tile
.
data
<
T
>
(),
sample_size
,
NxC
);
mean_tile_data
=
mean_arr
.
transpose
().
replicate
(
sample_size
,
1
);
inv_var_tile_data
=
inv_var_arr
.
transpose
().
replicate
(
sample_size
,
1
);
DenseTensor
Scale_data
;
if
(
!
Scale
)
{
Scale_data
.
Resize
({
C
});
dev_ctx
.
template
Alloc
<
T
>(
&
Scale_data
);
set_constant
(
dev_ctx
,
&
Scale_data
,
static_cast
<
T
>
(
1
));
}
ConstEigenVectorArrayMap
<
T
>
scale_arr
(
Scale
?
Scale
->
data
<
T
>
()
:
Scale_data
.
data
<
T
>
(),
C
);
DenseTensor
scale_tile
;
scale_tile
.
Resize
({
sample_size
,
NxC
});
dev_ctx
.
template
Alloc
<
T
>(
&
scale_tile
);
EigenArrayMap
<
T
>
scale_tile_data
(
scale_tile
.
data
<
T
>
(),
sample_size
,
NxC
);
scale_tile_data
=
scale_arr
.
transpose
().
replicate
(
sample_size
,
N
);
ConstEigenArrayMap
<
T
>
dy_arr
(
dy
.
data
<
T
>
(),
sample_size
,
NxC
);
ConstEigenArrayMap
<
T
>
ddx_arr
(
ddX
->
data
<
T
>
(),
sample_size
,
NxC
);
// math: dx = scale * ((x - mean) * inv_var / HxW * (np.mean(ddx,
// axis=(h,w)) * np.sum(dy, axis=(h,w)) -
// np.sum(dy * ddx, axis=(h,w)) + 3 * np.mean(dy * (x - mean),
// axis=(h,w)) * inv_var.pow(2) *
// np.sum(ddx * (x - mean), axis=(h,w))) + inv_var.pow(3) / HxW *
// np.sum(ddx * (x - mean)) *
// (np.mean(dy, axis=(h,w)) - dy) + inv_var.pow(3) / HxW *
// np.sum(dy, axis=(h,w)) * (x - mean) *
// (np.mean(ddx, axis=(h,w)) - ddx)) + ddr * (dy * inv_var -
// inv_var * np.mean(dy, axis=(h,w)) - inv_var.pow(3) *
// (x - mean) * np.mean(dy * (x - mean), axis=(h,w)))
DenseTensor
x_sub_mean_mul_invstd
;
x_sub_mean_mul_invstd
.
Resize
({
sample_size
,
NxC
});
dev_ctx
.
template
Alloc
<
T
>(
&
x_sub_mean_mul_invstd
);
EigenArrayMap
<
T
>
x_sub_mean_mul_invstd_arr
(
x_sub_mean_mul_invstd
.
data
<
T
>
(),
sample_size
,
NxC
);
x_sub_mean_mul_invstd_arr
=
(
x_arr
-
mean_tile_data
)
*
inv_var_tile_data
;
if
(
dx
)
{
dev_ctx
.
template
Alloc
<
T
>(
dx
);
set_constant
(
dev_ctx
,
dx
,
static_cast
<
T
>
(
0
));
EigenArrayMap
<
T
>
dx_arr
(
dx
->
data
<
T
>
(),
sample_size
,
NxC
);
if
(
ddX
)
{
dx_arr
+=
x_sub_mean_mul_invstd_arr
*
inv_var_tile_data
*
inv_var_tile_data
/
sample_size
*
(
ddx_arr
.
colwise
().
sum
()
*
dy_arr
.
colwise
().
sum
()
/
sample_size
-
(
dy_arr
*
ddx_arr
).
colwise
().
sum
()
+
3.
*
(
dy_arr
*
x_sub_mean_mul_invstd_arr
).
colwise
().
sum
()
*
(
ddx_arr
*
x_sub_mean_mul_invstd_arr
).
colwise
().
sum
()
/
sample_size
);
dx_arr
+=
(
ddx_arr
*
x_sub_mean_mul_invstd_arr
).
colwise
().
sum
()
/
sample_size
*
inv_var_tile_data
*
inv_var_tile_data
*
(
dy_arr
.
colwise
().
sum
()
/
sample_size
-
dy_arr
);
dx_arr
+=
(
dy_arr
*
x_sub_mean_mul_invstd_arr
).
colwise
().
sum
()
/
sample_size
*
inv_var_tile_data
*
inv_var_tile_data
*
(
ddx_arr
.
colwise
().
sum
()
/
sample_size
-
ddx_arr
);
dx_arr
=
scale_tile_data
*
dx_arr
;
}
if
(
ddScale
)
{
ConstEigenVectorArrayMap
<
T
>
ddscale_arr
(
ddScale
->
data
<
T
>
(),
C
);
DenseTensor
ddscale_tile
;
ddscale_tile
.
Resize
({
sample_size
,
NxC
});
dev_ctx
.
template
Alloc
<
T
>(
&
ddscale_tile
);
EigenArrayMap
<
T
>
ddscale_tile_data
(
ddscale_tile
.
data
<
T
>
(),
sample_size
,
NxC
);
ddscale_tile_data
=
ddscale_arr
.
transpose
().
replicate
(
sample_size
,
N
);
dx_arr
+=
(
dy_arr
*
inv_var_tile_data
-
dy_arr
.
colwise
().
sum
()
/
sample_size
*
inv_var_tile_data
-
x_sub_mean_mul_invstd_arr
*
inv_var_tile_data
*
(
dy_arr
*
x_sub_mean_mul_invstd_arr
).
colwise
().
sum
()
/
sample_size
)
*
ddscale_tile_data
;
}
}
if
(
dscale
)
{
// math: dscale = inv_var * (dy - np.mean(dy, axis=(h,w) - (x-mean) *
// inv_var.pow(2) * np.mean(dy * (x-mean), axis=(h,w)))) * ddx
dev_ctx
.
template
Alloc
<
T
>(
dscale
);
set_constant
(
dev_ctx
,
dscale
,
static_cast
<
T
>
(
0
));
EigenVectorArrayMap
<
T
>
dscale_arr
(
dscale
->
data
<
T
>
(),
C
);
if
(
ddX
)
{
DenseTensor
first_grad
;
first_grad
.
Resize
({
sample_size
,
NxC
});
dev_ctx
.
template
Alloc
<
T
>(
&
first_grad
);
set_constant
(
dev_ctx
,
&
first_grad
,
static_cast
<
T
>
(
0
));
EigenArrayMap
<
T
>
first_grad_arr
(
first_grad
.
data
<
T
>
(),
sample_size
,
NxC
);
first_grad_arr
+=
inv_var_tile_data
*
(
dy_arr
-
dy_arr
.
colwise
().
sum
().
replicate
(
sample_size
,
1
)
/
sample_size
-
x_sub_mean_mul_invstd_arr
*
(
dy_arr
*
x_sub_mean_mul_invstd_arr
)
.
colwise
()
.
sum
()
.
replicate
(
sample_size
,
1
)
/
sample_size
);
first_grad_arr
=
first_grad_arr
*
ddx_arr
;
for
(
int
nc
=
0
;
nc
<
NxC
;
++
nc
)
{
int
c
=
nc
%
C
;
dscale_arr
(
c
)
+=
first_grad_arr
.
colwise
().
sum
()(
nc
);
}
}
}
if
(
ddy
)
{
// math: ddy = (x - mean) * inv_var * ddscale + ddbias +
// scale * inv_var * (ddx - (x - mean) * inv_var.pow(2) *
// np.mean(ddx * (x - mean), axis=(h,w)))
dev_ctx
.
template
Alloc
<
T
>(
ddy
);
set_constant
(
dev_ctx
,
ddy
,
static_cast
<
T
>
(
0
));
EigenArrayMap
<
T
>
ddy_arr
(
ddy
->
data
<
T
>
(),
sample_size
,
NxC
);
if
(
ddX
)
{
ddy_arr
+=
scale_tile_data
*
inv_var_tile_data
*
(
ddx_arr
-
ddx_arr
.
colwise
().
sum
()
/
sample_size
-
x_sub_mean_mul_invstd_arr
*
(
ddx_arr
*
x_sub_mean_mul_invstd_arr
).
colwise
().
sum
()
/
sample_size
);
}
if
(
ddScale
&&
ddBias
)
{
ConstEigenVectorArrayMap
<
T
>
ddscale_arr
(
ddScale
->
data
<
T
>
(),
C
);
DenseTensor
ddscale_tile
;
ddscale_tile
.
Resize
({
sample_size
,
NxC
});
dev_ctx
.
template
Alloc
<
T
>(
&
ddscale_tile
);
EigenArrayMap
<
T
>
ddscale_tile_data
(
ddscale_tile
.
data
<
T
>
(),
sample_size
,
NxC
);
ddscale_tile_data
=
ddscale_arr
.
transpose
().
replicate
(
sample_size
,
N
);
ConstEigenVectorArrayMap
<
T
>
ddbias_arr
(
ddBias
->
data
<
T
>
(),
C
);
DenseTensor
ddbias_tile
;
ddbias_tile
.
Resize
({
sample_size
,
NxC
});
dev_ctx
.
template
Alloc
<
T
>(
&
ddbias_tile
);
EigenArrayMap
<
T
>
ddbias_tile_data
(
ddbias_tile
.
data
<
T
>
(),
sample_size
,
NxC
);
ddbias_tile_data
=
ddbias_arr
.
transpose
().
replicate
(
sample_size
,
N
);
ddy_arr
+=
x_sub_mean_mul_invstd_arr
*
ddscale_tile_data
;
ddy_arr
+=
ddbias_tile_data
;
}
}
}
}
// namespace phi
}
// namespace phi
PD_REGISTER_KERNEL
(
instance_norm_grad
,
PD_REGISTER_KERNEL
(
instance_norm_grad
,
...
@@ -144,3 +340,9 @@ PD_REGISTER_KERNEL(instance_norm_grad,
...
@@ -144,3 +340,9 @@ PD_REGISTER_KERNEL(instance_norm_grad,
phi
::
InstanceNormGradKernel
,
phi
::
InstanceNormGradKernel
,
float
,
float
,
double
)
{}
double
)
{}
PD_REGISTER_KERNEL
(
instance_norm_double_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
InstanceNormDoubleGradKernel
,
float
,
double
)
{}
paddle/phi/kernels/funcs/norm_utils.h
0 → 100644
浏览文件 @
b2b78cd4
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/ddim.h"
namespace
phi
{
namespace
funcs
{
inline
void
ExtractNCWHD
(
const
phi
::
DDim
&
dims
,
const
DataLayout
&
data_layout
,
int
*
N
,
int
*
C
,
int
*
H
,
int
*
W
,
int
*
D
)
{
*
N
=
dims
[
0
];
if
(
dims
.
size
()
==
2
)
{
*
C
=
dims
[
1
];
*
H
=
1
;
*
W
=
1
;
*
D
=
1
;
}
else
{
*
C
=
data_layout
==
DataLayout
::
kNCHW
?
dims
[
1
]
:
dims
[
dims
.
size
()
-
1
];
*
H
=
data_layout
==
DataLayout
::
kNCHW
?
dims
[
2
]
:
dims
[
1
];
*
W
=
dims
.
size
()
>
3
?
(
data_layout
==
DataLayout
::
kNCHW
?
dims
[
3
]
:
dims
[
2
])
:
1
;
*
D
=
dims
.
size
()
>
4
?
(
data_layout
==
DataLayout
::
kNCHW
?
dims
[
4
]
:
dims
[
3
])
:
1
;
}
}
}
// namespace funcs
}
// namespace phi
paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu
浏览文件 @
b2b78cd4
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/operators/norm_utils.cu.h"
#include "paddle/fluid/operators/norm_utils.cu.h"
#include "paddle/
fluid/operator
s/norm_utils.h"
#include "paddle/
phi/kernels/func
s/norm_utils.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/layout_utils.h"
#include "paddle/fluid/operators/layout_utils.h"
...
@@ -351,7 +351,7 @@ void BatchNormGradRawKernel(const Context &ctx,
...
@@ -351,7 +351,7 @@ void BatchNormGradRawKernel(const Context &ctx,
x_dims
.
size
(),
x_dims
.
size
(),
x_dims
));
x_dims
));
int
N
,
C
,
H
,
W
,
D
;
int
N
,
C
,
H
,
W
,
D
;
p
addle
::
operator
s
::
ExtractNCWHD
(
x_dims
,
data_layout
,
&
N
,
&
C
,
&
H
,
&
W
,
&
D
);
p
hi
::
func
s
::
ExtractNCWHD
(
x_dims
,
data_layout
,
&
N
,
&
C
,
&
H
,
&
W
,
&
D
);
// init output
// init output
if
(
d_x
)
{
if
(
d_x
)
{
...
...
paddle/phi/kernels/gpu/batch_norm_kernel.cu
浏览文件 @
b2b78cd4
...
@@ -27,7 +27,7 @@ namespace cub = hipcub;
...
@@ -27,7 +27,7 @@ namespace cub = hipcub;
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/fluid/operators/norm_utils.cu.h"
#include "paddle/fluid/operators/norm_utils.cu.h"
#include "paddle/
fluid/operator
s/norm_utils.h"
#include "paddle/
phi/kernels/func
s/norm_utils.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/layout_utils.h"
#include "paddle/fluid/operators/layout_utils.h"
...
@@ -179,7 +179,7 @@ void BatchNormKernel(const Context &ctx,
...
@@ -179,7 +179,7 @@ void BatchNormKernel(const Context &ctx,
ctx
.
template
Alloc
<
T
>(
y
);
ctx
.
template
Alloc
<
T
>(
y
);
int
N
,
C
,
H
,
W
,
D
;
int
N
,
C
,
H
,
W
,
D
;
p
addle
::
operator
s
::
ExtractNCWHD
(
x_dims
,
data_layout
,
&
N
,
&
C
,
&
H
,
&
W
,
&
D
);
p
hi
::
func
s
::
ExtractNCWHD
(
x_dims
,
data_layout
,
&
N
,
&
C
,
&
H
,
&
W
,
&
D
);
auto
dtype
=
paddle
::
platform
::
CudnnDataType
<
T
>::
type
;
auto
dtype
=
paddle
::
platform
::
CudnnDataType
<
T
>::
type
;
...
...
paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu
浏览文件 @
b2b78cd4
...
@@ -14,16 +14,15 @@
...
@@ -14,16 +14,15 @@
#include "paddle/phi/kernels/instance_norm_grad_kernel.h"
#include "paddle/phi/kernels/instance_norm_grad_kernel.h"
#include "paddle/fluid/operators/norm_utils.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/norm_utils.h"
#include "paddle/phi/kernels/gpu/instance_norm_utils.h"
#include "paddle/phi/kernels/gpu/instance_norm_utils.h"
namespace
phi
{
namespace
phi
{
template
<
typename
T
,
int
BlockDim
>
template
<
typename
T
,
int
BlockDim
>
static
__global__
void
GradComputeDX
(
const
T
*
dy
,
static
__global__
void
GradComputeDX
(
const
T
*
dy
,
const
BatchNormParamType
<
T
>
*
scale
,
const
BatchNormParamType
<
T
>
*
scale
,
...
@@ -37,16 +36,13 @@ static __global__ void GradComputeDX(const T *dy,
...
@@ -37,16 +36,13 @@ static __global__ void GradComputeDX(const T *dy,
int
end_idx
=
(
blockIdx
.
x
+
1
)
*
sample_size
;
int
end_idx
=
(
blockIdx
.
x
+
1
)
*
sample_size
;
int
ncid
=
blockIdx
.
x
;
int
ncid
=
blockIdx
.
x
;
int
c
=
ncid
%
C
;
int
c
=
ncid
%
C
;
BatchNormParamType
<
T
>
mean_val
=
mean
[
ncid
];
BatchNormParamType
<
T
>
mean_val
=
mean
[
ncid
];
BatchNormParamType
<
T
>
inv_var_val
=
variance
[
ncid
];
BatchNormParamType
<
T
>
inv_var_val
=
variance
[
ncid
];
typedef
cub
::
BlockReduce
<
BatchNormParamType
<
T
>
,
BlockDim
>
BlockReduce
;
typedef
cub
::
BlockReduce
<
BatchNormParamType
<
T
>
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
dy_storage
;
__shared__
typename
BlockReduce
::
TempStorage
dy_storage
;
__shared__
typename
BlockReduce
::
TempStorage
dy_x_sub_mean_storage
;
__shared__
typename
BlockReduce
::
TempStorage
dy_x_sub_mean_storage
;
__shared__
BatchNormParamType
<
T
>
dy_sum_val
;
__shared__
BatchNormParamType
<
T
>
dy_sum_val
;
__shared__
BatchNormParamType
<
T
>
dy_x_sub_mean_sum_val
;
__shared__
BatchNormParamType
<
T
>
dy_x_sub_mean_sum_val
;
BatchNormParamType
<
T
>
dy_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
BatchNormParamType
<
T
>
dy_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
BatchNormParamType
<
T
>
dy_x_sub_mean_sum
=
BatchNormParamType
<
T
>
dy_x_sub_mean_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
...
@@ -60,13 +56,11 @@ static __global__ void GradComputeDX(const T *dy,
...
@@ -60,13 +56,11 @@ static __global__ void GradComputeDX(const T *dy,
dy_sum
=
BlockReduce
(
dy_storage
).
Reduce
(
dy_sum
,
cub
::
Sum
());
dy_sum
=
BlockReduce
(
dy_storage
).
Reduce
(
dy_sum
,
cub
::
Sum
());
dy_x_sub_mean_sum
=
dy_x_sub_mean_sum
=
BlockReduce
(
dy_x_sub_mean_storage
).
Reduce
(
dy_x_sub_mean_sum
,
cub
::
Sum
());
BlockReduce
(
dy_x_sub_mean_storage
).
Reduce
(
dy_x_sub_mean_sum
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
dy_sum_val
=
dy_sum
;
dy_sum_val
=
dy_sum
;
dy_x_sub_mean_sum_val
=
dy_x_sub_mean_sum
;
dy_x_sub_mean_sum_val
=
dy_x_sub_mean_sum
;
}
}
__syncthreads
();
__syncthreads
();
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
dx
[
i
]
=
dx
[
i
]
=
(
static_cast
<
BatchNormParamType
<
T
>>
(
dy
[
i
])
-
(
static_cast
<
BatchNormParamType
<
T
>>
(
dy
[
i
])
-
...
@@ -77,6 +71,222 @@ static __global__ void GradComputeDX(const T *dy,
...
@@ -77,6 +71,222 @@ static __global__ void GradComputeDX(const T *dy,
}
}
}
}
static
__device__
__forceinline__
float
real_sqrt
(
float
x
)
{
return
1.
/
sqrtf
(
x
);
}
static
__device__
__forceinline__
double
real_sqrt
(
double
x
)
{
return
1.
/
sqrt
(
x
);
}
template
<
typename
T
,
int
BlockDim
>
__global__
void
DoubleGradComputeDX
(
const
T
*
x
,
const
T
*
mean
,
const
T
*
variance
,
const
T
*
ddx
,
const
T
*
dy
,
const
T
*
scale
,
const
T
*
ddscale
,
int
C
,
int
sample_size
,
const
double
epsilon
,
T
*
dx
)
{
int
beg_idx
=
blockIdx
.
x
*
sample_size
+
threadIdx
.
x
;
int
end_idx
=
(
blockIdx
.
x
+
1
)
*
sample_size
;
int
ncid
=
blockIdx
.
x
;
int
c
=
ncid
%
C
;
T
mean_val
=
mean
[
ncid
];
T
var_val
=
variance
[
ncid
];
typedef
cub
::
BlockReduce
<
T
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
dy_storage
;
__shared__
typename
BlockReduce
::
TempStorage
ddx_storage
;
__shared__
typename
BlockReduce
::
TempStorage
dy_mul_ddx_storage
;
__shared__
typename
BlockReduce
::
TempStorage
dy_mul_x_sub_mean_storage
;
__shared__
typename
BlockReduce
::
TempStorage
ddx_mul_x_sub_mean_storage
;
__shared__
T
dy_sum_val
;
__shared__
T
ddx_sum_val
;
__shared__
T
dy_mul_ddx_sum_val
;
__shared__
T
dy_mul_x_sub_mean_sum_val
;
__shared__
T
ddx_mul_x_sub_mean_sum_val
;
T
dy_sum
=
0
;
T
ddx_sum
=
0
;
T
dy_mul_ddx_sum
=
0
;
T
dy_mul_x_sub_mean_sum
=
0
;
T
ddx_mul_x_sub_mean_sum
=
0
;
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
T
ddx_i
=
ddx
[
i
];
T
dy_i
=
dy
[
i
];
T
tmp
=
x
[
i
]
-
mean_val
;
dy_sum
+=
dy_i
;
ddx_sum
+=
ddx_i
;
dy_mul_ddx_sum
+=
(
ddx_i
*
dy_i
);
dy_mul_x_sub_mean_sum
+=
(
dy_i
*
tmp
);
ddx_mul_x_sub_mean_sum
+=
(
ddx_i
*
tmp
);
}
dy_sum
=
BlockReduce
(
dy_storage
).
Reduce
(
dy_sum
,
cub
::
Sum
());
ddx_sum
=
BlockReduce
(
ddx_storage
).
Reduce
(
ddx_sum
,
cub
::
Sum
());
dy_mul_ddx_sum
=
BlockReduce
(
dy_mul_ddx_storage
).
Reduce
(
dy_mul_ddx_sum
,
cub
::
Sum
());
dy_mul_x_sub_mean_sum
=
BlockReduce
(
dy_mul_x_sub_mean_storage
)
.
Reduce
(
dy_mul_x_sub_mean_sum
,
cub
::
Sum
());
ddx_mul_x_sub_mean_sum
=
BlockReduce
(
ddx_mul_x_sub_mean_storage
)
.
Reduce
(
ddx_mul_x_sub_mean_sum
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
{
dy_sum_val
=
dy_sum
;
ddx_sum_val
=
ddx_sum
;
dy_mul_ddx_sum_val
=
dy_mul_ddx_sum
;
dy_mul_x_sub_mean_sum_val
=
dy_mul_x_sub_mean_sum
;
ddx_mul_x_sub_mean_sum_val
=
ddx_mul_x_sub_mean_sum
;
}
__syncthreads
();
if
(
ddx
!=
nullptr
)
{
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
dx
[
i
]
+=
((
x
[
i
]
-
mean_val
)
*
var_val
*
var_val
*
var_val
/
sample_size
*
(
ddx_sum_val
*
dy_sum_val
/
sample_size
-
dy_mul_ddx_sum_val
+
3.
*
dy_mul_x_sub_mean_sum_val
*
var_val
*
ddx_mul_x_sub_mean_sum_val
*
var_val
/
sample_size
)
+
ddx_mul_x_sub_mean_sum_val
*
var_val
/
sample_size
*
var_val
*
var_val
*
(
dy_sum_val
/
sample_size
-
dy
[
i
])
+
dy_mul_x_sub_mean_sum_val
*
var_val
/
sample_size
*
var_val
*
var_val
*
(
ddx_sum_val
/
sample_size
-
ddx
[
i
]))
*
scale
[
c
];
}
}
__syncthreads
();
if
(
ddscale
!=
nullptr
)
{
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
dx
[
i
]
+=
(
dy
[
i
]
*
var_val
-
dy_sum_val
/
sample_size
*
var_val
-
(
x
[
i
]
-
mean_val
)
*
var_val
*
dy_mul_x_sub_mean_sum_val
*
var_val
/
sample_size
)
*
ddscale
[
c
];
}
}
}
template
<
typename
T
,
int
BlockDim
>
__global__
void
DoubleGradComputeDDY
(
const
T
*
x
,
const
T
*
mean
,
const
T
*
variance
,
const
T
*
ddscale
,
const
T
*
ddbias
,
const
T
*
ddx
,
const
T
*
scale
,
int
C
,
int
sample_size
,
const
double
epsilon
,
T
*
ddy
)
{
int
beg_idx
=
blockIdx
.
x
*
sample_size
+
threadIdx
.
x
;
int
end_idx
=
(
blockIdx
.
x
+
1
)
*
sample_size
;
int
ncid
=
blockIdx
.
x
;
int
c
=
ncid
%
C
;
T
mean_val
=
mean
[
ncid
];
T
var_val
=
variance
[
ncid
];
typedef
cub
::
BlockReduce
<
T
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
ddx_storage
;
__shared__
typename
BlockReduce
::
TempStorage
ddx_mul_x_sub_mean_storage
;
__shared__
T
ddx_sum_val
;
__shared__
T
ddx_mul_x_sub_mean_sum_val
;
T
ddx_sum
=
0
;
T
ddx_mul_x_sub_mean_sum
=
0
;
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
T
ddx_i
=
ddx
[
i
];
ddx_sum
+=
ddx_i
;
ddx_mul_x_sub_mean_sum
+=
(
ddx_i
*
(
x
[
i
]
-
mean_val
));
}
ddx_sum
=
BlockReduce
(
ddx_storage
).
Reduce
(
ddx_sum
,
cub
::
Sum
());
ddx_mul_x_sub_mean_sum
=
BlockReduce
(
ddx_mul_x_sub_mean_storage
)
.
Reduce
(
ddx_mul_x_sub_mean_sum
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
{
ddx_sum_val
=
ddx_sum
;
ddx_mul_x_sub_mean_sum_val
=
ddx_mul_x_sub_mean_sum
;
}
__syncthreads
();
if
(
ddx
!=
nullptr
)
{
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
ddy
[
i
]
+=
scale
[
c
]
*
var_val
*
(
ddx
[
i
]
-
ddx_sum_val
/
sample_size
-
(
x
[
i
]
-
mean_val
)
*
var_val
*
ddx_mul_x_sub_mean_sum_val
*
var_val
/
sample_size
);
}
}
__syncthreads
();
if
(
ddscale
!=
nullptr
)
{
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
ddy
[
i
]
+=
(
x
[
i
]
-
mean_val
)
*
var_val
*
ddscale
[
c
];
}
}
__syncthreads
();
if
(
ddbias
!=
nullptr
)
{
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
ddy
[
i
]
+=
ddbias
[
c
];
}
}
}
template
<
typename
T
,
int
BlockDim
>
__global__
void
DoubleGradComputeDScale
(
const
T
*
x
,
const
T
*
mean
,
const
T
*
variance
,
const
T
*
ddx
,
const
T
*
dy
,
int
C
,
int
sample_size
,
const
double
epsilon
,
T
*
dscale
)
{
int
beg_idx
=
blockIdx
.
x
*
sample_size
+
threadIdx
.
x
;
int
end_idx
=
(
blockIdx
.
x
+
1
)
*
sample_size
;
int
ncid
=
blockIdx
.
x
;
int
c
=
ncid
%
C
;
T
mean_val
=
mean
[
ncid
];
T
var_val
=
variance
[
ncid
];
typedef
cub
::
BlockReduce
<
T
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
dy_storage
;
__shared__
typename
BlockReduce
::
TempStorage
dy_mul_x_sub_mean_storage
;
__shared__
typename
BlockReduce
::
TempStorage
dscale_tmp_storage
;
__shared__
T
dy_sum_val
;
__shared__
T
dy_mul_x_sub_mean_sum_val
;
T
dy_sum
=
0
;
T
dy_mul_x_sub_mean_sum
=
0
;
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
T
dy_i
=
dy
[
i
];
dy_sum
+=
dy_i
;
dy_mul_x_sub_mean_sum
+=
(
dy_i
*
(
x
[
i
]
-
mean_val
));
}
dy_sum
=
BlockReduce
(
dy_storage
).
Reduce
(
dy_sum
,
cub
::
Sum
());
dy_mul_x_sub_mean_sum
=
BlockReduce
(
dy_mul_x_sub_mean_storage
)
.
Reduce
(
dy_mul_x_sub_mean_sum
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
{
dy_sum_val
=
dy_sum
;
dy_mul_x_sub_mean_sum_val
=
dy_mul_x_sub_mean_sum
;
}
__syncthreads
();
if
(
ddx
!=
nullptr
)
{
T
dscale_tmp
=
0
;
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
dscale_tmp
+=
ddx
[
i
]
*
var_val
*
(
dy
[
i
]
-
dy_sum_val
/
sample_size
-
dy_mul_x_sub_mean_sum_val
*
(
x
[
i
]
-
mean_val
)
*
var_val
*
var_val
/
sample_size
);
}
dscale_tmp
=
BlockReduce
(
dscale_tmp_storage
).
Reduce
(
dscale_tmp
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
{
dscale
[
ncid
]
+=
dscale_tmp
;
}
__syncthreads
();
}
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
InstanceNormGradKernel
(
const
Context
&
dev_ctx
,
void
InstanceNormGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
...
@@ -94,8 +304,7 @@ void InstanceNormGradKernel(const Context &dev_ctx,
...
@@ -94,8 +304,7 @@ void InstanceNormGradKernel(const Context &dev_ctx,
const
auto
&
x_dims
=
x
.
dims
();
const
auto
&
x_dims
=
x
.
dims
();
int
N
,
C
,
H
,
W
,
D
;
int
N
,
C
,
H
,
W
,
D
;
paddle
::
operators
::
ExtractNCWHD
(
funcs
::
ExtractNCWHD
(
x_dims
,
DataLayout
::
kNCHW
,
&
N
,
&
C
,
&
H
,
&
W
,
&
D
);
x_dims
,
DataLayout
::
kNCHW
,
&
N
,
&
C
,
&
H
,
&
W
,
&
D
);
int
NxC
=
N
*
C
;
int
NxC
=
N
*
C
;
DenseTensor
x_tmp
,
d_y_tmp
;
DenseTensor
x_tmp
,
d_y_tmp
;
...
@@ -303,12 +512,120 @@ void InstanceNormGradKernel(const Context &dev_ctx,
...
@@ -303,12 +512,120 @@ void InstanceNormGradKernel(const Context &dev_ctx,
paddle
::
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
in_param_desc_
));
paddle
::
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
in_param_desc_
));
#endif
#endif
}
}
template
<
typename
T
,
typename
Context
>
void
InstanceNormDoubleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
paddle
::
optional
<
const
DenseTensor
&>
scale
,
const
DenseTensor
&
saved_mean
,
const
DenseTensor
&
saved_variance
,
const
DenseTensor
&
dy
,
paddle
::
optional
<
const
DenseTensor
&>
ddx
,
paddle
::
optional
<
const
DenseTensor
&>
ddscale
,
paddle
::
optional
<
const
DenseTensor
&>
ddbias
,
float
epsilon_f
,
DenseTensor
*
dx
,
DenseTensor
*
dscale
,
DenseTensor
*
ddy
)
{
const
auto
*
Scale
=
scale
.
get_ptr
();
const
auto
*
ddX
=
ddx
.
get_ptr
();
const
auto
*
ddScale
=
ddscale
.
get_ptr
();
const
auto
*
ddBias
=
ddbias
.
get_ptr
();
const
double
epsilon
=
static_cast
<
double
>
(
epsilon_f
);
const
T
*
x_data
=
x
.
data
<
T
>
();
const
T
*
dy_data
=
dy
.
data
<
T
>
();
const
T
*
ddx_data
=
(
ddX
==
nullptr
?
nullptr
:
ddX
->
data
<
T
>
());
const
T
*
ddscale_data
=
(
ddScale
==
nullptr
?
nullptr
:
ddScale
->
data
<
T
>
());
const
T
*
ddbias_data
=
(
ddScale
==
nullptr
?
nullptr
:
ddBias
->
data
<
T
>
());
const
T
*
mean_data
=
saved_mean
.
data
<
T
>
();
const
T
*
variance_data
=
saved_variance
.
data
<
T
>
();
phi
::
funcs
::
SetConstant
<
GPUContext
,
T
>
set_zero
;
auto
&
x_dims
=
x
.
dims
();
int
N
,
C
,
H
,
W
,
D
;
funcs
::
ExtractNCWHD
(
x_dims
,
DataLayout
::
kNCHW
,
&
N
,
&
C
,
&
H
,
&
W
,
&
D
);
int
NxC
=
N
*
C
;
const
int
n
=
x
.
numel
();
int
sample_size
=
n
/
N
/
C
;
DenseTensor
scale_tmp
;
if
(
!
Scale
)
{
scale_tmp
.
Resize
({
C
});
dev_ctx
.
template
Alloc
<
T
>(
&
scale_tmp
);
set_zero
(
dev_ctx
,
&
scale_tmp
,
static_cast
<
T
>
(
1
));
}
const
T
*
scale_data
=
Scale
?
Scale
->
data
<
T
>
()
:
scale_tmp
.
data
<
T
>
();
const
int
block
=
512
;
int
max_threads
=
dev_ctx
.
GetMaxPhysicalThreadCount
();
const
int
max_blocks
=
std
::
max
(
max_threads
/
block
,
1
);
const
int
grid
=
NxC
;
const
int
grid1
=
(
C
+
block
-
1
)
/
block
;
if
(
dx
)
{
T
*
dx_data
=
dev_ctx
.
template
Alloc
<
T
>(
dx
);
set_zero
(
dev_ctx
,
dx
,
static_cast
<
T
>
(
0
));
DoubleGradComputeDX
<
T
,
block
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
x_data
,
mean_data
,
variance_data
,
ddx_data
,
dy_data
,
scale_data
,
ddscale_data
,
C
,
sample_size
,
epsilon
,
dx_data
);
}
if
(
dscale
)
{
DenseTensor
dscale_tmp
;
dscale_tmp
.
Resize
({
NxC
});
dev_ctx
.
template
Alloc
<
T
>(
&
dscale_tmp
);
set_zero
(
dev_ctx
,
&
dscale_tmp
,
static_cast
<
T
>
(
0
));
T
*
dscale_tmp_data
=
dscale_tmp
.
data
<
T
>
();
T
*
dscale_data
=
dev_ctx
.
template
Alloc
<
T
>(
dscale
);
set_zero
(
dev_ctx
,
dscale
,
static_cast
<
T
>
(
0
));
DoubleGradComputeDScale
<
T
,
block
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
x_data
,
mean_data
,
variance_data
,
ddx_data
,
dy_data
,
C
,
sample_size
,
epsilon
,
dscale_tmp_data
);
add_param
<
T
,
block
,
false
><<<
grid1
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
dscale_tmp
.
data
<
T
>
(),
dscale
->
data
<
T
>
(),
N
,
C
);
}
if
(
ddy
)
{
T
*
ddy_data
=
dev_ctx
.
template
Alloc
<
T
>(
ddy
);
set_zero
(
dev_ctx
,
ddy
,
static_cast
<
T
>
(
0
));
DoubleGradComputeDDY
<
T
,
block
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
x_data
,
mean_data
,
variance_data
,
ddscale_data
,
ddbias_data
,
ddx_data
,
scale_data
,
C
,
sample_size
,
epsilon
,
ddy_data
);
}
}
}
// namespace phi
}
// namespace phi
#ifdef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
// MIOPEN do not support double
PD_REGISTER_KERNEL
(
PD_REGISTER_KERNEL
(
instance_norm_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
InstanceNormGradKernel
,
float
)
{}
instance_norm_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
InstanceNormGradKernel
,
float
)
{}
PD_REGISTER_KERNEL
(
instance_norm_double_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
InstanceNormDoubleGradKernel
,
float
)
{}
#else
#else
PD_REGISTER_KERNEL
(
instance_norm_grad
,
PD_REGISTER_KERNEL
(
instance_norm_grad
,
GPU
,
GPU
,
...
@@ -316,4 +633,10 @@ PD_REGISTER_KERNEL(instance_norm_grad,
...
@@ -316,4 +633,10 @@ PD_REGISTER_KERNEL(instance_norm_grad,
phi
::
InstanceNormGradKernel
,
phi
::
InstanceNormGradKernel
,
float
,
float
,
double
)
{}
double
)
{}
PD_REGISTER_KERNEL
(
instance_norm_double_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
InstanceNormDoubleGradKernel
,
float
,
double
)
{}
#endif
#endif
paddle/phi/kernels/gpu/instance_norm_kernel.cu
浏览文件 @
b2b78cd4
...
@@ -14,11 +14,11 @@
...
@@ -14,11 +14,11 @@
#include "paddle/phi/kernels/instance_norm_kernel.h"
#include "paddle/phi/kernels/instance_norm_kernel.h"
#include "paddle/fluid/operators/norm_utils.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/norm_utils.h"
#include "paddle/phi/kernels/gpu/instance_norm_utils.h"
#include "paddle/phi/kernels/gpu/instance_norm_utils.h"
namespace
phi
{
namespace
phi
{
...
@@ -51,8 +51,7 @@ void InstanceNormKernel(const Context &dev_ctx,
...
@@ -51,8 +51,7 @@ void InstanceNormKernel(const Context &dev_ctx,
"the size of X's dimensions is [%d]"
,
"the size of X's dimensions is [%d]"
,
x_dims
.
size
()));
x_dims
.
size
()));
int
N
,
C
,
H
,
W
,
D
;
int
N
,
C
,
H
,
W
,
D
;
paddle
::
operators
::
ExtractNCWHD
(
funcs
::
ExtractNCWHD
(
x_dims
,
DataLayout
::
kNCHW
,
&
N
,
&
C
,
&
H
,
&
W
,
&
D
);
x_dims
,
DataLayout
::
kNCHW
,
&
N
,
&
C
,
&
H
,
&
W
,
&
D
);
int
NxC
=
N
*
C
;
int
NxC
=
N
*
C
;
DenseTensor
x_tmp
;
DenseTensor
x_tmp
;
x_tmp
.
ShareDataWith
(
x
).
Resize
({
1
,
NxC
,
H
,
W
,
D
});
x_tmp
.
ShareDataWith
(
x
).
Resize
({
1
,
NxC
,
H
,
W
,
D
});
...
...
paddle/phi/kernels/instance_norm_grad_kernel.h
浏览文件 @
b2b78cd4
...
@@ -30,4 +30,19 @@ void InstanceNormGradKernel(const Context& dev_ctx,
...
@@ -30,4 +30,19 @@ void InstanceNormGradKernel(const Context& dev_ctx,
DenseTensor
*
scale_grad
,
DenseTensor
*
scale_grad
,
DenseTensor
*
bias_grad
);
DenseTensor
*
bias_grad
);
template
<
typename
T
,
typename
Context
>
void
InstanceNormDoubleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
paddle
::
optional
<
const
DenseTensor
&>
scale
,
const
DenseTensor
&
saved_mean
,
const
DenseTensor
&
saved_variance
,
const
DenseTensor
&
dy
,
paddle
::
optional
<
const
DenseTensor
&>
ddx
,
paddle
::
optional
<
const
DenseTensor
&>
ddscale
,
paddle
::
optional
<
const
DenseTensor
&>
ddbias
,
float
epsilon
,
DenseTensor
*
dx
,
DenseTensor
*
dscale
,
DenseTensor
*
ddy
);
}
// namespace phi
}
// namespace phi
paddle/phi/ops/compat/instance_norm_sig.cc
浏览文件 @
b2b78cd4
...
@@ -31,8 +31,26 @@ KernelSignature InstanceNormGradOpArgumentMapping(
...
@@ -31,8 +31,26 @@ KernelSignature InstanceNormGradOpArgumentMapping(
{
"epsilon"
},
{
"epsilon"
},
{
"X@GRAD"
,
"Scale@GRAD"
,
"Bias@GRAD"
});
{
"X@GRAD"
,
"Scale@GRAD"
,
"Bias@GRAD"
});
}
}
KernelSignature
InstanceNormDoubleGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"instance_norm_double_grad"
,
{
"X"
,
"Scale"
,
"SavedMean"
,
"SavedVariance"
,
"DY"
,
"DDX"
,
"DDScale"
,
"DDBias"
},
{
"epsilon"
},
{
"DX"
,
"DScale"
,
"DDY"
});
}
}
// namespace phi
}
// namespace phi
PD_REGISTER_BASE_KERNEL_NAME
(
instance_norm_grad_grad
,
instance_norm_double_grad
);
PD_REGISTER_ARG_MAPPING_FN
(
instance_norm
,
phi
::
InstanceNormOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
instance_norm
,
phi
::
InstanceNormOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
instance_norm_grad
,
PD_REGISTER_ARG_MAPPING_FN
(
instance_norm_grad
,
phi
::
InstanceNormGradOpArgumentMapping
);
phi
::
InstanceNormGradOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
instance_norm_grad_grad
,
phi
::
InstanceNormDoubleGradOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录