提交 bcfe391e 编写于 作者: Z zchen0211

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into develop

......@@ -33,22 +33,45 @@ The mapping relationship between an operator and its gradient operators is a fun
```cpp
// (OpDesc) --> vector<OpDesc>
using GradOpDescMaker = std::function<std::vector<OpDesc>(const OpDesc&)>;
std::function<std::vector<OpDescBind>(const OpDescBind&)>;
```
The function take a `OpDesc` of the forward operator and return one or many gradient operator descriptions.
The function takes an `OpDescBind` of the forward operator and returns one or many gradient operator descriptions. `OpDescBind` is a C++ wrapper for protobuf message `OpDesc` to manipulate `OpDesc` fast.
The `GradOpDescMaker` will be registered in `OpInfo`, to replace `grad_op_type_` field. The `OpInfo` should be
```cpp
struct OpInfo {
GradOpDescMaker grad_op_maker_;
std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)> grad_op_maker_;
...
};
```
The `grad_op_maker_ ` is `nullptr` if the operator does not have associated gradient operators.
We propose a base class called `GradOpDescMakerBase` to let operator developers generate `Gradient Operators` easily. The public interface of that class is
```cpp
class GradOpDescMakerBase {
public:
GradOpDescMakerBase(const OpDescBind& );
virtual std::vector<std::unique_ptr<OpDescBind>> operator()()const = 0;
};
```
We can convert `GradOpDescMakerBase` to `std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>` by
```cpp
using GradOpMaker = ...;
std::function<std::vector<OpDescBind>(const OpDescBind&)> func;
func = [] (const OpDescBind& fwd_op) {
GradOpMaker maker(fwd_op);
return maker();
};
```
We can write many helper functions since the `GradOpDescMakerBase` is a class now. The basic helper functions get the variables of `Input`, `Output`, `InputGradient` and `OutputGradient` in the forwarding operator.
We should chagne register macros at the same time. In the current solution, there is no difference between forwarding operators and backward operators. So `REGISTER_OP` just register one operator. If the `REGISTER_OPERATOR ` contains `OpProtoAndCheckerMaker` and `GradOpDescMaker`, we just list them in the same macro. It can be done by a macro contains `__VA_ARGS__`.
The user interface should be
......
......@@ -147,7 +147,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
for (size_t output_idx = 0; output_idx < dup_outputs.size() - 1;
++output_idx) {
auto insert_add_x = dup_outputs[output_idx];
auto insert_add_y = dup_outputs[output_idx];
auto insert_add_y = dup_outputs[output_idx + 1];
auto insert_add_out = name + "@SHARED@" + std::to_string(output_idx);
// first add op inserted
if (output_idx == dup_outputs.size() - 2) {
......@@ -158,9 +158,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
}
insert_position.push_back(
{dup_op.back(),
OpRegistry::CreateOp(
"sum", {{"X", {insert_add_x}}, {"X", {insert_add_y}}},
{{"Out", {insert_add_out}}}, {})});
OpRegistry::CreateOp("sum", {{"X", {insert_add_x, insert_add_y}}},
{{"Out", {insert_add_out}}}, {})});
}
}
......@@ -200,7 +199,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
// process recurrent gradient op as a special operator.
if (forwardOp.Type() == "recurrent") {
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself), or
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself),
// or
// this will result in infinite loop.
const auto& rnnop =
*static_cast<const operators::RecurrentOp*>(&forwardOp);
......
......@@ -14,6 +14,7 @@
#pragma once
#include "paddle/framework/grad_op_desc_maker.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_proto_maker.h"
#include "paddle/framework/operator.h"
......@@ -96,7 +97,10 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
template <typename T>
struct OpInfoFiller<T, kGradOpDescMaker> {
void operator()(const char* op_type, OpInfo* info) const {
info->grad_op_maker_ = new T();
info->grad_op_maker_ = [](const OpDescBind& fwd_op) {
T maker(fwd_op);
return maker();
};
}
};
} // namespace details
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace framework {
class GradOpDescMakerBase {
public:
explicit GradOpDescMakerBase(const OpDescBind& fwd_op) : fwd_op_(fwd_op) {}
virtual ~GradOpDescMakerBase() = default;
virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0;
protected:
static std::vector<std::string> ToGradNames(
const std::vector<std::string>& var_names) {
std::vector<std::string> ret_val;
ret_val.reserve(var_names.size());
std::transform(var_names.begin(), var_names.end(),
std::back_inserter(ret_val), GradVarName);
return ret_val;
}
std::vector<std::string> InputGrad(const std::string& name) const {
return ToGradNames(fwd_op_.Input(name));
}
std::vector<std::string> OutputGrad(const std::string& name) const {
return ToGradNames(fwd_op_.Output(name));
}
std::vector<std::string> InputNames() const {
return this->fwd_op_.InputNames();
}
std::vector<std::string> OutputNames() const {
return this->fwd_op_.OutputNames();
}
std::vector<std::string> Input(const std::string& name) const {
return fwd_op_.Input(name);
}
std::vector<std::string> Output(const std::string& name) const {
return fwd_op_.Output(name);
}
const std::unordered_map<std::string, Attribute>& Attrs() const {
return fwd_op_.GetAttrMap();
}
const Attribute& GetAttr(const std::string& name) const {
auto& map = fwd_op_.GetAttrMap();
auto it = map.find(name);
PADDLE_ENFORCE(it != map.end(), "Cannot find attribute %s", name);
return it->second;
}
std::string ForwardOpType() const { return this->fwd_op_.Type(); }
private:
const OpDescBind& fwd_op_;
};
class SingleGradOpDescMaker : public GradOpDescMakerBase {
public:
using GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<OpDescBind>> operator()() const {
std::vector<std::unique_ptr<OpDescBind>> retv;
retv.emplace_back(this->Apply());
return retv;
}
protected:
virtual std::unique_ptr<OpDescBind> Apply() const = 0;
};
class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
public:
using SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
virtual std::unique_ptr<OpDescBind> Apply() const {
auto* grad = new OpDescBind();
grad->SetType(this->GradOpType());
for (auto& input_param : this->InputNames()) {
grad->SetInput(input_param, this->Input(input_param));
grad->SetOutput(GradVarName(input_param), this->InputGrad(input_param));
}
for (auto& output_param : this->OutputNames()) {
grad->SetInput(output_param, this->Output(output_param));
grad->SetInput(GradVarName(output_param), this->OutputGrad(output_param));
}
grad->SetAttrMap(this->Attrs());
return std::unique_ptr<OpDescBind>(grad);
}
virtual std::string GradOpType() const {
return this->ForwardOpType() + "_grad";
}
};
} // namespace framework
} // namespace paddle
......@@ -31,15 +31,6 @@ const std::vector<std::string> &OpDescBind::Input(
return it->second;
}
std::vector<std::string> OpDescBind::InputNames() const {
std::vector<std::string> retv;
retv.reserve(this->inputs_.size());
for (auto &ipt : this->inputs_) {
retv.push_back(ipt.first);
}
return retv;
}
void OpDescBind::SetInput(const std::string &param_name,
const std::vector<std::string> &args) {
need_update_ = true;
......@@ -54,15 +45,6 @@ const std::vector<std::string> &OpDescBind::Output(
return it->second;
}
std::vector<std::string> OpDescBind::OutputNames() const {
std::vector<std::string> retv;
retv.reserve(this->outputs_.size());
for (auto &ipt : this->outputs_) {
retv.push_back(ipt.first);
}
return retv;
}
void OpDescBind::SetOutput(const std::string &param_name,
const std::vector<std::string> &args) {
need_update_ = true;
......
......@@ -35,15 +35,11 @@ class OpDescBind {
const std::vector<std::string> &Input(const std::string &name) const;
std::vector<std::string> InputNames() const;
void SetInput(const std::string &param_name,
const std::vector<std::string> &args);
const std::vector<std::string> &Output(const std::string &name) const;
std::vector<std::string> OutputNames() const;
void SetOutput(const std::string &param_name,
const std::vector<std::string> &args);
......@@ -61,9 +57,6 @@ class OpDescBind {
void SetBlockAttr(const std::string &name, BlockDescBind &block);
// Only be used in C++
void SetAttrMap(const AttributeMap &attr_map);
Attribute GetAttr(const std::string &name) const;
int GetBlockAttr(const std::string &name) const;
......@@ -71,7 +64,23 @@ class OpDescBind {
// Only be used in C++
const AttributeMap &GetAttrMap() const;
// Only be used in C++
void SetAttrMap(const AttributeMap &attr_map);
std::vector<std::string> InputNames() const { return MapKeys(inputs_); }
std::vector<std::string> OutputNames() const { return MapKeys(outputs_); }
private:
template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
std::vector<typename MapType::key_type> ret_val;
ret_val.reserve(map.size());
std::transform(
map.begin(), map.end(), std::back_inserter(ret_val),
[](const typename MapType::value_type &pair) { return pair.first; });
return ret_val;
}
void Sync();
OpDesc op_desc_;
......
......@@ -25,16 +25,10 @@
namespace paddle {
namespace framework {
class GradOpDescMakerBase {
public:
virtual ~GradOpDescMakerBase() = default;
virtual std::vector<OpDescBind> operator()(const OpDescBind&) const = 0;
};
struct OpInfo {
OpCreator creator_;
std::string grad_op_type_;
GradOpDescMakerBase* grad_op_maker_{nullptr};
GradOpMakerFN grad_op_maker_;
OpProto* proto_{nullptr};
OpAttrChecker* checker_{nullptr};
......
......@@ -20,6 +20,7 @@
namespace paddle {
namespace framework {
class OperatorBase;
class OpDescBind;
using VariableNameMap = std::map<std::string, std::vector<std::string>>;
// The order should be as same as framework.proto
......@@ -34,5 +35,8 @@ using OpCreator = std::function<OperatorBase*(
const std::string& /*type*/, const VariableNameMap& /*inputs*/,
const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
using GradOpMakerFN =
std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>;
} // namespace framework
} // namespace paddle
......@@ -23,19 +23,22 @@ class SGDOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("param"),
"Input(param) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("grad"),
"Input(grad) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("learning_rate"),
"Input(learning_rate) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("param_out"),
"Output(param_out) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of SGDOp should not be null.");
auto param_dim = ctx->GetInputDim("param");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("grad"),
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 element");
auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"),
"Two input of SGD Op's dimension must be same.");
ctx->SetOutputDim("param_out", param_dim);
ctx->SetOutputDim("ParamOut", param_dim);
}
};
......@@ -43,10 +46,10 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("param", "input parameter");
AddInput("learning_rate", "learning rate of sgd");
AddInput("grad", "input gradient");
AddOutput("param_out", "output parameter");
AddInput("Param", "Input parameter");
AddInput("LearningRate", "Learning rate of SGD");
AddInput("Grad", "Input gradient");
AddOutput("ParamOut", "output parameter");
AddComment(R"DOC(
Simplest sgd algorithm.
......
......@@ -28,10 +28,10 @@ template <typename Place, typename T>
class SGDOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param = ctx.Input<Tensor>("param");
auto grad = ctx.Input<Tensor>("grad");
auto param_out = ctx.Output<Tensor>("param_out");
float lr = *ctx.Input<float>("learning_rate");
auto param = ctx.Input<Tensor>("Param");
auto grad = ctx.Input<Tensor>("Grad");
auto param_out = ctx.Output<Tensor>("ParamOut");
float lr = ctx.Input<Tensor>("LearningRate")->data<float>()[0];
param_out->mutable_data<T>(ctx.GetPlace());
......
......@@ -8,10 +8,10 @@ class TestSGDOp(OpTest):
self.op_type = "sgd"
w = np.random.random((102, 105)).astype("float32")
g = np.random.random((102, 105)).astype("float32")
lr = 0.1
lr = np.array([0.1]).astype("float32")
self.inputs = {'param': w, 'grad': g, 'learning_rate': lr}
self.outputs = {'param_out': w - lr * g}
self.inputs = {'Param': w, 'Grad': g, 'LearningRate': lr}
self.outputs = {'ParamOut': w - lr * g}
def test_check_output(self):
self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册