提交 14017418 编写于 作者: V VectorSL

gpu add fusion: replace momentum cast

上级 2f2dc390
......@@ -15,9 +15,9 @@
*/
#include "momentum_impl.cuh"
template <typename T, typename S>
template <typename T, typename S, typename G>
__global__ void MomentumUpdateVariableKernel(const size_t size, T *variable, T *accumulation, const S *learning_rate,
const T *gradient, const S *momentum) {
const G *gradient, const S *momentum) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) {
accumulation[i] = momentum[0] * accumulation[i] + gradient[i];
variable[i] -= learning_rate[0] * accumulation[i];
......@@ -34,19 +34,32 @@ __global__ void MomentumUpdateVariableKernel(const size_t size, half *variable,
}
return;
}
template <typename T, typename S>
void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const T *gradient,
template <>
__global__ void MomentumUpdateVariableKernel(const size_t size, float *variable, float *accumulation,
const float *learning_rate, const half *gradient,
const float *momentum) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) {
accumulation[i] = momentum[0] * accumulation[i] + __half2float(gradient[i]);
variable[i] -= learning_rate[0] * accumulation[i];
}
return;
}
template <typename T, typename S, typename G>
void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const G *gradient,
const S *momentum, cudaStream_t cuda_stream) {
MomentumUpdateVariableKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, variable, accumulation,
learning_rate, gradient, momentum);
return;
}
template void MomentumUpdateVariable<float, float>(const size_t size, float *variable, float *accumulation,
template void MomentumUpdateVariable<float, float, float>(const size_t size, float *variable, float *accumulation,
const float *learning_rate, const float *gradient,
const float *momentum, cudaStream_t cuda_stream);
template void MomentumUpdateVariable<half, half>(const size_t size, half *variable, half *accumulation,
template void MomentumUpdateVariable<half, half, half>(const size_t size, half *variable, half *accumulation,
const half *learning_rate, const half *gradient,
const half *momentum, cudaStream_t cuda_stream);
template void MomentumUpdateVariable<half, float>(const size_t size, half *variable, half *accumulation,
template void MomentumUpdateVariable<half, float, half>(const size_t size, half *variable, half *accumulation,
const float *learning_rate, const half *gradient,
const float *momentum, cudaStream_t cuda_stream);
template void MomentumUpdateVariable<float, float, half>(const size_t size, float *variable, float *accumulation,
const float *learning_rate, const half *gradient,
const float *momentum, cudaStream_t cuda_stream);
......@@ -18,8 +18,8 @@
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_
#include "runtime/device/gpu/cuda_common.h"
template <typename T, typename S>
void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const T *gradient,
template <typename T, typename S, typename G>
void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const G *gradient,
const S *momentum, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_
......@@ -88,6 +88,12 @@ class GpuKernelRegister {
static_assert(std::is_base_of<GpuKernel, OPCLASS<T, S>>::value, " must be base of GpuKernel"); \
static const GpuKernelRegister g_##OPNAME##_##T##_##S##_gpu_kernel_reg(#OPNAME, ATTR, \
[]() { return new OPCLASS<T, S>(); });
// register of mixed accuracy kernels which use template and maintain three typename
#define MS_REG_GPU_KERNEL_THREE(OPNAME, ATTR, OPCLASS, T, S, G) \
static_assert(std::is_base_of<GpuKernel, OPCLASS<T, S, G>>::value, " must be base of GpuKernel"); \
static const GpuKernelRegister g_##OPNAME##_##T##_##S##_##G##_gpu_kernel_reg( \
#OPNAME, ATTR, []() { return new OPCLASS<T, S, G>(); });
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNELFACTORY_H_
......@@ -34,15 +34,15 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNorm,
MS_REG_GPU_KERNEL_ONE(FusedBatchNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
FusedBatchNormGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(BatchNorm,
KernelAttr()
......@@ -60,15 +60,15 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm,
MS_REG_GPU_KERNEL_ONE(BatchNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
FusedBatchNormGpuKernel, half)
} // namespace kernel
} // namespace mindspore
......@@ -56,17 +56,17 @@ class FusedBatchNormGpuKernel : public GpuKernel {
return true;
}
auto x = GetDeviceAddress<T>(inputs, 0);
auto scale = GetDeviceAddress<T>(inputs, 1);
auto bias = GetDeviceAddress<T>(inputs, 2);
auto runing_mean = GetDeviceAddress<T>(inputs, 3);
auto runnig_variance = GetDeviceAddress<T>(inputs, 4);
auto scale = GetDeviceAddress<float>(inputs, 1);
auto bias = GetDeviceAddress<float>(inputs, 2);
auto runing_mean = GetDeviceAddress<float>(inputs, 3);
auto runnig_variance = GetDeviceAddress<float>(inputs, 4);
auto y = GetDeviceAddress<T>(outputs, 0);
const float alpha = 1;
const float beta = 0;
if (is_train_) {
auto save_mean = GetDeviceAddress<T>(outputs, 3);
auto save_variance = GetDeviceAddress<T>(outputs, 4);
auto save_mean = GetDeviceAddress<float>(outputs, 3);
auto save_variance = GetDeviceAddress<float>(outputs, 4);
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnBatchNormalizationForwardTraining(handle_, mode_, &alpha, &beta, x_desc_, x, y_desc_, y,
scale_bias_mean_var_desc_, scale, bias, exp_avg_factor_, runing_mean,
......
......@@ -33,12 +33,12 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
FusedBatchNormGradGpuKernel, half)
} // namespace kernel
} // namespace mindspore
......@@ -55,12 +55,12 @@ class FusedBatchNormGradGpuKernel : public GpuKernel {
}
auto dy = GetDeviceAddress<T>(inputs, 0);
auto x = GetDeviceAddress<T>(inputs, 1);
auto scale = GetDeviceAddress<T>(inputs, 2);
auto save_mean = GetDeviceAddress<T>(inputs, 3);
auto save_variance = GetDeviceAddress<T>(inputs, 4);
auto scale = GetDeviceAddress<float>(inputs, 2);
auto save_mean = GetDeviceAddress<float>(inputs, 3);
auto save_variance = GetDeviceAddress<float>(inputs, 4);
auto dx = GetDeviceAddress<T>(outputs, 0);
auto bn_scale = GetDeviceAddress<T>(outputs, 1);
auto bn_bias = GetDeviceAddress<T>(outputs, 2);
auto bn_scale = GetDeviceAddress<float>(outputs, 1);
auto bn_bias = GetDeviceAddress<float>(outputs, 2);
const float alpha_data_diff = 1;
const float beta_data_diff = 0;
......
......@@ -18,32 +18,41 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(ApplyMomentum,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
MomentumGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(ApplyMomentum,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
MomentumGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(ApplyMomentum,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16),
MomentumGpuKernel, half, float)
MS_REG_GPU_KERNEL_THREE(ApplyMomentum,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
MomentumGpuKernel, float, float, float)
MS_REG_GPU_KERNEL_THREE(ApplyMomentum,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
MomentumGpuKernel, half, half, half)
MS_REG_GPU_KERNEL_THREE(ApplyMomentum,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16),
MomentumGpuKernel, half, float, half)
MS_REG_GPU_KERNEL_THREE(ApplyMomentum,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
MomentumGpuKernel, float, float, half)
} // namespace kernel
} // namespace mindspore
......@@ -23,7 +23,7 @@
#include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
template <typename T, typename S, typename G>
class MomentumGpuKernel : public GpuKernel {
public:
MomentumGpuKernel()
......@@ -38,7 +38,7 @@ class MomentumGpuKernel : public GpuKernel {
T *variable = GetDeviceAddress<T>(inputs, 0);
T *accumulation = GetDeviceAddress<T>(inputs, 1);
S *learning_rate = GetDeviceAddress<S>(inputs, 2);
T *gradient = GetDeviceAddress<T>(inputs, 3);
G *gradient = GetDeviceAddress<G>(inputs, 3);
S *momentum = GetDeviceAddress<S>(inputs, 4);
MomentumUpdateVariable(inputs[0]->size / sizeof(T), variable, accumulation, learning_rate, gradient, momentum,
reinterpret_cast<cudaStream_t>(stream_ptr));
......@@ -54,7 +54,7 @@ class MomentumGpuKernel : public GpuKernel {
variable_size_ = sizeof(T);
accumulation_size_ = sizeof(T);
learning_rate_size_ = sizeof(S);
gradient_size_ = sizeof(T);
gradient_size_ = sizeof(G);
momentum_size_ = sizeof(S);
auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/gpu/replace_momentum_cast_fusion.h"
#include <memory>
#include <vector>
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
const BaseRef ReplaceMomentumCastFusion::DefinePattern() const {
VectorRef grad_cast = VectorRef({prim::kPrimCast, grad_});
VectorRef momentum = VectorRef({prim::kPrimApplyMomentum, var_, acc_, lr_, grad_cast, mom_});
return momentum;
}
const AnfNodePtr ReplaceMomentumCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
auto grad_cast = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 3);
auto grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(grad_cast), 0);
MS_EXCEPTION_IF_NULL(grad_cast);
MS_EXCEPTION_IF_NULL(grad);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->Replace(utils::cast<CNodePtr>(grad_cast), utils::cast<CNodePtr>(grad));
std::vector<TypeId> outputs_type;
std::vector<std::vector<size_t>> outputs_shape;
auto output_num = AnfAlgo::GetOutputTensorNum(node);
for (size_t i = 0; i < output_num; i++) {
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, i));
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(node, i));
}
outputs_type[3] = AnfAlgo::GetPrevNodeOutputInferDataType(grad_cast, 0);
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, node.get());
return node;
}
} // namespace opt
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_
#include <memory>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class ReplaceMomentumCastFusion : public PatternProcessPass {
public:
explicit ReplaceMomentumCastFusion(bool multigraph = true) : PatternProcessPass("replace_momentum_cast", multigraph) {
var_ = std::make_shared<Var>();
acc_ = std::make_shared<Var>();
lr_ = std::make_shared<Var>();
grad_ = std::make_shared<Var>();
mom_ = std::make_shared<Var>();
}
~ReplaceMomentumCastFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
VarPtr var_;
VarPtr acc_;
VarPtr lr_;
VarPtr grad_;
VarPtr mom_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REPLACE_MOMENTUM_CAST_FUSION_H_
......@@ -25,6 +25,11 @@
#include "backend/optimizer/pass/getitem_tuple.h"
#include "backend/optimizer/gpu/adam_weight_decay_fusion.h"
#include "backend/optimizer/gpu/adam_fusion.h"
#include "backend/optimizer/gpu/replace_bn_cast_fusion.h"
#include "backend/optimizer/gpu/replace_bn_grad_cast_fusion.h"
#include "backend/optimizer/gpu/replace_bn_grad_cast2_fusion.h"
#include "backend/optimizer/gpu/replace_momentum_cast_fusion.h"
#include "backend/optimizer/gpu/replace_addn_fusion.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "predict/predict.h"
#include "common/utils.h"
......@@ -59,6 +64,11 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>());
pm->AddPass(std::make_shared<opt::AdamFusion>());
pm->AddPass(std::make_shared<opt::ReplaceBNCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceBNGradCast2Fusion>());
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册