提交 b884bc33 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #4551 from reyoung/feature/grad_reg_mechanism_cont

Add helper function in GradOpDescMakerBase. Make it easier to use.
...@@ -33,22 +33,45 @@ The mapping relationship between an operator and its gradient operators is a fun ...@@ -33,22 +33,45 @@ The mapping relationship between an operator and its gradient operators is a fun
```cpp ```cpp
// (OpDesc) --> vector<OpDesc> // (OpDesc) --> vector<OpDesc>
using GradOpDescMaker = std::function<std::vector<OpDesc>(const OpDesc&)>; std::function<std::vector<OpDescBind>(const OpDescBind&)>;
``` ```
The function take a `OpDesc` of the forward operator and return one or many gradient operator descriptions. The function takes an `OpDescBind` of the forward operator and returns one or many gradient operator descriptions. `OpDescBind` is a C++ wrapper for protobuf message `OpDesc` to manipulate `OpDesc` fast.
The `GradOpDescMaker` will be registered in `OpInfo`, to replace `grad_op_type_` field. The `OpInfo` should be The `GradOpDescMaker` will be registered in `OpInfo`, to replace `grad_op_type_` field. The `OpInfo` should be
```cpp ```cpp
struct OpInfo { struct OpInfo {
GradOpDescMaker grad_op_maker_; std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)> grad_op_maker_;
... ...
}; };
``` ```
The `grad_op_maker_ ` is `nullptr` if the operator does not have associated gradient operators. The `grad_op_maker_ ` is `nullptr` if the operator does not have associated gradient operators.
We propose a base class called `GradOpDescMakerBase` to let operator developers generate `Gradient Operators` easily. The public interface of that class is
```cpp
class GradOpDescMakerBase {
public:
GradOpDescMakerBase(const OpDescBind& );
virtual std::vector<std::unique_ptr<OpDescBind>> operator()()const = 0;
};
```
We can convert `GradOpDescMakerBase` to `std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>` by
```cpp
using GradOpMaker = ...;
std::function<std::vector<OpDescBind>(const OpDescBind&)> func;
func = [] (const OpDescBind& fwd_op) {
GradOpMaker maker(fwd_op);
return maker();
};
```
We can write many helper functions since the `GradOpDescMakerBase` is a class now. The basic helper functions get the variables of `Input`, `Output`, `InputGradient` and `OutputGradient` in the forwarding operator.
We should chagne register macros at the same time. In the current solution, there is no difference between forwarding operators and backward operators. So `REGISTER_OP` just register one operator. If the `REGISTER_OPERATOR ` contains `OpProtoAndCheckerMaker` and `GradOpDescMaker`, we just list them in the same macro. It can be done by a macro contains `__VA_ARGS__`. We should chagne register macros at the same time. In the current solution, there is no difference between forwarding operators and backward operators. So `REGISTER_OP` just register one operator. If the `REGISTER_OPERATOR ` contains `OpProtoAndCheckerMaker` and `GradOpDescMaker`, we just list them in the same macro. It can be done by a macro contains `__VA_ARGS__`.
The user interface should be The user interface should be
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include "paddle/framework/grad_op_desc_maker.h"
#include "paddle/framework/op_info.h" #include "paddle/framework/op_info.h"
#include "paddle/framework/op_proto_maker.h" #include "paddle/framework/op_proto_maker.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
...@@ -96,7 +97,10 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> { ...@@ -96,7 +97,10 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
template <typename T> template <typename T>
struct OpInfoFiller<T, kGradOpDescMaker> { struct OpInfoFiller<T, kGradOpDescMaker> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
info->grad_op_maker_ = new T(); info->grad_op_maker_ = [](const OpDescBind& fwd_op) {
T maker(fwd_op);
return maker();
};
} }
}; };
} // namespace details } // namespace details
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace framework {
class GradOpDescMakerBase {
public:
explicit GradOpDescMakerBase(const OpDescBind& fwd_op) : fwd_op_(fwd_op) {}
virtual ~GradOpDescMakerBase() = default;
virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0;
protected:
static std::vector<std::string> ToGradNames(
const std::vector<std::string>& var_names) {
std::vector<std::string> ret_val;
ret_val.reserve(var_names.size());
std::transform(var_names.begin(), var_names.end(),
std::back_inserter(ret_val), GradVarName);
return ret_val;
}
std::vector<std::string> InputGrad(const std::string& name) const {
return ToGradNames(fwd_op_.Input(name));
}
std::vector<std::string> OutputGrad(const std::string& name) const {
return ToGradNames(fwd_op_.Output(name));
}
std::vector<std::string> InputNames() const {
return this->fwd_op_.InputNames();
}
std::vector<std::string> OutputNames() const {
return this->fwd_op_.OutputNames();
}
std::vector<std::string> Input(const std::string& name) const {
return fwd_op_.Input(name);
}
std::vector<std::string> Output(const std::string& name) const {
return fwd_op_.Output(name);
}
const std::unordered_map<std::string, Attribute>& Attrs() const {
return fwd_op_.GetAttrMap();
}
const Attribute& GetAttr(const std::string& name) const {
auto& map = fwd_op_.GetAttrMap();
auto it = map.find(name);
PADDLE_ENFORCE(it != map.end(), "Cannot find attribute %s", name);
return it->second;
}
std::string ForwardOpType() const { return this->fwd_op_.Type(); }
private:
const OpDescBind& fwd_op_;
};
class SingleGradOpDescMaker : public GradOpDescMakerBase {
public:
using GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<OpDescBind>> operator()() const {
std::vector<std::unique_ptr<OpDescBind>> retv;
retv.emplace_back(this->Apply());
return retv;
}
protected:
virtual std::unique_ptr<OpDescBind> Apply() const = 0;
};
class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
public:
using SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
virtual std::unique_ptr<OpDescBind> Apply() const {
auto* grad = new OpDescBind();
grad->SetType(this->GradOpType());
for (auto& input_param : this->InputNames()) {
grad->SetInput(input_param, this->Input(input_param));
grad->SetOutput(GradVarName(input_param), this->InputGrad(input_param));
}
for (auto& output_param : this->OutputNames()) {
grad->SetInput(output_param, this->Output(output_param));
grad->SetInput(GradVarName(output_param), this->OutputGrad(output_param));
}
grad->SetAttrMap(this->Attrs());
return std::unique_ptr<OpDescBind>(grad);
}
virtual std::string GradOpType() const {
return this->ForwardOpType() + "_grad";
}
};
} // namespace framework
} // namespace paddle
...@@ -31,15 +31,6 @@ const std::vector<std::string> &OpDescBind::Input( ...@@ -31,15 +31,6 @@ const std::vector<std::string> &OpDescBind::Input(
return it->second; return it->second;
} }
std::vector<std::string> OpDescBind::InputNames() const {
std::vector<std::string> retv;
retv.reserve(this->inputs_.size());
for (auto &ipt : this->inputs_) {
retv.push_back(ipt.first);
}
return retv;
}
void OpDescBind::SetInput(const std::string &param_name, void OpDescBind::SetInput(const std::string &param_name,
const std::vector<std::string> &args) { const std::vector<std::string> &args) {
need_update_ = true; need_update_ = true;
...@@ -54,15 +45,6 @@ const std::vector<std::string> &OpDescBind::Output( ...@@ -54,15 +45,6 @@ const std::vector<std::string> &OpDescBind::Output(
return it->second; return it->second;
} }
std::vector<std::string> OpDescBind::OutputNames() const {
std::vector<std::string> retv;
retv.reserve(this->outputs_.size());
for (auto &ipt : this->outputs_) {
retv.push_back(ipt.first);
}
return retv;
}
void OpDescBind::SetOutput(const std::string &param_name, void OpDescBind::SetOutput(const std::string &param_name,
const std::vector<std::string> &args) { const std::vector<std::string> &args) {
need_update_ = true; need_update_ = true;
......
...@@ -35,15 +35,11 @@ class OpDescBind { ...@@ -35,15 +35,11 @@ class OpDescBind {
const std::vector<std::string> &Input(const std::string &name) const; const std::vector<std::string> &Input(const std::string &name) const;
std::vector<std::string> InputNames() const;
void SetInput(const std::string &param_name, void SetInput(const std::string &param_name,
const std::vector<std::string> &args); const std::vector<std::string> &args);
const std::vector<std::string> &Output(const std::string &name) const; const std::vector<std::string> &Output(const std::string &name) const;
std::vector<std::string> OutputNames() const;
void SetOutput(const std::string &param_name, void SetOutput(const std::string &param_name,
const std::vector<std::string> &args); const std::vector<std::string> &args);
...@@ -61,9 +57,6 @@ class OpDescBind { ...@@ -61,9 +57,6 @@ class OpDescBind {
void SetBlockAttr(const std::string &name, BlockDescBind &block); void SetBlockAttr(const std::string &name, BlockDescBind &block);
// Only be used in C++
void SetAttrMap(const AttributeMap &attr_map);
Attribute GetAttr(const std::string &name) const; Attribute GetAttr(const std::string &name) const;
int GetBlockAttr(const std::string &name) const; int GetBlockAttr(const std::string &name) const;
...@@ -71,7 +64,23 @@ class OpDescBind { ...@@ -71,7 +64,23 @@ class OpDescBind {
// Only be used in C++ // Only be used in C++
const AttributeMap &GetAttrMap() const; const AttributeMap &GetAttrMap() const;
// Only be used in C++
void SetAttrMap(const AttributeMap &attr_map);
std::vector<std::string> InputNames() const { return MapKeys(inputs_); }
std::vector<std::string> OutputNames() const { return MapKeys(outputs_); }
private: private:
template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
std::vector<typename MapType::key_type> ret_val;
ret_val.reserve(map.size());
std::transform(
map.begin(), map.end(), std::back_inserter(ret_val),
[](const typename MapType::value_type &pair) { return pair.first; });
return ret_val;
}
void Sync(); void Sync();
OpDesc op_desc_; OpDesc op_desc_;
......
...@@ -25,16 +25,10 @@ ...@@ -25,16 +25,10 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class GradOpDescMakerBase {
public:
virtual ~GradOpDescMakerBase() = default;
virtual std::vector<OpDescBind> operator()(const OpDescBind&) const = 0;
};
struct OpInfo { struct OpInfo {
OpCreator creator_; OpCreator creator_;
std::string grad_op_type_; std::string grad_op_type_;
GradOpDescMakerBase* grad_op_maker_{nullptr}; GradOpMakerFN grad_op_maker_;
OpProto* proto_{nullptr}; OpProto* proto_{nullptr};
OpAttrChecker* checker_{nullptr}; OpAttrChecker* checker_{nullptr};
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OperatorBase; class OperatorBase;
class OpDescBind;
using VariableNameMap = std::map<std::string, std::vector<std::string>>; using VariableNameMap = std::map<std::string, std::vector<std::string>>;
// The order should be as same as framework.proto // The order should be as same as framework.proto
...@@ -34,5 +35,8 @@ using OpCreator = std::function<OperatorBase*( ...@@ -34,5 +35,8 @@ using OpCreator = std::function<OperatorBase*(
const std::string& /*type*/, const VariableNameMap& /*inputs*/, const std::string& /*type*/, const VariableNameMap& /*inputs*/,
const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>; const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
using GradOpMakerFN =
std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>;
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册