提交 1cd67218 编写于 作者: A Aurelius84 提交者: Tao Luo

Optimizer mmcpy if _rand_len=16 and remove data copy in GradKernel (#21099)

上级 78cc1ca6
......@@ -163,10 +163,14 @@ class CPUPyramidHashOPKernel : public framework::OpKernel<T> {
int _space_len) const {
for (unsigned int j = 0; j != _num_emb; j += _rand_len) {
unsigned int pos = XXH32(hash_id, len * sizeof(T), j) % _space_len;
if (_rand_len == 16) {
memcpy(top_pos + j, const_cast<float*>(weights + pos), 16 * sizeof(T));
} else {
memcpy(top_pos + j, const_cast<float*>(weights + pos),
_rand_len * sizeof(T));
}
}
}
void Compute(const framework::ExecutionContext& ctx) const override {
auto* bottom = ctx.Input<LoDTensor>("X");
......@@ -322,6 +326,8 @@ class PyramidHashOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true, "Input(W) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("DropPos"), true,
"Input(DropPos) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X_Temp_Out"), true,
"Input(X_Temp_Out) should not be null.");
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Out")), true,
"Input(Out@GRAD) of PyramidHashGradOp should not be null.");
......@@ -347,6 +353,7 @@ class PyramidHashGradOpMaker : public framework::SingleGradOpMaker<T> {
op_desc_ptr->SetInput("X", this->Input("X"));
op_desc_ptr->SetInput("W", this->Input("W"));
op_desc_ptr->SetInput("DropPos", this->Output("DropPos"));
op_desc_ptr->SetInput("X_Temp_Out", this->Output("X_Temp_Out"));
op_desc_ptr->SetInput(framework::GradVarName("Out"),
this->OutputGrad("Out"));
......@@ -380,13 +387,8 @@ class CPUPyramidHashOPGradKernel : public framework::OpKernel<T> {
int _space_len = ctx.Attr<int>("space_len");
int _pyramid_layer = ctx.Attr<int>("pyramid_layer");
const auto* bottom_data_ori = bottom->data<int32_t>();
Tensor buff;
buff.Resize(framework::make_ddim({bottom->dims()[0], bottom->dims()[1]}));
T* bottom_data = buff.mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < bottom->dims()[0]; i++) {
bottom_data[i] = bottom_data_ori[i];
}
auto* buff = ctx.Input<LoDTensor>("X_Temp_Out");
auto* bottom_data = buff->data<T>();
int _slot_len = bottom->dims()[0];
if (_slot_len == bottom->lod()[0].size() - 1 &&
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册