diff --git a/paddle/fluid/operators/gather.cu.h b/paddle/fluid/operators/gather.cu.h index e4df59c5d51c390cf593add0c5562665c91f33f6..5bc2e63757f19c1dc8a7d41fae9621a2816ff31b 100644 --- a/paddle/fluid/operators/gather.cu.h +++ b/paddle/fluid/operators/gather.cu.h @@ -64,6 +64,7 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src, for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; const T* p_src = src.data(); + // why must be int? const int* p_index = index.data(); T* p_output = output->data(); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index a37c04b5ffbaaa073fcfb05fb03caed77652d7e6..b703cc1a343b57bb0ad521700b96b19383505e39 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -235,6 +235,7 @@ PYBIND11_MODULE(core, m) { self.forward_id_ = forward_id; }, py::return_value_policy::reference) + .def_property_readonly("type", &imperative::OpBase::Type) .def_property( "backward_id", [](const imperative::OpBase &self) { return self.backward_id_; }, diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 5ac2b50a9967ef2b9a0e891b0bfcc0d77c2791eb..3f71247630faaf9fe5c28c0940b7f26cf5cceb52 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -744,7 +744,7 @@ class Operator(object): if _in_imperative_mode(): if type is None: raise ValueError( - "`type` to initilized an Operator can not be None.") + "`type` to initialized an Operator can not be None.") self.iop = core.OpBase(type) # TODO(minqiyang): remove these lines after we take apart all @@ -906,7 +906,10 @@ class Operator(object): @property def type(self): - return self.desc.type() + if _in_imperative_mode(): + return self.iop.type + else: + return self.desc.type() def input(self, name): """ diff --git a/python/paddle/fluid/imperative/base.py b/python/paddle/fluid/imperative/base.py index d619c09b1bdd704700af219856148524d9d0d8db..097cd2be35b01aced30486b874f202381c4d9962 100644 --- a/python/paddle/fluid/imperative/base.py +++ b/python/paddle/fluid/imperative/base.py @@ -55,7 +55,8 @@ def to_variable(value, block=None, name=None): type=core.VarDesc.VarType.LOD_TENSOR, name=name, shape=value.shape, - dtype=value.dtype) + dtype=value.dtype, + stop_gradient=True) var = py_var._ivar.value() tensor = var.get_tensor() tensor.set(value, framework._current_expected_place()) diff --git a/python/paddle/fluid/imperative/tracer.py b/python/paddle/fluid/imperative/tracer.py index bd77de7424c4547ea71a3f757de37f47b990d616..28c8586813410f7349da7943a966eaa9cc3816d2 100644 --- a/python/paddle/fluid/imperative/tracer.py +++ b/python/paddle/fluid/imperative/tracer.py @@ -62,7 +62,7 @@ class Tracer(core.Tracer): if len(backward_refs) > 0: op.iop.register_backward_hooks(release_op) - # TODO(minqiyang): remove all inputs and outputs after seperate + # TODO(minqiyang): remove all inputs and outputs after separate # var and grad op.backward_refs = defaultdict(list) for k, v in six.iteritems(op.inputs): diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index 482dfa6fac05bd914efa384bd0f5ec54cfab1dca..8358bb1aba98d8f5699cbda27e657ba6c470d333 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -212,7 +212,7 @@ class UniformInitializer(Initializer): if self._seed == 0: self._seed = block.program.random_seed - # to be compatible of fp16 initalizers + # to be compatible of fp16 initializers if var.dtype == VarDesc.VarType.FP16: out_dtype = VarDesc.VarType.FP32 out_var = block.create_var( @@ -756,7 +756,7 @@ class NumpyArrayInitializer(Initializer): values = [int(v) for v in self._value.flat] else: raise ValueError("Unsupported dtype %s", self._value.dtype) - if self._value.size > 1024 * 1024 * 5: + if self._value.size > 1024 * 1024 * 1024: raise ValueError("The size of input is too big. Please consider " "saving it to file and 'load_op' to load it") op = block._prepend_op( diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 505d9572a64e8c6f096764c4947a1fa554527e65..c0deb5eaccaefa52271b2c30e9f8b1d339624919 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -165,6 +165,8 @@ class Optimizer(object): name = self._name + "_" + name if (name in self._accumulators and param.name in self._accumulators[name]): + if framework._in_imperative_mode(): + return self._accumulators[name][param.name] raise Exception("Accumulator {} already exists for parameter {}". format(name, param.name)) if shape == None: diff --git a/python/paddle/fluid/tests/unittests/test_imperative_deepcf.py b/python/paddle/fluid/tests/unittests/test_imperative_deepcf.py index af80ca6ce77a4ec187dd52863c2fe2ba278d5023..ac123ee8db26ac23bbf9454e399a592a28c91c32 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_deepcf.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_deepcf.py @@ -15,6 +15,7 @@ import unittest import numpy as np import random +import os import sys import paddle @@ -23,16 +24,17 @@ import paddle.fluid.core as core from test_imperative_base import new_program_scope from paddle.fluid.imperative.base import to_variable -NUM_USERS = 100 -NUM_ITEMS = 1000 +# Can use Amusic dataset as the DeepCF describes. +DATA_PATH = os.environ.get('DATA_PATH', '') -BATCH_SIZE = 32 -NUM_BATCHES = 2 +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 128)) +NUM_BATCHES = int(os.environ.get('NUM_BATCHES', 5)) +NUM_EPOCHES = int(os.environ.get('NUM_EPOCHES', 1)) -class MLP(fluid.imperative.Layer): +class DMF(fluid.imperative.Layer): def __init__(self, name_scope): - super(MLP, self).__init__(name_scope) + super(DMF, self).__init__(name_scope) self._user_latent = fluid.imperative.FC(self.full_name(), 256) self._item_latent = fluid.imperative.FC(self.full_name(), 256) @@ -61,9 +63,9 @@ class MLP(fluid.imperative.Layer): return fluid.layers.elementwise_mul(users, items) -class DMF(fluid.imperative.Layer): +class MLP(fluid.imperative.Layer): def __init__(self, name_scope): - super(DMF, self).__init__(name_scope) + super(MLP, self).__init__(name_scope) self._user_latent = fluid.imperative.FC(self.full_name(), 256) self._item_latent = fluid.imperative.FC(self.full_name(), 256) self._match_layers = [] @@ -87,21 +89,30 @@ class DMF(fluid.imperative.Layer): class DeepCF(fluid.imperative.Layer): - def __init__(self, name_scope): + def __init__(self, name_scope, num_users, num_items, matrix): super(DeepCF, self).__init__(name_scope) - - self._user_emb = fluid.imperative.Embedding(self.full_name(), - [NUM_USERS, 256]) - self._item_emb = fluid.imperative.Embedding(self.full_name(), - [NUM_ITEMS, 256]) + self._num_users = num_users + self._num_items = num_items + self._rating_matrix = self.create_parameter( + fluid.ParamAttr(trainable=False), + matrix.shape, + matrix.dtype, + is_bias=False, + default_initializer=fluid.initializer.NumpyArrayInitializer(matrix)) + self._rating_matrix._stop_gradient = True self._mlp = MLP(self.full_name()) self._dmf = DMF(self.full_name()) self._match_fc = fluid.imperative.FC(self.full_name(), 1, act='sigmoid') def forward(self, users, items): - users_emb = self._user_emb(users) - items_emb = self._item_emb(items) + # users_emb = self._user_emb(users) + # items_emb = self._item_emb(items) + users_emb = fluid.layers.gather(self._rating_matrix, users) + items_emb = fluid.layers.gather( + fluid.layers.transpose(self._rating_matrix, [1, 0]), items) + users_emb.stop_gradient = True + items_emb.stop_gradient = True mlp_predictive = self._mlp(users_emb, items_emb) dmf_predictive = self._dmf(users_emb, items_emb) @@ -116,27 +127,79 @@ def get_data(): user_ids = [] item_ids = [] labels = [] + NUM_USERS = 100 + NUM_ITEMS = 1000 + matrix = np.zeros([NUM_USERS, NUM_ITEMS], dtype=np.float32) + for uid in range(NUM_USERS): for iid in range(NUM_ITEMS): - # 10% positive - label = float(random.randint(1, 10) == 1) + label = float(random.randint(1, 6) == 1) user_ids.append(uid) item_ids.append(iid) labels.append(label) - indices = np.arange(NUM_USERS * NUM_ITEMS) + matrix[uid, iid] = label + indices = np.arange(len(user_ids)) + np.random.shuffle(indices) + users_np = np.array(user_ids, dtype=np.int32)[indices] + items_np = np.array(item_ids, dtype=np.int32)[indices] + labels_np = np.array(labels, dtype=np.float32)[indices] + return np.expand_dims(users_np, -1), \ + np.expand_dims(items_np, -1), \ + np.expand_dims(labels_np, -1), NUM_USERS, NUM_ITEMS, matrix + + +def load_data(DATA_PATH): + sys.stderr.write('loading from %s\n' % DATA_PATH) + likes = dict() + num_users = -1 + num_items = -1 + with open(DATA_PATH, 'r') as f: + for l in f.readlines(): + uid, iid, rating = [int(v) for v in l.split('\t')] + num_users = max(num_users, uid + 1) + num_items = max(num_items, iid + 1) + if float(rating) > 0.0: + likes[(uid, iid)] = 1.0 + + user_ids = [] + item_ids = [] + labels = [] + matrix = np.zeros([num_users, num_items], dtype=np.float32) + for uid, iid in likes.keys(): + user_ids.append(uid) + item_ids.append(iid) + labels.append(1.0) + matrix[uid, iid] = 1.0 + + negative = 0 + while negative < 3: + nuid = random.randint(0, num_users - 1) + niid = random.randint(0, num_items - 1) + if (nuid, niid) not in likes: + negative += 1 + user_ids.append(nuid) + item_ids.append(niid) + labels.append(0.0) + + indices = np.arange(len(user_ids)) np.random.shuffle(indices) - users_np = np.array(user_ids, dtype=np.int64)[indices] - items_np = np.array(item_ids, dtype=np.int64)[indices] + users_np = np.array(user_ids, dtype=np.int32)[indices] + items_np = np.array(item_ids, dtype=np.int32)[indices] labels_np = np.array(labels, dtype=np.float32)[indices] return np.expand_dims(users_np, -1), \ np.expand_dims(items_np, -1), \ - np.expand_dims(labels_np, -1) + np.expand_dims(labels_np, -1), num_users, num_items, matrix class TestImperativeDeepCF(unittest.TestCase): - def test_gan_float32(self): + def test_deefcf(self): seed = 90 - users_np, items_np, labels_np = get_data() + if DATA_PATH: + (users_np, items_np, labels_np, num_users, num_items, + matrix) = load_data(DATA_PATH) + else: + (users_np, items_np, labels_np, num_users, num_items, + matrix) = get_data() startup = fluid.Program() startup.random_seed = seed @@ -145,11 +208,11 @@ class TestImperativeDeepCF(unittest.TestCase): scope = fluid.core.Scope() with new_program_scope(main=main, startup=startup, scope=scope): - users = fluid.layers.data('users', [1], dtype='int64') - items = fluid.layers.data('items', [1], dtype='int64') + users = fluid.layers.data('users', [1], dtype='int32') + items = fluid.layers.data('items', [1], dtype='int32') labels = fluid.layers.data('labels', [1], dtype='float32') - deepcf = DeepCF('deepcf') + deepcf = DeepCF('deepcf', num_users, num_items, matrix) prediction = deepcf(users, items) loss = fluid.layers.reduce_sum( fluid.layers.log_loss(prediction, labels)) @@ -159,35 +222,44 @@ class TestImperativeDeepCF(unittest.TestCase): exe = fluid.Executor(fluid.CPUPlace( ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) exe.run(startup) - for slice in range(0, BATCH_SIZE * NUM_BATCHES, BATCH_SIZE): - static_loss = exe.run( - main, - feed={ - users.name: users_np[slice:slice + BATCH_SIZE], - items.name: items_np[slice:slice + BATCH_SIZE], - labels.name: labels_np[slice:slice + BATCH_SIZE] - }, - fetch_list=[loss])[0] - sys.stderr.write('static loss %s\n' % static_loss) + for e in range(NUM_EPOCHES): + sys.stderr.write('epoch %d\n' % e) + for slice in range(0, BATCH_SIZE * NUM_BATCHES, BATCH_SIZE): + if slice + BATCH_SIZE >= users_np.shape[0]: + break + static_loss = exe.run( + main, + feed={ + users.name: users_np[slice:slice + BATCH_SIZE], + items.name: items_np[slice:slice + BATCH_SIZE], + labels.name: labels_np[slice:slice + BATCH_SIZE] + }, + fetch_list=[loss])[0] + sys.stderr.write('static loss %s\n' % static_loss) with fluid.imperative.guard(): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed - deepcf = DeepCF('deepcf') - for slice in range(0, BATCH_SIZE * NUM_BATCHES, BATCH_SIZE): - prediction = deepcf( - to_variable(users_np[slice:slice + BATCH_SIZE]), - to_variable(items_np[slice:slice + BATCH_SIZE])) - loss = fluid.layers.reduce_sum( - fluid.layers.log_loss(prediction, - to_variable(labels_np[slice:slice + - BATCH_SIZE]))) - loss._backward() - adam = fluid.optimizer.AdamOptimizer(0.01) - adam.minimize(loss) - deepcf.clear_gradients() - dy_loss = loss._numpy() + deepcf = DeepCF('deepcf', num_users, num_items, matrix) + adam = fluid.optimizer.AdamOptimizer(0.01) + for e in range(NUM_EPOCHES): + sys.stderr.write('epoch %d\n' % e) + for slice in range(0, BATCH_SIZE * NUM_BATCHES, BATCH_SIZE): + if slice + BATCH_SIZE >= users_np.shape[0]: + break + prediction = deepcf( + to_variable(users_np[slice:slice + BATCH_SIZE]), + to_variable(items_np[slice:slice + BATCH_SIZE])) + loss = fluid.layers.reduce_sum( + fluid.layers.log_loss(prediction, + to_variable(labels_np[ + slice:slice + BATCH_SIZE]))) + loss._backward() + adam.minimize(loss) + deepcf.clear_gradients() + dy_loss = loss._numpy() + sys.stderr.write('dynamic loss: %s %s\n' % (slice, dy_loss)) self.assertEqual(static_loss, dy_loss)