提交 cccc61c0 编写于 作者: T tangwei12

fix ctr-dnn local training

上级 afec7a49
......@@ -34,7 +34,7 @@ train:
reader:
mode: "dataset"
batch_size: 32
batch_size: 2
pipe_command: "python /paddle/eleps/models/ctr_dnn/dataset.py"
train_data_path: "/paddle/eleps/models/ctr_dnn/data/train"
......@@ -45,10 +45,7 @@ train:
sparse_feature_number: 1000001
sparse_feature_dim: 8
dense_input_dim: 13
fc_sizes: [101, 512, 32]
# - 1024
# - 512
# - 32
fc_sizes: [512, 256, 128, 32]
learning_rate: 0.001
save:
......
......@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
import paddle.fluid.incubate.data_generator as dg
......@@ -37,6 +39,7 @@ class CriteoDataset(dg.MultiSlotDataGenerator):
This function needs to be implemented by the user, based on data format
"""
features = line.rstrip('\n').split('\t')
dense_feature = []
sparse_feature = []
for idx in continuous_range_:
......@@ -46,6 +49,7 @@ class CriteoDataset(dg.MultiSlotDataGenerator):
dense_feature.append(
(float(features[idx]) - cont_min_[idx - 1]) /
cont_diff_[idx - 1])
for idx in categorical_range_:
sparse_feature.append(
[hash(str(idx) + features[idx]) % hash_dim_])
......
......@@ -39,7 +39,7 @@ class Train(object):
fluid.layers.data(name="C" + str(i),
shape=[1],
lod_level=1,
dtype="int64") for i in range(ids)
dtype="int64") for i in range(1, ids)
]
return sparse_input_ids, [var.name for var in sparse_input_ids]
......@@ -60,7 +60,7 @@ class Train(object):
self.label_input, self.label_input_varname = label_input()
def input_vars(self):
return self.sparse_inputs + [self.dense_input] + [self.label_input]
return [self.dense_input] + self.sparse_inputs + [self.label_input]
def input_varnames(self):
return [input.name for input in self.input_vars()]
......
......@@ -169,7 +169,7 @@ class SingleTrainerWithDataset(SingleTrainer):
dataset=dataset,
fetch_list=self.metrics,
fetch_info=["auc ", "batch auc"],
print_period=100)
print_period=1)
context['status'] = 'infer_pass'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册