未验证 提交 a0d465f8 编写于 作者: Z zyfncg 提交者: GitHub

【PTen】Add fill_constant kernel using ScalarArray in pten (#37481)

* add scalar and scalar_array

* remove DenseTensor include from Scalar and ScalarArray

* remove inner header from scalar_array

* refactor the method of fill_constant and add some comment

* add fill_constant kernel using ScalarArray

* modify some prompt

* remove fill_constant kernel with no shape
上级 3e088aaf
......@@ -30,6 +30,7 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
namespace paddle {
namespace framework {
......@@ -1903,26 +1904,59 @@ void OperatorWithKernel::BuildPtenKernelContext(
}
for (size_t i = 0; i < attr_names.size(); ++i) {
auto& attr = Attrs().at(attr_names[i]);
if (attr_defs[i].type_index == std::type_index(typeid(pten::Scalar))) {
if (attr_defs[i].type_index == std::type_index(typeid(pten::ScalarArray))) {
auto attr_iter = Attrs().find(attr_names[i]);
if (attr_iter != Attrs().end()) { // shape is in the attribute
if (std::type_index(attr_iter->second.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
pt_kernel_context_->EmplaceBackAttr(std::move(pten::ScalarArray(
BOOST_GET_CONST(std::vector<int64_t>, attr_iter->second))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to ScalarArray when "
"construct KernelContext.",
attr_names[i]));
}
} else { // shape is in the input
auto& ins_vector = ctx.inputs.at(attr_names[i]);
if (ins_vector.size() == 1) { // ShapeTensor
pt_kernel_context_->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVar(*ins_vector.front())));
} else { // ShapeTensorList
pt_kernel_context_->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVarList(ins_vector)));
}
}
} else if (attr_defs[i].type_index ==
std::type_index(typeid(pten::Scalar))) {
// TODO(chenweihang): support other attrs later
// TODO(zhangyunfei): Scalar should hold scaler type, and we should check
// attribtue type by attr_defs
if (std::type_index(attr.type()) == std::type_index(typeid(float))) {
pt_kernel_context_->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(float, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::string))) {
pt_kernel_context_->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr))));
auto attr_iter = Attrs().find(attr_names[i]);
if (attr_iter != Attrs().end()) { // scalar is in the attribute
auto& attr = Attrs().at(attr_names[i]);
if (std::type_index(attr.type()) == std::type_index(typeid(float))) {
pt_kernel_context_->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(float, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::string))) {
pt_kernel_context_->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when construct "
"KernelContext.",
attr_names[i]));
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` to Scalar when construct "
"KernelContext.",
attr_names[i]));
auto& ins_vector = ctx.inputs.at(attr_names[i]);
pt_kernel_context_->EmplaceBackAttr(std::move(
experimental::MakePtenScalarFromVar(*ins_vector.front())));
}
} else {
// TODO(chenweihang): support other attrs later
auto& attr = Attrs().at(attr_names[i]);
if (attr_defs[i].type_index == std::type_index(typeid(int))) {
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(float))) {
......@@ -1949,7 +1983,7 @@ void OperatorWithKernel::BuildPtenKernelContext(
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` when construct "
"Unsupported cast op attribute `%s` when construct "
"KernelContext.",
attr_names[i]));
}
......
......@@ -19,6 +19,7 @@
#include "paddle/fluid/imperative/infer_shape_context.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/utils/small_vector.h"
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/device/xpu/xpu_op_list.h"
......@@ -385,26 +386,66 @@ static void BuildDygraphPtenKernelContext(
}
for (size_t i = 0; i < attr_names.size(); ++i) {
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (attr_defs[i].type_index == std::type_index(typeid(pten::Scalar))) {
if (attr_defs[i].type_index == std::type_index(typeid(pten::ScalarArray))) {
if (attrs.find(attr_names[i]) !=
attrs.end()) { // shape is in the attribute
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
kernel_ctx->EmplaceBackAttr(std::move(
pten::ScalarArray(BOOST_GET_CONST(std::vector<int64_t>, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to VectorTensor when "
"construct KernelContext.",
attr_names[i]));
}
} else { // shape is in the input
auto& ins_vector = ins.at(attr_names[i]);
if (ins_vector.size() == 1) { // ShapeTensor
kernel_ctx->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVar(ins_vector[0]->Var())));
} else { // ShapeTensorList
std::vector<framework::Variable*> variables;
variables.reserve(ins_vector.size());
for (const auto& var_base : ins_vector) {
variables.push_back(var_base->MutableVar());
}
kernel_ctx->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVarList(variables)));
}
}
} else if (attr_defs[i].type_index ==
std::type_index(typeid(pten::Scalar))) {
// TODO(chenweihang): support other attrs later
// TODO(zhangyunfei): Scalar should hold scaler type, and we should check
// attribtue type by attr_defs
if (std::type_index(attr.type()) == std::type_index(typeid(float))) {
kernel_ctx->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(float, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::string))) {
kernel_ctx->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` to Scalar when construct "
"KernelContext in dygraph.",
attr_names[i]));
if (attrs.find(attr_names[i]) != attrs.end() ||
default_attrs.find(attr_names[i]) !=
default_attrs.end()) { // scalar is in the attribute
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (std::type_index(attr.type()) == std::type_index(typeid(float))) {
kernel_ctx->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(float, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::string))) {
kernel_ctx->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when construct "
"KernelContext in dygraph.",
attr_names[i]));
}
} else { // scalar is in the input
auto& ins_vector = ins.at(attr_names[i]);
kernel_ctx->EmplaceBackAttr(std::move(
experimental::MakePtenScalarFromVar(ins_vector[0]->Var())));
}
} else {
// TODO(chenweihang): support other attrs later
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (attr_defs[i].type_index == std::type_index(typeid(int))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(float))) {
......@@ -430,7 +471,7 @@ static void BuildDygraphPtenKernelContext(
// TODO(YuanRisheng) Need support vector<int64_t> attr
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` when construct "
"Unsupported cast op attribute `%s` when construct "
"KernelContext in dygraph.",
attr_names[i]));
}
......
......@@ -102,13 +102,23 @@ class FillConstantOp : public framework::OperatorWithKernel {
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
if (!ctx.HasInput("ShapeTensor") &&
ctx.MultiInput<framework::Tensor>("ShapeTensorList").empty() &&
!ctx.HasInput("ValueTensor") &&
!ctx.OutputVar("Out")->IsType<framework::SelectedRows>()) {
std::string shape;
if (ctx.HasInput("ShapeTensor")) {
shape = "ShapeTensor";
} else if (ctx.MultiInput<framework::Tensor>("ShapeTensorList").size()) {
shape = "ShapeTensorList";
} else {
shape = "shape";
}
std::string value;
if (ctx.HasInput("ValueTensor")) {
value = "ValueTensor";
} else {
const auto& str_value = ctx.Attr<std::string>("str_value");
std::string value = str_value.empty() ? "value" : "str_value";
return framework::KernelSignature("fill_constant.scalar", {}, {value},
value = str_value.empty() ? "value" : "str_value";
}
if (!ctx.OutputVar("Out")->IsType<framework::SelectedRows>()) {
return framework::KernelSignature("fill_constant", {}, {shape, value},
{"Out"});
}
return framework::KernelSignature("fill_constant.unregistered", {}, {}, {});
......
......@@ -97,6 +97,168 @@ std::unique_ptr<pten::DenseTensor> MakePtenDenseTensor(
}
}
pten::Scalar MakePtenScalar(const paddle::framework::LoDTensor& src) {
PADDLE_ENFORCE_EQ(src.numel(),
1,
paddle::platform::errors::InvalidArgument(
"The Scalar only supports Tensor with 1 element, "
"but now Tensor has %d element.",
src.numel()));
switch (src.type()) {
case paddle::framework::proto::VarType::FP32:
return {src.template data<float>()[0]};
case paddle::framework::proto::VarType::FP64:
return {src.template data<double>()[0]};
case paddle::framework::proto::VarType::FP16:
return {src.template data<float16>()[0]};
case paddle::framework::proto::VarType::BF16:
return {src.template data<bfloat16>()[0]};
case paddle::framework::proto::VarType::INT32:
return {src.template data<int32_t>()[0]};
case paddle::framework::proto::VarType::INT64:
return {src.template data<int64_t>()[0]};
case paddle::framework::proto::VarType::INT16:
return {src.template data<int16_t>()[0]};
case paddle::framework::proto::VarType::INT8:
return {src.template data<int8_t>()[0]};
case paddle::framework::proto::VarType::UINT8:
return {src.template data<uint8_t>()[0]};
case paddle::framework::proto::VarType::BOOL:
return {src.template data<bool>()[0]};
case paddle::framework::proto::VarType::COMPLEX64:
return {src.template data<complex64>()[0]};
case paddle::framework::proto::VarType::COMPLEX128:
return {src.template data<complex128>()[0]};
default:
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Data type error. Don't support casting a %d LoDTensor to Scalar.",
src.type()));
}
}
pten::Scalar MakePtenScalarFromVar(const framework::Variable& variable) {
auto expected_place = pten::TransToFluidPlace(pten::Backend::CPU);
if (variable.IsType<framework::LoDTensor>()) {
const auto& tensor = variable.Get<framework::LoDTensor>();
if (!platform::is_same_place(tensor.place(), expected_place)) {
framework::LoDTensor tmp_tensor;
framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
return MakePtenScalar(tmp_tensor);
} else {
return MakePtenScalar(tensor);
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport casting input `%s` type to Scalar when call pt "
"kernel.",
framework::ToTypeName(variable.Type())));
}
}
pten::ScalarArray MakePtenScalarArray(const paddle::framework::LoDTensor& src) {
if (src.type() == paddle::framework::proto::VarType::INT64) {
return {src.data<int64_t>(), src.numel()};
} else if (src.type() == paddle::framework::proto::VarType::INT32) {
return {src.data<int32_t>(), src.numel()};
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Data type error. When cast a LoDTensor to ScalarArray, "
"the data type of LoDTensor must be int32 or int64, "
"but now data type is %s.",
src.type()));
}
}
pten::ScalarArray MakePtenScalarArrayFromVar(
const framework::Variable& variable) {
auto expected_place = pten::TransToFluidPlace(pten::Backend::CPU);
if (variable.IsType<framework::LoDTensor>()) {
const auto& tensor = variable.Get<framework::LoDTensor>();
if (!platform::is_same_place(tensor.place(), expected_place)) {
framework::LoDTensor tmp_tensor;
framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
return MakePtenScalarArray(tmp_tensor);
} else {
return MakePtenScalarArray(tensor);
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport casting input `%s` type to ScalarArray when call pt "
"kernel.",
framework::ToTypeName(variable.Type())));
}
}
pten::ScalarArray MakePtenScalarArrayFromVarList(
const std::vector<framework::Variable*>& variable_list) {
if (variable_list.size() == 0) {
return pten::ScalarArray();
}
auto expected_place = pten::TransToFluidPlace(pten::Backend::CPU);
paddle::framework::proto::VarType::Type data_type;
auto* first_var = variable_list.front();
if (first_var->IsType<framework::LoDTensor>()) {
const auto& tensor = first_var->Get<framework::LoDTensor>();
data_type = tensor.type();
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport casting input `%s` type to VectorTensor when call pt "
"kernel.",
framework::ToTypeName(first_var->Type())));
}
std::vector<int64_t> vector_data;
vector_data.reserve(variable_list.size());
if (data_type == paddle::framework::proto::VarType::INT64) {
for (auto* var : variable_list) {
if (var->IsType<framework::LoDTensor>()) {
const auto& tensor = var->Get<framework::LoDTensor>();
if (!platform::is_same_place(tensor.place(), expected_place)) {
framework::LoDTensor tmp_tensor;
framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
vector_data.push_back(*tmp_tensor.data<int64_t>());
} else {
vector_data.push_back(*tensor.data<int64_t>());
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport casting input `%s` type to VectorTensor when call pt "
"kernel.",
framework::ToTypeName(var->Type())));
}
}
} else if (data_type == paddle::framework::proto::VarType::INT32) {
for (auto* var : variable_list) {
if (var->IsType<framework::LoDTensor>()) {
const auto& tensor = var->Get<framework::LoDTensor>();
if (!platform::is_same_place(tensor.place(), expected_place)) {
framework::LoDTensor tmp_tensor;
framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
vector_data.push_back(*tmp_tensor.data<int32_t>());
} else {
vector_data.push_back(*tensor.data<int32_t>());
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport casting input `%s` type to VectorTensor when call pt "
"kernel.",
framework::ToTypeName(var->Type())));
}
}
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Data type error. When cast a LoDTensor to VectorTensor, "
"the data type of LoDTensor must be int32 or int64, "
"but now data type is %s.",
data_type));
}
return {vector_data};
}
std::unique_ptr<pten::TensorBase> MakePtenTensorBaseFromVar(
const framework::Variable& variable, const pten::TensorArgDef& arg_def) {
auto expected_place = pten::TransToFluidPlace(arg_def.backend);
......
......@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/convert_utils.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_factory.h"
......@@ -34,6 +36,18 @@ std::unique_ptr<pten::DenseTensor> MakePtenDenseTensor(
std::unique_ptr<pten::DenseTensor> MakePtenDenseTensor(
const paddle::framework::LoDTensor& src);
pten::Scalar MakePtenScalar(const paddle::framework::LoDTensor& src);
pten::ScalarArray MakePtenScalarArray(const paddle::framework::LoDTensor& src);
pten::Scalar MakePtenScalarFromVar(const framework::Variable& variable);
pten::ScalarArray MakePtenScalarArrayFromVar(
const framework::Variable& variable);
pten::ScalarArray MakePtenScalarArrayFromVarList(
const std::vector<framework::Variable*>& variable_list);
std::unique_ptr<pten::TensorBase> MakePtenTensorBaseFromVar(
const framework::Variable& variable, const pten::TensorArgDef& arg_def);
......
......@@ -118,7 +118,7 @@ class ScalarArrayBase {
/// \brief Assign the data_ from const data pointer value of type T.
template <typename TYPE>
void AssignData(const TYPE* value_data, int64_t n) {
if (value_data) {
if (value_data || n == 0) {
array_.reserve(n);
for (auto i = 0; i < n; ++i) {
array_.push_back(static_cast<int64_t>(value_data[i]));
......
......@@ -52,16 +52,9 @@ void FillAnyLike(const CPUContext& dev_ctx,
template <typename T>
void FillConstant(const CPUContext& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DenseTensor* out) {
eigen::fill<CPUContext, T>(dev_ctx, out, val.to<T>());
}
template <typename T>
void FillConstantDynamicShape(const CPUContext& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DenseTensor* out) {
out->Resize(paddle::framework::make_ddim(shape.GetData()));
eigen::fill<CPUContext, T>(dev_ctx, out, val.to<T>());
}
......@@ -81,26 +74,10 @@ PT_REGISTER_KERNEL("fill_any_like",
bool,
paddle::platform::float16) {}
PT_REGISTER_KERNEL("fill_constant.scalar",
CPU,
ANY,
pten::FillConstant,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL("fill_constant",
CPU,
ANY,
pten::FillConstantDynamicShape,
pten::FillConstant,
float,
double,
uint8_t,
......
......@@ -31,13 +31,8 @@ void FillAnyLike(const CPUContext& dev_ctx,
template <typename T>
void FillConstant(const CPUContext& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DenseTensor* out);
template <typename T>
void FillConstantDynamicShape(const CPUContext& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DenseTensor* out);
} // namespace pten
......@@ -53,16 +53,9 @@ void FillAnyLike(const CUDAContext& dev_ctx,
template <typename T>
void FillConstant(const CUDAContext& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DenseTensor* out) {
eigen::fill<CUDAContext, T>(dev_ctx, out, val.to<T>());
}
template <typename T>
void FillConstantDynamicShape(const CUDAContext& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DenseTensor* out) {
out->Resize(paddle::framework::make_ddim(shape.GetData()));
eigen::fill<CUDAContext, T>(dev_ctx, out, val.to<T>());
}
......@@ -82,25 +75,10 @@ PT_REGISTER_KERNEL("fill_any_like",
bool,
paddle::platform::float16) {}
PT_REGISTER_KERNEL("fill_constant.scalar",
CUDA,
ANY,
pten::FillConstant,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL("fill_constant",
CUDA,
ANY,
pten::FillConstantDynamicShape,
pten::FillConstant,
float,
double,
uint8_t,
......
......@@ -34,15 +34,10 @@ void FillAnyLike(const CUDAContext& dev_ctx,
template <typename T>
void FillConstant(const CUDAContext& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DenseTensor* out);
template <typename T>
void FillConstantDynamicShape(const CUDAContext& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DenseTensor* out);
} // namespace pten
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册