未验证 提交 2f34fc7a 编写于 作者: H huangjiyi 提交者: GitHub

rm "paddle/fluid/framework/convert_utils.h" in phi (#48001)

上级 f3650201
......@@ -162,40 +162,5 @@ DataType String2DataType(const std::string& str) {
}
}
std::string DataType2String(DataType dtype) {
switch (dtype) {
case DataType::BOOL:
return "bool";
case DataType::INT8:
return "int8";
case DataType::UINT8:
return "uint8";
case DataType::INT16:
return "int16";
case DataType::INT32:
return "int32";
case DataType::INT64:
return "int64";
case DataType::FLOAT16:
return "float16";
case DataType::FLOAT32:
return "float32";
case DataType::FLOAT64:
return "float64";
case DataType::COMPLEX64:
return "complex64";
case DataType::COMPLEX128:
return "complex128";
case DataType::PSTRING:
return "pstring";
case DataType::BFLOAT16:
return "bfloat16";
default:
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Unknow phi::DataType, the int value = %d.",
static_cast<int>(dtype)));
return "";
}
}
} // namespace framework
} // namespace paddle
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/core/utils/data_type.h"
// TODO(chenweihang): this file may need to be removed
......@@ -37,7 +38,8 @@ paddle::framework::proto::VarType::Type TransToProtoVarType(
size_t DataTypeSize(DataType dtype);
DataType String2DataType(const std::string& str);
std::string DataType2String(DataType dtype);
using phi::DataType2String;
} // namespace framework
} // namespace paddle
......@@ -121,7 +121,7 @@ class PruneGateByCapacityCUDAKernel : public framework::OpKernel<T> {
framework::TensorCopy(*expert_count, context.GetPlace(), &expert_count_out);
PruneGateByCapacityFunctor<DeviceContext, T> functor(
context, gate_idx, &expert_count_out, new_gate_idx_data);
VisitDataType(expert_count->type(), functor);
::paddle::operators::VisitDataType(expert_count->type(), functor);
}
};
......
......@@ -41,6 +41,14 @@ static std::map<int, phi::DataType> var_type_map{{1, phi::DataType::INT16},
{6, phi::DataType::FLOAT64},
{20, phi::DataType::UINT8}};
static std::map<phi::DataType, int> map_to_var_type{{phi::DataType::INT16, 1},
{phi::DataType::INT32, 2},
{phi::DataType::INT64, 3},
{phi::DataType::FLOAT16, 4},
{phi::DataType::FLOAT32, 5},
{phi::DataType::FLOAT64, 6},
{phi::DataType::UINT8, 20}};
#define _PhiForEachDataTypeHelper_(callback, cpp_type, data_type) \
callback(cpp_type, data_type);
......@@ -129,4 +137,41 @@ inline DataType ToRealType(const DataType& type) {
type));
}
}
inline std::string DataType2String(DataType dtype) {
switch (dtype) {
case DataType::BOOL:
return "bool";
case DataType::INT8:
return "int8";
case DataType::UINT8:
return "uint8";
case DataType::INT16:
return "int16";
case DataType::INT32:
return "int32";
case DataType::INT64:
return "int64";
case DataType::FLOAT16:
return "float16";
case DataType::FLOAT32:
return "float32";
case DataType::FLOAT64:
return "float64";
case DataType::COMPLEX64:
return "complex64";
case DataType::COMPLEX128:
return "complex128";
case DataType::PSTRING:
return "pstring";
case DataType::BFLOAT16:
return "bfloat16";
default:
PADDLE_THROW(
errors::InvalidArgument("Unknow phi::DataType, the int value = %d.",
static_cast<int>(dtype)));
return "";
}
}
} // namespace phi
......@@ -17,11 +17,11 @@ limitations under the License. */
#include <algorithm>
#include <set>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/parse_qr_mode.h"
#include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"
......@@ -133,12 +133,9 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
phi::errors::InvalidArgument(
"The attribute of dtype in argmin/argmax must be [%s] or [%s], but "
"received [%s]",
paddle::framework::DataTypeToString(
paddle::framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
paddle::framework::proto::VarType::INT64),
paddle::framework::DataTypeToString(
static_cast<paddle::framework::proto::VarType::Type>(dtype))));
phi::DataType2String(DataType::INT32),
phi::DataType2String(DataType::INT64),
phi::DataType2String(var_type_map[dtype])));
if (!config.is_runtime && axis.FromTensor()) {
std::vector<int64_t> vec;
......@@ -180,11 +177,10 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
auto x_rank = x_dims.size();
if (int_axis < 0) int_axis += x_rank;
if (config.is_runtime) {
if (dtype == paddle::framework::proto::VarType::INT32) {
if (dtype == map_to_var_type[DataType::INT32]) {
int64_t all_element_num = 0;
if (flatten) {
all_element_num = phi::product(x_dims);
} else {
all_element_num = x_dims[int_axis];
}
......
......@@ -14,11 +14,11 @@
#include "paddle/phi/kernels/index_sample_grad_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
template <typename T, typename Context, typename IndexT = int>
void IndexSampleGradInner(const Context& context,
......@@ -76,18 +76,14 @@ void IndexSampleGradKernel(const Context& ctx,
auto index_type = index.dtype();
bool index_type_match =
index_type == DataType::INT32 || index_type == DataType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match,
true,
errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(index_type)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(DataType::INT32)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType((DataType::INT64)))));
PADDLE_ENFORCE_EQ(index_type_match,
true,
errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
phi::DataType2String(index_type),
phi::DataType2String(DataType::INT32),
phi::DataType2String(DataType::INT64)));
if (index_type == DataType::INT32) {
IndexSampleGradInner<T, Context, int>(ctx, out_grad, index, x_grad);
} else if (index_type == DataType::INT64) {
......
......@@ -21,11 +21,11 @@
#include <utility>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
template <typename T, typename Context, typename IndexT = int>
void IndexSampleInner(const Context &context,
......@@ -89,18 +89,14 @@ void IndexSampleKernel(const Context &ctx,
auto index_type = index.dtype();
bool index_type_match =
index_type == DataType::INT32 || index_type == DataType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match,
true,
errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(index_type)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(DataType::INT32)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType((DataType::INT64)))));
PADDLE_ENFORCE_EQ(index_type_match,
true,
errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
phi::DataType2String(index_type),
phi::DataType2String(DataType::INT32),
phi::DataType2String(DataType::INT64)));
if (index_type == DataType::INT32) {
IndexSampleInner<T, Context, int>(ctx, x, index, out);
} else if (index_type == DataType::INT64) {
......
......@@ -14,9 +14,9 @@
#include "paddle/phi/kernels/put_along_axis_grad_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
......@@ -37,11 +37,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
true,
errors::PreconditionNotMet("PutAlongAxisGradOpKernel only runs on CPU."));
const auto& index_type =
paddle::framework::TransToProtoVarType(index.dtype());
const auto& index_type = index.dtype();
if (x_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
if (index_type == paddle::framework::proto::VarType::INT32) {
if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_input_grad_kernel<T, int32_t>(
// Here passing an unused argument out_grad, because it's
// convenient to instantiate a bunch of template function with the
......@@ -60,10 +59,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
if (value_grad) {
value_grad->Resize(index.dims());
value_grad->mutable_data<T>(dev_ctx.GetPlace());
if (index_type == paddle::framework::proto::VarType::INT32) {
if (index_type == DataType::INT32) {
paddle::operators::cpu_gather_kernel<T, int32_t>(
out_grad, axis, index, *value_grad, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
} else if (index_type == DataType::INT64) {
paddle::operators::cpu_gather_kernel<T, int64_t>(
out_grad, axis, index, *value_grad, dev_ctx);
}
......
......@@ -14,9 +14,9 @@
#include "paddle/phi/kernels/put_along_axis_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
......@@ -37,29 +37,28 @@ void PutAlongAxisKernel(const Context& dev_ctx,
errors::PreconditionNotMet("PutAlongAxisOpKernel only runs on CPU."));
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
const auto& index_type =
paddle::framework::TransToProtoVarType(index.dtype());
const auto& index_type = index.dtype();
if (reduce == "add") {
if (index_type == paddle::framework::proto::VarType::INT32) {
if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_add_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
} else if (index_type == DataType::INT64) {
paddle::operators::cpu_scatter_add_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else if (reduce == "multiply" || reduce == "mul") {
if (index_type == paddle::framework::proto::VarType::INT32) {
if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_mul_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
} else if (index_type == DataType::INT64) {
paddle::operators::cpu_scatter_mul_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else if (reduce == "assign") {
if (index_type == paddle::framework::proto::VarType::INT32) {
if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_assign_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
} else if (index_type == DataType::INT64) {
paddle::operators::cpu_scatter_assign_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
......
......@@ -14,9 +14,9 @@
#include "paddle/phi/kernels/take_along_axis_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
......@@ -36,12 +36,11 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
out->Resize(index.dims());
dev_ctx.template Alloc<T>(out);
const auto& index_type =
paddle::framework::TransToProtoVarType(index.dtype());
if (index_type == paddle::framework::proto::VarType::INT32) {
const auto& index_type = index.dtype();
if (index_type == DataType::INT32) {
paddle::operators::cpu_gather_kernel<T, int32_t>(
x, axis, index, *out, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
} else if (index_type == DataType::INT64) {
paddle::operators::cpu_gather_kernel<T, int64_t>(
x, axis, index, *out, dev_ctx);
}
......
......@@ -17,7 +17,6 @@ limitations under the License. */
#include <memory>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
......
......@@ -13,8 +13,8 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......@@ -77,18 +77,14 @@ struct UniqueOpFunctor {
const auto& index_type = index_->dtype();
bool index_type_match =
index_type == DataType::INT32 || index_type == DataType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match,
true,
phi::errors::InvalidArgument(
"Index holds the wrong type, it holds %s, "
"but desires to be %s or %s",
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(index_type)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(DataType::INT32)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(DataType::INT64))));
PADDLE_ENFORCE_EQ(index_type_match,
true,
phi::errors::InvalidArgument(
"Index holds the wrong type, it holds %s, "
"but desires to be %s or %s",
phi::DataType2String(index_type),
phi::DataType2String(DataType::INT32),
phi::DataType2String(DataType::INT64)));
if (index_type == DataType::INT32) {
for (auto i = 0; i < in_->numel(); ++i) {
......
......@@ -17,7 +17,6 @@
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
......
......@@ -17,11 +17,11 @@
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
......@@ -70,18 +70,14 @@ void IndexSampleGradKernel(const Context& ctx,
auto index_type = index.dtype();
bool index_type_match =
index_type == DataType::INT32 || index_type == DataType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match,
true,
errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(index_type)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(DataType::INT32)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType((DataType::INT64)))));
PADDLE_ENFORCE_EQ(index_type_match,
true,
errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
phi::DataType2String(index_type),
phi::DataType2String(DataType::INT32),
phi::DataType2String(DataType::INT64)));
auto stream = reinterpret_cast<const phi::GPUContext&>(ctx).stream();
auto input_num = x.numel();
......
......@@ -17,10 +17,10 @@
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
......@@ -59,18 +59,14 @@ void IndexSampleKernel(const Context& ctx,
auto index_type = index.dtype();
bool index_type_match =
index_type == DataType::INT32 || index_type == DataType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match,
true,
errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(index_type)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(DataType::INT32)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType((DataType::INT64)))));
PADDLE_ENFORCE_EQ(index_type_match,
true,
errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
phi::DataType2String(index_type),
phi::DataType2String(DataType::INT32),
phi::DataType2String(DataType::INT64)));
const T* in_data = x.data<T>();
T* out_data = ctx.template Alloc<T>(out);
auto stream = reinterpret_cast<const phi::GPUContext&>(ctx).stream();
......
......@@ -14,12 +14,12 @@
#include "paddle/phi/kernels/put_along_axis_grad_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
......@@ -37,11 +37,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
errors::PreconditionNotMet(
"PutAlongAxisGradOpCUDAKernel only runs on GPU."));
const auto& index_type =
paddle::framework::TransToProtoVarType(index.dtype());
const auto& index_type = index.dtype();
if (x_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
if (index_type == paddle::framework::proto::VarType::INT32) {
if (index_type == DataType::INT32) {
paddle::operators::gpu_scatter_input_grad_kernel<T, int32_t>(
out_grad, axis, index, *x_grad, dev_ctx);
} else {
......@@ -52,14 +51,14 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
if (value_grad) {
value_grad->Resize(index.dims());
value_grad->mutable_data<T>(dev_ctx.GetPlace());
if (index_type == paddle::framework::proto::VarType::INT32) {
if (index_type == DataType::INT32) {
paddle::operators::gpu_gather_kernel<T, int32_t>(
out_grad,
axis,
index,
*value_grad,
dev_ctx); // the gradient of scatter is gather
} else if (index_type == paddle::framework::proto::VarType::INT64) {
} else if (index_type == DataType::INT64) {
paddle::operators::gpu_gather_kernel<T, int64_t>(
out_grad, axis, index, *value_grad, dev_ctx);
}
......
......@@ -14,12 +14,12 @@
#include "paddle/phi/kernels/put_along_axis_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
......@@ -36,31 +36,30 @@ void PutAlongAxisKernel(const Context& dev_ctx,
errors::PreconditionNotMet(
"PutAlongAxisCUDAKernel only runs on GPU device."));
const auto& index_type =
paddle::framework::TransToProtoVarType(index.dtype());
const auto& index_type = index.dtype();
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
if (reduce == "add") {
if (index_type == paddle::framework::proto::VarType::INT32) {
if (index_type == DataType::INT32) {
paddle::operators::gpu_scatter_add_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
} else if (index_type == DataType::INT64) {
paddle::operators::gpu_scatter_add_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else if (reduce == "multiply" || reduce == "mul") {
if (index_type == paddle::framework::proto::VarType::INT32) {
if (index_type == DataType::INT32) {
paddle::operators::gpu_scatter_mul_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
} else if (index_type == DataType::INT64) {
paddle::operators::gpu_scatter_mul_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else if (reduce == "assign") {
if (index_type == paddle::framework::proto::VarType::INT32) {
if (index_type == DataType::INT32) {
paddle::operators::gpu_scatter_assign_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
} else if (index_type == DataType::INT64) {
paddle::operators::gpu_scatter_assign_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
......
......@@ -30,7 +30,6 @@ namespace cub = hipcub;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#endif
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
......@@ -431,8 +430,7 @@ void SyncBatchNormGradFunctor(
}
if (comm) {
int dtype = paddle::platform::ToNCCLDataType(
paddle::framework::TransToProtoVarType(scale.dtype()));
int dtype = paddle::platform::ToNCCLDataType(scale.dtype());
// In-place operation
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::ncclAllReduce(stats,
......
......@@ -14,11 +14,11 @@
#include "paddle/phi/kernels/take_along_axis_grad_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
......@@ -43,17 +43,16 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx,
// Set to zero tensor.
phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, x_grad, static_cast<T>(0));
const auto& index_type =
paddle::framework::TransToProtoVarType(index.dtype());
const auto& index_type = index.dtype();
if (index_type == paddle::framework::proto::VarType::INT32) {
if (index_type == DataType::INT32) {
paddle::operators::gpu_scatter_add_kernel<T, int32_t>(
*x_grad,
axis,
index,
out_grad,
dev_ctx); // the gradient of gather is scatter
} else if (index_type == paddle::framework::proto::VarType::INT64) {
} else if (index_type == DataType::INT64) {
paddle::operators::gpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, index, out_grad, dev_ctx);
}
......
......@@ -14,11 +14,11 @@
#include "paddle/phi/kernels/take_along_axis_kernel.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
......@@ -36,12 +36,11 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
out->Resize(index.dims());
dev_ctx.template Alloc<T>(out);
const auto& index_type =
paddle::framework::TransToProtoVarType(index.dtype());
if (index_type == paddle::framework::proto::VarType::INT32) {
const auto& index_type = index.dtype();
if (index_type == DataType::INT32) {
paddle::operators::gpu_gather_kernel<T, int32_t>(
x, axis, index, *out, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
} else if (index_type == DataType::INT64) {
paddle::operators::gpu_gather_kernel<T, int64_t>(
x, axis, index, *out, dev_ctx);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册