未验证 提交 8906c694 编写于 作者: T Teng Xi 提交者: GitHub

add gp_nas (#723)

* add gp_nas
上级 c0f8a887
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import copy
import numpy as np
from paddleslim.nas import GPNAS
# 使用GP-NAS参加[CVPR 2021 NAS国际比赛](https://www.cvpr21-nas.com/competition) Track2 demo
# [CVPR 2021 NAS国际比赛Track2 studio地址](https://aistudio.baidu.com/aistudio/competition/detail/71?lang=en)
# [AI studio GP-NAS demo](https://aistudio.baidu.com/aistudio/projectdetail/1824958)
# demo 基于paddleslim自研NAS算法GP-NAS:Gaussian Process based Neural Architecture Search
# 基于本demo的改进版可以获得双倍奖金
def preprare_trainning_data(file_name, t_flag):
## t_flag ==1 using all trainning data
## t_flag ==2 using half trainning data
with open(file_name, 'r') as f:
arch_dict = json.load(f)
Y_all = []
X_all = []
for sub_dict in arch_dict.items():
Y_all.append(sub_dict[1]['acc'] * 100)
X_all.append(np.array(sub_dict[1]['arch']).T.reshape(4, 16)[2])
X_all, Y_all = np.array(X_all), np.array(Y_all)
X_train, Y_train, X_test, Y_test = X_all[0::t_flag], Y_all[
0::t_flag], X_all[1::t_flag], Y_all[1::t_flag]
return X_train, Y_train, X_test, Y_test
if __name__ == '__main__':
stage1_file = './datasets/Track2_stage1_trainning.json'
stage2_file = './datasets/Track2_stage2_few_show_trainning.json'
X_train_stage1, Y_train_stage1, X_test_stage1, Y_test_stage1 = preprare_trainning_data(
stage1_file, 1)
X_train_stage2, Y_train_stage2, X_test_stage2, Y_test_stage2 = preprare_trainning_data(
stage2_file, 2)
gpnas = GPNAS()
w = gpnas.get_initial_mean(X_test_stage1, Y_test_stage1)
init_cov = gpnas.get_initial_cov(X_train_stage1)
error_list = np.array(
Y_test_stage2.reshape(len(Y_test_stage2), 1) - gpnas.get_predict(
X_test_stage2))
print('RMSE trainning on stage1 testing on stage2:',
np.sqrt(np.dot(error_list.T, error_list) / len(error_list)))
gpnas.get_posterior_mean(X_train_stage2[0::3], Y_train_stage2[0::3])
gpnas.get_posterior_mean(X_train_stage2[1::3], Y_train_stage2[1::3])
gpnas.get_posterior_cov(X_train_stage2[1::3], Y_train_stage2[1::3])
error_list = np.array(
Y_test_stage2.reshape(len(Y_test_stage2), 1) - gpnas.get_predict_jiont(
X_test_stage2, X_train_stage2[::1], Y_train_stage2[::1]))
print('RMSE using stage1 as prior:',
np.sqrt(np.dot(error_list.T, error_list) / len(error_list)))
# GP-NAS使用示例
CVPR2021_NAS_competition_gpnas_demo.py演示如何使用GP-NAS参加[CVPR 2021 NAS国际比赛](https://www.cvpr21-nas.com/competition) Track2 demo
[CVPR 2021 NAS国际比赛Track2 studio地址](https://aistudio.baidu.com/aistudio/competition/detail/71?lang=en)
[AI studio GP-NAS demo](https://aistudio.baidu.com/aistudio/projectdetail/1824958)
基于GP-NAS的改进版方案可以获得双倍奖金
# CVPR 2021 NAS国际比赛背景
在不训练的情况下,准确的预测任意模型结构性能非常重要。基于此,我们不仅可以深度的分析怎样的模型结构会有很好的性能,怎样的模型性能会很差。同时还能够预测出满足任意硬件延时约束下的最优的模型结构。本赛事提供了部分(小样本)模型结构与模型精度之间对应关系的bench mark,参赛选手既可以通过黑盒的方式直接进行训练,也可以使用白盒的方式进行参数估计。
本赛道采用Mobilenet-like搜索空间,其中16个block可以搜索,每层的搜索空间由一个4元组[layer_index1, layer_index2, OP1, OP2]构成, layer1_index取值范围在[2,17]或为layer2_index除了[2,17]之外还可以为0,为0则表示本层只有一个后序节点,每层可以选择与1到2个编号大于该层数的后序节点相连接,OP1取值范围在[1,6]表示6种(kernel size三种选择,膨胀系数2种选择)不同的链接方式,OP2取值范围除了[1,6]之外还可以为0,为0则表示本层只有一个后序节点。
本赛道分为两个阶段,第一阶段为线性拓扑,即每层只与编号比本层多1层的后续节点相连接,第二阶段在第一阶段基础至上考察模型的few shot能力。
第二阶段赛题背景: 为什么关注acc而非ranking? 在很多场景,我们需要搜索到在特定硬件上精度不低于特定指标的最优的模型结构,只预测相对排序无法保证搜索结构可以满足精度的约束条件。 few shot背景: predictor based模型结构搜索, 需要采样足够多的模型结构来训练预测器,将模型结构训到较高的指标需要加很多trick并且需要训练非常久,从而限制采样子网络的数量。代理任务可以快速获得模型的精度,但是代理任务的精度分布与加入trick并且训练更久的精度分布之间会有diff。第二阶段的目标就是,基于第一阶段的代理任务采样的模型结构与模型精度之间的关联性,在只采样非常少量模型结构在非代理任务(加入trick并且训练更久)上的精度情况下,就可以准确的预测任意模型结构在非代理任务上的精度。
本demo基于paddleslim自研NAS算法[GP-NAS:Gaussian Process based Neural Architecture Search](https://openaccess.thecvf.com/content_CVPR_2020/papers/Li_GP-NAS_Gaussian_Process_Based_Neural_Architecture_Search_CVPR_2020_paper.pdf)(CVPR2020)
此差异已折叠。
{"arch_few_shot_1": {"acc": 0.929999983907, "arch": [[2, 0, 1, 0], [3, 0, 5, 0], [4, 0, 5, 0], [5, 0, 5, 0], [6, 0, 3, 0], [7, 0, 3, 0], [8, 0, 1, 0], [9, 0, 3, 0], [10, 0, 1, 0], [11, 0, 5, 0], [12, 0, 6, 0], [13, 0, 6, 0], [14, 0, 6, 0], [15, 0, 4, 0], [16, 0, 5, 0], [17, 0, 4, 0]]}, "arch_few_shot_2": {"acc": 0.934199981689, "arch": [[2, 0, 5, 0], [3, 0, 3, 0], [4, 0, 2, 0], [5, 0, 1, 0], [6, 0, 5, 0], [7, 0, 3, 0], [8, 0, 5, 0], [9, 0, 4, 0], [10, 0, 1, 0], [11, 0, 2, 0], [12, 0, 3, 0], [13, 0, 1, 0], [14, 0, 3, 0], [15, 0, 4, 0], [16, 0, 3, 0], [17, 0, 4, 0]]}, "arch_few_shot_3": {"acc": 0.934699984789, "arch": [[2, 0, 4, 0], [3, 0, 3, 0], [4, 0, 6, 0], [5, 0, 3, 0], [6, 0, 6, 0], [7, 0, 1, 0], [8, 0, 3, 0], [9, 0, 6, 0], [10, 0, 2, 0], [11, 0, 5, 0], [12, 0, 3, 0], [13, 0, 3, 0], [14, 0, 6, 0], [15, 0, 6, 0], [16, 0, 3, 0], [17, 0, 1, 0]]}, "arch_few_shot_4": {"acc": 0.935499984026, "arch": [[2, 0, 5, 0], [3, 0, 2, 0], [4, 0, 2, 0], [5, 0, 5, 0], [6, 0, 6, 0], [7, 0, 1, 0], [8, 0, 3, 0], [9, 0, 5, 0], [10, 0, 6, 0], [11, 0, 5, 0], [12, 0, 4, 0], [13, 0, 2, 0], [14, 0, 2, 0], [15, 0, 4, 0], [16, 0, 6, 0], [17, 0, 4, 0]]}, "arch_few_shot_5": {"acc": 0.935999985337, "arch": [[2, 0, 2, 0], [3, 0, 6, 0], [4, 0, 5, 0], [5, 0, 5, 0], [6, 0, 6, 0], [7, 0, 5, 0], [8, 0, 3, 0], [9, 0, 2, 0], [10, 0, 5, 0], [11, 0, 3, 0], [12, 0, 3, 0], [13, 0, 2, 0], [14, 0, 6, 0], [15, 0, 1, 0], [16, 0, 1, 0], [17, 0, 1, 0]]}, "arch_few_shot_6": {"acc": 0.93649998188, "arch": [[2, 0, 6, 0], [3, 0, 6, 0], [4, 0, 3, 0], [5, 0, 3, 0], [6, 0, 4, 0], [7, 0, 1, 0], [8, 0, 4, 0], [9, 0, 3, 0], [10, 0, 4, 0], [11, 0, 4, 0], [12, 0, 6, 0], [13, 0, 6, 0], [14, 0, 3, 0], [15, 0, 4, 0], [16, 0, 6, 0], [17, 0, 5, 0]]}, "arch_few_shot_7": {"acc": 0.936799981594, "arch": [[2, 0, 3, 0], [3, 0, 6, 0], [4, 0, 6, 0], [5, 0, 6, 0], [6, 0, 1, 0], [7, 0, 5, 0], [8, 0, 1, 0], [9, 0, 4, 0], [10, 0, 4, 0], [11, 0, 1, 0], [12, 0, 5, 0], [13, 0, 6, 0], [14, 0, 4, 0], [15, 0, 3, 0], [16, 0, 2, 0], [17, 0, 4, 0]]}, "arch_few_shot_8": {"acc": 0.937099980712, "arch": [[2, 0, 3, 0], [3, 0, 3, 0], [4, 0, 4, 0], [5, 0, 3, 0], [6, 0, 1, 0], [7, 0, 5, 0], [8, 0, 6, 0], [9, 0, 4, 0], [10, 0, 3, 0], [11, 0, 5, 0], [12, 0, 2, 0], [13, 0, 5, 0], [14, 0, 3, 0], [15, 0, 6, 0], [16, 0, 1, 0], [17, 0, 5, 0]]}, "arch_few_shot_9": {"acc": 0.937299979925, "arch": [[2, 0, 4, 0], [3, 0, 6, 0], [4, 0, 1, 0], [5, 0, 5, 0], [6, 0, 2, 0], [7, 0, 6, 0], [8, 0, 1, 0], [9, 0, 3, 0], [10, 0, 3, 0], [11, 0, 6, 0], [12, 0, 2, 0], [13, 0, 4, 0], [14, 0, 4, 0], [15, 0, 3, 0], [16, 0, 1, 0], [17, 0, 1, 0]]}, "arch_few_shot_10": {"acc": 0.937499979734, "arch": [[2, 0, 4, 0], [3, 0, 2, 0], [4, 0, 1, 0], [5, 0, 2, 0], [6, 0, 6, 0], [7, 0, 6, 0], [8, 0, 1, 0], [9, 0, 3, 0], [10, 0, 6, 0], [11, 0, 3, 0], [12, 0, 5, 0], [13, 0, 1, 0], [14, 0, 6, 0], [15, 0, 5, 0], [16, 0, 6, 0], [17, 0, 5, 0]]}, "arch_few_shot_11": {"acc": 0.937599983811, "arch": [[2, 0, 3, 0], [3, 0, 5, 0], [4, 0, 6, 0], [5, 0, 4, 0], [6, 0, 5, 0], [7, 0, 5, 0], [8, 0, 6, 0], [9, 0, 6, 0], [10, 0, 6, 0], [11, 0, 1, 0], [12, 0, 2, 0], [13, 0, 3, 0], [14, 0, 5, 0], [15, 0, 4, 0], [16, 0, 1, 0], [17, 0, 2, 0]]}, "arch_few_shot_12": {"acc": 0.937799983025, "arch": [[2, 0, 2, 0], [3, 0, 5, 0], [4, 0, 4, 0], [5, 0, 3, 0], [6, 0, 3, 0], [7, 0, 5, 0], [8, 0, 6, 0], [9, 0, 2, 0], [10, 0, 2, 0], [11, 0, 2, 0], [12, 0, 1, 0], [13, 0, 1, 0], [14, 0, 4, 0], [15, 0, 4, 0], [16, 0, 3, 0], [17, 0, 6, 0]]}, "arch_few_shot_13": {"acc": 0.937999982834, "arch": [[2, 0, 2, 0], [3, 0, 6, 0], [4, 0, 4, 0], [5, 0, 2, 0], [6, 0, 1, 0], [7, 0, 1, 0], [8, 0, 5, 0], [9, 0, 4, 0], [10, 0, 3, 0], [11, 0, 2, 0], [12, 0, 6, 0], [13, 0, 6, 0], [14, 0, 6, 0], [15, 0, 1, 0], [16, 0, 4, 0], [17, 0, 4, 0]]}, "arch_few_shot_14": {"acc": 0.938199983239, "arch": [[2, 0, 1, 0], [3, 0, 6, 0], [4, 0, 1, 0], [5, 0, 2, 0], [6, 0, 2, 0], [7, 0, 1, 0], [8, 0, 2, 0], [9, 0, 6, 0], [10, 0, 4, 0], [11, 0, 3, 0], [12, 0, 6, 0], [13, 0, 3, 0], [14, 0, 1, 0], [15, 0, 1, 0], [16, 0, 6, 0], [17, 0, 1, 0]]}, "arch_few_shot_15": {"acc": 0.938399982452, "arch": [[2, 0, 4, 0], [3, 0, 6, 0], [4, 0, 1, 0], [5, 0, 2, 0], [6, 0, 1, 0], [7, 0, 1, 0], [8, 0, 5, 0], [9, 0, 6, 0], [10, 0, 5, 0], [11, 0, 3, 0], [12, 0, 3, 0], [13, 0, 4, 0], [14, 0, 1, 0], [15, 0, 6, 0], [16, 0, 4, 0], [17, 0, 3, 0]]}, "arch_few_shot_16": {"acc": 0.938699982762, "arch": [[2, 0, 1, 0], [3, 0, 3, 0], [4, 0, 5, 0], [5, 0, 1, 0], [6, 0, 6, 0], [7, 0, 4, 0], [8, 0, 4, 0], [9, 0, 4, 0], [10, 0, 4, 0], [11, 0, 3, 0], [12, 0, 2, 0], [13, 0, 3, 0], [14, 0, 6, 0], [15, 0, 2, 0], [16, 0, 5, 0], [17, 0, 2, 0]]}, "arch_few_shot_17": {"acc": 0.938899983168, "arch": [[2, 0, 2, 0], [3, 0, 5, 0], [4, 0, 3, 0], [5, 0, 6, 0], [6, 0, 3, 0], [7, 0, 5, 0], [8, 0, 4, 0], [9, 0, 3, 0], [10, 0, 1, 0], [11, 0, 2, 0], [12, 0, 5, 0], [13, 0, 5, 0], [14, 0, 6, 0], [15, 0, 6, 0], [16, 0, 3, 0], [17, 0, 5, 0]]}, "arch_few_shot_18": {"acc": 0.939099984765, "arch": [[2, 0, 4, 0], [3, 0, 3, 0], [4, 0, 4, 0], [5, 0, 6, 0], [6, 0, 6, 0], [7, 0, 3, 0], [8, 0, 3, 0], [9, 0, 4, 0], [10, 0, 6, 0], [11, 0, 4, 0], [12, 0, 1, 0], [13, 0, 5, 0], [14, 0, 5, 0], [15, 0, 3, 0], [16, 0, 2, 0], [17, 0, 1, 0]]}, "arch_few_shot_19": {"acc": 0.939399982691, "arch": [[2, 0, 5, 0], [3, 0, 5, 0], [4, 0, 6, 0], [5, 0, 1, 0], [6, 0, 6, 0], [7, 0, 6, 0], [8, 0, 4, 0], [9, 0, 3, 0], [10, 0, 1, 0], [11, 0, 2, 0], [12, 0, 2, 0], [13, 0, 5, 0], [14, 0, 5, 0], [15, 0, 4, 0], [16, 0, 3, 0], [17, 0, 6, 0]]}, "arch_few_shot_20": {"acc": 0.9395999825, "arch": [[2, 0, 1, 0], [3, 0, 5, 0], [4, 0, 6, 0], [5, 0, 4, 0], [6, 0, 2, 0], [7, 0, 2, 0], [8, 0, 4, 0], [9, 0, 2, 0], [10, 0, 3, 0], [11, 0, 3, 0], [12, 0, 5, 0], [13, 0, 1, 0], [14, 0, 6, 0], [15, 0, 5, 0], [16, 0, 2, 0], [17, 0, 3, 0]]}, "arch_few_shot_21": {"acc": 0.939799980521, "arch": [[2, 0, 2, 0], [3, 0, 4, 0], [4, 0, 5, 0], [5, 0, 1, 0], [6, 0, 2, 0], [7, 0, 2, 0], [8, 0, 2, 0], [9, 0, 3, 0], [10, 0, 2, 0], [11, 0, 3, 0], [12, 0, 2, 0], [13, 0, 5, 0], [14, 0, 6, 0], [15, 0, 1, 0], [16, 0, 5, 0], [17, 0, 2, 0]]}, "arch_few_shot_22": {"acc": 0.939999982119, "arch": [[2, 0, 2, 0], [3, 0, 3, 0], [4, 0, 3, 0], [5, 0, 3, 0], [6, 0, 2, 0], [7, 0, 1, 0], [8, 0, 6, 0], [9, 0, 5, 0], [10, 0, 4, 0], [11, 0, 5, 0], [12, 0, 6, 0], [13, 0, 5, 0], [14, 0, 4, 0], [15, 0, 6, 0], [16, 0, 3, 0], [17, 0, 2, 0]]}, "arch_few_shot_23": {"acc": 0.940299983025, "arch": [[2, 0, 2, 0], [3, 0, 3, 0], [4, 0, 5, 0], [5, 0, 1, 0], [6, 0, 6, 0], [7, 0, 2, 0], [8, 0, 1, 0], [9, 0, 3, 0], [10, 0, 4, 0], [11, 0, 5, 0], [12, 0, 2, 0], [13, 0, 5, 0], [14, 0, 1, 0], [15, 0, 2, 0], [16, 0, 1, 0], [17, 0, 4, 0]]}, "arch_few_shot_24": {"acc": 0.940499982834, "arch": [[2, 0, 5, 0], [3, 0, 1, 0], [4, 0, 6, 0], [5, 0, 2, 0], [6, 0, 4, 0], [7, 0, 5, 0], [8, 0, 1, 0], [9, 0, 1, 0], [10, 0, 6, 0], [11, 0, 4, 0], [12, 0, 5, 0], [13, 0, 1, 0], [14, 0, 4, 0], [15, 0, 6, 0], [16, 0, 5, 0], [17, 0, 5, 0]]}, "arch_few_shot_25": {"acc": 0.940799985528, "arch": [[2, 0, 3, 0], [3, 0, 4, 0], [4, 0, 6, 0], [5, 0, 2, 0], [6, 0, 2, 0], [7, 0, 3, 0], [8, 0, 5, 0], [9, 0, 5, 0], [10, 0, 5, 0], [11, 0, 2, 0], [12, 0, 4, 0], [13, 0, 3, 0], [14, 0, 6, 0], [15, 0, 3, 0], [16, 0, 3, 0], [17, 0, 2, 0]]}, "arch_few_shot_26": {"acc": 0.941099981666, "arch": [[2, 0, 3, 0], [3, 0, 2, 0], [4, 0, 4, 0], [5, 0, 2, 0], [6, 0, 2, 0], [7, 0, 2, 0], [8, 0, 5, 0], [9, 0, 2, 0], [10, 0, 5, 0], [11, 0, 1, 0], [12, 0, 5, 0], [13, 0, 3, 0], [14, 0, 5, 0], [15, 0, 3, 0], [16, 0, 5, 0], [17, 0, 6, 0]]}, "arch_few_shot_27": {"acc": 0.941499980092, "arch": [[2, 0, 1, 0], [3, 0, 1, 0], [4, 0, 3, 0], [5, 0, 5, 0], [6, 0, 2, 0], [7, 0, 1, 0], [8, 0, 2, 0], [9, 0, 4, 0], [10, 0, 6, 0], [11, 0, 3, 0], [12, 0, 1, 0], [13, 0, 4, 0], [14, 0, 5, 0], [15, 0, 5, 0], [16, 0, 5, 0], [17, 0, 6, 0]]}, "arch_few_shot_28": {"acc": 0.941899983883, "arch": [[2, 0, 2, 0], [3, 0, 4, 0], [4, 0, 6, 0], [5, 0, 4, 0], [6, 0, 4, 0], [7, 0, 5, 0], [8, 0, 6, 0], [9, 0, 6, 0], [10, 0, 6, 0], [11, 0, 1, 0], [12, 0, 6, 0], [13, 0, 1, 0], [14, 0, 2, 0], [15, 0, 6, 0], [16, 0, 5, 0], [17, 0, 2, 0]]}, "arch_few_shot_29": {"acc": 0.942299984694, "arch": [[2, 0, 3, 0], [3, 0, 5, 0], [4, 0, 4, 0], [5, 0, 5, 0], [6, 0, 2, 0], [7, 0, 3, 0], [8, 0, 2, 0], [9, 0, 4, 0], [10, 0, 5, 0], [11, 0, 3, 0], [12, 0, 4, 0], [13, 0, 2, 0], [14, 0, 2, 0], [15, 0, 1, 0], [16, 0, 2, 0], [17, 0, 6, 0]]}, "arch_few_shot_30": {"acc": 0.942899983525, "arch": [[2, 0, 1, 0], [3, 0, 4, 0], [4, 0, 5, 0], [5, 0, 3, 0], [6, 0, 6, 0], [7, 0, 4, 0], [8, 0, 5, 0], [9, 0, 1, 0], [10, 0, 2, 0], [11, 0, 3, 0], [12, 0, 3, 0], [13, 0, 6, 0], [14, 0, 2, 0], [15, 0, 3, 0], [16, 0, 5, 0], [17, 0, 2, 0]]}, "arch_few_shot_31": {"acc": 0.944699983001, "arch": [[2, 0, 6, 0], [3, 0, 6, 0], [4, 0, 4, 0], [5, 0, 2, 0], [6, 0, 1, 0], [7, 0, 6, 0], [8, 0, 2, 0], [9, 0, 2, 0], [10, 0, 6, 0], [11, 0, 4, 0], [12, 0, 4, 0], [13, 0, 3, 0], [14, 0, 3, 0], [15, 0, 2, 0], [16, 0, 1, 0], [17, 0, 1, 0]]}}
\ No newline at end of file
...@@ -20,9 +20,11 @@ from .rl_nas import * ...@@ -20,9 +20,11 @@ from .rl_nas import *
from ..nas import darts from ..nas import darts
from .darts import * from .darts import *
from .ofa import * from .ofa import *
from .gp_nas import *
__all__ = [] __all__ = []
__all__ += sa_nas.__all__ __all__ += sa_nas.__all__
__all__ += search_space.__all__ __all__ += search_space.__all__
__all__ += rl_nas.__all__ __all__ += rl_nas.__all__
__all__ += darts.__all__ __all__ += darts.__all__
__all__ += gp_nas.__all__
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import copy
__all__ = ["GPNAS"]
class GPNAS(object):
"""
GPNAS(Gaussian Process based Neural Architecture Search) is a neural architecture search algorithm.
We model the correlation between architectue and performance from a Bayesian perspective. Specifically, by introducing a novel Gaussian Process based
NAS (GP-NAS) method, the correlations are modeled by the kernel function and mean function. The kernel function is also learnable to enable adaptive modeling for complex
correlations in different search spaces. Furthermore, by in-corporating a mutual information based sampling method, we can theoretically ensure the high-performance
architecture with only a small set of samples. After addressing these problems, training GP-NAS once enables direct performance prediction of any architecture in different
scenarios and may obtain efficient networks for different deployment platforms.
"""
def __init__(self, c_flag=2, m_flag=2):
self.hp_mat = 0.0000001
self.hp_cov = 0.01
self.cov_w = None
self.w = None
self.c_flag = c_flag
self.m_flag = m_flag
def _get_corelation(self, mat1, mat2):
"""
give two typical kernel function
Auto kernel hyperparameters estimation to be updated
"""
mat_diff = abs(mat1 - mat2)
if self.c_flag == 1:
return 0.5 * np.exp(-np.dot(mat_diff, mat_diff) / 16)
elif self.c_flag == 2:
return 1 * np.exp(-np.sqrt(np.dot(mat_diff, mat_diff)) / 12)
def _preprocess_X(self, X):
"""
preprocess of input feature/ tokens of architecture
more complicated preprocess can be added such as nonlineaer transformation
"""
X = X.tolist()
p_X = copy.deepcopy(X)
for feature in p_X:
feature.append(1)
return p_X
def _get_cor_mat(self, X):
"""
get kernel matrix
"""
X = np.array(X)
l = X.shape[0]
cor_mat = []
for c_idx in range(l):
col = []
c_mat = X[c_idx].copy()
for r_idx in range(l):
r_mat = X[r_idx].copy()
temp_cor = self._get_corelation(c_mat, r_mat)
col.append(temp_cor)
cor_mat.append(col)
return np.mat(cor_mat)
def _get_cor_mat_joint(self, X, X_train):
"""
get kernel matrix
"""
X = np.array(X)
X_train = np.array(X_train)
l_c = X.shape[0]
l_r = X_train.shape[0]
cor_mat = []
for c_idx in range(l_c):
col = []
c_mat = X[c_idx].copy()
for r_idx in range(l_r):
r_mat = X_train[r_idx].copy()
temp_cor = self._get_corelation(c_mat, r_mat)
col.append(temp_cor)
cor_mat.append(col)
return np.mat(cor_mat)
def get_predict(self, X):
"""
get the prediction of network architecture X
"""
X = self._preprocess_X(X)
X = np.mat(X)
return X * self.w
def get_predict_jiont(self, X, X_train, Y_train):
"""
get the prediction of network architecture X based on X_train and Y_train
"""
X = np.mat(X)
X_train = np.mat(X_train)
Y_train = np.mat(Y_train)
m_X = self.get_predict(X)
m_X_train = self.get_predict(X_train)
mat_train = self._get_cor_mat(X_train)
mat_joint = self._get_cor_mat_joint(X, X_train)
return m_X + mat_joint * np.linalg.inv(mat_train + self.hp_mat * np.eye(
X_train.shape[0])) * (Y_train.T - m_X_train)
def get_initial_mean(self, X, Y):
"""
get initial mean of w
"""
X = self._preprocess_X(X)
X = np.mat(X)
Y = np.mat(Y)
self.w = np.linalg.inv(X.T * X + self.hp_mat * np.eye(X.shape[
1])) * X.T * Y.T
return self.w
def get_initial_cov(self, X):
"""
get initial coviarnce matrix of w
"""
X = self._preprocess_X(X)
X = np.mat(X)
self.cov_w = self.hp_cov * np.eye(X.shape[1])
return self.cov_w
def get_posterior_mean(self, X, Y):
"""
get posterior mean of w
"""
X = self._preprocess_X(X)
X = np.mat(X)
Y = np.mat(Y)
cov_mat = self._get_cor_mat(X)
if self.m_flag == 1:
self.w = self.w + self.cov_w * X.T * np.linalg.inv(
np.linalg.inv(cov_mat + self.hp_mat * np.eye(X.shape[0])) + X *
self.cov_w * X.T + self.hp_mat * np.eye(X.shape[0])) * (
Y.T - X * self.w)
else:
self.w = np.linalg.inv(X.T * np.linalg.inv(
cov_mat + self.hp_mat * np.eye(X.shape[0])) * X + np.linalg.inv(
self.cov_w + self.hp_mat * np.eye(X.shape[
1])) + self.hp_mat * np.eye(X.shape[1])) * (
X.T * np.linalg.inv(cov_mat + self.hp_mat * np.eye(
X.shape[0])) * Y.T +
np.linalg.inv(self.cov_w + self.hp_mat * np.eye(
X.shape[1])) * self.w)
return self.w
def get_posterior_cov(self, X, Y):
"""
get posterior coviarnce matrix of w
"""
X = self._preprocess_X(X)
X = np.mat(X)
Y = np.mat(Y)
cov_mat = self._get_cor_mat(X)
self.cov_mat = np.linalg.inv(
np.linalg.inv(X.T * cov_mat * X + self.hp_mat * np.eye(X.shape[1]))
+ np.linalg.inv(self.cov_w + self.hp_mat * np.eye(X.shape[
1])) + self.hp_mat * np.eye(X.shape[1]))
return self.cov_mat
此差异已折叠。
{"arch_few_shot_1": {"acc": 0.929999983907, "arch": [[2, 0, 1, 0], [3, 0, 5, 0], [4, 0, 5, 0], [5, 0, 5, 0], [6, 0, 3, 0], [7, 0, 3, 0], [8, 0, 1, 0], [9, 0, 3, 0], [10, 0, 1, 0], [11, 0, 5, 0], [12, 0, 6, 0], [13, 0, 6, 0], [14, 0, 6, 0], [15, 0, 4, 0], [16, 0, 5, 0], [17, 0, 4, 0]]}, "arch_few_shot_2": {"acc": 0.934199981689, "arch": [[2, 0, 5, 0], [3, 0, 3, 0], [4, 0, 2, 0], [5, 0, 1, 0], [6, 0, 5, 0], [7, 0, 3, 0], [8, 0, 5, 0], [9, 0, 4, 0], [10, 0, 1, 0], [11, 0, 2, 0], [12, 0, 3, 0], [13, 0, 1, 0], [14, 0, 3, 0], [15, 0, 4, 0], [16, 0, 3, 0], [17, 0, 4, 0]]}, "arch_few_shot_3": {"acc": 0.934699984789, "arch": [[2, 0, 4, 0], [3, 0, 3, 0], [4, 0, 6, 0], [5, 0, 3, 0], [6, 0, 6, 0], [7, 0, 1, 0], [8, 0, 3, 0], [9, 0, 6, 0], [10, 0, 2, 0], [11, 0, 5, 0], [12, 0, 3, 0], [13, 0, 3, 0], [14, 0, 6, 0], [15, 0, 6, 0], [16, 0, 3, 0], [17, 0, 1, 0]]}, "arch_few_shot_4": {"acc": 0.935499984026, "arch": [[2, 0, 5, 0], [3, 0, 2, 0], [4, 0, 2, 0], [5, 0, 5, 0], [6, 0, 6, 0], [7, 0, 1, 0], [8, 0, 3, 0], [9, 0, 5, 0], [10, 0, 6, 0], [11, 0, 5, 0], [12, 0, 4, 0], [13, 0, 2, 0], [14, 0, 2, 0], [15, 0, 4, 0], [16, 0, 6, 0], [17, 0, 4, 0]]}, "arch_few_shot_5": {"acc": 0.935999985337, "arch": [[2, 0, 2, 0], [3, 0, 6, 0], [4, 0, 5, 0], [5, 0, 5, 0], [6, 0, 6, 0], [7, 0, 5, 0], [8, 0, 3, 0], [9, 0, 2, 0], [10, 0, 5, 0], [11, 0, 3, 0], [12, 0, 3, 0], [13, 0, 2, 0], [14, 0, 6, 0], [15, 0, 1, 0], [16, 0, 1, 0], [17, 0, 1, 0]]}, "arch_few_shot_6": {"acc": 0.93649998188, "arch": [[2, 0, 6, 0], [3, 0, 6, 0], [4, 0, 3, 0], [5, 0, 3, 0], [6, 0, 4, 0], [7, 0, 1, 0], [8, 0, 4, 0], [9, 0, 3, 0], [10, 0, 4, 0], [11, 0, 4, 0], [12, 0, 6, 0], [13, 0, 6, 0], [14, 0, 3, 0], [15, 0, 4, 0], [16, 0, 6, 0], [17, 0, 5, 0]]}, "arch_few_shot_7": {"acc": 0.936799981594, "arch": [[2, 0, 3, 0], [3, 0, 6, 0], [4, 0, 6, 0], [5, 0, 6, 0], [6, 0, 1, 0], [7, 0, 5, 0], [8, 0, 1, 0], [9, 0, 4, 0], [10, 0, 4, 0], [11, 0, 1, 0], [12, 0, 5, 0], [13, 0, 6, 0], [14, 0, 4, 0], [15, 0, 3, 0], [16, 0, 2, 0], [17, 0, 4, 0]]}, "arch_few_shot_8": {"acc": 0.937099980712, "arch": [[2, 0, 3, 0], [3, 0, 3, 0], [4, 0, 4, 0], [5, 0, 3, 0], [6, 0, 1, 0], [7, 0, 5, 0], [8, 0, 6, 0], [9, 0, 4, 0], [10, 0, 3, 0], [11, 0, 5, 0], [12, 0, 2, 0], [13, 0, 5, 0], [14, 0, 3, 0], [15, 0, 6, 0], [16, 0, 1, 0], [17, 0, 5, 0]]}, "arch_few_shot_9": {"acc": 0.937299979925, "arch": [[2, 0, 4, 0], [3, 0, 6, 0], [4, 0, 1, 0], [5, 0, 5, 0], [6, 0, 2, 0], [7, 0, 6, 0], [8, 0, 1, 0], [9, 0, 3, 0], [10, 0, 3, 0], [11, 0, 6, 0], [12, 0, 2, 0], [13, 0, 4, 0], [14, 0, 4, 0], [15, 0, 3, 0], [16, 0, 1, 0], [17, 0, 1, 0]]}, "arch_few_shot_10": {"acc": 0.937499979734, "arch": [[2, 0, 4, 0], [3, 0, 2, 0], [4, 0, 1, 0], [5, 0, 2, 0], [6, 0, 6, 0], [7, 0, 6, 0], [8, 0, 1, 0], [9, 0, 3, 0], [10, 0, 6, 0], [11, 0, 3, 0], [12, 0, 5, 0], [13, 0, 1, 0], [14, 0, 6, 0], [15, 0, 5, 0], [16, 0, 6, 0], [17, 0, 5, 0]]}, "arch_few_shot_11": {"acc": 0.937599983811, "arch": [[2, 0, 3, 0], [3, 0, 5, 0], [4, 0, 6, 0], [5, 0, 4, 0], [6, 0, 5, 0], [7, 0, 5, 0], [8, 0, 6, 0], [9, 0, 6, 0], [10, 0, 6, 0], [11, 0, 1, 0], [12, 0, 2, 0], [13, 0, 3, 0], [14, 0, 5, 0], [15, 0, 4, 0], [16, 0, 1, 0], [17, 0, 2, 0]]}, "arch_few_shot_12": {"acc": 0.937799983025, "arch": [[2, 0, 2, 0], [3, 0, 5, 0], [4, 0, 4, 0], [5, 0, 3, 0], [6, 0, 3, 0], [7, 0, 5, 0], [8, 0, 6, 0], [9, 0, 2, 0], [10, 0, 2, 0], [11, 0, 2, 0], [12, 0, 1, 0], [13, 0, 1, 0], [14, 0, 4, 0], [15, 0, 4, 0], [16, 0, 3, 0], [17, 0, 6, 0]]}, "arch_few_shot_13": {"acc": 0.937999982834, "arch": [[2, 0, 2, 0], [3, 0, 6, 0], [4, 0, 4, 0], [5, 0, 2, 0], [6, 0, 1, 0], [7, 0, 1, 0], [8, 0, 5, 0], [9, 0, 4, 0], [10, 0, 3, 0], [11, 0, 2, 0], [12, 0, 6, 0], [13, 0, 6, 0], [14, 0, 6, 0], [15, 0, 1, 0], [16, 0, 4, 0], [17, 0, 4, 0]]}, "arch_few_shot_14": {"acc": 0.938199983239, "arch": [[2, 0, 1, 0], [3, 0, 6, 0], [4, 0, 1, 0], [5, 0, 2, 0], [6, 0, 2, 0], [7, 0, 1, 0], [8, 0, 2, 0], [9, 0, 6, 0], [10, 0, 4, 0], [11, 0, 3, 0], [12, 0, 6, 0], [13, 0, 3, 0], [14, 0, 1, 0], [15, 0, 1, 0], [16, 0, 6, 0], [17, 0, 1, 0]]}, "arch_few_shot_15": {"acc": 0.938399982452, "arch": [[2, 0, 4, 0], [3, 0, 6, 0], [4, 0, 1, 0], [5, 0, 2, 0], [6, 0, 1, 0], [7, 0, 1, 0], [8, 0, 5, 0], [9, 0, 6, 0], [10, 0, 5, 0], [11, 0, 3, 0], [12, 0, 3, 0], [13, 0, 4, 0], [14, 0, 1, 0], [15, 0, 6, 0], [16, 0, 4, 0], [17, 0, 3, 0]]}, "arch_few_shot_16": {"acc": 0.938699982762, "arch": [[2, 0, 1, 0], [3, 0, 3, 0], [4, 0, 5, 0], [5, 0, 1, 0], [6, 0, 6, 0], [7, 0, 4, 0], [8, 0, 4, 0], [9, 0, 4, 0], [10, 0, 4, 0], [11, 0, 3, 0], [12, 0, 2, 0], [13, 0, 3, 0], [14, 0, 6, 0], [15, 0, 2, 0], [16, 0, 5, 0], [17, 0, 2, 0]]}, "arch_few_shot_17": {"acc": 0.938899983168, "arch": [[2, 0, 2, 0], [3, 0, 5, 0], [4, 0, 3, 0], [5, 0, 6, 0], [6, 0, 3, 0], [7, 0, 5, 0], [8, 0, 4, 0], [9, 0, 3, 0], [10, 0, 1, 0], [11, 0, 2, 0], [12, 0, 5, 0], [13, 0, 5, 0], [14, 0, 6, 0], [15, 0, 6, 0], [16, 0, 3, 0], [17, 0, 5, 0]]}, "arch_few_shot_18": {"acc": 0.939099984765, "arch": [[2, 0, 4, 0], [3, 0, 3, 0], [4, 0, 4, 0], [5, 0, 6, 0], [6, 0, 6, 0], [7, 0, 3, 0], [8, 0, 3, 0], [9, 0, 4, 0], [10, 0, 6, 0], [11, 0, 4, 0], [12, 0, 1, 0], [13, 0, 5, 0], [14, 0, 5, 0], [15, 0, 3, 0], [16, 0, 2, 0], [17, 0, 1, 0]]}, "arch_few_shot_19": {"acc": 0.939399982691, "arch": [[2, 0, 5, 0], [3, 0, 5, 0], [4, 0, 6, 0], [5, 0, 1, 0], [6, 0, 6, 0], [7, 0, 6, 0], [8, 0, 4, 0], [9, 0, 3, 0], [10, 0, 1, 0], [11, 0, 2, 0], [12, 0, 2, 0], [13, 0, 5, 0], [14, 0, 5, 0], [15, 0, 4, 0], [16, 0, 3, 0], [17, 0, 6, 0]]}, "arch_few_shot_20": {"acc": 0.9395999825, "arch": [[2, 0, 1, 0], [3, 0, 5, 0], [4, 0, 6, 0], [5, 0, 4, 0], [6, 0, 2, 0], [7, 0, 2, 0], [8, 0, 4, 0], [9, 0, 2, 0], [10, 0, 3, 0], [11, 0, 3, 0], [12, 0, 5, 0], [13, 0, 1, 0], [14, 0, 6, 0], [15, 0, 5, 0], [16, 0, 2, 0], [17, 0, 3, 0]]}, "arch_few_shot_21": {"acc": 0.939799980521, "arch": [[2, 0, 2, 0], [3, 0, 4, 0], [4, 0, 5, 0], [5, 0, 1, 0], [6, 0, 2, 0], [7, 0, 2, 0], [8, 0, 2, 0], [9, 0, 3, 0], [10, 0, 2, 0], [11, 0, 3, 0], [12, 0, 2, 0], [13, 0, 5, 0], [14, 0, 6, 0], [15, 0, 1, 0], [16, 0, 5, 0], [17, 0, 2, 0]]}, "arch_few_shot_22": {"acc": 0.939999982119, "arch": [[2, 0, 2, 0], [3, 0, 3, 0], [4, 0, 3, 0], [5, 0, 3, 0], [6, 0, 2, 0], [7, 0, 1, 0], [8, 0, 6, 0], [9, 0, 5, 0], [10, 0, 4, 0], [11, 0, 5, 0], [12, 0, 6, 0], [13, 0, 5, 0], [14, 0, 4, 0], [15, 0, 6, 0], [16, 0, 3, 0], [17, 0, 2, 0]]}, "arch_few_shot_23": {"acc": 0.940299983025, "arch": [[2, 0, 2, 0], [3, 0, 3, 0], [4, 0, 5, 0], [5, 0, 1, 0], [6, 0, 6, 0], [7, 0, 2, 0], [8, 0, 1, 0], [9, 0, 3, 0], [10, 0, 4, 0], [11, 0, 5, 0], [12, 0, 2, 0], [13, 0, 5, 0], [14, 0, 1, 0], [15, 0, 2, 0], [16, 0, 1, 0], [17, 0, 4, 0]]}, "arch_few_shot_24": {"acc": 0.940499982834, "arch": [[2, 0, 5, 0], [3, 0, 1, 0], [4, 0, 6, 0], [5, 0, 2, 0], [6, 0, 4, 0], [7, 0, 5, 0], [8, 0, 1, 0], [9, 0, 1, 0], [10, 0, 6, 0], [11, 0, 4, 0], [12, 0, 5, 0], [13, 0, 1, 0], [14, 0, 4, 0], [15, 0, 6, 0], [16, 0, 5, 0], [17, 0, 5, 0]]}, "arch_few_shot_25": {"acc": 0.940799985528, "arch": [[2, 0, 3, 0], [3, 0, 4, 0], [4, 0, 6, 0], [5, 0, 2, 0], [6, 0, 2, 0], [7, 0, 3, 0], [8, 0, 5, 0], [9, 0, 5, 0], [10, 0, 5, 0], [11, 0, 2, 0], [12, 0, 4, 0], [13, 0, 3, 0], [14, 0, 6, 0], [15, 0, 3, 0], [16, 0, 3, 0], [17, 0, 2, 0]]}, "arch_few_shot_26": {"acc": 0.941099981666, "arch": [[2, 0, 3, 0], [3, 0, 2, 0], [4, 0, 4, 0], [5, 0, 2, 0], [6, 0, 2, 0], [7, 0, 2, 0], [8, 0, 5, 0], [9, 0, 2, 0], [10, 0, 5, 0], [11, 0, 1, 0], [12, 0, 5, 0], [13, 0, 3, 0], [14, 0, 5, 0], [15, 0, 3, 0], [16, 0, 5, 0], [17, 0, 6, 0]]}, "arch_few_shot_27": {"acc": 0.941499980092, "arch": [[2, 0, 1, 0], [3, 0, 1, 0], [4, 0, 3, 0], [5, 0, 5, 0], [6, 0, 2, 0], [7, 0, 1, 0], [8, 0, 2, 0], [9, 0, 4, 0], [10, 0, 6, 0], [11, 0, 3, 0], [12, 0, 1, 0], [13, 0, 4, 0], [14, 0, 5, 0], [15, 0, 5, 0], [16, 0, 5, 0], [17, 0, 6, 0]]}, "arch_few_shot_28": {"acc": 0.941899983883, "arch": [[2, 0, 2, 0], [3, 0, 4, 0], [4, 0, 6, 0], [5, 0, 4, 0], [6, 0, 4, 0], [7, 0, 5, 0], [8, 0, 6, 0], [9, 0, 6, 0], [10, 0, 6, 0], [11, 0, 1, 0], [12, 0, 6, 0], [13, 0, 1, 0], [14, 0, 2, 0], [15, 0, 6, 0], [16, 0, 5, 0], [17, 0, 2, 0]]}, "arch_few_shot_29": {"acc": 0.942299984694, "arch": [[2, 0, 3, 0], [3, 0, 5, 0], [4, 0, 4, 0], [5, 0, 5, 0], [6, 0, 2, 0], [7, 0, 3, 0], [8, 0, 2, 0], [9, 0, 4, 0], [10, 0, 5, 0], [11, 0, 3, 0], [12, 0, 4, 0], [13, 0, 2, 0], [14, 0, 2, 0], [15, 0, 1, 0], [16, 0, 2, 0], [17, 0, 6, 0]]}, "arch_few_shot_30": {"acc": 0.942899983525, "arch": [[2, 0, 1, 0], [3, 0, 4, 0], [4, 0, 5, 0], [5, 0, 3, 0], [6, 0, 6, 0], [7, 0, 4, 0], [8, 0, 5, 0], [9, 0, 1, 0], [10, 0, 2, 0], [11, 0, 3, 0], [12, 0, 3, 0], [13, 0, 6, 0], [14, 0, 2, 0], [15, 0, 3, 0], [16, 0, 5, 0], [17, 0, 2, 0]]}, "arch_few_shot_31": {"acc": 0.944699983001, "arch": [[2, 0, 6, 0], [3, 0, 6, 0], [4, 0, 4, 0], [5, 0, 2, 0], [6, 0, 1, 0], [7, 0, 6, 0], [8, 0, 2, 0], [9, 0, 2, 0], [10, 0, 6, 0], [11, 0, 4, 0], [12, 0, 4, 0], [13, 0, 3, 0], [14, 0, 3, 0], [15, 0, 2, 0], [16, 0, 1, 0], [17, 0, 1, 0]]}}
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import json
import copy
import unittest
import numpy as np
from paddleslim.nas import GPNAS
from static_case import StaticCase
# 使用GP-NAS参加[CVPR 2021 NAS国际比赛](https://www.cvpr21-nas.com/competition) Track2 demo
# [CVPR 2021 NAS国际比赛Track2 studio地址](https://aistudio.baidu.com/aistudio/competition/detail/71?lang=en)
# [AI studio GP-NAS demo](https://aistudio.baidu.com/aistudio/projectdetail/1824958)
# demo 基于paddleslim自研NAS算法GP-NAS:Gaussian Process based Neural Architecture Search
# 基于本demo的改进版可以获得双倍奖金
class TestGPNAS(StaticCase):
def test_gpnas(self):
def preprare_trainning_data(file_name, t_flag):
## t_flag ==1 using all trainning data
## t_flag ==2 using half trainning data
with open(file_name, 'r') as f:
arch_dict = json.load(f)
Y_all = []
X_all = []
for sub_dict in arch_dict.items():
Y_all.append(sub_dict[1]['acc'] * 100)
X_all.append(np.array(sub_dict[1]['arch']).T.reshape(4, 16)[2])
X_all, Y_all = np.array(X_all), np.array(Y_all)
X_train, Y_train, X_test, Y_test = X_all[0::t_flag], Y_all[
0::t_flag], X_all[1::t_flag], Y_all[1::t_flag]
return X_train, Y_train, X_test, Y_test
stage1_file = './Track2_stage1_trainning.json'
stage2_file = './Track2_stage2_few_show_trainning.json'
X_train_stage1, Y_train_stage1, X_test_stage1, Y_test_stage1 = preprare_trainning_data(
stage1_file, 1)
X_train_stage2, Y_train_stage2, X_test_stage2, Y_test_stage2 = preprare_trainning_data(
stage2_file, 2)
gpnas = GPNAS(1, 1)
w = gpnas.get_initial_mean(X_test_stage1, Y_test_stage1)
init_cov = gpnas.get_initial_cov(X_train_stage1)
error_list = np.array(
Y_test_stage2.reshape(len(Y_test_stage2), 1) - gpnas.get_predict(
X_test_stage2))
print('RMSE trainning on stage1 testing on stage2:',
np.sqrt(np.dot(error_list.T, error_list) / len(error_list)))
gpnas.get_posterior_mean(X_train_stage2[0::3], Y_train_stage2[0::3])
gpnas.get_posterior_mean(X_train_stage2[1::3], Y_train_stage2[1::3])
gpnas.get_posterior_cov(X_train_stage2[1::3], Y_train_stage2[1::3])
error_list = np.array(
Y_test_stage2.reshape(len(Y_test_stage2), 1) -
gpnas.get_predict_jiont(X_test_stage2, X_train_stage2[::1],
Y_train_stage2[::1]))
print('RMSE using stage1 as prior:',
np.sqrt(np.dot(error_list.T, error_list) / len(error_list)))
gpnas = GPNAS(2, 2)
w = gpnas.get_initial_mean(X_test_stage1, Y_test_stage1)
init_cov = gpnas.get_initial_cov(X_train_stage1)
error_list = np.array(
Y_test_stage2.reshape(len(Y_test_stage2), 1) - gpnas.get_predict(
X_test_stage2))
print('RMSE trainning on stage1 testing on stage2:',
np.sqrt(np.dot(error_list.T, error_list) / len(error_list)))
gpnas.get_posterior_mean(X_train_stage2[0::3], Y_train_stage2[0::3])
gpnas.get_posterior_mean(X_train_stage2[1::3], Y_train_stage2[1::3])
gpnas.get_posterior_cov(X_train_stage2[1::3], Y_train_stage2[1::3])
error_list = np.array(
Y_test_stage2.reshape(len(Y_test_stage2), 1) -
gpnas.get_predict_jiont(X_test_stage2, X_train_stage2[::1],
Y_train_stage2[::1]))
print('RMSE using stage1 as prior:',
np.sqrt(np.dot(error_list.T, error_list) / len(error_list)))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册