未验证 提交 a02eb143 编写于 作者: Z zhangkaihuo 提交者: GitHub

[Sparse] Support static graph (#46245)

上级 52009d19
......@@ -237,7 +237,7 @@ cc_test(
cc_library(
var_type_traits
SRCS var_type_traits.cc
DEPS framework_proto scope tensor_array)
DEPS framework_proto scope tensor_array sparse_coo_tensor)
if(WITH_GPU)
target_link_libraries(var_type_traits dynload_cuda)
endif()
......@@ -1185,7 +1185,8 @@ cc_library(
phi
phi_api_utils
op_info
shape_inference)
shape_inference
sparse_coo_tensor)
cc_test(
infershape_utils_test
SRCS infershape_utils_test.cc
......
......@@ -22,10 +22,11 @@ limitations under the License. */
namespace paddle {
namespace framework {
using FeedType = paddle::variant<LoDTensor, Strings>;
using FeedType = paddle::variant<LoDTensor, Strings, phi::SparseCooTensor>;
using FeedList = std::vector<FeedType>;
using FetchType = paddle::variant<LoDTensor, LoDTensorArray, framework::Vocab>;
using FetchType = paddle::
variant<LoDTensor, LoDTensorArray, framework::Vocab, phi::SparseCooTensor>;
using FetchList = std::vector<FetchType>;
using FetchUnmergedList = std::vector<std::vector<FetchType>>;
......@@ -52,6 +53,13 @@ inline bool data_is_string_tensor(const FeedType &data) {
return false;
}
inline bool data_is_sparse_coo_tensor(const FetchType &data) {
if (data.type() == typeid(phi::SparseCooTensor)) {
return true;
}
return false;
}
static const char kFeedOpType[] = "feed";
static const char kFetchOpType[] = "fetch";
......
......@@ -154,6 +154,8 @@ message VarType {
FEED_LIST = 28;
// The data type of phi::StringTensor
PSTRING = 29;
// the data type of phi::SparseCooTensor
SPARSE_COO = 30;
}
required Type type = 1;
......@@ -186,6 +188,7 @@ message VarType {
optional TensorDesc string = 8;
optional TensorDesc strings = 9;
optional TensorDesc vocab = 10;
optional TensorDesc sparse_coo = 11;
}
message VarDesc {
......
......@@ -110,6 +110,11 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
});
}
bool IsSparseCooTensorInput(const std::string& name) const override {
auto var_type = ctx_.GetInputVarType(name);
return var_type == proto::VarType::SPARSE_COO;
}
bool IsDenseTensorOutput(const std::string& name) const override {
auto var_types = ctx_.GetOutputsVarType(name);
return std::all_of(var_types.begin(),
......@@ -192,6 +197,8 @@ DDim CompatMetaTensor::dims() const {
return var->Get<phi::DenseTensor>().dims();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().GetCompleteDims();
} else if (var->IsType<phi::SparseCooTensor>()) {
return var->Get<phi::SparseCooTensor>().dims();
} else if (var->IsType<framework::LoDTensorArray>()) {
// use tensor array size as dims
auto& tensor_array = var->Get<framework::LoDTensorArray>();
......@@ -217,6 +224,8 @@ phi::DataType CompatMetaTensor::dtype() const {
return var->Get<phi::DenseTensor>().dtype();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().dtype();
} else if (var->IsType<phi::SparseCooTensor>()) {
return var->Get<phi::SparseCooTensor>().dtype();
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported get dtype from LoDTensorArray now
......@@ -239,6 +248,8 @@ DataLayout CompatMetaTensor::layout() const {
return var->Get<phi::DenseTensor>().layout();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().layout();
} else if (var->IsType<phi::SparseCooTensor>()) {
return var->Get<phi::SparseCooTensor>().layout();
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported get layout from LoDTensorArray now
......@@ -264,6 +275,9 @@ void CompatMetaTensor::set_dims(const DDim& dims) {
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
} else if (var->IsType<phi::SelectedRows>()) {
var->GetMutable<phi::SelectedRows>()->set_height(dims[0]);
} else if (var->IsType<phi::SparseCooTensor>()) {
auto* tensor = var->GetMutable<phi::SparseCooTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
} else if (var->IsType<framework::LoDTensorArray>()) {
auto* tensor_array = var->GetMutable<framework::LoDTensorArray>();
// Note: Here I want enforce `tensor_array->size() == 0UL`, because
......@@ -295,6 +309,9 @@ void CompatMetaTensor::set_dtype(phi::DataType dtype) {
} else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
} else if (var->IsType<phi::SparseCooTensor>()) {
auto* tensor = var->GetMutable<phi::SparseCooTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported set dtype for LoDTensorArray now
......@@ -318,6 +335,9 @@ void CompatMetaTensor::set_layout(DataLayout layout) {
} else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
} else if (var->IsType<phi::SparseCooTensor>()) {
auto* tensor = var->GetMutable<phi::SparseCooTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported set dtype for LoDTensorArray now
......
......@@ -2382,6 +2382,17 @@ void OperatorWithKernel::ParseInputDataType(
t = &var->Get<LoDTensor>();
} else if (var->IsType<phi::SelectedRows>()) {
t = &(var->Get<phi::SelectedRows>().value());
} else if (var->IsType<phi::SparseCooTensor>()) {
const phi::SparseCooTensor* sp_t = &(var->Get<phi::SparseCooTensor>());
PADDLE_ENFORCE_EQ(
sp_t->initialized(),
true,
platform::errors::InvalidArgument("The %s Op's Input Variable `%s` "
"contains uninitialized Tensor.",
Type(),
name));
*data_type = paddle::framework::TransToProtoVarType(sp_t->dtype());
return;
} else if (var->IsType<LoDTensorArray>()) {
auto t_arr = &var->Get<LoDTensorArray>();
for (size_t j = 0; j < t_arr->size(); j++) {
......@@ -2419,6 +2430,29 @@ void OperatorWithKernel::ParseMultiInputDataType(
t = &var->Get<LoDTensor>();
} else if (var->IsType<phi::SelectedRows>()) {
t = &(var->Get<phi::SelectedRows>().value());
} else if (var->IsType<phi::SparseCooTensor>()) {
const phi::SparseCooTensor* sp_t = &(var->Get<phi::SparseCooTensor>());
PADDLE_ENFORCE_EQ(
sp_t->initialized(),
true,
platform::errors::InvalidArgument("The %s Op's Input Variable `%s` "
"contains uninitialized Tensor.",
Type(),
name));
proto::VarType::Type tmp =
paddle::framework::TransToProtoVarType(sp_t->dtype());
PADDLE_ENFORCE(tmp == *data_type || *data_type == default_data_type,
platform::errors::InvalidArgument(
"The DataType of %s Op's duplicable or different "
"slot Variable %s must be "
"consistent or reigster GetExpectedKernelType. The "
"current variable type is (%s), but the "
"previous variable type is (%s).",
Type(),
name,
DataTypeToString(tmp),
DataTypeToString(*data_type)));
*data_type = tmp;
} else if (var->IsType<LoDTensorArray>()) {
auto t_arr = &var->Get<LoDTensorArray>();
for (size_t j = 0; j < t_arr->size(); j++) {
......@@ -2663,6 +2697,9 @@ void OperatorWithKernel::BuildPhiKernelContext(
} else if (var->IsType<phi::SelectedRows>()) {
tensor_in = &(var->Get<phi::SelectedRows>());
phi_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
} else if (var->IsType<phi::SparseCooTensor>()) {
tensor_in = &(var->Get<phi::SparseCooTensor>());
phi_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
} else if (var->IsType<framework::LoDTensorArray>()) {
need_prepare_phi_data_ = true;
tensor_in = &(var->Get<framework::LoDTensorArray>());
......@@ -2708,6 +2745,9 @@ void OperatorWithKernel::BuildPhiKernelContext(
} else if (var->template IsType<phi::SelectedRows>()) {
tensor_out = var->template GetMutable<phi::SelectedRows>();
phi_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} else if (var->template IsType<phi::SparseCooTensor>()) {
tensor_out = var->template GetMutable<phi::SparseCooTensor>();
phi_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} else if (var->template IsType<framework::LoDTensorArray>()) {
tensor_out = var->template GetMutable<framework::LoDTensorArray>();
// Note: If the input LoDTensorArray size is 0, the output
......
......@@ -531,6 +531,11 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
});
}
bool IsSparseCooTensorInput(const std::string& name) const override {
const auto* var = ctx_.InputVar(name);
return var->IsType<phi::SparseCooTensor>();
}
bool IsDenseTensorOutput(const std::string& name) const override {
auto vars = ctx_.MultiOutputVar(name);
return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
namespace paddle {
namespace framework {
......
......@@ -237,6 +237,8 @@ const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
return desc_.type().strings();
case proto::VarType::VOCAB:
return desc_.type().vocab();
case proto::VarType::SPARSE_COO:
return desc_.type().sparse_coo();
default:
PADDLE_THROW(platform::errors::Unavailable(
"Getting 'tensor_desc' is not supported by the %s type variable.",
......@@ -284,6 +286,8 @@ proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
return desc_.mutable_type()->mutable_strings();
case proto::VarType::VOCAB:
return desc_.mutable_type()->mutable_vocab();
case proto::VarType::SPARSE_COO:
return desc_.mutable_type()->mutable_sparse_coo();
default:
PADDLE_THROW(
platform::errors::Unavailable("Getting 'mutable_tensor_desc' is not "
......
......@@ -33,6 +33,7 @@ inline proto::VarType::Type ToVarType(int type) {
switch (type) {
case proto::VarType::LOD_TENSOR:
case proto::VarType::SELECTED_ROWS:
case proto::VarType::SPARSE_COO:
case proto::VarType::LOD_RANK_TABLE:
case proto::VarType::LOD_TENSOR_ARRAY:
case proto::VarType::FETCH_LIST:
......@@ -59,6 +60,9 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
case proto::VarType::SELECTED_ROWS:
visitor(var.Get<phi::SelectedRows>());
return;
case proto::VarType::SPARSE_COO:
visitor(var.Get<phi::SparseCooTensor>());
return;
case proto::VarType::READER:
visitor(var.Get<ReaderHolder>());
return;
......
......@@ -54,6 +54,7 @@
namespace phi {
class DenseTensor;
class SelectedRows;
class SparseCooTensor;
} // namespace phi
// Users should add forward declarations here
......@@ -180,6 +181,7 @@ struct VarTypeRegistryImpl {
using VarTypeRegistry = detail::VarTypeRegistryImpl<
Tensor,
phi::SelectedRows,
phi::SparseCooTensor,
std::vector<Scope *>,
LoDRankTable,
Strings,
......@@ -252,6 +254,7 @@ REG_PROTO_VAR_TYPE_TRAIT(float, proto::VarType::FP32);
REG_PROTO_VAR_TYPE_TRAIT(Vocab, proto::VarType::VOCAB);
REG_PROTO_VAR_TYPE_TRAIT(String, proto::VarType::STRING);
REG_PROTO_VAR_TYPE_TRAIT(Strings, proto::VarType::STRINGS);
REG_PROTO_VAR_TYPE_TRAIT(phi::SparseCooTensor, proto::VarType::SPARSE_COO);
/** End of variable type registration */
......
......@@ -52,6 +52,8 @@ void InitializeVariable(Variable *var, proto::VarType::Type var_type) {
var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarType::RAW) {
// GetMutable will be called in operator
} else if (var_type == proto::VarType::SPARSE_COO) {
var->GetMutable<phi::SparseCooTensor>();
} else {
PADDLE_THROW(platform::errors::Unavailable(
"Variable type %d is not in "
......
......@@ -108,6 +108,10 @@ bool PluginArgumentMappingContext::IsSelectedRowsInputs(
const std::string& name) const {
return false;
}
bool PluginArgumentMappingContext::IsSparseCooTensorInput(
const std::string& name) const {
return false;
}
bool PluginArgumentMappingContext::IsDenseTensorVectorInput(
const std::string& name) const {
return false;
......
......@@ -48,6 +48,8 @@ class PluginArgumentMappingContext : public ::phi::ArgumentMappingContext {
bool IsSelectedRowsInputs(const std::string& name) const override;
bool IsSparseCooTensorInput(const std::string& name) const override;
bool IsDenseTensorVectorInput(const std::string& name) const override;
bool IsDenseTensorOutput(const std::string& name) const override;
......
......@@ -11,6 +11,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/tensor_utils.h"
namespace paddle {
namespace framework {
......@@ -61,6 +62,22 @@ class FeedVariableVisitor {
*out_str = in_str;
}
void operator()(const phi::SparseCooTensor &in_tensor) const {
phi::SparseCooTensor *out_tensor =
out_var_->GetMutable<phi::SparseCooTensor>();
if (platform::is_same_place(in_tensor.place(), place_)) {
*out_tensor = in_tensor;
} else {
platform::DeviceContext *context =
platform::DeviceContextPool::Instance().Get(place_);
phi::DenseTensor indices, values;
framework::TensorCopy(in_tensor.indices(), place_, *context, &indices);
framework::TensorCopy(in_tensor.values(), place_, *context, &values);
out_tensor->SetMember(indices, values, in_tensor.meta());
}
}
private:
framework::Variable *out_var_;
const platform::Place &place_;
......
......@@ -123,6 +123,9 @@ class FetchOp : public framework::OperatorBase {
auto &src_item = fetch_var->Get<framework::Vocab>();
auto *dst_item = &(PADDLE_GET(framework::Vocab, fetch_list->at(col)));
*dst_item = src_item;
} else if (fetch_var->IsType<phi::SparseCooTensor>()) {
auto &src_item = fetch_var->Get<phi::SparseCooTensor>();
fetch_list->at(col) = src_item;
} else {
auto &src_item = fetch_var->Get<framework::LoDTensorArray>();
framework::LoDTensorArray tmp(src_item.size());
......
......@@ -98,6 +98,12 @@ class FetchV2Op : public framework::OperatorWithKernel {
return framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace());
}
} else if (fetch_var->IsType<phi::SparseCooTensor>()) {
auto &src_item = fetch_var->Get<phi::SparseCooTensor>();
if (!src_item.initialized()) {
return framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace());
}
} else {
auto &src_item = fetch_var->Get<framework::LoDTensorArray>();
if (src_item.empty() || !src_item[0].IsInitialized()) {
......@@ -163,6 +169,12 @@ class FetchV2Kernel {
dst_item->ShareDataWith(src_item);
dst_item->set_lod(src_item.lod());
}
} else if (fetch_var->IsType<phi::SparseCooTensor>()) {
auto &src_item = fetch_var->Get<phi::SparseCooTensor>();
if (!src_item.initialized()) {
return;
}
fetch_list->at(col) = src_item;
} else {
auto &src_item = fetch_var->Get<framework::LoDTensorArray>();
framework::LoDTensorArray tmp(src_item.size());
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/sparse/binary.h"
#include "paddle/phi/infermeta/sparse/unary.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class SparseSparseCooTensorOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("values", "(Tensor), input 0 of sparse_coo_tensor op.");
AddInput("indices", "(Tensor), input 1 of sparse_coo_tensor op.");
AddOutput("out", "(Tensor), output 0 of sparse_coo_tensor op.");
AddAttr<std::vector<int>>(
"dense_shape", "(vector<int>), attribute 0 for sparse_coo_tensor op.");
AddComment(R"DOC(
TODO: Documentation of sparse_coo_tensor op.
)DOC");
}
};
class SparseSparseCooTensorOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
DECLARE_INFER_SHAPE_FUNCTOR(
sparse_sparse_coo_tensor,
SparseSparseCooTensorInferShapeFunctor,
PD_INFER_META(phi::sparse::SparseCooTensorInferMeta));
class SparseValuesOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("x", "(Tensor), input 0 of sparse_values op.");
AddOutput("out", "(Tensor), output 0 of sparse_values op.");
AddComment(R"DOC(
TODO: Documentation of sparse_values op.
)DOC");
}
};
class SparseValuesOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
DECLARE_INFER_SHAPE_FUNCTOR(sparse_values,
SparseValuesInferShapeFunctor,
PD_INFER_META(phi::sparse::ValuesInferMeta));
class SparseIndicesOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("x", "(Tensor), input 0 of sparse_indices op.");
AddOutput("out", "(Tensor), output 0 of sparse_indices op.");
AddComment(R"DOC(
TODO: Documentation of sparse_indices op.
)DOC");
}
};
class SparseIndicesOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
DECLARE_INFER_SHAPE_FUNCTOR(sparse_indices,
SparseIndicesInferShapeFunctor,
PD_INFER_META(phi::sparse::IndicesInferMeta));
class SparseToDenseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("x", "(Tensor), input 0 of sparse_to_dense op.");
AddOutput("out", "(Tensor), output 0 of sparse_to_dense op.");
AddComment(R"DOC(
TODO: Documentation of sparse_to_dense op.
)DOC");
}
};
class SparseToDenseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
DECLARE_INFER_SHAPE_FUNCTOR(sparse_to_dense,
SparseToDenseInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
class SparseReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("x", "(Tensor), input 0 of sparse_relu op.");
AddOutput("out", "(Tensor), output 0 of sparse_relu op.");
AddComment(R"DOC(
TODO: Documentation of sparse_relu op.
)DOC");
}
};
class SparseReluOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
DECLARE_INFER_SHAPE_FUNCTOR(sparse_relu,
SparseReluInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
class SparseConv3dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("x", "(Tensor), input 0 of sparse_conv3d op.");
AddInput("kernel", "(Tensor), input 1 of sparse_conv3d op.");
AddOutput("out", "(Tensor), output 0 of sparse_conv3d op.");
AddOutput("rulebook", "(Tensor), output 1 of sparse_conv3d op.");
AddOutput("counter", "(Tensor), output 2 of sparse_conv3d op.");
AddAttr<std::vector<int>>(
"paddings", "(vector<int>), attribute 0 for sparse_conv3d op.");
AddAttr<std::vector<int>>(
"dilations", "(vector<int>), attribute 1 for sparse_conv3d op.");
AddAttr<std::vector<int>>(
"strides", "(vector<int>), attribute 2 for sparse_conv3d op.");
AddAttr<int>("groups", "(int), attribute 3 for sparse_conv3d op.");
AddAttr<bool>("subm", "(bool), attribute 4 for conv3d_coo op.");
AddAttr<std::string>("key", "(string), attribute 5 for sparse_conv3d op.")
.SetDefault("");
AddComment(R"DOC(
TODO: Documentation of sparse_conv3d op.
)DOC");
}
};
class SparseConv3dOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
DECLARE_INFER_SHAPE_FUNCTOR(sparse_conv3d,
SparseConv3dInferShapeFunctor,
PD_INFER_META(phi::sparse::Conv3dInferMeta));
class SparseAddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("x", "(Tensor), input 0 of sparse_add op.");
AddInput("y", "(Tensor), input 1 of sparse_add op.");
AddOutput("out", "(Tensor), output 0 of sparse_add op.");
AddComment(R"DOC(
TODO: Documentation of sparse_add op.
)DOC");
}
};
class SparseAddOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
DECLARE_INFER_SHAPE_FUNCTOR(sparse_add,
SparseAddInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sparse_sparse_coo_tensor,
ops::SparseSparseCooTensorOp,
ops::SparseSparseCooTensorOpMaker,
ops::SparseSparseCooTensorInferShapeFunctor);
REGISTER_OPERATOR(sparse_values,
ops::SparseValuesOp,
ops::SparseValuesOpMaker,
ops::SparseValuesInferShapeFunctor);
REGISTER_OPERATOR(sparse_indices,
ops::SparseIndicesOp,
ops::SparseIndicesOpMaker,
ops::SparseIndicesInferShapeFunctor);
REGISTER_OPERATOR(sparse_to_dense,
ops::SparseToDenseOp,
ops::SparseToDenseOpMaker,
ops::SparseToDenseInferShapeFunctor);
REGISTER_OPERATOR(sparse_relu,
ops::SparseReluOp,
ops::SparseReluOpMaker,
ops::SparseReluInferShapeFunctor);
REGISTER_OPERATOR(sparse_conv3d,
ops::SparseConv3dOp,
ops::SparseConv3dOpMaker,
ops::SparseConv3dInferShapeFunctor);
REGISTER_OPERATOR(sparse_add,
ops::SparseAddOp,
ops::SparseAddOpMaker,
ops::SparseAddInferShapeFunctor);
......@@ -275,7 +275,8 @@ void BindVarDsec(pybind11::module *m) {
.value("RAW", pd::proto::VarType::RAW)
.value("STRING", pd::proto::VarType::STRING)
.value("STRINGS", pd::proto::VarType::STRINGS)
.value("VOCAB", pd::proto::VarType::VOCAB);
.value("VOCAB", pd::proto::VarType::VOCAB)
.value("SPARSE_COO", pd::proto::VarType::SPARSE_COO);
}
void BindOpDesc(pybind11::module *m) {
......
......@@ -1937,6 +1937,9 @@ All parameter, weight, gradient are variables in Paddle.
if (data_is_lod_tensor(self[i])) {
auto &data = PADDLE_GET(LoDTensor, self[i]);
res[i] = py::cast(std::move(data));
} else if (data_is_sparse_coo_tensor(self[i])) {
auto &data = PADDLE_GET(phi::SparseCooTensor, self[i]);
res[i] = py::cast(std::move(data));
} else {
auto &data = PADDLE_GET(LoDTensorArray, self[i]);
py::list tmp(data.size());
......
......@@ -1105,6 +1105,20 @@ void BindTensor(pybind11::module &m) { // NOLINT
std::copy(rows.begin(), rows.end(), std::back_inserter(new_rows));
return new_rows;
});
py::class_<phi::SparseCooTensor>(m, "SparseCooTensor")
.def("__init__",
[](phi::SparseCooTensor &instance) {
new (&instance) phi::SparseCooTensor();
})
.def("numel",
[](const phi::SparseCooTensor &self) -> int64_t {
return self.numel();
})
.def("indices",
[](const phi::SparseCooTensor &self) -> framework::Tensor {
return self.indices();
});
}
} // namespace pybind
......
......@@ -109,6 +109,7 @@ class ArgumentMappingContext {
virtual bool IsDenseTensorInputs(const std::string& name) const = 0;
virtual bool IsSelectedRowsInput(const std::string& name) const = 0;
virtual bool IsSelectedRowsInputs(const std::string& name) const = 0;
virtual bool IsSparseCooTensorInput(const std::string& name) const = 0;
// For compatibility with LoDTensorArray
virtual bool IsDenseTensorVectorInput(const std::string& name) const = 0;
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h"
#include "paddle/phi/kernels/elementwise_kernel.h"
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
// TODO(zhangkaihuo): add csr op
KernelSignature SparseSparseCooTensorOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"sparse_coo_tensor", {"values", "indices"}, {"dense_shape"}, {"out"});
}
KernelSignature SparseValuesOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsSparseCooTensorInput("x")) {
return KernelSignature("values_coo", {"x"}, {}, {"out"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
}
KernelSignature SparseIndicesOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsSparseCooTensorInput("x")) {
return KernelSignature("indices_coo", {"x"}, {}, {"out"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
}
KernelSignature SparseToDenseOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsSparseCooTensorInput("x")) {
return KernelSignature("coo_to_dense", {"x"}, {}, {"out"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
}
KernelSignature SparseReluOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsSparseCooTensorInput("x")) {
return KernelSignature("relu_coo", {"x"}, {}, {"out"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
}
KernelSignature SparseConv3dOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsSparseCooTensorInput("x")) {
return KernelSignature(
"conv3d_coo",
{"x", "kernel"},
{"paddings", "dilations", "strides", "groups", "subm", "key"},
{"out", "rulebook", "counter"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
}
KernelSignature SparseAddOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsSparseCooTensorInput("x") && ctx.IsSparseCooTensorInput("y")) {
return KernelSignature("add_coo_coo", {"x", "y"}, {}, {"out"});
} else if (ctx.IsSparseCooTensorInput("x") && ctx.IsDenseTensorInput("y")) {
return KernelSignature("add_coo_dense", {"x", "y"}, {}, {"out"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(sparse_sparse_coo_tensor, sparse_coo_tensor);
PD_REGISTER_ARG_MAPPING_FN(sparse_sparse_coo_tensor,
phi::SparseSparseCooTensorOpArgumentMapping);
PD_REGISTER_BASE_KERNEL_NAME(sparse_values, values_coo);
PD_REGISTER_ARG_MAPPING_FN(sparse_values, phi::SparseValuesOpArgumentMapping);
PD_REGISTER_BASE_KERNEL_NAME(sparse_indices, indices_coo);
PD_REGISTER_ARG_MAPPING_FN(sparse_indices, phi::SparseIndicesOpArgumentMapping);
PD_REGISTER_BASE_KERNEL_NAME(sparse_to_dense, coo_to_dense);
PD_REGISTER_ARG_MAPPING_FN(sparse_to_dense,
phi::SparseToDenseOpArgumentMapping);
PD_REGISTER_BASE_KERNEL_NAME(sparse_relu, relu_coo);
PD_REGISTER_ARG_MAPPING_FN(sparse_relu, phi::SparseReluOpArgumentMapping);
PD_REGISTER_BASE_KERNEL_NAME(sparse_conv3d, conv3d_coo);
PD_REGISTER_ARG_MAPPING_FN(sparse_conv3d, phi::SparseConv3dOpArgumentMapping);
PD_REGISTER_BASE_KERNEL_NAME(sparse_add, add_coo_coo);
PD_REGISTER_ARG_MAPPING_FN(sparse_add, phi::SparseAddOpArgumentMapping);
......@@ -86,6 +86,10 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext {
return false;
}
bool IsSparseCooTensorInput(const std::string& name) const override {
return false;
}
bool IsDenseTensorOutput(const std::string& name) const override {
return dense_tensor_outputs.count(name) > 0;
}
......
......@@ -1377,6 +1377,9 @@ class Variable(object):
type = core.VarDesc.VarType.STRINGS
lod_level = None
if type == core.VarDesc.VarType.SPARSE_COO:
lod_level = None
self.belong_to_optimizer = belong_to_optimizer
self.error_clip = error_clip
......
......@@ -408,6 +408,30 @@ class LayerHelperBase(object):
persistable=False,
stop_gradient=stop_gradient)
def create_sparse_variable_for_type_inference(self,
dtype,
stop_gradient=False,
shape=None):
"""Create a temporary sparse variable that should be type inferred layer.
Note:
The default type will be set to SPARSE_COO. However, when
the var is used as operator output, its type will be updated
based on operator's `VarTypeInference` implementation in
infer_var_type.
"""
# set global dtype
if not dtype:
dtype = self.__dtype
return self.main_program.current_block().create_var(
name=unique_name.generate_with_ignorable_key(".".join(
[self.name, 'tmp'])),
dtype=dtype,
shape=shape,
type=core.VarDesc.VarType.SPARSE_COO,
persistable=False,
stop_gradient=stop_gradient)
def create_variable(self, *args, **kwargs):
"""Create Variable for this layers.
Returns created Variable.
......
......@@ -78,6 +78,10 @@ def monkey_patch_variable():
tmp_name = unique_tmp_name()
return block.create_var(name=tmp_name, dtype=dtype)
def create_new_tmp_sparse_var(block, dtype, type):
tmp_name = unique_tmp_name()
return block.create_var(name=tmp_name, dtype=dtype, type=type)
def create_tensor(block, value, dtype, shape):
value = float(value)
var = create_new_tmp_var(block, dtype)
......@@ -431,6 +435,33 @@ def monkey_patch_variable():
__impl__.__name__ = method_name
return __impl__
def values(var):
block = current_block(var)
out = create_new_tmp_var(block, var.dtype)
block.append_op(type="sparse_values",
inputs={"x": [var]},
outputs={"out": [out]},
attrs={})
return out
def indices(var):
block = current_block(var)
out = create_new_tmp_var(block, var.dtype)
block.append_op(type="sparse_indices",
inputs={"x": [var]},
outputs={"out": [out]},
attrs={})
return out
def to_dense(var):
block = current_block(var)
out = create_new_tmp_var(block, var.dtype)
block.append_op(type="sparse_to_dense",
inputs={"x": [var]},
outputs={"out": [out]},
attrs={})
return out
variable_methods = [
# b=-a
('__neg__', _neg_),
......@@ -483,7 +514,10 @@ def monkey_patch_variable():
('__lt__', _binary_creator_('__lt__', 'less_than', False, None)),
('__le__', _binary_creator_('__le__', 'less_equal', False, None)),
('__gt__', _binary_creator_('__gt__', 'greater_than', False, None)),
('__ge__', _binary_creator_('__ge__', 'greater_equal', False, None))
('__ge__', _binary_creator_('__ge__', 'greater_equal', False, None)),
('values', values),
('indices', indices),
('to_dense', to_dense),
]
global _already_patch_variable
......
......@@ -18,6 +18,7 @@ import paddle
from paddle import _C_ops, _legacy_C_ops
from paddle.fluid import core
from paddle.fluid.framework import _test_eager_guard
import paddle.incubate.sparse as sparse
class TestSparseConv(unittest.TestCase):
......@@ -158,3 +159,66 @@ class TestSparseConv(unittest.TestCase):
sp_conv3d.bias.grad.numpy(),
atol=1e-5,
rtol=1e-5)
class TestStatic(unittest.TestCase):
def test(self):
paddle.enable_static()
indices = paddle.static.data(name='indices',
shape=[4, 4],
dtype='int32')
values = paddle.static.data(name='values',
shape=[4, 1],
dtype='float32')
dense_shape = [1, 1, 3, 4, 1]
sp_x = sparse.sparse_coo_tensor(indices, values, dense_shape)
weight_shape = [1, 3, 3, 1, 1]
weight = paddle.static.data(name='weight',
shape=weight_shape,
dtype='float32')
bias_shape = [1]
bias = paddle.static.data(name='bias',
shape=bias_shape,
dtype='float32')
out = sparse.nn.functional.conv3d(sp_x,
weight,
bias,
stride=1,
padding=0,
dilation=1,
groups=1,
data_format="NDHWC")
sp_out = sparse.nn.functional.relu(out)
out_indices = sp_out.indices()
out_values = sp_out.values()
out = sp_out.to_dense()
exe = paddle.static.Executor()
indices_data = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]]
values_data = [[1.0], [2.0], [3.0], [4.0]]
weight_data = np.array([[[[[1], [1], [1]], [[1], [1], [1]],
[[1], [1], [1]]]]]).astype('float32')
weight_data = weight_data.reshape(weight_shape)
bias_data = np.array([1]).astype('float32')
fetch = exe.run(feed={
'indices': indices_data,
'values': values_data,
'weight': weight_data,
'bias': bias_data
},
fetch_list=[out, out_indices, out_values],
return_numpy=True)
correct_out = np.array([[[[[5.0], [11.0]]]]]).astype('float64')
correct_out_values = [[5.0], [11.0]]
assert np.array_equal(correct_out, fetch[0])
assert np.array_equal(correct_out_values, fetch[2])
assert out_indices.dtype == paddle.int32
paddle.disable_static()
if __name__ == "__main__":
unittest.main()
......@@ -14,6 +14,9 @@
from paddle import _C_ops, _legacy_C_ops
from paddle.fluid.framework import dygraph_only, core
from paddle import in_dynamic_mode
from paddle.fluid.layer_helper import LayerHelper
from .unary import cast
__all__ = []
......@@ -254,7 +257,19 @@ def add(x, y, name=None):
"""
if y.dtype != x.dtype:
y = cast(y, None, x.dtype)
return _C_ops.sparse_add(x, y)
if in_dynamic_mode():
return _C_ops.sparse_add(x, y)
else:
op_type = 'sparse_add'
inputs = {'x': x, 'y': y}
helper = LayerHelper(op_type)
out = helper.create_sparse_variable_for_type_inference(x.dtype)
helper.append_op(type=op_type,
inputs=inputs,
outputs={'out': out},
attrs={})
return out
@dygraph_only
......
......@@ -18,6 +18,8 @@ from paddle.fluid.framework import core, dygraph_only
from paddle.fluid.framework import _current_expected_place, _get_paddle_place
from paddle.tensor import to_tensor, max
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype
from paddle import in_dynamic_mode
from paddle.fluid.layer_helper import LayerHelper
import numpy as np
......@@ -64,7 +66,6 @@ def _check_indices_dtype(dtype):
)
@dygraph_only
def sparse_coo_tensor(indices,
values,
shape=None,
......@@ -114,53 +115,68 @@ def sparse_coo_tensor(indices,
# values=[1., 2., 3.])
"""
place = _get_place(place)
if in_dynamic_mode():
place = _get_place(place)
if not isinstance(indices, core.eager.Tensor):
indices = to_tensor(indices,
dtype=None,
place=place,
stop_gradient=True)
if not isinstance(values, core.eager.Tensor):
values = to_tensor(values, dtype, place, stop_gradient)
if len(indices.shape) != 2:
raise ValueError("'indices' must be 2-D.")
if not isinstance(indices, core.eager.Tensor):
indices = to_tensor(indices,
dtype=None,
place=place,
stop_gradient=True)
if not isinstance(values, core.eager.Tensor):
values = to_tensor(values, dtype, place, stop_gradient)
if len(indices.shape) != 2:
raise ValueError("'indices' must be 2-D.")
nnz = indices.shape[1]
sparse_dim = indices.shape[0]
nnz = indices.shape[1]
sparse_dim = indices.shape[0]
_check_indices_dtype(indices.dtype)
_check_indices_dtype(indices.dtype)
if nnz != values.shape[0]:
raise ValueError(
"the indices and values must have same number of non-zero, but get {} and {}"
.format(nnz, values.shape[0]))
if nnz != values.shape[0]:
raise ValueError(
"the indices and values must have same number of non-zero, but get {} and {}"
.format(nnz, values.shape[0]))
dense_dim = len(values.shape) - 1
dense_dim = len(values.shape) - 1
if not indices.place._equals(place):
indices = indices._copy_to(place, False)
if not indices.place._equals(place):
indices = indices._copy_to(place, False)
if not values.place._equals(place):
values = values._copy_to(place, False)
values = _handle_dtype(values, dtype)
values.stop_gradient = stop_gradient
if not values.place._equals(place):
values = values._copy_to(place, False)
values = _handle_dtype(values, dtype)
values.stop_gradient = stop_gradient
min_shape = _infer_dense_shape(indices, values)
min_shape = _infer_dense_shape(indices, values)
if shape is None:
shape = min_shape
else:
if shape < min_shape:
raise ValueError(
"the minimun shape required is {}, but get {}".format(
min_shape, shape))
if len(shape) != sparse_dim + dense_dim:
raise ValueError(
"the number of dimensions(len(shape) must be sparse_dim({}) + dense_dim({}), but get {}"
.format(sparse_dim, dense_dim, len(shape)))
if shape is None:
shape = min_shape
else:
if shape < min_shape:
raise ValueError(
"the minimun shape required is {}, but get {}".format(
min_shape, shape))
if len(shape) != sparse_dim + dense_dim:
raise ValueError(
"the number of dimensions(len(shape) must be sparse_dim({}) + dense_dim({}), but get {}"
.format(sparse_dim, dense_dim, len(shape)))
return _C_ops.sparse_sparse_coo_tensor(values, indices, shape)
return _C_ops.sparse_sparse_coo_tensor(values, indices, shape)
else:
op_type = 'sparse_sparse_coo_tensor'
inputs = {'values': values, 'indices': indices}
if shape[0] is None:
shape[0] = -1
attrs = {'dense_shape': shape}
helper = LayerHelper(op_type)
out = helper.create_sparse_variable_for_type_inference(dtype)
helper.append_op(type=op_type,
inputs=inputs,
outputs={'out': out},
attrs=attrs)
return out
#TODO: need to support shape is None
......
......@@ -16,9 +16,10 @@ __all__ = []
from paddle import _C_ops, _legacy_C_ops
from paddle.fluid.framework import dygraph_only
from paddle import in_dynamic_mode
from paddle.fluid.layer_helper import LayerHelper
@dygraph_only
def relu(x, name=None):
"""
sparse relu activation, requiring x to be a SparseCooTensor or SparseCsrTensor.
......@@ -45,7 +46,17 @@ def relu(x, name=None):
out = paddle.incubate.sparse.nn.functional.relu(sparse_x)
# [0., 0., 1.]
"""
return _C_ops.sparse_relu(x)
if in_dynamic_mode():
return _C_ops.sparse_relu(x)
else:
op_type = 'sparse_relu'
helper = LayerHelper(op_type)
out = helper.create_sparse_variable_for_type_inference(x.dtype)
helper.append_op(type=op_type,
inputs={'x': x},
outputs={'out': out},
attrs={})
return out
@dygraph_only
......
......@@ -19,8 +19,8 @@ from paddle.fluid.layers.utils import convert_to_list
from paddle.fluid.layers.nn import elementwise_add
from ...creation import sparse_coo_tensor
from ...binary import add
from paddle.tensor import arange
from paddle.nn.functional.conv import _update_padding_nd
from paddle.fluid.layer_helper import LayerHelper
def _conv3d(x,
......@@ -34,7 +34,6 @@ def _conv3d(x,
key=None,
data_format="NDHWC",
name=None):
assert in_dynamic_mode(), "Currently, only support dynamic mode"
assert groups == 1, "Currently, only support groups=1"
dims = 3
......@@ -63,15 +62,41 @@ def _conv3d(x,
padding, padding_algorithm = _update_padding_nd(padding, channel_last, dims)
stride = convert_to_list(stride, dims, 'stride')
dilation = convert_to_list(dilation, dims, 'dilation')
op_type = "conv3d"
pre_bias = _C_ops.sparse_conv3d(x, weight, padding, dilation, stride,
groups, subm,
key if key is not None else "")
if bias is not None:
return add(pre_bias, bias)
if in_dynamic_mode():
pre_bias = _C_ops.sparse_conv3d(x, weight, padding, dilation, stride,
groups, subm,
key if key is not None else "")
if bias is not None:
return add(pre_bias, bias)
else:
return pre_bias
else:
return pre_bias
inputs = {'x': x, 'kernel': weight}
attrs = {
'paddings': padding,
'dilations': dilation,
'strides': stride,
'groups': groups,
'subm': subm,
'key': key
}
op_type = 'sparse_conv3d'
helper = LayerHelper(op_type, **locals())
rulebook = helper.create_variable_for_type_inference(dtype='int32',
stop_gradient=True)
counter = helper.create_variable_for_type_inference(dtype='int32',
stop_gradient=True)
pre_bias = helper.create_sparse_variable_for_type_inference(x.dtype)
outputs = {"out": pre_bias, "rulebook": rulebook, "counter": counter}
helper.append_op(type=op_type,
inputs=inputs,
outputs=outputs,
attrs=attrs)
if bias is not None:
return add(pre_bias, bias)
else:
return pre_bias
def conv3d(x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册