提交 103b25ad 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4215 Add uncertainty toolbox to nn/probability

Merge pull request !4215 from zhangxinfeng3/uncer
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Uncertainty toolbox.
"""
from .uncertainty_evaluation import UncertaintyEvaluation
__all__ = ['UncertaintyEvaluation']
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Toolbox for Uncertainty Evaluation."""
import numpy as np
from mindspore._checkparam import check_int_positive
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.train import Model
from mindspore.train.callback import LossMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from ...cell import Cell
from ...layer.basic import Dense, Flatten, Dropout
from ...layer.conv import Conv2d
from ...layer.container import SequentialCell
from ...loss import SoftmaxCrossEntropyWithLogits, MSELoss
from ...metrics import Accuracy, MSE
from ...optim import Adam
class UncertaintyEvaluation:
r"""
Toolbox for Uncertainty Evaluation.
Args:
model (Cell): The model for uncertainty evaluation.
train_dataset (Dataset): A dataset iterator.
task_type (str): Option for the task types of model
- regression: A regression model.
- classification: A classification model.
num_classes (int): The number of labels of classification.
If the task type is classification, it must be set; if not classification, it need not to be set.
Default: None.
epochs (int): Total number of iterations on the data. Default: None.
uncertainty_model_path (str): The save or read path of the uncertainty model.
Examples:
>>> network = LeNet()
>>> param_dict = load_checkpoint('checkpoint_lenet.ckpt')
>>> load_param_into_net(network, param_dict)
>>> ds_train = create_dataset('workspace/mnist/train')
>>> evaluation = UncertaintyEvaluation(model=network,
>>> train_dataset=ds_train,
>>> task_type='classification',
>>> num_classes=10,
>>> epochs=5,
>>> uncertainty_model_path=None)
>>> epistemic_uncertainty = evaluation.eval_epistemic(eval_data)
>>> aleatoric_uncertainty = evaluation.eval_aleatoric(eval_data)
"""
def __init__(self, model, train_dataset, task_type, num_classes=None, epochs=None,
uncertainty_model_path=None):
self.model = model
self.train_dataset = train_dataset
self.task_type = task_type
self.num_classes = check_int_positive(num_classes)
self.epochs = epochs
self.uncer_model_path = uncertainty_model_path
self.uncer_model = None
self.concat = P.Concat(axis=0)
self.sum = P.ReduceSum()
self.pow = P.Pow()
if self.task_type not in ('regression', 'classification'):
raise ValueError('The task should be regression or classification.')
if self.task_type == 'classification':
if self.num_classes is None:
raise ValueError("Classification task needs to input labels.")
def _uncertainty_normalize(self, data):
area = np.max(data) - np.min(data)
return (data - np.min(data)) / area
def _get_epistemic_uncertainty_model(self):
"""
Get the model which can obtain the epistemic uncertainty.
"""
if self.uncer_model and self.uncer_model_path is None:
self.uncer_model = EpistemicUncertaintyModel(self.model)
if self.uncer_model.drop_count == 0:
if self.task_type == 'classification':
net_loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_opt = Adam(self.uncer_model.trainable_params())
model = Model(self.uncer_model, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
else:
net_loss = MSELoss()
net_opt = Adam(self.uncer_model.trainable_params())
model = Model(self.uncer_model, net_loss, net_opt, metrics={"MSE": MSE()})
model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()])
elif self.uncer_model is None:
uncer_param_dict = load_checkpoint(self.uncer_model_path)
load_param_into_net(self.uncer_model, uncer_param_dict)
def _eval_epistemic_uncertainty(self, eval_data, mc=10):
"""
Evaluate the epistemic uncertainty of classification and regression models using MC dropout.
"""
self._get_epistemic_uncertainty_model()
self.uncer_model.set_train(True)
outputs = [None] * mc
for i in range(mc):
pred = self.uncer_model(eval_data)
outputs[i] = pred.asnumpy()
if self.task_type == 'classification':
outputs = np.stack(outputs, axis=2)
epi_uncertainty = outputs.var(axis=2)
else:
outputs = np.stack(outputs, axis=1)
epi_uncertainty = outputs.var(axis=1)
epi_uncertainty = self._uncertainty_normalize(np.array(epi_uncertainty))
return epi_uncertainty
def _get_aleatoric_uncertainty_model(self):
"""
Get the model which can obtain the aleatoric uncertainty.
"""
if self.uncer_model and self.uncer_model_path is None:
self.uncer_model = AleatoricUncertaintyModel(self.model, self.num_classes, self.task_type)
net_loss = AleatoricLoss(self.task_type)
net_opt = Adam(self.uncer_model.trainable_params())
if self.task_type == 'classification':
model = Model(self.uncer_model, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
else:
model = Model(self.uncer_model, net_loss, net_opt, metrics={"MSE": MSE()})
model.train(self.epochs, self.train_dataset, callbacks=[LossMonitor()])
elif self.uncer_model is None:
uncer_param_dict = load_checkpoint(self.uncer_model_path)
load_param_into_net(self.uncer_model, uncer_param_dict)
def _eval_aleatoric_uncertainty(self, eval_data):
"""
Evaluate the aleatoric uncertainty of classification and regression models.
"""
self._get_aleatoric_uncertainty_model()
_, var = self.uncer_model(eval_data)
ale_uncertainty = self.sum(self.pow(var, 2), 1)
ale_uncertainty = self._uncertainty_normalize(ale_uncertainty.asnumpy())
return ale_uncertainty
def eval_epistemic(self, eval_data):
"""
Evaluate the epistemic uncertainty of inference results, which also called model uncertainty.
Args:
eval_data (Tensor): The data samples to be evaluated, the shape should be (N,C,H,W).
Returns:
numpy.dtype, the epistemic uncertainty of inference results of data samples.
"""
uncertainty = self._eval_aleatoric_uncertainty(eval_data)
return uncertainty
def eval_aleatoric(self, eval_data):
"""
Evaluate the aleatoric uncertainty of inference results, which also called data uncertainty.
Args:
eval_data (Tensor): The data samples to be evaluated, the shape should be (N,C,H,W).
Returns:
numpy.dtype, the aleatoric uncertainty of inference results of data samples.
"""
uncertainty = self._eval_epistemic_uncertainty(eval_data)
return uncertainty
class EpistemicUncertaintyModel(Cell):
"""
Using dropout during training and eval time which is approximate bayesian inference. In this way,
we can obtain the epistemic uncertainty (also called model uncertainty).
If the original model has Dropout layer, just use dropout when eval time, if not, add dropout layer
after Dense layer or Conv layer, then use dropout during train and eval time.
See more details in `Dropout as a Bayesian Approximation: Representing Model uncertainty in Deep Learning
<https://arxiv.org/abs/1506.02142>`.
"""
def __init__(self, model):
super(EpistemicUncertaintyModel, self).__init__()
self.drop_count = 0
self.model = self._make_epistemic(model)
def construct(self, x):
x = self.model(x)
return x
def _make_epistemic(self, model, dropout_rate=0.5):
"""
The dropout rate is set to 0.5 by default.
"""
for (name, layer) in model.name_cells().items():
if isinstance(layer, Dropout):
self.drop_count += 1
return model
for (name, layer) in model.name_cells().items():
if isinstance(layer, (Conv2d, Dense)):
uncertainty_layer = layer
uncertainty_name = name
drop = Dropout(keep_prob=dropout_rate)
bnn_drop = SequentialCell([uncertainty_layer, drop])
setattr(model, uncertainty_name, bnn_drop)
return model
raise ValueError("The model has not Dense Layer or Convolution Layer, "
"it can not evaluate epistemic uncertainty so far.")
class AleatoricUncertaintyModel(Cell):
"""
The aleatoric uncertainty (also called data uncertainty) is caused by input data, to obtain this
uncertainty, the loss function should be modified in order to add variance into loss.
See more details in `What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision?
<https://arxiv.org/abs/1703.04977>`.
"""
def __init__(self, model, labels, task):
super(AleatoricUncertaintyModel, self).__init__()
self.task = task
if task == 'classification':
self.model = model
self.var_layer = Dense(labels, labels)
else:
self.model, self.var_layer, self.pred_layer = self._make_aleatoric(model)
def construct(self, x):
if self.task == 'classification':
pred = self.model(x)
var = self.var_layer(pred)
else:
x = self.model(x)
pred = self.pred_layer(x)
var = self.var_layer(x)
return pred, var
def _make_aleatoric(self, model):
"""
In order to add variance into original loss, add var Layer after the original network.
"""
dense_layer = dense_name = None
for (name, layer) in model.name_cells().items():
if isinstance(layer, Dense):
dense_layer = layer
dense_name = name
if dense_layer is None:
raise ValueError("The model has not Dense Layer, "
"it can not evaluate aleatoric uncertainty so far.")
setattr(model, dense_name, Flatten())
var_layer = Dense(dense_layer.in_channels, dense_layer.out_channels)
return model, var_layer, dense_layer
class AleatoricLoss(Cell):
"""
The loss function of aleatoric model, different modification methods are adopted for
classification and regression.
"""
def __init__(self, task):
super(AleatoricLoss, self).__init__()
self.task = task
if self.task == 'classification':
self.sum = P.ReduceSum()
self.exp = P.Exp()
self.normal = C.normal
self.to_tensor = P.ScalarToArray()
self.entropy = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
else:
self.mean = P.ReduceMean()
self.exp = P.Exp()
self.pow = P.Pow()
def construct(self, data_pred, y):
y_pred, var = data_pred
if self.task == 'classification':
sample_times = 10
epsilon = self.normal((1, sample_times), self.to_tensor(0.0), self.to_tensor(1.0), 0)
total_loss = 0
for i in range(sample_times):
y_pred_i = y_pred + epsilon[0][i] * var
loss = self.entropy(y_pred_i, y)
total_loss += loss
avg_loss = total_loss / sample_times
return avg_loss
loss = self.mean(0.5 * self.exp(-var) * self.pow(y - y_pred, 2) + 0.5 * var)
return loss
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test uncertainty toolbox """
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.transforms.vision.c_transforms as CV
import mindspore.nn as nn
from mindspore import context, Tensor
from mindspore.common import dtype as mstype
from mindspore.common.initializer import TruncatedNormal
from mindspore.dataset.transforms.vision import Inter
from mindspore.nn.probability.toolbox.uncertainty_evaluation import UncertaintyEvaluation
from mindspore.train.serialization import load_checkpoint, load_param_into_net
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""weight initial for conv layer"""
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer"""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)
class LeNet5(nn.Cell):
def __init__(self, num_class=10, channel=1):
super(LeNet5, self).__init__()
self.num_class = num_class
self.conv1 = conv(channel, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, self.num_class)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
"""
create dataset for train or test
"""
# define dataset
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081
# define map operations
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
# apply map operations on images
mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)
# apply DatasetOps
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size)
return mnist_ds
if __name__ == '__main__':
# get trained model
network = LeNet5()
param_dict = load_checkpoint('checkpoint_lenet.ckpt')
load_param_into_net(network, param_dict)
# get train and eval dataset
ds_train = create_dataset('workspace/mnist/train')
ds_eval = create_dataset('workspace/mnist/test')
evaluation = UncertaintyEvaluation(model=network,
train_dataset=ds_train,
task_type='classification',
num_classes=10,
epochs=5,
uncertainty_model_path=None)
for eval_data in ds_eval.create_dict_iterator():
eval_data = Tensor(eval_data['image'], mstype.float32)
epistemic_uncertainty = evaluation.eval_epistemic(eval_data)
aleatoric_uncertainty = evaluation.eval_aleatoric(eval_data)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册