diff --git a/fleet_rec/core/utils/dataloader_instance.py b/fleet_rec/core/utils/dataloader_instance.py index 8c43c610a29394b3b334d755bc95d21fb0f514ce..5af98e7c4ef2ae6e87291805d47716691a737cfd 100644 --- a/fleet_rec/core/utils/dataloader_instance.py +++ b/fleet_rec/core/utils/dataloader_instance.py @@ -18,6 +18,7 @@ import sys from fleetrec.core.utils.envs import lazy_instance from fleetrec.core.utils.envs import get_global_env +from fleetrec.core.utils.envs import get_runtime_environ def dataloader(readerclass, train, yaml_file): @@ -30,6 +31,11 @@ def dataloader(readerclass, train, yaml_file): reader_name = "EvaluateReader" data_path = get_global_env("test_data_path", None, namespace) + if data_path.startswith("fleetrec::"): + package_base = get_runtime_environ("PACKAGE_BASE") + assert package_base is not None + data_path = os.path.join(package_base, data_path.split("::")[1]) + files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] reader_class = lazy_instance(readerclass, reader_name) @@ -50,4 +56,5 @@ def dataloader(readerclass, train, yaml_file): for pased in parsed_line: values.append(pased[1]) yield values + return gen_reader