未验证 提交 6b756fb7 编写于 作者: A Aurelius84 提交者: GitHub

[Pten]Modify framework::VisitDataType into Pten::VisitDataType (#39550)

* Modify framework::VisitDataType into Pten::VisitDataType

* migrate unittest
上级 2c7f6e6d
...@@ -458,4 +458,5 @@ if(WITH_GPU OR WITH_ROCM) ...@@ -458,4 +458,5 @@ if(WITH_GPU OR WITH_ROCM)
else() else()
cc_library(fluid_convert_utils SRCS convert_utils.cc DEPS data_type place) cc_library(fluid_convert_utils SRCS convert_utils.cc DEPS data_type place)
endif() endif()
cc_test(convert_utils_test SRCS convert_utils_test.cc DEPS fluid_convert_utils)
cc_test(custom_kernel_test SRCS custom_kernel_test.cc DEPS custom_kernel pten_tensor) cc_test(custom_kernel_test SRCS custom_kernel_test.cc DEPS custom_kernel pten_tensor)
...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "gtest/gtest.h"
namespace pten { namespace pten {
namespace tests { namespace tests {
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/mlu/mlu_baseop.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
......
...@@ -57,7 +57,7 @@ inline void VisitDataType(pten::DataType type, Visitor visitor) { ...@@ -57,7 +57,7 @@ inline void VisitDataType(pten::DataType type, Visitor visitor) {
_PtenForEachDataType_(PtenVisitDataTypeCallback); _PtenForEachDataType_(PtenVisitDataTypeCallback);
#undef PtenVisitDataTypeCallback #undef PtenVisitDataTypeCallback
PADDLE_THROW(pten::errors::Unimplemented( PADDLE_THROW(pten::errors::Unimplemented(
"Not supported proto::VarType::Type(%d) as data type.", "Not supported pten::DataType(%d) as data type.",
static_cast<int>(type))); static_cast<int>(type)));
} }
} // namespace pten } // namespace pten
...@@ -229,9 +229,7 @@ void set_constant_with_place<paddle::platform::CPUPlace>( ...@@ -229,9 +229,7 @@ void set_constant_with_place<paddle::platform::CPUPlace>(
const paddle::platform::DeviceContext& context, const paddle::platform::DeviceContext& context,
paddle::framework::Tensor* tensor, paddle::framework::Tensor* tensor,
float value) { float value) {
paddle::framework::VisitDataType( pten::VisitDataType(tensor->dtype(), TensorSetConstantCPU(tensor, value));
paddle::framework::TransToProtoVarType(tensor->type()),
TensorSetConstantCPU(tensor, value));
} }
template <> template <>
...@@ -248,9 +246,7 @@ void set_constant_with_place<paddle::platform::CUDAPinnedPlace>( ...@@ -248,9 +246,7 @@ void set_constant_with_place<paddle::platform::CUDAPinnedPlace>(
const paddle::platform::DeviceContext& context, const paddle::platform::DeviceContext& context,
paddle::framework::Tensor* tensor, paddle::framework::Tensor* tensor,
float value) { float value) {
paddle::framework::VisitDataType( pten::VisitDataType(tensor->dtype(), TensorSetConstantCPU(tensor, value));
paddle::framework::TransToProtoVarType(tensor->type()),
TensorSetConstantCPU(tensor, value));
} }
struct TensorSetConstantWithPlace : public boost::static_visitor<void> { struct TensorSetConstantWithPlace : public boost::static_visitor<void> {
......
...@@ -226,8 +226,7 @@ void set_constant_with_place<paddle::platform::CUDAPlace>( ...@@ -226,8 +226,7 @@ void set_constant_with_place<paddle::platform::CUDAPlace>(
const paddle::platform::DeviceContext& context, const paddle::platform::DeviceContext& context,
paddle::framework::Tensor* tensor, paddle::framework::Tensor* tensor,
float value) { float value) {
paddle::framework::VisitDataType( pten::VisitDataType(tensor->dtype(),
paddle::framework::TransToProtoVarType(tensor->type()),
TensorSetConstantGPU(context, tensor, value)); TensorSetConstantGPU(context, tensor, value));
} }
......
...@@ -25,6 +25,7 @@ limitations under the License. */ ...@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/utils/data_type.h"
namespace pten { namespace pten {
namespace funcs { namespace funcs {
......
...@@ -30,8 +30,8 @@ void SetConstant<DeviceContext, T>::operator()( ...@@ -30,8 +30,8 @@ void SetConstant<DeviceContext, T>::operator()(
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
if (paddle::platform::is_xpu_place(context.GetPlace())) { if (paddle::platform::is_xpu_place(context.GetPlace())) {
xpu_place = true; xpu_place = true;
paddle::framework::VisitDataType( pten::VisitDataType(
paddle::framework::TransToProtoVarType(tensor->type()), tensor->dtype(),
TensorSetConstantXPU<T>(tensor, num, context.GetPlace())); TensorSetConstantXPU<T>(tensor, num, context.GetPlace()));
} }
#endif #endif
......
cc_test(test_dense_tensor SRCS test_dense_tensor.cc DEPS dense_tensor) cc_test(test_dense_tensor SRCS test_dense_tensor.cc DEPS dense_tensor)
cc_test(test_intrusive_ptr SRCS test_intrusive_ptr.cc) cc_test(test_intrusive_ptr SRCS test_intrusive_ptr.cc)
cc_test(test_type_info SRCS test_type_info.cc) cc_test(test_type_info SRCS test_type_info.cc)
cc_test(test_convert_utils SRCS test_convert_utils.cc DEPS convert_utils)
cc_test(test_kernel_factory SRCS test_kernel_factory.cc DEPS kernel_factory scale_kernel) cc_test(test_kernel_factory SRCS test_kernel_factory.cc DEPS kernel_factory scale_kernel)
cc_test(test_sparse_coo_tensor SRCS test_sparse_coo_tensor.cc DEPS dense_tensor sparse_coo_tensor) cc_test(test_sparse_coo_tensor SRCS test_sparse_coo_tensor.cc DEPS dense_tensor sparse_coo_tensor)
cc_test(test_sparse_csr_tensor SRCS test_sparse_csr_tensor.cc DEPS dense_tensor sparse_csr_tensor) cc_test(test_sparse_csr_tensor SRCS test_sparse_csr_tensor.cc DEPS dense_tensor sparse_csr_tensor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册