提交 8364688c 编写于 作者: M minqiyang

Fix py_func_op's problem

上级 b40e41fb
......@@ -93,9 +93,9 @@ execution.
class GetPlacesInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext &ctx) const override {
for (auto &o_name : ctx.Output("Out")) {
ctx.SetType(o_name, framework::proto::VarType::PLACE_LIST);
void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &o_name : ctx->Output("Out")) {
ctx->SetType(o_name, framework::proto::VarType::PLACE_LIST);
}
}
};
......
......@@ -99,7 +99,7 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
void operator()(framework::InferVarTypeContext *ctx) const override {
bool has_out = (ctx->HasOutput("Out") && !ctx->Output("Out").empty());
bool has_in = (ctx->HasInput("X") && !ctx->Input("Out").empty());
bool has_in = (ctx->HasInput("X") && !ctx->Input("X").empty());
/**
* X or Out can be empty, so that py_func can be more flexible
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册