未验证 提交 fee84e09 编写于 作者: L levi131 提交者: GitHub

Add bernoulli primitive op and support dropout op in new AD. (#46238)

* init dropout

* small format fix

* fix pr comments

* add value test
上级 403cd2b5
......@@ -36,6 +36,7 @@ set(PRIM_OP_SRCS
pow_p_op.cc
max_p_op.cc
erf_p_op.cc
bernoulli_p_op.cc
abs_p_op.cc
cast_p_op.cc
rsqrt_p_op.cc)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class BernoulliPrimOp : public framework::OperatorBase {
public:
BernoulliPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator bernoulli_p should not be excuted directly"));
}
};
class BernoulliPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddOutput("Y", "(Tensor), The output tensor of bernoulli_p op.");
AddAttr<std::vector<int64_t>>(
"shape", "(std::vector<int64_t>) The shape of output tensor.");
AddAttr<int>("dtype", "(int) The dtype of output tensor.");
AddAttr<float>("p", "(float) The probability of bernoulli distribution.");
AddComment(R"DOC(
Autograd primitive bernoulli_p operator.
)DOC");
}
};
class BernoulliPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
auto shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(shape);
}
};
class BernoulliPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto y_name = Output(ctx, "Y")[0];
auto data_type = static_cast<framework::proto::VarType::Type>(
PADDLE_GET_CONST(int, ctx->GetAttr("dtype")));
SetDataType(ctx, y_name, data_type);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(bernoulli_p,
paddle::operators::BernoulliPrimOp,
paddle::operators::BernoulliPrimOpMaker,
paddle::operators::BernoulliPrimOpShapeInference,
paddle::operators::BernoulliPrimOpVarTypeInference);
......@@ -40,6 +40,7 @@ USE_OP_ITSELF(eq_p);
USE_OP_ITSELF(pow_p);
USE_OP_ITSELF(max_p);
USE_OP_ITSELF(erf_p);
USE_OP_ITSELF(bernoulli_p);
namespace paddle {
namespace framework {
......@@ -730,5 +731,26 @@ TEST(PrimOp, erf_p) {
ASSERT_EQ(shapes[2], 5L);
}
TEST(PrimOp, bernoulli_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::string x0 = "x0";
AppendOp(block,
"bernoulli_p",
{{}},
{{"Y", {x0}}},
{{"p", 0.5f},
{"dtype", proto::VarType_Type_FP32},
{"shape", std::vector<int64_t>{3, 4, 5}}});
ASSERT_EQ(block->Var("x0")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x0")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("x0")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 3L);
ASSERT_EQ(shapes[1], 4L);
ASSERT_EQ(shapes[2], 5L);
}
} // namespace framework
} // namespace paddle
......@@ -766,10 +766,120 @@ class TestGeluApproximateOrig2Prim(TestElementWiseAddOrig2Prim):
self.out_map = {0: self.output['Out']}
class TestDropoutOrig2PrimCase1(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'dropout'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
self.input = {'X': X}
self.output = {
'Mask':
self.layer_help.create_variable_for_type_inference(
dtype=paddle.uint8),
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype),
}
self.attrs = {
'dropout_prob': 0.5,
'is_test': False,
'dropout_implementation': 'upscale_in_train'
}
self.orig2prim_args = (None, X)
self.all_ops = [
'bernoulli_p', 'mul_p', 'fill_constant_p', 'div_p', 'cast_p',
'dropout'
]
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Mask'], 1: self.output['Out']}
class TestDropoutOrig2PrimCase2(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'dropout'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
self.input = {'X': X}
self.output = {
'Mask':
self.layer_help.create_variable_for_type_inference(
dtype=paddle.uint8),
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype),
}
self.attrs = {
'dropout_prob': 0.5,
'is_test': False,
'dropout_implementation': 'downgrade_in_infer'
}
self.orig2prim_args = (None, X)
self.all_ops = ['bernoulli_p', 'mul_p', 'cast_p', 'dropout']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Mask'], 1: self.output['Out']}
class TestDropoutOrig2PrimCase3(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'dropout'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
self.input = {'X': X}
self.output = {
'Mask':
self.layer_help.create_variable_for_type_inference(
dtype=paddle.uint8),
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype),
}
self.attrs = {
'dropout_prob': 0.5,
'is_test': True,
'dropout_implementation': 'upscale_in_train'
}
self.orig2prim_args = (None, X)
self.all_ops = ['bernoulli_p', 'cast_p', 'dropout']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Mask'], 1: self.output['Out']}
class TestDropoutOrig2PrimCase4(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'dropout'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
self.input = {'X': X}
self.output = {
'Mask':
self.layer_help.create_variable_for_type_inference(
dtype=paddle.uint8),
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype),
}
self.attrs = {
'dropout_prob': 0.5,
'is_test': True,
'dropout_implementation': 'downgrade_in_infer'
}
self.orig2prim_args = (None, X)
self.all_ops = [
'bernoulli_p', 'fill_constant_p', 'mul_p', 'cast_p', 'dropout'
]
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Mask'], 1: self.output['Out']}
class TestReduceSumOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'reduce_sum'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
self.input = {'X': X}
......
......@@ -670,6 +670,24 @@ class TestMaxPPrim2Orig(TestAddPPrim2Orig):
self.out_map = {self.output['Z']: 0}
class TestBernoulliPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'bernoulli_p'
self.input = {}
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(
dtype=paddle.float64)
}
self.attrs = {'shape': [7, 8], 'dtype': paddle.float64, 'p': 0.5}
self.prim2orig_args = ()
self.all_ops = ['bernoulli_p', 'fill_constant', 'bernoulli']
self.out_map = {self.output['Y']: 0}
class TestCastPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
......
......@@ -25,6 +25,69 @@ import config
import utils
@utils.place(config.DEVICES)
@utils.parameterize((utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'),
(('dropout', paddle.nn.functional.dropout,
(np.random.rand(5000, 5000), ), None, 'float32'), ))
class TestDropoutGrad(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.xs = tuple(x.astype(cls.dtype) for x in cls.xs)
cls._rtol = config.TOLERANCE.get(str(
cls.dtype)).get("first_order_grad").get("rtol")
cls._atol = config.TOLERANCE.get(str(
cls.dtype)).get("first_order_grad").get("atol")
def setUp(self):
paddle.enable_static()
paddle.incubate.autograd.enable_prim()
def tearDown(self):
paddle.incubate.autograd.disable_prim()
paddle.disable_static()
def test_grad(self):
def expected():
paddle.incubate.autograd.disable_prim()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
feed, static_xs, static_v = utils.gen_static_data_and_feed(
self.xs, self.v, stop_gradient=False)
_, ys_grad = paddle.incubate.autograd.vjp(
self.fun, static_xs, static_v)
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(mp, feed=feed, fetch_list=ys_grad)
paddle.incubate.autograd.enable_prim()
return out
def actual():
paddle.incubate.autograd.enable_prim()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
feed, static_xs, static_v = utils.gen_static_data_and_feed(
self.xs, self.v, stop_gradient=False)
ys = self.fun(*static_xs) if isinstance(
static_xs, typing.Sequence) else self.fun(static_xs)
ys_grad = paddle.incubate.autograd.grad(ys, static_xs, static_v)
paddle.incubate.autograd.prim2orig(mp.block(0))
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(mp, feed=feed, fetch_list=ys_grad)
paddle.incubate.autograd.disable_prim()
return out
expected = expected()
actual = actual()
self.assertEqual(type(actual), type(expected))
for i, j in zip(actual, expected):
np.testing.assert_allclose(np.sum(i), np.sum(j), rtol=1e-3)
@utils.place(config.DEVICES)
@utils.parameterize(
(utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'),
......
......@@ -76,6 +76,15 @@ def fill_const(value, shape, dtype, out=None):
return out
def bernoulli(shape, dtype, p, out=None):
attrs = {'shape': shape, 'dtype': dtype, 'p': p}
helper = LayerHelper('bernoulli_p', **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(type=helper.layer_type, outputs={'Y': out}, attrs=attrs)
return out
def neg(x, out=None):
zero = fill_const(0.0, x.shape, x.dtype)
return sub(zero, x)
......
......@@ -23,7 +23,7 @@ from .primops import (add, broadcast, concat, cos, div, eq, erf, exp,
fill_const, gather, ge, gt, log, matmul, max, mul, ne,
neg, reduce_sum, reshape, scatter_add, select, set_value,
sin, slice_assign, slice_select, split, sqrt, sub, tanh,
transpose, rsqrt)
transpose, bernoulli, rsqrt)
from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG,
REGISTER_TRANSPOSE, lookup_fn, lookup_jvp,
lookup_orig2prim, lookup_prim2orig, lookup_transpose,
......@@ -78,6 +78,7 @@ log
select
equal
elementwise_pow
dropout
These original ops are partially supported:
......@@ -439,6 +440,30 @@ def gelu_orig2prim(op, x):
erf(mul(x, fill_const(1 / math.sqrt(2.), x.shape, x.dtype)))))
@REGISTER_ORIG2PRIM('dropout')
def dropout_orig2prim(op, seed_t, x):
assert seed_t is None, 'Can not lower dropout into prim ops with seedtensor.'
mask = bernoulli(shape=x.shape, dtype=x.dtype, p=op.attr('dropout_prob'))
if op.attr('dropout_implementation') == 'upscale_in_train':
if op.attr('is_test') == False:
out = div(
mul(x, mask),
fill_const(1.0 - op.attr('dropout_prob'), x.shape, x.dtype))
return primops.cast(mask, dtype=paddle.uint8), out
else:
return primops.cast(mask, dtype=paddle.uint8), x
elif op.attr('dropout_implementation') == 'downgrade_in_infer':
if op.attr('is_test') == False:
return primops.cast(mask, dtype=paddle.uint8), mul(x, mask)
else:
return primops.cast(mask, dtype=paddle.uint8), mul(
x, fill_const(1.0 - op.attr('dropout_prob'), x.shape, x.dtype))
else:
raise RuntimeError(
'Unsupported dropout_implementation, only support upscale_in_train and downgrade_in_infer'
)
@REGISTER_ORIG2PRIM('reduce_sum')
def reduce_sum_orig2prim(op, x):
axes = tuple(range(0, len(
......@@ -634,6 +659,14 @@ def fill_constant_prim2orig(op):
dtype=INT_DTYPE_2_STRING[op.attr('dtype')])
@REGISTER_PRIM2ORIG('bernoulli_p')
def bernoulli_prim2orig(op):
t = paddle.full(shape=op.attr('shape'),
fill_value=op.attr('p'),
dtype=INT_DTYPE_2_STRING[op.attr('dtype')])
return paddle.bernoulli(t)
@REGISTER_PRIM2ORIG('select_p')
def select_prim2orig(op, condition, x, y):
return paddle.where(condition, x, y)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册