From 2e845182d90e68252ebcee216b89c998cf5003b1 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Thu, 15 Oct 2020 11:50:33 +0800 Subject: [PATCH] support channel last in BatchNorm*d 1. support channel last in BatchNorm*d (#27875) 2. fix a bug in batch_norm_op cuda kernel by extracting ResizeToChannelFist(Last), TransToChannelFirst(Last) to operators/layer_utils.h --- paddle/fluid/operators/batch_norm_op.h | 122 +------------- paddle/fluid/operators/conv_op.h | 97 +---------- paddle/fluid/operators/layout_utils.h | 155 ++++++++++++++++++ .../tests/unittests/test_batch_norm_op_v2.py | 54 ++++++ python/paddle/nn/functional/norm.py | 10 +- python/paddle/nn/layer/norm.py | 29 +++- 6 files changed, 236 insertions(+), 231 deletions(-) create mode 100644 paddle/fluid/operators/layout_utils.h diff --git a/paddle/fluid/operators/batch_norm_op.h b/paddle/fluid/operators/batch_norm_op.h index 1440b74290c..32e956e1528 100644 --- a/paddle/fluid/operators/batch_norm_op.h +++ b/paddle/fluid/operators/batch_norm_op.h @@ -19,6 +19,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/layout_utils.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/norm_utils.h" @@ -41,127 +42,6 @@ template using ConstEigenVectorArrayMap = Eigen::Map>; -template -inline void ResizeToChannelFirst(const framework::ExecutionContext& context, - const Tensor* input, - Tensor* transformed_input) { - int dim = input->dims().size() - 2; - if (dim == 3) { - // input - transformed_input->Resize(input->dims()); - - auto in_dims_vec = framework::vectorize(input->dims()); - in_dims_vec[1] = input->dims()[4]; - in_dims_vec[2] = input->dims()[1]; - in_dims_vec[3] = input->dims()[2]; - in_dims_vec[4] = input->dims()[3]; - transformed_input->Resize(framework::make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); - - } else if (dim == 2) { - // input - transformed_input->Resize(input->dims()); - - auto in_dims_vec = framework::vectorize(input->dims()); - in_dims_vec[1] = input->dims()[3]; - in_dims_vec[2] = input->dims()[1]; - in_dims_vec[3] = input->dims()[2]; - transformed_input->Resize(framework::make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); - } else if (dim == 1) { - transformed_input->Resize(input->dims()); - - auto in_dims_vec = framework::vectorize(input->dims()); - in_dims_vec[1] = input->dims()[2]; - in_dims_vec[2] = input->dims()[1]; - transformed_input->Resize(framework::make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); - } -} - -template -inline void TransToChannelFirst(const framework::ExecutionContext& context, - const Tensor* input, - Tensor* transformed_input) { - int dim = input->dims().size() - 2; - if (dim == 3) { - auto& dev_ctx = context.template device_context(); - std::vector axis{0, 4, 1, 2, 3}; - math::Transpose trans5; - trans5(dev_ctx, *input, transformed_input, axis); - - } else if (dim == 2) { - auto& dev_ctx = context.template device_context(); - std::vector axis{0, 3, 1, 2}; - math::Transpose trans4; - trans4(dev_ctx, *input, transformed_input, axis); - } else if (dim == 1) { - auto& dev_ctx = context.template device_context(); - std::vector axis{0, 2, 1}; - math::Transpose trans3; - trans3(dev_ctx, *input, transformed_input, axis); - } -} - -template -inline void ResizeToChannelLast(const framework::ExecutionContext& context, - const Tensor* input, - Tensor* transformed_input) { - int dim = input->dims().size() - 2; - if (dim == 3) { - transformed_input->Resize(input->dims()); - - auto in_dims_vec = framework::vectorize(input->dims()); - in_dims_vec[1] = input->dims()[2]; - in_dims_vec[2] = input->dims()[3]; - in_dims_vec[3] = input->dims()[4]; - in_dims_vec[4] = input->dims()[1]; - transformed_input->Resize(framework::make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); - - } else if (dim == 2) { - transformed_input->Resize(input->dims()); - - auto in_dims_vec = framework::vectorize(input->dims()); - in_dims_vec[1] = input->dims()[2]; - in_dims_vec[2] = input->dims()[3]; - in_dims_vec[3] = input->dims()[1]; - transformed_input->Resize(framework::make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); - } else if (dim == 1) { - transformed_input->Resize(input->dims()); - - auto in_dims_vec = framework::vectorize(input->dims()); - in_dims_vec[1] = input->dims()[2]; - in_dims_vec[2] = input->dims()[1]; - transformed_input->Resize(framework::make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); - } -} - -template -inline void TransToChannelLast(const framework::ExecutionContext& context, - const Tensor* input, Tensor* transformed_input) { - int dim = input->dims().size() - 2; - if (dim == 3) { - auto& dev_ctx = context.template device_context(); - std::vector axis{0, 2, 3, 4, 1}; - math::Transpose trans5; - trans5(dev_ctx, *input, transformed_input, axis); - - } else if (dim == 2) { - auto& dev_ctx = context.template device_context(); - std::vector axis{0, 2, 3, 1}; - math::Transpose trans4; - trans4(dev_ctx, *input, transformed_input, axis); - } else if (dim == 1) { - auto& dev_ctx = context.template device_context(); - std::vector axis{0, 2, 1}; - math::Transpose trans3; - trans3(dev_ctx, *input, transformed_input, axis); - } -} - class BatchNormOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index 662fac9e77e..364e3ab8d26 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -20,6 +20,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/layout_utils.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/depthwise_conv.h" #include "paddle/fluid/operators/math/im2col.h" @@ -138,102 +139,6 @@ inline bool IsExpand(const std::vector& filter_dim, return !(filter_1 && strides_1 && padding_0 && dilation_1); } -template -inline void ResizeToChannelFirst(const framework::ExecutionContext& context, - const Tensor* input, - Tensor* transformed_input) { - int dim = input->dims().size() - 2; - if (dim == 3) { - // input - transformed_input->Resize(input->dims()); - - auto in_dims_vec = framework::vectorize(input->dims()); - in_dims_vec[1] = input->dims()[4]; - in_dims_vec[2] = input->dims()[1]; - in_dims_vec[3] = input->dims()[2]; - in_dims_vec[4] = input->dims()[3]; - transformed_input->Resize(framework::make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); - - } else if (dim == 2) { - // input - transformed_input->Resize(input->dims()); - - auto in_dims_vec = framework::vectorize(input->dims()); - in_dims_vec[1] = input->dims()[3]; - in_dims_vec[2] = input->dims()[1]; - in_dims_vec[3] = input->dims()[2]; - transformed_input->Resize(framework::make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); - } -} - -template -inline void ResizeToChannelLast(const framework::ExecutionContext& context, - const Tensor* input, - Tensor* transformed_input) { - int dim = input->dims().size() - 2; - if (dim == 3) { - // input - transformed_input->Resize(input->dims()); - - auto in_dims_vec = framework::vectorize(input->dims()); - in_dims_vec[1] = input->dims()[2]; - in_dims_vec[2] = input->dims()[3]; - in_dims_vec[3] = input->dims()[4]; - in_dims_vec[4] = input->dims()[1]; - transformed_input->Resize(framework::make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); - - } else if (dim == 2) { - // input - transformed_input->Resize(input->dims()); - - auto in_dims_vec = framework::vectorize(input->dims()); - in_dims_vec[1] = input->dims()[2]; - in_dims_vec[2] = input->dims()[3]; - in_dims_vec[3] = input->dims()[1]; - transformed_input->Resize(framework::make_ddim(in_dims_vec)); - transformed_input->mutable_data(context.GetPlace()); - } -} - -template -inline void TransToChannelFirst(const framework::ExecutionContext& context, - const Tensor* input, - Tensor* transformed_input) { - int dim = input->dims().size() - 2; - if (dim == 3) { - auto& dev_ctx = context.template device_context(); - std::vector axis{0, 4, 1, 2, 3}; - math::Transpose trans5; - trans5(dev_ctx, *input, transformed_input, axis); - - } else if (dim == 2) { - auto& dev_ctx = context.template device_context(); - std::vector axis{0, 3, 1, 2}; - math::Transpose trans4; - trans4(dev_ctx, *input, transformed_input, axis); - } -} - -template -inline void TransToChannelLast(const framework::ExecutionContext& context, - const Tensor* input, Tensor* transformed_input) { - int dim = input->dims().size() - 2; - if (dim == 3) { - auto& dev_ctx = context.template device_context(); - std::vector axis{0, 2, 3, 4, 1}; - math::Transpose trans5; - trans5(dev_ctx, *input, transformed_input, axis); - - } else if (dim == 2) { - auto& dev_ctx = context.template device_context(); - std::vector axis{0, 2, 3, 1}; - math::Transpose trans4; - trans4(dev_ctx, *input, transformed_input, axis); - } -} // Define Op classes in .h file so that other conv // operator implementations can reuse the code. class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/layout_utils.h b/paddle/fluid/operators/layout_utils.h new file mode 100644 index 00000000000..52fa7fd1079 --- /dev/null +++ b/paddle/fluid/operators/layout_utils.h @@ -0,0 +1,155 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +inline void ResizeToChannelFirst(const framework::ExecutionContext& context, + const Tensor* input, + Tensor* transformed_input) { + int dim = input->dims().size() - 2; + if (dim == 3) { + // input + transformed_input->Resize(input->dims()); + + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[1] = input->dims()[4]; + in_dims_vec[2] = input->dims()[1]; + in_dims_vec[3] = input->dims()[2]; + in_dims_vec[4] = input->dims()[3]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + + } else if (dim == 2) { + // input + transformed_input->Resize(input->dims()); + + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[1] = input->dims()[3]; + in_dims_vec[2] = input->dims()[1]; + in_dims_vec[3] = input->dims()[2]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + } else if (dim == 1) { + transformed_input->Resize(input->dims()); + + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[1] = input->dims()[2]; + in_dims_vec[2] = input->dims()[1]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + } +} + +template +inline void ResizeToChannelLast(const framework::ExecutionContext& context, + const Tensor* input, + Tensor* transformed_input) { + int dim = input->dims().size() - 2; + if (dim == 3) { + // input + transformed_input->Resize(input->dims()); + + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[1] = input->dims()[2]; + in_dims_vec[2] = input->dims()[3]; + in_dims_vec[3] = input->dims()[4]; + in_dims_vec[4] = input->dims()[1]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + + } else if (dim == 2) { + // input + transformed_input->Resize(input->dims()); + + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[1] = input->dims()[2]; + in_dims_vec[2] = input->dims()[3]; + in_dims_vec[3] = input->dims()[1]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + } else if (dim == 1) { + transformed_input->Resize(input->dims()); + + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[1] = input->dims()[2]; + in_dims_vec[2] = input->dims()[1]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + } +} + +template +inline void TransToChannelFirst(const framework::ExecutionContext& context, + const Tensor* input, + Tensor* transformed_input) { + VLOG(5) << "Why am I called?"; + int dim = input->dims().size() - 2; + if (dim == 3) { + auto& dev_ctx = context.template device_context(); + std::vector axis{0, 4, 1, 2, 3}; + math::Transpose trans5; + trans5(dev_ctx, *input, transformed_input, axis); + + } else if (dim == 2) { + auto& dev_ctx = context.template device_context(); + std::vector axis{0, 3, 1, 2}; + math::Transpose trans4; + trans4(dev_ctx, *input, transformed_input, axis); + } else if (dim == 1) { + auto& dev_ctx = context.template device_context(); + std::vector axis{0, 2, 1}; + math::Transpose trans3; + trans3(dev_ctx, *input, transformed_input, axis); + } +} + +template +inline void TransToChannelLast(const framework::ExecutionContext& context, + const Tensor* input, Tensor* transformed_input) { + int dim = input->dims().size() - 2; + if (dim == 3) { + auto& dev_ctx = context.template device_context(); + std::vector axis{0, 2, 3, 4, 1}; + math::Transpose trans5; + trans5(dev_ctx, *input, transformed_input, axis); + + } else if (dim == 2) { + auto& dev_ctx = context.template device_context(); + std::vector axis{0, 2, 3, 1}; + math::Transpose trans4; + trans4(dev_ctx, *input, transformed_input, axis); + } else if (dim == 1) { + auto& dev_ctx = context.template device_context(); + std::vector axis{0, 2, 1}; + math::Transpose trans3; + trans3(dev_ctx, *input, transformed_input, axis); + } +} + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py index 2af0b31d6fc..324d4cf7110 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py @@ -168,5 +168,59 @@ class TestBatchNorm(unittest.TestCase): self.assertTrue(np.allclose(y1, y2)) +class TestBatchNormChannelLast(unittest.TestCase): + def setUp(self): + self.original_dtyep = paddle.get_default_dtype() + paddle.set_default_dtype("float64") + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"): + self.places.append(fluid.CUDAPlace(0)) + + def tearDown(self): + paddle.set_default_dtype(self.original_dtyep) + + def test_1d(self): + for p in self.places: + with fluid.dygraph.guard(p): + x = paddle.randn([2, 6, 4]) + net1 = paddle.nn.BatchNorm1d(4, data_format="NLC") + net2 = paddle.nn.BatchNorm1d(4) + net2.weight = net1.weight + net2.bias = net1.bias + y1 = net1(x) + channel_first_x = paddle.transpose(x, [0, 2, 1]) + y2 = net2(channel_first_x) + y2 = paddle.transpose(y2, [0, 2, 1]) + self.assertEqual(np.allclose(y1.numpy(), y2.numpy()), True) + + def test_2d(self): + for p in self.places: + with fluid.dygraph.guard(p): + x = paddle.randn([2, 6, 6, 4]) + net1 = paddle.nn.BatchNorm2d(4, data_format="NHWC") + net2 = paddle.nn.BatchNorm2d(4) + net2.weight = net1.weight + net2.bias = net1.bias + y1 = net1(x) + channel_first_x = paddle.transpose(x, [0, 3, 1, 2]) + y2 = net2(channel_first_x) + y2 = paddle.transpose(y2, [0, 2, 3, 1]) + self.assertEqual(np.allclose(y1.numpy(), y2.numpy()), True) + + def test_3d(self): + for p in self.places: + with fluid.dygraph.guard(p): + x = paddle.randn([2, 6, 6, 6, 4]) + net1 = paddle.nn.BatchNorm3d(4, data_format="NDHWC") + net2 = paddle.nn.BatchNorm3d(4) + net2.weight = net1.weight + net2.bias = net1.bias + y1 = net1(x) + channel_first_x = paddle.transpose(x, [0, 4, 1, 2, 3]) + y2 = net2(channel_first_x) + y2 = paddle.transpose(y2, [0, 2, 3, 4, 1]) + self.assertEqual(np.allclose(y1.numpy(), y2.numpy()), True) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index 0beedd96eb7..9b783682591 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -138,7 +138,7 @@ def batch_norm(x, epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5. momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. training(bool, optional): True means train mode which compute by batch data and track global mean and var during train period. False means inference mode which compute by global mean and var which calculated by train period. Defalut False. - data_format(str, optional): Specify the input data format, may be "NC", "NCL", "NCHW" or "NCDHW". Defalut "NCHW". + data_format(str, optional): Specify the input data format, may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Defalut "NCHW". name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. Returns: @@ -174,13 +174,13 @@ def batch_norm(x, mean_out = running_mean variance_out = running_var - true_data_format = ['NC', 'NCL', 'NCHW', 'NCDHW'] + true_data_format = ['NC', 'NCL', 'NCHW', 'NCDHW', 'NLC', 'NHWC', 'NDHWC'] if data_format not in true_data_format: raise ValueError( - "data_format must be one of 'NC', 'NCL', 'NCHW', 'NCDHW', but receive {}". - format(data_format)) + "data_format must be one of 'NC', 'NCL', 'NCHW', 'NCDHW', " + "'NLC', 'NHWC', 'NDHWC' but receive {}".format(data_format)) - data_format = 'NCHW' + data_format = 'NCHW' if data_format[1] == 'C' else 'NHWC' if in_dygraph_mode(): # for dygraph need tuple diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index e9a2df3dc6e..ad8dc9b64e7 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -719,14 +719,15 @@ class BatchNorm1d(_BatchNormBase): If it is set to None or one attribute of ParamAttr, batch_norm will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. - data_format(str, optional): Specify the input data format, may be "NC", "NCL". Defalut "NCL". + data_format(str, optional): Specify the input data format, may be "NC", "NCL" or "NLC". Defalut "NCL". track_running_stats(bool, optional): Whether to use global mean and variance. In train period, True will track global mean and variance used for inference. When inference, track_running_stats must be True. Default: True. name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. Shape: - - x: 2-D or 3-D tensor with shape: (batch, num_features) or (batch, num_features, length). + - x: 2-D or 3-D tensor with shape: (batch, num_features) or (batch, num_features, length) when data_format is "NC" or "NCL", + (batch, length, num_features) when data_format is "NLC". - output: 3-D tensor with same shape as input x. Returns: @@ -755,8 +756,11 @@ class BatchNorm1d(_BatchNormBase): def _check_data_format(self, input): if input == 'NCHW' or input == 'NC' or input == 'NCL': self._data_format = 'NCHW' + elif input == "NHWC" or input == 'NLC': + self._data_format = "NHWC" else: - raise ValueError('expected NC , NCL or None for data_format input') + raise ValueError( + 'expected NC , NCL, NLC or None for data_format input') def _check_input_dim(self, input): if len(input.shape) != 2 and len(input.shape) != 3: @@ -812,14 +816,15 @@ class BatchNorm2d(_BatchNormBase): If it is set to None or one attribute of ParamAttr, batch_norm will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. - data_format(str, optional): Specify the input data format, the data format can be "NCHW". Default: NCHW. + data_format(str, optional): Specify the input data format, the data format can be "NCHW" or "NHWC". Default: NCHW. track_running_stats(bool, optional): Whether to use global mean and variance. In train period, True will track global mean and variance used for inference. When inference, track_running_stats must be True. Default: True. name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. Shape: - - x: 4-D tensor with shape: (batch, num_features, height, weight). + - x: 4-D tensor with shape: (batch, num_features, height, weight) when data_format is "NCHW", + or (batch, height, weight, num_features) when data_format is "NHWC". - output: 4-D tensor with same shape as input x. Returns: @@ -847,8 +852,10 @@ class BatchNorm2d(_BatchNormBase): def _check_data_format(self, input): if input == 'NCHW': self._data_format = input + elif input == "NHWC": + self._data_format = input else: - raise ValueError('expected NCHW for data_format input') + raise ValueError('expected NCHW or NHWC for data_format input') def _check_input_dim(self, input): if len(input.shape) != 4: @@ -904,14 +911,15 @@ class BatchNorm3d(_BatchNormBase): If it is set to None or one attribute of ParamAttr, batch_norm will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. - data_format(str, optional): Specify the input data format, the data format can be "NCDHW". Default: NCDHW. + data_format(str, optional): Specify the input data format, the data format can be "NCDHW" or "NDHWC. Default: NCDHW. track_running_stats(bool, optional): Whether to use global mean and variance. In train period, True will track global mean and variance used for inference. When inference, track_running_stats must be True. Default: True. name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. Shape: - - x: 5-D tensor with shape: (batch, num_features, dims, height, weight). + - x: 5-D tensor with shape: (batch, num_features, dims, height, weight) when data_format is "NCDHW", + or (batch, dims, height, weight, num_features) when data_format is "NDHWC". - output: 5-D tensor with same shape as input x. Returns: @@ -939,8 +947,11 @@ class BatchNorm3d(_BatchNormBase): def _check_data_format(self, input): if input == 'NCHW' or input == 'NCDHW': self._data_format = 'NCHW' + elif input == "NHWC" or input == "NDHWC": + self._data_format = 'NHWC' else: - raise ValueError('expected NCDHW or None for data_format input') + raise ValueError( + 'expected NCDHW, NDHWC or None for data_format input') def _check_input_dim(self, input): if len(input.shape) != 5: -- GitLab