From a0d465f80fa29e347331a7600bdff05ed18d1f2f Mon Sep 17 00:00:00 2001 From: zyfncg Date: Thu, 25 Nov 2021 19:58:39 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PTen=E3=80=91Add=20fill=5Fconstant=20k?= =?UTF-8?q?ernel=20using=20ScalarArray=20in=20pten=20(#37481)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add scalar and scalar_array * remove DenseTensor include from Scalar and ScalarArray * remove inner header from scalar_array * refactor the method of fill_constant and add some comment * add fill_constant kernel using ScalarArray * modify some prompt * remove fill_constant kernel with no shape --- paddle/fluid/framework/operator.cc | 62 +++++-- paddle/fluid/imperative/prepared_operator.cc | 71 ++++++-- paddle/fluid/operators/fill_constant_op.cc | 22 ++- paddle/pten/api/lib/utils/tensor_utils.cc | 162 +++++++++++++++++++ paddle/pten/api/lib/utils/tensor_utils.h | 14 ++ paddle/pten/common/scalar_array.h | 2 +- paddle/pten/kernels/cpu/creation.cc | 27 +--- paddle/pten/kernels/cpu/creation.h | 7 +- paddle/pten/kernels/cuda/creation.cu | 26 +-- paddle/pten/kernels/cuda/creation.h | 7 +- 10 files changed, 303 insertions(+), 97 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 1a60acf49a4..4b1b0d4f05c 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -30,6 +30,7 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/profiler.h" #include "paddle/pten/common/scalar.h" +#include "paddle/pten/common/scalar_array.h" namespace paddle { namespace framework { @@ -1903,26 +1904,59 @@ void OperatorWithKernel::BuildPtenKernelContext( } for (size_t i = 0; i < attr_names.size(); ++i) { - auto& attr = Attrs().at(attr_names[i]); - if (attr_defs[i].type_index == std::type_index(typeid(pten::Scalar))) { + if (attr_defs[i].type_index == std::type_index(typeid(pten::ScalarArray))) { + auto attr_iter = Attrs().find(attr_names[i]); + if (attr_iter != Attrs().end()) { // shape is in the attribute + if (std::type_index(attr_iter->second.type()) == + std::type_index(typeid(std::vector))) { + pt_kernel_context_->EmplaceBackAttr(std::move(pten::ScalarArray( + BOOST_GET_CONST(std::vector, attr_iter->second)))); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to ScalarArray when " + "construct KernelContext.", + attr_names[i])); + } + } else { // shape is in the input + auto& ins_vector = ctx.inputs.at(attr_names[i]); + if (ins_vector.size() == 1) { // ShapeTensor + pt_kernel_context_->EmplaceBackAttr(std::move( + experimental::MakePtenScalarArrayFromVar(*ins_vector.front()))); + } else { // ShapeTensorList + pt_kernel_context_->EmplaceBackAttr(std::move( + experimental::MakePtenScalarArrayFromVarList(ins_vector))); + } + } + } else if (attr_defs[i].type_index == + std::type_index(typeid(pten::Scalar))) { // TODO(chenweihang): support other attrs later // TODO(zhangyunfei): Scalar should hold scaler type, and we should check // attribtue type by attr_defs - if (std::type_index(attr.type()) == std::type_index(typeid(float))) { - pt_kernel_context_->EmplaceBackAttr( - std::move(pten::Scalar(BOOST_GET_CONST(float, attr)))); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::string))) { - pt_kernel_context_->EmplaceBackAttr( - std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr)))); + auto attr_iter = Attrs().find(attr_names[i]); + if (attr_iter != Attrs().end()) { // scalar is in the attribute + auto& attr = Attrs().at(attr_names[i]); + if (std::type_index(attr.type()) == std::type_index(typeid(float))) { + pt_kernel_context_->EmplaceBackAttr( + std::move(pten::Scalar(BOOST_GET_CONST(float, attr)))); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::string))) { + pt_kernel_context_->EmplaceBackAttr( + std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr)))); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to Scalar when construct " + "KernelContext.", + attr_names[i])); + } } else { - PADDLE_THROW(platform::errors::Unimplemented( - "unsupported cast op attribute `%s` to Scalar when construct " - "KernelContext.", - attr_names[i])); + auto& ins_vector = ctx.inputs.at(attr_names[i]); + pt_kernel_context_->EmplaceBackAttr(std::move( + experimental::MakePtenScalarFromVar(*ins_vector.front()))); } + } else { // TODO(chenweihang): support other attrs later + auto& attr = Attrs().at(attr_names[i]); if (attr_defs[i].type_index == std::type_index(typeid(int))) { pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(int, attr)); } else if (attr_defs[i].type_index == std::type_index(typeid(float))) { @@ -1949,7 +1983,7 @@ void OperatorWithKernel::BuildPtenKernelContext( } else { PADDLE_THROW(platform::errors::Unimplemented( - "unsupported cast op attribute `%s` when construct " + "Unsupported cast op attribute `%s` when construct " "KernelContext.", attr_names[i])); } diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 167afdfb196..604f9d2be9e 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -19,6 +19,7 @@ #include "paddle/fluid/imperative/infer_shape_context.h" #include "paddle/fluid/imperative/tracer.h" #include "paddle/pten/common/scalar.h" +#include "paddle/pten/common/scalar_array.h" #include "paddle/utils/small_vector.h" #ifdef PADDLE_WITH_XPU #include "paddle/fluid/platform/device/xpu/xpu_op_list.h" @@ -385,26 +386,66 @@ static void BuildDygraphPtenKernelContext( } for (size_t i = 0; i < attr_names.size(); ++i) { - auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); - if (attr_defs[i].type_index == std::type_index(typeid(pten::Scalar))) { + if (attr_defs[i].type_index == std::type_index(typeid(pten::ScalarArray))) { + if (attrs.find(attr_names[i]) != + attrs.end()) { // shape is in the attribute + auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); + if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + kernel_ctx->EmplaceBackAttr(std::move( + pten::ScalarArray(BOOST_GET_CONST(std::vector, attr)))); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to VectorTensor when " + "construct KernelContext.", + attr_names[i])); + } + } else { // shape is in the input + auto& ins_vector = ins.at(attr_names[i]); + if (ins_vector.size() == 1) { // ShapeTensor + kernel_ctx->EmplaceBackAttr(std::move( + experimental::MakePtenScalarArrayFromVar(ins_vector[0]->Var()))); + } else { // ShapeTensorList + std::vector variables; + variables.reserve(ins_vector.size()); + for (const auto& var_base : ins_vector) { + variables.push_back(var_base->MutableVar()); + } + kernel_ctx->EmplaceBackAttr(std::move( + experimental::MakePtenScalarArrayFromVarList(variables))); + } + } + } else if (attr_defs[i].type_index == + std::type_index(typeid(pten::Scalar))) { // TODO(chenweihang): support other attrs later // TODO(zhangyunfei): Scalar should hold scaler type, and we should check // attribtue type by attr_defs - if (std::type_index(attr.type()) == std::type_index(typeid(float))) { - kernel_ctx->EmplaceBackAttr( - std::move(pten::Scalar(BOOST_GET_CONST(float, attr)))); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::string))) { - kernel_ctx->EmplaceBackAttr( - std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr)))); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "unsupported cast op attribute `%s` to Scalar when construct " - "KernelContext in dygraph.", - attr_names[i])); + if (attrs.find(attr_names[i]) != attrs.end() || + default_attrs.find(attr_names[i]) != + default_attrs.end()) { // scalar is in the attribute + auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); + if (std::type_index(attr.type()) == std::type_index(typeid(float))) { + kernel_ctx->EmplaceBackAttr( + std::move(pten::Scalar(BOOST_GET_CONST(float, attr)))); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::string))) { + kernel_ctx->EmplaceBackAttr( + std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr)))); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to Scalar when construct " + "KernelContext in dygraph.", + attr_names[i])); + } + } else { // scalar is in the input + auto& ins_vector = ins.at(attr_names[i]); + kernel_ctx->EmplaceBackAttr(std::move( + experimental::MakePtenScalarFromVar(ins_vector[0]->Var()))); } + } else { // TODO(chenweihang): support other attrs later + auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); if (attr_defs[i].type_index == std::type_index(typeid(int))) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr)); } else if (attr_defs[i].type_index == std::type_index(typeid(float))) { @@ -430,7 +471,7 @@ static void BuildDygraphPtenKernelContext( // TODO(YuanRisheng) Need support vector attr } else { PADDLE_THROW(platform::errors::Unimplemented( - "unsupported cast op attribute `%s` when construct " + "Unsupported cast op attribute `%s` when construct " "KernelContext in dygraph.", attr_names[i])); } diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index 480a08bfb7e..c28ca45fc14 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -102,13 +102,23 @@ class FillConstantOp : public framework::OperatorWithKernel { framework::KernelSignature GetExpectedPtenKernelArgs( const framework::ExecutionContext& ctx) const override { - if (!ctx.HasInput("ShapeTensor") && - ctx.MultiInput("ShapeTensorList").empty() && - !ctx.HasInput("ValueTensor") && - !ctx.OutputVar("Out")->IsType()) { + std::string shape; + if (ctx.HasInput("ShapeTensor")) { + shape = "ShapeTensor"; + } else if (ctx.MultiInput("ShapeTensorList").size()) { + shape = "ShapeTensorList"; + } else { + shape = "shape"; + } + std::string value; + if (ctx.HasInput("ValueTensor")) { + value = "ValueTensor"; + } else { const auto& str_value = ctx.Attr("str_value"); - std::string value = str_value.empty() ? "value" : "str_value"; - return framework::KernelSignature("fill_constant.scalar", {}, {value}, + value = str_value.empty() ? "value" : "str_value"; + } + if (!ctx.OutputVar("Out")->IsType()) { + return framework::KernelSignature("fill_constant", {}, {shape, value}, {"Out"}); } return framework::KernelSignature("fill_constant.unregistered", {}, {}, {}); diff --git a/paddle/pten/api/lib/utils/tensor_utils.cc b/paddle/pten/api/lib/utils/tensor_utils.cc index 04494640824..0983abfa921 100644 --- a/paddle/pten/api/lib/utils/tensor_utils.cc +++ b/paddle/pten/api/lib/utils/tensor_utils.cc @@ -97,6 +97,168 @@ std::unique_ptr MakePtenDenseTensor( } } +pten::Scalar MakePtenScalar(const paddle::framework::LoDTensor& src) { + PADDLE_ENFORCE_EQ(src.numel(), + 1, + paddle::platform::errors::InvalidArgument( + "The Scalar only supports Tensor with 1 element, " + "but now Tensor has %d element.", + src.numel())); + switch (src.type()) { + case paddle::framework::proto::VarType::FP32: + return {src.template data()[0]}; + case paddle::framework::proto::VarType::FP64: + return {src.template data()[0]}; + case paddle::framework::proto::VarType::FP16: + return {src.template data()[0]}; + case paddle::framework::proto::VarType::BF16: + return {src.template data()[0]}; + case paddle::framework::proto::VarType::INT32: + return {src.template data()[0]}; + case paddle::framework::proto::VarType::INT64: + return {src.template data()[0]}; + case paddle::framework::proto::VarType::INT16: + return {src.template data()[0]}; + case paddle::framework::proto::VarType::INT8: + return {src.template data()[0]}; + case paddle::framework::proto::VarType::UINT8: + return {src.template data()[0]}; + case paddle::framework::proto::VarType::BOOL: + return {src.template data()[0]}; + case paddle::framework::proto::VarType::COMPLEX64: + return {src.template data()[0]}; + case paddle::framework::proto::VarType::COMPLEX128: + return {src.template data()[0]}; + default: + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Data type error. Don't support casting a %d LoDTensor to Scalar.", + src.type())); + } +} + +pten::Scalar MakePtenScalarFromVar(const framework::Variable& variable) { + auto expected_place = pten::TransToFluidPlace(pten::Backend::CPU); + if (variable.IsType()) { + const auto& tensor = variable.Get(); + if (!platform::is_same_place(tensor.place(), expected_place)) { + framework::LoDTensor tmp_tensor; + framework::TensorCopySync(tensor, expected_place, &tmp_tensor); + return MakePtenScalar(tmp_tensor); + } else { + return MakePtenScalar(tensor); + } + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupport casting input `%s` type to Scalar when call pt " + "kernel.", + framework::ToTypeName(variable.Type()))); + } +} + +pten::ScalarArray MakePtenScalarArray(const paddle::framework::LoDTensor& src) { + if (src.type() == paddle::framework::proto::VarType::INT64) { + return {src.data(), src.numel()}; + } else if (src.type() == paddle::framework::proto::VarType::INT32) { + return {src.data(), src.numel()}; + } else { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Data type error. When cast a LoDTensor to ScalarArray, " + "the data type of LoDTensor must be int32 or int64, " + "but now data type is %s.", + src.type())); + } +} + +pten::ScalarArray MakePtenScalarArrayFromVar( + const framework::Variable& variable) { + auto expected_place = pten::TransToFluidPlace(pten::Backend::CPU); + if (variable.IsType()) { + const auto& tensor = variable.Get(); + if (!platform::is_same_place(tensor.place(), expected_place)) { + framework::LoDTensor tmp_tensor; + framework::TensorCopySync(tensor, expected_place, &tmp_tensor); + return MakePtenScalarArray(tmp_tensor); + } else { + return MakePtenScalarArray(tensor); + } + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupport casting input `%s` type to ScalarArray when call pt " + "kernel.", + framework::ToTypeName(variable.Type()))); + } +} + +pten::ScalarArray MakePtenScalarArrayFromVarList( + const std::vector& variable_list) { + if (variable_list.size() == 0) { + return pten::ScalarArray(); + } + auto expected_place = pten::TransToFluidPlace(pten::Backend::CPU); + + paddle::framework::proto::VarType::Type data_type; + auto* first_var = variable_list.front(); + if (first_var->IsType()) { + const auto& tensor = first_var->Get(); + data_type = tensor.type(); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupport casting input `%s` type to VectorTensor when call pt " + "kernel.", + framework::ToTypeName(first_var->Type()))); + } + + std::vector vector_data; + vector_data.reserve(variable_list.size()); + + if (data_type == paddle::framework::proto::VarType::INT64) { + for (auto* var : variable_list) { + if (var->IsType()) { + const auto& tensor = var->Get(); + if (!platform::is_same_place(tensor.place(), expected_place)) { + framework::LoDTensor tmp_tensor; + framework::TensorCopySync(tensor, expected_place, &tmp_tensor); + vector_data.push_back(*tmp_tensor.data()); + } else { + vector_data.push_back(*tensor.data()); + } + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupport casting input `%s` type to VectorTensor when call pt " + "kernel.", + framework::ToTypeName(var->Type()))); + } + } + + } else if (data_type == paddle::framework::proto::VarType::INT32) { + for (auto* var : variable_list) { + if (var->IsType()) { + const auto& tensor = var->Get(); + if (!platform::is_same_place(tensor.place(), expected_place)) { + framework::LoDTensor tmp_tensor; + framework::TensorCopySync(tensor, expected_place, &tmp_tensor); + vector_data.push_back(*tmp_tensor.data()); + } else { + vector_data.push_back(*tensor.data()); + } + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupport casting input `%s` type to VectorTensor when call pt " + "kernel.", + framework::ToTypeName(var->Type()))); + } + } + } else { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Data type error. When cast a LoDTensor to VectorTensor, " + "the data type of LoDTensor must be int32 or int64, " + "but now data type is %s.", + data_type)); + } + + return {vector_data}; +} + std::unique_ptr MakePtenTensorBaseFromVar( const framework::Variable& variable, const pten::TensorArgDef& arg_def) { auto expected_place = pten::TransToFluidPlace(arg_def.backend); diff --git a/paddle/pten/api/lib/utils/tensor_utils.h b/paddle/pten/api/lib/utils/tensor_utils.h index 62d4cab02b6..04f0f6c1ff0 100644 --- a/paddle/pten/api/lib/utils/tensor_utils.h +++ b/paddle/pten/api/lib/utils/tensor_utils.h @@ -21,6 +21,8 @@ limitations under the License. */ #include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/api/lib/utils/storage.h" +#include "paddle/pten/common/scalar.h" +#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/core/convert_utils.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/kernel_factory.h" @@ -34,6 +36,18 @@ std::unique_ptr MakePtenDenseTensor( std::unique_ptr MakePtenDenseTensor( const paddle::framework::LoDTensor& src); +pten::Scalar MakePtenScalar(const paddle::framework::LoDTensor& src); + +pten::ScalarArray MakePtenScalarArray(const paddle::framework::LoDTensor& src); + +pten::Scalar MakePtenScalarFromVar(const framework::Variable& variable); + +pten::ScalarArray MakePtenScalarArrayFromVar( + const framework::Variable& variable); + +pten::ScalarArray MakePtenScalarArrayFromVarList( + const std::vector& variable_list); + std::unique_ptr MakePtenTensorBaseFromVar( const framework::Variable& variable, const pten::TensorArgDef& arg_def); diff --git a/paddle/pten/common/scalar_array.h b/paddle/pten/common/scalar_array.h index 701f777d4a0..b4d21b98ca0 100644 --- a/paddle/pten/common/scalar_array.h +++ b/paddle/pten/common/scalar_array.h @@ -118,7 +118,7 @@ class ScalarArrayBase { /// \brief Assign the data_ from const data pointer value of type T. template void AssignData(const TYPE* value_data, int64_t n) { - if (value_data) { + if (value_data || n == 0) { array_.reserve(n); for (auto i = 0; i < n; ++i) { array_.push_back(static_cast(value_data[i])); diff --git a/paddle/pten/kernels/cpu/creation.cc b/paddle/pten/kernels/cpu/creation.cc index 84db03a78ec..71c4e9f1ebe 100644 --- a/paddle/pten/kernels/cpu/creation.cc +++ b/paddle/pten/kernels/cpu/creation.cc @@ -52,16 +52,9 @@ void FillAnyLike(const CPUContext& dev_ctx, template void FillConstant(const CPUContext& dev_ctx, + const ScalarArray& shape, const Scalar& val, DenseTensor* out) { - eigen::fill(dev_ctx, out, val.to()); -} - -template -void FillConstantDynamicShape(const CPUContext& dev_ctx, - const ScalarArray& shape, - const Scalar& val, - DenseTensor* out) { out->Resize(paddle::framework::make_ddim(shape.GetData())); eigen::fill(dev_ctx, out, val.to()); } @@ -81,26 +74,10 @@ PT_REGISTER_KERNEL("fill_any_like", bool, paddle::platform::float16) {} -PT_REGISTER_KERNEL("fill_constant.scalar", - CPU, - ANY, - pten::FillConstant, - float, - double, - uint8_t, - int16_t, - int, - int64_t, - bool, - paddle::platform::float16, - paddle::platform::bfloat16, - paddle::platform::complex, - paddle::platform::complex) {} - PT_REGISTER_KERNEL("fill_constant", CPU, ANY, - pten::FillConstantDynamicShape, + pten::FillConstant, float, double, uint8_t, diff --git a/paddle/pten/kernels/cpu/creation.h b/paddle/pten/kernels/cpu/creation.h index 668e242be9c..33e0107f1ac 100644 --- a/paddle/pten/kernels/cpu/creation.h +++ b/paddle/pten/kernels/cpu/creation.h @@ -31,13 +31,8 @@ void FillAnyLike(const CPUContext& dev_ctx, template void FillConstant(const CPUContext& dev_ctx, + const ScalarArray& shape, const Scalar& val, DenseTensor* out); -template -void FillConstantDynamicShape(const CPUContext& dev_ctx, - const ScalarArray& shape, - const Scalar& val, - DenseTensor* out); - } // namespace pten diff --git a/paddle/pten/kernels/cuda/creation.cu b/paddle/pten/kernels/cuda/creation.cu index 203562a6205..92d3b73ff18 100644 --- a/paddle/pten/kernels/cuda/creation.cu +++ b/paddle/pten/kernels/cuda/creation.cu @@ -53,16 +53,9 @@ void FillAnyLike(const CUDAContext& dev_ctx, template void FillConstant(const CUDAContext& dev_ctx, + const ScalarArray& shape, const Scalar& val, DenseTensor* out) { - eigen::fill(dev_ctx, out, val.to()); -} - -template -void FillConstantDynamicShape(const CUDAContext& dev_ctx, - const ScalarArray& shape, - const Scalar& val, - DenseTensor* out) { out->Resize(paddle::framework::make_ddim(shape.GetData())); eigen::fill(dev_ctx, out, val.to()); } @@ -82,25 +75,10 @@ PT_REGISTER_KERNEL("fill_any_like", bool, paddle::platform::float16) {} -PT_REGISTER_KERNEL("fill_constant.scalar", - CUDA, - ANY, - pten::FillConstant, - float, - double, - uint8_t, - int16_t, - int, - int64_t, - bool, - paddle::platform::float16, - paddle::platform::complex, - paddle::platform::complex) {} - PT_REGISTER_KERNEL("fill_constant", CUDA, ANY, - pten::FillConstantDynamicShape, + pten::FillConstant, float, double, uint8_t, diff --git a/paddle/pten/kernels/cuda/creation.h b/paddle/pten/kernels/cuda/creation.h index 45ea5348e21..4943f720761 100644 --- a/paddle/pten/kernels/cuda/creation.h +++ b/paddle/pten/kernels/cuda/creation.h @@ -34,15 +34,10 @@ void FillAnyLike(const CUDAContext& dev_ctx, template void FillConstant(const CUDAContext& dev_ctx, + const ScalarArray& shape, const Scalar& val, DenseTensor* out); -template -void FillConstantDynamicShape(const CUDAContext& dev_ctx, - const ScalarArray& shape, - const Scalar& val, - DenseTensor* out); - } // namespace pten #endif -- GitLab