未验证 提交 6b756fb7 编写于 作者: A Aurelius84 提交者: GitHub

[Pten]Modify framework::VisitDataType into Pten::VisitDataType (#39550)

* Modify framework::VisitDataType into Pten::VisitDataType

* migrate unittest
上级 2c7f6e6d
......@@ -458,4 +458,5 @@ if(WITH_GPU OR WITH_ROCM)
else()
cc_library(fluid_convert_utils SRCS convert_utils.cc DEPS data_type place)
endif()
cc_test(convert_utils_test SRCS convert_utils_test.cc DEPS fluid_convert_utils)
cc_test(custom_kernel_test SRCS custom_kernel_test.cc DEPS custom_kernel pten_tensor)
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "gtest/gtest.h"
namespace pten {
namespace tests {
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/operator.h"
......
......@@ -57,7 +57,7 @@ inline void VisitDataType(pten::DataType type, Visitor visitor) {
_PtenForEachDataType_(PtenVisitDataTypeCallback);
#undef PtenVisitDataTypeCallback
PADDLE_THROW(pten::errors::Unimplemented(
"Not supported proto::VarType::Type(%d) as data type.",
"Not supported pten::DataType(%d) as data type.",
static_cast<int>(type)));
}
} // namespace pten
......@@ -229,9 +229,7 @@ void set_constant_with_place<paddle::platform::CPUPlace>(
const paddle::platform::DeviceContext& context,
paddle::framework::Tensor* tensor,
float value) {
paddle::framework::VisitDataType(
paddle::framework::TransToProtoVarType(tensor->type()),
TensorSetConstantCPU(tensor, value));
pten::VisitDataType(tensor->dtype(), TensorSetConstantCPU(tensor, value));
}
template <>
......@@ -248,9 +246,7 @@ void set_constant_with_place<paddle::platform::CUDAPinnedPlace>(
const paddle::platform::DeviceContext& context,
paddle::framework::Tensor* tensor,
float value) {
paddle::framework::VisitDataType(
paddle::framework::TransToProtoVarType(tensor->type()),
TensorSetConstantCPU(tensor, value));
pten::VisitDataType(tensor->dtype(), TensorSetConstantCPU(tensor, value));
}
struct TensorSetConstantWithPlace : public boost::static_visitor<void> {
......
......@@ -226,9 +226,8 @@ void set_constant_with_place<paddle::platform::CUDAPlace>(
const paddle::platform::DeviceContext& context,
paddle::framework::Tensor* tensor,
float value) {
paddle::framework::VisitDataType(
paddle::framework::TransToProtoVarType(tensor->type()),
TensorSetConstantGPU(context, tensor, value));
pten::VisitDataType(tensor->dtype(),
TensorSetConstantGPU(context, tensor, value));
}
template <typename T>
......
......@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/utils/data_type.h"
namespace pten {
namespace funcs {
......
......@@ -30,8 +30,8 @@ void SetConstant<DeviceContext, T>::operator()(
#ifdef PADDLE_WITH_XPU
if (paddle::platform::is_xpu_place(context.GetPlace())) {
xpu_place = true;
paddle::framework::VisitDataType(
paddle::framework::TransToProtoVarType(tensor->type()),
pten::VisitDataType(
tensor->dtype(),
TensorSetConstantXPU<T>(tensor, num, context.GetPlace()));
}
#endif
......
cc_test(test_dense_tensor SRCS test_dense_tensor.cc DEPS dense_tensor)
cc_test(test_intrusive_ptr SRCS test_intrusive_ptr.cc)
cc_test(test_type_info SRCS test_type_info.cc)
cc_test(test_convert_utils SRCS test_convert_utils.cc DEPS convert_utils)
cc_test(test_kernel_factory SRCS test_kernel_factory.cc DEPS kernel_factory scale_kernel)
cc_test(test_sparse_coo_tensor SRCS test_sparse_coo_tensor.cc DEPS dense_tensor sparse_coo_tensor)
cc_test(test_sparse_csr_tensor SRCS test_sparse_csr_tensor.cc DEPS dense_tensor sparse_csr_tensor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册