diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index a3f0ed392646c370e731f2d2f573f3dde348a5c9..8b3842adbb8aa924bb392b9e0b7db985586b3406 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -458,4 +458,5 @@ if(WITH_GPU OR WITH_ROCM) else() cc_library(fluid_convert_utils SRCS convert_utils.cc DEPS data_type place) endif() +cc_test(convert_utils_test SRCS convert_utils_test.cc DEPS fluid_convert_utils) cc_test(custom_kernel_test SRCS custom_kernel_test.cc DEPS custom_kernel pten_tensor) diff --git a/paddle/pten/tests/core/test_convert_utils.cc b/paddle/fluid/framework/convert_utils_test.cc similarity index 100% rename from paddle/pten/tests/core/test_convert_utils.cc rename to paddle/fluid/framework/convert_utils_test.cc index 977e49aafb9bd4e84e6626e1f3bbe16a30ef4c52..d547070e6d1f092f5a65ccfef6d743de6e6331e2 100644 --- a/paddle/pten/tests/core/test_convert_utils.cc +++ b/paddle/fluid/framework/convert_utils_test.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "gtest/gtest.h" #include "paddle/fluid/framework/convert_utils.h" +#include "gtest/gtest.h" namespace pten { namespace tests { diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index b1001b4e5684be02df4784711ad459cd2005affb..068b31a6b7d2155ea78a1ebdfa9e3cda6e61d49a 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/mlu/mlu_baseop.h" +#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/operator.h" diff --git a/paddle/pten/core/utils/data_type.h b/paddle/pten/core/utils/data_type.h index ee223afb3b03c0e2b770097e4313ce31c45927ea..ca0c678e0623d7b7a38b8d87170fc448798f7ea6 100644 --- a/paddle/pten/core/utils/data_type.h +++ b/paddle/pten/core/utils/data_type.h @@ -57,7 +57,7 @@ inline void VisitDataType(pten::DataType type, Visitor visitor) { _PtenForEachDataType_(PtenVisitDataTypeCallback); #undef PtenVisitDataTypeCallback PADDLE_THROW(pten::errors::Unimplemented( - "Not supported proto::VarType::Type(%d) as data type.", + "Not supported pten::DataType(%d) as data type.", static_cast(type))); } } // namespace pten diff --git a/paddle/pten/kernels/funcs/math_function.cc b/paddle/pten/kernels/funcs/math_function.cc index 09717ee65e0452a8563b063dfb790821297800f3..dec89e79565dea863b1f2837334db372ed415522 100644 --- a/paddle/pten/kernels/funcs/math_function.cc +++ b/paddle/pten/kernels/funcs/math_function.cc @@ -229,9 +229,7 @@ void set_constant_with_place( const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, float value) { - paddle::framework::VisitDataType( - paddle::framework::TransToProtoVarType(tensor->type()), - TensorSetConstantCPU(tensor, value)); + pten::VisitDataType(tensor->dtype(), TensorSetConstantCPU(tensor, value)); } template <> @@ -248,9 +246,7 @@ void set_constant_with_place( const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, float value) { - paddle::framework::VisitDataType( - paddle::framework::TransToProtoVarType(tensor->type()), - TensorSetConstantCPU(tensor, value)); + pten::VisitDataType(tensor->dtype(), TensorSetConstantCPU(tensor, value)); } struct TensorSetConstantWithPlace : public boost::static_visitor { diff --git a/paddle/pten/kernels/funcs/math_function.cu b/paddle/pten/kernels/funcs/math_function.cu index f7cee12b2dfd42c2296a4bd30a739bfe181efb13..8ed72dbd1c1278d320ccebfd7463e83f7c101065 100644 --- a/paddle/pten/kernels/funcs/math_function.cu +++ b/paddle/pten/kernels/funcs/math_function.cu @@ -226,9 +226,8 @@ void set_constant_with_place( const paddle::platform::DeviceContext& context, paddle::framework::Tensor* tensor, float value) { - paddle::framework::VisitDataType( - paddle::framework::TransToProtoVarType(tensor->type()), - TensorSetConstantGPU(context, tensor, value)); + pten::VisitDataType(tensor->dtype(), + TensorSetConstantGPU(context, tensor, value)); } template diff --git a/paddle/pten/kernels/funcs/math_function.h b/paddle/pten/kernels/funcs/math_function.h index 73b9dd00bc64095ea2796154ff5d32c407fd9f1b..14f5b5b41489d09e53a47a1ece22d394c22f1c53 100644 --- a/paddle/pten/kernels/funcs/math_function.h +++ b/paddle/pten/kernels/funcs/math_function.h @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/utils/data_type.h" namespace pten { namespace funcs { diff --git a/paddle/pten/kernels/funcs/math_function_impl.h b/paddle/pten/kernels/funcs/math_function_impl.h index 19f3082c05cc27c265fe1354fba666226b88ce1c..a66692363572adf06a0d064fbdf9c82e44eb6d6a 100644 --- a/paddle/pten/kernels/funcs/math_function_impl.h +++ b/paddle/pten/kernels/funcs/math_function_impl.h @@ -30,8 +30,8 @@ void SetConstant::operator()( #ifdef PADDLE_WITH_XPU if (paddle::platform::is_xpu_place(context.GetPlace())) { xpu_place = true; - paddle::framework::VisitDataType( - paddle::framework::TransToProtoVarType(tensor->type()), + pten::VisitDataType( + tensor->dtype(), TensorSetConstantXPU(tensor, num, context.GetPlace())); } #endif diff --git a/paddle/pten/tests/core/CMakeLists.txt b/paddle/pten/tests/core/CMakeLists.txt index 32e6e0784dad0c716cfea384b46933f11adbe5d0..971d9112eead97f46ab1f165c9073ac525464676 100644 --- a/paddle/pten/tests/core/CMakeLists.txt +++ b/paddle/pten/tests/core/CMakeLists.txt @@ -1,7 +1,6 @@ cc_test(test_dense_tensor SRCS test_dense_tensor.cc DEPS dense_tensor) cc_test(test_intrusive_ptr SRCS test_intrusive_ptr.cc) cc_test(test_type_info SRCS test_type_info.cc) -cc_test(test_convert_utils SRCS test_convert_utils.cc DEPS convert_utils) cc_test(test_kernel_factory SRCS test_kernel_factory.cc DEPS kernel_factory scale_kernel) cc_test(test_sparse_coo_tensor SRCS test_sparse_coo_tensor.cc DEPS dense_tensor sparse_coo_tensor) cc_test(test_sparse_csr_tensor SRCS test_sparse_csr_tensor.cc DEPS dense_tensor sparse_csr_tensor)