提交 8b3d33a0 编写于 作者: Q qiaolongfei

fix-sgd

上级 e28e0073
...@@ -30,7 +30,7 @@ class SGDOpKernel : public framework::OpKernel { ...@@ -30,7 +30,7 @@ class SGDOpKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto param = ctx.Input<Tensor>("param"); auto param = ctx.Input<Tensor>("param");
auto grad = ctx.Input<Tensor>("grad"); auto grad = ctx.Input<Tensor>("grad");
auto param_out = ctx.Output<Tensor>(0); auto param_out = ctx.Output<Tensor>("param_out");
float lr = ctx.op_.GetAttr<float>("learning_rate"); float lr = ctx.op_.GetAttr<float>("learning_rate");
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册