未验证 提交 5cfc40de 编写于 作者: T tangwei12 提交者: GitHub

nce add check sample lables, test=develop (#15463)

* nce add check sample lables, test=develop
上级 af07118d
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/sampler.h"
#include <glog/logging.h>
#include <iostream>
#include <queue>
#include <utility>
......@@ -77,7 +78,14 @@ int64_t CustomSampler::Sample() const {
auto index = (*int_dist_)(*random_engine_);
auto p = (*real_dist_)(*random_engine_);
if (p > alias_probs_[index]) {
return alias_[index];
int alias = alias_[index];
if (alias == exceptional_val) {
LOG(WARNING) << "WARNING: CustomSampler get alias " << exceptional_val;
return index;
}
return alias;
} else {
return index;
}
......
......@@ -116,6 +116,7 @@ class CustomSampler : public Sampler {
const float* alias_probs_;
const int* alias_;
const float* probs_;
const int exceptional_val = -1;
std::shared_ptr<std::mt19937> random_engine_;
std::shared_ptr<std::uniform_real_distribution<>> real_dist_;
std::shared_ptr<std::uniform_int_distribution<>> int_dist_;
......
......@@ -119,6 +119,11 @@ class NCEKernel : public framework::OpKernel<T> {
PrepareSamples<DeviceContext, T>(context, sampler);
auto sample_labels = context.Output<Tensor>("SampleLabels");
const int64_t *sample_labels_data = sample_labels->data<int64_t>();
for (int x = 0; x < sample_labels->numel(); x++) {
PADDLE_ENFORCE_GE(sample_labels_data[x], 0, "nce sample label %d", x);
}
auto sample_out = context.Output<Tensor>("SampleLogits");
T *sample_out_data = sample_out->mutable_data<T>(context.GetPlace());
auto label = context.Input<Tensor>("Label");
......
......@@ -5146,9 +5146,9 @@ def nce(input,
littles = []
for i in range(custom_dist_len):
normal_prob = custom_dist[i] * custom_dist_len
if normal_prob - 1.0 > 1e-4:
if normal_prob - 1.0 > 0:
bigs.append((i, normal_prob))
elif 1.0 - normal_prob > 1e-4:
elif 1.0 - normal_prob > 0:
littles.append((i, normal_prob))
else:
alias_probs_[i] = normal_prob
......@@ -5164,9 +5164,9 @@ def nce(input,
alias_probs_[little[0]] = little[1]
alias_[little[0]] = big_idx
big_left = big[1] + little[1] - 1
if big_left - 1.0 > 1e-4:
if big_left - 1.0 > 0:
bigs.append((big_idx, big_left))
elif 1.0 - big_left > 1e-4:
elif 1.0 - big_left > 0:
littles.append((big_idx, big_left))
else:
alias_probs_[big_idx] = big_left
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册