对fluid.data的输入先进行concat、reshape等预处理再进行embedding时,反向传播报错
Created by: heygrain
1)PaddlePaddle版本:1.6 2)GPU:GTX1070单卡
错误信息:
Traceback (most recent call last):
File "/home/grain/.vscode/extensions/ms-python.python-2019.11.50794/pythonFiles/ptvsd_launcher.py", line 43, in <module>
main(ptvsdArgs)
File "/home/grain/.vscode/extensions/ms-python.python-2019.11.50794/pythonFiles/lib/python/old_ptvsd/ptvsd/__main__.py", line 432, in main
run()
File "/home/grain/.vscode/extensions/ms-python.python-2019.11.50794/pythonFiles/lib/python/old_ptvsd/ptvsd/__main__.py", line 316, in run_file
runpy.run_path(target, run_name='__main__')
File "/usr/lib/python3.6/runpy.py", line 263, in run_path
pkg_name=pkg_name, script_name=fname)
File "/usr/lib/python3.6/runpy.py", line 96, in _run_module_code
mod_name, mod_spec, pkg_name, script_name)
File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/media/grain/DATA/Grain_JGY/WorkSpace/graph_cross/tmp/tmp.py", line 98, in <module>
optimizer.minimize(model.loss)
File "</home/grain/ml/lib/python3.6/site-packages/decorator.py:decorator-gen-36>", line 2, in minimize
File "/home/grain/ml/lib/python3.6/site-packages/paddle/fluid/wrapped_decorator.py", line 25, in __impl__
return wrapped_func(*args, **kwargs)
File "/home/grain/ml/lib/python3.6/site-packages/paddle/fluid/dygraph/base.py", line 78, in __impl__
return func(*args, **kwargs)
File "/home/grain/ml/lib/python3.6/site-packages/paddle/fluid/optimizer.py", line 678, in minimize
no_grad_set=no_grad_set)
File "/home/grain/ml/lib/python3.6/site-packages/paddle/fluid/optimizer.py", line 551, in backward
no_grad_set, callbacks)
File "/home/grain/ml/lib/python3.6/site-packages/paddle/fluid/backward.py", line 1085, in append_backward
_append_backward_vars_(root_block, fwd_op_num, grad_to_var, grad_info_map)
File "/home/grain/ml/lib/python3.6/site-packages/paddle/fluid/backward.py", line 891, in _append_backward_vars_
op_desc.infer_shape(block.desc)
paddle.fluid.core_avx.EnforceNotMet:
--------------------------------------------
C++ Call Stacks (More useful to developers):
--------------------------------------------
0 std::string paddle::platform::GetTraceBackString<std::string const&>(std::string const&, char const*, int)
1 paddle::platform::EnforceNotMet::EnforceNotMet(std::string const&, char const*, int)
2 paddle::operators::Reshape2GradOp::InferShape(paddle::framework::InferShapeContext*) const
3 paddle::framework::OpDesc::InferShape(paddle::framework::BlockDesc const&) const
------------------------------------------
Python Call Stacks (More useful to users):
------------------------------------------
File "/home/grain/ml/lib/python3.6/site-packages/paddle/fluid/framework.py", line 2488, in append_op
attrs=kwargs.get("attrs", None))
File "/home/grain/ml/lib/python3.6/site-packages/paddle/fluid/layer_helper.py", line 43, in append_op
return self.main_program.current_block().append_op(*args, **kwargs)
File "/home/grain/ml/lib/python3.6/site-packages/paddle/fluid/layers/nn.py", line 9022, in reshape
"XShape": x_shape})
File "/media/grain/DATA/Grain_JGY/WorkSpace/graph_cross/tmp/tmp.py", line 72, in __init__
concat_users = layers.reshape(layers.concat([self.users] * (n_negsamples + 1)), [-1, 1], inplace=False)
File "/media/grain/DATA/Grain_JGY/WorkSpace/graph_cross/tmp/tmp.py", line 88, in <module>
model = GraphCross(user_id_range=data['user_id_range'], item_id_range=data['item_id_range'], n_negsamples=3)
File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/usr/lib/python3.6/runpy.py", line 96, in _run_module_code
mod_name, mod_spec, pkg_name, script_name)
File "/usr/lib/python3.6/runpy.py", line 263, in run_path
pkg_name=pkg_name, script_name=fname)
File "/home/grain/.vscode/extensions/ms-python.python-2019.11.50794/pythonFiles/lib/python/old_ptvsd/ptvsd/__main__.py", line 316, in run_file
runpy.run_path(target, run_name='__main__')
File "/home/grain/.vscode/extensions/ms-python.python-2019.11.50794/pythonFiles/lib/python/old_ptvsd/ptvsd/__main__.py", line 432, in main
run()
File "/home/grain/.vscode/extensions/ms-python.python-2019.11.50794/pythonFiles/ptvsd_launcher.py", line 43, in <module>
main(ptvsdArgs)
----------------------
Error Message Summary:
----------------------
Error: Input(Out@GRAD) shouldn't be null.
[Hint: Expected ctx->HasInput(framework::GradVarName("Out")) == true, but received ctx->HasInput(framework::GradVarName("Out")):0 != true:1.] at (/paddle/paddle/fluid/operators/reshape_op.cc:470)
[operator < reshape2_grad > error]
可供调试的代码:
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from scipy.sparse import coo_matrix
IS_SPARSE = True
# ===========这个函数是创建并加载一个调试用的模拟数据集==========
def load_toy_data():
item_demo = np.array([8, 8, 8, 8]).reshape((-1, 1))
train_user = np.array([0, 1, 2, 3, 0, 1, 2, 3]).reshape((-1, 1))
train_item = np.array([4, 5, 6, 7, 5, 4, 7, 6]).reshape((-1, 1))
data = dict()
data['train_user'] = train_user
data['train_item'] = train_item
data['item_demo'] = item_demo
data['user_demo'] = None
data['item_idx_min'] = 4
data['item_idx_max'] = 7
data['user_id_range'] = [0, 3]
data['item_id_range'] = [4, 7]
n_fields = 3
n_features = 9
# user-item graph (only use training interactions)
adj = np.hstack([data['train_user'], data['train_item']])
# user-demo graph
if data['user_demo'] is not None:
adj = np.vstack([adj, np.hstack([np.arange(0, data['item_idx_min']).reshape((-1, 1)), data['item_demo']])])
# item-demo graph
if data['item_demo'] is not None:
adj = np.vstack([adj, np.hstack([np.arange(data['item_idx_min'], data['item_idx_max'] + 1).reshape((-1, 1)), data['item_demo']])])
adj_mat = coo_matrix((np.ones(len(adj)), (adj[:, 0], adj[:, 1])), shape=(n_features, n_features))
adj_mat += adj_mat.transpose()
data['adj_mat'] = adj_mat
print('toy dataset:')
for key, val in data.items():
if (val is not None) and not (key in ['item_idx_min', 'item_idx_max', 'user_id_range', 'item_id_range']):
print('\t', key, ':', val.shape)
print('\tn_fields:', n_fields)
print('\tn_nodes/feats:', n_features)
return data, n_fields, n_features
class GraphCross(object):
def __init__(self, user_id_range, item_id_range, n_negsamples):
# ===============================model config===============================
self.user_id_range = user_id_range
self.item_id_range = item_id_range
self.n_users = user_id_range[1] - user_id_range[0] + 1
self.n_items = item_id_range[1] - item_id_range[0] + 1
# =================================inputs===============================
self.users = fluid.data(name='users', shape=[-1, 1], dtype='int64')
self.pos_items = fluid.data(name='pos_items', shape=[-1, 1], dtype='int64')
self.neg_items = fluid.data(name='neg_items', shape=[-1, n_negsamples], dtype='int64')
batch_size = layers.shape(self.users)[0]
self.batch_size = batch_size
# ===============================negative samples transfer===============================
concat_items = layers.reshape(layers.concat([self.pos_items, self.neg_items], axis=1), [-1, 1], inplace=False)
concat_items = layers.elementwise_sub(concat_items, layers.fill_constant(shape=layers.shape(concat_items), dtype='int64', value=self.item_id_range[0]))
concat_users = layers.reshape(layers.concat([self.users] * (n_negsamples + 1)), [-1, 1], inplace=False)
u_vec = fluid.embedding(concat_users, [self.n_users, 5], is_sparse=IS_SPARSE) # (N, embedding_dim=5)
i_vec = fluid.embedding(concat_items, [self.n_items, 5], is_sparse=IS_SPARSE)
scores = layers.reduce_sum(layers.elementwise_mul(u_vec, i_vec), dim=1) # of shape (N, )
scores = layers.reshape(scores, [batch_size, -1], inplace=False) # of shape (batch_size, (n_negsamples + 1))
# ===============================loss and metrics===============================
bpr = layers.bpr_loss(input=scores, label=layers.fill_constant(shape=[batch_size, 1], value=0, dtype='int64'))
self.loss = layers.mean(bpr)
# load data
data, n_fields, n_features = load_toy_data()
# create model
model = GraphCross(user_id_range=data['user_id_range'], item_id_range=data['item_id_range'], n_negsamples=3)
# test program
test_program = fluid.default_main_program().clone(for_test=True)
# creat optimizer
optimizer = fluid.optimizer.Adam(
learning_rate=0.001,
regularization=fluid.regularizer.L2DecayRegularizer(1e-5)
)
optimizer.minimize(model.loss)
# use cuda or not, and startup
use_cuda = True
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
train_user, train_item = data['train_user'], data['train_item']
neg_item = np.random.randint(low=data['item_idx_min'], high=data['item_idx_max'] + 1, size=[len(train_user), 3]) # [n_users, n_negsamples]
output = exe.run(feed={model.users.name: train_user.astype(np.int64),
model.pos_items.name: train_item.astype(np.int64),
model.neg_items.name: neg_item.astype(np.int64)},
fetch_list=[model.loss])
问题所在:
优化器那行,即:
optimizer.minimize(model.loss)
注释掉之后,可以跑通得到model.loss,
但是:
有optimizer,产生反向传播就不行。
应该是对输入self.users
、self.pos_item
、self.neg_item
进行concat、reshape等操作时导致了梯度计算的异常,请问应该怎么处理?
换句话说: 在自然语言处理(或者推荐系统)等场景,假设我的输入是词的id(或者用户和商品id),这时候我要先对fluid.data接受的id矩阵(dtype='int64')做一些concat和reshape的操作,然后再拿预处理完的id矩阵去做嵌入,也就是fluid.embedding。 反向传播梯度的问题就出在一开始预处理id的concat和reshape上面,实际上我并不需要梯度反向传播到这里,因为相当于预处理完了的数据才是真正的输入。所以,有适合用的api吗,不然我只能在cpu里用numpy先concat和reshape完了再feed模型了。。。