未验证 提交 a0d465f8 编写于 作者: Z zyfncg 提交者: GitHub

【PTen】Add fill_constant kernel using ScalarArray in pten (#37481)

* add scalar and scalar_array

* remove DenseTensor include from Scalar and ScalarArray

* remove inner header from scalar_array

* refactor the method of fill_constant and add some comment

* add fill_constant kernel using ScalarArray

* modify some prompt

* remove fill_constant kernel with no shape
上级 3e088aaf
...@@ -30,6 +30,7 @@ limitations under the License. */ ...@@ -30,6 +30,7 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include "paddle/pten/common/scalar.h" #include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -1903,26 +1904,59 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -1903,26 +1904,59 @@ void OperatorWithKernel::BuildPtenKernelContext(
} }
for (size_t i = 0; i < attr_names.size(); ++i) { for (size_t i = 0; i < attr_names.size(); ++i) {
auto& attr = Attrs().at(attr_names[i]); if (attr_defs[i].type_index == std::type_index(typeid(pten::ScalarArray))) {
if (attr_defs[i].type_index == std::type_index(typeid(pten::Scalar))) { auto attr_iter = Attrs().find(attr_names[i]);
if (attr_iter != Attrs().end()) { // shape is in the attribute
if (std::type_index(attr_iter->second.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
pt_kernel_context_->EmplaceBackAttr(std::move(pten::ScalarArray(
BOOST_GET_CONST(std::vector<int64_t>, attr_iter->second))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to ScalarArray when "
"construct KernelContext.",
attr_names[i]));
}
} else { // shape is in the input
auto& ins_vector = ctx.inputs.at(attr_names[i]);
if (ins_vector.size() == 1) { // ShapeTensor
pt_kernel_context_->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVar(*ins_vector.front())));
} else { // ShapeTensorList
pt_kernel_context_->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVarList(ins_vector)));
}
}
} else if (attr_defs[i].type_index ==
std::type_index(typeid(pten::Scalar))) {
// TODO(chenweihang): support other attrs later // TODO(chenweihang): support other attrs later
// TODO(zhangyunfei): Scalar should hold scaler type, and we should check // TODO(zhangyunfei): Scalar should hold scaler type, and we should check
// attribtue type by attr_defs // attribtue type by attr_defs
if (std::type_index(attr.type()) == std::type_index(typeid(float))) { auto attr_iter = Attrs().find(attr_names[i]);
pt_kernel_context_->EmplaceBackAttr( if (attr_iter != Attrs().end()) { // scalar is in the attribute
std::move(pten::Scalar(BOOST_GET_CONST(float, attr)))); auto& attr = Attrs().at(attr_names[i]);
} else if (std::type_index(attr.type()) == if (std::type_index(attr.type()) == std::type_index(typeid(float))) {
std::type_index(typeid(std::string))) { pt_kernel_context_->EmplaceBackAttr(
pt_kernel_context_->EmplaceBackAttr( std::move(pten::Scalar(BOOST_GET_CONST(float, attr))));
std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr)))); } else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::string))) {
pt_kernel_context_->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when construct "
"KernelContext.",
attr_names[i]));
}
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( auto& ins_vector = ctx.inputs.at(attr_names[i]);
"unsupported cast op attribute `%s` to Scalar when construct " pt_kernel_context_->EmplaceBackAttr(std::move(
"KernelContext.", experimental::MakePtenScalarFromVar(*ins_vector.front())));
attr_names[i]));
} }
} else { } else {
// TODO(chenweihang): support other attrs later // TODO(chenweihang): support other attrs later
auto& attr = Attrs().at(attr_names[i]);
if (attr_defs[i].type_index == std::type_index(typeid(int))) { if (attr_defs[i].type_index == std::type_index(typeid(int))) {
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(int, attr)); pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(float))) { } else if (attr_defs[i].type_index == std::type_index(typeid(float))) {
...@@ -1949,7 +1983,7 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -1949,7 +1983,7 @@ void OperatorWithKernel::BuildPtenKernelContext(
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` when construct " "Unsupported cast op attribute `%s` when construct "
"KernelContext.", "KernelContext.",
attr_names[i])); attr_names[i]));
} }
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/imperative/infer_shape_context.h" #include "paddle/fluid/imperative/infer_shape_context.h"
#include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/tracer.h"
#include "paddle/pten/common/scalar.h" #include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/utils/small_vector.h" #include "paddle/utils/small_vector.h"
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/device/xpu/xpu_op_list.h" #include "paddle/fluid/platform/device/xpu/xpu_op_list.h"
...@@ -385,26 +386,66 @@ static void BuildDygraphPtenKernelContext( ...@@ -385,26 +386,66 @@ static void BuildDygraphPtenKernelContext(
} }
for (size_t i = 0; i < attr_names.size(); ++i) { for (size_t i = 0; i < attr_names.size(); ++i) {
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); if (attr_defs[i].type_index == std::type_index(typeid(pten::ScalarArray))) {
if (attr_defs[i].type_index == std::type_index(typeid(pten::Scalar))) { if (attrs.find(attr_names[i]) !=
attrs.end()) { // shape is in the attribute
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
kernel_ctx->EmplaceBackAttr(std::move(
pten::ScalarArray(BOOST_GET_CONST(std::vector<int64_t>, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to VectorTensor when "
"construct KernelContext.",
attr_names[i]));
}
} else { // shape is in the input
auto& ins_vector = ins.at(attr_names[i]);
if (ins_vector.size() == 1) { // ShapeTensor
kernel_ctx->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVar(ins_vector[0]->Var())));
} else { // ShapeTensorList
std::vector<framework::Variable*> variables;
variables.reserve(ins_vector.size());
for (const auto& var_base : ins_vector) {
variables.push_back(var_base->MutableVar());
}
kernel_ctx->EmplaceBackAttr(std::move(
experimental::MakePtenScalarArrayFromVarList(variables)));
}
}
} else if (attr_defs[i].type_index ==
std::type_index(typeid(pten::Scalar))) {
// TODO(chenweihang): support other attrs later // TODO(chenweihang): support other attrs later
// TODO(zhangyunfei): Scalar should hold scaler type, and we should check // TODO(zhangyunfei): Scalar should hold scaler type, and we should check
// attribtue type by attr_defs // attribtue type by attr_defs
if (std::type_index(attr.type()) == std::type_index(typeid(float))) { if (attrs.find(attr_names[i]) != attrs.end() ||
kernel_ctx->EmplaceBackAttr( default_attrs.find(attr_names[i]) !=
std::move(pten::Scalar(BOOST_GET_CONST(float, attr)))); default_attrs.end()) { // scalar is in the attribute
} else if (std::type_index(attr.type()) == auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
std::type_index(typeid(std::string))) { if (std::type_index(attr.type()) == std::type_index(typeid(float))) {
kernel_ctx->EmplaceBackAttr( kernel_ctx->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr)))); std::move(pten::Scalar(BOOST_GET_CONST(float, attr))));
} else { } else if (std::type_index(attr.type()) ==
PADDLE_THROW(platform::errors::Unimplemented( std::type_index(typeid(std::string))) {
"unsupported cast op attribute `%s` to Scalar when construct " kernel_ctx->EmplaceBackAttr(
"KernelContext in dygraph.", std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr))));
attr_names[i])); } else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when construct "
"KernelContext in dygraph.",
attr_names[i]));
}
} else { // scalar is in the input
auto& ins_vector = ins.at(attr_names[i]);
kernel_ctx->EmplaceBackAttr(std::move(
experimental::MakePtenScalarFromVar(ins_vector[0]->Var())));
} }
} else { } else {
// TODO(chenweihang): support other attrs later // TODO(chenweihang): support other attrs later
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (attr_defs[i].type_index == std::type_index(typeid(int))) { if (attr_defs[i].type_index == std::type_index(typeid(int))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr)); kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(float))) { } else if (attr_defs[i].type_index == std::type_index(typeid(float))) {
...@@ -430,7 +471,7 @@ static void BuildDygraphPtenKernelContext( ...@@ -430,7 +471,7 @@ static void BuildDygraphPtenKernelContext(
// TODO(YuanRisheng) Need support vector<int64_t> attr // TODO(YuanRisheng) Need support vector<int64_t> attr
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` when construct " "Unsupported cast op attribute `%s` when construct "
"KernelContext in dygraph.", "KernelContext in dygraph.",
attr_names[i])); attr_names[i]));
} }
......
...@@ -102,13 +102,23 @@ class FillConstantOp : public framework::OperatorWithKernel { ...@@ -102,13 +102,23 @@ class FillConstantOp : public framework::OperatorWithKernel {
framework::KernelSignature GetExpectedPtenKernelArgs( framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
if (!ctx.HasInput("ShapeTensor") && std::string shape;
ctx.MultiInput<framework::Tensor>("ShapeTensorList").empty() && if (ctx.HasInput("ShapeTensor")) {
!ctx.HasInput("ValueTensor") && shape = "ShapeTensor";
!ctx.OutputVar("Out")->IsType<framework::SelectedRows>()) { } else if (ctx.MultiInput<framework::Tensor>("ShapeTensorList").size()) {
shape = "ShapeTensorList";
} else {
shape = "shape";
}
std::string value;
if (ctx.HasInput("ValueTensor")) {
value = "ValueTensor";
} else {
const auto& str_value = ctx.Attr<std::string>("str_value"); const auto& str_value = ctx.Attr<std::string>("str_value");
std::string value = str_value.empty() ? "value" : "str_value"; value = str_value.empty() ? "value" : "str_value";
return framework::KernelSignature("fill_constant.scalar", {}, {value}, }
if (!ctx.OutputVar("Out")->IsType<framework::SelectedRows>()) {
return framework::KernelSignature("fill_constant", {}, {shape, value},
{"Out"}); {"Out"});
} }
return framework::KernelSignature("fill_constant.unregistered", {}, {}, {}); return framework::KernelSignature("fill_constant.unregistered", {}, {}, {});
......
...@@ -97,6 +97,168 @@ std::unique_ptr<pten::DenseTensor> MakePtenDenseTensor( ...@@ -97,6 +97,168 @@ std::unique_ptr<pten::DenseTensor> MakePtenDenseTensor(
} }
} }
pten::Scalar MakePtenScalar(const paddle::framework::LoDTensor& src) {
PADDLE_ENFORCE_EQ(src.numel(),
1,
paddle::platform::errors::InvalidArgument(
"The Scalar only supports Tensor with 1 element, "
"but now Tensor has %d element.",
src.numel()));
switch (src.type()) {
case paddle::framework::proto::VarType::FP32:
return {src.template data<float>()[0]};
case paddle::framework::proto::VarType::FP64:
return {src.template data<double>()[0]};
case paddle::framework::proto::VarType::FP16:
return {src.template data<float16>()[0]};
case paddle::framework::proto::VarType::BF16:
return {src.template data<bfloat16>()[0]};
case paddle::framework::proto::VarType::INT32:
return {src.template data<int32_t>()[0]};
case paddle::framework::proto::VarType::INT64:
return {src.template data<int64_t>()[0]};
case paddle::framework::proto::VarType::INT16:
return {src.template data<int16_t>()[0]};
case paddle::framework::proto::VarType::INT8:
return {src.template data<int8_t>()[0]};
case paddle::framework::proto::VarType::UINT8:
return {src.template data<uint8_t>()[0]};
case paddle::framework::proto::VarType::BOOL:
return {src.template data<bool>()[0]};
case paddle::framework::proto::VarType::COMPLEX64:
return {src.template data<complex64>()[0]};
case paddle::framework::proto::VarType::COMPLEX128:
return {src.template data<complex128>()[0]};
default:
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Data type error. Don't support casting a %d LoDTensor to Scalar.",
src.type()));
}
}
pten::Scalar MakePtenScalarFromVar(const framework::Variable& variable) {
auto expected_place = pten::TransToFluidPlace(pten::Backend::CPU);
if (variable.IsType<framework::LoDTensor>()) {
const auto& tensor = variable.Get<framework::LoDTensor>();
if (!platform::is_same_place(tensor.place(), expected_place)) {
framework::LoDTensor tmp_tensor;
framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
return MakePtenScalar(tmp_tensor);
} else {
return MakePtenScalar(tensor);
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport casting input `%s` type to Scalar when call pt "
"kernel.",
framework::ToTypeName(variable.Type())));
}
}
pten::ScalarArray MakePtenScalarArray(const paddle::framework::LoDTensor& src) {
if (src.type() == paddle::framework::proto::VarType::INT64) {
return {src.data<int64_t>(), src.numel()};
} else if (src.type() == paddle::framework::proto::VarType::INT32) {
return {src.data<int32_t>(), src.numel()};
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Data type error. When cast a LoDTensor to ScalarArray, "
"the data type of LoDTensor must be int32 or int64, "
"but now data type is %s.",
src.type()));
}
}
pten::ScalarArray MakePtenScalarArrayFromVar(
const framework::Variable& variable) {
auto expected_place = pten::TransToFluidPlace(pten::Backend::CPU);
if (variable.IsType<framework::LoDTensor>()) {
const auto& tensor = variable.Get<framework::LoDTensor>();
if (!platform::is_same_place(tensor.place(), expected_place)) {
framework::LoDTensor tmp_tensor;
framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
return MakePtenScalarArray(tmp_tensor);
} else {
return MakePtenScalarArray(tensor);
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport casting input `%s` type to ScalarArray when call pt "
"kernel.",
framework::ToTypeName(variable.Type())));
}
}
pten::ScalarArray MakePtenScalarArrayFromVarList(
const std::vector<framework::Variable*>& variable_list) {
if (variable_list.size() == 0) {
return pten::ScalarArray();
}
auto expected_place = pten::TransToFluidPlace(pten::Backend::CPU);
paddle::framework::proto::VarType::Type data_type;
auto* first_var = variable_list.front();
if (first_var->IsType<framework::LoDTensor>()) {
const auto& tensor = first_var->Get<framework::LoDTensor>();
data_type = tensor.type();
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport casting input `%s` type to VectorTensor when call pt "
"kernel.",
framework::ToTypeName(first_var->Type())));
}
std::vector<int64_t> vector_data;
vector_data.reserve(variable_list.size());
if (data_type == paddle::framework::proto::VarType::INT64) {
for (auto* var : variable_list) {
if (var->IsType<framework::LoDTensor>()) {
const auto& tensor = var->Get<framework::LoDTensor>();
if (!platform::is_same_place(tensor.place(), expected_place)) {
framework::LoDTensor tmp_tensor;
framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
vector_data.push_back(*tmp_tensor.data<int64_t>());
} else {
vector_data.push_back(*tensor.data<int64_t>());
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport casting input `%s` type to VectorTensor when call pt "
"kernel.",
framework::ToTypeName(var->Type())));
}
}
} else if (data_type == paddle::framework::proto::VarType::INT32) {
for (auto* var : variable_list) {
if (var->IsType<framework::LoDTensor>()) {
const auto& tensor = var->Get<framework::LoDTensor>();
if (!platform::is_same_place(tensor.place(), expected_place)) {
framework::LoDTensor tmp_tensor;
framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
vector_data.push_back(*tmp_tensor.data<int32_t>());
} else {
vector_data.push_back(*tensor.data<int32_t>());
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport casting input `%s` type to VectorTensor when call pt "
"kernel.",
framework::ToTypeName(var->Type())));
}
}
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Data type error. When cast a LoDTensor to VectorTensor, "
"the data type of LoDTensor must be int32 or int64, "
"but now data type is %s.",
data_type));
}
return {vector_data};
}
std::unique_ptr<pten::TensorBase> MakePtenTensorBaseFromVar( std::unique_ptr<pten::TensorBase> MakePtenTensorBaseFromVar(
const framework::Variable& variable, const pten::TensorArgDef& arg_def) { const framework::Variable& variable, const pten::TensorArgDef& arg_def) {
auto expected_place = pten::TransToFluidPlace(arg_def.backend); auto expected_place = pten::TransToFluidPlace(arg_def.backend);
......
...@@ -21,6 +21,8 @@ limitations under the License. */ ...@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/api/lib/utils/storage.h" #include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/convert_utils.h" #include "paddle/pten/core/convert_utils.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_factory.h" #include "paddle/pten/core/kernel_factory.h"
...@@ -34,6 +36,18 @@ std::unique_ptr<pten::DenseTensor> MakePtenDenseTensor( ...@@ -34,6 +36,18 @@ std::unique_ptr<pten::DenseTensor> MakePtenDenseTensor(
std::unique_ptr<pten::DenseTensor> MakePtenDenseTensor( std::unique_ptr<pten::DenseTensor> MakePtenDenseTensor(
const paddle::framework::LoDTensor& src); const paddle::framework::LoDTensor& src);
pten::Scalar MakePtenScalar(const paddle::framework::LoDTensor& src);
pten::ScalarArray MakePtenScalarArray(const paddle::framework::LoDTensor& src);
pten::Scalar MakePtenScalarFromVar(const framework::Variable& variable);
pten::ScalarArray MakePtenScalarArrayFromVar(
const framework::Variable& variable);
pten::ScalarArray MakePtenScalarArrayFromVarList(
const std::vector<framework::Variable*>& variable_list);
std::unique_ptr<pten::TensorBase> MakePtenTensorBaseFromVar( std::unique_ptr<pten::TensorBase> MakePtenTensorBaseFromVar(
const framework::Variable& variable, const pten::TensorArgDef& arg_def); const framework::Variable& variable, const pten::TensorArgDef& arg_def);
......
...@@ -118,7 +118,7 @@ class ScalarArrayBase { ...@@ -118,7 +118,7 @@ class ScalarArrayBase {
/// \brief Assign the data_ from const data pointer value of type T. /// \brief Assign the data_ from const data pointer value of type T.
template <typename TYPE> template <typename TYPE>
void AssignData(const TYPE* value_data, int64_t n) { void AssignData(const TYPE* value_data, int64_t n) {
if (value_data) { if (value_data || n == 0) {
array_.reserve(n); array_.reserve(n);
for (auto i = 0; i < n; ++i) { for (auto i = 0; i < n; ++i) {
array_.push_back(static_cast<int64_t>(value_data[i])); array_.push_back(static_cast<int64_t>(value_data[i]));
......
...@@ -52,16 +52,9 @@ void FillAnyLike(const CPUContext& dev_ctx, ...@@ -52,16 +52,9 @@ void FillAnyLike(const CPUContext& dev_ctx,
template <typename T> template <typename T>
void FillConstant(const CPUContext& dev_ctx, void FillConstant(const CPUContext& dev_ctx,
const ScalarArray& shape,
const Scalar& val, const Scalar& val,
DenseTensor* out) { DenseTensor* out) {
eigen::fill<CPUContext, T>(dev_ctx, out, val.to<T>());
}
template <typename T>
void FillConstantDynamicShape(const CPUContext& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DenseTensor* out) {
out->Resize(paddle::framework::make_ddim(shape.GetData())); out->Resize(paddle::framework::make_ddim(shape.GetData()));
eigen::fill<CPUContext, T>(dev_ctx, out, val.to<T>()); eigen::fill<CPUContext, T>(dev_ctx, out, val.to<T>());
} }
...@@ -81,26 +74,10 @@ PT_REGISTER_KERNEL("fill_any_like", ...@@ -81,26 +74,10 @@ PT_REGISTER_KERNEL("fill_any_like",
bool, bool,
paddle::platform::float16) {} paddle::platform::float16) {}
PT_REGISTER_KERNEL("fill_constant.scalar",
CPU,
ANY,
pten::FillConstant,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL("fill_constant", PT_REGISTER_KERNEL("fill_constant",
CPU, CPU,
ANY, ANY,
pten::FillConstantDynamicShape, pten::FillConstant,
float, float,
double, double,
uint8_t, uint8_t,
......
...@@ -31,13 +31,8 @@ void FillAnyLike(const CPUContext& dev_ctx, ...@@ -31,13 +31,8 @@ void FillAnyLike(const CPUContext& dev_ctx,
template <typename T> template <typename T>
void FillConstant(const CPUContext& dev_ctx, void FillConstant(const CPUContext& dev_ctx,
const ScalarArray& shape,
const Scalar& val, const Scalar& val,
DenseTensor* out); DenseTensor* out);
template <typename T>
void FillConstantDynamicShape(const CPUContext& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DenseTensor* out);
} // namespace pten } // namespace pten
...@@ -53,16 +53,9 @@ void FillAnyLike(const CUDAContext& dev_ctx, ...@@ -53,16 +53,9 @@ void FillAnyLike(const CUDAContext& dev_ctx,
template <typename T> template <typename T>
void FillConstant(const CUDAContext& dev_ctx, void FillConstant(const CUDAContext& dev_ctx,
const ScalarArray& shape,
const Scalar& val, const Scalar& val,
DenseTensor* out) { DenseTensor* out) {
eigen::fill<CUDAContext, T>(dev_ctx, out, val.to<T>());
}
template <typename T>
void FillConstantDynamicShape(const CUDAContext& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DenseTensor* out) {
out->Resize(paddle::framework::make_ddim(shape.GetData())); out->Resize(paddle::framework::make_ddim(shape.GetData()));
eigen::fill<CUDAContext, T>(dev_ctx, out, val.to<T>()); eigen::fill<CUDAContext, T>(dev_ctx, out, val.to<T>());
} }
...@@ -82,25 +75,10 @@ PT_REGISTER_KERNEL("fill_any_like", ...@@ -82,25 +75,10 @@ PT_REGISTER_KERNEL("fill_any_like",
bool, bool,
paddle::platform::float16) {} paddle::platform::float16) {}
PT_REGISTER_KERNEL("fill_constant.scalar",
CUDA,
ANY,
pten::FillConstant,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL("fill_constant", PT_REGISTER_KERNEL("fill_constant",
CUDA, CUDA,
ANY, ANY,
pten::FillConstantDynamicShape, pten::FillConstant,
float, float,
double, double,
uint8_t, uint8_t,
......
...@@ -34,15 +34,10 @@ void FillAnyLike(const CUDAContext& dev_ctx, ...@@ -34,15 +34,10 @@ void FillAnyLike(const CUDAContext& dev_ctx,
template <typename T> template <typename T>
void FillConstant(const CUDAContext& dev_ctx, void FillConstant(const CUDAContext& dev_ctx,
const ScalarArray& shape,
const Scalar& val, const Scalar& val,
DenseTensor* out); DenseTensor* out);
template <typename T>
void FillConstantDynamicShape(const CUDAContext& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DenseTensor* out);
} // namespace pten } // namespace pten
#endif #endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册