提交 0ac79a38 编写于 作者: D dongzhihong

Merge remote-tracking branch 'reyoung/feature/backward' into feature/backward

......@@ -33,3 +33,4 @@ cc_library(net SRCS net.cc DEPS op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op)
cc_library(backward SRCS backward.cc DEPS net)
cc_test(backward_test SRCS backward_test.cc DEPS backward)
......@@ -12,8 +12,9 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include <paddle/framework/backward.h>
#include <paddle/framework/net.h>
#include "paddle/framework/backward.h"
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace framework {
......@@ -105,6 +106,24 @@ static void DeDuplicate(NetOp* net, std::unordered_se)
//! TODO(dzh)
} else {
//! TODO(fjy)
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
for (std::string& grad_input : grad_op->inputs_) {
if (no_grad_names.count(grad_input)) {
std::string prefix = grad_input.substr(
0, grad_input.size() - OperatorBase::GRAD_VAR_SUFFIX().size());
grad_input = prefix + OperatorBase::ZERO_VAR_SUFFIX();
std::vector<std::string> fill_zeros_in = {prefix};
std::vector<std::string> fill_zeros_out = {grad_input};
net.AddOp(OpRegistry::CreateOp("fill_zeros_like", fill_zeros_in,
fill_zeros_out, AttributeMap()));
}
}
for (std::string& grad_output : grad_op->output_) {
if (no_grad_names.count(grad_output)) {
grad_output = OperatorBase::EMPTY_VAR_NAME();
}
}
net.AddOp(grad_op);
}
net->CompleteAddOp();
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/backward.h"
#include <gtest/gtest.h>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace framework {
class EmptyOp : public OperatorBase {
public:
void InferShape(const std::shared_ptr<Scope> &scope) const override {}
void Run(const std::shared_ptr<Scope> &scope,
const platform::DeviceContext &dev_ctx) const override {}
};
class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
public:
RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input X of Add").IgnoreGradient();
AddInput("b", "Bias of Add").IgnoreGradient();
AddOutput("Out", "Out of Add").IgnoreGradient();
AddComment("Add Op");
}
};
class MulOpMaker : public OpProtoAndCheckerMaker {
public:
MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("A", "A");
AddInput("B", "B");
AddOutput("Out", "Out");
AddComment("Mul");
}
};
class SigmoidOpMaker : public OpProtoAndCheckerMaker {
public:
SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "X");
AddOutput("Y", "Y");
AddComment("Sigmoid");
}
};
class FcOp : public NetOp {
public:
void Init() override {
AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")},
{Output("before_act")}, {}));
auto b_name = Input("b");
if (b_name != EMPTY_VAR_NAME()) {
AddOp(OpRegistry::CreateOp("rowwise_add", {Output("before_act"), b_name},
{Output("before_act")}, {}));
}
AddOp(OpRegistry::CreateOp("sigmoid", {Output("before_act")},
{Output("Out")}, {}));
CompleteAddOp(false);
}
};
class FcOpMaker : public OpProtoAndCheckerMaker {
public:
FcOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "x");
AddInput("W", "w");
AddInput("b", "b");
AddOutput("before_act", "before act").SetTemporary();
AddOutput("Out", "");
AddComment("");
}
};
class ManyOutputOpMaker : public OpProtoAndCheckerMaker {
public:
ManyOutputOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("x", "x");
AddOutput("y", "y");
AddOutput("z", "z");
AddComment("");
}
};
class FillZeroOpMaker : public OpProtoAndCheckerMaker {
public:
FillZeroOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("x", "x");
AddOutput("out", "out");
AddComment("");
}
};
} // namespace framework
} // namespace paddle
namespace f = paddle::framework;
using EnforceNotMet = paddle::platform::EnforceNotMet;
REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker);
REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::EmptyOp);
REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker);
REGISTER_GRADIENT_OP(mul, mul_grad, f::EmptyOp);
REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker);
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, f::EmptyOp);
REGISTER_OP(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker);
REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp);
REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker);
TEST(Backward, simple_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
ASSERT_NE(fwd, nullptr);
auto gop = f::OpRegistry::CreateGradOp(*fwd);
ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]);
ASSERT_EQ("rowwise_add_grad", gop->type_);
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]);
ASSERT_EQ("b" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[1]);
// LOG(INFO) << gop->Output("X" + "@GRAD");
}
TEST(Backward, not_for_network) {
auto fwd =
f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"},
{{"temporary_index", std::vector<int>{1}}});
ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet);
}
TEST(Backward, all_input_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
auto backward = f::Backward(*fwd, {"X", "b"});
ASSERT_TRUE(backward->IsNetOp());
auto net = static_cast<f::NetOp *>(backward.get());
ASSERT_TRUE(net->ops_.empty());
}
TEST(Backward, all_output_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
auto backward = f::Backward(*fwd, {"Out"});
ASSERT_TRUE(backward->IsNetOp());
auto net = static_cast<f::NetOp *>(backward.get());
ASSERT_TRUE(net->ops_.empty());
}
TEST(Backward, part_of_output_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {});
auto backward = f::Backward(*fwd, {"Z"});
ASSERT_TRUE(backward->IsNetOp());
auto net = static_cast<f::NetOp *>(backward.get());
ASSERT_EQ(net->ops_.size(), 2);
auto &fill_zero = *net->ops_[0];
ASSERT_EQ("fill_zeros_like", fill_zero.type_);
ASSERT_EQ(1, fill_zero.inputs_.size());
ASSERT_EQ("Z", fill_zero.inputs_[0]);
ASSERT_EQ(1, fill_zero.outputs_.size());
ASSERT_EQ("Z@ZERO", fill_zero.outputs_[0]);
auto &d_many_out = *net->ops_[1];
ASSERT_EQ("many_output_op_grad", d_many_out.type_);
ASSERT_EQ(1 + 2 + 2, d_many_out.inputs_.size()); // I/O/OG
ASSERT_EQ("Z@ZERO", d_many_out.Input("z@GRAD"));
}
\ No newline at end of file
......@@ -20,7 +20,7 @@ namespace framework {
OperatorBase* GradOpBuilder::Build() {
BuildOpInOutArgList();
std::string grad_op_type = OpRegistry::grad_ops().at(op_->type_);
std::string grad_op_type = OpRegistry::grad_ops().at(op_.type_);
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
grad_op->type_ = grad_op_type;
CompleteGradOp(grad_op);
......@@ -39,15 +39,15 @@ OpInOutArg* GradOpBuilder::BuildArg(const VarProto& var,
}
void GradOpBuilder::BuildOpInOutArgList() {
const OpProto& op_proto = OpRegistry::protos().at(op_->type_);
const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_));
const OpProto& op_proto = OpRegistry::protos().at(op_.type_);
const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_.type_));
const std::vector<int>& in_format =
op_->attrs_.count("input_format")
? op_->GetAttr<std::vector<int>>("input_format")
op_.attrs_.count("input_format")
? op_.GetAttr<std::vector<int>>("input_format")
: std::vector<int>();
const std::vector<int>& out_format =
op_->attrs_.count("output_format")
? op_->GetAttr<std::vector<int>>("output_format")
op_.attrs_.count("output_format")
? op_.GetAttr<std::vector<int>>("output_format")
: std::vector<int>();
for (const auto& var : op_proto.inputs()) {
arg_list_.emplace_back(
......@@ -70,8 +70,7 @@ void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg,
}
(*varmap)[var_name] = idx++;
size_t pre_sz = in_out.size();
auto base_it =
arg->type_ == IN ? op_->inputs_.begin() : op_->outputs_.begin();
auto base_it = arg->type_ == IN ? op_.inputs_.begin() : op_.outputs_.begin();
std::copy(base_it + arg->begin_idx_, base_it + arg->end_idx_,
std::back_inserter(in_out));
if (is_grad) {
......@@ -83,7 +82,7 @@ void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg,
}
void GradOpBuilder::CompleteGradOp(OperatorBase* grad_op) const {
grad_op->attrs_ = op_->attrs_;
grad_op->attrs_ = op_.attrs_;
grad_op->attrs_.erase("input_format");
grad_op->attrs_.erase("output_format");
VarIndexMap* grad_varmap = new VarIndexMap();
......
......@@ -29,7 +29,7 @@ class GradOpBuilder {
using VarIndexMap = std::unordered_map<std::string, int>;
public:
GradOpBuilder(const OperatorBase* op) : op_(op) {}
GradOpBuilder(const OperatorBase& op) : op_(op) {}
OperatorBase* Build();
private:
......@@ -40,7 +40,7 @@ class GradOpBuilder {
std::vector<int>& format, VarIndexMap* varmap, int& idx,
bool is_grad) const;
void CompleteGradOp(OperatorBase* grad_op) const;
const OperatorBase* op_;
const OperatorBase& op_;
std::vector<std::shared_ptr<OpInOutArg>> arg_list_;
};
......
......@@ -11,7 +11,7 @@ namespace framework {
TEST(GradOpBuilder, AddTwo) {
std::shared_ptr<OperatorBase> add_op(
OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {}));
std::shared_ptr<OperatorBase> grad_add_op = OpRegistry::CreateGradOp(add_op);
std::shared_ptr<OperatorBase> grad_add_op = OpRegistry::CreateGradOp(*add_op);
EXPECT_EQ(static_cast<int>(grad_add_op->inputs_.size()), 4);
EXPECT_EQ(static_cast<int>(grad_add_op->outputs_.size()), 2);
EXPECT_EQ(grad_add_op->Input("X"), "x");
......
......@@ -86,43 +86,46 @@ class OpProtoAndCheckerMaker {
}
protected:
void AddInput(const std::string& name, const std::string& comment,
bool multiple = false, bool ignore_gradient = false) {
auto input = proto_->mutable_inputs()->Add();
*input->mutable_name() = name;
*input->mutable_comment() = comment;
input->set_ignore_gradient(ignore_gradient);
input->set_multiple(multiple);
if (multiple) {
SetHasMultipleInput();
struct VariableBuilder {
VarProto* var_;
std::function<void()> on_multiple_;
std::function<void()> on_temporary_;
VariableBuilder& SetMultiple() {
var_->set_multiple(true);
on_multiple_();
return *this;
}
VariableBuilder& SetTemporary() {
PADDLE_ENFORCE(bool(on_temporary_), "Cannot set temporary");
var_->set_temporary(true);
on_temporary_();
return *this;
}
void AddInputs(const std::string& name, const std::string& comment,
bool ignore_gradient = false) {
AddInput(name, comment, true, ignore_gradient);
VariableBuilder& IgnoreGradient() {
var_->set_ignore_gradient(true);
return *this;
}
};
void AddOutput(const std::string& name, const std::string& comment,
bool temporary = false, bool multiple = false,
bool ignore_gradient = false) {
VariableBuilder AddInput(const std::string& name,
const std::string& comment) {
auto input = proto_->mutable_inputs()->Add();
*input->mutable_name() = name;
*input->mutable_comment() = comment;
return VariableBuilder{input, [=] { this->SetHasMultipleInput(); },
nullptr};
}
VariableBuilder AddOutput(const std::string& name,
const std::string& comment) {
auto output = proto_->mutable_outputs()->Add();
*output->mutable_name() = name;
*output->mutable_comment() = comment;
output->set_ignore_gradient(ignore_gradient);
output->set_multiple(multiple);
if (multiple) {
SetHasMultipleOutput();
}
output->set_temporary(temporary);
if (temporary) {
SetHasTemporaryOutput();
}
}
void AddOutputs(const std::string& name, const std::string& comment,
bool temporary = false, bool ignore_gradient = false) {
AddOutput(name, comment, temporary, true, ignore_gradient);
return VariableBuilder{output, [=] { this->SetHasMultipleOutput(); },
[=] { this->SetHasTemporaryOutput(); }};
}
template <typename T>
......@@ -300,11 +303,10 @@ class OpRegistry {
return CreateOp(op_desc.type(), inputs, outputs, attrs);
}
static std::shared_ptr<OperatorBase> CreateGradOp(
std::shared_ptr<OperatorBase> op) {
PADDLE_ENFORCE(!op->IsNetOp(),
static std::shared_ptr<OperatorBase> CreateGradOp(const OperatorBase& op) {
PADDLE_ENFORCE(!op.IsNetOp(),
"Use framework::Backward to get backward ops");
GradOpBuilder builder(op.get());
GradOpBuilder builder(op);
std::shared_ptr<OperatorBase> grad_op(builder.Build());
grad_op->Init();
return grad_op;
......
......@@ -36,9 +36,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInputs("input", "input of cosine op");
AddOutput("output", "output of cosine op",
/*temporary*/ true);
AddInput("input", "input of cosine op").SetMultiple();
AddOutput("output", "output of cosine op").SetTemporary();
auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
};
......
......@@ -67,6 +67,9 @@ class OperatorBase {
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; }
/// Variables with this suffix are supposed to be filled up with zeros.
static std::string ZERO_VAR_SUFFIX() { return "@ZERO"; }
virtual ~OperatorBase() {}
template <typename T>
......
......@@ -137,9 +137,9 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto,
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInputs("xs", "inputs of test op");
AddInput("xs", "inputs of test op").SetMultiple();
AddInput("k", "input of test op");
AddOutputs("ys", "outputs of test op");
AddOutput("ys", "outputs of test op").SetMultiple();
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
......
......@@ -50,8 +50,8 @@ public:
AddInput("b", "the bias of fc operator");
AddOutput("Y", "the output of fc operator");
AddOutput(
"before_act", "the before activation output of fc operator", true);
AddOutput("before_act", "the before activation output of fc operator")
.SetTemporary();
AddAttr<std::string>("activation", "The activation key for fc layer")
.SetDefault("sigmoid")
.InEnum({"sigmoid", "softmax"});
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/fill_zeros_like_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace paddle {
namespace operators {
class FillZerosLike : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 1,
"Input size of FillZerosLike must be one.");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one.");
PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr,
"Outputs of FillZerosLike must all be set.");
outputs[0]->Resize(inputs[0]->dims());
}
};
class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FillZerosLikeOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Src", "The input of fill-zeros-like op.");
AddOutput("Dst", "The varibale will be filled up with zeros.");
AddComment(R"DOC(
Fill up a vriable with zeros.
The output will have the same size with input.
)DOC")
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP(fill_zeros_like,
paddle::operators::FillZerosLikeOp,
paddle::operators::FillZerosLikeOpMaker);
EGISTER_OP_CPU_KERNEL(
fill_zeros_like,
paddle::operators::FillZerosLikeKernal<paddle::platform::CPUPlace, float>);
\ No newline at end of file
#include "paddle/framework/op_registry.h"
#include "paddle/operators/fill_zeros_like_op.h"
REGISTER_OP_GPU_KERNEL(
fill_zeros_like,
paddle::operators::FillZerosLikeKernel<paddle::platform::GPUPlace, float>);
\ No newline at end of file
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "glog/logging.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class FillZerosLikeKernel : public framework::OpKernel {
public:
void Compute(const framework::KernelContext& context) const override {
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
output->mutable_data<T>(context.GetPlace());
framework::EigenVector<T>::Flatten(*output).setZero();
}
};
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册