未验证 提交 008857be 编写于 作者: S ShenLiang 提交者: GitHub

fix error message for scatter and scatter_nd (#24514)

上级 14376486
......@@ -95,11 +95,17 @@ void GPUScatterAssign(const framework::ExecutionContext& context,
const auto& ctx = context.device_context();
if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(index.dims()[1], 1,
"index.dims()[1] should be 1 when index.dims().size() == "
"2 in scatter_op.");
platform::errors::InvalidArgument(
"index.dims()[1] should be 1 when "
"index.dims().size() = 2 in scatter_op."
"But received value is [%d]",
index.dims()[1]));
} else {
PADDLE_ENFORCE_EQ(index.dims().size(), 1,
"index.dims().size() should be 1 or 2 in scatter_op.");
platform::errors::InvalidArgument(
"index.dims().size() should be 1 or 2 in scatter_op."
"But received value is [%d]",
index.dims().size()));
}
int index_size = index.dims()[0];
......
......@@ -73,15 +73,23 @@ elementwise_inner_add(const framework::ExecutionContext& ctx,
template <typename T, typename IndexT = int>
void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true);
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
// check index of shape 1-D
if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(index.dims()[1], 1,
"index.dims()[1] should be 1 when index.dims().size() == "
"2 in scatter_op.");
platform::errors::InvalidArgument(
"index.dims()[1] should be 1 when "
"index.dims().size() =2 in scatter_op."
"But received value is [%d]",
index.dims()[1]));
} else {
PADDLE_ENFORCE_EQ(index.dims().size(), 1,
"index.dims().size() should be 1 or 2 in scatter_op.");
platform::errors::InvalidArgument(
"index.dims().size() should be 1 or 2 in scatter_op."
"But received value is [%d]",
index.dims().size()));
}
int index_size = index.dims()[0];
......@@ -94,7 +102,9 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
// check src shape and dst shape should match
for (int i = 1; i < src_dims.size(); i++)
PADDLE_ENFORCE_EQ(src_dims[i], dst_dims[i]);
PADDLE_ENFORCE_EQ(src_dims[i], dst_dims[i],
platform::errors::InvalidArgument(
"src shape and dst shape should match"));
// slice size
size_t slice_size = 1;
......@@ -111,12 +121,14 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
template <typename T, typename IndexT = int>
void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.device_context().GetPlace()),
true);
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.device_context().GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
// check index of shape 1-D
PADDLE_ENFORCE(index.dims().size() == 1 ||
PADDLE_ENFORCE_EQ(
index.dims().size() == 1 ||
(index.dims().size() == 2 && index.dims()[1] == 1),
"");
true, platform::errors::InvalidArgument("index's shape is error."));
int index_size = index.dims()[0];
auto src_dims = src.dims();
......@@ -130,7 +142,9 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
// check src shape and dst shape should match
for (int i = 1; i < src_dims.size(); i++)
PADDLE_ENFORCE_EQ(src_dims[i], dst_dims[i]);
PADDLE_ENFORCE_EQ(src_dims[i], dst_dims[i],
platform::errors::InvalidArgument(
"src shape and dst shape should match"));
// slice size
size_t slice_size = 1;
......@@ -156,8 +170,9 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
template <typename T, typename IndexT = int>
void ScatterNdAdd(const framework::ExecutionContext& ctx, const Tensor& update,
const Tensor& index, Tensor* output) {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.device_context().GetPlace()),
true, "It should be running on the CPU");
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.device_context().GetPlace()), true,
platform::errors::PreconditionNotMet("It should be running on the CPU"));
// update.shape = index.shape[:-1] + output.shape[index.shape[-1]:]
auto index_dims = index.dims();
......
......@@ -26,13 +26,19 @@ class ScatterNdAddOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of ScatterNdAddOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true,
"Input(Index) of ScatterNdAddOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Updates"), true,
"Input(Updates) of ScatterNdAddOp should not be null.");
platform::errors::InvalidArgument(
"Input(X) of ScatterNdAddOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Index"), true,
platform::errors::InvalidArgument(
"Input(Index) of ScatterNdAddOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Updates"), true,
platform::errors::InvalidArgument(
"Input(Updates) of ScatterNdAddOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of ScatterNdAddOp should not be null.");
platform::errors::InvalidArgument(
"Output(Out) of ScatterNdAddOp should not be null."));
auto ref_dims = ctx->GetInputDim("X");
auto ref_dims_size = ref_dims.size();
......@@ -43,9 +49,11 @@ class ScatterNdAddOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_LE(
index_dims[index_dims_size - 1], ref_dims_size,
"Input(Index).shape[-1] should be no greater than Input(X).rank");
platform::errors::InvalidArgument(
"Input(Index).shape[-1] should be no greater than Input(X).rank"));
PADDLE_ENFORCE_GE(index_dims_size, 2UL,
"The rank of Input(Index) should be greater than 1");
platform::errors::InvalidArgument(
"The rank of Input(Index) should be greater than 1"));
// update.shape = index.shape[:-1] + output.shape[index.shape[-1]:]
std::vector<int64_t> r_updates_dims;
......@@ -56,12 +64,14 @@ class ScatterNdAddOp : public framework::OperatorWithKernel {
r_updates_dims.emplace_back(ref_dims[i]);
}
PADDLE_ENFORCE_EQ(r_updates_dims.size(), updates_dims_size,
"Updates has wrong shape");
PADDLE_ENFORCE_EQ(
r_updates_dims.size(), updates_dims_size,
platform::errors::InvalidArgument("Updates has wrong shape"));
for (int64_t i = 0; i < updates_dims_size; ++i) {
PADDLE_ENFORCE_EQ(r_updates_dims[i], updates_dims[i],
"Updates has wrong shape");
PADDLE_ENFORCE_EQ(
r_updates_dims[i], updates_dims[i],
platform::errors::InvalidArgument("Updates has wrong shape"));
}
ctx->SetOutputDim("Out", ref_dims);
ctx->ShareLoD("X", /*->*/ "Out");
......@@ -72,7 +82,8 @@ class ScatterNdAddOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "Updates"),
"Ref and Updates must have same type");
platform::errors::InvalidArgument(
"Ref and Updates must have same type"));
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
}
......
......@@ -25,7 +25,8 @@ class ScatterNdAddOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"This kernel only runs on GPU device.");
platform::errors::PreconditionNotMet(
"This kernel only runs on GPU device."));
auto *X = ctx.Input<Tensor>("X");
auto *Ids = ctx.Input<Tensor>("Index");
auto *Updates = ctx.Input<Tensor>("Updates");
......@@ -35,12 +36,15 @@ class ScatterNdAddOpCUDAKernel : public framework::OpKernel<T> {
const auto &index_type = Ids->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s], but "
"desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
GPUScatterNdAdd<DeviceContext, T, int32_t>(ctx, *Updates, *Ids, Out);
} else {
......@@ -54,7 +58,8 @@ class ScatterNdAddGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"This kernel only runs on GPU device.");
platform::errors::PreconditionNotMet(
"This kernel only runs on GPU device."));
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Ids = ctx.Input<Tensor>("Index");
......
......@@ -27,8 +27,9 @@ template <typename T>
class ScatterNdAddOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
"This kernel only runs on CPU.");
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
auto *X = ctx.Input<Tensor>("X");
auto *Ids = ctx.Input<Tensor>("Index");
auto *Updates = ctx.Input<Tensor>("Updates");
......@@ -39,12 +40,15 @@ class ScatterNdAddOpKernel : public framework::OpKernel<T> {
const auto &index_type = Ids->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s], but "
"desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
ScatterNdAdd<T, int32_t>(ctx, *Updates, *Ids, Out);
......@@ -58,8 +62,9 @@ template <typename T>
class ScatterNdAddGradientOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
"This kernel only runs on CPU.");
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Ids = ctx.Input<Tensor>("Index");
......
......@@ -24,24 +24,32 @@ class ScatterOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ScatterOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Ids"),
"Input(Ids) of ScatterOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Updates"),
"Input(Updates) of ScatterOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ScatterOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of ScatterOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Ids"), true,
platform::errors::InvalidArgument(
"Input(Ids) of ScatterOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Updates"), true,
platform::errors::InvalidArgument(
"Input(Updates) of ScatterOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of ScatterOp should not be null."));
auto updates_dims = ctx->GetInputDim("Updates");
auto ref_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Ids").size(), 1,
"Update Ids should be 1-D.");
PADDLE_ENFORCE_EQ(ref_dims.size(), updates_dims.size(),
"Xerence and Updates should have the same shape size");
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Ids").size(), 1,
platform::errors::InvalidArgument("Update Ids should be 1-D."));
PADDLE_ENFORCE_EQ(
ref_dims.size(), updates_dims.size(),
platform::errors::InvalidArgument(
"Rerence and Updates should have the same shape size."));
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Updates")[0],
ctx->GetInputDim("Ids")[0],
"Updates and Ids should have same batch-size.");
platform::errors::InvalidArgument(
"Updates and Ids should have same batch-size."));
ctx->SetOutputDim("Out", ref_dims);
}
......
......@@ -24,8 +24,9 @@ template <typename T>
class ScatterOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet(
"This kernel only runs on GPU device."));
auto *X = ctx.Input<Tensor>("X");
auto *Ids = ctx.Input<Tensor>("Ids");
auto *Updates = ctx.Input<Tensor>("Updates");
......@@ -39,11 +40,14 @@ class ScatterOpCUDAKernel : public framework::OpKernel<T> {
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
"scatter_op Index holds the wrong type, it holds %s, but desires to be "
"%s or %s",
platform::errors::InvalidArgument(
"scatter_op Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
GPUScatterAssign<T, int32_t>(ctx, *Updates, *Ids, Out, overwrite);
} else {
......@@ -56,8 +60,9 @@ template <typename T>
class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet(
"This kernel only runs on GPU device."));
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Ids = ctx.Input<Tensor>("Ids");
......@@ -74,12 +79,14 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
"scatter_op Index holds the wrong type, it holds %s, but desires to "
"be %s or %s",
platform::errors::InvalidArgument(
"scatter_op Index holds the wrong type, it holds [%s], "
"but desires to be [%s] or [%s]",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64));
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
// Gradient by Gather: dUpdates = dO[Ids]
if (index_type == framework::proto::VarType::INT32) {
GPUGather<T, int32_t>(ctx.device_context(), *dOut, *Ids, dUpdates);
......
......@@ -27,8 +27,9 @@ template <typename T>
class ScatterOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"This kernel only runs on CPU.");
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
auto *X = ctx.Input<Tensor>("X");
auto *Ids = ctx.Input<Tensor>("Ids");
auto *Updates = ctx.Input<Tensor>("Updates");
......@@ -41,12 +42,15 @@ class ScatterOpKernel : public framework::OpKernel<T> {
const auto &index_type = Ids->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE(
index_type_match,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (overwrite) {
if (index_type == framework::proto::VarType::INT32) {
ScatterAssign<T, int32_t>(ctx.device_context(), *Updates, *Ids, Out);
......@@ -67,8 +71,9 @@ template <typename T>
class ScatterGradientOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"This kernel only runs on CPU.");
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Ids = ctx.Input<Tensor>("Ids");
......@@ -86,12 +91,14 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
"scatter_op index holds the wrong type, it holds %s, but desires to "
"be %s or %s",
platform::errors::InvalidArgument(
"scatter_op index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s]",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64));
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
CPUGather<T, int32_t>(ctx.device_context(), *dOut, *Ids, dUpdates);
} else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册