提交 ee9ee56d 编写于 作者: H helinwang 提交者: GitHub

Merge pull request #2972 from jacquesqiao/fix-sgd-op

update tensor usage in sgd-op
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
...@@ -30,8 +31,10 @@ public: ...@@ -30,8 +31,10 @@ public:
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
param_out->flat<T>().device(*(ctx.GetEigenDevice<Place>())) = framework::EigenVector<T>::Flatten(*param_out)
param.flat<T>() - lr * grad.flat<T>(); .device(*(ctx.GetEigenDevice<Place>())) =
framework::EigenVector<T>::Flatten(param) -
lr * framework::EigenVector<T>::Flatten(grad);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册