Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
58b013c3
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
58b013c3
编写于
4月 16, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 16, 2020
浏览文件
操作
浏览文件
下载
差异文件
!363 clear the warmming scan by package
Merge pull request !363 from SanjayChan/labao
上级
0a06d5e8
b77f41d6
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
42 addition
and
37 deletion
+42
-37
mindspore/ccsrc/device/gpu/gpu_memory_manager.cc
mindspore/ccsrc/device/gpu/gpu_memory_manager.cc
+1
-1
mindspore/ccsrc/device/gpu/gpu_stream_assign.cc
mindspore/ccsrc/device/gpu/gpu_stream_assign.cc
+1
-1
mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc
mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc
+0
-1
mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.cc
...pore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.cc
+0
-1
mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h
...spore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h
+0
-1
mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.cc
...ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.cc
+0
-1
mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.cc
...spore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.cc
+0
-1
mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.h
+0
-1
mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.cc
.../ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.cc
+0
-1
mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.h
...e/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.h
+0
-1
mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.cc
...src/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.cc
+35
-27
mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.h
...csrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.h
+5
-0
未找到文件。
mindspore/ccsrc/device/gpu/gpu_memory_manager.cc
浏览文件 @
58b013c3
/**
* Copyright 20
19
Huawei Technologies Co., Ltd
* Copyright 20
20
Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
...
...
mindspore/ccsrc/device/gpu/gpu_stream_assign.cc
浏览文件 @
58b013c3
/**
* Copyright 20
19
Huawei Technologies Co., Ltd
* Copyright 20
20
Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
...
...
mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc
浏览文件 @
58b013c3
...
...
@@ -19,7 +19,6 @@
namespace
mindspore
{
namespace
kernel
{
DropoutGpuFwdKernel
::
DropoutGpuFwdKernel
()
:
cudnn_handle_
(
nullptr
),
is_null_input_
(
false
),
...
...
mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.cc
浏览文件 @
58b013c3
...
...
@@ -18,7 +18,6 @@
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_ONE
(
BatchNormFold2
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat32
)
...
...
mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h
浏览文件 @
58b013c3
...
...
@@ -132,7 +132,6 @@ class BatchNormFold2GpuKernel : public GpuKernel {
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
};
}
// namespace kernel
}
// namespace mindspore
...
...
mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.cc
浏览文件 @
58b013c3
...
...
@@ -18,7 +18,6 @@
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_ONE
(
BatchNormFold2Grad
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat32
)
...
...
mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.cc
浏览文件 @
58b013c3
...
...
@@ -18,7 +18,6 @@
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_ONE
(
BatchNormFold
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat32
)
...
...
mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.h
浏览文件 @
58b013c3
...
...
@@ -54,7 +54,6 @@ class CorrectionMulGpuKernel : public GpuKernel {
}
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
if
(
input_shape
.
size
()
!=
4
)
{
MS_LOG
(
ERROR
)
<<
"CorrectionMulGpuKernel input shape needs (N,C,H,W)."
;
return
false
;
...
...
mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.cc
浏览文件 @
58b013c3
...
...
@@ -19,7 +19,6 @@
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_ONE
(
CorrectionMulGrad
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat32
)
...
...
mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.h
浏览文件 @
58b013c3
...
...
@@ -61,7 +61,6 @@ class CorrectionMulGradGpuKernel : public GpuKernel {
}
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
if
(
input_shape
.
size
()
!=
4
)
{
MS_LOG
(
ERROR
)
<<
"CorrectionMulGradGpuKernel input shape needs (N,C,H,W)."
;
return
false
;
...
...
mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.cc
浏览文件 @
58b013c3
...
...
@@ -114,6 +114,36 @@ void FakeQuantPerChannelGpuKernel::InitSizeLists() {
workspace_size_list_
.
push_back
(
workspace_size_
);
}
void
FakeQuantPerChannelGpuKernel
::
CalFakeQuantizeForTraining
(
float
*
input
,
float
*
output
,
float
*
input_min
,
float
*
input_max
,
float
*
d_nudge_min
,
float
*
d_nudge_max
,
float
*
d_scale
,
uintptr_t
stream_ptr
)
{
// calculate the input min and max according by the parameter ema and ema_decay.
CalMinMaxPerChannel
(
input
,
input_min
,
input_max
,
input_size_
/
sizeof
(
float
),
channel_out_
,
ema_decay_
,
ema_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
// control flow for quant_delay
if
(
global_step_
>=
quant_delay_
)
{
// real launch
CalNudgePerChannel
(
input_min
,
input_max
,
quant_min_
,
quant_max_
,
d_nudge_min
,
d_nudge_max
,
d_scale
,
channel_out_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalFakeQuantizePerChannel
(
input
,
output
,
input_size_
/
sizeof
(
float
),
channel_out_
,
d_nudge_min
,
d_nudge_max
,
d_scale
,
symmetric_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
else
{
CHECK_CUDA_RET_WITH_ERROR
(
cudaMemcpy
(
output
,
input
,
input_size_
,
cudaMemcpyDeviceToDevice
),
"Copy gpu memory failed."
);
}
global_step_
++
;
}
void
FakeQuantPerChannelGpuKernel
::
CalFakeQuantizeForInfer
(
float
*
input
,
float
*
output
,
float
*
input_min
,
float
*
input_max
,
float
*
d_nudge_min
,
float
*
d_nudge_max
,
float
*
d_scale
,
uintptr_t
stream_ptr
)
{
// real launch
CalNudgePerChannel
(
input_min
,
input_max
,
quant_min_
,
quant_max_
,
d_nudge_min
,
d_nudge_max
,
d_scale
,
channel_out_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalFakeQuantizePerChannel
(
input
,
output
,
input_size_
/
sizeof
(
float
),
channel_out_
,
d_nudge_min
,
d_nudge_max
,
d_scale
,
symmetric_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
bool
FakeQuantPerChannelGpuKernel
::
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
uintptr_t
stream_ptr
)
{
...
...
@@ -126,11 +156,8 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
if
(
input
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"FakeQuantPerChannelGpuKernel input is null."
;
}
if
(
input_min
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"FakeQuantPerChannelGpuKernel input min is null."
;
}
if
(
input_max
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"FakeQuantPerChannelGpuKernel input max is null."
;
if
(
input_min
==
nullptr
||
input_max
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"FakeQuantPerChannelGpuKernel input min or max is null."
;
}
// Allocate space for device copies
...
...
@@ -143,30 +170,11 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
"Malloc gpu memory failed"
);
CHECK_CUDA_RET_WITH_ERROR
(
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
d_nudge_max
),
sizeof
(
float
)
*
channel_out_
),
"Malloc gpu memory failed"
);
int
total_size
=
input_size_
/
sizeof
(
float
);
bool
symmetric
=
false
;
if
(
training_
)
{
// calculate the input min and max according by the parameter ema and ema_decay.
CalMinMaxPerChannel
(
input
,
input_min
,
input_max
,
total_size
,
channel_out_
,
ema_decay_
,
ema_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
// control flow for quant_delay
if
(
global_step_
>=
quant_delay_
)
{
// real launch
CalNudgePerChannel
(
input_min
,
input_max
,
quant_min_
,
quant_max_
,
d_nudge_min
,
d_nudge_max
,
d_scale
,
channel_out_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalFakeQuantizePerChannel
(
input
,
output
,
total_size
,
channel_out_
,
d_nudge_min
,
d_nudge_max
,
d_scale
,
symmetric
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
else
{
CHECK_CUDA_RET_WITH_ERROR
(
cudaMemcpy
(
output
,
input
,
input_size_
,
cudaMemcpyDeviceToDevice
),
"Copy gpu memory failed."
);
}
global_step_
++
;
CalFakeQuantizeForTraining
(
input
,
output
,
input_min
,
input_max
,
d_nudge_min
,
d_nudge_max
,
d_scale
,
stream_ptr
);
}
else
{
// real launch
CalNudgePerChannel
(
input_min
,
input_max
,
quant_min_
,
quant_max_
,
d_nudge_min
,
d_nudge_max
,
d_scale
,
channel_out_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalFakeQuantizePerChannel
(
input
,
output
,
total_size
,
channel_out_
,
d_nudge_min
,
d_nudge_max
,
d_scale
,
symmetric
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalFakeQuantizeForInfer
(
input
,
output
,
input_min
,
input_max
,
d_nudge_min
,
d_nudge_max
,
d_scale
,
stream_ptr
);
}
// Cleanup
...
...
mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.h
浏览文件 @
58b013c3
...
...
@@ -39,6 +39,11 @@ class FakeQuantPerChannelGpuKernel : public GpuKernel {
void
InitSizeLists
()
override
;
private:
void
CalFakeQuantizeForTraining
(
float
*
input
,
float
*
output
,
float
*
input_min
,
float
*
input_max
,
float
*
d_nudge_min
,
float
*
d_nudge_max
,
float
*
d_scale
,
uintptr_t
stream_ptr
);
void
CalFakeQuantizeForInfer
(
float
*
input
,
float
*
output
,
float
*
input_min
,
float
*
input_max
,
float
*
d_nudge_min
,
float
*
d_nudge_max
,
float
*
d_scale
,
uintptr_t
stream_ptr
);
size_t
input_size_
;
size_t
min_size_
;
size_t
max_size_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录