提交 5b9716d1 编写于 作者: T tangwei12

add dims check

上级 4cd504d3
......@@ -30,6 +30,8 @@ class SamplingIdOp : public framework::OperatorWithKernel {
"Output(Out) of SamplingIdOp should not be null.");
auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(input_dims.size() == 2,
"Input(X, Filter) should be 2-D tensor.");
framework::DDim dims = input_dims;
ctx->SetOutputDim("Out", dims);
......@@ -46,10 +48,8 @@ class SamplingIdOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "SamplingId data tensor.");
AddComment(R"DOC(
SamplingId Operator.
@brief A layer for sampling id from multinomial distribution from the
input layer. Sampling one id for one sample. The result is stored in
output_.ids.
)DOC");
A layer for sampling id from multinomial distribution from the
input layer. Sampling one id for one sample.)DOC");
}
};
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册