提交 e2fd2bd0 编写于 作者: Y Yu Yang

Follow comments and merge develop

上级 80baf861
...@@ -31,88 +31,74 @@ static bool AllInSet(const std::vector<std::string>& names, ...@@ -31,88 +31,74 @@ static bool AllInSet(const std::vector<std::string>& names,
return true; return true;
} }
static std::vector<size_t> InSetIdx( static std::shared_ptr<OperatorBase> NOP() {
const std::vector<std::string>& names, const std::string& suffix,
const std::unordered_set<std::string>& set) {
std::vector<size_t> ret_val;
ret_val.reserve(names.size());
for (size_t i = 0; i < names.size(); ++i) {
if (set.find(names[i] + suffix) != set.end()) {
ret_val.push_back(i);
}
}
return ret_val;
}
static std::shared_ptr<OperatorBase> EmptyOp() {
auto net_op = std::make_shared<NetOp>(); auto net_op = std::make_shared<NetOp>();
net_op->type_ = "@EMPTY_OP@"; net_op->type_ = "@NOP@";
net_op->CompleteAddOp(); net_op->CompleteAddOp();
return net_op; return net_op;
} }
/** // Get backward operator from a forward operator, recursively implementation.
* @brief Backward an operator, implementation //
* @param forwardOp the forward operator // no_grad_names the gradient variable names without gradient calculating.
* @param no_grad_names variable names not calculate for gradient. Like X@GRAD //
* is not needed. // uniq_id is a unique index used inside recursively calling BackwardRecursive.
* @param uniq_id a unique index used inside BackwardImpl, it will be shared // use `uid = uniq_id++;` to get the unique index, and pass `uniq_id` through
* through recursive invoke. // recursive calling.
* @return The backward operator. For simple situation, it is a simple operator. //
* For complex situation, it is a NetOp. // returns The backward operator. For simple situation, it is a simple
* // operator. For complex situation, it is a NetOp.
* See Backward.h for details //
*/ // See Backward.h for details
static std::shared_ptr<OperatorBase> BackwardImpl( static std::shared_ptr<OperatorBase> BackwardRecursive(
const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id);
std::shared_ptr<OperatorBase> BackwardRecursive(
const OperatorBase& forwardOp, const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) { std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
/** // If all input gradients of forwarding operator do not need to calculate,
* If all input gradients of forwarding operator do not need to calculate, // just return an NOP. Not return null ptr because NOP does not take
* just return an EmptyOp. Not return null ptr because EmptyOp does not take // too much time for calculation, but it is useful for simplifying logic.
* too much time for calculation, but it is useful for simplifying logic.
*/
if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(), if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(),
no_grad_names)) { no_grad_names)) {
return EmptyOp(); return NOP();
} }
/** // All output gradients of forwarding operator do not need to calculate. Then
* All output gradients of forwarding operator do not need to calculate. Then // all input gradients cannot be computed at all, and we put them into
* all input gradients cannot be computed at all, and we put them into // `no_grad_names` set. Return an NOP.
* `no_grad_names` set. Return an EmptyOp.
*/
if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(), if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(),
no_grad_names)) { no_grad_names)) {
for (auto& name : forwardOp.inputs_) { for (auto& name : forwardOp.inputs_) {
/// Mark all input is not need // Mark all input is not need
no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX());
} }
return EmptyOp(); return NOP();
} }
//! Returned gradient network // Returned gradient network
auto net = std::make_shared<NetOp>(); auto net = std::make_shared<NetOp>();
if (forwardOp.IsNetOp()) { if (forwardOp.IsNetOp()) {
/// Because forwardOp is a net op, it can static_cast. // Because forwardOp is a net op, it can static_cast.
auto& forwardNet = static_cast<const NetOp&>(forwardOp); auto& forwardNet = static_cast<const NetOp&>(forwardOp);
//! Map from output gradient variable name to operator's indices in backward // Map from output gradient variable name to operator's indices in backward
//! net. That operator generates that variable. // net. That operator generates that variable.
std::unordered_map<std::string, std::vector<size_t>> dup_output_ops; std::unordered_map<std::string, std::vector<size_t>> dup_output_ops;
size_t local_op_id = 0; size_t local_op_id = 0;
/// reversely travel forwardNet // reversely travel forwardNet
for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
++it, ++local_op_id) { ++it, ++local_op_id) {
auto fwd = *it; auto fwd = *it;
auto bwd = BackwardImpl(*fwd, no_grad_names, uniq_id); auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id);
net->AddOp(bwd); net->AddOp(bwd);
for (auto& out : bwd->outputs_) { for (auto& out : bwd->outputs_) {
dup_output_ops[out].emplace_back(local_op_id); dup_output_ops[out].emplace_back(local_op_id);
} }
} }
/// Get unique ID for this method. // Get unique ID for this method.
auto uid = uniq_id++; auto uid = uniq_id++;
// TODO(dzh): more comment // TODO(dzh): more comment
using Pos = std::pair<size_t, std::shared_ptr<OperatorBase>>; using Pos = std::pair<size_t, std::shared_ptr<OperatorBase>>;
...@@ -145,13 +131,15 @@ static std::shared_ptr<OperatorBase> BackwardImpl( ...@@ -145,13 +131,15 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
} }
} else { } else {
//! TODO(fjy)
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp); std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
for (std::string& grad_input : grad_op->inputs_) { for (std::string& grad_input : grad_op->inputs_) {
if (no_grad_names.count(grad_input)) { if (no_grad_names.count(grad_input)) {
std::string prefix = grad_input.substr( std::string prefix = grad_input.substr(
0, grad_input.size() - OperatorBase::GRAD_VAR_SUFFIX().size()); 0, grad_input.size() - OperatorBase::GRAD_VAR_SUFFIX().size());
grad_input = prefix + OperatorBase::ZERO_VAR_SUFFIX(); grad_input = prefix + OperatorBase::ZERO_VAR_SUFFIX();
// If part of input gradient of that operator is not calculated, fill
// zero variables to that input gradient.
net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {prefix}, net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {prefix},
{grad_input}, {})); {grad_input}, {}));
} }
...@@ -173,8 +161,8 @@ static std::shared_ptr<OperatorBase> BackwardImpl( ...@@ -173,8 +161,8 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
return net; return net;
} }
//! See header for comments // See header for comments
extern std::shared_ptr<OperatorBase> Backward( std::shared_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp, const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars) { const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_set<std::string> no_grad_names; std::unordered_set<std::string> no_grad_names;
...@@ -184,7 +172,7 @@ extern std::shared_ptr<OperatorBase> Backward( ...@@ -184,7 +172,7 @@ extern std::shared_ptr<OperatorBase> Backward(
no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX());
} }
size_t uid = 0; size_t uid = 0;
return BackwardImpl(forwardOp, no_grad_names, uid); return BackwardRecursive(forwardOp, no_grad_names, uid);
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -18,12 +18,8 @@ ...@@ -18,12 +18,8 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
/** // Create the backward operator from a forward operator.
* @brief // TODO(yuyang18): Add more API reference comment.
* @param forwardOp
* @param no_grad_vars ignored input name of forward
* @return
*/
extern std::shared_ptr<OperatorBase> Backward( extern std::shared_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp, const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars); const std::unordered_set<std::string>& no_grad_vars);
......
...@@ -169,7 +169,6 @@ TEST(Backward, simple_op_grad) { ...@@ -169,7 +169,6 @@ TEST(Backward, simple_op_grad) {
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(),
gop->Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX())); gop->Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()));
// LOG(INFO) << gop->Output("X" + "@GRAD");
} }
TEST(Backward, simple_op_not_need_grad) { TEST(Backward, simple_op_not_need_grad) {
......
...@@ -21,15 +21,17 @@ namespace operators { ...@@ -21,15 +21,17 @@ namespace operators {
class FillZerosLikeOp : public framework::OperatorWithKernel { class FillZerosLikeOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(const framework::InferShapeContext &ctx) const override {
const std::vector<const framework::Tensor *> &inputs, PADDLE_ENFORCE(ctx.InputSize() == 1UL,
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 1,
"Input size of FillZerosLikeOp must be one."); "Input size of FillZerosLikeOp must be one.");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one."); PADDLE_ENFORCE(ctx.OutputSize() == 1UL,
PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr, "Output size of AddOp must be one.");
"Outputs of FillZerosLikeOp must all be set."); PADDLE_ENFORCE(ctx.InputVar(0) != nullptr,
outputs[0]->Resize(inputs[0]->dims()); "Input of FillZerosLikeOp must be set.");
PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr,
"Output of FillZerosLikeOp must be set.");
ctx.Output<framework::Tensor>(0)->Resize(
ctx.Input<framework::Tensor>(0)->dims());
} }
}; };
......
...@@ -23,8 +23,8 @@ namespace operators { ...@@ -23,8 +23,8 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class FillZerosLikeKernel : public framework::OpKernel { class FillZerosLikeKernel : public framework::OpKernel {
public: public:
void Compute(const framework::KernelContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* output = context.Output(0)->GetMutable<framework::Tensor>(); auto* output = context.Output<framework::Tensor>(0);
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
framework::EigenVector<T>::Flatten(*output).setZero(); framework::EigenVector<T>::Flatten(*output).setZero();
} }
......
...@@ -312,13 +312,14 @@ public: ...@@ -312,13 +312,14 @@ public:
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
const auto& name = RecurrentOp::kArgName; const auto& name = RecurrentOp::kArgName;
// inputs and outputs stored in proto // inputs and outputs stored in proto
AddInputs(name.inlinks, AddInput(name.inlinks, "the input that need to be segmented for each step.")
"the input that need to be segmented for each step."); .SetMultiple();
AddInputs(name.boot_memories, "variables to initialize memories."); AddInput(name.boot_memories, "variables to initialize memories.")
.SetMultiple();
AddInput(name.step_net, "network shared by all steps."); AddInput(name.step_net, "network shared by all steps.");
AddOutputs(name.outlinks, AddOutput(name.outlinks, "the output that need to concated for all steps.")
"the output that need to concated for all steps."); .SetMultiple();
AddOutput(name.step_scopes, "step scopes"); AddOutput(name.step_scopes, "step scopes");
// Attributes stored in AttributeMap // Attributes stored in AttributeMap
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册