提交 4cd504d3 编写于 作者: T tangwei12

bug fix

上级 da2cc99f
...@@ -13,7 +13,11 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <algorithm>
#include <iostream>
#include <iterator>
#include <random> #include <random>
#include <sstream>
#include <vector> #include <vector>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -34,17 +38,17 @@ class SamplingIdKernel : public framework::OpKernel<T> { ...@@ -34,17 +38,17 @@ class SamplingIdKernel : public framework::OpKernel<T> {
std::vector<T> ins_vector; std::vector<T> ins_vector;
framework::TensorToVector(*input, context.device_context(), &ins_vector); framework::TensorToVector(*input, context.device_context(), &ins_vector);
std::vector<int> ids(batch_size); std::vector<T> ids(batch_size);
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
double r = this->get_rand(); double r = this->get_rand();
int id = width - 1; int idx = width - 1;
for (int j = 0; j < width; ++j) { for (int j = 0; j < width; ++j) {
if ((r -= ins_vector[i * width + j]) < 0) { if ((r -= ins_vector[i * width + j]) < 0) {
id = j; idx = j;
break; break;
} }
} }
ids[i] = id; ids[i] = ins_vector[i * width + idx];
} }
std::vector<int64_t> out_dim; std::vector<int64_t> out_dim;
......
...@@ -25,17 +25,18 @@ class TestSamplingIdOp(OpTest): ...@@ -25,17 +25,18 @@ class TestSamplingIdOp(OpTest):
self.op_type = "sampling_id" self.op_type = "sampling_id"
self.use_mkldnn = False self.use_mkldnn = False
self.init_kernel_type() self.init_kernel_type()
X = np.random.random((3, 4)).astype('float32') self.X = np.random.random((8, 4)).astype('float32')
self.inputs = {"X": X} self.inputs = {"X": self.X}
Y = np.random.random(3).astype('float32') self.Y = np.random.random(8).astype('float32')
self.outputs = {'Out': Y} self.outputs = {'Out': self.Y}
self.attrs = {'use_mkldnn': self.use_mkldnn} self.attrs = {'use_mkldnn': self.use_mkldnn}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output_customized(self.verify_output)
def test_check_grad(self): def verify_output(self, outs):
self.check_grad(['X'], 'Out') out = np.array(outs[0])
self.assertEqual(len(out), len(self.Y))
def init_kernel_type(self): def init_kernel_type(self):
pass pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册