提交 b884bc33 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #4551 from reyoung/feature/grad_reg_mechanism_cont

Add helper function in GradOpDescMakerBase. Make it easier to use.
......@@ -33,22 +33,45 @@ The mapping relationship between an operator and its gradient operators is a fun
```cpp
// (OpDesc) --> vector<OpDesc>
using GradOpDescMaker = std::function<std::vector<OpDesc>(const OpDesc&)>;
std::function<std::vector<OpDescBind>(const OpDescBind&)>;
```
The function take a `OpDesc` of the forward operator and return one or many gradient operator descriptions.
The function takes an `OpDescBind` of the forward operator and returns one or many gradient operator descriptions. `OpDescBind` is a C++ wrapper for protobuf message `OpDesc` to manipulate `OpDesc` fast.
The `GradOpDescMaker` will be registered in `OpInfo`, to replace `grad_op_type_` field. The `OpInfo` should be
```cpp
struct OpInfo {
GradOpDescMaker grad_op_maker_;
std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)> grad_op_maker_;
...
};
```
The `grad_op_maker_ ` is `nullptr` if the operator does not have associated gradient operators.
We propose a base class called `GradOpDescMakerBase` to let operator developers generate `Gradient Operators` easily. The public interface of that class is
```cpp
class GradOpDescMakerBase {
public:
GradOpDescMakerBase(const OpDescBind& );
virtual std::vector<std::unique_ptr<OpDescBind>> operator()()const = 0;
};
```
We can convert `GradOpDescMakerBase` to `std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>` by
```cpp
using GradOpMaker = ...;
std::function<std::vector<OpDescBind>(const OpDescBind&)> func;
func = [] (const OpDescBind& fwd_op) {
GradOpMaker maker(fwd_op);
return maker();
};
```
We can write many helper functions since the `GradOpDescMakerBase` is a class now. The basic helper functions get the variables of `Input`, `Output`, `InputGradient` and `OutputGradient` in the forwarding operator.
We should chagne register macros at the same time. In the current solution, there is no difference between forwarding operators and backward operators. So `REGISTER_OP` just register one operator. If the `REGISTER_OPERATOR ` contains `OpProtoAndCheckerMaker` and `GradOpDescMaker`, we just list them in the same macro. It can be done by a macro contains `__VA_ARGS__`.
The user interface should be
......
......@@ -14,6 +14,7 @@
#pragma once
#include "paddle/framework/grad_op_desc_maker.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_proto_maker.h"
#include "paddle/framework/operator.h"
......@@ -96,7 +97,10 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
template <typename T>
struct OpInfoFiller<T, kGradOpDescMaker> {
void operator()(const char* op_type, OpInfo* info) const {
info->grad_op_maker_ = new T();
info->grad_op_maker_ = [](const OpDescBind& fwd_op) {
T maker(fwd_op);
return maker();
};
}
};
} // namespace details
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace framework {
class GradOpDescMakerBase {
public:
explicit GradOpDescMakerBase(const OpDescBind& fwd_op) : fwd_op_(fwd_op) {}
virtual ~GradOpDescMakerBase() = default;
virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0;
protected:
static std::vector<std::string> ToGradNames(
const std::vector<std::string>& var_names) {
std::vector<std::string> ret_val;
ret_val.reserve(var_names.size());
std::transform(var_names.begin(), var_names.end(),
std::back_inserter(ret_val), GradVarName);
return ret_val;
}
std::vector<std::string> InputGrad(const std::string& name) const {
return ToGradNames(fwd_op_.Input(name));
}
std::vector<std::string> OutputGrad(const std::string& name) const {
return ToGradNames(fwd_op_.Output(name));
}
std::vector<std::string> InputNames() const {
return this->fwd_op_.InputNames();
}
std::vector<std::string> OutputNames() const {
return this->fwd_op_.OutputNames();
}
std::vector<std::string> Input(const std::string& name) const {
return fwd_op_.Input(name);
}
std::vector<std::string> Output(const std::string& name) const {
return fwd_op_.Output(name);
}
const std::unordered_map<std::string, Attribute>& Attrs() const {
return fwd_op_.GetAttrMap();
}
const Attribute& GetAttr(const std::string& name) const {
auto& map = fwd_op_.GetAttrMap();
auto it = map.find(name);
PADDLE_ENFORCE(it != map.end(), "Cannot find attribute %s", name);
return it->second;
}
std::string ForwardOpType() const { return this->fwd_op_.Type(); }
private:
const OpDescBind& fwd_op_;
};
class SingleGradOpDescMaker : public GradOpDescMakerBase {
public:
using GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<OpDescBind>> operator()() const {
std::vector<std::unique_ptr<OpDescBind>> retv;
retv.emplace_back(this->Apply());
return retv;
}
protected:
virtual std::unique_ptr<OpDescBind> Apply() const = 0;
};
class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
public:
using SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
virtual std::unique_ptr<OpDescBind> Apply() const {
auto* grad = new OpDescBind();
grad->SetType(this->GradOpType());
for (auto& input_param : this->InputNames()) {
grad->SetInput(input_param, this->Input(input_param));
grad->SetOutput(GradVarName(input_param), this->InputGrad(input_param));
}
for (auto& output_param : this->OutputNames()) {
grad->SetInput(output_param, this->Output(output_param));
grad->SetInput(GradVarName(output_param), this->OutputGrad(output_param));
}
grad->SetAttrMap(this->Attrs());
return std::unique_ptr<OpDescBind>(grad);
}
virtual std::string GradOpType() const {
return this->ForwardOpType() + "_grad";
}
};
} // namespace framework
} // namespace paddle
......@@ -31,15 +31,6 @@ const std::vector<std::string> &OpDescBind::Input(
return it->second;
}
std::vector<std::string> OpDescBind::InputNames() const {
std::vector<std::string> retv;
retv.reserve(this->inputs_.size());
for (auto &ipt : this->inputs_) {
retv.push_back(ipt.first);
}
return retv;
}
void OpDescBind::SetInput(const std::string &param_name,
const std::vector<std::string> &args) {
need_update_ = true;
......@@ -54,15 +45,6 @@ const std::vector<std::string> &OpDescBind::Output(
return it->second;
}
std::vector<std::string> OpDescBind::OutputNames() const {
std::vector<std::string> retv;
retv.reserve(this->outputs_.size());
for (auto &ipt : this->outputs_) {
retv.push_back(ipt.first);
}
return retv;
}
void OpDescBind::SetOutput(const std::string &param_name,
const std::vector<std::string> &args) {
need_update_ = true;
......
......@@ -35,15 +35,11 @@ class OpDescBind {
const std::vector<std::string> &Input(const std::string &name) const;
std::vector<std::string> InputNames() const;
void SetInput(const std::string &param_name,
const std::vector<std::string> &args);
const std::vector<std::string> &Output(const std::string &name) const;
std::vector<std::string> OutputNames() const;
void SetOutput(const std::string &param_name,
const std::vector<std::string> &args);
......@@ -61,9 +57,6 @@ class OpDescBind {
void SetBlockAttr(const std::string &name, BlockDescBind &block);
// Only be used in C++
void SetAttrMap(const AttributeMap &attr_map);
Attribute GetAttr(const std::string &name) const;
int GetBlockAttr(const std::string &name) const;
......@@ -71,7 +64,23 @@ class OpDescBind {
// Only be used in C++
const AttributeMap &GetAttrMap() const;
// Only be used in C++
void SetAttrMap(const AttributeMap &attr_map);
std::vector<std::string> InputNames() const { return MapKeys(inputs_); }
std::vector<std::string> OutputNames() const { return MapKeys(outputs_); }
private:
template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
std::vector<typename MapType::key_type> ret_val;
ret_val.reserve(map.size());
std::transform(
map.begin(), map.end(), std::back_inserter(ret_val),
[](const typename MapType::value_type &pair) { return pair.first; });
return ret_val;
}
void Sync();
OpDesc op_desc_;
......
......@@ -25,16 +25,10 @@
namespace paddle {
namespace framework {
class GradOpDescMakerBase {
public:
virtual ~GradOpDescMakerBase() = default;
virtual std::vector<OpDescBind> operator()(const OpDescBind&) const = 0;
};
struct OpInfo {
OpCreator creator_;
std::string grad_op_type_;
GradOpDescMakerBase* grad_op_maker_{nullptr};
GradOpMakerFN grad_op_maker_;
OpProto* proto_{nullptr};
OpAttrChecker* checker_{nullptr};
......
......@@ -20,6 +20,7 @@
namespace paddle {
namespace framework {
class OperatorBase;
class OpDescBind;
using VariableNameMap = std::map<std::string, std::vector<std::string>>;
// The order should be as same as framework.proto
......@@ -34,5 +35,8 @@ using OpCreator = std::function<OperatorBase*(
const std::string& /*type*/, const VariableNameMap& /*inputs*/,
const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
using GradOpMakerFN =
std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>;
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册