未验证 提交 3f2a6ab6 编写于 作者: H hutuxian 提交者: GitHub

fix error msg (#27887)

上级 426de255
......@@ -21,10 +21,14 @@ class PullBoxSparseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_GE(ctx->Inputs("Ids").size(), 1UL,
"Inputs(Ids) of PullBoxSparseOp should not be empty.");
PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL,
"Outputs(Out) of PullBoxSparseOp should not be empty.");
PADDLE_ENFORCE_GE(
ctx->Inputs("Ids").size(), 1UL,
platform::errors::InvalidArgument(
"Inputs(Ids) of PullBoxSparseOp should not be empty."));
PADDLE_ENFORCE_GE(
ctx->Outputs("Out").size(), 1UL,
platform::errors::InvalidArgument(
"Outputs(Out) of PullBoxSparseOp should not be empty."));
auto hidden_size = static_cast<int64_t>(ctx->Attrs().Get<int>("size"));
auto all_ids_dim = ctx->GetInputsDim("Ids");
const size_t n_ids = all_ids_dim.size();
......@@ -34,9 +38,10 @@ class PullBoxSparseOp : public framework::OperatorWithKernel {
const auto ids_dims = all_ids_dim[i];
int ids_rank = ids_dims.size();
PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1,
platform::errors::InvalidArgument(
"Shape error in %lu id, the last dimension of the "
"'Ids' tensor must be 1.",
i);
i));
auto out_dim = framework::vectorize(
framework::slice_ddim(ids_dims, 0, ids_rank - 1));
out_dim.push_back(hidden_size);
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import numpy as np
......@@ -87,5 +88,19 @@ class TestRunCmd(unittest.TestCase):
self.assertTrue(ret2 == 0)
class TestPullBoxSparseOP(unittest.TestCase):
""" TestCases for _pull_box_sparse op"""
def test_pull_box_sparse_op(self):
paddle.enable_static()
program = fluid.Program()
with fluid.program_guard(program):
x = fluid.layers.data(
name='x', shape=[1], dtype='int64', lod_level=0)
y = fluid.layers.data(
name='y', shape=[1], dtype='int64', lod_level=0)
emb_x, emb_y = _pull_box_sparse([x, y], size=1)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册