未验证 提交 e9bec936 编写于 作者: W whs 提交者: GitHub

[slim] Add quantization strategy and distillation strategy. (#16408)

* Add fsp operator.
1 Add unitest.
2. Add python API.
3. Add layer test.

* Add quantization strategy.
1. Add API.
2. Add unitest.

* Add distillatoin strategy.

* Add unitest config file for quantization

* Fix Copyright
test=develop

* Fix setup.py

* Fix document of layers.py.
test=develop

* Fix unitest in python3.
test=develop

* Fix documents.
test=develop

* 1. refine fsp op by batched gemm
2. remove unused import
test=develop

* Fix test_dist_se_resnext.
1. disable test distillation.
2. reset framework.py
test=develop

* Enable unitest of distillation after fixing Block._clone_variable
test=develop

* Fix cdn issue.
test=develop
上级 de3b70a1
......@@ -222,6 +222,7 @@ paddle.fluid.layers.teacher_student_sigmoid_loss (ArgSpec(args=['input', 'label'
paddle.fluid.layers.huber_loss (ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None), ('document', '431a4301c35032166ec029f7432c80a7'))
paddle.fluid.layers.tree_conv (ArgSpec(args=['nodes_vector', 'edge_set', 'output_size', 'num_filters', 'max_depth', 'act', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 2, 'tanh', None, None, None)), ('document', '34ea12ac9f10a65dccbc50100d12e607'))
paddle.fluid.layers.npair_loss (ArgSpec(args=['anchor', 'positive', 'labels', 'l2_reg'], varargs=None, keywords=None, defaults=(0.002,)), ('document', '46994d10276dd4cb803b4062b5d14329'))
paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords=None, defaults=None), ('document', 'b76ccca3735bea4a58a0dbf0d77c5393'))
paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '33bbd42027d872b3818b3d64ec52e139'))
paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'b1ae2e1cc0750e58726374061ea90ecc'))
paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', 'b0a1c2fc51c27a106da28f3308c41f5e'))
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fsp_op.h"
namespace paddle {
namespace operators {
class FSPOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of FSPOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of FSPOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FSPOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE(
x_dims.size() == 4,
"The Input(X) must have shape [batch_size, channel, height, width].");
PADDLE_ENFORCE(
y_dims.size() == 4,
"The Input(Y) must have shape [batch_size, channel, height, width].");
PADDLE_ENFORCE(
(x_dims[2] == y_dims[2]) && (x_dims[3] == y_dims[3]),
"The Input(X) and Input(Y) should have the same height and width.");
ctx->SetOutputDim("Out", {x_dims[0], x_dims[1], y_dims[1]});
ctx->ShareLoD("X", "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context(), layout_, library_);
}
};
class FSPOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor) The input of FSP op with shape [batch_size, x_channel, "
"height, width]");
AddInput("Y",
"(Tensor) The input of FSP op with shape"
"[batch_size, y_channel, height, width]."
"The y_channel can be different with the x_channel of Input(X)"
" while the other dimensions must be the same with Input(X)'s.");
AddOutput(
"Out",
"(Tensor) The output of FSP op with shape "
"[batch_size, x_channel, y_channel]. The x_channel is the channel "
"of Input(X) and the y_channel is the channel of Input(Y).");
AddComment(R"DOC(
This op is used to calculate the flow of solution procedure (FSP) matrix of two feature maps.
Given feature map x with shape [x_channel, h, w] and feature map y with shape
[y_channel, h, w], we can get the fsp matrix of x and y in two steps:
step 1: reshape x into matrix with shape [x_channel, h * w] and reshape and
transpose y into matrix with shape [h * w, y_channel]
step 2: multiply x and y to get fsp matrix with shape [x_channel, y_channel]
The output is a batch of fsp matrices.
)DOC");
}
};
class FSPOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
ctx.device_context());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fsp, ops::FSPOp, ops::FSPOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(fsp_grad, ops::FSPOpGrad);
REGISTER_OP_CPU_KERNEL(
fsp, ops::FSPOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::FSPOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
fsp_grad, ops::FSPGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::FSPGradOpKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/fsp_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(fsp, ops::FSPOpKernel<plat::CUDADeviceContext, float>,
ops::FSPOpKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(fsp_grad,
ops::FSPGradOpKernel<plat::CUDADeviceContext, float>,
ops::FSPGradOpKernel<plat::CUDADeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class FSPOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y");
auto* output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace());
auto x_dims = x->dims();
auto y_dims = y->dims();
auto batch_size = x_dims[0];
auto x_channel = x_dims[1];
auto y_channel = y_dims[1];
auto height = x_dims[2];
auto width = x_dims[3];
auto blas = math::GetBlas<DeviceContext, T>(context);
math::MatDescriptor x_mat_desc;
x_mat_desc.height_ = x_channel;
x_mat_desc.width_ = height * width;
x_mat_desc.batch_size_ = batch_size;
x_mat_desc.stride_ = x_channel * height * width;
math::MatDescriptor y_mat_desc;
y_mat_desc.height_ = height * width;
y_mat_desc.width_ = y_channel;
y_mat_desc.batch_size_ = batch_size;
y_mat_desc.stride_ = y_channel * height * width;
y_mat_desc.trans_ = true;
blas.MatMul(*x, x_mat_desc, *y, y_mat_desc,
static_cast<T>(1.0 / (height * width)), output,
static_cast<T>(0.0));
}
};
template <typename DeviceContext, typename T>
class FSPGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* d_y = context.Output<Tensor>(framework::GradVarName("Y"));
if (d_x == nullptr && d_y == nullptr) {
return;
}
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
auto d_out_dims = d_out->dims();
auto batch_size = d_out_dims[0];
auto x_channel = d_out_dims[1];
auto y_channel = d_out_dims[2];
int64_t h = 0;
int64_t w = 0;
auto blas = math::GetBlas<DeviceContext, T>(context);
math::SetConstant<DeviceContext, T> set_zero;
if (d_x != nullptr) {
d_x->mutable_data<T>(context.GetPlace());
set_zero(context.template device_context<DeviceContext>(), d_x,
static_cast<T>(0));
auto* y = context.Input<Tensor>("Y");
auto y_dims = y->dims();
h = y_dims[2];
w = y_dims[3];
math::MatDescriptor d_out_mat_desc;
d_out_mat_desc.height_ = x_channel;
d_out_mat_desc.width_ = y_channel;
d_out_mat_desc.batch_size_ = batch_size;
d_out_mat_desc.stride_ = x_channel * y_channel;
math::MatDescriptor y_mat_desc;
y_mat_desc.height_ = y_channel;
y_mat_desc.width_ = h * w;
y_mat_desc.batch_size_ = batch_size;
y_mat_desc.stride_ = y_channel * h * w;
blas.MatMul(*d_out, d_out_mat_desc, *y, y_mat_desc,
static_cast<T>(1.0 / (h * w)), d_x, static_cast<T>(0.0));
}
if (d_y != nullptr) {
d_y->mutable_data<T>(context.GetPlace());
set_zero(context.template device_context<DeviceContext>(), d_y,
static_cast<T>(0));
auto* x = context.Input<Tensor>("X");
auto x_dims = x->dims();
h = x_dims[2];
w = x_dims[3];
math::MatDescriptor d_out_mat_desc;
d_out_mat_desc.height_ = y_channel;
d_out_mat_desc.width_ = x_channel;
d_out_mat_desc.batch_size_ = batch_size;
d_out_mat_desc.stride_ = x_channel * y_channel;
d_out_mat_desc.trans_ = true;
math::MatDescriptor x_mat_desc;
x_mat_desc.height_ = x_channel;
x_mat_desc.width_ = h * w;
x_mat_desc.batch_size_ = batch_size;
x_mat_desc.stride_ = x_channel * h * w;
blas.MatMul(*d_out, d_out_mat_desc, *x, x_mat_desc,
static_cast<T>(1.0 / (h * w)), d_y, static_cast<T>(0.0));
}
}
};
} // namespace operators
} // namespace paddle
......@@ -271,7 +271,7 @@ class Compressor(object):
self.eval_reader = eval_reader
self.teacher_graphs = []
for teacher in teacher_programs:
self.teacher_graphs.append(ImitationGraph(teacher, scope=scope))
self.teacher_graphs.append(GraphWrapper(teacher))
self.checkpoint = None
self.checkpoint_path = checkpoint_path
......
......@@ -19,6 +19,7 @@ from collections import OrderedDict
from ..prune import *
from ..quantization import *
from .strategy import *
from ..distillation import *
__all__ = ['ConfigFactory']
"""This factory is used to create instances by loading and parsing configure file with yaml format.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,3 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import distiller
from .distiller import *
from . import distillation_strategy
from .distillation_strategy import *
__all__ = distiller.__all__
__all__ += distillation_strategy.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..core.strategy import Strategy
from ....framework import Program, program_guard
from .... import Executor
import logging
__all__ = ['DistillationStrategy']
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
class DistillationStrategy(Strategy):
def __init__(self, distillers=None, start_epoch=0, end_epoch=0):
"""
Args:
distillers(list): A list of distiller used to combine student graph and teacher graph
by adding some loss.
start_epoch(int): The epoch when to merge student graph and teacher graph for
distillation training. default: 0
end_epoch(int): The epoch when to finish distillation training. default: 0
"""
super(DistillationStrategy, self).__init__(start_epoch, end_epoch)
self.distillers = distillers
def on_compression_begin(self, context):
# load from checkpoint
if context.epoch_id > 0:
if context.epoch_id > self.start_epoch and context.epoch_id < self.end_epoch:
_logger.info('Restore DistillationStrategy')
self._create_distillation_graph(context)
_logger.info('Restore DistillationStrategy finish.')
def on_epoch_begin(self, context):
if self.start_epoch == context.epoch_id:
_logger.info('DistillationStrategy::on_epoch_begin.')
self._create_distillation_graph(context)
_logger.info('DistillationStrategy set optimize_graph.')
def _create_distillation_graph(self, context):
"""
step 1: Merge student graph and teacher graph into distillation graph.
step 2: Add loss into distillation graph by distillers.
step 3: Append backward ops and optimize ops into distillation graph for training.
"""
# step 1
teacher = context.teacher_graphs[0]
for var in teacher.program.list_vars():
var.stop_gradient = True
graph = context.train_graph.clone()
graph.merge(teacher)
graph.out_nodes['student_loss'] = graph.out_nodes['loss']
# step 2
for distiller in self.distillers:
graph = distiller.distiller_loss(graph)
# step 3
startup_program = Program()
with program_guard(graph.program, startup_program):
context.distiller_optimizer._name = 'distillation_optimizer'
context.distiller_optimizer.minimize(
graph.var(graph.out_nodes['loss'])._var)
exe = Executor(context.place)
exe.run(startup_program, scope=context.scope)
# backup graph for fine-tune after distillation
context.put('distillation_backup_optimize_graph',
context.optimize_graph)
context.optimize_graph = graph
def on_epoch_end(self, context):
if context.epoch_id == (self.end_epoch - 1):
_logger.info('DistillationStrategy::on_epoch_end.')
# restore optimize_graph for fine-tune or other strategy in next stage.
context.optimize_graph = context.get(
'distillation_backup_optimize_graph')
_logger.info(
'DistillationStrategy set context.optimize_graph to None.')
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .... import layers
from .... import optimizer
from .... import Executor
from .... import Program
from .... import program_guard
from .... import regularizer
__all__ = ['FSPDistiller', 'L2Distiller']
class L2Distiller(object):
"""
Combine two layers from student net and teacher net by l2-loss.
And add the loss into the total loss using for distillation training.
"""
def __init__(self,
student_feature_map,
teacher_feature_map,
distillation_loss_weight=1):
"""
Args:
student_feature_map(str): The name of feature map from student network.
teacher_feature_map(str): The name of feature map from teacher network.
It's shape should be the same with student network.
distillation_loss_weight(float): The weight of the l2-loss.
"""
self.student_feature_map = student_feature_map
self.teacher_feature_map = teacher_feature_map
self.distillation_loss_weight = distillation_loss_weight
def distiller_loss(self, graph):
"""
Modify graph inplace to add l2-loss.
Args:
graph(GraphWrapper): The graph to be modified.
Returns:
GraphWrapper: The modified graph.
"""
distiller_pass = L2DistillerPass(self.student_feature_map,
self.teacher_feature_map,
self.distillation_loss_weight)
dis_graph = distiller_pass.apply(graph)
return dis_graph
class L2DistillerPass(object):
"""
The pass used to add l2-loss.
"""
def __init__(self,
student_feature_map,
teacher_feature_map,
distillation_loss_weight=1):
"""
Args:
student_feature_map(str): The name of feature map from student network.
teacher_feature_map(str): The name of feature map from teacher network.
It's shape should be the same with student network.
distillation_loss_weight(float): The weight of the l2-loss.
"""
self.student_feature_map = student_feature_map
self.teacher_feature_map = teacher_feature_map
self.distillation_loss_weight = distillation_loss_weight
def apply(self, graph):
ret_graph = graph
with program_guard(ret_graph.program):
student_feature_map = ret_graph.var(self.student_feature_map)._var
teacher_feature_map = ret_graph.var(self.teacher_feature_map)._var
l2loss = layers.reduce_mean(
layers.square(student_feature_map - teacher_feature_map))
distillation_loss = l2loss * self.distillation_loss_weight
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
loss = distillation_loss + student_loss
ret_graph.out_nodes[
'l2loss_' + self.student_feature_map + "_" +
self.teacher_feature_map] = distillation_loss.name
ret_graph.out_nodes['loss'] = loss.name
return ret_graph
class FSPDistiller(object):
"""
Combine layers from student net and teacher net by fsp-loss.
"""
def __init__(self, student_pairs, teacher_pairs,
distillation_loss_weight=1):
"""
Args:
student_pairs(list<tuple>): Each tuple, with two variable names, in student_pairs indicates
a section in student network. The variables in a tuple should
have the same feature map size.
teacher_pairs(list<tuple>): Each tuple, with two variable names, in teacher_pairs indicates
a section in teacher network. The variables in a tuple should
have the same feature map size. Varibale named teacher_pairs[i][j]
should has the save channel number with that of variable named
student_pairs[i][j].
distillation_loss_weight(float): The weight of the fsp-loss. default: 1.
"""
self.student_pairs = student_pairs
self.teacher_pairs = teacher_pairs
self.distillation_loss_weight = distillation_loss_weight
def distiller_loss(self, graph):
"""
Modify graph inplace to add fsp-loss.
Args:
graph(GraphWrapper): The graph to be modified.
Returns:
GraphWrapper: The modified graph.
"""
distiller_pass = FSPDistillerPass(self.student_pairs,
self.teacher_pairs,
self.distillation_loss_weight)
dis_graph = distiller_pass.apply(graph)
return dis_graph
class FSPDistillerPass(object):
'''
Combine layers from student net and teacher net by fsp-loss.
'''
def __init__(self, s_pairs, t_pairs, distillation_loss_weight=1):
"""
Args:
s_pairs(list<tuple>): Each tuple, with two variable names, in student_pairs indicates
a section in student network. The variables in a tuple should
have the same feature map size.
t_pairs(list<tuple>): Each tuple, with two variable names, in teacher_pairs indicates
a section in teacher network. The variables in a tuple should
have the same feature map size. Varibale named teacher_pairs[i][j]
should has the save channel number with that of variable named
student_pairs[i][j].
distillation_loss_weight(float): The weight of the fsp-loss. default: 1.
"""
self.s_pairs = s_pairs
self.t_pairs = t_pairs
self.distillation_loss_weight = distillation_loss_weight
def apply(self, graph):
ret_graph = graph
with program_guard(ret_graph.program):
losses = []
for s_pair, t_pair in zip(self.s_pairs, self.t_pairs):
s_pair_start = ret_graph.var(s_pair[0])._var
s_pair_end = ret_graph.var(s_pair[1])._var
s_fsp_matrix = self._fsp_matrix(s_pair_start, s_pair_end)
t_pair_start = ret_graph.var(t_pair[0])._var
t_pair_end = ret_graph.var(t_pair[1])._var
t_fsp_matrix = self._fsp_matrix(t_pair_start, t_pair_end)
l2_loss = layers.reduce_mean(
layers.square(s_fsp_matrix - t_fsp_matrix))
losses.append(l2_loss)
distillation_loss = layers.sum(
losses) * self.distillation_loss_weight
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
loss = distillation_loss + student_loss
ret_graph.out_nodes[
'fsp_distillation_loss'] = distillation_loss.name
ret_graph.out_nodes['loss'] = loss.name
return ret_graph
def _fsp_matrix(self, fea_map_0, fea_map_1):
return layers.fsp_matrix(fea_map_0, fea_map_1)
......@@ -300,7 +300,9 @@ class GraphWrapper(object):
graph(GraphWrapper): The graph to be merged by current graph.
"""
for var in graph.program.list_vars():
self.program.global_block()._clone_variable(var)
new_var = self.program.global_block()._clone_variable(
var, force_persistable=False)
new_var.stop_gradient = var.stop_gradient
# TODO: parameters should be cloned
for op in graph.ops():
op = op._op
......@@ -309,12 +311,12 @@ class GraphWrapper(object):
attrs = {}
for input_name in op.input_names:
inputs[input_name] = [
self.var(in_var_name)
for in_var_name in op.inputs(input_name)
self.var(in_var_name)._var
for in_var_name in op.input(input_name)
]
for output_name in op.output_names:
outputs[output_name] = [
self.var(out_var_name)
self.var(out_var_name)._var
for out_var_name in op.output(output_name)
]
for attr_name in op.attr_names:
......
......@@ -16,5 +16,7 @@ from __future__ import print_function
from . import quantization_pass
from .quantization_pass import *
from . import quantization_strategy
from .quantization_strategy import *
__all__ = quantization_pass.__all__
__all__ = quantization_pass.__all__ + quantization_strategy.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys
import numpy as np
from .... import Executor
from .... import io
from .... import core
from ....compiler import CompiledProgram
from ....compiler import BuildStrategy
from ....framework import IrGraph
from ..core.strategy import Strategy
from .quantization_pass import *
__all__ = ['QuantizationStrategy']
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
class QuantizationStrategy(Strategy):
"""
The strategy for Quantization.
"""
def __init__(self,
start_epoch=0,
end_epoch=0,
float_model_save_path=None,
mobile_model_save_path=None,
int8_model_save_path=None,
activation_bits=8,
weight_bits=8,
activation_quantize_type='abs_max',
save_in_nodes=None,
save_out_nodes=None):
"""
Args:
start_epoch(int): The 'on_epoch_begin' function will be called in start_epoch. default: 0
end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. default: 0
float_model_save_path(str): The path to save model with float weights.
None means it doesn't save float model. defalut: None.
mobile_model_save_path(str): The path to save model for paddle-mobile execution.
None means it doesn't save mobile model. defalut: None.
int8_model_save_path(str): The path to save model with int8_t weight.
None means it doesn't save int8 model. defalut: None.
activation_bits(int): quantization bit number for activation. default: 8.
weight_bits(int): quantization bit number for weights. The bias is not quantized.
default: 8.
activation_quantize_type(str): quantization type for activation,
now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'.
If use 'abs_max' mode, the quantization scale will be calculated
dynamically each step in both training and testing period. If use
'range_abs_max', a static quantization scale will be calculated
during training and used in inference.
save_in_nodes(list<str>): A list of variable names used to prune graph
for saving inference model.
save_out_nodes(list<str>): A list of variable names used to prune graph
for saving inference model.
"""
super(QuantizationStrategy, self).__init__(start_epoch, end_epoch)
self.start_epoch = start_epoch
self.end_epoch = end_epoch
self.float_model_save_path = float_model_save_path
self.mobile_model_save_path = mobile_model_save_path
self.int8_model_save_path = int8_model_save_path
self.activation_bits = activation_bits
self.weight_bits = weight_bits
self.activation_quantize_type = activation_quantize_type
self.save_out_nodes = save_out_nodes
self.save_in_nodes = save_in_nodes
def on_epoch_begin(self, context):
"""
Insert fake_quantize_op and fake_dequantize_op before trainging and testing.
"""
super(QuantizationStrategy, self).on_compression_begin(context)
if self.start_epoch == context.epoch_id:
_logger.info('QuantizationStrategy::on_epoch_begin')
train_ir_graph = IrGraph(
core.Graph(context.optimize_graph.program.desc), for_test=False)
test_ir_graph = IrGraph(
core.Graph(context.eval_graph.program.desc), for_test=True)
transform_pass = QuantizationTransformPass(
scope=context.scope,
place=context.place,
weight_bits=self.weight_bits,
activation_bits=self.activation_bits,
activation_quantize_type=self.activation_quantize_type)
transform_pass.apply(train_ir_graph)
transform_pass.apply(test_ir_graph)
build_strategy = BuildStrategy()
build_strategy.enable_inplace = False
build_strategy.memory_optimize = False
# for quantization training
context.optimize_graph.compiled_graph = CompiledProgram(
train_ir_graph.graph).with_data_parallel(
loss_name=context.optimize_graph.out_nodes['loss'],
build_strategy=build_strategy)
# for evaluation. And program compiled from ir graph must be with data parallel.
context.eval_graph.compiled_graph = CompiledProgram(
test_ir_graph.graph).with_data_parallel(
build_strategy=build_strategy)
# for saving inference model after training
context.put('quantization_test_ir_graph_backup', test_ir_graph)
_logger.info('Finish QuantizationStrategy::on_epoch_begin')
def on_epoch_end(self, context):
"""
Free and save inference model.
"""
super(QuantizationStrategy, self).on_compression_end(context)
if context.epoch_id == self.end_epoch:
_logger.info('QuantizationStrategy::on_epoch_end')
test_ir_graph = context.get('quantization_test_ir_graph_backup')
# freeze the graph after training
freeze_pass = QuantizationFreezePass(
scope=context.scope,
place=context.place,
weight_bits=self.weight_bits,
activation_bits=self.activation_bits)
freeze_pass.apply(test_ir_graph)
# for other strategies
context.eval_graph.program = test_ir_graph.to_program()
if self.save_out_nodes == None:
out_vars = [
context.eval_graph.var(var_name)._var
for var_name in context.eval_graph.out_nodes.values()
]
else:
out_vars = [
context.eval_graph.var(var_name)._var
for var_name in self.save_out_nodes
]
if self.save_in_nodes == None:
in_vars = list(context.eval_graph.out_nodes.values())
else:
in_vars = self.save_in_nodes
# save float model
if self.float_model_save_path:
executor = Executor(context.place)
io.save_inference_model(
self.float_model_save_path,
in_vars,
out_vars,
executor,
main_program=test_ir_graph.to_program(),
model_filename='model',
params_filename='weights',
export_for_deployment=True)
# save int8 model
if self.int8_model_save_path:
convert_int8_pass = ConvertToInt8Pass(
scope=context.scope, place=context.place)
convert_int8_pass.apply(test_ir_graph)
executor = Executor(context.place)
io.save_inference_model(
self.int8_model_save_path,
in_vars,
out_vars,
executor,
main_program=test_ir_graph.to_program(),
model_filename='model',
params_filename='weights',
export_for_deployment=True)
# save mobile model
if self.mobile_model_save_path:
if not self.int8_model_save_path:
# convert the weights as int8_t type
convert_int8_pass = ConvertToInt8Pass(
scope=context.scope, place=context.place)
convert_int8_pass.apply(test_ir_graph)
# make some changes on the graph for the mobile inference
mobile_pass = TransformForMobilePass()
mobile_pass.apply(test_ir_graph)
executor = Executor(context.place)
io.save_inference_model(
self.mobile_model_save_path,
in_vars,
out_vars,
executor,
main_program=test_ir_graph.to_program(),
model_filename='model',
params_filename='weights',
export_for_deployment=True)
_logger.info('Finish QuantizationStrategy::on_epoch_end')
#start_epoch(int): The epoch when to merge student graph and teacher graph for
# distillation training. default: 0
#
#end_epoch(int): The epoch when to finish distillation training. default: 0
#
#student_feature_map(str): The name of feature map from student network.
#
#teacher_feature_map(str): The name of feature map from teacher network.
# It's shape should be the same with student network.
#
#student_pairs(list<tuple>): Each tuple, with two variable names, in student_pairs indicates
# a section in student network. The variables in a tuple should
# have the same feature map size.
#
#teacher_pairs(list<tuple>): Each tuple, with two variable names, in teacher_pairs indicates
# a section in teacher network. The variables in a tuple should
# have the same feature map size. Varibale named teacher_pairs[i][j]
# should has the save channel number with that of variable named
# student_pairs[i][j].
#
#distillation_loss_weight(float): The weight of the loss.
version: 1.0
distillers:
fsp_distiller:
class: 'FSPDistiller'
# teacher_pairs: [['teacher_depthwise_conv2d_1.tmp_0', 'teacher_conv2d_3.tmp_0']]
# student_pairs: [['student_depthwise_conv2d_1.tmp_0', 'student_conv2d_3.tmp_0']]
teacher_pairs: [['teacher_conv2_1_dw.tmp_0', 'teacher_conv1.tmp_0']]
student_pairs: [['student_conv2_1_dw.tmp_0', 'student_conv1.tmp_0']]
distillation_loss_weight: 1
l2_distiller:
class: 'L2Distiller'
teacher_feature_map: 'teacher.tmp_2'
student_feature_map: 'student.tmp_2'
distillation_loss_weight: 1
strategies:
distillation_strategy:
class: 'DistillationStrategy'
distillers: ['fsp_distiller', 'l2_distiller']
start_epoch: 0
end_epoch: 1
compressor:
epoch: 1
checkpoint_path: './distillation_checkpoints/'
strategies:
- distillation_strategy
......@@ -29,6 +29,6 @@ strategies:
metric_name: 'acc_top1'
compressor:
epoch: 2
checkpoint_path: './checkpoints/'
checkpoint_path: './checkpoints_pruning/'
strategies:
- sensitive_pruning_strategy
......@@ -35,8 +35,9 @@ train_parameters = {
class MobileNet():
def __init__(self):
def __init__(self, name=""):
self.params = train_parameters
self.name = name
def net(self, input, class_dim=1000, scale=1.0):
# conv1: 112x112
......@@ -47,7 +48,7 @@ class MobileNet():
num_filters=int(32 * scale),
stride=2,
padding=1,
name="conv1")
name=self.name + "_conv1")
# 56x56
input = self.depthwise_separable(
......@@ -57,7 +58,7 @@ class MobileNet():
num_groups=32,
stride=1,
scale=scale,
name="conv2_1")
name=self.name + "_conv2_1")
input = self.depthwise_separable(
input,
......@@ -66,7 +67,7 @@ class MobileNet():
num_groups=64,
stride=2,
scale=scale,
name="conv2_2")
name=self.name + "_conv2_2")
# 28x28
input = self.depthwise_separable(
......@@ -76,7 +77,7 @@ class MobileNet():
num_groups=128,
stride=1,
scale=scale,
name="conv3_1")
name=self.name + "_conv3_1")
input = self.depthwise_separable(
input,
......@@ -85,7 +86,7 @@ class MobileNet():
num_groups=128,
stride=2,
scale=scale,
name="conv3_2")
name=self.name + "_conv3_2")
# 14x14
input = self.depthwise_separable(
......@@ -95,7 +96,7 @@ class MobileNet():
num_groups=256,
stride=1,
scale=scale,
name="conv4_1")
name=self.name + "_conv4_1")
input = self.depthwise_separable(
input,
......@@ -104,7 +105,7 @@ class MobileNet():
num_groups=256,
stride=2,
scale=scale,
name="conv4_2")
name=self.name + "_conv4_2")
# 14x14
for i in range(5):
......@@ -115,7 +116,7 @@ class MobileNet():
num_groups=512,
stride=1,
scale=scale,
name="conv5" + "_" + str(i + 1))
name=self.name + "_conv5" + "_" + str(i + 1))
# 7x7
input = self.depthwise_separable(
input,
......@@ -124,7 +125,7 @@ class MobileNet():
num_groups=512,
stride=2,
scale=scale,
name="conv5_6")
name=self.name + "_conv5_6")
input = self.depthwise_separable(
input,
......@@ -133,7 +134,7 @@ class MobileNet():
num_groups=1024,
stride=1,
scale=scale,
name="conv6")
name=self.name + "_conv6")
input = fluid.layers.pool2d(
input=input,
......@@ -142,12 +143,14 @@ class MobileNet():
pool_type='avg',
global_pooling=True)
output = fluid.layers.fc(input=input,
size=class_dim,
act='softmax',
param_attr=ParamAttr(
initializer=MSRA(), name="fc7_weights"),
bias_attr=ParamAttr(name="fc7_offset"))
output = fluid.layers.fc(
input=input,
size=class_dim,
act='softmax',
param_attr=ParamAttr(
initializer=MSRA(), name=self.name + "_fc7_weights"),
bias_attr=ParamAttr(name=self.name + "_fc7_offset"),
name=self.name)
return output
def conv_bn_layer(self,
......@@ -172,11 +175,13 @@ class MobileNet():
use_cudnn=use_cudnn,
param_attr=ParamAttr(
initializer=MSRA(), name=name + "_weights"),
name=name,
bias_attr=False)
bn_name = name + "_bn"
return fluid.layers.batch_norm(
input=conv,
act=act,
name=name,
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(name=bn_name + "_offset"),
moving_mean_name=bn_name + '_mean',
......
#start_epoch(int): The epoch to insert quantization operators. default: 0
#
#end_epoch(int): The epoch to save inferecne model. default: 0
#
#float_model_save_path(str): The path to save model with float weights.
# None means it doesn't save float model. defalut: None.
#
#mobile_model_save_path(str): The path to save model for paddle-mobile execution.
# None means it doesn't save mobile model. defalut: None.
#
#int8_model_save_path(str): The path to save model with int8_t weight.
# None means it doesn't save int8 model. defalut: None.
#
#activation_bits(int): quantization bit number for activation. default: 8.
#
#weight_bits(int): quantization bit number for weights. The bias is not quantized.
# default: 8.
#
#activation_quantize_type(str): quantization type for activation,
# now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'.
# If use 'abs_max' mode, the quantization scale will be calculated
# dynamically each step in both training and testing period. If use
# 'range_abs_max', a static quantization scale will be calculated
# during training and used in inference.
#
#save_in_nodes(list<str>): A list of variable names used to prune graph
# for saving inference model.
#
#save_out_nodes(list<str>): A list of variable names used to prune graph
# for saving inference model.
version: 1.0
strategies:
quantization_strategy:
class: 'QuantizationStrategy'
start_epoch: 0
end_epoch: 0
float_model_save_path: './output/float'
weight_bits: 8
activation_bits: 8
weight_quantize_type: 'abs_max'
activation_quantize_type: 'abs_max'
save_in_nodes: ['image']
save_out_nodes: ['quan.tmp_2']
compressor:
epoch: 1
checkpoint_path: './checkpoints_quan/'
strategies:
- quantization_strategy
# copyright (c) 2019 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import paddle
import unittest
import paddle.fluid as fluid
from mobilenet import MobileNet
from paddle.fluid.contrib.slim.core import Compressor
from paddle.fluid.contrib.slim.graph import GraphWrapper
class TestDistillationStrategy(unittest.TestCase):
"""
Test API of distillation strategy.
"""
def test_compression(self):
if not fluid.core.is_compiled_with_cuda():
return
class_dim = 10
image_shape = [1, 28, 28]
image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
image.stop_gradient = False
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
out = MobileNet(name="student").net(input=image, class_dim=class_dim)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
val_program = fluid.default_main_program().clone(for_test=False)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
optimizer = fluid.optimizer.Momentum(
momentum=0.9,
learning_rate=0.01,
regularization=fluid.regularizer.L2Decay(4e-5))
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128)
val_feed_list = [('img', image.name), ('label', label.name)]
val_fetch_list = [('acc_top1', acc_top1.name), ('acc_top5',
acc_top5.name)]
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128)
train_feed_list = [('img', image.name), ('label', label.name)]
train_fetch_list = [('loss', avg_cost.name)]
# define teacher program
teacher_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(teacher_program, startup_program):
img = teacher_program.global_block()._clone_variable(
image, force_persistable=False)
predict = MobileNet(name="teacher").net(input=img,
class_dim=class_dim)
exe.run(startup_program)
com_pass = Compressor(
place,
fluid.global_scope(),
fluid.default_main_program(),
train_reader=train_reader,
train_feed_list=train_feed_list,
train_fetch_list=train_fetch_list,
eval_program=val_program,
eval_reader=val_reader,
eval_feed_list=val_feed_list,
eval_fetch_list=val_fetch_list,
teacher_programs=[teacher_program.clone(for_test=True)],
train_optimizer=optimizer,
distiller_optimizer=optimizer)
com_pass.config('./distillation/compress.yaml')
eval_graph = com_pass.run()
if __name__ == '__main__':
unittest.main()
......@@ -15,7 +15,7 @@
import paddle
import unittest
import paddle.fluid as fluid
from filter_pruning.mobilenet import MobileNet
from mobilenet import MobileNet
from paddle.fluid.contrib.slim.core import Compressor
from paddle.fluid.contrib.slim.graph import GraphWrapper
......
# copyright (c) 2019 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import paddle
import unittest
import paddle.fluid as fluid
from mobilenet import MobileNet
from paddle.fluid.contrib.slim.core import Compressor
from paddle.fluid.contrib.slim.graph import GraphWrapper
class TestQuantizationStrategy(unittest.TestCase):
"""
Test API of quantization strategy.
"""
def test_compression(self):
if not fluid.core.is_compiled_with_cuda():
return
class_dim = 10
image_shape = [1, 28, 28]
image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
image.stop_gradient = False
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
out = MobileNet(name='quan').net(input=image, class_dim=class_dim)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
val_program = fluid.default_main_program().clone(for_test=False)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
optimizer = fluid.optimizer.Momentum(
momentum=0.9,
learning_rate=0.01,
regularization=fluid.regularizer.L2Decay(4e-5))
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128)
val_feed_list = [('img', image.name), ('label', label.name)]
val_fetch_list = [('acc_top1', acc_top1.name), ('acc_top5',
acc_top5.name)]
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128)
train_feed_list = [('img', image.name), ('label', label.name)]
train_fetch_list = [('loss', avg_cost.name)]
com_pass = Compressor(
place,
fluid.global_scope(),
fluid.default_main_program(),
train_reader=train_reader,
train_feed_list=train_feed_list,
train_fetch_list=train_fetch_list,
eval_program=val_program,
eval_reader=val_reader,
eval_feed_list=val_feed_list,
eval_fetch_list=val_fetch_list,
train_optimizer=optimizer)
com_pass.config('./quantization/compress.yaml')
eval_graph = com_pass.run()
if __name__ == '__main__':
unittest.main()
......@@ -1559,12 +1559,15 @@ class Block(object):
name=v.name)
self.vars[new_p.name] = new_p
def _clone_variable(self, var):
def _clone_variable(self, var, force_persistable=True):
"""
Clone a variable into current block.
Args:
var: the variable to be cloned.
force_persistable(bool): True means setting the result variable to being persistable.
False means setting the persistable the same with that of input var.
default: True.
Returns:
Variable: the new variable cloned from 'var' in current block.
......@@ -1584,7 +1587,7 @@ class Block(object):
shape=var.shape,
dtype=var.dtype,
type=var.type,
persistable=True,
persistable=True if force_persistable else var.persistable,
is_data=var.is_data)
else:
ret_var = self.create_var(
......@@ -1593,7 +1596,7 @@ class Block(object):
dtype=var.dtype,
type=var.type,
lod_level=var.lod_level,
persistable=True,
persistable=True if force_persistable else var.persistable,
is_data=var.is_data)
return ret_var
......
......@@ -189,6 +189,7 @@ __all__ = [
'huber_loss',
'tree_conv',
'npair_loss',
'fsp_matrix',
]
kIgnoreIndex = -100
......@@ -10790,3 +10791,46 @@ def npair_loss(anchor, positive, labels, l2_reg=0.002):
celoss = reduce_mean(cross_entropy)
return l2loss + celoss
def fsp_matrix(x, y):
"""
**FSP matrix op**
This op is used to calculate the flow of solution procedure (FSP) matrix of two feature maps.
Given feature map x with shape [x_channel, h, w] and feature map y with shape
[y_channel, h, w], we can get the fsp matrix of x and y in two steps:
1. reshape x into matrix with shape [x_channel, h * w] and reshape and
transpose y into matrix with shape [h * w, y_channel].
2. multiply x and y to get fsp matrix with shape [x_channel, y_channel].
The output is a batch of fsp matrices.
Args:
x (Variable): A feature map with shape [batch_size, x_channel, height, width].
y (Variable): A feature map with shape [batch_size, y_channel, height, width].
The y_channel can be different with the x_channel of Input(X)
while the other dimensions must be the same with Input(X)'s.
Returns:
fsp matrix (Variable): The output of FSP op with shape [batch_size, x_channel, y_channel].
The x_channel is the channel of x and the y_channel is the channel of y.
Examples:
.. code-block:: python
feature_map_0 = fluid.layers.conv2d(x)
feature_map_1 = fluid.layers.conv2d(feature_map_0)
loss = fluid.layers.fsp_matrix(feature_map_0, feature_map_1)
"""
helper = LayerHelper('fsp_matrix', **locals())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype(
input_param_name='x'))
helper.append_op(type='fsp', inputs={'X': x, 'Y': y}, outputs={'Out': out})
return out
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
def fsp_matrix(a, b):
batch = a.shape[0]
a_channel = a.shape[1]
b_channel = b.shape[1]
h = a.shape[2]
w = a.shape[3]
a_t = a.transpose([0, 2, 3, 1])
a_t = a_t.reshape([batch, h * w, a_channel])
b_t = b.transpose([0, 2, 3, 1]).reshape([batch, h * w, b_channel])
a_r = a_t.repeat(
b_channel, axis=1).reshape(
[batch, h * w, b_channel, a_channel]).transpose([0, 1, 3, 2])
b_r = b_t.repeat(
a_channel, axis=1).reshape([batch, h * w, a_channel, b_channel])
return np.mean(a_r * b_r, axis=1)
class TestFSPOp(OpTest):
def setUp(self):
self.op_type = "fsp"
self.initTestCase()
feature_map_0 = np.random.uniform(0, 10, self.a_shape).astype('float32')
feature_map_1 = np.random.uniform(0, 10, self.b_shape).astype('float32')
self.inputs = {'X': feature_map_0, 'Y': feature_map_1}
self.outputs = {'Out': fsp_matrix(feature_map_0, feature_map_1)}
def initTestCase(self):
self.a_shape = (2, 16, 32, 31)
self.b_shape = (2, 28, 32, 31)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.05)
if __name__ == '__main__':
unittest.main()
......@@ -1269,6 +1269,15 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(out)
print(str(program))
def test_fsp(self):
program = Program()
with program_guard(program):
x = layers.data(name="X", shape=[16, 4, 4], dtype="float32")
y = layers.data(name="Y", shape=[8, 4, 4], dtype="float32")
out = layers.fsp_matrix(x, y)
self.assertIsNotNone(out)
print(str(program))
if __name__ == '__main__':
unittest.main()
......@@ -117,6 +117,7 @@ packages=['paddle',
'paddle.fluid.contrib.slim.graph',
'paddle.fluid.contrib.slim.prune',
'paddle.fluid.contrib.slim.quantization',
'paddle.fluid.contrib.slim.distillation',
'paddle.fluid.contrib.utils',
'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册