提交 e3b27d19 编写于 作者: Q Qiao Longfei 提交者: GitHub

Add sgd op (#2950)

* a simplest SGD op
上级 d8108493
......@@ -51,3 +51,5 @@ op_library(softmax_op SRCS softmax_op.cc softmax_op.cu)
op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op
softmax_op net)
op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sgd_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace paddle {
namespace operators {
class SGDOp : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2, "Input size of SGDOp must be two");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of SGDOp must be one");
PADDLE_ENFORCE(inputs[0] != nullptr, "inputs[0] mast be set");
PADDLE_ENFORCE(inputs[1] != nullptr, "inputs[1] mast be set");
PADDLE_ENFORCE(outputs[0] != nullptr, "outputs[0] mast be set");
PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(),
"Two input of SGD Op's dimension must be same.");
outputs[0]->set_dims(inputs[0]->dims());
}
};
class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("param", "input parameter");
AddInput("grad", "input gradient");
AddOutput("param_out", "output parameter");
AddAttr<float>("learning_rate", "learning rate of sgd");
AddComment(R"DOC(
Simplest sgd algorithm.
param_out = param - learning_rate * grad;
)DOC");
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP(sgd, paddle::operators::SGDOp, paddle::operators::SGDOpMaker);
typedef paddle::operators::SGDOpKernel<::paddle::platform::CPUPlace, float>
SGDOpKernel_CPU_float;
REGISTER_OP_CPU_KERNEL(sgd, SGDOpKernel_CPU_float);
#include "paddle/operators/sgd_op.h"
#include "paddle/framework/op_registry.h"
typedef paddle::operators::SGDOpKernel<::paddle::platform::GPUPlace, float> SGDOpKernel_GPU_float;
REGISTER_OP_GPU_KERNEL(sgd, SGDOpKernel_GPU_float);
\ No newline at end of file
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "glog/logging.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class SGDOpKernel : public framework::OpKernel {
public:
void Compute(const framework::KernelContext& ctx) const override {
auto param = ctx.Input("param")->Get<framework::Tensor>();
auto grad = ctx.Input("grad")->Get<framework::Tensor>();
auto* param_out = ctx.Output(0)->GetMutable<framework::Tensor>();
float lr = ctx.op_.GetAttr<float>("learning_rate");
param_out->mutable_data<T>(ctx.GetPlace());
param_out->flat<T>().device(*(ctx.GetEigenDevice<Place>())) =
param.flat<T>() - lr * grad.flat<T>();
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <paddle/framework/op_registry.h>
USE_OP(sgd);
TEST(SGDOp, GetOpProto) {
auto& protos = paddle::framework::OpRegistry::protos();
auto it = protos.find("sgd");
ASSERT_NE(it, protos.end());
}
cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python
add_op fc_op)
add_op fc_op sgd_op)
......@@ -28,6 +28,7 @@ namespace pd = paddle::framework;
USE_OP(add_two);
USE_OP_WITHOUT_KERNEL(fc);
USE_OP(sgd);
PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of Paddle Paddle");
......
add_python_test(test_framework test_protobuf.py test_scope.py
test_default_scope_funcs.py test_op_creation_methods.py
test_tensor.py test_fc_op.py test_add_two_op.py)
test_tensor.py test_fc_op.py test_add_two_op.py test_sgd_op.py)
import unittest
import numpy
from op_test_util import OpTestMeta
class TestSGD(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "sgd"
self.param = numpy.random.random((342, 345)).astype("float32")
self.grad = numpy.random.random((342, 345)).astype("float32")
self.learning_rate = 0.1
self.param_out = self.param - self.learning_rate * self.grad
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册