Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindarmour
提交
b8732097
M
mindarmour
项目概览
MindSpore
/
mindarmour
通知
4
Star
2
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindarmour
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b8732097
编写于
5月 21, 2020
作者:
Z
ZhidanLiu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
solve DI [MS][MindArmour][Doc] some example of mindarmour need added and useage is clears
https://gitee.com/mindspore/dashboard?issue_id=I1GSTN
上级
e585e2b0
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
83 addition
and
14 deletion
+83
-14
mindarmour/attacks/gradient_method.py
mindarmour/attacks/gradient_method.py
+24
-13
mindarmour/detectors/mag_net.py
mindarmour/detectors/mag_net.py
+23
-0
mindarmour/fuzzing/model_coverage_metrics.py
mindarmour/fuzzing/model_coverage_metrics.py
+11
-1
mindarmour/utils/util.py
mindarmour/utils/util.py
+25
-0
未找到文件。
mindarmour/attacks/gradient_method.py
浏览文件 @
b8732097
...
...
@@ -47,6 +47,12 @@ class GradientMethod(Attack):
bounds (tuple): Upper and lower bounds of data, indicating the data range.
In form of (clip_min, clip_max). Default: None.
loss_fn (Loss): Loss function for optimization. Default: None.
Examples:
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
>>> attack = FastGradientMethod(network)
>>> adv_x = attack.generate(inputs, labels)
"""
def
__init__
(
self
,
network
,
eps
=
0.07
,
alpha
=
None
,
bounds
=
None
,
...
...
@@ -84,11 +90,6 @@ class GradientMethod(Attack):
Returns:
numpy.ndarray, generated adversarial examples.
Examples:
>>> adv_x = attack.generate([[0.1, 0.2, 0.6], [0.3, 0, 0.4]],
>>> [[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],[0, , 0, 1, 0, 0, 0, 0, 0, 0,
>>> 0]])
"""
inputs
,
labels
=
check_pair_numpy_param
(
'inputs'
,
inputs
,
'labels'
,
labels
)
...
...
@@ -154,7 +155,10 @@ class FastGradientMethod(GradientMethod):
loss_fn (Loss): Loss function for optimization. Default: None.
Examples:
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
>>> attack = FastGradientMethod(network)
>>> adv_x = attack.generate(inputs, labels)
"""
def
__init__
(
self
,
network
,
eps
=
0.07
,
alpha
=
None
,
bounds
=
(
0.0
,
1.0
),
...
...
@@ -178,10 +182,6 @@ class FastGradientMethod(GradientMethod):
Returns:
numpy.ndarray, gradient of inputs.
Examples:
>>> grad = self._gradient([[0.2, 0.3, 0.4]],
>>> [[0, 1, 0, 0, 0, 0, 0, 0, 0, 0])
"""
out_grad
=
self
.
_grad_all
(
Tensor
(
inputs
),
Tensor
(
labels
))
if
isinstance
(
out_grad
,
tuple
):
...
...
@@ -219,7 +219,10 @@ class RandomFastGradientMethod(FastGradientMethod):
ValueError: eps is smaller than alpha!
Examples:
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
>>> attack = RandomFastGradientMethod(network)
>>> adv_x = attack.generate(inputs, labels)
"""
def
__init__
(
self
,
network
,
eps
=
0.07
,
alpha
=
0.035
,
bounds
=
(
0.0
,
1.0
),
...
...
@@ -257,7 +260,10 @@ class FastGradientSignMethod(GradientMethod):
loss_fn (Loss): Loss function for optimization. Default: None.
Examples:
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
>>> attack = FastGradientSignMethod(network)
>>> adv_x = attack.generate(inputs, labels)
"""
def
__init__
(
self
,
network
,
eps
=
0.07
,
alpha
=
None
,
bounds
=
(
0.0
,
1.0
),
...
...
@@ -280,10 +286,6 @@ class FastGradientSignMethod(GradientMethod):
Returns:
numpy.ndarray, gradient of inputs.
Examples:
>>> grad = self._gradient([[0.2, 0.3, 0.4]],
>>> [[0, 1, 0, 0, 0, 0, 0, 0, 0, 0])
"""
out_grad
=
self
.
_grad_all
(
Tensor
(
inputs
),
Tensor
(
labels
))
if
isinstance
(
out_grad
,
tuple
):
...
...
@@ -318,7 +320,10 @@ class RandomFastGradientSignMethod(FastGradientSignMethod):
ValueError: eps is smaller than alpha!
Examples:
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
>>> attack = RandomFastGradientSignMethod(network)
>>> adv_x = attack.generate(inputs, labels)
"""
def
__init__
(
self
,
network
,
eps
=
0.07
,
alpha
=
0.035
,
bounds
=
(
0.0
,
1.0
),
...
...
@@ -351,7 +356,10 @@ class LeastLikelyClassMethod(FastGradientSignMethod):
loss_fn (Loss): Loss function for optimization. Default: None.
Examples:
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
>>> attack = LeastLikelyClassMethod(network)
>>> adv_x = attack.generate(inputs, labels)
"""
def
__init__
(
self
,
network
,
eps
=
0.07
,
alpha
=
None
,
bounds
=
(
0.0
,
1.0
),
...
...
@@ -385,7 +393,10 @@ class RandomLeastLikelyClassMethod(FastGradientSignMethod):
ValueError: eps is smaller than alpha!
Examples:
>>> inputs = np.array([[0.1, 0.2, 0.6], [0.3, 0, 0.4]])
>>> labels = np.array([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0]])
>>> attack = RandomLeastLikelyClassMethod(network)
>>> adv_x = attack.generate(inputs, labels)
"""
def
__init__
(
self
,
network
,
eps
=
0.07
,
alpha
=
0.035
,
bounds
=
(
0.0
,
1.0
),
...
...
mindarmour/detectors/mag_net.py
浏览文件 @
b8732097
...
...
@@ -47,6 +47,16 @@ class ErrorBasedDetector(Detector):
Default: 0.01.
bounds (tuple): (clip_min, clip_max). Default: (0.0, 1.0).
Examples:
>>> np.random.seed(5)
>>> ori = np.random.rand(4, 4, 4).astype(np.float32)
>>> np.random.seed(6)
>>> adv = np.random.rand(4, 4, 4).astype(np.float32)
>>> model = Model(Net())
>>> detector = ErrorBasedDetector(model)
>>> detector.fit(ori)
>>> detected_res = detector.detect(adv)
>>> adv_trans = detector.transform(adv)
"""
def
__init__
(
self
,
auto_encoder
,
false_positive_rate
=
0.01
,
...
...
@@ -159,6 +169,19 @@ class DivergenceBasedDetector(ErrorBasedDetector):
t (int): Temperature used to overcome numerical problem. Default: 1.
bounds (tuple): Upper and lower bounds of data.
In form of (clip_min, clip_max). Default: (0.0, 1.0).
Examples:
>>> np.random.seed(5)
>>> ori = np.random.rand(4, 4, 4).astype(np.float32)
>>> np.random.seed(6)
>>> adv = np.random.rand(4, 4, 4).astype(np.float32)
>>> encoder = Model(Net())
>>> model = Model(PredNet())
>>> detector = DivergenceBasedDetector(encoder, model)
>>> threshold = detector.fit(ori)
>>> detector.set_threshold(threshold)
>>> detected_res = detector.detect(adv)
>>> adv_trans = detector.transform(adv)
"""
def
__init__
(
self
,
auto_encoder
,
model
,
option
=
"jsd"
,
...
...
mindarmour/fuzzing/model_coverage_metrics.py
浏览文件 @
b8732097
...
...
@@ -37,6 +37,16 @@ class ModelCoverageMetrics:
n (int): The number of testing neurons.
train_dataset (numpy.ndarray): Training dataset used for determine
the neurons' output boundaries.
Examples:
>>> train_images = np.random.random((10000, 128)).astype(np.float32)
>>> test_images = np.random.random((5000, 128)).astype(np.float32)
>>> model = Model(net)
>>> model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images)
>>> model_fuzz_test.test_adequacy_coverage_calculate(test_images)
>>> print('KMNC of this test is : %s', model_fuzz_test.get_kmnc())
>>> print('NBC of this test is : %s', model_fuzz_test.get_nbc())
>>> print('SNAC of this test is : %s', model_fuzz_test.get_snac())
"""
def
__init__
(
self
,
model
,
k
,
n
,
train_dataset
):
...
...
@@ -163,7 +173,7 @@ class ModelCoverageMetrics:
Get the metric of 'strong neuron activation coverage'.
Returns:
float
:
the metric of 'strong neuron activation coverage'.
float
,
the metric of 'strong neuron activation coverage'.
Examples:
>>> model_fuzz_test.get_snac()
...
...
mindarmour/utils/util.py
浏览文件 @
b8732097
...
...
@@ -92,6 +92,18 @@ class GradWrapWithLoss(Cell):
"""
Construct a network to compute the gradient of loss function in input space
and weighted by `weight`.
Args:
network (Cell): The target network to wrap.
Examples:
>>> data = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32)*0.01)
>>> label = Tensor(np.ones([1, 10]).astype(np.float32))
>>> net = NET()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
>>> loss_net = WithLossCell(net, loss_fn)
>>> grad_all = GradWrapWithLoss(loss_net)
>>> out_grad = grad_all(data, labels)
"""
def
__init__
(
self
,
network
):
...
...
@@ -120,6 +132,19 @@ class GradWrap(Cell):
"""
Construct a network to compute the gradient of network outputs in input
space and weighted by `weight`, expressed as a jacobian matrix.
Args:
network (Cell): The target network to wrap.
Examples:
>>> data = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32)*0.01)
>>> label = Tensor(np.ones([1, 10]).astype(np.float32))
>>> num_classes = 10
>>> sens = np.zeros((data.shape[0], num_classes)).astype(np.float32)
>>> sens[:, 1] = 1.0
>>> net = NET()
>>> wrap_net = GradWrap(net)
>>> wrap_net(data, Tensor(sens))
"""
def
__init__
(
self
,
network
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录