未验证 提交 91ae7848 编写于 作者: L liuwei1031 提交者: GitHub

improve efficiency of runtime InferVarType (#22778) (#24181)

 * cherry pick #22778
上级 57b062e1
...@@ -45,19 +45,13 @@ class SumOpMaker : public OpProtoAndCheckerMaker { ...@@ -45,19 +45,13 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class SumOpVarTypeInference : public VarTypeInference { class SumOpVarTypeInference : public VarTypeInference {
public: public:
void operator()(InferVarTypeContext *ctx) const override { void operator()(InferVarTypeContext *ctx) const override {
auto &inputs = ctx->Input("X");
auto default_var_type = proto::VarType::SELECTED_ROWS; auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of( if (ctx->InputTypeAnyOf("X", proto::VarType::LOD_TENSOR)) {
inputs.begin(), inputs.end(), [&ctx](const std::string &name) {
return ctx->GetType(name) == proto::VarType::LOD_TENSOR;
});
if (any_input_is_lod_tensor) {
default_var_type = proto::VarType::LOD_TENSOR; default_var_type = proto::VarType::LOD_TENSOR;
} }
auto out_var_name = ctx->Output("Out").front(); ctx->SetOutputType("Out", default_var_type);
ctx->SetType(out_var_name, default_var_type);
} }
}; };
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <algorithm>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -25,8 +26,14 @@ namespace framework { ...@@ -25,8 +26,14 @@ namespace framework {
class OpDesc; class OpDesc;
class BlockDesc; class BlockDesc;
class StaticGraphVarTypeInference;
// default infer var type context // default infer var type context
static const int ALL_ELEMENTS = -1;
class InferVarTypeContext { class InferVarTypeContext {
friend class StaticGraphVarTypeInference;
public: public:
InferVarTypeContext(const OpDesc* op, BlockDesc* block) InferVarTypeContext(const OpDesc* op, BlockDesc* block)
: op_(op), block_(block) {} : op_(op), block_(block) {}
...@@ -34,91 +41,267 @@ class InferVarTypeContext { ...@@ -34,91 +41,267 @@ class InferVarTypeContext {
virtual ~InferVarTypeContext() {} virtual ~InferVarTypeContext() {}
virtual Attribute GetAttr(const std::string& name) const { virtual Attribute GetAttr(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_); PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
return op_->GetAttr(name); return op_->GetAttr(name);
} }
virtual bool HasVar(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindVarRecursive(name) != nullptr;
}
virtual bool HasInput(const std::string& name) const { virtual bool HasInput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_); PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
auto& inputs = op_->Inputs(); auto& inputs = op_->Inputs();
auto input = inputs.find(name); auto input = inputs.find(name);
return input != inputs.end() && !input->second.empty(); return input != inputs.end() && !input->second.empty();
} }
virtual bool HasOutput(const std::string& name) const { virtual bool HasOutput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_); PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
auto& outputs = op_->Outputs(); auto& outputs = op_->Outputs();
auto output = outputs.find(name); auto output = outputs.find(name);
return output != outputs.end() && !output->second.empty(); return output != outputs.end() && !output->second.empty();
} }
virtual const std::vector<std::string>& Input(const std::string& name) const { virtual size_t InputSize(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_); PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
return op_->Inputs().at(name).size();
}
virtual const std::string& InputVarName(const std::string& name,
const int index = 0) const {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
return op_->Inputs().at(name)[index];
}
virtual bool InputTypeAnyOf(const std::string& name,
proto::VarType::Type type) const {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
auto& inputs = op_->Input(name);
return std::any_of(inputs.begin(), inputs.end(),
[this, &type](const std::string& name) {
return this->GetVarType(name) == type;
});
}
virtual bool InputTypeAllOf(const std::string& name,
proto::VarType::Type type) const {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
auto& inputs = op_->Input(name);
return std::all_of(inputs.begin(), inputs.end(),
[this, &type](const std::string& name) {
return this->GetVarType(name) == type;
});
}
virtual void SyncTypeAndDataType(const std::string& input_name,
const std::string& output_name,
int index = 0) {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
auto& x_name = op_->Input(input_name).at(index);
auto& out_name = op_->Output(output_name).at(index);
if (x_name != out_name) {
this->SetVarType(out_name, this->GetVarType(x_name));
this->SetVarDataType(out_name, this->GetVarDataType(x_name));
}
}
virtual void SetOutputType(const std::string& name, proto::VarType::Type type,
int index = 0) {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
if (ALL_ELEMENTS == index) {
for (const auto& var_name : op_->Output(name)) {
this->SetVarType(var_name, type);
}
} else {
auto& var_name = op_->Output(name).at(index);
this->SetVarType(var_name, type);
}
}
virtual proto::VarType::Type GetInputType(const std::string& name,
const int& index = 0) const {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
return this->GetVarType(op_->Input(name).at(index));
}
virtual proto::VarType::Type GetOutputType(const std::string& name,
const int& index = 0) const {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
return this->GetVarType(op_->Output(name).at(index));
}
virtual proto::VarType::Type GetInputDataType(const std::string& name,
const int& index = 0) const {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
return this->GetVarDataType(op_->Input(name).at(index));
}
virtual void SetOutputDataType(const std::string& name,
proto::VarType::Type type, int index = 0) {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
if (ALL_ELEMENTS == index) {
for (const auto& var_name : op_->Output(name)) {
this->SetVarDataType(var_name, type);
}
} else {
auto& var_name = op_->Output(name).at(index);
this->SetVarDataType(var_name, type);
}
}
virtual std::vector<proto::VarType::Type> GetInputDataTypes(
const std::string& name, const int& index = 0) const {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
return this->GetVarDataTypes(op_->Input(name).at(index));
}
virtual void SetOutputDataTypes(
const std::string& name,
const std::vector<proto::VarType::Type>& multiple_data_type,
const int& index = 0) {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
auto& var_name = op_->Output(name).at(index);
this->SetVarDataTypes(var_name, multiple_data_type);
}
virtual std::vector<int64_t> GetInputShape(const std::string& name,
const int& index = 0) const {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
auto& var_name = op_->Input(name).at(index);
return this->GetVarShape(var_name);
}
virtual void SetOutputShape(const std::string& name,
const std::vector<int64_t>& dims,
const int& index = 0) {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
auto& var_name = op_->Output(name).at(index);
this->SetVarShape(var_name, dims);
}
virtual int32_t GetInputLoDLevel(const std::string& name,
const int& index = 0) const {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
auto& var_name = op_->Input(name).at(index);
return this->GetVarLoDLevel(var_name);
}
virtual void SetOutputLoDLevel(const std::string& name, int32_t lod_level,
const int& index = 0) {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
auto& var_name = op_->Output(name).at(index);
this->SetVarLoDLevel(var_name, lod_level);
}
// add a speical API for save_op
// avoid use this API for common logic
virtual void InsertVar(const std::string& var_name,
proto::VarType::Type var_type) {
if (!IsDygraph()) this->SetVarType(var_name, var_type);
}
virtual bool IsDygraph() const { return false; }
protected:
virtual bool HasVar(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet(
"block_ should not be null"));
return block_->FindVarRecursive(name) != nullptr;
}
virtual const std::vector<std::string>& InputVars(
const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
return op_->Input(name); return op_->Input(name);
} }
virtual const std::vector<std::string>& Output( virtual const std::vector<std::string>& OutputVars(
const std::string& name) const { const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_); PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
return op_->Output(name); return op_->Output(name);
} }
virtual proto::VarType::Type GetType(const std::string& name) const { virtual proto::VarType::Type GetVarType(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_); PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet(
"block_ should not be null"));
return block_->FindRecursiveOrCreateVar(name).GetType(); return block_->FindRecursiveOrCreateVar(name).GetType();
} }
virtual void SetType(const std::string& name, proto::VarType::Type type) { virtual void SetVarType(const std::string& name, proto::VarType::Type type) {
PADDLE_ENFORCE_NOT_NULL(block_); PADDLE_ENFORCE_NOT_NULL(
block_, platform::errors::PreconditionNotMet("op_ should not be null"));
block_->FindRecursiveOrCreateVar(name).SetType(type); block_->FindRecursiveOrCreateVar(name).SetType(type);
} }
virtual proto::VarType::Type GetDataType(const std::string& name) const { virtual proto::VarType::Type GetVarDataType(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_); PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet(
"block_ should not be null"));
return block_->FindRecursiveOrCreateVar(name).GetDataType(); return block_->FindRecursiveOrCreateVar(name).GetDataType();
} }
virtual void SetDataType(const std::string& name, proto::VarType::Type type) { virtual void SetVarDataType(const std::string& name,
PADDLE_ENFORCE_NOT_NULL(block_); proto::VarType::Type type) {
PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet(
"block_ should not be null"));
block_->FindRecursiveOrCreateVar(name).SetDataType(type); block_->FindRecursiveOrCreateVar(name).SetDataType(type);
} }
virtual std::vector<proto::VarType::Type> GetDataTypes( virtual std::vector<proto::VarType::Type> GetVarDataTypes(
const std::string& name) const { const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_); PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet(
"block_ should not be null"));
return block_->FindRecursiveOrCreateVar(name).GetDataTypes(); return block_->FindRecursiveOrCreateVar(name).GetDataTypes();
} }
virtual void SetDataTypes( virtual void SetVarDataTypes(
const std::string& name, const std::string& name,
const std::vector<proto::VarType::Type>& multiple_data_type) { const std::vector<proto::VarType::Type>& multiple_data_type) {
PADDLE_ENFORCE_NOT_NULL(block_); PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet(
"block_ should not be null"));
block_->FindRecursiveOrCreateVar(name).SetDataTypes(multiple_data_type); block_->FindRecursiveOrCreateVar(name).SetDataTypes(multiple_data_type);
} }
virtual std::vector<int64_t> GetShape(const std::string& name) const { virtual std::vector<int64_t> GetVarShape(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_); PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet(
"block_ should not be null"));
return block_->FindRecursiveOrCreateVar(name).GetShape(); return block_->FindRecursiveOrCreateVar(name).GetShape();
} }
virtual void SetShape(const std::string& name, virtual void SetVarShape(const std::string& name,
const std::vector<int64_t>& dims) { const std::vector<int64_t>& dims) {
PADDLE_ENFORCE_NOT_NULL(block_); PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet(
"block_ should not be null"));
block_->FindRecursiveOrCreateVar(name).SetShape(dims); block_->FindRecursiveOrCreateVar(name).SetShape(dims);
} }
virtual int32_t GetLoDLevel(const std::string& name) const { virtual int32_t GetVarLoDLevel(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_); PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet(
"block_ should not be null"));
return block_->FindRecursiveOrCreateVar(name).GetLoDLevel(); return block_->FindRecursiveOrCreateVar(name).GetLoDLevel();
} }
virtual void SetLoDLevel(const std::string& name, int32_t lod_level) { virtual void SetVarLoDLevel(const std::string& name, int32_t lod_level) {
PADDLE_ENFORCE_NOT_NULL(block_); PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet(
"block_ should not be null"));
block_->FindRecursiveOrCreateVar(name).SetLoDLevel(lod_level); block_->FindRecursiveOrCreateVar(name).SetLoDLevel(lod_level);
} }
...@@ -133,22 +316,85 @@ class VarTypeInference { ...@@ -133,22 +316,85 @@ class VarTypeInference {
virtual void operator()(InferVarTypeContext* context) const = 0; // NOLINT virtual void operator()(InferVarTypeContext* context) const = 0; // NOLINT
}; };
class StaticGraphVarTypeInference : public VarTypeInference {
protected:
bool HasVar(InferVarTypeContext* ctx, const std::string& name) const {
return ctx->HasVar(name);
}
const std::vector<std::string>& Input(InferVarTypeContext* ctx,
const std::string& name) const {
return ctx->InputVars(name);
}
const std::vector<std::string>& Output(InferVarTypeContext* ctx,
const std::string& name) const {
return ctx->OutputVars(name);
}
proto::VarType::Type GetType(InferVarTypeContext* ctx,
const std::string& name) const {
return ctx->GetVarType(name);
}
void SetType(InferVarTypeContext* ctx, const std::string& name,
proto::VarType::Type type) const {
ctx->SetVarType(name, type);
}
proto::VarType::Type GetDataType(InferVarTypeContext* ctx,
const std::string& name) const {
return ctx->GetVarDataType(name);
}
void SetDataType(InferVarTypeContext* ctx, const std::string& name,
proto::VarType::Type type) const {
ctx->SetVarDataType(name, type);
}
std::vector<proto::VarType::Type> GetDataTypes(
InferVarTypeContext* ctx, const std::string& name) const {
return ctx->GetVarDataTypes(name);
}
void SetDataTypes(
InferVarTypeContext* ctx, const std::string& name,
const std::vector<proto::VarType::Type>& multiple_data_type) {
return ctx->SetVarDataTypes(name, multiple_data_type);
}
std::vector<int64_t> GetShape(InferVarTypeContext* ctx,
const std::string& name) const {
return ctx->GetVarShape(name);
}
void SetShape(InferVarTypeContext* ctx, const std::string& name,
const std::vector<int64_t>& dims) const {
ctx->SetVarShape(name, dims);
}
int32_t GetLoDLevel(InferVarTypeContext* ctx, const std::string& name) const {
return ctx->GetVarLoDLevel(name);
}
void SetLoDLevel(InferVarTypeContext* ctx, const std::string& name,
int32_t lod_level) const {
ctx->SetVarLoDLevel(name, lod_level);
}
};
class PassInDtypeAndVarTypeToOutput : public framework::VarTypeInference { class PassInDtypeAndVarTypeToOutput : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext* ctx) const final { // NOLINT void operator()(framework::InferVarTypeContext* ctx) const final { // NOLINT
auto in_out_var_names = this->GetInputOutputWithSameType(); auto& in_out_var_names = this->GetInputOutputWithSameType();
for (auto& i_o_n : in_out_var_names) { for (auto& i_o_n : in_out_var_names) {
auto& x_name = ctx->Input(i_o_n.first).at(0); ctx->SyncTypeAndDataType(i_o_n.first, i_o_n.second);
auto& out_name = ctx->Output(i_o_n.second).at(0);
ctx->SetType(out_name, ctx->GetType(x_name));
ctx->SetDataType(out_name, ctx->GetDataType(x_name));
} }
} }
protected: protected:
virtual std::unordered_map<std::string, std::string> virtual std::unordered_map<std::string, std::string>&
GetInputOutputWithSameType() const = 0; GetInputOutputWithSameType() const = 0;
}; };
......
...@@ -24,13 +24,13 @@ namespace framework { ...@@ -24,13 +24,13 @@ namespace framework {
class NOP : public OperatorBase { class NOP : public OperatorBase {
public: public:
NOP(const std::string &type, const VariableNameMap &inputs, NOP(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap &outputs, const AttributeMap &attrs) const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
private: private:
void RunImpl(const Scope &scope, void RunImpl(const Scope& scope,
const platform::Place &place) const override {} const platform::Place& place) const override {}
}; };
class SumOpMaker : public OpProtoAndCheckerMaker { class SumOpMaker : public OpProtoAndCheckerMaker {
...@@ -44,20 +44,14 @@ class SumOpMaker : public OpProtoAndCheckerMaker { ...@@ -44,20 +44,14 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class SumOpVarTypeInference : public VarTypeInference { class SumOpVarTypeInference : public VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext* ctx) const override {
auto &inputs = ctx->Input("X");
auto default_var_type = proto::VarType::SELECTED_ROWS; auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of( if (ctx->InputTypeAnyOf("X", proto::VarType::LOD_TENSOR)) {
inputs.begin(), inputs.end(), [&ctx](const std::string &name) {
return ctx->GetType(name) == proto::VarType::LOD_TENSOR;
});
if (any_input_is_lod_tensor) {
default_var_type = proto::VarType::LOD_TENSOR; default_var_type = proto::VarType::LOD_TENSOR;
} }
auto out_var_name = ctx->Output("Out").front(); ctx->SetOutputType("Out", default_var_type);
ctx->SetType(out_var_name, default_var_type);
} }
}; };
} // namespace framework } // namespace framework
...@@ -71,9 +65,79 @@ REGISTER_OPERATOR(sum_without_infer_var_type, paddle::framework::NOP, ...@@ -71,9 +65,79 @@ REGISTER_OPERATOR(sum_without_infer_var_type, paddle::framework::NOP,
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class TestStaticGraphVarTypeInference : public StaticGraphVarTypeInference {
public:
void operator()(InferVarTypeContext* context) const override {}
bool HasVar(InferVarTypeContext* ctx, const std::string& name) const {
return StaticGraphVarTypeInference::HasVar(ctx, name);
}
const std::vector<std::string>& Input(InferVarTypeContext* ctx,
const std::string& name) const {
return StaticGraphVarTypeInference::Input(ctx, name);
}
const std::vector<std::string>& Output(InferVarTypeContext* ctx,
const std::string& name) const {
return StaticGraphVarTypeInference::Output(ctx, name);
}
proto::VarType::Type GetType(InferVarTypeContext* ctx,
const std::string& name) const {
return StaticGraphVarTypeInference::GetType(ctx, name);
}
void SetType(InferVarTypeContext* ctx, const std::string& name,
proto::VarType::Type type) const {
StaticGraphVarTypeInference::SetType(ctx, name, type);
}
proto::VarType::Type GetDataType(InferVarTypeContext* ctx,
const std::string& name) const {
return StaticGraphVarTypeInference::GetDataType(ctx, name);
}
void SetDataType(InferVarTypeContext* ctx, const std::string& name,
proto::VarType::Type type) const {
StaticGraphVarTypeInference::SetDataType(ctx, name, type);
}
std::vector<proto::VarType::Type> GetDataTypes(
InferVarTypeContext* ctx, const std::string& name) const {
return StaticGraphVarTypeInference::GetDataTypes(ctx, name);
}
void SetDataTypes(
InferVarTypeContext* ctx, const std::string& name,
const std::vector<proto::VarType::Type>& multiple_data_type) {
return StaticGraphVarTypeInference::SetDataTypes(ctx, name,
multiple_data_type);
}
std::vector<int64_t> GetShape(InferVarTypeContext* ctx,
const std::string& name) const {
return StaticGraphVarTypeInference::GetShape(ctx, name);
}
void SetShape(InferVarTypeContext* ctx, const std::string& name,
const std::vector<int64_t>& dims) const {
StaticGraphVarTypeInference::SetShape(ctx, name, dims);
}
int32_t GetLoDLevel(InferVarTypeContext* ctx, const std::string& name) const {
return StaticGraphVarTypeInference::GetLoDLevel(ctx, name);
}
void SetLoDLevel(InferVarTypeContext* ctx, const std::string& name,
int32_t lod_level) const {
StaticGraphVarTypeInference::SetLoDLevel(ctx, name, lod_level);
}
};
TEST(InferVarType, sum_op) { TEST(InferVarType, sum_op) {
ProgramDesc prog; ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp(); auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum"); op->SetType("sum");
op->SetInput("X", {"test_a", "test_b", "test_c"}); op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"}); op->SetOutput("Out", {"test_out"});
...@@ -96,7 +160,7 @@ TEST(InferVarType, sum_op) { ...@@ -96,7 +160,7 @@ TEST(InferVarType, sum_op) {
TEST(InferVarType, sum_op_without_infer_var_type) { TEST(InferVarType, sum_op_without_infer_var_type) {
ProgramDesc prog; ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp(); auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum_without_infer_var_type"); op->SetType("sum_without_infer_var_type");
op->SetInput("X", {"test2_a", "test2_b", "test2_c"}); op->SetInput("X", {"test2_a", "test2_b", "test2_c"});
op->SetOutput("Out", {"test2_out"}); op->SetOutput("Out", {"test2_out"});
...@@ -112,5 +176,112 @@ TEST(InferVarType, sum_op_without_infer_var_type) { ...@@ -112,5 +176,112 @@ TEST(InferVarType, sum_op_without_infer_var_type) {
prog.MutableBlock(0)->Var("test2_out")->GetType()); prog.MutableBlock(0)->Var("test2_out")->GetType());
} }
TEST(InferVarType, multiple_api) {
ProgramDesc prog;
auto* block = prog.MutableBlock(0);
auto* op = block->AppendOp();
op->SetType("sum_without_infer_var_type");
op->SetInput("X", {"test2_a", "test2_b"});
op->SetOutput("Out", {"test2_a_out", "test2_b_out"});
block->Var("test2_a")->SetType(proto::VarType::SELECTED_ROWS);
block->Var("test2_b")->SetType(proto::VarType::SELECTED_ROWS);
block->Var("test2_a_out");
block->Var("test2_b_out");
InferVarTypeContext ctx(op, block);
ASSERT_TRUE(ctx.HasInput("X"));
ASSERT_TRUE(ctx.HasOutput("Out"));
ASSERT_EQ(2u, ctx.InputSize("X"));
ASSERT_EQ("test2_a", ctx.InputVarName("X", 0));
ASSERT_EQ(proto::VarType::SELECTED_ROWS, ctx.GetInputType("X"));
ASSERT_TRUE(ctx.InputTypeAllOf("X", proto::VarType::SELECTED_ROWS));
ASSERT_FALSE(ctx.InputTypeAnyOf("X", proto::VarType::LOD_TENSOR));
ctx.SyncTypeAndDataType("X", "Out");
ASSERT_EQ(proto::VarType::SELECTED_ROWS, ctx.GetOutputType("Out"));
ASSERT_EQ(proto::VarType::LOD_TENSOR, ctx.GetOutputType("Out", 1));
ctx.SetOutputType("Out", proto::VarType::SELECTED_ROWS, ALL_ELEMENTS);
ctx.SetOutputType("Out", proto::VarType::LOD_TENSOR, 1);
ASSERT_EQ(proto::VarType::SELECTED_ROWS, ctx.GetOutputType("Out"));
ASSERT_EQ(proto::VarType::LOD_TENSOR, ctx.GetOutputType("Out", 1));
ASSERT_EQ(0, ctx.GetInputDataType("X"));
ctx.SetOutputDataType("Out", proto::VarType::FP32, ALL_ELEMENTS);
ctx.SetOutputDataType("Out", proto::VarType::INT8, 1);
ASSERT_EQ(proto::VarType::FP32,
prog.MutableBlock(0)->Var("test2_a_out")->GetDataType());
ASSERT_EQ(proto::VarType::INT8,
prog.MutableBlock(0)->Var("test2_b_out")->GetDataType());
ASSERT_FALSE(ctx.IsDygraph());
// test StaticGraphVarTypeInference
TestStaticGraphVarTypeInference infer;
ASSERT_TRUE(infer.HasVar(&ctx, "test2_a"));
ASSERT_EQ(infer.Input(&ctx, "X").size(), infer.Output(&ctx, "Out").size());
ASSERT_EQ(proto::VarType::FP32, infer.GetDataType(&ctx, "test2_a_out"));
infer.SetDataType(&ctx, "test2_a_out", proto::VarType::FP64);
ASSERT_EQ(proto::VarType::FP64, infer.GetDataType(&ctx, "test2_a_out"));
ASSERT_EQ(proto::VarType::SELECTED_ROWS, infer.GetType(&ctx, "test2_a_out"));
infer.SetType(&ctx, "test2_a_out", proto::VarType::LOD_TENSOR);
ASSERT_EQ(proto::VarType::LOD_TENSOR, infer.GetType(&ctx, "test2_a_out"));
ASSERT_ANY_THROW(infer.GetDataTypes(&ctx, "test2_a_out"));
ASSERT_ANY_THROW(infer.SetDataTypes(&ctx, "test2_a_out", {}));
ASSERT_EQ(0u, infer.GetShape(&ctx, "test2_a_out").size());
infer.SetShape(&ctx, "test2_a_out", {
1, 3, 3,
});
ASSERT_EQ(3u, infer.GetShape(&ctx, "test2_a_out").size());
ASSERT_EQ(0, infer.GetLoDLevel(&ctx, "test2_a_out"));
infer.SetLoDLevel(&ctx, "test2_a_out", 2);
ASSERT_EQ(2, infer.GetLoDLevel(&ctx, "test2_a_out"));
}
TEST(InferVarType, test_enforce_check) {
InferVarTypeContext ctx(nullptr, nullptr);
ASSERT_ANY_THROW(ctx.HasInput("X"));
ASSERT_ANY_THROW(ctx.HasOutput("Out"));
ASSERT_ANY_THROW(ctx.InputSize("X"));
ASSERT_ANY_THROW(ctx.InputVarName("X"));
ASSERT_ANY_THROW(ctx.InputTypeAnyOf("X", proto::VarType::LOD_TENSOR));
ASSERT_ANY_THROW(ctx.InputTypeAllOf("X", proto::VarType::LOD_TENSOR));
ASSERT_ANY_THROW(ctx.SyncTypeAndDataType("X", "Out"));
ASSERT_ANY_THROW(ctx.SetOutputType("Out", proto::VarType::LOD_TENSOR));
ASSERT_ANY_THROW(ctx.GetInputType("X"));
ASSERT_ANY_THROW(ctx.GetOutputType("Out"));
ASSERT_ANY_THROW(ctx.GetInputDataType("X"));
ASSERT_ANY_THROW(ctx.SetOutputDataType("Out", proto::VarType::LOD_TENSOR));
ASSERT_ANY_THROW(ctx.GetInputDataTypes("X"));
ASSERT_ANY_THROW(ctx.SetOutputDataTypes("Out", {}));
ASSERT_ANY_THROW(ctx.GetInputShape("X"));
ASSERT_ANY_THROW(ctx.SetOutputShape("Out", {}));
ASSERT_ANY_THROW(ctx.GetInputLoDLevel("X"));
ASSERT_ANY_THROW(ctx.SetOutputLoDLevel("Out", 1));
ASSERT_ANY_THROW(ctx.InsertVar("var", proto::VarType::LOD_TENSOR));
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -35,30 +36,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { ...@@ -35,30 +36,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
: InferVarTypeContext(nullptr, nullptr), : InferVarTypeContext(nullptr, nullptr),
inputs_(inputs), inputs_(inputs),
outputs_(outputs), outputs_(outputs),
attrs_(attrs_map), attrs_(attrs_map) {}
input_names_(),
output_names_(),
var_set_() {
input_names_.reserve(inputs_.size());
for (auto& it : inputs_) {
for (auto& var : it.second) {
if (var) {
input_names_[it.first].emplace_back(var->Name());
var_set_[var->Name()] = var.get();
}
}
}
output_names_.reserve(outputs_.size());
for (auto& it : outputs_) {
for (auto& var : it.second) {
if (var) {
output_names_[it.first].emplace_back(var->Name());
var_set_[var->Name()] = var.get();
}
}
}
}
virtual ~RuntimeInferVarTypeContext() {} virtual ~RuntimeInferVarTypeContext() {}
...@@ -70,10 +48,6 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { ...@@ -70,10 +48,6 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
return iter->second; return iter->second;
} }
bool HasVar(const std::string& name) const override {
return var_set_.count(name) > 0;
}
bool HasInput(const std::string& name) const override { bool HasInput(const std::string& name) const override {
auto it = inputs_.find(name); auto it = inputs_.find(name);
return (it != inputs_.end() && it->second.size() > 0); return (it != inputs_.end() && it->second.size() > 0);
...@@ -84,93 +58,173 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { ...@@ -84,93 +58,173 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
return (it != outputs_.end() && it->second.size() > 0); return (it != outputs_.end() && it->second.size() > 0);
} }
const std::vector<std::string>& Input( size_t InputSize(const std::string& name) const {
const std::string& name) const override { return inputs_.at(name).size();
auto iter = input_names_.find(name);
PADDLE_ENFORCE_EQ(
iter != input_names_.end(), true,
platform::errors::NotFound("Cannot find input var %s", name));
return iter->second;
} }
const std::vector<std::string>& Output( const std::string& InputVarName(const std::string& name,
const std::string& name) const override { const int index = 0) const {
auto iter = output_names_.find(name); return inputs_.at(name)[index]->Name();
}
PADDLE_ENFORCE_EQ( bool InputTypeAnyOf(const std::string& name,
iter != output_names_.end(), true, framework::proto::VarType::Type type) const override {
platform::errors::NotFound("Cannot find output var %s", name)); auto& inputs = inputs_.at(name);
return iter->second; return std::any_of(inputs.begin(), inputs.end(),
[&type](const std::shared_ptr<VarType>& var) {
return var->Type() == type;
});
} }
framework::proto::VarType::Type GetType( bool InputTypeAllOf(const std::string& name,
const std::string& name) const override { framework::proto::VarType::Type type) const override {
auto iter = var_set_.find(name); auto& inputs = inputs_.at(name);
return std::all_of(inputs.begin(), inputs.end(),
[&type](const std::shared_ptr<VarType>& var) {
return var->Type() == type;
});
}
PADDLE_ENFORCE_EQ( void SyncTypeAndDataType(const std::string& input_name,
iter != var_set_.end(), true, const std::string& output_name,
platform::errors::NotFound("Cannot find var %s in GetType", name)); int index = 0) override {
return iter->second->Type(); auto in_var = inputs_.at(input_name)[index];
auto out_var = outputs_.at(output_name)[index];
if (in_var != out_var) {
this->SetVarBaseType(out_var, in_var->Type());
this->SetVarBaseDataType(out_var, in_var->DataType());
}
} }
void SetType(const std::string& name, void SetOutputType(const std::string& name,
framework::proto::VarType::Type type) override { framework::proto::VarType::Type type,
if (name == "kLookupTablePath") { int index = 0) override {
VLOG(2) << "SUPER UGLY FIX, remove this when move imperative mode in C++"; if (index == framework::ALL_ELEMENTS) {
for (auto& item : outputs_.at(name)) {
this->SetVarBaseType(item, type);
}
} else {
auto& var = outputs_.at(name)[index];
this->SetVarBaseType(var, type);
}
}
void SetVarBaseType(std::shared_ptr<VarType> out,
framework::proto::VarType::Type type) {
out->SetType(type);
if ((out->MutableVar()->IsInitialized() == true) &&
(out->MutableVar()->Type() != type)) {
out->MutableVar()->Clear();
}
}
void SetVarBaseDataType(std::shared_ptr<VarType> out,
framework::proto::VarType::Type type) {
out->SetDataType(type);
}
framework::proto::VarType::Type GetInputType(
const std::string& name, const int& index = 0) const override {
return inputs_.at(name)[index]->Type();
}
framework::proto::VarType::Type GetOutputType(
const std::string& name, const int& index = 0) const override {
return outputs_.at(name)[index]->Type();
}
framework::proto::VarType::Type GetInputDataType(
const std::string& name, const int& index = 0) const override {
return inputs_.at(name)[index]->DataType();
}
void SetOutputDataType(const std::string& name,
framework::proto::VarType::Type type,
int index = 0) override {
if (framework::ALL_ELEMENTS == index) {
for (auto& item : outputs_.at(name)) {
this->SetVarBaseDataType(item, type);
}
} else { } else {
var_set_[name]->SetType(type); auto& var = outputs_.at(name)[index];
if ((var_set_[name]->MutableVar()->IsInitialized() == true) && this->SetVarBaseDataType(var, type);
(var_set_[name]->MutableVar()->Type() != type)) {
var_set_[name]->MutableVar()->Clear();
} }
} }
bool IsDygraph() const override { return true; }
protected:
bool HasVar(const std::string& name) const override {
PADDLE_THROW(platform::errors::PermissionDenied(
"HasVar is not supported in runtime InferVarType"));
} }
framework::proto::VarType::Type GetDataType( const std::vector<std::string>& InputVars(
const std::string& name) const override { const std::string& name) const override {
auto iter = var_set_.find(name); PADDLE_THROW(platform::errors::PermissionDenied(
"InputVars is not supported in runtime InferVarType"));
}
PADDLE_ENFORCE_EQ( const std::vector<std::string>& OutputVars(
iter != var_set_.end(), true, const std::string& name) const override {
platform::errors::NotFound("Cannot find var %s in GetDataType", name)); PADDLE_THROW(platform::errors::PermissionDenied(
return iter->second->DataType(); "OutputVars is not supported in runtime InferVarType"));
}
framework::proto::VarType::Type GetVarType(
const std::string& name) const override {
PADDLE_THROW(platform::errors::PermissionDenied(
"Do not manipulate var in runtime InferVarType"));
} }
void SetDataType(const std::string& name, void SetVarType(const std::string& name,
framework::proto::VarType::Type type) override { framework::proto::VarType::Type type) override {
var_set_[name]->SetDataType(type); PADDLE_THROW(platform::errors::PermissionDenied(
"Do not manipulate var in runtime InferVarType"));
}
framework::proto::VarType::Type GetVarDataType(
const std::string& name) const override {
PADDLE_THROW(platform::errors::PermissionDenied(
"Do not manipulate var in runtime InferVarType"));
}
void SetVarDataType(const std::string& name,
framework::proto::VarType::Type type) override {
PADDLE_THROW(platform::errors::PermissionDenied(
"Do not manipulate var in runtime InferVarType"));
} }
std::vector<framework::proto::VarType::Type> GetDataTypes( std::vector<framework::proto::VarType::Type> GetVarDataTypes(
const std::string& name) const override { const std::string& name) const override {
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"GetDataTypes is not supported in runtime InferVarType")); "GetVarDataTypes is not supported in runtime InferVarType"));
} }
void SetDataTypes(const std::string& name, void SetVarDataTypes(const std::string& name,
const std::vector<framework::proto::VarType::Type>& const std::vector<framework::proto::VarType::Type>&
multiple_data_type) override { multiple_data_type) override {
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"SetDataTypes is not supported in runtime InferVarType")); "SetVarDataTypes is not supported in runtime InferVarType"));
} }
std::vector<int64_t> GetShape(const std::string& name) const override { std::vector<int64_t> GetVarShape(const std::string& name) const override {
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"Do not handle Shape in runtime InferVarType")); "Do not handle Shape in runtime InferVarType"));
} }
void SetShape(const std::string& name, void SetVarShape(const std::string& name,
const std::vector<int64_t>& dims) override { const std::vector<int64_t>& dims) override {
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"Do not handle Shape in runtime InferVarType")); "Do not handle Shape in runtime InferVarType"));
} }
int32_t GetLoDLevel(const std::string& name) const override { int32_t GetVarLoDLevel(const std::string& name) const override {
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"Do not handle LoDLevel in runtime InferVarType")); "Do not handle LoDLevel in runtime InferVarType"));
} }
void SetLoDLevel(const std::string& name, int32_t lod_level) override { void SetVarLoDLevel(const std::string& name, int32_t lod_level) override {
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"Do not handle LoDLevel in runtime InferVarType")); "Do not handle LoDLevel in runtime InferVarType"));
} }
...@@ -179,9 +233,6 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { ...@@ -179,9 +233,6 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
const NameVarMap<VarType>& inputs_; const NameVarMap<VarType>& inputs_;
const NameVarMap<VarType>& outputs_; const NameVarMap<VarType>& outputs_;
const framework::AttributeMap& attrs_; const framework::AttributeMap& attrs_;
std::unordered_map<std::string, std::vector<std::string>> input_names_;
std::unordered_map<std::string, std::vector<std::string>> output_names_;
std::unordered_map<std::string, VarType*> var_set_;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -37,33 +37,154 @@ using vb_vector = std::vector<std::shared_ptr<imperative::VarBase>>; ...@@ -37,33 +37,154 @@ using vb_vector = std::vector<std::shared_ptr<imperative::VarBase>>;
using var_pair = std::pair<std::string, vb_vector>; using var_pair = std::pair<std::string, vb_vector>;
template <typename VarType>
class TestRuntimeInferVarTypeContext
: public RuntimeInferVarTypeContext<VarType> {
public:
TestRuntimeInferVarTypeContext(const NameVarMap<VarType>& inputs,
const NameVarMap<VarType>& outputs,
const framework::AttributeMap& attrs_map)
: RuntimeInferVarTypeContext<VarType>(inputs, outputs, attrs_map) {}
bool HasVar(const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::HasVar(name);
}
const std::vector<std::string>& InputVars(const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::InputVars(name);
}
const std::vector<std::string>& OutputVars(const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::OutputVars(name);
}
framework::proto::VarType::Type GetVarType(const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::GetVarType(name);
}
void SetVarType(const std::string& name,
framework::proto::VarType::Type type) {
RuntimeInferVarTypeContext<VarType>::SetVarType(name, type);
}
framework::proto::VarType::Type GetVarDataType(
const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::GetVarDataType(name);
}
void SetVarDataType(const std::string& name,
framework::proto::VarType::Type type) {
RuntimeInferVarTypeContext<VarType>::SetVarDataType(name, type);
}
std::vector<framework::proto::VarType::Type> GetVarDataTypes(
const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::GetVarDataTypes(name);
}
void SetVarDataTypes(
const std::string& name,
const std::vector<framework::proto::VarType::Type>& multiple_data_type) {
RuntimeInferVarTypeContext<VarType>::SetVarDataTypes(name,
multiple_data_type);
}
std::vector<int64_t> GetVarShape(const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::GetVarShape(name);
}
void SetVarShape(const std::string& name, const std::vector<int64_t>& dims) {
RuntimeInferVarTypeContext<VarType>::SetVarShape(name, dims);
}
int32_t GetVarLoDLevel(const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::GetVarLoDLevel(name);
}
void SetVarLoDLevel(const std::string& name, int32_t lod_level) {
RuntimeInferVarTypeContext<VarType>::SetVarLoDLevel(name, lod_level);
}
};
TEST(test_layer, test_runtime_context) { TEST(test_layer, test_runtime_context) {
std::shared_ptr<imperative::VarBase> vin( std::shared_ptr<imperative::VarBase> vin(
new imperative::VarBase(false, "vin")); new imperative::VarBase(false, "vin"));
std::shared_ptr<imperative::VarBase> vin_b(
new imperative::VarBase(false, "vin_b"));
std::shared_ptr<imperative::VarBase> vout( std::shared_ptr<imperative::VarBase> vout(
new imperative::VarBase(false, "vout")); new imperative::VarBase(false, "vout"));
var_pair in_pair = var_pair("X", vb_vector(1, vin)); std::shared_ptr<imperative::VarBase> vout_b(
var_pair out_pair = var_pair("Out", vb_vector(1, vout)); new imperative::VarBase(false, "vout_b"));
var_pair in_pair = var_pair("X", {vin, vin_b});
var_pair out_pair = var_pair("Out", {vout, vout_b});
imperative::NameVarBaseMap ins = {in_pair}; imperative::NameVarBaseMap ins = {in_pair};
imperative::NameVarBaseMap outs = {out_pair}; imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap attrs; framework::AttributeMap attrs;
auto *ctx = new imperative::RuntimeInferVarTypeContext<imperative::VarBase>(
auto* ctx =
new imperative::TestRuntimeInferVarTypeContext<imperative::VarBase>(
ins, outs, attrs); ins, outs, attrs);
ASSERT_TRUE(ctx->HasVar("vin"));
ASSERT_TRUE(ctx->HasInput("X")); ASSERT_TRUE(ctx->HasInput("X"));
ASSERT_TRUE(ctx->HasOutput("Out")); ASSERT_TRUE(ctx->HasOutput("Out"));
ASSERT_ANY_THROW(ctx->GetDataTypes("vin")); ASSERT_EQ(2u, ctx->InputSize("X"));
ASSERT_EQ("vin", ctx->InputVarName("X", 0));
ASSERT_TRUE(ctx->InputTypeAnyOf("X", framework::proto::VarType::LOD_TENSOR));
ASSERT_TRUE(ctx->InputTypeAllOf("X", framework::proto::VarType::LOD_TENSOR));
ASSERT_EQ(framework::proto::VarType::LOD_TENSOR, ctx->GetInputType("X"));
ASSERT_EQ(framework::proto::VarType::FP32, ctx->GetInputDataType("X"));
ctx->SyncTypeAndDataType("X", "Out");
ASSERT_EQ(framework::proto::VarType::FP32, vout->DataType());
ASSERT_EQ(framework::proto::VarType::LOD_TENSOR, ctx->GetOutputType("Out"));
ctx->SetOutputType("Out", framework::proto::VarType::SELECTED_ROWS,
framework::ALL_ELEMENTS);
ctx->SetOutputType("Out", framework::proto::VarType::LOD_TENSOR_ARRAY);
ASSERT_EQ(framework::proto::VarType::LOD_TENSOR_ARRAY, vout->Type());
ASSERT_EQ(framework::proto::VarType::SELECTED_ROWS, vout_b->Type());
ctx->SetOutputDataType("Out", framework::proto::VarType::FP64,
framework::ALL_ELEMENTS);
ctx->SetOutputDataType("Out", framework::proto::VarType::INT8);
ASSERT_EQ(framework::proto::VarType::INT8, vout->DataType());
ASSERT_EQ(framework::proto::VarType::FP64, vout_b->DataType());
// no throw, but do nothing
ASSERT_NO_THROW(
ctx->InsertVar("vout", framework::proto::VarType::LOD_TENSOR));
ASSERT_EQ(framework::proto::VarType::LOD_TENSOR_ARRAY, vout->Type());
ASSERT_ANY_THROW(ctx->HasVar("vin"));
ASSERT_ANY_THROW(ctx->InputVars("X"));
ASSERT_ANY_THROW(ctx->OutputVars("Out"));
ASSERT_ANY_THROW(ctx->GetVarType("vin"));
ASSERT_ANY_THROW(
ctx->SetVarType("vin", framework::proto::VarType::LOD_TENSOR));
ASSERT_ANY_THROW(ctx->GetVarDataType("vin"));
ASSERT_ANY_THROW(
ctx->SetVarDataType("vout", framework::proto::VarType::FP32));
ASSERT_ANY_THROW(ctx->GetVarDataTypes("vin"));
std::vector<framework::proto::VarType::Type> NullType; std::vector<framework::proto::VarType::Type> NullType;
ASSERT_ANY_THROW(ctx->SetDataTypes("vin", NullType)); ASSERT_ANY_THROW(ctx->SetVarDataTypes("vin", NullType));
ASSERT_ANY_THROW(ctx->GetShape("vin")); ASSERT_ANY_THROW(ctx->GetVarShape("vin"));
ASSERT_ANY_THROW(ctx->GetLoDLevel("vin")); ASSERT_ANY_THROW(ctx->SetVarShape("vin", {}));
ASSERT_ANY_THROW(ctx->SetLoDLevel("vin", 2)); ASSERT_ANY_THROW(ctx->GetVarLoDLevel("vin"));
ASSERT_ANY_THROW(ctx->SetVarLoDLevel("vin", 2));
ASSERT_TRUE(ctx->IsDygraph());
} }
std::string LayerDebugString(const std::string &op_type, std::string LayerDebugString(const std::string& op_type,
const NameVarBaseMap &ins, const NameVarBaseMap& ins,
const NameVarBaseMap &outs); const NameVarBaseMap& outs);
TEST(test_layer, test_debug_string) { TEST(test_layer, test_debug_string) {
platform::CPUPlace place; platform::CPUPlace place;
...@@ -71,7 +192,7 @@ TEST(test_layer, test_debug_string) { ...@@ -71,7 +192,7 @@ TEST(test_layer, test_debug_string) {
new imperative::VarBase(false, "vin")); new imperative::VarBase(false, "vin"));
var_pair in_pair = var_pair("X", vb_vector(1, vin)); var_pair in_pair = var_pair("X", vb_vector(1, vin));
auto test_func = [&](std::shared_ptr<imperative::VarBase> &vout) { auto test_func = [&](std::shared_ptr<imperative::VarBase>& vout) {
var_pair out_pair = var_pair("Out", vb_vector(1, vout)); var_pair out_pair = var_pair("Out", vb_vector(1, vout));
imperative::NameVarBaseMap ins = {in_pair}; imperative::NameVarBaseMap ins = {in_pair};
imperative::NameVarBaseMap outs = {out_pair}; imperative::NameVarBaseMap outs = {out_pair};
...@@ -124,26 +245,26 @@ TEST(test_layer, test_debug_string) { ...@@ -124,26 +245,26 @@ TEST(test_layer, test_debug_string) {
} }
static std::shared_ptr<imperative::GradOpNode> CreateGradNode( static std::shared_ptr<imperative::GradOpNode> CreateGradNode(
size_t id, const std::string &type, const imperative::NameVarBaseMap &ins, size_t id, const std::string& type, const imperative::NameVarBaseMap& ins,
const imperative::NameVarBaseMap &outs, const imperative::NameVarBaseMap& outs,
const framework::AttributeMap &attrs, const platform::Place &place) { const framework::AttributeMap& attrs, const platform::Place& place) {
auto node = std::make_shared<imperative::GradOpNode>(); auto node = std::make_shared<imperative::GradOpNode>();
auto *op = &(node->emplace_back()); auto* op = &(node->emplace_back());
op->SetId(id); op->SetId(id);
op->SetPlace(place); op->SetPlace(place);
op->SetType(type); op->SetType(type);
op->SetAttrMap(attrs); op->SetAttrMap(attrs);
for (auto &pair : ins) { for (auto& pair : ins) {
std::vector<std::shared_ptr<VariableWrapper>> vars; std::vector<std::shared_ptr<VariableWrapper>> vars;
for (auto &var : pair.second) { for (auto& var : pair.second) {
vars.emplace_back(var->SharedVar()); vars.emplace_back(var->SharedVar());
} }
op->SetInput(pair.first, vars, false); op->SetInput(pair.first, vars, false);
} }
for (auto &pair : outs) { for (auto& pair : outs) {
std::vector<std::shared_ptr<VariableWrapper>> vars; std::vector<std::shared_ptr<VariableWrapper>> vars;
for (auto &var : pair.second) { for (auto& var : pair.second) {
vars.emplace_back(var->SharedVar()); vars.emplace_back(var->SharedVar());
} }
op->SetOutput(pair.first, vars, false); op->SetOutput(pair.first, vars, false);
...@@ -173,7 +294,7 @@ TEST(test_layer, test_clear_backward_info) { ...@@ -173,7 +294,7 @@ TEST(test_layer, test_clear_backward_info) {
node->InsertGradPendingNode(pending_node); node->InsertGradPendingNode(pending_node);
ASSERT_EQ(node->size(), 1UL); ASSERT_EQ(node->size(), 1UL);
auto *op = &(node->back()); auto* op = &(node->back());
ASSERT_GT(op->GetInsMap().size(), 0UL); ASSERT_GT(op->GetInsMap().size(), 0UL);
ASSERT_GT(op->GetOutsMap().size(), 0UL); ASSERT_GT(op->GetOutsMap().size(), 0UL);
...@@ -196,10 +317,10 @@ TEST(test_layer, test_varbase_basic) { ...@@ -196,10 +317,10 @@ TEST(test_layer, test_varbase_basic) {
std::shared_ptr<imperative::VarBase> vin_with_grad( std::shared_ptr<imperative::VarBase> vin_with_grad(
new imperative::VarBase(true, "vin")); new imperative::VarBase(true, "vin"));
ASSERT_ANY_THROW(vin->MutableGradVar()); ASSERT_ANY_THROW(vin->MutableGradVar());
ASSERT_NO_THROW(ASSERT_TRUE(dynamic_cast<framework::Variable *>( ASSERT_NO_THROW(ASSERT_TRUE(dynamic_cast<framework::Variable*>(
vin_with_grad->MutableGradVar()) != 0)); vin_with_grad->MutableGradVar()) != 0));
ASSERT_TRUE(dynamic_cast<framework::Variable *>( ASSERT_TRUE(
vin_with_grad->MutableGradVar()) != 0); dynamic_cast<framework::Variable*>(vin_with_grad->MutableGradVar()) != 0);
vin_with_grad->SetOverridedStopGradient(false); vin_with_grad->SetOverridedStopGradient(false);
ASSERT_FALSE(vin_with_grad->OverridedStopGradient()); ASSERT_FALSE(vin_with_grad->OverridedStopGradient());
ASSERT_NO_FATAL_FAILURE(vin_with_grad->SetPersistable(true)); ASSERT_NO_FATAL_FAILURE(vin_with_grad->SetPersistable(true));
...@@ -228,9 +349,9 @@ TEST(test_layer, test_dygraph_execution_context) { ...@@ -228,9 +349,9 @@ TEST(test_layer, test_dygraph_execution_context) {
auto op = framework::OpRegistry::CreateOp("mul", {}, {}, {}, false); auto op = framework::OpRegistry::CreateOp("mul", {}, {}, {}, false);
paddle::platform::CPUPlace cpu_place; paddle::platform::CPUPlace cpu_place;
paddle::platform::DeviceContextPool &pool = paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance(); paddle::platform::DeviceContextPool::Instance();
auto *dev_ctx = pool.Get(cpu_place); auto* dev_ctx = pool.Get(cpu_place);
paddle::framework::RuntimeContext ctx({}, {}); paddle::framework::RuntimeContext ctx({}, {});
framework::Scope scope; framework::Scope scope;
......
...@@ -129,9 +129,10 @@ class ActivationOp : public framework::OperatorWithKernel { ...@@ -129,9 +129,10 @@ class ActivationOp : public framework::OperatorWithKernel {
class ActivationOpInferVarType class ActivationOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput { : public framework::PassInDtypeAndVarTypeToOutput {
protected: protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType() std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override { const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}}; static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
} }
}; };
......
...@@ -103,8 +103,7 @@ class AllcloseOp : public framework::OperatorWithKernel { ...@@ -103,8 +103,7 @@ class AllcloseOp : public framework::OperatorWithKernel {
class AllcloseOpVarTypeInference : public framework::VarTypeInference { class AllcloseOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
auto out_var_name = ctx->Output("Out").front(); ctx->SetOutputDataType("Out", framework::proto::VarType::BOOL);
ctx->SetDataType(out_var_name, framework::proto::VarType::BOOL);
} }
}; };
......
...@@ -60,11 +60,7 @@ class AssignOp : public framework::OperatorWithKernel { ...@@ -60,11 +60,7 @@ class AssignOp : public framework::OperatorWithKernel {
class AssignInferVarType : public framework::VarTypeInference { class AssignInferVarType : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
auto out_var_name = ctx->Output("Out")[0]; ctx->SyncTypeAndDataType("X", "Out");
auto input_type = ctx->GetType(ctx->Input("X")[0]);
auto input_data_type = ctx->GetDataType(ctx->Input("X")[0]);
ctx->SetType(out_var_name, input_type);
ctx->SetDataType(out_var_name, input_data_type);
} }
}; };
......
...@@ -171,9 +171,10 @@ class BatchNormGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -171,9 +171,10 @@ class BatchNormGradMaker : public framework::SingleGradOpMaker<T> {
class BatchNormOpInferVarType class BatchNormOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput { : public framework::PassInDtypeAndVarTypeToOutput {
protected: protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType() std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override { const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}}; static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Y"}};
return m;
} }
}; };
......
...@@ -204,12 +204,10 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase { ...@@ -204,12 +204,10 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase {
class BeamSearchDecodeInferVarType : public framework::VarTypeInference { class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext* ctx) const override { void operator()(framework::InferVarTypeContext* ctx) const override {
for (auto& o : ctx->Output("SentenceIds")) { ctx->SetOutputType("SentenceIds", framework::proto::VarType::LOD_TENSOR,
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR); framework::ALL_ELEMENTS);
} ctx->SetOutputType("SentenceScores", framework::proto::VarType::LOD_TENSOR,
for (auto& o : ctx->Output("SentenceScores")) { framework::ALL_ELEMENTS);
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
}
} }
}; };
......
...@@ -122,12 +122,10 @@ class BeamSearchOp : public framework::OperatorWithKernel { ...@@ -122,12 +122,10 @@ class BeamSearchOp : public framework::OperatorWithKernel {
class BeamSearchInferVarType : public framework::VarTypeInference { class BeamSearchInferVarType : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &o : ctx->Output("selected_ids")) { ctx->SetOutputType("selected_ids", framework::proto::VarType::LOD_TENSOR,
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR); framework::ALL_ELEMENTS);
} ctx->SetOutputType("selected_scores", framework::proto::VarType::LOD_TENSOR,
for (auto &o : ctx->Output("selected_scores")) { framework::ALL_ELEMENTS);
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
}
} }
}; };
......
...@@ -92,9 +92,8 @@ execution. ...@@ -92,9 +92,8 @@ execution.
class GetPlacesInferVarType : public framework::VarTypeInference { class GetPlacesInferVarType : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &o_name : ctx->Output("Out")) { ctx->SetOutputType("Out", framework::proto::VarType::PLACE_LIST,
ctx->SetType(o_name, framework::proto::VarType::PLACE_LIST); framework::ALL_ELEMENTS);
}
} }
}; };
......
...@@ -111,15 +111,15 @@ class WriteToArrayInferShape : public framework::InferShapeBase { ...@@ -111,15 +111,15 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
} }
}; };
class WriteToArrayInferVarType : public framework::VarTypeInference { class WriteToArrayInferVarType : public framework::StaticGraphVarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = ctx->Input("X")[0]; auto x_name = Input(ctx, "X")[0];
auto out_name = ctx->Output("Out")[0]; auto out_name = Output(ctx, "Out")[0];
VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY"; VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY";
ctx->SetType(out_name, framework::proto::VarType::LOD_TENSOR_ARRAY); SetType(ctx, out_name, framework::proto::VarType::LOD_TENSOR_ARRAY);
if (ctx->HasVar(x_name)) { if (HasVar(ctx, x_name)) {
ctx->SetDataType(out_name, ctx->GetDataType(x_name)); SetDataType(ctx, out_name, GetDataType(ctx, x_name));
} }
} }
}; };
......
...@@ -398,18 +398,19 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -398,18 +398,19 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
class WhileGradOpVarTypeInference : public framework::VarTypeInference { class WhileGradOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
auto p_names = ctx->Input(kX); auto p_names = Input(ctx, kX);
auto pg_ig_names = ctx->Output(framework::GradVarName(kX)); auto pg_ig_names = Output(ctx, framework::GradVarName(kX));
for (size_t i = 0; i < p_names.size(); ++i) { for (size_t i = 0; i < p_names.size(); ++i) {
if (ctx->HasVar(pg_ig_names[i])) { if (HasVar(ctx, pg_ig_names[i])) {
VLOG(5) << "Setting " << pg_ig_names[i] << " following " << p_names[i] VLOG(5) << "Setting " << pg_ig_names[i] << " following " << p_names[i]
<< " type: " << ctx->GetType(p_names[i]); << " type: " << GetType(ctx, p_names[i]);
ctx->SetType(pg_ig_names[i], ctx->GetType(p_names[i])); SetType(ctx, pg_ig_names[i], GetType(ctx, p_names[i]));
ctx->SetDataType(pg_ig_names[i], ctx->GetDataType(p_names[i])); SetDataType(ctx, pg_ig_names[i], GetDataType(ctx, p_names[i]));
} }
} }
} }
......
...@@ -254,10 +254,11 @@ class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -254,10 +254,11 @@ class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
class ConvOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { class ConvOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected: protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType() std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override { const override {
return std::unordered_map<std::string, std::string>{ static std::unordered_map<std::string, std::string> m{
{"Input", /*->*/ "Output"}}; {"Input", /*->*/ "Output"}};
return m;
} }
}; };
......
...@@ -177,9 +177,10 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel { ...@@ -177,9 +177,10 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
class CrossEntropyOpInferVarType class CrossEntropyOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput { : public framework::PassInDtypeAndVarTypeToOutput {
protected: protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType() std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override { const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}}; static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Y"}};
return m;
} }
}; };
......
...@@ -115,10 +115,8 @@ class MergeIdsOp : public framework::OperatorWithKernel { ...@@ -115,10 +115,8 @@ class MergeIdsOp : public framework::OperatorWithKernel {
class MergeIdsOpInferVarType : public framework::VarTypeInference { class MergeIdsOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
auto input_type = ctx->GetType(ctx->Input("Ids")[0]); auto input_type = ctx->GetInputType("Ids");
for (auto &out_var : ctx->Output("Out")) { ctx->SetOutputType("Out", input_type, framework::ALL_ELEMENTS);
ctx->SetType(out_var, input_type);
}
} }
}; };
......
...@@ -73,10 +73,8 @@ class SplitIdsOp : public framework::OperatorWithKernel { ...@@ -73,10 +73,8 @@ class SplitIdsOp : public framework::OperatorWithKernel {
class SplitIdsOpInferVarType : public framework::VarTypeInference { class SplitIdsOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
auto input_type = ctx->GetType(ctx->Input("Ids")[0]); auto input_type = ctx->GetInputType("Ids");
for (auto &out_var : ctx->Output("Out")) { ctx->SetOutputType("Out", input_type, framework::ALL_ELEMENTS);
ctx->SetType(out_var, input_type);
}
} }
}; };
......
...@@ -119,9 +119,10 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -119,9 +119,10 @@ class ElementwiseOp : public framework::OperatorWithKernel {
class ElementwiseOpInferVarType class ElementwiseOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput { : public framework::PassInDtypeAndVarTypeToOutput {
protected: protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType() std::unordered_map<std::string, std::string> &GetInputOutputWithSameType()
const override { const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}}; static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
} }
}; };
......
...@@ -49,8 +49,7 @@ class EyeOpVarTypeInference : public framework::VarTypeInference { ...@@ -49,8 +49,7 @@ class EyeOpVarTypeInference : public framework::VarTypeInference {
void operator()(framework::InferVarTypeContext* ctx) const override { void operator()(framework::InferVarTypeContext* ctx) const override {
auto data_type = static_cast<framework::proto::VarType::Type>( auto data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(ctx->GetAttr("dtype"))); boost::get<int>(ctx->GetAttr("dtype")));
auto& out_var_name = ctx->Output("Out").front(); ctx->SetOutputDataType("Out", data_type);
ctx->SetDataType(out_var_name, data_type);
} }
}; };
......
...@@ -72,14 +72,12 @@ The output will have the same shape and dtype as the input. ...@@ -72,14 +72,12 @@ The output will have the same shape and dtype as the input.
class FillAnyLikeVarTypeInference : public framework::VarTypeInference { class FillAnyLikeVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
auto out_var_name = ctx->Output("Out").front();
auto var_data_type = static_cast<framework::proto::VarType::Type>( auto var_data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(ctx->GetAttr("dtype"))); boost::get<int>(ctx->GetAttr("dtype")));
if (var_data_type < 0) { if (var_data_type < 0) {
const auto &input_var_name = ctx->Input("X").front(); ctx->SetOutputDataType("Out", ctx->GetInputDataType("X"));
ctx->SetDataType(out_var_name, ctx->GetDataType(input_var_name));
} else { } else {
ctx->SetDataType(out_var_name, var_data_type); ctx->SetOutputDataType("Out", var_data_type);
} }
} }
}; };
......
...@@ -64,8 +64,7 @@ class FillConstantOpVarTypeInference : public framework::VarTypeInference { ...@@ -64,8 +64,7 @@ class FillConstantOpVarTypeInference : public framework::VarTypeInference {
void operator()(framework::InferVarTypeContext* ctx) const override { void operator()(framework::InferVarTypeContext* ctx) const override {
auto data_type = static_cast<framework::proto::VarType::Type>( auto data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(ctx->GetAttr("dtype"))); boost::get<int>(ctx->GetAttr("dtype")));
auto& out_var_name = ctx->Output("Out").front(); ctx->SetOutputDataType("Out", data_type);
ctx->SetDataType(out_var_name, data_type);
} }
}; };
......
...@@ -63,8 +63,7 @@ class FillOpVarTypeInference : public framework::VarTypeInference { ...@@ -63,8 +63,7 @@ class FillOpVarTypeInference : public framework::VarTypeInference {
void operator()(framework::InferVarTypeContext* ctx) const override { void operator()(framework::InferVarTypeContext* ctx) const override {
auto data_type = static_cast<framework::proto::VarType::Type>( auto data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(ctx->GetAttr("dtype"))); boost::get<int>(ctx->GetAttr("dtype")));
auto& out_var_name = ctx->Output("Out").front(); ctx->SetOutputDataType("Out", data_type);
ctx->SetDataType(out_var_name, data_type);
} }
}; };
......
...@@ -114,9 +114,10 @@ class FlipOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -114,9 +114,10 @@ class FlipOpMaker : public framework::OpProtoAndCheckerMaker {
class FlipOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { class FlipOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected: protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType() std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override { const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}}; static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
} }
}; };
......
...@@ -85,9 +85,10 @@ class FusedBatchNormActGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -85,9 +85,10 @@ class FusedBatchNormActGradOpMaker : public framework::SingleGradOpMaker<T> {
class FusedBatchNormActOpInferVarType class FusedBatchNormActOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput { : public framework::PassInDtypeAndVarTypeToOutput {
protected: protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType() std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override { const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}}; static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Y"}};
return m;
} }
}; };
......
...@@ -146,19 +146,20 @@ class FusedEmbeddingSeqPoolOpGradVarTypeInference ...@@ -146,19 +146,20 @@ class FusedEmbeddingSeqPoolOpGradVarTypeInference
: public framework::VarTypeInference { : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext* ctx) const override { void operator()(framework::InferVarTypeContext* ctx) const override {
auto out_var_name = ctx->Output(framework::GradVarName("W")).front(); auto out_var_name = framework::GradVarName("W");
auto attr = ctx->GetAttr("is_sparse"); auto attr = ctx->GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr); bool is_sparse = boost::get<bool>(attr);
if (is_sparse) { if (is_sparse) {
VLOG(3) << "fused_embedding_seq_pool_grad op " VLOG(3) << "fused_embedding_seq_pool_grad op "
<< framework::GradVarName("W") << " is set to SelectedRows"; << framework::GradVarName("W") << " is set to SelectedRows";
ctx->SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS); ctx->SetOutputType(out_var_name,
framework::proto::VarType::SELECTED_ROWS);
} else { } else {
VLOG(3) << "fused_embedding_seq_pool_grad op " VLOG(3) << "fused_embedding_seq_pool_grad op "
<< framework::GradVarName("W") << " is set to LoDTensor"; << framework::GradVarName("W") << " is set to LoDTensor";
ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR); ctx->SetOutputType(out_var_name, framework::proto::VarType::LOD_TENSOR);
} }
ctx->SetDataType(out_var_name, ctx->GetDataType(ctx->Input("W")[0])); ctx->SetOutputDataType(out_var_name, ctx->GetInputDataType("W"));
} }
}; };
......
...@@ -83,11 +83,8 @@ class GetTensorFromSelectedRowsOpVarTypeInference ...@@ -83,11 +83,8 @@ class GetTensorFromSelectedRowsOpVarTypeInference
: public framework::VarTypeInference { : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const { // NOLINT void operator()(framework::InferVarTypeContext *ctx) const { // NOLINT
auto out_var_name = ctx->Output("Out").front(); ctx->SetOutputType("Out", framework::proto::VarType::LOD_TENSOR);
auto in_var_name = ctx->Input("X").front(); ctx->SetOutputDataType("Out", ctx->GetInputDataType("X"));
ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
ctx->SetDataType(out_var_name, ctx->GetDataType(in_var_name));
} }
}; };
......
...@@ -216,9 +216,10 @@ DECLARE_INPLACE_OP_INFERER(GroupNormGradInplaceInToOut, ...@@ -216,9 +216,10 @@ DECLARE_INPLACE_OP_INFERER(GroupNormGradInplaceInToOut,
class GroupNormOpInferVarType class GroupNormOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput { : public framework::PassInDtypeAndVarTypeToOutput {
protected: protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType() std::unordered_map<std::string, std::string> &GetInputOutputWithSameType()
const override { const override {
return {{"X", /*->*/ "Y"}}; static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Y"}};
return m;
} }
}; };
......
...@@ -229,31 +229,30 @@ class HierarchicalSigmoidGradOpGradVarTypeInference ...@@ -229,31 +229,30 @@ class HierarchicalSigmoidGradOpGradVarTypeInference
: public framework::VarTypeInference { : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext* ctx) const override { void operator()(framework::InferVarTypeContext* ctx) const override {
auto w_grad_var_name = ctx->Output(framework::GradVarName("W")).front(); auto w_grad_var_name = framework::GradVarName("W");
auto has_bias_grad_var = ctx->HasOutput(framework::GradVarName("Bias")); auto bias_grad_var_name = framework::GradVarName("Bias");
std::string bias_grad_var_name; if (ctx->HasOutput(bias_grad_var_name)) {
bool hasBias = false; VLOG(3) << "hierarchical_sigmoid_grad op "
if (has_bias_grad_var) { << framework::GradVarName("Bias") << " is set to LoDTensor";
hasBias = true; ctx->SetOutputType(bias_grad_var_name,
bias_grad_var_name = ctx->Output(framework::GradVarName("Bias")).front(); framework::proto::VarType::LOD_TENSOR);
} }
auto attr = ctx->GetAttr("is_sparse"); auto attr = ctx->GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr); bool is_sparse = boost::get<bool>(attr);
if (is_sparse) { if (is_sparse) {
VLOG(3) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") VLOG(3) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows"; << " is set to SelectedRows";
ctx->SetType(w_grad_var_name, framework::proto::VarType::SELECTED_ROWS); ctx->SetOutputType(w_grad_var_name,
framework::proto::VarType::SELECTED_ROWS);
} else { } else {
VLOG(3) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") VLOG(3) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor"; << " is set to LoDTensor";
ctx->SetType(w_grad_var_name, framework::proto::VarType::LOD_TENSOR); ctx->SetOutputType(w_grad_var_name,
framework::proto::VarType::LOD_TENSOR);
} }
if (hasBias) {
VLOG(3) << "hierarchical_sigmoid_grad op " ctx->SetOutputDataType(w_grad_var_name, ctx->GetInputDataType("W"));
<< framework::GradVarName("Bias") << " is set to LoDTensor";
ctx->SetType(bias_grad_var_name, framework::proto::VarType::LOD_TENSOR);
}
ctx->SetDataType(w_grad_var_name, ctx->GetDataType(ctx->Input("W")[0]));
} }
}; };
......
...@@ -123,9 +123,10 @@ class InstanceNormDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -123,9 +123,10 @@ class InstanceNormDoubleGradMaker : public framework::SingleGradOpMaker<T> {
class InstanceNormOpInferVarType class InstanceNormOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput { : public framework::PassInDtypeAndVarTypeToOutput {
protected: protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType() std::unordered_map<std::string, std::string> &GetInputOutputWithSameType()
const override { const override {
return std::unordered_map<std::string, std::string>{{"X", "Y"}}; static std::unordered_map<std::string, std::string> m{{"X", "Y"}};
return m;
} }
}; };
......
...@@ -65,9 +65,8 @@ class LoDRankTableInferShape : public framework::InferShapeBase { ...@@ -65,9 +65,8 @@ class LoDRankTableInferShape : public framework::InferShapeBase {
class LoDRankTableInferVarType : public framework::VarTypeInference { class LoDRankTableInferVarType : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &o : ctx->Output("Out")) { ctx->SetOutputType("Out", framework::proto::VarType::LOD_RANK_TABLE,
ctx->SetType(o, framework::proto::VarType::LOD_RANK_TABLE); framework::ALL_ELEMENTS);
}
} }
}; };
......
...@@ -76,24 +76,25 @@ class LoDResetOp : public framework::OperatorWithKernel { ...@@ -76,24 +76,25 @@ class LoDResetOp : public framework::OperatorWithKernel {
} }
}; };
class LoDResetOpVarTypeInference : public framework::VarTypeInference { class LoDResetOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_var_name = ctx->Input("X").front(); auto x_var_name = Input(ctx, "X").front();
auto out_var_name = ctx->Output("Out").front(); auto out_var_name = Output(ctx, "Out").front();
bool append = boost::get<bool>(ctx->GetAttr("append")); bool append = boost::get<bool>(ctx->GetAttr("append"));
if (ctx->HasInput("Y")) { if (ctx->HasInput("Y")) {
auto y_var_name = ctx->Input("Y").front(); auto y_var_name = Input(ctx, "Y").front();
auto y_lod_level = std::max(ctx->GetLoDLevel(y_var_name), 1); auto y_lod_level = std::max(GetLoDLevel(ctx, y_var_name), 1);
ctx->SetLoDLevel(out_var_name, y_lod_level); SetLoDLevel(ctx, out_var_name, y_lod_level);
} else if (append) { } else if (append) {
auto x_lod_level = std::max(ctx->GetLoDLevel(x_var_name), 1); auto x_lod_level = std::max(GetLoDLevel(ctx, x_var_name), 1);
ctx->SetLoDLevel(out_var_name, x_lod_level); SetLoDLevel(ctx, out_var_name, x_lod_level);
} else { } else {
ctx->SetLoDLevel(out_var_name, 1); SetLoDLevel(ctx, out_var_name, 1);
} }
ctx->SetDataType(out_var_name, ctx->GetDataType(x_var_name)); SetDataType(ctx, out_var_name, GetDataType(ctx, x_var_name));
ctx->SetType(out_var_name, paddle::framework::proto::VarType::LOD_TENSOR); SetType(ctx, out_var_name, paddle::framework::proto::VarType::LOD_TENSOR);
} }
}; };
......
...@@ -221,9 +221,8 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase { ...@@ -221,9 +221,8 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase {
class LoDTensorToArrayInferVarType : public framework::VarTypeInference { class LoDTensorToArrayInferVarType : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &out_var : ctx->Output("Out")) { ctx->SetOutputType("Out", framework::proto::VarType::LOD_TENSOR_ARRAY,
ctx->SetType(out_var, framework::proto::VarType::LOD_TENSOR_ARRAY); framework::ALL_ELEMENTS);
}
} }
}; };
......
...@@ -173,19 +173,20 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { ...@@ -173,19 +173,20 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
class LookupTableOpGradVarTypeInference : public framework::VarTypeInference { class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext* ctx) const override { void operator()(framework::InferVarTypeContext* ctx) const override {
auto out_var_name = ctx->Output(framework::GradVarName("W")).front(); auto out_var_name = framework::GradVarName("W");
auto attr = ctx->GetAttr("is_sparse"); auto attr = ctx->GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr); bool is_sparse = boost::get<bool>(attr);
if (is_sparse) { if (is_sparse) {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows"; << " is set to SelectedRows";
ctx->SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS); ctx->SetOutputType(out_var_name,
framework::proto::VarType::SELECTED_ROWS);
} else { } else {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor"; << " is set to LoDTensor";
ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR); ctx->SetOutputType(out_var_name, framework::proto::VarType::LOD_TENSOR);
} }
ctx->SetDataType(out_var_name, ctx->GetDataType(ctx->Input("W")[0])); ctx->SetOutputDataType(out_var_name, ctx->GetInputDataType("W"));
} }
}; };
......
...@@ -160,19 +160,20 @@ class LookupTableV2OpGrad : public framework::OperatorWithKernel { ...@@ -160,19 +160,20 @@ class LookupTableV2OpGrad : public framework::OperatorWithKernel {
class LookupTableV2OpGradVarTypeInference : public framework::VarTypeInference { class LookupTableV2OpGradVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext* ctx) const override { void operator()(framework::InferVarTypeContext* ctx) const override {
auto out_var_name = ctx->Output(framework::GradVarName("W")).front(); auto out_var_name = framework::GradVarName("W");
auto attr = ctx->GetAttr("is_sparse"); auto attr = ctx->GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr); bool is_sparse = boost::get<bool>(attr);
if (is_sparse) { if (is_sparse) {
VLOG(3) << "lookup_table_v2_grad op " << framework::GradVarName("W") VLOG(3) << "lookup_table_v2_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows"; << " is set to SelectedRows";
ctx->SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS); ctx->SetOutputType(out_var_name,
framework::proto::VarType::SELECTED_ROWS);
} else { } else {
VLOG(3) << "lookup_table_v2_grad op " << framework::GradVarName("W") VLOG(3) << "lookup_table_v2_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor"; << " is set to LoDTensor";
ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR); ctx->SetOutputType(out_var_name, framework::proto::VarType::LOD_TENSOR);
} }
ctx->SetDataType(out_var_name, ctx->GetDataType(ctx->Input("W")[0])); ctx->SetOutputDataType(out_var_name, ctx->GetInputDataType("W"));
} }
}; };
......
...@@ -45,9 +45,10 @@ Mean Operator calculates the mean of all elements in X. ...@@ -45,9 +45,10 @@ Mean Operator calculates the mean of all elements in X.
class MeanOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { class MeanOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected: protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType() std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override { const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}}; static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
} }
}; };
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/merge_selected_rows_op.h" #include "paddle/fluid/operators/merge_selected_rows_op.h"
#include <unordered_map>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -79,9 +80,10 @@ Example: ...@@ -79,9 +80,10 @@ Example:
class MergeSelectedRowsOpInferVarType class MergeSelectedRowsOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput { : public framework::PassInDtypeAndVarTypeToOutput {
protected: protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType() std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override { const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}}; static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
} }
}; };
......
...@@ -200,9 +200,10 @@ or not. But the output only shares the LoD information with input $X$. ...@@ -200,9 +200,10 @@ or not. But the output only shares the LoD information with input $X$.
class MulOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { class MulOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected: protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType() std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override { const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}}; static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
} }
}; };
......
...@@ -61,8 +61,7 @@ class NCCLInitOp : public framework::OperatorBase { ...@@ -61,8 +61,7 @@ class NCCLInitOp : public framework::OperatorBase {
class NCCLInitOpVarTypeInference : public framework::VarTypeInference { class NCCLInitOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
auto out_var_name = ctx->Output("Communicator").front(); ctx->SetOutputType("Communicator", framework::proto::VarType::RAW);
ctx->SetType(out_var_name, framework::proto::VarType::RAW);
} }
}; };
......
...@@ -280,20 +280,20 @@ class NCEOpGrad : public framework::OperatorWithKernel { ...@@ -280,20 +280,20 @@ class NCEOpGrad : public framework::OperatorWithKernel {
class NCEOpGradVarTypeInference : public framework::VarTypeInference { class NCEOpGradVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
auto weight_grad = ctx->Output(framework::GradVarName("Weight")).front(); auto weight_grad = framework::GradVarName("Weight");
auto attr = ctx->GetAttr("is_sparse"); auto attr = ctx->GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr); bool is_sparse = boost::get<bool>(attr);
if (is_sparse) { if (is_sparse) {
VLOG(3) << "nce_op_grad op " << weight_grad << " and " VLOG(3) << "nce_op_grad op " << weight_grad << " and "
<< " is set to SelectedRows"; << " is set to SelectedRows";
ctx->SetType(weight_grad, framework::proto::VarType::SELECTED_ROWS); ctx->SetOutputType(weight_grad, framework::proto::VarType::SELECTED_ROWS);
} else { } else {
VLOG(3) << "nce_op_grad op " << weight_grad << " and " VLOG(3) << "nce_op_grad op " << weight_grad << " and "
<< " is set to LoDTensor"; << " is set to LoDTensor";
ctx->SetType(weight_grad, framework::proto::VarType::LOD_TENSOR); ctx->SetOutputType(weight_grad, framework::proto::VarType::LOD_TENSOR);
} }
ctx->SetDataType(weight_grad, ctx->GetDataType(ctx->Input("Input")[0])); ctx->SetOutputDataType(weight_grad, ctx->GetInputDataType("Input"));
} }
}; };
......
...@@ -22,18 +22,15 @@ using Tensor = framework::Tensor; ...@@ -22,18 +22,15 @@ using Tensor = framework::Tensor;
class MomentumOpInferVarType : public framework::VarTypeInference { class MomentumOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext* ctx) const override { void operator()(framework::InferVarTypeContext* ctx) const override {
auto& input_var = ctx->Input("Param")[0]; auto in_var_type = ctx->GetInputType("Param");
for (auto& out_var : ctx->Output("ParamOut")) { PADDLE_ENFORCE_EQ(
if (ctx->GetType(input_var) == framework::proto::VarType::SELECTED_ROWS) { in_var_type == framework::proto::VarType::SELECTED_ROWS ||
ctx->SetType(out_var, framework::proto::VarType::SELECTED_ROWS); in_var_type == framework::proto::VarType::LOD_TENSOR,
} else if (ctx->GetType(input_var) == true,
framework::proto::VarType::LOD_TENSOR) { platform::errors::InvalidArgument(
ctx->SetType(out_var, framework::proto::VarType::LOD_TENSOR); "Only support LodTensor and SelectedRows, Unexpected Input Type."));
} else {
PADDLE_THROW( ctx->SetOutputType("ParamOut", in_var_type, framework::ALL_ELEMENTS);
"Only support LodTensor and SelectedRows, Unexpected Input Type.");
}
}
} }
}; };
......
...@@ -75,19 +75,15 @@ class SGDOp : public framework::OperatorWithKernel { ...@@ -75,19 +75,15 @@ class SGDOp : public framework::OperatorWithKernel {
class SGDOpInferVarType : public framework::VarTypeInference { class SGDOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
auto &input_var_n = ctx->Input("Param")[0]; auto in_var_type = ctx->GetInputType("Param");
auto in_var_type = ctx->GetType(input_var_n); PADDLE_ENFORCE_EQ(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
in_var_type == framework::proto::VarType::LOD_TENSOR, in_var_type == framework::proto::VarType::LOD_TENSOR,
"The input Var's type should be LoDtensor or SelectedRows," true, platform::errors::InvalidArgument(
" but the received var(%s)'s type is %s", "The input Var's type should be LoDtensor or "
input_var_n, in_var_type); "SelectedRows, but the received type is %s",
in_var_type));
for (auto &out_var_n : ctx->Output("ParamOut")) { ctx->SetOutputType("ParamOut", in_var_type, framework::ALL_ELEMENTS);
if (ctx->GetType(out_var_n) != in_var_type) {
ctx->SetType(out_var_n, in_var_type);
}
}
} }
}; };
......
...@@ -422,9 +422,10 @@ Example: ...@@ -422,9 +422,10 @@ Example:
class PoolOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { class PoolOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected: protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType() std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override { const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}}; static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
} }
}; };
......
...@@ -260,9 +260,7 @@ class PrintOpInferShape : public framework::InferShapeBase { ...@@ -260,9 +260,7 @@ class PrintOpInferShape : public framework::InferShapeBase {
class PrintOpVarTypeInference : public framework::VarTypeInference { class PrintOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
auto input_type = ctx->GetType(ctx->Input("In")[0]); ctx->SetOutputType("Out", ctx->GetInputType("In"));
auto out_name = ctx->Output("Out").front();
ctx->SetType(out_name, input_type);
} }
}; };
......
...@@ -116,12 +116,11 @@ static void CallPythonFunc(py::object *callable, ...@@ -116,12 +116,11 @@ static void CallPythonFunc(py::object *callable,
} }
} }
class PyFuncOpVarTypeInference : public framework::VarTypeInference { class PyFuncOpVarTypeInference : public framework::StaticGraphVarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
bool has_out = (ctx->HasOutput("Out") && !ctx->Output("Out").empty()); bool has_out = ctx->HasOutput("Out");
bool has_in = ctx->HasInput("X");
bool has_in = (ctx->HasInput("X") && !ctx->Input("X").empty());
/** /**
* X or Out can be empty, so that py_func can be more flexible * X or Out can be empty, so that py_func can be more flexible
...@@ -147,7 +146,7 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference { ...@@ -147,7 +146,7 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
* the corresponding forward variable * the corresponding forward variable
*/ */
const std::string kGradVarSuffix = framework::kGradVarSuffix; const std::string kGradVarSuffix = framework::kGradVarSuffix;
auto &out_var_names = ctx->Output("Out"); auto &out_var_names = Output(ctx, "Out");
for (auto &out_var_name : out_var_names) { for (auto &out_var_name : out_var_names) {
if (out_var_name == framework::kEmptyVarName || if (out_var_name == framework::kEmptyVarName ||
out_var_name.size() < kGradVarSuffix.size()) { out_var_name.size() < kGradVarSuffix.size()) {
...@@ -157,19 +156,17 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference { ...@@ -157,19 +156,17 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
size_t len = out_var_name.size() - kGradVarSuffix.size(); size_t len = out_var_name.size() - kGradVarSuffix.size();
if (out_var_name.substr(len) == kGradVarSuffix) { if (out_var_name.substr(len) == kGradVarSuffix) {
auto fwd_var_name = out_var_name.substr(0, len); auto fwd_var_name = out_var_name.substr(0, len);
PADDLE_ENFORCE_EQ(ctx->HasVar(out_var_name), true, OP_INOUT_CHECK(HasVar(ctx, out_var_name), "Var", out_var_name,
platform::errors::InvalidArgument( "py_func");
"Backward variable %s not found", out_var_name)); OP_INOUT_CHECK(HasVar(ctx, fwd_var_name), "Var", fwd_var_name,
PADDLE_ENFORCE_EQ(ctx->HasVar(fwd_var_name), true, "py_func");
platform::errors::InvalidArgument(
"Backward variable %s not found", fwd_var_name));
VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input(" VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input("
<< fwd_var_name << ")"; << fwd_var_name << ")";
ctx->SetShape(out_var_name, ctx->GetShape(fwd_var_name)); SetShape(ctx, out_var_name, GetShape(ctx, fwd_var_name));
ctx->SetDataType(out_var_name, ctx->GetDataType(fwd_var_name)); SetDataType(ctx, out_var_name, GetDataType(ctx, fwd_var_name));
ctx->SetLoDLevel(out_var_name, ctx->GetLoDLevel(fwd_var_name)); SetLoDLevel(ctx, out_var_name, GetLoDLevel(ctx, fwd_var_name));
ctx->SetType(out_var_name, ctx->GetType(fwd_var_name)); SetType(ctx, out_var_name, GetType(ctx, fwd_var_name));
} }
} }
} }
......
...@@ -75,8 +75,7 @@ class RandpermOpVarTypeInference : public framework::VarTypeInference { ...@@ -75,8 +75,7 @@ class RandpermOpVarTypeInference : public framework::VarTypeInference {
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
auto var_data_type = static_cast<framework::proto::VarType::Type>( auto var_data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(ctx->GetAttr("dtype"))); boost::get<int>(ctx->GetAttr("dtype")));
auto out_var_name = ctx->Output("Out").front(); ctx->SetOutputDataType("Out", var_data_type);
ctx->SetDataType(out_var_name, var_data_type);
} }
}; };
......
...@@ -70,18 +70,18 @@ class ReadInferShape : public framework::InferShapeBase { ...@@ -70,18 +70,18 @@ class ReadInferShape : public framework::InferShapeBase {
} }
}; };
class ReadInferVarType : public framework::VarTypeInference { class ReadInferVarType : public framework::StaticGraphVarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext* ctx) const override { void operator()(framework::InferVarTypeContext* ctx) const override {
bool infer_out = boost::get<bool>(ctx->GetAttr("infer_out")); bool infer_out = boost::get<bool>(ctx->GetAttr("infer_out"));
if (infer_out) { if (infer_out) {
std::string reader_name = ctx->Input("Reader")[0]; std::string reader_name = Input(ctx, "Reader")[0];
std::vector<std::string> out_names = ctx->Output("Out"); auto& out_names = Output(ctx, "Out");
auto dtypes = ctx->GetDataTypes(reader_name); auto dtypes = GetDataTypes(ctx, reader_name);
PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size()); PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size());
for (size_t i = 0; i < dtypes.size(); ++i) { for (size_t i = 0; i < dtypes.size(); ++i) {
ctx->SetType(out_names[i], framework::proto::VarType::LOD_TENSOR); SetType(ctx, out_names[i], framework::proto::VarType::LOD_TENSOR);
ctx->SetDataType(out_names[i], dtypes[i]); SetDataType(ctx, out_names[i], dtypes[i]);
} }
} }
} }
......
...@@ -100,8 +100,7 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const { ...@@ -100,8 +100,7 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
void FileReaderInferVarType::operator()( void FileReaderInferVarType::operator()(
framework::InferVarTypeContext* ctx) const { framework::InferVarTypeContext* ctx) const {
std::string reader_name = ctx->Output("Out")[0]; ctx->SetOutputType("Out", framework::proto::VarType::READER);
ctx->SetType(reader_name, framework::proto::VarType::READER);
} }
void DecoratedReaderInferShape::operator()( void DecoratedReaderInferShape::operator()(
...@@ -125,10 +124,8 @@ void DecoratedReaderInferShape::operator()( ...@@ -125,10 +124,8 @@ void DecoratedReaderInferShape::operator()(
void DecoratedReaderInferVarType::operator()( void DecoratedReaderInferVarType::operator()(
framework::InferVarTypeContext* ctx) const { framework::InferVarTypeContext* ctx) const {
const std::string& in_reader_name = ctx->Input("UnderlyingReader")[0]; ctx->SetOutputType("Out", framework::proto::VarType::READER);
const std::string& out_reader_name = ctx->Output("Out")[0]; ctx->SetOutputDataTypes("Out", ctx->GetInputDataTypes("UnderlyingReader"));
ctx->SetType(out_reader_name, framework::proto::VarType::READER);
ctx->SetDataTypes(out_reader_name, ctx->GetDataTypes(in_reader_name));
} }
void DecoratedReaderMakerBase::Make() { void DecoratedReaderMakerBase::Make() {
......
...@@ -58,8 +58,7 @@ class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference { ...@@ -58,8 +58,7 @@ class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference {
auto data_type = static_cast<paddle::framework::proto::VarType::Type>( auto data_type = static_cast<paddle::framework::proto::VarType::Type>(
boost::get<int>(ctx->GetAttr("out_dtype"))); boost::get<int>(ctx->GetAttr("out_dtype")));
if (data_type >= 0) { if (data_type >= 0) {
auto& out_var_name = ctx->Output("Out").front(); ctx->SetOutputDataType("Out", data_type);
ctx->SetDataType(out_var_name, data_type);
} }
} }
}; };
......
...@@ -85,9 +85,8 @@ to a file on disk. ...@@ -85,9 +85,8 @@ to a file on disk.
class SaveCombineOpInferVarType : public framework::VarTypeInference { class SaveCombineOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext* ctx) const override { void operator()(framework::InferVarTypeContext* ctx) const override {
for (auto& o : ctx->Output("Y")) { ctx->SetOutputType("Y", framework::proto::VarType::RAW,
ctx->SetType(o, framework::proto::VarType::RAW); framework::ALL_ELEMENTS);
}
} }
}; };
......
...@@ -73,7 +73,7 @@ class SaveOpVarTypeInference : public framework::VarTypeInference { ...@@ -73,7 +73,7 @@ class SaveOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
auto var_type = framework::proto::VarType::RAW; auto var_type = framework::proto::VarType::RAW;
ctx->SetType(LOOKUP_TABLE_PATH, var_type); ctx->InsertVar(LOOKUP_TABLE_PATH, var_type);
} }
}; };
......
...@@ -82,13 +82,7 @@ $$Out = scale*(X + bias)$$ ...@@ -82,13 +82,7 @@ $$Out = scale*(X + bias)$$
class ScaleOpVarTypeInference : public framework::VarTypeInference { class ScaleOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
auto &in_var_name = ctx->Input("X").front(); ctx->SyncTypeAndDataType("X", "Out");
auto out_var_name = ctx->Output("Out").front();
if (in_var_name != out_var_name) {
ctx->SetType(out_var_name, ctx->GetType(in_var_name));
ctx->SetDataType(out_var_name, ctx->GetDataType(in_var_name));
}
} }
}; };
......
...@@ -45,9 +45,10 @@ class SeluOp : public framework::OperatorWithKernel { ...@@ -45,9 +45,10 @@ class SeluOp : public framework::OperatorWithKernel {
class SeluOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { class SeluOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected: protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType() std::unordered_map<std::string, std::string> &GetInputOutputWithSameType()
const override { const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}}; static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
} }
}; };
......
...@@ -145,9 +145,10 @@ For each row $i$ and each column $j$ in the matrix, we have: ...@@ -145,9 +145,10 @@ For each row $i$ and each column $j$ in the matrix, we have:
class SoftmaxOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { class SoftmaxOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected: protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType() std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override { const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}}; static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
} }
}; };
......
...@@ -64,9 +64,8 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel { ...@@ -64,9 +64,8 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel {
class SplitSelectedRowsOpInferVarType : public framework::VarTypeInference { class SplitSelectedRowsOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &out_var : ctx->Output("Out")) { ctx->SetOutputType("Out", framework::proto::VarType::SELECTED_ROWS,
ctx->SetType(out_var, framework::proto::VarType::SELECTED_ROWS); framework::ALL_ELEMENTS);
}
} }
}; };
......
...@@ -210,43 +210,36 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -210,43 +210,36 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
class SumOpVarTypeInference : public framework::VarTypeInference { class SumOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext* ctx) const override { void operator()(framework::InferVarTypeContext* ctx) const override {
auto& inputs = ctx->Input("X"); if (!ctx->IsDygraph()) {
auto var_type = framework::proto::VarType::SELECTED_ROWS; auto var_type = framework::proto::VarType::SELECTED_ROWS;
for (auto& name : ctx->Input("X")) { if (VLOG_IS_ON(10)) {
VLOG(10) << name << " " << ctx->GetType(name); for (size_t ind = 0; ind < ctx->InputSize("X"); ++ind) {
VLOG(10) << ctx->InputVarName("X", ind) << " "
<< ctx->GetInputType("X", ind);
}
} }
bool any_input_is_lod_tensor = std::any_of( if (ctx->InputTypeAnyOf("X",
inputs.begin(), inputs.end(), [ctx](const std::string& name) { framework::proto::VarType::LOD_TENSOR_ARRAY)) {
return ctx->GetType(name) == framework::proto::VarType::LOD_TENSOR; if (!ctx->InputTypeAllOf("X",
}); framework::proto::VarType::LOD_TENSOR_ARRAY)) {
auto is_tensor_array = [ctx](const std::string& name) {
return ctx->GetType(name) == framework::proto::VarType::LOD_TENSOR_ARRAY;
};
bool any_input_is_tensor_array =
std::any_of(inputs.begin(), inputs.end(), is_tensor_array);
bool all_inputs_are_tensor_array =
std::all_of(inputs.begin(), inputs.end(), is_tensor_array);
if (any_input_is_tensor_array) {
if (!all_inputs_are_tensor_array) {
std::ostringstream os; std::ostringstream os;
for (auto& each : inputs) { for (size_t ind = 0; ind < ctx->InputSize("X"); ++ind) {
os << " " << each << " type is " << ctx->GetType(each) << "\n"; os << " " << ctx->InputVarName("X", ind) << " type is "
<< ctx->GetInputType("X", ind) << "\n";
} }
PADDLE_ENFORCE_EQ(all_inputs_are_tensor_array, true, PADDLE_THROW(platform::errors::InvalidArgument(
"Not all inputs are tensor array:\n%s", os.str()); "Not all inputs are tensor array:\n%s", os.str()));
} }
var_type = framework::proto::VarType::LOD_TENSOR_ARRAY; var_type = framework::proto::VarType::LOD_TENSOR_ARRAY;
} else if (any_input_is_lod_tensor) { } else if (ctx->InputTypeAnyOf("X",
framework::proto::VarType::LOD_TENSOR)) {
var_type = framework::proto::VarType::LOD_TENSOR; var_type = framework::proto::VarType::LOD_TENSOR;
} }
auto out_var_name = ctx->Output("Out").front(); ctx->SetOutputType("Out", var_type);
ctx->SetType(out_var_name, var_type); ctx->SetOutputDataType("Out", ctx->GetInputDataType("X"));
ctx->SetDataType(out_var_name, ctx->GetDataType(inputs.front())); }
} }
}; };
......
...@@ -213,9 +213,9 @@ class LoDTensorArray2TensorGradInferVarType ...@@ -213,9 +213,9 @@ class LoDTensorArray2TensorGradInferVarType
: public framework::VarTypeInference { : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &out_var : ctx->Output(framework::GradVarName("X"))) { ctx->SetOutputType(framework::GradVarName("X"),
ctx->SetType(out_var, framework::proto::VarType::LOD_TENSOR_ARRAY); framework::proto::VarType::LOD_TENSOR_ARRAY,
} framework::ALL_ELEMENTS);
} }
}; };
......
...@@ -232,15 +232,13 @@ uniform distribution. The random result is in set [min, max). ...@@ -232,15 +232,13 @@ uniform distribution. The random result is in set [min, max).
class UniformRandomOpVarTypeInference : public framework::VarTypeInference { class UniformRandomOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
auto out_var_name = ctx->Output("Out").front();
auto var_data_type = static_cast<framework::proto::VarType::Type>( auto var_data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(ctx->GetAttr("dtype"))); boost::get<int>(ctx->GetAttr("dtype")));
if (ctx->GetType(out_var_name) != if (ctx->GetOutputType("Out") != framework::proto::VarType::SELECTED_ROWS) {
framework::proto::VarType::SELECTED_ROWS) { ctx->SetOutputType("Out", framework::proto::VarType::LOD_TENSOR);
ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
} }
ctx->SetDataType(out_var_name, var_data_type); ctx->SetOutputDataType("Out", var_data_type);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册