diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index 37d22ec566c1927983b7c6b19fca8a965c433213..66507fe7cafbb82e2944dbed9614758cebfedad9 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -45,19 +45,13 @@ class SumOpMaker : public OpProtoAndCheckerMaker { class SumOpVarTypeInference : public VarTypeInference { public: void operator()(InferVarTypeContext *ctx) const override { - auto &inputs = ctx->Input("X"); auto default_var_type = proto::VarType::SELECTED_ROWS; - bool any_input_is_lod_tensor = std::any_of( - inputs.begin(), inputs.end(), [&ctx](const std::string &name) { - return ctx->GetType(name) == proto::VarType::LOD_TENSOR; - }); - if (any_input_is_lod_tensor) { + if (ctx->InputTypeAnyOf("X", proto::VarType::LOD_TENSOR)) { default_var_type = proto::VarType::LOD_TENSOR; } - auto out_var_name = ctx->Output("Out").front(); - ctx->SetType(out_var_name, default_var_type); + ctx->SetOutputType("Out", default_var_type); } }; diff --git a/paddle/fluid/framework/var_type_inference.h b/paddle/fluid/framework/var_type_inference.h index 66e6ac81623a1cd1c79981c1e4a97d974e9c2426..9312ac075dec3e240e6fa56a632b50509c0c2632 100644 --- a/paddle/fluid/framework/var_type_inference.h +++ b/paddle/fluid/framework/var_type_inference.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include #include #include @@ -25,8 +26,14 @@ namespace framework { class OpDesc; class BlockDesc; +class StaticGraphVarTypeInference; // default infer var type context + +static const int ALL_ELEMENTS = -1; + class InferVarTypeContext { + friend class StaticGraphVarTypeInference; + public: InferVarTypeContext(const OpDesc* op, BlockDesc* block) : op_(op), block_(block) {} @@ -34,91 +41,267 @@ class InferVarTypeContext { virtual ~InferVarTypeContext() {} virtual Attribute GetAttr(const std::string& name) const { - PADDLE_ENFORCE_NOT_NULL(op_); + PADDLE_ENFORCE_NOT_NULL( + op_, platform::errors::PreconditionNotMet("op_ should not be null")); return op_->GetAttr(name); } - virtual bool HasVar(const std::string& name) const { - PADDLE_ENFORCE_NOT_NULL(block_); - return block_->FindVarRecursive(name) != nullptr; - } - virtual bool HasInput(const std::string& name) const { - PADDLE_ENFORCE_NOT_NULL(op_); + PADDLE_ENFORCE_NOT_NULL( + op_, platform::errors::PreconditionNotMet("op_ should not be null")); auto& inputs = op_->Inputs(); auto input = inputs.find(name); return input != inputs.end() && !input->second.empty(); } virtual bool HasOutput(const std::string& name) const { - PADDLE_ENFORCE_NOT_NULL(op_); + PADDLE_ENFORCE_NOT_NULL( + op_, platform::errors::PreconditionNotMet("op_ should not be null")); auto& outputs = op_->Outputs(); auto output = outputs.find(name); return output != outputs.end() && !output->second.empty(); } - virtual const std::vector& Input(const std::string& name) const { - PADDLE_ENFORCE_NOT_NULL(op_); + virtual size_t InputSize(const std::string& name) const { + PADDLE_ENFORCE_NOT_NULL( + op_, platform::errors::PreconditionNotMet("op_ should not be null")); + return op_->Inputs().at(name).size(); + } + + virtual const std::string& InputVarName(const std::string& name, + const int index = 0) const { + PADDLE_ENFORCE_NOT_NULL( + op_, platform::errors::PreconditionNotMet("op_ should not be null")); + return op_->Inputs().at(name)[index]; + } + + virtual bool InputTypeAnyOf(const std::string& name, + proto::VarType::Type type) const { + PADDLE_ENFORCE_NOT_NULL( + op_, platform::errors::PreconditionNotMet("op_ should not be null")); + auto& inputs = op_->Input(name); + return std::any_of(inputs.begin(), inputs.end(), + [this, &type](const std::string& name) { + return this->GetVarType(name) == type; + }); + } + + virtual bool InputTypeAllOf(const std::string& name, + proto::VarType::Type type) const { + PADDLE_ENFORCE_NOT_NULL( + op_, platform::errors::PreconditionNotMet("op_ should not be null")); + auto& inputs = op_->Input(name); + return std::all_of(inputs.begin(), inputs.end(), + [this, &type](const std::string& name) { + return this->GetVarType(name) == type; + }); + } + + virtual void SyncTypeAndDataType(const std::string& input_name, + const std::string& output_name, + int index = 0) { + PADDLE_ENFORCE_NOT_NULL( + op_, platform::errors::PreconditionNotMet("op_ should not be null")); + auto& x_name = op_->Input(input_name).at(index); + auto& out_name = op_->Output(output_name).at(index); + + if (x_name != out_name) { + this->SetVarType(out_name, this->GetVarType(x_name)); + this->SetVarDataType(out_name, this->GetVarDataType(x_name)); + } + } + + virtual void SetOutputType(const std::string& name, proto::VarType::Type type, + int index = 0) { + PADDLE_ENFORCE_NOT_NULL( + op_, platform::errors::PreconditionNotMet("op_ should not be null")); + if (ALL_ELEMENTS == index) { + for (const auto& var_name : op_->Output(name)) { + this->SetVarType(var_name, type); + } + } else { + auto& var_name = op_->Output(name).at(index); + this->SetVarType(var_name, type); + } + } + + virtual proto::VarType::Type GetInputType(const std::string& name, + const int& index = 0) const { + PADDLE_ENFORCE_NOT_NULL( + op_, platform::errors::PreconditionNotMet("op_ should not be null")); + return this->GetVarType(op_->Input(name).at(index)); + } + + virtual proto::VarType::Type GetOutputType(const std::string& name, + const int& index = 0) const { + PADDLE_ENFORCE_NOT_NULL( + op_, platform::errors::PreconditionNotMet("op_ should not be null")); + return this->GetVarType(op_->Output(name).at(index)); + } + + virtual proto::VarType::Type GetInputDataType(const std::string& name, + const int& index = 0) const { + PADDLE_ENFORCE_NOT_NULL( + op_, platform::errors::PreconditionNotMet("op_ should not be null")); + return this->GetVarDataType(op_->Input(name).at(index)); + } + + virtual void SetOutputDataType(const std::string& name, + proto::VarType::Type type, int index = 0) { + PADDLE_ENFORCE_NOT_NULL( + op_, platform::errors::PreconditionNotMet("op_ should not be null")); + if (ALL_ELEMENTS == index) { + for (const auto& var_name : op_->Output(name)) { + this->SetVarDataType(var_name, type); + } + } else { + auto& var_name = op_->Output(name).at(index); + this->SetVarDataType(var_name, type); + } + } + + virtual std::vector GetInputDataTypes( + const std::string& name, const int& index = 0) const { + PADDLE_ENFORCE_NOT_NULL( + op_, platform::errors::PreconditionNotMet("op_ should not be null")); + return this->GetVarDataTypes(op_->Input(name).at(index)); + } + + virtual void SetOutputDataTypes( + const std::string& name, + const std::vector& multiple_data_type, + const int& index = 0) { + PADDLE_ENFORCE_NOT_NULL( + op_, platform::errors::PreconditionNotMet("op_ should not be null")); + auto& var_name = op_->Output(name).at(index); + this->SetVarDataTypes(var_name, multiple_data_type); + } + + virtual std::vector GetInputShape(const std::string& name, + const int& index = 0) const { + PADDLE_ENFORCE_NOT_NULL( + op_, platform::errors::PreconditionNotMet("op_ should not be null")); + auto& var_name = op_->Input(name).at(index); + return this->GetVarShape(var_name); + } + + virtual void SetOutputShape(const std::string& name, + const std::vector& dims, + const int& index = 0) { + PADDLE_ENFORCE_NOT_NULL( + op_, platform::errors::PreconditionNotMet("op_ should not be null")); + auto& var_name = op_->Output(name).at(index); + this->SetVarShape(var_name, dims); + } + + virtual int32_t GetInputLoDLevel(const std::string& name, + const int& index = 0) const { + PADDLE_ENFORCE_NOT_NULL( + op_, platform::errors::PreconditionNotMet("op_ should not be null")); + auto& var_name = op_->Input(name).at(index); + return this->GetVarLoDLevel(var_name); + } + + virtual void SetOutputLoDLevel(const std::string& name, int32_t lod_level, + const int& index = 0) { + PADDLE_ENFORCE_NOT_NULL( + op_, platform::errors::PreconditionNotMet("op_ should not be null")); + auto& var_name = op_->Output(name).at(index); + this->SetVarLoDLevel(var_name, lod_level); + } + + // add a speical API for save_op + // avoid use this API for common logic + virtual void InsertVar(const std::string& var_name, + proto::VarType::Type var_type) { + if (!IsDygraph()) this->SetVarType(var_name, var_type); + } + + virtual bool IsDygraph() const { return false; } + + protected: + virtual bool HasVar(const std::string& name) const { + PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet( + "block_ should not be null")); + return block_->FindVarRecursive(name) != nullptr; + } + + virtual const std::vector& InputVars( + const std::string& name) const { + PADDLE_ENFORCE_NOT_NULL( + op_, platform::errors::PreconditionNotMet("op_ should not be null")); return op_->Input(name); } - virtual const std::vector& Output( + virtual const std::vector& OutputVars( const std::string& name) const { - PADDLE_ENFORCE_NOT_NULL(op_); + PADDLE_ENFORCE_NOT_NULL( + op_, platform::errors::PreconditionNotMet("op_ should not be null")); return op_->Output(name); } - virtual proto::VarType::Type GetType(const std::string& name) const { - PADDLE_ENFORCE_NOT_NULL(block_); + virtual proto::VarType::Type GetVarType(const std::string& name) const { + PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet( + "block_ should not be null")); return block_->FindRecursiveOrCreateVar(name).GetType(); } - virtual void SetType(const std::string& name, proto::VarType::Type type) { - PADDLE_ENFORCE_NOT_NULL(block_); + virtual void SetVarType(const std::string& name, proto::VarType::Type type) { + PADDLE_ENFORCE_NOT_NULL( + block_, platform::errors::PreconditionNotMet("op_ should not be null")); block_->FindRecursiveOrCreateVar(name).SetType(type); } - virtual proto::VarType::Type GetDataType(const std::string& name) const { - PADDLE_ENFORCE_NOT_NULL(block_); + virtual proto::VarType::Type GetVarDataType(const std::string& name) const { + PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet( + "block_ should not be null")); return block_->FindRecursiveOrCreateVar(name).GetDataType(); } - virtual void SetDataType(const std::string& name, proto::VarType::Type type) { - PADDLE_ENFORCE_NOT_NULL(block_); + virtual void SetVarDataType(const std::string& name, + proto::VarType::Type type) { + PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet( + "block_ should not be null")); block_->FindRecursiveOrCreateVar(name).SetDataType(type); } - virtual std::vector GetDataTypes( + virtual std::vector GetVarDataTypes( const std::string& name) const { - PADDLE_ENFORCE_NOT_NULL(block_); + PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet( + "block_ should not be null")); return block_->FindRecursiveOrCreateVar(name).GetDataTypes(); } - virtual void SetDataTypes( + virtual void SetVarDataTypes( const std::string& name, const std::vector& multiple_data_type) { - PADDLE_ENFORCE_NOT_NULL(block_); + PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet( + "block_ should not be null")); block_->FindRecursiveOrCreateVar(name).SetDataTypes(multiple_data_type); } - virtual std::vector GetShape(const std::string& name) const { - PADDLE_ENFORCE_NOT_NULL(block_); + virtual std::vector GetVarShape(const std::string& name) const { + PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet( + "block_ should not be null")); return block_->FindRecursiveOrCreateVar(name).GetShape(); } - virtual void SetShape(const std::string& name, - const std::vector& dims) { - PADDLE_ENFORCE_NOT_NULL(block_); + virtual void SetVarShape(const std::string& name, + const std::vector& dims) { + PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet( + "block_ should not be null")); block_->FindRecursiveOrCreateVar(name).SetShape(dims); } - virtual int32_t GetLoDLevel(const std::string& name) const { - PADDLE_ENFORCE_NOT_NULL(block_); + virtual int32_t GetVarLoDLevel(const std::string& name) const { + PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet( + "block_ should not be null")); return block_->FindRecursiveOrCreateVar(name).GetLoDLevel(); } - virtual void SetLoDLevel(const std::string& name, int32_t lod_level) { - PADDLE_ENFORCE_NOT_NULL(block_); + virtual void SetVarLoDLevel(const std::string& name, int32_t lod_level) { + PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet( + "block_ should not be null")); block_->FindRecursiveOrCreateVar(name).SetLoDLevel(lod_level); } @@ -133,22 +316,85 @@ class VarTypeInference { virtual void operator()(InferVarTypeContext* context) const = 0; // NOLINT }; +class StaticGraphVarTypeInference : public VarTypeInference { + protected: + bool HasVar(InferVarTypeContext* ctx, const std::string& name) const { + return ctx->HasVar(name); + } + + const std::vector& Input(InferVarTypeContext* ctx, + const std::string& name) const { + return ctx->InputVars(name); + } + + const std::vector& Output(InferVarTypeContext* ctx, + const std::string& name) const { + return ctx->OutputVars(name); + } + + proto::VarType::Type GetType(InferVarTypeContext* ctx, + const std::string& name) const { + return ctx->GetVarType(name); + } + + void SetType(InferVarTypeContext* ctx, const std::string& name, + proto::VarType::Type type) const { + ctx->SetVarType(name, type); + } + + proto::VarType::Type GetDataType(InferVarTypeContext* ctx, + const std::string& name) const { + return ctx->GetVarDataType(name); + } + + void SetDataType(InferVarTypeContext* ctx, const std::string& name, + proto::VarType::Type type) const { + ctx->SetVarDataType(name, type); + } + + std::vector GetDataTypes( + InferVarTypeContext* ctx, const std::string& name) const { + return ctx->GetVarDataTypes(name); + } + + void SetDataTypes( + InferVarTypeContext* ctx, const std::string& name, + const std::vector& multiple_data_type) { + return ctx->SetVarDataTypes(name, multiple_data_type); + } + + std::vector GetShape(InferVarTypeContext* ctx, + const std::string& name) const { + return ctx->GetVarShape(name); + } + + void SetShape(InferVarTypeContext* ctx, const std::string& name, + const std::vector& dims) const { + ctx->SetVarShape(name, dims); + } + + int32_t GetLoDLevel(InferVarTypeContext* ctx, const std::string& name) const { + return ctx->GetVarLoDLevel(name); + } + + void SetLoDLevel(InferVarTypeContext* ctx, const std::string& name, + int32_t lod_level) const { + ctx->SetVarLoDLevel(name, lod_level); + } +}; + class PassInDtypeAndVarTypeToOutput : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext* ctx) const final { // NOLINT - auto in_out_var_names = this->GetInputOutputWithSameType(); + auto& in_out_var_names = this->GetInputOutputWithSameType(); for (auto& i_o_n : in_out_var_names) { - auto& x_name = ctx->Input(i_o_n.first).at(0); - auto& out_name = ctx->Output(i_o_n.second).at(0); - - ctx->SetType(out_name, ctx->GetType(x_name)); - ctx->SetDataType(out_name, ctx->GetDataType(x_name)); + ctx->SyncTypeAndDataType(i_o_n.first, i_o_n.second); } } protected: - virtual std::unordered_map + virtual std::unordered_map& GetInputOutputWithSameType() const = 0; }; diff --git a/paddle/fluid/framework/var_type_inference_test.cc b/paddle/fluid/framework/var_type_inference_test.cc index 6bbb25a573d076d5ec6d6fd960a304639e9e3d49..dc86d58f600b83a8ed59f22d9cd73fac7fab13b3 100644 --- a/paddle/fluid/framework/var_type_inference_test.cc +++ b/paddle/fluid/framework/var_type_inference_test.cc @@ -24,13 +24,13 @@ namespace framework { class NOP : public OperatorBase { public: - NOP(const std::string &type, const VariableNameMap &inputs, - const VariableNameMap &outputs, const AttributeMap &attrs) + NOP(const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} private: - void RunImpl(const Scope &scope, - const platform::Place &place) const override {} + void RunImpl(const Scope& scope, + const platform::Place& place) const override {} }; class SumOpMaker : public OpProtoAndCheckerMaker { @@ -44,20 +44,14 @@ class SumOpMaker : public OpProtoAndCheckerMaker { class SumOpVarTypeInference : public VarTypeInference { public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto &inputs = ctx->Input("X"); + void operator()(framework::InferVarTypeContext* ctx) const override { auto default_var_type = proto::VarType::SELECTED_ROWS; - bool any_input_is_lod_tensor = std::any_of( - inputs.begin(), inputs.end(), [&ctx](const std::string &name) { - return ctx->GetType(name) == proto::VarType::LOD_TENSOR; - }); - if (any_input_is_lod_tensor) { + if (ctx->InputTypeAnyOf("X", proto::VarType::LOD_TENSOR)) { default_var_type = proto::VarType::LOD_TENSOR; } - auto out_var_name = ctx->Output("Out").front(); - ctx->SetType(out_var_name, default_var_type); + ctx->SetOutputType("Out", default_var_type); } }; } // namespace framework @@ -71,9 +65,79 @@ REGISTER_OPERATOR(sum_without_infer_var_type, paddle::framework::NOP, namespace paddle { namespace framework { +class TestStaticGraphVarTypeInference : public StaticGraphVarTypeInference { + public: + void operator()(InferVarTypeContext* context) const override {} + + bool HasVar(InferVarTypeContext* ctx, const std::string& name) const { + return StaticGraphVarTypeInference::HasVar(ctx, name); + } + + const std::vector& Input(InferVarTypeContext* ctx, + const std::string& name) const { + return StaticGraphVarTypeInference::Input(ctx, name); + } + + const std::vector& Output(InferVarTypeContext* ctx, + const std::string& name) const { + return StaticGraphVarTypeInference::Output(ctx, name); + } + + proto::VarType::Type GetType(InferVarTypeContext* ctx, + const std::string& name) const { + return StaticGraphVarTypeInference::GetType(ctx, name); + } + + void SetType(InferVarTypeContext* ctx, const std::string& name, + proto::VarType::Type type) const { + StaticGraphVarTypeInference::SetType(ctx, name, type); + } + + proto::VarType::Type GetDataType(InferVarTypeContext* ctx, + const std::string& name) const { + return StaticGraphVarTypeInference::GetDataType(ctx, name); + } + + void SetDataType(InferVarTypeContext* ctx, const std::string& name, + proto::VarType::Type type) const { + StaticGraphVarTypeInference::SetDataType(ctx, name, type); + } + + std::vector GetDataTypes( + InferVarTypeContext* ctx, const std::string& name) const { + return StaticGraphVarTypeInference::GetDataTypes(ctx, name); + } + + void SetDataTypes( + InferVarTypeContext* ctx, const std::string& name, + const std::vector& multiple_data_type) { + return StaticGraphVarTypeInference::SetDataTypes(ctx, name, + multiple_data_type); + } + + std::vector GetShape(InferVarTypeContext* ctx, + const std::string& name) const { + return StaticGraphVarTypeInference::GetShape(ctx, name); + } + + void SetShape(InferVarTypeContext* ctx, const std::string& name, + const std::vector& dims) const { + StaticGraphVarTypeInference::SetShape(ctx, name, dims); + } + + int32_t GetLoDLevel(InferVarTypeContext* ctx, const std::string& name) const { + return StaticGraphVarTypeInference::GetLoDLevel(ctx, name); + } + + void SetLoDLevel(InferVarTypeContext* ctx, const std::string& name, + int32_t lod_level) const { + StaticGraphVarTypeInference::SetLoDLevel(ctx, name, lod_level); + } +}; + TEST(InferVarType, sum_op) { ProgramDesc prog; - auto *op = prog.MutableBlock(0)->AppendOp(); + auto* op = prog.MutableBlock(0)->AppendOp(); op->SetType("sum"); op->SetInput("X", {"test_a", "test_b", "test_c"}); op->SetOutput("Out", {"test_out"}); @@ -96,7 +160,7 @@ TEST(InferVarType, sum_op) { TEST(InferVarType, sum_op_without_infer_var_type) { ProgramDesc prog; - auto *op = prog.MutableBlock(0)->AppendOp(); + auto* op = prog.MutableBlock(0)->AppendOp(); op->SetType("sum_without_infer_var_type"); op->SetInput("X", {"test2_a", "test2_b", "test2_c"}); op->SetOutput("Out", {"test2_out"}); @@ -112,5 +176,112 @@ TEST(InferVarType, sum_op_without_infer_var_type) { prog.MutableBlock(0)->Var("test2_out")->GetType()); } +TEST(InferVarType, multiple_api) { + ProgramDesc prog; + + auto* block = prog.MutableBlock(0); + auto* op = block->AppendOp(); + op->SetType("sum_without_infer_var_type"); + op->SetInput("X", {"test2_a", "test2_b"}); + op->SetOutput("Out", {"test2_a_out", "test2_b_out"}); + + block->Var("test2_a")->SetType(proto::VarType::SELECTED_ROWS); + block->Var("test2_b")->SetType(proto::VarType::SELECTED_ROWS); + block->Var("test2_a_out"); + block->Var("test2_b_out"); + + InferVarTypeContext ctx(op, block); + + ASSERT_TRUE(ctx.HasInput("X")); + ASSERT_TRUE(ctx.HasOutput("Out")); + + ASSERT_EQ(2u, ctx.InputSize("X")); + ASSERT_EQ("test2_a", ctx.InputVarName("X", 0)); + + ASSERT_EQ(proto::VarType::SELECTED_ROWS, ctx.GetInputType("X")); + + ASSERT_TRUE(ctx.InputTypeAllOf("X", proto::VarType::SELECTED_ROWS)); + ASSERT_FALSE(ctx.InputTypeAnyOf("X", proto::VarType::LOD_TENSOR)); + + ctx.SyncTypeAndDataType("X", "Out"); + + ASSERT_EQ(proto::VarType::SELECTED_ROWS, ctx.GetOutputType("Out")); + ASSERT_EQ(proto::VarType::LOD_TENSOR, ctx.GetOutputType("Out", 1)); + + ctx.SetOutputType("Out", proto::VarType::SELECTED_ROWS, ALL_ELEMENTS); + ctx.SetOutputType("Out", proto::VarType::LOD_TENSOR, 1); + ASSERT_EQ(proto::VarType::SELECTED_ROWS, ctx.GetOutputType("Out")); + ASSERT_EQ(proto::VarType::LOD_TENSOR, ctx.GetOutputType("Out", 1)); + + ASSERT_EQ(0, ctx.GetInputDataType("X")); + + ctx.SetOutputDataType("Out", proto::VarType::FP32, ALL_ELEMENTS); + ctx.SetOutputDataType("Out", proto::VarType::INT8, 1); + ASSERT_EQ(proto::VarType::FP32, + prog.MutableBlock(0)->Var("test2_a_out")->GetDataType()); + ASSERT_EQ(proto::VarType::INT8, + prog.MutableBlock(0)->Var("test2_b_out")->GetDataType()); + + ASSERT_FALSE(ctx.IsDygraph()); + + // test StaticGraphVarTypeInference + TestStaticGraphVarTypeInference infer; + ASSERT_TRUE(infer.HasVar(&ctx, "test2_a")); + ASSERT_EQ(infer.Input(&ctx, "X").size(), infer.Output(&ctx, "Out").size()); + + ASSERT_EQ(proto::VarType::FP32, infer.GetDataType(&ctx, "test2_a_out")); + infer.SetDataType(&ctx, "test2_a_out", proto::VarType::FP64); + ASSERT_EQ(proto::VarType::FP64, infer.GetDataType(&ctx, "test2_a_out")); + + ASSERT_EQ(proto::VarType::SELECTED_ROWS, infer.GetType(&ctx, "test2_a_out")); + infer.SetType(&ctx, "test2_a_out", proto::VarType::LOD_TENSOR); + ASSERT_EQ(proto::VarType::LOD_TENSOR, infer.GetType(&ctx, "test2_a_out")); + + ASSERT_ANY_THROW(infer.GetDataTypes(&ctx, "test2_a_out")); + ASSERT_ANY_THROW(infer.SetDataTypes(&ctx, "test2_a_out", {})); + + ASSERT_EQ(0u, infer.GetShape(&ctx, "test2_a_out").size()); + infer.SetShape(&ctx, "test2_a_out", { + 1, 3, 3, + }); + ASSERT_EQ(3u, infer.GetShape(&ctx, "test2_a_out").size()); + + ASSERT_EQ(0, infer.GetLoDLevel(&ctx, "test2_a_out")); + infer.SetLoDLevel(&ctx, "test2_a_out", 2); + ASSERT_EQ(2, infer.GetLoDLevel(&ctx, "test2_a_out")); +} + +TEST(InferVarType, test_enforce_check) { + InferVarTypeContext ctx(nullptr, nullptr); + ASSERT_ANY_THROW(ctx.HasInput("X")); + ASSERT_ANY_THROW(ctx.HasOutput("Out")); + + ASSERT_ANY_THROW(ctx.InputSize("X")); + ASSERT_ANY_THROW(ctx.InputVarName("X")); + + ASSERT_ANY_THROW(ctx.InputTypeAnyOf("X", proto::VarType::LOD_TENSOR)); + ASSERT_ANY_THROW(ctx.InputTypeAllOf("X", proto::VarType::LOD_TENSOR)); + + ASSERT_ANY_THROW(ctx.SyncTypeAndDataType("X", "Out")); + + ASSERT_ANY_THROW(ctx.SetOutputType("Out", proto::VarType::LOD_TENSOR)); + ASSERT_ANY_THROW(ctx.GetInputType("X")); + ASSERT_ANY_THROW(ctx.GetOutputType("Out")); + + ASSERT_ANY_THROW(ctx.GetInputDataType("X")); + ASSERT_ANY_THROW(ctx.SetOutputDataType("Out", proto::VarType::LOD_TENSOR)); + + ASSERT_ANY_THROW(ctx.GetInputDataTypes("X")); + ASSERT_ANY_THROW(ctx.SetOutputDataTypes("Out", {})); + + ASSERT_ANY_THROW(ctx.GetInputShape("X")); + ASSERT_ANY_THROW(ctx.SetOutputShape("Out", {})); + + ASSERT_ANY_THROW(ctx.GetInputLoDLevel("X")); + ASSERT_ANY_THROW(ctx.SetOutputLoDLevel("Out", 1)); + + ASSERT_ANY_THROW(ctx.InsertVar("var", proto::VarType::LOD_TENSOR)); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/imperative/infer_var_type_context.h b/paddle/fluid/imperative/infer_var_type_context.h index e46ac315d2d9cadbcf4c37db60a271261c1ed8e3..f740507fa508600fd268c8b80e5850497b07ea3d 100644 --- a/paddle/fluid/imperative/infer_var_type_context.h +++ b/paddle/fluid/imperative/infer_var_type_context.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -35,30 +36,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { : InferVarTypeContext(nullptr, nullptr), inputs_(inputs), outputs_(outputs), - attrs_(attrs_map), - input_names_(), - output_names_(), - var_set_() { - input_names_.reserve(inputs_.size()); - for (auto& it : inputs_) { - for (auto& var : it.second) { - if (var) { - input_names_[it.first].emplace_back(var->Name()); - var_set_[var->Name()] = var.get(); - } - } - } - - output_names_.reserve(outputs_.size()); - for (auto& it : outputs_) { - for (auto& var : it.second) { - if (var) { - output_names_[it.first].emplace_back(var->Name()); - var_set_[var->Name()] = var.get(); - } - } - } - } + attrs_(attrs_map) {} virtual ~RuntimeInferVarTypeContext() {} @@ -70,10 +48,6 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { return iter->second; } - bool HasVar(const std::string& name) const override { - return var_set_.count(name) > 0; - } - bool HasInput(const std::string& name) const override { auto it = inputs_.find(name); return (it != inputs_.end() && it->second.size() > 0); @@ -84,93 +58,173 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { return (it != outputs_.end() && it->second.size() > 0); } - const std::vector& Input( - const std::string& name) const override { - auto iter = input_names_.find(name); - PADDLE_ENFORCE_EQ( - iter != input_names_.end(), true, - platform::errors::NotFound("Cannot find input var %s", name)); - return iter->second; + size_t InputSize(const std::string& name) const { + return inputs_.at(name).size(); } - const std::vector& Output( - const std::string& name) const override { - auto iter = output_names_.find(name); + const std::string& InputVarName(const std::string& name, + const int index = 0) const { + return inputs_.at(name)[index]->Name(); + } - PADDLE_ENFORCE_EQ( - iter != output_names_.end(), true, - platform::errors::NotFound("Cannot find output var %s", name)); - return iter->second; + bool InputTypeAnyOf(const std::string& name, + framework::proto::VarType::Type type) const override { + auto& inputs = inputs_.at(name); + return std::any_of(inputs.begin(), inputs.end(), + [&type](const std::shared_ptr& var) { + return var->Type() == type; + }); } - framework::proto::VarType::Type GetType( - const std::string& name) const override { - auto iter = var_set_.find(name); + bool InputTypeAllOf(const std::string& name, + framework::proto::VarType::Type type) const override { + auto& inputs = inputs_.at(name); + return std::all_of(inputs.begin(), inputs.end(), + [&type](const std::shared_ptr& var) { + return var->Type() == type; + }); + } - PADDLE_ENFORCE_EQ( - iter != var_set_.end(), true, - platform::errors::NotFound("Cannot find var %s in GetType", name)); - return iter->second->Type(); + void SyncTypeAndDataType(const std::string& input_name, + const std::string& output_name, + int index = 0) override { + auto in_var = inputs_.at(input_name)[index]; + auto out_var = outputs_.at(output_name)[index]; + if (in_var != out_var) { + this->SetVarBaseType(out_var, in_var->Type()); + this->SetVarBaseDataType(out_var, in_var->DataType()); + } } - void SetType(const std::string& name, - framework::proto::VarType::Type type) override { - if (name == "kLookupTablePath") { - VLOG(2) << "SUPER UGLY FIX, remove this when move imperative mode in C++"; + void SetOutputType(const std::string& name, + framework::proto::VarType::Type type, + int index = 0) override { + if (index == framework::ALL_ELEMENTS) { + for (auto& item : outputs_.at(name)) { + this->SetVarBaseType(item, type); + } } else { - var_set_[name]->SetType(type); - if ((var_set_[name]->MutableVar()->IsInitialized() == true) && - (var_set_[name]->MutableVar()->Type() != type)) { - var_set_[name]->MutableVar()->Clear(); + auto& var = outputs_.at(name)[index]; + this->SetVarBaseType(var, type); + } + } + + void SetVarBaseType(std::shared_ptr out, + framework::proto::VarType::Type type) { + out->SetType(type); + if ((out->MutableVar()->IsInitialized() == true) && + (out->MutableVar()->Type() != type)) { + out->MutableVar()->Clear(); + } + } + + void SetVarBaseDataType(std::shared_ptr out, + framework::proto::VarType::Type type) { + out->SetDataType(type); + } + + framework::proto::VarType::Type GetInputType( + const std::string& name, const int& index = 0) const override { + return inputs_.at(name)[index]->Type(); + } + + framework::proto::VarType::Type GetOutputType( + const std::string& name, const int& index = 0) const override { + return outputs_.at(name)[index]->Type(); + } + + framework::proto::VarType::Type GetInputDataType( + const std::string& name, const int& index = 0) const override { + return inputs_.at(name)[index]->DataType(); + } + + void SetOutputDataType(const std::string& name, + framework::proto::VarType::Type type, + int index = 0) override { + if (framework::ALL_ELEMENTS == index) { + for (auto& item : outputs_.at(name)) { + this->SetVarBaseDataType(item, type); } + } else { + auto& var = outputs_.at(name)[index]; + this->SetVarBaseDataType(var, type); } } - framework::proto::VarType::Type GetDataType( + bool IsDygraph() const override { return true; } + + protected: + bool HasVar(const std::string& name) const override { + PADDLE_THROW(platform::errors::PermissionDenied( + "HasVar is not supported in runtime InferVarType")); + } + + const std::vector& InputVars( const std::string& name) const override { - auto iter = var_set_.find(name); + PADDLE_THROW(platform::errors::PermissionDenied( + "InputVars is not supported in runtime InferVarType")); + } - PADDLE_ENFORCE_EQ( - iter != var_set_.end(), true, - platform::errors::NotFound("Cannot find var %s in GetDataType", name)); - return iter->second->DataType(); + const std::vector& OutputVars( + const std::string& name) const override { + PADDLE_THROW(platform::errors::PermissionDenied( + "OutputVars is not supported in runtime InferVarType")); + } + + framework::proto::VarType::Type GetVarType( + const std::string& name) const override { + PADDLE_THROW(platform::errors::PermissionDenied( + "Do not manipulate var in runtime InferVarType")); + } + + void SetVarType(const std::string& name, + framework::proto::VarType::Type type) override { + PADDLE_THROW(platform::errors::PermissionDenied( + "Do not manipulate var in runtime InferVarType")); } - void SetDataType(const std::string& name, - framework::proto::VarType::Type type) override { - var_set_[name]->SetDataType(type); + framework::proto::VarType::Type GetVarDataType( + const std::string& name) const override { + PADDLE_THROW(platform::errors::PermissionDenied( + "Do not manipulate var in runtime InferVarType")); + } + + void SetVarDataType(const std::string& name, + framework::proto::VarType::Type type) override { + PADDLE_THROW(platform::errors::PermissionDenied( + "Do not manipulate var in runtime InferVarType")); } - std::vector GetDataTypes( + std::vector GetVarDataTypes( const std::string& name) const override { PADDLE_THROW(platform::errors::PermissionDenied( - "GetDataTypes is not supported in runtime InferVarType")); + "GetVarDataTypes is not supported in runtime InferVarType")); } - void SetDataTypes(const std::string& name, - const std::vector& - multiple_data_type) override { + void SetVarDataTypes(const std::string& name, + const std::vector& + multiple_data_type) override { PADDLE_THROW(platform::errors::PermissionDenied( - "SetDataTypes is not supported in runtime InferVarType")); + "SetVarDataTypes is not supported in runtime InferVarType")); } - std::vector GetShape(const std::string& name) const override { + std::vector GetVarShape(const std::string& name) const override { PADDLE_THROW(platform::errors::PermissionDenied( "Do not handle Shape in runtime InferVarType")); } - void SetShape(const std::string& name, - const std::vector& dims) override { + void SetVarShape(const std::string& name, + const std::vector& dims) override { PADDLE_THROW(platform::errors::PermissionDenied( "Do not handle Shape in runtime InferVarType")); } - int32_t GetLoDLevel(const std::string& name) const override { + int32_t GetVarLoDLevel(const std::string& name) const override { PADDLE_THROW(platform::errors::PermissionDenied( "Do not handle LoDLevel in runtime InferVarType")); } - void SetLoDLevel(const std::string& name, int32_t lod_level) override { + void SetVarLoDLevel(const std::string& name, int32_t lod_level) override { PADDLE_THROW(platform::errors::PermissionDenied( "Do not handle LoDLevel in runtime InferVarType")); } @@ -179,9 +233,6 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { const NameVarMap& inputs_; const NameVarMap& outputs_; const framework::AttributeMap& attrs_; - std::unordered_map> input_names_; - std::unordered_map> output_names_; - std::unordered_map var_set_; }; } // namespace imperative diff --git a/paddle/fluid/imperative/tests/test_layer.cc b/paddle/fluid/imperative/tests/test_layer.cc index 36c448402d7325d3e71587853ec8f3b631c61e6b..a28916b59c3d299a37d6e4507838d1094fb8127b 100644 --- a/paddle/fluid/imperative/tests/test_layer.cc +++ b/paddle/fluid/imperative/tests/test_layer.cc @@ -37,33 +37,154 @@ using vb_vector = std::vector>; using var_pair = std::pair; +template +class TestRuntimeInferVarTypeContext + : public RuntimeInferVarTypeContext { + public: + TestRuntimeInferVarTypeContext(const NameVarMap& inputs, + const NameVarMap& outputs, + const framework::AttributeMap& attrs_map) + : RuntimeInferVarTypeContext(inputs, outputs, attrs_map) {} + + bool HasVar(const std::string& name) const { + return RuntimeInferVarTypeContext::HasVar(name); + } + + const std::vector& InputVars(const std::string& name) const { + return RuntimeInferVarTypeContext::InputVars(name); + } + + const std::vector& OutputVars(const std::string& name) const { + return RuntimeInferVarTypeContext::OutputVars(name); + } + + framework::proto::VarType::Type GetVarType(const std::string& name) const { + return RuntimeInferVarTypeContext::GetVarType(name); + } + + void SetVarType(const std::string& name, + framework::proto::VarType::Type type) { + RuntimeInferVarTypeContext::SetVarType(name, type); + } + + framework::proto::VarType::Type GetVarDataType( + const std::string& name) const { + return RuntimeInferVarTypeContext::GetVarDataType(name); + } + + void SetVarDataType(const std::string& name, + framework::proto::VarType::Type type) { + RuntimeInferVarTypeContext::SetVarDataType(name, type); + } + + std::vector GetVarDataTypes( + const std::string& name) const { + return RuntimeInferVarTypeContext::GetVarDataTypes(name); + } + + void SetVarDataTypes( + const std::string& name, + const std::vector& multiple_data_type) { + RuntimeInferVarTypeContext::SetVarDataTypes(name, + multiple_data_type); + } + + std::vector GetVarShape(const std::string& name) const { + return RuntimeInferVarTypeContext::GetVarShape(name); + } + + void SetVarShape(const std::string& name, const std::vector& dims) { + RuntimeInferVarTypeContext::SetVarShape(name, dims); + } + + int32_t GetVarLoDLevel(const std::string& name) const { + return RuntimeInferVarTypeContext::GetVarLoDLevel(name); + } + + void SetVarLoDLevel(const std::string& name, int32_t lod_level) { + RuntimeInferVarTypeContext::SetVarLoDLevel(name, lod_level); + } +}; + TEST(test_layer, test_runtime_context) { std::shared_ptr vin( new imperative::VarBase(false, "vin")); + std::shared_ptr vin_b( + new imperative::VarBase(false, "vin_b")); std::shared_ptr vout( new imperative::VarBase(false, "vout")); - var_pair in_pair = var_pair("X", vb_vector(1, vin)); - var_pair out_pair = var_pair("Out", vb_vector(1, vout)); + std::shared_ptr vout_b( + new imperative::VarBase(false, "vout_b")); + var_pair in_pair = var_pair("X", {vin, vin_b}); + var_pair out_pair = var_pair("Out", {vout, vout_b}); imperative::NameVarBaseMap ins = {in_pair}; imperative::NameVarBaseMap outs = {out_pair}; framework::AttributeMap attrs; - auto *ctx = new imperative::RuntimeInferVarTypeContext( - ins, outs, attrs); - ASSERT_TRUE(ctx->HasVar("vin")); + + auto* ctx = + new imperative::TestRuntimeInferVarTypeContext( + ins, outs, attrs); + ASSERT_TRUE(ctx->HasInput("X")); ASSERT_TRUE(ctx->HasOutput("Out")); - ASSERT_ANY_THROW(ctx->GetDataTypes("vin")); + ASSERT_EQ(2u, ctx->InputSize("X")); + ASSERT_EQ("vin", ctx->InputVarName("X", 0)); + + ASSERT_TRUE(ctx->InputTypeAnyOf("X", framework::proto::VarType::LOD_TENSOR)); + ASSERT_TRUE(ctx->InputTypeAllOf("X", framework::proto::VarType::LOD_TENSOR)); + + ASSERT_EQ(framework::proto::VarType::LOD_TENSOR, ctx->GetInputType("X")); + ASSERT_EQ(framework::proto::VarType::FP32, ctx->GetInputDataType("X")); + + ctx->SyncTypeAndDataType("X", "Out"); + + ASSERT_EQ(framework::proto::VarType::FP32, vout->DataType()); + + ASSERT_EQ(framework::proto::VarType::LOD_TENSOR, ctx->GetOutputType("Out")); + + ctx->SetOutputType("Out", framework::proto::VarType::SELECTED_ROWS, + framework::ALL_ELEMENTS); + ctx->SetOutputType("Out", framework::proto::VarType::LOD_TENSOR_ARRAY); + ASSERT_EQ(framework::proto::VarType::LOD_TENSOR_ARRAY, vout->Type()); + ASSERT_EQ(framework::proto::VarType::SELECTED_ROWS, vout_b->Type()); + + ctx->SetOutputDataType("Out", framework::proto::VarType::FP64, + framework::ALL_ELEMENTS); + ctx->SetOutputDataType("Out", framework::proto::VarType::INT8); + + ASSERT_EQ(framework::proto::VarType::INT8, vout->DataType()); + ASSERT_EQ(framework::proto::VarType::FP64, vout_b->DataType()); + + // no throw, but do nothing + ASSERT_NO_THROW( + ctx->InsertVar("vout", framework::proto::VarType::LOD_TENSOR)); + ASSERT_EQ(framework::proto::VarType::LOD_TENSOR_ARRAY, vout->Type()); + + ASSERT_ANY_THROW(ctx->HasVar("vin")); + ASSERT_ANY_THROW(ctx->InputVars("X")); + ASSERT_ANY_THROW(ctx->OutputVars("Out")); + ASSERT_ANY_THROW(ctx->GetVarType("vin")); + ASSERT_ANY_THROW( + ctx->SetVarType("vin", framework::proto::VarType::LOD_TENSOR)); + ASSERT_ANY_THROW(ctx->GetVarDataType("vin")); + ASSERT_ANY_THROW( + ctx->SetVarDataType("vout", framework::proto::VarType::FP32)); + + ASSERT_ANY_THROW(ctx->GetVarDataTypes("vin")); std::vector NullType; - ASSERT_ANY_THROW(ctx->SetDataTypes("vin", NullType)); - ASSERT_ANY_THROW(ctx->GetShape("vin")); - ASSERT_ANY_THROW(ctx->GetLoDLevel("vin")); - ASSERT_ANY_THROW(ctx->SetLoDLevel("vin", 2)); + ASSERT_ANY_THROW(ctx->SetVarDataTypes("vin", NullType)); + ASSERT_ANY_THROW(ctx->GetVarShape("vin")); + ASSERT_ANY_THROW(ctx->SetVarShape("vin", {})); + ASSERT_ANY_THROW(ctx->GetVarLoDLevel("vin")); + ASSERT_ANY_THROW(ctx->SetVarLoDLevel("vin", 2)); + + ASSERT_TRUE(ctx->IsDygraph()); } -std::string LayerDebugString(const std::string &op_type, - const NameVarBaseMap &ins, - const NameVarBaseMap &outs); +std::string LayerDebugString(const std::string& op_type, + const NameVarBaseMap& ins, + const NameVarBaseMap& outs); TEST(test_layer, test_debug_string) { platform::CPUPlace place; @@ -71,7 +192,7 @@ TEST(test_layer, test_debug_string) { new imperative::VarBase(false, "vin")); var_pair in_pair = var_pair("X", vb_vector(1, vin)); - auto test_func = [&](std::shared_ptr &vout) { + auto test_func = [&](std::shared_ptr& vout) { var_pair out_pair = var_pair("Out", vb_vector(1, vout)); imperative::NameVarBaseMap ins = {in_pair}; imperative::NameVarBaseMap outs = {out_pair}; @@ -124,26 +245,26 @@ TEST(test_layer, test_debug_string) { } static std::shared_ptr CreateGradNode( - size_t id, const std::string &type, const imperative::NameVarBaseMap &ins, - const imperative::NameVarBaseMap &outs, - const framework::AttributeMap &attrs, const platform::Place &place) { + size_t id, const std::string& type, const imperative::NameVarBaseMap& ins, + const imperative::NameVarBaseMap& outs, + const framework::AttributeMap& attrs, const platform::Place& place) { auto node = std::make_shared(); - auto *op = &(node->emplace_back()); + auto* op = &(node->emplace_back()); op->SetId(id); op->SetPlace(place); op->SetType(type); op->SetAttrMap(attrs); - for (auto &pair : ins) { + for (auto& pair : ins) { std::vector> vars; - for (auto &var : pair.second) { + for (auto& var : pair.second) { vars.emplace_back(var->SharedVar()); } op->SetInput(pair.first, vars, false); } - for (auto &pair : outs) { + for (auto& pair : outs) { std::vector> vars; - for (auto &var : pair.second) { + for (auto& var : pair.second) { vars.emplace_back(var->SharedVar()); } op->SetOutput(pair.first, vars, false); @@ -173,7 +294,7 @@ TEST(test_layer, test_clear_backward_info) { node->InsertGradPendingNode(pending_node); ASSERT_EQ(node->size(), 1UL); - auto *op = &(node->back()); + auto* op = &(node->back()); ASSERT_GT(op->GetInsMap().size(), 0UL); ASSERT_GT(op->GetOutsMap().size(), 0UL); @@ -196,10 +317,10 @@ TEST(test_layer, test_varbase_basic) { std::shared_ptr vin_with_grad( new imperative::VarBase(true, "vin")); ASSERT_ANY_THROW(vin->MutableGradVar()); - ASSERT_NO_THROW(ASSERT_TRUE(dynamic_cast( + ASSERT_NO_THROW(ASSERT_TRUE(dynamic_cast( vin_with_grad->MutableGradVar()) != 0)); - ASSERT_TRUE(dynamic_cast( - vin_with_grad->MutableGradVar()) != 0); + ASSERT_TRUE( + dynamic_cast(vin_with_grad->MutableGradVar()) != 0); vin_with_grad->SetOverridedStopGradient(false); ASSERT_FALSE(vin_with_grad->OverridedStopGradient()); ASSERT_NO_FATAL_FAILURE(vin_with_grad->SetPersistable(true)); @@ -228,9 +349,9 @@ TEST(test_layer, test_dygraph_execution_context) { auto op = framework::OpRegistry::CreateOp("mul", {}, {}, {}, false); paddle::platform::CPUPlace cpu_place; - paddle::platform::DeviceContextPool &pool = + paddle::platform::DeviceContextPool& pool = paddle::platform::DeviceContextPool::Instance(); - auto *dev_ctx = pool.Get(cpu_place); + auto* dev_ctx = pool.Get(cpu_place); paddle::framework::RuntimeContext ctx({}, {}); framework::Scope scope; diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 8ec77513530438cad5caecaa2955afa59d474826..4587b494b31ca04a51304742ba57135b34669b18 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -129,9 +129,10 @@ class ActivationOp : public framework::OperatorWithKernel { class ActivationOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { protected: - std::unordered_map GetInputOutputWithSameType() + std::unordered_map& GetInputOutputWithSameType() const override { - return std::unordered_map{{"X", /*->*/ "Out"}}; + static std::unordered_map m{{"X", /*->*/ "Out"}}; + return m; } }; diff --git a/paddle/fluid/operators/allclose_op.cc b/paddle/fluid/operators/allclose_op.cc index dcfb4104c31c9e995c5e93cf58e25caa54e86325..911757007266c9ff88b0e348d350909ce0ff0bce 100644 --- a/paddle/fluid/operators/allclose_op.cc +++ b/paddle/fluid/operators/allclose_op.cc @@ -103,8 +103,7 @@ class AllcloseOp : public framework::OperatorWithKernel { class AllcloseOpVarTypeInference : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const override { - auto out_var_name = ctx->Output("Out").front(); - ctx->SetDataType(out_var_name, framework::proto::VarType::BOOL); + ctx->SetOutputDataType("Out", framework::proto::VarType::BOOL); } }; diff --git a/paddle/fluid/operators/assign_op.cc b/paddle/fluid/operators/assign_op.cc index 549f44250c88c3675fb55fcb1c6f5ffa31189a9b..d6ab77b6cb85a647e77afebf861068a38cf3b7df 100644 --- a/paddle/fluid/operators/assign_op.cc +++ b/paddle/fluid/operators/assign_op.cc @@ -60,11 +60,7 @@ class AssignOp : public framework::OperatorWithKernel { class AssignInferVarType : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const override { - auto out_var_name = ctx->Output("Out")[0]; - auto input_type = ctx->GetType(ctx->Input("X")[0]); - auto input_data_type = ctx->GetDataType(ctx->Input("X")[0]); - ctx->SetType(out_var_name, input_type); - ctx->SetDataType(out_var_name, input_data_type); + ctx->SyncTypeAndDataType("X", "Out"); } }; diff --git a/paddle/fluid/operators/batch_norm_op.h b/paddle/fluid/operators/batch_norm_op.h index dd4043c6a7271b0ebb3be5099fc123db42ccd8c9..9f844b7c078bb7397d98dad57d9fad475283f397 100644 --- a/paddle/fluid/operators/batch_norm_op.h +++ b/paddle/fluid/operators/batch_norm_op.h @@ -171,9 +171,10 @@ class BatchNormGradMaker : public framework::SingleGradOpMaker { class BatchNormOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { protected: - std::unordered_map GetInputOutputWithSameType() + std::unordered_map& GetInputOutputWithSameType() const override { - return std::unordered_map{{"X", /*->*/ "Y"}}; + static std::unordered_map m{{"X", /*->*/ "Y"}}; + return m; } }; diff --git a/paddle/fluid/operators/beam_search_decode_op.cc b/paddle/fluid/operators/beam_search_decode_op.cc index 8c9e397b13f103f6d3e2fcaff01b06903d6c10ed..eece632e74a29ab3eb660f6ffa4839ecd8650d9c 100644 --- a/paddle/fluid/operators/beam_search_decode_op.cc +++ b/paddle/fluid/operators/beam_search_decode_op.cc @@ -204,12 +204,10 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase { class BeamSearchDecodeInferVarType : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext* ctx) const override { - for (auto& o : ctx->Output("SentenceIds")) { - ctx->SetType(o, framework::proto::VarType::LOD_TENSOR); - } - for (auto& o : ctx->Output("SentenceScores")) { - ctx->SetType(o, framework::proto::VarType::LOD_TENSOR); - } + ctx->SetOutputType("SentenceIds", framework::proto::VarType::LOD_TENSOR, + framework::ALL_ELEMENTS); + ctx->SetOutputType("SentenceScores", framework::proto::VarType::LOD_TENSOR, + framework::ALL_ELEMENTS); } }; diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index 3abd70dbb0536809f5fb405c783318cf5097b5db..c866189c7a0806e6791b1955ae879e7f7e1fe6a1 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -122,12 +122,10 @@ class BeamSearchOp : public framework::OperatorWithKernel { class BeamSearchInferVarType : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const override { - for (auto &o : ctx->Output("selected_ids")) { - ctx->SetType(o, framework::proto::VarType::LOD_TENSOR); - } - for (auto &o : ctx->Output("selected_scores")) { - ctx->SetType(o, framework::proto::VarType::LOD_TENSOR); - } + ctx->SetOutputType("selected_ids", framework::proto::VarType::LOD_TENSOR, + framework::ALL_ELEMENTS); + ctx->SetOutputType("selected_scores", framework::proto::VarType::LOD_TENSOR, + framework::ALL_ELEMENTS); } }; diff --git a/paddle/fluid/operators/controlflow/get_places_op.cc b/paddle/fluid/operators/controlflow/get_places_op.cc index eff88f54ade6e4bc71e8d80771b3f757819354a9..e60aee4b4ca99069b9510db3168b904ffcf2b114 100644 --- a/paddle/fluid/operators/controlflow/get_places_op.cc +++ b/paddle/fluid/operators/controlflow/get_places_op.cc @@ -92,9 +92,8 @@ execution. class GetPlacesInferVarType : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const override { - for (auto &o_name : ctx->Output("Out")) { - ctx->SetType(o_name, framework::proto::VarType::PLACE_LIST); - } + ctx->SetOutputType("Out", framework::proto::VarType::PLACE_LIST, + framework::ALL_ELEMENTS); } }; diff --git a/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc b/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc index 10443f45643dab4b20f5343eba0e140a0c038209..9f7702a5d6b63cc689535f2f1c880058e6211709 100644 --- a/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc +++ b/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc @@ -111,15 +111,15 @@ class WriteToArrayInferShape : public framework::InferShapeBase { } }; -class WriteToArrayInferVarType : public framework::VarTypeInference { +class WriteToArrayInferVarType : public framework::StaticGraphVarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_name = ctx->Input("X")[0]; - auto out_name = ctx->Output("Out")[0]; + auto x_name = Input(ctx, "X")[0]; + auto out_name = Output(ctx, "Out")[0]; VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY"; - ctx->SetType(out_name, framework::proto::VarType::LOD_TENSOR_ARRAY); - if (ctx->HasVar(x_name)) { - ctx->SetDataType(out_name, ctx->GetDataType(x_name)); + SetType(ctx, out_name, framework::proto::VarType::LOD_TENSOR_ARRAY); + if (HasVar(ctx, x_name)) { + SetDataType(ctx, out_name, GetDataType(ctx, x_name)); } } }; diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index f6aaa49eceda0aacc1f76b235672cdb75ceba3b8..daa4486360d3d289a6c5137bf446d022eb4dd21f 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -398,18 +398,19 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker { } }; -class WhileGradOpVarTypeInference : public framework::VarTypeInference { +class WhileGradOpVarTypeInference + : public framework::StaticGraphVarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const override { - auto p_names = ctx->Input(kX); - auto pg_ig_names = ctx->Output(framework::GradVarName(kX)); + auto p_names = Input(ctx, kX); + auto pg_ig_names = Output(ctx, framework::GradVarName(kX)); for (size_t i = 0; i < p_names.size(); ++i) { - if (ctx->HasVar(pg_ig_names[i])) { + if (HasVar(ctx, pg_ig_names[i])) { VLOG(5) << "Setting " << pg_ig_names[i] << " following " << p_names[i] - << " type: " << ctx->GetType(p_names[i]); - ctx->SetType(pg_ig_names[i], ctx->GetType(p_names[i])); - ctx->SetDataType(pg_ig_names[i], ctx->GetDataType(p_names[i])); + << " type: " << GetType(ctx, p_names[i]); + SetType(ctx, pg_ig_names[i], GetType(ctx, p_names[i])); + SetDataType(ctx, pg_ig_names[i], GetDataType(ctx, p_names[i])); } } } diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index 85ef42548ab1f471f85f9e8272a676a2bbbf05d2..8a5345e3cf8d9f1c657fe2996015af4dc038a1bf 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -254,10 +254,11 @@ class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker { class ConvOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { protected: - std::unordered_map GetInputOutputWithSameType() + std::unordered_map& GetInputOutputWithSameType() const override { - return std::unordered_map{ + static std::unordered_map m{ {"Input", /*->*/ "Output"}}; + return m; } }; diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index 13ce0b2a9b68eb31c0c503e40a002c689fd5f95c..880ea0d96ce8d077eea19d2640101603a59db90a 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -177,9 +177,10 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel { class CrossEntropyOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { protected: - std::unordered_map GetInputOutputWithSameType() + std::unordered_map& GetInputOutputWithSameType() const override { - return std::unordered_map{{"X", /*->*/ "Y"}}; + static std::unordered_map m{{"X", /*->*/ "Y"}}; + return m; } }; diff --git a/paddle/fluid/operators/distributed_ops/merge_ids_op.cc b/paddle/fluid/operators/distributed_ops/merge_ids_op.cc index 324edd7e47964ad358aa607ad606a5864d99aed7..1b309c8a2d41521c0b86fadf20153f7dc994f30e 100644 --- a/paddle/fluid/operators/distributed_ops/merge_ids_op.cc +++ b/paddle/fluid/operators/distributed_ops/merge_ids_op.cc @@ -115,10 +115,8 @@ class MergeIdsOp : public framework::OperatorWithKernel { class MergeIdsOpInferVarType : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const override { - auto input_type = ctx->GetType(ctx->Input("Ids")[0]); - for (auto &out_var : ctx->Output("Out")) { - ctx->SetType(out_var, input_type); - } + auto input_type = ctx->GetInputType("Ids"); + ctx->SetOutputType("Out", input_type, framework::ALL_ELEMENTS); } }; diff --git a/paddle/fluid/operators/distributed_ops/split_ids_op.cc b/paddle/fluid/operators/distributed_ops/split_ids_op.cc index b708626760caead47804edf9abb83afe8c2efc7c..df9681c315c67299189c203591f533efb93489a5 100644 --- a/paddle/fluid/operators/distributed_ops/split_ids_op.cc +++ b/paddle/fluid/operators/distributed_ops/split_ids_op.cc @@ -73,10 +73,8 @@ class SplitIdsOp : public framework::OperatorWithKernel { class SplitIdsOpInferVarType : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const override { - auto input_type = ctx->GetType(ctx->Input("Ids")[0]); - for (auto &out_var : ctx->Output("Out")) { - ctx->SetType(out_var, input_type); - } + auto input_type = ctx->GetInputType("Ids"); + ctx->SetOutputType("Out", input_type, framework::ALL_ELEMENTS); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index c613969343eb0e3970c7e46090e2b89e9a899084..7ea0b0d1efd397d1183f9cec93d114b6bbc8303f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -119,9 +119,10 @@ class ElementwiseOp : public framework::OperatorWithKernel { class ElementwiseOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { protected: - std::unordered_map GetInputOutputWithSameType() + std::unordered_map &GetInputOutputWithSameType() const override { - return std::unordered_map{{"X", /*->*/ "Out"}}; + static std::unordered_map m{{"X", /*->*/ "Out"}}; + return m; } }; diff --git a/paddle/fluid/operators/eye_op.cc b/paddle/fluid/operators/eye_op.cc index e7fd1a447342ebe9d037d608013638ad2a645d76..aa7f7035ba36a70f3cf132e0aa004cf580fe256d 100644 --- a/paddle/fluid/operators/eye_op.cc +++ b/paddle/fluid/operators/eye_op.cc @@ -49,8 +49,7 @@ class EyeOpVarTypeInference : public framework::VarTypeInference { void operator()(framework::InferVarTypeContext* ctx) const override { auto data_type = static_cast( boost::get(ctx->GetAttr("dtype"))); - auto& out_var_name = ctx->Output("Out").front(); - ctx->SetDataType(out_var_name, data_type); + ctx->SetOutputDataType("Out", data_type); } }; diff --git a/paddle/fluid/operators/fill_any_like_op.cc b/paddle/fluid/operators/fill_any_like_op.cc index 85e0789b0c04433220f41cf7b5c995b7a2b25822..d05a5163b5819ad43a508c21f0c79b48c6765064 100644 --- a/paddle/fluid/operators/fill_any_like_op.cc +++ b/paddle/fluid/operators/fill_any_like_op.cc @@ -72,14 +72,12 @@ The output will have the same shape and dtype as the input. class FillAnyLikeVarTypeInference : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const override { - auto out_var_name = ctx->Output("Out").front(); auto var_data_type = static_cast( boost::get(ctx->GetAttr("dtype"))); if (var_data_type < 0) { - const auto &input_var_name = ctx->Input("X").front(); - ctx->SetDataType(out_var_name, ctx->GetDataType(input_var_name)); + ctx->SetOutputDataType("Out", ctx->GetInputDataType("X")); } else { - ctx->SetDataType(out_var_name, var_data_type); + ctx->SetOutputDataType("Out", var_data_type); } } }; diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index f1a96f6f6ec6a146b7e81d64f975c10c1d96e4a2..5916082a8b172e8d04503982b8b9053f17b7d1b9 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -64,8 +64,7 @@ class FillConstantOpVarTypeInference : public framework::VarTypeInference { void operator()(framework::InferVarTypeContext* ctx) const override { auto data_type = static_cast( boost::get(ctx->GetAttr("dtype"))); - auto& out_var_name = ctx->Output("Out").front(); - ctx->SetDataType(out_var_name, data_type); + ctx->SetOutputDataType("Out", data_type); } }; diff --git a/paddle/fluid/operators/fill_op.cc b/paddle/fluid/operators/fill_op.cc index d1d364f0dd47eecf7c93ac8dd6a78ba2d199242a..c325874114115c959d7bcbe851d81fd7ab7fb681 100644 --- a/paddle/fluid/operators/fill_op.cc +++ b/paddle/fluid/operators/fill_op.cc @@ -63,8 +63,7 @@ class FillOpVarTypeInference : public framework::VarTypeInference { void operator()(framework::InferVarTypeContext* ctx) const override { auto data_type = static_cast( boost::get(ctx->GetAttr("dtype"))); - auto& out_var_name = ctx->Output("Out").front(); - ctx->SetDataType(out_var_name, data_type); + ctx->SetOutputDataType("Out", data_type); } }; diff --git a/paddle/fluid/operators/flip_op.cc b/paddle/fluid/operators/flip_op.cc index afb8cd4af416598cd82ff691aaa012757eba470b..7d0df5ffbd8945ca054fe24088b5fd7b6f5ef167 100644 --- a/paddle/fluid/operators/flip_op.cc +++ b/paddle/fluid/operators/flip_op.cc @@ -114,9 +114,10 @@ class FlipOpMaker : public framework::OpProtoAndCheckerMaker { class FlipOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { protected: - std::unordered_map GetInputOutputWithSameType() + std::unordered_map& GetInputOutputWithSameType() const override { - return std::unordered_map{{"X", /*->*/ "Out"}}; + static std::unordered_map m{{"X", /*->*/ "Out"}}; + return m; } }; diff --git a/paddle/fluid/operators/fused/fused_bn_activation_op.h b/paddle/fluid/operators/fused/fused_bn_activation_op.h index 0b7b75fe6f2c79b8c8e72fc7439565acc52091b3..b8404e4c6553fd0c25e269263b7aa7d71d2f3932 100644 --- a/paddle/fluid/operators/fused/fused_bn_activation_op.h +++ b/paddle/fluid/operators/fused/fused_bn_activation_op.h @@ -85,9 +85,10 @@ class FusedBatchNormActGradOpMaker : public framework::SingleGradOpMaker { class FusedBatchNormActOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { protected: - std::unordered_map GetInputOutputWithSameType() + std::unordered_map& GetInputOutputWithSameType() const override { - return std::unordered_map{{"X", /*->*/ "Y"}}; + static std::unordered_map m{{"X", /*->*/ "Y"}}; + return m; } }; diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc index 2db2ed0972858f9ffbdd8c86a6b4eed35ca2b772..faaed6545ca296071cdf469660af6a356da4f1a6 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc @@ -146,19 +146,20 @@ class FusedEmbeddingSeqPoolOpGradVarTypeInference : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext* ctx) const override { - auto out_var_name = ctx->Output(framework::GradVarName("W")).front(); + auto out_var_name = framework::GradVarName("W"); auto attr = ctx->GetAttr("is_sparse"); bool is_sparse = boost::get(attr); if (is_sparse) { VLOG(3) << "fused_embedding_seq_pool_grad op " << framework::GradVarName("W") << " is set to SelectedRows"; - ctx->SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS); + ctx->SetOutputType(out_var_name, + framework::proto::VarType::SELECTED_ROWS); } else { VLOG(3) << "fused_embedding_seq_pool_grad op " << framework::GradVarName("W") << " is set to LoDTensor"; - ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR); + ctx->SetOutputType(out_var_name, framework::proto::VarType::LOD_TENSOR); } - ctx->SetDataType(out_var_name, ctx->GetDataType(ctx->Input("W")[0])); + ctx->SetOutputDataType(out_var_name, ctx->GetInputDataType("W")); } }; diff --git a/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc b/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc index 8c90965678ddb166cf269096261cc4eb25136a83..8b4dec13cfb36572f323cfb8fad93ce6bd40d2ce 100644 --- a/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc +++ b/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc @@ -83,11 +83,8 @@ class GetTensorFromSelectedRowsOpVarTypeInference : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const { // NOLINT - auto out_var_name = ctx->Output("Out").front(); - auto in_var_name = ctx->Input("X").front(); - - ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR); - ctx->SetDataType(out_var_name, ctx->GetDataType(in_var_name)); + ctx->SetOutputType("Out", framework::proto::VarType::LOD_TENSOR); + ctx->SetOutputDataType("Out", ctx->GetInputDataType("X")); } }; diff --git a/paddle/fluid/operators/group_norm_op.cc b/paddle/fluid/operators/group_norm_op.cc index 4414f1a571258a97a3e973193d539e95da8382f9..0659f51e9731f29232fbb1155c339b5a9e248a5e 100644 --- a/paddle/fluid/operators/group_norm_op.cc +++ b/paddle/fluid/operators/group_norm_op.cc @@ -216,9 +216,10 @@ DECLARE_INPLACE_OP_INFERER(GroupNormGradInplaceInToOut, class GroupNormOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { protected: - std::unordered_map GetInputOutputWithSameType() + std::unordered_map &GetInputOutputWithSameType() const override { - return {{"X", /*->*/ "Y"}}; + static std::unordered_map m{{"X", /*->*/ "Y"}}; + return m; } }; diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index fb3eac791b1d849a77c21b5dc08371db9743c5f1..1b8a206d66d29e608ca012013c1117fc20ebc441 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -229,31 +229,30 @@ class HierarchicalSigmoidGradOpGradVarTypeInference : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext* ctx) const override { - auto w_grad_var_name = ctx->Output(framework::GradVarName("W")).front(); - auto has_bias_grad_var = ctx->HasOutput(framework::GradVarName("Bias")); - std::string bias_grad_var_name; - bool hasBias = false; - if (has_bias_grad_var) { - hasBias = true; - bias_grad_var_name = ctx->Output(framework::GradVarName("Bias")).front(); + auto w_grad_var_name = framework::GradVarName("W"); + auto bias_grad_var_name = framework::GradVarName("Bias"); + if (ctx->HasOutput(bias_grad_var_name)) { + VLOG(3) << "hierarchical_sigmoid_grad op " + << framework::GradVarName("Bias") << " is set to LoDTensor"; + ctx->SetOutputType(bias_grad_var_name, + framework::proto::VarType::LOD_TENSOR); } + auto attr = ctx->GetAttr("is_sparse"); bool is_sparse = boost::get(attr); if (is_sparse) { VLOG(3) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") << " is set to SelectedRows"; - ctx->SetType(w_grad_var_name, framework::proto::VarType::SELECTED_ROWS); + ctx->SetOutputType(w_grad_var_name, + framework::proto::VarType::SELECTED_ROWS); } else { VLOG(3) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") << " is set to LoDTensor"; - ctx->SetType(w_grad_var_name, framework::proto::VarType::LOD_TENSOR); + ctx->SetOutputType(w_grad_var_name, + framework::proto::VarType::LOD_TENSOR); } - if (hasBias) { - VLOG(3) << "hierarchical_sigmoid_grad op " - << framework::GradVarName("Bias") << " is set to LoDTensor"; - ctx->SetType(bias_grad_var_name, framework::proto::VarType::LOD_TENSOR); - } - ctx->SetDataType(w_grad_var_name, ctx->GetDataType(ctx->Input("W")[0])); + + ctx->SetOutputDataType(w_grad_var_name, ctx->GetInputDataType("W")); } }; diff --git a/paddle/fluid/operators/instance_norm_op.h b/paddle/fluid/operators/instance_norm_op.h index e56501f2f038b016a4ec498392aefa71c5a38eff..493f54ab3baa6dbf9166ed709b392fce1c9fb889 100644 --- a/paddle/fluid/operators/instance_norm_op.h +++ b/paddle/fluid/operators/instance_norm_op.h @@ -123,9 +123,10 @@ class InstanceNormDoubleGradMaker : public framework::SingleGradOpMaker { class InstanceNormOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { protected: - std::unordered_map GetInputOutputWithSameType() + std::unordered_map &GetInputOutputWithSameType() const override { - return std::unordered_map{{"X", "Y"}}; + static std::unordered_map m{{"X", "Y"}}; + return m; } }; diff --git a/paddle/fluid/operators/lod_rank_table_op.cc b/paddle/fluid/operators/lod_rank_table_op.cc index c73aaf75bcfbfa2bc2b0baeed53c8de1aaae095d..7cbfbd03e1dcb4983863445f6a9cd2c9ee17a8b0 100644 --- a/paddle/fluid/operators/lod_rank_table_op.cc +++ b/paddle/fluid/operators/lod_rank_table_op.cc @@ -65,9 +65,8 @@ class LoDRankTableInferShape : public framework::InferShapeBase { class LoDRankTableInferVarType : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const override { - for (auto &o : ctx->Output("Out")) { - ctx->SetType(o, framework::proto::VarType::LOD_RANK_TABLE); - } + ctx->SetOutputType("Out", framework::proto::VarType::LOD_RANK_TABLE, + framework::ALL_ELEMENTS); } }; diff --git a/paddle/fluid/operators/lod_reset_op.cc b/paddle/fluid/operators/lod_reset_op.cc index 7adcc678f5c42b474fdd9271390dcf6e8d741541..1b4ab5a184d0dd042b0e2ced5c8b2aaeaf3b4d29 100644 --- a/paddle/fluid/operators/lod_reset_op.cc +++ b/paddle/fluid/operators/lod_reset_op.cc @@ -76,24 +76,25 @@ class LoDResetOp : public framework::OperatorWithKernel { } }; -class LoDResetOpVarTypeInference : public framework::VarTypeInference { +class LoDResetOpVarTypeInference + : public framework::StaticGraphVarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const override { - auto x_var_name = ctx->Input("X").front(); - auto out_var_name = ctx->Output("Out").front(); + auto x_var_name = Input(ctx, "X").front(); + auto out_var_name = Output(ctx, "Out").front(); bool append = boost::get(ctx->GetAttr("append")); if (ctx->HasInput("Y")) { - auto y_var_name = ctx->Input("Y").front(); - auto y_lod_level = std::max(ctx->GetLoDLevel(y_var_name), 1); - ctx->SetLoDLevel(out_var_name, y_lod_level); + auto y_var_name = Input(ctx, "Y").front(); + auto y_lod_level = std::max(GetLoDLevel(ctx, y_var_name), 1); + SetLoDLevel(ctx, out_var_name, y_lod_level); } else if (append) { - auto x_lod_level = std::max(ctx->GetLoDLevel(x_var_name), 1); - ctx->SetLoDLevel(out_var_name, x_lod_level); + auto x_lod_level = std::max(GetLoDLevel(ctx, x_var_name), 1); + SetLoDLevel(ctx, out_var_name, x_lod_level); } else { - ctx->SetLoDLevel(out_var_name, 1); + SetLoDLevel(ctx, out_var_name, 1); } - ctx->SetDataType(out_var_name, ctx->GetDataType(x_var_name)); - ctx->SetType(out_var_name, paddle::framework::proto::VarType::LOD_TENSOR); + SetDataType(ctx, out_var_name, GetDataType(ctx, x_var_name)); + SetType(ctx, out_var_name, paddle::framework::proto::VarType::LOD_TENSOR); } }; diff --git a/paddle/fluid/operators/lod_tensor_to_array_op.cc b/paddle/fluid/operators/lod_tensor_to_array_op.cc index 50a2a2c9467fb15d44631e3257ed74bf2c0334dc..b130e84933bc9a26653b5eb164ccc450fdb7b63e 100644 --- a/paddle/fluid/operators/lod_tensor_to_array_op.cc +++ b/paddle/fluid/operators/lod_tensor_to_array_op.cc @@ -221,9 +221,8 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase { class LoDTensorToArrayInferVarType : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const override { - for (auto &out_var : ctx->Output("Out")) { - ctx->SetType(out_var, framework::proto::VarType::LOD_TENSOR_ARRAY); - } + ctx->SetOutputType("Out", framework::proto::VarType::LOD_TENSOR_ARRAY, + framework::ALL_ELEMENTS); } }; diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index 1cf23b124e38d0be28a9bd129196b3ea2e7cdb10..fc2a2fbb3f66e4723983af8ec80f9c8c16c75eb0 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -173,19 +173,20 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { class LookupTableOpGradVarTypeInference : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext* ctx) const override { - auto out_var_name = ctx->Output(framework::GradVarName("W")).front(); + auto out_var_name = framework::GradVarName("W"); auto attr = ctx->GetAttr("is_sparse"); bool is_sparse = boost::get(attr); if (is_sparse) { VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") << " is set to SelectedRows"; - ctx->SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS); + ctx->SetOutputType(out_var_name, + framework::proto::VarType::SELECTED_ROWS); } else { VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") << " is set to LoDTensor"; - ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR); + ctx->SetOutputType(out_var_name, framework::proto::VarType::LOD_TENSOR); } - ctx->SetDataType(out_var_name, ctx->GetDataType(ctx->Input("W")[0])); + ctx->SetOutputDataType(out_var_name, ctx->GetInputDataType("W")); } }; diff --git a/paddle/fluid/operators/lookup_table_v2_op.cc b/paddle/fluid/operators/lookup_table_v2_op.cc index 02977a7887c4d19c702f56e430472923987fc73c..849287ce5e1e9490ead753fcddb0a96c3a7381d2 100644 --- a/paddle/fluid/operators/lookup_table_v2_op.cc +++ b/paddle/fluid/operators/lookup_table_v2_op.cc @@ -160,19 +160,20 @@ class LookupTableV2OpGrad : public framework::OperatorWithKernel { class LookupTableV2OpGradVarTypeInference : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext* ctx) const override { - auto out_var_name = ctx->Output(framework::GradVarName("W")).front(); + auto out_var_name = framework::GradVarName("W"); auto attr = ctx->GetAttr("is_sparse"); bool is_sparse = boost::get(attr); if (is_sparse) { VLOG(3) << "lookup_table_v2_grad op " << framework::GradVarName("W") << " is set to SelectedRows"; - ctx->SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS); + ctx->SetOutputType(out_var_name, + framework::proto::VarType::SELECTED_ROWS); } else { VLOG(3) << "lookup_table_v2_grad op " << framework::GradVarName("W") << " is set to LoDTensor"; - ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR); + ctx->SetOutputType(out_var_name, framework::proto::VarType::LOD_TENSOR); } - ctx->SetDataType(out_var_name, ctx->GetDataType(ctx->Input("W")[0])); + ctx->SetOutputDataType(out_var_name, ctx->GetInputDataType("W")); } }; diff --git a/paddle/fluid/operators/mean_op.cc b/paddle/fluid/operators/mean_op.cc index 7237dafaac19e718ea119d1e22b1e599bc9772b2..7e75905bc4975b59772cb0d22d8a6db3520e1803 100644 --- a/paddle/fluid/operators/mean_op.cc +++ b/paddle/fluid/operators/mean_op.cc @@ -45,9 +45,10 @@ Mean Operator calculates the mean of all elements in X. class MeanOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { protected: - std::unordered_map GetInputOutputWithSameType() + std::unordered_map& GetInputOutputWithSameType() const override { - return std::unordered_map{{"X", /*->*/ "Out"}}; + static std::unordered_map m{{"X", /*->*/ "Out"}}; + return m; } }; diff --git a/paddle/fluid/operators/merge_selected_rows_op.cc b/paddle/fluid/operators/merge_selected_rows_op.cc index 50f44c7fc5ec90420d7c38f0f536ff7adb8f9ec4..e758c2bb6549c097709983d766db2a6f4388f7bc 100644 --- a/paddle/fluid/operators/merge_selected_rows_op.cc +++ b/paddle/fluid/operators/merge_selected_rows_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/merge_selected_rows_op.h" +#include namespace paddle { namespace operators { @@ -79,9 +80,10 @@ Example: class MergeSelectedRowsOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { protected: - std::unordered_map GetInputOutputWithSameType() + std::unordered_map& GetInputOutputWithSameType() const override { - return std::unordered_map{{"X", /*->*/ "Out"}}; + static std::unordered_map m{{"X", /*->*/ "Out"}}; + return m; } }; diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index d1ac0e0dff56b6324d4ca295b9da7a3eb9da8cd6..b3afba1e4f9791b1b9027ca038b495380f403773 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -200,9 +200,10 @@ or not. But the output only shares the LoD information with input $X$. class MulOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { protected: - std::unordered_map GetInputOutputWithSameType() + std::unordered_map& GetInputOutputWithSameType() const override { - return std::unordered_map{{"X", /*->*/ "Out"}}; + static std::unordered_map m{{"X", /*->*/ "Out"}}; + return m; } }; diff --git a/paddle/fluid/operators/nccl/nccl_op.cc b/paddle/fluid/operators/nccl/nccl_op.cc index 8cf43be35cef64d5a9710901c52fa9dfe62eed72..519fcf5924a0cf2248a5ec835d1e3fac54966e65 100644 --- a/paddle/fluid/operators/nccl/nccl_op.cc +++ b/paddle/fluid/operators/nccl/nccl_op.cc @@ -61,8 +61,7 @@ class NCCLInitOp : public framework::OperatorBase { class NCCLInitOpVarTypeInference : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const override { - auto out_var_name = ctx->Output("Communicator").front(); - ctx->SetType(out_var_name, framework::proto::VarType::RAW); + ctx->SetOutputType("Communicator", framework::proto::VarType::RAW); } }; diff --git a/paddle/fluid/operators/nce_op.cc b/paddle/fluid/operators/nce_op.cc index a22f20100823196a1854739e88482b83e932a146..2dfd05c23566afd3d70daf78f4f2145a83be587d 100644 --- a/paddle/fluid/operators/nce_op.cc +++ b/paddle/fluid/operators/nce_op.cc @@ -280,20 +280,20 @@ class NCEOpGrad : public framework::OperatorWithKernel { class NCEOpGradVarTypeInference : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const override { - auto weight_grad = ctx->Output(framework::GradVarName("Weight")).front(); + auto weight_grad = framework::GradVarName("Weight"); auto attr = ctx->GetAttr("is_sparse"); bool is_sparse = boost::get(attr); if (is_sparse) { VLOG(3) << "nce_op_grad op " << weight_grad << " and " << " is set to SelectedRows"; - ctx->SetType(weight_grad, framework::proto::VarType::SELECTED_ROWS); + ctx->SetOutputType(weight_grad, framework::proto::VarType::SELECTED_ROWS); } else { VLOG(3) << "nce_op_grad op " << weight_grad << " and " << " is set to LoDTensor"; - ctx->SetType(weight_grad, framework::proto::VarType::LOD_TENSOR); + ctx->SetOutputType(weight_grad, framework::proto::VarType::LOD_TENSOR); } - ctx->SetDataType(weight_grad, ctx->GetDataType(ctx->Input("Input")[0])); + ctx->SetOutputDataType(weight_grad, ctx->GetInputDataType("Input")); } }; diff --git a/paddle/fluid/operators/optimizers/momentum_op.cc b/paddle/fluid/operators/optimizers/momentum_op.cc index a0dc387fb677eef2209596823b448fddb3af7cf4..ccebfeca26ca33e4c1ff17d5cdc834af0db6d5b0 100644 --- a/paddle/fluid/operators/optimizers/momentum_op.cc +++ b/paddle/fluid/operators/optimizers/momentum_op.cc @@ -22,18 +22,15 @@ using Tensor = framework::Tensor; class MomentumOpInferVarType : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext* ctx) const override { - auto& input_var = ctx->Input("Param")[0]; - for (auto& out_var : ctx->Output("ParamOut")) { - if (ctx->GetType(input_var) == framework::proto::VarType::SELECTED_ROWS) { - ctx->SetType(out_var, framework::proto::VarType::SELECTED_ROWS); - } else if (ctx->GetType(input_var) == - framework::proto::VarType::LOD_TENSOR) { - ctx->SetType(out_var, framework::proto::VarType::LOD_TENSOR); - } else { - PADDLE_THROW( - "Only support LodTensor and SelectedRows, Unexpected Input Type."); - } - } + auto in_var_type = ctx->GetInputType("Param"); + PADDLE_ENFORCE_EQ( + in_var_type == framework::proto::VarType::SELECTED_ROWS || + in_var_type == framework::proto::VarType::LOD_TENSOR, + true, + platform::errors::InvalidArgument( + "Only support LodTensor and SelectedRows, Unexpected Input Type.")); + + ctx->SetOutputType("ParamOut", in_var_type, framework::ALL_ELEMENTS); } }; diff --git a/paddle/fluid/operators/optimizers/sgd_op.cc b/paddle/fluid/operators/optimizers/sgd_op.cc index 14db2c08d1a5183d68506db176d786b548cfadc7..aeff8da70b958a440953824c46a095a4f86b9379 100644 --- a/paddle/fluid/operators/optimizers/sgd_op.cc +++ b/paddle/fluid/operators/optimizers/sgd_op.cc @@ -75,19 +75,15 @@ class SGDOp : public framework::OperatorWithKernel { class SGDOpInferVarType : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const override { - auto &input_var_n = ctx->Input("Param")[0]; - auto in_var_type = ctx->GetType(input_var_n); - PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS || - in_var_type == framework::proto::VarType::LOD_TENSOR, - "The input Var's type should be LoDtensor or SelectedRows," - " but the received var(%s)'s type is %s", - input_var_n, in_var_type); - - for (auto &out_var_n : ctx->Output("ParamOut")) { - if (ctx->GetType(out_var_n) != in_var_type) { - ctx->SetType(out_var_n, in_var_type); - } - } + auto in_var_type = ctx->GetInputType("Param"); + PADDLE_ENFORCE_EQ(in_var_type == framework::proto::VarType::SELECTED_ROWS || + in_var_type == framework::proto::VarType::LOD_TENSOR, + true, platform::errors::InvalidArgument( + "The input Var's type should be LoDtensor or " + "SelectedRows, but the received type is %s", + in_var_type)); + + ctx->SetOutputType("ParamOut", in_var_type, framework::ALL_ELEMENTS); } }; diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index bf5f7cf6e8ad08df20de01deb819b21a232d2576..8ff3192ca24e193dfbddebdd1ce79ce91f08dbb2 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -422,9 +422,10 @@ Example: class PoolOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { protected: - std::unordered_map GetInputOutputWithSameType() + std::unordered_map& GetInputOutputWithSameType() const override { - return std::unordered_map{{"X", /*->*/ "Out"}}; + static std::unordered_map m{{"X", /*->*/ "Out"}}; + return m; } }; diff --git a/paddle/fluid/operators/print_op.cc b/paddle/fluid/operators/print_op.cc index 64f5ac11dcf8dfaafc9399b60e9b6b40c9ea4825..dff2074fbec0754433a6827d0a595c6e0d6f3755 100644 --- a/paddle/fluid/operators/print_op.cc +++ b/paddle/fluid/operators/print_op.cc @@ -260,9 +260,7 @@ class PrintOpInferShape : public framework::InferShapeBase { class PrintOpVarTypeInference : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const override { - auto input_type = ctx->GetType(ctx->Input("In")[0]); - auto out_name = ctx->Output("Out").front(); - ctx->SetType(out_name, input_type); + ctx->SetOutputType("Out", ctx->GetInputType("In")); } }; diff --git a/paddle/fluid/operators/py_func_op.cc b/paddle/fluid/operators/py_func_op.cc index b9c32304353e0e715cd79ea9d604cdadf6fde44f..849798166247b665655e8c3dbf9e5aafca31b4d8 100644 --- a/paddle/fluid/operators/py_func_op.cc +++ b/paddle/fluid/operators/py_func_op.cc @@ -116,12 +116,11 @@ static void CallPythonFunc(py::object *callable, } } -class PyFuncOpVarTypeInference : public framework::VarTypeInference { +class PyFuncOpVarTypeInference : public framework::StaticGraphVarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const override { - bool has_out = (ctx->HasOutput("Out") && !ctx->Output("Out").empty()); - - bool has_in = (ctx->HasInput("X") && !ctx->Input("X").empty()); + bool has_out = ctx->HasOutput("Out"); + bool has_in = ctx->HasInput("X"); /** * X or Out can be empty, so that py_func can be more flexible @@ -147,7 +146,7 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference { * the corresponding forward variable */ const std::string kGradVarSuffix = framework::kGradVarSuffix; - auto &out_var_names = ctx->Output("Out"); + auto &out_var_names = Output(ctx, "Out"); for (auto &out_var_name : out_var_names) { if (out_var_name == framework::kEmptyVarName || out_var_name.size() < kGradVarSuffix.size()) { @@ -157,19 +156,17 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference { size_t len = out_var_name.size() - kGradVarSuffix.size(); if (out_var_name.substr(len) == kGradVarSuffix) { auto fwd_var_name = out_var_name.substr(0, len); - PADDLE_ENFORCE_EQ(ctx->HasVar(out_var_name), true, - platform::errors::InvalidArgument( - "Backward variable %s not found", out_var_name)); - PADDLE_ENFORCE_EQ(ctx->HasVar(fwd_var_name), true, - platform::errors::InvalidArgument( - "Backward variable %s not found", fwd_var_name)); + OP_INOUT_CHECK(HasVar(ctx, out_var_name), "Var", out_var_name, + "py_func"); + OP_INOUT_CHECK(HasVar(ctx, fwd_var_name), "Var", fwd_var_name, + "py_func"); VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input(" << fwd_var_name << ")"; - ctx->SetShape(out_var_name, ctx->GetShape(fwd_var_name)); - ctx->SetDataType(out_var_name, ctx->GetDataType(fwd_var_name)); - ctx->SetLoDLevel(out_var_name, ctx->GetLoDLevel(fwd_var_name)); - ctx->SetType(out_var_name, ctx->GetType(fwd_var_name)); + SetShape(ctx, out_var_name, GetShape(ctx, fwd_var_name)); + SetDataType(ctx, out_var_name, GetDataType(ctx, fwd_var_name)); + SetLoDLevel(ctx, out_var_name, GetLoDLevel(ctx, fwd_var_name)); + SetType(ctx, out_var_name, GetType(ctx, fwd_var_name)); } } } diff --git a/paddle/fluid/operators/randperm_op.cc b/paddle/fluid/operators/randperm_op.cc index 67d7c578dcd777d084fcbe14658a9ae2cd3e0ed6..70808363e16182fe6dae21faf18e0e85d74a6df5 100644 --- a/paddle/fluid/operators/randperm_op.cc +++ b/paddle/fluid/operators/randperm_op.cc @@ -75,8 +75,7 @@ class RandpermOpVarTypeInference : public framework::VarTypeInference { void operator()(framework::InferVarTypeContext *ctx) const override { auto var_data_type = static_cast( boost::get(ctx->GetAttr("dtype"))); - auto out_var_name = ctx->Output("Out").front(); - ctx->SetDataType(out_var_name, var_data_type); + ctx->SetOutputDataType("Out", var_data_type); } }; diff --git a/paddle/fluid/operators/reader/read_op.cc b/paddle/fluid/operators/reader/read_op.cc index 9a5ef25b5b6cf96c05ac1cbe8353fa0134a15adf..2ba2ef244fe50e720df20fc1f24f8b9ed6bfeb76 100644 --- a/paddle/fluid/operators/reader/read_op.cc +++ b/paddle/fluid/operators/reader/read_op.cc @@ -70,18 +70,18 @@ class ReadInferShape : public framework::InferShapeBase { } }; -class ReadInferVarType : public framework::VarTypeInference { +class ReadInferVarType : public framework::StaticGraphVarTypeInference { public: void operator()(framework::InferVarTypeContext* ctx) const override { bool infer_out = boost::get(ctx->GetAttr("infer_out")); if (infer_out) { - std::string reader_name = ctx->Input("Reader")[0]; - std::vector out_names = ctx->Output("Out"); - auto dtypes = ctx->GetDataTypes(reader_name); + std::string reader_name = Input(ctx, "Reader")[0]; + auto& out_names = Output(ctx, "Out"); + auto dtypes = GetDataTypes(ctx, reader_name); PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size()); for (size_t i = 0; i < dtypes.size(); ++i) { - ctx->SetType(out_names[i], framework::proto::VarType::LOD_TENSOR); - ctx->SetDataType(out_names[i], dtypes[i]); + SetType(ctx, out_names[i], framework::proto::VarType::LOD_TENSOR); + SetDataType(ctx, out_names[i], dtypes[i]); } } } diff --git a/paddle/fluid/operators/reader/reader_op_registry.cc b/paddle/fluid/operators/reader/reader_op_registry.cc index 75b11991775d030ebdde3eb3a3eb67ed8eef1e4a..e51d73f6a81f80a078f8baf9dceaf30994aac060 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.cc +++ b/paddle/fluid/operators/reader/reader_op_registry.cc @@ -100,8 +100,7 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const { void FileReaderInferVarType::operator()( framework::InferVarTypeContext* ctx) const { - std::string reader_name = ctx->Output("Out")[0]; - ctx->SetType(reader_name, framework::proto::VarType::READER); + ctx->SetOutputType("Out", framework::proto::VarType::READER); } void DecoratedReaderInferShape::operator()( @@ -125,10 +124,8 @@ void DecoratedReaderInferShape::operator()( void DecoratedReaderInferVarType::operator()( framework::InferVarTypeContext* ctx) const { - const std::string& in_reader_name = ctx->Input("UnderlyingReader")[0]; - const std::string& out_reader_name = ctx->Output("Out")[0]; - ctx->SetType(out_reader_name, framework::proto::VarType::READER); - ctx->SetDataTypes(out_reader_name, ctx->GetDataTypes(in_reader_name)); + ctx->SetOutputType("Out", framework::proto::VarType::READER); + ctx->SetOutputDataTypes("Out", ctx->GetInputDataTypes("UnderlyingReader")); } void DecoratedReaderMakerBase::Make() { diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc index 0eb6f2e66ee6a5a09badf56f861cd4b31149d472..6e860010bcde6f0268741636975d5c66b21499f1 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc @@ -58,8 +58,7 @@ class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference { auto data_type = static_cast( boost::get(ctx->GetAttr("out_dtype"))); if (data_type >= 0) { - auto& out_var_name = ctx->Output("Out").front(); - ctx->SetDataType(out_var_name, data_type); + ctx->SetOutputDataType("Out", data_type); } } }; diff --git a/paddle/fluid/operators/save_combine_op.cc b/paddle/fluid/operators/save_combine_op.cc index a4ac4f009e8fa9dd54a34d95cc087c90e68e5370..ec038f16113dda3915dde167ba49b6be245c9f02 100644 --- a/paddle/fluid/operators/save_combine_op.cc +++ b/paddle/fluid/operators/save_combine_op.cc @@ -85,9 +85,8 @@ to a file on disk. class SaveCombineOpInferVarType : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext* ctx) const override { - for (auto& o : ctx->Output("Y")) { - ctx->SetType(o, framework::proto::VarType::RAW); - } + ctx->SetOutputType("Y", framework::proto::VarType::RAW, + framework::ALL_ELEMENTS); } }; diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 8a41d79433a8dade7bd931b3c68c8c2c40f0250a..c2a58b4199f32a0e5140c599028426afe98ae016 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -73,7 +73,7 @@ class SaveOpVarTypeInference : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const override { auto var_type = framework::proto::VarType::RAW; - ctx->SetType(LOOKUP_TABLE_PATH, var_type); + ctx->InsertVar(LOOKUP_TABLE_PATH, var_type); } }; diff --git a/paddle/fluid/operators/scale_op.cc b/paddle/fluid/operators/scale_op.cc index e6b870648cdad8f51917ec1379332e0c13b1238c..647e3cea99d3c1975d0da988d58dcab139ec1209 100644 --- a/paddle/fluid/operators/scale_op.cc +++ b/paddle/fluid/operators/scale_op.cc @@ -82,13 +82,7 @@ $$Out = scale*(X + bias)$$ class ScaleOpVarTypeInference : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const override { - auto &in_var_name = ctx->Input("X").front(); - auto out_var_name = ctx->Output("Out").front(); - - if (in_var_name != out_var_name) { - ctx->SetType(out_var_name, ctx->GetType(in_var_name)); - ctx->SetDataType(out_var_name, ctx->GetDataType(in_var_name)); - } + ctx->SyncTypeAndDataType("X", "Out"); } }; diff --git a/paddle/fluid/operators/selu_op.cc b/paddle/fluid/operators/selu_op.cc index 1570a1cd83ae6791ede7604a0ee4d6cadca13419..7c77b2688e7b528f678418c67e77fa4abff04248 100644 --- a/paddle/fluid/operators/selu_op.cc +++ b/paddle/fluid/operators/selu_op.cc @@ -45,9 +45,10 @@ class SeluOp : public framework::OperatorWithKernel { class SeluOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { protected: - std::unordered_map GetInputOutputWithSameType() + std::unordered_map &GetInputOutputWithSameType() const override { - return std::unordered_map{{"X", /*->*/ "Out"}}; + static std::unordered_map m{{"X", /*->*/ "Out"}}; + return m; } }; diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index b2a798f66d23567b992080981f65c23deb73d551..56dca92ea68dd93800ac4d350db149f43bc35844 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -145,9 +145,10 @@ For each row $i$ and each column $j$ in the matrix, we have: class SoftmaxOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { protected: - std::unordered_map GetInputOutputWithSameType() + std::unordered_map& GetInputOutputWithSameType() const override { - return std::unordered_map{{"X", /*->*/ "Out"}}; + static std::unordered_map m{{"X", /*->*/ "Out"}}; + return m; } }; diff --git a/paddle/fluid/operators/split_selected_rows_op.cc b/paddle/fluid/operators/split_selected_rows_op.cc index 96cbeb983b412ad7f497f951ec13342fe1a43609..c503c826c2ac7130b6c809a963d535d6b8b99a77 100644 --- a/paddle/fluid/operators/split_selected_rows_op.cc +++ b/paddle/fluid/operators/split_selected_rows_op.cc @@ -64,9 +64,8 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel { class SplitSelectedRowsOpInferVarType : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const override { - for (auto &out_var : ctx->Output("Out")) { - ctx->SetType(out_var, framework::proto::VarType::SELECTED_ROWS); - } + ctx->SetOutputType("Out", framework::proto::VarType::SELECTED_ROWS, + framework::ALL_ELEMENTS); } }; diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index 39808bb8f84dd91ac47a5d6ae01658ca86960113..1c59fd99ba64dc530c93d49eb6dfddc381478f85 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -210,43 +210,36 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker { class SumOpVarTypeInference : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext* ctx) const override { - auto& inputs = ctx->Input("X"); - auto var_type = framework::proto::VarType::SELECTED_ROWS; - for (auto& name : ctx->Input("X")) { - VLOG(10) << name << " " << ctx->GetType(name); - } - - bool any_input_is_lod_tensor = std::any_of( - inputs.begin(), inputs.end(), [ctx](const std::string& name) { - return ctx->GetType(name) == framework::proto::VarType::LOD_TENSOR; - }); - - auto is_tensor_array = [ctx](const std::string& name) { - return ctx->GetType(name) == framework::proto::VarType::LOD_TENSOR_ARRAY; - }; - - bool any_input_is_tensor_array = - std::any_of(inputs.begin(), inputs.end(), is_tensor_array); - bool all_inputs_are_tensor_array = - std::all_of(inputs.begin(), inputs.end(), is_tensor_array); + if (!ctx->IsDygraph()) { + auto var_type = framework::proto::VarType::SELECTED_ROWS; + if (VLOG_IS_ON(10)) { + for (size_t ind = 0; ind < ctx->InputSize("X"); ++ind) { + VLOG(10) << ctx->InputVarName("X", ind) << " " + << ctx->GetInputType("X", ind); + } + } - if (any_input_is_tensor_array) { - if (!all_inputs_are_tensor_array) { - std::ostringstream os; - for (auto& each : inputs) { - os << " " << each << " type is " << ctx->GetType(each) << "\n"; + if (ctx->InputTypeAnyOf("X", + framework::proto::VarType::LOD_TENSOR_ARRAY)) { + if (!ctx->InputTypeAllOf("X", + framework::proto::VarType::LOD_TENSOR_ARRAY)) { + std::ostringstream os; + for (size_t ind = 0; ind < ctx->InputSize("X"); ++ind) { + os << " " << ctx->InputVarName("X", ind) << " type is " + << ctx->GetInputType("X", ind) << "\n"; + } + PADDLE_THROW(platform::errors::InvalidArgument( + "Not all inputs are tensor array:\n%s", os.str())); } - PADDLE_ENFORCE_EQ(all_inputs_are_tensor_array, true, - "Not all inputs are tensor array:\n%s", os.str()); + var_type = framework::proto::VarType::LOD_TENSOR_ARRAY; + } else if (ctx->InputTypeAnyOf("X", + framework::proto::VarType::LOD_TENSOR)) { + var_type = framework::proto::VarType::LOD_TENSOR; } - var_type = framework::proto::VarType::LOD_TENSOR_ARRAY; - } else if (any_input_is_lod_tensor) { - var_type = framework::proto::VarType::LOD_TENSOR; - } - auto out_var_name = ctx->Output("Out").front(); - ctx->SetType(out_var_name, var_type); - ctx->SetDataType(out_var_name, ctx->GetDataType(inputs.front())); + ctx->SetOutputType("Out", var_type); + ctx->SetOutputDataType("Out", ctx->GetInputDataType("X")); + } } }; diff --git a/paddle/fluid/operators/tensor_array_to_tensor_op.cc b/paddle/fluid/operators/tensor_array_to_tensor_op.cc index 3b9e651ff5af8e5ab904489bab3ed6ac143e5550..8a1621f58ab10ed351aa9a9b15e4010192697eda 100644 --- a/paddle/fluid/operators/tensor_array_to_tensor_op.cc +++ b/paddle/fluid/operators/tensor_array_to_tensor_op.cc @@ -213,9 +213,9 @@ class LoDTensorArray2TensorGradInferVarType : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const override { - for (auto &out_var : ctx->Output(framework::GradVarName("X"))) { - ctx->SetType(out_var, framework::proto::VarType::LOD_TENSOR_ARRAY); - } + ctx->SetOutputType(framework::GradVarName("X"), + framework::proto::VarType::LOD_TENSOR_ARRAY, + framework::ALL_ELEMENTS); } }; diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index 063d390a56d3767e9c1821101845bb890aeda689..ac8e66ba3bb22a41c0325a6c79505ad77beda41f 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -232,15 +232,13 @@ uniform distribution. The random result is in set [min, max). class UniformRandomOpVarTypeInference : public framework::VarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const override { - auto out_var_name = ctx->Output("Out").front(); auto var_data_type = static_cast( boost::get(ctx->GetAttr("dtype"))); - if (ctx->GetType(out_var_name) != - framework::proto::VarType::SELECTED_ROWS) { - ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR); + if (ctx->GetOutputType("Out") != framework::proto::VarType::SELECTED_ROWS) { + ctx->SetOutputType("Out", framework::proto::VarType::LOD_TENSOR); } - ctx->SetDataType(out_var_name, var_data_type); + ctx->SetOutputDataType("Out", var_data_type); } };