未验证 提交 0fa8309a 编写于 作者: Z zhangkaihuo 提交者: GitHub

[cherry-pick]add sync_batch_norm_bn and deliver indices_dict (#47407)

add sync_batch_norm_bn and deliver indices_dict 
上级 eec93bda
......@@ -213,7 +213,7 @@ class SparseBatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("fuse_with_relu",
"(bool), attribute 4 for sparse_batch_norm op.");
AddComment(R"DOC(
TODO: Documentation of sparse_conv3d op.
TODO: Documentation of sparse_batch_norm op.
)DOC");
}
};
......
......@@ -22,7 +22,6 @@ from api_base import PREFIX_TENSOR_NAME
class SparseAPI(ForwardAPI):
def __init__(self, api_item_yaml):
super(SparseAPI, self).__init__(api_item_yaml)
......@@ -32,11 +31,13 @@ class SparseAPI(ForwardAPI):
{super(SparseAPI, self).gene_api_declaration()}
"""
def gene_output(self,
out_dtype_list,
out_tensor_type_list=None,
code_indent='',
inplace_flag=False):
def gene_output(
self,
out_dtype_list,
out_tensor_type_list=None,
code_indent='',
inplace_flag=False,
):
kernel_output = []
output_names = []
output_create = ""
......@@ -44,15 +45,19 @@ class SparseAPI(ForwardAPI):
output_type_map = {
'dense': 'TensorType::DENSE_TENSOR',
'sparse_coo': 'TensorType::SPARSE_COO',
'sparse_csr': 'TensorType::SPARSE_CSR'
'sparse_csr': 'TensorType::SPARSE_CSR',
}
if len(out_dtype_list) == 1:
kernel_output.append('kernel_out')
output_names.append('kernel_out')
inplace_assign = " = " + self.inplace_map[self.outputs['names'][
0]] if inplace_flag and self.inplace_map is not None and self.outputs[
'names'][0] in self.inplace_map else ""
inplace_assign = (
" = " + self.inplace_map[self.outputs['names'][0]]
if inplace_flag
and self.inplace_map is not None
and self.outputs['names'][0] in self.inplace_map
else ""
)
output_create = f"""
{return_type} api_output{inplace_assign};
auto* kernel_out = SetSparseKernelOutput(&api_output, {output_type_map[out_dtype_list[0]]});"""
......@@ -67,8 +72,9 @@ class SparseAPI(ForwardAPI):
for out_name in self.outputs['names']:
if out_name in self.inplace_map:
output_create = output_create + self.inplace_map[
out_name] + ', '
output_create = (
output_create + self.inplace_map[out_name] + ', '
)
else:
output_create += 'Tensor(), '
output_create = output_create[:-2] + '};'
......@@ -76,28 +82,30 @@ class SparseAPI(ForwardAPI):
for i in range(len(out_dtype_list)):
kernel_output.append(f'kernel_out_{i}')
output_names.append(f'kernel_out_{i}')
output_create = output_create + f"""
output_create = (
output_create
+ f"""
auto* kernel_out_{i} = SetSparseKernelOutput(&std::get<{i}>(api_output), {output_type_map[out_dtype_list[i]]});"""
)
else:
raise ValueError(
"{} : Output error: the output should not be empty.".format(
self.api))
self.api
)
)
return kernel_output, output_names, output_create
def gen_sparse_kernel_context(self, kernel_output_names):
input_trans_map = {
'const Tensor&':
'const phi::TenseBase&',
'const std::vector<Tensor>&':
'const std::vector<phi::TenseBase>&',
'const paddle::optional<Tensor>&':
'paddle::optional<const phi::TenseBase&>'
'const Tensor&': 'const phi::TenseBase&',
'const std::vector<Tensor>&': 'const std::vector<phi::TenseBase>&',
'const paddle::optional<Tensor>&': 'paddle::optional<const phi::TenseBase&>',
}
out_trans_map = {
'Tensor': 'phi::TenseBase*',
'std::vector<Tensor>': 'std::vector<phi::TenseBase*>'
'std::vector<Tensor>': 'std::vector<phi::TenseBase*>',
}
input_names = self.inputs['names']
input_infos = self.inputs['input_info']
......@@ -111,11 +119,17 @@ class SparseAPI(ForwardAPI):
for param in kernel_param:
if param in input_names:
if param in self.optional_vars:
kernel_context_code = kernel_context_code + f"""
kernel_context_code = (
kernel_context_code
+ f"""
kernel_context.EmplaceBackInput({param} ? {param}->impl().get() : nullptr);"""
)
else:
kernel_context_code = kernel_context_code + f"""
kernel_context_code = (
kernel_context_code
+ f"""
kernel_context.EmplaceBackInput({param}.impl().get());"""
)
continue
if param in attr_names:
......@@ -128,12 +142,18 @@ class SparseAPI(ForwardAPI):
param = str(param).lower()
else:
param + str(param) + ", "
kernel_context_code = kernel_context_code + f"""
kernel_context_code = (
kernel_context_code
+ f"""
kernel_context.EmplaceBackAttr({param});"""
)
for out_name in kernel_output_names:
kernel_context_code = kernel_context_code + f"""
kernel_context_code = (
kernel_context_code
+ f"""
kernel_context.EmplaceBackOutput({out_name});"""
)
return kernel_context_code
......@@ -143,20 +163,25 @@ class SparseAPI(ForwardAPI):
attr_names = self.attrs['names']
infer_meta = self.infer_meta
infer_meta_params = infer_meta['param'] if infer_meta[
'param'] is not None else input_names + attr_names
infer_meta_params = (
infer_meta['param']
if infer_meta['param'] is not None
else input_names + attr_names
)
create_input_var_code = ""
tensor_type_map = {
'dense': 'phi::DenseTensor',
'sparse_coo': 'phi::SparseCooTensor',
'sparse_csr': 'phi::SparseCsrTensor'
'sparse_csr': 'phi::SparseCsrTensor',
}
for param in infer_meta_params:
if param in input_names:
var_name = "auto " + PREFIX_TENSOR_NAME + param + " = "
if self.inputs['input_info'][param] == "const Tensor&":
create_input_var_code = create_input_var_code + var_name + param + ".impl();\n"
create_input_var_code = (
create_input_var_code + var_name + param + ".impl();\n"
)
elif param in self.optional_vars:
tensor_type = 'phi::DenseTensor'
for name, input_type in zip(input_names, input_types):
......@@ -164,17 +189,35 @@ class SparseAPI(ForwardAPI):
tensor_type = tensor_type_map[input_type]
break
optional_var = "paddle::optional<" + tensor_type + ">("
create_input_var_code = create_input_var_code + var_name + param + " ? " + optional_var + "*static_cast<" + tensor_type + "*>((*" + param + ").impl().get())) : " + optional_var + "paddle::none);\n"
create_input_var_code = (
create_input_var_code
+ var_name
+ param
+ " ? "
+ optional_var
+ "*static_cast<"
+ tensor_type
+ "*>((*"
+ param
+ ").impl().get())) : "
+ optional_var
+ "paddle::none);\n"
)
return f"""{create_input_var_code}"""
def gen_sparse_kernel_code(self, kernel_name, inplace_flag=False):
_, kernel_output_names, output_create = self.gene_output(
self.kernel['dispatch'][kernel_name][1], None, '', inplace_flag)
self.kernel['dispatch'][kernel_name][1], None, '', inplace_flag
)
kernel_context_code = self.gen_sparse_kernel_context(
kernel_output_names)
return_code = "" if len(
self.gene_return_code()) == 0 else " " + self.gene_return_code()
kernel_output_names
)
return_code = (
""
if len(self.gene_return_code()) == 0
else " " + self.gene_return_code()
)
return f"""
VLOG(6) << "{self.api} api sparse kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
......@@ -192,12 +235,13 @@ class SparseAPI(ForwardAPI):
{return_code}"""
def get_condition_code(self, kernel_name):
assert self.kernel['dispatch'][kernel_name], \
f"{self.api} api: the tensor type of inputs and outputs for kernel isn't set, see also 'kernel:func' of 'conv3d' in sparse_ops.yaml."
assert self.kernel['dispatch'][
kernel_name
], f"{self.api} api: the tensor type of inputs and outputs for kernel isn't set, see also 'kernel:func' of 'conv3d' in sparse_ops.yaml."
input_types = self.kernel['dispatch'][kernel_name][0]
sparse_type_map = {
'sparse_coo': 'DataLayout::SPARSE_COO',
'sparse_csr': 'DataLayout::SPARSE_CSR'
'sparse_csr': 'DataLayout::SPARSE_CSR',
}
condition_list = []
tensor_type_list = []
......@@ -214,10 +258,12 @@ class SparseAPI(ForwardAPI):
else:
if in_type == 'sparse_coo':
condition_list.append(
f"{self.inputs['names'][i]}.is_sparse_coo_tensor()")
f"{self.inputs['names'][i]}.is_sparse_coo_tensor()"
)
else:
condition_list.append(
f"{self.inputs['names'][i]}.is_sparse_csr_tensor()")
f"{self.inputs['names'][i]}.is_sparse_csr_tensor()"
)
tensor_type_list.append(in_type)
self.inputs['tensor_type'] = tensor_type_list
......@@ -237,10 +283,11 @@ class SparseAPI(ForwardAPI):
kernel_dispatch_code = f"{self.gene_kernel_select()}\n"
for kernel_name in self.kernel['func']:
kernel_dispatch_code += self.gene_dispatch_code(
kernel_name, inplace_flag)
kernel_name, inplace_flag
)
return f"""
PADDLE_API {self.get_return_type()} {api_func_name}({self.get_define_args()}) {{
PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{
{kernel_dispatch_code}
PADDLE_THROW(phi::errors::Unimplemented(
"The kernel of ({self.api}) for input tensors is unimplemented, please check the type of input tensors."));
......@@ -283,17 +330,20 @@ def source_include(header_file_path):
def api_namespace():
return ("""
return (
"""
namespace paddle {
namespace experimental {
namespace sparse {
""", """
""",
"""
} // namespace sparse
} // namespace experimental
} // namespace paddle
""")
""",
)
def generate_api(api_yaml_path, header_file_path, source_file_path):
......@@ -329,18 +379,25 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
def main():
parser = argparse.ArgumentParser(
description='Generate PaddlePaddle C++ Sparse API files')
parser.add_argument('--api_yaml_path',
help='path to sparse api yaml file',
default='paddle/phi/api/yaml/sparse_ops.yaml')
parser.add_argument('--api_header_path',
help='output of generated api header code file',
default='paddle/phi/api/include/sparse_api.h')
parser.add_argument('--api_source_path',
help='output of generated api source code file',
default='paddle/phi/api/lib/sparse_api.cc')
description='Generate PaddlePaddle C++ Sparse API files'
)
parser.add_argument(
'--api_yaml_path',
help='path to sparse api yaml file',
default='paddle/phi/api/yaml/sparse_ops.yaml',
)
parser.add_argument(
'--api_header_path',
help='output of generated api header code file',
default='paddle/phi/api/include/sparse_api.h',
)
parser.add_argument(
'--api_source_path',
help='output of generated api source code file',
default='paddle/phi/api/lib/sparse_api.cc',
)
options = parser.parse_args()
......
......@@ -367,6 +367,18 @@
func : subtract_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo},
subtract_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr}
- backward_op : sync_batch_norm_grad
forward : sync_batch_norm_(Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
args : (Tensor x, Tensor scale, Tensor bias, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor out_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu)
output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param : [x, scale, bias]
kernel :
func : sync_batch_norm_coo_grad{sparse_coo, dense, dense, dense, dense, dense, sparse_coo -> sparse_coo, dense, dense}
data_type : out_grad
optional : reserve_space
- backward_op : tan_grad
forward : tan(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......
......@@ -95,6 +95,7 @@
kernel :
func : batch_norm_coo {sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense}
data_type : x
view : (mean -> mean_out), (variance -> variance_out)
backward : batch_norm_grad
- op : cast
......@@ -378,7 +379,8 @@
args : (Tensor input, Tensor x, Tensor y, float alpha=1.0, float beta=1.0)
output : Tensor(out)
infer_meta :
func : AddmmInferMeta
func : UnchangedInferMeta
param : [input]
kernel :
func : addmm_csr_dense {dense, sparse_csr, dense -> dense},
addmm_csr_csr {sparse_csr, sparse_csr, sparse_csr -> sparse_csr},
......@@ -480,6 +482,17 @@
layout : x
backward : transpose_grad
- op : sync_batch_norm_
args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu)
output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
infer_meta :
func : BatchNormInferMeta
kernel :
func : sync_batch_norm_coo{sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense}
data_type : x
backward : sync_batch_norm_grad
inplace : (mean -> mean_out), (variance -> variance_out)
- op : reshape
args : (Tensor x, IntArray shape)
output : Tensor(out)
......
......@@ -31,6 +31,7 @@ void EmptyLikeCooKernel(const Context& dev_ctx,
const DenseTensor& x_values = x.values();
DenseTensor* out_values = out->mutable_values();
out_values->Resize(x_values.dims());
out->set_meta(x.meta());
dev_ctx.template Alloc<T>(out_values);
}
......@@ -44,6 +45,7 @@ void EmptyLikeCsrKernel(const Context& dev_ctx,
const DenseTensor& x_values = x.values();
DenseTensor* out_values = out->mutable_values();
out_values->Resize(x_values.dims());
out->set_meta(x.meta());
dev_ctx.template Alloc<T>(out_values);
}
......
......@@ -169,6 +169,7 @@ void CoalesceGPUKernel(const GPUContext& dev_ctx,
indexs_ptr, const_dims, out_nnz, sparse_dim, out_indices.data<IntT>());
out->SetMember(out_indices, out_values, x.dims(), true);
out->SetIndicesDict(x.GetIndicesDict());
}
template <typename T, typename Context>
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/sync_batch_norm_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sync_batch_norm_grad_kernel.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void SyncBatchNormCooGradKernel(
const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
const paddle::optional<DenseTensor>& reserve_space,
const SparseCooTensor& y_grad,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
SparseCooTensor* x_grad,
DenseTensor* scale_grad,
DenseTensor* bias_grad) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, x_grad);
*scale_grad = phi::EmptyLike<T, Context>(dev_ctx, scale);
*bias_grad = phi::EmptyLike<T, Context>(dev_ctx, bias);
phi::SyncBatchNormGradKernel<T, Context>(dev_ctx,
x.values(),
scale,
bias,
saved_mean,
saved_variance,
reserve_space,
y_grad.values(),
momentum,
epsilon,
data_layout,
is_test,
use_global_stats,
trainable_statistics,
fuse_with_relu,
x_grad->mutable_values(),
scale_grad,
bias_grad);
}
} // namespace sparse
} // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(sync_batch_norm_coo_grad,
GPU,
ALL_LAYOUT,
phi::sparse::SyncBatchNormCooGradKernel,
float,
phi::dtype::float16) {}
#else
PD_REGISTER_KERNEL(sync_batch_norm_coo_grad,
GPU,
ALL_LAYOUT,
phi::sparse::SyncBatchNormCooGradKernel,
float,
double,
phi::dtype::float16) {}
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/sync_batch_norm_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sync_batch_norm_kernel.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void SyncBatchNormCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const DenseTensor& mean,
const DenseTensor& variance,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
SparseCooTensor* y,
DenseTensor* mean_out,
DenseTensor* variance_out,
DenseTensor* saved_mean,
DenseTensor* saved_variance,
DenseTensor* reserve_space) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, y);
phi::SyncBatchNormKernel<T, Context>(dev_ctx,
x.values(),
scale,
bias,
mean,
variance,
momentum,
epsilon,
data_layout,
is_test,
use_global_stats,
trainable_statistics,
fuse_with_relu,
y->mutable_values(),
mean_out,
variance_out,
saved_mean,
saved_variance,
reserve_space);
y->SetIndicesDict(x.GetIndicesDict());
}
} // namespace sparse
} // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(sync_batch_norm_coo,
GPU,
ALL_LAYOUT,
phi::sparse::SyncBatchNormCooKernel,
float,
phi::dtype::float16) {}
#else
PD_REGISTER_KERNEL(sync_batch_norm_coo,
GPU,
ALL_LAYOUT,
phi::sparse::SyncBatchNormCooKernel,
float,
double,
phi::dtype::float16) {}
#endif
......@@ -37,6 +37,7 @@ namespace sparse {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, out); \
phi::prefix##Kernel<T, Context>( \
dev_ctx, x.non_zero_elements(), out->mutable_non_zero_elements()); \
out->SetIndicesDict(x.GetIndicesDict()); \
} \
\
template <typename T, typename Context> \
......@@ -105,6 +106,7 @@ void ScaleCooKernel(const Context& dev_ctx,
bias,
bias_after_scale,
out->mutable_non_zero_elements());
out->SetIndicesDict(x.GetIndicesDict());
}
template <typename T, typename Context>
......@@ -155,6 +157,7 @@ void CastCooKernel(const Context& dev_ctx,
meta.set_dtype(value_dtype);
phi::CastKernel<T, Context>(dev_ctx, x_values, value_dtype, out_values);
}
out->SetIndicesDict(x.GetIndicesDict());
}
template <typename T, typename Context>
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void SyncBatchNormCooGradKernel(
const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
const paddle::optional<DenseTensor>& reserve_space,
const SparseCooTensor& y_grad,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
SparseCooTensor* x_grad,
DenseTensor* scale_grad,
DenseTensor* bias_grad);
} // namespace sparse
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void SyncBatchNormCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const DenseTensor& mean,
const DenseTensor& variance,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
SparseCooTensor* y,
DenseTensor* mean_out,
DenseTensor* variance_out,
DenseTensor* saved_mean,
DenseTensor* saved_variance,
DenseTensor* reserve_space);
} // namespace sparse
} // namespace phi
......@@ -323,13 +323,22 @@ class SyncBatchNorm(paddle.nn.SyncBatchNorm):
)
def forward(self, x):
assert (
x.is_sparse_coo()
), "SyncBatchNorm only support SparseTensor in COO format."
out = super(SyncBatchNorm, self).forward(x.values())
return paddle.sparse.sparse_coo_tensor(
x.indices(), out, shape=x.shape, stop_gradient=x.stop_gradient
self._check_data_format()
sync_batch_norm_out, _, _, _, _, _ = _C_ops.sparse_sync_batch_norm_(
x,
self.weight,
self.bias,
self._mean,
self._variance,
self._momentum,
self._epsilon,
self._data_format,
not self.training,
False,
False,
False,
)
return sync_batch_norm_out
@classmethod
def convert_sync_batchnorm(cls, layer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册