未验证 提交 3f280ab3 编写于 作者: Z zyfncg 提交者: GitHub

Refine set kernel output (#45573)

* support selected_rows kernel for multiply in dygraph

* refine SetKernelOutput
上级 45171911
......@@ -65,7 +65,7 @@ Tensor embedding_impl(const Tensor& x,
auto input_x = PrepareData(x, kernel.InputAt(0), {});
auto input_weight = PrepareData(weight, kernel.InputAt(1), {});
auto* kernel_out = SetKernelOutput(kernel_key.backend(), &api_output);
auto* kernel_out = SetKernelOutput(&api_output);
phi::MetaTensor meta_out(kernel_out);
phi::EmbeddingInferMeta(MakeMetaTensor(*input_x),
......@@ -94,7 +94,7 @@ Tensor embedding_impl(const Tensor& x,
auto input_x = PrepareData(x, kernel.InputAt(0), {});
auto input_weight = TensorToSelectedRows(weight);
auto* kernel_out = SetKernelOutput(kernel_key.backend(), &api_output);
auto* kernel_out = SetKernelOutput(&api_output);
phi::MetaTensor meta_out(kernel_out);
phi::EmbeddingInferMeta(MakeMetaTensor(*input_x),
......@@ -150,7 +150,7 @@ std::vector<Tensor> split_impl(const Tensor& x,
}
std::vector<Tensor> out;
auto dense_outs = SetKernelOutput(out_number, kernel_backend, &out);
auto dense_outs = SetKernelOutput(out_number, &out);
std::vector<phi::MetaTensor> meta_outs;
meta_outs.reserve(out_number);
std::vector<phi::MetaTensor*> meta_out_ptrs;
......@@ -231,14 +231,14 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl(
auto input_variance = PrepareData(variance, kernel.InputAt(4), {});
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> api_output;
auto kernel_out_0 = SetKernelOutput(kernel_backend, &std::get<0>(api_output));
auto kernel_out_0 = SetKernelOutput(&std::get<0>(api_output));
std::get<1>(api_output).set_impl(mean.impl());
std::get<2>(api_output).set_impl(variance.impl());
auto kernel_out_1 = SetKernelOutput(kernel_backend, &std::get<1>(api_output));
auto kernel_out_2 = SetKernelOutput(kernel_backend, &std::get<2>(api_output));
auto kernel_out_3 = SetKernelOutput(kernel_backend, &std::get<3>(api_output));
auto kernel_out_4 = SetKernelOutput(kernel_backend, &std::get<4>(api_output));
auto kernel_out_5 = SetKernelOutput(kernel_backend, &std::get<5>(api_output));
auto kernel_out_1 = SetKernelOutput(&std::get<1>(api_output));
auto kernel_out_2 = SetKernelOutput(&std::get<2>(api_output));
auto kernel_out_3 = SetKernelOutput(&std::get<3>(api_output));
auto kernel_out_4 = SetKernelOutput(&std::get<4>(api_output));
auto kernel_out_5 = SetKernelOutput(&std::get<5>(api_output));
phi::MetaTensor meta_out_0(kernel_out_0);
phi::MetaTensor meta_out_1(kernel_out_1);
phi::MetaTensor meta_out_2(kernel_out_2);
......@@ -325,7 +325,7 @@ void imag_grad_impl(const Tensor& out_grad, Tensor* x_grad) {
auto dense_out_grad = TensorToDenseTensor(out_grad);
auto kernel_out = SetKernelOutput(kernel_key.backend(), x_grad);
auto kernel_out = SetKernelOutput(x_grad);
phi::MetaTensor meta_out(kernel_out);
phi::RealAndImagGradInferMeta(*dense_out_grad, &meta_out);
......@@ -365,8 +365,7 @@ void embedding_grad_impl(const Tensor& x,
auto input_out_grad = PrepareData(out_grad, kernel.InputAt(2), {});
if (sparse) {
auto* kernel_out =
SetSelectedRowsKernelOutput(kernel_key.backend(), weight_grad);
auto* kernel_out = SetSelectedRowsKernelOutput(weight_grad);
phi::MetaTensor meta_out(kernel_out);
meta_out.set_dims(input_weight->dims());
meta_out.set_dtype(input_weight->dtype());
......@@ -386,7 +385,7 @@ void embedding_grad_impl(const Tensor& x,
padding_idx,
kernel_out);
} else {
auto* kernel_out = SetKernelOutput(kernel_key.backend(), weight_grad);
auto* kernel_out = SetKernelOutput(weight_grad);
phi::MetaTensor meta_out(kernel_out);
phi::UnchangedInferMeta(MakeMetaTensor(*input_weight), &meta_out);
using kernel_signature = void (*)(const platform::DeviceContext&,
......@@ -418,8 +417,7 @@ void embedding_grad_impl(const Tensor& x,
auto input_out_grad = PrepareData(out_grad, kernel.InputAt(2), {});
if (sparse) {
auto* kernel_out =
SetSelectedRowsKernelOutput(kernel_key.backend(), weight_grad);
auto* kernel_out = SetSelectedRowsKernelOutput(weight_grad);
phi::MetaTensor meta_out(kernel_out);
phi::UnchangedInferMeta(MakeMetaTensor(*input_weight), &meta_out);
using kernel_signature = void (*)(const platform::DeviceContext&,
......@@ -436,7 +434,7 @@ void embedding_grad_impl(const Tensor& x,
padding_idx,
kernel_out);
} else {
auto* kernel_out = SetKernelOutput(kernel_key.backend(), weight_grad);
auto* kernel_out = SetKernelOutput(weight_grad);
phi::MetaTensor meta_out(kernel_out);
meta_out.set_dims(input_weight->GetCompleteDims());
meta_out.set_dtype(input_weight->dtype());
......@@ -472,7 +470,7 @@ void real_grad_impl(const Tensor& out_grad, Tensor* x_grad) {
auto dense_out_grad = TensorToDenseTensor(out_grad);
auto kernel_out = SetKernelOutput(kernel_key.backend(), x_grad);
auto kernel_out = SetKernelOutput(x_grad);
phi::MetaTensor meta_out(kernel_out);
phi::RealAndImagGradInferMeta(*dense_out_grad, &meta_out);
......
......@@ -144,7 +144,7 @@ std::vector<phi::MetaTensor> MakeMetaTensor(
/* ------------------ for output ----------------------- */
phi::DenseTensor* SetKernelOutput(Backend backend, Tensor* out) {
phi::DenseTensor* SetKernelOutput(Tensor* out) {
if (out) {
if (out->impl() == nullptr) {
out->set_impl(std::make_shared<phi::DenseTensor>());
......@@ -155,7 +155,6 @@ phi::DenseTensor* SetKernelOutput(Backend backend, Tensor* out) {
}
std::vector<phi::DenseTensor*> SetKernelOutput(size_t out_size,
Backend backend,
std::vector<Tensor>* out) {
out->reserve(out_size);
std::vector<phi::DenseTensor*> results(out_size);
......@@ -169,7 +168,7 @@ std::vector<phi::DenseTensor*> SetKernelOutput(size_t out_size,
}
std::vector<phi::DenseTensor*> SetInplaceVectorKernelOutput(
size_t out_size, Backend backend, std::vector<Tensor>* out) {
size_t out_size, std::vector<Tensor>* out) {
std::vector<phi::DenseTensor*> results(out->size(), nullptr);
for (size_t i = 0; i < out->size(); ++i) {
results[i] = static_cast<phi::DenseTensor*>(out->at(i).impl().get());
......@@ -178,9 +177,7 @@ std::vector<phi::DenseTensor*> SetInplaceVectorKernelOutput(
}
std::vector<phi::DenseTensor*> SetInplaceOptionalVectorKernelOutput(
size_t out_size,
Backend backend,
const paddle::optional<std::vector<Tensor>>& out) {
size_t out_size, const paddle::optional<std::vector<Tensor>>& out) {
std::vector<phi::DenseTensor*> results;
if (out) {
results = std::vector<phi::DenseTensor*>(out->size(), nullptr);
......@@ -203,7 +200,7 @@ std::vector<phi::DenseTensor*> SetKernelOutput(std::vector<Tensor*>* out) {
return results;
}
phi::SelectedRows* SetSelectedRowsKernelOutput(Backend backend, Tensor* out) {
phi::SelectedRows* SetSelectedRowsKernelOutput(Tensor* out) {
if (!out->initialized()) {
auto select_rows = std::make_shared<phi::SelectedRows>();
out->set_impl(select_rows);
......@@ -236,9 +233,7 @@ phi::TensorBase* SetSparseKernelOutput(Tensor* out, TensorType type) {
return out->impl().get();
}
phi::TensorBase* SetStringsKernelOutput(Backend backend,
Tensor* out,
TensorType type) {
phi::TensorBase* SetStringsKernelOutput(Tensor* out, TensorType type) {
if (!out->initialized()) {
if (type == TensorType::STRING_TENSOR) {
if (out->impl() == nullptr) {
......
......@@ -73,30 +73,25 @@ std::vector<phi::MetaTensor> MakeMetaTensor(
/* ------------------ for output ----------------------- */
phi::DenseTensor* SetKernelOutput(Backend backend, Tensor* out);
phi::DenseTensor* SetKernelOutput(Tensor* out);
std::vector<phi::DenseTensor*> SetKernelOutput(size_t out_size,
Backend backend,
std::vector<Tensor>* out);
std::vector<phi::DenseTensor*> SetInplaceVectorKernelOutput(
size_t out_size, Backend backend, std::vector<Tensor>* out);
size_t out_size, std::vector<Tensor>* out);
std::vector<phi::DenseTensor*> SetInplaceOptionalVectorKernelOutput(
size_t out_size,
Backend backend,
const paddle::optional<std::vector<Tensor>>& out);
size_t out_size, const paddle::optional<std::vector<Tensor>>& out);
// For backward api
std::vector<phi::DenseTensor*> SetKernelOutput(std::vector<Tensor*>* out);
phi::SelectedRows* SetSelectedRowsKernelOutput(Backend backend, Tensor* out);
phi::SelectedRows* SetSelectedRowsKernelOutput(Tensor* out);
phi::TensorBase* SetSparseKernelOutput(Tensor* out, TensorType type);
phi::TensorBase* SetStringsKernelOutput(Backend backend,
Tensor* out,
TensorType type);
phi::TensorBase* SetStringsKernelOutput(Tensor* out, TensorType type);
} // namespace experimental
} // namespace paddle
......@@ -41,7 +41,7 @@ void copy(const Tensor& src, const Place& place, bool blocking, Tensor* dst) {
auto dense_x = TensorToDenseTensor(src);
auto kernel_out = SetKernelOutput(kernel_key.backend(), dst);
auto kernel_out = SetKernelOutput(dst);
phi::MetaTensor meta_out(kernel_out);
phi::UnchangedInferMeta(*dense_x, &meta_out);
......
......@@ -140,7 +140,7 @@ void Tensor::copy_(const Tensor &src,
}
if (kernel_type == KernelType::DENSE_TENSOR_KENREL) {
SetKernelOutput(kernel_backend, this);
SetKernelOutput(this);
phi::MetaTensor meta_out(impl_.get());
phi::UnchangedInferMeta(
MakeMetaTensor(
......@@ -152,7 +152,7 @@ void Tensor::copy_(const Tensor &src,
blocking,
static_cast<phi::DenseTensor *>(impl_.get()));
} else if (kernel_type == KernelType::SELECTED_ROWS_KENREL) {
SetSelectedRowsKernelOutput(kernel_backend, this);
SetSelectedRowsKernelOutput(this);
phi::MetaTensor meta_out(impl_.get());
phi::UnchangedInferMeta(
MakeMetaTensor(
......
......@@ -156,11 +156,11 @@ class ForwardAPI(BaseAPI):
assert self.outputs['out_size_expr'][0] is not None, \
f"{self.api}: The out size expr : '{{expr}}' should be set when output has Tensor[]. You can refer 'split' api."
output_create = output_create + f"""
{code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr'][0]}, kernel_backend, &api_output);"""
{code_indent} auto kernel_out = {set_out_func}({self.outputs['out_size_expr'][0]}, &api_output);"""
else:
output_create = output_create + f"""
{code_indent} auto kernel_out = {set_out_func}(kernel_backend, &api_output);"""
{code_indent} auto kernel_out = {set_out_func}(&api_output);"""
if not inplace_flag and self.view_map is not None and self.outputs[
'names'][0] in self.view_map:
......@@ -207,11 +207,11 @@ class ForwardAPI(BaseAPI):
set_out_func = "SetInplaceOptionalVectorKernelOutput"
get_out_code = f"std::get<{i}>(api_output)"
output_create = output_create + f"""
{code_indent} auto kernel_out_{i} = {set_out_func}({self.outputs['out_size_expr'][i]}, kernel_backend, {get_out_code});"""
{code_indent} auto kernel_out_{i} = {set_out_func}({self.outputs['out_size_expr'][i]}, {get_out_code});"""
else:
output_create = output_create + f"""
{code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, {get_out_code});"""
{code_indent} auto kernel_out_{i} = {set_out_func}({get_out_code});"""
if not inplace_flag and self.view_map is not None and self.outputs[
'names'][i] in self.view_map:
......
......@@ -151,7 +151,7 @@ PADDLE_API void {api_func_name}({self.get_declare_args()});
else:
output_create = output_create + f"""
{code_indent} auto kernel_out = {set_out_func}(kernel_backend, {self.outputs['names'][0]});"""
{code_indent} auto kernel_out = {set_out_func}({self.outputs['names'][0]});"""
elif len(out_dtype_list) > 1:
output_create = ""
......@@ -167,7 +167,7 @@ PADDLE_API void {api_func_name}({self.get_declare_args()});
{code_indent} *{self.outputs['names'][i]} = {self.inplace_map[self.outputs['names'][i]]};"""
output_create = output_create + f"""
{code_indent} auto kernel_out_{i} = {set_out_func}(kernel_backend, {self.outputs['names'][i]});"""
{code_indent} auto kernel_out_{i} = {set_out_func}({self.outputs['names'][i]});"""
else:
if inplace_flag and self.inplace_map is not None and self.outputs[
......
......@@ -71,7 +71,7 @@ class StringsAPI(ForwardAPI):
'names'][0] in self.inplace_map else ""
output_create = f"""
{return_type} api_output{inplace_assign};
{tensor_type}* kernel_out = dynamic_cast<{tensor_type}*>(SetStringsKernelOutput(kernel_backend, &api_output, {kernel_tensor_out_type}));"""
{tensor_type}* kernel_out = dynamic_cast<{tensor_type}*>(SetStringsKernelOutput(&api_output, {kernel_tensor_out_type}));"""
elif len(out_dtype_list) > 1:
output_create = f"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册