提交 f56f03ea 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!598 GPU fix cudnn type error

Merge pull request !598 from VectorSL/fix-cudnn-type-error
...@@ -26,5 +26,8 @@ MS_REG_GPU_KERNEL_ONE( ...@@ -26,5 +26,8 @@ MS_REG_GPU_KERNEL_ONE(
TensorAdd, TensorAdd,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
TensorAddGpuFwdKernel, half) TensorAddGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(
TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
TensorAddGpuFwdKernel, int)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
...@@ -71,6 +71,9 @@ class TensorAddGpuFwdKernel : public GpuKernel { ...@@ -71,6 +71,9 @@ class TensorAddGpuFwdKernel : public GpuKernel {
bool Init(const CNodePtr &kernel_node) { bool Init(const CNodePtr &kernel_node) {
InitResource(); InitResource();
cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))];
if (cudnn_data_type_ == CUDNN_DATA_INT32) {
cudnn_data_type_ = CUDNN_DATA_FLOAT;
}
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) { if (input_num != 2) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but cudnnAddTensor needs 2 inputs."; MS_LOG(ERROR) << "Input number is " << input_num << ", but cudnnAddTensor needs 2 inputs.";
......
...@@ -101,7 +101,7 @@ class BiasAddGradGpuKernel : public GpuKernel { ...@@ -101,7 +101,7 @@ class BiasAddGradGpuKernel : public GpuKernel {
cudnnSetTensorNdDescriptorEx(db_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), db_dims.get()), cudnnSetTensorNdDescriptorEx(db_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), db_dims.get()),
"cudnnSetTensorNdDescriptor failed"); "cudnnSetTensorNdDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetReduceTensorDescriptor(op_desc_, CUDNN_REDUCE_TENSOR_ADD, cudnn_data_type_, CUDNN_NOT_PROPAGATE_NAN, cudnnSetReduceTensorDescriptor(op_desc_, CUDNN_REDUCE_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN,
CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES), CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES),
"cudnnSetReduceTensorDescriptor failed"); "cudnnSetReduceTensorDescriptor failed");
......
...@@ -151,7 +151,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): ...@@ -151,7 +151,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
loss_scale = loss_scale_manager.get_loss_scale() loss_scale = loss_scale_manager.get_loss_scale()
update_cell = loss_scale_manager.get_update_cell() update_cell = loss_scale_manager.get_update_cell()
if update_cell is not None: if update_cell is not None:
if not context.get_context("enable_ge"): if not (context.get_context("enable_ge") or (context.get_context("device_target") == "GPU")):
raise ValueError("Only `loss_scale_manager=None` and " raise ValueError("Only `loss_scale_manager=None` and "
"`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`" "`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`"
"are supported in current version. If you use `O2` option, please" "are supported in current version. If you use `O2` option, please"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册