diff --git a/paddle/pten/core/selected_rows.cc b/paddle/pten/core/selected_rows.cc index 6f64602bdcf4d9f70d57a76677a1796b373808ac..1dfcfa49347b50d305c2b37ccc4379eedb08a107 100644 --- a/paddle/pten/core/selected_rows.cc +++ b/paddle/pten/core/selected_rows.cc @@ -13,9 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/pten/core/selected_rows.h" - -// See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/framework/data_type.h" +#include "paddle/pten/core/utils/data_type.h" namespace pten { @@ -191,16 +189,16 @@ void SelectedRows::Get(const pten::DenseTensor& ids, int64_t index = AutoGrownIndex(id, auto_grown, is_test); if (index < 0) { VLOG(5) << "id " << id << " not in the table, return 0"; - paddle::framework::VisitDataType( - value_->type(), + pten::VisitDataType( + value_->dtype(), TensorFillVisitor(value, i * value_width, value_width, 0.0)); } else { - paddle::framework::VisitDataType(value_->type(), - TensorCopyVisitor(value, - i * value_width, - *value_.get(), - index * value_width, - value_width)); + pten::VisitDataType(value_->dtype(), + TensorCopyVisitor(value, + i * value_width, + *value_.get(), + index * value_width, + value_width)); } } } diff --git a/paddle/pten/core/utils/data_type.h b/paddle/pten/core/utils/data_type.h new file mode 100644 index 0000000000000000000000000000000000000000..ee223afb3b03c0e2b770097e4313ce31c45927ea --- /dev/null +++ b/paddle/pten/core/utils/data_type.h @@ -0,0 +1,63 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include + +#include "paddle/pten/common/data_type.h" +#include "paddle/pten/core/enforce.h" +#include "paddle/pten/kernels/funcs/eigen/extensions.h" + +namespace pten { + +#define _PtenForEachDataTypeHelper_(callback, cpp_type, data_type) \ + callback(cpp_type, data_type); + +#define _PtenForEachDataType_(callback) \ + _PtenForEachDataTypeHelper_(callback, float, DataType::FLOAT32); \ + _PtenForEachDataTypeHelper_( \ + callback, ::paddle::platform::float16, DataType::FLOAT16); \ + _PtenForEachDataTypeHelper_( \ + callback, ::paddle::platform::bfloat16, DataType::BFLOAT16); \ + _PtenForEachDataTypeHelper_(callback, double, DataType::FLOAT64); \ + _PtenForEachDataTypeHelper_(callback, int, DataType::INT32); \ + _PtenForEachDataTypeHelper_(callback, int64_t, DataType::INT64); \ + _PtenForEachDataTypeHelper_(callback, bool, DataType::BOOL); \ + _PtenForEachDataTypeHelper_(callback, uint8_t, DataType::UINT8); \ + _PtenForEachDataTypeHelper_(callback, int16_t, DataType::INT16); \ + _PtenForEachDataTypeHelper_(callback, int8_t, DataType::INT8); \ + _PtenForEachDataTypeHelper_( \ + callback, ::paddle::platform::complex, DataType::COMPLEX64); \ + _PtenForEachDataTypeHelper_( \ + callback, ::paddle::platform::complex, DataType::COMPLEX128); + +template +inline void VisitDataType(pten::DataType type, Visitor visitor) { +#define PtenVisitDataTypeCallback(cpp_type, data_type) \ + do { \ + if (type == data_type) { \ + visitor.template apply(); \ + return; \ + } \ + } while (0) + + _PtenForEachDataType_(PtenVisitDataTypeCallback); +#undef PtenVisitDataTypeCallback + PADDLE_THROW(pten::errors::Unimplemented( + "Not supported proto::VarType::Type(%d) as data type.", + static_cast(type))); +} +} // namespace pten