未验证 提交 888223b7 编写于 作者: Z zhangkaihuo 提交者: GitHub

[Sparse] Add a batch_norm kernel (#46359)

上级 57cdde13
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/sparse/binary.h"
#include "paddle/phi/infermeta/sparse/unary.h"
#include "paddle/phi/infermeta/unary.h"
......@@ -185,6 +186,47 @@ DECLARE_INFER_SHAPE_FUNCTOR(sparse_add,
SparseAddInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
class SparseBatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("x", "(Tensor), input 0 of sparse_batch_norm op.");
AddInput("scale", "(Tensor), input 1 of sparse_batch_norm op.");
AddInput("bias", "(Tensor), input 2 of sparse_batch_norm op.");
AddInput("mean", "(Tensor), input 3 of sparse_batch_norm op.");
AddInput("variance", "(Tensor), input 4 of sparse_batch_norm op.");
AddOutput("y", "(Tensor), output 0 of sparse_batch_norm op.");
AddOutput("mean_out", "(Tensor), output 1 of sparse_batch_norm op.");
AddOutput("variance_out", "(Tensor), output 2 of sparse_batch_norm op.");
AddOutput("saved_mean", "(Tensor), output 3 of sparse_batch_norm op.");
AddOutput("saved_variance", "(Tensor), output 4 of sparse_batch_norm op.");
AddOutput("reserve_space", "(Tensor), output 5 of sparse_batch_norm op.");
AddAttr<float>("momentum",
"(float), attribute 0 for sparse_batch_norm op.");
AddAttr<float>("epsilon", "(float), attribute 1 for sparse_batch_norm op.");
AddAttr<std::string>("data_layout",
"(string), attribute 2 for sparse_batch_norm op.");
AddAttr<bool>("is_test", "(bool), attribute 3 for sparse_batch_norm op.");
AddAttr<bool>("use_global_stats",
"(bool), attribute 4 for sparse_batch_norm op.");
AddAttr<bool>("trainable_statistics",
"(bool), attribute 4 for sparse_batch_norm op.");
AddAttr<bool>("fuse_with_relu",
"(bool), attribute 4 for sparse_batch_norm op.");
AddComment(R"DOC(
TODO: Documentation of sparse_conv3d op.
)DOC");
}
};
class SparseBatchNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
DECLARE_INFER_SHAPE_FUNCTOR(sparse_batch_norm,
SparseBatchNormInferShapeFunctor,
PD_INFER_META(phi::BatchNormInferMeta));
} // namespace operators
} // namespace paddle
......@@ -224,3 +266,8 @@ REGISTER_OPERATOR(sparse_add,
ops::SparseAddOp,
ops::SparseAddOpMaker,
ops::SparseAddInferShapeFunctor);
REGISTER_OPERATOR(sparse_batch_norm,
ops::SparseBatchNormOp,
ops::SparseBatchNormOpMaker,
ops::SparseBatchNormInferShapeFunctor);
......@@ -100,6 +100,18 @@
func : atanh_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
atanh_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
- backward_op : batch_norm_grad
forward : batch_norm (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
args : (Tensor x, Tensor scale, Tensor bias, Tensor mean_out, Tensor variance_out, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor out_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu)
output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param : [x, scale, bias]
kernel :
func : batch_norm_coo_grad {sparse_coo, dense, dense, dense, dense, dense, dense, dense, sparse_coo -> sparse_coo, dense, dense}
data_type : out_grad
optional : mean_out, variance_out, reserve_space
- backward_op : cast_grad
forward : cast(Tensor x, DataType index_dtype, DataType value_dtype) -> Tensor(out)
args : (Tensor x, Tensor out_grad, DataType value_dtype)
......
......@@ -87,6 +87,16 @@
layout : x
backward : atanh_grad
- op : batch_norm
args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu)
output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
infer_meta :
func : BatchNormInferMeta
kernel :
func : batch_norm_coo {sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense}
data_type : x
backward : batch_norm_grad
- op : cast
args : (Tensor x, DataType index_dtype=DataType::UNDEFINED, DataType value_dtype=DataType::UNDEFINED)
output : Tensor(out)
......
......@@ -638,6 +638,7 @@ void BatchNormInferMeta(const MetaTensor& x,
saved_variance->set_dims({C});
}
y->share_lod(x);
y->set_dtype(x.dtype());
}
void BatchNormInferInferMeta(const MetaTensor& x,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/batch_norm_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/batch_norm_grad_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void BatchNormCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const paddle::optional<DenseTensor>& mean,
const paddle::optional<DenseTensor>& variance,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
const paddle::optional<DenseTensor>& reserve_space,
const SparseCooTensor& y_grad,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
SparseCooTensor* x_grad,
DenseTensor* scale_grad,
DenseTensor* bias_grad) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, x_grad);
*scale_grad = phi::EmptyLike<T, Context>(dev_ctx, scale);
*bias_grad = phi::EmptyLike<T, Context>(dev_ctx, bias);
phi::BatchNormGradKernel<T, Context>(dev_ctx,
x.values(),
scale,
bias,
mean,
variance,
saved_mean,
saved_variance,
reserve_space,
y_grad.values(),
momentum,
epsilon,
data_layout,
is_test,
use_global_stats,
trainable_statistics,
fuse_with_relu,
x_grad->mutable_values(),
scale_grad,
bias_grad);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(batch_norm_coo_grad,
CPU,
ALL_LAYOUT,
phi::sparse::BatchNormCooGradKernel,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
#if defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(batch_norm_coo_grad,
GPU,
ALL_LAYOUT,
phi::sparse::BatchNormCooGradKernel,
float,
phi::dtype::float16) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
#endif
#if defined(PADDLE_WITH_CUDA)
PD_REGISTER_KERNEL(batch_norm_coo_grad,
GPU,
ALL_LAYOUT,
phi::sparse::BatchNormCooGradKernel,
float,
double,
phi::dtype::float16) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); // scale_grad
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
}
}
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void BatchNormCooGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const paddle::optional<DenseTensor>& mean,
const paddle::optional<DenseTensor>& variance,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
const paddle::optional<DenseTensor>& reserve_space,
const SparseCooTensor& y_grad,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
SparseCooTensor* x_grad,
DenseTensor* scale_grad,
DenseTensor* bias_grad);
} // namespace sparse
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/batch_norm_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/batch_norm_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void BatchNormCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const DenseTensor& mean,
const DenseTensor& variance,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
SparseCooTensor* y,
DenseTensor* mean_out,
DenseTensor* variance_out,
DenseTensor* saved_mean,
DenseTensor* saved_variance,
DenseTensor* reserve_space) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, y);
phi::BatchNormKernel<T, Context>(dev_ctx,
x.values(),
scale,
bias,
mean,
variance,
momentum,
epsilon,
data_layout,
is_test,
use_global_stats,
trainable_statistics,
fuse_with_relu,
y->mutable_values(),
mean_out,
variance_out,
saved_mean,
saved_variance,
reserve_space);
y->SetIndicesDict(x.GetIndicesDict());
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(batch_norm_coo,
CPU,
ALL_LAYOUT,
phi::sparse::BatchNormCooKernel,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
#if defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(batch_norm_coo,
GPU,
ALL_LAYOUT,
phi::sparse::BatchNormCooKernel,
float,
phi::dtype::float16) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
}
#endif
#if defined(PADDLE_WITH_CUDA)
PD_REGISTER_KERNEL(batch_norm_coo,
GPU,
ALL_LAYOUT,
phi::sparse::BatchNormCooKernel,
float,
double,
phi::dtype::float16) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
}
}
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void BatchNormKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const DenseTensor& mean,
const DenseTensor& variance,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
SparseCooTensor* y,
DenseTensor* mean_out,
DenseTensor* variance_out,
DenseTensor* saved_mean,
DenseTensor* saved_variance,
DenseTensor* reserve_space);
} // namespace sparse
} // namespace phi
......@@ -82,6 +82,29 @@ KernelSignature SparseAddOpArgumentMapping(const ArgumentMappingContext& ctx) {
}
}
KernelSignature SparseBatchNormOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsSparseCooTensorInput("x")) {
return KernelSignature("batch_norm_coo",
{"x", "scale", "bias", "mean", "variance"},
{"momentum",
"epsilon",
"data_layout",
"is_test",
"use_global_stats",
"trainable_statistics",
"fuse_with_relu"},
{"y",
"mean_out",
"variance_out",
"saved_mean",
"saved_variance",
"reserve_space"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(sparse_sparse_coo_tensor, sparse_coo_tensor);
......@@ -106,3 +129,7 @@ PD_REGISTER_ARG_MAPPING_FN(sparse_conv3d, phi::SparseConv3dOpArgumentMapping);
PD_REGISTER_BASE_KERNEL_NAME(sparse_add, add_coo_coo);
PD_REGISTER_ARG_MAPPING_FN(sparse_add, phi::SparseAddOpArgumentMapping);
PD_REGISTER_BASE_KERNEL_NAME(sparse_batch_norm, batch_norm_coo);
PD_REGISTER_ARG_MAPPING_FN(sparse_batch_norm,
phi::SparseBatchNormOpArgumentMapping);
......@@ -16,6 +16,7 @@ import unittest
import numpy as np
import paddle
from paddle.incubate.sparse import nn
import paddle.incubate.sparse as sparse
import paddle.fluid as fluid
import copy
......@@ -117,5 +118,49 @@ class TestSyncBatchNorm(unittest.TestCase):
self.assertEqual(isinstance(model[idx], nn.SyncBatchNorm), True)
class TestStatic(unittest.TestCase):
def test(self):
paddle.enable_static()
indices = paddle.static.data(name='indices',
shape=[4, 4],
dtype='int32')
values = paddle.static.data(name='values',
shape=[4, 1],
dtype='float32')
channels = 1
dense_shape = [1, 1, 3, 4, channels]
sp_x = sparse.sparse_coo_tensor(indices, values, dense_shape)
sparse_batch_norm = paddle.incubate.sparse.nn.BatchNorm(channels)
sp_y = sparse_batch_norm(sp_x)
out = sp_y.to_dense()
exe = paddle.static.Executor()
indices_data = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]]
values_data = np.array([[1.0], [2.0], [3.0], [4.0]]).astype('float32')
bias_data = np.array([1.0]).astype('float32')
weight_data = np.array([2.0]).astype('float32')
mean_data = np.array([1.0]).astype('float32')
variance_data = np.array([2.0]).astype('float32')
fetch = exe.run(feed={
'indices': indices_data,
'values': values_data,
'batch_norm_0.b_0': bias_data,
'batch_norm_0.w_0': weight_data,
'batch_norm_0.w_1': mean_data,
'batch_norm_0.w_2': variance_data
},
fetch_list=[out],
return_numpy=True)
correct_out = np.array([[[[[0.0], [-1.6832708], [0.0], [0.1055764]],
[[0.0], [0.0], [1.8944236], [0.0]],
[[0.0], [0.0], [0.0],
[3.683271]]]]]).astype('float32')
np.testing.assert_allclose(correct_out, fetch[0], rtol=1e-5)
paddle.disable_static()
if __name__ == "__main__":
unittest.main()
......@@ -12,23 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import warnings
from paddle.nn.layer.norm import _BatchNormBase
from paddle.framework import no_grad
from paddle import _C_ops, in_dynamic_mode
from paddle.fluid.layer_helper import LayerHelper
class BatchNorm(paddle.nn.BatchNorm1D):
......@@ -131,34 +120,70 @@ class BatchNorm(paddle.nn.BatchNorm1D):
raise ValueError('sparse BatchNorm only support layout of "NDHWC"')
def forward(self, input):
values = input.values()
self._check_data_format(self._data_format)
if len(values.shape) != 2:
raise ValueError('expected 2D input.values() (got {}D)'.format(
len(values.shape)))
if self.training:
warnings.warn(
"When training, we now always track global mean and variance.")
batch_norm_out = paddle.nn.functional.batch_norm(
values,
self._mean,
self._variance,
weight=self.weight,
bias=self.bias,
training=self.training,
momentum=self._momentum,
epsilon=self._epsilon,
data_format='NC',
use_global_stats=self._use_global_stats)
return paddle.incubate.sparse.sparse_coo_tensor(
input.indices(),
batch_norm_out,
shape=input.shape,
stop_gradient=input.stop_gradient)
if self._use_global_stats == None:
self._use_global_stats = not self.training
trainable_statistics = False
else:
trainable_statistics = not self._use_global_stats
data_format = 'NCHW' if self._data_format[1] == 'C' else 'NHWC'
if in_dynamic_mode():
batch_norm_out, _, _, _, _, _ = _C_ops.sparse_batch_norm(
input, self.weight, self.bias, self._mean, self._variance,
self._momentum, self._epsilon, data_format, not self.training,
self._use_global_stats, trainable_statistics, False)
return batch_norm_out
else:
inputs = {
'x': input,
'scale': self.weight,
'bias': self.bias,
'mean': self._mean,
'variance': self._variance
}
attrs = {
'momentum': self._momentum,
'epsilon': self._epsilon,
'data_layout': data_format,
'is_test': not self.training,
'use_global_stats': self._use_global_stats,
'trainable_statistics': trainable_statistics,
'fuse_with_relu': False
}
op_type = 'sparse_batch_norm'
helper = LayerHelper(op_type)
dtype = input.dtype
mean_out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
variance_out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
saved_mean = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
saved_variance = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
reserve_space = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
y = helper.create_sparse_variable_for_type_inference(dtype)
outputs = {
"y": y,
"mean_out": mean_out,
"variance_out": variance_out,
"saved_mean": saved_mean,
"saved_variance": saved_variance,
"reserve_space": reserve_space
}
helper.append_op(type=op_type,
inputs=inputs,
outputs=outputs,
attrs=attrs)
return y
class SyncBatchNorm(paddle.nn.SyncBatchNorm):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册