提交 d7495838 编写于 作者: Z zenghsh3

refine

上级 04a05d1d
...@@ -54,7 +54,7 @@ class SamplingIdKernel : public framework::OpKernel<T> { ...@@ -54,7 +54,7 @@ class SamplingIdKernel : public framework::OpKernel<T> {
static_cast<T>(context.Attr<float>("max"))); static_cast<T>(context.Attr<float>("max")));
std::vector<int64_t> ids(batch_size); std::vector<int64_t> ids(batch_size);
for (size_t i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
T r = dist(engine); T r = dist(engine);
int idx = width - 1; int idx = width - 1;
for (int j = 0; j < width; ++j) { for (int j = 0; j < width; ++j) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册