未验证 提交 5cfc40de 编写于 作者: T tangwei12 提交者: GitHub

nce add check sample lables, test=develop (#15463)

* nce add check sample lables, test=develop
上级 af07118d
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/sampler.h" #include "paddle/fluid/operators/math/sampler.h"
#include <glog/logging.h>
#include <iostream> #include <iostream>
#include <queue> #include <queue>
#include <utility> #include <utility>
...@@ -77,7 +78,14 @@ int64_t CustomSampler::Sample() const { ...@@ -77,7 +78,14 @@ int64_t CustomSampler::Sample() const {
auto index = (*int_dist_)(*random_engine_); auto index = (*int_dist_)(*random_engine_);
auto p = (*real_dist_)(*random_engine_); auto p = (*real_dist_)(*random_engine_);
if (p > alias_probs_[index]) { if (p > alias_probs_[index]) {
return alias_[index]; int alias = alias_[index];
if (alias == exceptional_val) {
LOG(WARNING) << "WARNING: CustomSampler get alias " << exceptional_val;
return index;
}
return alias;
} else { } else {
return index; return index;
} }
......
...@@ -116,6 +116,7 @@ class CustomSampler : public Sampler { ...@@ -116,6 +116,7 @@ class CustomSampler : public Sampler {
const float* alias_probs_; const float* alias_probs_;
const int* alias_; const int* alias_;
const float* probs_; const float* probs_;
const int exceptional_val = -1;
std::shared_ptr<std::mt19937> random_engine_; std::shared_ptr<std::mt19937> random_engine_;
std::shared_ptr<std::uniform_real_distribution<>> real_dist_; std::shared_ptr<std::uniform_real_distribution<>> real_dist_;
std::shared_ptr<std::uniform_int_distribution<>> int_dist_; std::shared_ptr<std::uniform_int_distribution<>> int_dist_;
......
...@@ -119,6 +119,11 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -119,6 +119,11 @@ class NCEKernel : public framework::OpKernel<T> {
PrepareSamples<DeviceContext, T>(context, sampler); PrepareSamples<DeviceContext, T>(context, sampler);
auto sample_labels = context.Output<Tensor>("SampleLabels"); auto sample_labels = context.Output<Tensor>("SampleLabels");
const int64_t *sample_labels_data = sample_labels->data<int64_t>(); const int64_t *sample_labels_data = sample_labels->data<int64_t>();
for (int x = 0; x < sample_labels->numel(); x++) {
PADDLE_ENFORCE_GE(sample_labels_data[x], 0, "nce sample label %d", x);
}
auto sample_out = context.Output<Tensor>("SampleLogits"); auto sample_out = context.Output<Tensor>("SampleLogits");
T *sample_out_data = sample_out->mutable_data<T>(context.GetPlace()); T *sample_out_data = sample_out->mutable_data<T>(context.GetPlace());
auto label = context.Input<Tensor>("Label"); auto label = context.Input<Tensor>("Label");
......
...@@ -5146,9 +5146,9 @@ def nce(input, ...@@ -5146,9 +5146,9 @@ def nce(input,
littles = [] littles = []
for i in range(custom_dist_len): for i in range(custom_dist_len):
normal_prob = custom_dist[i] * custom_dist_len normal_prob = custom_dist[i] * custom_dist_len
if normal_prob - 1.0 > 1e-4: if normal_prob - 1.0 > 0:
bigs.append((i, normal_prob)) bigs.append((i, normal_prob))
elif 1.0 - normal_prob > 1e-4: elif 1.0 - normal_prob > 0:
littles.append((i, normal_prob)) littles.append((i, normal_prob))
else: else:
alias_probs_[i] = normal_prob alias_probs_[i] = normal_prob
...@@ -5164,9 +5164,9 @@ def nce(input, ...@@ -5164,9 +5164,9 @@ def nce(input,
alias_probs_[little[0]] = little[1] alias_probs_[little[0]] = little[1]
alias_[little[0]] = big_idx alias_[little[0]] = big_idx
big_left = big[1] + little[1] - 1 big_left = big[1] + little[1] - 1
if big_left - 1.0 > 1e-4: if big_left - 1.0 > 0:
bigs.append((big_idx, big_left)) bigs.append((big_idx, big_left))
elif 1.0 - big_left > 1e-4: elif 1.0 - big_left > 0:
littles.append((big_idx, big_left)) littles.append((big_idx, big_left))
else: else:
alias_probs_[big_idx] = big_left alias_probs_[big_idx] = big_left
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册