You need to sign in or sign up before continuing.
提交 cc6c33b8 编写于 作者: Q Qiao Longfei 提交者: GitHub

export Backward to python (#3174)

* export Backward to python
上级 5cb29a8f
......@@ -50,10 +50,6 @@ The equation is: Out = X + Y
class AddOpGrad : public OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {}
std::string DebugString() const override {
LOG(INFO) << "AddOpGrad";
return "";
}
};
} // namespace operators
......
cc_library(paddle_pybind SHARED
SRCS pybind.cc
DEPS pybind python
DEPS pybind python backward
fc_op
sgd_op
add_op
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <fstream>
#include <vector>
#include "paddle/framework/backward.h"
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
......@@ -45,6 +46,10 @@ template <typename ClassType>
void ExposeOperator(ClassType& m) {
m.def("infer_shape", &ClassType::type::InferShape)
.def("run", &ClassType::type::Run)
.def("type",
[](const typename ClassType::type& op) -> std::string {
return op.type_;
})
.def("outputs",
[](const typename ClassType::type& op) -> std::vector<std::string> {
return op.outputs_;
......@@ -192,6 +197,13 @@ All parameter, weight, gradient are variables in Paddle.
desc.InitializationErrorString());
return pd::OpRegistry::CreateOp(desc);
});
operator_base.def("backward",
[](const pd::OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars) {
return pd::Backward(forwardOp, no_grad_vars);
});
ExposeOperator(operator_base);
py::class_<pd::NetOp, std::shared_ptr<pd::NetOp>> net(m, "Net");
......
import unittest
from op_test_util import OpTestMeta
import numpy
import paddle.v2.framework.core as core
import paddle.v2.framework.create_op_creation_methods as creation
from op_test_util import OpTestMeta
class TestAddOp(unittest.TestCase):
......@@ -13,5 +17,14 @@ class TestAddOp(unittest.TestCase):
self.Out = self.X + self.Y
class TestAddGradOp(unittest.TestCase):
def test_add_grad(self):
op = creation.op_creations.add_two(X="X", Y="Y", Out="Out")
backward_op = core.Operator.backward(op, set())
self.assertEqual(backward_op.type(), "add_two_grad")
expected = '''Op(add_two_grad), inputs:(X, Y, Out, Out@GRAD), outputs:(X@GRAD, Y@GRAD).'''
self.assertEqual(expected, str(backward_op))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册