From f10138b248601f4b42bb614c50c8a0aef77ed6c3 Mon Sep 17 00:00:00 2001 From: seiriosPlus Date: Thu, 27 Aug 2020 16:50:34 +0800 Subject: [PATCH] fix fuse --- .../lookup_sparse_table_fuse_adam_op.h | 14 +++++++------- .../lookup_sparse_table_fuse_sgd_op.h | 7 +++---- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_adam_op.h b/paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_adam_op.h index 77dbe4f6072..d62254220d5 100644 --- a/paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_adam_op.h +++ b/paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_adam_op.h @@ -114,10 +114,6 @@ class LargeScaleFuseAdamOpKernel platform::errors::InvalidArgument( "param_row should have the same size with grad_row")); - auto ¶ms = values[0]; - auto &moment_1 = values[1]; - auto &moment_2 = values[2]; - T lr_ = lr[0]; T beta1_ = beta1_pow->data()[0]; T beta2_ = beta2_pow->data()[0]; @@ -125,9 +121,13 @@ class LargeScaleFuseAdamOpKernel lr_ *= sqrt(1 - beta1_) / (1 - beta2_); for (size_t i = 0; i < in_rows.size(); i++) { - auto *m1_data = moment_1[i]->data(); - auto *m2_data = moment_2[i]->data(); - auto *p_data = params[i]->data(); + auto ¶ms = values[i][0]; + auto &moment_1 = values[i][1]; + auto &moment_2 = values[i][2]; + + auto *p_data = params->data(); + auto *m1_data = moment_1->data(); + auto *m2_data = moment_2->data(); for (int x = 0; x < grad_width; ++x) { auto g = grad_v.data()[grad_width * i + x]; diff --git a/paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_sgd_op.h b/paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_sgd_op.h index bda4df49f95..5d4bf1015fa 100644 --- a/paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_sgd_op.h +++ b/paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_sgd_op.h @@ -87,8 +87,6 @@ class LargeScaleFuseSGDOpKernel platform::errors::InvalidArgument( "param_row should have the same size with grad_row")); - auto ¶ms = values[0]; - auto blas = math::GetBlas(ctx); std::vector grads; @@ -97,8 +95,9 @@ class LargeScaleFuseSGDOpKernel blas.SCAL(grads.size(), lr[0], grads.data()); for (int x = 0; x < static_cast(in_rows.size()); ++x) { - blas.VSUB(grad_width, params[x]->data(), grads.data() + grad_width * x, - params[x]->data()); + auto ¶ms = values[x][0]; + blas.VSUB(grad_width, params->data(), grads.data() + grad_width * x, + params->data()); } } }; -- GitLab