提交 fc68290b 编写于 作者: Q Qiao Longfei 提交者: GitHub

update _create_op_func_ and support generate dropout layer (#5134)

上级 aa379ccb
......@@ -30,7 +30,7 @@ class DropoutOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
if (ctx->Attrs().Get<bool>("is_training") == 1) {
if (ctx->Attrs().Get<bool>("is_training") == true) {
ctx->SetOutputDim("Mask", x_dims);
}
ctx->ShareLoD("X", /*->*/ "Out");
......@@ -43,7 +43,7 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
DropoutOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<AttrType>("dropout_prob", "Probability of setting units to zero.")
AddAttr<float>("dropout_prob", "Probability of setting units to zero.")
.SetDefault(.5f);
AddAttr<bool>("is_training", "Whether in training phase.").SetDefault(true);
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
......@@ -69,7 +69,7 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_training"), 1,
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_training"), true,
"GradOp is only callable when is_training is true");
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
......@@ -77,8 +77,8 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) must not be null.");
PADDLE_ENFORCE_GE(ctx->Attrs().Get<AttrType>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx->Attrs().Get<AttrType>("dropout_prob"), 1);
PADDLE_ENFORCE_GE(ctx->Attrs().Get<float>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx->Attrs().Get<float>("dropout_prob"), 1);
auto x_dims = ctx->GetInputDim("X");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(x_dims, out_dims,
......
......@@ -33,7 +33,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto* y = context.Output<Tensor>("Out");
const auto* x_data = x->data<T>();
auto* y_data = y->mutable_data<T>(context.GetPlace());
AttrType dropout_prob = context.Attr<AttrType>("dropout_prob");
float dropout_prob = context.Attr<float>("dropout_prob");
if (context.Attr<bool>("is_training")) {
auto* mask = context.Output<Tensor>("Mask");
......@@ -41,7 +41,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
int seed = context.Attr<int>("seed");
std::minstd_rand engine;
engine.seed(seed);
std::uniform_real_distribution<AttrType> dist(0, 1);
std::uniform_real_distribution<float> dist(0, 1);
size_t size = framework::product(mask->dims());
for (size_t i = 0; i < size; ++i) {
if (dist(engine) < dropout_prob) {
......
......@@ -97,15 +97,28 @@ def _convert_(name):
def _create_op_func_(op_type):
op_proto = OpProtoHolder.instance().get_op_proto(op_type)
if len(op_proto.outputs) != 1:
not_intermediate_outputs = \
filter(lambda output: not output.intermediate, op_proto.outputs)
intermediate_outputs = \
filter(lambda output: output.intermediate, op_proto.outputs)
if len(not_intermediate_outputs) != 1:
raise ValueError(
"Only one output operator can be automatically generated")
"Only one not intermediate output operator can be automatically generated"
)
if op_proto.outputs[0].duplicable:
if not_intermediate_outputs[0].duplicable:
raise ValueError(
"Only not duplicable op can be automatically generated")
o_name = op_proto.outputs[0].name
for output in intermediate_outputs:
if output.duplicable:
raise ValueError(
"Only when all intermediate ops are not duplicable, "
"this op can be automatically generated")
o_name = not_intermediate_outputs[0].name
intermediate_output_names = [output.name for output in intermediate_outputs]
def func(**kwargs):
helper = LayerHelper(op_type, **kwargs)
......@@ -128,9 +141,13 @@ def _create_op_func_(op_type):
"operator {0} must input same dtype".format(op_type))
inputs[ipt.name] = val
outputs = dict()
out = helper.create_tmp_variable(dtype=dtype)
outputs[o_name] = [out]
for name in intermediate_output_names:
outputs[name] = [helper.create_tmp_variable(dtype=dtype)]
helper.append_op(
type=op_type, inputs=inputs, outputs={o_name: [out]}, attrs=kwargs)
type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs)
return out
func.__name__ = op_type
......@@ -141,6 +158,7 @@ def _create_op_func_(op_type):
_create_op_func_('mean')
_create_op_func_('mul')
_create_op_func_('dropout')
def concat(input, axis, program=None, init_program=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册