未验证 提交 4ab8255a 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Move part sum op kernel (#40873)

* move part sum op kernel

* remove deprecated names
上级 41f813e9
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include <algorithm>
#include <string> #include <string>
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
...@@ -69,27 +70,42 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -69,27 +70,42 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
bool IsDenseTensorInput(const std::string& name) const override { bool IsDenseTensorInput(const std::string& name) const override {
auto var_types = ctx_.GetInputsVarType(name); auto var_types = ctx_.GetInputsVarType(name);
return var_types[0] == proto::VarType::LOD_TENSOR; return std::all_of(var_types.begin(), var_types.end(),
[](const proto::VarType::Type& type) {
return type == proto::VarType::LOD_TENSOR;
});
} }
bool IsSelectedRowsInput(const std::string& name) const override { bool IsSelectedRowsInput(const std::string& name) const override {
auto var_types = ctx_.GetInputsVarType(name); auto var_types = ctx_.GetInputsVarType(name);
return var_types[0] == proto::VarType::SELECTED_ROWS; return std::all_of(var_types.begin(), var_types.end(),
[](const proto::VarType::Type& type) {
return type == proto::VarType::SELECTED_ROWS;
});
} }
bool IsDenseTensorVectorInput(const std::string& name) const override { bool IsDenseTensorVectorInput(const std::string& name) const override {
auto var_types = ctx_.GetInputsVarType(name); auto var_types = ctx_.GetInputsVarType(name);
return var_types[0] == proto::VarType::LOD_TENSOR_ARRAY; return std::all_of(var_types.begin(), var_types.end(),
[](const proto::VarType::Type& type) {
return type == proto::VarType::LOD_TENSOR_ARRAY;
});
} }
bool IsDenseTensorOutput(const std::string& name) const override { bool IsDenseTensorOutput(const std::string& name) const override {
auto var_types = ctx_.GetOutputsVarType(name); auto var_types = ctx_.GetOutputsVarType(name);
return var_types[0] == proto::VarType::LOD_TENSOR; return std::all_of(var_types.begin(), var_types.end(),
[](const proto::VarType::Type& type) {
return type == proto::VarType::LOD_TENSOR;
});
} }
bool IsSelectedRowsOutput(const std::string& name) const override { bool IsSelectedRowsOutput(const std::string& name) const override {
auto var_types = ctx_.GetOutputsVarType(name); auto var_types = ctx_.GetOutputsVarType(name);
return var_types[0] == proto::VarType::SELECTED_ROWS; return std::all_of(var_types.begin(), var_types.end(),
[](const proto::VarType::Type& type) {
return type == proto::VarType::SELECTED_ROWS;
});
} }
bool IsForInferShape() const override { return true; } bool IsForInferShape() const override { return true; }
......
...@@ -476,23 +476,38 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -476,23 +476,38 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
} }
bool IsDenseTensorInput(const std::string& name) const override { bool IsDenseTensorInput(const std::string& name) const override {
return ctx_.InputVar(name)->IsType<framework::LoDTensor>(); auto vars = ctx_.MultiInputVar(name);
return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
return var->IsType<phi::DenseTensor>();
});
} }
bool IsSelectedRowsInput(const std::string& name) const override { bool IsSelectedRowsInput(const std::string& name) const override {
return ctx_.InputVar(name)->IsType<phi::SelectedRows>(); auto vars = ctx_.MultiInputVar(name);
return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
return var->IsType<phi::SelectedRows>();
});
} }
bool IsDenseTensorVectorInput(const std::string& name) const override { bool IsDenseTensorVectorInput(const std::string& name) const override {
return ctx_.InputVar(name)->IsType<framework::LoDTensorArray>(); auto vars = ctx_.MultiInputVar(name);
return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
return var->IsType<framework::LoDTensorArray>();
});
} }
bool IsDenseTensorOutput(const std::string& name) const override { bool IsDenseTensorOutput(const std::string& name) const override {
return ctx_.OutputVar(name)->IsType<framework::LoDTensor>(); auto vars = ctx_.MultiOutputVar(name);
return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
return var->IsType<phi::DenseTensor>();
});
} }
bool IsSelectedRowsOutput(const std::string& name) const override { bool IsSelectedRowsOutput(const std::string& name) const override {
return ctx_.OutputVar(name)->IsType<phi::SelectedRows>(); auto vars = ctx_.MultiOutputVar(name);
return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
return var->IsType<phi::SelectedRows>();
});
} }
bool IsForInferShape() const override { return false; } bool IsForInferShape() const override { return false; }
......
...@@ -64,9 +64,7 @@ const std::unordered_set<std::string> deprecated_op_names({"diag", ...@@ -64,9 +64,7 @@ const std::unordered_set<std::string> deprecated_op_names({"diag",
"expand_as", "expand_as",
"expand_grad", "expand_grad",
"expand_as_grad", "expand_as_grad",
"sum",
"one_hot", "one_hot",
"sum_grad",
"top_k", "top_k",
"top_k_grad"}); "top_k_grad"});
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void AddNKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/add_n_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void AddNKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
DenseTensor* out) {
size_t in_num = x.size();
bool in_place = out == x[0];
auto* out_ptr = dev_ctx.template Alloc<T>(out);
if (in_num >= 1 && x[0]->initialized()) {
if (x[0]->numel() > 0) {
in_place = (x[0]->data<T>() == out_ptr);
}
}
auto result = EigenVector<T>::Flatten(*out);
auto& place = *dev_ctx.eigen_device();
int start = in_place ? 1 : 0;
if (!in_place) {
if ((in_num >= 2) && x[0]->initialized() && x[1]->initialized()) {
auto& in_0 = *x[0];
auto& in_1 = *x[1];
if (in_0.numel() && in_1.numel()) {
auto in_0_e = EigenVector<T>::Flatten(in_0);
auto in_1_e = EigenVector<T>::Flatten(in_1);
result.device(place) = in_0_e + in_1_e;
start = 2;
}
}
if (start != 2) {
VLOG(10) << "Fill with constant = 0 in sum kernel.";
funcs::SetConstant<Context, T> constant_functor;
constant_functor(dev_ctx, out, static_cast<T>(0));
}
}
// If in_place, just skip the first tensor
for (size_t i = start; i < in_num; i++) {
auto& in_t = *x[i];
if (!in_t.initialized() || in_t.numel() == 0) {
continue;
}
auto in = EigenVector<T>::Flatten(in_t);
result.device(place) = result + in;
}
}
} // namespace phi
PD_REGISTER_KERNEL(add_n,
CPU,
ALL_LAYOUT,
phi::AddNKernel,
float,
double,
int,
phi::dtype::bfloat16,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/add_n_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h"
namespace phi {
#define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
template <class T>
__global__ void SumArrayCUDAKernel(
T **in, T *out, int64_t N, size_t in_size, bool read_dst) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
while (id < N) {
T total(read_dst ? out[id] : static_cast<T>(0));
for (int i = 0; i < in_size; ++i) {
const T *tmp = in[i];
if (tmp) {
total += tmp[id];
}
}
out[id] = total;
id += blockDim.x * gridDim.x;
}
}
template <typename T, typename Context>
void AddNKernel(const Context &dev_ctx,
const std::vector<const DenseTensor *> &x,
DenseTensor *out) {
const size_t in_num = x.size();
constexpr size_t theory_sm_threads = 1024;
auto stream = dev_ctx.stream();
auto max_threads = dev_ctx.GetMaxPhysicalThreadCount();
auto sm_count = max_threads / theory_sm_threads;
size_t tile_size = 0;
dim3 grids;
dim3 blocks;
auto ComputeKernelParameter = [&](size_t length) {
if (length >= max_threads)
tile_size = 1024;
else if (length < max_threads && length > sm_count * 128)
tile_size = 512;
else if (length <= sm_count * 128)
tile_size = 256;
grids = dim3(CEIL_DIV(length, tile_size), 1, 1);
blocks = dim3(tile_size, 1, 1);
};
bool in_place = x[0] == out;
if (!in_place) {
auto *out_ptr = dev_ctx.template Alloc<T>(out);
if (in_num >= 1) {
auto &in_0_tensor = *x[0];
if (in_0_tensor.numel() > 0) {
in_place = (in_0_tensor.data<T>() == out_ptr);
}
}
}
// Sum of two tensors
if (in_num == 2) {
auto &in_0 = *x[0];
auto &in_1 = *x[1];
int64_t length_0 = in_0.numel();
int64_t length_1 = in_1.numel();
if (length_0 && length_1 && in_0.initialized() && in_1.initialized()) {
auto result = EigenVector<T>::Flatten(*out);
auto &place = *dev_ctx.eigen_device();
auto in_0_e = EigenVector<T>::Flatten(in_0);
auto in_1_e = EigenVector<T>::Flatten(in_1);
result.device(place) = in_0_e + in_1_e;
} else if (length_0 && in_0.initialized()) {
auto result = EigenVector<T>::Flatten(*out);
auto &place = *dev_ctx.eigen_device();
result.device(place) = EigenVector<T>::Flatten(in_0);
} else if (length_1 && in_1.initialized()) {
auto result = EigenVector<T>::Flatten(*out);
auto &place = *dev_ctx.eigen_device();
result.device(place) = EigenVector<T>::Flatten(in_1);
}
return;
}
int start = in_place ? 1 : 0;
if (!in_place) {
funcs::SetConstant<Context, T> constant_functor;
constant_functor(dev_ctx, out, static_cast<T>(0));
}
std::vector<const T *> in_data;
int64_t lod_length = 0;
bool dst_write = false;
for (int i = start; i < in_num; ++i) {
auto &in_i = *x[i];
lod_length = in_i.numel();
if (lod_length && in_i.initialized()) {
in_data.emplace_back(in_i.data<T>());
}
}
// if indata not null, merge into one kernel call.
if (!in_data.empty()) {
auto tmp_in_array =
paddle::memory::Alloc(dev_ctx, in_data.size() * sizeof(T *));
paddle::memory::Copy(dev_ctx.GetPlace(),
tmp_in_array->ptr(),
phi::CPUPlace(),
reinterpret_cast<void *>(in_data.data()),
in_data.size() * sizeof(T *),
dev_ctx.stream());
T **in_array_data = reinterpret_cast<T **>(tmp_in_array->ptr());
ComputeKernelParameter(lod_length);
SumArrayCUDAKernel<T><<<grids, blocks, 0, stream>>>(in_array_data,
out->data<T>(),
lod_length,
in_data.size(),
dst_write | in_place);
}
}
} // namespace phi
PD_REGISTER_KERNEL(add_n,
GPU,
ALL_LAYOUT,
phi::AddNKernel,
float,
double,
int,
int64_t,
phi::dtype::bfloat16,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature SumOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("X")) {
return KernelSignature("add_n", {"X"}, {}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(sum, add_n);
PD_REGISTER_ARG_MAPPING_FN(sum, phi::SumOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册