提交 9418717f 编写于 作者: F fengjiayi

Fix compile errors

上级 f41fcd43
...@@ -19,7 +19,8 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) ...@@ -19,7 +19,8 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor) cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc) cc_library(grad_op_creator SRCS grad_op_creator.cc)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc grad_op_creator)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/grad_op_creator.h" #include "paddle/framework/grad_op_creator.h"
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -22,15 +37,15 @@ OpInOutArg* GradOpCreator::BuildArg(const VarProto& var, ...@@ -22,15 +37,15 @@ OpInOutArg* GradOpCreator::BuildArg(const VarProto& var,
} }
void GradOpCreator::BuildOpInOutArgList() { void GradOpCreator::BuildOpInOutArgList() {
const OpProto& op_proto = OpRegistry::protos().at(op_->type); const OpProto& op_proto = OpRegistry::protos().at(op_->type_);
const auto& var_map = *(OpRegistry::VarIndexMaps().at(op->type_)); const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_));
const std::vector<int>& in_format = const std::vector<int>& in_format =
op_->attrs_.count("input_format") op_->attrs_.count("input_format")
? op->GetAttr<std::vector<int>>("input_format") ? op_->GetAttr<std::vector<int>>("input_format")
: std::vector<int>(); : std::vector<int>();
const std::vector<int>& out_format = const std::vector<int>& out_format =
op_->attrs_.count("output_format") op_->attrs_.count("output_format")
? op->GetAttr<std::vector<int>>("output_format") ? op_->GetAttr<std::vector<int>>("output_format")
: std::vector<int>(); : std::vector<int>();
for (const auto& var : op_proto.inputs()) { for (const auto& var : op_proto.inputs()) {
arg_list_.emplace_back( arg_list_.emplace_back(
...@@ -46,14 +61,15 @@ void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg, ...@@ -46,14 +61,15 @@ void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg,
std::vector<std::string>& in_out, std::vector<std::string>& in_out,
std::vector<int>& format, std::vector<int>& format,
VarIndexMap* varmap, int& idx, VarIndexMap* varmap, int& idx,
bool is_grad) { bool is_grad) const {
std::string var_name = arg->proto_name_; std::string var_name = arg->proto_name_;
if (is_grad) { if (is_grad) {
var_name += OperatorBase::GRAD_VAR_SUFFIX(); var_name += OperatorBase::GRAD_VAR_SUFFIX();
} }
*(varmap)[var_name] = idx++; (*varmap)[var_name] = idx++;
size_t pre_sz = in_out.size(); size_t pre_sz = in_out.size();
auto base_it = arg->type == IN ? op_->inputs_.begin() : op_->outputs_.begin(); auto base_it =
arg->type_ == IN ? op_->inputs_.begin() : op_->outputs_.begin();
std::copy(base_it + arg->begin_idx_, base_it + arg->end_idx_, std::copy(base_it + arg->begin_idx_, base_it + arg->end_idx_,
std::back_inserter(in_out)); std::back_inserter(in_out));
if (is_grad) { if (is_grad) {
......
...@@ -7,17 +7,9 @@ namespace paddle { ...@@ -7,17 +7,9 @@ namespace paddle {
namespace framework { namespace framework {
class OpRegistry; class OpRegistry;
class GradOpCreator { enum InOutType { IN, OUT };
using VarIndexMap = std::unordered_map<std::string, int>;
public: struct OpInOutArg {
GradOpCreator(const OperatorBase* op) : op_(op) {}
OperatorBase* Create();
private:
enum InOutType { IN, OUT };
struct OpInOutArg {
OpInOutArg(const std::string& proto_name, const InOutType& type, OpInOutArg(const std::string& proto_name, const InOutType& type,
bool needed_in_grad, size_t begin_idx, size_t end_idx) bool needed_in_grad, size_t begin_idx, size_t end_idx)
: proto_name_(proto_name), : proto_name_(proto_name),
...@@ -31,14 +23,22 @@ class GradOpCreator { ...@@ -31,14 +23,22 @@ class GradOpCreator {
bool needed_in_grad_; bool needed_in_grad_;
size_t begin_idx_; size_t begin_idx_;
size_t end_idx_; size_t end_idx_;
}; };
class GradOpCreator {
using VarIndexMap = std::unordered_map<std::string, int>;
public:
GradOpCreator(const OperatorBase* op) : op_(op) {}
OperatorBase* Create();
private:
OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map, OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map,
const std::vector<int>& format, InOutType type); const std::vector<int>& format, InOutType type);
void BuildOpInOutArgList(); void BuildOpInOutArgList();
void AddArgIntoGradOp(const OpInOutArg* arg, std::vector<std::string>& in_out, void AddArgIntoGradOp(const OpInOutArg* arg, std::vector<std::string>& in_out,
std::vector<int>& format, VarIndexMap* varmap, int& idx, std::vector<int>& format, VarIndexMap* varmap, int& idx,
bool is_grad); bool is_grad) const;
void CompleteGradOp(OperatorBase* grad_op) const; void CompleteGradOp(OperatorBase* grad_op) const;
const OperatorBase* op_; const OperatorBase* op_;
std::vector<std::shared_ptr<OpInOutArg>> arg_list_; std::vector<std::shared_ptr<OpInOutArg>> arg_list_;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once #pragma once
#include <algorithm> #include <algorithm>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册