From 655f4e3fcafacc479ba9231ab5ba3e13574f2f0f Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 8 Nov 2021 20:53:00 +0800 Subject: [PATCH] [PTen] Add full kernel in pten (incomplete) (#36930) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * initial tensor design & sign kernel demo * add move constructor for meta & add lodtensor * add dirs & sign xpu kernel * add mean cpu&cuda kernel impl * move sign & mean xpu & npu kernel * add selected_rows basic impl * refactor design, BaseTensor to DenseTensor, etc. * add scale mkldnn kernel * polish xpu & npu impl details * fix mkldnn reuse compile failed * change tensor operation lib name * rename util filename * add more comments * change TensorImplInterface to TensorInterface * add kernel key and factory * remove MKLDNNTensorMeta, add MKLDNNDenseTensor * change XXDeviceContext to XXContext * add base kernel registrar utils & test on sign * replace boost::any by paddle::any * fix several ci failed * fix npu compile error * add ordered map util * fix multiple ordered_map compile errors * move dev into include dir * support sign op in static op run * fix static op run error * fix new executor compile failed * add dygraph branch & remove sign_op.h * fix test_infer_no_need_buffer_slots * fix rocm compile link error * fix unitybuild error & clear glog * fix npu compile failed * skip quant trans test * fix part windows compile problem * fix xpu enforce error * fix inference test failed * remove ordered_map to solve quant failed * fix part of rcom compile faild * add more register kernels * revert scale kernel temporarily * fix code format error * add new kernel registrar marco * rename top to tcmpt * revert xpu, npu, mkldnn impl & remove op def * add kernel args parse functor to auto parse args * revert some change & add scale kernels * add op proto in dygraph kernelcontext building * polish kernel dispatch logic & nameing rule * fix scale kernel match error * fix scale test failed * add mean API and unittest * test mean api success * add branch to solve compiled error * skip clang format error * add mean skip rule in op_library * add dot kernel, api and unittest (#6) * remove old kernel and add symbol link * fix dot compiled failed * add merco for module declare * fix npu and xpu compile error * revert sign, mean, scale, dot kernel removing * add comment for keeping old kernel impl * fix mutable_data error * fix bfloat16 conflit * fix inference undef error * adapt to msvc compile rules * polish comment for template inst * add cmake template instantiation for win * fix backend to place device id bug * fix ifdef error * Op2functor (#7) * add kernel args maker class * make args maker non-const * remove debug log * modify codes by review options * split constructPrKernelContext function * fix output name bug * fix test_mean_op test_sign_op failed * fill_any_like kernel refactor (#10) * fill_any_like kernel refactor * remove useless code of full_like c++ api * skip dtype for fill_any_like * add attrs for kernel key constrcut * add use_pt_kernel Flags to control whether to use pt kernel (#13) * add use_pt_kernel Flags to control whether to use pt kernel * change the default value to true for cheking pt kernels * fix mutable_data cuda place error * move high level apis into hapi * remove selectedrows adapting temporarily * Support Scalar in Tensor Compute Library (#14) * fill_any_like kernel refactor * remove useless code of full_like c++ api * Support Scalar in Tensor Compute Library * add scalar in dygraph and static graph mode * keep the basic type for attr, instead of using scalar for all * merge the code * remove mkldnn tensor & polish details * use flat_hash_map and small_vector in kernel factory * Refactor flatten kernel (#12) * refactor flatten kernel * update infershape function * fix compile bugs * fix bugs when merge * fix compiler bugs * fix bugs when run test_flatten_api * fix bugs when run test * Revert "use flat_hash_map and small_vector in kernel factory" This reverts commit 23091495cfdd3df8cc1be592d30f09ea66a7c72b. * Move cpu, cuda and other device code into kernels (#15) * fill_any_like kernel refactor * remove useless code of full_like c++ api * Support Scalar in Tensor Compute Library * add scalar in dygraph and static graph mode * keep the basic type for attr, instead of using scalar for all * merge the code * start refactor matmul * move cpu, cuda and other device modules into kernels * merge code * polish code in operator.cc * Perfect unitests (#16) * perfect unittest * update license * replace with flat_hash_map, small_vector (#19) * fix small_vector build error on windows platform * replace with flat_hash_map, small_vector * remove todo * Perfect unitests (#20) * perfect unittest * update license * fix bug when run tcmpt_utils_test * refactor execution adapting impl * fix insert conflit * Fix CI bug of test_yolov3 (#21) * fill_any_like kernel refactor * remove useless code of full_like c++ api * Support Scalar in Tensor Compute Library * add scalar in dygraph and static graph mode * keep the basic type for attr, instead of using scalar for all * merge the code * start refactor matmul * move cpu, cuda and other device modules into kernels * merge code * polish code in operator.cc * Fix CI bug of test_yolov3 * add the tensor base class, test=develop (#17) * update the tensor base class, test=develop * remove two funcs, test=develop * update the error msg, test=develop Co-authored-by: Chen Weihang * [no-verify] commit backend and tensor signature changes * Rename tcmpt to pten (#23) * rename tcmpt to pten * update omitted files for rename to pten * update omitted file for rename to pten * remove k of all enum var * remove kernel_instantiate (#26) * remove symbols and spatial_tensor * change common to functions * readd share tensor impl methods * add a candidate dense tensor class, test=develop (#28) * change all Pt to Pten * resolve conflit with xiaowei * Op2functor opt1 (#27) * replace to small vector and change to const & * add std::move Co-authored-by: Chen Weihang * polish kernel factory and kernel registry * fix operator test error msg mismatch * remove tensor signature and backend set member * move scalar and polish enforce * revert dtype layout change to fix error * fix enum operator override error * add several base unittests * add pten utils tests * polish some details * Dev/op2func refactor 3 (#30) * add a candidate dense tensor class, test=develop * remove TensorBase::backend(), test=develop * remove some ops, test=develop * cherry-pick the pr of tensor meta, test=develop * moves the dense tensor and some ops, test=develop * update the linalg operator, test=develop * update other operators, test=develop * fix errors, test=develop * fix bugs, test=develop * try to resolve the problem of windows ci, test=develop * updates codes, test=develop * fix the tensor_utils.cc, test=develop * modify the dense tensor, test=develop * fix the data type, test=develop Co-authored-by: shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> * polish some details * polish kernel signature details * fix a bug about offsets of the tensor, test=develop (#31) Co-authored-by: shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> * polish some details * add fill_constant kernel in pten * fix bug of full api (c++) * remove the support for SelectRows in new fill_constant kernel * fix bug of setting fill_any_like kernel key * merge code confilct * modify fill_constant GetExpectedKernelType * fix fill_constant KernelType bug * polish code of build pten KernelContext * refactor code of fill_constant in pten Co-authored-by: Chen Weihang Co-authored-by: chentianyu03 Co-authored-by: YuanRisheng Co-authored-by: 石晓伟 <39303645+Shixiaowei02@users.noreply.github.com> --- paddle/fluid/framework/operator.cc | 4 ++ paddle/fluid/imperative/prepared_operator.cc | 4 ++ paddle/fluid/operators/fill_constant_op.cc | 44 +++++++++++++++++- paddle/pten/CMakeLists.txt | 2 +- paddle/pten/api/include/creation.h | 6 +++ paddle/pten/api/lib/creation.cc | 40 +++++++++++++++- paddle/pten/common/scalar.h | 12 +++++ paddle/pten/core/kernel_utils.h | 1 + paddle/pten/include/infershape.h | 1 + paddle/pten/infershape/CMakeLists.txt | 1 + paddle/pten/infershape/nary.cc | 27 +++++++++++ paddle/pten/infershape/nary.h | 34 ++++++++++++++ paddle/pten/kernels/cpu/creation.cc | 49 +++++++++++++++++++- paddle/pten/kernels/cpu/creation.h | 5 ++ paddle/pten/kernels/cuda/creation.cu | 47 +++++++++++++++++++ paddle/pten/kernels/cuda/creation.h | 5 ++ paddle/pten/kernels/functions/eigen/fill.h | 25 ---------- paddle/pten/tests/api/test_fill_api.cc | 34 ++++++++++++-- 18 files changed, 308 insertions(+), 33 deletions(-) create mode 100644 paddle/pten/infershape/nary.cc create mode 100644 paddle/pten/infershape/nary.h diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index e75fb4e3633..2fc2deb087e 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1838,6 +1838,10 @@ pten::KernelContext OperatorWithKernel::BuildPtenKernelContext( if (std::type_index(attr.type()) == std::type_index(typeid(float))) { op_kernel_ctx.EmplaceBackAttr( std::move(pten::Scalar(BOOST_GET_CONST(float, attr)))); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::string))) { + op_kernel_ctx.EmplaceBackAttr( + std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr)))); } else { PADDLE_THROW(platform::errors::Unimplemented( "unsupported cast op attribute `%s` to Scalar when construct " diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index b2d55babc7e..7c0aaed25ab 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -321,6 +321,10 @@ static pten::KernelContext BuildDygraphPtenKernelContext( if (std::type_index(attr.type()) == std::type_index(typeid(float))) { op_kernel_ctx.EmplaceBackAttr( std::move(pten::Scalar(BOOST_GET_CONST(float, attr)))); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::string))) { + op_kernel_ctx.EmplaceBackAttr( + std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr)))); } else { PADDLE_THROW(platform::errors::Unimplemented( "unsupported cast op attribute `%s` to Scalar when construct " diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index 44dcc343a4b..aea149fbedc 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -64,9 +64,51 @@ class FillConstantOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( + framework::OpKernelType kt = framework::OpKernelType( framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); + // TODO(zyfncg) The force_cpu and place_type are conflicted, it's a issue + // lefted before, and we may merge them in the future. + // In order to invoke new fill_constant kernel, the place of OpKernelType + // will be setted by force_cpu and place_type here. + if (ctx.Attr("force_cpu")) { + kt.place_ = platform::CPUPlace(); + } + auto place_type = ctx.Attr("place_type"); + if (place_type != -1) { + switch (place_type) { + case 0: + kt.place_ = platform::CPUPlace(); + break; + case 1: + case 2: + kt.place_ = platform::CUDAPlace(); + break; + case 3: + kt.place_ = platform::XPUPlace(); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Could NOT determine the place of variable, place_type = %d .", + place_type)); + } + } + + return kt; + } + + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext& ctx) const override { + if (!ctx.HasInput("ShapeTensor") && + ctx.MultiInput("ShapeTensorList").empty() && + !ctx.HasInput("ValueTensor") && + !ctx.OutputVar("Out")->IsType()) { + const auto& str_value = ctx.Attr("str_value"); + std::string value = str_value.empty() ? "value" : "str_value"; + return framework::KernelSignature("fill_constant.scalar", {}, {value}, + {"Out"}); + } + return framework::KernelSignature("fill_constant.unregistered", {}, {}, {}); } }; diff --git a/paddle/pten/CMakeLists.txt b/paddle/pten/CMakeLists.txt index e72ec1f8ae6..0b3bb255703 100644 --- a/paddle/pten/CMakeLists.txt +++ b/paddle/pten/CMakeLists.txt @@ -13,7 +13,7 @@ add_subdirectory(tests) # make an unity target for compile deps set(PTEN_DEPS convert_utils dense_tensor kernel_factory kernel_context) set(PTEN_DEPS ${PTEN_DEPS} math_cpu linalg_cpu creation_cpu manipulation_cpu) -set(PTEN_DEPS ${PTEN_DEPS} unary binary) +set(PTEN_DEPS ${PTEN_DEPS} nary unary binary) if(WITH_GPU OR WITH_ROCM) set(PTEN_DEPS ${PTEN_DEPS} math_cuda linalg_cuda creation_cuda manipulation_cuda) endif() diff --git a/paddle/pten/api/include/creation.h b/paddle/pten/api/include/creation.h index 755038adb1f..b7e7bf55c6b 100644 --- a/paddle/pten/api/include/creation.h +++ b/paddle/pten/api/include/creation.h @@ -21,6 +21,12 @@ namespace paddle { namespace experimental { +Tensor full(const std::vector& shape, + const Scalar& value, + DataType dtype = DataType::FLOAT32, + Backend backend = Backend::CPU, + DataLayout layout = DataLayout::NCHW); + Tensor full_like(const Tensor& x, const Scalar& value, DataType dtype = DataType::UNDEFINED); diff --git a/paddle/pten/api/lib/creation.cc b/paddle/pten/api/lib/creation.cc index 893f8b6fbc6..047b19010a2 100644 --- a/paddle/pten/api/lib/creation.cc +++ b/paddle/pten/api/lib/creation.cc @@ -26,6 +26,41 @@ limitations under the License. */ namespace paddle { namespace experimental { +Tensor full(const std::vector& shape, + const Scalar& value, + DataType dtype, + Backend backend, + DataLayout layout) { + // 1. Get kernel signature and kernel + pten::KernelKey kernel_key{backend, layout, dtype}; + auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( + "fill_constant.scalar", kernel_key); + + // 2. Get Device Context + auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); + auto kernel_context = pten::KernelContext(*dev_ctx); + + // 3. Auto data transform + kernel_context.EmplaceBackAttr(value); + + // 4. InferShape + auto out_meta = pten::FullInferShape(shape, dtype, layout); + + // 5. Prepare outputs + const auto allocator = + std::make_shared( + pten::TransToFluidPlace(kernel_key.backend())); + auto dense_out = std::make_shared(allocator, out_meta); + kernel_context.EmplaceBackOutput(dense_out); + Tensor out; + out.set_impl(dense_out); + + // 6. Call kernel + kernel(&kernel_context); + + return out; +} + Tensor full_like(const Tensor& x, const Scalar& value, paddle::experimental::DataType dtype) { @@ -33,7 +68,10 @@ Tensor full_like(const Tensor& x, auto kernel_key_set = ParseKernelKeyByInputArgs(x); auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey(); auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( - "fill_any_like", kernel_key); + "fill_any_like", + {kernel_key.backend(), + kernel_key.layout(), + dtype == DataType::UNDEFINED ? kernel_key.dtype() : dtype}); // 2. Get Device Context auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); diff --git a/paddle/pten/common/scalar.h b/paddle/pten/common/scalar.h index c55b700979a..ef648ba70f3 100644 --- a/paddle/pten/common/scalar.h +++ b/paddle/pten/common/scalar.h @@ -34,6 +34,18 @@ class Scalar { Scalar(bool val) : tag(Tag::HAS_B) { data_.b = val; } // NOLINT + Scalar(const std::string& str_value) : tag(Tag::HAS_D) { // NOLINT + if (str_value == "inf") { + data_.d = std::numeric_limits::infinity(); + } else if (str_value == "-inf") { + data_.d = -std::numeric_limits::infinity(); + } else if (str_value == "nan") { + data_.d = std::numeric_limits::quiet_NaN(); + } else { + data_.d = std::stod(str_value); + } + } + template inline T to() const { switch (tag) { diff --git a/paddle/pten/core/kernel_utils.h b/paddle/pten/core/kernel_utils.h index 45020260764..23143c06244 100644 --- a/paddle/pten/core/kernel_utils.h +++ b/paddle/pten/core/kernel_utils.h @@ -207,6 +207,7 @@ struct KernelImpl { PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int64_t); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(paddle::platform::float16); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&); + PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); /* Output Helpers */ diff --git a/paddle/pten/include/infershape.h b/paddle/pten/include/infershape.h index 8c1bd43aaa2..d8dd2837a72 100644 --- a/paddle/pten/include/infershape.h +++ b/paddle/pten/include/infershape.h @@ -16,4 +16,5 @@ limitations under the License. */ // See Note: [ How do we organize the kernel directory ] #include "paddle/pten/infershape/binary.h" +#include "paddle/pten/infershape/nary.h" #include "paddle/pten/infershape/unary.h" diff --git a/paddle/pten/infershape/CMakeLists.txt b/paddle/pten/infershape/CMakeLists.txt index 0b3771df357..b32ec0a51c7 100644 --- a/paddle/pten/infershape/CMakeLists.txt +++ b/paddle/pten/infershape/CMakeLists.txt @@ -1,2 +1,3 @@ +cc_library(nary SRCS nary.cc DEPS convert_utils) cc_library(unary SRCS unary.cc DEPS convert_utils) cc_library(binary SRCS binary.cc DEPS convert_utils) diff --git a/paddle/pten/infershape/nary.cc b/paddle/pten/infershape/nary.cc new file mode 100644 index 00000000000..b8745dd9b83 --- /dev/null +++ b/paddle/pten/infershape/nary.cc @@ -0,0 +1,27 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +// See Note [ Why still include the fluid headers? ] +#include "paddle/pten/infershape/nary.h" + +namespace pten { + +DenseTensorMeta FullInferShape(const std::vector& shape, + DataType dtype, + DataLayout layout) { + const auto& out_dims = paddle::framework::make_ddim(shape); + return {dtype, out_dims, layout}; +} + +} // namespace pten diff --git a/paddle/pten/infershape/nary.h b/paddle/pten/infershape/nary.h new file mode 100644 index 00000000000..8900e0ed71c --- /dev/null +++ b/paddle/pten/infershape/nary.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +// See Note [ Why still include the fluid headers? ] +#include "paddle/pten/core/tensor_meta.h" + +namespace pten { + +// Common InferShape Functions for 0-nary operators(no input tensor), The format +// like: +// +// 1. DenseTensorMeta [OpName]InferShape( ...) +// NOTE: The name "InferShape" may be not appropriate. "InferMeta" may be good. +// Because functions in this file +// not only can infer shape, but alse need infer lod or other useful data. + +DenseTensorMeta FullInferShape(const std::vector& shape, + DataType dtype, + DataLayout layout); + +} // namespace pten diff --git a/paddle/pten/kernels/cpu/creation.cc b/paddle/pten/kernels/cpu/creation.cc index c3986c985bd..2ab2537a844 100644 --- a/paddle/pten/kernels/cpu/creation.cc +++ b/paddle/pten/kernels/cpu/creation.cc @@ -24,7 +24,38 @@ void FillAnyLike(const CPUContext& dev_ctx, const DenseTensor& x, const Scalar& val, DenseTensor* out) { - eigen::fill(dev_ctx, out, val.to()); + auto value = val.to(); + using CommonType = typename std::common_type< + float, + typename std::conditional< + std::is_same::value, + float, + T>::type>::type; + + auto common_type_value = static_cast(value); + + PADDLE_ENFORCE_EQ( + (common_type_value >= + static_cast(std::numeric_limits::lowest())) && + (common_type_value <= + static_cast(std::numeric_limits::max())), + true, + paddle::platform::errors::InvalidArgument( + "The filled value is out of range for target type, " + "current kernel type is %s, the range should between %f " + "and %f, but now value is %f.", + typeid(T).name(), + static_cast(std::numeric_limits::lowest()), + static_cast(std::numeric_limits::max()), + static_cast(value))); + eigen::fill(dev_ctx, out, value); +} + +template +void FillConstant(const CPUContext& dev_ctx, + const Scalar& val, + DenseTensor* out) { + eigen::fill(dev_ctx, out, val.to()); } } // namespace pten @@ -41,3 +72,19 @@ PT_REGISTER_KERNEL("fill_any_like", int64_t, bool, paddle::platform::float16) {} + +PT_REGISTER_KERNEL("fill_constant.scalar", + CPU, + ANY, + pten::FillConstant, + float, + double, + uint8_t, + int16_t, + int, + int64_t, + bool, + paddle::platform::float16, + paddle::platform::bfloat16, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/cpu/creation.h b/paddle/pten/kernels/cpu/creation.h index 9991df31555..6d7732033ae 100644 --- a/paddle/pten/kernels/cpu/creation.h +++ b/paddle/pten/kernels/cpu/creation.h @@ -29,4 +29,9 @@ void FillAnyLike(const CPUContext& dev_ctx, const Scalar& val, DenseTensor* out); +template +void FillConstant(const CPUContext& dev_ctx, + const Scalar& val, + DenseTensor* out); + } // namespace pten diff --git a/paddle/pten/kernels/cuda/creation.cu b/paddle/pten/kernels/cuda/creation.cu index 40e965e5aac..b96b5ebea9b 100644 --- a/paddle/pten/kernels/cuda/creation.cu +++ b/paddle/pten/kernels/cuda/creation.cu @@ -24,9 +24,41 @@ void FillAnyLike(const CUDAContext& dev_ctx, const DenseTensor& x, const Scalar& val, DenseTensor* out) { + auto value = val.to(); + using CommonType = typename std::common_type< + float, + typename std::conditional< + std::is_same::value, + float, + T>::type>::type; + + auto common_type_value = static_cast(value); + + PADDLE_ENFORCE_EQ( + (common_type_value >= + static_cast(std::numeric_limits::lowest())) && + (common_type_value <= + static_cast(std::numeric_limits::max())), + true, + paddle::platform::errors::InvalidArgument( + "The filled value is out of range for target type, " + "current kernel type is %s, the range should between %f " + "and %f, but now value is %f.", + typeid(T).name(), + static_cast(std::numeric_limits::lowest()), + static_cast(std::numeric_limits::max()), + static_cast(value))); + eigen::fill(dev_ctx, out, val.to()); } +template +void FillConstant(const CUDAContext& dev_ctx, + const Scalar& val, + DenseTensor* out) { + eigen::fill(dev_ctx, out, val.to()); +} + } // namespace pten PT_REGISTER_MODULE(CreationCUDA); @@ -41,3 +73,18 @@ PT_REGISTER_KERNEL("fill_any_like", int64_t, bool, paddle::platform::float16) {} + +PT_REGISTER_KERNEL("fill_constant.scalar", + CUDA, + ANY, + pten::FillConstant, + float, + double, + uint8_t, + int16_t, + int, + int64_t, + bool, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/cuda/creation.h b/paddle/pten/kernels/cuda/creation.h index 84a868e917b..025cd6ba51b 100644 --- a/paddle/pten/kernels/cuda/creation.h +++ b/paddle/pten/kernels/cuda/creation.h @@ -32,6 +32,11 @@ void FillAnyLike(const CUDAContext& dev_ctx, const Scalar& val, DenseTensor* out); +template +void FillConstant(const CUDAContext& dev_ctx, + const Scalar& val, + DenseTensor* out); + } // namespace pten #endif diff --git a/paddle/pten/kernels/functions/eigen/fill.h b/paddle/pten/kernels/functions/eigen/fill.h index 3897da415c6..122a6aef22d 100644 --- a/paddle/pten/kernels/functions/eigen/fill.h +++ b/paddle/pten/kernels/functions/eigen/fill.h @@ -26,31 +26,6 @@ namespace eigen { template void fill(const DeviceContext& context, DenseTensor* tensor, VType val) { tensor->mutable_data(); - - using CommonType = typename std::common_type< - float, - typename std::conditional< - std::is_same::value, - float, - T>::type>::type; - - auto common_type_value = static_cast(val); - - PADDLE_ENFORCE_EQ( - (common_type_value >= - static_cast(std::numeric_limits::lowest())) && - (common_type_value <= - static_cast(std::numeric_limits::max())), - true, - paddle::platform::errors::InvalidArgument( - "The filled value is out of range for target type, " - "current kernel type is %s, the range should between %f " - "and %f, but now value is %f.", - typeid(T).name(), - static_cast(std::numeric_limits::lowest()), - static_cast(std::numeric_limits::max()), - static_cast(val))); - auto t = pten::EigenVector::Flatten(*tensor); t.device(*context.eigen_device()) = t.constant(static_cast(val)); } diff --git a/paddle/pten/tests/api/test_fill_api.cc b/paddle/pten/tests/api/test_fill_api.cc index cbac4c541f5..89763794254 100644 --- a/paddle/pten/tests/api/test_fill_api.cc +++ b/paddle/pten/tests/api/test_fill_api.cc @@ -81,21 +81,21 @@ TEST(API, zeros_like) { paddle::experimental::Tensor x(dense_x); // 2. test API - auto out = paddle::experimental::zeros_like(x, pten::DataType::FLOAT32); + auto out = paddle::experimental::zeros_like(x, pten::DataType::INT32); // 3. check result ASSERT_EQ(out.shape().size(), 2); ASSERT_EQ(out.shape()[0], 3); ASSERT_EQ(out.numel(), 6); ASSERT_EQ(out.is_cpu(), true); - ASSERT_EQ(out.type(), pten::DataType::FLOAT32); + ASSERT_EQ(out.type(), pten::DataType::INT32); ASSERT_EQ(out.layout(), pten::DataLayout::NCHW); ASSERT_EQ(out.initialized(), true); auto dense_out = std::dynamic_pointer_cast(out.impl()); - auto* actual_result = dense_out->data(); + auto* actual_result = dense_out->data(); for (auto i = 0; i < 6; i++) { - ASSERT_NEAR(actual_result[i], 0, 1e-6f); + ASSERT_EQ(actual_result[i], 0); } } @@ -131,3 +131,29 @@ TEST(API, ones_like) { ASSERT_EQ(actual_result[i], 1); } } + +TEST(API, full) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + + float val = 1.0; + + // 2. test API + auto out = paddle::experimental::full({3, 2}, val, pten::DataType::FLOAT32); + + // 3. check result + ASSERT_EQ(out.shape().size(), 2); + ASSERT_EQ(out.shape()[0], 3); + ASSERT_EQ(out.numel(), 6); + ASSERT_EQ(out.is_cpu(), true); + ASSERT_EQ(out.type(), pten::DataType::FLOAT32); + ASSERT_EQ(out.layout(), pten::DataLayout::NCHW); + ASSERT_EQ(out.initialized(), true); + + auto dense_out = std::dynamic_pointer_cast(out.impl()); + auto* actual_result = dense_out->data(); + for (auto i = 0; i < 6; i++) { + ASSERT_NEAR(actual_result[i], val, 1e-6f); + } +} -- GitLab