提交 73863fd0 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4819 Add FusedBatchEx support

Merge pull request !4819 from zyli2020/Add_FusedBatchEx_support
......@@ -130,8 +130,9 @@ class NcclGpuKernel : public GpuKernel {
for (size_t j = 0; j < shape.size(); j++) {
size *= IntToSize(shape[j]);
}
input_size_list_.push_back(size);
input_size_ += size;
size_t aligned_size = AlignMemorySize(size);
input_size_list_.push_back(aligned_size);
input_size_ += aligned_size;
}
for (size_t i = 0; i < output_num; ++i) {
auto shape = AnfAlgo::GetOutputInferShape(kernel_node, i);
......@@ -139,8 +140,9 @@ class NcclGpuKernel : public GpuKernel {
for (size_t j = 0; j < shape.size(); j++) {
size *= IntToSize(shape[j]);
}
output_size_list_.push_back(size);
output_size_ += size;
size_t aligned_size = AlignMemorySize(size);
output_size_list_.push_back(aligned_size);
output_size_ += aligned_size;
}
InferCommType(kernel_node);
......@@ -193,6 +195,13 @@ class NcclGpuKernel : public GpuKernel {
return;
}
size_t AlignMemorySize(size_t size) const {
if (size == 0) {
return COMMUNICATION_MEM_ALIGN_SIZE;
}
return ((size + COMMUNICATION_MEM_ALIGN_SIZE - 1) / COMMUNICATION_MEM_ALIGN_SIZE) * COMMUNICATION_MEM_ALIGN_SIZE;
}
NcclKernelType nccl_kernel_type_;
ncclRedOp_t nccl_reduce_type_;
ncclDataType_t nccl_data_type_;
......@@ -205,6 +214,8 @@ class NcclGpuKernel : public GpuKernel {
std::vector<size_t> workspace_size_list_;
const void *collective_handle_;
cudaStream_t comm_stream_;
static const size_t COMMUNICATION_MEM_ALIGN_SIZE = 16;
};
} // namespace kernel
} // namespace mindspore
......
......@@ -75,7 +75,7 @@ class ActivationGpuFwdKernel : public GpuKernel {
MS_LOG(ERROR) << "Argument number is " << input_num << ", but ActivationGpuFwdKernel needs 1.";
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "ActivationGpuFwdKernel input is null.";
......@@ -89,9 +89,15 @@ class ActivationGpuFwdKernel : public GpuKernel {
const int split_dim = 4;
if (input_shape.size() <= split_dim) {
ShapeNdTo4d(input_shape, &shape);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
shape[0], shape[1], shape[2], shape[3]),
"cudnnSetTensor4dDescriptor failed");
if (AnfAlgo::GetInputFormat(kernel_node, 0) == kOpFormat_NHWC) {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_,
shape[0], shape[3], shape[1], shape[2]),
"cudnnSetTensor4dDescriptor failed");
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
shape[0], shape[1], shape[2], shape[3]),
"cudnnSetTensor4dDescriptor failed");
}
} else {
CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_);
}
......
......@@ -28,7 +28,7 @@ namespace mindspore {
namespace opt {
const BaseRef ReplaceBNCastFusion::DefinePattern() const {
VectorRef in_cast = VectorRef({prim::kPrimCast, x_});
VectorRef fbn2 = VectorRef({prim::kPrimFusedBatchNorm, in_cast, scale_, bias_, mean_, var_});
VectorRef fbn2 = VectorRef({prim::kPrimFusedBatchNormEx, in_cast, scale_, bias_, mean_, var_});
VectorRef tupleget = VectorRef({prim::kPrimTupleGetItem, fbn2, index_});
return tupleget;
}
......
......@@ -28,7 +28,7 @@ namespace mindspore {
namespace opt {
const BaseRef ReplaceBNGradCastFusion::DefinePattern() const {
VectorRef dy_cast = VectorRef({prim::kPrimCast, dy_});
VectorRef fbn2g = VectorRef({prim::kPrimFusedBatchNormGrad, dy_cast, x_, scale_, mean_, var_});
VectorRef fbn2g = VectorRef({prim::kPrimFusedBatchNormGradEx, dy_cast, x_, scale_, mean_, var_, reserve_});
VectorRef tupleget = VectorRef({prim::kPrimTupleGetItem, fbn2g, index_});
return tupleget;
}
......
......@@ -33,6 +33,7 @@ class ReplaceBNGradCastFusion : public PatternProcessPass {
bn_scale_ = std::make_shared<Var>();
bn_bias_ = std::make_shared<Var>();
index_ = std::make_shared<Var>();
reserve_ = std::make_shared<Var>();
}
~ReplaceBNGradCastFusion() override = default;
const BaseRef DefinePattern() const override;
......@@ -48,6 +49,7 @@ class ReplaceBNGradCastFusion : public PatternProcessPass {
VarPtr bn_scale_;
VarPtr bn_bias_;
VarPtr index_;
VarPtr reserve_;
};
} // namespace opt
} // namespace mindspore
......
......@@ -28,6 +28,9 @@
#include "backend/optimizer/gpu/adam_fusion.h"
#include "backend/optimizer/gpu/replace_bn_cast_fusion.h"
#include "backend/optimizer/gpu/replace_bn_grad_cast_fusion.h"
#include "backend/optimizer/gpu/batch_norm_relu_fusion.h"
#include "backend/optimizer/gpu/batch_norm_relu_grad_fusion.h"
#include "backend/optimizer/gpu/batch_norm_add_relu_fusion.h"
#include "backend/optimizer/gpu/replace_momentum_cast_fusion.h"
#include "backend/optimizer/gpu/replace_addn_fusion.h"
#include "backend/optimizer/gpu/insert_format_transform_op.h"
......@@ -70,6 +73,9 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
......
......@@ -40,17 +40,21 @@ bool GPUDeviceAddress::SyncDeviceToHost(const std::vector<int> &, size_t size, T
return ret;
}
if (size != size_) {
MS_LOG(WARNING) << "SyncDeviceToHost ignored, host size: " << size << ", device size " << size_;
return true;
// nccl kernel input and outpu memory size is aligned, may lead to sync memory size is inconformity
MS_LOG(INFO) << "Sync memory size is inconformity, host size: " << size << ", device size " << size_;
}
return GPUDeviceManager::GetInstance().CopyDeviceMemToHost(host_ptr, ptr_, size_);
return GPUDeviceManager::GetInstance().CopyDeviceMemToHost(host_ptr, ptr_, size);
}
bool GPUDeviceAddress::SyncHostToDevice(const std::vector<int> &, size_t, TypeId, const void *host_ptr) const {
bool GPUDeviceAddress::SyncHostToDevice(const std::vector<int> &, size_t size, TypeId, const void *host_ptr) const {
MS_EXCEPTION_IF_NULL(host_ptr);
auto &stream = GPUDeviceManager::GetInstance().default_stream();
MS_EXCEPTION_IF_NULL(stream);
if (!GPUDeviceManager::GetInstance().CopyHostMemToDeviceAsync(ptr_, host_ptr, size_, stream)) {
if (size != size_) {
// nccl kernel input and outpu memory size is aligned, may lead to sync memory size is inconformity
MS_LOG(INFO) << "Sync memory size is inconformity, host size: " << size << ", device size " << size_;
}
if (!GPUDeviceManager::GetInstance().CopyHostMemToDeviceAsync(ptr_, host_ptr, size, stream)) {
MS_LOG(ERROR) << "CopyHostMemToDeviceAsync failed";
return false;
}
......
......@@ -1001,7 +1001,10 @@ void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfN
size_t total_size = 0;
std::vector<size_t> size_list;
DeviceAddressPtrList addr_list;
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
auto intput_sizes = kernel_mod->GetInputSizeList();
for (size_t i = 0; i < intput_sizes.size(); ++i) {
DeviceAddressPtr device_address;
if (mem_reuse_util_->is_all_nop_node()) {
// Graph may be all nop nodes and not remove nop node, so this can not skip nop node.
......@@ -1016,8 +1019,8 @@ void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfN
} else {
is_need_free_memory = true;
}
total_size += device_address->size_;
size_list.emplace_back(device_address->size_);
total_size += intput_sizes[i];
size_list.emplace_back(intput_sizes[i]);
addr_list.emplace_back(device_address);
}
AllocCommunicationOpMemory(is_need_alloc_memory, is_need_free_memory, addr_list, total_size, size_list);
......
......@@ -180,7 +180,7 @@ bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<Type
if (input_shape.size() != 4) {
return false;
}
return false;
return true;
}
void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type,
......
......@@ -86,6 +86,7 @@ class _BatchNorm(Cell):
self.dtype = P.DType()
self.reshape = P.Reshape()
self.is_ascend = context.get_context("device_target") == "Ascend"
self.is_gpu = context.get_context("device_target") == "GPU"
self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
self.momentum = 1.0 - momentum
if context.get_context("enable_ge"):
......@@ -96,6 +97,10 @@ class _BatchNorm(Cell):
if self.is_graph_mode and (self.is_ge_backend or self.is_ascend):
self.bn_train = P.BatchNorm(is_training=True,
epsilon=self.eps)
elif self.is_gpu:
self.bn_train = P.FusedBatchNormEx(mode=1,
epsilon=self.eps,
momentum=self.momentum)
else:
self.bn_train = P.FusedBatchNorm(mode=1,
epsilon=self.eps,
......
......@@ -535,6 +535,24 @@ def get_bprop_fused_batch_norm(self):
return bprop
@bprop_getters.register(P.FusedBatchNormEx)
def get_bprop_fused_batch_norm_ex(self):
"""Grad definition for `FusedBatchNormEx` operation."""
input_grad = G.FusedBatchNormGradEx(self.epsilon, self.momentum)
def bprop(x, scale, b, mean, variance, out, dout):
saved_mean = out[3]
saved_variance = out[4]
reserve = out[5]
out = input_grad(dout[0], x, scale, saved_mean, saved_variance, reserve)
dx = out[0]
dscale = out[1]
dbias = out[2]
return dx, dscale, dbias, zeros_like(mean), zeros_like(variance)
return bprop
@bprop_getters.register(P.BatchNorm)
def get_bprop_batch_norm(self):
"""Grad definition for `BatchNorm` operation."""
......
......@@ -62,7 +62,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl
BiasAdd, Conv2D,
DepthwiseConv2dNative,
DropoutDoMask, DropoutGrad, Dropout,
DropoutGenMask, Flatten, FusedBatchNorm, BNTrainingReduce, BNTrainingUpdate,
DropoutGenMask, Flatten, FusedBatchNorm, FusedBatchNormEx, BNTrainingReduce, BNTrainingUpdate,
Gelu, Elu,
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCLossV2,
LogSoftmax,
......@@ -118,6 +118,7 @@ __all__ = [
'Flatten',
'MaxPoolWithArgmax',
'FusedBatchNorm',
'FusedBatchNormEx',
'BNTrainingReduce',
'BNTrainingUpdate',
'BatchNorm',
......
......@@ -491,6 +491,22 @@ class FusedBatchNormGrad(Primitive):
raise NotImplementedError
class FusedBatchNormGradEx(PrimitiveWithInfer):
"""Gradients of FusedBatchNormEx operation."""
@prim_attr_register
def __init__(self, epsilon=0.0, momentum=0.1):
self.init_prim_io_names(inputs=['dy', 'x', 'scale', 'save_mean', 'save_inv_variance', 'reserve'],
outputs=['dx', 'bn_scale', 'bn_bias'])
self.add_prim_attr('data_format', "NCHW")
def infer_shape(self, y_backprop_shape, x_shape, scale_shape, save_mean_shape, save_variance_shape, reserve_shape):
return (x_shape, scale_shape, scale_shape)
def infer_dtype(self, y_backprop_type, x_type, scale_type, save_mean_type, save_variance_type, reserve_type):
return (x_type, scale_type, scale_type)
class UniqueGrad(Primitive):
"""Gradients of Unique operation."""
......
......@@ -623,6 +623,73 @@ class FusedBatchNorm(Primitive):
self._update_parameter = True
class FusedBatchNormEx(PrimitiveWithInfer):
r"""
FusedBatchNormEx is an extension of FusedBatchNorm
Args:
mode (int): Mode of batch normalization, value is 0 or 1. Default: 0.
epsilon (float): A small value added for numerical stability. Default: 1e-5.
momentum (float): The hyper parameter to compute moving average for running_mean and running_var
(e.g. :math:`new\_running\_mean = momentum * running\_mean + (1 - momentum) * current\_mean`).
Momentum value should be [0, 1]. Default: 0.9.
Inputs:
- **input_x** (Tensor) - Tensor of shape :math:`(N, C)`.
- **scale** (Tensor) - Tensor of shape :math:`(C,)`.
- **bias** (Tensor) - Tensor of shape :math:`(C,)`.
- **mean** (Tensor) - Tensor of shape :math:`(C,)`.
- **variance** (Tensor) - Tensor of shape :math:`(C,)`.
Outputs:
Tuple of 6 Tensor, the normalized input and the updated parameters.
- **output_x** (Tensor) - The same type and shape as the `input_x`.
- **updated_scale** (Tensor) - Tensor of shape :math:`(C,)`.
- **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`.
- **updated_moving_mean** (Tensor) - Tensor of shape :math:`(C,)`.
- **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`.
- **reserve** (Tensor) - Tensor of shape :math:`(C,)`.
Examples:
>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
>>> scale = Tensor(np.ones([64]), mindspore.float32)
>>> bias = Tensor(np.ones([64]), mindspore.float32)
>>> mean = Tensor(np.ones([64]), mindspore.float32)
>>> variance = Tensor(np.ones([64]), mindspore.float32)
>>> op = P.FusedBatchNormEx()
>>> output = op(input_x, scale, bias, mean, variance)
"""
@prim_attr_register
def __init__(self, mode=0, epsilon=1e-5, momentum=0.1):
self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'],
outputs=['y', 'save_scale', 'save_bias', 'save_mean', 'save_inv_variance', 'reserve'])
self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name)
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
self._update_parameter = True
self.add_prim_attr('data_format', "NCHW")
def infer_shape(self, input_x, scale, bias, mean, variance):
validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name)
validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name)
validator.check("scale shape[0]", scale[0], "input_x shape[1]", input_x[1], Rel.EQ, self.name)
validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name)
validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name)
validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name)
return (input_x, scale, scale, scale, scale, scale)
def infer_dtype(self, input_x, scale, bias, mean, variance):
validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name)
args = {"scale": scale, "bias": bias}
validator.check_tensor_type_same(args, [mstype.float32], self.name)
args_moving = {"mean": mean, "variance": variance}
valid_types = [mstype.tensor_type(mstype.float32)]
validator.check_type_same(args_moving, valid_types, self.name)
return (input_x, scale, scale, scale, scale, scale)
class BNTrainingReduce(PrimitiveWithInfer):
"""
reduce sum at axis [0, 2, 3].
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册