提交 a4910288 编写于 作者: S seiriosPlus

fix error

上级 35074963
......@@ -24,10 +24,9 @@ class SplitByrefOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SplitOp should not be null.");
PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL,
"Outputs(Out) of SplitOp should not be empty.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "Ids", "SplitOp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SplitOp");
auto in_dims = ctx->GetInputDim("X");
auto outs_names = ctx->Outputs("Out");
size_t num = static_cast<size_t>(ctx->Attrs().Get<int>("num"));
......@@ -51,9 +50,10 @@ class SplitByrefOp : public framework::OperatorWithKernel {
outs_dims.push_back(dim);
}
} else if (sections.size() > 0) {
PADDLE_ENFORCE_EQ(sections.size(), outs_number,
"tensor split sections size"
"should be equal to output size.");
PADDLE_ENFORCE_EQ(
sections.size(), outs_number,
platform::errors::InvalidArgument("tensor split sections size"
"should be equal to output size"));
for (size_t i = 0; i < outs_number; ++i) {
auto dim = in_dims;
dim[0] = sections[i];
......
......@@ -52,13 +52,19 @@ class SplitIdsOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("Ids"), "SplitIdsOp must have input Ids.");
PADDLE_ENFORCE(ctx->HasOutputs("Out"), "SplitIdsOp must have output Out.");
OP_INOUT_CHECK(ctx->HasInput("Ids"), "Input", "Ids", "SplitIdsOp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SplitIdsOp");
auto ids_var_type = ctx->GetInputsVarType("Ids").front();
auto ids_dims = ctx->GetInputsDim("Ids");
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims[0].size(), 2);
PADDLE_ENFORCE_EQ(
ids_dims[0], 2,
platform::errors::InvalidArgument(
"ShapeError: The dimensions of the 'split_ids' must be 2. "
"But received split_ids's dimensions = %d, "
"split_ids's shape = [%s].",
ids_dims[0].size(), ids_dims[0]));
}
}
......
......@@ -30,12 +30,15 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
if (!platform::is_cpu_place(place)) {
PADDLE_THROW("SplitIds do not support GPU kernel");
PADDLE_THROW(platform::errors::Unimplemented(
"SplitIds do not support GPU kernel"));
}
const auto ids_vars = ctx.MultiInputVar("Ids");
PADDLE_ENFORCE_GT(ids_vars.size(), 0, "The number of Ids should > 0");
PADDLE_ENFORCE_GT(platform::errors::InvalidArgument(
ids_vars.size(), 0, "The number of Ids expected > 0, but got %d",
ids_vars.size()));
auto *ids_var = ids_vars[0];
if (ids_var->IsType<framework::LoDTensor>()) {
......@@ -83,9 +86,6 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
} else if (ids_var->IsType<framework::SelectedRows>()) {
const auto *ids_selected_rows = ctx.Input<framework::SelectedRows>("Ids");
auto &ids_dims = ids_selected_rows->value().dims();
PADDLE_ENFORCE_EQ(ids_dims[0],
static_cast<int64_t>(ids_selected_rows->rows().size()),
"");
const T *ids_data = ids_selected_rows->value().data<T>();
const auto &ids_rows = ids_selected_rows->rows();
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
......@@ -114,9 +114,9 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
}
}
} else {
PADDLE_THROW(
PADDLE_THROW(platform::errors::InvalidArgument(
"% should be LoDTensor or SelectedRows, but the received type is %s",
ctx.InputNames("Ids")[0], framework::ToTypeName(ids_var->Type()));
ctx.InputNames("Ids")[0], framework::ToTypeName(ids_var->Type())));
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册