提交 fc68290b 编写于 作者: Q Qiao Longfei 提交者: GitHub

update _create_op_func_ and support generate dropout layer (#5134)

上级 aa379ccb
...@@ -30,7 +30,7 @@ class DropoutOp : public framework::OperatorWithKernel { ...@@ -30,7 +30,7 @@ class DropoutOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims); ctx->SetOutputDim("Out", x_dims);
if (ctx->Attrs().Get<bool>("is_training") == 1) { if (ctx->Attrs().Get<bool>("is_training") == true) {
ctx->SetOutputDim("Mask", x_dims); ctx->SetOutputDim("Mask", x_dims);
} }
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
...@@ -43,7 +43,7 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -43,7 +43,7 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
DropoutOpMaker(framework::OpProto* proto, DropoutOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<AttrType>("dropout_prob", "Probability of setting units to zero.") AddAttr<float>("dropout_prob", "Probability of setting units to zero.")
.SetDefault(.5f); .SetDefault(.5f);
AddAttr<bool>("is_training", "Whether in training phase.").SetDefault(true); AddAttr<bool>("is_training", "Whether in training phase.").SetDefault(true);
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0); AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
...@@ -69,7 +69,7 @@ class DropoutOpGrad : public framework::OperatorWithKernel { ...@@ -69,7 +69,7 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_training"), 1, PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_training"), true,
"GradOp is only callable when is_training is true"); "GradOp is only callable when is_training is true");
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
...@@ -77,8 +77,8 @@ class DropoutOpGrad : public framework::OperatorWithKernel { ...@@ -77,8 +77,8 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) must not be null."); "Input(Out@GRAD) must not be null.");
PADDLE_ENFORCE_GE(ctx->Attrs().Get<AttrType>("dropout_prob"), 0); PADDLE_ENFORCE_GE(ctx->Attrs().Get<float>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx->Attrs().Get<AttrType>("dropout_prob"), 1); PADDLE_ENFORCE_LE(ctx->Attrs().Get<float>("dropout_prob"), 1);
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(x_dims, out_dims, PADDLE_ENFORCE_EQ(x_dims, out_dims,
......
...@@ -33,7 +33,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> { ...@@ -33,7 +33,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto* y = context.Output<Tensor>("Out"); auto* y = context.Output<Tensor>("Out");
const auto* x_data = x->data<T>(); const auto* x_data = x->data<T>();
auto* y_data = y->mutable_data<T>(context.GetPlace()); auto* y_data = y->mutable_data<T>(context.GetPlace());
AttrType dropout_prob = context.Attr<AttrType>("dropout_prob"); float dropout_prob = context.Attr<float>("dropout_prob");
if (context.Attr<bool>("is_training")) { if (context.Attr<bool>("is_training")) {
auto* mask = context.Output<Tensor>("Mask"); auto* mask = context.Output<Tensor>("Mask");
...@@ -41,7 +41,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> { ...@@ -41,7 +41,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
int seed = context.Attr<int>("seed"); int seed = context.Attr<int>("seed");
std::minstd_rand engine; std::minstd_rand engine;
engine.seed(seed); engine.seed(seed);
std::uniform_real_distribution<AttrType> dist(0, 1); std::uniform_real_distribution<float> dist(0, 1);
size_t size = framework::product(mask->dims()); size_t size = framework::product(mask->dims());
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
if (dist(engine) < dropout_prob) { if (dist(engine) < dropout_prob) {
......
...@@ -97,15 +97,28 @@ def _convert_(name): ...@@ -97,15 +97,28 @@ def _convert_(name):
def _create_op_func_(op_type): def _create_op_func_(op_type):
op_proto = OpProtoHolder.instance().get_op_proto(op_type) op_proto = OpProtoHolder.instance().get_op_proto(op_type)
if len(op_proto.outputs) != 1: not_intermediate_outputs = \
filter(lambda output: not output.intermediate, op_proto.outputs)
intermediate_outputs = \
filter(lambda output: output.intermediate, op_proto.outputs)
if len(not_intermediate_outputs) != 1:
raise ValueError( raise ValueError(
"Only one output operator can be automatically generated") "Only one not intermediate output operator can be automatically generated"
)
if op_proto.outputs[0].duplicable: if not_intermediate_outputs[0].duplicable:
raise ValueError( raise ValueError(
"Only not duplicable op can be automatically generated") "Only not duplicable op can be automatically generated")
o_name = op_proto.outputs[0].name for output in intermediate_outputs:
if output.duplicable:
raise ValueError(
"Only when all intermediate ops are not duplicable, "
"this op can be automatically generated")
o_name = not_intermediate_outputs[0].name
intermediate_output_names = [output.name for output in intermediate_outputs]
def func(**kwargs): def func(**kwargs):
helper = LayerHelper(op_type, **kwargs) helper = LayerHelper(op_type, **kwargs)
...@@ -128,9 +141,13 @@ def _create_op_func_(op_type): ...@@ -128,9 +141,13 @@ def _create_op_func_(op_type):
"operator {0} must input same dtype".format(op_type)) "operator {0} must input same dtype".format(op_type))
inputs[ipt.name] = val inputs[ipt.name] = val
outputs = dict()
out = helper.create_tmp_variable(dtype=dtype) out = helper.create_tmp_variable(dtype=dtype)
outputs[o_name] = [out]
for name in intermediate_output_names:
outputs[name] = [helper.create_tmp_variable(dtype=dtype)]
helper.append_op( helper.append_op(
type=op_type, inputs=inputs, outputs={o_name: [out]}, attrs=kwargs) type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs)
return out return out
func.__name__ = op_type func.__name__ = op_type
...@@ -141,6 +158,7 @@ def _create_op_func_(op_type): ...@@ -141,6 +158,7 @@ def _create_op_func_(op_type):
_create_op_func_('mean') _create_op_func_('mean')
_create_op_func_('mul') _create_op_func_('mul')
_create_op_func_('dropout')
def concat(input, axis, program=None, init_program=None): def concat(input, axis, program=None, init_program=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册