提交 3db52783 编写于 作者: Y Yu Yang 提交者: GitHub

Feature/py executor test (#4922)

* Implement FC layer with helper

* Update LayerHelper

* Add debug string for Python ProtoBuf

and Rename `Sync` to `Flush`

* Add check of ProtoBuf initialization

* Layer wrapper for FC

* Fix unittest

* Fix CI

* Add code generator

* AttributeChecker Better error log and speicalize bool

Since lots of types can be cast to bool

* Complete mlp, fit_a_line

* Expose get global scope

* Make global scope not thread-safe

1. It is no need to make global scope thread-safe, since it will be
invoked in Python main thread.
2. Do not free the global scope when C++ exit. Let the OS free memories,
otherwise, we need to handle the destroy dependencies.

See
https://google.github.io/styleguide/cppguide.html#Static_and_Global_Variables

* Fix

* Implementation of simple conv_2d layer

* Stash

* Remove private data members in OpRegister

* Fix bugs

* Stash

* Expose FeedFetchList as VarType

* Change ProgramDesc not a global variable

* Polish code style

* Stash

* Correct implement BlockDesc destructor

* Correct implement BlockDesc destructor

* Unify program as parameter name

* Fix bugs

* Add unittest

* Fix unit test error

* Remove unused functions

* Add clone for Python Program

* Working on executor

* Stash

* Add glog as dependencies of ops

* Use VLOG to logging some information is helpful when we debug Paddle

* Expose VarDesc::persistable to Python

* Test executor

* Complete unittest

* Polish code

* Fix merge error

* Follow comment

* Polish Python Code
上级 63ffe525
......@@ -43,7 +43,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD
cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward glog)
cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
......@@ -68,9 +68,13 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {
for (auto& var : block.vars()) {
if (var.persistable()) {
scope->Var(var.name());
auto* ptr = scope->Var(var.name());
VLOG(3) << "Create Variable " << var.name()
<< " global, which pointer is " << ptr;
} else {
local_scope.Var(var.name());
auto* ptr = local_scope.Var(var.name());
VLOG(3) << "Create Variable " << var.name()
<< " locally, which pointer is " << ptr;
}
}
......
......@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "glog/logging.h"
#include "paddle/framework/feed_fetch_type.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/variable.h"
......@@ -24,6 +26,7 @@ void SetFeedVariable(const LoDTensor& input, const std::string& var_name,
size_t index) {
// If var_name Variable is not found in GlobalScope, a new variable will
// be created.
VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index;
Variable* g_feed_value = GetGlobalScope().Var(var_name);
auto& feed_inputs =
*(g_feed_value->GetMutable<std::vector<paddle::framework::LoDTensor>>());
......@@ -40,10 +43,15 @@ LoDTensor& GetFetchVariable(const std::string& var_name, size_t index) {
// Since we want to fetch LodTensor from a variable, the variable must
// be created alreadly.
Variable* g_fetch_value = GetGlobalScope().FindVar(var_name);
auto& fetch_outputs =
*(g_fetch_value->GetMutable<std::vector<paddle::framework::LoDTensor>>());
PADDLE_ENFORCE(g_fetch_value->IsType<FeedFetchList>(),
"Only %s can be invoked by GetFetchVariable",
typeid(FeedFetchList).name());
auto& fetch_outputs = *g_fetch_value->GetMutable<FeedFetchList>();
auto& tensor = fetch_outputs[index];
VLOG(3) << "Fetch " << var_name << " with index " << index
<< " shape= " << tensor.dims();
PADDLE_ENFORCE_LT(index, fetch_outputs.size());
return fetch_outputs[index];
return tensor;
}
} // namespace framework
......
......@@ -112,6 +112,8 @@ message VarDesc {
enum VarType {
LOD_TENSOR = 1;
SELECTED_ROWS = 2;
FEED_MINIBATCH = 3;
FETCH_LIST = 4;
}
required string name = 1;
required VarType type = 2;
......
......@@ -80,4 +80,4 @@ TEST(ProgramDesc, copy_ctor) {
// different and it is correct.
}
} // namespace framework
} // namespace paddle
\ No newline at end of file
} // namespace paddle
......@@ -25,7 +25,10 @@ class Variable {
public:
template <typename T>
const T& Get() const {
PADDLE_ENFORCE(IsType<T>(), "Variable must be type %s", typeid(T).name());
PADDLE_ENFORCE(holder_ != nullptr, "Variable must hold some thing");
PADDLE_ENFORCE(IsType<T>(),
"Variable must be type %s, the holding type is %s",
typeid(T).name(), holder_->Type().name());
return *static_cast<const T*>(holder_->Ptr());
}
......
......@@ -26,8 +26,9 @@ class FeedOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto feed_var_name = Input("Input");
auto feed_var_name = Input("X");
auto *feed_var = scope.FindVar(feed_var_name);
PADDLE_ENFORCE(feed_var != nullptr,
"Cannot find feed_var in scope, feed_var_name is %s",
feed_var_name);
......@@ -40,6 +41,9 @@ class FeedOp : public framework::OperatorBase {
auto col = Attr<int>("col");
VLOG(3) << "Feed Var " << feed_var_name << "'s " << col << " column to var"
<< out_name;
auto &feed_list = feed_var->Get<framework::FeedFetchList>();
auto &feed_item = feed_list.at(static_cast<size_t>(col));
auto *out_item = out_var->GetMutable<framework::FeedFetchType>();
......@@ -48,10 +52,21 @@ class FeedOp : public framework::OperatorBase {
}
};
class FeedOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public:
FeedOpInfoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of feed op");
AddOutput("Out", "The output of feed op");
AddComment("feed op, it should not be configured by users directly");
AddAttr<int>("col", "column of feed");
}
};
} // namespace operators
} // namespace paddle
// We do not need to register OpInfoMaker,
// since feed operator will not be used by end users directly
REGISTER_OPERATOR(feed, paddle::operators::FeedOp,
paddle::framework::EmptyGradOpMaker);
paddle::framework::EmptyGradOpMaker,
paddle::operators::FeedOpInfoMaker);
......@@ -27,7 +27,7 @@ class FetchOp : public framework::OperatorBase {
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto fetch_var_name = Input("Input");
auto fetch_var_name = Input("X");
auto *fetch_var = scope.FindVar(fetch_var_name);
PADDLE_ENFORCE(fetch_var != nullptr,
"Cannot find fetch variable in scope, fetch_var_name is %s",
......@@ -52,13 +52,25 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs?
dst_item.CopyFromTensor(src_item, platform::CPUPlace(), dev_ctx);
VLOG(3) << "Fetch variable " << fetch_var_name << " to " << out_name;
}
};
class FetchOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public:
FetchOpInfoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of fetch op");
AddOutput("Out", "The output of fetch op");
AddComment("fetch op, it should not be configured by users directly");
AddAttr<int>("col", "column of fetch");
}
};
} // namespace operators
} // namespace paddle
// We do not need to register OpInfoMaker,
// since fetch operator will not be used by end users directly
REGISTER_OPERATOR(fetch, paddle::operators::FetchOp,
paddle::framework::EmptyGradOpMaker);
paddle::framework::EmptyGradOpMaker,
paddle::operators::FetchOpInfoMaker);
......@@ -222,7 +222,9 @@ void BindVarDsec(py::module &m) {
py::enum_<VarDesc::VarType>(var_desc, "VarType", "")
.value("LOD_TENSOR", VarDesc::LOD_TENSOR)
.value("SELECTED_ROWS", VarDesc::SELECTED_ROWS);
.value("SELECTED_ROWS", VarDesc::SELECTED_ROWS)
.value("FEED_MINIBATCH", VarDesc::FEED_MINIBATCH)
.value("FETCH_LIST", VarDesc::FETCH_LIST);
}
void BindOpDesc(py::module &m) {
......
......@@ -111,6 +111,7 @@ PYBIND11_PLUGIN(core) {
new (&instance) LoDTensor(new_lod);
#endif
})
.def("__init__", [](LoDTensor &instance) { new (&instance) LoDTensor(); })
.def("set_lod",
[](LoDTensor &self, const std::vector<std::vector<size_t>> &lod) {
#ifndef PADDLE_WITH_CUDA
......@@ -216,7 +217,8 @@ All parameter, weight, gradient are variables in Paddle.
.def(py::init<>())
.def("new_scope", [](Scope &self) -> Scope * { return &self.NewScope(); },
py::return_value_policy::reference)
.def("drop_kids", &Scope::DropKids);
.def("drop_kids", &Scope::DropKids)
.def_static("global_scope", &GetGlobalScope);
//! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python.
......@@ -264,6 +266,17 @@ All parameter, weight, gradient are variables in Paddle.
.def(py::init<>())
.def("__str__", string::to_string<const platform::CPUPlace &>);
py::class_<platform::Place>(m, "Place")
.def(py::init<>())
.def("set_place",
[](platform::Place &self, const platform::CPUPlace &cpu_place) {
self = cpu_place;
})
.def("set_place",
[](platform::Place &self, const platform::GPUPlace &gpu_place) {
self = gpu_place;
});
py::class_<OperatorBase>(m, "Operator")
.def_static("create",
[](py::bytes protobin) {
......@@ -437,14 +450,15 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<framework::Executor>(m, "Executor")
.def(py::init<std::vector<platform::Place> &>())
.def("run",
[](Executor &self, const ProgramDesc &program_desc, int block_id) {
[](Executor &self, ProgramDescBind *program_bind, int block_id) {
framework::Scope &global_scope = GetGlobalScope();
self.Run(program_desc, &global_scope, block_id);
self.Run(*program_bind->Proto(), &global_scope, block_id);
});
m.def("unique_integer", UniqueIntegerGenerator);
m.def("is_compile_gpu", IsCompileGPU);
//! FIXME: it is no need to `set_xxx_float/double/int`
m.def("set_feed_variable_float", framework::SetFeedVariable<float>);
m.def("set_feed_variable_double", framework::SetFeedVariable<double>);
m.def("set_feed_variable_int", framework::SetFeedVariable<int>);
......
import paddle.v2.framework.core as core
from paddle.v2.framework.framework import Block, Program
class Executor(object):
def __init__(self, places):
if not isinstance(places, list) and not isinstance(places, tuple):
places = [places]
act_places = []
for each in places:
p = core.Place()
p.set_place(each)
act_places.append(p)
self.executor = core.Executor(act_places)
def run(self,
program,
feed,
fetch_list,
feed_var_name='feed',
fetch_var_name='fetch'):
if not isinstance(program, Program):
raise TypeError()
program = program.clone()
global_block = program.global_block()
feed_var = global_block.create_var(
name=feed_var_name,
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True)
for i, name in enumerate(feed):
out = global_block.var(name)
global_block.prepend_op(
'feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
# FIXME
core.set_feed_variable_float(feed[name], feed_var.name, i)
fetch_var = global_block.create_var(
name=fetch_var_name,
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True)
for i, var in enumerate(fetch_list):
global_block.append_op(
type='fetch',
inputs={'X': [var]},
outputs={'Out': [fetch_var]},
attrs={'col': i})
self.executor.run(program.desc, 0)
return [
core.get_fetch_variable(fetch_var_name, i)
for i in xrange(len(fetch_list))
]
......@@ -256,7 +256,8 @@ class Operator(object):
self.desc.set_block_attr(attr_name, attrs[attr_name].desc)
self.desc.check_attrs()
self.desc.infer_shape(self.block.desc)
if type not in {'feed', 'fetch'}:
self.desc.infer_shape(self.block.desc)
def __str__(self):
protostr = self.desc.serialize_to_string()
......@@ -323,9 +324,12 @@ class Block(object):
return self.desc.id
def var(self, name):
if name not in self.vars:
if not isinstance(name, basestring):
raise TypeError()
v = self.vars.get(name, None)
if v is None:
raise ValueError("var %s not in this block" % name)
return self.vars[name]
return v
def all_parameters(self):
return {v for k, v in self.vars.iteritems() if isinstance(v, Parameter)}
......
......@@ -55,9 +55,11 @@ def data(name,
shape,
data_type='float32',
type=core.VarDesc.VarType.LOD_TENSOR,
append_batch_size=True,
program=None):
helper = LayerHelper('data', **locals())
shape = [-1] + shape # append batch size as -1
if append_batch_size:
shape = [-1] + shape # append batch size as -1
return helper.create_global_variable(
name=name, shape=shape, dtype=data_type, type=type)
......@@ -112,6 +114,7 @@ def _create_op_func_(op_type):
_create_op_func_('mean')
_create_op_func_('mul')
_create_op_func_('pool2d')
......
import unittest
from paddle.v2.framework.layers import mul, data
import paddle.v2.framework.core as core
from paddle.v2.framework.executor import Executor
from paddle.v2.framework.framework import g_program
import numpy
class TestExecutor(unittest.TestCase):
def test_mul(self):
a = data(name='a', shape=[784], data_type='float32')
b = data(
name='b',
shape=[784, 100],
data_type='float32',
append_batch_size=False)
out = mul(x=a, y=b)
place = core.CPUPlace()
a_np = numpy.random.random((100, 784)).astype('float32')
tensor_a = core.LoDTensor()
tensor_a.set(a_np, place)
b_np = numpy.random.random((784, 100)).astype('float32')
tensor_b = core.LoDTensor()
tensor_b.set(b_np, place)
exe = Executor(place)
outs = exe.run(g_program,
feed={'a': tensor_a,
'b': tensor_b},
fetch_list=[out])
out = numpy.array(outs[0])
self.assertEqual((100, 100), out.shape)
self.assertTrue(numpy.allclose(out, numpy.dot(a_np, b_np)))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册