diff --git a/paddle/fluid/framework/feed_fetch_method.cc b/paddle/fluid/framework/feed_fetch_method.cc index 3e9353f5cf67d8de62c5551f12ea786e49190549..b13d0d38075c75868f86c04e2f614596d1b78b28 100644 --- a/paddle/fluid/framework/feed_fetch_method.cc +++ b/paddle/fluid/framework/feed_fetch_method.cc @@ -16,7 +16,9 @@ limitations under the License. */ #include #include #include "glog/logging.h" +#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/platform/place.h" namespace paddle { namespace framework { @@ -53,5 +55,20 @@ LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name, return tensor; } +LoDTensor& GetVariableTensor(const Scope& scope, const std::string& var_name) { + Variable* var = scope.FindVar(var_name); + PADDLE_ENFORCE(var, "%s no in scope", var_name); + // TODO(panyx0718): hack, remove it once we run oprerator. + LoDTensor* tensor = var->GetMutable(); + int numel = 10; + float* data = + tensor->mutable_data(framework::make_ddim({numel}), + platform::CPUPlace(), sizeof(float) * numel); + for (size_t i = 0; i < numel; ++i) data[i] = 1; + + PADDLE_ENFORCE(var->IsType(), "Variable is not LoDTensor"); + return *var->GetMutable(); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/feed_fetch_method.h b/paddle/fluid/framework/feed_fetch_method.h index 7f504bfd232862c014cb59b6e8301eec74e0351f..031f8e01aa6128b803dcbfb990778e87d4fafc13 100644 --- a/paddle/fluid/framework/feed_fetch_method.h +++ b/paddle/fluid/framework/feed_fetch_method.h @@ -27,5 +27,7 @@ void SetFeedVariable(Scope* scope, const LoDTensor& input, LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name, size_t index); +LoDTensor& GetVariableTensor(const Scope& scope, const std::string& var_name); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index efdabffb9b33ddf007c13008d0f3afb7a3961eda..6c60a041a191f1db4a755c1c5714724342053791 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ syntax = "proto2"; -option optimize_for = LITE_RUNTIME; +// option optimize_for = LITE_RUNTIME; package paddle.framework.proto; // Any incompatible changes to ProgramDesc and its dependencies should diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 42cc65ddc938b88b5ec50cf72edc9798e77d20d0..a83535af9c738100adb189d4421c1eb69754f64d 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -14,20 +14,31 @@ #pragma once +#include +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { namespace imperative { -class TensorFuture { +class VariableBase { public: + VariableBase() {} + virtual ~VariableBase() {} + + framework::VarDesc* var_desc_; }; class Layer { public: virtual ~Layer() {} - virtual void Forward() { LOG(ERROR) << "forward at cpp."; } + virtual std::vector Forward( + const std::vector& inputs) { + std::vector vars; + return vars; + } virtual void Backward() { LOG(ERROR) << "backward at cpp."; } }; diff --git a/paddle/fluid/pybind/imperative.h b/paddle/fluid/pybind/imperative.h index bfab6bd9b90854aed782965d8a949be22f6ca95c..9a558fbdb8acc598878685bb54162ef23d54692d 100644 --- a/paddle/fluid/pybind/imperative.h +++ b/paddle/fluid/pybind/imperative.h @@ -14,8 +14,10 @@ limitations under the License. */ #pragma once #include +#include #include "paddle/fluid/imperative/layer.h" #include "pybind11/pybind11.h" +#include "pybind11/stl.h" namespace paddle { namespace pybind { @@ -24,8 +26,10 @@ class PyLayer : public imperative::Layer { public: using imperative::Layer::Layer; // Inherit constructors - void Forward() override { - PYBIND11_OVERLOAD(void, Layer, Forward, ); // NOLINT + std::vector Forward( + const std::vector& inputs) override { + PYBIND11_OVERLOAD(std::vector, Layer, Forward, + inputs); // NOLINT } void Backward() override { @@ -33,7 +37,7 @@ class PyLayer : public imperative::Layer { } }; -void BindTracer(pybind11::module *m); +void BindTracer(pybind11::module* m); } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index fa3e3835361f1a3a5cc552ecfe643016f02194ed..3cf1ec34a7a7642cd34bdc1f312ac1ffadaa3e5d 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -101,9 +101,23 @@ PYBIND11_MODULE(core, m) { BindException(&m); + py::class_(m, "VariableBase", + R"DOC()DOC") + .def_property( + "desc", + [](const imperative::VariableBase &self) { return self.var_desc_; }, + [](imperative::VariableBase &self, framework::VarDesc *var_desc) { + self.var_desc_ = var_desc; + }, + py::return_value_policy::reference); + py::class_ layer(m, "Layer"); layer.def(py::init<>()) - .def("forward", &imperative::Layer::Forward) + .def("forward", + [](imperative::Layer &self, + const std::vector &inputs) { + return self.Forward(inputs); + }) .def("backward", &imperative::Layer::Backward); BindTracer(&m); @@ -608,6 +622,7 @@ All parameter, weight, gradient are variables in Paddle. m.def("set_feed_variable", framework::SetFeedVariable); m.def("get_fetch_variable", framework::GetFetchVariable); + m.def("get_variable_tensor", framework::GetVariableTensor); m.def("_is_program_version_supported", IsProgramVersionSupported); diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 525c36702b0bfdc5d4b1b2b1afdbae2ff9e81564..235d692afe84e813a57c27a3cc13eb286b1f5fc0 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -18,6 +18,7 @@ import collections import contextlib import re import six +import sys import numpy as np @@ -211,7 +212,7 @@ def _debug_string_(proto, throw_on_error=True): return proto.__str__() -class Variable(object): +class Variable(core.VariableBase): """ In Fluid, every input and output of an operator is a variable. In most cases, variables are used for holding different kinds of data or training @@ -282,15 +283,20 @@ class Variable(object): name = unique_name.generate('_generated_var') is_new_var = False name = cpt.to_text(name) - self.desc = self.block.desc.find_var(cpt.to_bytes(name)) + desc = self.block.desc.find_var(cpt.to_bytes(name)) - if self.desc is None: + if desc is None: + # sys.stderr.write('desc is None\n') self.desc = self.block.desc.var(cpt.to_bytes(name)) is_new_var = True + else: + # sys.stderr.write('found var %s %s' % (name, self.desc)) + self.desc = desc if is_new_var: self.desc.set_type(type) elif self.desc.type() != type: + # sys.stderr.write('%s vs %s\n' % (self.desc.type(), type)) raise ValueError("Variable {0} has been created before. The " "previous type is {1}; the new type is {2}. They" " are not matched".format(self.name, @@ -355,6 +361,10 @@ class Variable(object): self.stop_gradient = stop_gradient self.is_data = is_data + def numpy(self, scope): + tensor = core.get_variable_tensor(scope, self.desc.name()) + return np.array(tensor) + def __str__(self): return self.to_string(True) diff --git a/python/paddle/fluid/imperative/layers.py b/python/paddle/fluid/imperative/layers.py index ae96a0a6b2d24bc7e98076566600a72fe64b87d1..37b36f2cd0c999a3d8beaefcbdd34a4774335ccb 100644 --- a/python/paddle/fluid/imperative/layers.py +++ b/python/paddle/fluid/imperative/layers.py @@ -12,14 +12,43 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys +import numpy as np + from paddle.fluid import core +from paddle.fluid import framework __all__ = ['PyLayer'] class PyLayer(core.Layer): def __init__(self): - pass + self._scope = core.Scope() + + def __call__(self, inputs): + if not isinstance(inputs, list) and not isinstance(inputs, tuple): + inputs = [inputs] + + var_inputs = [] + for x in inputs: + if isinstance(x, np.ndarray): + tensor = core.LoDTensor() + tensor.set(x, core.CPUPlace()) + x = framework.Variable( + framework.default_main_program().current_block(), + type=core.VarDesc.VarType.LOD_TENSOR, + name=None, + shape=x.shape, + dtype=x.dtype) + elif not isinstance(x, framework.Variable): + raise ValueError("not var or ndarray %s" % type(x)) + self._scope.var(x.name) + var_inputs.append(x) + outputs = self.forward(var_inputs) + for out in outputs: + self._scope.var(out.name) + return outputs - def forward(self): + def forward(self, inputs): print("at python.") + return [] diff --git a/python/paddle/fluid/layer_helper.py b/python/paddle/fluid/layer_helper.py index dc317de9abbd06f4021e64b87ea88ba6af8809c9..ceabb52215ee5bc8487b90f23a12d47e6838a198 100644 --- a/python/paddle/fluid/layer_helper.py +++ b/python/paddle/fluid/layer_helper.py @@ -17,6 +17,8 @@ from __future__ import print_function import copy import itertools import six +import sys +import numpy as np from .framework import Variable, Parameter, default_main_program, default_startup_program, dtype_is_floating from . import unique_name @@ -46,23 +48,43 @@ class LayerHelper(object): def startup_program(self): return default_startup_program() + def _np_to_variable(self, x): + tensor = core.LoDTensor() + sys.stderr.write('%s %s\n' % (tensor, x)) + tensor.set(x, core.CPUPlace()) + return Variable( + self.main_program.current_block(), + type=core.VarDesc.VarType.LOD_TENSOR, + name=None, + shape=x.shape, + dtype=x.dtype) + + def to_variable(self, x): + if isinstance(x, Variable): + return x + elif isinstance(x, np.ndarray): + return self._np_to_variable(x) + else: + raise ValueError("inputs wrong type %s\n" % x) + + def to_variables(self, inputs): + if isinstance(inputs, list) or isinstance(inputs, tuple): + return [self._to_variable(x) for x in inputs] + else: + return [self._to_variable(inputs)] + def append_op(self, *args, **kwargs): return self.main_program.current_block().append_op(*args, **kwargs) def multiple_input(self, input_param_name='input'): inputs = self.kwargs.get(input_param_name, []) - type_error = TypeError( - "Input of {0} layer should be Variable or sequence of Variable". - format(self.layer_type)) - if isinstance(inputs, Variable): - inputs = [inputs] - elif not isinstance(inputs, list) and not isinstance(inputs, tuple): - raise type_error + ret = [] + if isinstance(inputs, list) or isinstance(inputs, tuple): + for inp in inputs: + ret.append(self.to_variable(inp)) else: - for each in inputs: - if not isinstance(each, Variable): - raise type_error - return inputs + ret.append(self.to_variable(inputs)) + return ret def input(self, input_param_name='input'): inputs = self.multiple_input(input_param_name) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 3c2975729c568a97bb17cef0876a2bd50f4c5e27..35232bd48981ce249fe662ca8bdfeb5393ffe7b3 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -6455,7 +6455,8 @@ def relu(x, name=None): helper = LayerHelper('relu', **locals()) dtype = helper.input_dtype(input_param_name='x') out = helper.create_variable_for_type_inference(dtype) - helper.append_op(type="relu", inputs={"X": x}, outputs={"Out": out}) + helper.append_op( + type="relu", inputs={"X": helper.input('x')}, outputs={"Out": out}) return out diff --git a/python/paddle/fluid/tests/unittests/test_imperative.py b/python/paddle/fluid/tests/unittests/test_imperative.py index cdd90accc1d6de5554430b2233977f5c920338e3..a10b5b34aa52ad9b73fb71619552ad05d4cdf1a3 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative.py +++ b/python/paddle/fluid/tests/unittests/test_imperative.py @@ -14,24 +14,42 @@ import unittest import sys +import numpy as np + import paddle.fluid as fluid from paddle.fluid import core +class MyLayer(fluid.imperative.PyLayer): + def __init__(self): + super(MyLayer, self).__init__() + + def forward(self, inputs): + x = fluid.layers.relu(inputs[0]) + return [fluid.layers.elementwise_mul(x, x)] + + class TestImperative(unittest.TestCase): def test_layer(self): cl = core.Layer() - cl.forward() + cl.forward([]) l = fluid.imperative.PyLayer() - l.forward() + l.forward([]) def test_imperative_trace(self): with fluid.imperative.guard(): self.assertTrue(fluid.imperative.enabled()) - x = fluid.layers.data(name='x', shape=[3, 4], dtype='float32') - x = fluid.layers.relu(x) - x = fluid.layers.elementwise_mul(x, x) - self.assertIsNotNone(x) + x = fluid.layers.data(name='abc', shape=[3, 4], dtype='float32') + for _ in xrange(2): + x = fluid.layers.relu(x) + x = fluid.layers.elementwise_mul(x, x) + self.assertIsNotNone(x) + + def test_layer_in_out(self): + l = MyLayer() + x = l(np.ones([1], np.float32))[0] + self.assertIsNotNone(x) + sys.stderr.write("%s output: %s\n" % (x, x.numpy(scope=l._scope))) if __name__ == '__main__':