未验证 提交 672def6c 编写于 作者: W Weilong Wu 提交者: GitHub

Support nce in eager mode (#39589)

上级 5b5656d0
...@@ -84,6 +84,9 @@ std::map<std::string, std::set<std::string>> op_ins_map = { ...@@ -84,6 +84,9 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"Q", "K", "V", "Offset", "Columns", "KeyPaddingMask", "AttnMask"}}, {"Q", "K", "V", "Offset", "Columns", "KeyPaddingMask", "AttnMask"}},
{"sgd", {"Param", "LearningRate", "Grad", "MasterParam"}}, {"sgd", {"Param", "LearningRate", "Grad", "MasterParam"}},
{"graph_khop_sampler", {"Row", "Eids", "Col_Ptr", "X"}}, {"graph_khop_sampler", {"Row", "Eids", "Col_Ptr", "X"}},
{"nce",
{"Input", "Label", "Weight", "Bias", "SampleWeight", "CustomDistProbs",
"CustomDistAlias", "CustomDistAliasProbs"}},
}; };
// NOTE(zhiqiu): Like op_ins_map. // NOTE(zhiqiu): Like op_ins_map.
......
...@@ -2233,6 +2233,19 @@ class NCE(layers.Layer): ...@@ -2233,6 +2233,19 @@ class NCE(layers.Layer):
self._inputs['Weight'] = self.weight self._inputs['Weight'] = self.weight
def forward(self, input, label, sample_weight=None): def forward(self, input, label, sample_weight=None):
if in_dygraph_mode():
attrs = ('num_total_classes', self._attrs['num_total_classes'],
'num_neg_samples', self._attrs['num_neg_samples'], 'seed',
self._attrs['seed'], 'sampler', self._attrs['sampler'],
'is_sparse', self._attrs['is_sparse'], 'remote_prefetch',
self._attrs['remote_prefetch'])
cost, _, _ = _C_ops.nce(
input, label, self.weight, self.bias,
self._inputs['SampleWeight'], self._inputs['CustomDistProbs'],
self._inputs['CustomDistAlias'],
self._inputs['CustomDistAliasProbs'], *attrs)
return cost / (self._num_neg_samples + 1)
check_variable_and_dtype(input, "input", ['float32', 'float64'], "NCE") check_variable_and_dtype(input, "input", ['float32', 'float64'], "NCE")
check_variable_and_dtype(label, "label", ['int64'], "NCE") check_variable_and_dtype(label, "label", ['int64'], "NCE")
check_type(sample_weight, 'sample_weight', (Variable, type(None)), check_type(sample_weight, 'sample_weight', (Variable, type(None)),
......
...@@ -1361,6 +1361,7 @@ class TestLayer(LayerTest): ...@@ -1361,6 +1361,7 @@ class TestLayer(LayerTest):
feed_dict['word_{0}'.format(i)] = inp_word[i] feed_dict['word_{0}'.format(i)] = inp_word[i]
static_rlt = self.get_static_graph_result( static_rlt = self.get_static_graph_result(
feed=feed_dict, fetch_list=[nce_loss])[0] feed=feed_dict, fetch_list=[nce_loss])[0]
with self.static_graph(): with self.static_graph():
words = [] words = []
for i in range(window_size): for i in range(window_size):
...@@ -1401,7 +1402,41 @@ class TestLayer(LayerTest): ...@@ -1401,7 +1402,41 @@ class TestLayer(LayerTest):
feed=feed_dict, fetch_list=[nce_loss2])[0] feed=feed_dict, fetch_list=[nce_loss2])[0]
with self.dynamic_graph(): with self.dynamic_graph():
# TODO(wuweilong): Add with _test_eager_guard(): with _test_eager_guard():
words = []
for i in range(window_size):
words.append(base.to_variable(inp_word[i]))
sample_weights = layers.fill_constant(
shape=[5, 1], dtype='float32', value=1)
emb = nn.Embedding(
size=[dict_size, 32],
param_attr='eager_emb.w',
is_sparse=False)
embs3 = []
for i in range(window_size):
if i == label_word:
continue
emb_rlt = emb(words[i])
embs3.append(emb_rlt)
embs3 = layers.concat(
input=embs3, axis=fluid.dygraph.to_variable(np.array([1])))
nce = nn.NCE(num_total_classes=dict_size,
dim=embs3.shape[1],
num_neg_samples=2,
sampler="custom_dist",
custom_dist=nid_freq_arr.tolist(),
seed=seed,
param_attr='eager_nce.w',
bias_attr='eager_nce.b',
sample_weight=sample_weights)
wl = fluid.layers.unsqueeze(words[label_word], axes=[0])
dy_eager_rlt = nce(embs3, wl)
dy_eager_rlt_value = dy_eager_rlt.numpy()
words = [] words = []
for i in range(window_size): for i in range(window_size):
words.append(base.to_variable(inp_word[i])) words.append(base.to_variable(inp_word[i]))
...@@ -1436,9 +1471,75 @@ class TestLayer(LayerTest): ...@@ -1436,9 +1471,75 @@ class TestLayer(LayerTest):
self.assertTrue(np.allclose(static_rlt2, static_rlt)) self.assertTrue(np.allclose(static_rlt2, static_rlt))
self.assertTrue(np.allclose(dy_rlt_value, static_rlt)) self.assertTrue(np.allclose(dy_rlt_value, static_rlt))
self.assertTrue(np.allclose(dy_eager_rlt_value, static_rlt))
with self.dynamic_graph(): with self.dynamic_graph():
# TODO(wuweilong): Add with _test_eager_guard(): with _test_eager_guard():
custom_weight = np.random.randn(dict_size,
128).astype("float32")
weight_attr = fluid.ParamAttr(
initializer=fluid.initializer.NumpyArrayInitializer(
custom_weight))
words = []
for i in range(window_size):
words.append(base.to_variable(inp_word[i]))
sample_weights = layers.fill_constant(
shape=fluid.dygraph.to_variable(np.array([5, 1])),
dtype='float32',
value=1)
emb = nn.Embedding(
size=[dict_size, 32],
param_attr='eager_emb.w',
is_sparse=False)
embs3 = []
for i in range(window_size):
if i == label_word:
continue
emb_rlt = emb(words[i])
embs3.append(emb_rlt)
embs3 = layers.concat(input=embs3, axis=1)
nce1 = nn.NCE(num_total_classes=dict_size,
dim=embs3.shape[1],
num_neg_samples=2,
sampler="custom_dist",
custom_dist=nid_freq_arr.tolist(),
seed=seed,
param_attr='eager_nce1.w',
bias_attr='eager_nce1.b',
sample_weight=sample_weights)
nce2 = nn.NCE(num_total_classes=dict_size,
dim=embs3.shape[1],
num_neg_samples=2,
sampler="custom_dist",
custom_dist=nid_freq_arr.tolist(),
seed=seed,
param_attr=weight_attr,
bias_attr='eager_nce2.b',
sample_weight=sample_weights)
wl = fluid.layers.unsqueeze(words[label_word], axes=[0])
nce1_loss = nce1(embs3, wl)
nce2_loss = nce2(embs3, wl)
self.assertFalse(
np.array_equal(nce1_loss.numpy(), nce2_loss.numpy()))
nce2.weight.set_value(nce1.weight.numpy())
nce2.bias.set_value(nce1.bias)
nce1_loss = nce1(embs3, wl)
nce2_loss = nce2(embs3, wl)
self.assertTrue(
np.array_equal(nce1_loss.numpy(), nce2_loss.numpy()))
nce2.weight = nce1.weight
nce2.bias = nce1.bias
self.assertTrue(
np.array_equal(nce1.weight.numpy(), nce2.weight.numpy()))
self.assertTrue(
np.array_equal(nce1.bias.numpy(), nce2.bias.numpy()))
custom_weight = np.random.randn(dict_size, 128).astype("float32") custom_weight = np.random.randn(dict_size, 128).astype("float32")
weight_attr = fluid.ParamAttr( weight_attr = fluid.ParamAttr(
initializer=fluid.initializer.NumpyArrayInitializer( initializer=fluid.initializer.NumpyArrayInitializer(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册