未验证 提交 3f280ab3 编写于 作者: Z zyfncg 提交者: GitHub

Refine set kernel output (#45573)

* support selected_rows kernel for multiply in dygraph

* refine SetKernelOutput
上级 45171911
...@@ -65,7 +65,7 @@ Tensor embedding_impl(const Tensor& x, ...@@ -65,7 +65,7 @@ Tensor embedding_impl(const Tensor& x,
auto input_x = PrepareData(x, kernel.InputAt(0), {}); auto input_x = PrepareData(x, kernel.InputAt(0), {});
auto input_weight = PrepareData(weight, kernel.InputAt(1), {}); auto input_weight = PrepareData(weight, kernel.InputAt(1), {});
auto* kernel_out = SetKernelOutput(kernel_key.backend(), &api_output); auto* kernel_out = SetKernelOutput(&api_output);
phi::MetaTensor meta_out(kernel_out); phi::MetaTensor meta_out(kernel_out);
phi::EmbeddingInferMeta(MakeMetaTensor(*input_x), phi::EmbeddingInferMeta(MakeMetaTensor(*input_x),
...@@ -94,7 +94,7 @@ Tensor embedding_impl(const Tensor& x, ...@@ -94,7 +94,7 @@ Tensor embedding_impl(const Tensor& x,
auto input_x = PrepareData(x, kernel.InputAt(0), {}); auto input_x = PrepareData(x, kernel.InputAt(0), {});
auto input_weight = TensorToSelectedRows(weight); auto input_weight = TensorToSelectedRows(weight);
auto* kernel_out = SetKernelOutput(kernel_key.backend(), &api_output); auto* kernel_out = SetKernelOutput(&api_output);
phi::MetaTensor meta_out(kernel_out); phi::MetaTensor meta_out(kernel_out);
phi::EmbeddingInferMeta(MakeMetaTensor(*input_x), phi::EmbeddingInferMeta(MakeMetaTensor(*input_x),
...@@ -150,7 +150,7 @@ std::vector<Tensor> split_impl(const Tensor& x, ...@@ -150,7 +150,7 @@ std::vector<Tensor> split_impl(const Tensor& x,
} }
std::vector<Tensor> out; std::vector<Tensor> out;
auto dense_outs = SetKernelOutput(out_number, kernel_backend, &out); auto dense_outs = SetKernelOutput(out_number, &out);
std::vector<phi::MetaTensor> meta_outs; std::vector<phi::MetaTensor> meta_outs;
meta_outs.reserve(out_number); meta_outs.reserve(out_number);
std::vector<phi::MetaTensor*> meta_out_ptrs; std::vector<phi::MetaTensor*> meta_out_ptrs;
...@@ -231,14 +231,14 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl( ...@@ -231,14 +231,14 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl(
auto input_variance = PrepareData(variance, kernel.InputAt(4), {}); auto input_variance = PrepareData(variance, kernel.InputAt(4), {});
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> api_output; std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> api_output;
auto kernel_out_0 = SetKernelOutput(kernel_backend, &std::get<0>(api_output)); auto kernel_out_0 = SetKernelOutput(&std::get<0>(api_output));
std::get<1>(api_output).set_impl(mean.impl()); std::get<1>(api_output).set_impl(mean.impl());
std::get<2>(api_output).set_impl(variance.impl()); std::get<2>(api_output).set_impl(variance.impl());
auto kernel_out_1 = SetKernelOutput(kernel_backend, &std::get<1>(api_output)); auto kernel_out_1 = SetKernelOutput(&std::get<1>(api_output));
auto kernel_out_2 = SetKernelOutput(kernel_backend, &std::get<2>(api_output)); auto kernel_out_2 = SetKernelOutput(&std::get<2>(api_output));
auto kernel_out_3 = SetKernelOutput(kernel_backend, &std::get<3>(api_output)); auto kernel_out_3 = SetKernelOutput(&std::get<3>(api_output));
auto kernel_out_4 = SetKernelOutput(kernel_backend, &std::get<4>(api_output)); auto kernel_out_4 = SetKernelOutput(&std::get<4>(api_output));
auto kernel_out_5 = SetKernelOutput(kernel_backend, &std::get<5>(api_output)); auto kernel_out_5 = SetKernelOutput(&std::get<5>(api_output));
phi::MetaTensor meta_out_0(kernel_out_0); phi::MetaTensor meta_out_0(kernel_out_0);
phi::MetaTensor meta_out_1(kernel_out_1); phi::MetaTensor meta_out_1(kernel_out_1);
phi::MetaTensor meta_out_2(kernel_out_2); phi::MetaTensor meta_out_2(kernel_out_2);
...@@ -325,7 +325,7 @@ void imag_grad_impl(const Tensor& out_grad, Tensor* x_grad) { ...@@ -325,7 +325,7 @@ void imag_grad_impl(const Tensor& out_grad, Tensor* x_grad) {
auto dense_out_grad = TensorToDenseTensor(out_grad); auto dense_out_grad = TensorToDenseTensor(out_grad);
auto kernel_out = SetKernelOutput(kernel_key.backend(), x_grad); auto kernel_out = SetKernelOutput(x_grad);
phi::MetaTensor meta_out(kernel_out); phi::MetaTensor meta_out(kernel_out);
phi::RealAndImagGradInferMeta(*dense_out_grad, &meta_out); phi::RealAndImagGradInferMeta(*dense_out_grad, &meta_out);
...@@ -365,8 +365,7 @@ void embedding_grad_impl(const Tensor& x, ...@@ -365,8 +365,7 @@ void embedding_grad_impl(const Tensor& x,
auto input_out_grad = PrepareData(out_grad, kernel.InputAt(2), {}); auto input_out_grad = PrepareData(out_grad, kernel.InputAt(2), {});
if (sparse) { if (sparse) {
auto* kernel_out = auto* kernel_out = SetSelectedRowsKernelOutput(weight_grad);
SetSelectedRowsKernelOutput(kernel_key.backend(), weight_grad);
phi::MetaTensor meta_out(kernel_out); phi::MetaTensor meta_out(kernel_out);
meta_out.set_dims(input_weight->dims()); meta_out.set_dims(input_weight->dims());
meta_out.set_dtype(input_weight->dtype()); meta_out.set_dtype(input_weight->dtype());
...@@ -386,7 +385,7 @@ void embedding_grad_impl(const Tensor& x, ...@@ -386,7 +385,7 @@ void embedding_grad_impl(const Tensor& x,
padding_idx, padding_idx,
kernel_out); kernel_out);
} else { } else {
auto* kernel_out = SetKernelOutput(kernel_key.backend(), weight_grad); auto* kernel_out = SetKernelOutput(weight_grad);
phi::MetaTensor meta_out(kernel_out); phi::MetaTensor meta_out(kernel_out);
phi::UnchangedInferMeta(MakeMetaTensor(*input_weight), &meta_out); phi::UnchangedInferMeta(MakeMetaTensor(*input_weight), &meta_out);
using kernel_signature = void (*)(const platform::DeviceContext&, using kernel_signature = void (*)(const platform::DeviceContext&,
...@@ -418,8 +417,7 @@ void embedding_grad_impl(const Tensor& x, ...@@ -418,8 +417,7 @@ void embedding_grad_impl(const Tensor& x,
auto input_out_grad = PrepareData(out_grad, kernel.InputAt(2), {}); auto input_out_grad = PrepareData(out_grad, kernel.InputAt(2), {});
if (sparse) { if (sparse) {
auto* kernel_out = auto* kernel_out = SetSelectedRowsKernelOutput(weight_grad);
SetSelectedRowsKernelOutput(kernel_key.backend(), weight_grad);
phi::MetaTensor meta_out(kernel_out); phi::MetaTensor meta_out(kernel_out);
phi::UnchangedInferMeta(MakeMetaTensor(*input_weight), &meta_out); phi::UnchangedInferMeta(MakeMetaTensor(*input_weight), &meta_out);
using kernel_signature = void (*)(const platform::DeviceContext&, using kernel_signature = void (*)(const platform::DeviceContext&,
...@@ -436,7 +434,7 @@ void embedding_grad_impl(const Tensor& x, ...@@ -436,7 +434,7 @@ void embedding_grad_impl(const Tensor& x,
padding_idx, padding_idx,
kernel_out); kernel_out);
} else { } else {
auto* kernel_out = SetKernelOutput(kernel_key.backend(), weight_grad); auto* kernel_out = SetKernelOutput(weight_grad);
phi::MetaTensor meta_out(kernel_out); phi::MetaTensor meta_out(kernel_out);
meta_out.set_dims(input_weight->GetCompleteDims()); meta_out.set_dims(input_weight->GetCompleteDims());
meta_out.set_dtype(input_weight->dtype()); meta_out.set_dtype(input_weight->dtype());
...@@ -472,7 +470,7 @@ void real_grad_impl(const Tensor& out_grad, Tensor* x_grad) { ...@@ -472,7 +470,7 @@ void real_grad_impl(const Tensor& out_grad, Tensor* x_grad) {
auto dense_out_grad = TensorToDenseTensor(out_grad); auto dense_out_grad = TensorToDenseTensor(out_grad);
auto kernel_out = SetKernelOutput(kernel_key.backend(), x_grad); auto kernel_out = SetKernelOutput(x_grad);
phi::MetaTensor meta_out(kernel_out); phi::MetaTensor meta_out(kernel_out);
phi::RealAndImagGradInferMeta(*dense_out_grad, &meta_out); phi::RealAndImagGradInferMeta(*dense_out_grad, &meta_out);
......
...@@ -144,7 +144,7 @@ std::vector<phi::MetaTensor> MakeMetaTensor( ...@@ -144,7 +144,7 @@ std::vector<phi::MetaTensor> MakeMetaTensor(
/* ------------------ for output ----------------------- */ /* ------------------ for output ----------------------- */
phi::DenseTensor* SetKernelOutput(Backend backend, Tensor* out) { phi::DenseTensor* SetKernelOutput(Tensor* out) {
if (out) { if (out) {
if (out->impl() == nullptr) { if (out->impl() == nullptr) {
out->set_impl(std::make_shared<phi::DenseTensor>()); out->set_impl(std::make_shared<phi::DenseTensor>());
...@@ -155,7 +155,6 @@ phi::DenseTensor* SetKernelOutput(Backend backend, Tensor* out) { ...@@ -155,7 +155,6 @@ phi::DenseTensor* SetKernelOutput(Backend backend, Tensor* out) {
} }
std::vector<phi::DenseTensor*> SetKernelOutput(size_t out_size, std::vector<phi::DenseTensor*> SetKernelOutput(size_t out_size,
Backend backend,
std::vector<Tensor>* out) { std::vector<Tensor>* out) {
out->reserve(out_size); out->reserve(out_size);
std::vector<phi::DenseTensor*> results(out_size); std::vector<phi::DenseTensor*> results(out_size);
...@@ -169,7 +168,7 @@ std::vector<phi::DenseTensor*> SetKernelOutput(size_t out_size, ...@@ -169,7 +168,7 @@ std::vector<phi::DenseTensor*> SetKernelOutput(size_t out_size,
} }
std::vector<phi::DenseTensor*> SetInplaceVectorKernelOutput( std::vector<phi::DenseTensor*> SetInplaceVectorKernelOutput(
size_t out_size, Backend backend, std::vector<Tensor>* out) { size_t out_size, std::vector<Tensor>* out) {
std::vector<phi::DenseTensor*> results(out->size(), nullptr); std::vector<phi::DenseTensor*> results(out->size(), nullptr);
for (size_t i = 0; i < out->size(); ++i) { for (size_t i = 0; i < out->size(); ++i) {
results[i] = static_cast<phi::DenseTensor*>(out->at(i).impl().get()); results[i] = static_cast<phi::DenseTensor*>(out->at(i).impl().get());
...@@ -178,9 +177,7 @@ std::vector<phi::DenseTensor*> SetInplaceVectorKernelOutput( ...@@ -178,9 +177,7 @@ std::vector<phi::DenseTensor*> SetInplaceVectorKernelOutput(
} }
std::vector<phi::DenseTensor*> SetInplaceOptionalVectorKernelOutput( std::vector<phi::DenseTensor*> SetInplaceOptionalVectorKernelOutput(
size_t out_size, size_t out_size, const paddle::optional<std::vector<Tensor>>& out) {
Backend backend,
const paddle::optional<std::vector<Tensor>>& out) {
std::vector<phi::DenseTensor*> results; std::vector<phi::DenseTensor*> results;
if (out) { if (out) {
results = std::vector<phi::DenseTensor*>(out->size(), nullptr); results = std::vector<phi::DenseTensor*>(out->size(), nullptr);
...@@ -203,7 +200,7 @@ std::vector<phi::DenseTensor*> SetKernelOutput(std::vector<Tensor*>* out) { ...@@ -203,7 +200,7 @@ std::vector<phi::DenseTensor*> SetKernelOutput(std::vector<Tensor*>* out) {
return results; return results;
} }
phi::SelectedRows* SetSelectedRowsKernelOutput(Backend backend, Tensor* out) { phi::SelectedRows* SetSelectedRowsKernelOutput(Tensor* out) {
if (!out->initialized()) { if (!out->initialized()) {
auto select_rows = std::make_shared<phi::SelectedRows>(); auto select_rows = std::make_shared<phi::SelectedRows>();
out->set_impl(select_rows); out->set_impl(select_rows);
...@@ -236,9 +233,7 @@ phi::TensorBase* SetSparseKernelOutput(Tensor* out, TensorType type) { ...@@ -236,9 +233,7 @@ phi::TensorBase* SetSparseKernelOutput(Tensor* out, TensorType type) {
return out->impl().get(); return out->impl().get();
} }
phi::TensorBase* SetStringsKernelOutput(Backend backend, phi::TensorBase* SetStringsKernelOutput(Tensor* out, TensorType type) {
Tensor* out,
TensorType type) {
if (!out->initialized()) { if (!out->initialized()) {
if (type == TensorType::STRING_TENSOR) { if (type == TensorType::STRING_TENSOR) {
if (out->impl() == nullptr) { if (out->impl() == nullptr) {
......
...@@ -73,30 +73,25 @@ std::vector<phi::MetaTensor> MakeMetaTensor( ...@@ -73,30 +73,25 @@ std::vector<phi::MetaTensor> MakeMetaTensor(
/* ------------------ for output ----------------------- */ /* ------------------ for output ----------------------- */
phi::DenseTensor* SetKernelOutput(Backend backend, Tensor* out); phi::DenseTensor* SetKernelOutput(Tensor* out);
std::vector<phi::DenseTensor*> SetKernelOutput(size_t out_size, std::vector<phi::DenseTensor*> SetKernelOutput(size_t out_size,
Backend backend,
std::vector<Tensor>* out); std::vector<Tensor>* out);
std::vector<phi::DenseTensor*> SetInplaceVectorKernelOutput( std::vector<phi::DenseTensor*> SetInplaceVectorKernelOutput(
size_t out_size, Backend backend, std::vector<Tensor>* out); size_t out_size, std::vector<Tensor>* out);
std::vector<phi::DenseTensor*> SetInplaceOptionalVectorKernelOutput( std::vector<phi::DenseTensor*> SetInplaceOptionalVectorKernelOutput(
size_t out_size, size_t out_size, const paddle::optional<std::vector<Tensor>>& out);
Backend backend,
const paddle::optional<std::vector<Tensor>>& out);
// For backward api // For backward api
std::vector<phi::DenseTensor*> SetKernelOutput(std::vector<Tensor*>* out); std::vector<phi::DenseTensor*> SetKernelOutput(std::vector<Tensor*>* out);
phi::SelectedRows* SetSelectedRowsKernelOutput(Backend backend, Tensor* out); phi::SelectedRows* SetSelectedRowsKernelOutput(Tensor* out);
phi::TensorBase* SetSparseKernelOutput(Tensor* out, TensorType type); phi::TensorBase* SetSparseKernelOutput(Tensor* out, TensorType type);
phi::TensorBase* SetStringsKernelOutput(Backend backend, phi::TensorBase* SetStringsKernelOutput(Tensor* out, TensorType type);
Tensor* out,
TensorType type);
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
...@@ -41,7 +41,7 @@ void copy(const Tensor& src, const Place& place, bool blocking, Tensor* dst) { ...@@ -41,7 +41,7 @@ void copy(const Tensor& src, const Place& place, bool blocking, Tensor* dst) {
auto dense_x = TensorToDenseTensor(src); auto dense_x = TensorToDenseTensor(src);
auto kernel_out = SetKernelOutput(kernel_key.backend(), dst); auto kernel_out = SetKernelOutput(dst);
phi::MetaTensor meta_out(kernel_out); phi::MetaTensor meta_out(kernel_out);
phi::UnchangedInferMeta(*dense_x, &meta_out); phi::UnchangedInferMeta(*dense_x, &meta_out);
......
...@@ -140,7 +140,7 @@ void Tensor::copy_(const Tensor &src, ...@@ -140,7 +140,7 @@ void Tensor::copy_(const Tensor &src,
} }
if (kernel_type == KernelType::DENSE_TENSOR_KENREL) { if (kernel_type == KernelType::DENSE_TENSOR_KENREL) {
SetKernelOutput(kernel_backend, this); SetKernelOutput(this);
phi::MetaTensor meta_out(impl_.get()); phi::MetaTensor meta_out(impl_.get());
phi::UnchangedInferMeta( phi::UnchangedInferMeta(
MakeMetaTensor( MakeMetaTensor(
...@@ -152,7 +152,7 @@ void Tensor::copy_(const Tensor &src, ...@@ -152,7 +152,7 @@ void Tensor::copy_(const Tensor &src,
blocking, blocking,
static_cast<phi::DenseTensor *>(impl_.get())); static_cast<phi::DenseTensor *>(impl_.get()));
} else if (kernel_type == KernelType::SELECTED_ROWS_KENREL) { } else if (kernel_type == KernelType::SELECTED_ROWS_KENREL) {
SetSelectedRowsKernelOutput(kernel_backend, this); SetSelectedRowsKernelOutput(this);
phi::MetaTensor meta_out(impl_.get()); phi::MetaTensor meta_out(impl_.get());
phi::UnchangedInferMeta( phi::UnchangedInferMeta(
MakeMetaTensor( MakeMetaTensor(
......
...@@ -156,11 +156,11 @@ class ForwardAPI(BaseAPI): ...@@ -156,11 +156,11 @@ class ForwardAPI(BaseAPI):
assert self.outputs['out_size_expr'][0] is not None, \ assert self.outputs['out_size_expr'][0] is not None, \
f"{self.api}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api." f"{self.api}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr'][0]}, kernel_backend, &api_output);""" {code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr'][0]}, &api_output);"""
else: else:
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} auto kernel_out = {set_out_func}(kernel_backend, &api_output);""" {code_indent} auto kernel_out = {set_out_func}(&api_output);"""
if not inplace_flag and self.view_map is not None and self.outputs[ if not inplace_flag and self.view_map is not None and self.outputs[
'names'][0] in self.view_map: 'names'][0] in self.view_map:
...@@ -207,11 +207,11 @@ class ForwardAPI(BaseAPI): ...@@ -207,11 +207,11 @@ class ForwardAPI(BaseAPI):
set_out_func = "SetInplaceOptionalVectorKernelOutput" set_out_func = "SetInplaceOptionalVectorKernelOutput"
get_out_code = f"std::get<{i}>(api_output)" get_out_code = f"std::get<{i}>(api_output)"
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} auto kernel_out_{i} = {set_out_func}({self.outputs['out_size_expr'][i]}, kernel_backend, {get_out_code});""" {code_indent} auto kernel_out_{i} = {set_out_func}({self.outputs['out_size_expr'][i]}, {get_out_code});"""
else: else:
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, {get_out_code});""" {code_indent} auto kernel_out_{i} = {set_out_func}({get_out_code});"""
if not inplace_flag and self.view_map is not None and self.outputs[ if not inplace_flag and self.view_map is not None and self.outputs[
'names'][i] in self.view_map: 'names'][i] in self.view_map:
......
...@@ -151,7 +151,7 @@ PADDLE_API void {api_func_name}({self.get_declare_args()}); ...@@ -151,7 +151,7 @@ PADDLE_API void {api_func_name}({self.get_declare_args()});
else: else:
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} auto kernel_out = {set_out_func}(kernel_backend, {self.outputs['names'][0]});""" {code_indent} auto kernel_out = {set_out_func}({self.outputs['names'][0]});"""
elif len(out_dtype_list) > 1: elif len(out_dtype_list) > 1:
output_create = "" output_create = ""
...@@ -167,7 +167,7 @@ PADDLE_API void {api_func_name}({self.get_declare_args()}); ...@@ -167,7 +167,7 @@ PADDLE_API void {api_func_name}({self.get_declare_args()});
{code_indent} *{self.outputs['names'][i]} = {self.inplace_map[self.outputs['names'][i]]};""" {code_indent} *{self.outputs['names'][i]} = {self.inplace_map[self.outputs['names'][i]]};"""
output_create = output_create + f""" output_create = output_create + f"""
{code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, {self.outputs['names'][i]});""" {code_indent} auto kernel_out_{i} = {set_out_func}({self.outputs['names'][i]});"""
else: else:
if inplace_flag and self.inplace_map is not None and self.outputs[ if inplace_flag and self.inplace_map is not None and self.outputs[
......
...@@ -71,7 +71,7 @@ class StringsAPI(ForwardAPI): ...@@ -71,7 +71,7 @@ class StringsAPI(ForwardAPI):
'names'][0] in self.inplace_map else "" 'names'][0] in self.inplace_map else ""
output_create = f""" output_create = f"""
{return_type} api_output{inplace_assign}; {return_type} api_output{inplace_assign};
{tensor_type}* kernel_out = dynamic_cast<{tensor_type}*>(SetStringsKernelOutput(kernel_backend, &api_output, {kernel_tensor_out_type}));""" {tensor_type}* kernel_out = dynamic_cast<{tensor_type}*>(SetStringsKernelOutput(&api_output, {kernel_tensor_out_type}));"""
elif len(out_dtype_list) > 1: elif len(out_dtype_list) > 1:
output_create = f""" output_create = f"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册