From 058eeac0fc94500137ce0192136d45a62ada0569 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Sun, 5 Feb 2017 01:17:00 +0000 Subject: [PATCH] Revert "Remove completely create_data_config_proto" This reverts commit ab279beed1725fae4fc4781ba891d80bde8eae0a. --- python/paddle/trainer/config_parser.py | 24 +++++++++++++++++++ .../trainer_config_helpers/data_sources.py | 3 +-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 65ebc96136..e3e2a6899e 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -894,6 +894,30 @@ class MaxOut(Cfg): self.add_keys(locals()) +def create_data_config_proto(async_load_data=False, + constant_slots=None, + data_ratio=1, + is_main_data=True, + usage_ratio=None): + # default: all sub dataproviders are treat as "main data". + # see proto/DataConfig.proto for is_main_data + data_config = DataConfig() + + data_config.async_load_data = async_load_data + + if constant_slots: + data_config.constant_slots.extend(constant_slots) + data_config.data_ratio = data_ratio + data_config.is_main_data = is_main_data + + usage_ratio = default(usage_ratio, settings_deprecated["usage_ratio"]) + config_assert(usage_ratio >= 0 and usage_ratio <= 1, + "The range of usage_ratio is [0, 1]") + data_config.usage_ratio = usage_ratio + + return data_config + + @config_func def SimpleData(files=None, feat_dim=None, diff --git a/python/paddle/trainer_config_helpers/data_sources.py b/python/paddle/trainer_config_helpers/data_sources.py index 6744dec5a8..0ea8fc77ee 100644 --- a/python/paddle/trainer_config_helpers/data_sources.py +++ b/python/paddle/trainer_config_helpers/data_sources.py @@ -14,7 +14,6 @@ """ Data Sources are helpers to define paddle training data or testing data. """ -import paddle.proto.DataConfig_pb2 from paddle.trainer.config_parser import * from .utils import deprecated @@ -196,7 +195,7 @@ def define_py_data_sources2(train_list, test_list, module, obj, args=None): def py_data2(files, load_data_module, load_data_object, load_data_args, **kwargs): - data = paddle.proto.DataConfig_pb2.DataConfig() + data = create_data_config_proto() data.type = 'py2' data.files = files data.load_data_module = load_data_module -- GitLab