提交 2d7a6528 编写于 作者: W wanghaox

del framework test_maxout_op

上级 25d76bc7
...@@ -37,7 +37,7 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data, ...@@ -37,7 +37,7 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data,
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx; (batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
T ele = static_cast<T>(-FLT_MAX); T ele = static_cast<T>(-FLT_MAX);
for (int g = 0; g < groups; ++g) { for (int g = 0; g < groups; ++g) {
T x=input_data[data_idx + g * feat_len]; T x = input_data[data_idx + g * feat_len];
ele = ele > x ? ele : x; ele = ele > x ? ele : x;
} }
output_data[i] = ele; output_data[i] = ele;
......
import unittest
import numpy as np
from op_test import OpTest
def maxout_forward_naive(input, groups,num_channels):
s0, s1, s2, s3 = input.shape
return np.ndarray([s0, s1 / groups, groups, s2, s3], \
buffer = input, dtype=input.dtype).max(axis=(2))
class TestMaxOutOp(OpTest):
def setUp(self):
self.op_type = "maxout"
self.init_test_case()
input = np.random.random(self.shape).astype("float32")
output = self.MaxOut_forward_naive(input, self.groups,
self.num_channels).astype("float32")
self.inputs = {'X': input}
self.attrs = {'groups': self.groups, 'num_channels': self.num_channels}
self.outputs = {'Out': output.astype('float32')}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
def init_test_case(self):
self.MaxOut_forward_naive = maxout_forward_naive
self.shape = [100, 6, 2, 2]
self.groups=2
self.num_channels=6
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册