diff --git a/paddle/fluid/operators/pyramid_hash_op.cc b/paddle/fluid/operators/pyramid_hash_op.cc index b02561e2311c796c8edf72ecfdff15fb608b9f45..bb1abe3a891fcbffa9ce6eab108bc3ad5eaeec5e 100644 --- a/paddle/fluid/operators/pyramid_hash_op.cc +++ b/paddle/fluid/operators/pyramid_hash_op.cc @@ -163,8 +163,12 @@ class CPUPyramidHashOPKernel : public framework::OpKernel { int _space_len) const { for (unsigned int j = 0; j != _num_emb; j += _rand_len) { unsigned int pos = XXH32(hash_id, len * sizeof(T), j) % _space_len; - memcpy(top_pos + j, const_cast(weights + pos), - _rand_len * sizeof(T)); + if (_rand_len == 16) { + memcpy(top_pos + j, const_cast(weights + pos), 16 * sizeof(T)); + } else { + memcpy(top_pos + j, const_cast(weights + pos), + _rand_len * sizeof(T)); + } } } @@ -322,6 +326,8 @@ class PyramidHashOpGrad : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true, "Input(W) should not be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("DropPos"), true, "Input(DropPos) should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("X_Temp_Out"), true, + "Input(X_Temp_Out) should not be null."); PADDLE_ENFORCE_EQ( ctx->HasInput(framework::GradVarName("Out")), true, "Input(Out@GRAD) of PyramidHashGradOp should not be null."); @@ -347,6 +353,7 @@ class PyramidHashGradOpMaker : public framework::SingleGradOpMaker { op_desc_ptr->SetInput("X", this->Input("X")); op_desc_ptr->SetInput("W", this->Input("W")); op_desc_ptr->SetInput("DropPos", this->Output("DropPos")); + op_desc_ptr->SetInput("X_Temp_Out", this->Output("X_Temp_Out")); op_desc_ptr->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); @@ -380,13 +387,8 @@ class CPUPyramidHashOPGradKernel : public framework::OpKernel { int _space_len = ctx.Attr("space_len"); int _pyramid_layer = ctx.Attr("pyramid_layer"); - const auto* bottom_data_ori = bottom->data(); - Tensor buff; - buff.Resize(framework::make_ddim({bottom->dims()[0], bottom->dims()[1]})); - T* bottom_data = buff.mutable_data(ctx.GetPlace()); - for (size_t i = 0; i < bottom->dims()[0]; i++) { - bottom_data[i] = bottom_data_ori[i]; - } + auto* buff = ctx.Input("X_Temp_Out"); + auto* bottom_data = buff->data(); int _slot_len = bottom->dims()[0]; if (_slot_len == bottom->lod()[0].size() - 1 &&