提交 b1f6fda5 编写于 作者: X Xin Pan

run forward

上级 a6d23083
......@@ -16,7 +16,9 @@ limitations under the License. */
#include <string>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
......@@ -53,5 +55,20 @@ LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name,
return tensor;
}
LoDTensor& GetVariableTensor(const Scope& scope, const std::string& var_name) {
Variable* var = scope.FindVar(var_name);
PADDLE_ENFORCE(var, "%s no in scope", var_name);
// TODO(panyx0718): hack, remove it once we run oprerator.
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int numel = 10;
float* data =
tensor->mutable_data<float>(framework::make_ddim({numel}),
platform::CPUPlace(), sizeof(float) * numel);
for (size_t i = 0; i < numel; ++i) data[i] = 1;
PADDLE_ENFORCE(var->IsType<LoDTensor>(), "Variable is not LoDTensor");
return *var->GetMutable<LoDTensor>();
}
} // namespace framework
} // namespace paddle
......@@ -27,5 +27,7 @@ void SetFeedVariable(Scope* scope, const LoDTensor& input,
LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name,
size_t index);
LoDTensor& GetVariableTensor(const Scope& scope, const std::string& var_name);
} // namespace framework
} // namespace paddle
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
option optimize_for = LITE_RUNTIME;
// option optimize_for = LITE_RUNTIME;
package paddle.framework.proto;
// Any incompatible changes to ProgramDesc and its dependencies should
......
......@@ -14,20 +14,31 @@
#pragma once
#include <vector>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace imperative {
class TensorFuture {
class VariableBase {
public:
VariableBase() {}
virtual ~VariableBase() {}
framework::VarDesc* var_desc_;
};
class Layer {
public:
virtual ~Layer() {}
virtual void Forward() { LOG(ERROR) << "forward at cpp."; }
virtual std::vector<VariableBase> Forward(
const std::vector<VariableBase>& inputs) {
std::vector<VariableBase> vars;
return vars;
}
virtual void Backward() { LOG(ERROR) << "backward at cpp."; }
};
......
......@@ -14,8 +14,10 @@ limitations under the License. */
#pragma once
#include <Python.h>
#include <vector>
#include "paddle/fluid/imperative/layer.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace paddle {
namespace pybind {
......@@ -24,8 +26,10 @@ class PyLayer : public imperative::Layer {
public:
using imperative::Layer::Layer; // Inherit constructors
void Forward() override {
PYBIND11_OVERLOAD(void, Layer, Forward, ); // NOLINT
std::vector<imperative::VariableBase> Forward(
const std::vector<imperative::VariableBase>& inputs) override {
PYBIND11_OVERLOAD(std::vector<imperative::VariableBase>, Layer, Forward,
inputs); // NOLINT
}
void Backward() override {
......@@ -33,7 +37,7 @@ class PyLayer : public imperative::Layer {
}
};
void BindTracer(pybind11::module *m);
void BindTracer(pybind11::module* m);
} // namespace pybind
} // namespace paddle
......@@ -101,9 +101,23 @@ PYBIND11_MODULE(core, m) {
BindException(&m);
py::class_<imperative::VariableBase>(m, "VariableBase",
R"DOC()DOC")
.def_property(
"desc",
[](const imperative::VariableBase &self) { return self.var_desc_; },
[](imperative::VariableBase &self, framework::VarDesc *var_desc) {
self.var_desc_ = var_desc;
},
py::return_value_policy::reference);
py::class_<imperative::Layer, PyLayer /* <--- trampoline*/> layer(m, "Layer");
layer.def(py::init<>())
.def("forward", &imperative::Layer::Forward)
.def("forward",
[](imperative::Layer &self,
const std::vector<imperative::VariableBase> &inputs) {
return self.Forward(inputs);
})
.def("backward", &imperative::Layer::Backward);
BindTracer(&m);
......@@ -608,6 +622,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("set_feed_variable", framework::SetFeedVariable);
m.def("get_fetch_variable", framework::GetFetchVariable);
m.def("get_variable_tensor", framework::GetVariableTensor);
m.def("_is_program_version_supported", IsProgramVersionSupported);
......
......@@ -18,6 +18,7 @@ import collections
import contextlib
import re
import six
import sys
import numpy as np
......@@ -211,7 +212,7 @@ def _debug_string_(proto, throw_on_error=True):
return proto.__str__()
class Variable(object):
class Variable(core.VariableBase):
"""
In Fluid, every input and output of an operator is a variable. In most
cases, variables are used for holding different kinds of data or training
......@@ -282,15 +283,20 @@ class Variable(object):
name = unique_name.generate('_generated_var')
is_new_var = False
name = cpt.to_text(name)
self.desc = self.block.desc.find_var(cpt.to_bytes(name))
desc = self.block.desc.find_var(cpt.to_bytes(name))
if self.desc is None:
if desc is None:
# sys.stderr.write('desc is None\n')
self.desc = self.block.desc.var(cpt.to_bytes(name))
is_new_var = True
else:
# sys.stderr.write('found var %s %s' % (name, self.desc))
self.desc = desc
if is_new_var:
self.desc.set_type(type)
elif self.desc.type() != type:
# sys.stderr.write('%s vs %s\n' % (self.desc.type(), type))
raise ValueError("Variable {0} has been created before. The "
"previous type is {1}; the new type is {2}. They"
" are not matched".format(self.name,
......@@ -355,6 +361,10 @@ class Variable(object):
self.stop_gradient = stop_gradient
self.is_data = is_data
def numpy(self, scope):
tensor = core.get_variable_tensor(scope, self.desc.name())
return np.array(tensor)
def __str__(self):
return self.to_string(True)
......
......@@ -12,14 +12,43 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
from paddle.fluid import core
from paddle.fluid import framework
__all__ = ['PyLayer']
class PyLayer(core.Layer):
def __init__(self):
pass
self._scope = core.Scope()
def __call__(self, inputs):
if not isinstance(inputs, list) and not isinstance(inputs, tuple):
inputs = [inputs]
var_inputs = []
for x in inputs:
if isinstance(x, np.ndarray):
tensor = core.LoDTensor()
tensor.set(x, core.CPUPlace())
x = framework.Variable(
framework.default_main_program().current_block(),
type=core.VarDesc.VarType.LOD_TENSOR,
name=None,
shape=x.shape,
dtype=x.dtype)
elif not isinstance(x, framework.Variable):
raise ValueError("not var or ndarray %s" % type(x))
self._scope.var(x.name)
var_inputs.append(x)
outputs = self.forward(var_inputs)
for out in outputs:
self._scope.var(out.name)
return outputs
def forward(self):
def forward(self, inputs):
print("at python.")
return []
......@@ -17,6 +17,8 @@ from __future__ import print_function
import copy
import itertools
import six
import sys
import numpy as np
from .framework import Variable, Parameter, default_main_program, default_startup_program, dtype_is_floating
from . import unique_name
......@@ -46,23 +48,43 @@ class LayerHelper(object):
def startup_program(self):
return default_startup_program()
def _np_to_variable(self, x):
tensor = core.LoDTensor()
sys.stderr.write('%s %s\n' % (tensor, x))
tensor.set(x, core.CPUPlace())
return Variable(
self.main_program.current_block(),
type=core.VarDesc.VarType.LOD_TENSOR,
name=None,
shape=x.shape,
dtype=x.dtype)
def to_variable(self, x):
if isinstance(x, Variable):
return x
elif isinstance(x, np.ndarray):
return self._np_to_variable(x)
else:
raise ValueError("inputs wrong type %s\n" % x)
def to_variables(self, inputs):
if isinstance(inputs, list) or isinstance(inputs, tuple):
return [self._to_variable(x) for x in inputs]
else:
return [self._to_variable(inputs)]
def append_op(self, *args, **kwargs):
return self.main_program.current_block().append_op(*args, **kwargs)
def multiple_input(self, input_param_name='input'):
inputs = self.kwargs.get(input_param_name, [])
type_error = TypeError(
"Input of {0} layer should be Variable or sequence of Variable".
format(self.layer_type))
if isinstance(inputs, Variable):
inputs = [inputs]
elif not isinstance(inputs, list) and not isinstance(inputs, tuple):
raise type_error
ret = []
if isinstance(inputs, list) or isinstance(inputs, tuple):
for inp in inputs:
ret.append(self.to_variable(inp))
else:
for each in inputs:
if not isinstance(each, Variable):
raise type_error
return inputs
ret.append(self.to_variable(inputs))
return ret
def input(self, input_param_name='input'):
inputs = self.multiple_input(input_param_name)
......
......@@ -6455,7 +6455,8 @@ def relu(x, name=None):
helper = LayerHelper('relu', **locals())
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(type="relu", inputs={"X": x}, outputs={"Out": out})
helper.append_op(
type="relu", inputs={"X": helper.input('x')}, outputs={"Out": out})
return out
......
......@@ -14,25 +14,43 @@
import unittest
import sys
import numpy as np
import paddle.fluid as fluid
from paddle.fluid import core
class MyLayer(fluid.imperative.PyLayer):
def __init__(self):
super(MyLayer, self).__init__()
def forward(self, inputs):
x = fluid.layers.relu(inputs[0])
return [fluid.layers.elementwise_mul(x, x)]
class TestImperative(unittest.TestCase):
def test_layer(self):
cl = core.Layer()
cl.forward()
cl.forward([])
l = fluid.imperative.PyLayer()
l.forward()
l.forward([])
def test_imperative_trace(self):
with fluid.imperative.guard():
self.assertTrue(fluid.imperative.enabled())
x = fluid.layers.data(name='x', shape=[3, 4], dtype='float32')
x = fluid.layers.data(name='abc', shape=[3, 4], dtype='float32')
for _ in xrange(2):
x = fluid.layers.relu(x)
x = fluid.layers.elementwise_mul(x, x)
self.assertIsNotNone(x)
def test_layer_in_out(self):
l = MyLayer()
x = l(np.ones([1], np.float32))[0]
self.assertIsNotNone(x)
sys.stderr.write("%s output: %s\n" % (x, x.numpy(scope=l._scope)))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册