提交 2f82d72e 编写于 作者: Y Yu Yang 提交者: emailweixu

Fix bug in yield dictionary in DataProvider. (#197)

* Fix bug in yield dictionary in DataProvider.
* Also make virtualenv work in Paddle.
上级 e4952ca6
develop 2.0.1-rocm-post Ligoml-patch-1 OliverLPH-patch-1 OliverLPH-patch-2 PaddlePM-patch-1 PaddlePM-patch-2 ZHUI-patch-1 add_default_att add_model_benchmark_ci add_some_yaml_config addfile all_new_design_exec ascendrc ascendrelease cherry_undefined_var compile_windows delete_2.0.1-rocm-post delete_add_default_att delete_all_new_design_exec delete_ascendrc delete_compile_windows delete_delete_addfile delete_disable_iterable_dataset_unittest delete_fix_dataloader_memory_leak delete_fix_imperative_dygraph_error delete_fix_retry_ci delete_fix_undefined_var delete_improve_sccache delete_incubate/lite delete_paddle_tiny_install delete_paralleltest delete_prv-disable-more-cache delete_revert-31068-fix_conv3d_windows delete_revert-31562-mean delete_revert-33630-bug-fix delete_revert-34159-add_npu_bce_logical_dev delete_revert-34910-spinlocks_for_allocator delete_revert-35069-revert-34910-spinlocks_for_allocator delete_revert-36057-dev/read_flags_in_ut dingjiaweiww-patch-1 disable_iterable_dataset_unittest dy2static enable_eager_model_test final_state_gen_python_c final_state_intermediate fix-numpy-issue fix_concat_slice fix_dataloader_memory_leak fix_imperative_dygraph_error fix_npu_ci fix_op_flops fix_retry_ci fix_rnn_docs fix_tensor_type fix_undefined_var fixiscan fixiscan1 fixiscan2 fixiscan3 github/fork/123malin/netifaces github/fork/123malin/tdm_abacus github/fork/AshburnLee/dev_unique github/fork/ForFishes/fix_memory_matmul github/fork/ForFishes/rm_fluid github/fork/LielinJiang/move-2.0-api github/fork/LielinJiang/visual-dl-cb github/fork/LiuChiachi/add-transformer-generate-square-subsequent-mask-api github/fork/LiuChiachi/fix-example-code-for-hapi-Model github/fork/LiuChiachi/remove-input-requirment-in-dygraph-Model github/fork/MrChengmo/fix_ps_profiler github/fork/MrChengmo/update_ps_heter github/fork/PWhiddy/patch-1 github/fork/Shixiaowei02/dev/save_load_upgrade github/fork/TCChenlong/fix_hapi github/fork/TCChenlong/fix_inden github/fork/Thunderbrook/xpu_slice github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var_2 github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var_3 github/fork/XieYunshen/timeout_20S_ut github/fork/ZeyuChen/remove-nltk github/fork/arlesniak/arlesniak/selective__mkldnn_flags github/fork/baiyfbupt/code_doc_mig github/fork/chalsliu/set_timeout github/fork/chen-zhiyu/develop github/fork/chenwhql/ci/try_to_find_test_buffer_shared_memory_reuse_pass_error github/fork/chenwhql/dygraph/remove_scale_loss_and_apply_collective_grads github/fork/chenwhql/saveload/add_get_inference_program github/fork/chenwhql/saveload/remove_save_load_config github/fork/cryoco/pass-compatibility-trt github/fork/danleifeng/isempty_api2.0 github/fork/frankwhzhang/api_transfer github/fork/hbwx24/error_msg/cuda_kernel_error_msg github/fork/heavengate/cherry_yolo_box github/fork/heavengate/update_yolo_box github/fork/iclementine/rnn_fix github/fork/iducn/testestse github/fork/jczaja/prv-25537-fix github/fork/jeff41404/release/1.8 github/fork/jiweibo/api_2.0 github/fork/jiweibo/fix_lite_resnet50_test github/fork/juncaipeng/fix_doc_1 github/fork/lfchener/sample_code github/fork/littletomatodonkey/fix_reg_doc github/fork/liym27/dy2stat_update_assign_to_rc20 github/fork/luotao1/profiler_ut github/fork/mapingshuo/add_wait github/fork/mapingshuo/doc_2.0 github/fork/mapingshuo/zero-0.5 github/fork/miraiwk/dev github/fork/pangyoki/add-Categorical-class-branch github/fork/pangyoki/add-multinomial-op-branch github/fork/pangyoki/fix-test_distritbution-CI github/fork/qjing666/doublegrad github/fork/qjing666/fix_hdfs_download github/fork/sandyhouse/add_gather_etc github/fork/sandyhouse/add_send_recv_alltoall_etc github/fork/sandyhouse/pipeline_exe_run github/fork/seiriosPlus/feature/large_scale_kv_save_delta github/fork/seiriosPlus/fix/paddle_errors_fix github/fork/seiriosPlus/fix/paddle_op_errors github/fork/shangzhizhou/fix_test_activation_op_random_bug github/fork/smallv0221/yxp0924 github/fork/smallv0221/yxp0925 github/fork/swtkiwi/del-matplotlib github/fork/tianshuo78520a/kunlun_test github/fork/tianshuo78520a/update_dockerfile github/fork/wanghaoshuang/bert_fuse github/fork/wanghaoshuang/label_smooth github/fork/wanghuancoder/develop_CUDASynchronize github/fork/wanghuancoder/develop_Layer_doc github/fork/wanghuancoder/develop_ParameterList_doc github/fork/wanghuancoder/develop_Sequential_doc github/fork/wanghuancoder/develop_bilinear_tensor_product github/fork/wanghuancoder/develop_coverage_build_sh github/fork/wanghuancoder/develop_in_dynamic_mode_doc github/fork/wanghuancoder/develop_unique_name_doc github/fork/wangxicoding/fleet_meta_combine github/fork/wawltor/error_message_fix_5 github/fork/willthefrog/remove_l2_norm github/fork/windstamp/momentum_op github/fork/windstamp/mv_op_5 github/fork/windstamp/normal_api github/fork/wojtuss/wojtuss/fusion_gru_quantization github/fork/wojtuss/wojtuss/quantization-with-shift github/fork/wzzju/fix_err_info github/fork/wzzju/pure_fp16 github/fork/xiemoyuan/op_error_message github/fork/xiemoyuan/optimize_error_message github/fork/yaoxuefeng6/fix_doc github/fork/yaoxuefeng6/mod_dataset_v2 github/fork/yongqiangma/lod github/fork/ysh329/fix-clip-by-norm-error github/fork/ysh329/fix-error-clip-by-value github/fork/yukavio/error_info github/fork/zhangting2020/conv_filter_grad github/fork/zhangting2020/is_compile_with_cuda github/fork/zhangting2020/place_doc github/fork/zhangting2020/program github/fork/zhhsplendid/fix_any github/fork/zhhsplendid/refine_api2 github/fork/zhhsplendid/refine_api2_test github/fork/zhhsplendid/refine_api_test_ptb_lm github/fork/zhhsplendid/refine_api_test_resnet github/fork/zhhsplendid/refine_api_test_simnet github/fork/zhiqiu/dev/refine_initializer github/fork/zhiqiu/dev/remove_inplace_argument github/fork/zlsh80826/nvinfer_plugin_var_len_cuda11 improve_sccache incubate/infrt incubate/lite inplace_addto make_flag_adding_easier master move_embedding_to_phi move_histogram_to_pten move_sgd_to_phi move_slice_to_pten move_temporal_shift_to_phi move_yolo_box_to_phi npu_fix_alloc numel paddle_tiny_install paralleltest preln_ernie prv-disable-more-cache prv-md-even-more prv-onednn-2.5 pten_tensor_refactor release/0.10.0 release/0.11.0 release/0.12.0 release/0.13.0 release/0.14.0 release/0.15.0 release/1.0.0 release/1.1 release/1.2 release/1.3 release/1.4 release/1.5 release/1.6 release/1.7 release/1.8 release/2.0 release/2.0-alpha release/2.0-beta release/2.0-rc release/2.0-rc1 release/2.1 release/2.2 release/2.3 release/2.3-fc-ernie-fix release/2.4 release/lite-0.1 revert-24981-add_device_attr_for_regulization revert-26856-strategy_example2 revert-27520-disable_pr revert-31068-fix_conv3d_windows revert-31562-mean revert-32290-develop-hardlabel revert-33037-forci revert-33475-fix_cifar_label_dimension revert-33630-bug-fix revert-34159-add_npu_bce_logical_dev revert-34406-add_copy_from_tensor revert-34910-spinlocks_for_allocator revert-35069-revert-34910-spinlocks_for_allocator revert-36057-dev/read_flags_in_ut revert-36201-refine_fast_threaded_ssa_graph_executor revert-36985-add_license revert-37318-refactor_dygraph_to_eager revert-37926-eager_coreops_500 revert-37956-revert-37727-pylayer_support_tuple revert-38100-mingdong revert-38301-allocation_rearrange_pr revert-38703-numpy_bf16_package_reupload revert-38732-remove_useless_header_in_elementwise_mul_grad revert-38959-Reduce_Grad revert-39143-adjust_empty revert-39227-move_trace_op_to_pten revert-39268-dev/remove_concat_fluid_kernel revert-40170-support_partial_grad revert-41056-revert-40727-move_some_activaion_to_phi revert-41065-revert-40993-mv_ele_floordiv_pow revert-41068-revert-40790-phi_new revert-41944-smaller_inference_api_test revert-42149-do-not-reset-default-stream-for-stream-safe-cuda-allocator revert-43155-fix_ut_tempfile revert-43882-revert-41944-smaller_inference_api_test revert-45808-phi/simplify_size_op revert-46827-deform_comment rocm_dev_0217 support_weight_transpose test_benchmark_ci test_feature_precision_test_c test_model_benchmark test_model_benchmark_ci zhiqiu-patch-1 v2.4.0-rc0 v2.3.2 v2.3.1 v2.3.0 v2.3.0-rc0 v2.2.2 v2.2.1 v2.2.0 v2.2.0-rc0 v2.2.0-bak0 v2.1.3 v2.1.2 v2.1.1 v2.1.0 v2.1.0-rc0 v2.0.2 v2.0.1 v2.0.0 v2.0.0-rc1 v2.0.0-rc0 v2.0.0-beta0 v2.0.0-alpha0 v1.8.5 v1.8.4 v1.8.3 v1.8.2 v1.8.1 v1.8.0 v1.7.2 v1.7.1 v1.7.0 v1.6.3 v1.6.2 v1.6.1 v1.6.0 v1.6.0-rc0 v1.5.2 v1.5.1 v1.5.0 v1.4.1 v1.4.0 v1.3.2 v1.3.1 v1.3.0 v1.2.1 v1.2.0 v1.1.0 v1.0.2 v1.0.1 v1.0.0 v1.0.0-rc0 v0.15.0 v0.15.0-rc0 v0.14.0 v0.13.0 v0.12.0 v0.11.1a2 v0.11.1a1 v0.11.0 v0.10.0 v0.10.0rc4 v0.10.0rc v0.9.0 v0.9.0a0 lite-v0.1
无相关合并请求
......@@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 2.8)
project(paddle CXX C)
set(PADDLE_MAJOR_VERSION 0)
set(PADDLE_MINOR_VERSION 8)
set(PADDLE_PATCH_VERSION 0b1)
set(PADDLE_PATCH_VERSION 0b2)
set(PADDLE_VERSION ${PADDLE_MAJOR_VERSION}.${PADDLE_MINOR_VERSION}.${PADDLE_PATCH_VERSION})
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake")
......
......@@ -184,3 +184,20 @@ macro(add_paddle_culib TARGET_NAME)
cuda_add_library(${TARGET_NAME} STATIC ${ARGN})
set(CUDA_NVCC_FLAGS ${NVCC_FLAG})
endmacro()
# Creates C resources file from files in given resource file
function(create_resources res_file output)
# Create empty output file
file(WRITE ${output} "")
# Get short filename
string(REGEX MATCH "([^/]+)$" filename ${res_file})
# Replace filename spaces & extension separator for C compatibility
string(REGEX REPLACE "\\.| |-" "_" filename ${filename})
# Read hex data from file
file(READ ${res_file} filedata HEX)
# Convert hex data for C compatibility
string(REGEX REPLACE "([0-9a-f][0-9a-f])" "0x\\1," filedata ${filedata})
# Append data to output file
file(APPEND ${output} "const unsigned char ${filename}[] = {${filedata}};\nconst unsigned ${filename}_size = sizeof(${filename});\n")
endfunction()
文件模式从 100644 更改为 100755
......@@ -2,10 +2,10 @@ from paddle.trainer.PyDataProvider2 import *
# Define a py data provider
@provider(input_types=[
dense_vector(28 * 28),
integer_value(10)
])
@provider(input_types={
'pixel': dense_vector(28 * 28),
'label': integer_value(10)
})
def process(settings, filename): # settings is not used currently.
imgf = filename + "-images-idx3-ubyte"
labelf = filename + "-labels-idx1-ubyte"
......@@ -14,20 +14,19 @@ def process(settings, filename): # settings is not used currently.
f.read(16)
l.read(8)
# Define number of samples for train/test
if "train" in filename:
n = 60000
else:
n = 10000
for i in range(n):
label = ord(l.read(1))
pixels = []
for j in range(28*28):
for j in range(28 * 28):
pixels.append(float(ord(f.read(1))) / 255.0)
yield { "pixel": pixels, 'label': label }
yield {"pixel": pixels, 'label': label}
f.close()
l.close()
\ No newline at end of file
......@@ -47,6 +47,7 @@ predict = small_vgg(input_image=img,
if not is_predict:
lbl = data_layer(name="label", size=label_size)
inputs(img, lbl)
outputs(classification_cost(input=predict, label=lbl))
else:
outputs(predict)
......@@ -2,10 +2,10 @@ from paddle.trainer.PyDataProvider2 import *
# Define a py data provider
@provider(input_types=[
dense_vector(28 * 28),
integer_value(10)
])
@provider(input_types={
'pixel': dense_vector(28 * 28),
'label': integer_value(10)
})
def process(settings, filename): # settings is not used currently.
f = open(filename, 'r') # open one of training file
......@@ -20,6 +20,6 @@ def process(settings, filename): # settings is not used currently.
pixels_float.append(float(each_pixel_str))
# give data to paddle.
yield { "pixel": pixels_float, 'label': int(label) }
yield {"pixel": pixels_float, 'label': int(label)}
f.close() # close file
......@@ -141,8 +141,6 @@ DataProvider创建的时候执行。这个初始化函数具有如下参数:
是一个batch size,但是有时为了计算均衡性,可以将一条数据设置成多个batch size
* cache 是数据缓存的策略,参考 `cache`_
* init_hook 是初始化时调用的函数,参考 `init_hook`_
* use_dynamic_order 如果是true的话,可以返回一个dict,key是data_layer的名字,value是特征值。同时,也可以
返回一个list或者tuple。如果是false的话,只能够返回list或者tuple
* check 设置成true的话,会根据input_types检查数据的合法性。
* check_fail_continue 如果设置成true的话,即使在check中数据不合法,也会扔到这条数据,继续训练。 如果
check是false的话,没有作用。
......
......@@ -246,8 +246,7 @@ private:
PyObjectPtr && kwargs) {
LOG(INFO) << "loading dataprovider " << model <<"::" << className;
PyObjectPtr module(PyImport_ImportModule(model.c_str()));
CHECK_PY(module) << "Cannot imort module " << model.c_str();
PyObjectPtr module = py::import(model);
PyObjectPtr moduleDict(PyModule_GetDict(module.get()));
CHECK_PY(moduleDict) << "Invoke module.__dict__ error";
PyObjectPtr cls(PyDict_GetItemString(moduleDict.get(),
......
......@@ -117,7 +117,7 @@ TEST(PyDataProvider2, index_no_seq) {
}
TEST(PyDataProvider2, init_hook) {
paddle::PyObjectPtr pickle(PyImport_ImportModule("pickle"));
paddle::PyObjectPtr pickle = paddle::py::import("pickle");
paddle::PyObjectPtr globals(
PyModule_GetDict(PyImport_AddModule("__main__")));
PyDict_SetItemString(globals.get(), "pickle", pickle.get());
......
......@@ -86,7 +86,7 @@ def test_can_over_batch_size(setting, filename):
yield [random.randint(0, 100 - 1) for _ in xrange(seq_len)]
@provider(input_types=[index_slot(10), index_slot(10)])
@provider(input_types={'input1':index_slot(10), 'input2': index_slot(10)})
def test_input_order(setting, filename):
for _ in xrange(1000):
yield {
......
enable_virtualenv.c
......@@ -2,6 +2,9 @@
file(GLOB UTIL_HEADERS . *.h)
file(GLOB UTIL_SOURCES . *.cpp)
create_resources(enable_virtualenv.py enable_virtualenv.c)
set(UTIL_RES enable_virtualenv.c)
if(APPLE)
file(GLOB UTIL_ARCH_SOURCES . arch/osx/*.cpp)
else()
......@@ -9,7 +12,8 @@ else()
endif()
add_library(paddle_utils STATIC
${UTIL_SOURCES}
${UTIL_ARCH_SOURCES})
${UTIL_ARCH_SOURCES}
${UTIL_RES})
add_style_check_target(paddle_utils ${UTIL_HEADERS})
add_style_check_target(paddle_utils ${UTIL_SOURCES}
${UTIL_ARCH_SOURCES})
......
......@@ -77,11 +77,18 @@ static std::recursive_mutex g_pyMutex;
PyGuard::PyGuard() : guard_(g_pyMutex) {}
static void printPyErrorStack(std::ostream& os, bool withEndl = false) {
static void printPyErrorStack(std::ostream& os, bool withEndl = false,
bool withPyPath = true) {
PyObject * ptype, *pvalue, *ptraceback;
PyErr_Fetch(&ptype, &pvalue, &ptraceback);
PyErr_NormalizeException(&ptype, &pvalue, &ptraceback);
PyErr_Clear();
if (withPyPath) {
os << "Current PYTHONPATH: " << py::repr(PySys_GetObject(strdup("path")));
if (withEndl) {
os << std::endl;
}
}
PyTracebackObject* obj = (PyTracebackObject*)ptraceback;
os << "Python Error: " << PyString_AsString(PyObject_Str(ptype))
......@@ -114,10 +121,7 @@ PyObjectPtr callPythonFuncRetPyObj(const std::string& moduleName,
const std::string& funcName,
const std::vector<std::string>& args) {
PyGuard guard;
PyObjectPtr pyModuleName(PyString_FromString(moduleName.c_str()));
CHECK_PY(pyModuleName) << "Import PyModule failed" << moduleName;
PyObjectPtr pyModule(PyImport_Import(pyModuleName.get()));
CHECK_PY(pyModule) << "Import Python Module"<< moduleName << " failed.";
PyObjectPtr pyModule = py::import(moduleName);
PyObjectPtr pyFunc(PyObject_GetAttrString(pyModule.get(), funcName.c_str()));
CHECK_PY(pyFunc) << "GetAttrString failed.";
PyObjectPtr pyArgs(PyTuple_New(args.size()));
......@@ -143,7 +147,7 @@ PyObjectPtr createPythonClass(
const std::vector<std::string>& args,
const std::map<std::string, std::string>& kwargs) {
PyGuard guard;
PyObjectPtr pyModule(PyImport_ImportModule(moduleName.c_str()));
PyObjectPtr pyModule = py::import(moduleName);
LOG(INFO) << "createPythonClass moduleName.c_str:" << moduleName.c_str();
CHECK_PY(pyModule) << "Import module " << moduleName << " failed.";
PyObjectPtr pyDict(PyModule_GetDict(pyModule.get()));
......@@ -181,18 +185,29 @@ std::string getPyCallStack() {
printPyErrorStack(os, true);
return os.str();
}
PyObjectPtr import(const std::string &moduleName) {
auto module = PyImport_ImportModule(moduleName.c_str());
CHECK_PY(module) << "Import " << moduleName << "Error";
return PyObjectPtr(module);
}
} // namespace py
#endif
extern "C" {
extern const char enable_virtualenv_py[];
}
void initPython(int argc, char** argv) {
#ifndef PADDLE_NO_PYTHON
Py_SetProgramName(argv[0]);
Py_Initialize();
PySys_SetArgv(argc, argv);
// python blocks SIGINT. Need to enable it.
signal(SIGINT, SIG_DFL);
// Manually activate virtualenv when user is using virtualenv
PyRun_SimpleString(enable_virtualenv_py);
#endif
}
......
......@@ -87,6 +87,8 @@ PyObjectPtr createPythonClass(const std::string& moduleName,
CHECK((x) != nullptr) << ::paddle::py::getPyCallStack()
namespace py {
PyObjectPtr import(const std::string& moduleName);
/**
* Cast a PyLong or PyInt to int type T.
* @tparam T return type.
......
import os
def __activate_virtual_env__():
__path__ = os.getenv('VIRTUAL_ENV')
if __path__ is None:
return
__script__ = os.path.join(__path__, 'bin', 'activate_this.py')
execfile(__script__, {'__file__': __script__})
__activate_virtual_env__()
......@@ -208,7 +208,6 @@ def provider(input_types=None, should_shuffle=None, pool_size=-1,
calc_batch_size=None,
cache=CacheType.NO_CACHE,
check=False, check_fail_continue=False,
use_dynamic_order=True,
init_hook=None, **kwargs):
"""
Provider decorator. Use it to make a function into PyDataProvider2 object.
......@@ -228,9 +227,15 @@ def provider(input_types=None, should_shuffle=None, pool_size=-1,
The configuration of data provider should be setup by\:
:param input_types: Specify the input types, can also be set in init_hook.
It is a list of InputType object. For example, input_types= \
[dense_vector(9), integer_value(2)].
:type input_types: list|tuple
It could be a list of InputType object. For example,
input_types=[dense_vector(9), integer_value(2)]. Or user
can set a dict of InputType object, which key is
data_layer's name. For example, input_types=\
{'img': img_features, 'label': label}. when using dict of
InputType, user could yield a dict of feature values, which
key is also data_layer's name.
:type input_types: list|tuple|dict
:param should_shuffle: True if data should shuffle. Pass None means shuffle
when is training and not to shuffle when is testing.
......@@ -281,12 +286,6 @@ def provider(input_types=None, should_shuffle=None, pool_size=-1,
drop the wrong format data when it is True. Has
no effect when check set to False.
:type check_fail_continue: bool
:param use_dynamic_order: Allow provider to yield a dictionary object, whose
key is a input data layer name, and value is the
feature value. The tuples are still allowed when
use_dynmaic_order is True.
:type use_dynamic_order: bool
"""
def __wrapper__(generator):
......@@ -340,6 +339,11 @@ def provider(input_types=None, should_shuffle=None, pool_size=-1,
assert self.slots is not None
assert self.generator is not None
use_dynamic_order = False
if isinstance(self.slots, dict): # reorder input_types
self.slots = [self.slots[ipt] for ipt in self.input_order]
use_dynamic_order = True
if len(self.slots) == 1:
self.generator = SingleSlotWrapper(self.generator)
......
......@@ -216,6 +216,10 @@ def Inputs(*args):
if g_current_submodel is g_root_submodel:
g_config.model_config.input_layer_names.append(name)
@config_func
def HasInputsSet():
return len(g_config.model_config.input_layer_names) != 0
# Define the name of the output layers of the NeuralNetwork.
# Usually the output is simply the cost layer.
......
......@@ -30,7 +30,7 @@ __all__ = ['sequence_conv_pool', 'simple_lstm', "simple_img_conv_pool",
'lstmemory_unit', 'small_vgg', 'img_conv_group', 'vgg_16_network',
'gru_unit', 'gru_group', 'simple_gru', 'simple_attention',
'text_conv_pool',
'bidirectional_lstm', 'outputs']
'bidirectional_lstm', 'inputs', 'outputs']
######################################################
......@@ -372,8 +372,8 @@ def small_vgg(input_image, num_channels, num_classes):
tmp = __vgg__(tmp, 128, 2, [0.4, 0])
tmp = __vgg__(tmp, 256, 3, [0.4, 0.4, 0])
tmp = __vgg__(tmp, 512, 3, [0.4, 0.4, 0])
tmp = img_pool_layer(input = tmp, stride = 2,
pool_size = 2, pool_type = MaxPooling())
tmp = img_pool_layer(input=tmp, stride=2,
pool_size=2, pool_type=MaxPooling())
tmp = dropout_layer(input=tmp, dropout_rate=0.5)
tmp = fc_layer(input=tmp, size=512, layer_attr=ExtraAttr(drop_rate=0.5),
act=LinearActivation())
......@@ -745,7 +745,6 @@ def gru_group(input,
gru_bias_attr=None,
act=None, gate_act=None,
gru_layer_attr=None):
"""
gru_group is a recurrent layer group version Gated Recurrent Unit. It
does exactly the same calculation as the grumemory layer does. A promising
......@@ -919,12 +918,12 @@ def bidirectional_lstm(input, size, name=None, return_seq=False,
fw = simple_lstm(name='%s_fw' % name, input=input, size=size,
**dict((k[len('fwd_'):], v) for k, v in args.iteritems()
if k.startswith('fwd_')))
if k.startswith('fwd_')))
bw = simple_lstm(name="%s_bw" % name, input=input, size=size,
reverse=True,
**dict((k[len('bwd_'):], v) for k, v in args.iteritems()
if k.startswith('bwd_')))
if k.startswith('bwd_')))
if return_seq:
return concat_layer(name=name, input=[fw, bw], layer_attr=concat_attr,
......@@ -1052,14 +1051,30 @@ def dropout_layer(input, dropout_rate, name=None):
layer_attr=ExtraAttr(drop_rate=dropout_rate))
def outputs(layers, *args):
def inputs(layers, *args):
"""
Declare the inputs of network. The order of input should be as same as
the data provider's return order.
:param layers: Input Layers.
:type layers: list|tuple|LayerOutput.
:return:
"""
Declare the end of network. Currently it will only calculate the
input/output order of network. It will calculate the predict network or
train network's output automatically.
if isinstance(layers, LayerOutput) or isinstance(layers, basestring):
layers = [layers]
if len(args) != 0:
layers.extend(args)
:param layers:
Inputs(*[l.name for l in layers])
def outputs(layers, *args):
"""
Declare the outputs of network. If user have not defined the inputs of
network, this method will calculate the input order by dfs travel.
:param layers: Output layers.
:type layers: list|tuple|LayerOutput
:return:
"""
......@@ -1093,6 +1108,11 @@ def outputs(layers, *args):
layers.extend(args)
assert len(layers) > 0
if HasInputsSet(): # input already set
Outputs(*[l.name for l in layers])
return # just return outputs.
if len(layers) != 1:
logger.warning("`outputs` routine try to calculate network's"
" inputs and outputs order. It might not work well."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部