Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindarmour
提交
03b70823
M
mindarmour
项目概览
MindSpore
/
mindarmour
通知
4
Star
2
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindarmour
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
03b70823
编写于
5月 25, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 25, 2020
浏览文件
操作
浏览文件
下载
差异文件
!22 Add the Monitoring module for Differential privacy training.
Merge pull request !22 from jxlang910/master
上级
0268e2c6
c0433207
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
548 addition
and
0 deletion
+548
-0
mindarmour/diff_privacy/__init__.py
mindarmour/diff_privacy/__init__.py
+0
-0
mindarmour/diff_privacy/monitor/__init__.py
mindarmour/diff_privacy/monitor/__init__.py
+0
-0
mindarmour/diff_privacy/monitor/monitor.py
mindarmour/diff_privacy/monitor/monitor.py
+418
-0
tests/ut/python/diff_privacy/test_monitor.py
tests/ut/python/diff_privacy/test_monitor.py
+130
-0
未找到文件。
mindarmour/diff_privacy/__init__.py
0 → 100644
浏览文件 @
03b70823
mindarmour/diff_privacy/monitor/__init__.py
0 → 100644
浏览文件 @
03b70823
mindarmour/diff_privacy/monitor/monitor.py
0 → 100644
浏览文件 @
03b70823
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Monitor module of differential privacy training. """
import
math
import
numpy
as
np
from
scipy
import
special
from
mindspore.train.callback
import
Callback
from
mindarmour.utils.logger
import
LogUtil
from
mindarmour.utils._check_param
import
check_int_positive
,
\
check_value_positive
LOGGER
=
LogUtil
.
get_instance
()
TAG
=
'DP monitor'
class
PrivacyMonitorFactory
:
"""
Factory class of DP training's privacy monitor.
"""
def
__init__
(
self
):
pass
@
staticmethod
def
create
(
policy
,
*
args
,
**
kwargs
):
"""
Create a privacy monitor class.
Args:
policy (str): Monitor policy, 'rdp' is supported by now.
args (Union[int, float, numpy.ndarray, list, str]): Parameters
used for creating a privacy monitor.
kwargs (Union[int, float, numpy.ndarray, list, str]): Keyword
parameters used for creating a privacy monitor.
Returns:
PrivacyMonitor, a privacy monitor.
Examples:
>>> rdp = PrivacyMonitorFactory.create(policy='rdp',
>>> num_samples=60000, batch_size=32)
"""
if
policy
==
'rdp'
:
return
RDPMonitor
(
*
args
,
**
kwargs
)
raise
ValueError
(
"Only RDP-policy is supported by now"
)
class
RDPMonitor
(
Callback
):
"""
Compute the privacy budget of DP training based on Renyi differential
privacy theory.
Reference: `Rényi Differential Privacy of the Sampled Gaussian Mechanism
<https://arxiv.org/abs/1908.10530>`_
Args:
num_samples (int): The total number of samples in training data sets.
batch_size (int): The number of samples in a batch while training.
initial_noise_multiplier (Union[float, int]): The initial
multiplier of added noise. Default: 0.4.
max_eps (Union[float, int, None]): The maximum acceptable epsilon
budget for DP training. Default: 3.0.
target_delta (Union[float, int, None]): Target delta budget for DP
training. Default: 1e-5.
max_delta (Union[float, int, None]): The maximum acceptable delta
budget for DP training. Max_delta must be less than 1 and
suggested to be less than 1e-3, otherwise overflow would be
encountered. Default: None.
target_eps (Union[float, int, None]): Target epsilon budget for DP
training. Default: None.
orders (Union[None, list[int, float]]): Finite orders used for
computing rdp, which must be greater than 1.
noise_decay_mode (str): Decay mode of adding noise while training,
which can be 'no_decay', 'time' or 'step'. Default: 'step'.
noise_decay_rate (Union[float, None]): Decay rate of noise while
training. Default: 6e-4.
per_print_times (int): The interval steps of computing and printing
the privacy budget. Default: 50.
Examples:
>>> rdp = PrivacyMonitorFactory.create(policy='rdp',
>>> num_samples=60000, batch_size=32)
>>> network = Net()
>>> net_loss = nn.SoftmaxCrossEntropyWithLogits()
>>> net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
>>> model = Model(network, net_loss, net_opt)
>>> model.train(epochs, ds, callbacks=[rdp], dataset_sink_mode=False)
"""
def
__init__
(
self
,
num_samples
,
batch_size
,
initial_noise_multiplier
=
0.4
,
max_eps
=
3.0
,
target_delta
=
1e-5
,
max_delta
=
None
,
target_eps
=
None
,
orders
=
None
,
noise_decay_mode
=
'step'
,
noise_decay_rate
=
6e-4
,
per_print_times
=
50
):
super
(
RDPMonitor
,
self
).
__init__
()
check_int_positive
(
'num_samples'
,
num_samples
)
check_int_positive
(
'batch_size'
,
batch_size
)
if
batch_size
>=
num_samples
:
msg
=
'Batch_size must be less than num_samples.'
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
check_value_positive
(
'initial_noise_multiplier'
,
initial_noise_multiplier
)
if
max_eps
is
not
None
:
check_value_positive
(
'max_eps'
,
max_eps
)
if
target_delta
is
not
None
:
check_value_positive
(
'target_delta'
,
target_delta
)
if
max_delta
is
not
None
:
check_value_positive
(
'max_delta'
,
max_delta
)
if
max_delta
>=
1
:
msg
=
'max_delta must be less than 1.'
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
if
target_eps
is
not
None
:
check_value_positive
(
'target_eps'
,
target_eps
)
if
orders
is
not
None
:
for
item
in
orders
:
check_value_positive
(
'order'
,
item
)
if
item
<=
1
:
msg
=
'orders must be greater than 1'
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
if
noise_decay_mode
not
in
(
'no_decay'
,
'step'
,
'time'
):
msg
=
'Noise decay mode must be in (no_decay, step, time)'
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
if
noise_decay_rate
is
not
None
:
check_value_positive
(
'noise_decay_rate'
,
noise_decay_rate
)
if
noise_decay_rate
>=
1
:
msg
=
'Noise decay rate must be less than 1'
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
check_int_positive
(
'per_print_times'
,
per_print_times
)
self
.
_total_echo_privacy
=
None
self
.
_num_samples
=
num_samples
self
.
_batch_size
=
batch_size
self
.
_initial_noise_multiplier
=
initial_noise_multiplier
self
.
_max_eps
=
max_eps
self
.
_target_delta
=
target_delta
self
.
_max_delta
=
max_delta
self
.
_target_eps
=
target_eps
self
.
_orders
=
orders
self
.
_noise_decay_mode
=
noise_decay_mode
self
.
_noise_decay_rate
=
noise_decay_rate
self
.
_rdp
=
0
self
.
_per_print_times
=
per_print_times
def
max_epoch_suggest
(
self
):
"""
Estimate the maximum training epochs to satisfy the predefined
privacy budget.
Returns:
int, the recommended maximum training epochs.
Examples:
>>> rdp = PrivacyMonitorFactory.create(policy='rdp',
>>> num_samples=60000, batch_size=32)
>>> suggest_epoch = rdp.max_epoch_suggest()
"""
epoch
=
1
while
epoch
<
10000
:
steps
=
self
.
_num_samples
//
self
.
_batch_size
eps
,
delta
=
self
.
_compute_privacy_steps
(
list
(
np
.
arange
((
epoch
-
1
)
*
steps
,
epoch
*
steps
+
1
)))
if
self
.
_max_eps
is
not
None
:
if
eps
<=
self
.
_max_eps
:
epoch
+=
1
else
:
break
if
self
.
_max_delta
is
not
None
:
if
delta
<=
self
.
_max_delta
:
epoch
+=
1
else
:
break
self
.
_rdp
=
0
return
epoch
def
step_end
(
self
,
run_context
):
"""
Compute privacy budget after each training step.
Args:
run_context (RunContext): Include some information of the model.
"""
cb_params
=
run_context
.
original_args
()
cur_step
=
cb_params
.
cur_step_num
cur_step_in_epoch
=
(
cb_params
.
cur_step_num
-
1
)
%
\
cb_params
.
batch_num
+
1
if
cb_params
.
cur_step_num
%
self
.
_per_print_times
==
0
:
steps
=
np
.
arange
(
cur_step
-
self
.
_per_print_times
,
cur_step
+
1
)
eps
,
delta
=
self
.
_compute_privacy_steps
(
list
(
steps
))
if
np
.
isnan
(
eps
)
or
np
.
isinf
(
eps
)
or
np
.
isnan
(
delta
)
or
np
.
isinf
(
delta
):
msg
=
'epoch: {} step: {}, invalid eps, terminating '
\
'training.'
.
format
(
cb_params
.
cur_epoch_num
,
cur_step_in_epoch
)
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
if
np
.
isnan
(
delta
)
or
np
.
isinf
(
delta
):
msg
=
'epoch: {} step: {}, invalid delta, terminating '
\
'training.'
.
format
(
cb_params
.
cur_epoch_num
,
cur_step_in_epoch
)
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
print
(
"epoch: %s step: %s, delta is %s, eps is %s"
%
(
cb_params
.
cur_epoch_num
,
cur_step_in_epoch
,
delta
,
eps
))
def
_compute_privacy_steps
(
self
,
steps
):
"""
Compute privacy budget corresponding to steps.
Args:
steps (list): Training steps.
Returns:
float, privacy budget.
"""
if
self
.
_target_eps
is
None
and
self
.
_target_delta
is
None
:
msg
=
'target eps and target delta cannot both be None'
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
if
self
.
_target_eps
is
not
None
and
self
.
_target_delta
is
not
None
:
msg
=
'One of target eps and target delta must be None'
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
if
self
.
_orders
is
None
:
self
.
_orders
=
(
[
1.005
,
1.01
,
1.02
,
1.08
,
1.2
,
2
,
5
,
10
,
20
,
40
,
80
])
sampling_rate
=
self
.
_batch_size
/
self
.
_num_samples
noise_step
=
self
.
_initial_noise_multiplier
if
self
.
_noise_decay_mode
==
'no_decay'
:
self
.
_rdp
+=
self
.
_compute_rdp
(
sampling_rate
,
noise_step
)
*
len
(
steps
)
else
:
if
self
.
_noise_decay_rate
is
None
:
msg
=
'noise_decay_rate in decay-mode cannot be None'
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
if
self
.
_noise_decay_mode
==
'time'
:
noise_step
=
[
self
.
_initial_noise_multiplier
/
(
1
+
self
.
_noise_decay_rate
*
step
)
for
step
in
steps
]
elif
self
.
_noise_decay_mode
==
'step'
:
noise_step
=
[
self
.
_initial_noise_multiplier
*
(
1
-
self
.
_noise_decay_rate
)
**
step
for
step
in
steps
]
self
.
_rdp
+=
sum
(
[
self
.
_compute_rdp
(
sampling_rate
,
noise
)
for
noise
in
noise_step
])
eps
,
delta
=
self
.
_compute_privacy_budget
(
self
.
_rdp
)
return
eps
,
delta
def
_compute_rdp
(
self
,
q
,
noise
):
"""
Compute rdp according to sampling rate, added noise and Renyi
divergence orders.
Args:
q (float): Sampling rate of each batch of samples.
noise (float): Noise multiplier.
Returns:
float or numpy.ndarray, rdp values.
"""
rdp
=
np
.
array
(
[
_compute_rdp_order
(
q
,
noise
,
order
)
for
order
in
self
.
_orders
])
return
rdp
def
_compute_privacy_budget
(
self
,
rdp
):
"""
Compute delta or eps for given rdp.
Args:
rdp (Union[float, numpy.ndarray]): Renyi differential privacy.
Returns:
float, delta budget or eps budget.
"""
if
self
.
_target_eps
is
not
None
:
delta
=
self
.
_compute_delta
(
rdp
)
return
self
.
_target_eps
,
delta
eps
=
self
.
_compute_eps
(
rdp
)
return
eps
,
self
.
_target_delta
def
_compute_delta
(
self
,
rdp
):
"""
Compute delta for given rdp and eps.
Args:
rdp (Union[float, numpy.ndarray]): Renyi differential privacy.
Returns:
float, delta budget.
"""
orders
=
np
.
atleast_1d
(
self
.
_orders
)
rdps
=
np
.
atleast_1d
(
rdp
)
if
len
(
orders
)
!=
len
(
rdps
):
msg
=
'rdp lists and orders list must have the same length.'
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
deltas
=
np
.
exp
((
rdps
-
self
.
_target_eps
)
*
(
orders
-
1
))
min_delta
=
min
(
deltas
)
return
min
(
min_delta
,
1.
)
def
_compute_eps
(
self
,
rdp
):
"""
Compute eps for given rdp and delta.
Args:
rdp (Union[float, numpy.ndarray]): Renyi differential privacy.
Returns:
float, eps budget.
"""
orders
=
np
.
atleast_1d
(
self
.
_orders
)
rdps
=
np
.
atleast_1d
(
rdp
)
if
len
(
orders
)
!=
len
(
rdps
):
msg
=
'rdp lists and orders list must have the same length.'
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
eps
=
rdps
-
math
.
log
(
self
.
_target_delta
)
/
(
orders
-
1
)
return
min
(
eps
)
def
_compute_rdp_order
(
q
,
sigma
,
alpha
):
"""
Compute rdp for each order.
Args:
q (float): Sampling probability.
sigma (float): Noise multiplier.
alpha: The order used for computing rdp.
Returns:
float, rdp value.
"""
if
float
(
alpha
).
is_integer
():
log_integrate
=
-
np
.
inf
for
k
in
range
(
alpha
+
1
):
term_k
=
(
math
.
log
(
special
.
binom
(
alpha
,
k
))
+
k
*
math
.
log
(
q
)
+
(
alpha
-
k
)
*
math
.
log
(
1
-
q
))
+
(
k
*
k
-
k
)
/
(
2
*
(
sigma
**
2
))
log_integrate
=
_log_add
(
log_integrate
,
term_k
)
return
float
(
log_integrate
)
/
(
alpha
-
1
)
log_part_0
,
log_part_1
=
-
np
.
inf
,
-
np
.
inf
k
=
0
z0
=
sigma
**
2
*
math
.
log
(
1
/
q
-
1
)
+
1
/
2
while
True
:
bi_coef
=
special
.
binom
(
alpha
,
k
)
log_coef
=
math
.
log
(
abs
(
bi_coef
))
j
=
alpha
-
k
term_k_part_0
=
log_coef
+
k
*
math
.
log
(
q
)
+
j
*
math
.
log
(
1
-
q
)
+
(
k
*
k
-
k
)
/
(
2
*
(
sigma
**
2
))
+
special
.
log_ndtr
(
(
z0
-
k
)
/
sigma
)
term_k_part_1
=
log_coef
+
j
*
math
.
log
(
q
)
+
k
*
math
.
log
(
1
-
q
)
+
(
j
*
j
-
j
)
/
(
2
*
(
sigma
**
2
))
+
special
.
log_ndtr
(
(
j
-
z0
)
/
sigma
)
if
bi_coef
>
0
:
log_part_0
=
_log_add
(
log_part_0
,
term_k_part_0
)
log_part_1
=
_log_add
(
log_part_1
,
term_k_part_1
)
else
:
log_part_0
=
_log_subtract
(
log_part_0
,
term_k_part_0
)
log_part_1
=
_log_subtract
(
log_part_1
,
term_k_part_1
)
k
+=
1
if
max
(
term_k_part_0
,
term_k_part_1
)
<
-
30
:
break
return
_log_add
(
log_part_0
,
log_part_1
)
/
(
alpha
-
1
)
def
_log_add
(
x
,
y
):
"""
Add x and y in log space.
"""
if
x
==
-
np
.
inf
:
return
y
if
y
==
-
np
.
inf
:
return
x
return
max
(
x
,
y
)
+
math
.
log1p
(
math
.
exp
(
-
abs
(
x
-
y
)))
def
_log_subtract
(
x
,
y
):
"""
Subtract y from x in log space, x must be greater than y.
"""
if
x
<=
y
:
msg
=
'The antilog of log functions must be positive'
LOGGER
.
error
(
TAG
,
msg
)
raise
ValueError
(
msg
)
if
y
==
-
np
.
inf
:
return
x
return
math
.
log1p
(
math
.
exp
(
y
-
x
))
+
x
tests/ut/python/diff_privacy/test_monitor.py
0 → 100644
浏览文件 @
03b70823
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
DP-Monitor test.
"""
import
pytest
import
numpy
as
np
import
mindspore.nn
as
nn
import
mindspore.dataset
as
ds
from
mindspore.train
import
Model
import
mindspore.context
as
context
from
mindspore.model_zoo.lenet
import
LeNet5
from
mindarmour.diff_privacy.monitor.monitor
import
PrivacyMonitorFactory
from
mindarmour.utils.logger
import
LogUtil
LOGGER
=
LogUtil
.
get_instance
()
TAG
=
'DP-Monitor Test'
def
dataset_generator
(
batch_size
,
batches
):
data
=
np
.
random
.
random
((
batches
*
batch_size
,
1
,
32
,
32
)).
astype
(
np
.
float32
)
label
=
np
.
random
.
randint
(
0
,
10
,
batches
*
batch_size
).
astype
(
np
.
int32
)
for
i
in
range
(
batches
):
yield
data
[
i
*
batch_size
:
(
i
+
1
)
*
batch_size
],
\
label
[
i
*
batch_size
:
(
i
+
1
)
*
batch_size
]
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_card
@
pytest
.
mark
.
component_mindarmour
def
test_dp_monitor
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
batch_size
=
16
batches
=
128
epochs
=
1
rdp
=
PrivacyMonitorFactory
.
create
(
policy
=
'rdp'
,
num_samples
=
60000
,
batch_size
=
batch_size
,
initial_noise_multiplier
=
0.4
,
noise_decay_rate
=
6e-5
)
suggest_epoch
=
rdp
.
max_epoch_suggest
()
LOGGER
.
info
(
TAG
,
'The recommended maximum training epochs is: %s'
,
suggest_epoch
)
network
=
LeNet5
()
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
0.01
,
0.9
)
model
=
Model
(
network
,
net_loss
,
net_opt
)
LOGGER
.
info
(
TAG
,
"============== Starting Training =============="
)
ds1
=
ds
.
GeneratorDataset
(
dataset_generator
(
batch_size
,
batches
),
[
"data"
,
"label"
])
ds1
.
set_dataset_size
(
batch_size
*
batches
)
model
.
train
(
epochs
,
ds1
,
callbacks
=
[
rdp
],
dataset_sink_mode
=
False
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_inference
@
pytest
.
mark
.
env_card
@
pytest
.
mark
.
component_mindarmour
def
test_dp_monitor_gpu
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
batch_size
=
16
batches
=
128
epochs
=
1
rdp
=
PrivacyMonitorFactory
.
create
(
policy
=
'rdp'
,
num_samples
=
60000
,
batch_size
=
batch_size
,
initial_noise_multiplier
=
0.4
,
noise_decay_rate
=
6e-5
)
suggest_epoch
=
rdp
.
max_epoch_suggest
()
LOGGER
.
info
(
TAG
,
'The recommended maximum training epochs is: %s'
,
suggest_epoch
)
network
=
LeNet5
()
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
0.01
,
0.9
)
model
=
Model
(
network
,
net_loss
,
net_opt
)
LOGGER
.
info
(
TAG
,
"============== Starting Training =============="
)
ds1
=
ds
.
GeneratorDataset
(
dataset_generator
(
batch_size
,
batches
),
[
"data"
,
"label"
])
ds1
.
set_dataset_size
(
batch_size
*
batches
)
model
.
train
(
epochs
,
ds1
,
callbacks
=
[
rdp
],
dataset_sink_mode
=
False
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
env_card
@
pytest
.
mark
.
component_mindarmour
def
test_dp_monitor_cpu
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"CPU"
)
batch_size
=
16
batches
=
128
epochs
=
1
rdp
=
PrivacyMonitorFactory
.
create
(
policy
=
'rdp'
,
num_samples
=
60000
,
batch_size
=
batch_size
,
initial_noise_multiplier
=
0.4
,
noise_decay_rate
=
6e-5
)
suggest_epoch
=
rdp
.
max_epoch_suggest
()
LOGGER
.
info
(
TAG
,
'The recommended maximum training epochs is: %s'
,
suggest_epoch
)
network
=
LeNet5
()
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
0.01
,
0.9
)
model
=
Model
(
network
,
net_loss
,
net_opt
)
LOGGER
.
info
(
TAG
,
"============== Starting Training =============="
)
ds1
=
ds
.
GeneratorDataset
(
dataset_generator
(
batch_size
,
batches
),
[
"data"
,
"label"
])
ds1
.
set_dataset_size
(
batch_size
*
batches
)
model
.
train
(
epochs
,
ds1
,
callbacks
=
[
rdp
],
dataset_sink_mode
=
False
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录