未验证 提交 f6050dac 编写于 作者: J Jiawei Wang 提交者: GitHub

SamplingID Op fix error print (#24521) (#24552)

* fix error print for sampling_id_op

* fix spell err

* fix spell err test=develop
上级 6f65b078
...@@ -24,16 +24,20 @@ class SamplingIdOp : public framework::OperatorWithKernel { ...@@ -24,16 +24,20 @@ class SamplingIdOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SampleIn");
"Input(X) of SamplingIdOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "X", "SampleOut");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE_LT(
"Output(Out) of SamplingIdOp should not be null."); ctx->Attrs().Get<float>("min"), ctx->Attrs().Get<float>("max"),
PADDLE_ENFORCE_LT(ctx->Attrs().Get<float>("min"), platform::errors::InvalidArgument(
ctx->Attrs().Get<float>("max"), "min must less then max"); "min must less then max, but here min is %f, max is %f",
ctx->Attrs().Get<float>("min"), ctx->Attrs().Get<float>("max")));
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(input_dims.size() == 2, PADDLE_ENFORCE_EQ(
"Input(X, Filter) should be 2-D tensor."); input_dims.size(), 2,
platform::errors::InvalidArgument(
"Input(X, Filter) should be 2-D tensor. But X dim is %d",
input_dims.size()));
auto dim0 = input_dims[0]; auto dim0 = input_dims[0];
framework::DDim dims = framework::make_ddim({dim0}); framework::DDim dims = framework::make_ddim({dim0});
......
...@@ -36,9 +36,15 @@ class SamplingIdKernel : public framework::OpKernel<T> { ...@@ -36,9 +36,15 @@ class SamplingIdKernel : public framework::OpKernel<T> {
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
const int width = static_cast<int>(input->dims()[1]); const int width = static_cast<int>(input->dims()[1]);
PADDLE_ENFORCE_GE(batch_size, 0, PADDLE_ENFORCE_GE(
"batch_size(dims[0]) must be nonnegative."); batch_size, 0,
PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative."); platform::errors::InvalidArgument(
"batch_size(dims[0]) must be nonnegative. but it is %d.",
batch_size));
PADDLE_ENFORCE_GE(
width, 0,
platform::errors::InvalidArgument(
"width(dims[1]) must be nonnegative. but it is %d.", width));
std::vector<T> ins_vector; std::vector<T> ins_vector;
framework::TensorToVector(*input, context.device_context(), &ins_vector); framework::TensorToVector(*input, context.device_context(), &ins_vector);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册