提交 6f12fd28 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #3192 from Canpio/dev_simplify_GradOpBuilder

Simplify building process of gradient operators
...@@ -42,9 +42,9 @@ static std::shared_ptr<OperatorBase> NOP() { ...@@ -42,9 +42,9 @@ static std::shared_ptr<OperatorBase> NOP() {
// //
// no_grad_names the gradient variable names without gradient calculating. // no_grad_names the gradient variable names without gradient calculating.
// //
// uniq_id is a unique index used inside recursively calling BackwardRecursive. // uniq_id is a unique index used inside recursively calling
// use `uid = uniq_id++;` to get the unique index, and pass `uniq_id` through // BackwardRecursive. use `uid = uniq_id++;` to get the unique index, and
// recursive calling. // pass `uniq_id` through recursive calling.
// //
// returns The backward operator. For simple situation, it is a simple // returns The backward operator. For simple situation, it is a simple
// operator. For complex situation, it is a NetOp. // operator. For complex situation, it is a NetOp.
...@@ -64,8 +64,8 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -64,8 +64,8 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
return NOP(); return NOP();
} }
// All output gradients of forwarding operator do not need to calculate. Then // All output gradients of forwarding operator do not need to calculate.
// all input gradients cannot be computed at all, and we put them into // Then all input gradients cannot be computed at all, and we put them into
// `no_grad_names` set. Return an NOP. // `no_grad_names` set. Return an NOP.
if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(), if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(),
no_grad_names)) { no_grad_names)) {
...@@ -83,8 +83,8 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -83,8 +83,8 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// Because forwardOp is a net op, it can static_cast. // Because forwardOp is a net op, it can static_cast.
auto& forwardNet = static_cast<const operators::NetOp&>(forwardOp); auto& forwardNet = static_cast<const operators::NetOp&>(forwardOp);
// Map from output gradient variable name to operator's indices in backward // Map from output gradient variable name to operator's indices in
// net. That operator generates that variable. // backward net. That operator generates that variable.
std::unordered_map<std::string, std::vector<size_t>> dup_output_ops; std::unordered_map<std::string, std::vector<size_t>> dup_output_ops;
size_t local_op_id = 0; size_t local_op_id = 0;
......
...@@ -162,8 +162,8 @@ TEST(Backward, simple_op_grad) { ...@@ -162,8 +162,8 @@ TEST(Backward, simple_op_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
auto gop = f::OpRegistry::CreateGradOp(*fwd); auto gop = f::OpRegistry::CreateGradOp(*fwd);
ASSERT_EQ(1UL, gop->inputs_.size()); ASSERT_EQ(4UL, gop->inputs_.size());
ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]); ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(), gop->inputs_[0]);
ASSERT_EQ("rowwise_add_grad", gop->type_); ASSERT_EQ("rowwise_add_grad", gop->type_);
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]); ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]);
ASSERT_EQ("b" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[1]); ASSERT_EQ("b" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[1]);
...@@ -360,7 +360,6 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ...@@ -360,7 +360,6 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
3UL /* external input number */ 3UL /* external input number */
+ 1UL /* external output number*/ + 1UL /* external output number*/
+ 1UL /* number of gradient of external output*/ + 1UL /* number of gradient of external output*/
- 1UL /*ignoreGradient varable number*/
+ 2U /* internal variable number*/); + 2U /* internal variable number*/);
EXPECT_EQ(grad_fc.outputs_.size(), 2UL /* input number of mul*/ EXPECT_EQ(grad_fc.outputs_.size(), 2UL /* input number of mul*/
+ 2UL /* input number of rowwise_add */ + 2UL /* input number of rowwise_add */
......
...@@ -8,107 +8,97 @@ You may obtain a copy of the License at ...@@ -8,107 +8,97 @@ You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOpArgType::OUT WARRANTIES OR CONDITIONS OF ANY KOpArgType::IND, either
See the License for the specific language governing permissions and express or implied. See the License for the specific language governing
limitations under the License. */ permissions and limitations under the License. */
#include "paddle/framework/grad_op_builder.h" #include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
OperatorBase* GradOpBuilder::Build() { class OpRegistry;
BuildOpInOutArgList();
std::string grad_op_type = OpRegistry::grad_ops().at(op_.type_); using VarIndexMap = std::unordered_map<std::string, int>;
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
grad_op->type_ = grad_op_type;
CompleteGradOp(grad_op);
return grad_op;
}
OpInOutArg* GradOpBuilder::BuildArg(const VarProto& var, enum class OpArgType { IN, OUT };
const VarIndexMap& var_map,
const std::vector<int>& format, static std::vector<int>* GetOpFormat(OperatorBase* op, const OpArgType& type) {
InOutType type) { std::string key = type == OpArgType::IN ? "input_format" : "output_name";
int idx = var_map.at(var.name()); return op->attrs_.count(key)
int begin_idx = format.empty() ? idx : format.at(idx); ? &boost::get<std::vector<int>>(op->attrs_.at(key))
int end_idx = format.empty() ? idx + 1 : format.at(idx + 1); : nullptr;
return new OpInOutArg(var.name(), type, !var.ignore_gradient(), begin_idx,
end_idx);
} }
void GradOpBuilder::BuildOpInOutArgList() { static const std::vector<int>* GetOpFormat(const OperatorBase* op,
const OpProto& op_proto = OpRegistry::protos().at(op_.type_); const OpArgType& type) {
const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_.type_)); std::string key = type == OpArgType::IN ? "input_format" : "output_name";
const std::vector<int>& in_format = return op->attrs_.count(key)
op_.attrs_.count("input_format") ? &boost::get<std::vector<int>>(op->attrs_.at(key))
? op_.GetAttr<std::vector<int>>("input_format") : nullptr;
: std::vector<int>();
const std::vector<int>& out_format =
op_.attrs_.count("output_format")
? op_.GetAttr<std::vector<int>>("output_format")
: std::vector<int>();
for (const auto& var : op_proto.inputs()) {
arg_list_.emplace_back(
std::shared_ptr<OpInOutArg>(BuildArg(var, var_map, in_format, IN)));
}
for (const auto& var : op_proto.outputs()) {
arg_list_.emplace_back(
std::shared_ptr<OpInOutArg>(BuildArg(var, var_map, out_format, OUT)));
}
} }
void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg, static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
std::vector<std::string>& in_out, const OpArgType& src_type, const OpArgType& dst_type,
std::vector<int>& format, int& idx, bool is_grad) {
VarIndexMap* varmap, int& idx, const std::vector<std::string>& src_inout =
bool is_grad) const { src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_;
std::string var_name = arg->proto_name_; const std::vector<int>* src_format = GetOpFormat(src_op, src_type);
if (is_grad) {
var_name += OperatorBase::GRAD_VAR_SUFFIX(); std::vector<std::string>& dst_inout =
} dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_;
(*varmap)[var_name] = idx++; std::vector<int>* dst_format = GetOpFormat(dst_op, dst_type);
size_t pre_sz = in_out.size(); const OpProto& proto = OpRegistry::protos().at(src_op->type_);
auto base_it = arg->type_ == IN ? op_.inputs_.begin() : op_.outputs_.begin(); const auto& src_arg_list =
std::copy(base_it + arg->begin_idx_, base_it + arg->end_idx_, src_type == OpArgType::IN ? proto.inputs() : proto.outputs();
std::back_inserter(in_out));
if (is_grad) { for (const auto& arg : src_arg_list) {
for (size_t i = pre_sz; i < in_out.size(); ++i) { std::string src_name = arg.name();
in_out[i] += OperatorBase::GRAD_VAR_SUFFIX(); std::string dst_name =
is_grad ? src_name + OperatorBase::GRAD_VAR_SUFFIX() : src_name;
(*dst_op->in_out_idxs_)[dst_name] = idx++;
int src_arg_idx = src_op->in_out_idxs_->at(src_name);
int src_begin =
src_format == nullptr ? src_arg_idx : src_format->at(src_arg_idx);
int src_end = src_format == nullptr ? src_arg_idx + 1
: src_format->at(src_arg_idx + 1);
for (int i = src_begin; i < src_end; ++i) {
std::string s = is_grad ? src_inout[i] + OperatorBase::GRAD_VAR_SUFFIX()
: arg.ignore_gradient()
? OperatorBase::EMPTY_VAR_NAME()
: src_inout[i];
dst_inout.emplace_back(s);
}
if (dst_format != nullptr) {
dst_format->push_back(dst_inout.size());
} }
} }
format.push_back(in_out.size());
} }
void GradOpBuilder::CompleteGradOp(OperatorBase* grad_op) const { OperatorBase* BuildGradOp(const OperatorBase* op) {
grad_op->attrs_ = op_.attrs_; std::string grad_op_type = OpRegistry::grad_ops().at(op->type_);
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
grad_op->type_ = grad_op_type;
grad_op->attrs_ = op->attrs_;
grad_op->attrs_.erase("input_format"); grad_op->attrs_.erase("input_format");
grad_op->attrs_.erase("output_format"); grad_op->attrs_.erase("output_format");
VarIndexMap* grad_varmap = new VarIndexMap(); if (GetOpFormat(op, OpArgType::OUT) != nullptr) {
grad_op->attrs_["output_format"] = std::vector<int>({0});
}
if (GetOpFormat(op, OpArgType::IN) != nullptr ||
GetOpFormat(op, OpArgType::OUT) != nullptr) {
grad_op->attrs_["input_format"] = std::vector<int>({0});
}
grad_op->in_out_idxs_.reset(new VarIndexMap());
int in_idx = 0; int in_idx = 0;
int out_idx = 0; int out_idx = 0;
std::vector<int> in_format({0}); TransOpArg(op, grad_op, OpArgType::IN, OpArgType::IN, in_idx, false); // I
std::vector<int> out_format({0}); TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, in_idx, false); // G
for (const auto& arg : arg_list_) { TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, in_idx, true); // OG
// op_'s inputs_ and outputs_ TransOpArg(op, grad_op, OpArgType::IN, OpArgType::OUT, out_idx, true); // IG
if (arg->needed_in_grad_) { return grad_op;
AddArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap,
in_idx, false);
}
if (arg->type_ == IN) {
// gradients of op_'s inputs_
AddArgIntoGradOp(arg.get(), grad_op->outputs_, out_format, grad_varmap,
out_idx, true);
} else {
// gradients of op_'s outputs_
AddArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap,
in_idx, true);
}
}
grad_op->attrs_["input_format"] = in_format;
grad_op->attrs_["output_format"] = out_format;
grad_op->in_out_idxs_.reset(grad_varmap);
} }
} // namespace framework } // namespace framework
......
#pragma once #pragma once
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OpRegistry;
enum InOutType { IN, OUT }; OperatorBase* BuildGradOp(const OperatorBase* op);
struct OpInOutArg {
OpInOutArg(const std::string& proto_name, const InOutType& type,
bool needed_in_grad, size_t begin_idx, size_t end_idx)
: proto_name_(proto_name),
type_(type),
needed_in_grad_(needed_in_grad),
begin_idx_(begin_idx),
end_idx_(end_idx) {}
std::string proto_name_;
InOutType type_;
bool needed_in_grad_;
size_t begin_idx_;
size_t end_idx_;
};
class GradOpBuilder {
using VarIndexMap = std::unordered_map<std::string, int>;
public:
GradOpBuilder(const OperatorBase& op) : op_(op) {}
OperatorBase* Build();
private:
OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map,
const std::vector<int>& format, InOutType type);
void BuildOpInOutArgList();
void AddArgIntoGradOp(const OpInOutArg* arg, std::vector<std::string>& in_out,
std::vector<int>& format, VarIndexMap* varmap, int& idx,
bool is_grad) const;
void CompleteGradOp(OperatorBase* grad_op) const;
const OperatorBase& op_;
std::vector<std::shared_ptr<OpInOutArg>> arg_list_;
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -306,8 +306,7 @@ class OpRegistry { ...@@ -306,8 +306,7 @@ class OpRegistry {
static std::shared_ptr<OperatorBase> CreateGradOp(const OperatorBase& op) { static std::shared_ptr<OperatorBase> CreateGradOp(const OperatorBase& op) {
PADDLE_ENFORCE(!op.IsNetOp(), PADDLE_ENFORCE(!op.IsNetOp(),
"Use framework::Backward to get backward ops"); "Use framework::Backward to get backward ops");
GradOpBuilder builder(op); std::shared_ptr<OperatorBase> grad_op(BuildGradOp(&op));
std::shared_ptr<OperatorBase> grad_op(builder.Build());
grad_op->Init(); grad_op->Init();
return grad_op; return grad_op;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册