提交 f10138b2 编写于 作者: S seiriosPlus

fix fuse

上级 77ac2c2b
......@@ -114,10 +114,6 @@ class LargeScaleFuseAdamOpKernel<platform::CPUDeviceContext, T>
platform::errors::InvalidArgument(
"param_row should have the same size with grad_row"));
auto &params = values[0];
auto &moment_1 = values[1];
auto &moment_2 = values[2];
T lr_ = lr[0];
T beta1_ = beta1_pow->data<T>()[0];
T beta2_ = beta2_pow->data<T>()[0];
......@@ -125,9 +121,13 @@ class LargeScaleFuseAdamOpKernel<platform::CPUDeviceContext, T>
lr_ *= sqrt(1 - beta1_) / (1 - beta2_);
for (size_t i = 0; i < in_rows.size(); i++) {
auto *m1_data = moment_1[i]->data();
auto *m2_data = moment_2[i]->data();
auto *p_data = params[i]->data();
auto &params = values[i][0];
auto &moment_1 = values[i][1];
auto &moment_2 = values[i][2];
auto *p_data = params->data();
auto *m1_data = moment_1->data();
auto *m2_data = moment_2->data();
for (int x = 0; x < grad_width; ++x) {
auto g = grad_v.data<T>()[grad_width * i + x];
......
......@@ -87,8 +87,6 @@ class LargeScaleFuseSGDOpKernel<platform::CPUDeviceContext, T>
platform::errors::InvalidArgument(
"param_row should have the same size with grad_row"));
auto &params = values[0];
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
std::vector<T> grads;
......@@ -97,8 +95,9 @@ class LargeScaleFuseSGDOpKernel<platform::CPUDeviceContext, T>
blas.SCAL(grads.size(), lr[0], grads.data());
for (int x = 0; x < static_cast<int>(in_rows.size()); ++x) {
blas.VSUB(grad_width, params[x]->data(), grads.data() + grad_width * x,
params[x]->data());
auto &params = values[x][0];
blas.VSUB(grad_width, params->data(), grads.data() + grad_width * x,
params->data());
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册