diff --git a/paddle/fluid/framework/data_layout_transform.cc b/paddle/fluid/framework/data_layout_transform.cc index 3b7d5fb4d8cd972b1bcf210739c684780d4251a0..8e94a04ab161be1dd21db6775a2b978fe76aa9ed 100644 --- a/paddle/fluid/framework/data_layout_transform.cc +++ b/paddle/fluid/framework/data_layout_transform.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/data_layout_transform.h" +#include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/imperative/tests/test_group.cc b/paddle/fluid/imperative/tests/test_group.cc index fef8c346f4b5b4ac852d52dbf888f10d59c63d77..d5f09868b899725cfac9bd961e3899ce37e586fc 100644 --- a/paddle/fluid/imperative/tests/test_group.cc +++ b/paddle/fluid/imperative/tests/test_group.cc @@ -18,6 +18,10 @@ #include "gtest/gtest.h" #include "paddle/fluid/imperative/reducer.h" +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/phi/core/utils/data_type.h" + namespace paddle { namespace imperative { diff --git a/paddle/fluid/operators/detection/anchor_generator_op.h b/paddle/fluid/operators/detection/anchor_generator_op.h index 70194a0abcbb27ae945703a46e42595e80d3039b..726b65fb1f427ce19cb75f9495bf94f6cfb237a0 100644 --- a/paddle/fluid/operators/detection/anchor_generator_op.h +++ b/paddle/fluid/operators/detection/anchor_generator_op.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/common/transform.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { diff --git a/paddle/fluid/operators/detection/prior_box_op.h b/paddle/fluid/operators/detection/prior_box_op.h index 4c5249ec56fce69f3a6659ebcba3a07101fd96f4..b49841399c71f9139e70cb114e02591366a16f0e 100644 --- a/paddle/fluid/operators/detection/prior_box_op.h +++ b/paddle/fluid/operators/detection/prior_box_op.h @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/common/transform.h" #include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { diff --git a/paddle/fluid/operators/index_select_op.h b/paddle/fluid/operators/index_select_op.h index 6bb91f325f9535fca8d8b354230045c9ff509100..ad1542666fd39d8c853ef8e941974386c53d4bd3 100644 --- a/paddle/fluid/operators/index_select_op.h +++ b/paddle/fluid/operators/index_select_op.h @@ -17,6 +17,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { diff --git a/paddle/fluid/operators/interpolate_op.h b/paddle/fluid/operators/interpolate_op.h index ad67efc4b78d55cf94beed770a6f954cf1a03d0a..12378a5f1f1d01ff346d15ebcb5a7071790ba66d 100644 --- a/paddle/fluid/operators/interpolate_op.h +++ b/paddle/fluid/operators/interpolate_op.h @@ -16,6 +16,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { diff --git a/paddle/fluid/operators/math/prelu.h b/paddle/fluid/operators/math/prelu.h index 5596a9fefed1be2d9b83da358f5951c9a25aff08..00ff1fbcbc38dba637fa13c9d262bbdc8a24452c 100644 --- a/paddle/fluid/operators/math/prelu.h +++ b/paddle/fluid/operators/math/prelu.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index eed7b64a3c83255e3404498508791a90de299602..8dbeff2bce1350221ac4f6b326a6de0fab188df5 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/jit/kernels.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/math/tree2col.cu b/paddle/fluid/operators/math/tree2col.cu index 22bdc48768dae5626c22dd4bfda019804b3c1e17..abaf5d3f3bbf95c23f267b284453b6c6b19c3b44 100644 --- a/paddle/fluid/operators/math/tree2col.cu +++ b/paddle/fluid/operators/math/tree2col.cu @@ -14,6 +14,7 @@ #include +#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/math/tree2col.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/reduce_ops/check_reduce_rank_test.cu b/paddle/fluid/operators/reduce_ops/check_reduce_rank_test.cu index a724524716be39e554c6046ca809624b7fbb053a..b94a78f898f3e76ea74f4b5d5b16ed4af9749419 100644 --- a/paddle/fluid/operators/reduce_ops/check_reduce_rank_test.cu +++ b/paddle/fluid/operators/reduce_ops/check_reduce_rank_test.cu @@ -13,6 +13,7 @@ // limitations under the License. #include "gtest/gtest.h" +#include "paddle/phi/core/enforce.h" #include "paddle/phi/kernels/gpu/reduce.h" namespace paddle { @@ -20,7 +21,7 @@ namespace operators { namespace details { TEST(test_reduce_rank_check, all) { - using EnforceNotMet = paddle::platform::EnforceNotMet; + using EnforceNotMet = phi::EnforceNotMet; constexpr int kMaxRank = framework::DDim::kMaxRank; for (int rank = 0; rank < kMaxRank; rank++) { @@ -42,7 +43,7 @@ TEST(test_reduce_rank_check, all) { phi::funcs::details::CheckReduceRank(reduce_rank, rank); } else { ASSERT_THROW(phi::funcs::details::CheckReduceRank(reduce_rank, rank), - paddle::platform::EnforceNotMet); + EnforceNotMet); } } } diff --git a/paddle/fluid/operators/sequence_ops/sequence_expand_op.h b/paddle/fluid/operators/sequence_ops/sequence_expand_op.h index 7a7a6f7b3e7148fca4ed50d8f017158147de4815..9270b97cfc36d7796688fc59eeb9dd8cb1d4cf72 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_expand_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_expand_op.h @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/memcpy.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { diff --git a/paddle/phi/core/utils/visit_place.h b/paddle/phi/core/utils/visit_place.h new file mode 100644 index 0000000000000000000000000000000000000000..e2e2ffec1bfee5f1d2ac7c55ce6fef134b17f5f6 --- /dev/null +++ b/paddle/phi/core/utils/visit_place.h @@ -0,0 +1,112 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/enforce.h" + +namespace phi { + +// need add dependency to phi_place when use phi::VisitPlace +template +typename Visitor::result_type VisitPlace(const phi::Place& place, + const Visitor& visitor) { + switch (place.GetType()) { + case phi::AllocationType::GPU: { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + phi::GPUPlace p(place.GetDeviceId()); + return visitor(p); +#else + PADDLE_THROW(phi::errors::Unavailable( + ("Paddle is not compiled with CUDA. Cannot visit cuda_pinned"))); + return typename Visitor::result_type(); +#endif + } + case phi::AllocationType::GPUPINNED: { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + phi::GPUPinnedPlace p; + return visitor(p); +#else + PADDLE_THROW(phi::errors::Unavailable( + ("Paddle is not compiled with CUDA. Cannot visit cuda_pinned"))); + return typename Visitor::result_type(); +#endif + } + case phi::AllocationType::XPU: { +#ifdef PADDLE_WITH_XPU + phi::XPUPlace p(place.GetDeviceId()); + return visitor(p); +#else + PADDLE_THROW(phi::errors::Unavailable( + ("Paddle is not compiled with XPU. Cannot visit xpu device"))); + return typename Visitor::result_type(); +#endif + } + case phi::AllocationType::NPU: { +#ifdef PADDLE_WITH_ASCEND_CL + phi::NPUPlace p(place.GetDeviceId()); + return visitor(p); +#else + PADDLE_THROW(phi::errors::Unavailable( + ("Paddle is not compiled with NPU. Cannot visit npu_pinned"))); + return typename Visitor::result_type(); +#endif + } + case phi::AllocationType::NPUPINNED: { +#ifdef PADDLE_WITH_ASCEND_CL + phi::NPUPinnedPlace p; + return visitor(p); +#else + PADDLE_THROW(phi::errors::Unavailable( + ("Paddle is not compiled with NPU. Cannot visit npu_pinned"))); + return typename Visitor::result_type(); +#endif + } + case phi::AllocationType::IPU: { +#ifdef PADDLE_WITH_IPU + phi::IPUPlace p(place.GetDeviceId()); + return visitor(p); +#else + PADDLE_THROW(phi::errors::Unavailable( + ("Paddle is not compiled with IPU. Cannot visit ipu device"))); + return typename Visitor::result_type(); +#endif + } + case phi::AllocationType::MLU: { +#ifdef PADDLE_WITH_MLU + phi::MLUPlace p(place.GetDeviceId()); + return visitor(p); +#else + PADDLE_THROW(phi::errors::Unavailable( + ("Paddle is not compiled with MLU. Cannot visit mlu device"))); +#endif + } + case phi::AllocationType::CUSTOM: { +#ifdef PADDLE_WITH_CUSTOM_DEVICE + phi::CustomPlace p(place.GetDeviceType(), place.GetDeviceId()); + return visitor(p); +#else + PADDLE_THROW(phi::errors::Unavailable( + ("Paddle is not compiled with CUSTOM. Cannot visit custom device"))); +#endif + } + default: { + phi::CPUPlace p; + return visitor(p); + } + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc b/paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc index 521d620eb7380e2a3e78a88a354e10671ce0a39b..75875f81beee8bf82a5429b5214344f9317a911d 100644 --- a/paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc @@ -45,29 +45,25 @@ void RepeatInterleaveWithTensorIndexGradKernel( repeats_tensor.dims()[0], x_grad->dims()[dim])); - const auto& index_type = - paddle::framework::TransToProtoVarType(repeats_tensor.dtype()); + const auto& index_type = repeats_tensor.dtype(); bool index_type_match = - index_type == paddle::framework::proto::VarType::INT32 || - index_type == paddle::framework::proto::VarType::INT64; + index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64; PADDLE_ENFORCE_EQ(index_type_match, true, phi::errors::InvalidArgument( "Input(Repeats) holds the wrong type, it holds %s, but " "desires to be %s or %s", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - paddle::framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - paddle::framework::proto::VarType::INT64))); + phi::DataTypeToString(index_type), + phi::DataTypeToString(phi::DataType::INT32), + phi::DataTypeToString(phi::DataType::INT64))); phi::DeviceContextPool::Instance().Get(repeats_tensor.place()); - if (index_type == paddle::framework::proto::VarType::INT32) { + if (index_type == phi::DataType::INT32) { phi::funcs::RepeatsTensor2IndexTensor( ctx, repeats_tensor, &index); IndexSelectGradInner(ctx, out_grad, index, x_grad, dim); - } else if (index_type == paddle::framework::proto::VarType::INT64) { + } else if (index_type == phi::DataType::INT64) { phi::funcs::RepeatsTensor2IndexTensor( ctx, repeats_tensor, &index); IndexSelectGradInner( diff --git a/paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc b/paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc index cfdccb5c8d9bacbdd65bdf1b56b78d6e3d1219f6..175b4a750a82030518279a44e97a516cce70c004 100644 --- a/paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc +++ b/paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc @@ -45,8 +45,7 @@ struct EmbeddingCPUSparseFunctor { int64_t row_width = table_t.value().dims()[1]; const auto* table = table_t.value().template data(); auto* output = dev_ctx_.template Alloc(output_t); - auto input_data_type = - paddle::framework::TransToProtoVarType(table_t.value().dtype()); + auto input_data_type = table_t.value().dtype(); for (int64_t i = 0; i < ids_numel; ++i) { if (padding_idx_ != kNoPadding && ids[i] == padding_idx_) { @@ -66,7 +65,7 @@ struct EmbeddingCPUSparseFunctor { phi::errors::InvalidArgument( "the input key should be exists. But received %d.", id_index)); - if (input_data_type == paddle::framework::proto::VarType::BF16) { + if (input_data_type == phi::DataType::BFLOAT16) { memcpy(output + i * row_width, table + id_index * row_width, row_width * sizeof(T)); diff --git a/paddle/phi/kernels/cpu/take_along_axis_grad_kernel.cc b/paddle/phi/kernels/cpu/take_along_axis_grad_kernel.cc index acfc0d4c29d092de0f68876e4b87dceb9b9096fa..8a7238203ec6476e2222b1df9d087e7c67b36a16 100644 --- a/paddle/phi/kernels/cpu/take_along_axis_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/take_along_axis_grad_kernel.cc @@ -43,16 +43,15 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx, phi::funcs::SetConstant functor; functor(dev_ctx, x_grad, static_cast(0)); - const auto& index_type = - paddle::framework::TransToProtoVarType(index.dtype()); - if (index_type == paddle::framework::proto::VarType::INT32) { + const auto& index_type = index.dtype(); + if (index_type == phi::DataType::INT32) { phi::funcs::cpu_scatter_add_kernel( *x_grad, axis, index, out_grad, dev_ctx); // the gradient of gather is scatter - } else if (index_type == paddle::framework::proto::VarType::INT64) { + } else if (index_type == phi::DataType::INT64) { phi::funcs::cpu_scatter_add_kernel( *x_grad, axis, index, out_grad, dev_ctx); } diff --git a/paddle/phi/kernels/cpu/unique_consecutive_kernel.cc b/paddle/phi/kernels/cpu/unique_consecutive_kernel.cc index 560578ed22843e35c3404969326ef744901bace0..f2e0574991407f03559700204c9a3a095d75b375 100644 --- a/paddle/phi/kernels/cpu/unique_consecutive_kernel.cc +++ b/paddle/phi/kernels/cpu/unique_consecutive_kernel.cc @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/unique_consecutive_kernel.h" +#include + #include "paddle/phi/kernels/cpu/unique_consecutive_functor.h" +#include "paddle/phi/kernels/unique_consecutive_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/errors.h" diff --git a/paddle/phi/kernels/cpu/unique_kernel.cc b/paddle/phi/kernels/cpu/unique_kernel.cc index 15c19b24444dd7927c8c27d31ae2cd37a131d8c4..3b742fbd1dfd1e9339a70e0091c5851f7060903f 100644 --- a/paddle/phi/kernels/cpu/unique_kernel.cc +++ b/paddle/phi/kernels/cpu/unique_kernel.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "paddle/phi/kernels/unique_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" diff --git a/paddle/phi/kernels/funcs/math_function.cc b/paddle/phi/kernels/funcs/math_function.cc index 8e051623c4da57bc17cbe5c846937dbf9971fd57..e8bd17efc7d24f175e297a444c1ecd9a5fb16b6c 100644 --- a/paddle/phi/kernels/funcs/math_function.cc +++ b/paddle/phi/kernels/funcs/math_function.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/core/utils/visit_place.h" #ifdef PADDLE_WITH_MKLML #include "paddle/phi/backends/dynload/mklml.h" @@ -236,7 +237,7 @@ void set_constant(const phi::DeviceContext& context, #endif #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) // tensor->place().apply_visitor(func); - paddle::platform::VisitPlace(tensor->place(), func); + phi::VisitPlace(tensor->place(), func); #elif defined(PADDLE_WITH_XPU) func(phi::XPUPlace()); #else diff --git a/paddle/phi/kernels/funcs/math_function.h b/paddle/phi/kernels/funcs/math_function.h index 7e69402c350f0ec2bb2b3fd98c7c19486754e848..d2de413dad51b90d82e00dc4e09fdc0733110655 100644 --- a/paddle/phi/kernels/funcs/math_function.h +++ b/paddle/phi/kernels/funcs/math_function.h @@ -17,12 +17,10 @@ limitations under the License. */ #include #include -#include "paddle/fluid/framework/operator.h" #include "paddle/phi/backends/all_context.h" +#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/utils/data_type.h" -#include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { namespace funcs { diff --git a/paddle/phi/kernels/funcs/math_function_impl.h b/paddle/phi/kernels/funcs/math_function_impl.h index 4e540a19d6c0a75d84680937c0e4a1ce2e764705..ed8e0669ab74f7cbc20a2aedaa9f786d3e61620d 100644 --- a/paddle/phi/kernels/funcs/math_function_impl.h +++ b/paddle/phi/kernels/funcs/math_function_impl.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "paddle/phi/common/data_type.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace phi { diff --git a/paddle/phi/kernels/funcs/segment_pooling.cu b/paddle/phi/kernels/funcs/segment_pooling.cu index f776a5d1905064e5110833f81f058d08f3e23376..2624b5850e1b2a94bf9e7b6870d07b02d68e7da8 100644 --- a/paddle/phi/kernels/funcs/segment_pooling.cu +++ b/paddle/phi/kernels/funcs/segment_pooling.cu @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/gather.cu.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/segment_pooling.h" diff --git a/paddle/phi/kernels/funcs/unique_functor.h b/paddle/phi/kernels/funcs/unique_functor.h index 510236e278d896583e7d4feb4d5bc0ea38d98e62..913ee1afb9f4c7c0430b045658c045ccde43b63b 100644 --- a/paddle/phi/kernels/funcs/unique_functor.h +++ b/paddle/phi/kernels/funcs/unique_functor.h @@ -13,6 +13,8 @@ // limitations under the License. #pragma once +#include + #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" diff --git a/paddle/phi/kernels/gpu/bincount_kernel.cu b/paddle/phi/kernels/gpu/bincount_kernel.cu index 1308d435bba4b908259e500775994671252ef93f..b1000dac6f72ab98b37534e60c3ac0a5ab540423 100644 --- a/paddle/phi/kernels/gpu/bincount_kernel.cu +++ b/paddle/phi/kernels/gpu/bincount_kernel.cu @@ -109,9 +109,6 @@ void BincountCUDAInner(const Context& dev_ctx, <<>>( input_data, input_numel, has_weights, weights_data, output_data); } else { - const auto& weights_type = - paddle::framework::TransToProtoVarType(weights->dtype()); - if (weights->dtype() == DataType::FLOAT32) { float* output_data = dev_ctx.template Alloc(output); phi::funcs::SetConstant()( diff --git a/paddle/phi/kernels/gpu/class_center_sample_kernel.cu b/paddle/phi/kernels/gpu/class_center_sample_kernel.cu index da5624e2d9d7e15159e65d8a32ee0ba23963e157..f63baadbde526e141ae8454588406a1b752c5cef 100644 --- a/paddle/phi/kernels/gpu/class_center_sample_kernel.cu +++ b/paddle/phi/kernels/gpu/class_center_sample_kernel.cu @@ -375,9 +375,7 @@ void ClassCenterSampleKernel(const Context& dev_ctx, num_classes_per_device_ptr, num_classes_per_device_ptr, num_classes_per_device.numel(), - paddle::platform::ToNCCLDataType( - paddle::framework::TransToProtoVarType( - num_classes_per_device.dtype())), + phi::ToNCCLDataType(num_classes_per_device.dtype()), ncclSum, comm->comm(), calcu_stream)); diff --git a/paddle/phi/kernels/gpu/edit_distance_kernel.cu b/paddle/phi/kernels/gpu/edit_distance_kernel.cu index d4d8433fdc0bba6c81698ede802dc6c63219ab9f..cb5b096ba3f78d341b9e66125e5689bbc697552a 100644 --- a/paddle/phi/kernels/gpu/edit_distance_kernel.cu +++ b/paddle/phi/kernels/gpu/edit_distance_kernel.cu @@ -21,6 +21,7 @@ #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/mixed_vector.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/margin_cross_entropy_grad_kernel.cu b/paddle/phi/kernels/gpu/margin_cross_entropy_grad_kernel.cu index 87faf0aad58479dc53bb54b5a80ff4755abae8fa..51f5e28b032f8e317bdda4e1575b934e9db7eaf3 100644 --- a/paddle/phi/kernels/gpu/margin_cross_entropy_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/margin_cross_entropy_grad_kernel.cu @@ -96,8 +96,7 @@ void GetClassInterval(const gpuStream_t& stream, num_classes_per_device_ptr, num_classes_per_device_ptr, num_classes_per_device.numel(), - paddle::platform::ToNCCLDataType(paddle::framework::TransToProtoVarType( - num_classes_per_device.dtype())), + phi::ToNCCLDataType(num_classes_per_device.dtype()), ncclSum, comm->comm(), calcu_stream)); @@ -188,8 +187,7 @@ void MarginCrossEntropyGradKernel(const Context& dev_ctx, int blocks = NumBlocks(N * D); int threads = kNumCUDAThreads; - const auto& label_type = - paddle::framework::TransToProtoVarType(label.dtype()); + const auto& label_type = label.dtype(); DenseTensor class_interval; GetClassInterval(dev_ctx.stream(), @@ -201,7 +199,7 @@ void MarginCrossEntropyGradKernel(const Context& dev_ctx, D, &class_interval); - if (label_type == paddle::framework::proto::VarType::INT32) { + if (label_type == phi::DataType::INT32) { typedef int32_t LabelT; CalculateGrad <<>>(logits_grad->data(), @@ -215,7 +213,7 @@ void MarginCrossEntropyGradKernel(const Context& dev_ctx, N, D, class_interval.data()); - } else if (label_type == paddle::framework::proto::VarType::INT64) { + } else if (label_type == phi::DataType::INT64) { typedef int64_t LabelT; CalculateGrad <<>>(logits_grad->data(), diff --git a/paddle/phi/kernels/gpu/margin_cross_entropy_kernel.cu b/paddle/phi/kernels/gpu/margin_cross_entropy_kernel.cu index 0bc442058acedf68db6894d66efdab2f5e71d542..5cbb21c45b76d8f42372bbf125de6d725fce9a95 100644 --- a/paddle/phi/kernels/gpu/margin_cross_entropy_kernel.cu +++ b/paddle/phi/kernels/gpu/margin_cross_entropy_kernel.cu @@ -92,8 +92,7 @@ void GetClassInterval(const gpuStream_t& stream, num_classes_per_device_ptr, num_classes_per_device_ptr, num_classes_per_device.numel(), - paddle::platform::ToNCCLDataType(paddle::framework::TransToProtoVarType( - num_classes_per_device.dtype())), + phi::ToNCCLDataType(num_classes_per_device.dtype()), ncclSum, comm->comm(), calcu_stream)); @@ -265,8 +264,7 @@ void MarginCrossEntropyKernel(const Context& dev_ctx, int blocks = NumBlocks(N); int threads = kNumCUDAThreads; - const auto& label_type = - paddle::framework::TransToProtoVarType(labels.dtype()); + const auto& label_type = labels.dtype(); // copy logits to softmax variable since we can't modify logits, // and it also be used when calculate grad @@ -291,7 +289,7 @@ void MarginCrossEntropyKernel(const Context& dev_ctx, // theta = acos(x_i) // (cos(m1 * theta + m2) - m3) // save match_logits, used for gradient computation. - if (label_type == paddle::framework::proto::VarType::INT32) { + if (label_type == phi::DataType::INT32) { typedef int32_t LabelT; AddMarginToPositiveLogitsKernel <<>>( @@ -305,7 +303,7 @@ void MarginCrossEntropyKernel(const Context& dev_ctx, N, D, class_interval.data()); - } else if (label_type == paddle::framework::proto::VarType::INT64) { + } else if (label_type == phi::DataType::INT64) { typedef int64_t LabelT; AddMarginToPositiveLogitsKernel <<>>( @@ -357,15 +355,14 @@ void MarginCrossEntropyKernel(const Context& dev_ctx, auto task = pg->AllReduce(in_tensor, out_tensor, opts); task->Wait(); } else { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce( - logits_max_buff, - logits_max_buff, - logits_max.numel(), - paddle::platform::ToNCCLDataType( - paddle::framework::TransToProtoVarType(logits_max.dtype())), - ncclMax, - comm->comm(), - stream)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::ncclAllReduce(logits_max_buff, + logits_max_buff, + logits_max.numel(), + phi::ToNCCLDataType(logits_max.dtype()), + ncclMax, + comm->comm(), + stream)); } } #endif @@ -403,8 +400,7 @@ void MarginCrossEntropyKernel(const Context& dev_ctx, sum_exp_logits_buff, sum_exp_logits_buff, sum_exp_logits.numel(), - paddle::platform::ToNCCLDataType( - paddle::framework::TransToProtoVarType(sum_exp_logits.dtype())), + phi::ToNCCLDataType(sum_exp_logits.dtype()), ncclSum, comm->comm(), stream)); @@ -423,7 +419,7 @@ void MarginCrossEntropyKernel(const Context& dev_ctx, phi::funcs::SetConstant functor; functor(dev_ctx, loss, static_cast(0.0)); - if (label_type == paddle::framework::proto::VarType::INT32) { + if (label_type == phi::DataType::INT32) { typedef int32_t LabelT; HardLabelSoftmaxWithCrossEntropyKernel <<>>(loss_ptr, @@ -433,7 +429,7 @@ void MarginCrossEntropyKernel(const Context& dev_ctx, N, D, class_interval.data()); - } else if (label_type == paddle::framework::proto::VarType::INT64) { + } else if (label_type == phi::DataType::INT64) { typedef int64_t LabelT; HardLabelSoftmaxWithCrossEntropyKernel <<>>(loss_ptr, @@ -458,15 +454,14 @@ void MarginCrossEntropyKernel(const Context& dev_ctx, auto task = pg->AllReduce(in_tensor, out_tensor, opts); task->Wait(); } else { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce( - loss_ptr, - loss_ptr, - loss->numel(), - paddle::platform::ToNCCLDataType( - paddle::framework::TransToProtoVarType(loss->dtype())), - ncclSum, - comm->comm(), - stream)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::ncclAllReduce(loss_ptr, + loss_ptr, + loss->numel(), + phi::ToNCCLDataType(loss->dtype()), + ncclSum, + comm->comm(), + stream)); } } #endif diff --git a/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu b/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu index 448004fc4b89efc262179cae41c6383e9cb1237f..19b9f5845bf76aeb6805cd5d7cf38c74aad8c4c0 100644 --- a/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu @@ -108,8 +108,7 @@ void SyncBatchNormKernel(const Context &ctx, } if (comm) { - int dtype = paddle::platform::ToNCCLDataType( - paddle::framework::TransToProtoVarType(mean_out->dtype())); + int dtype = phi::ToNCCLDataType(mean_out->dtype()); // In-place operation PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::ncclAllReduce(stats, diff --git a/paddle/phi/kernels/impl/repeat_interleave_kernel_impl.h b/paddle/phi/kernels/impl/repeat_interleave_kernel_impl.h index d8a65afaf2cfd7164806d8be44c62fe41d5ad824..ff413c7b61a292f6ad950123c456a2fc14d74fce 100644 --- a/paddle/phi/kernels/impl/repeat_interleave_kernel_impl.h +++ b/paddle/phi/kernels/impl/repeat_interleave_kernel_impl.h @@ -131,32 +131,28 @@ void RepeatInterleaveWithTensorIndexKernel(const Context& ctx, "But received: [%s], required: [%d].", repeats_tensor.dims()[0], x.dims()[dim])); - const auto& index_type = - paddle::framework::TransToProtoVarType(repeats_tensor.dtype()); + const auto& index_type = repeats_tensor.dtype(); bool index_type_match = - index_type == paddle::framework::proto::VarType::INT32 || - index_type == paddle::framework::proto::VarType::INT64; + index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64; PADDLE_ENFORCE_EQ( index_type_match, true, phi::errors::InvalidArgument( "Input(RepeatsTensor) holds the wrong type, it holds %s, but " "desires to be %s or %s", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - paddle::framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - paddle::framework::proto::VarType::INT64))); + phi::DataTypeToString(index_type), + phi::DataTypeToString(phi::DataType::INT32), + phi::DataTypeToString(phi::DataType::INT64))); if (place == cpu_place) { auto x_copy = x; - if (index_type == paddle::framework::proto::VarType::INT32) { + if (index_type == phi::DataType::INT32) { phi::funcs::RepeatsTensor2IndexTensor( ctx, repeats_tensor, &index); auto output_dim = phi::vectorize(x.dims()); output_dim[dim] = index.dims()[0]; out->Resize(phi::make_ddim(output_dim)); IndexSelectInner(ctx, &x_copy, index, out, dim); - } else if (index_type == paddle::framework::proto::VarType::INT64) { + } else if (index_type == phi::DataType::INT64) { phi::funcs::RepeatsTensor2IndexTensor( ctx, repeats_tensor, &index); auto output_dim = phi::vectorize(x.dims()); @@ -170,7 +166,7 @@ void RepeatInterleaveWithTensorIndexKernel(const Context& ctx, int64_t stride = stride_dim[dim]; auto stream = ctx.stream(); auto* in_data = x.data(); - if (index_type == paddle::framework::proto::VarType::INT64) { + if (index_type == phi::DataType::INT64) { phi::funcs::RepeatsTensor2IndexTensor( ctx, repeats_tensor, &index); diff --git a/paddle/phi/kernels/impl/set_value_grad_kernel_impl.h b/paddle/phi/kernels/impl/set_value_grad_kernel_impl.h index de930734be6122245073c8f283281dda284b45ad..02e5323c5b6c00ab21f0e604b6d053f30b3039fb 100644 --- a/paddle/phi/kernels/impl/set_value_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/set_value_grad_kernel_impl.h @@ -188,21 +188,21 @@ void SetValueGradImpl(const Context& dev_ctx, (value_grad_dims_size + decrease_axis_size - num_decrease)); fake_value_grad_dims[i] = value_grad_dims[index_grad]; - PADDLE_ENFORCE_EQ((out_dims[i] == value_grad_dims[index_grad]) || - (value_grad_dims[index_grad] == 1), - true, - errors::InvalidArgument( - "An error occurred while calculating %s: " - "[%s] can not be accumulated into [%s].", - paddle::framework::GradVarName("ValueTensor"), - out_dims, - value_grad_dims)); + PADDLE_ENFORCE_EQ( + (out_dims[i] == value_grad_dims[index_grad]) || + (value_grad_dims[index_grad] == 1), + true, + errors::InvalidArgument("An error occurred while calculating %s: " + "[%s] can not be accumulated into [%s].", + "ValueTensor@GRAD", + out_dims, + value_grad_dims)); } } VLOG(3) << "Dimensions of " - << paddle::framework::GradVarName("ValueTensor") << "([" - << value_grad_dims << "])is broadcasted into [" + << "ValueTensor@GRAD" + << "([" << value_grad_dims << "])is broadcasted into [" << fake_value_grad_dims << "]."; auto extent = Eigen::DSizes(); diff --git a/paddle/phi/kernels/xpu/distribute_fpn_proposals_kernel.cc b/paddle/phi/kernels/xpu/distribute_fpn_proposals_kernel.cc index 82efcd2959c24cf7ca2e236cdcbd6e10453705a1..e23b1052d1844f5486a0292c910bf448e433c3f0 100644 --- a/paddle/phi/kernels/xpu/distribute_fpn_proposals_kernel.cc +++ b/paddle/phi/kernels/xpu/distribute_fpn_proposals_kernel.cc @@ -32,11 +32,11 @@ static void Sort(const XPUContext& dev_ctx, scores_slice_cpu.Resize({value.numel()}); T* scores_slice_cpu_data = dev_ctx.template HostAlloc(&scores_slice_cpu); - paddle::memory::Copy(cpu_place, - scores_slice_cpu_data, - place, - value_data, - sizeof(T) * value.numel()); + memory_utils::Copy(cpu_place, + scores_slice_cpu_data, + place, + value_data, + sizeof(T) * value.numel()); // Sort index DenseTensor index_t; index_t.Resize({value.numel()}); @@ -52,7 +52,7 @@ static void Sort(const XPUContext& dev_ctx, std::sort(index, index + value.numel(), compare); index_out->Resize({index_t.numel()}); int* idx_out = dev_ctx.template Alloc(index_out); - paddle::memory::Copy( + memory_utils::Copy( place, idx_out, cpu_place, index, sizeof(T) * index_t.numel()); } diff --git a/paddle/phi/kernels/xpu/set_value_grad_kernel.cc b/paddle/phi/kernels/xpu/set_value_grad_kernel.cc index 26ba5e9308720dd68bfdc5a2c9248202a2bdf4bd..affc6b0fe94f75b7b8e1b275e392382a99d3ede2 100644 --- a/paddle/phi/kernels/xpu/set_value_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/set_value_grad_kernel.cc @@ -222,21 +222,21 @@ void SetValueGradImpl(const Context& dev_ctx, (value_grad_dims_size + decrease_axis_size - num_decrease)); fake_value_grad_dims[i] = value_grad_dims[index_grad]; - PADDLE_ENFORCE_EQ((out_dims[i] == value_grad_dims[index_grad]) || - (value_grad_dims[index_grad] == 1), - true, - errors::InvalidArgument( - "An error occurred while calculating %s: " - "[%s] can not be accumulated into [%s].", - paddle::framework::GradVarName("ValueTensor"), - out_dims, - value_grad_dims)); + PADDLE_ENFORCE_EQ( + (out_dims[i] == value_grad_dims[index_grad]) || + (value_grad_dims[index_grad] == 1), + true, + errors::InvalidArgument("An error occurred while calculating %s: " + "[%s] can not be accumulated into [%s].", + "ValueTensor@GRAD", + out_dims, + value_grad_dims)); } } VLOG(3) << "Dimensions of " - << paddle::framework::GradVarName("ValueTensor") << "([" - << value_grad_dims << "])is broadcasted into [" + << "ValueTensor@GRAD" + << "([" << value_grad_dims << "])is broadcasted into [" << fake_value_grad_dims << "]."; std::vector slice_end(RANK, 0);