From 4ab8255adf21253be1151d485f3ad7e69f458054 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Fri, 25 Mar 2022 11:32:44 +0800 Subject: [PATCH] [Phi] Move part sum op kernel (#40873) * move part sum op kernel * remove deprecated names --- paddle/fluid/framework/infershape_utils.cc | 26 +++- paddle/fluid/framework/operator.h | 25 +++- paddle/phi/core/compat/op_utils.h | 2 - paddle/phi/kernels/add_n_kernel.h | 26 ++++ paddle/phi/kernels/cpu/add_n_kernel.cc | 78 ++++++++++ paddle/phi/kernels/gpu/add_n_kernel.cu | 157 +++++++++++++++++++++ paddle/phi/ops/compat/sum_sig.cc | 31 ++++ 7 files changed, 333 insertions(+), 12 deletions(-) create mode 100644 paddle/phi/kernels/add_n_kernel.h create mode 100644 paddle/phi/kernels/cpu/add_n_kernel.cc create mode 100644 paddle/phi/kernels/gpu/add_n_kernel.cu create mode 100644 paddle/phi/ops/compat/sum_sig.cc diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 504fadedba..1b6f5c6535 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/framework/infershape_utils.h" +#include #include #include "paddle/fluid/framework/convert_utils.h" @@ -69,27 +70,42 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { bool IsDenseTensorInput(const std::string& name) const override { auto var_types = ctx_.GetInputsVarType(name); - return var_types[0] == proto::VarType::LOD_TENSOR; + return std::all_of(var_types.begin(), var_types.end(), + [](const proto::VarType::Type& type) { + return type == proto::VarType::LOD_TENSOR; + }); } bool IsSelectedRowsInput(const std::string& name) const override { auto var_types = ctx_.GetInputsVarType(name); - return var_types[0] == proto::VarType::SELECTED_ROWS; + return std::all_of(var_types.begin(), var_types.end(), + [](const proto::VarType::Type& type) { + return type == proto::VarType::SELECTED_ROWS; + }); } bool IsDenseTensorVectorInput(const std::string& name) const override { auto var_types = ctx_.GetInputsVarType(name); - return var_types[0] == proto::VarType::LOD_TENSOR_ARRAY; + return std::all_of(var_types.begin(), var_types.end(), + [](const proto::VarType::Type& type) { + return type == proto::VarType::LOD_TENSOR_ARRAY; + }); } bool IsDenseTensorOutput(const std::string& name) const override { auto var_types = ctx_.GetOutputsVarType(name); - return var_types[0] == proto::VarType::LOD_TENSOR; + return std::all_of(var_types.begin(), var_types.end(), + [](const proto::VarType::Type& type) { + return type == proto::VarType::LOD_TENSOR; + }); } bool IsSelectedRowsOutput(const std::string& name) const override { auto var_types = ctx_.GetOutputsVarType(name); - return var_types[0] == proto::VarType::SELECTED_ROWS; + return std::all_of(var_types.begin(), var_types.end(), + [](const proto::VarType::Type& type) { + return type == proto::VarType::SELECTED_ROWS; + }); } bool IsForInferShape() const override { return true; } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 6f68c261d2..4048995a44 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -476,23 +476,38 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext { } bool IsDenseTensorInput(const std::string& name) const override { - return ctx_.InputVar(name)->IsType(); + auto vars = ctx_.MultiInputVar(name); + return std::all_of(vars.begin(), vars.end(), [](const Variable* var) { + return var->IsType(); + }); } bool IsSelectedRowsInput(const std::string& name) const override { - return ctx_.InputVar(name)->IsType(); + auto vars = ctx_.MultiInputVar(name); + return std::all_of(vars.begin(), vars.end(), [](const Variable* var) { + return var->IsType(); + }); } bool IsDenseTensorVectorInput(const std::string& name) const override { - return ctx_.InputVar(name)->IsType(); + auto vars = ctx_.MultiInputVar(name); + return std::all_of(vars.begin(), vars.end(), [](const Variable* var) { + return var->IsType(); + }); } bool IsDenseTensorOutput(const std::string& name) const override { - return ctx_.OutputVar(name)->IsType(); + auto vars = ctx_.MultiOutputVar(name); + return std::all_of(vars.begin(), vars.end(), [](const Variable* var) { + return var->IsType(); + }); } bool IsSelectedRowsOutput(const std::string& name) const override { - return ctx_.OutputVar(name)->IsType(); + auto vars = ctx_.MultiOutputVar(name); + return std::all_of(vars.begin(), vars.end(), [](const Variable* var) { + return var->IsType(); + }); } bool IsForInferShape() const override { return false; } diff --git a/paddle/phi/core/compat/op_utils.h b/paddle/phi/core/compat/op_utils.h index b4616c8c1b..6716f47918 100644 --- a/paddle/phi/core/compat/op_utils.h +++ b/paddle/phi/core/compat/op_utils.h @@ -64,9 +64,7 @@ const std::unordered_set deprecated_op_names({"diag", "expand_as", "expand_grad", "expand_as_grad", - "sum", "one_hot", - "sum_grad", "top_k", "top_k_grad"}); diff --git a/paddle/phi/kernels/add_n_kernel.h b/paddle/phi/kernels/add_n_kernel.h new file mode 100644 index 0000000000..c35dc2270a --- /dev/null +++ b/paddle/phi/kernels/add_n_kernel.h @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void AddNKernel(const Context& dev_ctx, + const std::vector& x, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/add_n_kernel.cc b/paddle/phi/kernels/cpu/add_n_kernel.cc new file mode 100644 index 0000000000..d658b55758 --- /dev/null +++ b/paddle/phi/kernels/cpu/add_n_kernel.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/add_n_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void AddNKernel(const Context& dev_ctx, + const std::vector& x, + DenseTensor* out) { + size_t in_num = x.size(); + bool in_place = out == x[0]; + auto* out_ptr = dev_ctx.template Alloc(out); + if (in_num >= 1 && x[0]->initialized()) { + if (x[0]->numel() > 0) { + in_place = (x[0]->data() == out_ptr); + } + } + + auto result = EigenVector::Flatten(*out); + auto& place = *dev_ctx.eigen_device(); + int start = in_place ? 1 : 0; + if (!in_place) { + if ((in_num >= 2) && x[0]->initialized() && x[1]->initialized()) { + auto& in_0 = *x[0]; + auto& in_1 = *x[1]; + if (in_0.numel() && in_1.numel()) { + auto in_0_e = EigenVector::Flatten(in_0); + auto in_1_e = EigenVector::Flatten(in_1); + result.device(place) = in_0_e + in_1_e; + start = 2; + } + } + if (start != 2) { + VLOG(10) << "Fill with constant = 0 in sum kernel."; + funcs::SetConstant constant_functor; + constant_functor(dev_ctx, out, static_cast(0)); + } + } + + // If in_place, just skip the first tensor + for (size_t i = start; i < in_num; i++) { + auto& in_t = *x[i]; + if (!in_t.initialized() || in_t.numel() == 0) { + continue; + } + auto in = EigenVector::Flatten(in_t); + result.device(place) = result + in; + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(add_n, + CPU, + ALL_LAYOUT, + phi::AddNKernel, + float, + double, + int, + phi::dtype::bfloat16, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/add_n_kernel.cu b/paddle/phi/kernels/gpu/add_n_kernel.cu new file mode 100644 index 0000000000..87636631a9 --- /dev/null +++ b/paddle/phi/kernels/gpu/add_n_kernel.cu @@ -0,0 +1,157 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/add_n_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/memory/memcpy.h" + +namespace phi { + +#define CEIL_DIV(x, y) (((x) + (y)-1) / (y)) + +template +__global__ void SumArrayCUDAKernel( + T **in, T *out, int64_t N, size_t in_size, bool read_dst) { + int id = blockIdx.x * blockDim.x + threadIdx.x; + while (id < N) { + T total(read_dst ? out[id] : static_cast(0)); + for (int i = 0; i < in_size; ++i) { + const T *tmp = in[i]; + if (tmp) { + total += tmp[id]; + } + } + out[id] = total; + id += blockDim.x * gridDim.x; + } +} + +template +void AddNKernel(const Context &dev_ctx, + const std::vector &x, + DenseTensor *out) { + const size_t in_num = x.size(); + + constexpr size_t theory_sm_threads = 1024; + auto stream = dev_ctx.stream(); + + auto max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + auto sm_count = max_threads / theory_sm_threads; + size_t tile_size = 0; + dim3 grids; + dim3 blocks; + + auto ComputeKernelParameter = [&](size_t length) { + if (length >= max_threads) + tile_size = 1024; + else if (length < max_threads && length > sm_count * 128) + tile_size = 512; + else if (length <= sm_count * 128) + tile_size = 256; + grids = dim3(CEIL_DIV(length, tile_size), 1, 1); + blocks = dim3(tile_size, 1, 1); + }; + + bool in_place = x[0] == out; + + if (!in_place) { + auto *out_ptr = dev_ctx.template Alloc(out); + if (in_num >= 1) { + auto &in_0_tensor = *x[0]; + if (in_0_tensor.numel() > 0) { + in_place = (in_0_tensor.data() == out_ptr); + } + } + } + + // Sum of two tensors + if (in_num == 2) { + auto &in_0 = *x[0]; + auto &in_1 = *x[1]; + int64_t length_0 = in_0.numel(); + int64_t length_1 = in_1.numel(); + if (length_0 && length_1 && in_0.initialized() && in_1.initialized()) { + auto result = EigenVector::Flatten(*out); + auto &place = *dev_ctx.eigen_device(); + auto in_0_e = EigenVector::Flatten(in_0); + auto in_1_e = EigenVector::Flatten(in_1); + result.device(place) = in_0_e + in_1_e; + } else if (length_0 && in_0.initialized()) { + auto result = EigenVector::Flatten(*out); + auto &place = *dev_ctx.eigen_device(); + result.device(place) = EigenVector::Flatten(in_0); + } else if (length_1 && in_1.initialized()) { + auto result = EigenVector::Flatten(*out); + auto &place = *dev_ctx.eigen_device(); + result.device(place) = EigenVector::Flatten(in_1); + } + return; + } + + int start = in_place ? 1 : 0; + if (!in_place) { + funcs::SetConstant constant_functor; + constant_functor(dev_ctx, out, static_cast(0)); + } + + std::vector in_data; + int64_t lod_length = 0; + bool dst_write = false; + for (int i = start; i < in_num; ++i) { + auto &in_i = *x[i]; + lod_length = in_i.numel(); + if (lod_length && in_i.initialized()) { + in_data.emplace_back(in_i.data()); + } + } + + // if indata not null, merge into one kernel call. + if (!in_data.empty()) { + auto tmp_in_array = + paddle::memory::Alloc(dev_ctx, in_data.size() * sizeof(T *)); + + paddle::memory::Copy(dev_ctx.GetPlace(), + tmp_in_array->ptr(), + phi::CPUPlace(), + reinterpret_cast(in_data.data()), + in_data.size() * sizeof(T *), + dev_ctx.stream()); + + T **in_array_data = reinterpret_cast(tmp_in_array->ptr()); + ComputeKernelParameter(lod_length); + SumArrayCUDAKernel<<>>(in_array_data, + out->data(), + lod_length, + in_data.size(), + dst_write | in_place); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(add_n, + GPU, + ALL_LAYOUT, + phi::AddNKernel, + float, + double, + int, + int64_t, + phi::dtype::bfloat16, + phi::dtype::float16) {} diff --git a/paddle/phi/ops/compat/sum_sig.cc b/paddle/phi/ops/compat/sum_sig.cc new file mode 100644 index 0000000000..4364047b0e --- /dev/null +++ b/paddle/phi/ops/compat/sum_sig.cc @@ -0,0 +1,31 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature SumOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorInput("X")) { + return KernelSignature("add_n", {"X"}, {}, {"Out"}); + } + return KernelSignature("unregistered", {}, {}, {}); +} + +} // namespace phi + +PD_REGISTER_BASE_KERNEL_NAME(sum, add_n); + +PD_REGISTER_ARG_MAPPING_FN(sum, phi::SumOpArgumentMapping); -- GitLab