提交 4cd504d3 编写于 作者: T tangwei12

bug fix

上级 da2cc99f
......@@ -13,7 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <iostream>
#include <iterator>
#include <random>
#include <sstream>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -34,17 +38,17 @@ class SamplingIdKernel : public framework::OpKernel<T> {
std::vector<T> ins_vector;
framework::TensorToVector(*input, context.device_context(), &ins_vector);
std::vector<int> ids(batch_size);
std::vector<T> ids(batch_size);
for (size_t i = 0; i < batch_size; ++i) {
double r = this->get_rand();
int id = width - 1;
int idx = width - 1;
for (int j = 0; j < width; ++j) {
if ((r -= ins_vector[i * width + j]) < 0) {
id = j;
idx = j;
break;
}
}
ids[i] = id;
ids[i] = ins_vector[i * width + idx];
}
std::vector<int64_t> out_dim;
......
......@@ -25,17 +25,18 @@ class TestSamplingIdOp(OpTest):
self.op_type = "sampling_id"
self.use_mkldnn = False
self.init_kernel_type()
X = np.random.random((3, 4)).astype('float32')
self.inputs = {"X": X}
Y = np.random.random(3).astype('float32')
self.outputs = {'Out': Y}
self.X = np.random.random((8, 4)).astype('float32')
self.inputs = {"X": self.X}
self.Y = np.random.random(8).astype('float32')
self.outputs = {'Out': self.Y}
self.attrs = {'use_mkldnn': self.use_mkldnn}
def test_check_output(self):
self.check_output()
self.check_output_customized(self.verify_output)
def test_check_grad(self):
self.check_grad(['X'], 'Out')
def verify_output(self, outs):
out = np.array(outs[0])
self.assertEqual(len(out), len(self.Y))
def init_kernel_type(self):
pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册