提交 db77937e 编写于 作者: S sidgoyal78

Fix learning_rate usage for momentum

上级 c10da26c
......@@ -19,33 +19,35 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class MomentumOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param_out = ctx.Output<Tensor>("ParamOut");
auto velocity_out = ctx.Output<Tensor>("VelocityOut");
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
auto param = ctx.Input<framework::Tensor>("Param");
auto velocity = ctx.Input<framework::Tensor>("Velocity");
auto grad = ctx.Input<framework::Tensor>("Grad");
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
param_out->mutable_data<T>(ctx.GetPlace());
velocity_out->mutable_data<T>(ctx.GetPlace());
float mu = ctx.Attr<float>("mu");
auto param = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
auto grad = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Grad"));
auto velocity = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Velocity"));
float learning_rate = ctx.Input<Tensor>("LearningRate")->data<float>()[0];
auto p_out = EigenVector<T>::Flatten(*param_out);
auto v_out = EigenVector<T>::Flatten(*velocity_out);
auto p_out = framework::EigenVector<T>::Flatten(*param_out);
auto v_out = framework::EigenVector<T>::Flatten(*velocity_out);
auto p = framework::EigenVector<T>::Flatten(*param);
auto v = framework::EigenVector<T>::Flatten(*velocity);
auto g = framework::EigenVector<T>::Flatten(*grad);
auto lr = framework::EigenVector<T>::Flatten(*learning_rate);
auto place = ctx.GetEigenDevice<Place>();
v_out.device(place) = velocity * mu + grad;
p_out.device(place) = param - learning_rate * v_out;
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
v_out.device(place) = v * mu + g;
p_out.device(place) = p - lr.broadcast(grad_dsize) * v_out;
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册