未验证 提交 5cd6a707 编写于 作者: Z zhangkaihuo 提交者: GitHub

add sync_batch_norm_kernel (#46430)

上级 912be4f8
......@@ -213,7 +213,7 @@ class SparseBatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("fuse_with_relu",
"(bool), attribute 4 for sparse_batch_norm op.");
AddComment(R"DOC(
TODO: Documentation of sparse_conv3d op.
TODO: Documentation of sparse_batch_norm op.
)DOC");
}
};
......
......@@ -356,6 +356,18 @@
func : subtract_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo},
subtract_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr}
- backward_op : sync_batch_norm_grad
forward : sync_batch_norm(Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
args : (Tensor x, Tensor scale, Tensor bias, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor out_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu)
output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param : [x, scale, bias]
kernel :
func : sync_batch_norm_coo_grad{sparse_coo, dense, dense, dense, dense, dense, sparse_coo -> sparse_coo, dense, dense}
data_type : out_grad
optional : reserve_space
- backward_op : tan_grad
forward : tan(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......
......@@ -479,3 +479,13 @@
transpose_csr{sparse_csr -> sparse_csr}
layout : x
backward : transpose_grad
- op : sync_batch_norm
args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu)
output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
infer_meta :
func : BatchNormInferMeta
kernel :
func : sync_batch_norm_coo{sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense}
data_type : x
backward : sync_batch_norm_grad
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/sync_batch_norm_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sync_batch_norm_grad_kernel.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void SyncBatchNormCooGradKernel(
const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
const paddle::optional<DenseTensor>& reserve_space,
const SparseCooTensor& y_grad,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
SparseCooTensor* x_grad,
DenseTensor* scale_grad,
DenseTensor* bias_grad) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, x_grad);
*scale_grad = phi::EmptyLike<T, Context>(dev_ctx, scale);
*bias_grad = phi::EmptyLike<T, Context>(dev_ctx, bias);
phi::SyncBatchNormGradKernel<T, Context>(dev_ctx,
x.values(),
scale,
bias,
saved_mean,
saved_variance,
reserve_space,
y_grad.values(),
momentum,
epsilon,
data_layout,
is_test,
use_global_stats,
trainable_statistics,
fuse_with_relu,
x_grad->mutable_values(),
scale_grad,
bias_grad);
}
} // namespace sparse
} // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(sync_batch_norm_coo_grad,
GPU,
ALL_LAYOUT,
phi::sparse::SyncBatchNormCooGradKernel,
float,
phi::dtype::float16) {}
#else
PD_REGISTER_KERNEL(sync_batch_norm_coo_grad,
GPU,
ALL_LAYOUT,
phi::sparse::SyncBatchNormCooGradKernel,
float,
double,
phi::dtype::float16) {}
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/sync_batch_norm_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sync_batch_norm_kernel.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void SyncBatchNormCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const DenseTensor& mean,
const DenseTensor& variance,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
SparseCooTensor* y,
DenseTensor* mean_out,
DenseTensor* variance_out,
DenseTensor* saved_mean,
DenseTensor* saved_variance,
DenseTensor* reserve_space) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, y);
phi::SyncBatchNormKernel<T, Context>(dev_ctx,
x.values(),
scale,
bias,
mean,
variance,
momentum,
epsilon,
data_layout,
is_test,
use_global_stats,
trainable_statistics,
fuse_with_relu,
y->mutable_values(),
mean_out,
variance_out,
saved_mean,
saved_variance,
reserve_space);
y->SetIndicesDict(x.GetIndicesDict());
}
} // namespace sparse
} // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(sync_batch_norm_coo,
GPU,
ALL_LAYOUT,
phi::sparse::SyncBatchNormCooKernel,
float,
phi::dtype::float16) {}
#else
PD_REGISTER_KERNEL(sync_batch_norm_coo,
GPU,
ALL_LAYOUT,
phi::sparse::SyncBatchNormCooKernel,
float,
double,
phi::dtype::float16) {}
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void SyncBatchNormCooGradKernel(
const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
const paddle::optional<DenseTensor>& reserve_space,
const SparseCooTensor& y_grad,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
SparseCooTensor* x_grad,
DenseTensor* scale_grad,
DenseTensor* bias_grad);
} // namespace sparse
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void SyncBatchNormCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const DenseTensor& mean,
const DenseTensor& variance,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
SparseCooTensor* y,
DenseTensor* mean_out,
DenseTensor* variance_out,
DenseTensor* saved_mean,
DenseTensor* saved_variance,
DenseTensor* reserve_space);
} // namespace sparse
} // namespace phi
......@@ -296,11 +296,12 @@ class SyncBatchNorm(paddle.nn.SyncBatchNorm):
bias_attr, data_format, name)
def forward(self, x):
assert x.is_sparse_coo(
), "SyncBatchNorm only support SparseTensor in COO format."
out = super(SyncBatchNorm, self).forward(x.values())
return paddle.incubate.sparse.sparse_coo_tensor(
x.indices(), out, shape=x.shape, stop_gradient=x.stop_gradient)
self._check_data_format()
sync_batch_norm_out, _, _, _, _, _ = _C_ops.sparse_sync_batch_norm(
x, self.weight, self.bias, self._mean, self._variance,
self._momentum, self._epsilon, self._data_format, not self.training,
False, False, False)
return sync_batch_norm_out
@classmethod
def convert_sync_batchnorm(cls, layer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册