未验证 提交 4b3f2af1 编写于 作者: Y YuanRisheng 提交者: GitHub

[PHI]Move sum op to PHI (#45860)

* move sum

* fix ci bugs

* fix ci bugs

* fix set_lod bugs

* fix infershape bugs

* fix ci bugs

* fix ci unittest bug

* fix ci bugs

* perfect code

* update code according comment

* add unittest

* fix ci bugs
上级 d963e2e4
......@@ -87,6 +87,15 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
});
}
bool IsSelectedRowsInputs(const std::string& name) const override {
auto var_types = ctx_.GetInputsVarType(name);
return std::all_of(var_types.begin(),
var_types.end(),
[](const proto::VarType::Type& type) {
return type == proto::VarType::SELECTED_ROWS;
});
}
bool IsSelectedRowsInput(const std::string& name) const override {
auto var_type = ctx_.GetInputVarType(name);
return var_type == proto::VarType::SELECTED_ROWS;
......@@ -145,6 +154,36 @@ int64_t CompatMetaTensor::numel() const {
}
}
bool CompatMetaTensor::is_selected_rows() const {
if (is_runtime_) {
auto* var = PADDLE_GET_CONST(Variable*, var_);
return var->IsType<phi::SelectedRows>();
} else {
auto* var = PADDLE_GET_CONST(VarDesc*, var_);
return var->GetType() == proto::VarType::SELECTED_ROWS;
}
}
bool CompatMetaTensor::is_dense() const {
if (is_runtime_) {
auto* var = PADDLE_GET_CONST(Variable*, var_);
return var->IsType<phi::DenseTensor>();
} else {
auto* var = PADDLE_GET_CONST(VarDesc*, var_);
return var->GetType() == proto::VarType::LOD_TENSOR;
}
}
bool CompatMetaTensor::is_tensor_array() const {
if (is_runtime_) {
auto* var = PADDLE_GET_CONST(Variable*, var_);
return var->IsType<framework::LoDTensorArray>();
} else {
auto* var = PADDLE_GET_CONST(VarDesc*, var_);
return var->GetType() == proto::VarType::LOD_TENSOR_ARRAY;
}
}
DDim CompatMetaTensor::dims() const {
ValidCheck(*this);
if (is_runtime_) {
......@@ -152,7 +191,7 @@ DDim CompatMetaTensor::dims() const {
if (var->IsType<phi::DenseTensor>()) {
return var->Get<phi::DenseTensor>().dims();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().dims();
return var->Get<phi::SelectedRows>().GetCompleteDims();
} else if (var->IsType<framework::LoDTensorArray>()) {
// use tensor array size as dims
auto& tensor_array = var->Get<framework::LoDTensorArray>();
......@@ -224,8 +263,7 @@ void CompatMetaTensor::set_dims(const DDim& dims) {
auto* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
} else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
var->GetMutable<phi::SelectedRows>()->set_height(dims[0]);
} else if (var->IsType<framework::LoDTensorArray>()) {
auto* tensor_array = var->GetMutable<framework::LoDTensorArray>();
// Note: Here I want enforce `tensor_array->size() == 0UL`, because
......@@ -299,7 +337,7 @@ void CompatMetaTensor::share_lod(const MetaTensor& meta_tensor) {
ValidCheck(meta_tensor);
if (is_runtime_) {
auto* var = PADDLE_GET(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) {
if (var->IsType<phi::DenseTensor>() && meta_tensor.is_dense()) {
auto* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->lod =
static_cast<const CompatMetaTensor&>(meta_tensor).GetRuntimeLoD();
......@@ -309,6 +347,10 @@ void CompatMetaTensor::share_lod(const MetaTensor& meta_tensor) {
}
} else {
auto* var = PADDLE_GET(VarDesc*, var_);
if (!meta_tensor.is_dense() && !meta_tensor.is_tensor_array()) {
VLOG(3) << "input metatensor is not LoDTensor or LoDTensorArray.";
return;
}
var->SetLoDLevel(
static_cast<const CompatMetaTensor&>(meta_tensor).GetCompileTimeLoD());
}
......
......@@ -59,6 +59,10 @@ class CompatMetaTensor : public phi::MetaTensor {
bool initialized() const override { return initialized_; };
bool is_selected_rows() const;
bool is_tensor_array() const;
bool is_dense() const;
operator unspecified_bool_type() const override {
return initialized_ ? unspecified_bool_true : 0;
}
......
......@@ -50,7 +50,7 @@ USE_OP_ITSELF(concat_grad);
USE_OP_ITSELF(elementwise_mul_grad);
USE_OP_ITSELF(sigmoid_grad);
USE_OP_ITSELF(tanh_grad);
USE_OP(sum);
USE_OP_ITSELF(sum);
USE_OP_ITSELF(slice_grad);
USE_OP_ITSELF(lookup_table_grad);
USE_OP_ITSELF(sqrt);
......@@ -101,6 +101,7 @@ PD_DECLARE_KERNEL(slice_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(cross_entropy_with_softmax, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(cross_entropy_with_softmax_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sqrt, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add_n, GPU, ALL_LAYOUT);
namespace paddle {
namespace framework {
......
......@@ -506,6 +506,13 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
});
}
bool IsSelectedRowsInputs(const std::string& name) const override {
auto vars = ctx_.MultiInputVar(name);
return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
return var->IsType<phi::SelectedRows>();
});
}
bool IsSelectedRowsInput(const std::string& name) const override {
const auto* var = ctx_.InputVar(name);
return var->IsType<phi::SelectedRows>();
......
......@@ -104,6 +104,10 @@ bool PluginArgumentMappingContext::IsSelectedRowsInput(
const std::string& name) const {
return false;
}
bool PluginArgumentMappingContext::IsSelectedRowsInputs(
const std::string& name) const {
return false;
}
bool PluginArgumentMappingContext::IsDenseTensorVectorInput(
const std::string& name) const {
return false;
......
......@@ -46,6 +46,8 @@ class PluginArgumentMappingContext : public ::phi::ArgumentMappingContext {
bool IsSelectedRowsInput(const std::string& name) const override;
bool IsSelectedRowsInputs(const std::string& name) const override;
bool IsDenseTensorVectorInput(const std::string& name) const override;
bool IsDenseTensorOutput(const std::string& name) const override;
......
......@@ -24,7 +24,8 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sum_op.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace phi {
......@@ -37,6 +38,9 @@ namespace operators {
using paddle::platform::MKLDNNDeviceContext;
using phi::CPUContext;
using platform::to_void_cast;
using Tensor = framework::Tensor;
using SelectedRows = phi::SelectedRows;
using LoDTensor = framework::LoDTensor;
template <typename T>
class SumMKLDNNHandler
......
......@@ -9,15 +9,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sum_op.h"
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
......@@ -32,94 +34,6 @@ class SumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "sum");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "sum");
if (ctx->IsRuntime() && ctx->GetOutputsVarType("Out")[0] ==
framework::proto::VarType::LOD_TENSOR_ARRAY) {
return; // skip runtime infershape when is tensor array;
}
auto x_var_types = ctx->GetInputsVarType("X");
auto x_dims = ctx->GetInputsDim("X");
auto N = x_dims.size();
PADDLE_ENFORCE_GT(
N,
0,
platform::errors::InvalidArgument(
"The input tensor X's dimensions of SumOp "
"should be larger than 0. But received X's dimensions %d, "
"X's shape = [%s].",
N,
&x_dims));
if (N == 1) {
VLOG(3) << "Warning: SumOp have only one input, may waste memory";
}
framework::DDim in_dim({0});
for (size_t i = 0; i < x_dims.size(); ++i) {
auto& x_dim = x_dims[i];
// x_dim.size() == 1 means the real dim of selected rows is [0]
if (x_var_types[i] == framework::proto::VarType::SELECTED_ROWS &&
x_dim.size() == 1) {
continue;
}
if (phi::product(x_dim) == 0) {
continue;
}
if (phi::product(in_dim) == 0) {
in_dim = x_dim;
} else {
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(in_dim,
x_dim,
platform::errors::InvalidArgument(
"The input tensor X of SumOp must"
" have same shape. But received X[0]'s shape = "
"[%s], X[%d]'s shape = [%s].",
in_dim,
i,
x_dim));
} else {
PADDLE_ENFORCE_EQ(
in_dim.size(),
x_dim.size(),
platform::errors::InvalidArgument(
"The input tensor X of SumOp must have same "
"dimensions. But received X[0]'s dimensions = %d, X[0]'s "
"shape = "
"[%s], X[%d]'s dimensions = %d, X[%d]'s shape = [%s].",
in_dim.size(),
in_dim,
i,
x_dim.size(),
i,
x_dim));
// if in_dim or x_dim has -1, not check equal
for (int j = 0; j < x_dim.size(); ++j) {
if (x_dim[j] == -1 || in_dim[j] == -1) {
continue;
}
PADDLE_ENFORCE_EQ(
in_dim[j],
x_dim[j],
platform::errors::InvalidArgument(
"The input tensor X of SumOp must have same shape "
"if not -1."
"But received X[0]'s shape = [%s], X[%d]'s shape = [%s].",
in_dim,
i,
x_dim));
}
}
}
}
ctx->SetOutputDim("Out", in_dim);
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -350,18 +264,16 @@ DECLARE_INPLACE_OP_INFERER(SumInplaceInferer, {"X", "Out"});
namespace ops = paddle::operators;
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(sum,
AddNInferShapeFunctor,
PD_INFER_META(phi::AddNTensorArrayInferMeta));
REGISTER_OPERATOR(sum,
ops::SumOp,
ops::SumOpMaker,
ops::SumGradDescMaker,
ops::SumGradOpBaseMaker,
ops::SumOpVarTypeInference,
ops::SumInplaceInferer);
REGISTER_OP_CPU_KERNEL(
sum,
ops::SumKernel<phi::CPUContext, float>,
ops::SumKernel<phi::CPUContext, double>,
ops::SumKernel<phi::CPUContext, int>,
ops::SumKernel<phi::CPUContext, paddle::platform::bfloat16>,
ops::SumKernel<phi::CPUContext, int64_t>);
ops::SumInplaceInferer,
AddNInferShapeFunctor);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <paddle/fluid/platform/device_context.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/sum_op.h"
#include "paddle/fluid/platform/float16.h"
namespace plat = paddle::platform;
namespace paddle {
namespace operators {
#define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
using LoDTensor = framework::LoDTensor;
template <class T>
__global__ void Sum2CUDAKernel(const T *in_0,
const T *in_1,
T *out,
int64_t N) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
while (id < N) {
out[id] = in_0[id] + in_1[id];
id += blockDim.x * gridDim.x;
}
}
template <class T>
__global__ void SumArrayCUDAKernel(
T **in, T *out, int64_t N, size_t in_size, bool read_dst) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
while (id < N) {
T total(read_dst ? out[id] : static_cast<T>(0));
for (int i = 0; i < in_size; ++i) {
const T *tmp = in[i];
if (tmp) {
total += tmp[id];
}
}
out[id] = total;
id += blockDim.x * gridDim.x;
}
}
template <class T>
__global__ void SumSelectedRowsCUDAKernel(T **sr_in_out,
int64_t N,
size_t rows) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
while (id < N) {
for (int i = 0; i < 2 * rows; i += 2) {
const T *tmp = sr_in_out[i];
T *tmp_out = sr_in_out[i + 1];
if (tmp && tmp_out) {
tmp_out[id] += tmp[id];
}
}
id += blockDim.x * gridDim.x;
}
}
template <class T>
void SumToLoDTensor(const framework::ExecutionContext &context) {
auto in_vars = context.MultiInputVar("X");
const size_t in_num = in_vars.size();
constexpr size_t theory_sm_threads = 1024;
auto &dev_ctx = context.template device_context<phi::GPUContext>();
auto stream = dev_ctx.stream();
auto max_threads = dev_ctx.GetMaxPhysicalThreadCount();
auto sm_count = max_threads / theory_sm_threads;
size_t tile_size = 0;
dim3 grids;
dim3 blocks;
auto ComputeKernelParameter = [&](size_t length) {
if (length >= max_threads)
tile_size = 1024;
else if (length < max_threads && length > sm_count * 128)
tile_size = 512;
else if (length <= sm_count * 128)
tile_size = 256;
grids = dim3(CEIL_DIV(length, tile_size), 1, 1);
blocks = dim3(tile_size, 1, 1);
};
auto *out = context.Output<LoDTensor>("Out");
bool in_place = in_vars[0] == context.OutputVar("Out");
if (!in_place) {
auto *out_ptr = out->mutable_data<T>(context.GetPlace());
if (in_num >= 1 && in_vars[0]->IsType<framework::LoDTensor>()) {
auto &in_0_tensor = in_vars[0]->Get<framework::LoDTensor>();
if (in_0_tensor.numel() > 0) {
in_place = (in_0_tensor.data<T>() == out_ptr);
}
}
}
// Sum of two tensors
if (in_num == 2 && in_vars[0]->IsType<framework::LoDTensor>() &&
in_vars[1]->IsType<framework::LoDTensor>()) {
auto &in_0 = in_vars[0]->Get<framework::LoDTensor>();
auto &in_1 = in_vars[1]->Get<framework::LoDTensor>();
int64_t length_0 = in_0.numel();
int64_t length_1 = in_1.numel();
if (length_0 && length_1 && in_0.IsInitialized() && in_1.IsInitialized()) {
auto result = EigenVector<T>::Flatten(*out);
auto &place = *dev_ctx.eigen_device();
auto in_0_e = EigenVector<T>::Flatten(in_0);
auto in_1_e = EigenVector<T>::Flatten(in_1);
result.device(place) = in_0_e + in_1_e;
} else if (length_0 && in_0.IsInitialized()) {
auto result = EigenVector<T>::Flatten(*out);
auto &place = *dev_ctx.eigen_device();
result.device(place) = EigenVector<T>::Flatten(in_0);
} else if (length_1 && in_1.IsInitialized()) {
auto result = EigenVector<T>::Flatten(*out);
auto &place = *dev_ctx.eigen_device();
result.device(place) = EigenVector<T>::Flatten(in_1);
}
return;
}
int start = in_place ? 1 : 0;
if (!in_place) {
phi::funcs::SetConstant<phi::GPUContext, T> constant_functor;
constant_functor(context.template device_context<phi::GPUContext>(),
out,
static_cast<T>(0));
}
std::vector<const T *> in_data;
std::vector<int> selectrow_index;
int64_t lod_length = 0;
bool dst_write = false;
for (int i = start; i < in_num; ++i) {
if (in_vars[i]->IsType<framework::LoDTensor>()) {
auto &in_i = in_vars[i]->Get<framework::LoDTensor>();
lod_length = in_i.numel();
if (lod_length && in_i.IsInitialized()) {
in_data.emplace_back(in_i.data<T>());
}
} else if (in_vars[i]->IsType<phi::SelectedRows>()) {
selectrow_index.push_back(i);
}
}
// compute select rows separately.
if (!selectrow_index.empty()) {
std::vector<const T *> sr_in_out_data;
size_t rows = 0;
int64_t length = 0;
for (auto index : selectrow_index) {
auto &sr = in_vars[index]->Get<phi::SelectedRows>();
auto &sr_value = sr.value();
auto &sr_rows = sr.rows();
auto row_numel = sr_value.numel() / sr_rows.size();
auto out_dims = out->dims();
PADDLE_ENFORCE_EQ(sr.height(),
out_dims[0],
platform::errors::InvalidArgument(
"The table height of input must be same as output, "
"but received input height is %d"
", output height is %d",
sr.height(),
out_dims[0]));
PADDLE_ENFORCE_EQ(row_numel,
out->numel() / sr.height(),
platform::errors::InvalidArgument(
"The table width of input must be same as output, "
"but received input width is %d"
", output width is %d",
row_numel,
out->numel() / sr.height()));
auto *sr_data = sr_value.data<T>();
auto *sr_out_data = out->data<T>();
rows += sr_rows.size();
length = row_numel;
for (size_t i = 0; i < sr_rows.size(); ++i) {
sr_in_out_data.emplace_back(&sr_data[i * row_numel]);
sr_in_out_data.emplace_back(&sr_out_data[sr_rows[i] * row_numel]);
}
}
if (!sr_in_out_data.empty()) {
auto tmp_sr_in_out_array = memory::Alloc(
dev_ctx.GetPlace(),
sr_in_out_data.size() * sizeof(T *),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
memory::Copy(dev_ctx.GetPlace(),
tmp_sr_in_out_array->ptr(),
platform::CPUPlace(),
reinterpret_cast<void *>(sr_in_out_data.data()),
sr_in_out_data.size() * sizeof(T *),
dev_ctx.stream());
T **sr_in_out_array_data =
reinterpret_cast<T **>(tmp_sr_in_out_array->ptr());
ComputeKernelParameter(length);
SumSelectedRowsCUDAKernel<T>
<<<grids, blocks, 0, stream>>>(sr_in_out_array_data, length, rows);
dst_write = true;
}
}
// if indata not null, merge into one kernel call.
if (!in_data.empty()) {
auto tmp_in_array = memory::Alloc(
dev_ctx.GetPlace(),
in_data.size() * sizeof(T *),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
memory::Copy(dev_ctx.GetPlace(),
tmp_in_array->ptr(),
platform::CPUPlace(),
reinterpret_cast<void *>(in_data.data()),
in_data.size() * sizeof(T *),
dev_ctx.stream());
T **in_array_data = reinterpret_cast<T **>(tmp_in_array->ptr());
ComputeKernelParameter(lod_length);
SumArrayCUDAKernel<T><<<grids, blocks, 0, stream>>>(in_array_data,
out->data<T>(),
lod_length,
in_data.size(),
dst_write | in_place);
}
}
template <typename T>
class SumKernel<phi::GPUContext, T> : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto out_var = context.OutputVar("Out");
if (out_var->IsType<framework::LoDTensor>()) {
SumToLoDTensor<T>(context);
} else if (out_var->IsType<phi::SelectedRows>()) {
SelectedRowsCompute<phi::GPUContext, T>(context);
} else if (out_var->IsType<framework::LoDTensorArray>()) {
LodTensorArrayCompute<phi::GPUContext, T>(context);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Expected type of Output(out) must be Tensor, SelectedRows or "
"LodTensorArray. But got "
"unsupport type: %s.",
framework::ToTypeName(out_var->Type())));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(sum,
ops::SumKernel<phi::GPUContext, float>,
ops::SumKernel<phi::GPUContext, double>,
ops::SumKernel<phi::GPUContext, int>,
ops::SumKernel<phi::GPUContext, int64_t>,
ops::SumKernel<phi::GPUContext, plat::float16>,
ops::SumKernel<phi::GPUContext, plat::bfloat16>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using SelectedRows = phi::SelectedRows;
using LoDTensor = framework::LoDTensor;
template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename DeviceContext, typename T>
void SelectedRowsCompute(const framework::ExecutionContext &context) {
auto in_vars = context.MultiInputVar("X");
auto out_var = context.OutputVar("Out");
bool in_place = out_var == in_vars[0];
if (in_place && in_vars.size() < 2) {
return;
}
std::vector<const phi::SelectedRows *> inputs;
SelectedRows temp_in0;
if (in_place) {
auto &in0 = in_vars[0]->Get<phi::SelectedRows>();
temp_in0.set_height(in0.height());
temp_in0.set_rows(in0.rows());
framework::TensorCopy(in0.value(),
in0.place(),
context.device_context(),
temp_in0.mutable_value());
inputs.push_back(&temp_in0);
for (size_t i = 1; i < in_vars.size(); ++i) {
auto &in = in_vars[i]->Get<phi::SelectedRows>();
if (in.rows().size() > 0) {
inputs.push_back(&in);
}
}
} else {
for (auto &in_var : in_vars) {
auto &in = in_var->Get<phi::SelectedRows>();
if (in.rows().size() > 0) {
inputs.push_back(&in_var->Get<phi::SelectedRows>());
}
}
}
auto *out = context.Output<phi::SelectedRows>("Out");
out->mutable_rows()->clear();
bool has_data = false;
for (auto &in : inputs) {
if (in->rows().size() > 0) {
has_data = true;
break;
}
}
if (has_data) {
math::scatter::MergeAdd<DeviceContext, T> merge_add;
merge_add(context.template device_context<DeviceContext>(), inputs, out);
out->SyncIndex();
} else {
// no data, just set a empty out tensor.
out->mutable_value()->mutable_data<T>(phi::make_ddim({0}),
context.GetPlace());
}
}
template <typename DeviceContext, typename T>
void LodTensorArrayCompute(const framework::ExecutionContext &context) {
auto in_vars = context.MultiInputVar("X");
auto out_var = context.OutputVar("Out");
bool in_place = out_var == in_vars[0];
auto &out_array = *out_var->GetMutable<framework::LoDTensorArray>();
for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) {
PADDLE_ENFORCE_EQ(in_vars[i]->IsType<framework::LoDTensorArray>(),
true,
platform::errors::InvalidArgument(
"Only support all inputs are TensorArray, "
"but inputs[%d] is not TensorArray.",
i));
auto &in_array = in_vars[i]->Get<framework::LoDTensorArray>();
for (size_t i = 0; i < in_array.size(); ++i) {
if (in_array[i].IsInitialized() && (in_array[i].numel() != 0)) {
if (i >= out_array.size()) {
out_array.resize(i + 1);
}
if (!out_array[i].IsInitialized() || (out_array[i].numel() == 0)) {
framework::TensorCopy(in_array[i],
in_array[i].place(),
context.device_context(),
&out_array[i]);
out_array[i].set_lod(in_array[i].lod());
} else {
PADDLE_ENFORCE_EQ(
out_array[i].lod(),
in_array[i].lod(),
platform::errors::InvalidArgument(
"The lod message between inputs[%d] and"
" outputs[%d] must be same, but now is not same.",
i,
i));
auto in = EigenVector<T>::Flatten(in_array[i]);
auto result = EigenVector<T>::Flatten(out_array[i]);
result.device(*context.template device_context<DeviceContext>()
.eigen_device()) = result + in;
}
}
}
}
}
template <typename DeviceContext, typename T>
class SumKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
VLOG(10) << "start sum kernel";
auto in_vars = context.MultiInputVar("X");
size_t in_num = in_vars.size();
auto out_var = context.OutputVar("Out");
bool in_place = out_var == in_vars[0];
if (out_var->IsType<framework::LoDTensor>()) {
auto *out = out_var->GetMutable<framework::LoDTensor>();
auto *out_ptr = out->mutable_data<T>(context.GetPlace());
if (in_num >= 1 && in_vars[0]->IsType<framework::LoDTensor>() &&
in_vars[0]->Get<framework::LoDTensor>().IsInitialized()) {
auto &in_0_tensor = in_vars[0]->Get<framework::LoDTensor>();
if (in_0_tensor.numel() > 0) {
in_place = (in_0_tensor.data<T>() == out_ptr);
}
}
auto result = EigenVector<T>::Flatten(*out);
auto &place =
*context.template device_context<DeviceContext>().eigen_device();
int start = in_place ? 1 : 0;
if (!in_place) {
if ((in_num >= 2) && in_vars[0]->IsType<framework::LoDTensor>() &&
in_vars[1]->IsType<framework::LoDTensor>() &&
in_vars[0]->Get<framework::LoDTensor>().IsInitialized() &&
in_vars[1]->Get<framework::LoDTensor>().IsInitialized()) {
auto &in_0 = in_vars[0]->Get<framework::LoDTensor>();
auto &in_1 = in_vars[1]->Get<framework::LoDTensor>();
if (in_0.numel() && in_1.numel()) {
auto in_0_e = EigenVector<T>::Flatten(in_0);
auto in_1_e = EigenVector<T>::Flatten(in_1);
result.device(place) = in_0_e + in_1_e;
start = 2;
}
}
if (start != 2) {
VLOG(10) << "Fill with constant = 0 in sum kernel.";
phi::funcs::SetConstant<DeviceContext, T> constant_functor;
constant_functor(context.template device_context<DeviceContext>(),
out,
static_cast<T>(0));
}
}
math::SelectedRowsAddToTensor<DeviceContext, T> functor;
// If in_place, just skip the first tensor
for (size_t i = start; i < in_num; i++) {
if (in_vars[i]->IsType<framework::LoDTensor>()) {
auto &in_t = in_vars[i]->Get<framework::LoDTensor>();
if (!in_t.IsInitialized() || in_t.numel() == 0) {
continue;
}
auto in = EigenVector<T>::Flatten(in_t);
result.device(place) = result + in;
} else if (in_vars[i]->IsType<phi::SelectedRows>()) {
auto &in_t = in_vars[i]->Get<phi::SelectedRows>();
functor(context.template device_context<DeviceContext>(), in_t, out);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Expected type of Input(X) of %d-th must be Tensor, "
"SelectedRows. But got "
"unsupport type: %s.",
framework::ToTypeName(in_vars[i]->Type())));
}
}
} else if (out_var->IsType<phi::SelectedRows>()) {
SelectedRowsCompute<DeviceContext, T>(context);
} else if (out_var->IsType<framework::LoDTensorArray>()) {
LodTensorArrayCompute<DeviceContext, T>(context);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Expected type of Output(out) must be Tensor, SelectedRows, "
"LoDTensorArray. But got "
"unsupport type: %s.",
framework::ToTypeName(out_var->Type())));
}
VLOG(10) << "end sum kernel";
}
};
} // namespace operators
} // namespace paddle
......@@ -12,13 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/operators/sum_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using SelectedRows = phi::SelectedRows;
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T>
class SumMLUKernel : public framework::OpKernel<T> {
......
......@@ -16,13 +16,16 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/operators/sum_op.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using SelectedRows = phi::SelectedRows;
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T>
class SumNPUKernel : public framework::OpKernel<T> {
......
......@@ -13,14 +13,16 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/operators/sum_op.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
namespace paddle {
namespace operators {
using framework::Tensor;
using SelectedRows = phi::SelectedRows;
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T>
class SumXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
......
......@@ -69,10 +69,16 @@ bool ProtoArgumentMappingContext::IsDenseTensorInputs(
return true;
}
bool ProtoArgumentMappingContext::IsSelectedRowsInputs(
const std::string& name) const {
return false;
}
bool ProtoArgumentMappingContext::IsSelectedRowsInput(
const std::string& name) const {
return false;
}
bool ProtoArgumentMappingContext::IsDenseTensorVectorInput(
const std::string& name) const {
return false;
......
......@@ -45,6 +45,7 @@ class ProtoArgumentMappingContext : public ::phi::ArgumentMappingContext {
bool IsDenseTensorInput(const std::string& name) const override;
bool IsDenseTensorInputs(const std::string& name) const override;
bool IsSelectedRowsInput(const std::string& name) const override;
bool IsSelectedRowsInputs(const std::string& name) const override;
bool IsDenseTensorVectorInput(const std::string& name) const override;
bool IsDenseTensorOutput(const std::string& name) const override;
......
......@@ -34,6 +34,95 @@ namespace experimental {
////////////////// Forward api impls //////////////////////
Tensor add_n_impl(const std::vector<Tensor>& x) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
bool is_sr_kernel = true;
for (auto& input : x) {
if (phi::DenseTensor::classof(input.impl().get())) {
is_sr_kernel = false;
break;
}
}
const std::string kernel_name = (is_sr_kernel ? "add_n_sr" : "add_n");
VLOG(6) << "add_n API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, {kernel_backend, kernel_layout, kernel_data_type});
const auto& kernel = kernel_result.kernel;
VLOG(6) << kernel_name << " kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(
kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend);
Tensor api_output;
if (is_sr_kernel) {
std::vector<const phi::SelectedRows*> input_x(x.size());
for (size_t i = 0; i < input_x.size(); ++i) {
input_x[i] = static_cast<phi::SelectedRows*>(x[i].impl().get());
}
auto x_meta_vec = MakeMetaTensor(input_x);
std::vector<const phi::MetaTensor*> x_metas(x_meta_vec.size());
for (size_t i = 0; i < x_meta_vec.size(); ++i) {
x_metas[i] = &x_meta_vec[i];
}
auto kernel_out = SetSelectedRowsKernelOutput(&api_output);
phi::MetaTensor meta_out(kernel_out);
phi::AddNInferMeta(x_metas, &meta_out);
using kernel_signature =
void (*)(const platform::DeviceContext&,
const std::vector<const phi::SelectedRows*>&,
phi::SelectedRows*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, input_x, kernel_out);
} else {
std::vector<const phi::TensorBase*> input_x(x.size());
for (size_t i = 0; i < input_x.size(); ++i) {
input_x[i] = x[i].impl().get();
}
auto x_meta_vec = MakeMetaTensor(input_x);
std::vector<const phi::MetaTensor*> x_metas(x_meta_vec.size());
for (size_t i = 0; i < x_meta_vec.size(); ++i) {
x_metas[i] = &x_meta_vec[i];
}
auto kernel_out = SetKernelOutput(&api_output);
phi::MetaTensor meta_out(kernel_out);
phi::AddNInferMeta(x_metas, &meta_out);
using kernel_signature =
void (*)(const platform::DeviceContext&,
const std::vector<const phi::TensorBase*>&,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, input_x, kernel_out);
}
return api_output;
}
Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) {
Tensor out;
copy(x, place, blocking, &out);
......
......@@ -31,6 +31,8 @@ namespace experimental {
////////////////// Forward api impls //////////////////////
Tensor add_n_impl(const std::vector<Tensor>& x);
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl(
const Tensor& x,
const Tensor& scale,
......
......@@ -94,6 +94,16 @@ phi::MetaTensor MakeMetaTensor(const phi::TensorBase& tensor) {
return phi::MetaTensor(tensor);
}
std::vector<phi::MetaTensor> MakeMetaTensor(
const std::vector<const phi::TensorBase*>& tensors) {
std::vector<phi::MetaTensor> meta_tensors;
meta_tensors.reserve(tensors.size());
for (const auto* t : tensors) {
meta_tensors.emplace_back(*t);
}
return meta_tensors;
}
phi::MetaTensor MakeMetaTensor(
const paddle::optional<phi::DenseTensor>& tensor) {
if (tensor) {
......@@ -112,6 +122,16 @@ std::vector<phi::MetaTensor> MakeMetaTensor(
return meta_tensors;
}
std::vector<phi::MetaTensor> MakeMetaTensor(
const std::vector<const phi::SelectedRows*>& tensors) {
std::vector<phi::MetaTensor> meta_tensors;
meta_tensors.reserve(tensors.size());
for (const auto* t : tensors) {
meta_tensors.emplace_back(*t);
}
return meta_tensors;
}
std::vector<phi::MetaTensor> MakeMetaTensor(
const std::vector<phi::DenseTensor*>& tensors) {
std::vector<phi::MetaTensor> meta_tensors;
......
......@@ -65,12 +65,18 @@ std::vector<phi::MetaTensor> MakeMetaTensor(
std::vector<phi::MetaTensor> MakeMetaTensor(
const std::vector<phi::DenseTensor*>& tensors);
std::vector<phi::MetaTensor> MakeMetaTensor(
const std::vector<const phi::SelectedRows*>& tensors);
phi::MetaTensor MakeMetaTensor(
const paddle::optional<phi::SelectedRows>& tensor);
std::vector<phi::MetaTensor> MakeMetaTensor(
const paddle::optional<std::vector<const phi::DenseTensor*>>& tensors);
std::vector<phi::MetaTensor> MakeMetaTensor(
const std::vector<const phi::TensorBase*>& tensors);
/* ------------------ for output ----------------------- */
phi::DenseTensor* SetKernelOutput(Tensor* out);
......
......@@ -102,10 +102,7 @@
- op : add_n
args : (Tensor[] x)
output : Tensor
infer_meta :
func : AddNInferMeta
kernel :
func : add_n
invoke : add_n_impl(x)
backward : add_n_grad
- op : addmm
......
......@@ -108,6 +108,7 @@ class ArgumentMappingContext {
virtual bool IsDenseTensorInput(const std::string& name) const = 0;
virtual bool IsDenseTensorInputs(const std::string& name) const = 0;
virtual bool IsSelectedRowsInput(const std::string& name) const = 0;
virtual bool IsSelectedRowsInputs(const std::string& name) const = 0;
// For compatibility with LoDTensorArray
virtual bool IsDenseTensorVectorInput(const std::string& name) const = 0;
......
......@@ -100,6 +100,24 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(
const std::vector<const SelectedRows*>&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(
const std::vector<const TensorBase*>&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(
const std::vector<const TensorArray*>&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(const SelectedRows&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
......
......@@ -270,6 +270,8 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor);
PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SelectedRows);
PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor);
PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(TensorBase);
PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(SelectedRows);
PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows);
PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_MULTI_INPUT(DenseTensor);
......
......@@ -139,6 +139,14 @@ void MetaTensor::share_meta(const MetaTensor& meta_tensor) {
}
}
TensorBase* MetaTensor::tensor() const { return tensor_; }
bool MetaTensor::is_dense() const { return DenseTensor::classof(tensor_); }
bool MetaTensor::is_selected_rows() const {
return SelectedRows::classof(tensor_);
}
bool MetaTensor::is_tensor_array() const { return false; }
void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
ValidCheck(*this);
bool is_dense_tensor = phi::DenseTensor::classof(tensor_);
......@@ -178,6 +186,4 @@ const LoD& MetaTensor::lod() const {
}
}
TensorBase* MetaTensor::tensor() const { return tensor_; }
} // namespace phi
......@@ -68,6 +68,12 @@ class MetaTensor {
virtual bool initialized() const;
virtual bool is_selected_rows() const;
virtual bool is_dense() const;
// TODO(YuanRisheng) This API is for compatible with Fluid
// and it will be deleted in the future.
virtual bool is_tensor_array() const;
virtual operator unspecified_bool_type() const {
return tensor_ == nullptr ? 0 : unspecified_bool_true;
}
......
......@@ -301,6 +301,10 @@ void AddNInferMeta(const std::vector<const MetaTensor*>& x,
phi::DDim in_dim({0});
for (size_t i = 0; i < x.size(); ++i) {
auto x_dim = x[i]->dims();
// x_dim.size() == 1 means the real dim of selected rows is [0]
if (x[i]->is_selected_rows() && x_dim.size() == 1) {
continue;
}
if (phi::product(x_dim) == 0) {
continue;
}
......@@ -355,6 +359,31 @@ void AddNInferMeta(const std::vector<const MetaTensor*>& x,
out->share_lod(*x[0]);
}
// TODO(YuanRisheng) This InferMeta is used in Fluid
// and will be deleted in the future.
void AddNTensorArrayInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out,
MetaConfig config) {
int64_t max_length = 0;
bool has_tensor_array = false;
for (auto input : x) {
if (input->is_tensor_array()) {
has_tensor_array = true;
// if input is lod_tensor_array, dims() will return its size (one element)
max_length =
input->dims()[0] > max_length ? input->dims()[0] : max_length;
}
}
if (has_tensor_array) {
if (out->is_tensor_array()) {
out->set_dims(make_ddim({max_length}));
}
} else {
AddNInferMeta(x, out, config);
}
}
void AucInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& stat_pos,
......
......@@ -123,6 +123,10 @@ void AddNInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void AddNTensorArrayInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out,
MetaConfig config);
void AucInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& stat_pos,
......
......@@ -15,12 +15,20 @@
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/tensor_array.h"
namespace phi {
// Note(YuanRisheng): std::vector<const TensorBase*> shouldn't be widely used in
// PHI. Here, we use it to be compatible with Fluid.
template <typename T, typename Context>
void AddNKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const std::vector<const TensorBase*>& x,
DenseTensor* out);
template <typename T, typename Context>
void AddNArrayKernel(const Context& dev_ctx,
const std::vector<const TensorArray*>& x,
TensorArray* out);
} // namespace phi
......@@ -12,24 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/add_n_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/add_n_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void AddNKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const std::vector<const TensorBase*>& x,
DenseTensor* out) {
size_t in_num = x.size();
bool in_place = out == x[0];
auto* out_ptr = dev_ctx.template Alloc<T>(out);
if (in_num >= 1 && x[0]->initialized()) {
if (x[0]->numel() > 0) {
in_place = (x[0]->data<T>() == out_ptr);
dev_ctx.template Alloc<T>(out);
bool in_place = false;
if (x.size() > 0 && x[0]->initialized() && DenseTensor::classof(x[0])) {
if ((static_cast<const DenseTensor*>(x[0]))->Holder() == out->Holder()) {
in_place = true;
}
}
......@@ -37,9 +34,11 @@ void AddNKernel(const Context& dev_ctx,
auto& place = *dev_ctx.eigen_device();
int start = in_place ? 1 : 0;
if (!in_place) {
if ((in_num >= 2) && x[0]->initialized() && x[1]->initialized()) {
auto& in_0 = *x[0];
auto& in_1 = *x[1];
if ((in_num >= 2) && DenseTensor::classof(x[0]) &&
DenseTensor::classof(x[1]) && x[0]->initialized() &&
x[1]->initialized()) {
auto& in_0 = *(static_cast<const DenseTensor*>(x[0]));
auto& in_1 = *(static_cast<const DenseTensor*>(x[1]));
if (in_0.numel() && in_1.numel()) {
auto in_0_e = EigenVector<T>::Flatten(in_0);
auto in_1_e = EigenVector<T>::Flatten(in_1);
......@@ -49,20 +48,33 @@ void AddNKernel(const Context& dev_ctx,
}
if (start != 2) {
VLOG(10) << "Fill with constant = 0 in sum kernel.";
funcs::SetConstant<Context, T> constant_functor;
phi::funcs::SetConstant<Context, T> constant_functor;
constant_functor(dev_ctx, out, static_cast<T>(0));
}
}
paddle::operators::math::SelectedRowsAddToTensor<Context, T> functor;
// If in_place, just skip the first tensor
for (size_t i = start; i < in_num; i++) {
auto& in_t = *x[i];
if (!in_t.initialized() || in_t.numel() == 0) {
continue;
if (DenseTensor::classof(x[i])) {
auto& in_t = *(static_cast<const DenseTensor*>(x[i]));
if (!in_t.initialized() || in_t.numel() == 0) {
continue;
}
auto in = EigenVector<T>::Flatten(in_t);
result.device(place) = result + in;
} else if (SelectedRows::classof(x[i])) {
auto& in_t = *(static_cast<const SelectedRows*>(x[i]));
functor(dev_ctx, in_t, out);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Expected type of Input(X) of %d-th must be Tensor, "
"SelectedRows. But got "
"unsupport type: %s.",
x[i]->type_info().name()));
}
auto in = EigenVector<T>::Flatten(in_t);
result.device(place) = result + in;
}
VLOG(10) << "end add_n kernel";
}
} // namespace phi
......@@ -76,3 +88,13 @@ PD_REGISTER_KERNEL(add_n,
int,
phi::dtype::bfloat16,
int64_t) {}
PD_REGISTER_KERNEL(add_n_array,
CPU,
ALL_LAYOUT,
phi::AddNArrayKernel,
float,
double,
int,
phi::dtype::bfloat16,
int64_t) {}
......@@ -14,16 +14,27 @@
#include "paddle/phi/kernels/add_n_kernel.h"
#include "paddle/phi/kernels/impl/add_n_kernel_impl.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
#define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
template <class T>
__global__ void Sum2CUDAKernel(const T *in_0,
const T *in_1,
T *out,
int64_t N) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
while (id < N) {
out[id] = in_0[id] + in_1[id];
id += blockDim.x * gridDim.x;
}
}
template <class T>
__global__ void SumArrayCUDAKernel(
T **in, T *out, int64_t N, size_t in_size, bool read_dst) {
......@@ -41,9 +52,26 @@ __global__ void SumArrayCUDAKernel(
}
}
template <class T>
__global__ void SumSelectedRowsCUDAKernel(T **sr_in_out,
int64_t N,
size_t rows) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
while (id < N) {
for (int i = 0; i < 2 * rows; i += 2) {
const T *tmp = sr_in_out[i];
T *tmp_out = sr_in_out[i + 1];
if (tmp && tmp_out) {
tmp_out[id] += tmp[id];
}
}
id += blockDim.x * gridDim.x;
}
}
template <typename T, typename Context>
void AddNKernel(const Context &dev_ctx,
const std::vector<const DenseTensor *> &x,
const std::vector<const TensorBase *> &x,
DenseTensor *out) {
const size_t in_num = x.size();
......@@ -66,36 +94,38 @@ void AddNKernel(const Context &dev_ctx,
grids = dim3(CEIL_DIV(length, tile_size), 1, 1);
blocks = dim3(tile_size, 1, 1);
};
auto *out_ptr = dev_ctx.template Alloc<T>(out);
bool in_place = false;
if (x.size() > 0 && x[0]->initialized() && DenseTensor::classof(x[0])) {
if ((static_cast<const DenseTensor *>(x[0]))->data() == out->data()) {
in_place = true;
}
}
bool in_place = x[0] == out;
if (!in_place) {
auto *out_ptr = dev_ctx.template Alloc<T>(out);
if (in_num >= 1) {
auto &in_0_tensor = *x[0];
if (in_0_tensor.numel() > 0) {
in_place = (in_0_tensor.data<T>() == out_ptr);
}
if (!in_place && in_num >= 1 && DenseTensor::classof(x[0])) {
auto &in_0_tensor = *(static_cast<const DenseTensor *>(x[0]));
if (in_0_tensor.numel() > 0) {
in_place = (in_0_tensor.data<T>() == out_ptr);
}
}
// Sum of two tensors
if (in_num == 2) {
auto &in_0 = *x[0];
auto &in_1 = *x[1];
if (in_num == 2 && DenseTensor::classof(x[0]) && DenseTensor::classof(x[1])) {
auto &in_0 = *(static_cast<const DenseTensor *>(x[0]));
auto &in_1 = *(static_cast<const DenseTensor *>(x[1]));
int64_t length_0 = in_0.numel();
int64_t length_1 = in_1.numel();
if (length_0 && length_1 && in_0.initialized() && in_1.initialized()) {
if (length_0 && length_1 && in_0.IsInitialized() && in_1.IsInitialized()) {
auto result = EigenVector<T>::Flatten(*out);
auto &place = *dev_ctx.eigen_device();
auto in_0_e = EigenVector<T>::Flatten(in_0);
auto in_1_e = EigenVector<T>::Flatten(in_1);
result.device(place) = in_0_e + in_1_e;
} else if (length_0 && in_0.initialized()) {
} else if (length_0 && in_0.IsInitialized()) {
auto result = EigenVector<T>::Flatten(*out);
auto &place = *dev_ctx.eigen_device();
result.device(place) = EigenVector<T>::Flatten(in_0);
} else if (length_1 && in_1.initialized()) {
} else if (length_1 && in_1.IsInitialized()) {
auto result = EigenVector<T>::Flatten(*out);
auto &place = *dev_ctx.eigen_device();
result.device(place) = EigenVector<T>::Flatten(in_1);
......@@ -105,27 +135,90 @@ void AddNKernel(const Context &dev_ctx,
int start = in_place ? 1 : 0;
if (!in_place) {
funcs::SetConstant<Context, T> constant_functor;
phi::funcs::SetConstant<phi::GPUContext, T> constant_functor;
constant_functor(dev_ctx, out, static_cast<T>(0));
}
std::vector<const T *> in_data;
std::vector<int> selectrow_index;
int64_t lod_length = 0;
bool dst_write = false;
for (int i = start; i < in_num; ++i) {
auto &in_i = *x[i];
lod_length = in_i.numel();
if (lod_length && in_i.initialized()) {
in_data.emplace_back(in_i.data<T>());
if (DenseTensor::classof(x[i])) {
auto &in_i = *(static_cast<const DenseTensor *>(x[i]));
lod_length = in_i.numel();
if (lod_length && in_i.IsInitialized()) {
in_data.emplace_back(in_i.data<T>());
}
} else if (SelectedRows::classof(x[i])) {
selectrow_index.push_back(i);
}
}
// compute select rows separately.
if (!selectrow_index.empty()) {
std::vector<const T *> sr_in_out_data;
size_t rows = 0;
int64_t length = 0;
for (auto index : selectrow_index) {
auto &sr = *(static_cast<const SelectedRows *>(x[index]));
auto &sr_value = sr.value();
auto &sr_rows = sr.rows();
auto row_numel = sr_value.numel() / sr_rows.size();
auto out_dims = out->dims();
PADDLE_ENFORCE_EQ(sr.height(),
out_dims[0],
errors::InvalidArgument(
"The table height of input must be same as output, "
"but received input height is %d"
", output height is %d",
sr.height(),
out_dims[0]));
PADDLE_ENFORCE_EQ(row_numel,
out->numel() / sr.height(),
errors::InvalidArgument(
"The table width of input must be same as output, "
"but received input width is %d"
", output width is %d",
row_numel,
out->numel() / sr.height()));
auto *sr_data = sr_value.data<T>();
auto *sr_out_data = out->data<T>();
rows += sr_rows.size();
length = row_numel;
for (size_t i = 0; i < sr_rows.size(); ++i) {
sr_in_out_data.emplace_back(&sr_data[i * row_numel]);
sr_in_out_data.emplace_back(&sr_out_data[sr_rows[i] * row_numel]);
}
}
if (!sr_in_out_data.empty()) {
auto tmp_sr_in_out_array = paddle::memory::Alloc(
dev_ctx.GetPlace(), sr_in_out_data.size() * sizeof(T *));
paddle::memory::Copy(dev_ctx.GetPlace(),
tmp_sr_in_out_array->ptr(),
phi::CPUPlace(),
reinterpret_cast<void *>(sr_in_out_data.data()),
sr_in_out_data.size() * sizeof(T *),
dev_ctx.stream());
T **sr_in_out_array_data =
reinterpret_cast<T **>(tmp_sr_in_out_array->ptr());
ComputeKernelParameter(length);
SumSelectedRowsCUDAKernel<T>
<<<grids, blocks, 0, stream>>>(sr_in_out_array_data, length, rows);
dst_write = true;
}
}
// if indata not null, merge into one kernel call.
if (!in_data.empty()) {
auto tmp_in_array = paddle::memory::Alloc(
dev_ctx.GetPlace(),
in_data.size() * sizeof(T *),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto tmp_in_array =
paddle::memory::Alloc(dev_ctx.GetPlace(), in_data.size() * sizeof(T *));
paddle::memory::Copy(dev_ctx.GetPlace(),
tmp_in_array->ptr(),
......@@ -153,6 +246,17 @@ PD_REGISTER_KERNEL(add_n,
float,
double,
int,
int64_t,
phi::dtype::bfloat16,
phi::dtype::float16) {}
phi::dtype::float16,
int64_t) {}
PD_REGISTER_KERNEL(add_n_array,
GPU,
ALL_LAYOUT,
phi::AddNArrayKernel,
float,
double,
int,
phi::dtype::bfloat16,
phi::dtype::float16,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/add_n_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
namespace phi {
template <typename T, typename Context>
void AddNArrayKernel(const Context& dev_ctx,
const std::vector<const TensorArray*>& x,
TensorArray* out) {
for (auto& ele : *out) {
dev_ctx.template Alloc<T>(&ele);
}
bool in_place = true;
if (x.size() > 0 && x[0]->size() == out->size()) {
for (size_t i = 0; i < out->size(); i++) {
if (x[0]->at(i).IsInitialized() &&
out->at(i).data() != x[0]->at(i).data()) {
in_place = false;
break;
}
}
} else {
in_place = false;
}
for (size_t i = in_place ? 1 : 0; i < x.size(); ++i) {
auto* in_array = x.at(i);
for (size_t j = 0; j < in_array->size(); ++j) {
if (in_array->at(j).IsInitialized() && (in_array->at(j).numel() != 0)) {
if (j >= out->size()) {
out->resize(j + 1);
}
if (!out->at(j).IsInitialized() || (out->at(j).numel() == 0)) {
Copy<Context>(dev_ctx,
in_array->at(j),
in_array->at(j).place(),
false,
&out->at(j));
out->at(j).set_lod(in_array->at(j).lod());
} else {
PADDLE_ENFORCE_EQ(
out->at(j).lod(),
in_array->at(j).lod(),
phi::errors::InvalidArgument(
"The lod message between inputs[%d] and"
" outputs[%d] must be same, but now is not same.",
j,
j));
auto in = EigenVector<T>::Flatten(in_array->at(j));
auto result = EigenVector<T>::Flatten(out->at(j));
result.device(*dev_ctx.eigen_device()) = result + in;
}
}
}
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/selected_rows.h"
namespace phi {
namespace sr {
template <typename T, typename Context>
void AddNKernel(const Context& dev_ctx,
const std::vector<const SelectedRows*>& x,
SelectedRows* out);
} // namespace sr
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/selected_rows/impl/add_n_kernel_impl.h"
PD_REGISTER_KERNEL(add_n_sr,
CPU,
ALL_LAYOUT,
phi::sr::AddNKernel,
float,
double,
int,
phi::dtype::bfloat16,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/selected_rows/impl/add_n_kernel_impl.h"
PD_REGISTER_KERNEL(add_n_sr,
GPU,
ALL_LAYOUT,
phi::sr::AddNKernel,
float,
double,
int,
phi::dtype::bfloat16,
phi::dtype::float16,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/selected_rows/add_n_kernel.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
namespace sr {
template <typename T, typename Context>
void AddNKernel(const Context &dev_ctx,
const std::vector<const SelectedRows *> &x,
SelectedRows *out) {
dev_ctx.template Alloc<T>(out->mutable_value());
bool in_place = false;
if (x.size() > 0 && x[0]->value().Holder() == out->value().Holder()) {
in_place = true;
}
if (in_place && x.size() < 2) {
return;
}
std::vector<const phi::SelectedRows *> inputs;
SelectedRows temp_in0;
if (in_place) {
auto &in0 = *x[0];
temp_in0.set_height(in0.height());
temp_in0.set_rows(in0.rows());
Copy<Context>(
dev_ctx, in0.value(), in0.place(), false, temp_in0.mutable_value());
inputs.push_back(&temp_in0);
for (size_t i = 1; i < x.size(); ++i) {
auto &in = *x[i];
if (in.rows().size() > 0) {
inputs.push_back(&in);
}
}
} else {
for (auto in_var : x) {
auto &in = *in_var;
if (in.rows().size() > 0) {
inputs.push_back(in_var);
}
}
}
out->mutable_rows()->clear();
bool has_data = false;
for (auto &in : inputs) {
if (in->rows().size() > 0) {
has_data = true;
break;
}
}
if (has_data) {
paddle::operators::math::scatter::MergeAdd<Context, T> merge_add;
merge_add(dev_ctx, inputs, out);
out->SyncIndex();
} else {
// no data, just set a empty out tensor.
auto *out_dense = out->mutable_value();
out_dense->clear();
out_dense->Resize(phi::make_ddim({0}));
dev_ctx.template Alloc<T>(out_dense);
}
}
} // namespace sr
} // namespace phi
......@@ -27,6 +27,15 @@ KernelSignature MemcpyD2HOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("memcpy_d2h", {"X"}, {"dst_place_type"}, {"Out"});
}
KernelSignature MemcpyOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("X")) {
return KernelSignature("memcpy", {"X"}, {"dst_place_type"}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(memcpy_d2h, phi::MemcpyD2HOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(memcpy, phi::MemcpyOpArgumentMapping);
......@@ -18,10 +18,13 @@
namespace phi {
KernelSignature SumOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInputs("X")) {
if (ctx.IsSelectedRowsInputs("X")) {
return KernelSignature("add_n_sr", {"X"}, {}, {"Out"});
} else if (ctx.IsDenseTensorVectorInput("X")) {
return KernelSignature("add_n_array", {"X"}, {}, {"Out"});
} else {
return KernelSignature("add_n", {"X"}, {}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
}
} // namespace phi
......
......@@ -109,3 +109,7 @@ cc_test(
test_strings_lower_upper_api
SRCS test_strings_lower_upper_api.cc
DEPS ${COMMON_API_TEST_DEPS})
cc_test(
test_add_n_api
SRCS test_add_n_api.cc
DEPS ${COMMON_API_TEST_DEPS})
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/api/lib/api_custom_impl.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/selected_rows.h"
PD_DECLARE_KERNEL(add_n_sr, CPU, ALL_LAYOUT);
namespace paddle {
namespace tests {
TEST(API, add_n) {
// 1. create tensor
std::vector<int64_t> rows = {0, 1, 2, 3, 4, 5, 6};
int64_t row_numel = 12;
auto x_sr = std::make_shared<phi::SelectedRows>(rows, 10);
auto x_meta = phi::DenseTensorMeta(
phi::DataType::FLOAT32,
phi::make_ddim({static_cast<int64_t>(rows.size()), row_numel}),
phi::DataLayout::NCHW);
x_sr->mutable_value()->set_meta(x_meta);
x_sr->AllocateFrom(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get(),
phi::DataType::FLOAT32);
auto* dense_x_data = x_sr->mutable_value()->data<float>();
auto y_sr = std::make_shared<phi::SelectedRows>(rows, 10);
y_sr->mutable_value()->set_meta(x_meta);
y_sr->AllocateFrom(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get(),
phi::DataType::FLOAT32);
auto* dense_y_data = y_sr->mutable_value()->data<float>();
float sum[84] = {0.0};
for (size_t i = 0; i < 7; ++i) {
for (size_t j = 0; j < 12; ++j) {
dense_x_data[i * 12 + j] = (i * 4 + j);
dense_y_data[i * 12 + j] = (i * 4 + j);
sum[i * 12 + j] += (i * 4 + j) * 2;
}
}
paddle::experimental::Tensor x(x_sr);
paddle::experimental::Tensor y(y_sr);
auto out = paddle::experimental::add_n_impl({x, y});
// check slice result
ASSERT_EQ(
static_cast<int>(std::dynamic_pointer_cast<phi::SelectedRows>(out.impl())
->rows()
.size()),
7);
for (int64_t i = 0; i < 84; ++i) {
ASSERT_EQ(sum[i],
std::dynamic_pointer_cast<phi::SelectedRows>(out.impl())
->value()
.data<float>()[i]);
}
}
} // namespace tests
} // namespace paddle
......@@ -77,6 +77,10 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext {
return selected_rows_inputs.count(name) > 0;
}
bool IsSelectedRowsInputs(const std::string& name) const override {
return selected_rows_inputs.count(name) > 0;
}
// add member if needed
bool IsDenseTensorVectorInput(const std::string& name) const override {
return false;
......
......@@ -1040,9 +1040,7 @@ class Optimizer(object):
assert regularization_term is not None
if framework.in_dygraph_mode():
if grad.is_dense() and regularization_term.is_dense():
return _C_ops.add_n([grad, regularization_term])
return _legacy_C_ops.sum([grad, regularization_term])
return _C_ops.add_n([grad, regularization_term])
elif framework._in_legacy_dygraph():
return _legacy_C_ops.sum([grad, regularization_term])
......
......@@ -1501,9 +1501,6 @@ def add_n(inputs, name=None):
if in_dygraph_mode():
if isinstance(inputs, Variable):
inputs = [inputs]
for x in inputs:
if not x.is_dense():
return _legacy_C_ops.sum(inputs, 'use_mkldnn', False)
return _C_ops.add_n(inputs)
if _in_legacy_dygraph():
if isinstance(inputs, Variable):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册