diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index ae67e2ebb35ccef7fe07ee8c76db33a459b1dfce..79b8ac6d0b8352b2e817e6bdbefca74c835ad6b2 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/phi/api/ext/dispatch.h" #include "paddle/phi/api/lib/kernel_dispatch.h" +#include "paddle/phi/api/lib/utils/storage.h" #include "paddle/phi/backends/all_context.h" #include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/kernels/transfer_layout_kernel.h" diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index e9108787082d071d67ea0012add837bc18592a0a..16fae8d879cc33a23180b115ecc2620e3e227420 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -9,13 +9,22 @@ add_subdirectory(funcs) # phi depends all phi kernel targets set_property(GLOBAL PROPERTY PHI_KERNELS "") +# [ 1. Common kernel compilation dependencies ] set(COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col vol2col concat_and_split_functor softmax) # remove this dep after removing fluid deps on tensor creation set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_api_utils) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta) -# NOTE: Some kernels depend on some targets that are not commonly used. +# [ 2. Kernels that most kernels depend on ] +# There are a few kernels that are very basic operations, and most of the +# kernels depend on these kernels. +set(COMMON_BAISC_KERNELS empty_kernel full_kernel) +kernel_library(empty_kernel DEPS ${COMMON_KERNEL_DEPS}) +kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel) + +# [ 3. Kernels with special dependencies ] +# Some kernels depend on some targets that are not commonly used. # These targets are not suitable for common dependencies. # In this case, you need to manually generate them here. set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel) @@ -24,8 +33,8 @@ kernel_library(softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_reduce) -# auto parse and build kernel targets by cmake -register_kernels(EXCLUDES ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS}) +# 4. auto parse and build kernel targets by cmake +register_kernels(EXCLUDES ${COMMON_BAISC_KERNELS} ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS} ${COMMON_BAISC_KERNELS} ) # phi sparse kernels add_subdirectory(sparse) diff --git a/paddle/phi/kernels/cast_kernel.h b/paddle/phi/kernels/cast_kernel.h index c760b2842d0c97f0afd848ee3dcc333517349c9e..5e07388f5fb20d3a791bcf288e1e6597479e12c5 100644 --- a/paddle/phi/kernels/cast_kernel.h +++ b/paddle/phi/kernels/cast_kernel.h @@ -29,7 +29,7 @@ template DenseTensor Cast(const Context& dev_ctx, const DenseTensor& x, DataType out_dtype) { - auto dense_out = phi::Empty(dev_ctx); + DenseTensor dense_out; MetaTensor meta_out(&dense_out); CastInferMeta(x, out_dtype, &meta_out); CastKernel(dev_ctx, x, out_dtype, &dense_out); diff --git a/paddle/phi/kernels/complex_kernel.h b/paddle/phi/kernels/complex_kernel.h index 2c52001ece1c44a800372f83b2c59c872af2fb5c..07f93f9b926f174c374bbc20b7c655a65732423f 100644 --- a/paddle/phi/kernels/complex_kernel.h +++ b/paddle/phi/kernels/complex_kernel.h @@ -38,7 +38,7 @@ template < std::is_same>::value, bool> = true> DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) { - auto dense_out = phi::Empty(dev_ctx); + DenseTensor dense_out; MetaTensor meta_out(&dense_out); UnchangedInferMeta(x, &meta_out); ConjKernel(dev_ctx, x, &dense_out); @@ -64,7 +64,7 @@ template < std::is_same>::value, bool> = true> DenseTensor Real(const Context& dev_ctx, const DenseTensor& x) { - auto dense_out = phi::Empty(dev_ctx); + DenseTensor dense_out; MetaTensor meta_out(&dense_out); RealAndImagInferMeta(x, &meta_out); RealKernel(dev_ctx, x, &dense_out); @@ -90,7 +90,7 @@ template < std::is_same>::value, bool> = true> DenseTensor Imag(const Context& dev_ctx, const DenseTensor& x) { - auto dense_out = phi::Empty(dev_ctx); + DenseTensor dense_out; MetaTensor meta_out(&dense_out); RealAndImagInferMeta(x, &meta_out); ImagKernel(dev_ctx, x, &dense_out); diff --git a/paddle/phi/kernels/concat_kernel.h b/paddle/phi/kernels/concat_kernel.h index ed969e963ec0e4d9a7fe3b2ebc3df2253747df27..4e72159aeca671614ccfe483ec1496f70e6b1d6a 100644 --- a/paddle/phi/kernels/concat_kernel.h +++ b/paddle/phi/kernels/concat_kernel.h @@ -38,7 +38,7 @@ DenseTensor Concat(const Context& dev_ctx, meta_x_ptr.push_back(&meta_x.back()); } - auto dense_out = phi::Empty(dev_ctx); + DenseTensor dense_out; MetaTensor meta_out(&dense_out); ConcatInferMeta(meta_x_ptr, axis.to(), &meta_out, /*is_runtime=*/true); ConcatKernel(dev_ctx, x, axis, &dense_out); diff --git a/paddle/phi/kernels/dot_kernel.h b/paddle/phi/kernels/dot_kernel.h index 9377fba204bea4afea5d5346ee4ad13bb1730586..9c7703440d8aeea4dd518436a5bb62dac2f12519 100644 --- a/paddle/phi/kernels/dot_kernel.h +++ b/paddle/phi/kernels/dot_kernel.h @@ -29,7 +29,7 @@ template DenseTensor Dot(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y) { - auto dense_out = phi::Empty(dev_ctx); + DenseTensor dense_out; MetaTensor meta_out(&dense_out); DotInferMeta(x, y, &meta_out); DotKernel(dev_ctx, x, y, &dense_out); diff --git a/paddle/phi/kernels/empty_kernel.h b/paddle/phi/kernels/empty_kernel.h index 0b8d95ee94fb5480684023ec6c71698ba06d9c13..f66f4419fd7f5853565d561751d793b8f10c9b46 100644 --- a/paddle/phi/kernels/empty_kernel.h +++ b/paddle/phi/kernels/empty_kernel.h @@ -14,9 +14,9 @@ #pragma once -#include "paddle/phi/api/lib/utils/storage.h" #include "paddle/phi/common/scalar_array.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" #include "paddle/phi/infermeta/nullary.h" #include "paddle/phi/infermeta/unary.h" @@ -34,28 +34,17 @@ void EmptyLikeKernel(const Context& dev_ctx, DataType dtype, DenseTensor* out); -// TODO(chenweihang): the tensor creation method need to be replaced later, -// all kernel api call Empty here instead of making tensor self template DenseTensor Empty(const Context& dev_ctx, DenseTensorMeta&& meta) { - phi::DenseTensor dense_out( - phi::make_intrusive( - dev_ctx.GetPlace()), - std::move(meta)); + phi::DenseTensor dense_out; + dense_out.set_meta(meta); + dev_ctx.Alloc(&dense_out, dense_out.dtype()); return dense_out; } -template -DenseTensor Empty(const Context& dev_ctx) { - return Empty(dev_ctx, - {paddle::experimental::CppTypeToDataType::Type(), - {-1}, - DataLayout::NCHW}); -} - template DenseTensor Empty(const Context& dev_ctx, const ScalarArray& shape) { - auto dense_out = Empty(dev_ctx); + DenseTensor dense_out; MetaTensor meta_out(&dense_out); DataType dtype = paddle::experimental::CppTypeToDataType::Type(); CreateInferMeta(shape, dtype, &meta_out); @@ -65,7 +54,7 @@ DenseTensor Empty(const Context& dev_ctx, const ScalarArray& shape) { template DenseTensor EmptyLike(const Context& dev_ctx, const DenseTensor& x) { - auto dense_out = Empty(dev_ctx); + DenseTensor dense_out; MetaTensor meta_out(&dense_out); DataType dtype = paddle::experimental::CppTypeToDataType::Type(); CreateLikeInferMeta(x, dtype, &meta_out); diff --git a/paddle/phi/kernels/flatten_kernel.h b/paddle/phi/kernels/flatten_kernel.h index de57dcf2e8d3a06b027190aabffb7e7d1d8ebcfe..808af7d9b7beedfc01da7ac53234d9b469c5239f 100644 --- a/paddle/phi/kernels/flatten_kernel.h +++ b/paddle/phi/kernels/flatten_kernel.h @@ -40,7 +40,7 @@ DenseTensor Flatten(const Context& dev_ctx, const DenseTensor& x, int start_axis, int stop_axis) { - auto dense_out = Empty(dev_ctx); + DenseTensor dense_out; MetaTensor meta_out(&dense_out); FlattenInferMeta(x, start_axis, stop_axis, &meta_out); FlattenKernel(dev_ctx, x, start_axis, stop_axis, &dense_out); diff --git a/paddle/phi/kernels/full_kernel.h b/paddle/phi/kernels/full_kernel.h index 05929ba83f3b8e61c79358874a0e20064c981038..c44f048051d5dc1c4bc26b0ab34c1e5193e1fa74 100644 --- a/paddle/phi/kernels/full_kernel.h +++ b/paddle/phi/kernels/full_kernel.h @@ -41,7 +41,7 @@ template DenseTensor Full(const Context& dev_ctx, const ScalarArray& shape, const Scalar& val) { - auto dense_out = Empty(dev_ctx); + DenseTensor dense_out; MetaTensor meta_out(&dense_out); DataType dtype = paddle::experimental::CppTypeToDataType::Type(); CreateInferMeta(shape, dtype, &meta_out); @@ -53,7 +53,7 @@ template DenseTensor FullLike(const Context& dev_ctx, const DenseTensor& x, const Scalar& val) { - auto dense_out = Empty(dev_ctx); + DenseTensor dense_out; MetaTensor meta_out(&dense_out); DataType dtype = paddle::experimental::CppTypeToDataType::Type(); CreateLikeInferMeta(x, dtype, &meta_out); diff --git a/paddle/phi/kernels/funcs/reduce_function.h b/paddle/phi/kernels/funcs/reduce_function.h index 7df772682ecf9d05f77edccdc38d93d4220c6496..ce6bb0d559c8143ffa443238043541bec987e0d3 100644 --- a/paddle/phi/kernels/funcs/reduce_function.h +++ b/paddle/phi/kernels/funcs/reduce_function.h @@ -344,9 +344,8 @@ struct ReduceConfig { const phi::GPUContext& dev_ctx, phi::DenseTensor* tmp) { if (should_reduce_again) { - tmp->ResizeAndAllocate(phi::make_ddim( + tmp->Resize(phi::make_ddim( {static_cast(left_num * grid.z * grid.y * sizeof(Ty))})); - output_data = dev_ctx.Alloc(tmp); } else { output_data = y_data; @@ -1053,8 +1052,8 @@ CubTensorReduceImpl(const Tx* x_data, reducer, reducer.initial(), stream); - phi::DenseTensor tmp = - phi::Empty(dev_ctx, {static_cast(temp_storage_bytes)}); + phi::DenseTensor tmp = phi::Empty( + dev_ctx, {static_cast(temp_storage_bytes)}); auto* temp_storage = dev_ctx.Alloc(&tmp); @@ -1106,7 +1105,7 @@ void TensorReduceImpl(const phi::GPUContext& dev_ctx, // y_data; phi::DDim tmp_ddim; - phi::DenseTensor tmp = phi::Empty(dev_ctx); + phi::DenseTensor tmp; auto x_data = x.data(); auto y_data = y->data(); diff --git a/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h b/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h index 7c8d10e05653d0bacd0ff93d5363f0c6a617f0c3..d06bdc55030567ae4de8ba51bec7282231cc8661 100644 --- a/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_grad_kernel_impl.h @@ -329,8 +329,8 @@ void MatmulGradKernel(const Context& dev_ctx, x_conj = Conj(dev_ctx, x); y_conj = Conj(dev_ctx, y); - DenseTensor dx_help = Empty(dev_ctx); - DenseTensor dy_help = Empty(dev_ctx); + DenseTensor dx_help; + DenseTensor dy_help; if (transpose_x) { if (transpose_y) { @@ -686,8 +686,8 @@ void MatmulDoubleGradKernel(const Context& dev_ctx, y_conj = Conj(dev_ctx, y); } - DenseTensor dx_help = Empty(dev_ctx); - DenseTensor dy_help = Empty(dev_ctx); + DenseTensor dx_help; + DenseTensor dy_help; if (transpose_x) { if (transpose_y) { @@ -1373,10 +1373,10 @@ void MatmulTripleGradKernel(const Context& dev_ctx, VLOG(3) << "It need cost much time to reduce sum for the broadcast and " "wastes the memory. So we should avoid the case in reality"; - DenseTensor out_dx_help = Empty(dev_ctx); - DenseTensor out_dy_help = Empty(dev_ctx); - DenseTensor out_d_ddx_help = Empty(dev_ctx); - DenseTensor out_d_ddy_help = Empty(dev_ctx); + DenseTensor out_dx_help; + DenseTensor out_dy_help; + DenseTensor out_d_ddx_help; + DenseTensor out_d_ddy_help; if (out_d_dout) { ddx_conj = Conj(dev_ctx, ddx); diff --git a/paddle/phi/kernels/impl/triangular_solve_grad_kernel_impl.h b/paddle/phi/kernels/impl/triangular_solve_grad_kernel_impl.h index a6868ebe6ca51c1e412249695469d6e3ec35363c..9b1e4b1d3a65d5c0da831a36152cff85a3353fa3 100644 --- a/paddle/phi/kernels/impl/triangular_solve_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/triangular_solve_grad_kernel_impl.h @@ -49,7 +49,7 @@ void TriangularSolveGradKernel(const Context& dev_ctx, DenseTensor dy_bst = phi::Empty(dev_ctx, y_bst_dims_array); if (dy) { // calculate x's conjugate for complex - DenseTensor x_conj = phi::Empty(dev_ctx); + DenseTensor x_conj; x_conj.Resize(x.dims()); phi::funcs::ForRange x_for_range(dev_ctx, x.numel()); @@ -76,7 +76,7 @@ void TriangularSolveGradKernel(const Context& dev_ctx, DenseTensor dx_bst = phi::Empty(dev_ctx, x_bst_dims_array); if (dx) { // calculate x's conjugate for complex - DenseTensor out_conj = phi::Empty(dev_ctx); + DenseTensor out_conj; out_conj.Resize(out.dims()); phi::funcs::ForRange out_for_range(dev_ctx, out.numel()); diff --git a/paddle/phi/kernels/math_kernel.h b/paddle/phi/kernels/math_kernel.h index 342393d79bd4d3729afdae45b423b50827a37d61..fe8f3b749cdd8a72836cb05a75bc75ba10066e0e 100644 --- a/paddle/phi/kernels/math_kernel.h +++ b/paddle/phi/kernels/math_kernel.h @@ -109,7 +109,7 @@ template DenseTensor Add(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y) { - auto dense_out = phi::Empty(dev_ctx); + DenseTensor dense_out; MetaTensor meta_out(&dense_out); ElementwiseInferMeta(x, y, &meta_out); AddKernel(dev_ctx, x, y, &dense_out); @@ -120,7 +120,7 @@ template DenseTensor Subtract(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y) { - auto dense_out = phi::Empty(dev_ctx); + DenseTensor dense_out; MetaTensor meta_out(&dense_out); ElementwiseInferMeta(x, y, &meta_out); SubtractKernel(dev_ctx, x, y, &dense_out); @@ -131,7 +131,7 @@ template DenseTensor Divide(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y) { - auto dense_out = phi::Empty(dev_ctx); + DenseTensor dense_out; MetaTensor meta_out(&dense_out); ElementwiseInferMeta(x, y, &meta_out); DivideKernel(dev_ctx, x, y, &dense_out); @@ -142,7 +142,7 @@ template DenseTensor Multiply(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y) { - auto dense_out = phi::Empty(dev_ctx); + DenseTensor dense_out; MetaTensor meta_out(&dense_out); ElementwiseInferMeta(x, y, &meta_out); MultiplyKernel(dev_ctx, x, y, &dense_out); @@ -154,7 +154,7 @@ DenseTensor Mean(const Context& dev_ctx, const DenseTensor& x, const std::vector& axis, bool keep_dim) { - auto dense_out = phi::Empty(dev_ctx); + DenseTensor dense_out; MetaTensor meta_out(&dense_out); ReduceInferMetaBase(x, axis, keep_dim, false, x.dtype(), &meta_out); MeanKernel(dev_ctx, x, axis, keep_dim, &dense_out); @@ -167,7 +167,7 @@ DenseTensor Sum(const Context& dev_ctx, const std::vector& axis, DataType dtype, bool keep_dim) { - auto dense_out = phi::Empty(dev_ctx); + DenseTensor dense_out; MetaTensor meta_out(&dense_out); SumInferMeta(x, axis, dtype, keep_dim, &meta_out); SumKernel(dev_ctx, x, axis, dtype, keep_dim, &dense_out); diff --git a/paddle/phi/kernels/matmul_kernel.h b/paddle/phi/kernels/matmul_kernel.h index 1f1cb22c2717b14339dc0f85be268adb72d75994..b524b9e5863dcbcacaea11df9a96b71570312213 100644 --- a/paddle/phi/kernels/matmul_kernel.h +++ b/paddle/phi/kernels/matmul_kernel.h @@ -35,7 +35,7 @@ DenseTensor Matmul(const Context& dev_ctx, const DenseTensor& y, bool transpose_x = false, bool transpose_y = false) { - auto dense_out = Empty(dev_ctx); + DenseTensor dense_out; MetaTensor meta_out(&dense_out); MatmulInferMeta(x, y, transpose_x, transpose_y, &meta_out); MatmulKernel(dev_ctx, x, y, transpose_x, transpose_y, &dense_out); diff --git a/paddle/phi/kernels/reshape_kernel.h b/paddle/phi/kernels/reshape_kernel.h index 1a3d0db8a8a3b8668431ced5b95480f5d4758566..848f162a2a881ddc4d4ea136313216fd569accfd 100644 --- a/paddle/phi/kernels/reshape_kernel.h +++ b/paddle/phi/kernels/reshape_kernel.h @@ -38,7 +38,7 @@ template DenseTensor Reshape(const Context& dev_ctx, const DenseTensor& x, const std::vector& shape) { - auto dense_out = Empty(dev_ctx); + DenseTensor dense_out; MetaTensor meta_out(&dense_out); InferMetaFromVecValue(x, shape, &meta_out); ReshapeKernel(dev_ctx, x, ScalarArray(shape), &dense_out); diff --git a/paddle/phi/kernels/scale_kernel.h b/paddle/phi/kernels/scale_kernel.h index 22e6efb03ac2d113180ab6010e3e28736c928373..7537dc1130b83ea3963ae128bf8ce3859411c199 100644 --- a/paddle/phi/kernels/scale_kernel.h +++ b/paddle/phi/kernels/scale_kernel.h @@ -34,7 +34,7 @@ DenseTensor Scale(const Context& dev_ctx, const Scalar& scale, float bias, bool bias_after_scale) { - auto dense_out = phi::Empty(dev_ctx); + DenseTensor dense_out; MetaTensor meta_out(&dense_out); UnchangedInferMeta(x, &meta_out); ScaleKernel( diff --git a/paddle/phi/kernels/sign_kernel.h b/paddle/phi/kernels/sign_kernel.h index 7ee1145012dbd94d077bc229e0ffb8e833eb52a4..4b5900d90f45daa01c117b9f1649a152734c5b76 100644 --- a/paddle/phi/kernels/sign_kernel.h +++ b/paddle/phi/kernels/sign_kernel.h @@ -25,7 +25,7 @@ void SignKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); template DenseTensor Sign(const Context& dev_ctx, const DenseTensor& x) { - auto dense_out = phi::Empty(dev_ctx); + DenseTensor dense_out; MetaTensor meta_out(&dense_out); UnchangedInferMeta(x, &meta_out); SignKernel(dev_ctx, x, &dense_out); diff --git a/paddle/phi/kernels/sparse/convolution_grad_kernel.h b/paddle/phi/kernels/sparse/convolution_grad_kernel.h index 1a6ac852448a5f4a25248d2a2b6919a301a04874..3ada3473355d075188c746e95ebfa35939a094be 100644 --- a/paddle/phi/kernels/sparse/convolution_grad_kernel.h +++ b/paddle/phi/kernels/sparse/convolution_grad_kernel.h @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/sparse/convolution_kernel.h" namespace phi { namespace sparse { @@ -45,6 +47,7 @@ std::vector Conv3dGrad(const Context& dev_ctx, const int groups) { DenseTensor x_grad = phi::Empty(dev_ctx); DenseTensor kernel_grad = phi::Empty(dev_ctx); + // TODO(zhangkaihuo): call InferMeta func here Conv3dGradKernel(dev_ctx, x, rulebook, diff --git a/paddle/phi/kernels/sparse/convolution_kernel.h b/paddle/phi/kernels/sparse/convolution_kernel.h index 71160a6365dc778e40476af960f21443cac698e5..1c1e62c8306c264c363ff90740764942ebd09dc6 100644 --- a/paddle/phi/kernels/sparse/convolution_kernel.h +++ b/paddle/phi/kernels/sparse/convolution_kernel.h @@ -14,11 +14,24 @@ limitations under the License. */ #pragma once +#include "paddle/phi/api/lib/utils/storage.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/kernels/empty_kernel.h" namespace phi { + +template +DenseTensor Empty(const Context& dev_ctx) { + phi::DenseTensor dense_out( + phi::make_intrusive( + dev_ctx.GetPlace()), + {paddle::experimental::CppTypeToDataType::Type(), + {-1}, + DataLayout::NCHW}); + return dense_out; +} + namespace sparse { struct Dims4D { diff --git a/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc index d4f770ce8713aa84c7f87f0e49bf8468467ffdbf..cb6cf435435e6a93543555fc57dde661bb5326d0 100644 --- a/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc @@ -74,6 +74,7 @@ void Conv3dGradKernel(const Context& dev_ctx, dev_ctx.Alloc( kernel_grad, kernel_grad->dtype(), kernel_grad->numel() * sizeof(T)); T* d_kernel_ptr = kernel_grad->data(); + memset(d_kernel_ptr, 0, sizeof(T) * kernel_grad->numel()); Gather(x.non_zero_elements().data(), rulebook_ptr + rulebook_len, diff --git a/paddle/phi/kernels/sparse/sparse_utils_kernel.h b/paddle/phi/kernels/sparse/sparse_utils_kernel.h index d96d134a26b08a0208122a7ea9a62ce07c033d51..c83b2130ed4550540a98148aec26e42332c8060d 100644 --- a/paddle/phi/kernels/sparse/sparse_utils_kernel.h +++ b/paddle/phi/kernels/sparse/sparse_utils_kernel.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/phi/api/lib/utils/storage.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" @@ -63,8 +64,8 @@ template SparseCooTensor DenseToSparseCoo(const Context& dev_ctx, const DenseTensor& x, const int64_t sparse_dim) { - DenseTensor indices = phi::Empty(dev_ctx); - DenseTensor values = phi::Empty(dev_ctx); + DenseTensor indices; + DenseTensor values; SparseCooTensor coo(indices, values, x.dims()); DenseToSparseCooKernel(dev_ctx, x, sparse_dim, &coo); return coo; @@ -78,8 +79,8 @@ void SparseCsrToCooKernel(const Context& dev_ctx, template SparseCooTensor SparseCsrToCoo(const Context& dev_ctx, const SparseCsrTensor& x) { - DenseTensor indices = phi::Empty(dev_ctx); - DenseTensor values = phi::Empty(dev_ctx); + DenseTensor indices; + DenseTensor values; SparseCooTensor coo(indices, values, x.dims()); SparseCsrToCooKernel(dev_ctx, x, &coo); return coo; @@ -93,9 +94,9 @@ void SparseCooToCsrKernel(const Context& dev_ctx, template SparseCsrTensor SparseCooToCsr(const Context& dev_ctx, const SparseCooTensor& x) { - DenseTensor non_zero_crows = phi::Empty(dev_ctx); - DenseTensor non_zero_cols = phi::Empty(dev_ctx); - DenseTensor non_zero_elements = phi::Empty(dev_ctx); + DenseTensor non_zero_crows; + DenseTensor non_zero_cols; + DenseTensor non_zero_elements; SparseCsrTensor csr( non_zero_crows, non_zero_cols, non_zero_elements, x.dims()); SparseCooToCsrKernel(dev_ctx, x, &csr); @@ -113,8 +114,8 @@ void DenseToSparseCsrKernel(const Context& dev_ctx, phi::errors::InvalidArgument( "SparseCsrTensor only support 2-D or 3-D Tensor.")); const int64_t sparse_dim = x_dims.size() == 2 ? 2 : 3; - DenseTensor indices = phi::Empty(dev_ctx); - DenseTensor values = phi::Empty(dev_ctx); + DenseTensor indices; + DenseTensor values; SparseCooTensor coo(indices, values, x.dims()); DenseToSparseCooKernel(dev_ctx, x, sparse_dim, &coo); SparseCooToCsrKernel(dev_ctx, coo, out); @@ -122,9 +123,9 @@ void DenseToSparseCsrKernel(const Context& dev_ctx, template SparseCsrTensor DenseToSparseCsr(const Context& dev_ctx, const DenseTensor& x) { - DenseTensor non_zero_crows = phi::Empty(dev_ctx); - DenseTensor non_zero_cols = phi::Empty(dev_ctx); - DenseTensor non_zero_elements = phi::Empty(dev_ctx); + DenseTensor non_zero_crows; + DenseTensor non_zero_cols; + DenseTensor non_zero_elements; SparseCsrTensor csr( non_zero_crows, non_zero_cols, non_zero_elements, x.dims()); DenseToSparseCsrKernel(dev_ctx, x, &csr); @@ -148,8 +149,8 @@ template void SparseCsrToDenseKernel(const Context& dev_ctx, const SparseCsrTensor& x, DenseTensor* out) { - DenseTensor indices = phi::Empty(dev_ctx); - DenseTensor values = phi::Empty(dev_ctx); + DenseTensor indices; + DenseTensor values; SparseCooTensor coo(indices, values, x.dims()); SparseCsrToCooKernel(dev_ctx, x, &coo); SparseCooToDenseKernel(dev_ctx, coo, out); diff --git a/paddle/phi/kernels/split_kernel.h b/paddle/phi/kernels/split_kernel.h index 840fe4366ce7eaca82608612dfb41cc7f7783f4c..e42b25e60c42291b3949c144c4aba6f62ec98a44 100644 --- a/paddle/phi/kernels/split_kernel.h +++ b/paddle/phi/kernels/split_kernel.h @@ -50,7 +50,7 @@ std::vector Split(const Context& dev_ctx, result.reserve(out_number); for (size_t i = 0; i < out_number; ++i) { - result.emplace_back(phi::Empty(dev_ctx)); + result.emplace_back(DenseTensor()); out_meta.emplace_back(&result.back()); out_meta_ptr.push_back(&out_meta.back()); } diff --git a/paddle/phi/kernels/transpose_kernel.h b/paddle/phi/kernels/transpose_kernel.h index 3d89b324bab5b08490457183b7aa31fd4704744b..b8d7fbaa2757d76db1005ce57498d181046d77c9 100644 --- a/paddle/phi/kernels/transpose_kernel.h +++ b/paddle/phi/kernels/transpose_kernel.h @@ -32,7 +32,7 @@ template DenseTensor Transpose(const Context& dev_ctx, const DenseTensor& x, const std::vector& axis) { - auto dense_out = Empty(dev_ctx); + DenseTensor dense_out; MetaTensor meta_out(&dense_out); TransposeInferMeta(x, axis, &meta_out); TransposeKernel(dev_ctx, x, axis, &dense_out); diff --git a/paddle/phi/tests/api/scale_api.h b/paddle/phi/tests/api/scale_api.h index d93f00129b9a14170b979dfd23eb6e292e996ce8..6b9bb7aecefe6fabb2334e6f4d150a3937628f34 100644 --- a/paddle/phi/tests/api/scale_api.h +++ b/paddle/phi/tests/api/scale_api.h @@ -20,6 +20,7 @@ #include "paddle/phi/api/lib/api_registry.h" #include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/utils/allocator.h" +#include "paddle/phi/api/lib/utils/storage.h" #include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar_array.h" #include "paddle/phi/core/kernel_registry.h" diff --git a/paddle/phi/tests/kernels/test_sparse_utils_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_utils_dev_api.cc index 3e2ad0495f3ba85836dc08afa3f4fa4ed0b10afd..b8f214b79e290c2e102fc2c08dab2ddc6a61dd71 100644 --- a/paddle/phi/tests/kernels/test_sparse_utils_dev_api.cc +++ b/paddle/phi/tests/kernels/test_sparse_utils_dev_api.cc @@ -90,6 +90,10 @@ void TestDenseToSparseCoo(const DenseTensor& dense_x, phi::CPUContext dev_ctx_cpu; dev_ctx_cpu.Init(); + dev_ctx_cpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(phi::CPUPlace()) + .get()); // 1. test cpu auto cpu_sparse_out = @@ -300,6 +304,11 @@ void TestSparseCsrToCoo(const DDim& dense_dims, // 1. test cpu phi::CPUContext dev_ctx_cpu; + dev_ctx_cpu.Init(); + dev_ctx_cpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(phi::CPUPlace()) + .get()); auto cpu_sparse_out = sparse::SparseCsrToCoo(dev_ctx_cpu, csr); CheckResult(&dev_ctx_cpu, cpu_sparse_out, @@ -473,6 +482,11 @@ void TestCooToCsr(const DDim& dense_dims, // 1. test cpu phi::CPUContext dev_ctx_cpu; + dev_ctx_cpu.Init(); + dev_ctx_cpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(phi::CPUPlace()) + .get()); auto cpu_sparse_out = sparse::SparseCooToCsr(dev_ctx_cpu, coo); CheckCsrResult(&dev_ctx_cpu, cpu_sparse_out, @@ -563,6 +577,11 @@ void TestDenseToSparseCsr(const DenseTensor& dense_x, const auto alloc = std::make_shared( paddle::platform::CPUPlace()); phi::CPUContext dev_ctx_cpu; + dev_ctx_cpu.Init(); + dev_ctx_cpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(phi::CPUPlace()) + .get()); // 1. test cpu auto cpu_sparse_out = sparse::DenseToSparseCsr(dev_ctx_cpu, dense_x); @@ -667,6 +686,11 @@ void TestSparseCooToDense(const DDim& dense_dims, const int64_t non_zero_num, const int64_t sparse_dim) { phi::CPUContext dev_ctx_cpu; + dev_ctx_cpu.Init(); + dev_ctx_cpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(phi::CPUPlace()) + .get()); const auto alloc = std::make_shared( paddle::platform::CPUPlace()); @@ -836,6 +860,11 @@ void TestSparseCsrToDense(const DDim& dense_dims, // 1. test cpu phi::CPUContext dev_ctx_cpu; + dev_ctx_cpu.Init(); + dev_ctx_cpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(phi::CPUPlace()) + .get()); DenseTensor cpu_sparse_out = sparse::SparseCsrToDense(dev_ctx_cpu, csr); int cmp_cpu = memcmp(cpu_sparse_out.data(), dense_data.data(),