未验证 提交 f8c279b1 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #16454 from panyx0718/imperative2

polish deepCF model to support real dataset
......@@ -64,6 +64,7 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
const T* p_src = src.data<T>();
// why must be int?
const int* p_index = index.data<int>();
T* p_output = output->data<T>();
......
......@@ -235,6 +235,7 @@ PYBIND11_MODULE(core, m) {
self.forward_id_ = forward_id;
},
py::return_value_policy::reference)
.def_property_readonly("type", &imperative::OpBase::Type)
.def_property(
"backward_id",
[](const imperative::OpBase &self) { return self.backward_id_; },
......
......@@ -744,7 +744,7 @@ class Operator(object):
if _in_imperative_mode():
if type is None:
raise ValueError(
"`type` to initilized an Operator can not be None.")
"`type` to initialized an Operator can not be None.")
self.iop = core.OpBase(type)
# TODO(minqiyang): remove these lines after we take apart all
......@@ -906,7 +906,10 @@ class Operator(object):
@property
def type(self):
return self.desc.type()
if _in_imperative_mode():
return self.iop.type
else:
return self.desc.type()
def input(self, name):
"""
......
......@@ -55,7 +55,8 @@ def to_variable(value, block=None, name=None):
type=core.VarDesc.VarType.LOD_TENSOR,
name=name,
shape=value.shape,
dtype=value.dtype)
dtype=value.dtype,
stop_gradient=True)
var = py_var._ivar.value()
tensor = var.get_tensor()
tensor.set(value, framework._current_expected_place())
......
......@@ -62,7 +62,7 @@ class Tracer(core.Tracer):
if len(backward_refs) > 0:
op.iop.register_backward_hooks(release_op)
# TODO(minqiyang): remove all inputs and outputs after seperate
# TODO(minqiyang): remove all inputs and outputs after separate
# var and grad
op.backward_refs = defaultdict(list)
for k, v in six.iteritems(op.inputs):
......
......@@ -212,7 +212,7 @@ class UniformInitializer(Initializer):
if self._seed == 0:
self._seed = block.program.random_seed
# to be compatible of fp16 initalizers
# to be compatible of fp16 initializers
if var.dtype == VarDesc.VarType.FP16:
out_dtype = VarDesc.VarType.FP32
out_var = block.create_var(
......@@ -756,7 +756,7 @@ class NumpyArrayInitializer(Initializer):
values = [int(v) for v in self._value.flat]
else:
raise ValueError("Unsupported dtype %s", self._value.dtype)
if self._value.size > 1024 * 1024 * 5:
if self._value.size > 1024 * 1024 * 1024:
raise ValueError("The size of input is too big. Please consider "
"saving it to file and 'load_op' to load it")
op = block._prepend_op(
......
......@@ -165,6 +165,8 @@ class Optimizer(object):
name = self._name + "_" + name
if (name in self._accumulators and
param.name in self._accumulators[name]):
if framework._in_imperative_mode():
return self._accumulators[name][param.name]
raise Exception("Accumulator {} already exists for parameter {}".
format(name, param.name))
if shape == None:
......
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
import random
import os
import sys
import paddle
......@@ -23,16 +24,17 @@ import paddle.fluid.core as core
from test_imperative_base import new_program_scope
from paddle.fluid.imperative.base import to_variable
NUM_USERS = 100
NUM_ITEMS = 1000
# Can use Amusic dataset as the DeepCF describes.
DATA_PATH = os.environ.get('DATA_PATH', '')
BATCH_SIZE = 32
NUM_BATCHES = 2
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 128))
NUM_BATCHES = int(os.environ.get('NUM_BATCHES', 5))
NUM_EPOCHES = int(os.environ.get('NUM_EPOCHES', 1))
class MLP(fluid.imperative.Layer):
class DMF(fluid.imperative.Layer):
def __init__(self, name_scope):
super(MLP, self).__init__(name_scope)
super(DMF, self).__init__(name_scope)
self._user_latent = fluid.imperative.FC(self.full_name(), 256)
self._item_latent = fluid.imperative.FC(self.full_name(), 256)
......@@ -61,9 +63,9 @@ class MLP(fluid.imperative.Layer):
return fluid.layers.elementwise_mul(users, items)
class DMF(fluid.imperative.Layer):
class MLP(fluid.imperative.Layer):
def __init__(self, name_scope):
super(DMF, self).__init__(name_scope)
super(MLP, self).__init__(name_scope)
self._user_latent = fluid.imperative.FC(self.full_name(), 256)
self._item_latent = fluid.imperative.FC(self.full_name(), 256)
self._match_layers = []
......@@ -87,21 +89,30 @@ class DMF(fluid.imperative.Layer):
class DeepCF(fluid.imperative.Layer):
def __init__(self, name_scope):
def __init__(self, name_scope, num_users, num_items, matrix):
super(DeepCF, self).__init__(name_scope)
self._user_emb = fluid.imperative.Embedding(self.full_name(),
[NUM_USERS, 256])
self._item_emb = fluid.imperative.Embedding(self.full_name(),
[NUM_ITEMS, 256])
self._num_users = num_users
self._num_items = num_items
self._rating_matrix = self.create_parameter(
fluid.ParamAttr(trainable=False),
matrix.shape,
matrix.dtype,
is_bias=False,
default_initializer=fluid.initializer.NumpyArrayInitializer(matrix))
self._rating_matrix._stop_gradient = True
self._mlp = MLP(self.full_name())
self._dmf = DMF(self.full_name())
self._match_fc = fluid.imperative.FC(self.full_name(), 1, act='sigmoid')
def forward(self, users, items):
users_emb = self._user_emb(users)
items_emb = self._item_emb(items)
# users_emb = self._user_emb(users)
# items_emb = self._item_emb(items)
users_emb = fluid.layers.gather(self._rating_matrix, users)
items_emb = fluid.layers.gather(
fluid.layers.transpose(self._rating_matrix, [1, 0]), items)
users_emb.stop_gradient = True
items_emb.stop_gradient = True
mlp_predictive = self._mlp(users_emb, items_emb)
dmf_predictive = self._dmf(users_emb, items_emb)
......@@ -116,27 +127,79 @@ def get_data():
user_ids = []
item_ids = []
labels = []
NUM_USERS = 100
NUM_ITEMS = 1000
matrix = np.zeros([NUM_USERS, NUM_ITEMS], dtype=np.float32)
for uid in range(NUM_USERS):
for iid in range(NUM_ITEMS):
# 10% positive
label = float(random.randint(1, 10) == 1)
label = float(random.randint(1, 6) == 1)
user_ids.append(uid)
item_ids.append(iid)
labels.append(label)
indices = np.arange(NUM_USERS * NUM_ITEMS)
matrix[uid, iid] = label
indices = np.arange(len(user_ids))
np.random.shuffle(indices)
users_np = np.array(user_ids, dtype=np.int32)[indices]
items_np = np.array(item_ids, dtype=np.int32)[indices]
labels_np = np.array(labels, dtype=np.float32)[indices]
return np.expand_dims(users_np, -1), \
np.expand_dims(items_np, -1), \
np.expand_dims(labels_np, -1), NUM_USERS, NUM_ITEMS, matrix
def load_data(DATA_PATH):
sys.stderr.write('loading from %s\n' % DATA_PATH)
likes = dict()
num_users = -1
num_items = -1
with open(DATA_PATH, 'r') as f:
for l in f.readlines():
uid, iid, rating = [int(v) for v in l.split('\t')]
num_users = max(num_users, uid + 1)
num_items = max(num_items, iid + 1)
if float(rating) > 0.0:
likes[(uid, iid)] = 1.0
user_ids = []
item_ids = []
labels = []
matrix = np.zeros([num_users, num_items], dtype=np.float32)
for uid, iid in likes.keys():
user_ids.append(uid)
item_ids.append(iid)
labels.append(1.0)
matrix[uid, iid] = 1.0
negative = 0
while negative < 3:
nuid = random.randint(0, num_users - 1)
niid = random.randint(0, num_items - 1)
if (nuid, niid) not in likes:
negative += 1
user_ids.append(nuid)
item_ids.append(niid)
labels.append(0.0)
indices = np.arange(len(user_ids))
np.random.shuffle(indices)
users_np = np.array(user_ids, dtype=np.int64)[indices]
items_np = np.array(item_ids, dtype=np.int64)[indices]
users_np = np.array(user_ids, dtype=np.int32)[indices]
items_np = np.array(item_ids, dtype=np.int32)[indices]
labels_np = np.array(labels, dtype=np.float32)[indices]
return np.expand_dims(users_np, -1), \
np.expand_dims(items_np, -1), \
np.expand_dims(labels_np, -1)
np.expand_dims(labels_np, -1), num_users, num_items, matrix
class TestImperativeDeepCF(unittest.TestCase):
def test_gan_float32(self):
def test_deefcf(self):
seed = 90
users_np, items_np, labels_np = get_data()
if DATA_PATH:
(users_np, items_np, labels_np, num_users, num_items,
matrix) = load_data(DATA_PATH)
else:
(users_np, items_np, labels_np, num_users, num_items,
matrix) = get_data()
startup = fluid.Program()
startup.random_seed = seed
......@@ -145,11 +208,11 @@ class TestImperativeDeepCF(unittest.TestCase):
scope = fluid.core.Scope()
with new_program_scope(main=main, startup=startup, scope=scope):
users = fluid.layers.data('users', [1], dtype='int64')
items = fluid.layers.data('items', [1], dtype='int64')
users = fluid.layers.data('users', [1], dtype='int32')
items = fluid.layers.data('items', [1], dtype='int32')
labels = fluid.layers.data('labels', [1], dtype='float32')
deepcf = DeepCF('deepcf')
deepcf = DeepCF('deepcf', num_users, num_items, matrix)
prediction = deepcf(users, items)
loss = fluid.layers.reduce_sum(
fluid.layers.log_loss(prediction, labels))
......@@ -159,35 +222,44 @@ class TestImperativeDeepCF(unittest.TestCase):
exe = fluid.Executor(fluid.CPUPlace(
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
exe.run(startup)
for slice in range(0, BATCH_SIZE * NUM_BATCHES, BATCH_SIZE):
static_loss = exe.run(
main,
feed={
users.name: users_np[slice:slice + BATCH_SIZE],
items.name: items_np[slice:slice + BATCH_SIZE],
labels.name: labels_np[slice:slice + BATCH_SIZE]
},
fetch_list=[loss])[0]
sys.stderr.write('static loss %s\n' % static_loss)
for e in range(NUM_EPOCHES):
sys.stderr.write('epoch %d\n' % e)
for slice in range(0, BATCH_SIZE * NUM_BATCHES, BATCH_SIZE):
if slice + BATCH_SIZE >= users_np.shape[0]:
break
static_loss = exe.run(
main,
feed={
users.name: users_np[slice:slice + BATCH_SIZE],
items.name: items_np[slice:slice + BATCH_SIZE],
labels.name: labels_np[slice:slice + BATCH_SIZE]
},
fetch_list=[loss])[0]
sys.stderr.write('static loss %s\n' % static_loss)
with fluid.imperative.guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
deepcf = DeepCF('deepcf')
for slice in range(0, BATCH_SIZE * NUM_BATCHES, BATCH_SIZE):
prediction = deepcf(
to_variable(users_np[slice:slice + BATCH_SIZE]),
to_variable(items_np[slice:slice + BATCH_SIZE]))
loss = fluid.layers.reduce_sum(
fluid.layers.log_loss(prediction,
to_variable(labels_np[slice:slice +
BATCH_SIZE])))
loss._backward()
adam = fluid.optimizer.AdamOptimizer(0.01)
adam.minimize(loss)
deepcf.clear_gradients()
dy_loss = loss._numpy()
deepcf = DeepCF('deepcf', num_users, num_items, matrix)
adam = fluid.optimizer.AdamOptimizer(0.01)
for e in range(NUM_EPOCHES):
sys.stderr.write('epoch %d\n' % e)
for slice in range(0, BATCH_SIZE * NUM_BATCHES, BATCH_SIZE):
if slice + BATCH_SIZE >= users_np.shape[0]:
break
prediction = deepcf(
to_variable(users_np[slice:slice + BATCH_SIZE]),
to_variable(items_np[slice:slice + BATCH_SIZE]))
loss = fluid.layers.reduce_sum(
fluid.layers.log_loss(prediction,
to_variable(labels_np[
slice:slice + BATCH_SIZE])))
loss._backward()
adam.minimize(loss)
deepcf.clear_gradients()
dy_loss = loss._numpy()
sys.stderr.write('dynamic loss: %s %s\n' % (slice, dy_loss))
self.assertEqual(static_loss, dy_loss)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册