diff --git a/fleetrec/core/model.py b/fleetrec/core/model.py index 4d30627d80991860063015fe5a00649f54ecba0e..9f940b6c0a204a0d1270f9ea46933583dc6307e0 100644 --- a/fleetrec/core/model.py +++ b/fleetrec/core/model.py @@ -1,5 +1,5 @@ import abc - +from fleetrec.core.utils import envs class Model(object): """R @@ -15,6 +15,7 @@ class Model(object): self._data_loader = None self._fetch_interval = 20 self._namespace = "train.model" + self._platform = envs.get_platform() def get_inputs(self): return self._data_var diff --git a/fleetrec/models/ctr/dnn/model.py b/fleetrec/models/ctr/dnn/model.py index 7816e5a72077e0b8069fd3906aa5da7c5b1f4b7e..ac3887d70c9578892f2ebb5b4796d632685a5327 100644 --- a/fleetrec/models/ctr/dnn/model.py +++ b/fleetrec/models/ctr/dnn/model.py @@ -57,8 +57,10 @@ class Model(ModelBase): self._data_var.append(input) self._data_var.append(self.label_input) - self._data_loader = fluid.io.PyReader( - feed_list=self._data_var, capacity=64, use_double_buffer=False, iterable=False) + + if self._platform != "LINUX": + self._data_loader = fluid.io.PyReader( + feed_list=self._data_var, capacity=64, use_double_buffer=False, iterable=False) def net(self): trainer = envs.get_trainer()