diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_manager.cc b/mindspore/ccsrc/device/gpu/gpu_memory_manager.cc index 7d042264b6cd88c30114970551117560fe1f517b..8bb65963d87051132c8aadb67acdcca0bfb1a898 100644 --- a/mindspore/ccsrc/device/gpu/gpu_memory_manager.cc +++ b/mindspore/ccsrc/device/gpu/gpu_memory_manager.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mindspore/ccsrc/device/gpu/gpu_stream_assign.cc b/mindspore/ccsrc/device/gpu/gpu_stream_assign.cc index 39d5ca3fe60687aa7ffc39e50c30c4a9ddee4903..08a19aa469078074e6f39f18d70b48c5638a0309 100644 --- a/mindspore/ccsrc/device/gpu/gpu_stream_assign.cc +++ b/mindspore/ccsrc/device/gpu/gpu_stream_assign.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc index eeec8365dacfd8b7e785caab17b3a1a93de507b5..937f38137fb30e9f2d1456d0577f02421efd15d8 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc @@ -19,7 +19,6 @@ namespace mindspore { namespace kernel { - DropoutGpuFwdKernel::DropoutGpuFwdKernel() : cudnn_handle_(nullptr), is_null_input_(false), diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.cc index a95c1b78dd15d52f39ff8e19018fcd7ba5f7bca3..af957674076f04489eda302d2989380567c3975e 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.cc @@ -18,7 +18,6 @@ namespace mindspore { namespace kernel { - MS_REG_GPU_KERNEL_ONE(BatchNormFold2, KernelAttr() .AddInputAttr(kNumberTypeFloat32) diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h index 3e246f18f64eb7a61af8fc7f73ec41a859cb2a9f..beeeb12a9a663f2a67cc680fe42f7431c4d71739 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h @@ -132,7 +132,6 @@ class BatchNormFold2GpuKernel : public GpuKernel { std::vector output_size_list_; std::vector workspace_size_list_; }; - } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.cc index d5932f1984b983410600f34673a4feb8cf01738c..93862aeeddf6b40af474b61dff85dd0e3bbed473 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.cc @@ -18,7 +18,6 @@ namespace mindspore { namespace kernel { - MS_REG_GPU_KERNEL_ONE(BatchNormFold2Grad, KernelAttr() .AddInputAttr(kNumberTypeFloat32) diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.cc index b5fbfe4927052cf7532d1b15e73b913f02eb982d..4f968a0fa3284799bed416e54c9e4f86cfa4bea3 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.cc @@ -18,7 +18,6 @@ namespace mindspore { namespace kernel { - MS_REG_GPU_KERNEL_ONE(BatchNormFold, KernelAttr() .AddInputAttr(kNumberTypeFloat32) diff --git a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.h index 7608ae5d3c55ca4ed905245a063be5d8f9fe9921..eeab872ab367b4b48c0678c74f9aca51cd55b6a7 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.h @@ -54,7 +54,6 @@ class CorrectionMulGpuKernel : public GpuKernel { } auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() != 4) { MS_LOG(ERROR) << "CorrectionMulGpuKernel input shape needs (N,C,H,W)."; return false; diff --git a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.cc index 211c515e025777f1ca849088d0e65692b7a7fa6a..28b5d56e68494dd9baf4906785f5000879078dd3 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.cc @@ -19,7 +19,6 @@ namespace mindspore { namespace kernel { - MS_REG_GPU_KERNEL_ONE(CorrectionMulGrad, KernelAttr() .AddInputAttr(kNumberTypeFloat32) diff --git a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.h index 2439826cc39457496e0742aa6686b6450047b492..29aeb3be139e7e16a839b37100027e5a714552eb 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.h @@ -61,7 +61,6 @@ class CorrectionMulGradGpuKernel : public GpuKernel { } auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() != 4) { MS_LOG(ERROR) << "CorrectionMulGradGpuKernel input shape needs (N,C,H,W)."; return false; diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.cc index 302ef8d99f657fd9c9e1e0a817fcd7f455bbc686..1da9f457a1d09681ff674eb5808391bb41d8108f 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.cc @@ -114,6 +114,36 @@ void FakeQuantPerChannelGpuKernel::InitSizeLists() { workspace_size_list_.push_back(workspace_size_); } +void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForTraining(float *input, float *output, float *input_min, + float *input_max, float *d_nudge_min, float *d_nudge_max, + float *d_scale, uintptr_t stream_ptr) { + // calculate the input min and max according by the parameter ema and ema_decay. + CalMinMaxPerChannel(input, input_min, input_max, input_size_ / sizeof(float), channel_out_, ema_decay_, ema_, + reinterpret_cast(stream_ptr)); + // control flow for quant_delay + if (global_step_ >= quant_delay_) { + // real launch + CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_, + reinterpret_cast(stream_ptr)); + CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), channel_out_, d_nudge_min, d_nudge_max, + d_scale, symmetric_, reinterpret_cast(stream_ptr)); + } else { + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, input, input_size_, cudaMemcpyDeviceToDevice), + "Copy gpu memory failed."); + } + global_step_++; +} + +void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForInfer(float *input, float *output, float *input_min, + float *input_max, float *d_nudge_min, float *d_nudge_max, + float *d_scale, uintptr_t stream_ptr) { + // real launch + CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_, + reinterpret_cast(stream_ptr)); + CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), channel_out_, d_nudge_min, d_nudge_max, d_scale, + symmetric_, reinterpret_cast(stream_ptr)); +} + bool FakeQuantPerChannelGpuKernel::Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, uintptr_t stream_ptr) { @@ -126,11 +156,8 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector &inputs, if (input == nullptr) { MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input is null."; } - if (input_min == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input min is null."; - } - if (input_max == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input max is null."; + if (input_min == nullptr || input_max == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input min or max is null."; } // Allocate space for device copies @@ -143,30 +170,11 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector &inputs, "Malloc gpu memory failed"); CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast(&d_nudge_max), sizeof(float) * channel_out_), "Malloc gpu memory failed"); - int total_size = input_size_ / sizeof(float); - bool symmetric = false; + if (training_) { - // calculate the input min and max according by the parameter ema and ema_decay. - CalMinMaxPerChannel(input, input_min, input_max, total_size, channel_out_, ema_decay_, ema_, - reinterpret_cast(stream_ptr)); - // control flow for quant_delay - if (global_step_ >= quant_delay_) { - // real launch - CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_, - reinterpret_cast(stream_ptr)); - CalFakeQuantizePerChannel(input, output, total_size, channel_out_, d_nudge_min, d_nudge_max, d_scale, symmetric, - reinterpret_cast(stream_ptr)); - } else { - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, input, input_size_, cudaMemcpyDeviceToDevice), - "Copy gpu memory failed."); - } - global_step_++; + CalFakeQuantizeForTraining(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr); } else { - // real launch - CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_, - reinterpret_cast(stream_ptr)); - CalFakeQuantizePerChannel(input, output, total_size, channel_out_, d_nudge_min, d_nudge_max, d_scale, symmetric, - reinterpret_cast(stream_ptr)); + CalFakeQuantizeForInfer(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr); } // Cleanup diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.h index faf8684fcac6bf10ba6153825e076be9c6eba761..8a1bb7293aafd3b1653c8c3dcec639922566154d 100755 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.h @@ -39,6 +39,11 @@ class FakeQuantPerChannelGpuKernel : public GpuKernel { void InitSizeLists() override; private: + void CalFakeQuantizeForTraining(float *input, float *output, float *input_min, float *input_max, float *d_nudge_min, + float *d_nudge_max, float *d_scale, uintptr_t stream_ptr); + void CalFakeQuantizeForInfer(float *input, float *output, float *input_min, float *input_max, float *d_nudge_min, + float *d_nudge_max, float *d_scale, uintptr_t stream_ptr); + size_t input_size_; size_t min_size_; size_t max_size_;