Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
50afdd83
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
50afdd83
编写于
3月 02, 2018
作者:
Q
qingqing01
提交者:
GitHub
3月 02, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #617 from dragonwarrior/caffe2fluid
Caffe2fluid
上级
243ee52b
39daecc2
变更
20
展开全部
隐藏空白更改
内联
并排
Showing
20 changed file
with
3529 addition
and
0 deletion
+3529
-0
fluid/image_classification/caffe2fluid/README.md
fluid/image_classification/caffe2fluid/README.md
+25
-0
fluid/image_classification/caffe2fluid/convert.py
fluid/image_classification/caffe2fluid/convert.py
+72
-0
fluid/image_classification/caffe2fluid/kaffe/__init__.py
fluid/image_classification/caffe2fluid/kaffe/__init__.py
+5
-0
fluid/image_classification/caffe2fluid/kaffe/caffe/__init__.py
.../image_classification/caffe2fluid/kaffe/caffe/__init__.py
+1
-0
fluid/image_classification/caffe2fluid/kaffe/caffe/resolver.py
.../image_classification/caffe2fluid/kaffe/caffe/resolver.py
+61
-0
fluid/image_classification/caffe2fluid/kaffe/errors.py
fluid/image_classification/caffe2fluid/kaffe/errors.py
+34
-0
fluid/image_classification/caffe2fluid/kaffe/graph.py
fluid/image_classification/caffe2fluid/kaffe/graph.py
+302
-0
fluid/image_classification/caffe2fluid/kaffe/layers.py
fluid/image_classification/caffe2fluid/kaffe/layers.py
+152
-0
fluid/image_classification/caffe2fluid/kaffe/paddle/__init__.py
...image_classification/caffe2fluid/kaffe/paddle/__init__.py
+2
-0
fluid/image_classification/caffe2fluid/kaffe/paddle/network.py
.../image_classification/caffe2fluid/kaffe/paddle/network.py
+260
-0
fluid/image_classification/caffe2fluid/kaffe/paddle/transformer.py
...ge_classification/caffe2fluid/kaffe/paddle/transformer.py
+353
-0
fluid/image_classification/caffe2fluid/kaffe/shapes.py
fluid/image_classification/caffe2fluid/kaffe/shapes.py
+88
-0
fluid/image_classification/caffe2fluid/kaffe/transformers.py
fluid/image_classification/caffe2fluid/kaffe/transformers.py
+303
-0
fluid/image_classification/caffe2fluid/proto/caffe.proto
fluid/image_classification/caffe2fluid/proto/caffe.proto
+1411
-0
fluid/image_classification/caffe2fluid/proto/compile.sh
fluid/image_classification/caffe2fluid/proto/compile.sh
+28
-0
fluid/image_classification/caffe2fluid/tests/lenet/README.md
fluid/image_classification/caffe2fluid/tests/lenet/README.md
+28
-0
fluid/image_classification/caffe2fluid/tests/lenet/convert.sh
...d/image_classification/caffe2fluid/tests/lenet/convert.sh
+33
-0
fluid/image_classification/caffe2fluid/tests/lenet/lenet.npy
fluid/image_classification/caffe2fluid/tests/lenet/lenet.npy
+0
-0
fluid/image_classification/caffe2fluid/tests/lenet/lenet.py
fluid/image_classification/caffe2fluid/tests/lenet/lenet.py
+297
-0
fluid/image_classification/caffe2fluid/tests/lenet/predict.py
...d/image_classification/caffe2fluid/tests/lenet/predict.py
+74
-0
未找到文件。
fluid/image_classification/caffe2fluid/README.md
0 → 100644
浏览文件 @
50afdd83
### Caffe2Fluid
This tool is used to convert a Caffe model to Fluid model
### Howto
1, Prepare caffepb.py in ./proto, two options provided
1) generate it from caffe.proto using protoc
bash ./proto/compile.sh
2) download one from github directly
cd proto/ && wget https://github.com/ethereon/caffe-tensorflow/blob/master/kaffe/caffe/caffepb.py
2, Convert the caffe model using 'convert.py' which will generate a python script and a weight(in .npy) file
3, Use the converted model to predict
see more detail info in 'tests/lenet/README.md'
### Supported models
-
Lenet on mnist dataset
-
ResNets:(ResNet-50, ResNet-101, ResNet-152)
model addrs:(https://onedrive.live.com/?authkey=%21AAFW2-FVoxeVRck&id=4006CBB8476FF777%2117887&cid=4006CBB8476FF777)
### Notes
Some of this code come from here: https://github.com/ethereon/caffe-tensorflow
fluid/image_classification/caffe2fluid/convert.py
0 → 100755
浏览文件 @
50afdd83
#!/usr/bin/env python
import
os
import
sys
import
numpy
as
np
import
argparse
from
kaffe
import
KaffeError
,
print_stderr
from
kaffe.paddle
import
Transformer
def
fatal_error
(
msg
):
""" fatal error encounted
"""
print_stderr
(
msg
)
exit
(
-
1
)
def
validate_arguments
(
args
):
""" validate args
"""
if
(
args
.
data_output_path
is
not
None
)
and
(
args
.
caffemodel
is
None
):
fatal_error
(
'No input data path provided.'
)
if
(
args
.
caffemodel
is
not
None
)
and
(
args
.
data_output_path
is
None
):
fatal_error
(
'No output data path provided.'
)
if
(
args
.
code_output_path
is
None
)
and
(
args
.
data_output_path
is
None
):
fatal_error
(
'No output path specified.'
)
def
convert
(
def_path
,
caffemodel_path
,
data_output_path
,
code_output_path
,
phase
):
""" convert caffe model to tf/paddle models
"""
try
:
transformer
=
Transformer
(
def_path
,
caffemodel_path
,
phase
=
phase
)
print_stderr
(
'Converting data...'
)
if
caffemodel_path
is
not
None
:
data
=
transformer
.
transform_data
()
print_stderr
(
'Saving data...'
)
with
open
(
data_output_path
,
'wb'
)
as
data_out
:
np
.
save
(
data_out
,
data
)
if
code_output_path
:
print_stderr
(
'Saving source...'
)
with
open
(
code_output_path
,
'wb'
)
as
src_out
:
src_out
.
write
(
transformer
.
transform_source
())
print_stderr
(
'Done.'
)
except
KaffeError
as
err
:
fatal_error
(
'Error encountered: {}'
.
format
(
err
))
def
main
():
""" main
"""
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'def_path'
,
help
=
'Model definition (.prototxt) path'
)
parser
.
add_argument
(
'--caffemodel'
,
help
=
'Model data (.caffemodel) path'
)
parser
.
add_argument
(
'--data-output-path'
,
help
=
'Converted data output path'
)
parser
.
add_argument
(
'--code-output-path'
,
help
=
'Save generated source to this path'
)
parser
.
add_argument
(
'-p'
,
'--phase'
,
default
=
'test'
,
help
=
'The phase to convert: test (default) or train'
)
args
=
parser
.
parse_args
()
validate_arguments
(
args
)
convert
(
args
.
def_path
,
args
.
caffemodel
,
args
.
data_output_path
,
args
.
code_output_path
,
args
.
phase
)
if
__name__
==
'__main__'
:
main
()
fluid/image_classification/caffe2fluid/kaffe/__init__.py
0 → 100644
浏览文件 @
50afdd83
from
.graph
import
GraphBuilder
,
NodeMapper
from
.errors
import
KaffeError
,
print_stderr
import
os
from
.
import
paddle
fluid/image_classification/caffe2fluid/kaffe/caffe/__init__.py
0 → 100644
浏览文件 @
50afdd83
from
.resolver
import
get_caffe_resolver
,
has_pycaffe
fluid/image_classification/caffe2fluid/kaffe/caffe/resolver.py
0 → 100644
浏览文件 @
50afdd83
import
os
import
sys
SHARED_CAFFE_RESOLVER
=
None
def
import_caffepb
():
p
=
os
.
path
.
realpath
(
__file__
)
p
=
os
.
path
.
dirname
(
p
)
p
=
os
.
path
.
join
(
p
,
'../../proto'
)
sys
.
path
.
insert
(
0
,
p
)
import
caffepb
return
caffepb
class
CaffeResolver
(
object
):
def
__init__
(
self
):
self
.
import_caffe
()
def
import_caffe
(
self
):
self
.
caffe
=
None
try
:
# Try to import PyCaffe first
import
caffe
self
.
caffe
=
caffe
except
ImportError
:
# Fall back to the protobuf implementation
self
.
caffepb
=
import_caffepb
()
show_fallback_warning
()
if
self
.
caffe
:
# Use the protobuf code from the imported distribution.
# This way, Caffe variants with custom layers will work.
self
.
caffepb
=
self
.
caffe
.
proto
.
caffe_pb2
self
.
NetParameter
=
self
.
caffepb
.
NetParameter
def
has_pycaffe
(
self
):
return
self
.
caffe
is
not
None
def
get_caffe_resolver
():
global
SHARED_CAFFE_RESOLVER
if
SHARED_CAFFE_RESOLVER
is
None
:
SHARED_CAFFE_RESOLVER
=
CaffeResolver
()
return
SHARED_CAFFE_RESOLVER
def
has_pycaffe
():
return
get_caffe_resolver
().
has_pycaffe
()
def
show_fallback_warning
():
msg
=
'''
------------------------------------------------------------
WARNING: PyCaffe not found!
Falling back to a pure protocol buffer implementation.
* Conversions will be drastically slower.
* This backend is UNTESTED!
------------------------------------------------------------
'''
sys
.
stderr
.
write
(
msg
)
fluid/image_classification/caffe2fluid/kaffe/errors.py
0 → 100644
浏览文件 @
50afdd83
import
sys
#debug level, can be 'warn', 'verbose'
log_level
=
'warn'
class
KaffeError
(
Exception
):
pass
def
print_stderr
(
msg
):
sys
.
stderr
.
write
(
'%s
\n
'
%
msg
)
def
debug
(
msg
):
if
log_level
==
'verbose'
:
print_stderr
(
'[DEBUG]'
+
msg
)
def
notice
(
msg
):
print_stderr
(
'[NOTICE]'
+
msg
)
def
warn
(
msg
):
print_stderr
(
'[WARNING]'
+
msg
)
def
set_loglevel
(
level
):
global
log_level
if
'warn'
!=
level
and
'verbose'
!=
level
:
raise
Exception
(
'not supported log level[%s]'
%
(
level
))
log_level
=
level
fluid/image_classification/caffe2fluid/kaffe/graph.py
0 → 100644
浏览文件 @
50afdd83
from
google.protobuf
import
text_format
from
.caffe
import
get_caffe_resolver
from
.errors
import
KaffeError
,
print_stderr
from
.layers
import
LayerAdapter
,
LayerType
,
NodeKind
,
NodeDispatch
from
.shapes
import
TensorShape
class
Node
(
object
):
def
__init__
(
self
,
name
,
kind
,
layer
=
None
):
self
.
name
=
name
self
.
kind
=
kind
self
.
layer
=
LayerAdapter
(
layer
,
kind
)
if
layer
else
None
self
.
parents
=
[]
self
.
children
=
[]
self
.
data
=
None
self
.
output_shape
=
None
self
.
metadata
=
{}
def
add_parent
(
self
,
parent_node
):
assert
parent_node
not
in
self
.
parents
self
.
parents
.
append
(
parent_node
)
if
self
not
in
parent_node
.
children
:
parent_node
.
children
.
append
(
self
)
def
add_child
(
self
,
child_node
):
assert
child_node
not
in
self
.
children
self
.
children
.
append
(
child_node
)
if
self
not
in
child_node
.
parents
:
child_node
.
parents
.
append
(
self
)
def
get_only_parent
(
self
):
if
len
(
self
.
parents
)
!=
1
:
raise
KaffeError
(
'Node (%s) expected to have 1 parent. Found %s.'
%
(
self
,
len
(
self
.
parents
)))
return
self
.
parents
[
0
]
@
property
def
parameters
(
self
):
if
self
.
layer
is
not
None
:
return
self
.
layer
.
parameters
return
None
def
__str__
(
self
):
return
'[%s] %s'
%
(
self
.
kind
,
self
.
name
)
def
__repr__
(
self
):
return
'%s (0x%x)'
%
(
self
.
name
,
id
(
self
))
class
Graph
(
object
):
def
__init__
(
self
,
nodes
=
None
,
name
=
None
):
self
.
nodes
=
nodes
or
[]
self
.
node_lut
=
{
node
.
name
:
node
for
node
in
self
.
nodes
}
self
.
name
=
name
def
add_node
(
self
,
node
):
self
.
nodes
.
append
(
node
)
self
.
node_lut
[
node
.
name
]
=
node
def
get_node
(
self
,
name
):
try
:
return
self
.
node_lut
[
name
]
except
KeyError
:
raise
KaffeError
(
'Layer not found: %s'
%
name
)
def
get_input_nodes
(
self
):
return
[
node
for
node
in
self
.
nodes
if
len
(
node
.
parents
)
==
0
]
def
get_output_nodes
(
self
):
return
[
node
for
node
in
self
.
nodes
if
len
(
node
.
children
)
==
0
]
def
topologically_sorted
(
self
):
sorted_nodes
=
[]
unsorted_nodes
=
list
(
self
.
nodes
)
temp_marked
=
set
()
perm_marked
=
set
()
def
visit
(
node
):
if
node
in
temp_marked
:
raise
KaffeError
(
'Graph is not a DAG.'
)
if
node
in
perm_marked
:
return
temp_marked
.
add
(
node
)
for
child
in
node
.
children
:
visit
(
child
)
perm_marked
.
add
(
node
)
temp_marked
.
remove
(
node
)
sorted_nodes
.
insert
(
0
,
node
)
while
len
(
unsorted_nodes
):
visit
(
unsorted_nodes
.
pop
())
return
sorted_nodes
def
compute_output_shapes
(
self
):
sorted_nodes
=
self
.
topologically_sorted
()
for
node
in
sorted_nodes
:
node
.
output_shape
=
TensorShape
(
*
NodeKind
.
compute_output_shape
(
node
))
def
replaced
(
self
,
new_nodes
):
return
Graph
(
nodes
=
new_nodes
,
name
=
self
.
name
)
def
transformed
(
self
,
transformers
):
graph
=
self
for
transformer
in
transformers
:
graph
=
transformer
(
graph
)
if
graph
is
None
:
raise
KaffeError
(
'Transformer failed: {}'
.
format
(
transformer
))
assert
isinstance
(
graph
,
Graph
)
return
graph
def
__contains__
(
self
,
key
):
return
key
in
self
.
node_lut
def
__str__
(
self
):
hdr
=
'{:<20} {:<30} {:>20} {:>20}'
.
format
(
'Type'
,
'Name'
,
'Param'
,
'Output'
)
s
=
[
hdr
,
'-'
*
94
]
for
node
in
self
.
topologically_sorted
():
# If the node has learned parameters, display the first one's shape.
# In case of convolutions, this corresponds to the weights.
data_shape
=
node
.
data
[
0
].
shape
if
node
.
data
else
'--'
out_shape
=
node
.
output_shape
or
'--'
s
.
append
(
'{:<20} {:<30} {:>20} {:>20}'
.
format
(
node
.
kind
,
node
.
name
,
data_shape
,
tuple
(
out_shape
)))
return
'
\n
'
.
join
(
s
)
class
GraphBuilder
(
object
):
'''Constructs a model graph from a Caffe protocol buffer definition.'''
def
__init__
(
self
,
def_path
,
phase
=
'test'
):
'''
def_path: Path to the model definition (.prototxt)
data_path: Path to the model data (.caffemodel)
phase: Either 'test' or 'train'. Used for filtering phase-specific nodes.
'''
self
.
def_path
=
def_path
self
.
phase
=
phase
self
.
load
()
def
load
(
self
):
'''Load the layer definitions from the prototxt.'''
self
.
params
=
get_caffe_resolver
().
NetParameter
()
with
open
(
self
.
def_path
,
'rb'
)
as
def_file
:
text_format
.
Merge
(
def_file
.
read
(),
self
.
params
)
def
filter_layers
(
self
,
layers
):
'''Filter out layers based on the current phase.'''
phase_map
=
{
0
:
'train'
,
1
:
'test'
}
filtered_layer_names
=
set
()
filtered_layers
=
[]
for
layer
in
layers
:
phase
=
self
.
phase
if
len
(
layer
.
include
):
phase
=
phase_map
[
layer
.
include
[
0
].
phase
]
if
len
(
layer
.
exclude
):
phase
=
phase_map
[
1
-
layer
.
include
[
0
].
phase
]
exclude
=
(
phase
!=
self
.
phase
)
# Dropout layers appear in a fair number of Caffe
# test-time networks. These are just ignored. We'll
# filter them out here.
if
(
not
exclude
)
and
(
phase
==
'test'
):
exclude
=
(
layer
.
type
==
LayerType
.
Dropout
)
if
not
exclude
:
filtered_layers
.
append
(
layer
)
# Guard against dupes.
assert
layer
.
name
not
in
filtered_layer_names
filtered_layer_names
.
add
(
layer
.
name
)
return
filtered_layers
def
make_node
(
self
,
layer
):
'''Create a graph node for the given layer.'''
kind
=
NodeKind
.
map_raw_kind
(
layer
.
type
)
if
kind
is
None
:
raise
KaffeError
(
'Unknown layer type encountered: %s'
%
layer
.
type
)
# We want to use the layer's top names (the "output" names), rather than the
# name attribute, which is more of readability thing than a functional one.
# Other layers will refer to a node by its "top name".
return
Node
(
layer
.
name
,
kind
,
layer
=
layer
)
def
make_input_nodes
(
self
):
'''
Create data input nodes.
This method is for old-style inputs, where the input specification
was not treated as a first-class layer in the prototext.
Newer models use the "Input layer" type.
'''
nodes
=
[
Node
(
name
,
NodeKind
.
Data
)
for
name
in
self
.
params
.
input
]
if
len
(
nodes
):
input_dim
=
map
(
int
,
self
.
params
.
input_dim
)
if
not
input_dim
:
if
len
(
self
.
params
.
input_shape
)
>
0
:
input_dim
=
map
(
int
,
self
.
params
.
input_shape
[
0
].
dim
)
else
:
raise
KaffeError
(
'Dimensions for input not specified.'
)
for
node
in
nodes
:
node
.
output_shape
=
tuple
(
input_dim
)
return
nodes
def
build
(
self
):
'''
Builds the graph from the Caffe layer definitions.
'''
# Get the layers
layers
=
self
.
params
.
layers
or
self
.
params
.
layer
# Filter out phase-excluded layers
layers
=
self
.
filter_layers
(
layers
)
# Get any separately-specified input layers
nodes
=
self
.
make_input_nodes
()
nodes
+=
[
self
.
make_node
(
layer
)
for
layer
in
layers
]
# Initialize the graph
graph
=
Graph
(
nodes
=
nodes
,
name
=
self
.
params
.
name
)
# Connect the nodes
#
# A note on layers and outputs:
# In Caffe, each layer can produce multiple outputs ("tops") from a set of inputs
# ("bottoms"). The bottoms refer to other layers' tops. The top can rewrite a bottom
# (in case of in-place operations). Note that the layer's name is not used for establishing
# any connectivity. It's only used for data association. By convention, a layer with a
# single top will often use the same name (although this is not required).
#
# The current implementation only supports single-output nodes (note that a node can still
# have multiple children, since multiple child nodes can refer to the single top's name).
node_outputs
=
{}
for
layer
in
layers
:
node
=
graph
.
get_node
(
layer
.
name
)
for
input_name
in
layer
.
bottom
:
assert
input_name
!=
layer
.
name
parent_node
=
node_outputs
.
get
(
input_name
)
if
(
parent_node
is
None
)
or
(
parent_node
==
node
):
parent_node
=
graph
.
get_node
(
input_name
)
node
.
add_parent
(
parent_node
)
if
len
(
layer
.
top
)
>
1
:
raise
KaffeError
(
'Multiple top nodes are not supported.'
)
for
output_name
in
layer
.
top
:
if
output_name
==
layer
.
name
:
# Output is named the same as the node. No further action required.
continue
# There are two possibilities here:
#
# Case 1: output_name refers to another node in the graph.
# This is an "in-place operation" that overwrites an existing node.
# This would create a cycle in the graph. We'll undo the in-placing
# by substituting this node wherever the overwritten node is referenced.
#
# Case 2: output_name violates the convention layer.name == output_name.
# Since we are working in the single-output regime, we will can rename it to
# match the layer name.
#
# For both cases, future references to this top re-routes to this node.
node_outputs
[
output_name
]
=
node
graph
.
compute_output_shapes
()
return
graph
class
NodeMapper
(
NodeDispatch
):
def
__init__
(
self
,
graph
):
self
.
graph
=
graph
def
map
(
self
):
nodes
=
self
.
graph
.
topologically_sorted
()
# Remove input nodes - we'll handle them separately.
input_nodes
=
self
.
graph
.
get_input_nodes
()
nodes
=
[
t
for
t
in
nodes
if
t
not
in
input_nodes
]
# Decompose DAG into chains.
chains
=
[]
for
node
in
nodes
:
attach_to_chain
=
None
if
len
(
node
.
parents
)
==
1
:
parent
=
node
.
get_only_parent
()
for
chain
in
chains
:
if
chain
[
-
1
]
==
parent
:
# Node is part of an existing chain.
attach_to_chain
=
chain
break
if
attach_to_chain
is
None
:
# Start a new chain for this node.
attach_to_chain
=
[]
chains
.
append
(
attach_to_chain
)
attach_to_chain
.
append
(
node
)
# Map each chain.
mapped_chains
=
[]
for
chain
in
chains
:
mapped_chains
.
append
(
self
.
map_chain
(
chain
))
return
self
.
commit
(
mapped_chains
)
def
map_chain
(
self
,
chain
):
return
[
self
.
map_node
(
node
)
for
node
in
chain
]
def
map_node
(
self
,
node
):
map_func
=
self
.
get_handler
(
node
.
kind
,
'map'
)
mapped_node
=
map_func
(
node
)
assert
mapped_node
is
not
None
mapped_node
.
node
=
node
return
mapped_node
def
commit
(
self
,
mapped_chains
):
raise
NotImplementedError
(
'Must be implemented by subclass.'
)
fluid/image_classification/caffe2fluid/kaffe/layers.py
0 → 100644
浏览文件 @
50afdd83
import
re
import
numbers
from
collections
import
namedtuple
from
.shapes
import
*
LAYER_DESCRIPTORS
=
{
# Caffe Types
'AbsVal'
:
shape_identity
,
'Accuracy'
:
shape_scalar
,
'ArgMax'
:
shape_not_implemented
,
'BatchNorm'
:
shape_identity
,
'BNLL'
:
shape_not_implemented
,
'Concat'
:
shape_concat
,
'ContrastiveLoss'
:
shape_scalar
,
'Convolution'
:
shape_convolution
,
'Deconvolution'
:
shape_not_implemented
,
'Data'
:
shape_data
,
'Dropout'
:
shape_identity
,
'DummyData'
:
shape_data
,
'EuclideanLoss'
:
shape_scalar
,
'Eltwise'
:
shape_identity
,
'Exp'
:
shape_identity
,
'Flatten'
:
shape_not_implemented
,
'HDF5Data'
:
shape_data
,
'HDF5Output'
:
shape_identity
,
'HingeLoss'
:
shape_scalar
,
'Im2col'
:
shape_not_implemented
,
'ImageData'
:
shape_data
,
'InfogainLoss'
:
shape_scalar
,
'InnerProduct'
:
shape_inner_product
,
'Input'
:
shape_data
,
'LRN'
:
shape_identity
,
'MemoryData'
:
shape_mem_data
,
'MultinomialLogisticLoss'
:
shape_scalar
,
'MVN'
:
shape_not_implemented
,
'Pooling'
:
shape_pool
,
'Power'
:
shape_identity
,
'ReLU'
:
shape_identity
,
'Scale'
:
shape_identity
,
'Sigmoid'
:
shape_identity
,
'SigmoidCrossEntropyLoss'
:
shape_scalar
,
'Silence'
:
shape_not_implemented
,
'Softmax'
:
shape_identity
,
'SoftmaxWithLoss'
:
shape_scalar
,
'Split'
:
shape_not_implemented
,
'Slice'
:
shape_not_implemented
,
'TanH'
:
shape_identity
,
'WindowData'
:
shape_not_implemented
,
'Threshold'
:
shape_identity
,
}
LAYER_TYPES
=
LAYER_DESCRIPTORS
.
keys
()
LayerType
=
type
(
'LayerType'
,
(),
{
t
:
t
for
t
in
LAYER_TYPES
})
class
NodeKind
(
LayerType
):
@
staticmethod
def
map_raw_kind
(
kind
):
if
kind
in
LAYER_TYPES
:
return
kind
return
None
@
staticmethod
def
compute_output_shape
(
node
):
try
:
val
=
LAYER_DESCRIPTORS
[
node
.
kind
](
node
)
return
val
except
NotImplementedError
:
raise
KaffeError
(
'Output shape computation not implemented for type: %s'
%
node
.
kind
)
class
NodeDispatchError
(
KaffeError
):
pass
class
NodeDispatch
(
object
):
@
staticmethod
def
get_handler_name
(
node_kind
):
if
len
(
node_kind
)
<=
4
:
# A catch-all for things like ReLU and tanh
return
node_kind
.
lower
()
# Convert from CamelCase to under_scored
name
=
re
.
sub
(
'(.)([A-Z][a-z]+)'
,
r
'\1_\2'
,
node_kind
)
return
re
.
sub
(
'([a-z0-9])([A-Z])'
,
r
'\1_\2'
,
name
).
lower
()
def
get_handler
(
self
,
node_kind
,
prefix
):
name
=
self
.
get_handler_name
(
node_kind
)
name
=
'_'
.
join
((
prefix
,
name
))
try
:
return
getattr
(
self
,
name
)
except
AttributeError
:
raise
NodeDispatchError
(
'No handler found for node kind: %s (expected: %s)'
%
(
node_kind
,
name
))
class
LayerAdapter
(
object
):
def
__init__
(
self
,
layer
,
kind
):
self
.
layer
=
layer
self
.
kind
=
kind
@
property
def
parameters
(
self
):
name
=
NodeDispatch
.
get_handler_name
(
self
.
kind
)
name
=
'_'
.
join
((
name
,
'param'
))
try
:
return
getattr
(
self
.
layer
,
name
)
except
AttributeError
:
raise
NodeDispatchError
(
'Caffe parameters not found for layer kind: %s'
%
(
self
.
kind
))
@
staticmethod
def
get_kernel_value
(
scalar
,
repeated
,
idx
,
default
=
None
):
if
scalar
:
return
scalar
if
repeated
:
if
isinstance
(
repeated
,
numbers
.
Number
):
return
repeated
if
len
(
repeated
)
==
1
:
# Same value applies to all spatial dimensions
return
int
(
repeated
[
0
])
assert
idx
<
len
(
repeated
)
# Extract the value for the given spatial dimension
return
repeated
[
idx
]
if
default
is
None
:
raise
ValueError
(
'Unable to determine kernel parameter!'
)
return
default
@
property
def
kernel_parameters
(
self
):
assert
self
.
kind
in
(
NodeKind
.
Convolution
,
NodeKind
.
Pooling
)
params
=
self
.
parameters
k_h
=
self
.
get_kernel_value
(
params
.
kernel_h
,
params
.
kernel_size
,
0
)
k_w
=
self
.
get_kernel_value
(
params
.
kernel_w
,
params
.
kernel_size
,
1
)
s_h
=
self
.
get_kernel_value
(
params
.
stride_h
,
params
.
stride
,
0
,
default
=
1
)
s_w
=
self
.
get_kernel_value
(
params
.
stride_w
,
params
.
stride
,
1
,
default
=
1
)
p_h
=
self
.
get_kernel_value
(
params
.
pad_h
,
params
.
pad
,
0
,
default
=
0
)
p_w
=
self
.
get_kernel_value
(
params
.
pad_h
,
params
.
pad
,
1
,
default
=
0
)
return
KernelParameters
(
k_h
,
k_w
,
s_h
,
s_w
,
p_h
,
p_w
)
KernelParameters
=
namedtuple
(
'KernelParameters'
,
[
'kernel_h'
,
'kernel_w'
,
'stride_h'
,
'stride_w'
,
'pad_h'
,
'pad_w'
])
fluid/image_classification/caffe2fluid/kaffe/paddle/__init__.py
0 → 100644
浏览文件 @
50afdd83
from
.transformer
import
Transformer
from
.network
import
Network
fluid/image_classification/caffe2fluid/kaffe/paddle/network.py
0 → 100644
浏览文件 @
50afdd83
import
math
import
os
import
numpy
as
np
def
import_fluid
():
import
paddle.v2.fluid
as
fluid
return
fluid
def
layer
(
op
):
'''Decorator for composable network layers.'''
def
layer_decorated
(
self
,
*
args
,
**
kwargs
):
# Automatically set a name if not provided.
name
=
kwargs
.
setdefault
(
'name'
,
self
.
get_unique_name
(
op
.
__name__
))
# Figure out the layer inputs.
if
len
(
self
.
terminals
)
==
0
:
raise
RuntimeError
(
'No input variables found for layer %s.'
%
name
)
elif
len
(
self
.
terminals
)
==
1
:
layer_input
=
self
.
terminals
[
0
]
else
:
layer_input
=
list
(
self
.
terminals
)
# Perform the operation and get the output.
layer_output
=
op
(
self
,
layer_input
,
*
args
,
**
kwargs
)
# Add to layer LUT.
self
.
layers
[
name
]
=
layer_output
# This output is now the input for the next layer.
self
.
feed
(
layer_output
)
# Return self for chained calls.
return
self
return
layer_decorated
class
Network
(
object
):
def
__init__
(
self
,
inputs
,
trainable
=
True
):
# The input nodes for this network
self
.
inputs
=
inputs
# The current list of terminal nodes
self
.
terminals
=
[]
# Mapping from layer names to layers
self
.
layers
=
dict
(
inputs
)
# If true, the resulting variables are set as trainable
self
.
trainable
=
trainable
# Switch variable for dropout
self
.
paddle_env
=
None
self
.
setup
()
def
setup
(
self
):
'''Construct the network. '''
raise
NotImplementedError
(
'Must be implemented by the subclass.'
)
def
load
(
self
,
data_path
,
exe
=
None
,
place
=
None
,
ignore_missing
=
False
):
'''Load network weights.
data_path: The path to the numpy-serialized network weights
ignore_missing: If true, serialized weights for missing layers are ignored.
'''
fluid
=
import_fluid
()
#load fluid mode directly
if
os
.
path
.
isdir
(
data_path
):
assert
(
exe
is
not
None
),
\
'must provide a executor to load fluid model'
fluid
.
io
.
load_persistables_if_exist
(
executor
=
exe
,
dirname
=
data_path
)
return
True
#load model from a npy file
if
exe
is
None
or
place
is
None
:
if
self
.
paddle_env
is
None
:
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
self
.
paddle_env
=
{
'place'
:
place
,
'exe'
:
exe
}
exe
=
exe
.
run
(
fluid
.
default_startup_program
())
else
:
place
=
self
.
paddle_env
[
'place'
]
exe
=
self
.
paddle_env
[
'exe'
]
data_dict
=
np
.
load
(
data_path
).
item
()
for
op_name
in
data_dict
:
layer
=
self
.
layers
[
op_name
]
for
param_name
,
data
in
data_dict
[
op_name
].
iteritems
():
try
:
name
=
'%s_%s'
%
(
op_name
,
param_name
)
v
=
fluid
.
global_scope
().
find_var
(
name
)
w
=
v
.
get_tensor
()
w
.
set
(
data
,
place
)
except
ValueError
:
if
not
ignore_missing
:
raise
return
True
def
feed
(
self
,
*
args
):
'''Set the input(s) for the next operation by replacing the terminal nodes.
The arguments can be either layer names or the actual layers.
'''
assert
len
(
args
)
!=
0
self
.
terminals
=
[]
for
fed_layer
in
args
:
if
isinstance
(
fed_layer
,
basestring
):
try
:
fed_layer
=
self
.
layers
[
fed_layer
]
except
KeyError
:
raise
KeyError
(
'Unknown layer name fed: %s'
%
fed_layer
)
self
.
terminals
.
append
(
fed_layer
)
return
self
def
get_output
(
self
):
'''Returns the current network output.'''
return
self
.
terminals
[
-
1
]
def
get_unique_name
(
self
,
prefix
):
'''Returns an index-suffixed unique name for the given prefix.
This is used for auto-generating layer names based on the type-prefix.
'''
ident
=
sum
(
t
.
startswith
(
prefix
)
for
t
,
_
in
self
.
layers
.
items
())
+
1
return
'%s_%d'
%
(
prefix
,
ident
)
@
layer
def
conv
(
self
,
input
,
k_h
,
k_w
,
c_o
,
s_h
,
s_w
,
name
,
relu
=
True
,
padding
=
None
,
group
=
1
,
biased
=
True
):
if
padding
is
None
:
padding
=
[
0
,
0
]
# Get the number of channels in the input
c_i
,
h_i
,
w_i
=
input
.
shape
[
1
:]
# Verify that the grouping parameter is valid
assert
c_i
%
group
==
0
assert
c_o
%
group
==
0
fluid
=
import_fluid
()
prefix
=
name
+
'_'
output
=
fluid
.
layers
.
conv2d
(
input
=
input
,
filter_size
=
[
k_h
,
k_w
],
num_filters
=
c_o
,
stride
=
[
s_h
,
s_w
],
padding
=
padding
,
groups
=
group
,
param_attr
=
fluid
.
ParamAttr
(
name
=
prefix
+
"weights"
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
prefix
+
"biases"
),
act
=
"relu"
if
relu
is
True
else
None
)
return
output
@
layer
def
relu
(
self
,
input
,
name
):
fluid
=
import_fluid
()
output
=
fluid
.
layers
.
relu
(
x
=
input
)
return
output
@
layer
def
max_pool
(
self
,
input
,
k_h
,
k_w
,
s_h
,
s_w
,
name
,
padding
=
None
):
if
padding
is
None
:
padding
=
[
0
,
0
]
# Get the number of channels in the input
h_i
,
w_i
=
input
.
shape
[
2
:]
fluid
=
import_fluid
()
output
=
fluid
.
layers
.
pool2d
(
input
=
input
,
pool_size
=
[
k_h
,
k_w
],
pool_stride
=
[
s_h
,
s_w
],
pool_padding
=
padding
,
pool_type
=
'max'
)
return
output
@
layer
def
avg_pool
(
self
,
input
,
k_h
,
k_w
,
s_h
,
s_w
,
name
,
padding
=
None
):
if
padding
is
None
:
padding
=
[
0
,
0
]
# Get the number of channels in the input
h_i
,
w_i
=
input
.
shape
[
2
:]
fluid
=
import_fluid
()
output
=
fluid
.
layers
.
pool2d
(
input
=
input
,
pool_size
=
[
k_h
,
k_w
],
pool_stride
=
[
s_h
,
s_w
],
pool_padding
=
padding
,
pool_type
=
'avg'
)
return
output
@
layer
def
lrn
(
self
,
input
,
radius
,
alpha
,
beta
,
name
,
bias
=
1.0
):
raise
Exception
(
'lrn() not implemented yet'
)
@
layer
def
concat
(
self
,
inputs
,
axis
,
name
):
fluid
=
import_fluid
()
output
=
fluid
.
layers
.
concat
(
input
=
inputs
,
axis
=
axis
)
return
output
@
layer
def
add
(
self
,
inputs
,
name
):
fluid
=
import_fluid
()
output
=
inputs
[
0
]
for
i
in
inputs
[
1
:]:
output
=
fluid
.
layers
.
elementwise_add
(
x
=
output
,
y
=
i
)
return
output
@
layer
def
fc
(
self
,
input
,
num_out
,
name
,
relu
=
True
,
act
=
None
):
fluid
=
import_fluid
()
if
act
is
None
:
act
=
'relu'
if
relu
is
True
else
None
prefix
=
name
+
'_'
output
=
fluid
.
layers
.
fc
(
name
=
name
,
input
=
input
,
size
=
num_out
,
act
=
act
,
param_attr
=
fluid
.
ParamAttr
(
name
=
prefix
+
'weights'
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
prefix
+
'biases'
))
return
output
@
layer
def
softmax
(
self
,
input
,
name
):
fluid
=
import_fluid
()
output
=
fluid
.
layers
.
softmax
(
x
=
input
,
name
=
name
)
return
output
@
layer
def
batch_normalization
(
self
,
input
,
name
,
scale_offset
=
True
,
relu
=
False
):
# NOTE: Currently, only inference is supported
fluid
=
import_fluid
()
prefix
=
name
+
'_'
param_attr
=
None
if
scale_offset
is
False
else
fluid
.
ParamAttr
(
name
=
prefix
+
'scale'
)
bias_attr
=
None
if
scale_offset
is
False
else
fluid
.
ParamAttr
(
name
=
prefix
+
'offset'
)
mean_name
=
prefix
+
'mean'
variance_name
=
prefix
+
'variance'
output
=
fluid
.
layers
.
batch_norm
(
name
=
name
,
input
=
input
,
is_test
=
True
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
,
moving_mean_name
=
mean_name
,
moving_variance_name
=
variance_name
,
epsilon
=
1e-5
,
act
=
'relu'
if
relu
is
True
else
None
)
return
output
@
layer
def
dropout
(
self
,
input
,
keep_prob
,
name
):
raise
Exception
(
'dropout() not implemented yet'
)
fluid/image_classification/caffe2fluid/kaffe/paddle/transformer.py
0 → 100644
浏览文件 @
50afdd83
import
numpy
as
np
from
..errors
import
KaffeError
,
print_stderr
from
..graph
import
GraphBuilder
,
NodeMapper
from
..layers
import
NodeKind
from
..transformers
import
(
DataInjector
,
DataReshaper
,
NodeRenamer
,
ReLUFuser
,
BatchNormScaleBiasFuser
,
BatchNormPreprocessor
,
ParameterNamer
)
from
.
import
network
def
get_padding_type
(
kernel_params
,
input_shape
,
output_shape
):
'''Translates Caffe's numeric padding to one of ('SAME', 'VALID').
Caffe supports arbitrary padding values, while TensorFlow only
supports 'SAME' and 'VALID' modes. So, not all Caffe paddings
can be translated to TensorFlow. There are some subtleties to
how the padding edge-cases are handled. These are described here:
https://github.com/Yangqing/caffe2/blob/master/caffe2/proto/caffe2_legacy.proto
'''
k_h
,
k_w
,
s_h
,
s_w
,
p_h
,
p_w
=
kernel_params
if
p_h
*
p_w
>
0
:
return
[
p_h
,
p_w
]
else
:
return
None
class
TensorFlowNode
(
object
):
'''An intermediate representation for TensorFlow operations.'''
def
__init__
(
self
,
op
,
*
args
,
**
kwargs
):
# A string corresponding to the TensorFlow operation
self
.
op
=
op
# Positional arguments for the operation
self
.
args
=
args
# Keyword arguments for the operation
self
.
kwargs
=
list
(
kwargs
.
items
())
# The source Caffe node
self
.
node
=
None
def
format
(
self
,
arg
):
'''Returns a string representation for the given value.'''
return
"'%s'"
%
arg
if
isinstance
(
arg
,
basestring
)
else
str
(
arg
)
def
pair
(
self
,
key
,
value
):
'''Returns key=formatted(value).'''
return
'%s=%s'
%
(
key
,
self
.
format
(
value
))
def
emit
(
self
):
'''Emits the Python source for this node.'''
# Format positional arguments
args
=
map
(
self
.
format
,
self
.
args
)
# Format any keyword arguments
if
self
.
kwargs
:
args
+=
[
self
.
pair
(
k
,
v
)
for
k
,
v
in
self
.
kwargs
]
# Set the node name
args
.
append
(
self
.
pair
(
'name'
,
self
.
node
.
name
))
args
=
', '
.
join
(
args
)
return
'%s(%s)'
%
(
self
.
op
,
args
)
class
MaybeActivated
(
object
):
def
__init__
(
self
,
node
,
default
=
True
):
self
.
inject_kwargs
=
{}
if
node
.
metadata
.
get
(
'relu'
,
False
)
!=
default
:
self
.
inject_kwargs
[
'relu'
]
=
not
default
def
__call__
(
self
,
*
args
,
**
kwargs
):
kwargs
.
update
(
self
.
inject_kwargs
)
return
TensorFlowNode
(
*
args
,
**
kwargs
)
class
TensorFlowMapper
(
NodeMapper
):
def
get_kernel_params
(
self
,
node
):
kernel_params
=
node
.
layer
.
kernel_parameters
input_shape
=
node
.
get_only_parent
().
output_shape
padding
=
get_padding_type
(
kernel_params
,
input_shape
,
node
.
output_shape
)
# Only emit the padding if it's not the default value.
padding
=
{
'padding'
:
padding
}
if
padding
is
not
None
else
{}
return
(
kernel_params
,
padding
)
def
map_convolution
(
self
,
node
):
(
kernel_params
,
kwargs
)
=
self
.
get_kernel_params
(
node
)
h
=
kernel_params
.
kernel_h
w
=
kernel_params
.
kernel_w
c_o
=
node
.
output_shape
[
1
]
c_i
=
node
.
parents
[
0
].
output_shape
[
1
]
group
=
node
.
parameters
.
group
if
group
!=
1
:
kwargs
[
'group'
]
=
group
if
not
node
.
parameters
.
bias_term
:
kwargs
[
'biased'
]
=
False
assert
kernel_params
.
kernel_h
==
h
assert
kernel_params
.
kernel_w
==
w
return
MaybeActivated
(
node
)(
'conv'
,
kernel_params
.
kernel_h
,
kernel_params
.
kernel_w
,
c_o
,
kernel_params
.
stride_h
,
kernel_params
.
stride_w
,
**
kwargs
)
def
map_relu
(
self
,
node
):
return
TensorFlowNode
(
'relu'
)
def
map_pooling
(
self
,
node
):
pool_type
=
node
.
parameters
.
pool
if
pool_type
==
0
:
pool_op
=
'max_pool'
elif
pool_type
==
1
:
pool_op
=
'avg_pool'
else
:
# Stochastic pooling, for instance.
raise
KaffeError
(
'Unsupported pooling type.'
)
(
kernel_params
,
padding
)
=
self
.
get_kernel_params
(
node
)
return
TensorFlowNode
(
pool_op
,
kernel_params
.
kernel_h
,
kernel_params
.
kernel_w
,
kernel_params
.
stride_h
,
kernel_params
.
stride_w
,
**
padding
)
def
map_inner_product
(
self
,
node
):
#TODO: Axis
assert
node
.
parameters
.
axis
==
1
#TODO: Unbiased
assert
node
.
parameters
.
bias_term
==
True
return
MaybeActivated
(
node
)(
'fc'
,
node
.
parameters
.
num_output
)
def
map_softmax
(
self
,
node
):
return
TensorFlowNode
(
'softmax'
)
def
map_lrn
(
self
,
node
):
params
=
node
.
parameters
# The window size must be an odd value. For a window
# size of (2*n+1), TensorFlow defines depth_radius = n.
assert
params
.
local_size
%
2
==
1
# Caffe scales by (alpha/(2*n+1)), whereas TensorFlow
# just scales by alpha (as does Krizhevsky's paper).
# We'll account for that here.
alpha
=
params
.
alpha
/
float
(
params
.
local_size
)
return
TensorFlowNode
(
'lrn'
,
int
(
params
.
local_size
/
2
),
alpha
,
params
.
beta
)
def
map_concat
(
self
,
node
):
return
TensorFlowNode
(
'concat'
,
node
.
parameters
.
axis
)
def
map_dropout
(
self
,
node
):
return
TensorFlowNode
(
'dropout'
,
node
.
parameters
.
dropout_ratio
)
def
map_batch_norm
(
self
,
node
):
scale_offset
=
len
(
node
.
data
)
==
4
kwargs
=
{}
if
scale_offset
else
{
'scale_offset'
:
False
}
return
MaybeActivated
(
node
,
default
=
False
)(
'batch_normalization'
,
**
kwargs
)
def
map_eltwise
(
self
,
node
):
operations
=
{
0
:
'multiply'
,
1
:
'add'
,
2
:
'max'
}
op_code
=
node
.
parameters
.
operation
try
:
return
TensorFlowNode
(
operations
[
op_code
])
except
KeyError
:
raise
KaffeError
(
'Unknown elementwise operation: {}'
.
format
(
op_code
))
def
commit
(
self
,
chains
):
return
chains
class
TensorFlowEmitter
(
object
):
def
__init__
(
self
,
tab
=
None
):
self
.
tab
=
tab
or
' '
*
4
self
.
prefix
=
''
self
.
net_name
=
''
def
indent
(
self
):
self
.
prefix
+=
self
.
tab
def
outdent
(
self
):
self
.
prefix
=
self
.
prefix
[:
-
len
(
self
.
tab
)]
def
statement
(
self
,
s
):
return
self
.
prefix
+
s
+
'
\n
'
def
emit_imports
(
self
):
import
inspect
codes
=
[]
codes
.
append
(
'### generated by caffe2fluid, your net is in class "%s" ###
\n
'
%
(
self
.
net_name
))
network_source
=
inspect
.
getsource
(
network
)
codes
.
append
(
network_source
+
'
\n
'
)
return
self
.
statement
(
'
\n
'
.
join
(
codes
))
def
emit_class_def
(
self
,
name
):
return
self
.
statement
(
'class %s(Network):'
%
(
name
))
def
emit_setup_def
(
self
):
return
self
.
statement
(
'def setup(self):'
)
def
emit_convert_def
(
self
,
input_nodes
):
def
data_layer_def
(
name
,
shape
,
dtype
=
None
):
if
dtype
is
None
:
dtype
=
'float32'
layer_var
=
name
+
'_layer'
shape
=
[
str
(
s
)
for
s
in
shape
[
1
:]]
layer_def
=
'%s = fluid.layers.data(name="%s", shape=[%s], dtype="%s")'
\
%
(
layer_var
,
name
,
','
.
join
(
shape
),
dtype
)
return
layer_var
,
layer_def
codes
=
[]
inputs
=
{}
for
n
in
input_nodes
:
name
=
n
.
name
layer_var
,
layer_def
=
data_layer_def
(
n
.
name
,
n
.
output_shape
)
codes
.
append
(
layer_def
)
inputs
[
name
]
=
layer_var
input_dict
=
','
.
join
([
'"%s": %s'
%
(
n
,
l
)
for
n
,
l
in
inputs
.
items
()])
codes
.
append
(
'feed_data = {'
+
input_dict
+
'}'
)
codes
.
append
(
'net = cls(feed_data)'
)
codes
.
append
(
"place = fluid.CPUPlace()"
)
codes
.
append
(
"exe = fluid.Executor(place)"
)
codes
.
append
(
"exe.run(fluid.default_startup_program())"
)
codes
.
append
(
"net.load(data_path=npy_model, exe=exe, place=place)"
)
codes
.
append
(
"fluid.io.save_persistables(executor=exe, dirname=fluid_path)"
)
self
.
outdent
()
func_def
=
self
.
statement
(
'@classmethod'
)
func_def
+=
self
.
statement
(
'def convert(cls, npy_model, fluid_path):'
)
self
.
indent
()
func_def
+=
self
.
statement
(
'import paddle.v2.fluid as fluid'
)
for
l
in
codes
:
func_def
+=
self
.
statement
(
l
)
return
'
\n\n
'
+
func_def
def
emit_main_def
(
self
,
name
):
if
name
is
None
:
return
''
self
.
prefix
=
''
main_def
=
self
.
statement
(
'if __name__ == "__main__":'
)
self
.
indent
()
main_def
+=
self
.
statement
(
"#usage: python xxxnet.py xxx.npy ./model
\n
"
)
main_def
+=
self
.
statement
(
"import sys"
)
main_def
+=
self
.
statement
(
"npy_weight = sys.argv[1]"
)
main_def
+=
self
.
statement
(
"fluid_model = sys.argv[2]"
)
main_def
+=
self
.
statement
(
"%s.convert(npy_weight, fluid_model)"
%
(
name
))
main_def
+=
self
.
statement
(
"exit(0)"
)
return
'
\n\n
'
+
main_def
def
emit_parents
(
self
,
chain
):
assert
len
(
chain
)
s
=
'self.feed('
sep
=
',
\n
'
+
self
.
prefix
+
(
' '
*
len
(
s
))
s
+=
sep
.
join
(
[
"'%s'"
%
parent
.
name
for
parent
in
chain
[
0
].
node
.
parents
])
return
self
.
statement
(
s
+
')'
)
def
emit_node
(
self
,
node
):
return
self
.
statement
(
'self.'
+
node
.
emit
())
def
emit
(
self
,
name
,
chains
,
input_nodes
=
None
):
self
.
net_name
=
name
s
=
self
.
emit_imports
()
s
+=
self
.
emit_class_def
(
name
)
self
.
indent
()
s
+=
self
.
emit_setup_def
()
self
.
indent
()
blocks
=
[]
for
chain
in
chains
:
b
=
''
b
+=
self
.
emit_parents
(
chain
)
for
node
in
chain
:
b
+=
self
.
emit_node
(
node
)
blocks
.
append
(
b
[:
-
1
])
s
=
s
+
'
\n\n
'
.
join
(
blocks
)
s
+=
self
.
emit_convert_def
(
input_nodes
)
s
+=
self
.
emit_main_def
(
name
)
return
s
class
Transformer
(
object
):
def
__init__
(
self
,
def_path
,
data_path
,
verbose
=
True
,
phase
=
'test'
):
self
.
verbose
=
verbose
self
.
phase
=
phase
self
.
load
(
def_path
,
data_path
,
phase
)
self
.
params
=
None
self
.
source
=
None
def
load
(
self
,
def_path
,
data_path
,
phase
):
# Build the graph
graph
=
GraphBuilder
(
def_path
,
phase
).
build
()
if
data_path
is
not
None
:
# Load and associate learned parameters
graph
=
DataInjector
(
def_path
,
data_path
)(
graph
)
# Transform the graph
transformers
=
[
# Fuse split batch normalization layers
BatchNormScaleBiasFuser
(),
# Fuse ReLUs
# TODO: Move non-linearity application to layer wrapper, allowing
# any arbitrary operation to be optionally activated.
ReLUFuser
(
allowed_parent_types
=
[
NodeKind
.
Convolution
,
NodeKind
.
InnerProduct
,
NodeKind
.
BatchNorm
]),
# Rename nodes
# Slashes are used for scoping in TensorFlow. Replace slashes
# in node names with underscores.
# (Caffe's GoogLeNet implementation uses slashes)
NodeRenamer
(
lambda
node
:
node
.
name
.
replace
(
'/'
,
'_'
))
]
self
.
graph
=
graph
.
transformed
(
transformers
)
# Display the graph
if
self
.
verbose
:
print_stderr
(
self
.
graph
)
def
transform_data
(
self
):
if
self
.
params
is
None
:
transformers
=
[
# Reshape the parameters to TensorFlow's ordering
DataReshaper
({
# (c_o, c_i, h, w) -> (h, w, c_i, c_o) for TF
NodeKind
.
Convolution
:
(
0
,
1
,
2
,
3
),
# (c_o, c_i) -> (c_i, c_o)
NodeKind
.
InnerProduct
:
(
1
,
0
)
}),
# Pre-process batch normalization data
BatchNormPreprocessor
(),
# Convert parameters to dictionaries
ParameterNamer
(),
]
self
.
graph
=
self
.
graph
.
transformed
(
transformers
)
self
.
params
=
{
node
.
name
:
node
.
data
for
node
in
self
.
graph
.
nodes
if
node
.
data
}
return
self
.
params
def
transform_source
(
self
):
if
self
.
source
is
None
:
mapper
=
TensorFlowMapper
(
self
.
graph
)
chains
=
mapper
.
map
()
emitter
=
TensorFlowEmitter
()
input_nodes
=
self
.
graph
.
get_input_nodes
()
self
.
source
=
emitter
.
emit
(
self
.
graph
.
name
,
chains
,
input_nodes
)
return
self
.
source
fluid/image_classification/caffe2fluid/kaffe/shapes.py
0 → 100644
浏览文件 @
50afdd83
import
math
from
collections
import
namedtuple
from
.errors
import
KaffeError
TensorShape
=
namedtuple
(
'TensorShape'
,
[
'batch_size'
,
'channels'
,
'height'
,
'width'
])
def
get_filter_output_shape
(
i_h
,
i_w
,
params
,
round_func
):
o_h
=
(
i_h
+
2
*
params
.
pad_h
-
params
.
kernel_h
)
/
float
(
params
.
stride_h
)
+
1
o_w
=
(
i_w
+
2
*
params
.
pad_w
-
params
.
kernel_w
)
/
float
(
params
.
stride_w
)
+
1
return
(
int
(
round_func
(
o_h
)),
int
(
round_func
(
o_w
)))
def
get_strided_kernel_output_shape
(
node
,
round_func
):
assert
node
.
layer
is
not
None
input_shape
=
node
.
get_only_parent
().
output_shape
o_h
,
o_w
=
get_filter_output_shape
(
input_shape
.
height
,
input_shape
.
width
,
node
.
layer
.
kernel_parameters
,
round_func
)
params
=
node
.
layer
.
parameters
has_c_o
=
hasattr
(
params
,
'num_output'
)
c
=
params
.
num_output
if
has_c_o
else
input_shape
.
channels
return
TensorShape
(
input_shape
.
batch_size
,
c
,
o_h
,
o_w
)
def
shape_not_implemented
(
node
):
raise
NotImplementedError
def
shape_identity
(
node
):
assert
len
(
node
.
parents
)
>
0
return
node
.
parents
[
0
].
output_shape
def
shape_scalar
(
node
):
return
TensorShape
(
1
,
1
,
1
,
1
)
def
shape_data
(
node
):
if
node
.
output_shape
:
# Old-style input specification
return
node
.
output_shape
try
:
# New-style input specification
return
map
(
int
,
node
.
parameters
.
shape
[
0
].
dim
)
except
:
# We most likely have a data layer on our hands. The problem is,
# Caffe infers the dimensions of the data from the source (eg: LMDB).
# We want to avoid reading datasets here. Fail for now.
# This can be temporarily fixed by transforming the data layer to
# Caffe's "input" layer (as is usually used in the "deploy" version).
# TODO: Find a better solution for this.
raise
KaffeError
(
'Cannot determine dimensions of data layer.
\n
'
'See comments in function shape_data for more info.'
)
def
shape_mem_data
(
node
):
params
=
node
.
parameters
return
TensorShape
(
params
.
batch_size
,
params
.
channels
,
params
.
height
,
params
.
width
)
def
shape_concat
(
node
):
axis
=
node
.
layer
.
parameters
.
axis
output_shape
=
None
for
parent
in
node
.
parents
:
if
output_shape
is
None
:
output_shape
=
list
(
parent
.
output_shape
)
else
:
output_shape
[
axis
]
+=
parent
.
output_shape
[
axis
]
return
tuple
(
output_shape
)
def
shape_convolution
(
node
):
return
get_strided_kernel_output_shape
(
node
,
math
.
floor
)
def
shape_pool
(
node
):
return
get_strided_kernel_output_shape
(
node
,
math
.
ceil
)
def
shape_inner_product
(
node
):
input_shape
=
node
.
get_only_parent
().
output_shape
return
TensorShape
(
input_shape
.
batch_size
,
node
.
layer
.
parameters
.
num_output
,
1
,
1
)
fluid/image_classification/caffe2fluid/kaffe/transformers.py
0 → 100644
浏览文件 @
50afdd83
'''
A collection of graph transforms.
A transformer is a callable that accepts a graph and returns a transformed version.
'''
import
os
import
numpy
as
np
from
.caffe
import
get_caffe_resolver
,
has_pycaffe
from
.errors
import
KaffeError
,
debug
,
notice
,
warn
from
.layers
import
NodeKind
class
DataInjector
(
object
):
'''
Associates parameters loaded from a .caffemodel file with their corresponding nodes.
'''
def
__init__
(
self
,
def_path
,
data_path
):
# The .prototxt file defining the graph
self
.
def_path
=
def_path
# The .caffemodel file containing the learned parameters
self
.
data_path
=
data_path
# Set to true if the fallback protocol-buffer based backend was used
self
.
did_use_pb
=
False
# A list containing (layer name, parameters) tuples
self
.
params
=
None
# Load the parameters
self
.
load
()
def
load
(
self
):
if
has_pycaffe
():
self
.
load_using_caffe
()
else
:
self
.
load_using_pb
()
def
load_using_caffe
(
self
):
caffe
=
get_caffe_resolver
().
caffe
net
=
caffe
.
Net
(
self
.
def_path
,
self
.
data_path
,
caffe
.
TEST
)
data
=
lambda
blob
:
blob
.
data
self
.
params
=
[(
k
,
map
(
data
,
v
))
for
k
,
v
in
net
.
params
.
items
()]
def
load_using_pb
(
self
):
data
=
get_caffe_resolver
().
NetParameter
()
data
.
MergeFromString
(
open
(
self
.
data_path
,
'rb'
).
read
())
pair
=
lambda
layer
:
(
layer
.
name
,
self
.
normalize_pb_data
(
layer
))
layers
=
data
.
layers
or
data
.
layer
self
.
params
=
[
pair
(
layer
)
for
layer
in
layers
if
layer
.
blobs
]
self
.
did_use_pb
=
True
def
normalize_pb_data
(
self
,
layer
):
transformed
=
[]
for
blob
in
layer
.
blobs
:
if
len
(
blob
.
shape
.
dim
):
dims
=
blob
.
shape
.
dim
c_o
,
c_i
,
h
,
w
=
map
(
int
,
[
1
]
*
(
4
-
len
(
dims
))
+
list
(
dims
))
else
:
c_o
=
blob
.
num
c_i
=
blob
.
channels
h
=
blob
.
height
w
=
blob
.
width
data
=
np
.
array
(
blob
.
data
,
dtype
=
np
.
float32
).
reshape
(
c_o
,
c_i
,
h
,
w
)
transformed
.
append
(
data
)
return
transformed
def
adjust_parameters
(
self
,
node
,
data
):
if
not
self
.
did_use_pb
:
return
data
# When using the protobuf-backend, each parameter initially has four dimensions.
# In certain cases (like FC layers), we want to eliminate the singleton dimensions.
# This implementation takes care of the common cases. However, it does leave the
# potential for future issues.
# The Caffe-backend does not suffer from this problem.
data
=
list
(
data
)
squeeze_indices
=
[
1
]
# Squeeze biases.
if
node
.
kind
==
NodeKind
.
InnerProduct
:
squeeze_indices
.
append
(
0
)
# Squeeze FC.
for
idx
in
squeeze_indices
:
if
idx
>=
len
(
data
):
continue
shape_old
=
data
[
idx
].
shape
data
[
idx
]
=
np
.
squeeze
(
data
[
idx
])
shape_new
=
data
[
idx
].
shape
if
len
(
shape_old
)
!=
shape_new
:
debug
(
'squeeze idx:%d, with kind:%s,name:%s'
%
\
(
idx
,
node
.
kind
,
node
.
name
))
return
data
def
__call__
(
self
,
graph
):
for
layer_name
,
data
in
self
.
params
:
if
layer_name
in
graph
:
node
=
graph
.
get_node
(
layer_name
)
node
.
data
=
self
.
adjust_parameters
(
node
,
data
)
else
:
notice
(
'Ignoring parameters for non-existent layer: %s'
%
\
layer_name
)
return
graph
class
DataReshaper
(
object
):
def
__init__
(
self
,
mapping
,
replace
=
True
):
# A dictionary mapping NodeKind to the transposed order.
self
.
mapping
=
mapping
# The node kinds eligible for reshaping
self
.
reshaped_node_types
=
self
.
mapping
.
keys
()
# If true, the reshaped data will replace the old one.
# Otherwise, it's set to the reshaped_data attribute.
self
.
replace
=
replace
def
has_spatial_parent
(
self
,
node
):
try
:
parent
=
node
.
get_only_parent
()
s
=
parent
.
output_shape
return
s
.
height
>
1
or
s
.
width
>
1
except
KaffeError
:
return
False
def
map
(
self
,
node_kind
):
try
:
return
self
.
mapping
[
node_kind
]
except
KeyError
:
raise
#raise KaffeError('Ordering not found for node kind: {}'.format(node_kind))
def
__call__
(
self
,
graph
):
for
node
in
graph
.
nodes
:
if
node
.
data
is
None
:
continue
if
node
.
kind
not
in
self
.
reshaped_node_types
:
# Check for 2+ dimensional data
if
any
(
len
(
tensor
.
shape
)
>
1
for
tensor
in
node
.
data
):
notice
(
'parmaters not reshaped for node: {}'
.
format
(
node
))
continue
transpose_order
=
self
.
map
(
node
.
kind
)
weights
=
node
.
data
[
0
]
if
(
node
.
kind
==
NodeKind
.
InnerProduct
)
and
self
.
has_spatial_parent
(
node
):
# The FC layer connected to the spatial layer needs to be
# re-wired to match the new spatial ordering.
in_shape
=
node
.
get_only_parent
().
output_shape
fc_shape
=
weights
.
shape
output_channels
=
fc_shape
[
0
]
weights
=
weights
.
reshape
((
output_channels
,
-
1
))
weights
=
weights
.
transpose
(
transpose_order
)
node
.
reshaped_data
=
weights
else
:
node
.
reshaped_data
=
weights
.
transpose
(
transpose_order
)
if
self
.
replace
:
for
node
in
graph
.
nodes
:
if
hasattr
(
node
,
'reshaped_data'
):
# Set the weights
node
.
data
[
0
]
=
node
.
reshaped_data
del
node
.
reshaped_data
return
graph
class
SubNodeFuser
(
object
):
'''
An abstract helper for merging a single-child with its single-parent.
'''
def
__call__
(
self
,
graph
):
nodes
=
graph
.
nodes
fused_nodes
=
[]
for
node
in
nodes
:
if
len
(
node
.
parents
)
!=
1
:
# We're only fusing nodes with single parents
continue
parent
=
node
.
get_only_parent
()
if
len
(
parent
.
children
)
!=
1
:
# We can only fuse a node if its parent's
# value isn't used by any other node.
continue
if
not
self
.
is_eligible_pair
(
parent
,
node
):
continue
# Rewrite the fused node's children to its parent.
for
child
in
node
.
children
:
child
.
parents
.
remove
(
node
)
parent
.
add_child
(
child
)
# Disconnect the fused node from the graph.
parent
.
children
.
remove
(
node
)
fused_nodes
.
append
(
node
)
# Let the sub-class merge the fused node in any arbitrary way.
self
.
merge
(
parent
,
node
)
transformed_nodes
=
[
node
for
node
in
nodes
if
node
not
in
fused_nodes
]
return
graph
.
replaced
(
transformed_nodes
)
def
is_eligible_pair
(
self
,
parent
,
child
):
'''Returns true if this parent/child pair is eligible for fusion.'''
raise
NotImplementedError
(
'Must be implemented by subclass.'
)
def
merge
(
self
,
parent
,
child
):
'''Merge the child node into the parent.'''
raise
NotImplementedError
(
'Must be implemented by subclass'
)
class
ReLUFuser
(
SubNodeFuser
):
'''
Fuses rectified linear units with their parent nodes.
'''
def
__init__
(
self
,
allowed_parent_types
=
None
):
# Fuse ReLUs when the parent node is one of the given types.
# If None, all node types are eligible.
self
.
allowed_parent_types
=
allowed_parent_types
def
is_eligible_pair
(
self
,
parent
,
child
):
return
((
self
.
allowed_parent_types
is
None
or
\
parent
.
kind
in
self
.
allowed_parent_types
)
and
\
child
.
kind
==
NodeKind
.
ReLU
)
def
merge
(
self
,
parent
,
_
):
parent
.
metadata
[
'relu'
]
=
True
class
BatchNormScaleBiasFuser
(
SubNodeFuser
):
'''
The original batch normalization paper includes two learned
parameters: a scaling factor \gamma and a bias
\b
eta.
Caffe's implementation does not include these two. However, it is commonly
replicated by adding a scaling+bias layer immidiately after the batch norm.
This fuser merges the scaling+bias layer with the batch norm.
'''
def
is_eligible_pair
(
self
,
parent
,
child
):
return
(
parent
.
kind
==
NodeKind
.
BatchNorm
and
\
child
.
kind
==
NodeKind
.
Scale
and
\
child
.
parameters
.
axis
==
1
and
\
child
.
parameters
.
bias_term
==
True
)
def
merge
(
self
,
parent
,
child
):
parent
.
scale_bias_node
=
child
class
BatchNormPreprocessor
(
object
):
'''
Prescale batch normalization parameters.
Concatenate gamma (scale) and beta (bias) terms if set.
'''
def
__call__
(
self
,
graph
):
for
node
in
graph
.
nodes
:
if
node
.
kind
!=
NodeKind
.
BatchNorm
:
continue
assert
node
.
data
is
not
None
assert
len
(
node
.
data
)
==
3
node
.
data
=
[
np
.
squeeze
(
i
)
for
i
in
node
.
data
]
mean
,
variance
,
scale
=
node
.
data
# Prescale the stats
scaling_factor
=
1.0
/
scale
if
scale
!=
0
else
0
mean
*=
scaling_factor
variance
*=
scaling_factor
# Replace with the updated values
node
.
data
=
[
mean
,
variance
]
if
hasattr
(
node
,
'scale_bias_node'
):
# Include the scale and bias terms
gamma
,
beta
=
node
.
scale_bias_node
.
data
node
.
data
+=
[
np
.
squeeze
(
i
)
for
i
in
[
gamma
,
beta
]]
return
graph
class
NodeRenamer
(
object
):
'''
Renames nodes in the graph using a given unary function that
accepts a node and returns its new name.
'''
def
__init__
(
self
,
renamer
):
self
.
renamer
=
renamer
def
__call__
(
self
,
graph
):
for
node
in
graph
.
nodes
:
node
.
name
=
self
.
renamer
(
node
)
return
graph
class
ParameterNamer
(
object
):
'''
Convert layer data arrays to a dictionary mapping parameter names to their values.
'''
def
__call__
(
self
,
graph
):
for
node
in
graph
.
nodes
:
if
node
.
data
is
None
:
continue
if
node
.
kind
in
(
NodeKind
.
Convolution
,
NodeKind
.
InnerProduct
):
names
=
(
'weights'
,
)
if
node
.
parameters
.
bias_term
:
names
+=
(
'biases'
,
)
elif
node
.
kind
==
NodeKind
.
BatchNorm
:
names
=
(
'mean'
,
'variance'
)
if
len
(
node
.
data
)
==
4
:
names
+=
(
'scale'
,
'offset'
)
else
:
warn
(
'Unhandled parameters: {}'
.
format
(
node
.
kind
))
continue
assert
len
(
names
)
==
len
(
node
.
data
)
node
.
data
=
dict
(
zip
(
names
,
node
.
data
))
return
graph
fluid/image_classification/caffe2fluid/proto/caffe.proto
0 → 100644
浏览文件 @
50afdd83
此差异已折叠。
点击以展开。
fluid/image_classification/caffe2fluid/proto/compile.sh
0 → 100644
浏览文件 @
50afdd83
#!/bin/bash
#function:
# script used to generate caffepb.py from caffe.proto using protoc
#
PROTOC
=
`
which protoc
`
if
[[
-z
$PROTOC
]]
;
then
echo
"not found protoc, you should first install it following this[https://github.com/google/protobuf/releases]"
exit
1
fi
WORK_ROOT
=
$(
dirname
`
readlink
-f
"
$BASH_SOURCE
[0]"
`
)
PY_NAME
=
"
$WORK_ROOT
/caffepb.py"
$PROTOC
--proto_path
=
$WORK_ROOT
--python_out
=
$WORK_ROOT
$WORK_ROOT
/caffe.proto
ret
=
$?
if
[
$ret
-eq
0
]
;
then
mv
$WORK_ROOT
/caffe_pb2.py
$PY_NAME
fi
if
[
-e
"
$PY_NAME
"
]
;
then
echo
"succeed to generate [
$PY_NAME
]"
exit
0
else
echo
"failed to generate [
$PY_NAME
]"
fi
exit
$ret
fluid/image_classification/caffe2fluid/tests/lenet/README.md
0 → 100644
浏览文件 @
50afdd83
### Convert lenet model from caffe format into paddle format(fluid api)
### Howto
1, Prepare your caffepb.py
2, Download a lenet caffe-model
lenet_iter_10000.caffemodel
download address: https://github.com/ethereon/caffe-tensorflow/raw/master/examples/mnist/lenet_iter_10000.caffemodel
md5: cbec75c1c374b6c1981c4a1eb024ae01
lenet.prototxt
download address: https://raw.githubusercontent.com/BVLC/caffe/master/examples/mnist/lenet.prototxt
md5: 27384af843338ab90b00c8d1c81de7d5
2, Convert this model(make sure caffepb.py is ready in ../../proto)
convert to npy format
bash ./convert.sh lenet.prototxt lenet.caffemodel lenet.py lenet.npy
save to fluid format(optional)
bash ./convert.sh lenet.prototxt lenet.caffemodel lenet.py lenet.npy && python ./lenet.py ./lenet.npy ./fluid.model
4, Use this new model(paddle installed in this python)
use fluid format
python ./predict.py ./fluid.model
use npy format
python ./predict.py ./lenet.npy
fluid/image_classification/caffe2fluid/tests/lenet/convert.sh
0 → 100755
浏览文件 @
50afdd83
#!/bin/bash
#function:
# convert a caffe model
# eg:
# bash ./convert.sh ./model.caffe/lenet.prototxt ./model.caffe/lenet.caffemodel lenet.py lenet.npy
if
[[
$#
-ne
4
]]
;
then
echo
"usage:"
echo
" bash
$0
[PROTOTXT] [CAFFEMODEL] [PY_NAME] [WEIGHT_NAME]"
echo
" eg: bash
$0
lenet.prototxt lenet.caffemodel lenet.py lenet.npy"
exit
1
fi
WORK_ROOT
=
$(
dirname
`
readlink
-f
${
BASH_SOURCE
[0]
}
`
)
if
[[
-z
$PYTHON
]]
;
then
PYTHON
=
`
which python
`
fi
PROTOTXT
=
$1
CAFFEMODEL
=
$2
PY_NAME
=
$3
WEIGHT_NAME
=
$4
CONVERTER_PY
=
"
$WORK_ROOT
/../../convert.py"
$PYTHON
$CONVERTER_PY
$PROTOTXT
--caffemodel
$CAFFEMODEL
--code-output-path
=
$PY_NAME
--data-output-path
=
$WEIGHT_NAME
ret
=
$?
if
[[
$ret
-eq
0
]]
;
then
echo
"succeed to convert caffe model[
$CAFFEMODEL
,
$PROTOTXT
] to paddle model[
$PY_NAME
,
$WEIGHT_NAME
]"
else
echo
"failed to convert caffe model[
$CAFFEMODEL
,
$PROTOTXT
]"
fi
exit
$ret
fluid/image_classification/caffe2fluid/tests/lenet/lenet.npy
0 → 100644
浏览文件 @
50afdd83
文件已添加
fluid/image_classification/caffe2fluid/tests/lenet/lenet.py
0 → 100644
浏览文件 @
50afdd83
### generated by caffe2fluid, your net is in class "LeNet" ###
import
math
import
os
import
numpy
as
np
def
import_fluid
():
import
paddle.v2.fluid
as
fluid
return
fluid
def
layer
(
op
):
'''Decorator for composable network layers.'''
def
layer_decorated
(
self
,
*
args
,
**
kwargs
):
# Automatically set a name if not provided.
name
=
kwargs
.
setdefault
(
'name'
,
self
.
get_unique_name
(
op
.
__name__
))
# Figure out the layer inputs.
if
len
(
self
.
terminals
)
==
0
:
raise
RuntimeError
(
'No input variables found for layer %s.'
%
name
)
elif
len
(
self
.
terminals
)
==
1
:
layer_input
=
self
.
terminals
[
0
]
else
:
layer_input
=
list
(
self
.
terminals
)
# Perform the operation and get the output.
layer_output
=
op
(
self
,
layer_input
,
*
args
,
**
kwargs
)
# Add to layer LUT.
self
.
layers
[
name
]
=
layer_output
# This output is now the input for the next layer.
self
.
feed
(
layer_output
)
# Return self for chained calls.
return
self
return
layer_decorated
class
Network
(
object
):
def
__init__
(
self
,
inputs
,
trainable
=
True
):
# The input nodes for this network
self
.
inputs
=
inputs
# The current list of terminal nodes
self
.
terminals
=
[]
# Mapping from layer names to layers
self
.
layers
=
dict
(
inputs
)
# If true, the resulting variables are set as trainable
self
.
trainable
=
trainable
# Switch variable for dropout
self
.
paddle_env
=
None
self
.
setup
()
def
setup
(
self
):
'''Construct the network. '''
raise
NotImplementedError
(
'Must be implemented by the subclass.'
)
def
load
(
self
,
data_path
,
exe
=
None
,
place
=
None
,
ignore_missing
=
False
):
'''Load network weights.
data_path: The path to the numpy-serialized network weights
ignore_missing: If true, serialized weights for missing layers are ignored.
'''
fluid
=
import_fluid
()
#load fluid mode directly
if
os
.
path
.
isdir
(
data_path
):
assert
(
exe
is
not
None
),
\
'must provide a executor to load fluid model'
fluid
.
io
.
load_persistables_if_exist
(
executor
=
exe
,
dirname
=
data_path
)
return
True
#load model from a npy file
if
exe
is
None
or
place
is
None
:
if
self
.
paddle_env
is
None
:
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
self
.
paddle_env
=
{
'place'
:
place
,
'exe'
:
exe
}
exe
=
exe
.
run
(
fluid
.
default_startup_program
())
else
:
place
=
self
.
paddle_env
[
'place'
]
exe
=
self
.
paddle_env
[
'exe'
]
data_dict
=
np
.
load
(
data_path
).
item
()
for
op_name
in
data_dict
:
layer
=
self
.
layers
[
op_name
]
for
param_name
,
data
in
data_dict
[
op_name
].
iteritems
():
try
:
name
=
'%s_%s'
%
(
op_name
,
param_name
)
v
=
fluid
.
global_scope
().
find_var
(
name
)
w
=
v
.
get_tensor
()
w
.
set
(
data
,
place
)
except
ValueError
:
if
not
ignore_missing
:
raise
return
True
def
feed
(
self
,
*
args
):
'''Set the input(s) for the next operation by replacing the terminal nodes.
The arguments can be either layer names or the actual layers.
'''
assert
len
(
args
)
!=
0
self
.
terminals
=
[]
for
fed_layer
in
args
:
if
isinstance
(
fed_layer
,
basestring
):
try
:
fed_layer
=
self
.
layers
[
fed_layer
]
except
KeyError
:
raise
KeyError
(
'Unknown layer name fed: %s'
%
fed_layer
)
self
.
terminals
.
append
(
fed_layer
)
return
self
def
get_output
(
self
):
'''Returns the current network output.'''
return
self
.
terminals
[
-
1
]
def
get_unique_name
(
self
,
prefix
):
'''Returns an index-suffixed unique name for the given prefix.
This is used for auto-generating layer names based on the type-prefix.
'''
ident
=
sum
(
t
.
startswith
(
prefix
)
for
t
,
_
in
self
.
layers
.
items
())
+
1
return
'%s_%d'
%
(
prefix
,
ident
)
@
layer
def
conv
(
self
,
input
,
k_h
,
k_w
,
c_o
,
s_h
,
s_w
,
name
,
relu
=
True
,
padding
=
None
,
group
=
1
,
biased
=
True
):
if
padding
is
None
:
padding
=
[
0
,
0
]
# Get the number of channels in the input
c_i
,
h_i
,
w_i
=
input
.
shape
[
1
:]
# Verify that the grouping parameter is valid
assert
c_i
%
group
==
0
assert
c_o
%
group
==
0
fluid
=
import_fluid
()
prefix
=
name
+
'_'
output
=
fluid
.
layers
.
conv2d
(
input
=
input
,
filter_size
=
[
k_h
,
k_w
],
num_filters
=
c_o
,
stride
=
[
s_h
,
s_w
],
padding
=
padding
,
groups
=
group
,
param_attr
=
fluid
.
ParamAttr
(
name
=
prefix
+
"weights"
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
prefix
+
"biases"
),
act
=
"relu"
if
relu
is
True
else
None
)
return
output
@
layer
def
relu
(
self
,
input
,
name
):
fluid
=
import_fluid
()
output
=
fluid
.
layers
.
relu
(
x
=
input
)
return
output
@
layer
def
max_pool
(
self
,
input
,
k_h
,
k_w
,
s_h
,
s_w
,
name
,
padding
=
None
):
if
padding
is
None
:
padding
=
[
0
,
0
]
# Get the number of channels in the input
h_i
,
w_i
=
input
.
shape
[
2
:]
fluid
=
import_fluid
()
output
=
fluid
.
layers
.
pool2d
(
input
=
input
,
pool_size
=
[
k_h
,
k_w
],
pool_stride
=
[
s_h
,
s_w
],
pool_padding
=
padding
,
pool_type
=
'max'
)
return
output
@
layer
def
avg_pool
(
self
,
input
,
k_h
,
k_w
,
s_h
,
s_w
,
name
,
padding
=
None
):
if
padding
is
None
:
padding
=
[
0
,
0
]
# Get the number of channels in the input
h_i
,
w_i
=
input
.
shape
[
2
:]
fluid
=
import_fluid
()
output
=
fluid
.
layers
.
pool2d
(
input
=
input
,
pool_size
=
[
k_h
,
k_w
],
pool_stride
=
[
s_h
,
s_w
],
pool_padding
=
padding
,
pool_type
=
'avg'
)
return
output
@
layer
def
lrn
(
self
,
input
,
radius
,
alpha
,
beta
,
name
,
bias
=
1.0
):
raise
Exception
(
'lrn() not implemented yet'
)
@
layer
def
concat
(
self
,
inputs
,
axis
,
name
):
fluid
=
import_fluid
()
output
=
fluid
.
layers
.
concat
(
input
=
inputs
,
axis
=
axis
)
return
output
@
layer
def
add
(
self
,
inputs
,
name
):
fluid
=
import_fluid
()
output
=
inputs
[
0
]
for
i
in
inputs
[
1
:]:
output
=
fluid
.
layers
.
elementwise_add
(
x
=
output
,
y
=
i
)
return
output
@
layer
def
fc
(
self
,
input
,
num_out
,
name
,
relu
=
True
,
act
=
None
):
fluid
=
import_fluid
()
if
act
is
None
:
act
=
'relu'
if
relu
is
True
else
None
prefix
=
name
+
'_'
output
=
fluid
.
layers
.
fc
(
name
=
name
,
input
=
input
,
size
=
num_out
,
act
=
act
,
param_attr
=
fluid
.
ParamAttr
(
name
=
prefix
+
'weights'
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
prefix
+
'biases'
))
return
output
@
layer
def
softmax
(
self
,
input
,
name
):
fluid
=
import_fluid
()
output
=
fluid
.
layers
.
softmax
(
x
=
input
,
name
=
name
)
return
output
@
layer
def
batch_normalization
(
self
,
input
,
name
,
scale_offset
=
True
,
relu
=
False
):
# NOTE: Currently, only inference is supported
fluid
=
import_fluid
()
prefix
=
name
+
'_'
param_attr
=
None
if
scale_offset
is
False
else
fluid
.
ParamAttr
(
name
=
prefix
+
'scale'
)
bias_attr
=
None
if
scale_offset
is
False
else
fluid
.
ParamAttr
(
name
=
prefix
+
'offset'
)
mean_name
=
prefix
+
'mean'
variance_name
=
prefix
+
'variance'
output
=
fluid
.
layers
.
batch_norm
(
name
=
name
,
input
=
input
,
is_test
=
True
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
,
moving_mean_name
=
mean_name
,
moving_variance_name
=
variance_name
,
epsilon
=
1e-5
,
act
=
'relu'
if
relu
is
True
else
None
)
return
output
@
layer
def
dropout
(
self
,
input
,
keep_prob
,
name
):
raise
Exception
(
'dropout() not implemented yet'
)
class
LeNet
(
Network
):
def
setup
(
self
):
self
.
feed
(
'data'
)
self
.
conv
(
5
,
5
,
20
,
1
,
1
,
relu
=
False
,
name
=
'conv1'
)
self
.
max_pool
(
2
,
2
,
2
,
2
,
name
=
'pool1'
)
self
.
conv
(
5
,
5
,
50
,
1
,
1
,
relu
=
False
,
name
=
'conv2'
)
self
.
max_pool
(
2
,
2
,
2
,
2
,
name
=
'pool2'
)
self
.
fc
(
500
,
name
=
'ip1'
)
self
.
fc
(
10
,
relu
=
False
,
name
=
'ip2'
)
self
.
softmax
(
name
=
'prob'
)
@
classmethod
def
convert
(
cls
,
npy_model
,
fluid_path
):
import
paddle.v2.fluid
as
fluid
data_layer
=
fluid
.
layers
.
data
(
name
=
"data"
,
shape
=
[
1
,
28
,
28
],
dtype
=
"float32"
)
feed_data
=
{
"data"
:
data_layer
}
net
=
cls
(
feed_data
)
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
net
.
load
(
data_path
=
npy_model
,
exe
=
exe
,
place
=
place
)
fluid
.
io
.
save_persistables
(
executor
=
exe
,
dirname
=
fluid_path
)
if
__name__
==
"__main__"
:
#usage: python xxxnet.py xxx.npy ./model
import
sys
npy_weight
=
sys
.
argv
[
1
]
fluid_model
=
sys
.
argv
[
2
]
LeNet
.
convert
(
npy_weight
,
fluid_model
)
exit
(
0
)
fluid/image_classification/caffe2fluid/tests/lenet/predict.py
0 → 100644
浏览文件 @
50afdd83
#!/bin/env python
#function:
# demo to show how to use converted model using caffe2fluid
#
import
numpy
as
np
import
paddle.v2
as
paddle
import
paddle.v2.fluid
as
fluid
from
lenet
import
LeNet
as
MyNet
def
test_model
(
exe
,
test_program
,
fetch_list
,
test_reader
,
feeder
):
acc_set
=
[]
for
data
in
test_reader
():
acc_np
,
pred
=
exe
.
run
(
program
=
test_program
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
fetch_list
)
acc_set
.
append
(
float
(
acc_np
))
acc_val
=
np
.
array
(
acc_set
).
mean
()
return
float
(
acc_val
)
def
main
(
model_path
):
""" main
"""
print
(
'load fluid model in %s'
%
(
model_path
))
with_gpu
=
False
paddle
.
init
(
use_gpu
=
with_gpu
)
#1, define network topology
images
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
1
,
28
,
28
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
net
=
MyNet
({
'data'
:
images
})
prediction
=
net
.
layers
[
'prob'
]
acc
=
fluid
.
layers
.
accuracy
(
input
=
prediction
,
label
=
label
)
place
=
fluid
.
CUDAPlace
(
0
)
if
with_gpu
is
True
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
#2, load weights
if
model_path
.
find
(
'.npy'
)
>
0
:
net
.
load
(
data_path
=
model_path
,
exe
=
exe
,
place
=
place
)
else
:
net
.
load
(
data_path
=
model_path
,
exe
=
exe
)
#3, test this model
test_program
=
fluid
.
default_main_program
().
clone
()
test_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
128
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
images
,
label
],
place
=
place
)
fetch_list
=
[
acc
,
prediction
]
print
(
'go to test model using test set'
)
acc_val
=
test_model
(
exe
,
test_program
,
\
fetch_list
,
test_reader
,
feeder
)
print
(
'test accuracy is [%.4f], expected value[0.919]'
%
(
acc_val
))
if
__name__
==
"__main__"
:
import
sys
if
len
(
sys
.
argv
)
==
2
:
fluid_model_path
=
sys
.
argv
[
1
]
else
:
fluid_model_path
=
'./model.fluid'
main
(
fluid_model_path
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录