From fee84e094c4368db2ab1ca7a01b1bfdbcde3efa0 Mon Sep 17 00:00:00 2001 From: levi131 <83750468+levi131@users.noreply.github.com> Date: Tue, 27 Sep 2022 18:42:43 +0800 Subject: [PATCH] Add bernoulli primitive op and support dropout op in new AD. (#46238) * init dropout * small format fix * fix pr comments * add value test --- .../fluid/operators/prim_ops/CMakeLists.txt | 1 + .../operators/prim_ops/bernoulli_p_op.cc | 82 +++++++++++++ .../fluid/operators/prim_ops/prim_op_test.cc | 22 ++++ .../unittests/autograd/test_orig2prim.py | 110 ++++++++++++++++++ .../unittests/autograd/test_prim2orig.py | 18 +++ .../tests/unittests/autograd/test_primapi.py | 63 ++++++++++ python/paddle/incubate/autograd/primops.py | 9 ++ python/paddle/incubate/autograd/primrules.py | 35 +++++- 8 files changed, 339 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/prim_ops/bernoulli_p_op.cc diff --git a/paddle/fluid/operators/prim_ops/CMakeLists.txt b/paddle/fluid/operators/prim_ops/CMakeLists.txt index 7a75b9b9857..34290303dfb 100644 --- a/paddle/fluid/operators/prim_ops/CMakeLists.txt +++ b/paddle/fluid/operators/prim_ops/CMakeLists.txt @@ -36,6 +36,7 @@ set(PRIM_OP_SRCS pow_p_op.cc max_p_op.cc erf_p_op.cc + bernoulli_p_op.cc abs_p_op.cc cast_p_op.cc rsqrt_p_op.cc) diff --git a/paddle/fluid/operators/prim_ops/bernoulli_p_op.cc b/paddle/fluid/operators/prim_ops/bernoulli_p_op.cc new file mode 100644 index 00000000000..85e2eb7e07e --- /dev/null +++ b/paddle/fluid/operators/prim_ops/bernoulli_p_op.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +class InferShapeContext; +class VarDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class BernoulliPrimOp : public framework::OperatorBase { + public: + BernoulliPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator bernoulli_p should not be excuted directly")); + } +}; + +class BernoulliPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddOutput("Y", "(Tensor), The output tensor of bernoulli_p op."); + AddAttr>( + "shape", "(std::vector) The shape of output tensor."); + AddAttr("dtype", "(int) The dtype of output tensor."); + AddAttr("p", "(float) The probability of bernoulli distribution."); + AddComment(R"DOC( +Autograd primitive bernoulli_p operator. +)DOC"); + } +}; + +class BernoulliPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; + auto shape = ctx->Attrs().Get>("shape"); + PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(shape); + } +}; + +class BernoulliPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto y_name = Output(ctx, "Y")[0]; + auto data_type = static_cast( + PADDLE_GET_CONST(int, ctx->GetAttr("dtype"))); + SetDataType(ctx, y_name, data_type); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(bernoulli_p, + paddle::operators::BernoulliPrimOp, + paddle::operators::BernoulliPrimOpMaker, + paddle::operators::BernoulliPrimOpShapeInference, + paddle::operators::BernoulliPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/prim_op_test.cc b/paddle/fluid/operators/prim_ops/prim_op_test.cc index 153a4575463..87ec25d2f8c 100644 --- a/paddle/fluid/operators/prim_ops/prim_op_test.cc +++ b/paddle/fluid/operators/prim_ops/prim_op_test.cc @@ -40,6 +40,7 @@ USE_OP_ITSELF(eq_p); USE_OP_ITSELF(pow_p); USE_OP_ITSELF(max_p); USE_OP_ITSELF(erf_p); +USE_OP_ITSELF(bernoulli_p); namespace paddle { namespace framework { @@ -730,5 +731,26 @@ TEST(PrimOp, erf_p) { ASSERT_EQ(shapes[2], 5L); } +TEST(PrimOp, bernoulli_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::string x0 = "x0"; + + AppendOp(block, + "bernoulli_p", + {{}}, + {{"Y", {x0}}}, + {{"p", 0.5f}, + {"dtype", proto::VarType_Type_FP32}, + {"shape", std::vector{3, 4, 5}}}); + ASSERT_EQ(block->Var("x0")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x0")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("x0")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 3L); + ASSERT_EQ(shapes[1], 4L); + ASSERT_EQ(shapes[2], 5L); +} + } // namespace framework } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py index 8b3c0a2be90..275e5f1bee8 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py @@ -766,10 +766,120 @@ class TestGeluApproximateOrig2Prim(TestElementWiseAddOrig2Prim): self.out_map = {0: self.output['Out']} +class TestDropoutOrig2PrimCase1(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'dropout' + X = paddle.static.data(name='X', shape=[5, 8], dtype='float') + + self.input = {'X': X} + self.output = { + 'Mask': + self.layer_help.create_variable_for_type_inference( + dtype=paddle.uint8), + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype), + } + self.attrs = { + 'dropout_prob': 0.5, + 'is_test': False, + 'dropout_implementation': 'upscale_in_train' + } + + self.orig2prim_args = (None, X) + self.all_ops = [ + 'bernoulli_p', 'mul_p', 'fill_constant_p', 'div_p', 'cast_p', + 'dropout' + ] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Mask'], 1: self.output['Out']} + + +class TestDropoutOrig2PrimCase2(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'dropout' + X = paddle.static.data(name='X', shape=[5, 8], dtype='float') + + self.input = {'X': X} + self.output = { + 'Mask': + self.layer_help.create_variable_for_type_inference( + dtype=paddle.uint8), + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype), + } + self.attrs = { + 'dropout_prob': 0.5, + 'is_test': False, + 'dropout_implementation': 'downgrade_in_infer' + } + + self.orig2prim_args = (None, X) + self.all_ops = ['bernoulli_p', 'mul_p', 'cast_p', 'dropout'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Mask'], 1: self.output['Out']} + + +class TestDropoutOrig2PrimCase3(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'dropout' + X = paddle.static.data(name='X', shape=[5, 8], dtype='float') + + self.input = {'X': X} + self.output = { + 'Mask': + self.layer_help.create_variable_for_type_inference( + dtype=paddle.uint8), + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype), + } + self.attrs = { + 'dropout_prob': 0.5, + 'is_test': True, + 'dropout_implementation': 'upscale_in_train' + } + + self.orig2prim_args = (None, X) + self.all_ops = ['bernoulli_p', 'cast_p', 'dropout'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Mask'], 1: self.output['Out']} + + +class TestDropoutOrig2PrimCase4(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'dropout' + X = paddle.static.data(name='X', shape=[5, 8], dtype='float') + + self.input = {'X': X} + self.output = { + 'Mask': + self.layer_help.create_variable_for_type_inference( + dtype=paddle.uint8), + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype), + } + self.attrs = { + 'dropout_prob': 0.5, + 'is_test': True, + 'dropout_implementation': 'downgrade_in_infer' + } + + self.orig2prim_args = (None, X) + self.all_ops = [ + 'bernoulli_p', 'fill_constant_p', 'mul_p', 'cast_p', 'dropout' + ] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Mask'], 1: self.output['Out']} + + class TestReduceSumOrig2Prim(TestElementWiseAddOrig2Prim): def init_data(self): self.op_type = 'reduce_sum' + X = paddle.static.data(name='X', shape=[5, 8], dtype='float') self.input = {'X': X} diff --git a/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py b/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py index abc2803a8c0..a1f80e5506c 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py @@ -670,6 +670,24 @@ class TestMaxPPrim2Orig(TestAddPPrim2Orig): self.out_map = {self.output['Z']: 0} +class TestBernoulliPPrim2Orig(TestAddPPrim2Orig): + + def init_data(self): + self.op_type = 'bernoulli_p' + + self.input = {} + self.output = { + 'Y': + self.layer_help.create_variable_for_type_inference( + dtype=paddle.float64) + } + self.attrs = {'shape': [7, 8], 'dtype': paddle.float64, 'p': 0.5} + + self.prim2orig_args = () + self.all_ops = ['bernoulli_p', 'fill_constant', 'bernoulli'] + self.out_map = {self.output['Y']: 0} + + class TestCastPPrim2Orig(TestAddPPrim2Orig): def init_data(self): diff --git a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py index cd5004a815a..9e7dbae5bbd 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py @@ -25,6 +25,69 @@ import config import utils +@utils.place(config.DEVICES) +@utils.parameterize((utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'), + (('dropout', paddle.nn.functional.dropout, + (np.random.rand(5000, 5000), ), None, 'float32'), )) +class TestDropoutGrad(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.xs = tuple(x.astype(cls.dtype) for x in cls.xs) + cls._rtol = config.TOLERANCE.get(str( + cls.dtype)).get("first_order_grad").get("rtol") + cls._atol = config.TOLERANCE.get(str( + cls.dtype)).get("first_order_grad").get("atol") + + def setUp(self): + paddle.enable_static() + paddle.incubate.autograd.enable_prim() + + def tearDown(self): + paddle.incubate.autograd.disable_prim() + paddle.disable_static() + + def test_grad(self): + + def expected(): + paddle.incubate.autograd.disable_prim() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + feed, static_xs, static_v = utils.gen_static_data_and_feed( + self.xs, self.v, stop_gradient=False) + _, ys_grad = paddle.incubate.autograd.vjp( + self.fun, static_xs, static_v) + exe = paddle.static.Executor() + exe.run(sp) + out = exe.run(mp, feed=feed, fetch_list=ys_grad) + paddle.incubate.autograd.enable_prim() + return out + + def actual(): + paddle.incubate.autograd.enable_prim() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + feed, static_xs, static_v = utils.gen_static_data_and_feed( + self.xs, self.v, stop_gradient=False) + ys = self.fun(*static_xs) if isinstance( + static_xs, typing.Sequence) else self.fun(static_xs) + ys_grad = paddle.incubate.autograd.grad(ys, static_xs, static_v) + paddle.incubate.autograd.prim2orig(mp.block(0)) + exe = paddle.static.Executor() + exe.run(sp) + out = exe.run(mp, feed=feed, fetch_list=ys_grad) + paddle.incubate.autograd.disable_prim() + return out + + expected = expected() + actual = actual() + self.assertEqual(type(actual), type(expected)) + for i, j in zip(actual, expected): + np.testing.assert_allclose(np.sum(i), np.sum(j), rtol=1e-3) + + @utils.place(config.DEVICES) @utils.parameterize( (utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'), diff --git a/python/paddle/incubate/autograd/primops.py b/python/paddle/incubate/autograd/primops.py index e7002ece693..454a99b764a 100644 --- a/python/paddle/incubate/autograd/primops.py +++ b/python/paddle/incubate/autograd/primops.py @@ -76,6 +76,15 @@ def fill_const(value, shape, dtype, out=None): return out +def bernoulli(shape, dtype, p, out=None): + attrs = {'shape': shape, 'dtype': dtype, 'p': p} + helper = LayerHelper('bernoulli_p', **locals()) + if out is None: + out = helper.create_variable_for_type_inference(dtype) + helper.append_op(type=helper.layer_type, outputs={'Y': out}, attrs=attrs) + return out + + def neg(x, out=None): zero = fill_const(0.0, x.shape, x.dtype) return sub(zero, x) diff --git a/python/paddle/incubate/autograd/primrules.py b/python/paddle/incubate/autograd/primrules.py index a6a29e04184..73058912761 100644 --- a/python/paddle/incubate/autograd/primrules.py +++ b/python/paddle/incubate/autograd/primrules.py @@ -23,7 +23,7 @@ from .primops import (add, broadcast, concat, cos, div, eq, erf, exp, fill_const, gather, ge, gt, log, matmul, max, mul, ne, neg, reduce_sum, reshape, scatter_add, select, set_value, sin, slice_assign, slice_select, split, sqrt, sub, tanh, - transpose, rsqrt) + transpose, bernoulli, rsqrt) from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG, REGISTER_TRANSPOSE, lookup_fn, lookup_jvp, lookup_orig2prim, lookup_prim2orig, lookup_transpose, @@ -78,6 +78,7 @@ log select equal elementwise_pow +dropout These original ops are partially supported: @@ -439,6 +440,30 @@ def gelu_orig2prim(op, x): erf(mul(x, fill_const(1 / math.sqrt(2.), x.shape, x.dtype))))) +@REGISTER_ORIG2PRIM('dropout') +def dropout_orig2prim(op, seed_t, x): + assert seed_t is None, 'Can not lower dropout into prim ops with seedtensor.' + mask = bernoulli(shape=x.shape, dtype=x.dtype, p=op.attr('dropout_prob')) + if op.attr('dropout_implementation') == 'upscale_in_train': + if op.attr('is_test') == False: + out = div( + mul(x, mask), + fill_const(1.0 - op.attr('dropout_prob'), x.shape, x.dtype)) + return primops.cast(mask, dtype=paddle.uint8), out + else: + return primops.cast(mask, dtype=paddle.uint8), x + elif op.attr('dropout_implementation') == 'downgrade_in_infer': + if op.attr('is_test') == False: + return primops.cast(mask, dtype=paddle.uint8), mul(x, mask) + else: + return primops.cast(mask, dtype=paddle.uint8), mul( + x, fill_const(1.0 - op.attr('dropout_prob'), x.shape, x.dtype)) + else: + raise RuntimeError( + 'Unsupported dropout_implementation, only support upscale_in_train and downgrade_in_infer' + ) + + @REGISTER_ORIG2PRIM('reduce_sum') def reduce_sum_orig2prim(op, x): axes = tuple(range(0, len( @@ -634,6 +659,14 @@ def fill_constant_prim2orig(op): dtype=INT_DTYPE_2_STRING[op.attr('dtype')]) +@REGISTER_PRIM2ORIG('bernoulli_p') +def bernoulli_prim2orig(op): + t = paddle.full(shape=op.attr('shape'), + fill_value=op.attr('p'), + dtype=INT_DTYPE_2_STRING[op.attr('dtype')]) + return paddle.bernoulli(t) + + @REGISTER_PRIM2ORIG('select_p') def select_prim2orig(op, condition, x, y): return paddle.where(condition, x, y) -- GitLab