importunittestfromop_test_utilimportOpTestMetaimportnumpyasnpdefstable_softmax(x):"""Compute the softmax of vector x in a numerically stable way."""shiftx=x-np.max(x)exps=np.exp(shiftx)returnexps/np.sum(exps)classTestSoftmaxOp(unittest.TestCase):__metaclass__=OpTestMetadefsetUp(self):self.type="softmax"self.X=np.random.random((32,100)).astype("float32")self.Y=np.apply_along_axis(stable_softmax,1,self.X)if__name__=='__main__':unittest.main()