提交 490eb906 编写于 作者: S sneaxiy

polish infer shape of py_func op

test=develop
上级 dc8847af
...@@ -34,8 +34,6 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -34,8 +34,6 @@ class CompileTimeInferShapeContext : public InferShapeContext {
public: public:
CompileTimeInferShapeContext(const OpDesc &op, const BlockDesc &block); CompileTimeInferShapeContext(const OpDesc &op, const BlockDesc &block);
InferShapeOpPtr GetOp() const override { return &op_; }
bool HasInput(const std::string &name) const override; bool HasInput(const std::string &name) const override;
bool HasOutput(const std::string &name) const override; bool HasOutput(const std::string &name) const override;
......
...@@ -481,8 +481,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -481,8 +481,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope) RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {} : op_(op), scope_(scope) {}
InferShapeOpPtr GetOp() const override { return &op_; }
bool HasInput(const std::string& name) const override { bool HasInput(const std::string& name) const override {
// has only one input // has only one input
const auto& ins = op_.Inputs(); const auto& ins = op_.Inputs();
......
...@@ -28,7 +28,6 @@ namespace framework { ...@@ -28,7 +28,6 @@ namespace framework {
class OperatorBase; class OperatorBase;
using InferShapeVarPtr = boost::variant<VarDesc *, Variable *>; using InferShapeVarPtr = boost::variant<VarDesc *, Variable *>;
using InferShapeOpPtr = boost::variant<const OpDesc *, const OperatorBase *>;
class InferShapeContext { class InferShapeContext {
public: public:
...@@ -41,8 +40,6 @@ class InferShapeContext { ...@@ -41,8 +40,6 @@ class InferShapeContext {
std::vector<proto::VarType::Type> GetOutputsVarType( std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string &name) const; const std::string &name) const;
virtual InferShapeOpPtr GetOp() const = 0;
virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasInputs(const std::string &name) const = 0;
virtual bool HasOutputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0;
......
...@@ -91,63 +91,65 @@ static void CallPythonFunc(py::object *callable, ...@@ -91,63 +91,65 @@ static void CallPythonFunc(py::object *callable,
} }
} }
class PyFuncOpShapeInference : public framework::InferShapeBase { class PyFuncOpVarTypInference : public framework::VarTypeInference {
public: public:
void operator()(framework::InferShapeContext *ctx) const override { void operator()(const framework::OpDesc &op,
PADDLE_ENFORCE(!ctx->IsRuntime(), framework::BlockDesc *block) const override {
"Infer shape cannot be called in runtime."); auto &outs = op.Outputs();
bool has_out = (outs.count("Out") > 0 && !outs.at("Out").empty());
auto &ins = op.Inputs();
bool has_in = (ins.count("X") > 0 && !ins.at("X").empty());
/** /**
* X or Out can be empty, so that py_func can be more flexible * X or Out can be empty, so that py_func can be more flexible
* to support Python functions with no input or no output * to support Python functions with no input or no output
*/ */
PADDLE_ENFORCE(ctx->HasInputs("X") || ctx->HasOutputs("Out"), PADDLE_ENFORCE(has_in || has_out, "Input(X) or Output(Out) must exist");
"Input(X) or Output(Out) must exist");
PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>(kForwardPythonCallableId), 0, PADDLE_ENFORCE_GE(boost::get<int>(op.GetAttr(kForwardPythonCallableId)), 0,
"Function id cannot be less than 0"); "Function id cannot be less than 0");
if (!has_out) return;
/** /**
* Traverse all outputs, check if name of any output ends with @GRAD. * Traverse all outputs, check if name of any output ends with @GRAD.
* If found, set its shape, dtype, lod_level, type to be the same as * If found, set its shape, dtype, lod_level, type to be the same as
* the corresponding forward variable * the corresponding forward variable
*
* Why not get input dims from InferShapeContext?
* Because some variables in forward inputs/outputs may not be needed
* in backward. Those variables are not inside InferShapeContext.
*
* InferShape would be only called in compile time. During runtime,
* the shapes of outputs should be guaranteed by user-defined Python
* functions.
*/ */
auto *op = boost::get<const framework::OpDesc *>(ctx->GetOp());
auto *block = op->Block();
const std::string kGradVarSuffix = framework::kGradVarSuffix; const std::string kGradVarSuffix = framework::kGradVarSuffix;
auto out_vars = ctx->GetOutputVarPtrs("Out"); auto &out_var_names = outs.at("Out");
for (auto &out_var : out_vars) { for (auto &out_var_name : out_var_names) {
auto *out_var_desc = boost::get<framework::VarDesc *>(out_var); if (out_var_name == framework::kEmptyVarName ||
if (out_var_desc == nullptr) { out_var_name.size() < kGradVarSuffix.size()) {
continue;
}
auto out_name = out_var_desc->Name();
if (out_name == framework::kEmptyVarName ||
out_name.size() < kGradVarSuffix.size()) {
continue; continue;
} }
size_t len = out_name.size() - kGradVarSuffix.size(); size_t len = out_var_name.size() - kGradVarSuffix.size();
if (out_name.substr(len) == kGradVarSuffix) { if (out_var_name.substr(len) == kGradVarSuffix) {
auto fwd_var_name = out_name.substr(0, len); auto fwd_var_name = out_var_name.substr(0, len);
auto *in_var_desc = block->FindVarRecursive(fwd_var_name); auto *out_var_desc = block->FindVarRecursive(out_var_name);
PADDLE_ENFORCE_NOT_NULL(in_var_desc, "Forward variable %s not found", auto *fwd_var_desc = block->FindVarRecursive(fwd_var_name);
PADDLE_ENFORCE_NOT_NULL(out_var_desc, "Backward variable %s not found",
out_var_name);
PADDLE_ENFORCE_NOT_NULL(fwd_var_desc, "Forward variable %s not found",
fwd_var_name); fwd_var_name);
VLOG(10) << "Infer shape of Output(" << out_name << ") as Input(" VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input("
<< in_var_desc->Name() << ")"; << fwd_var_name << ")";
out_var_desc->SetShape(in_var_desc->GetShape()); out_var_desc->SetShape(fwd_var_desc->GetShape());
out_var_desc->SetDataType(in_var_desc->GetDataType()); out_var_desc->SetDataType(fwd_var_desc->GetDataType());
out_var_desc->SetLoDLevel(in_var_desc->GetLoDLevel()); out_var_desc->SetLoDLevel(fwd_var_desc->GetLoDLevel());
out_var_desc->SetType(in_var_desc->GetType()); out_var_desc->SetType(fwd_var_desc->GetType());
}
} }
} }
};
class PyFuncOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(!ctx->IsRuntime(),
"Infer shape cannot be called in runtime.");
} }
}; };
...@@ -307,4 +309,5 @@ class PyFuncOp : public framework::OperatorBase { ...@@ -307,4 +309,5 @@ class PyFuncOp : public framework::OperatorBase {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(py_func, ops::PyFuncOp, ops::PyFuncOpMaker, REGISTER_OPERATOR(py_func, ops::PyFuncOp, ops::PyFuncOpMaker,
ops::PyFuncOpShapeInference, ops::PyFuncOpGradDescMaker); ops::PyFuncOpVarTypInference, ops::PyFuncOpShapeInference,
ops::PyFuncOpGradDescMaker);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册