“5de576b0af93519236a2307855b1182c86c5d142”上不存在“paddle/phi/ops/compat/fill_any_like_sig.cc”
提交 25af911e 编写于 作者: V VectorSL

gpu update bn

上级 420ef2a3
...@@ -82,6 +82,7 @@ class FusedBatchNormGpuKernel : public GpuKernel { ...@@ -82,6 +82,7 @@ class FusedBatchNormGpuKernel : public GpuKernel {
} }
bool Init(const CNodePtr &kernel_node) override { bool Init(const CNodePtr &kernel_node) override {
InitResource(); InitResource();
cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))];
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 5) { if (input_num != 5) {
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGpuKernel should be 5"; MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGpuKernel should be 5";
...@@ -112,11 +113,11 @@ class FusedBatchNormGpuKernel : public GpuKernel { ...@@ -112,11 +113,11 @@ class FusedBatchNormGpuKernel : public GpuKernel {
} }
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, batch_, channel_, height_, width_), cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_),
"Set x desc failed"); "Set x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, batch_, channel_, height_, width_), cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_),
"Set y desc failed"); "Set y desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
......
...@@ -110,7 +110,7 @@ class FusedBatchNormGradGpuKernel : public GpuKernel { ...@@ -110,7 +110,7 @@ class FusedBatchNormGradGpuKernel : public GpuKernel {
cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_),
"Set dx desc failed"); "Set dx desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT( CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(scale_bias_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 1, channel_, 1, 1), cudnnSetTensor4dDescriptor(scale_bias_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel_, 1, 1),
"Set para desc failed"); "Set para desc failed");
InitSizeLists(); InitSizeLists();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册