提交 47c02b5c 编写于 作者: F fengjiayi

Add unit tests

上级 12619fcf
...@@ -113,5 +113,4 @@ REGISTER_OPERATOR(bilinear_interp_grad, ops::BilinearInterpOpGrad); ...@@ -113,5 +113,4 @@ REGISTER_OPERATOR(bilinear_interp_grad, ops::BilinearInterpOpGrad);
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::BilinearInterpKernel<float>, REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::BilinearInterpKernel<float>,
ops::BilinearInterpKernel<uint8_t>); ops::BilinearInterpKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(bilinear_interp_grad, REGISTER_OP_CPU_KERNEL(bilinear_interp_grad,
ops::BilinearInterpGradKernel<float>, ops::BilinearInterpGradKernel<float>);
ops::BilinearInterpGradKernel<uint8_t>);
...@@ -72,10 +72,10 @@ class BilinearInterpKernel : public framework::OpKernel<T> { ...@@ -72,10 +72,10 @@ class BilinearInterpKernel : public framework::OpKernel<T> {
for (int c = 0; c < channels; ++c) { // loop for channels for (int c = 0; c < channels; ++c) { // loop for channels
// bilinear interpolation // bilinear interpolation
out_pos[0] = out_pos[0] = static_cast<T>(
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[wid]) + h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[wid]) +
h1lambda * (w2lambda * in_pos[hid * in_w] + h1lambda * (w2lambda * in_pos[hid * in_w] +
w1lambda * in_pos[hid * in_w + wid]); w1lambda * in_pos[hid * in_w + wid]));
in_pos += in_hw; in_pos += in_hw;
out_pos += out_hw; out_pos += out_hw;
} }
...@@ -143,10 +143,12 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> { ...@@ -143,10 +143,12 @@ class BilinearInterpGradKernel : public framework::OpKernel<T> {
const T* out_pos = &d_output[k * out_chw + i * out_w + j]; const T* out_pos = &d_output[k * out_chw + i * out_w + j];
for (int c = 0; c < channels; ++c) { // loop for channels for (int c = 0; c < channels; ++c) { // loop for channels
in_pos[0] += h2lambda * w2lambda * out_pos[0]; in_pos[0] += static_cast<T>(h2lambda * w2lambda * out_pos[0]);
in_pos[wid] += h2lambda * w1lambda * out_pos[0]; in_pos[wid] += static_cast<T>(h2lambda * w1lambda * out_pos[0]);
in_pos[hid * in_w] += h1lambda * w2lambda * out_pos[0]; in_pos[hid * in_w] +=
in_pos[hid * in_w + wid] += h1lambda * w1lambda * out_pos[0]; static_cast<T>(h1lambda * w2lambda * out_pos[0]);
in_pos[hid * in_w + wid] +=
static_cast<T>(h1lambda * w1lambda * out_pos[0]);
in_pos += in_hw; in_pos += in_hw;
out_pos += out_hw; out_pos += out_hw;
} }
......
...@@ -97,7 +97,7 @@ struct CastToPyBufferImpl<true, I, ARGS...> { ...@@ -97,7 +97,7 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
inline pybind11::buffer_info CastToPyBuffer(const framework::Tensor &tensor) { inline pybind11::buffer_info CastToPyBuffer(const framework::Tensor &tensor) {
auto buffer_info = auto buffer_info =
details::CastToPyBufferImpl<true, 0, float, int, double, int64_t, bool, details::CastToPyBufferImpl<true, 0, float, int, double, int64_t, bool,
platform::float16>()(tensor); uint8_t, platform::float16>()(tensor);
return buffer_info; return buffer_info;
} }
......
...@@ -45,9 +45,9 @@ def bilinear_interp_np(input, out_h, out_w, out_size): ...@@ -45,9 +45,9 @@ def bilinear_interp_np(input, out_h, out_w, out_size):
out[:, :, i, j] = h2lambda*(w2lambda*input[:, :, h, w] + out[:, :, i, j] = h2lambda*(w2lambda*input[:, :, h, w] +
w1lambda*input[:, :, h, w+wid]) + \ w1lambda*input[:, :, h, w+wid]) + \
h1lambda*(w2lambda*input[:, :, h+hid, w] + h1lambda*(w2lambda*input[:, :, h+hid, w] +
w1lambda*input[:, :, h+hid, w+wid]) w1lambda*input[:, :, h+hid, w+wid])
return out.astype("float32") return out.astype(input.dtype)
class TestBilinearInterpOp(OpTest): class TestBilinearInterpOp(OpTest):
...@@ -122,5 +122,44 @@ class TestCase6(TestBilinearInterpOp): ...@@ -122,5 +122,44 @@ class TestCase6(TestBilinearInterpOp):
self.out_size = np.array([65, 129]).astype("int32") self.out_size = np.array([65, 129]).astype("int32")
class TestBilinearInterpOpUint8(OpTest):
def setUp(self):
self.out_size = None
self.init_test_case()
self.op_type = "bilinear_interp"
input_np = np.random.randint(
low=0, high=256, size=self.input_shape).astype("uint8")
output_np = bilinear_interp_np(input_np, self.out_h, self.out_w,
self.out_size)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
self.attrs = {'out_h': self.out_h, 'out_w': self.out_w}
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output(atol=1)
def init_test_case(self):
self.input_shape = [1, 3, 9, 6]
self.out_h = 10
self.out_w = 9
class TestCase1Uint8(TestBilinearInterpOpUint8):
def init_test_case(self):
self.input_shape = [2, 3, 128, 64]
self.out_h = 120
self.out_w = 50
class TestCase2Uint8(TestBilinearInterpOpUint8):
def init_test_case(self):
self.input_shape = [4, 1, 7, 8]
self.out_h = 5
self.out_w = 13
self.out_size = np.array([6, 15]).astype("int32")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册