From 578a357b616ee188d692764843ae834a449e81c2 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 2 Oct 2017 15:12:20 -0700 Subject: [PATCH] Make compile pass --- paddle/framework/CMakeLists.txt | 4 +--- paddle/framework/backward.cc | 33 +++++++++++++++++++++++++++++---- paddle/framework/op_desc.h | 14 ++++++++------ paddle/framework/op_registry.cc | 6 ++++++ paddle/framework/op_registry.h | 8 +++++--- 5 files changed, 49 insertions(+), 16 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 9140854a96c..eb316b4c8cc 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -26,10 +26,8 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) -cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc) -cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker op_info) +cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) -cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry sum_op) py_proto_compile(framework_py_proto SRCS framework.proto) # Generate an empty __init__.py to make framework_py_proto as a valid python module. diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index ab2567a25c0..eb34bc36932 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/framework/backward.h" +#include "paddle/operators/net_op.h" #include #include @@ -24,6 +25,32 @@ namespace paddle { namespace framework { +static inline std::unique_ptr CreateGradOp( + const OperatorBase& op) { + OpDescBind op_desc; + op_desc.SetInputMap(op.Inputs()); + op_desc.SetOutputMap(op.Outputs()); + op_desc.SetType(op.Type()); + op_desc.SetAttrMap(op.Attrs()); + auto& info = OpInfoMap::Instance().Get(op.Type()); + auto grad_descs = info.grad_op_maker_(op_desc); + std::vector> grad_ops; + grad_ops.reserve(grad_descs.size()); + std::transform( + grad_descs.begin(), grad_descs.end(), std::back_inserter(grad_ops), + [](OpDescBind& grad_desc) { return OpRegistry::CreateOp(&grad_desc); }); + PADDLE_ENFORCE_GT(grad_ops.size(), 0); + if (grad_ops.size() == 1) { + return std::move(grad_ops[0]); + } else { + auto net_op = new operators::NetOp(); + for (auto& grad_op : grad_ops) { + net_op->AppendOp(std::move(grad_op)); + } + return std::unique_ptr(net_op); + } +} + template static void ForEachVarName(const Map& names, T callback) { for (auto& name : names) { @@ -154,10 +181,8 @@ static std::unique_ptr BackwardRecursive( net->InsertOp(pos.first + 1, std::move(pos.second)); } } else { - OpDescBind fwd_desc; - fwd_desc.SetInput(forwardOp.Inputs()); - - std::unique_ptr grad_op(OpRegistry::CreateGradOp(forwardOp)); + std::unique_ptr grad_op(CreateGradOp(forwardOp)); + PADDLE_ENFORCE(grad_op != nullptr); ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op]( const std::string& grad_input) { diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index ec92d087688..72d7a0379b9 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -76,18 +76,22 @@ class OpDescBind { return MapKeys(outputs_); } - void SetInput( - const std::unordered_map> &input) { + void SetInputMap(const VariableNameMap &input) { this->inputs_ = input; this->need_update_ = true; } - void SetOutput( - const std::unordered_map> &output) { + void SetOutputMap(const VariableNameMap &output) { this->outputs_ = output; this->need_update_ = true; } + void Sync(); + + const VariableNameMap &Inputs() const { return inputs_; } + + const VariableNameMap &Outputs() const { return outputs_; } + private: template static std::vector MapKeys(const MapType &map) { @@ -99,8 +103,6 @@ class OpDescBind { return ret_val; } - void Sync(); - OpDesc op_desc_; VariableNameMap inputs_; VariableNameMap outputs_; diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index 0a2b6fd582a..35f280981ba 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -52,5 +52,11 @@ std::unique_ptr OpRegistry::CreateOp(const OpDesc& op_desc) { return CreateOp(op_desc.type(), inputs, outputs, attrs); } +std::unique_ptr OpRegistry::CreateOp(OpDescBind* op_desc) { + op_desc->Sync(); + return CreateOp(op_desc->Type(), op_desc->Inputs(), op_desc->Outputs(), + op_desc->GetAttrMap()); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 0f377f34cbf..d14f70008b3 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -23,8 +23,8 @@ limitations under the License. */ #include "paddle/framework/attribute.h" #include "paddle/framework/details/op_registry.h" #include "paddle/framework/framework.pb.h" -#include "paddle/framework/grad_op_builder.h" #include "paddle/framework/grad_op_desc_maker.h" +#include "paddle/framework/op_desc.h" #include "paddle/framework/operator.h" #include "paddle/framework/scope.h" @@ -46,15 +46,15 @@ class Registrar { template struct OperatorRegistrar : public Registrar { explicit OperatorRegistrar(const char* op_type) : op_type(op_type) { + std::cerr << "Reg operator " << op_type << std::endl; PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type), "'%s' is registered more than once.", op_type); static_assert(sizeof...(ARGS) != 0, "OperatorRegistrar should be invoked at least by OpClass"); details::OperatorRegistrarRecursive<0, false, ARGS...>(op_type, &info); + OpInfoMap::Instance().Insert(op_type, info); } - ~OperatorRegistrar() { OpInfoMap::Instance().Insert(op_type, info); } - const char* op_type; OpInfo info; @@ -79,6 +79,8 @@ class OpRegistry { AttributeMap attrs); static std::unique_ptr CreateOp(const OpDesc& op_desc); + + static std::unique_ptr CreateOp(OpDescBind* op_desc); }; template -- GitLab