提交 490eb906 编写于 作者: S sneaxiy

polish infer shape of py_func op

test=develop
上级 dc8847af
......@@ -34,8 +34,6 @@ class CompileTimeInferShapeContext : public InferShapeContext {
public:
CompileTimeInferShapeContext(const OpDesc &op, const BlockDesc &block);
InferShapeOpPtr GetOp() const override { return &op_; }
bool HasInput(const std::string &name) const override;
bool HasOutput(const std::string &name) const override;
......
......@@ -481,8 +481,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {}
InferShapeOpPtr GetOp() const override { return &op_; }
bool HasInput(const std::string& name) const override {
// has only one input
const auto& ins = op_.Inputs();
......
......@@ -28,7 +28,6 @@ namespace framework {
class OperatorBase;
using InferShapeVarPtr = boost::variant<VarDesc *, Variable *>;
using InferShapeOpPtr = boost::variant<const OpDesc *, const OperatorBase *>;
class InferShapeContext {
public:
......@@ -41,8 +40,6 @@ class InferShapeContext {
std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string &name) const;
virtual InferShapeOpPtr GetOp() const = 0;
virtual bool HasInputs(const std::string &name) const = 0;
virtual bool HasOutputs(const std::string &name) const = 0;
......
......@@ -91,63 +91,65 @@ static void CallPythonFunc(py::object *callable,
}
}
class PyFuncOpShapeInference : public framework::InferShapeBase {
class PyFuncOpVarTypInference : public framework::VarTypeInference {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(!ctx->IsRuntime(),
"Infer shape cannot be called in runtime.");
void operator()(const framework::OpDesc &op,
framework::BlockDesc *block) const override {
auto &outs = op.Outputs();
bool has_out = (outs.count("Out") > 0 && !outs.at("Out").empty());
auto &ins = op.Inputs();
bool has_in = (ins.count("X") > 0 && !ins.at("X").empty());
/**
* X or Out can be empty, so that py_func can be more flexible
* to support Python functions with no input or no output
*/
PADDLE_ENFORCE(ctx->HasInputs("X") || ctx->HasOutputs("Out"),
"Input(X) or Output(Out) must exist");
PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>(kForwardPythonCallableId), 0,
PADDLE_ENFORCE(has_in || has_out, "Input(X) or Output(Out) must exist");
PADDLE_ENFORCE_GE(boost::get<int>(op.GetAttr(kForwardPythonCallableId)), 0,
"Function id cannot be less than 0");
if (!has_out) return;
/**
* Traverse all outputs, check if name of any output ends with @GRAD.
* If found, set its shape, dtype, lod_level, type to be the same as
* the corresponding forward variable
*
* Why not get input dims from InferShapeContext?
* Because some variables in forward inputs/outputs may not be needed
* in backward. Those variables are not inside InferShapeContext.
*
* InferShape would be only called in compile time. During runtime,
* the shapes of outputs should be guaranteed by user-defined Python
* functions.
*/
auto *op = boost::get<const framework::OpDesc *>(ctx->GetOp());
auto *block = op->Block();
const std::string kGradVarSuffix = framework::kGradVarSuffix;
auto out_vars = ctx->GetOutputVarPtrs("Out");
for (auto &out_var : out_vars) {
auto *out_var_desc = boost::get<framework::VarDesc *>(out_var);
if (out_var_desc == nullptr) {
continue;
}
auto out_name = out_var_desc->Name();
if (out_name == framework::kEmptyVarName ||
out_name.size() < kGradVarSuffix.size()) {
auto &out_var_names = outs.at("Out");
for (auto &out_var_name : out_var_names) {
if (out_var_name == framework::kEmptyVarName ||
out_var_name.size() < kGradVarSuffix.size()) {
continue;
}
size_t len = out_name.size() - kGradVarSuffix.size();
if (out_name.substr(len) == kGradVarSuffix) {
auto fwd_var_name = out_name.substr(0, len);
auto *in_var_desc = block->FindVarRecursive(fwd_var_name);
PADDLE_ENFORCE_NOT_NULL(in_var_desc, "Forward variable %s not found",
size_t len = out_var_name.size() - kGradVarSuffix.size();
if (out_var_name.substr(len) == kGradVarSuffix) {
auto fwd_var_name = out_var_name.substr(0, len);
auto *out_var_desc = block->FindVarRecursive(out_var_name);
auto *fwd_var_desc = block->FindVarRecursive(fwd_var_name);
PADDLE_ENFORCE_NOT_NULL(out_var_desc, "Backward variable %s not found",
out_var_name);
PADDLE_ENFORCE_NOT_NULL(fwd_var_desc, "Forward variable %s not found",
fwd_var_name);
VLOG(10) << "Infer shape of Output(" << out_name << ") as Input("
<< in_var_desc->Name() << ")";
out_var_desc->SetShape(in_var_desc->GetShape());
out_var_desc->SetDataType(in_var_desc->GetDataType());
out_var_desc->SetLoDLevel(in_var_desc->GetLoDLevel());
out_var_desc->SetType(in_var_desc->GetType());
VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input("
<< fwd_var_name << ")";
out_var_desc->SetShape(fwd_var_desc->GetShape());
out_var_desc->SetDataType(fwd_var_desc->GetDataType());
out_var_desc->SetLoDLevel(fwd_var_desc->GetLoDLevel());
out_var_desc->SetType(fwd_var_desc->GetType());
}
}
}
};
class PyFuncOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(!ctx->IsRuntime(),
"Infer shape cannot be called in runtime.");
}
};
......@@ -307,4 +309,5 @@ class PyFuncOp : public framework::OperatorBase {
namespace ops = paddle::operators;
REGISTER_OPERATOR(py_func, ops::PyFuncOp, ops::PyFuncOpMaker,
ops::PyFuncOpShapeInference, ops::PyFuncOpGradDescMaker);
ops::PyFuncOpVarTypInference, ops::PyFuncOpShapeInference,
ops::PyFuncOpGradDescMaker);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册