diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 6585fb7b5d295816d6b6c9c9e062e3358b7613f4..4bc3fdeeea461ea2a1f82caa00d6c0c11a2775d0 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -47,5 +47,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope frame cc_library(tensor_array SRCS tensor_array.cc DEPS lod_tensor) cc_test(tensor_array_test SRCS tensor_array_test.cc DEPS tensor_array place) +cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry + proto_desc) cc_library(selected_rows SRCS selected_rows.cc DEPS tensor) cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows) diff --git a/paddle/framework/details/op_registry.h b/paddle/framework/details/op_registry.h index ed7c5f17b0854809bde923276f36440cce193a88..357ad21f39f3b1f6dbdb98063f8fb24ec6800ec6 100644 --- a/paddle/framework/details/op_registry.h +++ b/paddle/framework/details/op_registry.h @@ -18,6 +18,7 @@ #include "paddle/framework/op_info.h" #include "paddle/framework/op_proto_maker.h" #include "paddle/framework/operator.h" +#include "paddle/framework/var_type_inference.h" namespace paddle { namespace framework { @@ -26,7 +27,8 @@ namespace details { enum OpInfoFillType { kOperator = 0, kOpProtoAndCheckerMaker = 1, - kGradOpDescMaker = 2 + kGradOpDescMaker = 2, + kVarTypeInference = 3 }; template @@ -38,7 +40,9 @@ struct OpInfoFillTypeID { ? kOpProtoAndCheckerMaker : (std::is_base_of::value ? kGradOpDescMaker - : static_cast(-1))); + : (std::is_base_of::value + ? kVarTypeInference + : static_cast(-1)))); } }; @@ -106,6 +110,17 @@ struct OpInfoFiller { }; } }; + +template +struct OpInfoFiller { + void operator()(const char* op_type, OpInfo* info) const { + info->infer_var_type_ = [](const OpDescBind& fwd_op, BlockDescBind* block) { + T inference; + inference(fwd_op, block); + }; + } +}; + } // namespace details } // namespace framework diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index 7f7cebb026703d53220b3b4d6dc092d5bcbbe039..18fabe481dac9c1b70e7c30cb83ec5ee8ac47026 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -239,5 +239,19 @@ void OpDescBind::InferShape(const BlockDescBind &block) const { it->second(&ctx); } +void OpDescBind::InferVarType(BlockDescBind *block) const { + auto &info = OpInfoMap::Instance().Get(this->Type()); + if (info.infer_var_type_) { + info.infer_var_type_(*this, block); + } else { + // all output type is LoDTensor by default + for (auto &out_pair : this->outputs_) { + for (auto &out_var_name : out_pair.second) { + block->Var(out_var_name)->SetType(VarDesc::LOD_TENSOR); + } + } + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index 73b5cf846f702fe21277ae139156ec9784aa79b3..313bf538ac7c947c5e77ca0ead6bb53e6a156478 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -102,6 +102,8 @@ class OpDescBind { void InferShape(const BlockDescBind &block) const; + void InferVarType(BlockDescBind *block) const; + void Flush(); private: diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h index c504f69e30bb899c183bd4281d2eadb50fd3b376..e926180780609c0a8ffc6270627835c50bbce782 100644 --- a/paddle/framework/op_info.h +++ b/paddle/framework/op_info.h @@ -19,7 +19,6 @@ #include #include "paddle/framework/attribute.h" -#include "paddle/framework/op_desc.h" #include "paddle/framework/type_defs.h" #include "paddle/platform/macros.h" @@ -31,6 +30,7 @@ struct OpInfo { GradOpMakerFN grad_op_maker_; OpProto* proto_{nullptr}; OpAttrChecker* checker_{nullptr}; + InferVarTypeFN infer_var_type_; bool HasOpProtoAndChecker() const { return proto_ != nullptr && checker_ != nullptr; diff --git a/paddle/framework/type_defs.h b/paddle/framework/type_defs.h index 0d1564a7510ddf0106ff417fb0b487ddbde1ac2e..00da7289394cf18e013220a4bedde2c182f6a4a4 100644 --- a/paddle/framework/type_defs.h +++ b/paddle/framework/type_defs.h @@ -16,12 +16,18 @@ #include #include #include +#include +#include +#include +#include #include "paddle/platform/variant.h" namespace paddle { namespace framework { class OperatorBase; class OpDescBind; +class BlockDescBind; +class BlockDesc; using VariableNameMap = std::map>; // The order should be as same as framework.proto @@ -40,5 +46,8 @@ using GradOpMakerFN = std::function>( const OpDescBind&, const std::unordered_set& /*no_grad_set*/, std::unordered_map* /*grad_to_var*/)>; +using InferVarTypeFN = std::function; + } // namespace framework } // namespace paddle diff --git a/paddle/framework/var_type_inference.h b/paddle/framework/var_type_inference.h new file mode 100644 index 0000000000000000000000000000000000000000..32abbeb33479444c5e7a9889f4211f59af07f98f --- /dev/null +++ b/paddle/framework/var_type_inference.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/type_defs.h" + +namespace paddle { +namespace framework { + +class VarTypeInference { + public: + virtual ~VarTypeInference() {} + virtual void operator()(const OpDescBind& op_desc, + BlockDescBind* block) const = 0; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/var_type_inference_test.cc b/paddle/framework/var_type_inference_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..97b8c647485c4520dea895f220077ec6a6eedd70 --- /dev/null +++ b/paddle/framework/var_type_inference_test.cc @@ -0,0 +1,104 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/var_type_inference.h" +#include "gtest/gtest.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" +#include "paddle/framework/program_desc.h" + +namespace paddle { +namespace framework { + +class SumOpMaker : public OpProtoAndCheckerMaker { + public: + SumOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "").AsDuplicable(); + AddOutput("Out", ""); + AddComment(""); + } +}; + +class SumOpVarTypeInference : public VarTypeInference { + public: + void operator()(const OpDescBind &op_desc, + BlockDescBind *block) const override { + auto &inputs = op_desc.Input("X"); + auto default_var_type = VarDesc::SELECTED_ROWS; + + bool any_input_is_lod_tensor = std::any_of( + inputs.begin(), inputs.end(), [block](const std::string &name) { + return block->Var(name)->GetType() == VarDesc::LOD_TENSOR; + }); + if (any_input_is_lod_tensor) { + default_var_type = VarDesc::LOD_TENSOR; + } + + auto out_var_name = op_desc.Output("Out").front(); + block->Var(out_var_name)->SetType(default_var_type); + } +}; +} // namespace framework +} // namespace paddle + +REGISTER_OPERATOR(sum, paddle::framework::NOP, paddle::framework::SumOpMaker, + paddle::framework::SumOpVarTypeInference); +REGISTER_OPERATOR(sum_without_infer_var_type, paddle::framework::NOP, + paddle::framework::SumOpMaker); + +namespace paddle { +namespace framework { + +TEST(InferVarType, sum_op) { + auto &prog = ProgramDescBind::Instance(&GetProgramDesc()); + auto *op = prog.Block(0)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"test_a", "test_b", "test_c"}); + op->SetOutput("Out", {"test_out"}); + + prog.Block(0)->NewVar("test_a")->SetType(VarDesc::SELECTED_ROWS); + prog.Block(0)->NewVar("test_b")->SetType(VarDesc::SELECTED_ROWS); + prog.Block(0)->NewVar("test_c")->SetType(VarDesc::SELECTED_ROWS); + prog.Block(0)->NewVar("test_out"); + + op->InferVarType(prog.Block(0)); + + ASSERT_EQ(VarDesc::SELECTED_ROWS, prog.Block(0)->Var("test_out")->GetType()); + + prog.Block(0)->Var("test_b")->SetType(VarDesc::LOD_TENSOR); + op->InferVarType(prog.Block(0)); + ASSERT_EQ(VarDesc::LOD_TENSOR, prog.Block(0)->Var("test_out")->GetType()); +} + +TEST(InferVarType, sum_op_without_infer_var_type) { + auto &prog = ProgramDescBind::Instance(&GetProgramDesc()); + auto *op = prog.Block(0)->AppendOp(); + op->SetType("sum_without_infer_var_type"); + op->SetInput("X", {"test2_a", "test2_b", "test2_c"}); + op->SetOutput("Out", {"test2_out"}); + + prog.Block(0)->NewVar("test2_a")->SetType(VarDesc::SELECTED_ROWS); + prog.Block(0)->NewVar("test2_b")->SetType(VarDesc::SELECTED_ROWS); + prog.Block(0)->NewVar("test2_c")->SetType(VarDesc::SELECTED_ROWS); + prog.Block(0)->NewVar("test2_out"); + + op->InferVarType(prog.Block(0)); + + ASSERT_EQ(VarDesc_VarType_LOD_TENSOR, + prog.Block(0)->Var("test2_out")->GetType()); +} + +} // namespace framework +} // namespace paddle \ No newline at end of file