未验证 提交 2f34fc7a 编写于 作者: H huangjiyi 提交者: GitHub

rm "paddle/fluid/framework/convert_utils.h" in phi (#48001)

上级 f3650201
...@@ -162,40 +162,5 @@ DataType String2DataType(const std::string& str) { ...@@ -162,40 +162,5 @@ DataType String2DataType(const std::string& str) {
} }
} }
std::string DataType2String(DataType dtype) {
switch (dtype) {
case DataType::BOOL:
return "bool";
case DataType::INT8:
return "int8";
case DataType::UINT8:
return "uint8";
case DataType::INT16:
return "int16";
case DataType::INT32:
return "int32";
case DataType::INT64:
return "int64";
case DataType::FLOAT16:
return "float16";
case DataType::FLOAT32:
return "float32";
case DataType::FLOAT64:
return "float64";
case DataType::COMPLEX64:
return "complex64";
case DataType::COMPLEX128:
return "complex128";
case DataType::PSTRING:
return "pstring";
case DataType::BFLOAT16:
return "bfloat16";
default:
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Unknow phi::DataType, the int value = %d.",
static_cast<int>(dtype)));
return "";
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/tensor_meta.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/core/utils/data_type.h"
// TODO(chenweihang): this file may need to be removed // TODO(chenweihang): this file may need to be removed
...@@ -37,7 +38,8 @@ paddle::framework::proto::VarType::Type TransToProtoVarType( ...@@ -37,7 +38,8 @@ paddle::framework::proto::VarType::Type TransToProtoVarType(
size_t DataTypeSize(DataType dtype); size_t DataTypeSize(DataType dtype);
DataType String2DataType(const std::string& str); DataType String2DataType(const std::string& str);
std::string DataType2String(DataType dtype);
using phi::DataType2String;
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -121,7 +121,7 @@ class PruneGateByCapacityCUDAKernel : public framework::OpKernel<T> { ...@@ -121,7 +121,7 @@ class PruneGateByCapacityCUDAKernel : public framework::OpKernel<T> {
framework::TensorCopy(*expert_count, context.GetPlace(), &expert_count_out); framework::TensorCopy(*expert_count, context.GetPlace(), &expert_count_out);
PruneGateByCapacityFunctor<DeviceContext, T> functor( PruneGateByCapacityFunctor<DeviceContext, T> functor(
context, gate_idx, &expert_count_out, new_gate_idx_data); context, gate_idx, &expert_count_out, new_gate_idx_data);
VisitDataType(expert_count->type(), functor); ::paddle::operators::VisitDataType(expert_count->type(), functor);
} }
}; };
......
...@@ -41,6 +41,14 @@ static std::map<int, phi::DataType> var_type_map{{1, phi::DataType::INT16}, ...@@ -41,6 +41,14 @@ static std::map<int, phi::DataType> var_type_map{{1, phi::DataType::INT16},
{6, phi::DataType::FLOAT64}, {6, phi::DataType::FLOAT64},
{20, phi::DataType::UINT8}}; {20, phi::DataType::UINT8}};
static std::map<phi::DataType, int> map_to_var_type{{phi::DataType::INT16, 1},
{phi::DataType::INT32, 2},
{phi::DataType::INT64, 3},
{phi::DataType::FLOAT16, 4},
{phi::DataType::FLOAT32, 5},
{phi::DataType::FLOAT64, 6},
{phi::DataType::UINT8, 20}};
#define _PhiForEachDataTypeHelper_(callback, cpp_type, data_type) \ #define _PhiForEachDataTypeHelper_(callback, cpp_type, data_type) \
callback(cpp_type, data_type); callback(cpp_type, data_type);
...@@ -129,4 +137,41 @@ inline DataType ToRealType(const DataType& type) { ...@@ -129,4 +137,41 @@ inline DataType ToRealType(const DataType& type) {
type)); type));
} }
} }
inline std::string DataType2String(DataType dtype) {
switch (dtype) {
case DataType::BOOL:
return "bool";
case DataType::INT8:
return "int8";
case DataType::UINT8:
return "uint8";
case DataType::INT16:
return "int16";
case DataType::INT32:
return "int32";
case DataType::INT64:
return "int64";
case DataType::FLOAT16:
return "float16";
case DataType::FLOAT32:
return "float32";
case DataType::FLOAT64:
return "float64";
case DataType::COMPLEX64:
return "complex64";
case DataType::COMPLEX128:
return "complex128";
case DataType::PSTRING:
return "pstring";
case DataType::BFLOAT16:
return "bfloat16";
default:
PADDLE_THROW(
errors::InvalidArgument("Unknow phi::DataType, the int value = %d.",
static_cast<int>(dtype)));
return "";
}
}
} // namespace phi } // namespace phi
...@@ -17,11 +17,11 @@ limitations under the License. */ ...@@ -17,11 +17,11 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <set> #include <set>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/type_traits.h" #include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/parse_qr_mode.h" #include "paddle/phi/kernels/funcs/parse_qr_mode.h"
#include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/slice_utils.h" #include "paddle/phi/kernels/funcs/slice_utils.h"
...@@ -133,12 +133,9 @@ void ArgMinMaxInferMeta(const MetaTensor& x, ...@@ -133,12 +133,9 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The attribute of dtype in argmin/argmax must be [%s] or [%s], but " "The attribute of dtype in argmin/argmax must be [%s] or [%s], but "
"received [%s]", "received [%s]",
paddle::framework::DataTypeToString( phi::DataType2String(DataType::INT32),
paddle::framework::proto::VarType::INT32), phi::DataType2String(DataType::INT64),
paddle::framework::DataTypeToString( phi::DataType2String(var_type_map[dtype])));
paddle::framework::proto::VarType::INT64),
paddle::framework::DataTypeToString(
static_cast<paddle::framework::proto::VarType::Type>(dtype))));
if (!config.is_runtime && axis.FromTensor()) { if (!config.is_runtime && axis.FromTensor()) {
std::vector<int64_t> vec; std::vector<int64_t> vec;
...@@ -180,11 +177,10 @@ void ArgMinMaxInferMeta(const MetaTensor& x, ...@@ -180,11 +177,10 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
auto x_rank = x_dims.size(); auto x_rank = x_dims.size();
if (int_axis < 0) int_axis += x_rank; if (int_axis < 0) int_axis += x_rank;
if (config.is_runtime) { if (config.is_runtime) {
if (dtype == paddle::framework::proto::VarType::INT32) { if (dtype == map_to_var_type[DataType::INT32]) {
int64_t all_element_num = 0; int64_t all_element_num = 0;
if (flatten) { if (flatten) {
all_element_num = phi::product(x_dims); all_element_num = phi::product(x_dims);
} else { } else {
all_element_num = x_dims[int_axis]; all_element_num = x_dims[int_axis];
} }
......
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
#include "paddle/phi/kernels/index_sample_grad_kernel.h" #include "paddle/phi/kernels/index_sample_grad_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi { namespace phi {
template <typename T, typename Context, typename IndexT = int> template <typename T, typename Context, typename IndexT = int>
void IndexSampleGradInner(const Context& context, void IndexSampleGradInner(const Context& context,
...@@ -76,18 +76,14 @@ void IndexSampleGradKernel(const Context& ctx, ...@@ -76,18 +76,14 @@ void IndexSampleGradKernel(const Context& ctx,
auto index_type = index.dtype(); auto index_type = index.dtype();
bool index_type_match = bool index_type_match =
index_type == DataType::INT32 || index_type == DataType::INT64; index_type == DataType::INT32 || index_type == DataType::INT64;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(index_type_match,
index_type_match, true,
true, errors::InvalidArgument(
errors::InvalidArgument( "Input(Index) holds the wrong type, it holds %s, but "
"Input(Index) holds the wrong type, it holds %s, but " "desires to be %s or %s",
"desires to be %s or %s", phi::DataType2String(index_type),
paddle::framework::DataTypeToString( phi::DataType2String(DataType::INT32),
paddle::framework::TransToProtoVarType(index_type)), phi::DataType2String(DataType::INT64)));
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(DataType::INT32)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType((DataType::INT64)))));
if (index_type == DataType::INT32) { if (index_type == DataType::INT32) {
IndexSampleGradInner<T, Context, int>(ctx, out_grad, index, x_grad); IndexSampleGradInner<T, Context, int>(ctx, out_grad, index, x_grad);
} else if (index_type == DataType::INT64) { } else if (index_type == DataType::INT64) {
......
...@@ -21,11 +21,11 @@ ...@@ -21,11 +21,11 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi { namespace phi {
template <typename T, typename Context, typename IndexT = int> template <typename T, typename Context, typename IndexT = int>
void IndexSampleInner(const Context &context, void IndexSampleInner(const Context &context,
...@@ -89,18 +89,14 @@ void IndexSampleKernel(const Context &ctx, ...@@ -89,18 +89,14 @@ void IndexSampleKernel(const Context &ctx,
auto index_type = index.dtype(); auto index_type = index.dtype();
bool index_type_match = bool index_type_match =
index_type == DataType::INT32 || index_type == DataType::INT64; index_type == DataType::INT32 || index_type == DataType::INT64;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(index_type_match,
index_type_match, true,
true, errors::InvalidArgument(
errors::InvalidArgument( "Input(Index) holds the wrong type, it holds %s, but "
"Input(Index) holds the wrong type, it holds %s, but " "desires to be %s or %s",
"desires to be %s or %s", phi::DataType2String(index_type),
paddle::framework::DataTypeToString( phi::DataType2String(DataType::INT32),
paddle::framework::TransToProtoVarType(index_type)), phi::DataType2String(DataType::INT64)));
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(DataType::INT32)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType((DataType::INT64)))));
if (index_type == DataType::INT32) { if (index_type == DataType::INT32) {
IndexSampleInner<T, Context, int>(ctx, x, index, out); IndexSampleInner<T, Context, int>(ctx, x, index, out);
} else if (index_type == DataType::INT64) { } else if (index_type == DataType::INT64) {
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#include "paddle/phi/kernels/put_along_axis_grad_kernel.h" #include "paddle/phi/kernels/put_along_axis_grad_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h" #include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
...@@ -37,11 +37,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx, ...@@ -37,11 +37,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
true, true,
errors::PreconditionNotMet("PutAlongAxisGradOpKernel only runs on CPU.")); errors::PreconditionNotMet("PutAlongAxisGradOpKernel only runs on CPU."));
const auto& index_type = const auto& index_type = index.dtype();
paddle::framework::TransToProtoVarType(index.dtype());
if (x_grad) { if (x_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
if (index_type == paddle::framework::proto::VarType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_input_grad_kernel<T, int32_t>( paddle::operators::cpu_scatter_input_grad_kernel<T, int32_t>(
// Here passing an unused argument out_grad, because it's // Here passing an unused argument out_grad, because it's
// convenient to instantiate a bunch of template function with the // convenient to instantiate a bunch of template function with the
...@@ -60,10 +59,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx, ...@@ -60,10 +59,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
if (value_grad) { if (value_grad) {
value_grad->Resize(index.dims()); value_grad->Resize(index.dims());
value_grad->mutable_data<T>(dev_ctx.GetPlace()); value_grad->mutable_data<T>(dev_ctx.GetPlace());
if (index_type == paddle::framework::proto::VarType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::cpu_gather_kernel<T, int32_t>( paddle::operators::cpu_gather_kernel<T, int32_t>(
out_grad, axis, index, *value_grad, dev_ctx); out_grad, axis, index, *value_grad, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) { } else if (index_type == DataType::INT64) {
paddle::operators::cpu_gather_kernel<T, int64_t>( paddle::operators::cpu_gather_kernel<T, int64_t>(
out_grad, axis, index, *value_grad, dev_ctx); out_grad, axis, index, *value_grad, dev_ctx);
} }
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#include "paddle/phi/kernels/put_along_axis_kernel.h" #include "paddle/phi/kernels/put_along_axis_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h" #include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
...@@ -37,29 +37,28 @@ void PutAlongAxisKernel(const Context& dev_ctx, ...@@ -37,29 +37,28 @@ void PutAlongAxisKernel(const Context& dev_ctx,
errors::PreconditionNotMet("PutAlongAxisOpKernel only runs on CPU.")); errors::PreconditionNotMet("PutAlongAxisOpKernel only runs on CPU."));
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
const auto& index_type = const auto& index_type = index.dtype();
paddle::framework::TransToProtoVarType(index.dtype());
if (reduce == "add") { if (reduce == "add") {
if (index_type == paddle::framework::proto::VarType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_add_kernel<T, int32_t>( paddle::operators::cpu_scatter_add_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) { } else if (index_type == DataType::INT64) {
paddle::operators::cpu_scatter_add_kernel<T, int64_t>( paddle::operators::cpu_scatter_add_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} }
} else if (reduce == "multiply" || reduce == "mul") { } else if (reduce == "multiply" || reduce == "mul") {
if (index_type == paddle::framework::proto::VarType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_mul_kernel<T, int32_t>( paddle::operators::cpu_scatter_mul_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) { } else if (index_type == DataType::INT64) {
paddle::operators::cpu_scatter_mul_kernel<T, int64_t>( paddle::operators::cpu_scatter_mul_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} }
} else if (reduce == "assign") { } else if (reduce == "assign") {
if (index_type == paddle::framework::proto::VarType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_assign_kernel<T, int32_t>( paddle::operators::cpu_scatter_assign_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) { } else if (index_type == DataType::INT64) {
paddle::operators::cpu_scatter_assign_kernel<T, int64_t>( paddle::operators::cpu_scatter_assign_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} }
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#include "paddle/phi/kernels/take_along_axis_kernel.h" #include "paddle/phi/kernels/take_along_axis_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h" #include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
...@@ -36,12 +36,11 @@ void TakeAlongAxisKernel(const Context& dev_ctx, ...@@ -36,12 +36,11 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
out->Resize(index.dims()); out->Resize(index.dims());
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
const auto& index_type = const auto& index_type = index.dtype();
paddle::framework::TransToProtoVarType(index.dtype()); if (index_type == DataType::INT32) {
if (index_type == paddle::framework::proto::VarType::INT32) {
paddle::operators::cpu_gather_kernel<T, int32_t>( paddle::operators::cpu_gather_kernel<T, int32_t>(
x, axis, index, *out, dev_ctx); x, axis, index, *out, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) { } else if (index_type == DataType::INT64) {
paddle::operators::cpu_gather_kernel<T, int64_t>( paddle::operators::cpu_gather_kernel<T, int64_t>(
x, axis, index, *out, dev_ctx); x, axis, index, *out, dev_ctx);
} }
......
...@@ -17,7 +17,6 @@ limitations under the License. */ ...@@ -17,7 +17,6 @@ limitations under the License. */
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
...@@ -77,18 +77,14 @@ struct UniqueOpFunctor { ...@@ -77,18 +77,14 @@ struct UniqueOpFunctor {
const auto& index_type = index_->dtype(); const auto& index_type = index_->dtype();
bool index_type_match = bool index_type_match =
index_type == DataType::INT32 || index_type == DataType::INT64; index_type == DataType::INT32 || index_type == DataType::INT64;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(index_type_match,
index_type_match, true,
true, phi::errors::InvalidArgument(
phi::errors::InvalidArgument( "Index holds the wrong type, it holds %s, "
"Index holds the wrong type, it holds %s, " "but desires to be %s or %s",
"but desires to be %s or %s", phi::DataType2String(index_type),
paddle::framework::DataTypeToString( phi::DataType2String(DataType::INT32),
paddle::framework::TransToProtoVarType(index_type)), phi::DataType2String(DataType::INT64)));
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(DataType::INT32)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(DataType::INT64))));
if (index_type == DataType::INT32) { if (index_type == DataType::INT32) {
for (auto i = 0; i < in_->numel(); ++i) { for (auto i = 0; i < in_->numel(); ++i) {
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/common_shape.h"
......
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace phi { namespace phi {
...@@ -70,18 +70,14 @@ void IndexSampleGradKernel(const Context& ctx, ...@@ -70,18 +70,14 @@ void IndexSampleGradKernel(const Context& ctx,
auto index_type = index.dtype(); auto index_type = index.dtype();
bool index_type_match = bool index_type_match =
index_type == DataType::INT32 || index_type == DataType::INT64; index_type == DataType::INT32 || index_type == DataType::INT64;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(index_type_match,
index_type_match, true,
true, errors::InvalidArgument(
errors::InvalidArgument( "Input(Index) holds the wrong type, it holds %s, but "
"Input(Index) holds the wrong type, it holds %s, but " "desires to be %s or %s",
"desires to be %s or %s", phi::DataType2String(index_type),
paddle::framework::DataTypeToString( phi::DataType2String(DataType::INT32),
paddle::framework::TransToProtoVarType(index_type)), phi::DataType2String(DataType::INT64)));
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(DataType::INT32)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType((DataType::INT64)))));
auto stream = reinterpret_cast<const phi::GPUContext&>(ctx).stream(); auto stream = reinterpret_cast<const phi::GPUContext&>(ctx).stream();
auto input_num = x.numel(); auto input_num = x.numel();
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace phi { namespace phi {
...@@ -59,18 +59,14 @@ void IndexSampleKernel(const Context& ctx, ...@@ -59,18 +59,14 @@ void IndexSampleKernel(const Context& ctx,
auto index_type = index.dtype(); auto index_type = index.dtype();
bool index_type_match = bool index_type_match =
index_type == DataType::INT32 || index_type == DataType::INT64; index_type == DataType::INT32 || index_type == DataType::INT64;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(index_type_match,
index_type_match, true,
true, errors::InvalidArgument(
errors::InvalidArgument( "Input(Index) holds the wrong type, it holds %s, but "
"Input(Index) holds the wrong type, it holds %s, but " "desires to be %s or %s",
"desires to be %s or %s", phi::DataType2String(index_type),
paddle::framework::DataTypeToString( phi::DataType2String(DataType::INT32),
paddle::framework::TransToProtoVarType(index_type)), phi::DataType2String(DataType::INT64)));
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(DataType::INT32)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType((DataType::INT64)))));
const T* in_data = x.data<T>(); const T* in_data = x.data<T>();
T* out_data = ctx.template Alloc<T>(out); T* out_data = ctx.template Alloc<T>(out);
auto stream = reinterpret_cast<const phi::GPUContext&>(ctx).stream(); auto stream = reinterpret_cast<const phi::GPUContext&>(ctx).stream();
......
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
#include "paddle/phi/kernels/put_along_axis_grad_kernel.h" #include "paddle/phi/kernels/put_along_axis_grad_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h" #include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi { namespace phi {
...@@ -37,11 +37,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx, ...@@ -37,11 +37,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
errors::PreconditionNotMet( errors::PreconditionNotMet(
"PutAlongAxisGradOpCUDAKernel only runs on GPU.")); "PutAlongAxisGradOpCUDAKernel only runs on GPU."));
const auto& index_type = const auto& index_type = index.dtype();
paddle::framework::TransToProtoVarType(index.dtype());
if (x_grad) { if (x_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
if (index_type == paddle::framework::proto::VarType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::gpu_scatter_input_grad_kernel<T, int32_t>( paddle::operators::gpu_scatter_input_grad_kernel<T, int32_t>(
out_grad, axis, index, *x_grad, dev_ctx); out_grad, axis, index, *x_grad, dev_ctx);
} else { } else {
...@@ -52,14 +51,14 @@ void PutAlongAxisGradKernel(const Context& dev_ctx, ...@@ -52,14 +51,14 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
if (value_grad) { if (value_grad) {
value_grad->Resize(index.dims()); value_grad->Resize(index.dims());
value_grad->mutable_data<T>(dev_ctx.GetPlace()); value_grad->mutable_data<T>(dev_ctx.GetPlace());
if (index_type == paddle::framework::proto::VarType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::gpu_gather_kernel<T, int32_t>( paddle::operators::gpu_gather_kernel<T, int32_t>(
out_grad, out_grad,
axis, axis,
index, index,
*value_grad, *value_grad,
dev_ctx); // the gradient of scatter is gather dev_ctx); // the gradient of scatter is gather
} else if (index_type == paddle::framework::proto::VarType::INT64) { } else if (index_type == DataType::INT64) {
paddle::operators::gpu_gather_kernel<T, int64_t>( paddle::operators::gpu_gather_kernel<T, int64_t>(
out_grad, axis, index, *value_grad, dev_ctx); out_grad, axis, index, *value_grad, dev_ctx);
} }
......
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
#include "paddle/phi/kernels/put_along_axis_kernel.h" #include "paddle/phi/kernels/put_along_axis_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h" #include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi { namespace phi {
...@@ -36,31 +36,30 @@ void PutAlongAxisKernel(const Context& dev_ctx, ...@@ -36,31 +36,30 @@ void PutAlongAxisKernel(const Context& dev_ctx,
errors::PreconditionNotMet( errors::PreconditionNotMet(
"PutAlongAxisCUDAKernel only runs on GPU device.")); "PutAlongAxisCUDAKernel only runs on GPU device."));
const auto& index_type = const auto& index_type = index.dtype();
paddle::framework::TransToProtoVarType(index.dtype());
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
if (reduce == "add") { if (reduce == "add") {
if (index_type == paddle::framework::proto::VarType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::gpu_scatter_add_kernel<T, int32_t>( paddle::operators::gpu_scatter_add_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) { } else if (index_type == DataType::INT64) {
paddle::operators::gpu_scatter_add_kernel<T, int64_t>( paddle::operators::gpu_scatter_add_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} }
} else if (reduce == "multiply" || reduce == "mul") { } else if (reduce == "multiply" || reduce == "mul") {
if (index_type == paddle::framework::proto::VarType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::gpu_scatter_mul_kernel<T, int32_t>( paddle::operators::gpu_scatter_mul_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) { } else if (index_type == DataType::INT64) {
paddle::operators::gpu_scatter_mul_kernel<T, int64_t>( paddle::operators::gpu_scatter_mul_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} }
} else if (reduce == "assign") { } else if (reduce == "assign") {
if (index_type == paddle::framework::proto::VarType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::gpu_scatter_assign_kernel<T, int32_t>( paddle::operators::gpu_scatter_assign_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) { } else if (index_type == DataType::INT64) {
paddle::operators::gpu_scatter_assign_kernel<T, int64_t>( paddle::operators::gpu_scatter_assign_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx); *out, axis, index, value, dev_ctx);
} }
......
...@@ -30,7 +30,6 @@ namespace cub = hipcub; ...@@ -30,7 +30,6 @@ namespace cub = hipcub;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h" #include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#endif #endif
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
...@@ -431,8 +430,7 @@ void SyncBatchNormGradFunctor( ...@@ -431,8 +430,7 @@ void SyncBatchNormGradFunctor(
} }
if (comm) { if (comm) {
int dtype = paddle::platform::ToNCCLDataType( int dtype = paddle::platform::ToNCCLDataType(scale.dtype());
paddle::framework::TransToProtoVarType(scale.dtype()));
// In-place operation // In-place operation
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::ncclAllReduce(stats, phi::dynload::ncclAllReduce(stats,
......
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
#include "paddle/phi/kernels/take_along_axis_grad_kernel.h" #include "paddle/phi/kernels/take_along_axis_grad_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h" #include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace phi { namespace phi {
...@@ -43,17 +43,16 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx, ...@@ -43,17 +43,16 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx,
// Set to zero tensor. // Set to zero tensor.
phi::funcs::SetConstant<Context, T> functor; phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, x_grad, static_cast<T>(0)); functor(dev_ctx, x_grad, static_cast<T>(0));
const auto& index_type = const auto& index_type = index.dtype();
paddle::framework::TransToProtoVarType(index.dtype());
if (index_type == paddle::framework::proto::VarType::INT32) { if (index_type == DataType::INT32) {
paddle::operators::gpu_scatter_add_kernel<T, int32_t>( paddle::operators::gpu_scatter_add_kernel<T, int32_t>(
*x_grad, *x_grad,
axis, axis,
index, index,
out_grad, out_grad,
dev_ctx); // the gradient of gather is scatter dev_ctx); // the gradient of gather is scatter
} else if (index_type == paddle::framework::proto::VarType::INT64) { } else if (index_type == DataType::INT64) {
paddle::operators::gpu_scatter_add_kernel<T, int64_t>( paddle::operators::gpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, index, out_grad, dev_ctx); *x_grad, axis, index, out_grad, dev_ctx);
} }
......
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
#include "paddle/phi/kernels/take_along_axis_kernel.h" #include "paddle/phi/kernels/take_along_axis_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h" #include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi { namespace phi {
...@@ -36,12 +36,11 @@ void TakeAlongAxisKernel(const Context& dev_ctx, ...@@ -36,12 +36,11 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
out->Resize(index.dims()); out->Resize(index.dims());
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
const auto& index_type = const auto& index_type = index.dtype();
paddle::framework::TransToProtoVarType(index.dtype()); if (index_type == DataType::INT32) {
if (index_type == paddle::framework::proto::VarType::INT32) {
paddle::operators::gpu_gather_kernel<T, int32_t>( paddle::operators::gpu_gather_kernel<T, int32_t>(
x, axis, index, *out, dev_ctx); x, axis, index, *out, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) { } else if (index_type == DataType::INT64) {
paddle::operators::gpu_gather_kernel<T, int64_t>( paddle::operators::gpu_gather_kernel<T, int64_t>(
x, axis, index, *out, dev_ctx); x, axis, index, *out, dev_ctx);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册