未验证 提交 b2a10916 编写于 作者: C cyber-pioneer 提交者: GitHub

Merge ops composite into to_static (#49836)

* support @to_static+to_prime+cinn

* fix code logic

* debug4

* debug5

* debug6

* debug7

* debug 8

* debug 9

* debug10

* debug11

* debug11

* debug 12
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
上级 412573f0
......@@ -371,3 +371,37 @@ def set_paddle_lib_path():
set_paddle_lib_path()
def set_prim_forward(value):
"""set flag FLAGS_prim_forward."""
flag = str(value)
if flag.lower() not in ["true", "false", "debug"]:
raise TypeError(f"flag {flag} should be string of bool or 'debug'.")
os.environ["FLAGS_prim_forward"] = flag
return
def enable_prim_forward():
flag = os.getenv("FLAGS_prim_forward", "true").lower()
if flag == "false":
return False
if flag == "debug":
return "debug"
return True
def set_prim_backward(value):
"""set flag FLAGS_prim_backward,"""
flag = str(value)
if flag.lower() not in ["true", "false"]:
raise TypeError(f"flag {flag} should be bool or string of bool.")
os.environ["FLAGS_prim_backward"] = flag
return
def enable_prim_backward():
flag = os.getenv("FLAGS_prim_backward", "true")
if flag.lower() == "false":
return False
return True
......@@ -57,8 +57,7 @@ attrs = Attr()
def fn(x):
y = paddle.tan(x)
return F.softmax(y, axis=attrs.axis, dtype=attrs.dtype)
return F.softmax(x, axis=attrs.axis, dtype=attrs.dtype)
def expect_forward(inputs):
......@@ -81,8 +80,17 @@ class TestCompositeSoftmax(unittest.TestCase):
)
y = fn(x)
blocks = main_program.blocks
fwd_ops = [op.type for op in blocks[0].ops]
# Ensure that softmax in original block
self.assertTrue('softmax' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that softmax is splitted into small ops
self.assertTrue('softmax' not in fwd_ops_new)
exe = paddle.static.Executor()
exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[y])
......@@ -97,7 +105,7 @@ class TestCompositeSoftmax(unittest.TestCase):
actual = self.cal_composite(np_data)[0]
assert expect.dtype == actual.dtype
assert np.allclose(
np.testing.assert_allclose(
expect,
actual,
rtol=attrs.get_rtol("forward"),
......
......@@ -19,6 +19,7 @@ from utils import TOLERANCE
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
def generate_data(shape, dtype="float32"):
......@@ -57,11 +58,11 @@ attrs = Attr()
def fn(x):
y = paddle.tan(x)
return F.softmax(y, axis=attrs.axis, dtype=attrs.dtype)
return F.softmax(x, axis=attrs.axis, dtype=attrs.dtype)
def expect_grad(inputs):
paddle.disable_static()
inputs.stop_gradient = False
res = fn(inputs)
......@@ -86,8 +87,22 @@ class TestCompositeSoftmax(unittest.TestCase):
x.stop_gradient = False
y = fn(x)
blocks = main_program.blocks
fwd_ops = [op.type for op in blocks[0].ops]
# Ensure that softmax in original block
self.assertTrue('softmax' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that softmax is splitted into small ops
self.assertTrue('softmax' not in fwd_ops_new)
z = paddle.static.gradients([y], x)
fwd_ops_grad = [op.type for op in blocks[0].ops]
# Ensure that softmax_grad not in grad block
self.assertTrue('softmax_grad' not in fwd_ops_grad)
exe = paddle.static.Executor()
exe.run(startup_program)
......@@ -103,7 +118,7 @@ class TestCompositeSoftmax(unittest.TestCase):
actual = self.cal_composite_grad(np_data)[0]
assert expect.dtype == actual.dtype
assert np.allclose(
np.testing.assert_allclose(
expect,
actual,
rtol=attrs.get_rtol("backward"),
......@@ -120,5 +135,59 @@ class TestCompositeSoftmax(unittest.TestCase):
self.compare_backward()
class TestCompositeSoftmaxPrimBackward(unittest.TestCase):
"test composite softmax and prim backward"
def setUp(self):
core.set_prim_enabled(True)
self.dtypes = ["float32"]
self.shapes = [[2, 3, 4], [2, 3]]
self.axes = [-1, 0, 1]
def cal_composite_grad(self, inputs):
paddle.enable_static()
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
'x', shape=inputs.shape, dtype=str(inputs.dtype)
)
x.stop_gradient = False
y = fn(x)
blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks)
z = paddle.static.gradients([y], x)
exe = paddle.static.Executor()
exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z])
paddle.disable_static()
return res
def compare_backward(self):
np_data = generate_data(attrs.shape)
tensor_data = paddle.to_tensor(np_data)
expect = expect_grad(tensor_data)[0].numpy()
actual = self.cal_composite_grad(np_data)[0]
assert expect.dtype == actual.dtype
np.testing.assert_allclose(
expect,
actual,
rtol=attrs.get_rtol("prim_backward"),
atol=attrs.get_rtol("prim_backward"),
)
def test_prim_backward(self):
for i in self.axes:
for j in self.dtypes:
for t in self.shapes:
attrs.set_axis(i)
attrs.set_dtype(j)
attrs.set_shape(t)
self.compare_backward()
if __name__ == '__main__':
unittest.main()
......@@ -15,11 +15,13 @@
TOLERANCE = {
"float32": {
"forward": {"rtol": 1e-6, "atol": 1e-6},
"backward": {"rtol": 1e-6, "atol": 1e-6},
},
"float64": {
"forward": {"rtol": 1e-7, "atol": 1e-7},
"backward": {"rtol": 1e-7, "atol": 1e-7},
"prim_backward": {"rtol": 1e-6, "atol": 1e-6},
},
"float64": {
"forward": {"rtol": 1e-16, "atol": 1e-16},
"backward": {"rtol": 1e-15, "atol": 1e-15},
"prim_backward": {"rtol": 1e-15, "atol": 1e-15},
},
}
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
import unittest
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
def apply_to_static(net, use_cinn):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(net, build_strategy=build_strategy)
class PrimeNet(paddle.nn.Layer):
def __init__(self):
super(PrimeNet, self).__init__()
self.fc = paddle.nn.Linear(4, 4)
def forward(self, x):
y = self.fc(x)
out = F.softmax(y)
return out
class TestPrimForward(unittest.TestCase):
"""
This case only tests prim_forward + to_static + cinn. Thus we need to
set this flag as False to avoid prim_backward.
core.set_prim_backward(False)
"""
def setUp(self):
core.set_prim_backward(False)
paddle.seed(2022)
self.x = paddle.randn([2, 4])
self.x.stop_gradient = False
def train(self, use_prim):
paddle.seed(2022)
net = PrimeNet()
sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters()
)
if use_prim:
net = apply_to_static(net, use_prim)
res = []
for _ in range(10):
out = net(self.x)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_grad()
res.append(out.numpy())
self.check_prim(net, use_prim)
return res
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
# Ensure that softmax is splitted into small ops
self.assertTrue('softmax' not in fwd_ops)
def test_cinn_prim_forward(self):
dy_res = self.train(use_prim=False)
cinn_res = self.train(use_prim=True)
for i in range(len(dy_res)):
np.testing.assert_allclose(
cinn_res[i], dy_res[i], rtol=1e-7, atol=1e-7
)
class TestPrimForwardAndBackward(unittest.TestCase):
"""
Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph
"""
def setUp(self):
paddle.seed(2022)
self.x = paddle.randn([2, 4])
self.x.stop_gradient = False
def train(self, use_prim):
core.set_prim_backward(True)
paddle.seed(2022)
net = PrimeNet()
sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters()
)
if use_prim:
net = apply_to_static(net, use_prim)
res = []
for _ in range(10):
out = net(self.x)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_grad()
res.append(out.numpy())
self.check_prim(net, use_prim)
return res
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
# Ensure that softmax is splitted into small ops
self.assertTrue('softmax' not in fwd_ops)
def test_cinn_prim(self):
plat = platform.system()
if plat == "Linux":
dy_res = self.train(use_prim=False)
cinn_res = self.train(use_prim=True)
for i in range(len(dy_res)):
np.testing.assert_allclose(
cinn_res[i], dy_res[i], rtol=1e-6, atol=1e-6
)
else:
pass
if __name__ == '__main__':
unittest.main()
......@@ -14,6 +14,7 @@
import math
import os
import platform
import tempfile
import time
import unittest
......@@ -450,5 +451,67 @@ class TestResnet(unittest.TestCase):
fluid.set_flags({'FLAGS_use_mkldnn': False})
class TestResnetPrim(unittest.TestCase):
"test prim forward + prim backward + to_static"
def setUp(self):
self.resnet_helper = ResNetHelper()
def train(self, to_static):
paddle.jit.enable_to_static(to_static)
return self.resnet_helper.train(to_static)
def verify_predict(self):
image = np.random.random([1, 3, 224, 224]).astype('float32')
dy_pre = self.resnet_helper.predict_dygraph(image)
st_pre = self.resnet_helper.predict_static(image)
dy_jit_pre = self.resnet_helper.predict_dygraph_jit(image)
predictor_pre = self.resnet_helper.predict_analysis_inference(image)
np.testing.assert_allclose(
dy_pre,
st_pre,
rtol=1e-05,
err_msg='dy_pre:\n {}\n, st_pre: \n{}.'.format(dy_pre, st_pre),
)
np.testing.assert_allclose(
dy_jit_pre,
st_pre,
rtol=1e-05,
err_msg='dy_jit_pre:\n {}\n, st_pre: \n{}.'.format(
dy_jit_pre, st_pre
),
)
np.testing.assert_allclose(
predictor_pre,
st_pre,
rtol=1e-05,
err_msg='predictor_pre:\n {}\n, st_pre: \n{}.'.format(
predictor_pre, st_pre
),
)
def test_resnet_composite(self):
plat = platform.system()
if plat == "Linux":
print("=================== origin resnet ===================")
core.set_prim_enabled(False)
static_loss = self.train(to_static=True)
print("======= resnet with prim forward and backward =======")
core.set_prim_enabled(True)
core.set_prim_forward("debug")
dygraph_loss = self.train(to_static=True)
np.testing.assert_allclose(
static_loss,
dygraph_loss,
rtol=1e-02,
err_msg='static_loss: {} \n dygraph_loss: {}'.format(
static_loss, dygraph_loss
),
)
core.set_prim_enabled(False)
else:
pass
if __name__ == '__main__':
unittest.main()
......@@ -29,7 +29,9 @@ def _composite(op, *args):
@REGISTER_COMPOSITE('softmax')
def softmax_composite(x, axis):
"""define composite rule of op softmax"""
molecular = exp(x)
denominator = broadcast_to(sum(molecular, axis=axis, keepdim=True), x.shape)
max_temp = max(x, axis, keepdim=True)
max_temp.stop_gradient = True
molecular = exp(x - max_temp)
denominator = sum(molecular, axis=axis, keepdim=True)
res = divide(molecular, denominator)
return res
......@@ -16,7 +16,7 @@ import logging
import typing
import paddle
from paddle.fluid import backward, framework
from paddle.fluid import backward, core, framework
from paddle.incubate.autograd import primx, utils
......@@ -218,13 +218,22 @@ def grad(outputs, inputs, grad_outputs=None):
@framework.static_only
def to_prim(blocks):
"""Search nonbasic ops which have be registered composite rules and replace them with primitive ops."""
if not core.enable_prim_forward():
return
if isinstance(blocks, paddle.fluid.framework.Block):
logging.info("Atomize composite op to primitive ops begin.")
primx._lower_composite(blocks)
return
main_program = blocks.program
elif isinstance(blocks, typing.Sequence):
for item in blocks:
to_prim(item)
return
if not isinstance(item, paddle.fluid.framework.Block):
raise TypeError(
f"Expect block or sequence of blocks, but sequence contains {type(item)}."
)
main_program = blocks[0].program
else:
raise TypeError
raise TypeError(
f"Expect block or sequence of blocks, but got {type(blocks)}."
)
with framework.program_guard(main_program):
primx._lower_composite(blocks)
return
......@@ -39,6 +39,8 @@ from paddle.tensor import log1p # noqa: F401
from paddle.tensor import logcumsumexp # noqa: F401
from paddle.tensor import logit # noqa: F401
from paddle.tensor import logsumexp # noqa: F401
from paddle.tensor import max # noqa: F401
from paddle.tensor import min # noqa: F401
from paddle.tensor import multiply # noqa: F401
from paddle.tensor import pow # noqa: F401
from paddle.tensor import prod # noqa: F401
......@@ -73,6 +75,8 @@ math_op = [
'logsumexp',
'logcumsumexp',
'logit',
'max',
'min',
]
trigonometric_op = [
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import typing
from collections import OrderedDict
import paddle
......@@ -575,90 +577,101 @@ def _lower_composite(block, blacklist=[]):
return_list.append(x)
return return_list
# Step1: Do some preparatory work for lower
lower_fn = _composite
lookup_fn = lookup_composite
value_table = {}
to_bind = {}
to_bind_rev = {}
for var in block.desc.all_vars():
value_table[var.name()] = block.var(var.name())
ops_to_remove = []
vars_to_remove = set()
# Step2: Process all ops in the target block
for op_idx in range(len(block.ops)):
op = block.ops[op_idx]
ops_to_remove.append(op_idx)
if lookup_fn(op.type) is not None and op.type not in blacklist:
input_args = prepare_python_api_arguments(op)
bind(input_args, to_bind, value_table)
for orig_out, new_out in zip(
expand_nested_list(get_output_var_list(op)),
expand_nested_list(as_tensors(lower_fn(op, *input_args))),
):
assert not (orig_out is None) ^ (
new_out is None
), "orig_out and new_out should match."
vars_to_remove.add(new_out.name)
value_table[new_out.name] = new_out
to_bind[orig_out.name] = new_out.name
to_bind_rev[new_out.name] = orig_out.name
else:
inputs = {}
for i in range(len(op.input_names)):
inputs[op.input_names[i]] = bind_name(
op.input(op.input_names[i]), to_bind
)
outputs = {}
for i in range(len(op.output_names)):
outputs[op.output_names[i]] = op.output(op.output_names[i])
attrs = {}
for name in sorted(op.attr_names):
attrs[name] = op.attr(name)
from paddle.fluid.dygraph.base import param_guard
new_op_desc = block.desc.append_op()
with param_guard(inputs), param_guard(outputs):
op = Operator(
block=block,
desc=new_op_desc,
type=op.type,
inputs=inputs,
outputs=outputs,
attrs=attrs,
)
block.ops.append(op)
# Step3: Do some post-processing work
for op_idx in reversed(ops_to_remove):
block.desc._remove_op(op_idx, op_idx + 1)
del block.ops[op_idx]
block._sync_with_cpp()
for op_idx in range(len(block.ops)):
op = block.ops[op_idx]
for in_name in op.input_arg_names:
if in_name in to_bind_rev:
op._rename_input(in_name, to_bind_rev[in_name])
for out_name in op.output_arg_names:
if out_name in to_bind_rev:
op._rename_output(out_name, to_bind_rev[out_name])
if isinstance(block, paddle.fluid.framework.Block):
logging.info("Atomize composite op to primitive ops begin.")
# Step1: Do some preparatory work for lower
lower_fn = _composite
lookup_fn = lookup_composite
value_table = {}
to_bind = {}
to_bind_rev = {}
for var in block.desc.all_vars():
value_table[var.name()] = block.var(var.name())
ops_to_remove = []
vars_to_remove = set()
# Step2: Process all ops in the target block
for op_idx in range(len(block.ops)):
op = block.ops[op_idx]
ops_to_remove.append(op_idx)
if lookup_fn(op.type) is not None and op.type not in blacklist:
input_args = prepare_python_api_arguments(op)
bind(input_args, to_bind, value_table)
for orig_out, new_out in zip(
expand_nested_list(get_output_var_list(op)),
expand_nested_list(as_tensors(lower_fn(op, *input_args))),
):
assert not (orig_out is None) ^ (
new_out is None
), "orig_out and new_out should match."
vars_to_remove.add(new_out.name)
value_table[new_out.name] = new_out
to_bind[orig_out.name] = new_out.name
to_bind_rev[new_out.name] = orig_out.name
else:
inputs = {}
for i in range(len(op.input_names)):
inputs[op.input_names[i]] = bind_name(
op.input(op.input_names[i]), to_bind
)
outputs = {}
for i in range(len(op.output_names)):
outputs[op.output_names[i]] = op.output(op.output_names[i])
attrs = {}
for name in sorted(op.attr_names):
attrs[name] = op.attr(name)
from paddle.fluid.dygraph.base import param_guard
new_op_desc = block.desc.append_op()
with param_guard(inputs), param_guard(outputs):
op = Operator(
block=block,
desc=new_op_desc,
type=op.type,
inputs=inputs,
outputs=outputs,
attrs=attrs,
)
block.ops.append(op)
# Step3: Do some post-processing work
for op_idx in reversed(ops_to_remove):
block.desc._remove_op(op_idx, op_idx + 1)
del block.ops[op_idx]
block._sync_with_cpp()
for var_name in sorted(vars_to_remove):
assert (
var_name in to_bind_rev
), 'var_name "{}" is not in to_bind_rev.'.format(var_name)
if var_name != to_bind_rev[var_name]:
block.desc._remove_var(var_name.encode())
del block.vars[var_name]
block._sync_with_cpp()
for op_idx in range(len(block.ops)):
op = block.ops[op_idx]
for in_name in op.input_arg_names:
if in_name in to_bind_rev:
op._rename_input(in_name, to_bind_rev[in_name])
for out_name in op.output_arg_names:
if out_name in to_bind_rev:
op._rename_output(out_name, to_bind_rev[out_name])
for var_name in sorted(vars_to_remove):
assert (
var_name in to_bind_rev
), 'var_name "{}" is not in to_bind_rev.'.format(var_name)
if var_name != to_bind_rev[var_name]:
block.desc._remove_var(var_name.encode())
del block.vars[var_name]
block._sync_with_cpp()
return
elif isinstance(block, typing.Sequence):
for item in block:
_lower_composite(item)
return
else:
raise TypeError
@framework.static_only
......
......@@ -571,7 +571,13 @@ class PartialProgramLayer:
targets.append(program.global_block().var(out.name))
if targets:
backward.gradients(targets=targets, inputs=[])
enable_prim = self._build_strategy.build_cinn_pass
if enable_prim and core.enable_prim_backward():
core.set_prim_enabled(True)
backward.gradients(targets=targets, inputs=[])
core.set_prim_enabled(False)
else:
backward.gradients(targets=targets, inputs=[])
start_idx = len(main_program.block(0).ops) + 2 * len(
self._outputs.tolist()
......
......@@ -18,7 +18,7 @@ import textwrap
import threading
import weakref
from paddle.fluid import _non_static_mode, framework
from paddle.fluid import _non_static_mode, core, framework
from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph.base import param_guard, switch_to_static_graph
......@@ -930,6 +930,13 @@ class ConcreteProgram:
self.function = function
self.kwargs = kwargs
@switch_to_static_graph
def _to_prim(self):
# TODO(Aurelius84): Fix this cycle import problem
from paddle.incubate.autograd.primapi import to_prim
to_prim(self.main_program.blocks)
@staticmethod
@switch_to_static_graph
def from_func_spec(
......@@ -1083,6 +1090,11 @@ class ProgramCache:
self._recent_cache_key = None
def _build_once(self, cache_key):
# TODO(Aurelius84): Need a gloabl FLAGS to enable/disable to_prim
enable_prim = cache_key.kwargs['build_strategy'].build_cinn_pass
if enable_prim and core.enable_prim_backward():
core.set_prim_enabled(True)
concrete_program = ConcreteProgram.from_func_spec(
func_spec=cache_key.function_spec,
input_spec=cache_key.input_args_with_spec,
......@@ -1090,6 +1102,10 @@ class ProgramCache:
class_instance=cache_key.class_instance,
**cache_key.kwargs
)
if enable_prim or core.enable_prim_forward() == "debug":
concrete_program._to_prim()
core.set_prim_enabled(False)
return concrete_program, partial_program_from(concrete_program)
def __getitem__(self, item):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册