提交 ef7e76fc 编写于 作者: Y Yu Yang

Merge branch 'develop' into make_network_op

...@@ -69,7 +69,7 @@ TEST(OpKernel, all) { ...@@ -69,7 +69,7 @@ TEST(OpKernel, all) {
net->Run(scope, dev_ctx); net->Run(scope, dev_ctx);
ASSERT_EQ(2, infer_shape_cnt); ASSERT_EQ(2, infer_shape_cnt);
ASSERT_EQ(2, run_cnt); ASSERT_EQ(2, run_cnt);
ASSERT_THROW(net->AddOp(op2), std::runtime_error); ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet);
} }
TEST(AddBackwardOp, TestGradOp) { TEST(AddBackwardOp, TestGradOp) {
auto net = std::make_shared<PlainNet>(); auto net = std::make_shared<PlainNet>();
......
...@@ -90,7 +90,7 @@ TEST(OpRegistry, IllegalAttr) { ...@@ -90,7 +90,7 @@ TEST(OpRegistry, IllegalAttr) {
bool caught = false; bool caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "larger_than check fail"; std::string msg = "larger_than check fail";
const char* err_msg = err.what(); const char* err_msg = err.what();
...@@ -136,7 +136,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -136,7 +136,7 @@ TEST(OpRegistry, CustomChecker) {
bool caught = false; bool caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "Attribute 'test_attr' is required!"; std::string msg = "Attribute 'test_attr' is required!";
const char* err_msg = err.what(); const char* err_msg = err.what();
...@@ -154,7 +154,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -154,7 +154,7 @@ TEST(OpRegistry, CustomChecker) {
caught = false; caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (std::runtime_error& err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "'test_attr' must be even!"; std::string msg = "'test_attr' must be even!";
const char* err_msg = err.what(); const char* err_msg = err.what();
...@@ -192,7 +192,7 @@ TEST(ProtoMaker, DuplicatedAttr) { ...@@ -192,7 +192,7 @@ TEST(ProtoMaker, DuplicatedAttr) {
pd::OpProto op_proto; pd::OpProto op_proto;
pd::OpAttrChecker op_checker; pd::OpAttrChecker op_checker;
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker); auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), std::runtime_error); ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
} }
class TestInOutProtoMaker : public pd::OpProtoAndCheckerMaker { class TestInOutProtoMaker : public pd::OpProtoAndCheckerMaker {
...@@ -208,5 +208,5 @@ TEST(ProtoMaker, DuplicatedInOut) { ...@@ -208,5 +208,5 @@ TEST(ProtoMaker, DuplicatedInOut) {
pd::OpProto op_proto; pd::OpProto op_proto;
pd::OpAttrChecker op_checker; pd::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker); auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), std::runtime_error); ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
} }
...@@ -56,7 +56,9 @@ class Scope { ...@@ -56,7 +56,9 @@ class Scope {
if (var) { if (var) {
return var; return var;
} else { } else {
vars_[name] = std::unique_ptr<Variable>(new Variable()); auto ptr = new Variable();
name_to_var_[name] = std::unique_ptr<Variable>(ptr);
var_to_name_[ptr] = name;
return GetVariable(name); return GetVariable(name);
} }
} }
...@@ -68,8 +70,8 @@ class Scope { ...@@ -68,8 +70,8 @@ class Scope {
* from it's parent scope. Return nullptr if not found. * from it's parent scope. Return nullptr if not found.
*/ */
Variable* GetVariable(const std::string& name) const { Variable* GetVariable(const std::string& name) const {
auto it = vars_.find(name); auto it = name_to_var_.find(name);
if (it != vars_.end()) { if (it != name_to_var_.end()) {
return it->second.get(); return it->second.get();
} else if (parent_ != nullptr) { } else if (parent_ != nullptr) {
return parent_->GetVariable(name); return parent_->GetVariable(name);
...@@ -84,12 +86,21 @@ class Scope { ...@@ -84,12 +86,21 @@ class Scope {
* Find if there is a Variable in this scope and it's parent scope * Find if there is a Variable in this scope and it's parent scope
*/ */
bool HasVariable(const std::string& name) const { bool HasVariable(const std::string& name) const {
return (vars_.find(name) != vars_.end() || return (name_to_var_.find(name) != name_to_var_.end() ||
(parent_ && parent_->HasVariable(name))); (parent_ && parent_->HasVariable(name)));
} }
std::string GetVariableName(Variable* const var) const {
try {
return var_to_name_.at(var);
} catch (...) {
return "";
}
}
private: private:
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_; std::unordered_map<Variable*, std::string> var_to_name_;
std::unordered_map<std::string, std::unique_ptr<Variable>> name_to_var_;
std::shared_ptr<Scope> parent_{nullptr}; std::shared_ptr<Scope> parent_{nullptr};
}; };
......
...@@ -40,6 +40,11 @@ TEST(Scope, Create) { ...@@ -40,6 +40,11 @@ TEST(Scope, Create) {
/// already exist. /// already exist.
Variable* var4 = scope->CreateVariable("a"); Variable* var4 = scope->CreateVariable("a");
EXPECT_EQ(var4, var2); EXPECT_EQ(var4, var2);
EXPECT_EQ("a", scope->GetVariableName(var4));
Scope scope2;
auto var = scope2.CreateVariable("tmp");
EXPECT_EQ("", scope->GetVariableName(var));
} }
TEST(Scope, Parent) { TEST(Scope, Parent) {
......
...@@ -33,7 +33,7 @@ TEST(Tensor, DataAssert) { ...@@ -33,7 +33,7 @@ TEST(Tensor, DataAssert) {
bool caught = false; bool caught = false;
try { try {
src_tensor.data<double>(); src_tensor.data<double>();
} catch (std::runtime_error& err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = std::string msg =
"Tenosr holds no memory. Call Tensor::mutable_data first."; "Tenosr holds no memory. Call Tensor::mutable_data first.";
...@@ -107,7 +107,7 @@ TEST(Tensor, ShareDataWith) { ...@@ -107,7 +107,7 @@ TEST(Tensor, ShareDataWith) {
bool caught = false; bool caught = false;
try { try {
dst_tensor.ShareDataWith<float>(src_tensor); dst_tensor.ShareDataWith<float>(src_tensor);
} catch (std::runtime_error& err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = std::string msg =
"Tenosr holds no memory. Call Tensor::mutable_data first."; "Tenosr holds no memory. Call Tensor::mutable_data first.";
......
...@@ -36,6 +36,21 @@ limitations under the License. */ ...@@ -36,6 +36,21 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
struct EnforceNotMet : public std::exception {
std::exception_ptr exp_;
std::string err_str_;
EnforceNotMet(std::exception_ptr e, const char* f, int l) : exp_(e) {
try {
std::rethrow_exception(exp_);
} catch (const std::exception& exp) {
err_str_ = string::Sprintf("%s at [%s:%d]", exp.what(), f, l);
}
}
const char* what() const noexcept { return err_str_.c_str(); }
};
// Because most enforce conditions would evaluate to true, we can use // Because most enforce conditions would evaluate to true, we can use
// __builtin_expect to instruct the C++ compiler to generate code that // __builtin_expect to instruct the C++ compiler to generate code that
// always forces branch prediction of true. // always forces branch prediction of true.
...@@ -52,9 +67,7 @@ template <typename... Args> ...@@ -52,9 +67,7 @@ template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
int stat, const Args&... args) { int stat, const Args&... args) {
if (UNLIKELY(!(stat))) { if (UNLIKELY(!(stat))) {
throw std::runtime_error( throw std::runtime_error(string::Sprintf(args...));
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
} }
} }
...@@ -64,12 +77,8 @@ template <typename... Args> ...@@ -64,12 +77,8 @@ template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
cudaError_t e, const Args&... args) { cudaError_t e, const Args&... args) {
if (UNLIKELY(e)) { if (UNLIKELY(e)) {
// clang-format off throw thrust::system_error(e, thrust::cuda_category(),
throw thrust::system_error( string::Sprintf(args...));
e, thrust::cuda_category(),
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
// clang-format on
} }
} }
...@@ -77,12 +86,8 @@ template <typename... Args> ...@@ -77,12 +86,8 @@ template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
curandStatus_t stat, const Args&... args) { curandStatus_t stat, const Args&... args) {
if (stat != CURAND_STATUS_SUCCESS) { if (stat != CURAND_STATUS_SUCCESS) {
// clang-format off throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(),
throw thrust::system_error( string::Sprintf(args...));
cudaErrorLaunchFailure, thrust::cuda_category(),
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
// clang-format on
} }
} }
...@@ -92,12 +97,8 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( ...@@ -92,12 +97,8 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
if (stat == CUDNN_STATUS_SUCCESS) { if (stat == CUDNN_STATUS_SUCCESS) {
return; return;
} else { } else {
// clang-format off throw std::runtime_error(platform::dynload::cudnnGetErrorString(stat) +
throw std::runtime_error( string::Sprintf(args...));
platform::dynload::cudnnGetErrorString(stat) +
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
// clang-format on
} }
} }
...@@ -126,22 +127,27 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( ...@@ -126,22 +127,27 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
} else if (stat == CUBLAS_STATUS_LICENSE_ERROR) { } else if (stat == CUBLAS_STATUS_LICENSE_ERROR) {
err = "CUBLAS: license error, "; err = "CUBLAS: license error, ";
} }
throw std::runtime_error(err + string::Sprintf(args...) + throw std::runtime_error(err + string::Sprintf(args...));
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
} }
#endif // PADDLE_ONLY_CPU #endif // PADDLE_ONLY_CPU
#define PADDLE_THROW(...) \ #define PADDLE_THROW(...) \
do { \ do { \
throw std::runtime_error( \ throw ::paddle::platform::EnforceNotMet( \
string::Sprintf(__VA_ARGS__) + \ std::make_exception_ptr( \
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); \ std::runtime_error(string::Sprintf(__VA_ARGS__))), \
__FILE__, __LINE__); \
} while (0) } while (0)
#define PADDLE_ENFORCE(...) \ #define PADDLE_ENFORCE(...) \
do { \ do { \
::paddle::platform::throw_on_error(__VA_ARGS__); \ try { \
::paddle::platform::throw_on_error(__VA_ARGS__); \
} catch (...) { \
throw ::paddle::platform::EnforceNotMet(std::current_exception(), \
__FILE__, __LINE__); \
} \
} while (0) } while (0)
} // namespace platform } // namespace platform
......
...@@ -23,7 +23,7 @@ TEST(ENFORCE, FAILED) { ...@@ -23,7 +23,7 @@ TEST(ENFORCE, FAILED) {
bool in_catch = false; bool in_catch = false;
try { try {
PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123); PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123);
} catch (const std::runtime_error& error) { } catch (paddle::platform::EnforceNotMet error) {
// your error handling code here // your error handling code here
in_catch = true; in_catch = true;
std::string msg = "Enforce is not ok 123 at all"; std::string msg = "Enforce is not ok 123 at all";
......
...@@ -48,6 +48,11 @@ void ExposeOperator(ClassType& m) { ...@@ -48,6 +48,11 @@ void ExposeOperator(ClassType& m) {
.def("__str__", &ClassType::type::DebugString); .def("__str__", &ClassType::type::DebugString);
} }
static size_t UniqueIntegerGenerator() {
static std::atomic<size_t> generator;
return generator.fetch_add(1);
}
PYBIND11_PLUGIN(core) { PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of PaddlePaddle"); py::module m("core", "C++ core of PaddlePaddle");
...@@ -98,7 +103,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -98,7 +103,8 @@ All parameter, weight, gradient are variables in Paddle.
py::return_value_policy::reference) py::return_value_policy::reference)
.def("create_var", .def("create_var",
&pd::Scope::CreateVariable, &pd::Scope::CreateVariable,
py::return_value_policy::reference); py::return_value_policy::reference)
.def("get_var_name", &pd::Scope::GetVariableName);
//! @note: Be careful! PyBind will return std::string as an unicode, not //! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python. //! Python str. If you want a str object, you should cast them in Python.
...@@ -141,23 +147,24 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -141,23 +147,24 @@ All parameter, weight, gradient are variables in Paddle.
ExposeOperator(operator_base); ExposeOperator(operator_base);
using PlainNetPtr = std::shared_ptr<pd::PlainNet>; using PlainNetPtr = std::shared_ptr<pd::PlainNet>;
py::class_<pd::PlainNet, PlainNetPtr> plain_net(m, "PlainNet"); py::class_<pd::PlainNet, PlainNetPtr> net(m, "Net");
plain_net net.def_static("create",
.def_static("create", []() -> std::shared_ptr<pd::PlainNet> {
[]() -> std::shared_ptr<pd::PlainNet> { auto retv = std::make_shared<pd::PlainNet>();
auto retv = std::make_shared<pd::PlainNet>(); retv->type_ = "plain_net";
retv->type_ = "plain_net"; return retv;
return retv; })
})
.def("add_op", &pd::PlainNet::AddOp) .def("add_op", &pd::PlainNet::AddOp)
.def("add_op", .def("add_op",
[](PlainNetPtr& self, const PlainNetPtr& plain_net) -> void { [](PlainNetPtr& self, const PlainNetPtr& net) -> void {
self->AddOp(std::static_pointer_cast<pd::OperatorBase>(plain_net)); self->AddOp(std::static_pointer_cast<pd::OperatorBase>(net));
}) })
.def("complete_add_op", &pd::PlainNet::CompleteAddOp) .def("complete_add_op", &pd::PlainNet::CompleteAddOp)
.def("complete_add_op", [](PlainNetPtr& self) { self->CompleteAddOp(); }); .def("complete_add_op", [](PlainNetPtr& self) { self->CompleteAddOp(); });
ExposeOperator(plain_net); ExposeOperator(net);
m.def("unique_integer", UniqueIntegerGenerator);
return m.ptr(); return m.ptr();
} }
...@@ -2055,8 +2055,7 @@ class BatchNormLayer(LayerBase): ...@@ -2055,8 +2055,7 @@ class BatchNormLayer(LayerBase):
# Automatically select cudnn_batch_norm for GPU and batch_norm for CPU. # Automatically select cudnn_batch_norm for GPU and batch_norm for CPU.
# Also based on cudnn version. # Also based on cudnn version.
use_cudnn = use_gpu and batch_norm_type != "batch_norm" and \ use_cudnn = use_gpu and batch_norm_type != "batch_norm" and \
((not parallel_nn) or self.config.device > -1) and \ ((not parallel_nn) or self.config.device > -1)
cudnn_version >= 4007
self.layer_type = "cudnn_batch_norm" if use_cudnn else "batch_norm" self.layer_type = "cudnn_batch_norm" if use_cudnn else "batch_norm"
super(BatchNormLayer, self).__init__( super(BatchNormLayer, self).__init__(
name, self.layer_type, 0, inputs=inputs, **xargs) name, self.layer_type, 0, inputs=inputs, **xargs)
......
...@@ -34,6 +34,7 @@ import minibatch ...@@ -34,6 +34,7 @@ import minibatch
import plot import plot
import image import image
import model import model
import paddle.trainer.config_parser as cp
__all__ = [ __all__ = [
'optimizer', 'optimizer',
...@@ -58,6 +59,8 @@ __all__ = [ ...@@ -58,6 +59,8 @@ __all__ = [
'model', 'model',
] ]
cp.begin_parse()
def init(**kwargs): def init(**kwargs):
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
...@@ -73,6 +76,11 @@ def init(**kwargs): ...@@ -73,6 +76,11 @@ def init(**kwargs):
for key in args_dict.keys(): for key in args_dict.keys():
args.append('--%s=%s' % (key, str(args_dict[key]))) args.append('--%s=%s' % (key, str(args_dict[key])))
if 'use_gpu' in kwargs:
cp.g_command_config_args['use_gpu'] = kwargs['use_gpu']
assert 'parallel_nn' not in kwargs, ("currently 'parallel_nn' is not "
"supported in v2 APIs.")
api.initPaddle(*args) api.initPaddle(*args)
......
...@@ -242,9 +242,9 @@ def gen_list(querylist): ...@@ -242,9 +242,9 @@ def gen_list(querylist):
if not isinstance(querylist, QueryList): if not isinstance(querylist, QueryList):
querylist = QueryList(querylist) querylist = QueryList(querylist)
querylist._correct_ranking_() querylist._correct_ranking_()
relevance_score_list = [query.relevance_score for query in querylist] relevance_score_list = [[query.relevance_score] for query in querylist]
feature_vector_list = [query.feature_vector for query in querylist] feature_vector_list = [query.feature_vector for query in querylist]
yield np.array(relevance_score_list).T, np.array(feature_vector_list) yield np.array(relevance_score_list), np.array(feature_vector_list)
def query_filter(querylists): def query_filter(querylists):
......
...@@ -220,6 +220,9 @@ def create_op_creation_method(op_proto): ...@@ -220,6 +220,9 @@ def create_op_creation_method(op_proto):
__impl__.all_input_args = [var.name for var in op_proto.inputs] __impl__.all_input_args = [var.name for var in op_proto.inputs]
__impl__.all_output_args = [var.name for var in op_proto.outputs] __impl__.all_output_args = [var.name for var in op_proto.outputs]
__impl__.all_attr_args = [attr.name for attr in op_proto.attrs] __impl__.all_attr_args = [attr.name for attr in op_proto.attrs]
__impl__.all_not_temp_output_args = [
var.name for var in op_proto.outputs if not var.temporary
]
return __impl__ return __impl__
......
import paddle.v2.framework.core as core
from paddle.v2.framework.create_op_creation_methods import op_creations
from default_scope_funcs import create_var, get_var, get_cur_scope
__all__ = ['Network'] # Only expose Network
class NetworkFunctor(object):
"""
Network Op Creation Function. Used internally in this module.
It convert string input to Variable. If it is not created before, just
create in scope.
It is a functor object. means the instances are callable.
:param func: The op creation function which generated in Python.
:param net: The Network instance.
"""
def __init__(self, func, net):
self.func = func
self.net = net
def __call__(self, *args, **kwargs):
if len(args) != 0:
raise ValueError("Paddle must use keyword argument")
inputs = self.func.all_input_args
for ipt in inputs:
if ipt in kwargs:
var = kwargs[ipt]
if isinstance(var, basestring):
var = create_var(var)
if not isinstance(var, core.Variable):
raise TypeError(
"Input of op creation must be string or variable")
kwargs[ipt] = get_cur_scope().get_var_name(var)
notemp_outputs = self.func.all_not_temp_output_args
for name in notemp_outputs:
if name not in kwargs:
kwargs[
name] = self.func.__name__ + "@OUT@%d" % core.unique_integer(
)
outputs = self.func.all_output_args
for opt in outputs:
if opt in kwargs:
var = kwargs[opt]
if isinstance(var, basestring):
var = create_var(var)
if not isinstance(var, core.Variable):
raise TypeError(
"Output of op creation must be string or variable")
kwargs[opt] = get_cur_scope().get_var_name(var)
op = self.func(**kwargs)
self.net.net.add_op(op)
lst = [get_var(kwargs[opt]) for opt in notemp_outputs]
if len(lst) == 1:
return lst[0]
elif len(lst) == 0:
return None
else:
return lst
class Network(object):
"""
The network concept. It avoid user to manually create operator, create
variable, and combine them into a Net. Just use Network.xxx can create the
operator, create variables in default scope, and add them into `self.net`.
For example:
.. code-block: python
net = Network()
out = net.add_two(X="a", Y="b")
fc_out = net.fc(X="out", W="fc.w")
net.run(...)
"""
def __init__(self):
self.net = core.Net.create()
funcs = (func_name for func_name in dir(op_creations)
if not func_name.startswith("__"))
# TODO(yuyang18): This code can work, but do not generate a good
# docstring, try to give a better way generate function in runtime
# later.
for func_name in funcs:
func = getattr(op_creations, func_name)
impl = NetworkFunctor(func, self)
setattr(self, func_name, impl.__call__)
self.__complete_add_op__ = False
def infer_shape(self):
self.complete_add_op()
self.net.infer_shape(get_cur_scope())
def run(self, device_context):
self.complete_add_op()
self.net.run(get_cur_scope(), device_context)
def __str__(self):
return str(self.net)
def complete_add_op(self):
if not self.__complete_add_op__:
self.net.complete_add_op()
self.__complete_add_op__ = True
if __name__ == '__main__':
net = Network()
out = net.add_two(X="a", Y="b")
fc_out = net.fc(X=out, W="fc.w", b="fc.b", activation="softmax")
net.complete_add_op()
print net
...@@ -3,7 +3,7 @@ add_python_test(test_framework ...@@ -3,7 +3,7 @@ add_python_test(test_framework
test_scope.py test_scope.py
test_default_scope_funcs.py test_default_scope_funcs.py
test_op_creation_methods.py test_op_creation_methods.py
test_plain_net.py test_net.py
test_tensor.py test_tensor.py
test_fc_op.py test_fc_op.py
test_add_two_op.py test_add_two_op.py
...@@ -12,4 +12,5 @@ add_python_test(test_framework ...@@ -12,4 +12,5 @@ add_python_test(test_framework
test_mul_op.py test_mul_op.py
test_sigmoid_op.py test_sigmoid_op.py
test_softmax_op.py test_softmax_op.py
test_rowwise_add_op.py) test_rowwise_add_op.py
test_network.py)
...@@ -5,11 +5,11 @@ import unittest ...@@ -5,11 +5,11 @@ import unittest
class TestNet(unittest.TestCase): class TestNet(unittest.TestCase):
def test_net_all(self): def test_net_all(self):
net = core.PlainNet.create() net = core.Net.create()
op1 = op_creations.add_two(X="X", Y="Y", Out="Out") op1 = op_creations.add_two(X="X", Y="Y", Out="Out")
net.add_op(op1) net.add_op(op1)
net2 = core.PlainNet.create() net2 = core.Net.create()
net2.add_op(op_creations.fc(X="X", W="w", Y="fc.out")) net2.add_op(op_creations.fc(X="X", W="w", Y="fc.out"))
net2.complete_add_op(True) net2.complete_add_op(True)
net.add_op(net2) net.add_op(net2)
......
from paddle.v2.framework.network import Network
import paddle.v2.framework.core as core
import unittest
class TestNet(unittest.TestCase):
def test_net_all(self):
net = Network()
out = net.add_two(X="X", Y="Y")
fc_out = net.fc(X=out, W="w")
net.complete_add_op()
self.assertTrue(isinstance(fc_out, core.Variable))
self.assertEqual(
'''Op(plain_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, add_two@OUT@0, fc@OUT@1).
Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@0).
Op(fc), inputs:(add_two@OUT@0, w, @EMPTY@), outputs:(fc@OUT@1, @TEMP@fc@0).
Op(mul), inputs:(add_two@OUT@0, w), outputs:(@TEMP@fc@0).
Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc@OUT@1).
''', str(net))
net2 = Network()
tmp = net2.add_two(X="X", Y="Y")
self.assertTrue(isinstance(tmp, core.Variable))
net2.complete_add_op()
self.assertEqual(
'''Op(plain_net), inputs:(X, Y), outputs:(add_two@OUT@2).
Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@2).
''', str(net2))
if __name__ == '__main__':
unittest.main()
...@@ -324,6 +324,3 @@ def parse_network(output_layers, extra_layers=None): ...@@ -324,6 +324,3 @@ def parse_network(output_layers, extra_layers=None):
def get_layer(name): def get_layer(name):
return config_base.__layer_map__.get(name) return config_base.__layer_map__.get(name)
cp.begin_parse()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册