From b2a1091663e279fc806e8c83ef4f7bb05e75fc83 Mon Sep 17 00:00:00 2001 From: cyber-pioneer <116002591+cyber-pioneer@users.noreply.github.com> Date: Tue, 17 Jan 2023 14:36:33 +0800 Subject: [PATCH] Merge ops composite into to_static (#49836) * support @to_static+to_prime+cinn * fix code logic * debug4 * debug5 * debug6 * debug7 * debug 8 * debug 9 * debug10 * debug11 * debug11 * debug 12 Co-authored-by: Aurelius84 --- python/paddle/fluid/core.py | 34 ++++ .../composite_ops/test_composite_softmax.py | 14 +- .../test_composite_softmax_grad.py | 75 +++++++- .../tests/unittests/composite_ops/utils.py | 10 +- .../dygraph_to_static/test_cinn_prim.py | 151 +++++++++++++++ .../dygraph_to_static/test_resnet.py | 63 ++++++ .../incubate/autograd/composite_rules.py | 6 +- python/paddle/incubate/autograd/primapi.py | 21 +- python/paddle/incubate/autograd/primitives.py | 4 + python/paddle/incubate/autograd/primx.py | 179 ++++++++++-------- .../paddle/jit/dy2static/partial_program.py | 8 +- .../jit/dy2static/program_translator.py | 18 +- 12 files changed, 480 insertions(+), 103 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py diff --git a/python/paddle/fluid/core.py b/python/paddle/fluid/core.py index 299ca1e3661..09e079ca583 100644 --- a/python/paddle/fluid/core.py +++ b/python/paddle/fluid/core.py @@ -371,3 +371,37 @@ def set_paddle_lib_path(): set_paddle_lib_path() + + +def set_prim_forward(value): + """set flag FLAGS_prim_forward.""" + flag = str(value) + if flag.lower() not in ["true", "false", "debug"]: + raise TypeError(f"flag {flag} should be string of bool or 'debug'.") + os.environ["FLAGS_prim_forward"] = flag + return + + +def enable_prim_forward(): + flag = os.getenv("FLAGS_prim_forward", "true").lower() + if flag == "false": + return False + if flag == "debug": + return "debug" + return True + + +def set_prim_backward(value): + """set flag FLAGS_prim_backward,""" + flag = str(value) + if flag.lower() not in ["true", "false"]: + raise TypeError(f"flag {flag} should be bool or string of bool.") + os.environ["FLAGS_prim_backward"] = flag + return + + +def enable_prim_backward(): + flag = os.getenv("FLAGS_prim_backward", "true") + if flag.lower() == "false": + return False + return True diff --git a/python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax.py b/python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax.py index 2ac671962c0..c7c876b8f8f 100644 --- a/python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax.py +++ b/python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax.py @@ -57,8 +57,7 @@ attrs = Attr() def fn(x): - y = paddle.tan(x) - return F.softmax(y, axis=attrs.axis, dtype=attrs.dtype) + return F.softmax(x, axis=attrs.axis, dtype=attrs.dtype) def expect_forward(inputs): @@ -81,8 +80,17 @@ class TestCompositeSoftmax(unittest.TestCase): ) y = fn(x) blocks = main_program.blocks + + fwd_ops = [op.type for op in blocks[0].ops] + # Ensure that softmax in original block + self.assertTrue('softmax' in fwd_ops) + paddle.incubate.autograd.to_prim(blocks) + fwd_ops_new = [op.type for op in blocks[0].ops] + # Ensure that softmax is splitted into small ops + self.assertTrue('softmax' not in fwd_ops_new) + exe = paddle.static.Executor() exe.run(startup_program) res = exe.run(main_program, feed={'x': inputs}, fetch_list=[y]) @@ -97,7 +105,7 @@ class TestCompositeSoftmax(unittest.TestCase): actual = self.cal_composite(np_data)[0] assert expect.dtype == actual.dtype - assert np.allclose( + np.testing.assert_allclose( expect, actual, rtol=attrs.get_rtol("forward"), diff --git a/python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax_grad.py b/python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax_grad.py index c47399ba5a9..808c5f8324b 100644 --- a/python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax_grad.py +++ b/python/paddle/fluid/tests/unittests/composite_ops/test_composite_softmax_grad.py @@ -19,6 +19,7 @@ from utils import TOLERANCE import paddle import paddle.nn.functional as F +from paddle.fluid import core def generate_data(shape, dtype="float32"): @@ -57,11 +58,11 @@ attrs = Attr() def fn(x): - y = paddle.tan(x) - return F.softmax(y, axis=attrs.axis, dtype=attrs.dtype) + return F.softmax(x, axis=attrs.axis, dtype=attrs.dtype) def expect_grad(inputs): + paddle.disable_static() inputs.stop_gradient = False res = fn(inputs) @@ -86,8 +87,22 @@ class TestCompositeSoftmax(unittest.TestCase): x.stop_gradient = False y = fn(x) blocks = main_program.blocks + + fwd_ops = [op.type for op in blocks[0].ops] + # Ensure that softmax in original block + self.assertTrue('softmax' in fwd_ops) + paddle.incubate.autograd.to_prim(blocks) + + fwd_ops_new = [op.type for op in blocks[0].ops] + # Ensure that softmax is splitted into small ops + self.assertTrue('softmax' not in fwd_ops_new) + z = paddle.static.gradients([y], x) + fwd_ops_grad = [op.type for op in blocks[0].ops] + # Ensure that softmax_grad not in grad block + + self.assertTrue('softmax_grad' not in fwd_ops_grad) exe = paddle.static.Executor() exe.run(startup_program) @@ -103,7 +118,7 @@ class TestCompositeSoftmax(unittest.TestCase): actual = self.cal_composite_grad(np_data)[0] assert expect.dtype == actual.dtype - assert np.allclose( + np.testing.assert_allclose( expect, actual, rtol=attrs.get_rtol("backward"), @@ -120,5 +135,59 @@ class TestCompositeSoftmax(unittest.TestCase): self.compare_backward() +class TestCompositeSoftmaxPrimBackward(unittest.TestCase): + "test composite softmax and prim backward" + + def setUp(self): + core.set_prim_enabled(True) + self.dtypes = ["float32"] + self.shapes = [[2, 3, 4], [2, 3]] + self.axes = [-1, 0, 1] + + def cal_composite_grad(self, inputs): + paddle.enable_static() + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + 'x', shape=inputs.shape, dtype=str(inputs.dtype) + ) + x.stop_gradient = False + y = fn(x) + blocks = main_program.blocks + paddle.incubate.autograd.to_prim(blocks) + z = paddle.static.gradients([y], x) + + exe = paddle.static.Executor() + exe.run(startup_program) + res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z]) + paddle.disable_static() + return res + + def compare_backward(self): + np_data = generate_data(attrs.shape) + tensor_data = paddle.to_tensor(np_data) + + expect = expect_grad(tensor_data)[0].numpy() + actual = self.cal_composite_grad(np_data)[0] + + assert expect.dtype == actual.dtype + np.testing.assert_allclose( + expect, + actual, + rtol=attrs.get_rtol("prim_backward"), + atol=attrs.get_rtol("prim_backward"), + ) + + def test_prim_backward(self): + for i in self.axes: + for j in self.dtypes: + for t in self.shapes: + attrs.set_axis(i) + attrs.set_dtype(j) + attrs.set_shape(t) + self.compare_backward() + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/composite_ops/utils.py b/python/paddle/fluid/tests/unittests/composite_ops/utils.py index c43f79d1c05..798da50a1c4 100644 --- a/python/paddle/fluid/tests/unittests/composite_ops/utils.py +++ b/python/paddle/fluid/tests/unittests/composite_ops/utils.py @@ -15,11 +15,13 @@ TOLERANCE = { "float32": { - "forward": {"rtol": 1e-6, "atol": 1e-6}, - "backward": {"rtol": 1e-6, "atol": 1e-6}, - }, - "float64": { "forward": {"rtol": 1e-7, "atol": 1e-7}, "backward": {"rtol": 1e-7, "atol": 1e-7}, + "prim_backward": {"rtol": 1e-6, "atol": 1e-6}, + }, + "float64": { + "forward": {"rtol": 1e-16, "atol": 1e-16}, + "backward": {"rtol": 1e-15, "atol": 1e-15}, + "prim_backward": {"rtol": 1e-15, "atol": 1e-15}, }, } diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py new file mode 100644 index 00000000000..2811a348f46 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py @@ -0,0 +1,151 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import platform +import unittest + +import numpy as np + +import paddle +import paddle.nn.functional as F +from paddle.fluid import core + + +def apply_to_static(net, use_cinn): + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = use_cinn + return paddle.jit.to_static(net, build_strategy=build_strategy) + + +class PrimeNet(paddle.nn.Layer): + def __init__(self): + super(PrimeNet, self).__init__() + self.fc = paddle.nn.Linear(4, 4) + + def forward(self, x): + y = self.fc(x) + out = F.softmax(y) + return out + + +class TestPrimForward(unittest.TestCase): + """ + This case only tests prim_forward + to_static + cinn. Thus we need to + set this flag as False to avoid prim_backward. + core.set_prim_backward(False) + """ + + def setUp(self): + core.set_prim_backward(False) + paddle.seed(2022) + self.x = paddle.randn([2, 4]) + self.x.stop_gradient = False + + def train(self, use_prim): + paddle.seed(2022) + net = PrimeNet() + sgd = paddle.optimizer.SGD( + learning_rate=0.1, parameters=net.parameters() + ) + if use_prim: + net = apply_to_static(net, use_prim) + + res = [] + for _ in range(10): + out = net(self.x) + loss = paddle.mean(out) + loss.backward() + sgd.step() + sgd.clear_grad() + + res.append(out.numpy()) + + self.check_prim(net, use_prim) + + return res + + def check_prim(self, net, use_prim): + if not use_prim: + return + fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] + # Ensure that softmax is splitted into small ops + self.assertTrue('softmax' not in fwd_ops) + + def test_cinn_prim_forward(self): + dy_res = self.train(use_prim=False) + cinn_res = self.train(use_prim=True) + + for i in range(len(dy_res)): + np.testing.assert_allclose( + cinn_res[i], dy_res[i], rtol=1e-7, atol=1e-7 + ) + + +class TestPrimForwardAndBackward(unittest.TestCase): + """ + Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph + """ + + def setUp(self): + paddle.seed(2022) + self.x = paddle.randn([2, 4]) + self.x.stop_gradient = False + + def train(self, use_prim): + core.set_prim_backward(True) + paddle.seed(2022) + net = PrimeNet() + sgd = paddle.optimizer.SGD( + learning_rate=0.1, parameters=net.parameters() + ) + if use_prim: + net = apply_to_static(net, use_prim) + + res = [] + for _ in range(10): + out = net(self.x) + loss = paddle.mean(out) + loss.backward() + sgd.step() + sgd.clear_grad() + + res.append(out.numpy()) + + self.check_prim(net, use_prim) + + return res + + def check_prim(self, net, use_prim): + if not use_prim: + return + fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] + # Ensure that softmax is splitted into small ops + self.assertTrue('softmax' not in fwd_ops) + + def test_cinn_prim(self): + plat = platform.system() + if plat == "Linux": + dy_res = self.train(use_prim=False) + cinn_res = self.train(use_prim=True) + + for i in range(len(dy_res)): + np.testing.assert_allclose( + cinn_res[i], dy_res[i], rtol=1e-6, atol=1e-6 + ) + else: + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py index 40919edbce6..b195c7d342a 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py @@ -14,6 +14,7 @@ import math import os +import platform import tempfile import time import unittest @@ -450,5 +451,67 @@ class TestResnet(unittest.TestCase): fluid.set_flags({'FLAGS_use_mkldnn': False}) +class TestResnetPrim(unittest.TestCase): + "test prim forward + prim backward + to_static" + + def setUp(self): + self.resnet_helper = ResNetHelper() + + def train(self, to_static): + paddle.jit.enable_to_static(to_static) + return self.resnet_helper.train(to_static) + + def verify_predict(self): + image = np.random.random([1, 3, 224, 224]).astype('float32') + dy_pre = self.resnet_helper.predict_dygraph(image) + st_pre = self.resnet_helper.predict_static(image) + dy_jit_pre = self.resnet_helper.predict_dygraph_jit(image) + predictor_pre = self.resnet_helper.predict_analysis_inference(image) + np.testing.assert_allclose( + dy_pre, + st_pre, + rtol=1e-05, + err_msg='dy_pre:\n {}\n, st_pre: \n{}.'.format(dy_pre, st_pre), + ) + np.testing.assert_allclose( + dy_jit_pre, + st_pre, + rtol=1e-05, + err_msg='dy_jit_pre:\n {}\n, st_pre: \n{}.'.format( + dy_jit_pre, st_pre + ), + ) + np.testing.assert_allclose( + predictor_pre, + st_pre, + rtol=1e-05, + err_msg='predictor_pre:\n {}\n, st_pre: \n{}.'.format( + predictor_pre, st_pre + ), + ) + + def test_resnet_composite(self): + plat = platform.system() + if plat == "Linux": + print("=================== origin resnet ===================") + core.set_prim_enabled(False) + static_loss = self.train(to_static=True) + print("======= resnet with prim forward and backward =======") + core.set_prim_enabled(True) + core.set_prim_forward("debug") + dygraph_loss = self.train(to_static=True) + np.testing.assert_allclose( + static_loss, + dygraph_loss, + rtol=1e-02, + err_msg='static_loss: {} \n dygraph_loss: {}'.format( + static_loss, dygraph_loss + ), + ) + core.set_prim_enabled(False) + else: + pass + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index 456ac20db2e..23bf8f0f7e3 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -29,7 +29,9 @@ def _composite(op, *args): @REGISTER_COMPOSITE('softmax') def softmax_composite(x, axis): """define composite rule of op softmax""" - molecular = exp(x) - denominator = broadcast_to(sum(molecular, axis=axis, keepdim=True), x.shape) + max_temp = max(x, axis, keepdim=True) + max_temp.stop_gradient = True + molecular = exp(x - max_temp) + denominator = sum(molecular, axis=axis, keepdim=True) res = divide(molecular, denominator) return res diff --git a/python/paddle/incubate/autograd/primapi.py b/python/paddle/incubate/autograd/primapi.py index 7cfabdd9e55..76e08021942 100644 --- a/python/paddle/incubate/autograd/primapi.py +++ b/python/paddle/incubate/autograd/primapi.py @@ -16,7 +16,7 @@ import logging import typing import paddle -from paddle.fluid import backward, framework +from paddle.fluid import backward, core, framework from paddle.incubate.autograd import primx, utils @@ -218,13 +218,22 @@ def grad(outputs, inputs, grad_outputs=None): @framework.static_only def to_prim(blocks): """Search nonbasic ops which have be registered composite rules and replace them with primitive ops.""" + if not core.enable_prim_forward(): + return if isinstance(blocks, paddle.fluid.framework.Block): logging.info("Atomize composite op to primitive ops begin.") - primx._lower_composite(blocks) - return + main_program = blocks.program elif isinstance(blocks, typing.Sequence): for item in blocks: - to_prim(item) - return + if not isinstance(item, paddle.fluid.framework.Block): + raise TypeError( + f"Expect block or sequence of blocks, but sequence contains {type(item)}." + ) + main_program = blocks[0].program else: - raise TypeError + raise TypeError( + f"Expect block or sequence of blocks, but got {type(blocks)}." + ) + with framework.program_guard(main_program): + primx._lower_composite(blocks) + return diff --git a/python/paddle/incubate/autograd/primitives.py b/python/paddle/incubate/autograd/primitives.py index 371746bf349..a9ec324c05a 100644 --- a/python/paddle/incubate/autograd/primitives.py +++ b/python/paddle/incubate/autograd/primitives.py @@ -39,6 +39,8 @@ from paddle.tensor import log1p # noqa: F401 from paddle.tensor import logcumsumexp # noqa: F401 from paddle.tensor import logit # noqa: F401 from paddle.tensor import logsumexp # noqa: F401 +from paddle.tensor import max # noqa: F401 +from paddle.tensor import min # noqa: F401 from paddle.tensor import multiply # noqa: F401 from paddle.tensor import pow # noqa: F401 from paddle.tensor import prod # noqa: F401 @@ -73,6 +75,8 @@ math_op = [ 'logsumexp', 'logcumsumexp', 'logit', + 'max', + 'min', ] trigonometric_op = [ diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index 6f2d4d9d521..c472137ab71 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import typing from collections import OrderedDict import paddle @@ -575,90 +577,101 @@ def _lower_composite(block, blacklist=[]): return_list.append(x) return return_list - # Step1: Do some preparatory work for lower - lower_fn = _composite - lookup_fn = lookup_composite - - value_table = {} - to_bind = {} - to_bind_rev = {} - for var in block.desc.all_vars(): - value_table[var.name()] = block.var(var.name()) - - ops_to_remove = [] - vars_to_remove = set() - - # Step2: Process all ops in the target block - for op_idx in range(len(block.ops)): - op = block.ops[op_idx] - ops_to_remove.append(op_idx) - if lookup_fn(op.type) is not None and op.type not in blacklist: - input_args = prepare_python_api_arguments(op) - bind(input_args, to_bind, value_table) - - for orig_out, new_out in zip( - expand_nested_list(get_output_var_list(op)), - expand_nested_list(as_tensors(lower_fn(op, *input_args))), - ): - assert not (orig_out is None) ^ ( - new_out is None - ), "orig_out and new_out should match." - vars_to_remove.add(new_out.name) - value_table[new_out.name] = new_out - to_bind[orig_out.name] = new_out.name - to_bind_rev[new_out.name] = orig_out.name - else: - inputs = {} - for i in range(len(op.input_names)): - inputs[op.input_names[i]] = bind_name( - op.input(op.input_names[i]), to_bind - ) - - outputs = {} - for i in range(len(op.output_names)): - outputs[op.output_names[i]] = op.output(op.output_names[i]) - - attrs = {} - for name in sorted(op.attr_names): - attrs[name] = op.attr(name) - from paddle.fluid.dygraph.base import param_guard - - new_op_desc = block.desc.append_op() - with param_guard(inputs), param_guard(outputs): - op = Operator( - block=block, - desc=new_op_desc, - type=op.type, - inputs=inputs, - outputs=outputs, - attrs=attrs, - ) - block.ops.append(op) - - # Step3: Do some post-processing work - for op_idx in reversed(ops_to_remove): - block.desc._remove_op(op_idx, op_idx + 1) - del block.ops[op_idx] - block._sync_with_cpp() - - for op_idx in range(len(block.ops)): - op = block.ops[op_idx] - for in_name in op.input_arg_names: - if in_name in to_bind_rev: - op._rename_input(in_name, to_bind_rev[in_name]) - - for out_name in op.output_arg_names: - if out_name in to_bind_rev: - op._rename_output(out_name, to_bind_rev[out_name]) + if isinstance(block, paddle.fluid.framework.Block): + logging.info("Atomize composite op to primitive ops begin.") + + # Step1: Do some preparatory work for lower + lower_fn = _composite + lookup_fn = lookup_composite + + value_table = {} + to_bind = {} + to_bind_rev = {} + for var in block.desc.all_vars(): + value_table[var.name()] = block.var(var.name()) + + ops_to_remove = [] + vars_to_remove = set() + + # Step2: Process all ops in the target block + for op_idx in range(len(block.ops)): + op = block.ops[op_idx] + ops_to_remove.append(op_idx) + if lookup_fn(op.type) is not None and op.type not in blacklist: + input_args = prepare_python_api_arguments(op) + bind(input_args, to_bind, value_table) + + for orig_out, new_out in zip( + expand_nested_list(get_output_var_list(op)), + expand_nested_list(as_tensors(lower_fn(op, *input_args))), + ): + assert not (orig_out is None) ^ ( + new_out is None + ), "orig_out and new_out should match." + vars_to_remove.add(new_out.name) + value_table[new_out.name] = new_out + to_bind[orig_out.name] = new_out.name + to_bind_rev[new_out.name] = orig_out.name + else: + inputs = {} + for i in range(len(op.input_names)): + inputs[op.input_names[i]] = bind_name( + op.input(op.input_names[i]), to_bind + ) + + outputs = {} + for i in range(len(op.output_names)): + outputs[op.output_names[i]] = op.output(op.output_names[i]) + + attrs = {} + for name in sorted(op.attr_names): + attrs[name] = op.attr(name) + from paddle.fluid.dygraph.base import param_guard + + new_op_desc = block.desc.append_op() + with param_guard(inputs), param_guard(outputs): + op = Operator( + block=block, + desc=new_op_desc, + type=op.type, + inputs=inputs, + outputs=outputs, + attrs=attrs, + ) + block.ops.append(op) + + # Step3: Do some post-processing work + for op_idx in reversed(ops_to_remove): + block.desc._remove_op(op_idx, op_idx + 1) + del block.ops[op_idx] + block._sync_with_cpp() - for var_name in sorted(vars_to_remove): - assert ( - var_name in to_bind_rev - ), 'var_name "{}" is not in to_bind_rev.'.format(var_name) - if var_name != to_bind_rev[var_name]: - block.desc._remove_var(var_name.encode()) - del block.vars[var_name] - block._sync_with_cpp() + for op_idx in range(len(block.ops)): + op = block.ops[op_idx] + for in_name in op.input_arg_names: + if in_name in to_bind_rev: + op._rename_input(in_name, to_bind_rev[in_name]) + + for out_name in op.output_arg_names: + if out_name in to_bind_rev: + op._rename_output(out_name, to_bind_rev[out_name]) + + for var_name in sorted(vars_to_remove): + assert ( + var_name in to_bind_rev + ), 'var_name "{}" is not in to_bind_rev.'.format(var_name) + if var_name != to_bind_rev[var_name]: + block.desc._remove_var(var_name.encode()) + del block.vars[var_name] + block._sync_with_cpp() + return + + elif isinstance(block, typing.Sequence): + for item in block: + _lower_composite(item) + return + else: + raise TypeError @framework.static_only diff --git a/python/paddle/jit/dy2static/partial_program.py b/python/paddle/jit/dy2static/partial_program.py index 44478604781..701765cc731 100644 --- a/python/paddle/jit/dy2static/partial_program.py +++ b/python/paddle/jit/dy2static/partial_program.py @@ -571,7 +571,13 @@ class PartialProgramLayer: targets.append(program.global_block().var(out.name)) if targets: - backward.gradients(targets=targets, inputs=[]) + enable_prim = self._build_strategy.build_cinn_pass + if enable_prim and core.enable_prim_backward(): + core.set_prim_enabled(True) + backward.gradients(targets=targets, inputs=[]) + core.set_prim_enabled(False) + else: + backward.gradients(targets=targets, inputs=[]) start_idx = len(main_program.block(0).ops) + 2 * len( self._outputs.tolist() diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index 7fd6b0ce7fe..5b8493977e9 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -18,7 +18,7 @@ import textwrap import threading import weakref -from paddle.fluid import _non_static_mode, framework +from paddle.fluid import _non_static_mode, core, framework from paddle.fluid.data_feeder import check_type from paddle.fluid.dygraph import layers from paddle.fluid.dygraph.base import param_guard, switch_to_static_graph @@ -930,6 +930,13 @@ class ConcreteProgram: self.function = function self.kwargs = kwargs + @switch_to_static_graph + def _to_prim(self): + # TODO(Aurelius84): Fix this cycle import problem + from paddle.incubate.autograd.primapi import to_prim + + to_prim(self.main_program.blocks) + @staticmethod @switch_to_static_graph def from_func_spec( @@ -1083,6 +1090,11 @@ class ProgramCache: self._recent_cache_key = None def _build_once(self, cache_key): + # TODO(Aurelius84): Need a gloabl FLAGS to enable/disable to_prim + enable_prim = cache_key.kwargs['build_strategy'].build_cinn_pass + if enable_prim and core.enable_prim_backward(): + core.set_prim_enabled(True) + concrete_program = ConcreteProgram.from_func_spec( func_spec=cache_key.function_spec, input_spec=cache_key.input_args_with_spec, @@ -1090,6 +1102,10 @@ class ProgramCache: class_instance=cache_key.class_instance, **cache_key.kwargs ) + + if enable_prim or core.enable_prim_forward() == "debug": + concrete_program._to_prim() + core.set_prim_enabled(False) return concrete_program, partial_program_from(concrete_program) def __getitem__(self, item): -- GitLab