未验证 提交 b2a10916 编写于 作者: C cyber-pioneer 提交者: GitHub

Merge ops composite into to_static (#49836)

* support @to_static+to_prime+cinn

* fix code logic

* debug4

* debug5

* debug6

* debug7

* debug 8

* debug 9

* debug10

* debug11

* debug11

* debug 12
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
上级 412573f0
...@@ -371,3 +371,37 @@ def set_paddle_lib_path(): ...@@ -371,3 +371,37 @@ def set_paddle_lib_path():
set_paddle_lib_path() set_paddle_lib_path()
def set_prim_forward(value):
"""set flag FLAGS_prim_forward."""
flag = str(value)
if flag.lower() not in ["true", "false", "debug"]:
raise TypeError(f"flag {flag} should be string of bool or 'debug'.")
os.environ["FLAGS_prim_forward"] = flag
return
def enable_prim_forward():
flag = os.getenv("FLAGS_prim_forward", "true").lower()
if flag == "false":
return False
if flag == "debug":
return "debug"
return True
def set_prim_backward(value):
"""set flag FLAGS_prim_backward,"""
flag = str(value)
if flag.lower() not in ["true", "false"]:
raise TypeError(f"flag {flag} should be bool or string of bool.")
os.environ["FLAGS_prim_backward"] = flag
return
def enable_prim_backward():
flag = os.getenv("FLAGS_prim_backward", "true")
if flag.lower() == "false":
return False
return True
...@@ -57,8 +57,7 @@ attrs = Attr() ...@@ -57,8 +57,7 @@ attrs = Attr()
def fn(x): def fn(x):
y = paddle.tan(x) return F.softmax(x, axis=attrs.axis, dtype=attrs.dtype)
return F.softmax(y, axis=attrs.axis, dtype=attrs.dtype)
def expect_forward(inputs): def expect_forward(inputs):
...@@ -81,8 +80,17 @@ class TestCompositeSoftmax(unittest.TestCase): ...@@ -81,8 +80,17 @@ class TestCompositeSoftmax(unittest.TestCase):
) )
y = fn(x) y = fn(x)
blocks = main_program.blocks blocks = main_program.blocks
fwd_ops = [op.type for op in blocks[0].ops]
# Ensure that softmax in original block
self.assertTrue('softmax' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks) paddle.incubate.autograd.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that softmax is splitted into small ops
self.assertTrue('softmax' not in fwd_ops_new)
exe = paddle.static.Executor() exe = paddle.static.Executor()
exe.run(startup_program) exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[y]) res = exe.run(main_program, feed={'x': inputs}, fetch_list=[y])
...@@ -97,7 +105,7 @@ class TestCompositeSoftmax(unittest.TestCase): ...@@ -97,7 +105,7 @@ class TestCompositeSoftmax(unittest.TestCase):
actual = self.cal_composite(np_data)[0] actual = self.cal_composite(np_data)[0]
assert expect.dtype == actual.dtype assert expect.dtype == actual.dtype
assert np.allclose( np.testing.assert_allclose(
expect, expect,
actual, actual,
rtol=attrs.get_rtol("forward"), rtol=attrs.get_rtol("forward"),
......
...@@ -19,6 +19,7 @@ from utils import TOLERANCE ...@@ -19,6 +19,7 @@ from utils import TOLERANCE
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.fluid import core
def generate_data(shape, dtype="float32"): def generate_data(shape, dtype="float32"):
...@@ -57,11 +58,11 @@ attrs = Attr() ...@@ -57,11 +58,11 @@ attrs = Attr()
def fn(x): def fn(x):
y = paddle.tan(x) return F.softmax(x, axis=attrs.axis, dtype=attrs.dtype)
return F.softmax(y, axis=attrs.axis, dtype=attrs.dtype)
def expect_grad(inputs): def expect_grad(inputs):
paddle.disable_static()
inputs.stop_gradient = False inputs.stop_gradient = False
res = fn(inputs) res = fn(inputs)
...@@ -86,8 +87,22 @@ class TestCompositeSoftmax(unittest.TestCase): ...@@ -86,8 +87,22 @@ class TestCompositeSoftmax(unittest.TestCase):
x.stop_gradient = False x.stop_gradient = False
y = fn(x) y = fn(x)
blocks = main_program.blocks blocks = main_program.blocks
fwd_ops = [op.type for op in blocks[0].ops]
# Ensure that softmax in original block
self.assertTrue('softmax' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks) paddle.incubate.autograd.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that softmax is splitted into small ops
self.assertTrue('softmax' not in fwd_ops_new)
z = paddle.static.gradients([y], x) z = paddle.static.gradients([y], x)
fwd_ops_grad = [op.type for op in blocks[0].ops]
# Ensure that softmax_grad not in grad block
self.assertTrue('softmax_grad' not in fwd_ops_grad)
exe = paddle.static.Executor() exe = paddle.static.Executor()
exe.run(startup_program) exe.run(startup_program)
...@@ -103,7 +118,7 @@ class TestCompositeSoftmax(unittest.TestCase): ...@@ -103,7 +118,7 @@ class TestCompositeSoftmax(unittest.TestCase):
actual = self.cal_composite_grad(np_data)[0] actual = self.cal_composite_grad(np_data)[0]
assert expect.dtype == actual.dtype assert expect.dtype == actual.dtype
assert np.allclose( np.testing.assert_allclose(
expect, expect,
actual, actual,
rtol=attrs.get_rtol("backward"), rtol=attrs.get_rtol("backward"),
...@@ -120,5 +135,59 @@ class TestCompositeSoftmax(unittest.TestCase): ...@@ -120,5 +135,59 @@ class TestCompositeSoftmax(unittest.TestCase):
self.compare_backward() self.compare_backward()
class TestCompositeSoftmaxPrimBackward(unittest.TestCase):
"test composite softmax and prim backward"
def setUp(self):
core.set_prim_enabled(True)
self.dtypes = ["float32"]
self.shapes = [[2, 3, 4], [2, 3]]
self.axes = [-1, 0, 1]
def cal_composite_grad(self, inputs):
paddle.enable_static()
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
'x', shape=inputs.shape, dtype=str(inputs.dtype)
)
x.stop_gradient = False
y = fn(x)
blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks)
z = paddle.static.gradients([y], x)
exe = paddle.static.Executor()
exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z])
paddle.disable_static()
return res
def compare_backward(self):
np_data = generate_data(attrs.shape)
tensor_data = paddle.to_tensor(np_data)
expect = expect_grad(tensor_data)[0].numpy()
actual = self.cal_composite_grad(np_data)[0]
assert expect.dtype == actual.dtype
np.testing.assert_allclose(
expect,
actual,
rtol=attrs.get_rtol("prim_backward"),
atol=attrs.get_rtol("prim_backward"),
)
def test_prim_backward(self):
for i in self.axes:
for j in self.dtypes:
for t in self.shapes:
attrs.set_axis(i)
attrs.set_dtype(j)
attrs.set_shape(t)
self.compare_backward()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -15,11 +15,13 @@ ...@@ -15,11 +15,13 @@
TOLERANCE = { TOLERANCE = {
"float32": { "float32": {
"forward": {"rtol": 1e-6, "atol": 1e-6},
"backward": {"rtol": 1e-6, "atol": 1e-6},
},
"float64": {
"forward": {"rtol": 1e-7, "atol": 1e-7}, "forward": {"rtol": 1e-7, "atol": 1e-7},
"backward": {"rtol": 1e-7, "atol": 1e-7}, "backward": {"rtol": 1e-7, "atol": 1e-7},
"prim_backward": {"rtol": 1e-6, "atol": 1e-6},
},
"float64": {
"forward": {"rtol": 1e-16, "atol": 1e-16},
"backward": {"rtol": 1e-15, "atol": 1e-15},
"prim_backward": {"rtol": 1e-15, "atol": 1e-15},
}, },
} }
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
import unittest
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
def apply_to_static(net, use_cinn):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(net, build_strategy=build_strategy)
class PrimeNet(paddle.nn.Layer):
def __init__(self):
super(PrimeNet, self).__init__()
self.fc = paddle.nn.Linear(4, 4)
def forward(self, x):
y = self.fc(x)
out = F.softmax(y)
return out
class TestPrimForward(unittest.TestCase):
"""
This case only tests prim_forward + to_static + cinn. Thus we need to
set this flag as False to avoid prim_backward.
core.set_prim_backward(False)
"""
def setUp(self):
core.set_prim_backward(False)
paddle.seed(2022)
self.x = paddle.randn([2, 4])
self.x.stop_gradient = False
def train(self, use_prim):
paddle.seed(2022)
net = PrimeNet()
sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters()
)
if use_prim:
net = apply_to_static(net, use_prim)
res = []
for _ in range(10):
out = net(self.x)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_grad()
res.append(out.numpy())
self.check_prim(net, use_prim)
return res
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
# Ensure that softmax is splitted into small ops
self.assertTrue('softmax' not in fwd_ops)
def test_cinn_prim_forward(self):
dy_res = self.train(use_prim=False)
cinn_res = self.train(use_prim=True)
for i in range(len(dy_res)):
np.testing.assert_allclose(
cinn_res[i], dy_res[i], rtol=1e-7, atol=1e-7
)
class TestPrimForwardAndBackward(unittest.TestCase):
"""
Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph
"""
def setUp(self):
paddle.seed(2022)
self.x = paddle.randn([2, 4])
self.x.stop_gradient = False
def train(self, use_prim):
core.set_prim_backward(True)
paddle.seed(2022)
net = PrimeNet()
sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters()
)
if use_prim:
net = apply_to_static(net, use_prim)
res = []
for _ in range(10):
out = net(self.x)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_grad()
res.append(out.numpy())
self.check_prim(net, use_prim)
return res
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
# Ensure that softmax is splitted into small ops
self.assertTrue('softmax' not in fwd_ops)
def test_cinn_prim(self):
plat = platform.system()
if plat == "Linux":
dy_res = self.train(use_prim=False)
cinn_res = self.train(use_prim=True)
for i in range(len(dy_res)):
np.testing.assert_allclose(
cinn_res[i], dy_res[i], rtol=1e-6, atol=1e-6
)
else:
pass
if __name__ == '__main__':
unittest.main()
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import math import math
import os import os
import platform
import tempfile import tempfile
import time import time
import unittest import unittest
...@@ -450,5 +451,67 @@ class TestResnet(unittest.TestCase): ...@@ -450,5 +451,67 @@ class TestResnet(unittest.TestCase):
fluid.set_flags({'FLAGS_use_mkldnn': False}) fluid.set_flags({'FLAGS_use_mkldnn': False})
class TestResnetPrim(unittest.TestCase):
"test prim forward + prim backward + to_static"
def setUp(self):
self.resnet_helper = ResNetHelper()
def train(self, to_static):
paddle.jit.enable_to_static(to_static)
return self.resnet_helper.train(to_static)
def verify_predict(self):
image = np.random.random([1, 3, 224, 224]).astype('float32')
dy_pre = self.resnet_helper.predict_dygraph(image)
st_pre = self.resnet_helper.predict_static(image)
dy_jit_pre = self.resnet_helper.predict_dygraph_jit(image)
predictor_pre = self.resnet_helper.predict_analysis_inference(image)
np.testing.assert_allclose(
dy_pre,
st_pre,
rtol=1e-05,
err_msg='dy_pre:\n {}\n, st_pre: \n{}.'.format(dy_pre, st_pre),
)
np.testing.assert_allclose(
dy_jit_pre,
st_pre,
rtol=1e-05,
err_msg='dy_jit_pre:\n {}\n, st_pre: \n{}.'.format(
dy_jit_pre, st_pre
),
)
np.testing.assert_allclose(
predictor_pre,
st_pre,
rtol=1e-05,
err_msg='predictor_pre:\n {}\n, st_pre: \n{}.'.format(
predictor_pre, st_pre
),
)
def test_resnet_composite(self):
plat = platform.system()
if plat == "Linux":
print("=================== origin resnet ===================")
core.set_prim_enabled(False)
static_loss = self.train(to_static=True)
print("======= resnet with prim forward and backward =======")
core.set_prim_enabled(True)
core.set_prim_forward("debug")
dygraph_loss = self.train(to_static=True)
np.testing.assert_allclose(
static_loss,
dygraph_loss,
rtol=1e-02,
err_msg='static_loss: {} \n dygraph_loss: {}'.format(
static_loss, dygraph_loss
),
)
core.set_prim_enabled(False)
else:
pass
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -29,7 +29,9 @@ def _composite(op, *args): ...@@ -29,7 +29,9 @@ def _composite(op, *args):
@REGISTER_COMPOSITE('softmax') @REGISTER_COMPOSITE('softmax')
def softmax_composite(x, axis): def softmax_composite(x, axis):
"""define composite rule of op softmax""" """define composite rule of op softmax"""
molecular = exp(x) max_temp = max(x, axis, keepdim=True)
denominator = broadcast_to(sum(molecular, axis=axis, keepdim=True), x.shape) max_temp.stop_gradient = True
molecular = exp(x - max_temp)
denominator = sum(molecular, axis=axis, keepdim=True)
res = divide(molecular, denominator) res = divide(molecular, denominator)
return res return res
...@@ -16,7 +16,7 @@ import logging ...@@ -16,7 +16,7 @@ import logging
import typing import typing
import paddle import paddle
from paddle.fluid import backward, framework from paddle.fluid import backward, core, framework
from paddle.incubate.autograd import primx, utils from paddle.incubate.autograd import primx, utils
...@@ -218,13 +218,22 @@ def grad(outputs, inputs, grad_outputs=None): ...@@ -218,13 +218,22 @@ def grad(outputs, inputs, grad_outputs=None):
@framework.static_only @framework.static_only
def to_prim(blocks): def to_prim(blocks):
"""Search nonbasic ops which have be registered composite rules and replace them with primitive ops.""" """Search nonbasic ops which have be registered composite rules and replace them with primitive ops."""
if not core.enable_prim_forward():
return
if isinstance(blocks, paddle.fluid.framework.Block): if isinstance(blocks, paddle.fluid.framework.Block):
logging.info("Atomize composite op to primitive ops begin.") logging.info("Atomize composite op to primitive ops begin.")
primx._lower_composite(blocks) main_program = blocks.program
return
elif isinstance(blocks, typing.Sequence): elif isinstance(blocks, typing.Sequence):
for item in blocks: for item in blocks:
to_prim(item) if not isinstance(item, paddle.fluid.framework.Block):
return raise TypeError(
f"Expect block or sequence of blocks, but sequence contains {type(item)}."
)
main_program = blocks[0].program
else: else:
raise TypeError raise TypeError(
f"Expect block or sequence of blocks, but got {type(blocks)}."
)
with framework.program_guard(main_program):
primx._lower_composite(blocks)
return
...@@ -39,6 +39,8 @@ from paddle.tensor import log1p # noqa: F401 ...@@ -39,6 +39,8 @@ from paddle.tensor import log1p # noqa: F401
from paddle.tensor import logcumsumexp # noqa: F401 from paddle.tensor import logcumsumexp # noqa: F401
from paddle.tensor import logit # noqa: F401 from paddle.tensor import logit # noqa: F401
from paddle.tensor import logsumexp # noqa: F401 from paddle.tensor import logsumexp # noqa: F401
from paddle.tensor import max # noqa: F401
from paddle.tensor import min # noqa: F401
from paddle.tensor import multiply # noqa: F401 from paddle.tensor import multiply # noqa: F401
from paddle.tensor import pow # noqa: F401 from paddle.tensor import pow # noqa: F401
from paddle.tensor import prod # noqa: F401 from paddle.tensor import prod # noqa: F401
...@@ -73,6 +75,8 @@ math_op = [ ...@@ -73,6 +75,8 @@ math_op = [
'logsumexp', 'logsumexp',
'logcumsumexp', 'logcumsumexp',
'logit', 'logit',
'max',
'min',
] ]
trigonometric_op = [ trigonometric_op = [
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import typing
from collections import OrderedDict from collections import OrderedDict
import paddle import paddle
...@@ -575,90 +577,101 @@ def _lower_composite(block, blacklist=[]): ...@@ -575,90 +577,101 @@ def _lower_composite(block, blacklist=[]):
return_list.append(x) return_list.append(x)
return return_list return return_list
# Step1: Do some preparatory work for lower if isinstance(block, paddle.fluid.framework.Block):
lower_fn = _composite logging.info("Atomize composite op to primitive ops begin.")
lookup_fn = lookup_composite
# Step1: Do some preparatory work for lower
value_table = {} lower_fn = _composite
to_bind = {} lookup_fn = lookup_composite
to_bind_rev = {}
for var in block.desc.all_vars(): value_table = {}
value_table[var.name()] = block.var(var.name()) to_bind = {}
to_bind_rev = {}
ops_to_remove = [] for var in block.desc.all_vars():
vars_to_remove = set() value_table[var.name()] = block.var(var.name())
# Step2: Process all ops in the target block ops_to_remove = []
for op_idx in range(len(block.ops)): vars_to_remove = set()
op = block.ops[op_idx]
ops_to_remove.append(op_idx) # Step2: Process all ops in the target block
if lookup_fn(op.type) is not None and op.type not in blacklist: for op_idx in range(len(block.ops)):
input_args = prepare_python_api_arguments(op) op = block.ops[op_idx]
bind(input_args, to_bind, value_table) ops_to_remove.append(op_idx)
if lookup_fn(op.type) is not None and op.type not in blacklist:
for orig_out, new_out in zip( input_args = prepare_python_api_arguments(op)
expand_nested_list(get_output_var_list(op)), bind(input_args, to_bind, value_table)
expand_nested_list(as_tensors(lower_fn(op, *input_args))),
): for orig_out, new_out in zip(
assert not (orig_out is None) ^ ( expand_nested_list(get_output_var_list(op)),
new_out is None expand_nested_list(as_tensors(lower_fn(op, *input_args))),
), "orig_out and new_out should match." ):
vars_to_remove.add(new_out.name) assert not (orig_out is None) ^ (
value_table[new_out.name] = new_out new_out is None
to_bind[orig_out.name] = new_out.name ), "orig_out and new_out should match."
to_bind_rev[new_out.name] = orig_out.name vars_to_remove.add(new_out.name)
else: value_table[new_out.name] = new_out
inputs = {} to_bind[orig_out.name] = new_out.name
for i in range(len(op.input_names)): to_bind_rev[new_out.name] = orig_out.name
inputs[op.input_names[i]] = bind_name( else:
op.input(op.input_names[i]), to_bind inputs = {}
) for i in range(len(op.input_names)):
inputs[op.input_names[i]] = bind_name(
outputs = {} op.input(op.input_names[i]), to_bind
for i in range(len(op.output_names)): )
outputs[op.output_names[i]] = op.output(op.output_names[i])
outputs = {}
attrs = {} for i in range(len(op.output_names)):
for name in sorted(op.attr_names): outputs[op.output_names[i]] = op.output(op.output_names[i])
attrs[name] = op.attr(name)
from paddle.fluid.dygraph.base import param_guard attrs = {}
for name in sorted(op.attr_names):
new_op_desc = block.desc.append_op() attrs[name] = op.attr(name)
with param_guard(inputs), param_guard(outputs): from paddle.fluid.dygraph.base import param_guard
op = Operator(
block=block, new_op_desc = block.desc.append_op()
desc=new_op_desc, with param_guard(inputs), param_guard(outputs):
type=op.type, op = Operator(
inputs=inputs, block=block,
outputs=outputs, desc=new_op_desc,
attrs=attrs, type=op.type,
) inputs=inputs,
block.ops.append(op) outputs=outputs,
attrs=attrs,
# Step3: Do some post-processing work )
for op_idx in reversed(ops_to_remove): block.ops.append(op)
block.desc._remove_op(op_idx, op_idx + 1)
del block.ops[op_idx] # Step3: Do some post-processing work
block._sync_with_cpp() for op_idx in reversed(ops_to_remove):
block.desc._remove_op(op_idx, op_idx + 1)
for op_idx in range(len(block.ops)): del block.ops[op_idx]
op = block.ops[op_idx] block._sync_with_cpp()
for in_name in op.input_arg_names:
if in_name in to_bind_rev:
op._rename_input(in_name, to_bind_rev[in_name])
for out_name in op.output_arg_names:
if out_name in to_bind_rev:
op._rename_output(out_name, to_bind_rev[out_name])
for var_name in sorted(vars_to_remove): for op_idx in range(len(block.ops)):
assert ( op = block.ops[op_idx]
var_name in to_bind_rev for in_name in op.input_arg_names:
), 'var_name "{}" is not in to_bind_rev.'.format(var_name) if in_name in to_bind_rev:
if var_name != to_bind_rev[var_name]: op._rename_input(in_name, to_bind_rev[in_name])
block.desc._remove_var(var_name.encode())
del block.vars[var_name] for out_name in op.output_arg_names:
block._sync_with_cpp() if out_name in to_bind_rev:
op._rename_output(out_name, to_bind_rev[out_name])
for var_name in sorted(vars_to_remove):
assert (
var_name in to_bind_rev
), 'var_name "{}" is not in to_bind_rev.'.format(var_name)
if var_name != to_bind_rev[var_name]:
block.desc._remove_var(var_name.encode())
del block.vars[var_name]
block._sync_with_cpp()
return
elif isinstance(block, typing.Sequence):
for item in block:
_lower_composite(item)
return
else:
raise TypeError
@framework.static_only @framework.static_only
......
...@@ -571,7 +571,13 @@ class PartialProgramLayer: ...@@ -571,7 +571,13 @@ class PartialProgramLayer:
targets.append(program.global_block().var(out.name)) targets.append(program.global_block().var(out.name))
if targets: if targets:
backward.gradients(targets=targets, inputs=[]) enable_prim = self._build_strategy.build_cinn_pass
if enable_prim and core.enable_prim_backward():
core.set_prim_enabled(True)
backward.gradients(targets=targets, inputs=[])
core.set_prim_enabled(False)
else:
backward.gradients(targets=targets, inputs=[])
start_idx = len(main_program.block(0).ops) + 2 * len( start_idx = len(main_program.block(0).ops) + 2 * len(
self._outputs.tolist() self._outputs.tolist()
......
...@@ -18,7 +18,7 @@ import textwrap ...@@ -18,7 +18,7 @@ import textwrap
import threading import threading
import weakref import weakref
from paddle.fluid import _non_static_mode, framework from paddle.fluid import _non_static_mode, core, framework
from paddle.fluid.data_feeder import check_type from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph import layers from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph.base import param_guard, switch_to_static_graph from paddle.fluid.dygraph.base import param_guard, switch_to_static_graph
...@@ -930,6 +930,13 @@ class ConcreteProgram: ...@@ -930,6 +930,13 @@ class ConcreteProgram:
self.function = function self.function = function
self.kwargs = kwargs self.kwargs = kwargs
@switch_to_static_graph
def _to_prim(self):
# TODO(Aurelius84): Fix this cycle import problem
from paddle.incubate.autograd.primapi import to_prim
to_prim(self.main_program.blocks)
@staticmethod @staticmethod
@switch_to_static_graph @switch_to_static_graph
def from_func_spec( def from_func_spec(
...@@ -1083,6 +1090,11 @@ class ProgramCache: ...@@ -1083,6 +1090,11 @@ class ProgramCache:
self._recent_cache_key = None self._recent_cache_key = None
def _build_once(self, cache_key): def _build_once(self, cache_key):
# TODO(Aurelius84): Need a gloabl FLAGS to enable/disable to_prim
enable_prim = cache_key.kwargs['build_strategy'].build_cinn_pass
if enable_prim and core.enable_prim_backward():
core.set_prim_enabled(True)
concrete_program = ConcreteProgram.from_func_spec( concrete_program = ConcreteProgram.from_func_spec(
func_spec=cache_key.function_spec, func_spec=cache_key.function_spec,
input_spec=cache_key.input_args_with_spec, input_spec=cache_key.input_args_with_spec,
...@@ -1090,6 +1102,10 @@ class ProgramCache: ...@@ -1090,6 +1102,10 @@ class ProgramCache:
class_instance=cache_key.class_instance, class_instance=cache_key.class_instance,
**cache_key.kwargs **cache_key.kwargs
) )
if enable_prim or core.enable_prim_forward() == "debug":
concrete_program._to_prim()
core.set_prim_enabled(False)
return concrete_program, partial_program_from(concrete_program) return concrete_program, partial_program_from(concrete_program)
def __getitem__(self, item): def __getitem__(self, item):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册