未验证 提交 0fa8309a 编写于 作者: Z zhangkaihuo 提交者: GitHub

[cherry-pick]add sync_batch_norm_bn and deliver indices_dict (#47407)

add sync_batch_norm_bn and deliver indices_dict 
上级 eec93bda
...@@ -213,7 +213,7 @@ class SparseBatchNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -213,7 +213,7 @@ class SparseBatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("fuse_with_relu", AddAttr<bool>("fuse_with_relu",
"(bool), attribute 4 for sparse_batch_norm op."); "(bool), attribute 4 for sparse_batch_norm op.");
AddComment(R"DOC( AddComment(R"DOC(
TODO: Documentation of sparse_conv3d op. TODO: Documentation of sparse_batch_norm op.
)DOC"); )DOC");
} }
}; };
......
...@@ -22,7 +22,6 @@ from api_base import PREFIX_TENSOR_NAME ...@@ -22,7 +22,6 @@ from api_base import PREFIX_TENSOR_NAME
class SparseAPI(ForwardAPI): class SparseAPI(ForwardAPI):
def __init__(self, api_item_yaml): def __init__(self, api_item_yaml):
super(SparseAPI, self).__init__(api_item_yaml) super(SparseAPI, self).__init__(api_item_yaml)
...@@ -32,11 +31,13 @@ class SparseAPI(ForwardAPI): ...@@ -32,11 +31,13 @@ class SparseAPI(ForwardAPI):
{super(SparseAPI, self).gene_api_declaration()} {super(SparseAPI, self).gene_api_declaration()}
""" """
def gene_output(self, def gene_output(
out_dtype_list, self,
out_tensor_type_list=None, out_dtype_list,
code_indent='', out_tensor_type_list=None,
inplace_flag=False): code_indent='',
inplace_flag=False,
):
kernel_output = [] kernel_output = []
output_names = [] output_names = []
output_create = "" output_create = ""
...@@ -44,15 +45,19 @@ class SparseAPI(ForwardAPI): ...@@ -44,15 +45,19 @@ class SparseAPI(ForwardAPI):
output_type_map = { output_type_map = {
'dense': 'TensorType::DENSE_TENSOR', 'dense': 'TensorType::DENSE_TENSOR',
'sparse_coo': 'TensorType::SPARSE_COO', 'sparse_coo': 'TensorType::SPARSE_COO',
'sparse_csr': 'TensorType::SPARSE_CSR' 'sparse_csr': 'TensorType::SPARSE_CSR',
} }
if len(out_dtype_list) == 1: if len(out_dtype_list) == 1:
kernel_output.append('kernel_out') kernel_output.append('kernel_out')
output_names.append('kernel_out') output_names.append('kernel_out')
inplace_assign = " = " + self.inplace_map[self.outputs['names'][ inplace_assign = (
0]] if inplace_flag and self.inplace_map is not None and self.outputs[ " = " + self.inplace_map[self.outputs['names'][0]]
'names'][0] in self.inplace_map else "" if inplace_flag
and self.inplace_map is not None
and self.outputs['names'][0] in self.inplace_map
else ""
)
output_create = f""" output_create = f"""
{return_type} api_output{inplace_assign}; {return_type} api_output{inplace_assign};
auto* kernel_out = SetSparseKernelOutput(&api_output, {output_type_map[out_dtype_list[0]]});""" auto* kernel_out = SetSparseKernelOutput(&api_output, {output_type_map[out_dtype_list[0]]});"""
...@@ -67,8 +72,9 @@ class SparseAPI(ForwardAPI): ...@@ -67,8 +72,9 @@ class SparseAPI(ForwardAPI):
for out_name in self.outputs['names']: for out_name in self.outputs['names']:
if out_name in self.inplace_map: if out_name in self.inplace_map:
output_create = output_create + self.inplace_map[ output_create = (
out_name] + ', ' output_create + self.inplace_map[out_name] + ', '
)
else: else:
output_create += 'Tensor(), ' output_create += 'Tensor(), '
output_create = output_create[:-2] + '};' output_create = output_create[:-2] + '};'
...@@ -76,28 +82,30 @@ class SparseAPI(ForwardAPI): ...@@ -76,28 +82,30 @@ class SparseAPI(ForwardAPI):
for i in range(len(out_dtype_list)): for i in range(len(out_dtype_list)):
kernel_output.append(f'kernel_out_{i}') kernel_output.append(f'kernel_out_{i}')
output_names.append(f'kernel_out_{i}') output_names.append(f'kernel_out_{i}')
output_create = output_create + f""" output_create = (
output_create
+ f"""
auto* kernel_out_{i} = SetSparseKernelOutput(&std::get<{i}>(api_output), {output_type_map[out_dtype_list[i]]});""" auto* kernel_out_{i} = SetSparseKernelOutput(&std::get<{i}>(api_output), {output_type_map[out_dtype_list[i]]});"""
)
else: else:
raise ValueError( raise ValueError(
"{} : Output error: the output should not be empty.".format( "{} : Output error: the output should not be empty.".format(
self.api)) self.api
)
)
return kernel_output, output_names, output_create return kernel_output, output_names, output_create
def gen_sparse_kernel_context(self, kernel_output_names): def gen_sparse_kernel_context(self, kernel_output_names):
input_trans_map = { input_trans_map = {
'const Tensor&': 'const Tensor&': 'const phi::TenseBase&',
'const phi::TenseBase&', 'const std::vector<Tensor>&': 'const std::vector<phi::TenseBase>&',
'const std::vector<Tensor>&': 'const paddle::optional<Tensor>&': 'paddle::optional<const phi::TenseBase&>',
'const std::vector<phi::TenseBase>&',
'const paddle::optional<Tensor>&':
'paddle::optional<const phi::TenseBase&>'
} }
out_trans_map = { out_trans_map = {
'Tensor': 'phi::TenseBase*', 'Tensor': 'phi::TenseBase*',
'std::vector<Tensor>': 'std::vector<phi::TenseBase*>' 'std::vector<Tensor>': 'std::vector<phi::TenseBase*>',
} }
input_names = self.inputs['names'] input_names = self.inputs['names']
input_infos = self.inputs['input_info'] input_infos = self.inputs['input_info']
...@@ -111,11 +119,17 @@ class SparseAPI(ForwardAPI): ...@@ -111,11 +119,17 @@ class SparseAPI(ForwardAPI):
for param in kernel_param: for param in kernel_param:
if param in input_names: if param in input_names:
if param in self.optional_vars: if param in self.optional_vars:
kernel_context_code = kernel_context_code + f""" kernel_context_code = (
kernel_context_code
+ f"""
kernel_context.EmplaceBackInput({param} ? {param}->impl().get() : nullptr);""" kernel_context.EmplaceBackInput({param} ? {param}->impl().get() : nullptr);"""
)
else: else:
kernel_context_code = kernel_context_code + f""" kernel_context_code = (
kernel_context_code
+ f"""
kernel_context.EmplaceBackInput({param}.impl().get());""" kernel_context.EmplaceBackInput({param}.impl().get());"""
)
continue continue
if param in attr_names: if param in attr_names:
...@@ -128,12 +142,18 @@ class SparseAPI(ForwardAPI): ...@@ -128,12 +142,18 @@ class SparseAPI(ForwardAPI):
param = str(param).lower() param = str(param).lower()
else: else:
param + str(param) + ", " param + str(param) + ", "
kernel_context_code = kernel_context_code + f""" kernel_context_code = (
kernel_context_code
+ f"""
kernel_context.EmplaceBackAttr({param});""" kernel_context.EmplaceBackAttr({param});"""
)
for out_name in kernel_output_names: for out_name in kernel_output_names:
kernel_context_code = kernel_context_code + f""" kernel_context_code = (
kernel_context_code
+ f"""
kernel_context.EmplaceBackOutput({out_name});""" kernel_context.EmplaceBackOutput({out_name});"""
)
return kernel_context_code return kernel_context_code
...@@ -143,20 +163,25 @@ class SparseAPI(ForwardAPI): ...@@ -143,20 +163,25 @@ class SparseAPI(ForwardAPI):
attr_names = self.attrs['names'] attr_names = self.attrs['names']
infer_meta = self.infer_meta infer_meta = self.infer_meta
infer_meta_params = infer_meta['param'] if infer_meta[ infer_meta_params = (
'param'] is not None else input_names + attr_names infer_meta['param']
if infer_meta['param'] is not None
else input_names + attr_names
)
create_input_var_code = "" create_input_var_code = ""
tensor_type_map = { tensor_type_map = {
'dense': 'phi::DenseTensor', 'dense': 'phi::DenseTensor',
'sparse_coo': 'phi::SparseCooTensor', 'sparse_coo': 'phi::SparseCooTensor',
'sparse_csr': 'phi::SparseCsrTensor' 'sparse_csr': 'phi::SparseCsrTensor',
} }
for param in infer_meta_params: for param in infer_meta_params:
if param in input_names: if param in input_names:
var_name = "auto " + PREFIX_TENSOR_NAME + param + " = " var_name = "auto " + PREFIX_TENSOR_NAME + param + " = "
if self.inputs['input_info'][param] == "const Tensor&": if self.inputs['input_info'][param] == "const Tensor&":
create_input_var_code = create_input_var_code + var_name + param + ".impl();\n" create_input_var_code = (
create_input_var_code + var_name + param + ".impl();\n"
)
elif param in self.optional_vars: elif param in self.optional_vars:
tensor_type = 'phi::DenseTensor' tensor_type = 'phi::DenseTensor'
for name, input_type in zip(input_names, input_types): for name, input_type in zip(input_names, input_types):
...@@ -164,17 +189,35 @@ class SparseAPI(ForwardAPI): ...@@ -164,17 +189,35 @@ class SparseAPI(ForwardAPI):
tensor_type = tensor_type_map[input_type] tensor_type = tensor_type_map[input_type]
break break
optional_var = "paddle::optional<" + tensor_type + ">(" optional_var = "paddle::optional<" + tensor_type + ">("
create_input_var_code = create_input_var_code + var_name + param + " ? " + optional_var + "*static_cast<" + tensor_type + "*>((*" + param + ").impl().get())) : " + optional_var + "paddle::none);\n" create_input_var_code = (
create_input_var_code
+ var_name
+ param
+ " ? "
+ optional_var
+ "*static_cast<"
+ tensor_type
+ "*>((*"
+ param
+ ").impl().get())) : "
+ optional_var
+ "paddle::none);\n"
)
return f"""{create_input_var_code}""" return f"""{create_input_var_code}"""
def gen_sparse_kernel_code(self, kernel_name, inplace_flag=False): def gen_sparse_kernel_code(self, kernel_name, inplace_flag=False):
_, kernel_output_names, output_create = self.gene_output( _, kernel_output_names, output_create = self.gene_output(
self.kernel['dispatch'][kernel_name][1], None, '', inplace_flag) self.kernel['dispatch'][kernel_name][1], None, '', inplace_flag
)
kernel_context_code = self.gen_sparse_kernel_context( kernel_context_code = self.gen_sparse_kernel_context(
kernel_output_names) kernel_output_names
return_code = "" if len( )
self.gene_return_code()) == 0 else " " + self.gene_return_code() return_code = (
""
if len(self.gene_return_code()) == 0
else " " + self.gene_return_code()
)
return f""" return f"""
VLOG(6) << "{self.api} api sparse kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]"; VLOG(6) << "{self.api} api sparse kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
...@@ -192,12 +235,13 @@ class SparseAPI(ForwardAPI): ...@@ -192,12 +235,13 @@ class SparseAPI(ForwardAPI):
{return_code}""" {return_code}"""
def get_condition_code(self, kernel_name): def get_condition_code(self, kernel_name):
assert self.kernel['dispatch'][kernel_name], \ assert self.kernel['dispatch'][
f"{self.api} api: the tensor type of inputs and outputs for kernel isn't set, see also 'kernel:func' of 'conv3d' in sparse_ops.yaml." kernel_name
], f"{self.api} api: the tensor type of inputs and outputs for kernel isn't set, see also 'kernel:func' of 'conv3d' in sparse_ops.yaml."
input_types = self.kernel['dispatch'][kernel_name][0] input_types = self.kernel['dispatch'][kernel_name][0]
sparse_type_map = { sparse_type_map = {
'sparse_coo': 'DataLayout::SPARSE_COO', 'sparse_coo': 'DataLayout::SPARSE_COO',
'sparse_csr': 'DataLayout::SPARSE_CSR' 'sparse_csr': 'DataLayout::SPARSE_CSR',
} }
condition_list = [] condition_list = []
tensor_type_list = [] tensor_type_list = []
...@@ -214,10 +258,12 @@ class SparseAPI(ForwardAPI): ...@@ -214,10 +258,12 @@ class SparseAPI(ForwardAPI):
else: else:
if in_type == 'sparse_coo': if in_type == 'sparse_coo':
condition_list.append( condition_list.append(
f"{self.inputs['names'][i]}.is_sparse_coo_tensor()") f"{self.inputs['names'][i]}.is_sparse_coo_tensor()"
)
else: else:
condition_list.append( condition_list.append(
f"{self.inputs['names'][i]}.is_sparse_csr_tensor()") f"{self.inputs['names'][i]}.is_sparse_csr_tensor()"
)
tensor_type_list.append(in_type) tensor_type_list.append(in_type)
self.inputs['tensor_type'] = tensor_type_list self.inputs['tensor_type'] = tensor_type_list
...@@ -237,10 +283,11 @@ class SparseAPI(ForwardAPI): ...@@ -237,10 +283,11 @@ class SparseAPI(ForwardAPI):
kernel_dispatch_code = f"{self.gene_kernel_select()}\n" kernel_dispatch_code = f"{self.gene_kernel_select()}\n"
for kernel_name in self.kernel['func']: for kernel_name in self.kernel['func']:
kernel_dispatch_code += self.gene_dispatch_code( kernel_dispatch_code += self.gene_dispatch_code(
kernel_name, inplace_flag) kernel_name, inplace_flag
)
return f""" return f"""
PADDLE_API {self.get_return_type()} {api_func_name}({self.get_define_args()}) {{ PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{
{kernel_dispatch_code} {kernel_dispatch_code}
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"The kernel of ({self.api}) for input tensors is unimplemented, please check the type of input tensors.")); "The kernel of ({self.api}) for input tensors is unimplemented, please check the type of input tensors."));
...@@ -283,17 +330,20 @@ def source_include(header_file_path): ...@@ -283,17 +330,20 @@ def source_include(header_file_path):
def api_namespace(): def api_namespace():
return (""" return (
"""
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
namespace sparse { namespace sparse {
""", """ """,
"""
} // namespace sparse } // namespace sparse
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
""") """,
)
def generate_api(api_yaml_path, header_file_path, source_file_path): def generate_api(api_yaml_path, header_file_path, source_file_path):
...@@ -329,18 +379,25 @@ def generate_api(api_yaml_path, header_file_path, source_file_path): ...@@ -329,18 +379,25 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Generate PaddlePaddle C++ Sparse API files') description='Generate PaddlePaddle C++ Sparse API files'
parser.add_argument('--api_yaml_path', )
help='path to sparse api yaml file', parser.add_argument(
default='paddle/phi/api/yaml/sparse_ops.yaml') '--api_yaml_path',
help='path to sparse api yaml file',
parser.add_argument('--api_header_path', default='paddle/phi/api/yaml/sparse_ops.yaml',
help='output of generated api header code file', )
default='paddle/phi/api/include/sparse_api.h')
parser.add_argument(
parser.add_argument('--api_source_path', '--api_header_path',
help='output of generated api source code file', help='output of generated api header code file',
default='paddle/phi/api/lib/sparse_api.cc') default='paddle/phi/api/include/sparse_api.h',
)
parser.add_argument(
'--api_source_path',
help='output of generated api source code file',
default='paddle/phi/api/lib/sparse_api.cc',
)
options = parser.parse_args() options = parser.parse_args()
......
...@@ -367,6 +367,18 @@ ...@@ -367,6 +367,18 @@
func : subtract_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo}, func : subtract_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo},
subtract_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr} subtract_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr}
- backward_op : sync_batch_norm_grad
forward : sync_batch_norm_(Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
args : (Tensor x, Tensor scale, Tensor bias, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor out_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu)
output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param : [x, scale, bias]
kernel :
func : sync_batch_norm_coo_grad{sparse_coo, dense, dense, dense, dense, dense, sparse_coo -> sparse_coo, dense, dense}
data_type : out_grad
optional : reserve_space
- backward_op : tan_grad - backward_op : tan_grad
forward : tan(Tensor x) -> Tensor(out) forward : tan(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad) args : (Tensor x, Tensor out_grad)
......
...@@ -95,6 +95,7 @@ ...@@ -95,6 +95,7 @@
kernel : kernel :
func : batch_norm_coo {sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense} func : batch_norm_coo {sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense}
data_type : x data_type : x
view : (mean -> mean_out), (variance -> variance_out)
backward : batch_norm_grad backward : batch_norm_grad
- op : cast - op : cast
...@@ -378,7 +379,8 @@ ...@@ -378,7 +379,8 @@
args : (Tensor input, Tensor x, Tensor y, float alpha=1.0, float beta=1.0) args : (Tensor input, Tensor x, Tensor y, float alpha=1.0, float beta=1.0)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : AddmmInferMeta func : UnchangedInferMeta
param : [input]
kernel : kernel :
func : addmm_csr_dense {dense, sparse_csr, dense -> dense}, func : addmm_csr_dense {dense, sparse_csr, dense -> dense},
addmm_csr_csr {sparse_csr, sparse_csr, sparse_csr -> sparse_csr}, addmm_csr_csr {sparse_csr, sparse_csr, sparse_csr -> sparse_csr},
...@@ -480,6 +482,17 @@ ...@@ -480,6 +482,17 @@
layout : x layout : x
backward : transpose_grad backward : transpose_grad
- op : sync_batch_norm_
args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu)
output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
infer_meta :
func : BatchNormInferMeta
kernel :
func : sync_batch_norm_coo{sparse_coo, dense, dense, dense, dense -> sparse_coo, dense, dense, dense, dense, dense}
data_type : x
backward : sync_batch_norm_grad
inplace : (mean -> mean_out), (variance -> variance_out)
- op : reshape - op : reshape
args : (Tensor x, IntArray shape) args : (Tensor x, IntArray shape)
output : Tensor(out) output : Tensor(out)
......
...@@ -31,6 +31,7 @@ void EmptyLikeCooKernel(const Context& dev_ctx, ...@@ -31,6 +31,7 @@ void EmptyLikeCooKernel(const Context& dev_ctx,
const DenseTensor& x_values = x.values(); const DenseTensor& x_values = x.values();
DenseTensor* out_values = out->mutable_values(); DenseTensor* out_values = out->mutable_values();
out_values->Resize(x_values.dims()); out_values->Resize(x_values.dims());
out->set_meta(x.meta());
dev_ctx.template Alloc<T>(out_values); dev_ctx.template Alloc<T>(out_values);
} }
...@@ -44,6 +45,7 @@ void EmptyLikeCsrKernel(const Context& dev_ctx, ...@@ -44,6 +45,7 @@ void EmptyLikeCsrKernel(const Context& dev_ctx,
const DenseTensor& x_values = x.values(); const DenseTensor& x_values = x.values();
DenseTensor* out_values = out->mutable_values(); DenseTensor* out_values = out->mutable_values();
out_values->Resize(x_values.dims()); out_values->Resize(x_values.dims());
out->set_meta(x.meta());
dev_ctx.template Alloc<T>(out_values); dev_ctx.template Alloc<T>(out_values);
} }
......
...@@ -169,6 +169,7 @@ void CoalesceGPUKernel(const GPUContext& dev_ctx, ...@@ -169,6 +169,7 @@ void CoalesceGPUKernel(const GPUContext& dev_ctx,
indexs_ptr, const_dims, out_nnz, sparse_dim, out_indices.data<IntT>()); indexs_ptr, const_dims, out_nnz, sparse_dim, out_indices.data<IntT>());
out->SetMember(out_indices, out_values, x.dims(), true); out->SetMember(out_indices, out_values, x.dims(), true);
out->SetIndicesDict(x.GetIndicesDict());
} }
template <typename T, typename Context> template <typename T, typename Context>
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/sync_batch_norm_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sync_batch_norm_grad_kernel.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void SyncBatchNormCooGradKernel(
const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
const paddle::optional<DenseTensor>& reserve_space,
const SparseCooTensor& y_grad,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
SparseCooTensor* x_grad,
DenseTensor* scale_grad,
DenseTensor* bias_grad) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, x_grad);
*scale_grad = phi::EmptyLike<T, Context>(dev_ctx, scale);
*bias_grad = phi::EmptyLike<T, Context>(dev_ctx, bias);
phi::SyncBatchNormGradKernel<T, Context>(dev_ctx,
x.values(),
scale,
bias,
saved_mean,
saved_variance,
reserve_space,
y_grad.values(),
momentum,
epsilon,
data_layout,
is_test,
use_global_stats,
trainable_statistics,
fuse_with_relu,
x_grad->mutable_values(),
scale_grad,
bias_grad);
}
} // namespace sparse
} // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(sync_batch_norm_coo_grad,
GPU,
ALL_LAYOUT,
phi::sparse::SyncBatchNormCooGradKernel,
float,
phi::dtype::float16) {}
#else
PD_REGISTER_KERNEL(sync_batch_norm_coo_grad,
GPU,
ALL_LAYOUT,
phi::sparse::SyncBatchNormCooGradKernel,
float,
double,
phi::dtype::float16) {}
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/sync_batch_norm_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sync_batch_norm_kernel.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void SyncBatchNormCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const DenseTensor& mean,
const DenseTensor& variance,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
SparseCooTensor* y,
DenseTensor* mean_out,
DenseTensor* variance_out,
DenseTensor* saved_mean,
DenseTensor* saved_variance,
DenseTensor* reserve_space) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, y);
phi::SyncBatchNormKernel<T, Context>(dev_ctx,
x.values(),
scale,
bias,
mean,
variance,
momentum,
epsilon,
data_layout,
is_test,
use_global_stats,
trainable_statistics,
fuse_with_relu,
y->mutable_values(),
mean_out,
variance_out,
saved_mean,
saved_variance,
reserve_space);
y->SetIndicesDict(x.GetIndicesDict());
}
} // namespace sparse
} // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(sync_batch_norm_coo,
GPU,
ALL_LAYOUT,
phi::sparse::SyncBatchNormCooKernel,
float,
phi::dtype::float16) {}
#else
PD_REGISTER_KERNEL(sync_batch_norm_coo,
GPU,
ALL_LAYOUT,
phi::sparse::SyncBatchNormCooKernel,
float,
double,
phi::dtype::float16) {}
#endif
...@@ -37,6 +37,7 @@ namespace sparse { ...@@ -37,6 +37,7 @@ namespace sparse {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, out); \ EmptyLikeCooKernel<T, Context>(dev_ctx, x, out); \
phi::prefix##Kernel<T, Context>( \ phi::prefix##Kernel<T, Context>( \
dev_ctx, x.non_zero_elements(), out->mutable_non_zero_elements()); \ dev_ctx, x.non_zero_elements(), out->mutable_non_zero_elements()); \
out->SetIndicesDict(x.GetIndicesDict()); \
} \ } \
\ \
template <typename T, typename Context> \ template <typename T, typename Context> \
...@@ -105,6 +106,7 @@ void ScaleCooKernel(const Context& dev_ctx, ...@@ -105,6 +106,7 @@ void ScaleCooKernel(const Context& dev_ctx,
bias, bias,
bias_after_scale, bias_after_scale,
out->mutable_non_zero_elements()); out->mutable_non_zero_elements());
out->SetIndicesDict(x.GetIndicesDict());
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -155,6 +157,7 @@ void CastCooKernel(const Context& dev_ctx, ...@@ -155,6 +157,7 @@ void CastCooKernel(const Context& dev_ctx,
meta.set_dtype(value_dtype); meta.set_dtype(value_dtype);
phi::CastKernel<T, Context>(dev_ctx, x_values, value_dtype, out_values); phi::CastKernel<T, Context>(dev_ctx, x_values, value_dtype, out_values);
} }
out->SetIndicesDict(x.GetIndicesDict());
} }
template <typename T, typename Context> template <typename T, typename Context>
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void SyncBatchNormCooGradKernel(
const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
const paddle::optional<DenseTensor>& reserve_space,
const SparseCooTensor& y_grad,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
SparseCooTensor* x_grad,
DenseTensor* scale_grad,
DenseTensor* bias_grad);
} // namespace sparse
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void SyncBatchNormCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
const DenseTensor& mean,
const DenseTensor& variance,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu,
SparseCooTensor* y,
DenseTensor* mean_out,
DenseTensor* variance_out,
DenseTensor* saved_mean,
DenseTensor* saved_variance,
DenseTensor* reserve_space);
} // namespace sparse
} // namespace phi
...@@ -323,13 +323,22 @@ class SyncBatchNorm(paddle.nn.SyncBatchNorm): ...@@ -323,13 +323,22 @@ class SyncBatchNorm(paddle.nn.SyncBatchNorm):
) )
def forward(self, x): def forward(self, x):
assert ( self._check_data_format()
x.is_sparse_coo() sync_batch_norm_out, _, _, _, _, _ = _C_ops.sparse_sync_batch_norm_(
), "SyncBatchNorm only support SparseTensor in COO format." x,
out = super(SyncBatchNorm, self).forward(x.values()) self.weight,
return paddle.sparse.sparse_coo_tensor( self.bias,
x.indices(), out, shape=x.shape, stop_gradient=x.stop_gradient self._mean,
self._variance,
self._momentum,
self._epsilon,
self._data_format,
not self.training,
False,
False,
False,
) )
return sync_batch_norm_out
@classmethod @classmethod
def convert_sync_batchnorm(cls, layer): def convert_sync_batchnorm(cls, layer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册