提交 2f82d72e 编写于 作者: Y Yu Yang 提交者: emailweixu

Fix bug in yield dictionary in DataProvider. (#197)

* Fix bug in yield dictionary in DataProvider.
* Also make virtualenv work in Paddle.
上级 e4952ca6
...@@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 2.8) ...@@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 2.8)
project(paddle CXX C) project(paddle CXX C)
set(PADDLE_MAJOR_VERSION 0) set(PADDLE_MAJOR_VERSION 0)
set(PADDLE_MINOR_VERSION 8) set(PADDLE_MINOR_VERSION 8)
set(PADDLE_PATCH_VERSION 0b1) set(PADDLE_PATCH_VERSION 0b2)
set(PADDLE_VERSION ${PADDLE_MAJOR_VERSION}.${PADDLE_MINOR_VERSION}.${PADDLE_PATCH_VERSION}) set(PADDLE_VERSION ${PADDLE_MAJOR_VERSION}.${PADDLE_MINOR_VERSION}.${PADDLE_PATCH_VERSION})
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake") set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake")
......
...@@ -184,3 +184,20 @@ macro(add_paddle_culib TARGET_NAME) ...@@ -184,3 +184,20 @@ macro(add_paddle_culib TARGET_NAME)
cuda_add_library(${TARGET_NAME} STATIC ${ARGN}) cuda_add_library(${TARGET_NAME} STATIC ${ARGN})
set(CUDA_NVCC_FLAGS ${NVCC_FLAG}) set(CUDA_NVCC_FLAGS ${NVCC_FLAG})
endmacro() endmacro()
# Creates C resources file from files in given resource file
function(create_resources res_file output)
# Create empty output file
file(WRITE ${output} "")
# Get short filename
string(REGEX MATCH "([^/]+)$" filename ${res_file})
# Replace filename spaces & extension separator for C compatibility
string(REGEX REPLACE "\\.| |-" "_" filename ${filename})
# Read hex data from file
file(READ ${res_file} filedata HEX)
# Convert hex data for C compatibility
string(REGEX REPLACE "([0-9a-f][0-9a-f])" "0x\\1," filedata ${filedata})
# Append data to output file
file(APPEND ${output} "const unsigned char ${filename}[] = {${filedata}};\nconst unsigned ${filename}_size = sizeof(${filename});\n")
endfunction()
文件模式从 100644 更改为 100755
...@@ -2,10 +2,10 @@ from paddle.trainer.PyDataProvider2 import * ...@@ -2,10 +2,10 @@ from paddle.trainer.PyDataProvider2 import *
# Define a py data provider # Define a py data provider
@provider(input_types=[ @provider(input_types={
dense_vector(28 * 28), 'pixel': dense_vector(28 * 28),
integer_value(10) 'label': integer_value(10)
]) })
def process(settings, filename): # settings is not used currently. def process(settings, filename): # settings is not used currently.
imgf = filename + "-images-idx3-ubyte" imgf = filename + "-images-idx3-ubyte"
labelf = filename + "-labels-idx1-ubyte" labelf = filename + "-labels-idx1-ubyte"
...@@ -14,20 +14,19 @@ def process(settings, filename): # settings is not used currently. ...@@ -14,20 +14,19 @@ def process(settings, filename): # settings is not used currently.
f.read(16) f.read(16)
l.read(8) l.read(8)
# Define number of samples for train/test # Define number of samples for train/test
if "train" in filename: if "train" in filename:
n = 60000 n = 60000
else: else:
n = 10000 n = 10000
for i in range(n): for i in range(n):
label = ord(l.read(1)) label = ord(l.read(1))
pixels = [] pixels = []
for j in range(28*28): for j in range(28 * 28):
pixels.append(float(ord(f.read(1))) / 255.0) pixels.append(float(ord(f.read(1))) / 255.0)
yield { "pixel": pixels, 'label': label } yield {"pixel": pixels, 'label': label}
f.close() f.close()
l.close() l.close()
\ No newline at end of file
...@@ -47,6 +47,7 @@ predict = small_vgg(input_image=img, ...@@ -47,6 +47,7 @@ predict = small_vgg(input_image=img,
if not is_predict: if not is_predict:
lbl = data_layer(name="label", size=label_size) lbl = data_layer(name="label", size=label_size)
inputs(img, lbl)
outputs(classification_cost(input=predict, label=lbl)) outputs(classification_cost(input=predict, label=lbl))
else: else:
outputs(predict) outputs(predict)
...@@ -2,10 +2,10 @@ from paddle.trainer.PyDataProvider2 import * ...@@ -2,10 +2,10 @@ from paddle.trainer.PyDataProvider2 import *
# Define a py data provider # Define a py data provider
@provider(input_types=[ @provider(input_types={
dense_vector(28 * 28), 'pixel': dense_vector(28 * 28),
integer_value(10) 'label': integer_value(10)
]) })
def process(settings, filename): # settings is not used currently. def process(settings, filename): # settings is not used currently.
f = open(filename, 'r') # open one of training file f = open(filename, 'r') # open one of training file
...@@ -20,6 +20,6 @@ def process(settings, filename): # settings is not used currently. ...@@ -20,6 +20,6 @@ def process(settings, filename): # settings is not used currently.
pixels_float.append(float(each_pixel_str)) pixels_float.append(float(each_pixel_str))
# give data to paddle. # give data to paddle.
yield { "pixel": pixels_float, 'label': int(label) } yield {"pixel": pixels_float, 'label': int(label)}
f.close() # close file f.close() # close file
...@@ -141,8 +141,6 @@ DataProvider创建的时候执行。这个初始化函数具有如下参数: ...@@ -141,8 +141,6 @@ DataProvider创建的时候执行。这个初始化函数具有如下参数:
是一个batch size,但是有时为了计算均衡性,可以将一条数据设置成多个batch size 是一个batch size,但是有时为了计算均衡性,可以将一条数据设置成多个batch size
* cache 是数据缓存的策略,参考 `cache`_ * cache 是数据缓存的策略,参考 `cache`_
* init_hook 是初始化时调用的函数,参考 `init_hook`_ * init_hook 是初始化时调用的函数,参考 `init_hook`_
* use_dynamic_order 如果是true的话,可以返回一个dict,key是data_layer的名字,value是特征值。同时,也可以
返回一个list或者tuple。如果是false的话,只能够返回list或者tuple
* check 设置成true的话,会根据input_types检查数据的合法性。 * check 设置成true的话,会根据input_types检查数据的合法性。
* check_fail_continue 如果设置成true的话,即使在check中数据不合法,也会扔到这条数据,继续训练。 如果 * check_fail_continue 如果设置成true的话,即使在check中数据不合法,也会扔到这条数据,继续训练。 如果
check是false的话,没有作用。 check是false的话,没有作用。
......
...@@ -246,8 +246,7 @@ private: ...@@ -246,8 +246,7 @@ private:
PyObjectPtr && kwargs) { PyObjectPtr && kwargs) {
LOG(INFO) << "loading dataprovider " << model <<"::" << className; LOG(INFO) << "loading dataprovider " << model <<"::" << className;
PyObjectPtr module(PyImport_ImportModule(model.c_str())); PyObjectPtr module = py::import(model);
CHECK_PY(module) << "Cannot imort module " << model.c_str();
PyObjectPtr moduleDict(PyModule_GetDict(module.get())); PyObjectPtr moduleDict(PyModule_GetDict(module.get()));
CHECK_PY(moduleDict) << "Invoke module.__dict__ error"; CHECK_PY(moduleDict) << "Invoke module.__dict__ error";
PyObjectPtr cls(PyDict_GetItemString(moduleDict.get(), PyObjectPtr cls(PyDict_GetItemString(moduleDict.get(),
......
...@@ -117,7 +117,7 @@ TEST(PyDataProvider2, index_no_seq) { ...@@ -117,7 +117,7 @@ TEST(PyDataProvider2, index_no_seq) {
} }
TEST(PyDataProvider2, init_hook) { TEST(PyDataProvider2, init_hook) {
paddle::PyObjectPtr pickle(PyImport_ImportModule("pickle")); paddle::PyObjectPtr pickle = paddle::py::import("pickle");
paddle::PyObjectPtr globals( paddle::PyObjectPtr globals(
PyModule_GetDict(PyImport_AddModule("__main__"))); PyModule_GetDict(PyImport_AddModule("__main__")));
PyDict_SetItemString(globals.get(), "pickle", pickle.get()); PyDict_SetItemString(globals.get(), "pickle", pickle.get());
......
...@@ -86,7 +86,7 @@ def test_can_over_batch_size(setting, filename): ...@@ -86,7 +86,7 @@ def test_can_over_batch_size(setting, filename):
yield [random.randint(0, 100 - 1) for _ in xrange(seq_len)] yield [random.randint(0, 100 - 1) for _ in xrange(seq_len)]
@provider(input_types=[index_slot(10), index_slot(10)]) @provider(input_types={'input1':index_slot(10), 'input2': index_slot(10)})
def test_input_order(setting, filename): def test_input_order(setting, filename):
for _ in xrange(1000): for _ in xrange(1000):
yield { yield {
......
enable_virtualenv.c
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
file(GLOB UTIL_HEADERS . *.h) file(GLOB UTIL_HEADERS . *.h)
file(GLOB UTIL_SOURCES . *.cpp) file(GLOB UTIL_SOURCES . *.cpp)
create_resources(enable_virtualenv.py enable_virtualenv.c)
set(UTIL_RES enable_virtualenv.c)
if(APPLE) if(APPLE)
file(GLOB UTIL_ARCH_SOURCES . arch/osx/*.cpp) file(GLOB UTIL_ARCH_SOURCES . arch/osx/*.cpp)
else() else()
...@@ -9,7 +12,8 @@ else() ...@@ -9,7 +12,8 @@ else()
endif() endif()
add_library(paddle_utils STATIC add_library(paddle_utils STATIC
${UTIL_SOURCES} ${UTIL_SOURCES}
${UTIL_ARCH_SOURCES}) ${UTIL_ARCH_SOURCES}
${UTIL_RES})
add_style_check_target(paddle_utils ${UTIL_HEADERS}) add_style_check_target(paddle_utils ${UTIL_HEADERS})
add_style_check_target(paddle_utils ${UTIL_SOURCES} add_style_check_target(paddle_utils ${UTIL_SOURCES}
${UTIL_ARCH_SOURCES}) ${UTIL_ARCH_SOURCES})
......
...@@ -77,11 +77,18 @@ static std::recursive_mutex g_pyMutex; ...@@ -77,11 +77,18 @@ static std::recursive_mutex g_pyMutex;
PyGuard::PyGuard() : guard_(g_pyMutex) {} PyGuard::PyGuard() : guard_(g_pyMutex) {}
static void printPyErrorStack(std::ostream& os, bool withEndl = false) { static void printPyErrorStack(std::ostream& os, bool withEndl = false,
bool withPyPath = true) {
PyObject * ptype, *pvalue, *ptraceback; PyObject * ptype, *pvalue, *ptraceback;
PyErr_Fetch(&ptype, &pvalue, &ptraceback); PyErr_Fetch(&ptype, &pvalue, &ptraceback);
PyErr_NormalizeException(&ptype, &pvalue, &ptraceback); PyErr_NormalizeException(&ptype, &pvalue, &ptraceback);
PyErr_Clear(); PyErr_Clear();
if (withPyPath) {
os << "Current PYTHONPATH: " << py::repr(PySys_GetObject(strdup("path")));
if (withEndl) {
os << std::endl;
}
}
PyTracebackObject* obj = (PyTracebackObject*)ptraceback; PyTracebackObject* obj = (PyTracebackObject*)ptraceback;
os << "Python Error: " << PyString_AsString(PyObject_Str(ptype)) os << "Python Error: " << PyString_AsString(PyObject_Str(ptype))
...@@ -114,10 +121,7 @@ PyObjectPtr callPythonFuncRetPyObj(const std::string& moduleName, ...@@ -114,10 +121,7 @@ PyObjectPtr callPythonFuncRetPyObj(const std::string& moduleName,
const std::string& funcName, const std::string& funcName,
const std::vector<std::string>& args) { const std::vector<std::string>& args) {
PyGuard guard; PyGuard guard;
PyObjectPtr pyModuleName(PyString_FromString(moduleName.c_str())); PyObjectPtr pyModule = py::import(moduleName);
CHECK_PY(pyModuleName) << "Import PyModule failed" << moduleName;
PyObjectPtr pyModule(PyImport_Import(pyModuleName.get()));
CHECK_PY(pyModule) << "Import Python Module"<< moduleName << " failed.";
PyObjectPtr pyFunc(PyObject_GetAttrString(pyModule.get(), funcName.c_str())); PyObjectPtr pyFunc(PyObject_GetAttrString(pyModule.get(), funcName.c_str()));
CHECK_PY(pyFunc) << "GetAttrString failed."; CHECK_PY(pyFunc) << "GetAttrString failed.";
PyObjectPtr pyArgs(PyTuple_New(args.size())); PyObjectPtr pyArgs(PyTuple_New(args.size()));
...@@ -143,7 +147,7 @@ PyObjectPtr createPythonClass( ...@@ -143,7 +147,7 @@ PyObjectPtr createPythonClass(
const std::vector<std::string>& args, const std::vector<std::string>& args,
const std::map<std::string, std::string>& kwargs) { const std::map<std::string, std::string>& kwargs) {
PyGuard guard; PyGuard guard;
PyObjectPtr pyModule(PyImport_ImportModule(moduleName.c_str())); PyObjectPtr pyModule = py::import(moduleName);
LOG(INFO) << "createPythonClass moduleName.c_str:" << moduleName.c_str(); LOG(INFO) << "createPythonClass moduleName.c_str:" << moduleName.c_str();
CHECK_PY(pyModule) << "Import module " << moduleName << " failed."; CHECK_PY(pyModule) << "Import module " << moduleName << " failed.";
PyObjectPtr pyDict(PyModule_GetDict(pyModule.get())); PyObjectPtr pyDict(PyModule_GetDict(pyModule.get()));
...@@ -181,18 +185,29 @@ std::string getPyCallStack() { ...@@ -181,18 +185,29 @@ std::string getPyCallStack() {
printPyErrorStack(os, true); printPyErrorStack(os, true);
return os.str(); return os.str();
} }
PyObjectPtr import(const std::string &moduleName) {
auto module = PyImport_ImportModule(moduleName.c_str());
CHECK_PY(module) << "Import " << moduleName << "Error";
return PyObjectPtr(module);
}
} // namespace py } // namespace py
#endif #endif
extern "C" {
extern const char enable_virtualenv_py[];
}
void initPython(int argc, char** argv) { void initPython(int argc, char** argv) {
#ifndef PADDLE_NO_PYTHON #ifndef PADDLE_NO_PYTHON
Py_SetProgramName(argv[0]); Py_SetProgramName(argv[0]);
Py_Initialize(); Py_Initialize();
PySys_SetArgv(argc, argv); PySys_SetArgv(argc, argv);
// python blocks SIGINT. Need to enable it. // python blocks SIGINT. Need to enable it.
signal(SIGINT, SIG_DFL); signal(SIGINT, SIG_DFL);
// Manually activate virtualenv when user is using virtualenv
PyRun_SimpleString(enable_virtualenv_py);
#endif #endif
} }
......
...@@ -87,6 +87,8 @@ PyObjectPtr createPythonClass(const std::string& moduleName, ...@@ -87,6 +87,8 @@ PyObjectPtr createPythonClass(const std::string& moduleName,
CHECK((x) != nullptr) << ::paddle::py::getPyCallStack() CHECK((x) != nullptr) << ::paddle::py::getPyCallStack()
namespace py { namespace py {
PyObjectPtr import(const std::string& moduleName);
/** /**
* Cast a PyLong or PyInt to int type T. * Cast a PyLong or PyInt to int type T.
* @tparam T return type. * @tparam T return type.
......
import os
def __activate_virtual_env__():
__path__ = os.getenv('VIRTUAL_ENV')
if __path__ is None:
return
__script__ = os.path.join(__path__, 'bin', 'activate_this.py')
execfile(__script__, {'__file__': __script__})
__activate_virtual_env__()
...@@ -208,7 +208,6 @@ def provider(input_types=None, should_shuffle=None, pool_size=-1, ...@@ -208,7 +208,6 @@ def provider(input_types=None, should_shuffle=None, pool_size=-1,
calc_batch_size=None, calc_batch_size=None,
cache=CacheType.NO_CACHE, cache=CacheType.NO_CACHE,
check=False, check_fail_continue=False, check=False, check_fail_continue=False,
use_dynamic_order=True,
init_hook=None, **kwargs): init_hook=None, **kwargs):
""" """
Provider decorator. Use it to make a function into PyDataProvider2 object. Provider decorator. Use it to make a function into PyDataProvider2 object.
...@@ -228,9 +227,15 @@ def provider(input_types=None, should_shuffle=None, pool_size=-1, ...@@ -228,9 +227,15 @@ def provider(input_types=None, should_shuffle=None, pool_size=-1,
The configuration of data provider should be setup by\: The configuration of data provider should be setup by\:
:param input_types: Specify the input types, can also be set in init_hook. :param input_types: Specify the input types, can also be set in init_hook.
It is a list of InputType object. For example, input_types= \ It could be a list of InputType object. For example,
[dense_vector(9), integer_value(2)]. input_types=[dense_vector(9), integer_value(2)]. Or user
:type input_types: list|tuple can set a dict of InputType object, which key is
data_layer's name. For example, input_types=\
{'img': img_features, 'label': label}. when using dict of
InputType, user could yield a dict of feature values, which
key is also data_layer's name.
:type input_types: list|tuple|dict
:param should_shuffle: True if data should shuffle. Pass None means shuffle :param should_shuffle: True if data should shuffle. Pass None means shuffle
when is training and not to shuffle when is testing. when is training and not to shuffle when is testing.
...@@ -281,12 +286,6 @@ def provider(input_types=None, should_shuffle=None, pool_size=-1, ...@@ -281,12 +286,6 @@ def provider(input_types=None, should_shuffle=None, pool_size=-1,
drop the wrong format data when it is True. Has drop the wrong format data when it is True. Has
no effect when check set to False. no effect when check set to False.
:type check_fail_continue: bool :type check_fail_continue: bool
:param use_dynamic_order: Allow provider to yield a dictionary object, whose
key is a input data layer name, and value is the
feature value. The tuples are still allowed when
use_dynmaic_order is True.
:type use_dynamic_order: bool
""" """
def __wrapper__(generator): def __wrapper__(generator):
...@@ -340,6 +339,11 @@ def provider(input_types=None, should_shuffle=None, pool_size=-1, ...@@ -340,6 +339,11 @@ def provider(input_types=None, should_shuffle=None, pool_size=-1,
assert self.slots is not None assert self.slots is not None
assert self.generator is not None assert self.generator is not None
use_dynamic_order = False
if isinstance(self.slots, dict): # reorder input_types
self.slots = [self.slots[ipt] for ipt in self.input_order]
use_dynamic_order = True
if len(self.slots) == 1: if len(self.slots) == 1:
self.generator = SingleSlotWrapper(self.generator) self.generator = SingleSlotWrapper(self.generator)
......
...@@ -216,6 +216,10 @@ def Inputs(*args): ...@@ -216,6 +216,10 @@ def Inputs(*args):
if g_current_submodel is g_root_submodel: if g_current_submodel is g_root_submodel:
g_config.model_config.input_layer_names.append(name) g_config.model_config.input_layer_names.append(name)
@config_func
def HasInputsSet():
return len(g_config.model_config.input_layer_names) != 0
# Define the name of the output layers of the NeuralNetwork. # Define the name of the output layers of the NeuralNetwork.
# Usually the output is simply the cost layer. # Usually the output is simply the cost layer.
......
...@@ -30,7 +30,7 @@ __all__ = ['sequence_conv_pool', 'simple_lstm', "simple_img_conv_pool", ...@@ -30,7 +30,7 @@ __all__ = ['sequence_conv_pool', 'simple_lstm', "simple_img_conv_pool",
'lstmemory_unit', 'small_vgg', 'img_conv_group', 'vgg_16_network', 'lstmemory_unit', 'small_vgg', 'img_conv_group', 'vgg_16_network',
'gru_unit', 'gru_group', 'simple_gru', 'simple_attention', 'gru_unit', 'gru_group', 'simple_gru', 'simple_attention',
'text_conv_pool', 'text_conv_pool',
'bidirectional_lstm', 'outputs'] 'bidirectional_lstm', 'inputs', 'outputs']
###################################################### ######################################################
...@@ -372,8 +372,8 @@ def small_vgg(input_image, num_channels, num_classes): ...@@ -372,8 +372,8 @@ def small_vgg(input_image, num_channels, num_classes):
tmp = __vgg__(tmp, 128, 2, [0.4, 0]) tmp = __vgg__(tmp, 128, 2, [0.4, 0])
tmp = __vgg__(tmp, 256, 3, [0.4, 0.4, 0]) tmp = __vgg__(tmp, 256, 3, [0.4, 0.4, 0])
tmp = __vgg__(tmp, 512, 3, [0.4, 0.4, 0]) tmp = __vgg__(tmp, 512, 3, [0.4, 0.4, 0])
tmp = img_pool_layer(input = tmp, stride = 2, tmp = img_pool_layer(input=tmp, stride=2,
pool_size = 2, pool_type = MaxPooling()) pool_size=2, pool_type=MaxPooling())
tmp = dropout_layer(input=tmp, dropout_rate=0.5) tmp = dropout_layer(input=tmp, dropout_rate=0.5)
tmp = fc_layer(input=tmp, size=512, layer_attr=ExtraAttr(drop_rate=0.5), tmp = fc_layer(input=tmp, size=512, layer_attr=ExtraAttr(drop_rate=0.5),
act=LinearActivation()) act=LinearActivation())
...@@ -745,7 +745,6 @@ def gru_group(input, ...@@ -745,7 +745,6 @@ def gru_group(input,
gru_bias_attr=None, gru_bias_attr=None,
act=None, gate_act=None, act=None, gate_act=None,
gru_layer_attr=None): gru_layer_attr=None):
""" """
gru_group is a recurrent layer group version Gated Recurrent Unit. It gru_group is a recurrent layer group version Gated Recurrent Unit. It
does exactly the same calculation as the grumemory layer does. A promising does exactly the same calculation as the grumemory layer does. A promising
...@@ -919,12 +918,12 @@ def bidirectional_lstm(input, size, name=None, return_seq=False, ...@@ -919,12 +918,12 @@ def bidirectional_lstm(input, size, name=None, return_seq=False,
fw = simple_lstm(name='%s_fw' % name, input=input, size=size, fw = simple_lstm(name='%s_fw' % name, input=input, size=size,
**dict((k[len('fwd_'):], v) for k, v in args.iteritems() **dict((k[len('fwd_'):], v) for k, v in args.iteritems()
if k.startswith('fwd_'))) if k.startswith('fwd_')))
bw = simple_lstm(name="%s_bw" % name, input=input, size=size, bw = simple_lstm(name="%s_bw" % name, input=input, size=size,
reverse=True, reverse=True,
**dict((k[len('bwd_'):], v) for k, v in args.iteritems() **dict((k[len('bwd_'):], v) for k, v in args.iteritems()
if k.startswith('bwd_'))) if k.startswith('bwd_')))
if return_seq: if return_seq:
return concat_layer(name=name, input=[fw, bw], layer_attr=concat_attr, return concat_layer(name=name, input=[fw, bw], layer_attr=concat_attr,
...@@ -1052,14 +1051,30 @@ def dropout_layer(input, dropout_rate, name=None): ...@@ -1052,14 +1051,30 @@ def dropout_layer(input, dropout_rate, name=None):
layer_attr=ExtraAttr(drop_rate=dropout_rate)) layer_attr=ExtraAttr(drop_rate=dropout_rate))
def outputs(layers, *args): def inputs(layers, *args):
"""
Declare the inputs of network. The order of input should be as same as
the data provider's return order.
:param layers: Input Layers.
:type layers: list|tuple|LayerOutput.
:return:
""" """
Declare the end of network. Currently it will only calculate the
input/output order of network. It will calculate the predict network or
train network's output automatically.
if isinstance(layers, LayerOutput) or isinstance(layers, basestring):
layers = [layers]
if len(args) != 0:
layers.extend(args)
:param layers: Inputs(*[l.name for l in layers])
def outputs(layers, *args):
"""
Declare the outputs of network. If user have not defined the inputs of
network, this method will calculate the input order by dfs travel.
:param layers: Output layers.
:type layers: list|tuple|LayerOutput :type layers: list|tuple|LayerOutput
:return: :return:
""" """
...@@ -1093,6 +1108,11 @@ def outputs(layers, *args): ...@@ -1093,6 +1108,11 @@ def outputs(layers, *args):
layers.extend(args) layers.extend(args)
assert len(layers) > 0 assert len(layers) > 0
if HasInputsSet(): # input already set
Outputs(*[l.name for l in layers])
return # just return outputs.
if len(layers) != 1: if len(layers) != 1:
logger.warning("`outputs` routine try to calculate network's" logger.warning("`outputs` routine try to calculate network's"
" inputs and outputs order. It might not work well." " inputs and outputs order. It might not work well."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册