From 1b1cb44f13242f2e315b6f648679cf936eb999a2 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 13 Oct 2017 14:05:28 -0700 Subject: [PATCH] Complete infer_var_type --- paddle/framework/CMakeLists.txt | 3 + paddle/framework/details/op_registry.h | 19 +++- paddle/framework/op_desc.cc | 14 +++ paddle/framework/op_desc.h | 2 + paddle/framework/op_info.h | 2 +- paddle/framework/type_defs.h | 9 ++ paddle/framework/var_type_inference.h | 29 ++++++ paddle/framework/var_type_inference_test.cc | 103 ++++++++++++++++++++ 8 files changed, 178 insertions(+), 3 deletions(-) create mode 100644 paddle/framework/var_type_inference.h create mode 100644 paddle/framework/var_type_inference_test.cc diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 14947b6f2..2c61ae40a 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -53,3 +53,6 @@ endif() cc_library(tensor_array SRCS tensor_array.cc DEPS lod_tensor) cc_test(tensor_array_test SRCS tensor_array_test.cc DEPS tensor_array place) + +cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry + proto_desc) diff --git a/paddle/framework/details/op_registry.h b/paddle/framework/details/op_registry.h index ca8584b78..71ee14301 100644 --- a/paddle/framework/details/op_registry.h +++ b/paddle/framework/details/op_registry.h @@ -18,6 +18,7 @@ #include "paddle/framework/op_info.h" #include "paddle/framework/op_proto_maker.h" #include "paddle/framework/operator.h" +#include "paddle/framework/var_type_inference.h" namespace paddle { namespace framework { @@ -26,7 +27,8 @@ namespace details { enum OpInfoFillType { kOperator = 0, kOpProtoAndCheckerMaker = 1, - kGradOpDescMaker = 2 + kGradOpDescMaker = 2, + kVarTypeInference = 3 }; template @@ -38,7 +40,9 @@ struct OpInfoFillTypeID { ? kOpProtoAndCheckerMaker : (std::is_base_of::value ? kGradOpDescMaker - : static_cast(-1))); + : (std::is_base_of::value + ? kVarTypeInference + : static_cast(-1)))); } }; @@ -105,6 +109,17 @@ struct OpInfoFiller { }; } }; + +template +struct OpInfoFiller { + void operator()(const char* op_type, OpInfo* info) const { + info->infer_var_type_ = [](const OpDescBind& fwd_op, BlockDescBind* block) { + T inference; + inference(fwd_op, block); + }; + } +}; + } // namespace details } // namespace framework diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index a5d515bbc..09a544fb9 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -236,5 +236,19 @@ void OpDescBind::InferShape(const BlockDescBind &block) const { it->second(&ctx); } +void OpDescBind::InferVarType(BlockDescBind *block) const { + auto &info = OpInfoMap::Instance().Get(this->Type()); + if (info.infer_var_type_) { + info.infer_var_type_(*this, block); + } else { + // all output type is LoDTensor by default + for (auto &out_pair : this->outputs_) { + for (auto &out_var_name : out_pair.second) { + block->Var(out_var_name)->SetType(VarDesc::LOD_TENSOR); + } + } + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index 90155fade..d05ee0875 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -104,6 +104,8 @@ class OpDescBind { void InferShape(const BlockDescBind &block) const; + void InferVarType(BlockDescBind *block) const; + private: template static std::vector MapKeys(const MapType &map) { diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h index c504f69e3..e92618078 100644 --- a/paddle/framework/op_info.h +++ b/paddle/framework/op_info.h @@ -19,7 +19,6 @@ #include #include "paddle/framework/attribute.h" -#include "paddle/framework/op_desc.h" #include "paddle/framework/type_defs.h" #include "paddle/platform/macros.h" @@ -31,6 +30,7 @@ struct OpInfo { GradOpMakerFN grad_op_maker_; OpProto* proto_{nullptr}; OpAttrChecker* checker_{nullptr}; + InferVarTypeFN infer_var_type_; bool HasOpProtoAndChecker() const { return proto_ != nullptr && checker_ != nullptr; diff --git a/paddle/framework/type_defs.h b/paddle/framework/type_defs.h index 7e1b79c97..a4e8253bf 100644 --- a/paddle/framework/type_defs.h +++ b/paddle/framework/type_defs.h @@ -16,12 +16,18 @@ #include #include #include +#include +#include +#include +#include #include "paddle/platform/variant.h" namespace paddle { namespace framework { class OperatorBase; class OpDescBind; +class BlockDescBind; +class BlockDesc; using VariableNameMap = std::map>; // The order should be as same as framework.proto @@ -39,5 +45,8 @@ using OpCreator = std::function>( const OpDescBind&, const std::unordered_set& /*no_grad_set*/)>; +using InferVarTypeFN = std::function; + } // namespace framework } // namespace paddle diff --git a/paddle/framework/var_type_inference.h b/paddle/framework/var_type_inference.h new file mode 100644 index 000000000..32abbeb33 --- /dev/null +++ b/paddle/framework/var_type_inference.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/type_defs.h" + +namespace paddle { +namespace framework { + +class VarTypeInference { + public: + virtual ~VarTypeInference() {} + virtual void operator()(const OpDescBind& op_desc, + BlockDescBind* block) const = 0; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/var_type_inference_test.cc b/paddle/framework/var_type_inference_test.cc new file mode 100644 index 000000000..e3f4893f1 --- /dev/null +++ b/paddle/framework/var_type_inference_test.cc @@ -0,0 +1,103 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/var_type_inference.h" +#include "gtest/gtest.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" +#include "paddle/framework/program_desc.h" + +namespace paddle { +namespace framework { + +class SumOpMaker : public OpProtoAndCheckerMaker { + public: + SumOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "").AsDuplicable(); + AddOutput("Out", ""); + AddComment(""); + } +}; + +class SumOpVarTypeInference : public VarTypeInference { + public: + void operator()(const OpDescBind &op_desc, + BlockDescBind *block) const override { + auto default_var_type = VarDesc::LOD_TENSOR; + for (auto &in_var_name : op_desc.Input("X")) { + auto in_var_type = block->Var(in_var_name)->GetType(); + if (in_var_type != default_var_type) { + default_var_type = in_var_type; + break; + } + } + auto out_var_name = op_desc.Output("Out").front(); + block->Var(out_var_name)->SetType(default_var_type); + } +}; +} // namespace framework +} // namespace paddle + +REGISTER_OPERATOR(sum, paddle::framework::NOP, paddle::framework::SumOpMaker, + paddle::framework::SumOpVarTypeInference); +REGISTER_OPERATOR(sum_without_infer_var_type, paddle::framework::NOP, + paddle::framework::SumOpMaker); + +namespace paddle { +namespace framework { + +TEST(InferVarType, sum_op) { + auto &prog = ProgramDescBind::Instance(&GetProgramDesc()); + auto *op = prog.Block(0)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"test_a", "test_b", "test_c"}); + op->SetOutput("Out", {"test_out"}); + + prog.Block(0)->NewVar("test_a")->SetType(VarDesc_VarType_LOD_TENSOR); + prog.Block(0)->NewVar("test_b")->SetType(VarDesc_VarType_LOD_TENSOR); + prog.Block(0)->NewVar("test_c")->SetType(VarDesc_VarType_LOD_TENSOR); + prog.Block(0)->NewVar("test_out"); + + op->InferVarType(prog.Block(0)); + + ASSERT_EQ(VarDesc_VarType_LOD_TENSOR, + prog.Block(0)->Var("test_out")->GetType()); + + prog.Block(0)->Var("test_b")->SetType(VarDesc_VarType_SELECTED_ROWS); + op->InferVarType(prog.Block(0)); + ASSERT_EQ(VarDesc_VarType_SELECTED_ROWS, + prog.Block(0)->Var("test_out")->GetType()); +} + +TEST(InferVarType, sum_op_without_infer_var_type) { + auto &prog = ProgramDescBind::Instance(&GetProgramDesc()); + auto *op = prog.Block(0)->AppendOp(); + op->SetType("sum_without_infer_var_type"); + op->SetInput("X", {"test2_a", "test2_b", "test2_c"}); + op->SetOutput("Out", {"test2_out"}); + + prog.Block(0)->NewVar("test2_a")->SetType(VarDesc_VarType_LOD_TENSOR); + prog.Block(0)->NewVar("test2_b")->SetType(VarDesc_VarType_SELECTED_ROWS); + prog.Block(0)->NewVar("test2_c")->SetType(VarDesc_VarType_LOD_TENSOR); + prog.Block(0)->NewVar("test2_out"); + + op->InferVarType(prog.Block(0)); + + ASSERT_EQ(VarDesc_VarType_LOD_TENSOR, + prog.Block(0)->Var("test2_out")->GetType()); +} + +} // namespace framework +} // namespace paddle \ No newline at end of file -- GitLab