diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 9e9491d983b3e2b5b4f70692bb9171abc3ee895d..a43861f4cd1c5e238438f4974e824384aa85b797 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -19,7 +19,8 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) -cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc) +cc_library(grad_op_creator SRCS grad_op_creator.cc) +cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc grad_op_creator) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) diff --git a/paddle/framework/grad_op_creator.cc b/paddle/framework/grad_op_creator.cc index ac3663b7fcbd9e7161c2b5e9f97c68d4fe5f88cb..106c2eae9dade9ef1829fc2f1b793faf483947d4 100644 --- a/paddle/framework/grad_op_creator.cc +++ b/paddle/framework/grad_op_creator.cc @@ -1,4 +1,19 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #include "paddle/framework/grad_op_creator.h" +#include "paddle/framework/op_registry.h" namespace paddle { namespace framework { @@ -22,15 +37,15 @@ OpInOutArg* GradOpCreator::BuildArg(const VarProto& var, } void GradOpCreator::BuildOpInOutArgList() { - const OpProto& op_proto = OpRegistry::protos().at(op_->type); - const auto& var_map = *(OpRegistry::VarIndexMaps().at(op->type_)); + const OpProto& op_proto = OpRegistry::protos().at(op_->type_); + const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_)); const std::vector& in_format = op_->attrs_.count("input_format") - ? op->GetAttr>("input_format") + ? op_->GetAttr>("input_format") : std::vector(); const std::vector& out_format = op_->attrs_.count("output_format") - ? op->GetAttr>("output_format") + ? op_->GetAttr>("output_format") : std::vector(); for (const auto& var : op_proto.inputs()) { arg_list_.emplace_back( @@ -46,14 +61,15 @@ void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg, std::vector& in_out, std::vector& format, VarIndexMap* varmap, int& idx, - bool is_grad) { + bool is_grad) const { std::string var_name = arg->proto_name_; if (is_grad) { var_name += OperatorBase::GRAD_VAR_SUFFIX(); } - *(varmap)[var_name] = idx++; + (*varmap)[var_name] = idx++; size_t pre_sz = in_out.size(); - auto base_it = arg->type == IN ? op_->inputs_.begin() : op_->outputs_.begin(); + auto base_it = + arg->type_ == IN ? op_->inputs_.begin() : op_->outputs_.begin(); std::copy(base_it + arg->begin_idx_, base_it + arg->end_idx_, std::back_inserter(in_out)); if (is_grad) { @@ -96,4 +112,4 @@ void GradOpCreator::CompleteGradOp(OperatorBase* grad_op) const { } } // namespace framework -} // namespace paddle \ No newline at end of file +} // namespace paddle diff --git a/paddle/framework/grad_op_creator.h b/paddle/framework/grad_op_creator.h index 456b066f1d00160dae33499fc3b4f09d23c60f09..21b160a73f3f6402a0571e2f13be06b26b5c30bc 100644 --- a/paddle/framework/grad_op_creator.h +++ b/paddle/framework/grad_op_creator.h @@ -7,6 +7,24 @@ namespace paddle { namespace framework { class OpRegistry; +enum InOutType { IN, OUT }; + +struct OpInOutArg { + OpInOutArg(const std::string& proto_name, const InOutType& type, + bool needed_in_grad, size_t begin_idx, size_t end_idx) + : proto_name_(proto_name), + type_(type), + needed_in_grad_(needed_in_grad), + begin_idx_(begin_idx), + end_idx_(end_idx) {} + + std::string proto_name_; + InOutType type_; + bool needed_in_grad_; + size_t begin_idx_; + size_t end_idx_; +}; + class GradOpCreator { using VarIndexMap = std::unordered_map; @@ -15,30 +33,12 @@ class GradOpCreator { OperatorBase* Create(); private: - enum InOutType { IN, OUT }; - - struct OpInOutArg { - OpInOutArg(const std::string& proto_name, const InOutType& type, - bool needed_in_grad, size_t begin_idx, size_t end_idx) - : proto_name_(proto_name), - type_(type), - needed_in_grad_(needed_in_grad), - begin_idx_(begin_idx), - end_idx_(end_idx) {} - - std::string proto_name_; - InOutType type_; - bool needed_in_grad_; - size_t begin_idx_; - size_t end_idx_; - }; - OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map, const std::vector& format, InOutType type); void BuildOpInOutArgList(); void AddArgIntoGradOp(const OpInOutArg* arg, std::vector& in_out, std::vector& format, VarIndexMap* varmap, int& idx, - bool is_grad); + bool is_grad) const; void CompleteGradOp(OperatorBase* grad_op) const; const OperatorBase* op_; std::vector> arg_list_; diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 897238fc69b54842f8aa23f956a5a0dcbec8103a..bbeeefb20cab19e629c317e1340a0104c11c940c 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #pragma once #include