提交 b3943b0c 编写于 作者: S sandyhouse

resolve conflict

# This demo shows how to use user-defined training dataset.
# The following steps are needed to use user-defined training datasets:
# 1. Build a reader, which preprocess images and yield a sample in the
# format (data, label) each time, where data is the decoded image data;
# 2. Batch the above samples;
# 3. Set the reader to use during training to the above batch reader.
import argparse
import paddle
from plsc import Entry
from plsc.utils import jpeg_reader as reader
parser = argparse.ArgumentParser()
parser.add_argument("--model_save_dir",
type=str,
default="./saved_model",
help="Directory to save models.")
parser.add_argument("--data_dir",
type=str,
default="./data",
help="Directory for datasets.")
parser.add_argument("--num_epochs",
type=int,
default=2,
help="Number of epochs to run.")
parser.add_argument("--loss_type",
type=str,
default='arcface',
help="Loss type to use.")
args = parser.parse_args()
def main():
global args
ins = Entry()
ins.set_model_save_dir(args.model_save_dir)
ins.set_dataset_dir(args.data_dir)
ins.set_train_epochs(args.num_epochs)
ins.set_loss_type(args.loss_type)
# 1. Build a reader, which yield a sample in the format (data, label)
# each time, where data is the decoded image data;
train_reader = reader.arc_train(args.data_dir,
ins.num_classes)
# 2. Batch the above samples;
batched_train_reader = paddle.batch(train_reader,
ins.train_batch_size)
# 3. Set the reader to use during training to the above batch reader.
ins.train_reader = batched_train_reader
ins.train()
if __name__ == "__main__":
main()
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line.
SPHINXOPTS =
SPHINXBUILD = sphinx-build
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
\ No newline at end of file
#!/usr/bin/env python3
"""
markdown to rst
"""
# -*- coding: utf-8 -*-
from __future__ import print_function, unicode_literals
import os
import os.path
import re
import sys
from argparse import ArgumentParser, Namespace
from docutils import statemachine, nodes, io, utils
from docutils.parsers import rst
from docutils.core import ErrorString
from docutils.utils import SafeString, column_width
import mistune
if sys.version_info < (3, ):
from codecs import open as _open
from urlparse import urlparse
else:
_open = open
from urllib.parse import urlparse
__version__ = '0.2.1'
_is_sphinx = False
prolog = '''\
.. role:: raw-html-m2r(raw)
:format: html
'''
# for command-line use
parser = ArgumentParser()
options = Namespace()
parser.add_argument(
'input_file', nargs='*', help='files to convert to reST format')
parser.add_argument(
'--overwrite',
action='store_true',
default=False,
help='overwrite output file without confirmaion')
parser.add_argument(
'--dry-run',
action='store_true',
default=False,
help='print conversion result and not save output file')
parser.add_argument(
'--no-underscore-emphasis',
action='store_true',
default=False,
help='do not use underscore (_) for emphasis')
parser.add_argument(
'--parse-relative-links',
action='store_true',
default=False,
help='parse relative links into ref or doc directives')
parser.add_argument(
'--anonymous-references',
action='store_true',
default=False,
help='use anonymous references in generated rst')
parser.add_argument(
'--disable-inline-math',
action='store_true',
default=False,
help='disable parsing inline math')
def parse_options():
"""parse_options"""
parser.parse_known_args(namespace=options)
class RestBlockGrammar(mistune.BlockGrammar):
"""RestBlockGrammar"""
directive = re.compile(
r'^( *\.\..*?)\n(?=\S)',
re.DOTALL | re.MULTILINE, )
oneline_directive = re.compile(
r'^( *\.\..*?)$',
re.DOTALL | re.MULTILINE, )
rest_code_block = re.compile(
r'^::\s*$',
re.DOTALL | re.MULTILINE, )
class RestBlockLexer(mistune.BlockLexer):
"""RestBlockLexer"""
grammar_class = RestBlockGrammar
default_rules = [
'directive',
'oneline_directive',
'rest_code_block',
] + mistune.BlockLexer.default_rules
def parse_directive(self, m):
"""parse_directive"""
self.tokens.append({
'type': 'directive',
'text': m.group(1),
})
def parse_oneline_directive(self, m):
"""parse_oneline_directive"""
# reuse directive output
self.tokens.append({
'type': 'directive',
'text': m.group(1),
})
def parse_rest_code_block(self, m):
"""parse_rest_code_block"""
self.tokens.append({'type': 'rest_code_block', })
class RestInlineGrammar(mistune.InlineGrammar):
"""RestInlineGrammar"""
image_link = re.compile(
r'\[!\[(?P<alt>.*?)\]\((?P<url>.*?)\).*?\]\((?P<target>.*?)\)')
rest_role = re.compile(r':.*?:`.*?`|`[^`]+`:.*?:')
rest_link = re.compile(r'`[^`]*?`_')
inline_math = re.compile(r'.*\$(.*)?\$')
eol_literal_marker = re.compile(r'(\s+)?::\s*$')
# add colon and space as special text
text = re.compile(r'^[\s\S]+?(?=[\\<!\[:_*`~ ]|https?://| {2,}\n|$)')
# __word__ or **word**
double_emphasis = re.compile(r'^([_*]){2}(?P<text>[\s\S]+?)\1{2}(?!\1)')
# _word_ or *word*
emphasis = re.compile(r'^\b_((?:__|[^_])+?)_\b' # _word_
r'|'
r'^\*(?P<text>(?:\*\*|[^\*])+?)\*(?!\*)' # *word*
)
def no_underscore_emphasis(self):
"""no_underscore_emphasis"""
self.double_emphasis = re.compile(
r'^\*{2}(?P<text>[\s\S]+?)\*{2}(?!\*)' # **word**
)
self.emphasis = re.compile(
r'^\*(?P<text>(?:\*\*|[^\*])+?)\*(?!\*)' # *word*
)
class RestInlineLexer(mistune.InlineLexer):
"""RestInlineLexer"""
grammar_class = RestInlineGrammar
default_rules = [
'image_link',
'rest_role',
'rest_link',
'eol_literal_marker',
] + mistune.InlineLexer.default_rules
def __init__(self, *args, **kwargs):
no_underscore_emphasis = kwargs.pop('no_underscore_emphasis', False)
disable_inline_math = kwargs.pop('disable_inline_math', False)
super(RestInlineLexer, self).__init__(*args, **kwargs)
if not _is_sphinx:
parse_options()
if no_underscore_emphasis or getattr(options, 'no_underscore_emphasis',
False):
self.rules.no_underscore_emphasis()
inline_maths = 'inline_math' in self.default_rules
if disable_inline_math or getattr(options, 'disable_inline_math',
False):
if inline_maths:
self.default_rules.remove('inline_math')
elif not inline_maths:
self.default_rules.insert(0, 'inline_math')
def output_double_emphasis(self, m):
"""output_double_emphasis"""
# may include code span
text = self.output(m.group('text'))
return self.renderer.double_emphasis(text)
def output_emphasis(self, m):
"""output_emphasis"""
# may include code span
text = self.output(m.group('text') or m.group(1))
return self.renderer.emphasis(text)
def output_image_link(self, m):
"""Pass through rest role."""
return self.renderer.image_link(
m.group('url'), m.group('target'), m.group('alt'))
def output_rest_role(self, m):
"""Pass through rest role."""
return self.renderer.rest_role(m.group(0))
def output_rest_link(self, m):
"""Pass through rest link."""
return self.renderer.rest_link(m.group(0))
def output_inline_math(self, m):
"""Pass through rest link."""
return self.renderer.inline_math(m.group(0))
def output_eol_literal_marker(self, m):
"""Pass through rest link."""
marker = ':' if m.group(1) is None else ''
return self.renderer.eol_literal_marker(marker)
class RestRenderer(mistune.Renderer):
"""RestRenderer"""
_include_raw_html = False
list_indent_re = re.compile(r'^(\s*(#\.|\*)\s)')
indent = ' ' * 3
list_marker = '{#__rest_list_mark__#}'
hmarks = {
1: '=',
2: '-',
3: '^',
4: '~',
5: '"',
6: '#',
}
def __init__(self, *args, **kwargs):
self.parse_relative_links = kwargs.pop('parse_relative_links', False)
self.anonymous_references = kwargs.pop('anonymous_references', False)
super(RestRenderer, self).__init__(*args, **kwargs)
if not _is_sphinx:
parse_options()
if getattr(options, 'parse_relative_links', False):
self.parse_relative_links = options.parse_relative_links
if getattr(options, 'anonymous_references', False):
self.anonymous_references = options.anonymous_references
def _indent_block(self, block):
return '\n'.join(self.indent + line if line else ''
for line in block.splitlines())
def _raw_html(self, html):
self._include_raw_html = True
return '\ :raw-html-m2r:`{}`\ '.format(html)
def block_code(self, code, lang=None):
"""block_code"""
if lang == 'math':
first_line = '\n.. math::\n\n'
elif lang:
first_line = '\n.. code-block:: {}\n\n'.format(lang)
elif _is_sphinx:
first_line = '\n.. code-block:: guess\n\n'
else:
first_line = '\n.. code-block::\n\n'
return first_line + self._indent_block(code) + '\n'
def block_quote(self, text):
"""block_quote"""
# text includes some empty line
return '\n..\n\n{}\n\n'.format(self._indent_block(text.strip('\n')))
def block_html(self, html):
"""Rendering block level pure html content.
:param html: text content of the html snippet.
"""
return '\n\n.. raw:: html\n\n' + self._indent_block(html) + '\n\n'
def header(self, text, level, raw=None):
"""Rendering header/heading tags like ``<h1>`` ``<h2>``.
:param text: rendered text content for the header.
:param level: a number for the header level, for example: 1.
:param raw: raw text content of the header.
"""
return '\n{0}\n{1}\n'.format(text,
self.hmarks[level] * column_width(text))
def hrule(self):
"""Rendering method for ``<hr>`` tag."""
return '\n----\n'
def list(self, body, ordered=True):
"""Rendering list tags like ``<ul>`` and ``<ol>``.
:param body: body contents of the list.
:param ordered: whether this list is ordered or not.
"""
mark = '#. ' if ordered else '* '
lines = body.splitlines()
for i, line in enumerate(lines):
if line and not line.startswith(self.list_marker):
lines[i] = ' ' * len(mark) + line
return '\n{}\n'.format('\n'.join(lines)).replace(self.list_marker,
mark)
def list_item(self, text):
"""Rendering list item snippet. Like ``<li>``."""
return '\n' + self.list_marker + text
def paragraph(self, text):
"""Rendering paragraph tags. Like ``<p>``."""
return '\n' + text + '\n'
def table(self, header, body):
"""Rendering table element. Wrap header and body in it.
:param header: header part of the table.
:param body: body part of the table.
"""
table = '\n.. list-table::\n'
if header and not header.isspace():
table = (table + self.indent + ':header-rows: 1\n\n' +
self._indent_block(header) + '\n')
else:
table = table + '\n'
table = table + self._indent_block(body) + '\n\n'
return table
def table_row(self, content):
"""Rendering a table row. Like ``<tr>``.
:param content: content of current table row.
"""
contents = content.splitlines()
if not contents:
return ''
clist = ['* ' + contents[0]]
if len(contents) > 1:
for c in contents[1:]:
clist.append(' ' + c)
return '\n'.join(clist) + '\n'
def table_cell(self, content, **flags):
"""Rendering a table cell. Like ``<th>`` ``<td>``.
:param content: content of current table cell.
:param header: whether this is header or not.
:param align: align of current table cell.
"""
return '- ' + content + '\n'
def double_emphasis(self, text):
"""Rendering **strong** text.
:param text: text content for emphasis.
"""
return '\ **{}**\ '.format(text)
def emphasis(self, text):
"""Rendering *emphasis* text.
:param text: text content for emphasis.
"""
return '\ *{}*\ '.format(text)
def codespan(self, text):
"""Rendering inline `code` text.
:param text: text content for inline code.
"""
if '``' not in text:
return '\ ``{}``\ '.format(text)
else:
# actually, docutils split spaces in literal
return self._raw_html('<code class="docutils literal">'
'<span class="pre">{}</span>'
'</code>'.format(
text.replace('`', '&#96;')))
def linebreak(self):
"""Rendering line break like ``<br>``."""
if self.options.get('use_xhtml'):
return self._raw_html('<br />') + '\n'
return self._raw_html('<br>') + '\n'
def strikethrough(self, text):
"""Rendering ~~strikethrough~~ text.
:param text: text content for strikethrough.
"""
return self._raw_html('<del>{}</del>'.format(text))
def text(self, text):
"""Rendering unformatted text.
:param text: text content.
"""
return text
def autolink(self, link, is_email=False):
"""Rendering a given link or email address.
:param link: link content or email address.
:param is_email: whether this is an email or not.
"""
return link
def link(self, link, title, text):
"""Rendering a given link with content and title.
:param link: href link for ``<a>`` tag.
:param title: title content for `title` attribute.
:param text: text content for description.
"""
if self.anonymous_references:
underscore = '__'
else:
underscore = '_'
if title:
return self._raw_html(
'<a href="{link}" title="{title}">{text}</a>'.format(
link=link, title=title, text=text))
if not self.parse_relative_links:
return '\ `{text} <{target}>`{underscore}\ '.format(
target=link, text=text, underscore=underscore)
else:
url_info = urlparse(link)
if url_info.scheme:
return '\ `{text} <{target}>`{underscore}\ '.format(
target=link, text=text, underscore=underscore)
else:
link_type = 'doc'
anchor = url_info.fragment
if url_info.fragment:
if url_info.path:
# Can't link to anchors via doc directive.
anchor = ''
else:
# Example: [text](#anchor)
link_type = 'ref'
doc_link = '{doc_name}{anchor}'.format(
# splittext approach works whether or not path is set. It
# will return an empty string if unset, which leads to
# anchor only ref.
doc_name=os.path.splitext(url_info.path)[0],
anchor=anchor)
return '\ :{link_type}:`{text} <{doc_link}>`\ '.format(
link_type=link_type, doc_link=doc_link, text=text)
def image(self, src, title, text):
"""Rendering a image with title and text.
:param src: source link of the image.
:param title: title text of the image.
:param text: alt text of the image.
"""
# rst does not support title option
# and I couldn't find title attribute in HTML standard
return '\n'.join([
'',
'.. image:: {}'.format(src),
' :target: {}'.format(src),
' :alt: {}'.format(text),
'',
])
def inline_html(self, html):
"""Rendering span level pure html content.
:param html: text content of the html snippet.
"""
return self._raw_html(html)
def newline(self):
"""Rendering newline element."""
return ''
def footnote_ref(self, key, index):
"""Rendering the ref anchor of a footnote.
:param key: identity key for the footnote.
:param index: the index count of current footnote.
"""
return '\ [#fn-{}]_\ '.format(key)
def footnote_item(self, key, text):
"""Rendering a footnote item.
:param key: identity key for the footnote.
:param text: text content of the footnote.
"""
return '.. [#fn-{0}] {1}\n'.format(key, text.strip())
def footnotes(self, text):
"""Wrapper for all footnotes.
:param text: contents of all footnotes.
"""
if text:
return '\n\n' + text
else:
return ''
"""Below outputs are for rst."""
def image_link(self, url, target, alt):
"""image_link"""
return '\n'.join([
'',
'.. image:: {}'.format(url),
' :target: {}'.format(target),
' :alt: {}'.format(alt),
'',
])
def rest_role(self, text):
"""rest_role"""
return text
def rest_link(self, text):
"""rest_link"""
return text
def inline_math(self, math):
"""Extension of recommonmark"""
return re.sub(r'\$(.*?)\$',
lambda x: '\ :math:`{}`\ '.format(x.group(1)), math)
def eol_literal_marker(self, marker):
"""Extension of recommonmark"""
return marker
def directive(self, text):
"""directive"""
return '\n' + text + '\n'
def rest_code_block(self):
"""rest_code_block"""
return '\n\n'
class M2R(mistune.Markdown):
"""M2R"""
def __init__(self,
renderer=None,
inline=RestInlineLexer,
block=RestBlockLexer,
**kwargs):
if renderer is None:
renderer = RestRenderer(**kwargs)
super(M2R, self).__init__(
renderer, inline=inline, block=block, **kwargs)
def parse(self, text):
"""parse"""
output = super(M2R, self).parse(text)
return self.post_process(output)
def output_directive(self):
"""output_directive"""
return self.renderer.directive(self.token['text'])
def output_rest_code_block(self):
"""output_rest_code_block"""
return self.renderer.rest_code_block()
def post_process(self, text):
"""post_process"""
output = (text.replace('\\ \n', '\n').replace('\n\\ ', '\n')
.replace(' \\ ', ' ').replace('\\ ', ' ')
.replace('\\ .', '.'))
if self.renderer._include_raw_html:
return prolog + output
else:
return output
class M2RParser(rst.Parser, object):
"""M2RParser"""
# Explicitly tell supported formats to sphinx
supported = ('markdown', 'md', 'mkd')
def parse(self, inputstrings, document):
"""parse"""
if isinstance(inputstrings, statemachine.StringList):
inputstring = '\n'.join(inputstrings)
else:
inputstring = inputstrings
config = document.settings.env.config
converter = M2R(no_underscore_emphasis=config.no_underscore_emphasis,
parse_relative_links=config.m2r_parse_relative_links,
anonymous_references=config.m2r_anonymous_references,
disable_inline_math=config.m2r_disable_inline_math)
super(M2RParser, self).parse(converter(inputstring), document)
class MdInclude(rst.Directive):
"""Directive class to include markdown in sphinx.
Load a file and convert it to rst and insert as a node. Currently
directive-specific options are not implemented.
"""
required_arguments = 1
optional_arguments = 0
option_spec = {
'start-line': int,
'end-line': int,
}
def run(self):
"""Most of this method is from ``docutils.parser.rst.Directive``.
docutils version: 0.12
"""
if not self.state.document.settings.file_insertion_enabled:
raise self.warning('"%s" directive disabled.' % self.name)
source = self.state_machine.input_lines.source(
self.lineno - self.state_machine.input_offset - 1)
source_dir = os.path.dirname(os.path.abspath(source))
path = rst.directives.path(self.arguments[0])
path = os.path.normpath(os.path.join(source_dir, path))
path = utils.relative_path(None, path)
path = nodes.reprunicode(path)
# get options (currently not use directive-specific options)
encoding = self.options.get(
'encoding', self.state.document.settings.input_encoding)
e_handler = self.state.document.settings.input_encoding_error_handler
tab_width = self.options.get('tab-width',
self.state.document.settings.tab_width)
# open the including file
try:
self.state.document.settings.record_dependencies.add(path)
include_file = io.FileInput(
source_path=path, encoding=encoding, error_handler=e_handler)
except UnicodeEncodeError as error:
raise self.severe('Problems with "%s" directive path:\n'
'Cannot encode input file path "%s" '
'(wrong locale?).' %
(self.name, SafeString(path)))
except IOError as error:
raise self.severe('Problems with "%s" directive path:\n%s.' %
(self.name, ErrorString(error)))
# read from the file
startline = self.options.get('start-line', None)
endline = self.options.get('end-line', None)
try:
if startline or (endline is not None):
lines = include_file.readlines()
rawtext = ''.join(lines[startline:endline])
else:
rawtext = include_file.read()
except UnicodeError as error:
raise self.severe('Problem with "%s" directive:\n%s' %
(self.name, ErrorString(error)))
config = self.state.document.settings.env.config
converter = M2R(no_underscore_emphasis=config.no_underscore_emphasis,
parse_relative_links=config.m2r_parse_relative_links,
anonymous_references=config.m2r_anonymous_references,
disable_inline_math=config.m2r_disable_inline_math)
include_lines = statemachine.string2lines(
converter(rawtext), tab_width, convert_whitespace=True)
self.state_machine.insert_input(include_lines, path)
return []
def setup(app):
"""When used for sphinx extension."""
global _is_sphinx
_is_sphinx = True
app.add_config_value('no_underscore_emphasis', False, 'env')
app.add_config_value('m2r_parse_relative_links', False, 'env')
app.add_config_value('m2r_anonymous_references', False, 'env')
app.add_config_value('m2r_disable_inline_math', False, 'env')
app.add_source_parser('.md', M2RParser)
app.add_directive('mdinclude', MdInclude)
metadata = dict(
version=__version__,
parallel_read_safe=True,
parallel_write_safe=True, )
return metadata
def convert(text, **kwargs):
"""convert"""
return M2R(**kwargs)(text)
def parse_from_file(file, encoding='utf-8', **kwargs):
"""parse_from_file"""
if not os.path.exists(file):
raise OSError('No such file exists: {}'.format(file))
with _open(file, encoding=encoding) as f:
src = f.read()
output = convert(src, **kwargs)
return output
def save_to_file(file, src, encoding='utf-8', **kwargs):
"""save_to_file"""
target = os.path.splitext(file)[0] + '.rst'
if not options.overwrite and os.path.exists(target):
confirm = input('{} already exists. overwrite it? [y/n]: '.format(
target))
if confirm.upper() not in ('Y', 'YES'):
print('skip {}'.format(file))
return
with _open(target, 'w', encoding=encoding) as f:
f.write(src)
def main():
"""main"""
parse_options() # parse cli options
if not options.input_file:
parser.print_help()
parser.exit(0)
for file in options.input_file:
output = parse_from_file(file)
if options.dry_run:
print(output)
else:
save_to_file(file, output)
if __name__ == '__main__':
main()
# -*- coding: utf-8 -*-
#
# Configuration file for the Sphinx documentation builder.
#
# This file does only contain a selection of the most common options. For a
# full list see the documentation:
# http://www.sphinx-doc.org/en/master/config
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
# import os
# import sys
# sys.path.insert(0, os.path.abspath('.'))
# -- Project information -----------------------------------------------------
project = u'PLSC'
copyright = u'2020, Paddle Authors'
author = u'Paddle Authors'
# The short X.Y version
version = u''
# The full version, including alpha/beta/rc tags
release = u''
# -- General configuration ---------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'
# The master toctree document.
master_doc = 'index'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = u'zh'
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = None
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'alabaster'
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
# html_theme_options = {}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
#
# The default sidebars (for documents that don't match any pattern) are
# defined by theme itself. Builtin themes are using these templates by
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
# 'searchbox.html']``.
#
# html_sidebars = {}
# -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'PLSCdoc'
# -- Options for LaTeX output ------------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'PLSC.tex', u'PLSC Documentation',
u'Paddle Authors', 'manual'),
]
# -- Options for manual page output ------------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'plsc', u'PLSC Documentation',
[author], 1)
]
# -- Options for Texinfo output ----------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'PLSC', u'PLSC Documentation',
author, 'PLSC', 'One line description of project.',
'Miscellaneous'),
]
# -- Options for Epub output -------------------------------------------------
# Bibliographic Dublin Core info.
epub_title = project
# The unique identifier of the text. This can be a ISBN number
# or the project homepage.
#
# epub_identifier = ''
# A unique identification for the text.
#
# epub_uid = ''
# A list of files that should not be packed into the epub file.
epub_exclude_files = ['search.html']
:github_url: https://github.com/PaddlePaddle/PLSC
快速开始
========
.. toctree::
:maxdepth: 1
:caption: 快速开始
:hidden:
instruction.rst
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
.. toctree::
:maxdepth: 1
:caption: 使用示例
examples/index.rst
.. toctree::
:maxdepth: 2
:caption: API参考
api/plsc
开发团队
=======
PLSC由百度Nimitz团队开发和维护。
许可协议
========
PLSC使用Apache License 2.0。
# 进阶指南
## 模型参数上传和下载(HDFS)
当通过set_hdfs_info(fs_name,fs_ugi,fs_dir_for_save=None,fs_checkpoint_dir=None)函数设置了HDFS相关信息时,PLSC会在训练开始前自动下载云训练模型参数,并在训练结束后自动将保存的模型参数上传到HDFS指定目录。
### 模型参数上传
使用模型参数上传的训练脚本示例如下:
```python
from plsc import Entry
if __name__ == "__main__":
ins = Entry()
ins.set_model_save_dir('./saved_model')
ins.set_hdfs_info("you_hdfs_addr", "name,passwd", "some_dir")
ins.train()
```
### 模型参数下载
使用模型参数下载的训练脚本示例如下:
```python
from plsc import Entry
if __name__ == "__main__":
ins = Entry()
ins.set_checkpoint_dir('./saved_model')
ins.set_hdfs_info("you_hdfs_addr",
"name,passwd",
fs_checkpoint_dir="some_dir")
ins.train()
```
该脚本将HDFS系统中"some_dir"目录下的所有模型参数下载到本地"./saved_model"目录。请确保"./saved_model"目录存在。
## Base64格式图像数据预处理
实际业务中,一种常见的训练数据存储格式是将图像数据编码为base64格式存储,训练数据文件的每一行存储一张图像的base64数据和该图像的标签,并通常以制表符('\t')分隔图像数据和图像标签。
通常,所有训练数据文件的文件列表记录在一个单独的文件中,整个训练数据集的目录结构如下:
```shell script
dataset
|-- file_list.txt
|-- dataset.part1
|-- dataset.part2
... ....
`-- dataset.part10
```
其中,file_list.txt记录训练数据的文件列表,每行代表一个文件,以上面的例子来讲,file_list.txt的文件内容如下:
```shell script
dataset.part1
dataset.part2
...
dataset.part10
```
而数据文件的每一行表示一张图像数据的base64表示,以及以制表符分隔的图像标签。
对于分布式训练,需要每张GPU卡处理相同数量的图像数据,并且通常需要在训练前做一次训练数据的全局shuffle。
本文档介绍Base64格式图像预处理工具,用于对训练数据做全局shuffle,并将训练数据均分到多个数据文件,数据文件的数量和训练中使用的GPU卡数相同。当训练数据的总量不能整除GPU卡数时,通常会填充部分图像数据(填充的图像数据随机选自训练数据集),以保证总的训练图像数量是GPU卡数的整数倍,即每个数据文件中包含相同数量的图像数据。
### 使用指南
该工具位于tools目录下。使用该工具时,需要安装sqlite3模块。可以通过下面的命令安装:
```shell script
pip install sqlite3
```
可以通过下面的命令行查看工具的使用帮助信息:
```shell script
python tools/process_base64_files.py --help
```
该工具支持以下命令行选项:
* data_dir: 训练数据的根目录
* file_list: 记录训练数据文件的列表文件,如file_list.txt
* nranks: 训练所使用的GPU卡的数量。
可以通过以下命令行运行该工具:
```shell script
python tools/process_base64_files.py --data_dir=./dataset --file_list=file_list.txt --nranks=8
```
那么,会生成8个数据文件,每个文件中包含相同数量的训练数据。
可以使用plsc.utils.base64_reader读取base64格式图像数据。
## 混合精度训练
PLSC支持混合精度训练。使用混合精度训练可以提升训练的速度,同时减少训练使用的显存开销。
### 使用指南
可以通过下面的代码设置开启混合精度训练:
```python
from plsc import Entry
def main():
ins = Entry()
ins.set_mixed_precision(True)
ins.train()
if __name__ == "__main__":
main()
```
### 参数说明
set_mixed_precision 函数提供7个参数,其中use_fp16为必选项,决定是否开启混合精度训练,其他6个参数均有默认值,具体说明如下:
| 参数 | 类型 | 默认值| 说明 |
| --- | --- | ---|---|
|use_fp16| bool | 无,需用户设定| 是否开启混合精度训练,设为True为开启混合精度训练 |
|init_loss_scaling| float | 1.0|初始的损失缩放值,这个值有可能会影响混合精度训练的精度,建议设为默认值 |
|incr_every_n_steps | int | 2000|累计迭代`incr_every_n_steps`步都没出现FP16的越界,loss_scaling则会增加`incr_ratio`倍,建议设为默认值 |
|decr_every_n_nan_or_inf| int | 2|累计迭代`decr_every_n_nan_or_inf`步出现了FP16的越界,loss_scaling则会缩小为原来的`decr_ratio`倍,建议设为默认值 |
|incr_ratio |float|2.0|扩大loss_scaling的倍数,建议设为默认值 |
|decr_ratio| float |0.5| 缩小loss_scaling的倍数,建议设为默认值 |
|use_dynamic_loss_scaling | bool | True| 是否使用动态损失缩放机制。如果开启,才会用到`incr_every_n_steps``decr_every_n_nan_or_inf``incr_ratio``decr_ratio`四个参数,开启会提高混合精度训练的稳定性和精度,建议设为默认值 |
|amp_lists|AutoMixedPrecisionLists类|None|自动混合精度列表类,可以指定具体使用fp16计算的operators列表,建议设为默认值 |
更多关于混合精度训练的介绍可参考:
- Paper: [MIXED PRECISION TRAINING](https://arxiv.org/abs/1710.03740)
- Nvidia Introduction: [Training With Mixed Precision](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html)
### 训练性能
配置: Nvidia Tesla v100 GPU 单机8卡
| 模型\速度 | FP32训练 | 混合精度训练 | 加速比 |
| --- | --- | --- | --- |
| ResNet50 | 2567.96 images/s | 3643.11 images/s | 1.42 |
备注:上述模型训练使用的loss_type均为'dist_arcface'。
## 自定义模型
默认地,PLSC构建基于ResNet50模型的训练模型。
PLSC提供了模型基类plsc.models.base_model.BaseModel,用户可以基于该基类构建自己的网络模型。用户自定义的模型类需要继承自该基类,并实现build_network方法,构建自定义模型。
下面的例子给出如何使用BaseModel基类定义用户自己的网络模型和使用方法:
```python
import paddle.fluid as fluid
from plsc import Entry
from plsc.models.base_model import BaseModel
class ResNet(BaseModel):
def __init__(self, layers=50, emb_dim=512):
super(ResNet, self).__init__()
self.layers = layers
self.emb_dim = emb_dim
def build_network(self,
input,
label,
is_train):
layers = self.layers
supported_layers = [50, 101, 152]
assert layers in supported_layers, \
"supported layers {}, but given {}".format(supported_layers, layers)
if layers == 50:
depth = [3, 4, 14, 3]
num_filters = [64, 128, 256, 512]
elif layers == 101:
depth = [3, 4, 23, 3]
num_filters = [256, 512, 1024, 2048]
elif layers == 152:
depth = [3, 8, 36, 3]
num_filters = [256, 512, 1024, 2048]
conv = self.conv_bn_layer(input=input,
num_filters=64,
filter_size=3,
stride=1,
pad=1,
act='prelu',
is_train=is_train)
for block in range(len(depth)):
for i in range(depth[block]):
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 else 1,
is_train=is_train)
bn = fluid.layers.batch_norm(input=conv,
act=None,
epsilon=2e-05,
is_test=False if is_train else True)
drop = fluid.layers.dropout(
x=bn,
dropout_prob=0.4,
dropout_implementation='upscale_in_train',
is_test=False if is_train else True)
fc = fluid.layers.fc(
input=drop,
size=self.emb_dim,
act=None,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Xavier(uniform=False, fan_in=0.0)),
bias_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.ConstantInitializer()))
emb = fluid.layers.batch_norm(input=fc,
act=None,
epsilon=2e-05,
is_test=False if is_train else True)
return emb
def conv_bn_layer(
... ...
if __name__ == "__main__":
ins = Entry()
ins.set_model(ResNet())
ins.train()
```
用户自定义模型类需要继承自基类BaseModel,并实现build_network方法。
build_network方法的输入如下:
* input: 输入图像数据
* label: 图像类别
* is_train: 表示训练阶段还是测试/预测阶段
build_network方法返回用户自定义组网的输出变量。
## 自定义训练数据
默认地,我们假设用户的训练数据目录组织如下:
```shell script
train_data/
|-- images
`-- label.txt
```
其中,images目录中存放用户训练数据,label.txt文件记录用户训练数据中每幅图像的地址和对应的类别标签。
当用户的训练数据按照其它自定义格式组织时,可以按照下面的步骤使用自定义训练数据:
1. 定义reader函数(生成器),该函数对用户数据进行预处理(如裁剪),并使用yield生成数据样本;
* 数据样本的格式为形如(data, label)的元组,其中data为解码和预处理后的图像数据,label为该图像的类别标签。
2. 使用paddle.batch封装reader生成器,得到新的生成器batched_reader;
3. 将batched_reader赋值给plsc.Entry类示例的train_reader成员。
为了便于描述,我们仍然假设用户训练数据组织结构如下:
```shell script
train_data/
|-- images
`-- label.txt
```
定义样本生成器的代码如下所示(reader.py):
```python
import random
import os
from PIL import Image
def arc_train(data_dir):
label_file = os.path.join(data_dir, 'label.txt')
train_image_list = None
with open(label_file, 'r') as f:
train_image_list = f.readlines()
train_image_list = get_train_image_list(data_dir)
def reader():
for j in range(len(train_image_list)):
path, label = train_image_list[j]
path = os.path.join(data_dir, path)
img = Image.open(path)
if random.randint(0, 1) == 1:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img).astype('float32').transpose((2, 0, 1))
yield img, label
return reader
```
使用用户自定义训练数据的训练代码如下:
```python
import argparse
import paddle
from plsc import Entry
import reader
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir",
type=str,
default="./data",
help="Directory for datasets.")
args = parser.parse_args()
def main():
global args
ins = Entry()
ins.set_dataset_dir(args.data_dir)
train_reader = reader.arc_train(args.data_dir)
# Batch the above samples;
batched_train_reader = paddle.batch(train_reader,
ins.train_batch_size)
# Set the reader to use during training to the above batch reader.
ins.train_reader = batched_train_reader
ins.train()
if __name__ == "__main__":
main()
```
更多详情请参考[示例代码](../../../demo/custom_reader.py)
# API参考
## 默认配置信息
PLSC大规模提供了默认配置参数,用于设置模型训练、验证和模型参数等信息,如训练数据集目录、训练轮数等。
这些默认参数位于plsc.config中,下面给出这些参数的含义和默认值。
### 模型训练相关信息
| 参数名称 | 参数含义 | 默认值 |
| :------- | :------- | :----- |
| train_batch_size | 训练阶段batch size值 | 128 |
| dataset_dir | 数据集根目录 | './train_data' |
| train_image_num | 训练图像的数量 | 5822653 |
| train_epochs | 训练轮数 | 120 |
| warmup_epochs | warmup的轮数 | 0 |
| lr | 初始学习率 | 0.1 |
| lr_steps | 学习率衰减的步数 | (100000,160000,220000) |
### 模型验证相关信息
| 参数名称 | 参数含义 | 默认值 |
| :------- | :------- | :----- |
| val_targets | 验证数据集名称,以逗号分隔,如'lfw,cfp_fp' | lfw |
| test_batch_size | 验证阶段batch size的值 | 120 |
| with_test | 是否在每轮训练之后开始验证模型效果 | True |
### 模型参数相关信息
| 参数名称 | 参数含义 | 默认值 |
| :------- | :------- | :----- |
| model_name | 使用的模型的名称 | 'RestNet50' |
| checkpoint_dir | 预训练模型(checkpoint)目录 | "",表示不使用预训练模型 |
| model_save_dir | 训练模型的保存目录 | "./output" |
| loss_type | loss计算方法, | 'dist_arcface' |
| num_classes | 分类类别的数量 | 85742 |
| image_shape | 图像尺寸列表,格式为CHW | [3, 112, 112] |
| margin | dist_arcface和arcface的margin参数 | 0.5 |
| scale | dist_arcface和arcface的scale参数 | 64.0 |
| emb_size | 模型最后一层隐层的输出维度 | 512 |
备注:
* checkpoint_dir和model_save_dir的区别:
* checkpoint_dir指用于在训练/验证前加载的预训练模型所在目录;
* model_save_dir指的是训练模型的保存目录。
## 参数设置API
PLSC的Entry类提供了下面的API,用于修改默认参数信息:
* set_val_targets(targets)
* 设置验证数据集名称,以逗号分隔,类型为字符串。
* set_train_batch_size(size)
* 设置训练batch size的值,类型为int。
* set_test_batch_size(size)
* 设置验证batch size的值,类型为int。
* set_mixed_precision(use_fp16,init_loss_scaling=1.0,incr_every_n_steps=2000,decr_every_n_nan_or_inf=2,incr_ratio=2.0,decr_ratio=0.5,use_dynamic_loss_scaling=True,amp_lists=None)
* 设置是否使用混合精度训练,以及相关的参数。
* set_hdfs_info(fs_name,fs_ugi,fs_dir_for_save,fs_checkpoint_dir)
* 设置hdfs文件系统信息,具体参数含义如下:
* fs_name: hdfs地址,类型为字符串;
* fs_ugi: 逗号分隔的用户名和密码,类型为字符串;
* fs_dir_for_save: 模型的上传目录,当设置该参数时,会在训练结束后自动将保存的模型参数上传到该目录;
* fs_checkpoint_dir: hdfs上预训练模型参数的保存目录,当设置该参数和checkpoint目录后,会在训练开始前自动下载模型参数。
* set_model_save_dir(dir)
* 设置模型保存路径model_save_dir,类型为字符串。
* set_dataset_dir(dir)
* 设置数据集根目录dataset_dir,类型为字符串。
* set_train_image_num(num)
* 设置训练图像的总数量,类型为int。
* set_calc_acc(calc)
* 设置是否在训练时计算acc1和acc5值,类型为bool,在训练过程中计算acc值会占用额外的显存空间,导致支持的类别数下降,仅在必要时设置。
* set_class_num(num)
* 设置分类类别的总数量,类型为int。
* set_emb_size(size)
* 设置最后一层隐层的输出维度,类型为int。
* set_model(model)
* 设置用户自定义模型类实例,BaseModel的子类的实例。
* set_train_epochs(num)
* 设置训练的轮数,类型为int。
* set_checkpoint_dir(dir)
* 设置用于加载的预训练模型的目录,类型为字符串。
* set_warmup_epochs(num)
* 设置warmup的轮数,类型为int。
* set_loss_type(loss_type)
* 设置模型loss值的计算方法,可选项为'arcface','softmax', 'dist_softmax'和'dist_arcface',类型为字符串;
* 'arcface'和'softmax'表示只使用数据并行,而不是用分布式FC参数,'distarcface'和'distsoftmax'表示使用分布式版本的arcface和softmax,即将最后一层FC的参数分布到多张GPU卡上;
* 关于arcface的细节请参考[ArcFace: Additive Angular Margin Loss for Deep Face Recognition](https://arxiv.org/abs/1801.07698)
* set_image_shape(size)
* 设置图像尺寸,格式为CHW,类型为元组或列表。
* set_optimizer(optimizer)
* 设置训练阶段的optimizer,值为Optimizer类或其子类的实例。
* set_with_test(with_test)
* 设置是否在每完成一轮训练后验证模型效果,类型为bool。
* set_distfc_attr(param_attr=None, bias_attr=None)
* 设置最后一层FC的W和b参数的属性信息,请参考[参数属性信息](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/ParamAttr_cn.html#paramattr)
* convert_for_prediction()
* 将预训练模型转换为预测模型。
* test()
* 模型验证。
* train()
* 模型训练。
备注:set_warmup_epochs和set_image_num函数的附加说明:
默认的,我们认为训练过程中总的batch size为1024时可以取得较好的训练效果。例如,当使用8张GPU时,每张GPU卡上的batch size为128。当训练过程中总的batch size不等于1024时,需要根据batch size调整初始学习率的大小,即:lr = total_batch_size / 1024 * lr。这里,lr表示初始学习率大小。另外,为了保持训练的稳定性,通常需要设置warmup_epochs,即在最初的warmup_epochs轮训练中,学习率有一个较小的值逐步增长到初始学习率。为了实现warmup过程,我们需要根据训练数据集中图像的数量计算每轮的训练步数。
当需要改变这种逻辑设定时,可以自定义实现Optimizer,并通过set_optimizer函数设置。
本节介绍的API均为PLSC的Entry类的方法,需要通过该类的实例调用,例如:
```python
from plsc import Entry
ins = Entry()
ins.set_class_num(85742)
ins.train()
```
# 快速开始
## 安装说明
Python版本要求:
* python 2.7+
### 1. 安装PaddlePaddle
#### 1.1 版本要求:
```shell script
PaddlePaddle>=1.6.2或开发版
```
#### 1.2 pip安装
当前,需要在GPU版本的PaddlePaddle下使用大规模分类库。
```shell script
pip install paddlepaddle-gpu>=1.6.2
```
关于PaddlePaddle对操作系统、CUDA、cuDNN等软件版本的兼容信息,以及更多PaddlePaddle的安装说明,请参考[PaddlePaddle安装说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/install/index_cn.html)
如需要使用开发版本的PaddlePaddle,请先通过下面的命令行卸载已安装的PaddlePaddle,并重新安装开发版本的PaddlePaddle。关于如何获取和安装开发版本的PaddlePaddle,请参考[多版本whl包列表](https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/install/Tables.html#ciwhls)
```shell script
pip uninstall paddlepaddle-gpu
```
### 2. 安装PLSC大规模分类库
可以直接使用pip安装PLSC大规模分类库:
```shell script
pip install plsc
```
## 训练和验证
PLSC提供了从训练、评估到预测部署的全流程解决方案。本节介绍如何使用PLSC快速完成模型训练和模型效果验证。
### 数据准备
我们假设用户数据的组织结构如下:
```shell script
train_data/
|-- agedb_30.bin
|-- cfp_ff.bin
|-- cfp_fp.bin
|-- images
|-- label.txt
`-- lfw.bin
```
其中,*train_data*是用户数据的根目录,*agedb_30.bin**cfp_ff.bin**cfp_fp.bin**lfw.bin*分别是不同的验证数据集,且这些验证数据集不是必须的。本教程默认使用lfw.bin作为验证数据集,因此在浏览本教程时,请确保lfw.bin验证数据集可用。*images*目录包含JPEG格式的训练图像,*label.txt*中的每一行对应一张训练图像以及该图像的类别。
*label.txt*文件的内容示例如下:
```shell script
images/00000000.jpg 0
images/00000001.jpg 0
images/00000002.jpg 0
images/00000003.jpg 0
images/00000004.jpg 0
images/00000005.jpg 0
images/00000006.jpg 0
images/00000007.jpg 0
... ...
```
其中,每一行表示一张图像的路径和该图像对应的类别,图像路径和类别间以空格分隔。
### 模型训练
#### 训练代码
下面给出使用PLSC完成大规模分类训练的脚本文件*train.py*
```python
from plsc import Entry
if __name__ == "__main__":
ins = Entry()
ins.set_train_epochs(1)
ins.set_model_save_dir("./saved_model")
# ins.set_with_test(False) # 当没有验证集时,请取消该行的注释
# ins.set_loss_type('arcface') # 当仅有一张GPU卡时,请取消该行的注释
ins.train()
```
使用PLSC开始训练,包括以下几个主要步骤:
1. 从plsc包导入Entry类,该类是PLCS大规模分类库所有功能的接口类。
2. 生成Entry类的实例。
3. 调用Entry类的train方法,开始模型训练。
默认地,该训练脚本使用的loss值计算方法为'dist_arcface',需要两张或以上的GPU卡,当仅有一张可用GPU卡时,可以使用下面的语句将loss值计算方法改为'arcface'。
```python
ins.set_loss_type('arcface')
```
默认地,训练过程会在每个训练轮次之后会使用验证集验证模型的效果,当没有验证数据集时,可以使用*set_with_test(False)*关闭模型验证。
#### 启动训练任务
下面的例子给出如何使用上述脚本启动训练任务:
```shell script
python -m paddle.distributed.launch \
--cluster_node_ips="127.0.0.1" \
--node_ip="127.0.0.1" \
--selected_gpus=0,1,2,3,4,5,6,7 \
train.py
```
paddle.distributed.launch模块用于启动多机/多卡分布式训练任务脚本,简化分布式训练任务启动过程,各个参数的含义如下:
* cluster_node_ips: 参与训练的机器的ip地址列表,以逗号分隔;
* node_ip: 当前训练机器的ip地址;
* selected_gpus: 每个训练节点所使用的gpu设备列表,以逗号分隔。
对于单机多卡训练任务,可以省略cluster_node_ips和node_ip两个参数,如下所示:
```shell script
python -m paddle.distributed.launch \
--selected_gpus=0,1,2,3,4,5,6,7 \
train.py
```
当仅使用一张GPU卡时,请使用下面的命令启动训练任务:
```shell script
python train.py
```
### 模型验证
本节我们使用lfw.bin验证集为例说明如何评估模型的效果。
#### 验证代码
下面的例子给出模型验证脚本*val.py*
```python
from plsc import Entry
if __name__ == "__main__":
ins = Entry()
ins.set_checkpoint_dir("./saved_model/0/")
ins.test()
```
训练过程中,我们将模型参数保存在'./saved_model'目录下,并将每个epoch的模型参数分别保存在不同的子目录下,例如'./saved_model/0'目录下保存的是第一个epoch训练完成后的模型参数,以此类推。
在模型验证阶段,我们首先需要设置模型参数的目录,接着调用Entry类的test方法开始模型验证。
#### 启动验证任务
下面的例子给出如何使用上述脚本启动验证任务:
```shell script
python -m paddle.distributed.launch \
--selected_gpus=0,1,2,3,4,5,6,7 \
val.py
```
使用上面的脚本,将在多张GPU卡上并行执行验证任务,缩短验证时间。
当仅有一张GPU卡可用时,可以使用下面的命令启动验证任务:
```shell script
python val.py
```
# 预测部署
## 预测模型导出
通常,PLSC在训练过程中保存的模型只包含模型的参数信息,而不包括预测模型结构。为了部署PLSC预测库,需要将预训练模型导出为预测模型。预测模型包括预测所需要的模型参数和模型结构,用于后续地预测任务(参见[预测库使用指南](#预测库使用指南))。
可以通过下面的代码将预训练模型导出为预测模型'export_for_inference.py':
```python
from plsc import Entry
if __name__ == "__main__":
ins = Entry()
ins.set_checkpoint_dir('./pretrain_model')
ins.set_model_save_dir('./inference_model')
ins.convert_for_prediction()
```
其中'./pretrain_model'目录为预训练模型参数目录,'./inference_model'为用于预测的模型目录。
通过下面的命令行启动导出任务:
```shell script
python export_for_inference.py
```
## 预测库使用指南
python版本要求:
* python3
### 安装
#### server端安装
```shell script
pip3 install plsc-serving
```
#### client端安装
* 安装ujson:
```shell script
pip install ujson
```
* 复制[client脚本](./serving/client/face_service/face_service.py)到使用路径。
### 使用指南
#### server端使用指南
目前仅支持在GPU机器上进行预测,要求cuda版本>=9.0。
通过下面的脚本运行server端:
```python
from plsc_serving.run import PLSCServer
fs = PLSCServer()
#设定使用的模型文路径,str类型,绝对路径
fs.with_model(model_path = '/XXX/XXX')
#跑单个进程,gpu_index指定使用的gpu,int类型,默认为0;port指定使用的端口,int类型,默认为8866
fs.run(gpu_index = 0, port = 8010)
```
#### client端使用指南
通过下面的脚本运行client端:
```python
from face_service import FaceService
with open('./data/00000000.jpg', 'rb') as f:
image = f.read()
fc = FaceService()
#添加server端连接,str类型,默认本机8010端口
fc.connect('127.0.0.1:8010')
#调用server端预测,输入为样本列表list类型,返回值为样本对应的embedding结果,list类型,shape为 batch size * embedding size
result = fc.encode([image])
print(result[0])
bc.close()
```
......@@ -136,6 +136,9 @@ class Entry(object):
self.model_save_dir = os.path.abspath(self.model_save_dir)
if self.dataset_dir:
self.dataset_dir = os.path.abspath(self.dataset_dir)
self.lr_decay_factor = 0.1
self.log_period = 200
logger.info('=' * 30)
logger.info("Default configuration:")
......@@ -143,6 +146,8 @@ class Entry(object):
logger.info('\t' + str(key) + ": " + str(self.config[key]))
logger.info('trainer_id: {}, num_trainers: {}'.format(
trainer_id, num_trainers))
logger.info('default lr_decay_factor: {}'.format(self.lr_decay_factor))
logger.info('default log period: {}'.format(self.log_period))
logger.info('=' * 30)
def set_val_targets(self, targets):
......@@ -157,6 +162,20 @@ class Entry(object):
self.global_train_batch_size = batch_size * self.num_trainers
logger.info("Set train batch size to {}.".format(batch_size))
def set_log_period(self, period):
self.log_period = period
logger.info("Set log period to {}.".format(period))
def set_lr_decay_factor(self, factor):
self.lr_decay_factor = factor
logger.info("Set lr decay factor to {}.".format(factor))
def set_step_boundaries(self, boundaries):
if not isinstance(boundaries, list):
raise ValueError("The parameter must be of type list.")
self.lr_steps = boundaries
logger.info("Set step boundaries to {}.".format(boundaries))
def set_mixed_precision(self,
use_fp16,
init_loss_scaling=1.0,
......@@ -332,7 +351,8 @@ class Entry(object):
warmup_steps = steps_per_pass * self.warmup_epochs
batch_denom = 1024
base_lr = start_lr * global_batch_size / batch_denom
lr = [base_lr * (0.1 ** i) for i in range(len(bd) + 1)]
lr_decay_factor = self.lr_decay_factor
lr = [base_lr * (lr_decay_factor ** i) for i in range(len(bd) + 1)]
logger.info("LR boundaries: {}".format(bd))
logger.info("lr_step: {}".format(lr))
if self.warmup_epochs:
......@@ -938,7 +958,7 @@ class Entry(object):
local_time = 0.0
nsamples = 0
inspect_steps = 200
inspect_steps = self.log_period
global_batch_size = self.global_train_batch_size
for pass_id in range(self.train_epochs):
self.train_pass_id = pass_id
......@@ -971,15 +991,15 @@ class Entry(object):
avg_lr = np.mean(local_train_info[1])
speed = nsamples / local_time
if self.calc_train_acc:
logger.info("Pass:{} batch:%d lr:{:.8f} loss:{:.6f} "
logger.info("Pass:{} batch:{} lr:{:.8f} loss:{:.6f} "
"qps:{:.2f} acc1:{:.6f} acc5:{:.6f}".format(
pass_id,
batch_id,
avg_lr,
avg_loss,
speed,
acc1,
acc5))
acc1[0],
acc5[0]))
else:
logger.info("Pass:{} batch:{} lr:{:.8f} loss:{:.6f} "
"qps:{:.2f}".format(pass_id,
......
numpy>=1.12, <=1.16.4 ; python_version<"3.5"
numpy>=1.12 ; python_version>="3.5"
scipy>=0.19.0, <=1.2.1 ; python_version<"3.5"
scipy<=1.3.1 ; python_version>="3.5"
scikit-learn<=0.20 ; python_version<"3.5"
scikit-learn ; python_version>="3.5"
scipy>=0.19.0, <=1.2.1 ; python_version<"3.5"
scipy ; python_version>="3.5"
Pillow
sklearn
easydict
Pillow
six
paddlepaddle-gpu>=1.6.2
......@@ -28,7 +28,7 @@ REQUIRED_PACKAGES = [
'scikit-learn<=0.20;python_version<"3.5"',
'scikit-learn;python_version>="3.5"',
'scipy>=0.19.0,<=1.2.1;python_version<"3.5"',
'scipy;python_version>="3.5"',
'scipy<=1.3.1;python_version>="3.5"',
'sklearn',
'easydict',
'Pillow',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册