未验证 提交 b2a10916 编写于 作者: C cyber-pioneer 提交者: GitHub

Merge ops composite into to_static (#49836)

* support @to_static+to_prime+cinn

* fix code logic

* debug4

* debug5

* debug6

* debug7

* debug 8

* debug 9

* debug10

* debug11

* debug11

* debug 12
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
上级 412573f0
...@@ -371,3 +371,37 @@ def set_paddle_lib_path(): ...@@ -371,3 +371,37 @@ def set_paddle_lib_path():
set_paddle_lib_path() set_paddle_lib_path()
def set_prim_forward(value):
"""set flag FLAGS_prim_forward."""
flag = str(value)
if flag.lower() not in ["true", "false", "debug"]:
raise TypeError(f"flag {flag} should be string of bool or 'debug'.")
os.environ["FLAGS_prim_forward"] = flag
return
def enable_prim_forward():
flag = os.getenv("FLAGS_prim_forward", "true").lower()
if flag == "false":
return False
if flag == "debug":
return "debug"
return True
def set_prim_backward(value):
"""set flag FLAGS_prim_backward,"""
flag = str(value)
if flag.lower() not in ["true", "false"]:
raise TypeError(f"flag {flag} should be bool or string of bool.")
os.environ["FLAGS_prim_backward"] = flag
return
def enable_prim_backward():
flag = os.getenv("FLAGS_prim_backward", "true")
if flag.lower() == "false":
return False
return True
...@@ -57,8 +57,7 @@ attrs = Attr() ...@@ -57,8 +57,7 @@ attrs = Attr()
def fn(x): def fn(x):
y = paddle.tan(x) return F.softmax(x, axis=attrs.axis, dtype=attrs.dtype)
return F.softmax(y, axis=attrs.axis, dtype=attrs.dtype)
def expect_forward(inputs): def expect_forward(inputs):
...@@ -81,8 +80,17 @@ class TestCompositeSoftmax(unittest.TestCase): ...@@ -81,8 +80,17 @@ class TestCompositeSoftmax(unittest.TestCase):
) )
y = fn(x) y = fn(x)
blocks = main_program.blocks blocks = main_program.blocks
fwd_ops = [op.type for op in blocks[0].ops]
# Ensure that softmax in original block
self.assertTrue('softmax' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks) paddle.incubate.autograd.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that softmax is splitted into small ops
self.assertTrue('softmax' not in fwd_ops_new)
exe = paddle.static.Executor() exe = paddle.static.Executor()
exe.run(startup_program) exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[y]) res = exe.run(main_program, feed={'x': inputs}, fetch_list=[y])
...@@ -97,7 +105,7 @@ class TestCompositeSoftmax(unittest.TestCase): ...@@ -97,7 +105,7 @@ class TestCompositeSoftmax(unittest.TestCase):
actual = self.cal_composite(np_data)[0] actual = self.cal_composite(np_data)[0]
assert expect.dtype == actual.dtype assert expect.dtype == actual.dtype
assert np.allclose( np.testing.assert_allclose(
expect, expect,
actual, actual,
rtol=attrs.get_rtol("forward"), rtol=attrs.get_rtol("forward"),
......
...@@ -19,6 +19,7 @@ from utils import TOLERANCE ...@@ -19,6 +19,7 @@ from utils import TOLERANCE
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.fluid import core
def generate_data(shape, dtype="float32"): def generate_data(shape, dtype="float32"):
...@@ -57,11 +58,11 @@ attrs = Attr() ...@@ -57,11 +58,11 @@ attrs = Attr()
def fn(x): def fn(x):
y = paddle.tan(x) return F.softmax(x, axis=attrs.axis, dtype=attrs.dtype)
return F.softmax(y, axis=attrs.axis, dtype=attrs.dtype)
def expect_grad(inputs): def expect_grad(inputs):
paddle.disable_static()
inputs.stop_gradient = False inputs.stop_gradient = False
res = fn(inputs) res = fn(inputs)
...@@ -86,8 +87,22 @@ class TestCompositeSoftmax(unittest.TestCase): ...@@ -86,8 +87,22 @@ class TestCompositeSoftmax(unittest.TestCase):
x.stop_gradient = False x.stop_gradient = False
y = fn(x) y = fn(x)
blocks = main_program.blocks blocks = main_program.blocks
fwd_ops = [op.type for op in blocks[0].ops]
# Ensure that softmax in original block
self.assertTrue('softmax' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks) paddle.incubate.autograd.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that softmax is splitted into small ops
self.assertTrue('softmax' not in fwd_ops_new)
z = paddle.static.gradients([y], x) z = paddle.static.gradients([y], x)
fwd_ops_grad = [op.type for op in blocks[0].ops]
# Ensure that softmax_grad not in grad block
self.assertTrue('softmax_grad' not in fwd_ops_grad)
exe = paddle.static.Executor() exe = paddle.static.Executor()
exe.run(startup_program) exe.run(startup_program)
...@@ -103,7 +118,7 @@ class TestCompositeSoftmax(unittest.TestCase): ...@@ -103,7 +118,7 @@ class TestCompositeSoftmax(unittest.TestCase):
actual = self.cal_composite_grad(np_data)[0] actual = self.cal_composite_grad(np_data)[0]
assert expect.dtype == actual.dtype assert expect.dtype == actual.dtype
assert np.allclose( np.testing.assert_allclose(
expect, expect,
actual, actual,
rtol=attrs.get_rtol("backward"), rtol=attrs.get_rtol("backward"),
...@@ -120,5 +135,59 @@ class TestCompositeSoftmax(unittest.TestCase): ...@@ -120,5 +135,59 @@ class TestCompositeSoftmax(unittest.TestCase):
self.compare_backward() self.compare_backward()
class TestCompositeSoftmaxPrimBackward(unittest.TestCase):
"test composite softmax and prim backward"
def setUp(self):
core.set_prim_enabled(True)
self.dtypes = ["float32"]
self.shapes = [[2, 3, 4], [2, 3]]
self.axes = [-1, 0, 1]
def cal_composite_grad(self, inputs):
paddle.enable_static()
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
'x', shape=inputs.shape, dtype=str(inputs.dtype)
)
x.stop_gradient = False
y = fn(x)
blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks)
z = paddle.static.gradients([y], x)
exe = paddle.static.Executor()
exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z])
paddle.disable_static()
return res
def compare_backward(self):
np_data = generate_data(attrs.shape)
tensor_data = paddle.to_tensor(np_data)
expect = expect_grad(tensor_data)[0].numpy()
actual = self.cal_composite_grad(np_data)[0]
assert expect.dtype == actual.dtype
np.testing.assert_allclose(
expect,
actual,
rtol=attrs.get_rtol("prim_backward"),
atol=attrs.get_rtol("prim_backward"),
)
def test_prim_backward(self):
for i in self.axes:
for j in self.dtypes:
for t in self.shapes:
attrs.set_axis(i)
attrs.set_dtype(j)
attrs.set_shape(t)
self.compare_backward()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -15,11 +15,13 @@ ...@@ -15,11 +15,13 @@
TOLERANCE = { TOLERANCE = {
"float32": { "float32": {
"forward": {"rtol": 1e-6, "atol": 1e-6},
"backward": {"rtol": 1e-6, "atol": 1e-6},
},
"float64": {
"forward": {"rtol": 1e-7, "atol": 1e-7}, "forward": {"rtol": 1e-7, "atol": 1e-7},
"backward": {"rtol": 1e-7, "atol": 1e-7}, "backward": {"rtol": 1e-7, "atol": 1e-7},
"prim_backward": {"rtol": 1e-6, "atol": 1e-6},
},
"float64": {
"forward": {"rtol": 1e-16, "atol": 1e-16},
"backward": {"rtol": 1e-15, "atol": 1e-15},
"prim_backward": {"rtol": 1e-15, "atol": 1e-15},
}, },
} }
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
import unittest
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
def apply_to_static(net, use_cinn):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(net, build_strategy=build_strategy)
class PrimeNet(paddle.nn.Layer):
def __init__(self):
super(PrimeNet, self).__init__()
self.fc = paddle.nn.Linear(4, 4)
def forward(self, x):
y = self.fc(x)
out = F.softmax(y)
return out
class TestPrimForward(unittest.TestCase):
"""
This case only tests prim_forward + to_static + cinn. Thus we need to
set this flag as False to avoid prim_backward.
core.set_prim_backward(False)
"""
def setUp(self):
core.set_prim_backward(False)
paddle.seed(2022)
self.x = paddle.randn([2, 4])
self.x.stop_gradient = False
def train(self, use_prim):
paddle.seed(2022)
net = PrimeNet()
sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters()
)
if use_prim:
net = apply_to_static(net, use_prim)
res = []
for _ in range(10):
out = net(self.x)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_grad()
res.append(out.numpy())
self.check_prim(net, use_prim)
return res
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
# Ensure that softmax is splitted into small ops
self.assertTrue('softmax' not in fwd_ops)
def test_cinn_prim_forward(self):
dy_res = self.train(use_prim=False)
cinn_res = self.train(use_prim=True)
for i in range(len(dy_res)):
np.testing.assert_allclose(
cinn_res[i], dy_res[i], rtol=1e-7, atol=1e-7
)
class TestPrimForwardAndBackward(unittest.TestCase):
"""
Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph
"""
def setUp(self):
paddle.seed(2022)
self.x = paddle.randn([2, 4])
self.x.stop_gradient = False
def train(self, use_prim):
core.set_prim_backward(True)
paddle.seed(2022)
net = PrimeNet()
sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters()
)
if use_prim:
net = apply_to_static(net, use_prim)
res = []
for _ in range(10):
out = net(self.x)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_grad()
res.append(out.numpy())
self.check_prim(net, use_prim)
return res
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
# Ensure that softmax is splitted into small ops
self.assertTrue('softmax' not in fwd_ops)
def test_cinn_prim(self):
plat = platform.system()
if plat == "Linux":
dy_res = self.train(use_prim=False)
cinn_res = self.train(use_prim=True)
for i in range(len(dy_res)):
np.testing.assert_allclose(
cinn_res[i], dy_res[i], rtol=1e-6, atol=1e-6
)
else:
pass
if __name__ == '__main__':
unittest.main()
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import math import math
import os import os
import platform
import tempfile import tempfile
import time import time
import unittest import unittest
...@@ -450,5 +451,67 @@ class TestResnet(unittest.TestCase): ...@@ -450,5 +451,67 @@ class TestResnet(unittest.TestCase):
fluid.set_flags({'FLAGS_use_mkldnn': False}) fluid.set_flags({'FLAGS_use_mkldnn': False})
class TestResnetPrim(unittest.TestCase):
"test prim forward + prim backward + to_static"
def setUp(self):
self.resnet_helper = ResNetHelper()
def train(self, to_static):
paddle.jit.enable_to_static(to_static)
return self.resnet_helper.train(to_static)
def verify_predict(self):
image = np.random.random([1, 3, 224, 224]).astype('float32')
dy_pre = self.resnet_helper.predict_dygraph(image)
st_pre = self.resnet_helper.predict_static(image)
dy_jit_pre = self.resnet_helper.predict_dygraph_jit(image)
predictor_pre = self.resnet_helper.predict_analysis_inference(image)
np.testing.assert_allclose(
dy_pre,
st_pre,
rtol=1e-05,
err_msg='dy_pre:\n {}\n, st_pre: \n{}.'.format(dy_pre, st_pre),
)
np.testing.assert_allclose(
dy_jit_pre,
st_pre,
rtol=1e-05,
err_msg='dy_jit_pre:\n {}\n, st_pre: \n{}.'.format(
dy_jit_pre, st_pre
),
)
np.testing.assert_allclose(
predictor_pre,
st_pre,
rtol=1e-05,
err_msg='predictor_pre:\n {}\n, st_pre: \n{}.'.format(
predictor_pre, st_pre
),
)
def test_resnet_composite(self):
plat = platform.system()
if plat == "Linux":
print("=================== origin resnet ===================")
core.set_prim_enabled(False)
static_loss = self.train(to_static=True)
print("======= resnet with prim forward and backward =======")
core.set_prim_enabled(True)
core.set_prim_forward("debug")
dygraph_loss = self.train(to_static=True)
np.testing.assert_allclose(
static_loss,
dygraph_loss,
rtol=1e-02,
err_msg='static_loss: {} \n dygraph_loss: {}'.format(
static_loss, dygraph_loss
),
)
core.set_prim_enabled(False)
else:
pass
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -29,7 +29,9 @@ def _composite(op, *args): ...@@ -29,7 +29,9 @@ def _composite(op, *args):
@REGISTER_COMPOSITE('softmax') @REGISTER_COMPOSITE('softmax')
def softmax_composite(x, axis): def softmax_composite(x, axis):
"""define composite rule of op softmax""" """define composite rule of op softmax"""
molecular = exp(x) max_temp = max(x, axis, keepdim=True)
denominator = broadcast_to(sum(molecular, axis=axis, keepdim=True), x.shape) max_temp.stop_gradient = True
molecular = exp(x - max_temp)
denominator = sum(molecular, axis=axis, keepdim=True)
res = divide(molecular, denominator) res = divide(molecular, denominator)
return res return res
...@@ -16,7 +16,7 @@ import logging ...@@ -16,7 +16,7 @@ import logging
import typing import typing
import paddle import paddle
from paddle.fluid import backward, framework from paddle.fluid import backward, core, framework
from paddle.incubate.autograd import primx, utils from paddle.incubate.autograd import primx, utils
...@@ -218,13 +218,22 @@ def grad(outputs, inputs, grad_outputs=None): ...@@ -218,13 +218,22 @@ def grad(outputs, inputs, grad_outputs=None):
@framework.static_only @framework.static_only
def to_prim(blocks): def to_prim(blocks):
"""Search nonbasic ops which have be registered composite rules and replace them with primitive ops.""" """Search nonbasic ops which have be registered composite rules and replace them with primitive ops."""
if not core.enable_prim_forward():
return
if isinstance(blocks, paddle.fluid.framework.Block): if isinstance(blocks, paddle.fluid.framework.Block):
logging.info("Atomize composite op to primitive ops begin.") logging.info("Atomize composite op to primitive ops begin.")
primx._lower_composite(blocks) main_program = blocks.program
return
elif isinstance(blocks, typing.Sequence): elif isinstance(blocks, typing.Sequence):
for item in blocks: for item in blocks:
to_prim(item) if not isinstance(item, paddle.fluid.framework.Block):
return raise TypeError(
f"Expect block or sequence of blocks, but sequence contains {type(item)}."
)
main_program = blocks[0].program
else: else:
raise TypeError raise TypeError(
f"Expect block or sequence of blocks, but got {type(blocks)}."
)
with framework.program_guard(main_program):
primx._lower_composite(blocks)
return
...@@ -39,6 +39,8 @@ from paddle.tensor import log1p # noqa: F401 ...@@ -39,6 +39,8 @@ from paddle.tensor import log1p # noqa: F401
from paddle.tensor import logcumsumexp # noqa: F401 from paddle.tensor import logcumsumexp # noqa: F401
from paddle.tensor import logit # noqa: F401 from paddle.tensor import logit # noqa: F401
from paddle.tensor import logsumexp # noqa: F401 from paddle.tensor import logsumexp # noqa: F401
from paddle.tensor import max # noqa: F401
from paddle.tensor import min # noqa: F401
from paddle.tensor import multiply # noqa: F401 from paddle.tensor import multiply # noqa: F401
from paddle.tensor import pow # noqa: F401 from paddle.tensor import pow # noqa: F401
from paddle.tensor import prod # noqa: F401 from paddle.tensor import prod # noqa: F401
...@@ -73,6 +75,8 @@ math_op = [ ...@@ -73,6 +75,8 @@ math_op = [
'logsumexp', 'logsumexp',
'logcumsumexp', 'logcumsumexp',
'logit', 'logit',
'max',
'min',
] ]
trigonometric_op = [ trigonometric_op = [
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import typing
from collections import OrderedDict from collections import OrderedDict
import paddle import paddle
...@@ -575,6 +577,9 @@ def _lower_composite(block, blacklist=[]): ...@@ -575,6 +577,9 @@ def _lower_composite(block, blacklist=[]):
return_list.append(x) return_list.append(x)
return return_list return return_list
if isinstance(block, paddle.fluid.framework.Block):
logging.info("Atomize composite op to primitive ops begin.")
# Step1: Do some preparatory work for lower # Step1: Do some preparatory work for lower
lower_fn = _composite lower_fn = _composite
lookup_fn = lookup_composite lookup_fn = lookup_composite
...@@ -659,6 +664,14 @@ def _lower_composite(block, blacklist=[]): ...@@ -659,6 +664,14 @@ def _lower_composite(block, blacklist=[]):
block.desc._remove_var(var_name.encode()) block.desc._remove_var(var_name.encode())
del block.vars[var_name] del block.vars[var_name]
block._sync_with_cpp() block._sync_with_cpp()
return
elif isinstance(block, typing.Sequence):
for item in block:
_lower_composite(item)
return
else:
raise TypeError
@framework.static_only @framework.static_only
......
...@@ -571,6 +571,12 @@ class PartialProgramLayer: ...@@ -571,6 +571,12 @@ class PartialProgramLayer:
targets.append(program.global_block().var(out.name)) targets.append(program.global_block().var(out.name))
if targets: if targets:
enable_prim = self._build_strategy.build_cinn_pass
if enable_prim and core.enable_prim_backward():
core.set_prim_enabled(True)
backward.gradients(targets=targets, inputs=[])
core.set_prim_enabled(False)
else:
backward.gradients(targets=targets, inputs=[]) backward.gradients(targets=targets, inputs=[])
start_idx = len(main_program.block(0).ops) + 2 * len( start_idx = len(main_program.block(0).ops) + 2 * len(
......
...@@ -18,7 +18,7 @@ import textwrap ...@@ -18,7 +18,7 @@ import textwrap
import threading import threading
import weakref import weakref
from paddle.fluid import _non_static_mode, framework from paddle.fluid import _non_static_mode, core, framework
from paddle.fluid.data_feeder import check_type from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph import layers from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph.base import param_guard, switch_to_static_graph from paddle.fluid.dygraph.base import param_guard, switch_to_static_graph
...@@ -930,6 +930,13 @@ class ConcreteProgram: ...@@ -930,6 +930,13 @@ class ConcreteProgram:
self.function = function self.function = function
self.kwargs = kwargs self.kwargs = kwargs
@switch_to_static_graph
def _to_prim(self):
# TODO(Aurelius84): Fix this cycle import problem
from paddle.incubate.autograd.primapi import to_prim
to_prim(self.main_program.blocks)
@staticmethod @staticmethod
@switch_to_static_graph @switch_to_static_graph
def from_func_spec( def from_func_spec(
...@@ -1083,6 +1090,11 @@ class ProgramCache: ...@@ -1083,6 +1090,11 @@ class ProgramCache:
self._recent_cache_key = None self._recent_cache_key = None
def _build_once(self, cache_key): def _build_once(self, cache_key):
# TODO(Aurelius84): Need a gloabl FLAGS to enable/disable to_prim
enable_prim = cache_key.kwargs['build_strategy'].build_cinn_pass
if enable_prim and core.enable_prim_backward():
core.set_prim_enabled(True)
concrete_program = ConcreteProgram.from_func_spec( concrete_program = ConcreteProgram.from_func_spec(
func_spec=cache_key.function_spec, func_spec=cache_key.function_spec,
input_spec=cache_key.input_args_with_spec, input_spec=cache_key.input_args_with_spec,
...@@ -1090,6 +1102,10 @@ class ProgramCache: ...@@ -1090,6 +1102,10 @@ class ProgramCache:
class_instance=cache_key.class_instance, class_instance=cache_key.class_instance,
**cache_key.kwargs **cache_key.kwargs
) )
if enable_prim or core.enable_prim_forward() == "debug":
concrete_program._to_prim()
core.set_prim_enabled(False)
return concrete_program, partial_program_from(concrete_program) return concrete_program, partial_program_from(concrete_program)
def __getitem__(self, item): def __getitem__(self, item):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册