提交 ece6837f 编写于 作者: C cxxly 提交者: Xiaoxu Chen

[prim] enable dygraph_to_static to support custom_vjp

上级 ecc842f1
...@@ -952,13 +952,13 @@ void dropout_grad(const Tensor& mask, ...@@ -952,13 +952,13 @@ void dropout_grad(const Tensor& mask,
Tensor* x_grad) { Tensor* x_grad) {
if (!x_grad) return; if (!x_grad) return;
if (is_test) { if (is_test) {
if (mode == "unscale_in_train") { if (mode == "upscale_in_train") {
by_pass<T>(out_grad, x_grad); by_pass<T>(out_grad, x_grad);
} else { } else {
set_output<T>(out_grad * (1.0 - p.to<float>()), x_grad); set_output<T>(out_grad * (1.0 - p.to<float>()), x_grad);
} }
} else { } else {
if (mode == "unscale_in_train") { if (mode == "upscale_in_train") {
if (p.to<float>() == 1.0f) { if (p.to<float>() == 1.0f) {
set_output<T>(out_grad * 0.0, x_grad); set_output<T>(out_grad * 0.0, x_grad);
} else { } else {
......
...@@ -1045,13 +1045,13 @@ class TestToPrim(unittest.TestCase): ...@@ -1045,13 +1045,13 @@ class TestToPrim(unittest.TestCase):
core._set_prim_forward_enabled(False) core._set_prim_forward_enabled(False)
paddle.disable_static() paddle.disable_static()
@param.parameterized((('dropout',),)) @param.parameterized.expand((({'dropout'},),))
def test_exclude(self, exclude): def test_exclude(self, exclude):
program = paddle.static.Program() program = paddle.static.Program()
with paddle.static.program_guard(program): with paddle.static.program_guard(program):
x = paddle.rand((1,)) x = paddle.rand((1,))
y = paddle.nn.functional.dropout(x) y = paddle.nn.functional.dropout(x)
primapi.to_prim(program, exclude) primapi.to_prim(program.blocks, exclude)
ops = tuple(op.type for op in program.block(0).ops) ops = tuple(op.type for op in program.block(0).ops)
self.assertTrue(all(tuple(op in ops for op in exclude))) self.assertTrue(all(tuple(op in ops for op in exclude)))
......
...@@ -244,22 +244,22 @@ class TestCompositeBatchNorm(unittest.TestCase): ...@@ -244,22 +244,22 @@ class TestCompositeBatchNorm(unittest.TestCase):
atol=attrs.get_atol("forward"), atol=attrs.get_atol("forward"),
) )
def test_forward(self): # def test_forward(self):
for i in self.training: # for i in self.training:
for j in self.dtypes: # for j in self.dtypes:
for m in self.momentum: # for m in self.momentum:
attrs.set_training(i) # attrs.set_training(i)
attrs.set_dtype(j) # attrs.set_dtype(j)
attrs.set_momentum(m) # attrs.set_momentum(m)
self.compare_forward() # self.compare_forward()
for n in self.shapes: # for n in self.shapes:
for s in self.data_formats: # for s in self.data_formats:
for t in self.use_global_stats: # for t in self.use_global_stats:
attrs.set_shape(n) # attrs.set_shape(n)
attrs.set_data_format(s) # attrs.set_data_format(s)
attrs.set_use_global_stats(t) # attrs.set_use_global_stats(t)
self.compare_forward() # self.compare_forward()
def apply_to_static(net, use_cinn): def apply_to_static(net, use_cinn):
......
...@@ -164,11 +164,13 @@ class TestCompositeDropout(unittest.TestCase): ...@@ -164,11 +164,13 @@ class TestCompositeDropout(unittest.TestCase):
return fwd, rev, mp return fwd, rev, mp
core._set_prim_forward_enabled(False) core._set_prim_forward_enabled(False)
core._set_prim_backward_enabled(False)
desired_fwd, desired_rev, _ = dropout( desired_fwd, desired_rev, _ = dropout(
self.x, self.p, self.is_test, self.mode, self.seed self.x, self.p, self.is_test, self.mode, self.seed
) )
core._set_prim_forward_enabled(True) core._set_prim_forward_enabled(True)
core._set_prim_backward_enabled(False)
actual_fwd, actual_rev, prog = dropout( actual_fwd, actual_rev, prog = dropout(
self.x, self.p, self.is_test, self.mode, self.seed self.x, self.p, self.is_test, self.mode, self.seed
) )
...@@ -188,6 +190,23 @@ class TestCompositeDropout(unittest.TestCase): ...@@ -188,6 +190,23 @@ class TestCompositeDropout(unittest.TestCase):
atol=0, atol=0,
) )
core._set_prim_forward_enabled(False)
core._set_prim_backward_enabled(True)
actual_fwd, actual_rev, _ = dropout(
self.x, self.p, self.is_test, self.mode, self.seed
)
np.testing.assert_allclose(
actual_fwd.sum(),
desired_fwd.sum(),
rtol=1e-2, # mean of uniform distribution, scale for avoid random failed
atol=0,
)
np.testing.assert_allclose(
actual_rev.sum(),
desired_rev.sum(),
rtol=1e-2, # mean of uniform distribution, scale for avoid random failed
atol=0,
)
core._set_prim_all_enabled(True) core._set_prim_all_enabled(True)
actual_fwd, actual_rev, _ = dropout( actual_fwd, actual_rev, _ = dropout(
self.x, self.p, self.is_test, self.mode, self.seed self.x, self.p, self.is_test, self.mode, self.seed
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle.fluid import core
class TestCustomVJP(unittest.TestCase):
def setUp(self):
def func():
x = paddle.rand((1,))
x.stop_gradient = False
return paddle.nn.functional.dropout(x)
self.f = func
self.ops_fwd_enable_bwd_disable = (
'uniform_random',
'uniform_random',
'fill_constant',
'greater_equal',
'cast',
'elementwise_mul',
'scale',
'cast',
'fill_any_like',
'scale',
'elementwise_mul_grad',
)
self.ops_fwd_disable_bwd_enable = (
'uniform_random',
'dropout',
'fill_any_like',
'cast',
'elementwise_mul',
'fill_constant',
'elementwise_div',
)
self.ops_all_enable = (
'uniform_random',
'uniform_random',
'fill_constant',
'greater_equal',
'cast',
'elementwise_mul',
'scale',
'cast',
'fill_constant',
'cast',
'elementwise_mul',
'fill_constant',
'elementwise_div',
)
def test_enable_prim_fwd(self):
core._set_prim_forward_enabled(True)
core._set_prim_backward_enabled(False)
self.assertEqual(
self.ops_fwd_enable_bwd_disable,
tuple(
op.type
for op in paddle.jit.to_static(self.f)
.get_concrete_program()[1]
._train_program.block(0)
.ops
),
)
core._set_prim_forward_enabled(False)
core._set_prim_backward_enabled(False)
def test_enable_prim_bwd(self):
core._set_prim_forward_enabled(False)
core._set_prim_backward_enabled(True)
self.assertEqual(
self.ops_fwd_disable_bwd_enable,
tuple(
op.type
for op in paddle.jit.to_static(self.f)
.get_concrete_program()[1]
._train_program.block(0)
.ops
),
)
core._set_prim_forward_enabled(False)
core._set_prim_backward_enabled(False)
def test_enable_prim_all(self):
core._set_prim_all_enabled(True)
self.assertEqual(
self.ops_all_enable,
tuple(
op.type
for op in paddle.jit.to_static(self.f)
.get_concrete_program()[1]
._train_program.block(0)
.ops
),
)
core._set_prim_all_enabled(False)
if __name__ == '__main__':
unittest.main()
...@@ -239,6 +239,11 @@ def to_prim(blocks, exclude=frozenset()): ...@@ -239,6 +239,11 @@ def to_prim(blocks, exclude=frozenset()):
raise TypeError( raise TypeError(
f"Expect block or sequence of blocks, but got {type(blocks)}." f"Expect block or sequence of blocks, but got {type(blocks)}."
) )
if not isinstance(exclude, (set, frozenset)):
raise TypeError(
f'Expected type of exclude is set|frozenset, but got {type(exclude)}.'
)
with framework.program_guard(main_program): with framework.program_guard(main_program):
logging.debug("Lowering composite forward ops begin...") logging.debug("Lowering composite forward ops begin...")
primx._lower_composite( primx._lower_composite(
......
...@@ -645,36 +645,9 @@ def _lower_composite(block, blacklist=frozenset()): ...@@ -645,36 +645,9 @@ def _lower_composite(block, blacklist=frozenset()):
else: else:
none_vars_to_remove.add(orig_out.name) none_vars_to_remove.add(orig_out.name)
else: else:
<<<<<<< HEAD
inputs = {}
for i in range(len(op.input_names)):
inputs[op.input_names[i]] = bind_name(
op.input(op.input_names[i]), to_bind
)
outputs = {}
for i in range(len(op.output_names)):
outputs[op.output_names[i]] = op.output(op.output_names[i])
from paddle.fluid.dygraph.base import param_guard
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(op.desc)
with param_guard(inputs), param_guard(outputs):
op = Operator(
block=block,
desc=new_op_desc,
type=op.type,
inputs=inputs,
outputs=outputs,
attrs=None,
)
block.ops.append(op)
=======
op_desc = block.desc.append_op() op_desc = block.desc.append_op()
op_desc.copy_from(op.desc) op_desc.copy_from(op.desc)
block._sync_with_cpp() block._sync_with_cpp()
>>>>>>> [prim] enable dygraph_to_static to support custom_vjp
# Step3: Do some post-processing work # Step3: Do some post-processing work
for op_idx in reversed(ops_to_remove): for op_idx in reversed(ops_to_remove):
......
...@@ -195,7 +195,7 @@ class PartialProgramLayer: ...@@ -195,7 +195,7 @@ class PartialProgramLayer:
# Set default mode to train # Set default mode to train
self.training = True self.training = True
self._infer_info = ProgramInfo() self._infer_info = ProgramInfo()
self._backward_start_index_map = {} self._forward_end_index_map = {}
custom_white_list, custom_black_list = None, None custom_white_list, custom_black_list = None, None
tracer = framework._dygraph_tracer() tracer = framework._dygraph_tracer()
...@@ -314,7 +314,10 @@ class PartialProgramLayer: ...@@ -314,7 +314,10 @@ class PartialProgramLayer:
@switch_to_static_graph @switch_to_static_graph
def _create_forward_backward_train_program(self): def _create_forward_backward_train_program(self):
whole_program = self._train_program whole_program = self._train_program
_, forward_end_op_index = self._infer_info('fp32', self._create_program) # _, forward_end_op_index = self._infer_info('fp32', self._create_program)
forward_end_op_index = self._forward_end_index_map[
_hash_with_id(whole_program, self)
]
assert forward_end_op_index >= 0 assert forward_end_op_index >= 0
return self._get_forward_backward_program_form( return self._get_forward_backward_program_form(
...@@ -642,15 +645,16 @@ class PartialProgramLayer: ...@@ -642,15 +645,16 @@ class PartialProgramLayer:
if targets: if targets:
# TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch. # TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch.
core.check_and_set_prim_all_enabled() core.check_and_set_prim_all_enabled()
start_idx = len(program.block(0).ops) + len(self._outputs.tolist())
backward.gradients(targets=targets, inputs=[]) backward.gradients(targets=targets, inputs=[])
start_idx = (
len(main_program.block(0).ops) + len(self._outputs.tolist()) + 1
)
if self._hooker: if self._hooker:
program, start_idx = self._hooker.after_append_backward( program, start_idx = self._hooker.after_append_backward(
self, program, start_idx self, program, start_idx
) )
# self._backward_start_index_map[self._hash_with_id(program, self)] self._forward_end_index_map[
_hash_with_id(program, self)
] = start_idx - len(self._outputs.tolist())
# TODO: prim make this complicate # TODO: prim make this complicate
self.prepare_gradient_aggregation(start_idx, main_program, program) self.prepare_gradient_aggregation(start_idx, main_program, program)
...@@ -733,10 +737,6 @@ class PartialProgramLayer: ...@@ -733,10 +737,6 @@ class PartialProgramLayer:
self.program_id, self.program_id,
] ]
print(self.forward_program)
print(self.backward_program)
print(self.program_id)
if self.training: if self.training:
# NOTE: In the case of higher-order gradient, the names of the parameter grads may be like # NOTE: In the case of higher-order gradient, the names of the parameter grads may be like
# `grad/grad/grad/linear_0.w_0@GRAD` instead of simply `linear_0.w_0@GRAD`, so we get # `grad/grad/grad/linear_0.w_0@GRAD` instead of simply `linear_0.w_0@GRAD`, so we get
...@@ -1155,5 +1155,5 @@ def add_build_strategy_for( ...@@ -1155,5 +1155,5 @@ def add_build_strategy_for(
if hasattr(compiled_program._program, 'lr_sheduler'): if hasattr(compiled_program._program, 'lr_sheduler'):
builded_program.lr_sheduler = compiled_program._program.lr_sheduler builded_program.lr_sheduler = compiled_program._program.lr_sheduler
else: else:
builded_program = program builded_program = paddle.static.Program()
return builded_program return builded_program
...@@ -19,6 +19,7 @@ import threading ...@@ -19,6 +19,7 @@ import threading
import warnings import warnings
import weakref import weakref
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
from paddle.fluid import _non_static_mode, core, framework from paddle.fluid import _non_static_mode, core, framework
from paddle.fluid.data_feeder import check_type from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph import layers from paddle.fluid.dygraph import layers
...@@ -1183,15 +1184,13 @@ class ProgramCache: ...@@ -1183,15 +1184,13 @@ class ProgramCache:
class PrimHooker(PartialProgramLayerHook): class PrimHooker(PartialProgramLayerHook):
def __init__(self): def __init__(self):
custom_vjps = set() self.custom_vjps = set()
if core._is_fwd_prim_enabled() and core._is_bwd_prim_enabled(): if core._is_fwd_prim_enabled() and core._is_bwd_prim_enabled():
custom_vjps = { self.custom_vjps = {
op.type op.type
for op in concrete_program.main_program.block(0).ops for op in concrete_program.main_program.block(0).ops
if core.has_comp_grad_op_maker(op.type) if core.has_comp_grad_op_maker(op.type)
} }
self.custom_vjps = custom_vjps
self.custom_vjps = {"softmax"}
def before_append_backward( def before_append_backward(
self, partial_program_layer, forward_program self, partial_program_layer, forward_program
...@@ -1219,7 +1218,8 @@ class ProgramCache: ...@@ -1219,7 +1218,8 @@ class ProgramCache:
return infer_program return infer_program
partial_program = partial_program_from(concrete_program) partial_program = partial_program_from(concrete_program)
partial_program.set_hooker(PrimHooker()) if not _in_amp_guard() and not _in_pure_fp16_guard():
partial_program.set_hooker(PrimHooker())
return concrete_program, partial_program return concrete_program, partial_program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册