diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index f69a3cfbf8b195cceb15fc2ce81d8bb9d28e5830..f4fef055daf39e9be0645deaafdad4132fc7e35f 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -24,9 +24,10 @@ cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc) cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) -cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog) +cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute) +cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog shape_inference) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) -cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info operator glog) +cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index 133869e7b58dd2082bd6e099351609f7ed37e96a..c2d6f124ad292bf46b4e7e9a1dcc2984aae7fcda 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -16,15 +16,51 @@ limitations under the License. */ #include #include #include +#include "glog/logging.h" #include "paddle/framework/block_desc.h" #include "paddle/framework/operator.h" #include "paddle/framework/program_desc.h" - -#include "glog/logging.h" +#include "paddle/framework/shape_inference.h" namespace paddle { namespace framework { +class OpDescBind; +class BlockDescBind; +class CompileTimeInferShapeContext : public InferShapeContext { + public: + CompileTimeInferShapeContext(const OpDescBind &op, + const BlockDescBind &block); + + bool HasInput(const std::string &name) const override; + + bool HasOutput(const std::string &name) const override; + + bool HasInputs(const std::string &name) const override; + + bool HasOutputs(const std::string &name) const override; + + DDim GetInputDim(const std::string &name) const override; + + void SetOutputDim(const std::string &name, const DDim &dim) override; + + AttrReader Attrs() const override; + + const std::vector &Inputs( + const std::string &name) const override; + + const std::vector &Outputs( + const std::string &name) const override; + + private: + DDim GetDim(const std::string &name) const override; + + void SetDim(const std::string &name, const DDim &dim) override; + + const OpDescBind &op_; + const BlockDescBind &block_; +}; + OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs) { @@ -288,5 +324,97 @@ void OpDescBind::InferVarType(BlockDescBind *block) const { } } +CompileTimeInferShapeContext::CompileTimeInferShapeContext( + const OpDescBind &op, const BlockDescBind &block) + : op_(op), block_(block) {} + +bool CompileTimeInferShapeContext::HasInput(const std::string &name) const { + const std::vector &input_names = op_.Input(name); + auto length = input_names.size(); + if (length == 0) { + return false; + } + PADDLE_ENFORCE_EQ(length, 1UL, + "Input(%s) should have only one value, " + "but it have %d now", + name, length); + return block_.HasVarRecursive(input_names[0]); +} + +bool CompileTimeInferShapeContext::HasOutput(const std::string &name) const { + const std::vector &output_names = op_.Output(name); + auto length = output_names.size(); + if (length == 0) { + return false; + } + PADDLE_ENFORCE_EQ(length, 1UL, + "Output(%s) should have only one value, " + "but it have %d now", + name, length); + return block_.HasVarRecursive(output_names[0]); +} + +bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const { + const std::vector &input_names = op_.Input(name); + if (input_names.empty()) { + return false; + } + for (auto &input : input_names) { + if (!block_.HasVarRecursive(input)) return false; + } + return true; +} + +bool CompileTimeInferShapeContext::HasOutputs(const std::string &name) const { + const std::vector &output_names = op_.Output(name); + if (output_names.empty()) { + return false; + } + for (auto &output : output_names) { + if (!block_.HasVarRecursive(output)) return false; + } + return true; +} + +DDim CompileTimeInferShapeContext::GetInputDim(const std::string &name) const { + std::vector ddims = GetInputsDim(name); + auto length = ddims.size(); + PADDLE_ENFORCE_EQ(length, 1UL, + "Input(%s) should have 1 value, " + "but it has %d now", + name, length); + return ddims[0]; +} + +void CompileTimeInferShapeContext::SetOutputDim(const std::string &name, + const DDim &dim) { + SetOutputsDim(name, {dim}); +} + +AttrReader CompileTimeInferShapeContext::Attrs() const { + return AttrReader(op_.GetAttrMap()); +} + +const std::vector &CompileTimeInferShapeContext::Inputs( + const std::string &name) const { + return op_.Input(name); +} + +const std::vector &CompileTimeInferShapeContext::Outputs( + const std::string &name) const { + return op_.Output(name); +} + +DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const { + auto var = block_.FindVarRecursive(name); + PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name); + return framework::make_ddim(var->Shape()); +} + +void CompileTimeInferShapeContext::SetDim(const std::string &name, + const DDim &dim) { + block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim)); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index ed85c386ec2632604bf5faf0ff9b1a087eb9c276..deacf41f993e42c9c5e41b93433ddacf95a175e5 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -29,6 +29,7 @@ limitations under the License. */ #include "paddle/framework/op_desc.h" #include "paddle/framework/operator.h" #include "paddle/framework/scope.h" +#include "paddle/framework/shape_inference.h" namespace paddle { namespace framework { diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index db154e4f76fbec444ae4347523cadd1b6d29d319..9e1e955aaeaf8336eee0c0a7cbada56aa26352e2 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/framework/operator.h" #include #include +#include "paddle/framework/shape_inference.h" namespace paddle { namespace framework { @@ -273,5 +274,136 @@ bool OpSupportGPU(const std::string& op_type) { return false; } +class RuntimeInferShapeContext : public InferShapeContext { + public: + RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope) + : op_(op), scope_(scope) {} + + bool HasInput(const std::string& name) const override { + auto& ins = Inputs(name); + size_t length = ins.size(); + if (length == 0) { + return false; + } + PADDLE_ENFORCE_EQ(length, 1UL, "Input %s should have more than one inputs", + name); + auto ipt = ins[0]; + auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); + return var != nullptr; + } + + bool HasOutput(const std::string& name) const override { + auto& outs = Outputs(name); + size_t length = outs.size(); + if (length == 0) { + return false; + } + PADDLE_ENFORCE_EQ(length, 1UL, "Output %s should have more than one inputs", + name); + auto ipt = outs[0]; + auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); + return var != nullptr; + } + + bool HasInputs(const std::string& name) const override { + auto inputs = op_.Inputs(name); + if (inputs.empty()) { + return false; + } + for (auto& input : inputs) { + if (scope_.FindVar(input) == nullptr) { + return false; + } + } + return true; + } + + bool HasOutputs(const std::string& name) const override { + auto outputs = op_.Outputs(name); + if (outputs.empty()) { + return false; + } + for (auto& output : outputs) { + if (scope_.FindVar(output) == nullptr) { + return false; + } + } + return true; + } + + DDim GetInputDim(const std::string& name) const override { + return GetDim(op_.Input(name)); + } + + void SetOutputDim(const std::string& name, const DDim& dim) override { + SetDim(op_.Output(name), dim); + } + + AttrReader Attrs() const override { return AttrReader(op_.Attrs()); } + + const std::vector& Inputs( + const std::string& name) const override { + return op_.Inputs(name); + } + + const std::vector& Outputs( + const std::string& name) const override { + return op_.Outputs(name); + } + + private: + DDim GetDim(const std::string& name) const override { + Variable* var = scope_.FindVar(name); + if (var->IsType()) { + return var->Get().dims(); + } else if (var->IsType()) { + return var->Get().GetCompleteDims(); + } else { + PADDLE_THROW("Variable type must be LoDTensor/SelectedRows."); + } + } + + void SetDim(const std::string& name, const DDim& dim) override { + Variable* var = scope_.FindVar(name); + if (var->IsType()) { + var->GetMutable()->Resize(dim); + } else if (var->IsType()) { + var->GetMutable()->set_height(dim[0]); + } else { + PADDLE_THROW("Variable type must be LoDTensor/SelectedRows."); + } + } + + const OperatorBase& op_; + const Scope& scope_; +}; + +void OperatorWithKernel::Run(const Scope& scope, + const platform::DeviceContext& dev_ctx) const { + VLOG(3) << "Running operator " << this->Type(); + RuntimeInferShapeContext infer_shape_ctx(*this, scope); + this->InferShape(&infer_shape_ctx); + + ExecutionContext ctx(*this, scope, dev_ctx); + + // check if op[type] has kernel registered. + auto& all_op_kernels = AllOpKernels(); + auto kernels_iter = all_op_kernels.find(type_); + if (kernels_iter == all_op_kernels.end()) { + PADDLE_THROW("op[%s] has no kernel", type_); + } + + // check if op[type] have kernel for kernel_key + OpKernelMap& kernels = kernels_iter->second; + auto kernel_key = OpKernelKey(IndicateDataType(ctx), dev_ctx); + auto kernel_iter = kernels.find(kernel_key); + + if (kernel_iter == kernels.end()) { + PADDLE_THROW("op[%s] has no kernel with kernel_key[%s]", type_, kernel_key); + } + + kernel_iter->second->Compute(ctx); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index aa79f16df82ab9d81e093af60b730d9aacd09568..3a9c7a732885b489c23c7a7961066a6e87849203 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -29,7 +29,6 @@ limitations under the License. */ #include "paddle/framework/op_info.h" #include "paddle/framework/scope.h" #include "paddle/framework/selected_rows.h" -#include "paddle/framework/shape_inference.h" #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" #include "paddle/platform/place.h" @@ -317,226 +316,6 @@ template <> std::vector ExecutionContext::MultiOutput( const std::string& name) const; -class CompileTimeInferShapeContext : public InferShapeContext { - public: - CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block) - : op_(op), block_(block) {} - - bool HasInput(const std::string& name) const override { - const std::vector& input_names = op_.Input(name); - auto length = input_names.size(); - if (length == 0) { - return false; - } - PADDLE_ENFORCE_EQ(length, 1UL, - "Input(%s) should have only one value, " - "but it have %d now", - name, length); - return block_.HasVarRecursive(input_names[0]); - } - - bool HasOutput(const std::string& name) const override { - const std::vector& output_names = op_.Output(name); - auto length = output_names.size(); - if (length == 0) { - return false; - } - PADDLE_ENFORCE_EQ(length, 1UL, - "Output(%s) should have only one value, " - "but it have %d now", - name, length); - return block_.HasVarRecursive(output_names[0]); - } - - bool HasInputs(const std::string& name) const override { - const std::vector& input_names = op_.Input(name); - if (input_names.empty()) { - return false; - } - for (auto& input : input_names) { - if (!block_.HasVarRecursive(input)) return false; - } - return true; - } - - bool HasOutputs(const std::string& name) const override { - const std::vector& output_names = op_.Output(name); - if (output_names.empty()) { - return false; - } - for (auto& output : output_names) { - if (!block_.HasVarRecursive(output)) return false; - } - return true; - } - - DDim GetInputDim(const std::string& name) const override { - std::vector ddims = GetInputsDim(name); - auto length = ddims.size(); - PADDLE_ENFORCE_EQ(length, 1UL, - "Input(%s) should have 1 value, " - "but it has %d now", - name, length); - return ddims[0]; - } - - void SetInputDim(const std::string& name, const DDim& dim) override { - SetInputsDim(name, {dim}); - } - - DDim GetOutputDim(const std::string& name) const override { - std::vector ddims = GetOutputsDim(name); - auto length = ddims.size(); - PADDLE_ENFORCE_EQ(length, 1UL, - "Output(%s) should have 1 value, " - "but it has %d now", - name, length); - return ddims[0]; - } - - void SetOutputDim(const std::string& name, const DDim& dim) override { - SetOutputsDim(name, {dim}); - } - - AttrReader Attrs() const override { return AttrReader(op_.GetAttrMap()); } - - const std::vector& Inputs( - const std::string& name) const override { - return op_.Input(name); - } - - const std::vector& Outputs( - const std::string& name) const override { - return op_.Output(name); - } - - private: - DDim GetDim(const std::string& name) const override { - auto var = block_.FindVarRecursive(name); - PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name); - return framework::make_ddim(var->Shape()); - } - - void SetDim(const std::string& name, const DDim& dim) override { - block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim)); - } - - const OpDescBind& op_; - const BlockDescBind& block_; -}; - -class RuntimeInferShapeContext : public InferShapeContext { - public: - RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope) - : op_(op), scope_(scope) {} - - bool HasInput(const std::string& name) const override { - auto& ins = Inputs(name); - size_t length = ins.size(); - if (length == 0) { - return false; - } - PADDLE_ENFORCE_EQ(length, 1UL, "Input %s should have more than one inputs", - name); - auto ipt = ins[0]; - auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); - return var != nullptr; - } - - bool HasOutput(const std::string& name) const override { - auto& outs = Outputs(name); - size_t length = outs.size(); - if (length == 0) { - return false; - } - PADDLE_ENFORCE_EQ(length, 1UL, "Output %s should have more than one inputs", - name); - auto ipt = outs[0]; - auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); - return var != nullptr; - } - - bool HasInputs(const std::string& name) const override { - auto inputs = op_.Inputs(name); - if (inputs.empty()) { - return false; - } - for (auto& input : inputs) { - if (scope_.FindVar(input) == nullptr) { - return false; - } - } - return true; - } - - bool HasOutputs(const std::string& name) const override { - auto outputs = op_.Outputs(name); - if (outputs.empty()) { - return false; - } - for (auto& output : outputs) { - if (scope_.FindVar(output) == nullptr) { - return false; - } - } - return true; - } - - DDim GetInputDim(const std::string& name) const override { - return GetDim(op_.Input(name)); - } - - void SetInputDim(const std::string& name, const DDim& dim) override { - SetDim(op_.Input(name), dim); - } - - DDim GetOutputDim(const std::string& name) const override { - return GetDim(op_.Output(name)); - } - - void SetOutputDim(const std::string& name, const DDim& dim) override { - SetDim(op_.Output(name), dim); - } - - AttrReader Attrs() const override { return AttrReader(op_.Attrs()); } - - const std::vector& Inputs( - const std::string& name) const override { - return op_.Inputs(name); - } - - const std::vector& Outputs( - const std::string& name) const override { - return op_.Outputs(name); - } - - private: - DDim GetDim(const std::string& name) const override { - Variable* var = scope_.FindVar(name); - if (var->IsType()) { - return var->Get().dims(); - } else if (var->IsType()) { - return var->Get().GetCompleteDims(); - } else { - PADDLE_THROW("Variable type must be LoDTensor/SelectedRows."); - } - } - - void SetDim(const std::string& name, const DDim& dim) override { - Variable* var = scope_.FindVar(name); - if (var->IsType()) { - var->GetMutable()->Resize(dim); - } else if (var->IsType()) { - var->GetMutable()->set_height(dim[0]); - } else { - PADDLE_THROW("Variable type must be LoDTensor/SelectedRows."); - } - } - - const OperatorBase& op_; - const Scope& scope_; -}; - class OpKernelBase { public: /** @@ -595,32 +374,7 @@ class OperatorWithKernel : public OperatorBase { : OperatorBase(type, inputs, outputs, attrs) {} void Run(const Scope& scope, - const platform::DeviceContext& dev_ctx) const final { - VLOG(3) << "Running operator " << this->Type(); - RuntimeInferShapeContext infer_shape_ctx(*this, scope); - this->InferShape(&infer_shape_ctx); - - ExecutionContext ctx(*this, scope, dev_ctx); - - // check if op[type] has kernel registered. - auto& all_op_kernels = AllOpKernels(); - auto kernels_iter = all_op_kernels.find(type_); - if (kernels_iter == all_op_kernels.end()) { - PADDLE_THROW("op[%s] has no kernel", type_); - } - - // check if op[type] have kernel for kernel_key - OpKernelMap& kernels = kernels_iter->second; - auto kernel_key = OpKernelKey(IndicateDataType(ctx), dev_ctx); - auto kernel_iter = kernels.find(kernel_key); - - if (kernel_iter == kernels.end()) { - PADDLE_THROW("op[%s] has no kernel with kernel_key[%s]", type_, - kernel_key); - } - - kernel_iter->second->Compute(ctx); - } + const platform::DeviceContext& dev_ctx) const final; static std::unordered_map& AllOpKernels() { diff --git a/paddle/framework/shape_inference.cc b/paddle/framework/shape_inference.cc new file mode 100644 index 0000000000000000000000000000000000000000..33a1d0b9b217c5d2a4b0fb63f427529e7988b24e --- /dev/null +++ b/paddle/framework/shape_inference.cc @@ -0,0 +1,54 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ +#include "paddle/framework/shape_inference.h" + +namespace paddle { +namespace framework { + +std::vector InferShapeContext::GetInputsDim( + const std::string &name) const { + const std::vector &names = Inputs(name); + return GetDims(names); +} + +void InferShapeContext::SetOutputsDim( + const std::string &name, const std::vector &dims) { + auto &names = Outputs(name); + SetDims(names, dims); +} + +void InferShapeContext::ShareLoD(const std::string &in, const std::string &out, + size_t i, size_t j) const {} + +std::vector InferShapeContext::GetDims( + const std::vector &names) const { + std::vector ret; + ret.reserve(names.size()); + std::transform( + names.begin(), names.end(), std::back_inserter(ret), + [this](const std::string &name) { return this->GetDim(name); }); + return ret; +} + +void InferShapeContext::SetDims(const std::vector &names, + const std::vector &dims) { + size_t length = names.size(); + PADDLE_ENFORCE_EQ(length, dims.size()); + for (size_t i = 0; i < length; ++i) { + SetDim(names[i], dims[i]); + } +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h index b93f980cf6d279d18388b9637a2ff45d797ca78e..f1f1e44bccd771be81cad7c28efe9b1b885eef6b 100644 --- a/paddle/framework/shape_inference.h +++ b/paddle/framework/shape_inference.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/framework/attribute.h" #include "paddle/framework/ddim.h" namespace paddle { @@ -21,7 +22,7 @@ namespace framework { class InferShapeContext { public: - virtual ~InferShapeContext() {} + virtual ~InferShapeContext() = default; virtual bool HasInput(const std::string &name) const = 0; virtual bool HasOutput(const std::string &name) const = 0; @@ -29,57 +30,32 @@ class InferShapeContext { virtual bool HasOutputs(const std::string &name) const = 0; virtual framework::DDim GetInputDim(const std::string &name) const = 0; - std::vector GetInputsDim(const std::string &name) const { - const std::vector &names = Inputs(name); - return GetDims(names); - } - virtual void SetInputDim(const std::string &name, - const framework::DDim &dim) = 0; - void SetInputsDim(const std::string &name, - const std::vector &dims) { - auto &names = Inputs(name); - SetDims(names, dims); - } - virtual framework::DDim GetOutputDim(const std::string &name) const = 0; - std::vector GetOutputsDim(const std::string &name) const { - const std::vector &names = Outputs(name); - return GetDims(names); - } + + std::vector GetInputsDim(const std::string &name) const; + virtual void SetOutputDim(const std::string &name, const DDim &dim) = 0; void SetOutputsDim(const std::string &name, - const std::vector &dims) { - auto &names = Outputs(name); - SetDims(names, dims); - } + const std::vector &dims); + virtual AttrReader Attrs() const = 0; virtual const std::vector &Inputs( const std::string &name) const = 0; virtual const std::vector &Outputs( const std::string &name) const = 0; + // TODO(qiao) implement this function void ShareLoD(const std::string &in, const std::string &out, size_t i = 0, - size_t j = 0) const {} + size_t j = 0) const; protected: virtual framework::DDim GetDim(const std::string &name) const = 0; virtual void SetDim(const std::string &name, const framework::DDim &dim) = 0; + std::vector GetDims( - const std::vector &names) const { - std::vector ret; - ret.reserve(names.size()); - std::transform( - names.begin(), names.end(), std::back_inserter(ret), - [this](const std::string &name) { return this->GetDim(name); }); - return ret; - } + const std::vector &names) const; + void SetDims(const std::vector &names, - const std::vector &dims) { - size_t length = names.size(); - PADDLE_ENFORCE_EQ(length, dims.size()); - for (size_t i = 0; i < length; ++i) { - SetDim(names[i], dims[i]); - } - } + const std::vector &dims); }; } // namespace framework