未验证 提交 3f2a6ab6 编写于 作者: H hutuxian 提交者: GitHub

fix error msg (#27887)

上级 426de255
...@@ -21,10 +21,14 @@ class PullBoxSparseOp : public framework::OperatorWithKernel { ...@@ -21,10 +21,14 @@ class PullBoxSparseOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_GE(ctx->Inputs("Ids").size(), 1UL, PADDLE_ENFORCE_GE(
"Inputs(Ids) of PullBoxSparseOp should not be empty."); ctx->Inputs("Ids").size(), 1UL,
PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL, platform::errors::InvalidArgument(
"Outputs(Out) of PullBoxSparseOp should not be empty."); "Inputs(Ids) of PullBoxSparseOp should not be empty."));
PADDLE_ENFORCE_GE(
ctx->Outputs("Out").size(), 1UL,
platform::errors::InvalidArgument(
"Outputs(Out) of PullBoxSparseOp should not be empty."));
auto hidden_size = static_cast<int64_t>(ctx->Attrs().Get<int>("size")); auto hidden_size = static_cast<int64_t>(ctx->Attrs().Get<int>("size"));
auto all_ids_dim = ctx->GetInputsDim("Ids"); auto all_ids_dim = ctx->GetInputsDim("Ids");
const size_t n_ids = all_ids_dim.size(); const size_t n_ids = all_ids_dim.size();
...@@ -34,9 +38,10 @@ class PullBoxSparseOp : public framework::OperatorWithKernel { ...@@ -34,9 +38,10 @@ class PullBoxSparseOp : public framework::OperatorWithKernel {
const auto ids_dims = all_ids_dim[i]; const auto ids_dims = all_ids_dim[i];
int ids_rank = ids_dims.size(); int ids_rank = ids_dims.size();
PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1, PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1,
platform::errors::InvalidArgument(
"Shape error in %lu id, the last dimension of the " "Shape error in %lu id, the last dimension of the "
"'Ids' tensor must be 1.", "'Ids' tensor must be 1.",
i); i));
auto out_dim = framework::vectorize( auto out_dim = framework::vectorize(
framework::slice_ddim(ids_dims, 0, ids_rank - 1)); framework::slice_ddim(ids_dims, 0, ids_rank - 1));
out_dim.push_back(hidden_size); out_dim.push_back(hidden_size);
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
import numpy as np import numpy as np
...@@ -87,5 +88,19 @@ class TestRunCmd(unittest.TestCase): ...@@ -87,5 +88,19 @@ class TestRunCmd(unittest.TestCase):
self.assertTrue(ret2 == 0) self.assertTrue(ret2 == 0)
class TestPullBoxSparseOP(unittest.TestCase):
""" TestCases for _pull_box_sparse op"""
def test_pull_box_sparse_op(self):
paddle.enable_static()
program = fluid.Program()
with fluid.program_guard(program):
x = fluid.layers.data(
name='x', shape=[1], dtype='int64', lod_level=0)
y = fluid.layers.data(
name='y', shape=[1], dtype='int64', lod_level=0)
emb_x, emb_y = _pull_box_sparse([x, y], size=1)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册