未验证 提交 19d98467 编写于 作者: S ShenLiang 提交者: GitHub

fix error message, test=develop (#24425) (#24547)

上级 55827199
......@@ -33,8 +33,9 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace();
PADDLE_ENFORCE(is_gpu_place(place),
"AllReduce op can run on gpu place only for now.");
PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
platform::errors::PreconditionNotMet(
"AllReduce op can run on gpu place only for now."));
#if defined(PADDLE_WITH_NCCL)
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto in = ctx.Input<framework::Tensor>("X");
......@@ -49,7 +50,8 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
auto* comm = dev_ctx.nccl_comm();
// FIXME(typhoonzero): should use nccl stream here.
auto stream = dev_ctx.stream();
PADDLE_ENFORCE_NOT_NULL(stream, "Should initialize NCCL firstly.");
PADDLE_ENFORCE_NOT_NULL(
stream, platform::errors::NotFound("Should initialize NCCL firstly."));
int reduce_type = ctx.Attr<int>("reduce_type");
ncclRedOp_t red_type = ncclSum;
......@@ -67,7 +69,7 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
red_type = ncclMin;
break;
}
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, static_cast<ncclDataType_t>(dtype), red_type,
comm, stream));
if (ctx.Attr<bool>("sync_mode")) {
......
......@@ -26,10 +26,12 @@ class BroadcastOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of BroadcastOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Output) of ConvOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of BroadcastOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Output) of ConvOp should not be null."));
}
};
......
......@@ -34,8 +34,10 @@ template <typename T>
class NCCLBroadcastOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"The place of ExecutionContext should be CUDAPlace.");
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet(
"The place of ExecutionContext should be CUDAPlace."));
#if defined(PADDLE_WITH_NCCL)
int dev_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).device;
......@@ -43,19 +45,22 @@ class NCCLBroadcastOpKernel : public framework::OpKernel<T> {
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
PADDLE_ENFORCE(out->IsInitialized(),
"Currently, the output of broadcast op must be initialized, "
"because this op can only be an In-Place operation.");
PADDLE_ENFORCE_EQ(
out->IsInitialized(), true,
platform::errors::PreconditionNotMet(
"Currently, the output of broadcast op must be initialized,"
"because this op can only be an In-Place operation."));
void* send_recv_buffer = out->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_EQ(
send_recv_buffer, in->data<void>(),
"Currently, the broadcast op can only be an In-Place operation.");
platform::errors::PreconditionNotMet("Currently, the broadcast op can "
"only be an In-Place operation."));
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto comm = dev_ctx.nccl_comm();
auto stream = dev_ctx.stream();
PADDLE_ENFORCE(platform::dynload::ncclBcast(
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast(
send_recv_buffer, static_cast<size_t>(in->numel()),
platform::ToNCCLDataType(in->type()), root_dev_id, comm, stream));
......
......@@ -22,16 +22,20 @@ class EyeOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of EyeOP should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of EyeOP should not be null."));
auto num_rows = ctx->Attrs().Get<int64_t>("num_rows");
PADDLE_ENFORCE(num_rows >= 0,
"The value of Input(num_rows) should be non-negative int.");
PADDLE_ENFORCE_EQ(
num_rows >= 0, true,
platform::errors::InvalidArgument(
"The value of Input(num_rows) should be non-negative int."));
auto num_columns = ctx->Attrs().Get<int64_t>("num_columns");
if (num_columns == -1) num_columns = num_rows;
PADDLE_ENFORCE(
num_columns >= 0,
"The value of Input(num_columns) should be non-negative int.");
PADDLE_ENFORCE_EQ(
num_columns >= 0, true,
platform::errors::InvalidArgument(
"The value of Input(num_columns) should be non-negative int."));
ctx->SetOutputDim("Out", {num_rows, num_columns});
}
......
......@@ -78,12 +78,14 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
// check index of shape 1-D
if (index.dims().size() == 1) {
PADDLE_ENFORCE_GT(index.dims()[0], 0,
"The index of gather_op should not be empty when the "
"index's rank is 1.");
platform::errors::InvalidArgument(
"The index of gather_op should not be empty"
"when the index's rank is 1."));
} else if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(index.dims()[1], 1,
" If the index's rank of gather_op is 2, the second "
"dimension should be 1.");
platform::errors::InvalidArgument(
"If the index's rank of gather_op is 2,"
" the second dimension should be 1."));
}
int index_size = index.dims()[0];
......
......@@ -36,15 +36,23 @@ using framework::Tensor;
template <typename T, typename IndexT = int>
void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true);
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("It should be running on the CPU."));
// check index of shape 1-D
if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(index.dims()[1], 1,
"index.dims()[1] should be 1 when index.dims().size() == "
"2 in gather_op.");
PADDLE_ENFORCE_EQ(
index.dims()[1], 1,
platform::errors::InvalidArgument(
"index.dims()[1] should be 1 when index.dims().size() = 2"
"in gather_op, but received value is [%d].",
index.dims()[1]));
} else {
PADDLE_ENFORCE_EQ(index.dims().size(), 1,
"index.dims().size() should be 1 or 2 in gather_op.");
platform::errors::InvalidArgument(
"index.dims().size() should be 1 or 2 in gather_op,"
"but received shape's size is [%d].",
index.dims().size()));
}
int64_t index_size = index.dims()[0];
......@@ -69,8 +77,9 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
template <typename T, typename IndexT = int>
void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input,
const Tensor& index, Tensor* output) {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
"It should be running on the CPU");
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("It should be running on the CPU."));
auto index_dims = index.dims();
auto index_dims_size = index_dims.size();
......@@ -98,11 +107,14 @@ void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input,
int64_t temp = 1;
for (int64_t j = end_size - 1; j >= 0; --j) {
IndexT index_value = p_index[i * end_size + j];
PADDLE_ENFORCE_LT(index_value, input_dims[j],
"Input(index[-1)] has wrong value, it is %d",
index_value);
PADDLE_ENFORCE_GE(index_value, 0UL,
"The value of Input(index) must be no less than 0");
PADDLE_ENFORCE_LT(
index_value, input_dims[j],
platform::errors::InvalidArgument(
"Input(index[-1)] has wrong value, it is [%d]", index_value));
PADDLE_ENFORCE_GE(
index_value, 0UL,
platform::errors::InvalidArgument(
"The value of Input(index) must be no less than 0"));
index_ += (index_value * temp);
temp *= input_dims[j];
......
......@@ -27,11 +27,14 @@ class GatherNdOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of GatherNdOp should not be null.");
platform::errors::InvalidArgument(
"Input(X) of GatherNdOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true,
"Input(Index) of GatherNdOp should not be null.");
platform::errors::InvalidArgument(
"Input(Index) of GatherNdOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of GatherNdOp should not be null.");
platform::errors::InvalidArgument(
"Output(Out) of GatherNdOp should not be null."));
auto x_dims = ctx->GetInputDim("X");
auto x_dims_size = x_dims.size();
......@@ -40,9 +43,11 @@ class GatherNdOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_LE(
index_dims[index_dims_size - 1], x_dims_size,
"Input(Index).shape[-1] should be no greater than Input(X).rank");
platform::errors::InvalidArgument(
"Input(Index).shape[-1] should be no greater than Input(X).rank"));
PADDLE_ENFORCE_GE(index_dims_size, 2UL,
"The rank of Input(Index) should be greater than 1");
platform::errors::InvalidArgument(
"The rank of Input(Index) should be greater than 1"));
std::vector<int64_t> result_dims;
// The result dims is
......
......@@ -25,7 +25,8 @@ class GatherNdOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"This kernel only runs on GPU device.");
platform::errors::PreconditionNotMet(
"This kernel only runs on GPU device."));
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out");
......@@ -35,12 +36,15 @@ class GatherNdOpCUDAKernel : public framework::OpKernel<T> {
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s], but "
"desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
GPUGatherNd<DeviceContext, T, int>(ctx, *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) {
......@@ -54,7 +58,8 @@ class GatherNdGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"This kernel only runs on GPU device.");
platform::errors::PreconditionNotMet(
"This kernel only runs on GPU device."));
auto *index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
......@@ -70,12 +75,15 @@ class GatherNdGradOpCUDAKernel : public framework::OpKernel<T> {
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
GPUScatterNdAdd<DeviceContext, T, int>(ctx, *dO, *index, dX);
......
......@@ -27,8 +27,9 @@ template <typename T>
class GatherNdOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
"This kernel only runs on CPU.");
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index");
......@@ -40,12 +41,15 @@ class GatherNdOpKernel : public framework::OpKernel<T> {
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s]",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
CPUGatherNd<T, int>(ctx.device_context(), *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) {
......@@ -58,8 +62,9 @@ template <typename T>
class GatherNdGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
"This kernel only runs on CPU.");
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
auto *index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
......@@ -73,12 +78,15 @@ class GatherNdGradOpKernel : public framework::OpKernel<T> {
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s]",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
ScatterNdAdd<T, int32_t>(ctx, *dO, *index, dX);
} else if (index_type == framework::proto::VarType::INT64) {
......
......@@ -26,12 +26,15 @@ class GatherOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of GatherOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Index"),
"Input(Index) of GatherOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of GatherOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of GatherOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true,
platform::errors::InvalidArgument(
"Input(Index) of GatherOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of GatherOp should not be null."));
auto index_dims = ctx->GetInputDim("Index");
PADDLE_ENFORCE(index_dims.size() == 1 ||
......
......@@ -24,8 +24,9 @@ template <typename T>
class GatherOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet(
"This kernel only runs on GPU device."));
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out");
......@@ -35,12 +36,15 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE(
index_type_match,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
GPUGather<T, int>(ctx.device_context(), *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) {
......@@ -53,8 +57,9 @@ template <typename T>
class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet(
"This kernel only runs on GPU device."));
auto *index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
......@@ -69,12 +74,15 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE(
index_type_match,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
GPUScatterAssign<T, int>(ctx, *dO, *index, dX,
ctx.Attr<bool>("overwrite"));
......
......@@ -27,8 +27,9 @@ template <typename T>
class GatherOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"This kernel only runs on CPU.");
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index");
......@@ -40,12 +41,15 @@ class GatherOpKernel : public framework::OpKernel<T> {
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE(
index_type_match,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
CPUGather<T, int>(ctx.device_context(), *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) {
......@@ -58,8 +62,9 @@ template <typename T>
class GatherGradientOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"This kernel only runs on CPU.");
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
auto *index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
......@@ -76,12 +81,15 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE(
index_type_match,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
if (overwrite) {
ScatterAssign<T, int32_t>(ctx.device_context(), *dO, *index, dX);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册