提交 aa113c33 编写于 作者: P pkuliuliu

remove param name of GradOperation; and change predict logic of detector to avoid Memory Overriding

上级 d99219c2
...@@ -121,8 +121,8 @@ class ErrorBasedDetector(Detector): ...@@ -121,8 +121,8 @@ class ErrorBasedDetector(Detector):
float, the distance between reconstructed and original samples. float, the distance between reconstructed and original samples.
""" """
inputs = check_numpy_param('inputs', inputs) inputs = check_numpy_param('inputs', inputs)
x_trans = self._auto_encoder.predict(Tensor(inputs)) x_trans = self._auto_encoder.predict(Tensor(inputs)).asnumpy()
diff = np.abs(inputs - x_trans.asnumpy()) diff = np.abs(inputs - x_trans)
dims = tuple(np.arange(len(inputs.shape))[1:]) dims = tuple(np.arange(len(inputs.shape))[1:])
marks = np.mean(np.power(diff, 2), axis=dims) marks = np.mean(np.power(diff, 2), axis=dims)
return marks return marks
...@@ -138,10 +138,10 @@ class ErrorBasedDetector(Detector): ...@@ -138,10 +138,10 @@ class ErrorBasedDetector(Detector):
numpy.ndarray, reconstructed images. numpy.ndarray, reconstructed images.
""" """
inputs = check_numpy_param('inputs', inputs) inputs = check_numpy_param('inputs', inputs)
x_trans = self._auto_encoder.predict(Tensor(inputs)) x_trans = self._auto_encoder.predict(Tensor(inputs)).asnumpy()
if self._bounds is not None: if self._bounds is not None:
clip_min, clip_max = self._bounds clip_min, clip_max = self._bounds
x_trans = np.clip(x_trans.asnumpy(), clip_min, clip_max) x_trans = np.clip(x_trans, clip_min, clip_max)
return x_trans return x_trans
def set_threshold(self, threshold): def set_threshold(self, threshold):
...@@ -214,12 +214,12 @@ class DivergenceBasedDetector(ErrorBasedDetector): ...@@ -214,12 +214,12 @@ class DivergenceBasedDetector(ErrorBasedDetector):
""" """
inputs = check_numpy_param('inputs', inputs) inputs = check_numpy_param('inputs', inputs)
x_len = inputs.shape[0] x_len = inputs.shape[0]
x_transformed = self._auto_encoder.predict(Tensor(inputs)) x_transformed = self._auto_encoder.predict(Tensor(inputs)).asnumpy()
x_origin = self._model.predict(Tensor(inputs)) x_origin = self._model.predict(Tensor(inputs)).asnumpy()
x_trans = self._model.predict(x_transformed) x_trans = self._model.predict(Tensor(x_transformed)).asnumpy()
y_pred = softmax(x_origin.asnumpy() / self._t, axis=1) y_pred = softmax(x_origin / self._t, axis=1)
y_trans_pred = softmax(x_trans.asnumpy() / self._t, axis=1) y_trans_pred = softmax(x_trans / self._t, axis=1)
if self._option == 'jsd': if self._option == 'jsd':
marks = [_jsd(y_pred[i], y_trans_pred[i]) for i in range(x_len)] marks = [_jsd(y_pred[i], y_trans_pred[i]) for i in range(x_len)]
......
...@@ -83,10 +83,10 @@ class SpatialSmoothing(Detector): ...@@ -83,10 +83,10 @@ class SpatialSmoothing(Detector):
as positive, i.e. adversarial. as positive, i.e. adversarial.
""" """
inputs = check_numpy_param('inputs', inputs) inputs = check_numpy_param('inputs', inputs)
raw_pred = self._model.predict(Tensor(inputs)) raw_pred = self._model.predict(Tensor(inputs)).asnumpy()
smoothing_pred = self._model.predict(Tensor(self.transform(inputs))) smoothing_pred = self._model.predict(Tensor(self.transform(inputs))).asnumpy()
dist = self._dist(raw_pred.asnumpy(), smoothing_pred.asnumpy()) dist = self._dist(raw_pred, smoothing_pred)
index = int(len(dist)*(1 - self._fpr)) index = int(len(dist)*(1 - self._fpr))
threshold = np.sort(dist, axis=None)[index] threshold = np.sort(dist, axis=None)[index]
self._threshold = threshold self._threshold = threshold
...@@ -104,9 +104,9 @@ class SpatialSmoothing(Detector): ...@@ -104,9 +104,9 @@ class SpatialSmoothing(Detector):
input sample with index i is adversarial. input sample with index i is adversarial.
""" """
inputs = check_numpy_param('inputs', inputs) inputs = check_numpy_param('inputs', inputs)
raw_pred = self._model.predict(Tensor(inputs)) raw_pred = self._model.predict(Tensor(inputs)).asnumpy()
smoothing_pred = self._model.predict(Tensor(self.transform(inputs))) smoothing_pred = self._model.predict(Tensor(self.transform(inputs))).asnumpy()
dist = self._dist(raw_pred.asnumpy(), smoothing_pred.asnumpy()) dist = self._dist(raw_pred, smoothing_pred)
res = [0]*len(dist) res = [0]*len(dist)
for i, elem in enumerate(dist): for i, elem in enumerate(dist):
...@@ -127,9 +127,9 @@ class SpatialSmoothing(Detector): ...@@ -127,9 +127,9 @@ class SpatialSmoothing(Detector):
float, distance. float, distance.
""" """
inputs = check_numpy_param('inputs', inputs) inputs = check_numpy_param('inputs', inputs)
raw_pred = self._model.predict(Tensor(inputs)) raw_pred = self._model.predict(Tensor(inputs)).asnumpy()
smoothing_pred = self._model.predict(Tensor(self.transform(inputs))) smoothing_pred = self._model.predict(Tensor(self.transform(inputs))).asnumpy()
dist = self._dist(raw_pred.asnumpy(), smoothing_pred.asnumpy()) dist = self._dist(raw_pred, smoothing_pred)
return dist return dist
def transform(self, inputs): def transform(self, inputs):
......
...@@ -383,7 +383,7 @@ class _TrainOneStepWithLossScaleCell(Cell): ...@@ -383,7 +383,7 @@ class _TrainOneStepWithLossScaleCell(Cell):
self.network.add_flags(defer_inline=True) self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
if context.get_context("device_target") == "GPU": if context.get_context("device_target") == "GPU":
self.gpu_target = True self.gpu_target = True
...@@ -602,7 +602,7 @@ class _TrainOneStepCell(Cell): ...@@ -602,7 +602,7 @@ class _TrainOneStepCell(Cell):
self.network.add_flags(defer_inline=True) self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
self.reducer_flag = False self.reducer_flag = False
self.grad_reducer = None self.grad_reducer = None
......
...@@ -108,9 +108,7 @@ class GradWrapWithLoss(Cell): ...@@ -108,9 +108,7 @@ class GradWrapWithLoss(Cell):
def __init__(self, network): def __init__(self, network):
super(GradWrapWithLoss, self).__init__() super(GradWrapWithLoss, self).__init__()
self._grad_all = GradOperation(name="get_all", self._grad_all = GradOperation(get_all=True, sens_param=False)
get_all=True,
sens_param=False)
self._network = network self._network = network
def construct(self, inputs, labels): def construct(self, inputs, labels):
...@@ -149,8 +147,7 @@ class GradWrap(Cell): ...@@ -149,8 +147,7 @@ class GradWrap(Cell):
def __init__(self, network): def __init__(self, network):
super(GradWrap, self).__init__() super(GradWrap, self).__init__()
self.grad = GradOperation(name="grad", get_all=False, self.grad = GradOperation(get_all=False, sens_param=True)
sens_param=True)
self.network = network self.network = network
def construct(self, inputs, weight): def construct(self, inputs, weight):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册