提交 08b73b68 编写于 作者: Z zenghsh3

fix bug of sampling_id_op

上级 823c4f87
...@@ -53,8 +53,13 @@ class SamplingIdKernel : public framework::OpKernel<T> { ...@@ -53,8 +53,13 @@ class SamplingIdKernel : public framework::OpKernel<T> {
static_cast<T>(context.Attr<float>("min")), static_cast<T>(context.Attr<float>("min")),
static_cast<T>(context.Attr<float>("max"))); static_cast<T>(context.Attr<float>("max")));
<<<<<<< HEAD
std::vector<int64_t> ids(batch_size);
for (size_t i = 0; i < batch_size; ++i) {
=======
std::vector<T> ids(batch_size); std::vector<T> ids(batch_size);
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
>>>>>>> 823c4f87beff04e4029e3f4a183658621ca8f01b
T r = dist(engine); T r = dist(engine);
int idx = width - 1; int idx = width - 1;
for (int j = 0; j < width; ++j) { for (int j = 0; j < width; ++j) {
...@@ -63,7 +68,11 @@ class SamplingIdKernel : public framework::OpKernel<T> { ...@@ -63,7 +68,11 @@ class SamplingIdKernel : public framework::OpKernel<T> {
break; break;
} }
} }
<<<<<<< HEAD
ids[i] = int64_t(idx);
=======
ids[i] = ins_vector[idx]; ids[i] = ins_vector[idx];
>>>>>>> 823c4f87beff04e4029e3f4a183658621ca8f01b
} }
std::vector<int64_t> out_dim; std::vector<int64_t> out_dim;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册