From 6b756fb76454866f26f77d0c136941fb78650da1 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 16 Feb 2022 10:41:53 +0800 Subject: [PATCH] [Pten]Modify framework::VisitDataType into Pten::VisitDataType (#39550) * Modify framework::VisitDataType into Pten::VisitDataType * migrate unittest --- paddle/fluid/framework/CMakeLists.txt | 1 + .../framework/convert_utils_test.cc} | 2 +- paddle/fluid/operators/mlu/mlu_baseop.cc | 1 + paddle/pten/core/utils/data_type.h | 2 +- paddle/pten/kernels/funcs/math_function.cc | 8 ++------ paddle/pten/kernels/funcs/math_function.cu | 5 ++--- paddle/pten/kernels/funcs/math_function.h | 1 + paddle/pten/kernels/funcs/math_function_impl.h | 4 ++-- paddle/pten/tests/core/CMakeLists.txt | 1 - 9 files changed, 11 insertions(+), 14 deletions(-) rename paddle/{pten/tests/core/test_convert_utils.cc => fluid/framework/convert_utils_test.cc} (100%) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index a3f0ed39264..8b3842adbb8 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -458,4 +458,5 @@ if(WITH_GPU OR WITH_ROCM) else() cc_library(fluid_convert_utils SRCS convert_utils.cc DEPS data_type place) endif() +cc_test(convert_utils_test SRCS convert_utils_test.cc DEPS fluid_convert_utils) cc_test(custom_kernel_test SRCS custom_kernel_test.cc DEPS custom_kernel pten_tensor) diff --git a/paddle/pten/tests/core/test_convert_utils.cc b/paddle/fluid/framework/convert_utils_test.cc similarity index 100% rename from paddle/pten/tests/core/test_convert_utils.cc rename to paddle/fluid/framework/convert_utils_test.cc index 977e49aafb9..d547070e6d1 100644 --- a/paddle/pten/tests/core/test_convert_utils.cc +++ b/paddle/fluid/framework/convert_utils_test.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "gtest/gtest.h" #include "paddle/fluid/framework/convert_utils.h" +#include "gtest/gtest.h" namespace pten { namespace tests { diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index b1001b4e568..068b31a6b7d 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/mlu/mlu_baseop.h" +#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/operator.h" diff --git a/paddle/pten/core/utils/data_type.h b/paddle/pten/core/utils/data_type.h index ee223afb3b0..ca0c678e062 100644 --- a/paddle/pten/core/utils/data_type.h +++ b/paddle/pten/core/utils/data_type.h @@ -57,7 +57,7 @@ inline void VisitDataType(pten::DataType type, Visitor visitor) { _PtenForEachDataType_(PtenVisitDataTypeCallback); #undef PtenVisitDataTypeCallback PADDLE_THROW(pten::errors::Unimplemented( - "Not supported proto::VarType::Type(%d) as data type.", + "Not supported pten::DataType(%d) as data type.", static_cast(type))); } } // namespace pten diff --git a/paddle/pten/kernels/funcs/math_function.cc b/paddle/pten/kernels/funcs/math_function.cc index 09717ee65e0..dec89e79565 100644 --- a/paddle/pten/kernels/funcs/math_function.cc +++ b/paddle/pten/kernels/funcs/math_function.cc @@ -229,9 +229,7 @@ void set_constant_with_place( const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, float value) { - paddle::framework::VisitDataType( - paddle::framework::TransToProtoVarType(tensor->type()), - TensorSetConstantCPU(tensor, value)); + pten::VisitDataType(tensor->dtype(), TensorSetConstantCPU(tensor, value)); } template <> @@ -248,9 +246,7 @@ void set_constant_with_place( const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, float value) { - paddle::framework::VisitDataType( - paddle::framework::TransToProtoVarType(tensor->type()), - TensorSetConstantCPU(tensor, value)); + pten::VisitDataType(tensor->dtype(), TensorSetConstantCPU(tensor, value)); } struct TensorSetConstantWithPlace : public boost::static_visitor { diff --git a/paddle/pten/kernels/funcs/math_function.cu b/paddle/pten/kernels/funcs/math_function.cu index f7cee12b2df..8ed72dbd1c1 100644 --- a/paddle/pten/kernels/funcs/math_function.cu +++ b/paddle/pten/kernels/funcs/math_function.cu @@ -226,9 +226,8 @@ void set_constant_with_place( const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, float value) { - paddle::framework::VisitDataType( - paddle::framework::TransToProtoVarType(tensor->type()), - TensorSetConstantGPU(context, tensor, value)); + pten::VisitDataType(tensor->dtype(), + TensorSetConstantGPU(context, tensor, value)); } template diff --git a/paddle/pten/kernels/funcs/math_function.h b/paddle/pten/kernels/funcs/math_function.h index 73b9dd00bc6..14f5b5b4148 100644 --- a/paddle/pten/kernels/funcs/math_function.h +++ b/paddle/pten/kernels/funcs/math_function.h @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/utils/data_type.h" namespace pten { namespace funcs { diff --git a/paddle/pten/kernels/funcs/math_function_impl.h b/paddle/pten/kernels/funcs/math_function_impl.h index 19f3082c05c..a6669236357 100644 --- a/paddle/pten/kernels/funcs/math_function_impl.h +++ b/paddle/pten/kernels/funcs/math_function_impl.h @@ -30,8 +30,8 @@ void SetConstant::operator()( #ifdef PADDLE_WITH_XPU if (paddle::platform::is_xpu_place(context.GetPlace())) { xpu_place = true; - paddle::framework::VisitDataType( - paddle::framework::TransToProtoVarType(tensor->type()), + pten::VisitDataType( + tensor->dtype(), TensorSetConstantXPU(tensor, num, context.GetPlace())); } #endif diff --git a/paddle/pten/tests/core/CMakeLists.txt b/paddle/pten/tests/core/CMakeLists.txt index 32e6e0784da..971d9112eea 100644 --- a/paddle/pten/tests/core/CMakeLists.txt +++ b/paddle/pten/tests/core/CMakeLists.txt @@ -1,7 +1,6 @@ cc_test(test_dense_tensor SRCS test_dense_tensor.cc DEPS dense_tensor) cc_test(test_intrusive_ptr SRCS test_intrusive_ptr.cc) cc_test(test_type_info SRCS test_type_info.cc) -cc_test(test_convert_utils SRCS test_convert_utils.cc DEPS convert_utils) cc_test(test_kernel_factory SRCS test_kernel_factory.cc DEPS kernel_factory scale_kernel) cc_test(test_sparse_coo_tensor SRCS test_sparse_coo_tensor.cc DEPS dense_tensor sparse_coo_tensor) cc_test(test_sparse_csr_tensor SRCS test_sparse_csr_tensor.cc DEPS dense_tensor sparse_csr_tensor) -- GitLab