diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 70a4d7b40b154ff80ff6d30adaa147556749e905..cfe6730e0ca96020932880fd11b292469349fdf4 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -222,6 +222,7 @@ paddle.fluid.layers.teacher_student_sigmoid_loss (ArgSpec(args=['input', 'label' paddle.fluid.layers.huber_loss (ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None), ('document', '431a4301c35032166ec029f7432c80a7')) paddle.fluid.layers.tree_conv (ArgSpec(args=['nodes_vector', 'edge_set', 'output_size', 'num_filters', 'max_depth', 'act', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 2, 'tanh', None, None, None)), ('document', '34ea12ac9f10a65dccbc50100d12e607')) paddle.fluid.layers.npair_loss (ArgSpec(args=['anchor', 'positive', 'labels', 'l2_reg'], varargs=None, keywords=None, defaults=(0.002,)), ('document', '46994d10276dd4cb803b4062b5d14329')) +paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords=None, defaults=None), ('document', 'b76ccca3735bea4a58a0dbf0d77c5393')) paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '33bbd42027d872b3818b3d64ec52e139')) paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'b1ae2e1cc0750e58726374061ea90ecc')) paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', 'b0a1c2fc51c27a106da28f3308c41f5e')) diff --git a/paddle/fluid/operators/fsp_op.cc b/paddle/fluid/operators/fsp_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..fbe8e56a6160219175bd573a2ff186eb35e56fdf --- /dev/null +++ b/paddle/fluid/operators/fsp_op.cc @@ -0,0 +1,128 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fsp_op.h" + +namespace paddle { +namespace operators { + +class FSPOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of FSPOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of FSPOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of FSPOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + + PADDLE_ENFORCE( + x_dims.size() == 4, + "The Input(X) must have shape [batch_size, channel, height, width]."); + PADDLE_ENFORCE( + y_dims.size() == 4, + "The Input(Y) must have shape [batch_size, channel, height, width]."); + PADDLE_ENFORCE( + (x_dims[2] == y_dims[2]) && (x_dims[3] == y_dims[3]), + "The Input(X) and Input(Y) should have the same height and width."); + + ctx->SetOutputDim("Out", {x_dims[0], x_dims[1], y_dims[1]}); + ctx->ShareLoD("X", "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + framework::LibraryType library_{framework::LibraryType::kPlain}; + framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context(), layout_, library_); + } +}; + +class FSPOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor) The input of FSP op with shape [batch_size, x_channel, " + "height, width]"); + AddInput("Y", + "(Tensor) The input of FSP op with shape" + "[batch_size, y_channel, height, width]." + "The y_channel can be different with the x_channel of Input(X)" + " while the other dimensions must be the same with Input(X)'s."); + AddOutput( + "Out", + "(Tensor) The output of FSP op with shape " + "[batch_size, x_channel, y_channel]. The x_channel is the channel " + "of Input(X) and the y_channel is the channel of Input(Y)."); + AddComment(R"DOC( + This op is used to calculate the flow of solution procedure (FSP) matrix of two feature maps. + Given feature map x with shape [x_channel, h, w] and feature map y with shape + [y_channel, h, w], we can get the fsp matrix of x and y in two steps: + + step 1: reshape x into matrix with shape [x_channel, h * w] and reshape and + transpose y into matrix with shape [h * w, y_channel] + step 2: multiply x and y to get fsp matrix with shape [x_channel, y_channel] + + The output is a batch of fsp matrices. + )DOC"); + } +}; + +class FSPOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); + } + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Out"))->type(), + ctx.device_context()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fsp, ops::FSPOp, ops::FSPOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(fsp_grad, ops::FSPOpGrad); +REGISTER_OP_CPU_KERNEL( + fsp, ops::FSPOpKernel, + ops::FSPOpKernel); +REGISTER_OP_CPU_KERNEL( + fsp_grad, ops::FSPGradOpKernel, + ops::FSPGradOpKernel); diff --git a/paddle/fluid/operators/fsp_op.cu b/paddle/fluid/operators/fsp_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..4fd7ba04ff9af1806963427ad58c68fc216e82ac --- /dev/null +++ b/paddle/fluid/operators/fsp_op.cu @@ -0,0 +1,24 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/fsp_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL(fsp, ops::FSPOpKernel, + ops::FSPOpKernel); +REGISTER_OP_CUDA_KERNEL(fsp_grad, + ops::FSPGradOpKernel, + ops::FSPGradOpKernel); diff --git a/paddle/fluid/operators/fsp_op.h b/paddle/fluid/operators/fsp_op.h new file mode 100644 index 0000000000000000000000000000000000000000..544af2b7d9b9729fe5dce08793da6c983fbcc6fa --- /dev/null +++ b/paddle/fluid/operators/fsp_op.h @@ -0,0 +1,136 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class FSPOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* y = context.Input("Y"); + auto* output = context.Output("Out"); + output->mutable_data(context.GetPlace()); + auto x_dims = x->dims(); + auto y_dims = y->dims(); + + auto batch_size = x_dims[0]; + auto x_channel = x_dims[1]; + auto y_channel = y_dims[1]; + auto height = x_dims[2]; + auto width = x_dims[3]; + + auto blas = math::GetBlas(context); + + math::MatDescriptor x_mat_desc; + x_mat_desc.height_ = x_channel; + x_mat_desc.width_ = height * width; + x_mat_desc.batch_size_ = batch_size; + x_mat_desc.stride_ = x_channel * height * width; + + math::MatDescriptor y_mat_desc; + y_mat_desc.height_ = height * width; + y_mat_desc.width_ = y_channel; + y_mat_desc.batch_size_ = batch_size; + y_mat_desc.stride_ = y_channel * height * width; + y_mat_desc.trans_ = true; + + blas.MatMul(*x, x_mat_desc, *y, y_mat_desc, + static_cast(1.0 / (height * width)), output, + static_cast(0.0)); + } +}; + +template +class FSPGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* d_x = context.Output(framework::GradVarName("X")); + auto* d_y = context.Output(framework::GradVarName("Y")); + if (d_x == nullptr && d_y == nullptr) { + return; + } + auto* d_out = context.Input(framework::GradVarName("Out")); + auto d_out_dims = d_out->dims(); + auto batch_size = d_out_dims[0]; + auto x_channel = d_out_dims[1]; + auto y_channel = d_out_dims[2]; + int64_t h = 0; + int64_t w = 0; + + auto blas = math::GetBlas(context); + math::SetConstant set_zero; + if (d_x != nullptr) { + d_x->mutable_data(context.GetPlace()); + set_zero(context.template device_context(), d_x, + static_cast(0)); + auto* y = context.Input("Y"); + auto y_dims = y->dims(); + h = y_dims[2]; + w = y_dims[3]; + + math::MatDescriptor d_out_mat_desc; + d_out_mat_desc.height_ = x_channel; + d_out_mat_desc.width_ = y_channel; + d_out_mat_desc.batch_size_ = batch_size; + d_out_mat_desc.stride_ = x_channel * y_channel; + + math::MatDescriptor y_mat_desc; + y_mat_desc.height_ = y_channel; + y_mat_desc.width_ = h * w; + y_mat_desc.batch_size_ = batch_size; + y_mat_desc.stride_ = y_channel * h * w; + + blas.MatMul(*d_out, d_out_mat_desc, *y, y_mat_desc, + static_cast(1.0 / (h * w)), d_x, static_cast(0.0)); + } + + if (d_y != nullptr) { + d_y->mutable_data(context.GetPlace()); + set_zero(context.template device_context(), d_y, + static_cast(0)); + auto* x = context.Input("X"); + auto x_dims = x->dims(); + h = x_dims[2]; + w = x_dims[3]; + + math::MatDescriptor d_out_mat_desc; + d_out_mat_desc.height_ = y_channel; + d_out_mat_desc.width_ = x_channel; + d_out_mat_desc.batch_size_ = batch_size; + d_out_mat_desc.stride_ = x_channel * y_channel; + d_out_mat_desc.trans_ = true; + + math::MatDescriptor x_mat_desc; + x_mat_desc.height_ = x_channel; + x_mat_desc.width_ = h * w; + x_mat_desc.batch_size_ = batch_size; + x_mat_desc.stride_ = x_channel * h * w; + + blas.MatMul(*d_out, d_out_mat_desc, *x, x_mat_desc, + static_cast(1.0 / (h * w)), d_y, static_cast(0.0)); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/contrib/slim/core/compressor.py b/python/paddle/fluid/contrib/slim/core/compressor.py index 832ade497c67ee16b6068cad4f0edace94128989..1547b6abbe660b6be7a681a4e270e3080a5dac36 100644 --- a/python/paddle/fluid/contrib/slim/core/compressor.py +++ b/python/paddle/fluid/contrib/slim/core/compressor.py @@ -271,7 +271,7 @@ class Compressor(object): self.eval_reader = eval_reader self.teacher_graphs = [] for teacher in teacher_programs: - self.teacher_graphs.append(ImitationGraph(teacher, scope=scope)) + self.teacher_graphs.append(GraphWrapper(teacher)) self.checkpoint = None self.checkpoint_path = checkpoint_path diff --git a/python/paddle/fluid/contrib/slim/core/config.py b/python/paddle/fluid/contrib/slim/core/config.py index 12df9fcd1b0042c26aabac88d6ecba5fb827cba0..9bb395aee95b5236850ca51096ed870ab1d27b62 100644 --- a/python/paddle/fluid/contrib/slim/core/config.py +++ b/python/paddle/fluid/contrib/slim/core/config.py @@ -19,6 +19,7 @@ from collections import OrderedDict from ..prune import * from ..quantization import * from .strategy import * +from ..distillation import * __all__ = ['ConfigFactory'] """This factory is used to create instances by loading and parsing configure file with yaml format. diff --git a/python/paddle/fluid/contrib/slim/tests/filter_pruning/__init__.py b/python/paddle/fluid/contrib/slim/distillation/__init__.py similarity index 68% rename from python/paddle/fluid/contrib/slim/tests/filter_pruning/__init__.py rename to python/paddle/fluid/contrib/slim/distillation/__init__.py index d0c32e26092f6ea25771279418582a24ea449ab2..455c7c563318daec42892e71dcf0a48f22f376a1 100644 --- a/python/paddle/fluid/contrib/slim/tests/filter_pruning/__init__.py +++ b/python/paddle/fluid/contrib/slim/distillation/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,3 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from . import distiller +from .distiller import * +from . import distillation_strategy +from .distillation_strategy import * + +__all__ = distiller.__all__ +__all__ += distillation_strategy.__all__ diff --git a/python/paddle/fluid/contrib/slim/distillation/distillation_strategy.py b/python/paddle/fluid/contrib/slim/distillation/distillation_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..1f11f07a51e713d42cee5e63bd2a9a02d82232f7 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/distillation/distillation_strategy.py @@ -0,0 +1,94 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..core.strategy import Strategy +from ....framework import Program, program_guard +from .... import Executor +import logging + +__all__ = ['DistillationStrategy'] + +logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + + +class DistillationStrategy(Strategy): + def __init__(self, distillers=None, start_epoch=0, end_epoch=0): + """ + Args: + distillers(list): A list of distiller used to combine student graph and teacher graph + by adding some loss. + start_epoch(int): The epoch when to merge student graph and teacher graph for + distillation training. default: 0 + end_epoch(int): The epoch when to finish distillation training. default: 0 + + """ + super(DistillationStrategy, self).__init__(start_epoch, end_epoch) + self.distillers = distillers + + def on_compression_begin(self, context): + # load from checkpoint + if context.epoch_id > 0: + if context.epoch_id > self.start_epoch and context.epoch_id < self.end_epoch: + _logger.info('Restore DistillationStrategy') + self._create_distillation_graph(context) + _logger.info('Restore DistillationStrategy finish.') + + def on_epoch_begin(self, context): + if self.start_epoch == context.epoch_id: + _logger.info('DistillationStrategy::on_epoch_begin.') + self._create_distillation_graph(context) + _logger.info('DistillationStrategy set optimize_graph.') + + def _create_distillation_graph(self, context): + """ + step 1: Merge student graph and teacher graph into distillation graph. + step 2: Add loss into distillation graph by distillers. + step 3: Append backward ops and optimize ops into distillation graph for training. + """ + # step 1 + teacher = context.teacher_graphs[0] + for var in teacher.program.list_vars(): + var.stop_gradient = True + graph = context.train_graph.clone() + graph.merge(teacher) + graph.out_nodes['student_loss'] = graph.out_nodes['loss'] + + # step 2 + for distiller in self.distillers: + graph = distiller.distiller_loss(graph) + + # step 3 + startup_program = Program() + with program_guard(graph.program, startup_program): + context.distiller_optimizer._name = 'distillation_optimizer' + context.distiller_optimizer.minimize( + graph.var(graph.out_nodes['loss'])._var) + exe = Executor(context.place) + exe.run(startup_program, scope=context.scope) + + # backup graph for fine-tune after distillation + context.put('distillation_backup_optimize_graph', + context.optimize_graph) + context.optimize_graph = graph + + def on_epoch_end(self, context): + if context.epoch_id == (self.end_epoch - 1): + _logger.info('DistillationStrategy::on_epoch_end.') + # restore optimize_graph for fine-tune or other strategy in next stage. + context.optimize_graph = context.get( + 'distillation_backup_optimize_graph') + _logger.info( + 'DistillationStrategy set context.optimize_graph to None.') diff --git a/python/paddle/fluid/contrib/slim/distillation/distiller.py b/python/paddle/fluid/contrib/slim/distillation/distiller.py new file mode 100644 index 0000000000000000000000000000000000000000..13bb35a8be73ed29e907308d08a33cdc13dee069 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/distillation/distiller.py @@ -0,0 +1,188 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .... import layers +from .... import optimizer +from .... import Executor +from .... import Program +from .... import program_guard +from .... import regularizer + +__all__ = ['FSPDistiller', 'L2Distiller'] + + +class L2Distiller(object): + """ + Combine two layers from student net and teacher net by l2-loss. + And add the loss into the total loss using for distillation training. + """ + + def __init__(self, + student_feature_map, + teacher_feature_map, + distillation_loss_weight=1): + """ + Args: + student_feature_map(str): The name of feature map from student network. + teacher_feature_map(str): The name of feature map from teacher network. + It's shape should be the same with student network. + distillation_loss_weight(float): The weight of the l2-loss. + """ + self.student_feature_map = student_feature_map + self.teacher_feature_map = teacher_feature_map + self.distillation_loss_weight = distillation_loss_weight + + def distiller_loss(self, graph): + """ + Modify graph inplace to add l2-loss. + Args: + graph(GraphWrapper): The graph to be modified. + Returns: + GraphWrapper: The modified graph. + """ + distiller_pass = L2DistillerPass(self.student_feature_map, + self.teacher_feature_map, + self.distillation_loss_weight) + dis_graph = distiller_pass.apply(graph) + return dis_graph + + +class L2DistillerPass(object): + """ + The pass used to add l2-loss. + """ + + def __init__(self, + student_feature_map, + teacher_feature_map, + distillation_loss_weight=1): + """ + Args: + student_feature_map(str): The name of feature map from student network. + teacher_feature_map(str): The name of feature map from teacher network. + It's shape should be the same with student network. + distillation_loss_weight(float): The weight of the l2-loss. + """ + self.student_feature_map = student_feature_map + self.teacher_feature_map = teacher_feature_map + self.distillation_loss_weight = distillation_loss_weight + + def apply(self, graph): + ret_graph = graph + with program_guard(ret_graph.program): + + student_feature_map = ret_graph.var(self.student_feature_map)._var + teacher_feature_map = ret_graph.var(self.teacher_feature_map)._var + l2loss = layers.reduce_mean( + layers.square(student_feature_map - teacher_feature_map)) + + distillation_loss = l2loss * self.distillation_loss_weight + student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var + loss = distillation_loss + student_loss + + ret_graph.out_nodes[ + 'l2loss_' + self.student_feature_map + "_" + + self.teacher_feature_map] = distillation_loss.name + ret_graph.out_nodes['loss'] = loss.name + return ret_graph + + +class FSPDistiller(object): + """ + Combine layers from student net and teacher net by fsp-loss. + """ + + def __init__(self, student_pairs, teacher_pairs, + distillation_loss_weight=1): + """ + Args: + student_pairs(list): Each tuple, with two variable names, in student_pairs indicates + a section in student network. The variables in a tuple should + have the same feature map size. + teacher_pairs(list): Each tuple, with two variable names, in teacher_pairs indicates + a section in teacher network. The variables in a tuple should + have the same feature map size. Varibale named teacher_pairs[i][j] + should has the save channel number with that of variable named + student_pairs[i][j]. + + distillation_loss_weight(float): The weight of the fsp-loss. default: 1. + """ + self.student_pairs = student_pairs + self.teacher_pairs = teacher_pairs + self.distillation_loss_weight = distillation_loss_weight + + def distiller_loss(self, graph): + """ + Modify graph inplace to add fsp-loss. + Args: + graph(GraphWrapper): The graph to be modified. + Returns: + GraphWrapper: The modified graph. + """ + distiller_pass = FSPDistillerPass(self.student_pairs, + self.teacher_pairs, + self.distillation_loss_weight) + dis_graph = distiller_pass.apply(graph) + return dis_graph + + +class FSPDistillerPass(object): + ''' + Combine layers from student net and teacher net by fsp-loss. + ''' + + def __init__(self, s_pairs, t_pairs, distillation_loss_weight=1): + """ + Args: + s_pairs(list): Each tuple, with two variable names, in student_pairs indicates + a section in student network. The variables in a tuple should + have the same feature map size. + t_pairs(list): Each tuple, with two variable names, in teacher_pairs indicates + a section in teacher network. The variables in a tuple should + have the same feature map size. Varibale named teacher_pairs[i][j] + should has the save channel number with that of variable named + student_pairs[i][j]. + + distillation_loss_weight(float): The weight of the fsp-loss. default: 1. + """ + self.s_pairs = s_pairs + self.t_pairs = t_pairs + self.distillation_loss_weight = distillation_loss_weight + + def apply(self, graph): + ret_graph = graph + with program_guard(ret_graph.program): + losses = [] + for s_pair, t_pair in zip(self.s_pairs, self.t_pairs): + s_pair_start = ret_graph.var(s_pair[0])._var + s_pair_end = ret_graph.var(s_pair[1])._var + s_fsp_matrix = self._fsp_matrix(s_pair_start, s_pair_end) + t_pair_start = ret_graph.var(t_pair[0])._var + t_pair_end = ret_graph.var(t_pair[1])._var + t_fsp_matrix = self._fsp_matrix(t_pair_start, t_pair_end) + l2_loss = layers.reduce_mean( + layers.square(s_fsp_matrix - t_fsp_matrix)) + losses.append(l2_loss) + distillation_loss = layers.sum( + losses) * self.distillation_loss_weight + student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var + loss = distillation_loss + student_loss + + ret_graph.out_nodes[ + 'fsp_distillation_loss'] = distillation_loss.name + ret_graph.out_nodes['loss'] = loss.name + return ret_graph + + def _fsp_matrix(self, fea_map_0, fea_map_1): + return layers.fsp_matrix(fea_map_0, fea_map_1) diff --git a/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py b/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py index 8694be782708a6d47b3e1450305975d34fd3bd7f..c208553fd811c7b18f9168b8fcae4da6e5856070 100644 --- a/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py +++ b/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py @@ -300,7 +300,9 @@ class GraphWrapper(object): graph(GraphWrapper): The graph to be merged by current graph. """ for var in graph.program.list_vars(): - self.program.global_block()._clone_variable(var) + new_var = self.program.global_block()._clone_variable( + var, force_persistable=False) + new_var.stop_gradient = var.stop_gradient # TODO: parameters should be cloned for op in graph.ops(): op = op._op @@ -309,12 +311,12 @@ class GraphWrapper(object): attrs = {} for input_name in op.input_names: inputs[input_name] = [ - self.var(in_var_name) - for in_var_name in op.inputs(input_name) + self.var(in_var_name)._var + for in_var_name in op.input(input_name) ] for output_name in op.output_names: outputs[output_name] = [ - self.var(out_var_name) + self.var(out_var_name)._var for out_var_name in op.output(output_name) ] for attr_name in op.attr_names: diff --git a/python/paddle/fluid/contrib/slim/quantization/__init__.py b/python/paddle/fluid/contrib/slim/quantization/__init__.py index 6c26475f48855674d97abf5778a631646734fcf8..1c51aa15373779b06273296a27d913c070079f41 100644 --- a/python/paddle/fluid/contrib/slim/quantization/__init__.py +++ b/python/paddle/fluid/contrib/slim/quantization/__init__.py @@ -16,5 +16,7 @@ from __future__ import print_function from . import quantization_pass from .quantization_pass import * +from . import quantization_strategy +from .quantization_strategy import * -__all__ = quantization_pass.__all__ +__all__ = quantization_pass.__all__ + quantization_strategy.__all__ diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py b/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..6812b4c633d5b55d84fff935b696297f30b18c6b --- /dev/null +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py @@ -0,0 +1,209 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import sys +import numpy as np +from .... import Executor +from .... import io +from .... import core +from ....compiler import CompiledProgram +from ....compiler import BuildStrategy +from ....framework import IrGraph +from ..core.strategy import Strategy +from .quantization_pass import * + +__all__ = ['QuantizationStrategy'] + +logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + + +class QuantizationStrategy(Strategy): + """ + The strategy for Quantization. + """ + + def __init__(self, + start_epoch=0, + end_epoch=0, + float_model_save_path=None, + mobile_model_save_path=None, + int8_model_save_path=None, + activation_bits=8, + weight_bits=8, + activation_quantize_type='abs_max', + save_in_nodes=None, + save_out_nodes=None): + """ + Args: + start_epoch(int): The 'on_epoch_begin' function will be called in start_epoch. default: 0 + end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. default: 0 + float_model_save_path(str): The path to save model with float weights. + None means it doesn't save float model. defalut: None. + mobile_model_save_path(str): The path to save model for paddle-mobile execution. + None means it doesn't save mobile model. defalut: None. + int8_model_save_path(str): The path to save model with int8_t weight. + None means it doesn't save int8 model. defalut: None. + activation_bits(int): quantization bit number for activation. default: 8. + weight_bits(int): quantization bit number for weights. The bias is not quantized. + default: 8. + activation_quantize_type(str): quantization type for activation, + now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'. + If use 'abs_max' mode, the quantization scale will be calculated + dynamically each step in both training and testing period. If use + 'range_abs_max', a static quantization scale will be calculated + during training and used in inference. + save_in_nodes(list): A list of variable names used to prune graph + for saving inference model. + save_out_nodes(list): A list of variable names used to prune graph + for saving inference model. + + """ + super(QuantizationStrategy, self).__init__(start_epoch, end_epoch) + self.start_epoch = start_epoch + self.end_epoch = end_epoch + self.float_model_save_path = float_model_save_path + self.mobile_model_save_path = mobile_model_save_path + self.int8_model_save_path = int8_model_save_path + self.activation_bits = activation_bits + self.weight_bits = weight_bits + self.activation_quantize_type = activation_quantize_type + self.save_out_nodes = save_out_nodes + self.save_in_nodes = save_in_nodes + + def on_epoch_begin(self, context): + """ + Insert fake_quantize_op and fake_dequantize_op before trainging and testing. + """ + super(QuantizationStrategy, self).on_compression_begin(context) + if self.start_epoch == context.epoch_id: + _logger.info('QuantizationStrategy::on_epoch_begin') + train_ir_graph = IrGraph( + core.Graph(context.optimize_graph.program.desc), for_test=False) + test_ir_graph = IrGraph( + core.Graph(context.eval_graph.program.desc), for_test=True) + transform_pass = QuantizationTransformPass( + scope=context.scope, + place=context.place, + weight_bits=self.weight_bits, + activation_bits=self.activation_bits, + activation_quantize_type=self.activation_quantize_type) + transform_pass.apply(train_ir_graph) + transform_pass.apply(test_ir_graph) + + build_strategy = BuildStrategy() + build_strategy.enable_inplace = False + build_strategy.memory_optimize = False + # for quantization training + context.optimize_graph.compiled_graph = CompiledProgram( + train_ir_graph.graph).with_data_parallel( + loss_name=context.optimize_graph.out_nodes['loss'], + build_strategy=build_strategy) + # for evaluation. And program compiled from ir graph must be with data parallel. + context.eval_graph.compiled_graph = CompiledProgram( + test_ir_graph.graph).with_data_parallel( + build_strategy=build_strategy) + # for saving inference model after training + context.put('quantization_test_ir_graph_backup', test_ir_graph) + _logger.info('Finish QuantizationStrategy::on_epoch_begin') + + def on_epoch_end(self, context): + """ + Free and save inference model. + """ + super(QuantizationStrategy, self).on_compression_end(context) + + if context.epoch_id == self.end_epoch: + _logger.info('QuantizationStrategy::on_epoch_end') + test_ir_graph = context.get('quantization_test_ir_graph_backup') + # freeze the graph after training + freeze_pass = QuantizationFreezePass( + scope=context.scope, + place=context.place, + weight_bits=self.weight_bits, + activation_bits=self.activation_bits) + freeze_pass.apply(test_ir_graph) + + # for other strategies + context.eval_graph.program = test_ir_graph.to_program() + + if self.save_out_nodes == None: + out_vars = [ + context.eval_graph.var(var_name)._var + for var_name in context.eval_graph.out_nodes.values() + ] + else: + out_vars = [ + context.eval_graph.var(var_name)._var + for var_name in self.save_out_nodes + ] + + if self.save_in_nodes == None: + in_vars = list(context.eval_graph.out_nodes.values()) + else: + in_vars = self.save_in_nodes + + # save float model + if self.float_model_save_path: + executor = Executor(context.place) + io.save_inference_model( + self.float_model_save_path, + in_vars, + out_vars, + executor, + main_program=test_ir_graph.to_program(), + model_filename='model', + params_filename='weights', + export_for_deployment=True) + + # save int8 model + if self.int8_model_save_path: + convert_int8_pass = ConvertToInt8Pass( + scope=context.scope, place=context.place) + convert_int8_pass.apply(test_ir_graph) + + executor = Executor(context.place) + io.save_inference_model( + self.int8_model_save_path, + in_vars, + out_vars, + executor, + main_program=test_ir_graph.to_program(), + model_filename='model', + params_filename='weights', + export_for_deployment=True) + + # save mobile model + if self.mobile_model_save_path: + if not self.int8_model_save_path: + # convert the weights as int8_t type + convert_int8_pass = ConvertToInt8Pass( + scope=context.scope, place=context.place) + convert_int8_pass.apply(test_ir_graph) + # make some changes on the graph for the mobile inference + mobile_pass = TransformForMobilePass() + mobile_pass.apply(test_ir_graph) + executor = Executor(context.place) + io.save_inference_model( + self.mobile_model_save_path, + in_vars, + out_vars, + executor, + main_program=test_ir_graph.to_program(), + model_filename='model', + params_filename='weights', + export_for_deployment=True) + _logger.info('Finish QuantizationStrategy::on_epoch_end') diff --git a/python/paddle/fluid/contrib/slim/tests/distillation/compress.yaml b/python/paddle/fluid/contrib/slim/tests/distillation/compress.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ef89dfb7801e6df8a2cf842a5fcc745d70254977 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/distillation/compress.yaml @@ -0,0 +1,46 @@ +#start_epoch(int): The epoch when to merge student graph and teacher graph for +# distillation training. default: 0 +# +#end_epoch(int): The epoch when to finish distillation training. default: 0 +# +#student_feature_map(str): The name of feature map from student network. +# +#teacher_feature_map(str): The name of feature map from teacher network. +# It's shape should be the same with student network. +# +#student_pairs(list): Each tuple, with two variable names, in student_pairs indicates +# a section in student network. The variables in a tuple should +# have the same feature map size. +# +#teacher_pairs(list): Each tuple, with two variable names, in teacher_pairs indicates +# a section in teacher network. The variables in a tuple should +# have the same feature map size. Varibale named teacher_pairs[i][j] +# should has the save channel number with that of variable named +# student_pairs[i][j]. +# +#distillation_loss_weight(float): The weight of the loss. +version: 1.0 +distillers: + fsp_distiller: + class: 'FSPDistiller' +# teacher_pairs: [['teacher_depthwise_conv2d_1.tmp_0', 'teacher_conv2d_3.tmp_0']] +# student_pairs: [['student_depthwise_conv2d_1.tmp_0', 'student_conv2d_3.tmp_0']] + teacher_pairs: [['teacher_conv2_1_dw.tmp_0', 'teacher_conv1.tmp_0']] + student_pairs: [['student_conv2_1_dw.tmp_0', 'student_conv1.tmp_0']] + distillation_loss_weight: 1 + l2_distiller: + class: 'L2Distiller' + teacher_feature_map: 'teacher.tmp_2' + student_feature_map: 'student.tmp_2' + distillation_loss_weight: 1 +strategies: + distillation_strategy: + class: 'DistillationStrategy' + distillers: ['fsp_distiller', 'l2_distiller'] + start_epoch: 0 + end_epoch: 1 +compressor: + epoch: 1 + checkpoint_path: './distillation_checkpoints/' + strategies: + - distillation_strategy diff --git a/python/paddle/fluid/contrib/slim/tests/filter_pruning/compress.yaml b/python/paddle/fluid/contrib/slim/tests/filter_pruning/compress.yaml index 232276feac5023c45d594015cf7084b000cd5b4a..5f747a049e95a5920236336c69a80a9492e6190d 100644 --- a/python/paddle/fluid/contrib/slim/tests/filter_pruning/compress.yaml +++ b/python/paddle/fluid/contrib/slim/tests/filter_pruning/compress.yaml @@ -29,6 +29,6 @@ strategies: metric_name: 'acc_top1' compressor: epoch: 2 - checkpoint_path: './checkpoints/' + checkpoint_path: './checkpoints_pruning/' strategies: - sensitive_pruning_strategy diff --git a/python/paddle/fluid/contrib/slim/tests/filter_pruning/mobilenet.py b/python/paddle/fluid/contrib/slim/tests/mobilenet.py similarity index 86% rename from python/paddle/fluid/contrib/slim/tests/filter_pruning/mobilenet.py rename to python/paddle/fluid/contrib/slim/tests/mobilenet.py index 0148325a642a2bcbebd3d7794056ff2778a3992d..f5dbef17e8d4a7c474881d88b6619061a3424177 100644 --- a/python/paddle/fluid/contrib/slim/tests/filter_pruning/mobilenet.py +++ b/python/paddle/fluid/contrib/slim/tests/mobilenet.py @@ -35,8 +35,9 @@ train_parameters = { class MobileNet(): - def __init__(self): + def __init__(self, name=""): self.params = train_parameters + self.name = name def net(self, input, class_dim=1000, scale=1.0): # conv1: 112x112 @@ -47,7 +48,7 @@ class MobileNet(): num_filters=int(32 * scale), stride=2, padding=1, - name="conv1") + name=self.name + "_conv1") # 56x56 input = self.depthwise_separable( @@ -57,7 +58,7 @@ class MobileNet(): num_groups=32, stride=1, scale=scale, - name="conv2_1") + name=self.name + "_conv2_1") input = self.depthwise_separable( input, @@ -66,7 +67,7 @@ class MobileNet(): num_groups=64, stride=2, scale=scale, - name="conv2_2") + name=self.name + "_conv2_2") # 28x28 input = self.depthwise_separable( @@ -76,7 +77,7 @@ class MobileNet(): num_groups=128, stride=1, scale=scale, - name="conv3_1") + name=self.name + "_conv3_1") input = self.depthwise_separable( input, @@ -85,7 +86,7 @@ class MobileNet(): num_groups=128, stride=2, scale=scale, - name="conv3_2") + name=self.name + "_conv3_2") # 14x14 input = self.depthwise_separable( @@ -95,7 +96,7 @@ class MobileNet(): num_groups=256, stride=1, scale=scale, - name="conv4_1") + name=self.name + "_conv4_1") input = self.depthwise_separable( input, @@ -104,7 +105,7 @@ class MobileNet(): num_groups=256, stride=2, scale=scale, - name="conv4_2") + name=self.name + "_conv4_2") # 14x14 for i in range(5): @@ -115,7 +116,7 @@ class MobileNet(): num_groups=512, stride=1, scale=scale, - name="conv5" + "_" + str(i + 1)) + name=self.name + "_conv5" + "_" + str(i + 1)) # 7x7 input = self.depthwise_separable( input, @@ -124,7 +125,7 @@ class MobileNet(): num_groups=512, stride=2, scale=scale, - name="conv5_6") + name=self.name + "_conv5_6") input = self.depthwise_separable( input, @@ -133,7 +134,7 @@ class MobileNet(): num_groups=1024, stride=1, scale=scale, - name="conv6") + name=self.name + "_conv6") input = fluid.layers.pool2d( input=input, @@ -142,12 +143,14 @@ class MobileNet(): pool_type='avg', global_pooling=True) - output = fluid.layers.fc(input=input, - size=class_dim, - act='softmax', - param_attr=ParamAttr( - initializer=MSRA(), name="fc7_weights"), - bias_attr=ParamAttr(name="fc7_offset")) + output = fluid.layers.fc( + input=input, + size=class_dim, + act='softmax', + param_attr=ParamAttr( + initializer=MSRA(), name=self.name + "_fc7_weights"), + bias_attr=ParamAttr(name=self.name + "_fc7_offset"), + name=self.name) return output def conv_bn_layer(self, @@ -172,11 +175,13 @@ class MobileNet(): use_cudnn=use_cudnn, param_attr=ParamAttr( initializer=MSRA(), name=name + "_weights"), + name=name, bias_attr=False) bn_name = name + "_bn" return fluid.layers.batch_norm( input=conv, act=act, + name=name, param_attr=ParamAttr(name=bn_name + "_scale"), bias_attr=ParamAttr(name=bn_name + "_offset"), moving_mean_name=bn_name + '_mean', diff --git a/python/paddle/fluid/contrib/slim/tests/quantization/compress.yaml b/python/paddle/fluid/contrib/slim/tests/quantization/compress.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f29eb53f88d22d87b61f82279b676af5ec1ef497 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/quantization/compress.yaml @@ -0,0 +1,48 @@ +#start_epoch(int): The epoch to insert quantization operators. default: 0 +# +#end_epoch(int): The epoch to save inferecne model. default: 0 +# +#float_model_save_path(str): The path to save model with float weights. +# None means it doesn't save float model. defalut: None. +# +#mobile_model_save_path(str): The path to save model for paddle-mobile execution. +# None means it doesn't save mobile model. defalut: None. +# +#int8_model_save_path(str): The path to save model with int8_t weight. +# None means it doesn't save int8 model. defalut: None. +# +#activation_bits(int): quantization bit number for activation. default: 8. +# +#weight_bits(int): quantization bit number for weights. The bias is not quantized. +# default: 8. +# +#activation_quantize_type(str): quantization type for activation, +# now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'. +# If use 'abs_max' mode, the quantization scale will be calculated +# dynamically each step in both training and testing period. If use +# 'range_abs_max', a static quantization scale will be calculated +# during training and used in inference. +# +#save_in_nodes(list): A list of variable names used to prune graph +# for saving inference model. +# +#save_out_nodes(list): A list of variable names used to prune graph +# for saving inference model. +version: 1.0 +strategies: + quantization_strategy: + class: 'QuantizationStrategy' + start_epoch: 0 + end_epoch: 0 + float_model_save_path: './output/float' + weight_bits: 8 + activation_bits: 8 + weight_quantize_type: 'abs_max' + activation_quantize_type: 'abs_max' + save_in_nodes: ['image'] + save_out_nodes: ['quan.tmp_2'] +compressor: + epoch: 1 + checkpoint_path: './checkpoints_quan/' + strategies: + - quantization_strategy diff --git a/python/paddle/fluid/contrib/slim/tests/test_distillation_strategy.py b/python/paddle/fluid/contrib/slim/tests/test_distillation_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..9b967c0ac7d2bfdab23d4557ef0b7d28f4118ff7 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_distillation_strategy.py @@ -0,0 +1,94 @@ +# copyright (c) 2019 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +import paddle +import unittest +import paddle.fluid as fluid +from mobilenet import MobileNet +from paddle.fluid.contrib.slim.core import Compressor +from paddle.fluid.contrib.slim.graph import GraphWrapper + + +class TestDistillationStrategy(unittest.TestCase): + """ + Test API of distillation strategy. + """ + + def test_compression(self): + if not fluid.core.is_compiled_with_cuda(): + return + class_dim = 10 + image_shape = [1, 28, 28] + image = fluid.layers.data( + name='image', shape=image_shape, dtype='float32') + image.stop_gradient = False + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + out = MobileNet(name="student").net(input=image, class_dim=class_dim) + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) + val_program = fluid.default_main_program().clone(for_test=False) + + cost = fluid.layers.cross_entropy(input=out, label=label) + avg_cost = fluid.layers.mean(x=cost) + optimizer = fluid.optimizer.Momentum( + momentum=0.9, + learning_rate=0.01, + regularization=fluid.regularizer.L2Decay(4e-5)) + + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128) + + val_feed_list = [('img', image.name), ('label', label.name)] + val_fetch_list = [('acc_top1', acc_top1.name), ('acc_top5', + acc_top5.name)] + + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=128) + train_feed_list = [('img', image.name), ('label', label.name)] + train_fetch_list = [('loss', avg_cost.name)] + + # define teacher program + teacher_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(teacher_program, startup_program): + img = teacher_program.global_block()._clone_variable( + image, force_persistable=False) + predict = MobileNet(name="teacher").net(input=img, + class_dim=class_dim) + + exe.run(startup_program) + + com_pass = Compressor( + place, + fluid.global_scope(), + fluid.default_main_program(), + train_reader=train_reader, + train_feed_list=train_feed_list, + train_fetch_list=train_fetch_list, + eval_program=val_program, + eval_reader=val_reader, + eval_feed_list=val_feed_list, + eval_fetch_list=val_fetch_list, + teacher_programs=[teacher_program.clone(for_test=True)], + train_optimizer=optimizer, + distiller_optimizer=optimizer) + com_pass.config('./distillation/compress.yaml') + eval_graph = com_pass.run() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/contrib/slim/tests/test_filter_pruning.py b/python/paddle/fluid/contrib/slim/tests/test_filter_pruning.py index d73ee27779a0d17a0f60df645a6d2946d665c01e..e1763039b3a962a43f2fe3a22c05cb32cba596ed 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_filter_pruning.py +++ b/python/paddle/fluid/contrib/slim/tests/test_filter_pruning.py @@ -15,7 +15,7 @@ import paddle import unittest import paddle.fluid as fluid -from filter_pruning.mobilenet import MobileNet +from mobilenet import MobileNet from paddle.fluid.contrib.slim.core import Compressor from paddle.fluid.contrib.slim.graph import GraphWrapper diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantization_strategy.py b/python/paddle/fluid/contrib/slim/tests/test_quantization_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..92afd892afed86e69266c9ab9c97d90daebb86d5 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_quantization_strategy.py @@ -0,0 +1,82 @@ +# copyright (c) 2019 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +import paddle +import unittest +import paddle.fluid as fluid +from mobilenet import MobileNet +from paddle.fluid.contrib.slim.core import Compressor +from paddle.fluid.contrib.slim.graph import GraphWrapper + + +class TestQuantizationStrategy(unittest.TestCase): + """ + Test API of quantization strategy. + """ + + def test_compression(self): + if not fluid.core.is_compiled_with_cuda(): + return + class_dim = 10 + image_shape = [1, 28, 28] + image = fluid.layers.data( + name='image', shape=image_shape, dtype='float32') + image.stop_gradient = False + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + out = MobileNet(name='quan').net(input=image, class_dim=class_dim) + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) + val_program = fluid.default_main_program().clone(for_test=False) + + cost = fluid.layers.cross_entropy(input=out, label=label) + avg_cost = fluid.layers.mean(x=cost) + + optimizer = fluid.optimizer.Momentum( + momentum=0.9, + learning_rate=0.01, + regularization=fluid.regularizer.L2Decay(4e-5)) + + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128) + + val_feed_list = [('img', image.name), ('label', label.name)] + val_fetch_list = [('acc_top1', acc_top1.name), ('acc_top5', + acc_top5.name)] + + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=128) + train_feed_list = [('img', image.name), ('label', label.name)] + train_fetch_list = [('loss', avg_cost.name)] + + com_pass = Compressor( + place, + fluid.global_scope(), + fluid.default_main_program(), + train_reader=train_reader, + train_feed_list=train_feed_list, + train_fetch_list=train_fetch_list, + eval_program=val_program, + eval_reader=val_reader, + eval_feed_list=val_feed_list, + eval_fetch_list=val_fetch_list, + train_optimizer=optimizer) + com_pass.config('./quantization/compress.yaml') + eval_graph = com_pass.run() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index e4169c247f40f1944f98ddd185e55b404bdbf9e3..f3d876f141763beec940899e8ab5ed464328b06e 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1559,12 +1559,15 @@ class Block(object): name=v.name) self.vars[new_p.name] = new_p - def _clone_variable(self, var): + def _clone_variable(self, var, force_persistable=True): """ Clone a variable into current block. Args: var: the variable to be cloned. + force_persistable(bool): True means setting the result variable to being persistable. + False means setting the persistable the same with that of input var. + default: True. Returns: Variable: the new variable cloned from 'var' in current block. @@ -1584,7 +1587,7 @@ class Block(object): shape=var.shape, dtype=var.dtype, type=var.type, - persistable=True, + persistable=True if force_persistable else var.persistable, is_data=var.is_data) else: ret_var = self.create_var( @@ -1593,7 +1596,7 @@ class Block(object): dtype=var.dtype, type=var.type, lod_level=var.lod_level, - persistable=True, + persistable=True if force_persistable else var.persistable, is_data=var.is_data) return ret_var diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index e2c8be613fb2b27d33acbcafdabbf4c8a526f5d5..c4e6053fec0514479ec4b0c110dfaf4610e677f5 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -189,6 +189,7 @@ __all__ = [ 'huber_loss', 'tree_conv', 'npair_loss', + 'fsp_matrix', ] kIgnoreIndex = -100 @@ -10790,3 +10791,46 @@ def npair_loss(anchor, positive, labels, l2_reg=0.002): celoss = reduce_mean(cross_entropy) return l2loss + celoss + + +def fsp_matrix(x, y): + """ + + **FSP matrix op** + + This op is used to calculate the flow of solution procedure (FSP) matrix of two feature maps. + Given feature map x with shape [x_channel, h, w] and feature map y with shape + [y_channel, h, w], we can get the fsp matrix of x and y in two steps: + + 1. reshape x into matrix with shape [x_channel, h * w] and reshape and + transpose y into matrix with shape [h * w, y_channel]. + 2. multiply x and y to get fsp matrix with shape [x_channel, y_channel]. + + The output is a batch of fsp matrices. + + Args: + + x (Variable): A feature map with shape [batch_size, x_channel, height, width]. + y (Variable): A feature map with shape [batch_size, y_channel, height, width]. + The y_channel can be different with the x_channel of Input(X) + while the other dimensions must be the same with Input(X)'s. + + Returns: + + fsp matrix (Variable): The output of FSP op with shape [batch_size, x_channel, y_channel]. + The x_channel is the channel of x and the y_channel is the channel of y. + + Examples: + + .. code-block:: python + + feature_map_0 = fluid.layers.conv2d(x) + feature_map_1 = fluid.layers.conv2d(feature_map_0) + loss = fluid.layers.fsp_matrix(feature_map_0, feature_map_1) + + """ + helper = LayerHelper('fsp_matrix', **locals()) + out = helper.create_variable_for_type_inference(dtype=helper.input_dtype( + input_param_name='x')) + helper.append_op(type='fsp', inputs={'X': x, 'Y': y}, outputs={'Out': out}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_fsp_op.py b/python/paddle/fluid/tests/unittests/test_fsp_op.py new file mode 100644 index 0000000000000000000000000000000000000000..6ad7418447b4bac5e6a6034f94540091590fa189 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fsp_op.py @@ -0,0 +1,60 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +def fsp_matrix(a, b): + batch = a.shape[0] + a_channel = a.shape[1] + b_channel = b.shape[1] + h = a.shape[2] + w = a.shape[3] + a_t = a.transpose([0, 2, 3, 1]) + a_t = a_t.reshape([batch, h * w, a_channel]) + b_t = b.transpose([0, 2, 3, 1]).reshape([batch, h * w, b_channel]) + a_r = a_t.repeat( + b_channel, axis=1).reshape( + [batch, h * w, b_channel, a_channel]).transpose([0, 1, 3, 2]) + b_r = b_t.repeat( + a_channel, axis=1).reshape([batch, h * w, a_channel, b_channel]) + return np.mean(a_r * b_r, axis=1) + + +class TestFSPOp(OpTest): + def setUp(self): + self.op_type = "fsp" + self.initTestCase() + + feature_map_0 = np.random.uniform(0, 10, self.a_shape).astype('float32') + feature_map_1 = np.random.uniform(0, 10, self.b_shape).astype('float32') + + self.inputs = {'X': feature_map_0, 'Y': feature_map_1} + self.outputs = {'Out': fsp_matrix(feature_map_0, feature_map_1)} + + def initTestCase(self): + self.a_shape = (2, 16, 32, 31) + self.b_shape = (2, 28, 32, 31) + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.05) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 1672c3600f389d87e85f965f96122065137cf0ac..f343ed4e87edb260fc79c87921020e79cef93325 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1269,6 +1269,15 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(out) print(str(program)) + def test_fsp(self): + program = Program() + with program_guard(program): + x = layers.data(name="X", shape=[16, 4, 4], dtype="float32") + y = layers.data(name="Y", shape=[8, 4, 4], dtype="float32") + out = layers.fsp_matrix(x, y) + self.assertIsNotNone(out) + print(str(program)) + if __name__ == '__main__': unittest.main() diff --git a/python/setup.py.in b/python/setup.py.in index a7c1e91f9c3a9597d799659a0abe3c9f56e54a57..9f87f5644fc969f3f55fd08689f3e2bbaf36dc39 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -117,6 +117,7 @@ packages=['paddle', 'paddle.fluid.contrib.slim.graph', 'paddle.fluid.contrib.slim.prune', 'paddle.fluid.contrib.slim.quantization', + 'paddle.fluid.contrib.slim.distillation', 'paddle.fluid.contrib.utils', 'paddle.fluid.transpiler', 'paddle.fluid.transpiler.details']