From b143e0086155b357bcccdb31fd55aa9366dcaace Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 27 Oct 2022 17:14:58 +0800 Subject: [PATCH] [cherry-pick] add batch_norm_kernel (#47394) * cherry-pick #46359 and resolve conflict --- paddle/fluid/operators/sparse_manual_op.cc | 47 ++++ paddle/phi/api/yaml/sparse_backward.yaml | 12 + paddle/phi/api/yaml/sparse_ops.yaml | 10 + paddle/phi/infermeta/multiary.cc | 1 + .../kernels/sparse/batch_norm_grad_kernel.cc | 108 ++++++++ .../kernels/sparse/batch_norm_grad_kernel.h | 48 ++++ .../phi/kernels/sparse/batch_norm_kernel.cc | 117 +++++++++ paddle/phi/kernels/sparse/batch_norm_kernel.h | 47 ++++ paddle/phi/ops/compat/sparse_manual_op_sig.cc | 27 ++ .../tests/unittests/test_sparse_norm_op.py | 109 ++++++-- python/paddle/sparse/nn/layer/norm.py | 241 +++++++++++------- 11 files changed, 657 insertions(+), 110 deletions(-) create mode 100644 paddle/phi/kernels/sparse/batch_norm_grad_kernel.cc create mode 100644 paddle/phi/kernels/sparse/batch_norm_grad_kernel.h create mode 100644 paddle/phi/kernels/sparse/batch_norm_kernel.cc create mode 100644 paddle/phi/kernels/sparse/batch_norm_kernel.h diff --git a/paddle/fluid/operators/sparse_manual_op.cc b/paddle/fluid/operators/sparse_manual_op.cc index e2ed1ed0ff..04e12391b4 100644 --- a/paddle/fluid/operators/sparse_manual_op.cc +++ b/paddle/fluid/operators/sparse_manual_op.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/infermeta/binary.h" +#include "paddle/phi/infermeta/multiary.h" #include "paddle/phi/infermeta/sparse/binary.h" #include "paddle/phi/infermeta/sparse/unary.h" #include "paddle/phi/infermeta/unary.h" @@ -185,6 +186,47 @@ DECLARE_INFER_SHAPE_FUNCTOR(sparse_add, SparseAddInferShapeFunctor, PD_INFER_META(phi::UnchangedInferMeta)); +class SparseBatchNormOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("x", "(Tensor), input 0 of sparse_batch_norm op."); + AddInput("scale", "(Tensor), input 1 of sparse_batch_norm op."); + AddInput("bias", "(Tensor), input 2 of sparse_batch_norm op."); + AddInput("mean", "(Tensor), input 3 of sparse_batch_norm op."); + AddInput("variance", "(Tensor), input 4 of sparse_batch_norm op."); + AddOutput("y", "(Tensor), output 0 of sparse_batch_norm op."); + AddOutput("mean_out", "(Tensor), output 1 of sparse_batch_norm op."); + AddOutput("variance_out", "(Tensor), output 2 of sparse_batch_norm op."); + AddOutput("saved_mean", "(Tensor), output 3 of sparse_batch_norm op."); + AddOutput("saved_variance", "(Tensor), output 4 of sparse_batch_norm op."); + AddOutput("reserve_space", "(Tensor), output 5 of sparse_batch_norm op."); + AddAttr("momentum", + "(float), attribute 0 for sparse_batch_norm op."); + AddAttr("epsilon", "(float), attribute 1 for sparse_batch_norm op."); + AddAttr("data_layout", + "(string), attribute 2 for sparse_batch_norm op."); + AddAttr("is_test", "(bool), attribute 3 for sparse_batch_norm op."); + AddAttr("use_global_stats", + "(bool), attribute 4 for sparse_batch_norm op."); + AddAttr("trainable_statistics", + "(bool), attribute 4 for sparse_batch_norm op."); + AddAttr("fuse_with_relu", + "(bool), attribute 4 for sparse_batch_norm op."); + AddComment(R"DOC( +TODO: Documentation of sparse_conv3d op. +)DOC"); + } +}; + +class SparseBatchNormOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; +}; + +DECLARE_INFER_SHAPE_FUNCTOR(sparse_batch_norm, + SparseBatchNormInferShapeFunctor, + PD_INFER_META(phi::BatchNormInferMeta)); + } // namespace operators } // namespace paddle @@ -224,3 +266,8 @@ REGISTER_OPERATOR(sparse_add, ops::SparseAddOp, ops::SparseAddOpMaker, ops::SparseAddInferShapeFunctor); + +REGISTER_OPERATOR(sparse_batch_norm, + ops::SparseBatchNormOp, + ops::SparseBatchNormOpMaker, + ops::SparseBatchNormInferShapeFunctor); diff --git a/paddle/phi/api/yaml/sparse_backward.yaml b/paddle/phi/api/yaml/sparse_backward.yaml index 40b646cb38..13b4454626 100644 --- a/paddle/phi/api/yaml/sparse_backward.yaml +++ b/paddle/phi/api/yaml/sparse_backward.yaml @@ -100,6 +100,18 @@ func : atanh_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, atanh_csr_grad {sparse_csr, sparse_csr -> sparse_csr} +- backward_op : batch_norm_grad + forward : batch_norm (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) + args : (Tensor x, Tensor scale, Tensor bias, Tensor mean_out, Tensor variance_out, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor out_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) + output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad) + infer_meta : + func : GeneralTernaryGradInferMeta + param : [x, scale, bias] + kernel : + func : batch_norm_coo_grad {sparse_coo, dense, dense, dense, dense, dense, dense, dense, sparse_coo -> sparse_coo, dense, dense} + data_type : out_grad + optional : mean_out, variance_out, reserve_space + - backward_op : cast_grad forward : cast(Tensor x, DataType index_dtype, DataType value_dtype) -> Tensor(out) args : (Tensor x, Tensor out_grad, DataType value_dtype) diff --git a/paddle/phi/api/yaml/sparse_ops.yaml b/paddle/phi/api/yaml/sparse_ops.yaml index c6ad1bfa58..2d96b22e5a 100644 --- a/paddle/phi/api/yaml/sparse_ops.yaml +++ b/paddle/phi/api/yaml/sparse_ops.yaml @@ -87,6 +87,16 @@ layout : x backward : atanh_grad +- op : batch_norm + args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) + output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) + infer_meta : + func : BatchNormInferMeta + kernel : + func : batch_norm_coo {sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense} + data_type : x + backward : batch_norm_grad + - op : cast args : (Tensor x, DataType index_dtype=DataType::UNDEFINED, DataType value_dtype=DataType::UNDEFINED) output : Tensor(out) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 56dc40cc7c..ac0e865022 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -609,6 +609,7 @@ void BatchNormInferMeta(const MetaTensor& x, saved_variance->set_dims({C}); } y->share_lod(x); + y->set_dtype(x.dtype()); } void BatchNormInferInferMeta(const MetaTensor& x, diff --git a/paddle/phi/kernels/sparse/batch_norm_grad_kernel.cc b/paddle/phi/kernels/sparse/batch_norm_grad_kernel.cc new file mode 100644 index 0000000000..f9a96b15ee --- /dev/null +++ b/paddle/phi/kernels/sparse/batch_norm_grad_kernel.cc @@ -0,0 +1,108 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/batch_norm_grad_kernel.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/batch_norm_grad_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" + +namespace phi { +namespace sparse { + +template +void BatchNormCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& scale, + const DenseTensor& bias, + const paddle::optional& mean, + const paddle::optional& variance, + const DenseTensor& saved_mean, + const DenseTensor& saved_variance, + const paddle::optional& reserve_space, + const SparseCooTensor& y_grad, + float momentum, + float epsilon, + const std::string& data_layout, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + bool fuse_with_relu, + SparseCooTensor* x_grad, + DenseTensor* scale_grad, + DenseTensor* bias_grad) { + EmptyLikeCooKernel(dev_ctx, x, x_grad); + *scale_grad = phi::EmptyLike(dev_ctx, scale); + *bias_grad = phi::EmptyLike(dev_ctx, bias); + phi::BatchNormGradKernel(dev_ctx, + x.values(), + scale, + bias, + mean, + variance, + saved_mean, + saved_variance, + reserve_space, + y_grad.values(), + momentum, + epsilon, + data_layout, + is_test, + use_global_stats, + trainable_statistics, + fuse_with_relu, + x_grad->mutable_values(), + scale_grad, + bias_grad); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(batch_norm_coo_grad, + CPU, + ALL_LAYOUT, + phi::sparse::BatchNormCooGradKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + +#if defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL(batch_norm_coo_grad, + GPU, + ALL_LAYOUT, + phi::sparse::BatchNormCooGradKernel, + float, + phi::dtype::float16) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} +#endif + +#if defined(PADDLE_WITH_CUDA) +PD_REGISTER_KERNEL(batch_norm_coo_grad, + GPU, + ALL_LAYOUT, + phi::sparse::BatchNormCooGradKernel, + float, + double, + phi::dtype::float16) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); // scale_grad + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad + } +} +#endif diff --git a/paddle/phi/kernels/sparse/batch_norm_grad_kernel.h b/paddle/phi/kernels/sparse/batch_norm_grad_kernel.h new file mode 100644 index 0000000000..b705168317 --- /dev/null +++ b/paddle/phi/kernels/sparse/batch_norm_grad_kernel.h @@ -0,0 +1,48 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/sparse_coo_tensor.h" + +namespace phi { +namespace sparse { + +template +void BatchNormCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& scale, + const DenseTensor& bias, + const paddle::optional& mean, + const paddle::optional& variance, + const DenseTensor& saved_mean, + const DenseTensor& saved_variance, + const paddle::optional& reserve_space, + const SparseCooTensor& y_grad, + float momentum, + float epsilon, + const std::string& data_layout, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + bool fuse_with_relu, + SparseCooTensor* x_grad, + DenseTensor* scale_grad, + DenseTensor* bias_grad); + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/kernels/sparse/batch_norm_kernel.cc b/paddle/phi/kernels/sparse/batch_norm_kernel.cc new file mode 100644 index 0000000000..4f925e83a9 --- /dev/null +++ b/paddle/phi/kernels/sparse/batch_norm_kernel.cc @@ -0,0 +1,117 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/batch_norm_kernel.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/batch_norm_kernel.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" + +namespace phi { +namespace sparse { + +template +void BatchNormCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& scale, + const DenseTensor& bias, + const DenseTensor& mean, + const DenseTensor& variance, + float momentum, + float epsilon, + const std::string& data_layout, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + bool fuse_with_relu, + SparseCooTensor* y, + DenseTensor* mean_out, + DenseTensor* variance_out, + DenseTensor* saved_mean, + DenseTensor* saved_variance, + DenseTensor* reserve_space) { + EmptyLikeCooKernel(dev_ctx, x, y); + phi::BatchNormKernel(dev_ctx, + x.values(), + scale, + bias, + mean, + variance, + momentum, + epsilon, + data_layout, + is_test, + use_global_stats, + trainable_statistics, + fuse_with_relu, + y->mutable_values(), + mean_out, + variance_out, + saved_mean, + saved_variance, + reserve_space); + y->SetIndicesDict(x.GetIndicesDict()); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(batch_norm_coo, + CPU, + ALL_LAYOUT, + phi::sparse::BatchNormCooKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + +#if defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL(batch_norm_coo, + GPU, + ALL_LAYOUT, + phi::sparse::BatchNormCooKernel, + float, + phi::dtype::float16) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); + kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); +} +#endif + +#if defined(PADDLE_WITH_CUDA) +PD_REGISTER_KERNEL(batch_norm_coo, + GPU, + ALL_LAYOUT, + phi::sparse::BatchNormCooKernel, + float, + double, + phi::dtype::float16) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + } +} +#endif diff --git a/paddle/phi/kernels/sparse/batch_norm_kernel.h b/paddle/phi/kernels/sparse/batch_norm_kernel.h new file mode 100644 index 0000000000..282a8de7b3 --- /dev/null +++ b/paddle/phi/kernels/sparse/batch_norm_kernel.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/sparse_coo_tensor.h" + +namespace phi { +namespace sparse { + +template +void BatchNormKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& scale, + const DenseTensor& bias, + const DenseTensor& mean, + const DenseTensor& variance, + float momentum, + float epsilon, + const std::string& data_layout, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + bool fuse_with_relu, + SparseCooTensor* y, + DenseTensor* mean_out, + DenseTensor* variance_out, + DenseTensor* saved_mean, + DenseTensor* saved_variance, + DenseTensor* reserve_space); + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/ops/compat/sparse_manual_op_sig.cc b/paddle/phi/ops/compat/sparse_manual_op_sig.cc index 45f8a417a1..6c2a2bc9f4 100644 --- a/paddle/phi/ops/compat/sparse_manual_op_sig.cc +++ b/paddle/phi/ops/compat/sparse_manual_op_sig.cc @@ -82,6 +82,29 @@ KernelSignature SparseAddOpArgumentMapping(const ArgumentMappingContext& ctx) { } } +KernelSignature SparseBatchNormOpArgumentMapping( + const ArgumentMappingContext& ctx) { + if (ctx.IsSparseCooTensorInput("x")) { + return KernelSignature("batch_norm_coo", + {"x", "scale", "bias", "mean", "variance"}, + {"momentum", + "epsilon", + "data_layout", + "is_test", + "use_global_stats", + "trainable_statistics", + "fuse_with_relu"}, + {"y", + "mean_out", + "variance_out", + "saved_mean", + "saved_variance", + "reserve_space"}); + } else { + return KernelSignature("unregistered", {}, {}, {}); + } +} + } // namespace phi PD_REGISTER_BASE_KERNEL_NAME(sparse_sparse_coo_tensor, sparse_coo_tensor); @@ -106,3 +129,7 @@ PD_REGISTER_ARG_MAPPING_FN(sparse_conv3d, phi::SparseConv3dOpArgumentMapping); PD_REGISTER_BASE_KERNEL_NAME(sparse_add, add_coo_coo); PD_REGISTER_ARG_MAPPING_FN(sparse_add, phi::SparseAddOpArgumentMapping); + +PD_REGISTER_BASE_KERNEL_NAME(sparse_batch_norm, batch_norm_coo); +PD_REGISTER_ARG_MAPPING_FN(sparse_batch_norm, + phi::SparseBatchNormOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_sparse_norm_op.py b/python/paddle/fluid/tests/unittests/test_sparse_norm_op.py index b0d6e3749c..90660e35fe 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_norm_op.py @@ -17,18 +17,18 @@ import unittest import numpy as np import paddle from paddle.sparse import nn +import paddle.sparse as sparse import paddle.fluid as fluid import copy class TestSparseBatchNorm(unittest.TestCase): - def test(self): fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) paddle.seed(0) channels = 4 shape = [2, 3, 6, 6, channels] - #there is no zero in dense_x + # there is no zero in dense_x dense_x = paddle.randn(shape) dense_x.stop_gradient = False @@ -48,17 +48,21 @@ class TestSparseBatchNorm(unittest.TestCase): sparse_y = sparse_batch_norm(sparse_x) # compare the result with dense batch_norm - assert np.allclose(dense_y.flatten().numpy(), - sparse_y.values().flatten().numpy(), - atol=1e-5, - rtol=1e-5) + assert np.allclose( + dense_y.flatten().numpy(), + sparse_y.values().flatten().numpy(), + atol=1e-5, + rtol=1e-5, + ) # test backward sparse_y.backward(sparse_y) - assert np.allclose(dense_x.grad.flatten().numpy(), - sparse_x.grad.values().flatten().numpy(), - atol=1e-5, - rtol=1e-5) + assert np.allclose( + dense_x.grad.flatten().numpy(), + sparse_x.grad.values().flatten().numpy(), + atol=1e-5, + rtol=1e-5, + ) fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False}) def test_error_layout(self): @@ -66,8 +70,9 @@ class TestSparseBatchNorm(unittest.TestCase): shape = [2, 3, 6, 6, 3] x = paddle.randn(shape) sparse_x = x.to_sparse_coo(4) - sparse_batch_norm = paddle.sparse.nn.BatchNorm(3, - data_format='NCDHW') + sparse_batch_norm = paddle.sparse.nn.BatchNorm( + 3, data_format='NCDHW' + ) sparse_batch_norm(sparse_x) def test2(self): @@ -86,10 +91,10 @@ class TestSparseBatchNorm(unittest.TestCase): class TestSyncBatchNorm(unittest.TestCase): - def test_sync_batch_norm(self): - x = np.array([[[[0.3, 0.4], [0.3, 0.07]], - [[0.83, 0.37], [0.18, 0.93]]]]).astype('float32') + x = np.array( + [[[[0.3, 0.4], [0.3, 0.07]], [[0.83, 0.37], [0.18, 0.93]]]] + ).astype('float32') x = paddle.to_tensor(x) sparse_x = x.to_sparse_coo(len(x.shape) - 1) @@ -100,23 +105,81 @@ class TestSyncBatchNorm(unittest.TestCase): dense_sync_bn = paddle.nn.SyncBatchNorm(2) x = x.reshape((-1, x.shape[-1])) dense_hidden = dense_sync_bn(x) - assert np.allclose(sparse_hidden.values().numpy(), - dense_hidden.numpy()) + assert np.allclose( + sparse_hidden.values().numpy(), dense_hidden.numpy() + ) def test_convert(self): - base_model = paddle.nn.Sequential(nn.Conv3D(3, 5, 3), nn.BatchNorm(5), - nn.BatchNorm(5)) + base_model = paddle.nn.Sequential( + nn.Conv3D(3, 5, 3), nn.BatchNorm(5), nn.BatchNorm(5) + ) model = paddle.nn.Sequential( - nn.Conv3D(3, 5, 3), nn.BatchNorm(5), - nn.BatchNorm(5, - weight_attr=fluid.ParamAttr(name='bn.scale'), - bias_attr=fluid.ParamAttr(name='bn.bias'))) + nn.Conv3D(3, 5, 3), + nn.BatchNorm(5), + nn.BatchNorm( + 5, + weight_attr=fluid.ParamAttr(name='bn.scale'), + bias_attr=fluid.ParamAttr(name='bn.bias'), + ), + ) model = nn.SyncBatchNorm.convert_sync_batchnorm(model) for idx, sublayer in enumerate(base_model.sublayers()): if isinstance(sublayer, nn.BatchNorm): self.assertEqual(isinstance(model[idx], nn.SyncBatchNorm), True) +class TestStatic(unittest.TestCase): + def test(self): + paddle.enable_static() + indices = paddle.static.data( + name='indices', shape=[4, 4], dtype='int32' + ) + values = paddle.static.data( + name='values', shape=[4, 1], dtype='float32' + ) + channels = 1 + dense_shape = [1, 1, 3, 4, channels] + sp_x = sparse.sparse_coo_tensor(indices, values, dense_shape) + + sparse_batch_norm = paddle.sparse.nn.BatchNorm(channels) + sp_y = sparse_batch_norm(sp_x) + out = sp_y.to_dense() + + exe = paddle.static.Executor() + indices_data = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]] + values_data = np.array([[1.0], [2.0], [3.0], [4.0]]).astype('float32') + bias_data = np.array([1.0]).astype('float32') + weight_data = np.array([2.0]).astype('float32') + mean_data = np.array([1.0]).astype('float32') + variance_data = np.array([2.0]).astype('float32') + + fetch = exe.run( + feed={ + 'indices': indices_data, + 'values': values_data, + 'batch_norm_0.b_0': bias_data, + 'batch_norm_0.w_0': weight_data, + 'batch_norm_0.w_1': mean_data, + 'batch_norm_0.w_2': variance_data, + }, + fetch_list=[out], + return_numpy=True, + ) + correct_out = np.array( + [ + [ + [ + [[0.0], [-1.6832708], [0.0], [0.1055764]], + [[0.0], [0.0], [1.8944236], [0.0]], + [[0.0], [0.0], [0.0], [3.683271]], + ] + ] + ] + ).astype('float32') + np.testing.assert_allclose(correct_out, fetch[0], rtol=1e-5) + paddle.disable_static() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/sparse/nn/layer/norm.py b/python/paddle/sparse/nn/layer/norm.py index 7bdec9cd18..617ea1a78d 100644 --- a/python/paddle/sparse/nn/layer/norm.py +++ b/python/paddle/sparse/nn/layer/norm.py @@ -12,23 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import paddle import warnings from paddle.nn.layer.norm import _BatchNormBase from paddle.framework import no_grad +from paddle import _C_ops, in_dynamic_mode +from paddle.fluid.layer_helper import LayerHelper class BatchNorm(paddle.nn.BatchNorm1D): @@ -108,57 +97,112 @@ class BatchNorm(paddle.nn.BatchNorm1D): # [1, 6, 6, 6, 3] """ - def __init__(self, - num_features, - momentum=0.9, - epsilon=1e-05, - weight_attr=None, - bias_attr=None, - data_format='NDHWC', - use_global_stats=None, - name=None): - super(BatchNorm, self).__init__(num_features, - momentum=momentum, - epsilon=epsilon, - weight_attr=weight_attr, - bias_attr=bias_attr, - data_format=data_format, - use_global_stats=use_global_stats, - name=name) + def __init__( + self, + num_features, + momentum=0.9, + epsilon=1e-05, + weight_attr=None, + bias_attr=None, + data_format='NDHWC', + use_global_stats=None, + name=None, + ): + super(BatchNorm, self).__init__( + num_features, + momentum=momentum, + epsilon=epsilon, + weight_attr=weight_attr, + bias_attr=bias_attr, + data_format=data_format, + use_global_stats=use_global_stats, + name=name, + ) def _check_data_format(self, input): if input != "NDHWC": raise ValueError('sparse BatchNorm only support layout of "NDHWC"') def forward(self, input): - values = input.values() self._check_data_format(self._data_format) - if len(values.shape) != 2: - raise ValueError('expected 2D input.values() (got {}D)'.format( - len(values.shape))) - if self.training: warnings.warn( - "When training, we now always track global mean and variance.") - - batch_norm_out = paddle.nn.functional.batch_norm( - values, - self._mean, - self._variance, - weight=self.weight, - bias=self.bias, - training=self.training, - momentum=self._momentum, - epsilon=self._epsilon, - data_format='NC', - use_global_stats=self._use_global_stats) - - return paddle.sparse.sparse_coo_tensor( - input.indices(), - batch_norm_out, - shape=input.shape, - stop_gradient=input.stop_gradient) + "When training, we now always track global mean and variance." + ) + + if self._use_global_stats == None: + self._use_global_stats = not self.training + trainable_statistics = False + else: + trainable_statistics = not self._use_global_stats + + data_format = 'NCHW' if self._data_format[1] == 'C' else 'NHWC' + + if in_dynamic_mode(): + batch_norm_out, _, _, _, _, _ = _C_ops.sparse_batch_norm( + input, + self.weight, + self.bias, + self._mean, + self._variance, + self._momentum, + self._epsilon, + data_format, + not self.training, + self._use_global_stats, + trainable_statistics, + False, + ) + return batch_norm_out + else: + inputs = { + 'x': input, + 'scale': self.weight, + 'bias': self.bias, + 'mean': self._mean, + 'variance': self._variance, + } + attrs = { + 'momentum': self._momentum, + 'epsilon': self._epsilon, + 'data_layout': data_format, + 'is_test': not self.training, + 'use_global_stats': self._use_global_stats, + 'trainable_statistics': trainable_statistics, + 'fuse_with_relu': False, + } + op_type = 'sparse_batch_norm' + helper = LayerHelper(op_type) + dtype = input.dtype + mean_out = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True + ) + variance_out = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True + ) + saved_mean = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True + ) + saved_variance = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True + ) + reserve_space = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True + ) + y = helper.create_sparse_variable_for_type_inference(dtype) + outputs = { + "y": y, + "mean_out": mean_out, + "variance_out": variance_out, + "saved_mean": saved_mean, + "saved_variance": saved_variance, + "reserve_space": reserve_space, + } + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs + ) + return y class SyncBatchNorm(paddle.nn.SyncBatchNorm): @@ -258,26 +302,34 @@ class SyncBatchNorm(paddle.nn.SyncBatchNorm): # [-0.88415730, 1.57439375]]) """ - def __init__(self, - num_features, - momentum=0.9, - epsilon=1e-05, - weight_attr=None, - bias_attr=None, - data_format='NCHW', - name=None): - super(SyncBatchNorm, - self).__init__(num_features, momentum, epsilon, weight_attr, - bias_attr, data_format, name) + def __init__( + self, + num_features, + momentum=0.9, + epsilon=1e-05, + weight_attr=None, + bias_attr=None, + data_format='NCHW', + name=None, + ): + super(SyncBatchNorm, self).__init__( + num_features, + momentum, + epsilon, + weight_attr, + bias_attr, + data_format, + name, + ) def forward(self, x): - assert x.is_sparse_coo( + assert ( + x.is_sparse_coo() ), "SyncBatchNorm only support SparseTensor in COO format." out = super(SyncBatchNorm, self).forward(x.values()) - return paddle.sparse.sparse_coo_tensor(x.indices(), - out, - shape=x.shape, - stop_gradient=x.stop_gradient) + return paddle.sparse.sparse_coo_tensor( + x.indices(), out, shape=x.shape, stop_gradient=x.stop_gradient + ) @classmethod def convert_sync_batchnorm(cls, layer): @@ -303,27 +355,41 @@ class SyncBatchNorm(paddle.nn.SyncBatchNorm): layer_output = layer if isinstance(layer, _BatchNormBase): - if layer._weight_attr != None and not isinstance( - layer._weight_attr, - bool) and layer._weight_attr.name != None: + if ( + layer._weight_attr != None + and not isinstance(layer._weight_attr, bool) + and layer._weight_attr.name != None + ): layer._weight_attr.name = layer._weight_attr.name + '_sync' - if layer._bias_attr != None and not isinstance( - layer._bias_attr, bool) and layer._bias_attr.name != None: + if ( + layer._bias_attr != None + and not isinstance(layer._bias_attr, bool) + and layer._bias_attr.name != None + ): layer._bias_attr.name = layer._bias_attr.name + '_sync' - #convert sparse BatchNorm + # convert sparse BatchNorm if isinstance(layer, BatchNorm): - layer_output = SyncBatchNorm(layer._num_features, - layer._momentum, layer._epsilon, - layer._weight_attr, - layer._bias_attr, - layer._data_format, layer._name) - #convert dense BatchNorm + layer_output = SyncBatchNorm( + layer._num_features, + layer._momentum, + layer._epsilon, + layer._weight_attr, + layer._bias_attr, + layer._data_format, + layer._name, + ) + # convert dense BatchNorm else: layer_output = paddle.nn.SyncBatchNorm( - layer._num_features, layer._momentum, layer._epsilon, - layer._weight_attr, layer._bias_attr, layer._data_format, - layer._name) + layer._num_features, + layer._momentum, + layer._epsilon, + layer._weight_attr, + layer._bias_attr, + layer._data_format, + layer._name, + ) if layer._weight_attr != False and layer._bias_attr != False: with no_grad(): @@ -333,7 +399,8 @@ class SyncBatchNorm(paddle.nn.SyncBatchNorm): layer_output._variance = layer._variance for name, sublayer in layer.named_children(): - layer_output.add_sublayer(name, - cls.convert_sync_batchnorm(sublayer)) + layer_output.add_sublayer( + name, cls.convert_sync_batchnorm(sublayer) + ) del layer return layer_output -- GitLab