From 490eb9061f7d3bd19240fbff8465a2d5e4f25204 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Thu, 20 Dec 2018 08:45:43 +0000 Subject: [PATCH] polish infer shape of py_func op test=develop --- paddle/fluid/framework/op_desc.cc | 2 - paddle/fluid/framework/operator.cc | 2 - paddle/fluid/framework/shape_inference.h | 3 - paddle/fluid/operators/py_func_op.cc | 79 ++++++++++++------------ 4 files changed, 41 insertions(+), 45 deletions(-) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 0faf9fe0548..dde642764fa 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -34,8 +34,6 @@ class CompileTimeInferShapeContext : public InferShapeContext { public: CompileTimeInferShapeContext(const OpDesc &op, const BlockDesc &block); - InferShapeOpPtr GetOp() const override { return &op_; } - bool HasInput(const std::string &name) const override; bool HasOutput(const std::string &name) const override; diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 222b261e2a6..66055e6f1d8 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -481,8 +481,6 @@ class RuntimeInferShapeContext : public InferShapeContext { RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope) : op_(op), scope_(scope) {} - InferShapeOpPtr GetOp() const override { return &op_; } - bool HasInput(const std::string& name) const override { // has only one input const auto& ins = op_.Inputs(); diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index 2f95ab353ee..55349376baa 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -28,7 +28,6 @@ namespace framework { class OperatorBase; using InferShapeVarPtr = boost::variant; -using InferShapeOpPtr = boost::variant; class InferShapeContext { public: @@ -41,8 +40,6 @@ class InferShapeContext { std::vector GetOutputsVarType( const std::string &name) const; - virtual InferShapeOpPtr GetOp() const = 0; - virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0; diff --git a/paddle/fluid/operators/py_func_op.cc b/paddle/fluid/operators/py_func_op.cc index a2895b54043..a6b1c738af1 100644 --- a/paddle/fluid/operators/py_func_op.cc +++ b/paddle/fluid/operators/py_func_op.cc @@ -91,66 +91,68 @@ static void CallPythonFunc(py::object *callable, } } -class PyFuncOpShapeInference : public framework::InferShapeBase { +class PyFuncOpVarTypInference : public framework::VarTypeInference { public: - void operator()(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(!ctx->IsRuntime(), - "Infer shape cannot be called in runtime."); + void operator()(const framework::OpDesc &op, + framework::BlockDesc *block) const override { + auto &outs = op.Outputs(); + bool has_out = (outs.count("Out") > 0 && !outs.at("Out").empty()); + + auto &ins = op.Inputs(); + bool has_in = (ins.count("X") > 0 && !ins.at("X").empty()); /** * X or Out can be empty, so that py_func can be more flexible * to support Python functions with no input or no output */ - PADDLE_ENFORCE(ctx->HasInputs("X") || ctx->HasOutputs("Out"), - "Input(X) or Output(Out) must exist"); - PADDLE_ENFORCE_GE(ctx->Attrs().Get(kForwardPythonCallableId), 0, + PADDLE_ENFORCE(has_in || has_out, "Input(X) or Output(Out) must exist"); + + PADDLE_ENFORCE_GE(boost::get(op.GetAttr(kForwardPythonCallableId)), 0, "Function id cannot be less than 0"); + if (!has_out) return; + /** * Traverse all outputs, check if name of any output ends with @GRAD. * If found, set its shape, dtype, lod_level, type to be the same as * the corresponding forward variable - * - * Why not get input dims from InferShapeContext? - * Because some variables in forward inputs/outputs may not be needed - * in backward. Those variables are not inside InferShapeContext. - * - * InferShape would be only called in compile time. During runtime, - * the shapes of outputs should be guaranteed by user-defined Python - * functions. */ - auto *op = boost::get(ctx->GetOp()); - auto *block = op->Block(); const std::string kGradVarSuffix = framework::kGradVarSuffix; - auto out_vars = ctx->GetOutputVarPtrs("Out"); - for (auto &out_var : out_vars) { - auto *out_var_desc = boost::get(out_var); - if (out_var_desc == nullptr) { - continue; - } - auto out_name = out_var_desc->Name(); - if (out_name == framework::kEmptyVarName || - out_name.size() < kGradVarSuffix.size()) { + auto &out_var_names = outs.at("Out"); + for (auto &out_var_name : out_var_names) { + if (out_var_name == framework::kEmptyVarName || + out_var_name.size() < kGradVarSuffix.size()) { continue; } - size_t len = out_name.size() - kGradVarSuffix.size(); - if (out_name.substr(len) == kGradVarSuffix) { - auto fwd_var_name = out_name.substr(0, len); - auto *in_var_desc = block->FindVarRecursive(fwd_var_name); - PADDLE_ENFORCE_NOT_NULL(in_var_desc, "Forward variable %s not found", + size_t len = out_var_name.size() - kGradVarSuffix.size(); + if (out_var_name.substr(len) == kGradVarSuffix) { + auto fwd_var_name = out_var_name.substr(0, len); + auto *out_var_desc = block->FindVarRecursive(out_var_name); + auto *fwd_var_desc = block->FindVarRecursive(fwd_var_name); + PADDLE_ENFORCE_NOT_NULL(out_var_desc, "Backward variable %s not found", + out_var_name); + PADDLE_ENFORCE_NOT_NULL(fwd_var_desc, "Forward variable %s not found", fwd_var_name); - VLOG(10) << "Infer shape of Output(" << out_name << ") as Input(" - << in_var_desc->Name() << ")"; - out_var_desc->SetShape(in_var_desc->GetShape()); - out_var_desc->SetDataType(in_var_desc->GetDataType()); - out_var_desc->SetLoDLevel(in_var_desc->GetLoDLevel()); - out_var_desc->SetType(in_var_desc->GetType()); + VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input(" + << fwd_var_name << ")"; + out_var_desc->SetShape(fwd_var_desc->GetShape()); + out_var_desc->SetDataType(fwd_var_desc->GetDataType()); + out_var_desc->SetLoDLevel(fwd_var_desc->GetLoDLevel()); + out_var_desc->SetType(fwd_var_desc->GetType()); } } } }; +class PyFuncOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(!ctx->IsRuntime(), + "Infer shape cannot be called in runtime."); + } +}; + class PyFuncOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -307,4 +309,5 @@ class PyFuncOp : public framework::OperatorBase { namespace ops = paddle::operators; REGISTER_OPERATOR(py_func, ops::PyFuncOp, ops::PyFuncOpMaker, - ops::PyFuncOpShapeInference, ops::PyFuncOpGradDescMaker); + ops::PyFuncOpVarTypInference, ops::PyFuncOpShapeInference, + ops::PyFuncOpGradDescMaker); -- GitLab