提交 f56f03ea 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!598 GPU fix cudnn type error

Merge pull request !598 from VectorSL/fix-cudnn-type-error
......@@ -26,5 +26,8 @@ MS_REG_GPU_KERNEL_ONE(
TensorAdd,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
TensorAddGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(
TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
TensorAddGpuFwdKernel, int)
} // namespace kernel
} // namespace mindspore
......@@ -71,6 +71,9 @@ class TensorAddGpuFwdKernel : public GpuKernel {
bool Init(const CNodePtr &kernel_node) {
InitResource();
cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))];
if (cudnn_data_type_ == CUDNN_DATA_INT32) {
cudnn_data_type_ = CUDNN_DATA_FLOAT;
}
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but cudnnAddTensor needs 2 inputs.";
......
......@@ -101,7 +101,7 @@ class BiasAddGradGpuKernel : public GpuKernel {
cudnnSetTensorNdDescriptorEx(db_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), db_dims.get()),
"cudnnSetTensorNdDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetReduceTensorDescriptor(op_desc_, CUDNN_REDUCE_TENSOR_ADD, cudnn_data_type_, CUDNN_NOT_PROPAGATE_NAN,
cudnnSetReduceTensorDescriptor(op_desc_, CUDNN_REDUCE_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN,
CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES),
"cudnnSetReduceTensorDescriptor failed");
......
......@@ -151,7 +151,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
loss_scale = loss_scale_manager.get_loss_scale()
update_cell = loss_scale_manager.get_update_cell()
if update_cell is not None:
if not context.get_context("enable_ge"):
if not (context.get_context("enable_ge") or (context.get_context("device_target") == "GPU")):
raise ValueError("Only `loss_scale_manager=None` and "
"`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`"
"are supported in current version. If you use `O2` option, please"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册