提交 bb9c6afa 编写于 作者: S sneaxiy

merge develop to solve conflict again, test=develop

......@@ -339,9 +339,11 @@ class OpAttrChecker {
return *(checker.target<TypedAttrChecker<T>>());
}
void Check(AttributeMap* attr_map) const {
for (const auto& checker : attr_checkers_) {
checker(attr_map, false);
void Check(AttributeMap* attr_map, bool explicit_only = false) const {
auto checker_num = attr_checkers_.size();
if (explicit_only) checker_num = explicit_checker_num_;
for (size_t i = 0; i < checker_num; ++i) {
attr_checkers_[i](attr_map, false);
}
}
......@@ -353,8 +355,21 @@ class OpAttrChecker {
return default_values_map;
}
void RecordExplicitCheckerNum() {
explicit_checker_num_ = attr_checkers_.size();
}
private:
std::vector<AttrChecker> attr_checkers_;
// in order to improve the efficiency of dynamic graph mode,
// we divede the attribute into explicit type and implicit type.
// for explicit attribute, we mean the attribute added in the customized
// op makers, usually it's defined in the overloaded Make method.
// for implicit attribute, we mean the attribute added outside of the Make
// method like "op_role", "op_role_var", and they are useless in dynamic graph
// mode
size_t explicit_checker_num_;
};
} // namespace framework
......
......@@ -62,6 +62,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
proto_ = proto;
op_checker_ = attr_checker;
Make();
op_checker_->RecordExplicitCheckerNum();
AddAttr<int>(OpRoleAttrName(), "The role of this operator")
.InEnum(
......
......@@ -440,13 +440,11 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
auto &bias_qk = detail::Ref(context.Input<framework::Tensor>("BiasQK"),
"Cannot find QK");
auto *out = context.Output<framework::Tensor>("Out");
auto *input_d = input->data<T>();
auto *w_d = w->data<T>();
auto *bias_d = bias->data<T>();
auto *bias_qk_d = bias_qk.data<T>();
auto *output_d = out->mutable_data<T>(context.GetPlace());
T scale = static_cast<T>(context.Attr<float>("alpha"));
int head_number = context.Attr<int>("head_number");
......@@ -463,6 +461,10 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
int all_head_size = w_dims[2];
int head_size = all_head_size / head_number;
auto *out = context.Output<framework::Tensor>("Out");
out->Resize({batch, seq_len, all_head_size});
auto *output_d = out->mutable_data<T>(context.GetPlace());
// (B*S, hidden)
const Tensor input_matrix =
framework::ReshapeToMatrix(*input, 2 /*x_num_col_dims */);
......
......@@ -48,9 +48,11 @@ def parse_args():
parser.add_argument(
'--qat_model', type=str, default='', help='A path to a QAT model.')
parser.add_argument(
'--save_model',
action='store_true',
help='If used, the QAT model will be saved after all transformations')
'--fp32_model',
type=str,
default='',
help='A path to an FP32 model. If empty, the QAT model will be used for FP32 inference.'
)
parser.add_argument('--infer_data', type=str, default='', help='Data file.')
parser.add_argument(
'--labels', type=str, default='', help='File with labels.')
......@@ -240,7 +242,10 @@ class QatInt8NLPComparisonTest(unittest.TestCase):
return
qat_model_path = test_case_args.qat_model
assert qat_model_path, 'The QAT model path cannot be empty. Please, use the --qat_model option.'
fp32_model_path = test_case_args.fp32_model if test_case_args.fp32_model else qat_model_path
data_path = test_case_args.infer_data
assert data_path, 'The dataset path cannot be empty. Please, use the --infer_data option.'
labels_path = test_case_args.labels
batch_size = test_case_args.batch_size
batch_num = test_case_args.batch_num
......@@ -251,6 +256,7 @@ class QatInt8NLPComparisonTest(unittest.TestCase):
_logger.info('QAT FP32 & INT8 prediction run.')
_logger.info('QAT model: {0}'.format(qat_model_path))
_logger.info('FP32 model: {0}'.format(fp32_model_path))
_logger.info('Dataset: {0}'.format(data_path))
_logger.info('Labels: {0}'.format(labels_path))
_logger.info('Batch size: {0}'.format(batch_size))
......@@ -263,11 +269,12 @@ class QatInt8NLPComparisonTest(unittest.TestCase):
self._reader_creator(data_path, labels_path), batch_size=batch_size)
fp32_acc, fp32_pps, fp32_lat = self._predict(
val_reader,
qat_model_path,
fp32_model_path,
batch_size,
batch_num,
skip_batch_num,
transform_to_int8=False)
_logger.info('FP32: avg accuracy: {0:.6f}'.format(fp32_acc))
_logger.info('--- QAT INT8 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path, labels_path), batch_size=batch_size)
......@@ -278,6 +285,7 @@ class QatInt8NLPComparisonTest(unittest.TestCase):
batch_num,
skip_batch_num,
transform_to_int8=True)
_logger.info('INT8: avg accuracy: {0:.6f}'.format(int8_acc))
self._summarize_performance(fp32_pps, fp32_lat, int8_pps, int8_lat)
self._compare_accuracy(fp32_acc, int8_acc, acc_diff_threshold)
......
......@@ -20,10 +20,18 @@ from .ast_transformer import *
from . import static_analysis
from .static_analysis import *
from . import loop_transformer
from .loop_transformer import *
from . import variable_trans_func
from .variable_trans_func import *
from . import cache_program
from .cache_program import *
__all__ = []
__all__ += ast_transformer.__all__
__all__ += loop_transformer.__all__
__all__ += static_analysis.__all__
__all__ += variable_trans_func.__all__
__all__ += cache_program.__all__
......@@ -13,17 +13,21 @@
# limitations under the License.
from __future__ import print_function
from .utils import *
import gast
import textwrap
import inspect
import astor
# gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST).
# It provides a compatibility layer between the AST of various Python versions,
# as produced by ast.parse from the standard ast module.
# See details in https://github.com/serge-sans-paille/gast/
from .ast_utils import is_control_flow_if, create_cond_node, transform_if_else, ast_to_func
import gast
import textwrap
import inspect
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer
from .ast_utils import is_control_flow_if, create_cond_node, transform_if_else, ast_to_func
from .static_analysis import AstNodeWrapper, StaticAnalysisVisitor
from .utils import *
__all__ = ['DygraphToStaticAst', 'convert_to_static']
......@@ -124,17 +128,19 @@ class DygraphToStaticAst(gast.NodeTransformer):
self.transfer_from_node_type(self.static_analysis_root)
return self.static_analysis_root
def transfer_from_node_type(self, node):
def transfer_from_node_type(self, node_wrapper):
# Generic transformation
self.visit(node.node)
self.visit(node_wrapper.node)
# Transform basic api of dygraph to static graph
basic_api_trans = BasicApiTransformer(node)
basic_api_trans = BasicApiTransformer(node_wrapper)
basic_api_trans.ast_visit()
self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id()
# Transform all if/else statement of Dygraph into Static Graph.
IfElseTransformer(node).ast_visit()
IfElseTransformer(node_wrapper).ast_visit()
LoopTransformer(node_wrapper).transform()
def visit_FunctionDef(self, node):
if self.decorate_func_name is None:
......
......@@ -14,8 +14,8 @@
from __future__ import print_function
import astor
import ast
import astor
import gast
import six
import copy
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import copy
import gast
from collections import defaultdict
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.ast_utils import create_funcDef_node
from paddle.fluid.dygraph.dygraph_to_static.ast_utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node
__all__ = ['LoopTransformer', 'NameVisitor']
WHILE_CONDITION_PREFIX = 'while_condition'
WHILE_BODY_PREFIX = 'while_body'
def create_while_node(condition_name, body_name, loop_var_names):
while_args = []
while_args.append(
gast.Name(
id=condition_name,
ctx=gast.Param(),
annotation=None,
type_comment=None))
while_args.append(
gast.Name(
id=body_name, ctx=gast.Param(), annotation=None, type_comment=None))
assign_targets = [
gast.Name(
id=var_name, ctx=gast.Param(), annotation=None, type_comment=None)
for var_name in loop_var_names
]
while_args.append(gast.List(elts=assign_targets, ctx=gast.Param()))
while_func_id = gast.parse('fluid.layers.while_loop').body[0].value
while_node = gast.Call(func=while_func_id, args=while_args, keywords=[])
assign_node = gast.Assign(
targets=[gast.Tuple(
elts=assign_targets, ctx=gast.Store())],
value=while_node)
return assign_node
class NameVisitor(gast.NodeVisitor):
'''
Analysis name liveness for loop transformer
'''
def __init__(self, root_node):
# Set of gast.Name
self.current_seen_vars = set()
# List of gast.While/gast.For nodes
self.current_loop = []
# Mapping from gast.While/gast.For to string name of vars
self.before_loop_vars = defaultdict(set)
self.in_loop_vars = defaultdict(set)
self.visit(root_node)
def is_control_flow_loop(self, node):
# TODO: make a better condition
return True
def get_loop_var_names(self, node):
assert isinstance(node, gast.While) or isinstance(
while_node, gast.For), "Input node is not gast loop node"
loop_var_names = set()
create_var_names = set()
read_context = {type(gast.Load), type(gast.AugLoad)}
in_loop_vars = self.in_loop_vars[node]
in_loop_name_strs = set(name.id for name in in_loop_vars)
before_loop_vars = self.before_loop_vars[node]
before_loop_name_strs = set(name.id for name in before_loop_vars)
after_loop_vars = self.current_seen_vars - before_loop_vars - in_loop_vars
after_loop_name_strs = set(
name.id for name in after_loop_vars
if type(name.ctx) in read_context)
for name in in_loop_name_strs:
if name in before_loop_name_strs:
# If a variable is used in loop and created before loop, it
# should be in loop_var as input
loop_var_names.add(name)
elif name in after_loop_name_strs:
# If a variable is created in the while loop and read after
# loop, it should be in loop_var and we should create it
loop_var_names.add(name)
create_var_names.add(name)
return loop_var_names, create_var_names
def visit_Name(self, node):
self.current_seen_vars.add(node)
for loop_node in self.current_loop:
self.in_loop_vars[loop_node].add(node)
self.generic_visit(node)
def visit_For(self, node):
self.current_loop.append(node)
self.before_loop_vars[node] = copy.deepcopy(self.current_seen_vars)
self.generic_visit(node)
self.current_loop.pop()
def visit_While(self, node):
self.current_loop.append(node)
self.before_loop_vars[node] = copy.deepcopy(self.current_seen_vars)
self.generic_visit(node)
self.current_loop.pop()
class LoopTransformer(gast.NodeTransformer):
"""
This class transforms python while/for statement into Static Graph Ast
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of WhileTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.name_visitor = NameVisitor(self.root)
def transform(self):
self.visit(self.root)
def get_for_stmt_nodes(self, node):
self.generic_visit(node)
# TODO
return node
def visit(self, node):
self.generic_visit(node)
# All parent nodes that may contain gast.While/gast.For
if hasattr(node, 'body'):
self.replace_stmt_list(node.body)
if hasattr(node, 'orelse'):
self.replace_stmt_list(node.orelse)
return node
def replace_stmt_list(self, body_list):
if not isinstance(body_list, list):
return
i = 0
while i < len(body_list):
if isinstance(body_list[i], gast.While):
new_stmts = self.get_while_stmt_nodes(body_list[i])
body_list[i:i + 1] = new_stmts
i += len(new_stmts)
elif isinstance(body_list[i], gast.For):
# TODO
i += 1
else:
i += 1
def get_while_stmt_nodes(self, node):
# TODO: consider while - else in python
# self.generic_visit(node)
if not self.name_visitor.is_control_flow_loop(node):
return [node]
loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
node)
new_stmts = []
# Python can create variable in loop and use it out of loop, E.g.
#
# while x < 10:
# x += 1
# y = x
# z = y
#
# We need to create static variable for those variables
for name in create_var_names:
new_stmts.append(create_static_variable_gast_node(name))
# while x < 10 in dygraph should be convert into static tensor < 10
for name in loop_var_names:
new_stmts.append(to_static_variable_gast_node(name))
condition_func_node = gast.FunctionDef(
name=unique_name.generate(WHILE_CONDITION_PREFIX),
args=gast.arguments(
args=[
gast.Name(
id=name,
ctx=gast.Param(),
annotation=None,
type_comment=None) for name in loop_var_names
],
posonlyargs=[],
vararg=None,
kwonlyargs=[],
kw_defaults=None,
kwarg=None,
defaults=[]),
body=[gast.Return(value=node.test)],
decorator_list=[],
returns=None,
type_comment=None)
new_stmts.append(condition_func_node)
new_body = node.body
new_body.append(
gast.Return(value=generate_name_node(
loop_var_names, ctx=gast.Load())))
body_func_node = gast.FunctionDef(
name=unique_name.generate(WHILE_BODY_PREFIX),
args=gast.arguments(
args=[
gast.Name(
id=name,
ctx=gast.Param(),
annotation=None,
type_comment=None) for name in loop_var_names
],
posonlyargs=[],
vararg=None,
kwonlyargs=[],
kw_defaults=None,
kwarg=None,
defaults=[]),
body=new_body,
decorator_list=[],
returns=None,
type_comment=None)
new_stmts.append(body_func_node)
while_loop_node = create_while_node(condition_func_node.name,
body_func_node.name, loop_var_names)
new_stmts.append(while_loop_node)
return new_stmts
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import gast
from paddle.fluid.layers import fill_constant
__all__ = ['to_static_variable_gast_node', 'create_static_variable_gast_node']
def to_static_variable_gast_node(name):
func_code = "{} = fluid.dygraph.dygraph_to_static.variable_trans_func.to_static_variable({})".format(
name, name)
return gast.parse(func_code)
def create_static_variable_gast_node(name):
func_code = "{} = fluid.layers.data(name='{}', shape=[-1], dtype='float32')".format(
name, name)
return gast.parse(func_code)
def to_static_variable(x):
'''
Translate a Python variable to PaddlePaddle static graph variable
'''
if isinstance(x, bool):
return fill_constant(shape=[1], dtype='bool', value=x)
if isinstance(x, int):
return fill_constant(shape=[1], dtype='int64', value=x)
if isinstance(x, float):
return fill_constant(shape=[1], dtype='float64', value=x)
return x
......@@ -137,7 +137,9 @@ class CleanupFuncRegistrar():
# BlockingQueue) may not be completely released, resulting in the corresponding
# memory-mapped file remaining on the disk (/dev/shm), so register this function
# to clean up shared memory objects in these two queues before the python interpreter exits.
CleanupFuncRegistrar.register(_cleanup)
# NOTE: Currently multi-process DataLoader only supports Linux platform
if not (sys.platform == 'darwin' or sys.platform == 'win32'):
CleanupFuncRegistrar.register(_cleanup)
class DataLoaderBase(object):
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import gast
import inspect
import numpy as np
import paddle.fluid as fluid
import unittest
from paddle.fluid.dygraph.jit import dygraph_to_static_graph
#from paddle.fluid.dygraph.dygraph_to_static import NameVistor
SEED = 2020
np.random.seed(SEED)
def while_loop_dyfunc(x):
i = fluid.dygraph.to_variable(x)
while x < 10:
i = i + x
x = x + 1
return i
class TestNameVisitor(unittest.TestCase):
def test_loop_vars(self):
#TODO
pass
class TestTransformWhile(unittest.TestCase):
def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
self.x = np.zeros(shape=(1), dtype=np.int32)
def _run_static(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
x_var = fluid.layers.assign(self.x)
static_func = dygraph_to_static_graph(while_loop_dyfunc)
out = static_func(x_var)
exe = fluid.Executor(self.place)
ret = exe.run(main_program, fetch_list=out)
return ret
def _run_dygraph(self):
with fluid.dygraph.guard(self.place):
ret = while_loop_dyfunc(fluid.dygraph.to_variable(self.x))
return ret.numpy()
def test_ast_to_func(self):
static_numpy = self._run_static()
self.assertTrue(
np.allclose(
np.full(
shape=(1), fill_value=45, dtype=np.int32), static_numpy))
# Enable next lines after Paddle dygraph supports while x < 10
#
# self._run_dygraph()
# self.assertTrue(np.allclose(self._run_dygraph(), self._run_static()))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册