Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ae50c37c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ae50c37c
编写于
7月 16, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 16, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3092 GPU add fuison: replace momentum cast
Merge pull request !3092 from VectorSL/momentum
上级
4e31cb9b
14017418
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
218 addition
and
71 deletion
+218
-71
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu
...rc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu
+20
-7
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh
...c/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh
+2
-2
mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h
...re/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h
+6
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.cc
...end/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.cc
+16
-16
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h
...kend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h
+6
-6
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.cc
...kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.cc
+5
-5
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.h
.../kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.h
+5
-5
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.cc
...src/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.cc
+36
-27
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.h
...csrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.h
+3
-3
mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.cc
...src/backend/optimizer/gpu/replace_momentum_cast_fusion.cc
+63
-0
mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.h
...csrc/backend/optimizer/gpu/replace_momentum_cast_fusion.h
+46
-0
mindspore/ccsrc/backend/session/gpu_session.cc
mindspore/ccsrc/backend/session/gpu_session.cc
+10
-0
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu
浏览文件 @
ae50c37c
...
@@ -15,9 +15,9 @@
...
@@ -15,9 +15,9 @@
*/
*/
#include "momentum_impl.cuh"
#include "momentum_impl.cuh"
template
<
typename
T
,
typename
S
>
template
<
typename
T
,
typename
S
,
typename
G
>
__global__
void
MomentumUpdateVariableKernel
(
const
size_t
size
,
T
*
variable
,
T
*
accumulation
,
const
S
*
learning_rate
,
__global__
void
MomentumUpdateVariableKernel
(
const
size_t
size
,
T
*
variable
,
T
*
accumulation
,
const
S
*
learning_rate
,
const
T
*
gradient
,
const
S
*
momentum
)
{
const
G
*
gradient
,
const
S
*
momentum
)
{
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
(
size
);
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
(
size
);
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
accumulation
[
i
]
=
momentum
[
0
]
*
accumulation
[
i
]
+
gradient
[
i
];
accumulation
[
i
]
=
momentum
[
0
]
*
accumulation
[
i
]
+
gradient
[
i
];
variable
[
i
]
-=
learning_rate
[
0
]
*
accumulation
[
i
];
variable
[
i
]
-=
learning_rate
[
0
]
*
accumulation
[
i
];
...
@@ -34,19 +34,32 @@ __global__ void MomentumUpdateVariableKernel(const size_t size, half *variable,
...
@@ -34,19 +34,32 @@ __global__ void MomentumUpdateVariableKernel(const size_t size, half *variable,
}
}
return
;
return
;
}
}
template
<
typename
T
,
typename
S
>
template
<
>
void
MomentumUpdateVariable
(
const
size_t
size
,
T
*
variable
,
T
*
accumulation
,
const
S
*
learning_rate
,
const
T
*
gradient
,
__global__
void
MomentumUpdateVariableKernel
(
const
size_t
size
,
float
*
variable
,
float
*
accumulation
,
const
float
*
learning_rate
,
const
half
*
gradient
,
const
float
*
momentum
)
{
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
(
size
);
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
accumulation
[
i
]
=
momentum
[
0
]
*
accumulation
[
i
]
+
__half2float
(
gradient
[
i
]);
variable
[
i
]
-=
learning_rate
[
0
]
*
accumulation
[
i
];
}
return
;
}
template
<
typename
T
,
typename
S
,
typename
G
>
void
MomentumUpdateVariable
(
const
size_t
size
,
T
*
variable
,
T
*
accumulation
,
const
S
*
learning_rate
,
const
G
*
gradient
,
const
S
*
momentum
,
cudaStream_t
cuda_stream
)
{
const
S
*
momentum
,
cudaStream_t
cuda_stream
)
{
MomentumUpdateVariableKernel
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
variable
,
accumulation
,
MomentumUpdateVariableKernel
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
variable
,
accumulation
,
learning_rate
,
gradient
,
momentum
);
learning_rate
,
gradient
,
momentum
);
return
;
return
;
}
}
template
void
MomentumUpdateVariable
<
float
,
float
>(
const
size_t
size
,
float
*
variable
,
float
*
accumulation
,
template
void
MomentumUpdateVariable
<
float
,
float
,
float
>(
const
size_t
size
,
float
*
variable
,
float
*
accumulation
,
const
float
*
learning_rate
,
const
float
*
gradient
,
const
float
*
learning_rate
,
const
float
*
gradient
,
const
float
*
momentum
,
cudaStream_t
cuda_stream
);
const
float
*
momentum
,
cudaStream_t
cuda_stream
);
template
void
MomentumUpdateVariable
<
half
,
half
>(
const
size_t
size
,
half
*
variable
,
half
*
accumulation
,
template
void
MomentumUpdateVariable
<
half
,
half
,
half
>(
const
size_t
size
,
half
*
variable
,
half
*
accumulation
,
const
half
*
learning_rate
,
const
half
*
gradient
,
const
half
*
learning_rate
,
const
half
*
gradient
,
const
half
*
momentum
,
cudaStream_t
cuda_stream
);
const
half
*
momentum
,
cudaStream_t
cuda_stream
);
template
void
MomentumUpdateVariable
<
half
,
float
>(
const
size_t
size
,
half
*
variable
,
half
*
accumulation
,
template
void
MomentumUpdateVariable
<
half
,
float
,
half
>(
const
size_t
size
,
half
*
variable
,
half
*
accumulation
,
const
float
*
learning_rate
,
const
half
*
gradient
,
const
float
*
momentum
,
cudaStream_t
cuda_stream
);
template
void
MomentumUpdateVariable
<
float
,
float
,
half
>(
const
size_t
size
,
float
*
variable
,
float
*
accumulation
,
const
float
*
learning_rate
,
const
half
*
gradient
,
const
float
*
learning_rate
,
const
half
*
gradient
,
const
float
*
momentum
,
cudaStream_t
cuda_stream
);
const
float
*
momentum
,
cudaStream_t
cuda_stream
);
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh
浏览文件 @
ae50c37c
...
@@ -18,8 +18,8 @@
...
@@ -18,8 +18,8 @@
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_
#include "runtime/device/gpu/cuda_common.h"
#include "runtime/device/gpu/cuda_common.h"
template
<
typename
T
,
typename
S
>
template
<
typename
T
,
typename
S
,
typename
G
>
void
MomentumUpdateVariable
(
const
size_t
size
,
T
*
variable
,
T
*
accumulation
,
const
S
*
learning_rate
,
const
T
*
gradient
,
void
MomentumUpdateVariable
(
const
size_t
size
,
T
*
variable
,
T
*
accumulation
,
const
S
*
learning_rate
,
const
G
*
gradient
,
const
S
*
momentum
,
cudaStream_t
cuda_stream
);
const
S
*
momentum
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_
mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h
浏览文件 @
ae50c37c
...
@@ -88,6 +88,12 @@ class GpuKernelRegister {
...
@@ -88,6 +88,12 @@ class GpuKernelRegister {
static_assert(std::is_base_of<GpuKernel, OPCLASS<T, S>>::value, " must be base of GpuKernel"); \
static_assert(std::is_base_of<GpuKernel, OPCLASS<T, S>>::value, " must be base of GpuKernel"); \
static const GpuKernelRegister g_##OPNAME##_##T##_##S##_gpu_kernel_reg(#OPNAME, ATTR, \
static const GpuKernelRegister g_##OPNAME##_##T##_##S##_gpu_kernel_reg(#OPNAME, ATTR, \
[]() { return new OPCLASS<T, S>(); });
[]() { return new OPCLASS<T, S>(); });
// register of mixed accuracy kernels which use template and maintain three typename
#define MS_REG_GPU_KERNEL_THREE(OPNAME, ATTR, OPCLASS, T, S, G) \
static_assert(std::is_base_of<GpuKernel, OPCLASS<T, S, G>>::value, " must be base of GpuKernel"); \
static const GpuKernelRegister g_##OPNAME##_##T##_##S##_##G##_gpu_kernel_reg( \
#OPNAME, ATTR, []() { return new OPCLASS<T, S, G>(); });
}
// namespace kernel
}
// namespace kernel
}
// namespace mindspore
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNELFACTORY_H_
#endif // MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNELFACTORY_H_
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.cc
浏览文件 @
ae50c37c
...
@@ -34,15 +34,15 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNorm,
...
@@ -34,15 +34,15 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNorm,
MS_REG_GPU_KERNEL_ONE
(
FusedBatchNorm
,
MS_REG_GPU_KERNEL_ONE
(
FusedBatchNorm
,
KernelAttr
()
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
),
.
AddOutputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
),
FusedBatchNormGpuKernel
,
half
)
FusedBatchNormGpuKernel
,
half
)
MS_REG_GPU_KERNEL_ONE
(
BatchNorm
,
MS_REG_GPU_KERNEL_ONE
(
BatchNorm
,
KernelAttr
()
KernelAttr
()
...
@@ -60,15 +60,15 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm,
...
@@ -60,15 +60,15 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm,
MS_REG_GPU_KERNEL_ONE
(
BatchNorm
,
MS_REG_GPU_KERNEL_ONE
(
BatchNorm
,
KernelAttr
()
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
),
.
AddOutputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
),
FusedBatchNormGpuKernel
,
half
)
FusedBatchNormGpuKernel
,
half
)
}
// namespace kernel
}
// namespace kernel
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h
浏览文件 @
ae50c37c
...
@@ -56,17 +56,17 @@ class FusedBatchNormGpuKernel : public GpuKernel {
...
@@ -56,17 +56,17 @@ class FusedBatchNormGpuKernel : public GpuKernel {
return
true
;
return
true
;
}
}
auto
x
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
auto
x
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
auto
scale
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
auto
scale
=
GetDeviceAddress
<
float
>
(
inputs
,
1
);
auto
bias
=
GetDeviceAddress
<
T
>
(
inputs
,
2
);
auto
bias
=
GetDeviceAddress
<
float
>
(
inputs
,
2
);
auto
runing_mean
=
GetDeviceAddress
<
T
>
(
inputs
,
3
);
auto
runing_mean
=
GetDeviceAddress
<
float
>
(
inputs
,
3
);
auto
runnig_variance
=
GetDeviceAddress
<
T
>
(
inputs
,
4
);
auto
runnig_variance
=
GetDeviceAddress
<
float
>
(
inputs
,
4
);
auto
y
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
auto
y
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
const
float
alpha
=
1
;
const
float
alpha
=
1
;
const
float
beta
=
0
;
const
float
beta
=
0
;
if
(
is_train_
)
{
if
(
is_train_
)
{
auto
save_mean
=
GetDeviceAddress
<
T
>
(
outputs
,
3
);
auto
save_mean
=
GetDeviceAddress
<
float
>
(
outputs
,
3
);
auto
save_variance
=
GetDeviceAddress
<
T
>
(
outputs
,
4
);
auto
save_variance
=
GetDeviceAddress
<
float
>
(
outputs
,
4
);
CHECK_CUDNN_RET_WITH_EXCEPT
(
CHECK_CUDNN_RET_WITH_EXCEPT
(
cudnnBatchNormalizationForwardTraining
(
handle_
,
mode_
,
&
alpha
,
&
beta
,
x_desc_
,
x
,
y_desc_
,
y
,
cudnnBatchNormalizationForwardTraining
(
handle_
,
mode_
,
&
alpha
,
&
beta
,
x_desc_
,
x
,
y_desc_
,
y
,
scale_bias_mean_var_desc_
,
scale
,
bias
,
exp_avg_factor_
,
runing_mean
,
scale_bias_mean_var_desc_
,
scale
,
bias
,
exp_avg_factor_
,
runing_mean
,
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.cc
浏览文件 @
ae50c37c
...
@@ -33,12 +33,12 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGrad,
...
@@ -33,12 +33,12 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGrad,
KernelAttr
()
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
),
.
AddOutputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
),
FusedBatchNormGradGpuKernel
,
half
)
FusedBatchNormGradGpuKernel
,
half
)
}
// namespace kernel
}
// namespace kernel
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.h
浏览文件 @
ae50c37c
...
@@ -55,12 +55,12 @@ class FusedBatchNormGradGpuKernel : public GpuKernel {
...
@@ -55,12 +55,12 @@ class FusedBatchNormGradGpuKernel : public GpuKernel {
}
}
auto
dy
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
auto
dy
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
auto
x
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
auto
x
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
auto
scale
=
GetDeviceAddress
<
T
>
(
inputs
,
2
);
auto
scale
=
GetDeviceAddress
<
float
>
(
inputs
,
2
);
auto
save_mean
=
GetDeviceAddress
<
T
>
(
inputs
,
3
);
auto
save_mean
=
GetDeviceAddress
<
float
>
(
inputs
,
3
);
auto
save_variance
=
GetDeviceAddress
<
T
>
(
inputs
,
4
);
auto
save_variance
=
GetDeviceAddress
<
float
>
(
inputs
,
4
);
auto
dx
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
auto
dx
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
auto
bn_scale
=
GetDeviceAddress
<
T
>
(
outputs
,
1
);
auto
bn_scale
=
GetDeviceAddress
<
float
>
(
outputs
,
1
);
auto
bn_bias
=
GetDeviceAddress
<
T
>
(
outputs
,
2
);
auto
bn_bias
=
GetDeviceAddress
<
float
>
(
outputs
,
2
);
const
float
alpha_data_diff
=
1
;
const
float
alpha_data_diff
=
1
;
const
float
beta_data_diff
=
0
;
const
float
beta_data_diff
=
0
;
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.cc
浏览文件 @
ae50c37c
...
@@ -18,32 +18,41 @@
...
@@ -18,32 +18,41 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
kernel
{
namespace
kernel
{
MS_REG_GPU_KERNEL_TWO
(
ApplyMomentum
,
MS_REG_GPU_KERNEL_THREE
(
ApplyMomentum
,
KernelAttr
()
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
),
.
AddOutputAttr
(
kNumberTypeFloat32
),
MomentumGpuKernel
,
float
,
float
)
MomentumGpuKernel
,
float
,
float
,
float
)
MS_REG_GPU_KERNEL_TWO
(
ApplyMomentum
,
MS_REG_GPU_KERNEL_THREE
(
ApplyMomentum
,
KernelAttr
()
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
),
.
AddOutputAttr
(
kNumberTypeFloat16
),
MomentumGpuKernel
,
half
,
half
)
MomentumGpuKernel
,
half
,
half
,
half
)
MS_REG_GPU_KERNEL_TWO
(
ApplyMomentum
,
MS_REG_GPU_KERNEL_THREE
(
ApplyMomentum
,
KernelAttr
()
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat16
),
.
AddOutputAttr
(
kNumberTypeFloat16
),
MomentumGpuKernel
,
half
,
float
)
MomentumGpuKernel
,
half
,
float
,
half
)
MS_REG_GPU_KERNEL_THREE
(
ApplyMomentum
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
),
MomentumGpuKernel
,
float
,
float
,
half
)
}
// namespace kernel
}
// namespace kernel
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.h
浏览文件 @
ae50c37c
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
#include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh"
namespace
mindspore
{
namespace
mindspore
{
namespace
kernel
{
namespace
kernel
{
template
<
typename
T
,
typename
S
>
template
<
typename
T
,
typename
S
,
typename
G
>
class
MomentumGpuKernel
:
public
GpuKernel
{
class
MomentumGpuKernel
:
public
GpuKernel
{
public:
public:
MomentumGpuKernel
()
MomentumGpuKernel
()
...
@@ -38,7 +38,7 @@ class MomentumGpuKernel : public GpuKernel {
...
@@ -38,7 +38,7 @@ class MomentumGpuKernel : public GpuKernel {
T
*
variable
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
variable
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
accumulation
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
T
*
accumulation
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
S
*
learning_rate
=
GetDeviceAddress
<
S
>
(
inputs
,
2
);
S
*
learning_rate
=
GetDeviceAddress
<
S
>
(
inputs
,
2
);
T
*
gradient
=
GetDeviceAddress
<
T
>
(
inputs
,
3
);
G
*
gradient
=
GetDeviceAddress
<
G
>
(
inputs
,
3
);
S
*
momentum
=
GetDeviceAddress
<
S
>
(
inputs
,
4
);
S
*
momentum
=
GetDeviceAddress
<
S
>
(
inputs
,
4
);
MomentumUpdateVariable
(
inputs
[
0
]
->
size
/
sizeof
(
T
),
variable
,
accumulation
,
learning_rate
,
gradient
,
momentum
,
MomentumUpdateVariable
(
inputs
[
0
]
->
size
/
sizeof
(
T
),
variable
,
accumulation
,
learning_rate
,
gradient
,
momentum
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
...
@@ -54,7 +54,7 @@ class MomentumGpuKernel : public GpuKernel {
...
@@ -54,7 +54,7 @@ class MomentumGpuKernel : public GpuKernel {
variable_size_
=
sizeof
(
T
);
variable_size_
=
sizeof
(
T
);
accumulation_size_
=
sizeof
(
T
);
accumulation_size_
=
sizeof
(
T
);
learning_rate_size_
=
sizeof
(
S
);
learning_rate_size_
=
sizeof
(
S
);
gradient_size_
=
sizeof
(
T
);
gradient_size_
=
sizeof
(
G
);
momentum_size_
=
sizeof
(
S
);
momentum_size_
=
sizeof
(
S
);
auto
variable_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
auto
variable_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
...
...
mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.cc
0 → 100644
浏览文件 @
ae50c37c
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/gpu/replace_momentum_cast_fusion.h"
#include <memory>
#include <vector>
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
namespace
mindspore
{
namespace
opt
{
const
BaseRef
ReplaceMomentumCastFusion
::
DefinePattern
()
const
{
VectorRef
grad_cast
=
VectorRef
({
prim
::
kPrimCast
,
grad_
});
VectorRef
momentum
=
VectorRef
({
prim
::
kPrimApplyMomentum
,
var_
,
acc_
,
lr_
,
grad_cast
,
mom_
});
return
momentum
;
}
const
AnfNodePtr
ReplaceMomentumCastFusion
::
Process
(
const
FuncGraphPtr
&
graph
,
const
AnfNodePtr
&
node
,
const
EquivPtr
&
equiv
)
const
{
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
equiv
);
auto
grad_cast
=
AnfAlgo
::
GetInputNode
(
utils
::
cast
<
CNodePtr
>
(
node
),
3
);
auto
grad
=
AnfAlgo
::
GetInputNode
(
utils
::
cast
<
CNodePtr
>
(
grad_cast
),
0
);
MS_EXCEPTION_IF_NULL
(
grad_cast
);
MS_EXCEPTION_IF_NULL
(
grad
);
auto
manager
=
graph
->
manager
();
MS_EXCEPTION_IF_NULL
(
manager
);
manager
->
Replace
(
utils
::
cast
<
CNodePtr
>
(
grad_cast
),
utils
::
cast
<
CNodePtr
>
(
grad
));
std
::
vector
<
TypeId
>
outputs_type
;
std
::
vector
<
std
::
vector
<
size_t
>>
outputs_shape
;
auto
output_num
=
AnfAlgo
::
GetOutputTensorNum
(
node
);
for
(
size_t
i
=
0
;
i
<
output_num
;
i
++
)
{
outputs_type
.
push_back
(
AnfAlgo
::
GetOutputInferDataType
(
node
,
i
));
outputs_shape
.
push_back
(
AnfAlgo
::
GetOutputInferShape
(
node
,
i
));
}
outputs_type
[
3
]
=
AnfAlgo
::
GetPrevNodeOutputInferDataType
(
grad_cast
,
0
);
AnfAlgo
::
SetOutputInferTypeAndShape
(
outputs_type
,
outputs_shape
,
node
.
get
());
return
node
;
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.h
0 → 100644
浏览文件 @
ae50c37c
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_
#include <memory>
#include "backend/optimizer/common/optimizer.h"
namespace
mindspore
{
namespace
opt
{
class
ReplaceMomentumCastFusion
:
public
PatternProcessPass
{
public:
explicit
ReplaceMomentumCastFusion
(
bool
multigraph
=
true
)
:
PatternProcessPass
(
"replace_momentum_cast"
,
multigraph
)
{
var_
=
std
::
make_shared
<
Var
>
();
acc_
=
std
::
make_shared
<
Var
>
();
lr_
=
std
::
make_shared
<
Var
>
();
grad_
=
std
::
make_shared
<
Var
>
();
mom_
=
std
::
make_shared
<
Var
>
();
}
~
ReplaceMomentumCastFusion
()
override
=
default
;
const
BaseRef
DefinePattern
()
const
override
;
const
AnfNodePtr
Process
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
,
const
EquivPtr
&
)
const
override
;
private:
VarPtr
var_
;
VarPtr
acc_
;
VarPtr
lr_
;
VarPtr
grad_
;
VarPtr
mom_
;
};
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_
mindspore/ccsrc/backend/session/gpu_session.cc
浏览文件 @
ae50c37c
...
@@ -25,6 +25,11 @@
...
@@ -25,6 +25,11 @@
#include "backend/optimizer/pass/getitem_tuple.h"
#include "backend/optimizer/pass/getitem_tuple.h"
#include "backend/optimizer/gpu/adam_weight_decay_fusion.h"
#include "backend/optimizer/gpu/adam_weight_decay_fusion.h"
#include "backend/optimizer/gpu/adam_fusion.h"
#include "backend/optimizer/gpu/adam_fusion.h"
#include "backend/optimizer/gpu/replace_bn_cast_fusion.h"
#include "backend/optimizer/gpu/replace_bn_grad_cast_fusion.h"
#include "backend/optimizer/gpu/replace_bn_grad_cast2_fusion.h"
#include "backend/optimizer/gpu/replace_momentum_cast_fusion.h"
#include "backend/optimizer/gpu/replace_addn_fusion.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "predict/predict.h"
#include "predict/predict.h"
#include "common/utils.h"
#include "common/utils.h"
...
@@ -59,6 +64,11 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
...
@@ -59,6 +64,11 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
auto
pm
=
std
::
make_shared
<
opt
::
PassManager
>
();
pm
->
AddPass
(
std
::
make_shared
<
opt
::
AdamWeightDecayFusion
>
());
pm
->
AddPass
(
std
::
make_shared
<
opt
::
AdamWeightDecayFusion
>
());
pm
->
AddPass
(
std
::
make_shared
<
opt
::
AdamFusion
>
());
pm
->
AddPass
(
std
::
make_shared
<
opt
::
AdamFusion
>
());
pm
->
AddPass
(
std
::
make_shared
<
opt
::
ReplaceBNCastFusion
>
());
pm
->
AddPass
(
std
::
make_shared
<
opt
::
ReplaceBNGradCastFusion
>
());
pm
->
AddPass
(
std
::
make_shared
<
opt
::
ReplaceBNGradCast2Fusion
>
());
pm
->
AddPass
(
std
::
make_shared
<
opt
::
ReplaceMomentumCastFusion
>
());
pm
->
AddPass
(
std
::
make_shared
<
opt
::
ReplaceAddNFusion
>
());
optimizer
->
AddPassManager
(
pm
);
optimizer
->
AddPassManager
(
pm
);
(
void
)
optimizer
->
Optimize
(
kernel_graph
);
(
void
)
optimizer
->
Optimize
(
kernel_graph
);
kernel_graph
->
SetExecOrderByDefault
();
kernel_graph
->
SetExecOrderByDefault
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录