未验证 提交 9ffdb2b7 编写于 作者: R RedContritio 提交者: GitHub

【Hackathon No.67】remove reference to operator.h in phi [part 1] (#50624)

* add visit_place to phi/core/utils

* remove reference to operator.h in phi/kernels/funcs/math_function.h

* update data type from framework.proto to phi

* fix enforce error in fluid
上级 7d138402
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -18,6 +18,10 @@
#include "gtest/gtest.h"
#include "paddle/fluid/imperative/reducer.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/core/utils/data_type.h"
namespace paddle {
namespace imperative {
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/common/transform.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/common/transform.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
......
......@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
......
......@@ -16,6 +16,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <string>
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/jit/kernels.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -14,6 +14,7 @@
#include <stack>
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/tree2col.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "gtest/gtest.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/gpu/reduce.h"
namespace paddle {
......@@ -20,7 +21,7 @@ namespace operators {
namespace details {
TEST(test_reduce_rank_check, all) {
using EnforceNotMet = paddle::platform::EnforceNotMet;
using EnforceNotMet = phi::EnforceNotMet;
constexpr int kMaxRank = framework::DDim::kMaxRank;
for (int rank = 0; rank < kMaxRank; rank++) {
......@@ -42,7 +43,7 @@ TEST(test_reduce_rank_check, all) {
phi::funcs::details::CheckReduceRank(reduce_rank, rank);
} else {
ASSERT_THROW(phi::funcs::details::CheckReduceRank(reduce_rank, rank),
paddle::platform::EnforceNotMet);
EnforceNotMet);
}
}
}
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h"
namespace phi {
// need add dependency to phi_place when use phi::VisitPlace
template <typename Visitor>
typename Visitor::result_type VisitPlace(const phi::Place& place,
const Visitor& visitor) {
switch (place.GetType()) {
case phi::AllocationType::GPU: {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::GPUPlace p(place.GetDeviceId());
return visitor(p);
#else
PADDLE_THROW(phi::errors::Unavailable(
("Paddle is not compiled with CUDA. Cannot visit cuda_pinned")));
return typename Visitor::result_type();
#endif
}
case phi::AllocationType::GPUPINNED: {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::GPUPinnedPlace p;
return visitor(p);
#else
PADDLE_THROW(phi::errors::Unavailable(
("Paddle is not compiled with CUDA. Cannot visit cuda_pinned")));
return typename Visitor::result_type();
#endif
}
case phi::AllocationType::XPU: {
#ifdef PADDLE_WITH_XPU
phi::XPUPlace p(place.GetDeviceId());
return visitor(p);
#else
PADDLE_THROW(phi::errors::Unavailable(
("Paddle is not compiled with XPU. Cannot visit xpu device")));
return typename Visitor::result_type();
#endif
}
case phi::AllocationType::NPU: {
#ifdef PADDLE_WITH_ASCEND_CL
phi::NPUPlace p(place.GetDeviceId());
return visitor(p);
#else
PADDLE_THROW(phi::errors::Unavailable(
("Paddle is not compiled with NPU. Cannot visit npu_pinned")));
return typename Visitor::result_type();
#endif
}
case phi::AllocationType::NPUPINNED: {
#ifdef PADDLE_WITH_ASCEND_CL
phi::NPUPinnedPlace p;
return visitor(p);
#else
PADDLE_THROW(phi::errors::Unavailable(
("Paddle is not compiled with NPU. Cannot visit npu_pinned")));
return typename Visitor::result_type();
#endif
}
case phi::AllocationType::IPU: {
#ifdef PADDLE_WITH_IPU
phi::IPUPlace p(place.GetDeviceId());
return visitor(p);
#else
PADDLE_THROW(phi::errors::Unavailable(
("Paddle is not compiled with IPU. Cannot visit ipu device")));
return typename Visitor::result_type();
#endif
}
case phi::AllocationType::MLU: {
#ifdef PADDLE_WITH_MLU
phi::MLUPlace p(place.GetDeviceId());
return visitor(p);
#else
PADDLE_THROW(phi::errors::Unavailable(
("Paddle is not compiled with MLU. Cannot visit mlu device")));
#endif
}
case phi::AllocationType::CUSTOM: {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
phi::CustomPlace p(place.GetDeviceType(), place.GetDeviceId());
return visitor(p);
#else
PADDLE_THROW(phi::errors::Unavailable(
("Paddle is not compiled with CUSTOM. Cannot visit custom device")));
#endif
}
default: {
phi::CPUPlace p;
return visitor(p);
}
}
}
} // namespace phi
......@@ -45,29 +45,25 @@ void RepeatInterleaveWithTensorIndexGradKernel(
repeats_tensor.dims()[0],
x_grad->dims()[dim]));
const auto& index_type =
paddle::framework::TransToProtoVarType(repeats_tensor.dtype());
const auto& index_type = repeats_tensor.dtype();
bool index_type_match =
index_type == paddle::framework::proto::VarType::INT32 ||
index_type == paddle::framework::proto::VarType::INT64;
index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64;
PADDLE_ENFORCE_EQ(index_type_match,
true,
phi::errors::InvalidArgument(
"Input(Repeats) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
paddle::framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
paddle::framework::proto::VarType::INT64)));
phi::DataTypeToString(index_type),
phi::DataTypeToString(phi::DataType::INT32),
phi::DataTypeToString(phi::DataType::INT64)));
phi::DeviceContextPool::Instance().Get(repeats_tensor.place());
if (index_type == paddle::framework::proto::VarType::INT32) {
if (index_type == phi::DataType::INT32) {
phi::funcs::RepeatsTensor2IndexTensor<Context, int>(
ctx, repeats_tensor, &index);
IndexSelectGradInner<Context, T, int>(ctx, out_grad, index, x_grad, dim);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
} else if (index_type == phi::DataType::INT64) {
phi::funcs::RepeatsTensor2IndexTensor<Context, int64_t>(
ctx, repeats_tensor, &index);
IndexSelectGradInner<Context, T, int64_t>(
......
......@@ -45,8 +45,7 @@ struct EmbeddingCPUSparseFunctor {
int64_t row_width = table_t.value().dims()[1];
const auto* table = table_t.value().template data<T>();
auto* output = dev_ctx_.template Alloc<T>(output_t);
auto input_data_type =
paddle::framework::TransToProtoVarType(table_t.value().dtype());
auto input_data_type = table_t.value().dtype();
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx_ != kNoPadding && ids[i] == padding_idx_) {
......@@ -66,7 +65,7 @@ struct EmbeddingCPUSparseFunctor {
phi::errors::InvalidArgument(
"the input key should be exists. But received %d.", id_index));
if (input_data_type == paddle::framework::proto::VarType::BF16) {
if (input_data_type == phi::DataType::BFLOAT16) {
memcpy(output + i * row_width,
table + id_index * row_width,
row_width * sizeof(T));
......
......@@ -43,16 +43,15 @@ void TakeAlongAxisGradKernel(const Context& dev_ctx,
phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, x_grad, static_cast<T>(0));
const auto& index_type =
paddle::framework::TransToProtoVarType(index.dtype());
if (index_type == paddle::framework::proto::VarType::INT32) {
const auto& index_type = index.dtype();
if (index_type == phi::DataType::INT32) {
phi::funcs::cpu_scatter_add_kernel<T, int32_t>(
*x_grad,
axis,
index,
out_grad,
dev_ctx); // the gradient of gather is scatter
} else if (index_type == paddle::framework::proto::VarType::INT64) {
} else if (index_type == phi::DataType::INT64) {
phi::funcs::cpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, index, out_grad, dev_ctx);
}
......
......@@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/unique_consecutive_kernel.h"
#include <climits>
#include "paddle/phi/kernels/cpu/unique_consecutive_functor.h"
#include "paddle/phi/kernels/unique_consecutive_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/errors.h"
......
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <climits>
#include "paddle/phi/kernels/unique_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/core/utils/visit_place.h"
#ifdef PADDLE_WITH_MKLML
#include "paddle/phi/backends/dynload/mklml.h"
......@@ -236,7 +237,7 @@ void set_constant(const phi::DeviceContext& context,
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// tensor->place().apply_visitor(func);
paddle::platform::VisitPlace(tensor->place(), func);
phi::VisitPlace(tensor->place(), func);
#elif defined(PADDLE_WITH_XPU)
func(phi::XPUPlace());
#else
......
......@@ -17,12 +17,10 @@ limitations under the License. */
#include <memory>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
namespace funcs {
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <vector>
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/segment_pooling.h"
......
......@@ -13,6 +13,8 @@
// limitations under the License.
#pragma once
#include <set>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
......
......@@ -109,9 +109,6 @@ void BincountCUDAInner(const Context& dev_ctx,
<<<GET_BLOCKS(input_numel), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
input_data, input_numel, has_weights, weights_data, output_data);
} else {
const auto& weights_type =
paddle::framework::TransToProtoVarType(weights->dtype());
if (weights->dtype() == DataType::FLOAT32) {
float* output_data = dev_ctx.template Alloc<float>(output);
phi::funcs::SetConstant<Context, float>()(
......
......@@ -375,9 +375,7 @@ void ClassCenterSampleKernel(const Context& dev_ctx,
num_classes_per_device_ptr,
num_classes_per_device_ptr,
num_classes_per_device.numel(),
paddle::platform::ToNCCLDataType(
paddle::framework::TransToProtoVarType(
num_classes_per_device.dtype())),
phi::ToNCCLDataType(num_classes_per_device.dtype()),
ncclSum,
comm->comm(),
calcu_stream));
......
......@@ -21,6 +21,7 @@
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/mixed_vector.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
......
......@@ -96,8 +96,7 @@ void GetClassInterval(const gpuStream_t& stream,
num_classes_per_device_ptr,
num_classes_per_device_ptr,
num_classes_per_device.numel(),
paddle::platform::ToNCCLDataType(paddle::framework::TransToProtoVarType(
num_classes_per_device.dtype())),
phi::ToNCCLDataType(num_classes_per_device.dtype()),
ncclSum,
comm->comm(),
calcu_stream));
......@@ -188,8 +187,7 @@ void MarginCrossEntropyGradKernel(const Context& dev_ctx,
int blocks = NumBlocks(N * D);
int threads = kNumCUDAThreads;
const auto& label_type =
paddle::framework::TransToProtoVarType(label.dtype());
const auto& label_type = label.dtype();
DenseTensor class_interval;
GetClassInterval<T, Context>(dev_ctx.stream(),
......@@ -201,7 +199,7 @@ void MarginCrossEntropyGradKernel(const Context& dev_ctx,
D,
&class_interval);
if (label_type == paddle::framework::proto::VarType::INT32) {
if (label_type == phi::DataType::INT32) {
typedef int32_t LabelT;
CalculateGrad<T, LabelT>
<<<blocks, threads, 0, dev_ctx.stream()>>>(logits_grad->data<T>(),
......@@ -215,7 +213,7 @@ void MarginCrossEntropyGradKernel(const Context& dev_ctx,
N,
D,
class_interval.data<int>());
} else if (label_type == paddle::framework::proto::VarType::INT64) {
} else if (label_type == phi::DataType::INT64) {
typedef int64_t LabelT;
CalculateGrad<T, LabelT>
<<<blocks, threads, 0, dev_ctx.stream()>>>(logits_grad->data<T>(),
......
......@@ -92,8 +92,7 @@ void GetClassInterval(const gpuStream_t& stream,
num_classes_per_device_ptr,
num_classes_per_device_ptr,
num_classes_per_device.numel(),
paddle::platform::ToNCCLDataType(paddle::framework::TransToProtoVarType(
num_classes_per_device.dtype())),
phi::ToNCCLDataType(num_classes_per_device.dtype()),
ncclSum,
comm->comm(),
calcu_stream));
......@@ -265,8 +264,7 @@ void MarginCrossEntropyKernel(const Context& dev_ctx,
int blocks = NumBlocks(N);
int threads = kNumCUDAThreads;
const auto& label_type =
paddle::framework::TransToProtoVarType(labels.dtype());
const auto& label_type = labels.dtype();
// copy logits to softmax variable since we can't modify logits,
// and it also be used when calculate grad
......@@ -291,7 +289,7 @@ void MarginCrossEntropyKernel(const Context& dev_ctx,
// theta = acos(x_i)
// (cos(m1 * theta + m2) - m3)
// save match_logits, used for gradient computation.
if (label_type == paddle::framework::proto::VarType::INT32) {
if (label_type == phi::DataType::INT32) {
typedef int32_t LabelT;
AddMarginToPositiveLogitsKernel<T>
<<<NumBlocks(N), threads, 0, dev_ctx.stream()>>>(
......@@ -305,7 +303,7 @@ void MarginCrossEntropyKernel(const Context& dev_ctx,
N,
D,
class_interval.data<int>());
} else if (label_type == paddle::framework::proto::VarType::INT64) {
} else if (label_type == phi::DataType::INT64) {
typedef int64_t LabelT;
AddMarginToPositiveLogitsKernel<T>
<<<NumBlocks(N), threads, 0, dev_ctx.stream()>>>(
......@@ -357,15 +355,14 @@ void MarginCrossEntropyKernel(const Context& dev_ctx,
auto task = pg->AllReduce(in_tensor, out_tensor, opts);
task->Wait();
} else {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce(
logits_max_buff,
logits_max_buff,
logits_max.numel(),
paddle::platform::ToNCCLDataType(
paddle::framework::TransToProtoVarType(logits_max.dtype())),
ncclMax,
comm->comm(),
stream));
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::ncclAllReduce(logits_max_buff,
logits_max_buff,
logits_max.numel(),
phi::ToNCCLDataType(logits_max.dtype()),
ncclMax,
comm->comm(),
stream));
}
}
#endif
......@@ -403,8 +400,7 @@ void MarginCrossEntropyKernel(const Context& dev_ctx,
sum_exp_logits_buff,
sum_exp_logits_buff,
sum_exp_logits.numel(),
paddle::platform::ToNCCLDataType(
paddle::framework::TransToProtoVarType(sum_exp_logits.dtype())),
phi::ToNCCLDataType(sum_exp_logits.dtype()),
ncclSum,
comm->comm(),
stream));
......@@ -423,7 +419,7 @@ void MarginCrossEntropyKernel(const Context& dev_ctx,
phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, loss, static_cast<T>(0.0));
if (label_type == paddle::framework::proto::VarType::INT32) {
if (label_type == phi::DataType::INT32) {
typedef int32_t LabelT;
HardLabelSoftmaxWithCrossEntropyKernel<T, LabelT>
<<<blocks, threads, 0, dev_ctx.stream()>>>(loss_ptr,
......@@ -433,7 +429,7 @@ void MarginCrossEntropyKernel(const Context& dev_ctx,
N,
D,
class_interval.data<int>());
} else if (label_type == paddle::framework::proto::VarType::INT64) {
} else if (label_type == phi::DataType::INT64) {
typedef int64_t LabelT;
HardLabelSoftmaxWithCrossEntropyKernel<T, LabelT>
<<<blocks, threads, 0, dev_ctx.stream()>>>(loss_ptr,
......@@ -458,15 +454,14 @@ void MarginCrossEntropyKernel(const Context& dev_ctx,
auto task = pg->AllReduce(in_tensor, out_tensor, opts);
task->Wait();
} else {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce(
loss_ptr,
loss_ptr,
loss->numel(),
paddle::platform::ToNCCLDataType(
paddle::framework::TransToProtoVarType(loss->dtype())),
ncclSum,
comm->comm(),
stream));
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::ncclAllReduce(loss_ptr,
loss_ptr,
loss->numel(),
phi::ToNCCLDataType(loss->dtype()),
ncclSum,
comm->comm(),
stream));
}
}
#endif
......
......@@ -108,8 +108,7 @@ void SyncBatchNormKernel(const Context &ctx,
}
if (comm) {
int dtype = paddle::platform::ToNCCLDataType(
paddle::framework::TransToProtoVarType(mean_out->dtype()));
int dtype = phi::ToNCCLDataType(mean_out->dtype());
// In-place operation
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::ncclAllReduce(stats,
......
......@@ -131,32 +131,28 @@ void RepeatInterleaveWithTensorIndexKernel(const Context& ctx,
"But received: [%s], required: [%d].",
repeats_tensor.dims()[0],
x.dims()[dim]));
const auto& index_type =
paddle::framework::TransToProtoVarType(repeats_tensor.dtype());
const auto& index_type = repeats_tensor.dtype();
bool index_type_match =
index_type == paddle::framework::proto::VarType::INT32 ||
index_type == paddle::framework::proto::VarType::INT64;
index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match,
true,
phi::errors::InvalidArgument(
"Input(RepeatsTensor) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
paddle::framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
paddle::framework::proto::VarType::INT64)));
phi::DataTypeToString(index_type),
phi::DataTypeToString(phi::DataType::INT32),
phi::DataTypeToString(phi::DataType::INT64)));
if (place == cpu_place) {
auto x_copy = x;
if (index_type == paddle::framework::proto::VarType::INT32) {
if (index_type == phi::DataType::INT32) {
phi::funcs::RepeatsTensor2IndexTensor<Context, int>(
ctx, repeats_tensor, &index);
auto output_dim = phi::vectorize(x.dims());
output_dim[dim] = index.dims()[0];
out->Resize(phi::make_ddim(output_dim));
IndexSelectInner<Context, T, int>(ctx, &x_copy, index, out, dim);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
} else if (index_type == phi::DataType::INT64) {
phi::funcs::RepeatsTensor2IndexTensor<Context, int64_t>(
ctx, repeats_tensor, &index);
auto output_dim = phi::vectorize(x.dims());
......@@ -170,7 +166,7 @@ void RepeatInterleaveWithTensorIndexKernel(const Context& ctx,
int64_t stride = stride_dim[dim];
auto stream = ctx.stream();
auto* in_data = x.data<T>();
if (index_type == paddle::framework::proto::VarType::INT64) {
if (index_type == phi::DataType::INT64) {
phi::funcs::RepeatsTensor2IndexTensor<Context, int64_t>(
ctx, repeats_tensor, &index);
......
......@@ -188,21 +188,21 @@ void SetValueGradImpl(const Context& dev_ctx,
(value_grad_dims_size + decrease_axis_size - num_decrease));
fake_value_grad_dims[i] = value_grad_dims[index_grad];
PADDLE_ENFORCE_EQ((out_dims[i] == value_grad_dims[index_grad]) ||
(value_grad_dims[index_grad] == 1),
true,
errors::InvalidArgument(
"An error occurred while calculating %s: "
"[%s] can not be accumulated into [%s].",
paddle::framework::GradVarName("ValueTensor"),
out_dims,
value_grad_dims));
PADDLE_ENFORCE_EQ(
(out_dims[i] == value_grad_dims[index_grad]) ||
(value_grad_dims[index_grad] == 1),
true,
errors::InvalidArgument("An error occurred while calculating %s: "
"[%s] can not be accumulated into [%s].",
"ValueTensor@GRAD",
out_dims,
value_grad_dims));
}
}
VLOG(3) << "Dimensions of "
<< paddle::framework::GradVarName("ValueTensor") << "(["
<< value_grad_dims << "])is broadcasted into ["
<< "ValueTensor@GRAD"
<< "([" << value_grad_dims << "])is broadcasted into ["
<< fake_value_grad_dims << "].";
auto extent = Eigen::DSizes<Eigen::DenseIndex, RANK>();
......
......@@ -32,11 +32,11 @@ static void Sort(const XPUContext& dev_ctx,
scores_slice_cpu.Resize({value.numel()});
T* scores_slice_cpu_data = dev_ctx.template HostAlloc<T>(&scores_slice_cpu);
paddle::memory::Copy(cpu_place,
scores_slice_cpu_data,
place,
value_data,
sizeof(T) * value.numel());
memory_utils::Copy(cpu_place,
scores_slice_cpu_data,
place,
value_data,
sizeof(T) * value.numel());
// Sort index
DenseTensor index_t;
index_t.Resize({value.numel()});
......@@ -52,7 +52,7 @@ static void Sort(const XPUContext& dev_ctx,
std::sort(index, index + value.numel(), compare);
index_out->Resize({index_t.numel()});
int* idx_out = dev_ctx.template Alloc<int>(index_out);
paddle::memory::Copy(
memory_utils::Copy(
place, idx_out, cpu_place, index, sizeof(T) * index_t.numel());
}
......
......@@ -222,21 +222,21 @@ void SetValueGradImpl(const Context& dev_ctx,
(value_grad_dims_size + decrease_axis_size - num_decrease));
fake_value_grad_dims[i] = value_grad_dims[index_grad];
PADDLE_ENFORCE_EQ((out_dims[i] == value_grad_dims[index_grad]) ||
(value_grad_dims[index_grad] == 1),
true,
errors::InvalidArgument(
"An error occurred while calculating %s: "
"[%s] can not be accumulated into [%s].",
paddle::framework::GradVarName("ValueTensor"),
out_dims,
value_grad_dims));
PADDLE_ENFORCE_EQ(
(out_dims[i] == value_grad_dims[index_grad]) ||
(value_grad_dims[index_grad] == 1),
true,
errors::InvalidArgument("An error occurred while calculating %s: "
"[%s] can not be accumulated into [%s].",
"ValueTensor@GRAD",
out_dims,
value_grad_dims));
}
}
VLOG(3) << "Dimensions of "
<< paddle::framework::GradVarName("ValueTensor") << "(["
<< value_grad_dims << "])is broadcasted into ["
<< "ValueTensor@GRAD"
<< "([" << value_grad_dims << "])is broadcasted into ["
<< fake_value_grad_dims << "].";
std::vector<int64_t> slice_end(RANK, 0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册