提交 4c96008a 编写于 作者: Y Yi Wang 提交者: GitHub

Merge pull request #4566 from reyoung/feature/grad_reg_mechanism_cont2

Complete register gradient for compile time
...@@ -26,10 +26,8 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc) ...@@ -26,10 +26,8 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope proto_desc) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope proto_desc)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator)
cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker op_info)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry sum_op)
py_proto_compile(framework_py_proto SRCS framework.proto) py_proto_compile(framework_py_proto SRCS framework.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module. # Generate an empty __init__.py to make framework_py_proto as a valid python module.
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
limitations under the License. */ limitations under the License. */
#include "paddle/framework/backward.h" #include "paddle/framework/backward.h"
#include "paddle/operators/net_op.h"
#include <list> #include <list>
#include <memory> #include <memory>
...@@ -24,6 +25,35 @@ ...@@ -24,6 +25,35 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
static inline std::unique_ptr<OperatorBase> CreateGradOp(
const OperatorBase& op) {
OpDescBind op_desc;
op_desc.SetInputMap(op.Inputs());
op_desc.SetOutputMap(op.Outputs());
op_desc.SetType(op.Type());
op_desc.SetAttrMap(op.Attrs());
auto& info = OpInfoMap::Instance().Get(op.Type());
auto grad_descs = info.GradOpMaker()(op_desc);
std::vector<std::unique_ptr<OperatorBase>> grad_ops;
grad_ops.reserve(grad_descs.size());
std::transform(grad_descs.begin(), grad_descs.end(),
std::back_inserter(grad_ops),
[](const std::unique_ptr<OpDescBind>& grad_desc) {
return OpRegistry::CreateOp(*grad_desc);
});
PADDLE_ENFORCE(!grad_ops.empty());
if (grad_ops.size() == 1) {
return std::move(grad_ops[0]);
} else {
auto net_op = new operators::NetOp();
for (auto& grad_op : grad_ops) {
net_op->AppendOp(std::move(grad_op));
}
net_op->CompleteAddOp();
return std::unique_ptr<OperatorBase>(net_op);
}
}
template <typename Map, typename T> template <typename Map, typename T>
static void ForEachVarName(const Map& names, T callback) { static void ForEachVarName(const Map& names, T callback) {
for (auto& name : names) { for (auto& name : names) {
...@@ -171,7 +201,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -171,7 +201,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
net->InsertOp(pos.first + 1, std::move(pos.second)); net->InsertOp(pos.first + 1, std::move(pos.second));
} }
} else { } else {
std::unique_ptr<OperatorBase> grad_op(OpRegistry::CreateGradOp(forwardOp)); std::unique_ptr<OperatorBase> grad_op(CreateGradOp(forwardOp));
ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op]( ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op](
const std::string& grad_input) { const std::string& grad_input) {
......
...@@ -21,24 +21,34 @@ ...@@ -21,24 +21,34 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
using OperatorBase = framework::OperatorBase;
using OpProtoAndCheckerMaker = framework::OpProtoAndCheckerMaker;
using OpProto = framework::OpProto;
using OpAttrChecker = framework::OpAttrChecker;
using Scope = framework::Scope;
using DeviceContext = platform::DeviceContext; using DeviceContext = platform::DeviceContext;
class RowWiseAddOpMaker : public OpProtoAndCheckerMaker { class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
public: public:
RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input X of Add").NotInGradient(); AddInput("X", "Input X of Add");
AddInput("b", "Bias of Add").NotInGradient(); AddInput("b", "Bias of Add");
AddOutput("Out", "Out of Add").NotInGradient(); AddOutput("Out", "Out of Add");
AddComment("Add Op"); AddComment("Add Op");
} }
}; };
class RowWiseAddGradMaker : public SingleGradOpDescMaker {
public:
using SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<OpDescBind> Apply() const override {
auto grad_op = new OpDescBind();
grad_op->SetInput(GradVarName("Out"), OutputGrad("Out"));
grad_op->SetOutput(GradVarName("X"), InputGrad("X"));
grad_op->SetOutput(GradVarName("b"), InputGrad("b"));
grad_op->SetType("rowwise_add_grad");
return std::unique_ptr<OpDescBind>(grad_op);
}
};
class MulOpMaker : public OpProtoAndCheckerMaker { class MulOpMaker : public OpProtoAndCheckerMaker {
public: public:
MulOpMaker(OpProto *proto, OpAttrChecker *op_checker) MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
...@@ -137,10 +147,8 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -137,10 +147,8 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input tensors of sum operator.") AddInput("X", "the input tensors of sum operator.").AsDuplicable();
.AsDuplicable() AddOutput("Out", "the output tensor of sum operator.");
.NotInGradient();
AddOutput("Out", "the output tensor of sum operator.").NotInGradient();
AddComment(""); AddComment("");
} }
}; };
...@@ -151,8 +159,9 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -151,8 +159,9 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
namespace f = paddle::framework; namespace f = paddle::framework;
namespace ops = paddle::operators; namespace ops = paddle::operators;
using EnforceNotMet = paddle::platform::EnforceNotMet; using EnforceNotMet = paddle::platform::EnforceNotMet;
REGISTER_OP(rowwise_add, f::NOP, f::RowWiseAddOpMaker, rowwise_add_grad, REGISTER_OPERATOR(rowwise_add, f::NOP, f::RowWiseAddOpMaker,
f::NOP); f::RowWiseAddGradMaker);
REGISTER_OPERATOR(rowwise_add_grad, f::NOP);
REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP); REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP);
REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP); REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP);
REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker); REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker);
...@@ -162,17 +171,6 @@ REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker); ...@@ -162,17 +171,6 @@ REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad, REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad,
f::NOP); f::NOP);
TEST(Backward, simple_op_grad) {
auto fwd = f::OpRegistry::CreateOp(
"rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {});
ASSERT_NE(fwd, nullptr);
auto gop = f::OpRegistry::CreateGradOp(*fwd);
ASSERT_EQ(1UL, gop->Inputs().size());
ASSERT_EQ("rowwise_add_grad", gop->Type());
ASSERT_EQ(f::GradVarName("x"), gop->Output(f::GradVarName("X")));
ASSERT_EQ(f::GradVarName("b"), gop->Output(f::GradVarName("b")));
}
TEST(Backward, simple_op_not_need_grad) { TEST(Backward, simple_op_not_need_grad) {
auto fwd = f::OpRegistry::CreateOp( auto fwd = f::OpRegistry::CreateOp(
"rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {}); "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {});
...@@ -289,17 +287,6 @@ TEST(Backward, net_shared_weight) { ...@@ -289,17 +287,6 @@ TEST(Backward, net_shared_weight) {
ASSERT_EQ("sum", bwd_net->ops_[2]->Type()); ASSERT_EQ("sum", bwd_net->ops_[2]->Type());
} }
TEST(Backward, op_register_grad_not_for_network) {
auto fwd =
f::OpRegistry::CreateOp("fc", {{"X", {"x"}}, {"W", {"w"}}, {"b", {"b"}}},
{{"mul_result", {"mul_out"}},
{"add_result", {"add_out"}},
{"Out", {"out1"}}},
{{"temporary_index", std::vector<int>{0, 1}}});
ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet);
}
TEST(Backward, op_all_input_are_not_need) { TEST(Backward, op_all_input_are_not_need) {
auto fwd = f::OpRegistry::CreateOp( auto fwd = f::OpRegistry::CreateOp(
"rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {}); "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {});
......
...@@ -66,7 +66,6 @@ message OpProto { ...@@ -66,7 +66,6 @@ message OpProto {
optional bool duplicable = 3 [ default = false ]; optional bool duplicable = 3 [ default = false ];
optional bool intermediate = 4 [ default = false ]; optional bool intermediate = 4 [ default = false ];
optional bool not_in_gradient = 5 [ default = false ];
} }
// AttrProto describes the C++ type Attribute. // AttrProto describes the C++ type Attribute.
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOpArgType::OUT WARRANTIES OR CONDITIONS OF ANY KOpArgType::IND, either
express or implied. See the License for the specific language governing
permissions and limitations under the License. */
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace framework {
enum class OpArgType { IN, OUT };
static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
bool is_grad, VariableNameMap* vars) {
const auto& src_inout =
src_type == OpArgType::IN ? src_op->Inputs() : src_op->Outputs();
auto& dst_inout = *vars;
auto& proto = OpInfoMap::Instance().Get(src_op->Type()).Proto();
const auto& src_arg_list =
src_type == OpArgType::IN ? proto.inputs() : proto.outputs();
for (const auto& arg : src_arg_list) {
if (arg.not_in_gradient() && !is_grad) continue;
const std::string src_name = arg.name();
std::string dst_name = is_grad ? GradVarName(src_name) : src_name;
dst_inout[dst_name].reserve(src_inout.at(src_name).size());
for (auto& var_name : src_inout.at(src_name)) {
std::string s = is_grad ? GradVarName(var_name) : var_name;
dst_inout[dst_name].emplace_back(s);
}
}
}
OperatorBase* BuildGradOp(const OperatorBase* op) {
auto& info = OpInfoMap::Instance().Get(op->Type());
PADDLE_ENFORCE(info.HasGradientOp());
VariableNameMap inputs;
VariableNameMap outputs;
TransOpArg(op, OpArgType::IN, false, &inputs); // I
TransOpArg(op, OpArgType::OUT, false, &inputs); // O
TransOpArg(op, OpArgType::OUT, true, &inputs); // OG
TransOpArg(op, OpArgType::IN, true, &outputs); // IG
auto& grad_info = OpInfoMap::Instance().Get(info.grad_op_type_);
return grad_info.Creator()(info.grad_op_type_, inputs, outputs, op->Attrs());
}
static void TransOpDescArg(const OpDescBind* src_op, const OpArgType& src_type,
bool is_grad, OpDescBind* dst_op,
const OpArgType& dst_type) {
PADDLE_ENFORCE(dst_op != nullptr,
"Protobuf desc of gradient op must be initialized first.");
const auto& proto = OpInfoMap::Instance().Get(src_op->Type()).Proto();
const auto& src_arg_list =
src_type == OpArgType::IN ? proto.inputs() : proto.outputs();
for (const auto& arg : src_arg_list) {
if (arg.not_in_gradient() && !is_grad) continue;
const std::string src_name = arg.name();
std::vector<std::string> vars = src_type == OpArgType::IN
? src_op->Input(src_name)
: src_op->Output(src_name);
if (is_grad) {
for (std::string& var : vars) {
var = GradVarName(var);
}
}
std::string dst_name = is_grad ? GradVarName(src_name) : src_name;
dst_type == OpArgType::IN ? dst_op->SetInput(dst_name, vars)
: dst_op->SetOutput(dst_name, vars);
}
}
void CompleteGradOpDesc(const OpDescBind* forw_op, OpDescBind* grad_op) {
auto& info = OpInfoMap::Instance().Get(forw_op->Type());
PADDLE_ENFORCE(info.HasGradientOp());
grad_op->SetType(info.grad_op_type_);
TransOpDescArg(forw_op, OpArgType::IN, false, grad_op, OpArgType::IN);
TransOpDescArg(forw_op, OpArgType::OUT, false, grad_op, OpArgType::IN);
TransOpDescArg(forw_op, OpArgType::OUT, true, grad_op, OpArgType::IN);
TransOpDescArg(forw_op, OpArgType::IN, true, grad_op, OpArgType::OUT);
grad_op->SetAttrMap(forw_op->GetAttrMap());
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace framework {
OperatorBase* BuildGradOp(const OperatorBase* op);
void CompleteGradOpDesc(const OpDescBind* forw_op, OpDescBind* grad_op);
} // namespace framework
} // namespace paddle
#include "paddle/framework/grad_op_builder.h"
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
USE_OP(sum);
namespace paddle {
namespace framework {
class MutiInOutOpMaker : public OpProtoAndCheckerMaker {
public:
MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("In1", "a single input");
AddInput("In2_mult", "a multiple input").AsDuplicable();
AddInput("In3", "another single input");
AddOutput("Out1", "a single output");
AddOutput("Out2_mult", "a multiple output").AsDuplicable();
AddComment("test op with multiple inputs and outputs");
}
};
class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
public:
IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("In1", "a single input");
AddInput("In2_mult", "a multiple input").AsDuplicable().NotInGradient();
AddInput("In3_mult", "another multiple input").AsDuplicable();
AddOutput("Out1_mult", "a multiple output").AsDuplicable();
AddOutput("Out2", "a single output").NotInGradient();
AddComment("op with inputs and outputs ignored in gradient calculating");
}
};
} // namespace framework
} // namespace paddle
namespace f = paddle::framework;
REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP);
REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP);
TEST(GradOpBuilder, MutiInOut) {
std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp(
"mult_io", {{"In1", {"in1"}},
{"In2_mult", {"in2_1", "in2_2", "in2_3"}},
{"In3", {"in3"}}},
{{"Out1", {"out1"}}, {"Out2_mult", {"out2_1", "out2_2"}}}, {}));
std::shared_ptr<f::OperatorBase> grad_test_op =
f::OpRegistry::CreateGradOp(*test_op);
ASSERT_EQ(grad_test_op->Inputs().size(), 3UL + 2UL + 2UL);
EXPECT_EQ(grad_test_op->Input("In1"), "in1");
EXPECT_EQ(grad_test_op->Inputs("In2_mult"),
std::vector<std::string>({"in2_1", "in2_2", "in2_3"}));
EXPECT_EQ(grad_test_op->Input("In3"), "in3");
EXPECT_EQ(grad_test_op->Input("Out1"), "out1");
EXPECT_EQ(grad_test_op->Inputs("Out2_mult"),
std::vector<std::string>({"out2_1", "out2_2"}));
EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out1")),
f::GradVarName("out1"));
EXPECT_EQ(grad_test_op->Inputs(f::GradVarName("Out2_mult")),
std::vector<std::string>(
{f::GradVarName("out2_1"), f::GradVarName("out2_2")}));
ASSERT_EQ(grad_test_op->Outputs().size(), 3UL);
EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1"));
EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")),
std::vector<std::string>({f::GradVarName("in2_1"),
f::GradVarName("in2_2"),
f::GradVarName("in2_3")}));
EXPECT_EQ(grad_test_op->Output(f::GradVarName("In3")), f::GradVarName("in3"));
}
TEST(GradOpBuilder, IOIgnoredInGradient) {
std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp(
"io_ignored", {{"In1", {"in1"}},
{"In2_mult", {"in2_1", "in2_2"}},
{"In3_mult", {"in3_1", "in3_2"}}},
{{"Out1_mult", {"out1_1", "out1_2"}}, {"Out2", {"out2"}}}, {}));
std::shared_ptr<f::OperatorBase> grad_test_op =
f::OpRegistry::CreateGradOp(*test_op);
// 'In2' and 'Out2' are ignored in gradient calculating
ASSERT_EQ(grad_test_op->Inputs().size(), 2UL + 1UL + 2UL);
EXPECT_EQ(grad_test_op->Input("In1"), "in1");
EXPECT_EQ(grad_test_op->Inputs("In3_mult"),
std::vector<std::string>({"in3_1", "in3_2"}));
EXPECT_EQ(grad_test_op->Inputs("Out1_mult"),
std::vector<std::string>({"out1_1", "out1_2"}));
EXPECT_EQ(grad_test_op->Inputs(f::GradVarName("Out1_mult")),
std::vector<std::string>(
{f::GradVarName("out1_1"), f::GradVarName("out1_2")}));
EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out2")),
f::GradVarName("out2"));
ASSERT_EQ(grad_test_op->Outputs().size(), 3UL);
EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1"));
EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")),
std::vector<std::string>(
{f::GradVarName("in2_1"), f::GradVarName("in2_2")}));
EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In3_mult")),
std::vector<std::string>(
{f::GradVarName("in3_1"), f::GradVarName("in3_2")}));
}
TEST(GradOpDescBuilder, MutiInOut) {
f::OpDescBind *forw_op = new f::OpDescBind();
forw_op->SetType("mult_io");
forw_op->SetInput("In1", {"in1"});
forw_op->SetInput("In2_mult", {"in2_1", "in2_2", "in2_3"});
forw_op->SetInput("In3", {"in3"});
forw_op->SetOutput("Out1", {"out1"});
forw_op->SetOutput("Out2_mult", {"out2_1", "out2_2"});
f::OpDescBind *grad_op = new f::OpDescBind();
f::CompleteGradOpDesc(forw_op, grad_op);
EXPECT_EQ(grad_op->Type(), "mult_io_grad");
ASSERT_EQ(grad_op->InputNames().size(), 3UL + 2UL + 2UL);
EXPECT_EQ(grad_op->Input("In1"), std::vector<std::string>({"in1"}));
EXPECT_EQ(grad_op->Input("In2_mult"),
std::vector<std::string>({"in2_1", "in2_2", "in2_3"}));
EXPECT_EQ(grad_op->Input("In3"), std::vector<std::string>({"in3"}));
EXPECT_EQ(grad_op->Input("Out1"), std::vector<std::string>({"out1"}));
EXPECT_EQ(grad_op->Input("Out2_mult"),
std::vector<std::string>({"out2_1", "out2_2"}));
EXPECT_EQ(grad_op->Input(f::GradVarName("Out1")),
std::vector<std::string>({f::GradVarName("out1")}));
EXPECT_EQ(grad_op->Input(f::GradVarName("Out2_mult")),
std::vector<std::string>(
{f::GradVarName("out2_1"), f::GradVarName("out2_2")}));
ASSERT_EQ(grad_op->OutputNames().size(), 3UL);
EXPECT_EQ(grad_op->Output(f::GradVarName("In1")),
std::vector<std::string>({f::GradVarName("in1")}));
EXPECT_EQ(grad_op->Output(f::GradVarName("In2_mult")),
std::vector<std::string>({f::GradVarName("in2_1"),
f::GradVarName("in2_2"),
f::GradVarName("in2_3")}));
EXPECT_EQ(grad_op->Output(f::GradVarName("In3")),
std::vector<std::string>({f::GradVarName("in3")}));
delete forw_op;
delete grad_op;
}
TEST(GradOpDescBuilder, IOIgnoredInGradient) {
f::OpDescBind *forw_op = new f::OpDescBind();
forw_op->SetType("io_ignored");
forw_op->SetInput("In1", {"in1"});
forw_op->SetInput("In2_mult", {"in2_1", "in2_2"});
forw_op->SetInput("In3_mult", {"in3_1", "in3_2"});
forw_op->SetOutput("Out1_mult", {"out1_1", "out1_2"});
forw_op->SetOutput("Out2", {"out2"});
f::OpDescBind *grad_op = new f::OpDescBind();
f::CompleteGradOpDesc(forw_op, grad_op);
EXPECT_EQ(grad_op->Type(), "io_ignored_grad");
// 'In2' and 'Out2' are ignored in gradient calculating
ASSERT_EQ(grad_op->InputNames().size(), 2UL + 1UL + 2UL);
EXPECT_EQ(grad_op->Input("In1"), std::vector<std::string>({"in1"}));
EXPECT_EQ(grad_op->Input("In3_mult"),
std::vector<std::string>({"in3_1", "in3_2"}));
EXPECT_EQ(grad_op->Input("Out1_mult"),
std::vector<std::string>({"out1_1", "out1_2"}));
EXPECT_EQ(grad_op->Input(f::GradVarName("Out1_mult")),
std::vector<std::string>(
{f::GradVarName("out1_1"), f::GradVarName("out1_2")}));
EXPECT_EQ(grad_op->Input(f::GradVarName("Out2")),
std::vector<std::string>({f::GradVarName("out2")}));
ASSERT_EQ(grad_op->OutputNames().size(), 3UL);
EXPECT_EQ(grad_op->Output(f::GradVarName("In1")),
std::vector<std::string>({f::GradVarName("in1")}));
EXPECT_EQ(grad_op->Output(f::GradVarName("In2_mult")),
std::vector<std::string>(
{f::GradVarName("in2_1"), f::GradVarName("in2_2")}));
EXPECT_EQ(grad_op->Output(f::GradVarName("In3_mult")),
std::vector<std::string>(
{f::GradVarName("in3_1"), f::GradVarName("in3_2")}));
delete forw_op;
delete grad_op;
}
...@@ -70,6 +70,22 @@ class OpDescBind { ...@@ -70,6 +70,22 @@ class OpDescBind {
std::vector<std::string> InputNames() const { return MapKeys(inputs_); } std::vector<std::string> InputNames() const { return MapKeys(inputs_); }
std::vector<std::string> OutputNames() const { return MapKeys(outputs_); } std::vector<std::string> OutputNames() const { return MapKeys(outputs_); }
void SetInputMap(const VariableNameMap &input) {
this->inputs_ = input;
this->need_update_ = true;
}
void SetOutputMap(const VariableNameMap &output) {
this->outputs_ = output;
this->need_update_ = true;
}
void Sync();
const VariableNameMap &Inputs() const { return inputs_; }
const VariableNameMap &Outputs() const { return outputs_; }
private: private:
template <typename MapType> template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) { static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
...@@ -81,8 +97,6 @@ class OpDescBind { ...@@ -81,8 +97,6 @@ class OpDescBind {
return ret_val; return ret_val;
} }
void Sync();
OpDesc op_desc_; OpDesc op_desc_;
VariableNameMap inputs_; VariableNameMap inputs_;
VariableNameMap outputs_; VariableNameMap outputs_;
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <map> #include <map>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
#include "paddle/framework/op_desc.h" #include "paddle/framework/op_desc.h"
#include "paddle/framework/type_defs.h" #include "paddle/framework/type_defs.h"
...@@ -27,7 +28,6 @@ namespace framework { ...@@ -27,7 +28,6 @@ namespace framework {
struct OpInfo { struct OpInfo {
OpCreator creator_; OpCreator creator_;
std::string grad_op_type_;
GradOpMakerFN grad_op_maker_; GradOpMakerFN grad_op_maker_;
OpProto* proto_{nullptr}; OpProto* proto_{nullptr};
OpAttrChecker* checker_{nullptr}; OpAttrChecker* checker_{nullptr};
...@@ -43,19 +43,19 @@ struct OpInfo { ...@@ -43,19 +43,19 @@ struct OpInfo {
return *proto_; return *proto_;
} }
const OpAttrChecker& Checker() const {
PADDLE_ENFORCE_NOT_NULL(checker_,
"Operator Checker has not been registered");
return *checker_;
}
const OpCreator& Creator() const { const OpCreator& Creator() const {
PADDLE_ENFORCE_NOT_NULL(creator_, PADDLE_ENFORCE_NOT_NULL(creator_,
"Operator Creator has not been registered"); "Operator Creator has not been registered");
return creator_; return creator_;
} }
bool HasGradientOp() const { return !grad_op_type_.empty(); } const GradOpMakerFN& GradOpMaker() const {
PADDLE_ENFORCE_NOT_NULL(grad_op_maker_,
"Operator GradOpMaker has not been registered.");
return grad_op_maker_;
}
const OpAttrChecker* Checker() const { return checker_; }
}; };
class OpInfoMap { class OpInfoMap {
......
...@@ -44,11 +44,6 @@ class OpProtoAndCheckerMaker { ...@@ -44,11 +44,6 @@ class OpProtoAndCheckerMaker {
var_->set_intermediate(true); var_->set_intermediate(true);
return *this; return *this;
} }
VariableBuilder& NotInGradient() {
var_->set_not_in_gradient(true);
return *this;
}
}; };
VariableBuilder AddInput(const std::string& name, const std::string& comment); VariableBuilder AddInput(const std::string& name, const std::string& comment);
......
...@@ -23,7 +23,9 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp( ...@@ -23,7 +23,9 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
const std::string& type, const VariableNameMap& inputs, const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, AttributeMap attrs) { const VariableNameMap& outputs, AttributeMap attrs) {
auto& info = OpInfoMap::Instance().Get(type); auto& info = OpInfoMap::Instance().Get(type);
info.Checker().Check(attrs); if (info.Checker() != nullptr) {
info.Checker()->Check(attrs);
}
auto op = info.Creator()(type, inputs, outputs, attrs); auto op = info.Creator()(type, inputs, outputs, attrs);
return std::unique_ptr<OperatorBase>(op); return std::unique_ptr<OperatorBase>(op);
} }
...@@ -52,9 +54,9 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) { ...@@ -52,9 +54,9 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
return CreateOp(op_desc.type(), inputs, outputs, attrs); return CreateOp(op_desc.type(), inputs, outputs, attrs);
} }
std::unique_ptr<OperatorBase> OpRegistry::CreateGradOp(const OperatorBase& op) { std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDescBind& op_desc) {
PADDLE_ENFORCE(!op.IsNetOp(), "Use framework::Backward to get backward ops"); return CreateOp(op_desc.Type(), op_desc.Inputs(), op_desc.Outputs(),
return std::unique_ptr<OperatorBase>(BuildGradOp(&op)); op_desc.GetAttrMap());
} }
} // namespace framework } // namespace framework
......
...@@ -23,25 +23,37 @@ limitations under the License. */ ...@@ -23,25 +23,37 @@ limitations under the License. */
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
#include "paddle/framework/details/op_registry.h" #include "paddle/framework/details/op_registry.h"
#include "paddle/framework/framework.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/framework/grad_op_builder.h" #include "paddle/framework/grad_op_desc_maker.h"
#include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class Registrar {
public:
// In our design, various kinds of classes, e.g., operators and kernels,
// have their corresponding registry and registrar. The action of
// registration is in the constructor of a global registrar variable, which,
// however, are not used in the code that calls package framework, and would
// be removed from the generated binary file by the linker. To avoid such
// removal, we add Touch to all registrar classes and make USE_OP macros to
// call this method. So, as long as the callee code calls USE_OP, the global
// registrar variable won't be removed by the linker.
void Touch() {}
};
template <typename... ARGS> template <typename... ARGS>
struct OperatorRegistrar { struct OperatorRegistrar : public Registrar {
explicit OperatorRegistrar(const char* op_type) : op_type(op_type) { explicit OperatorRegistrar(const char* op_type) : op_type(op_type) {
PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type), PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type),
"'%s' is registered more than once.", op_type); "'%s' is registered more than once.", op_type);
static_assert(sizeof...(ARGS) != 0, static_assert(sizeof...(ARGS) != 0,
"OperatorRegistrar should be invoked at least by OpClass"); "OperatorRegistrar should be invoked at least by OpClass");
details::OperatorRegistrarRecursive<0, false, ARGS...>(op_type, &info); details::OperatorRegistrarRecursive<0, false, ARGS...>(op_type, &info);
OpInfoMap::Instance().Insert(op_type, info);
} }
~OperatorRegistrar() { OpInfoMap::Instance().Insert(op_type, info); }
const char* op_type; const char* op_type;
OpInfo info; OpInfo info;
...@@ -67,20 +79,7 @@ class OpRegistry { ...@@ -67,20 +79,7 @@ class OpRegistry {
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
static std::unique_ptr<OperatorBase> CreateGradOp(const OperatorBase& op); static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc);
};
class Registrar {
public:
// In our design, various kinds of classes, e.g., operators and kernels,
// have their corresponding registry and registrar. The action of
// registration is in the constructor of a global registrar variable, which,
// however, are not used in the code that calls package framework, and would
// be removed from the generated binary file by the linker. To avoid such
// removal, we add Touch to all registrar classes and make USE_OP macros to
// call this method. So, as long as the callee code calls USE_OP, the global
// registrar variable won't be removed by the linker.
void Touch() {}
}; };
template <typename OpType, typename ProtoMakerType, typename GradOpType> template <typename OpType, typename ProtoMakerType, typename GradOpType>
...@@ -138,33 +137,41 @@ class OpKernelRegistrar : public Registrar { ...@@ -138,33 +137,41 @@ class OpKernelRegistrar : public Registrar {
__test_global_namespace_##uniq_name##__>::value, \ __test_global_namespace_##uniq_name##__>::value, \
msg) msg)
#define REGISTER_OPERATOR(op_type, op_class, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, \
"REGISTER_OPERATOR must be called in global namespace"); \
class _OpClass_##op_type##_ : public op_class { \
public: \
DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \
DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \
}; \
static ::paddle::framework::OperatorRegistrar<_OpClass_##op_type##_, \
##__VA_ARGS__> \
__op_registrar_##op_type##__(#op_type); \
int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \
return 0; \
}
/** /**
* Macro to register Operator. * Macro to register Operator.
*/ */
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \ #define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class) \ grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ REGISTER_OPERATOR(grad_op_type, grad_op_class); \
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ class _GradOpDescMaker_##grad_op_type##_ \
class _OpClass_##op_type##_ : public op_class { \ : public ::paddle::framework::DefaultGradOpDescMaker { \
public: \ using ::paddle::framework::DefaultGradOpDescMaker::DefaultGradOpDescMaker; \
DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \ \
DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \ protected: \
}; \ virtual std::string GradOpType() const { return #grad_op_type; } \
class _OpGradClass_##op_type##_ : public grad_op_class { \ }; \
public: \ REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \
DEFINE_OP_CLONE_METHOD(_OpGradClass_##op_type##_); \ op_maker_class);
DEFINE_OP_CONSTRUCTOR(_OpGradClass_##op_type##_, grad_op_class); \
}; \
static ::paddle::framework::OpRegistrar< \
_OpClass_##op_type##_, op_maker_class, _OpGradClass_##op_type##_> \
__op_registrar_##op_type##__(#op_type, #grad_op_type); \
int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \
return 0; \
}
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
REGISTER_OP(op_type, op_class, op_maker_class, , ::paddle::framework::NOP) REGISTER_OPERATOR(op_type, op_class, op_maker_class)
/** /**
* Macro to register OperatorKernel. * Macro to register OperatorKernel.
......
...@@ -36,7 +36,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -36,7 +36,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
MeanOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) MeanOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of mean op"); AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op").NotInGradient(); AddOutput("Out", "The output of mean op");
AddComment(R"DOC( Mean Operator AddComment(R"DOC( Mean Operator
)DOC"); )DOC");
} }
...@@ -52,11 +52,27 @@ class MeanGradOp : public framework::OperatorWithKernel { ...@@ -52,11 +52,27 @@ class MeanGradOp : public framework::OperatorWithKernel {
} }
}; };
class MeanGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto* grad_op = new framework::OpDescBind();
grad_op->SetType("mean_grad");
grad_op->SetInput("X", Input("X"));
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
return std::unique_ptr<framework::OpDescBind>(grad_op);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, mean_grad, ops::MeanGradOp); REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanGradMaker);
REGISTER_OPERATOR(mean_grad, ops::MeanGradOp);
REGISTER_OP_CPU_KERNEL(mean, REGISTER_OP_CPU_KERNEL(mean,
ops::MeanKernel<paddle::platform::CPUPlace, float>); ops::MeanKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mean_grad, REGISTER_OP_CPU_KERNEL(mean_grad,
......
...@@ -49,9 +49,9 @@ class MinusOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -49,9 +49,9 @@ class MinusOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MinusOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) MinusOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The left tensor of minus operator.").NotInGradient(); AddInput("X", "The left tensor of minus operator.");
AddInput("Y", "The right tensor of minus operator.").NotInGradient(); AddInput("Y", "The right tensor of minus operator.");
AddOutput("Out", "The output tensor of minus operator.").NotInGradient(); AddOutput("Out", "The output tensor of minus operator.");
AddComment(R"DOC(Minus Operator AddComment(R"DOC(Minus Operator
...@@ -64,26 +64,35 @@ or not. But the output only shares the LoD with input `X`. ...@@ -64,26 +64,35 @@ or not. But the output only shares the LoD with input `X`.
)DOC"); )DOC");
} }
}; };
template <typename AttrType>
class MinusGradOp : public NetOp { class MinusGradMaker : public framework::GradOpDescMakerBase {
public: public:
MinusGradOp(const std::string &type, const framework::VariableNameMap &inputs, using framework::GradOpDescMakerBase::GradOpDescMakerBase;
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) std::vector<std::unique_ptr<framework::OpDescBind>> operator()()
: NetOp(type, inputs, outputs, attrs) { const override {
auto out_grad = Input(framework::GradVarName("Out")); std::vector<std::unique_ptr<framework::OpDescBind>> ops;
auto x_grad = Output(framework::GradVarName("X")); auto x_g = InputGrad("X");
auto y_grad = Output(framework::GradVarName("Y")); if (!x_g.empty()) {
auto *x_g_op = new framework::OpDescBind();
// x_grad = out_grad x_g_op->SetType("scale");
AppendOp(framework::OpRegistry::CreateOp("identity", {{"X", {out_grad}}}, x_g_op->SetInput("X", OutputGrad("Out"));
{{"Y", {x_grad}}}, {})); x_g_op->SetOutput("Out", x_g);
x_g_op->SetAttr("scale", 1.0f);
framework::AttributeMap scale_attr; ops.emplace_back(x_g_op);
scale_attr["scale"] = static_cast<AttrType>(-1); }
AppendOp(framework::OpRegistry::CreateOp("scale", {{"X", {out_grad}}},
{{"Out", {y_grad}}}, scale_attr)); auto y_g = InputGrad("Y");
CompleteAddOp(false); if (!y_g.empty()) {
auto *y_g_op = new framework::OpDescBind();
y_g_op->SetType("scale");
y_g_op->SetInput("X", OutputGrad("Out"));
y_g_op->SetOutput("Out", y_g);
y_g_op->SetAttr("scale", -1.0f);
ops.emplace_back(y_g_op);
}
return ops;
} }
}; };
...@@ -91,7 +100,6 @@ class MinusGradOp : public NetOp { ...@@ -91,7 +100,6 @@ class MinusGradOp : public NetOp {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(minus, ops::MinusOp, ops::MinusOpMaker, minus_grad, REGISTER_OPERATOR(minus, ops::MinusOp, ops::MinusOpMaker, ops::MinusGradMaker);
ops::MinusGradOp<float>);
REGISTER_OP_CPU_KERNEL(minus, REGISTER_OP_CPU_KERNEL(minus,
ops::MinusKernel<paddle::platform::CPUPlace, float>); ops::MinusKernel<paddle::platform::CPUPlace, float>);
...@@ -56,8 +56,7 @@ class PadOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -56,8 +56,7 @@ class PadOpMaker : public framework::OpProtoAndCheckerMaker {
"The input should be a k-D tensor(k > 0 and k < 7)"); "The input should be a k-D tensor(k > 0 and k < 7)");
AddOutput("Out", AddOutput("Out",
"The output of pad op." "The output of pad op."
"A tensor with the same shape as X.") "A tensor with the same shape as X.");
.NotInGradient();
AddComment(R"DOC( AddComment(R"DOC(
Pad input into output, as specified by paddings and pad_value. The input should be a k-D tensor(k > 0 and k < 7). As an example: Pad input into output, as specified by paddings and pad_value. The input should be a k-D tensor(k > 0 and k < 7). As an example:
...@@ -111,11 +110,29 @@ class PadOpGrad : public framework::OperatorWithKernel { ...@@ -111,11 +110,29 @@ class PadOpGrad : public framework::OperatorWithKernel {
} }
}; };
class PadOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto* bind = new framework::OpDescBind();
bind->SetInput("X", Input("X"));
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind->SetOutput(framework::GradVarName("X"), InputGrad("X"));
bind->SetAttrMap(Attrs());
bind->SetType("pad_grad");
return std::unique_ptr<framework::OpDescBind>(bind);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(pad, ops::PadOp, ops::PadOpMaker, pad_grad, ops::PadOpGrad);
REGISTER_OPERATOR(pad, ops::PadOp, ops::PadOpMaker, ops::PadOpGradMaker);
REGISTER_OPERATOR(pad_grad, ops::PadOpGrad);
REGISTER_OP_CPU_KERNEL(pad, ops::PadKernel<paddle::platform::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(pad, ops::PadKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(pad_grad, REGISTER_OP_CPU_KERNEL(pad_grad,
ops::PadGradKernel<paddle::platform::CPUPlace, float>); ops::PadGradKernel<paddle::platform::CPUPlace, float>);
...@@ -41,8 +41,8 @@ class ScaleOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -41,8 +41,8 @@ class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of scale operator.").NotInGradient(); AddInput("X", "The input tensor of scale operator.");
AddOutput("Out", "The output tensor of scale operator.").NotInGradient(); AddOutput("Out", "The output tensor of scale operator.");
AddComment(R"DOC(Scale operator AddComment(R"DOC(Scale operator
The equation is: Out = scale*X The equation is: Out = scale*X
...@@ -52,21 +52,18 @@ The equation is: Out = scale*X ...@@ -52,21 +52,18 @@ The equation is: Out = scale*X
} }
}; };
// The operator to calculate gradients of a scale operator is just the scale class ScaleGradMaker : public framework::SingleGradOpDescMaker {
// operator itself.
// Grad(Out=scale(X)) => Grad(X) = scale(Grad(Out))
template <typename AttrType>
class ScaleGradOp : public NetOp {
public: public:
ScaleGradOp(const std::string &type, const framework::VariableNameMap &inputs, using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) protected:
: NetOp(type, inputs, outputs, attrs) { std::unique_ptr<framework::OpDescBind> Apply() const override {
AppendOp(framework::OpRegistry::CreateOp( auto *grad_op = new framework::OpDescBind();
"scale", {{"X", {Input(framework::GradVarName("Out"))}}}, grad_op->SetType("scale");
{{"Out", {Output(framework::GradVarName("X"))}}}, grad_op->SetInput("X", OutputGrad("Out"));
{{"scale", Attr<AttrType>("scale")}})); grad_op->SetOutput("Out", InputGrad("X"));
CompleteAddOp(false); grad_op->SetAttr("scale", GetAttr("scale"));
return std::unique_ptr<framework::OpDescBind>(grad_op);
} }
}; };
...@@ -75,7 +72,7 @@ class ScaleGradOp : public NetOp { ...@@ -75,7 +72,7 @@ class ScaleGradOp : public NetOp {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(scale, ops::ScaleOp, ops::ScaleOpMaker<float>, scale_grad, REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker<float>,
ops::ScaleGradOp<float>); ops::ScaleGradMaker);
REGISTER_OP_CPU_KERNEL(scale, REGISTER_OP_CPU_KERNEL(scale,
ops::ScaleKernel<paddle::platform::CPUPlace, float>); ops::ScaleKernel<paddle::platform::CPUPlace, float>);
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/operators/softmax_with_cross_entropy_op.h" #include "paddle/operators/softmax_with_cross_entropy_op.h"
#include <paddle/function/TensorType.h> #include <paddle/function/TensorType.h>
#include <iostream>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -27,15 +28,14 @@ class SoftmaxWithCrossEntropyOpMaker ...@@ -27,15 +28,14 @@ class SoftmaxWithCrossEntropyOpMaker
AddInput("Logits", AddInput("Logits",
"(Tensor, default: Tensor<float>), The unscaled log probabilities " "(Tensor, default: Tensor<float>), The unscaled log probabilities "
"which is a 2-D tensor with shape [N x K]. N is the batch_size, " "which is a 2-D tensor with shape [N x K]. N is the batch_size, "
"and K is the class number.") "and K is the class number.");
.NotInGradient(); AddInput("Label",
AddInput( "(Tensor, default: Tensor<int>), The ground truth which is a 2-D "
"Label", "tensor. "
"(Tensor, default: Tensor<int>), The ground truth which is a 2-D " "If softLable is set to 0, Label is a Tensor<int> with shape [N x "
"tensor. " "1]. "
"If softLable is set to 0, Label is a Tensor<int> with shape [N x 1]. " "If softLable is set to 1, Label is a Tensor<float/double> "
"If softLable is set to 1, Label is a Tensor<float/double> " "with shape [N x K].");
"with shape [N x K].");
AddOutput( AddOutput(
"Softmax", "Softmax",
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x K]. " "(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x K]. "
...@@ -163,15 +163,34 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { ...@@ -163,15 +163,34 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
} }
}; };
class SoftmaxGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto* grad_op = new framework::OpDescBind();
grad_op->SetType("softmax_with_cross_entropy_grad");
grad_op->SetInput("Label", Input("Label"));
grad_op->SetInput("Softmax", Output("Softmax"));
grad_op->SetInput("Loss", Output("Loss"));
grad_op->SetInput(framework::GradVarName("Softmax"), OutputGrad("Softmax"));
grad_op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
grad_op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits"));
grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDescBind>(grad_op);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp, REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
ops::SoftmaxWithCrossEntropyOpMaker, ops::SoftmaxWithCrossEntropyOpMaker, ops::SoftmaxGradMaker);
softmax_with_cross_entropy_grad, REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyOpGrad); ops::SoftmaxWithCrossEntropyOpGrad);
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy, REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
ops::SoftmaxWithCrossEntropyKernel<float>); ops::SoftmaxWithCrossEntropyKernel<float>);
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad, REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
......
...@@ -45,10 +45,8 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -45,10 +45,8 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SumOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) SumOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input tensors of sum operator.") AddInput("X", "the input tensors of sum operator.").AsDuplicable();
.AsDuplicable() AddOutput("Out", "the output tensor of sum operator.");
.NotInGradient();
AddOutput("Out", "the output tensor of sum operator.").NotInGradient();
AddComment(R"DOC( AddComment(R"DOC(
Sum the input tensors. Sum the input tensors.
...@@ -58,23 +56,26 @@ or not. But the output only shares the LoD with the first input. ...@@ -58,23 +56,26 @@ or not. But the output only shares the LoD with the first input.
} }
}; };
class SumGradOp : public NetOp { class SumGradMaker : public framework::GradOpDescMakerBase {
public: public:
SumGradOp(const std::string& type, const framework::VariableNameMap& inputs, using framework::GradOpDescMakerBase::GradOpDescMakerBase;
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: NetOp(type, inputs, outputs, attrs) {
auto& x_grad_names = Outputs(framework::GradVarName("X"));
auto out_grad_name = this->Input(framework::GradVarName("Out"));
framework::AttributeMap grad_attrs; std::vector<std::unique_ptr<framework::OpDescBind>> operator()()
grad_attrs["scale"] = 1.0f; const override {
for (auto& x_grad_name : x_grad_names) { auto x_grads = InputGrad("X");
AppendOp(framework::OpRegistry::CreateOp( std::vector<std::unique_ptr<framework::OpDescBind>> grad_ops;
"scale", {{"X", {out_grad_name}}}, {{"Out", {x_grad_name}}}, grad_ops.reserve(x_grads.size());
grad_attrs)); auto og = OutputGrad("Out");
} std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
CompleteAddOp(false); [&og](const std::string& x_grad) {
auto* grad_op = new framework::OpDescBind();
grad_op->SetType("scale");
grad_op->SetInput("X", og);
grad_op->SetOutput("Out", {x_grad});
grad_op->SetAttr("scale", 1.0f);
return std::unique_ptr<framework::OpDescBind>(grad_op);
});
return grad_ops;
} }
}; };
...@@ -82,5 +83,6 @@ class SumGradOp : public NetOp { ...@@ -82,5 +83,6 @@ class SumGradOp : public NetOp {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(sum, ops::SumOp, ops::SumOpMaker, sum_grad, ops::SumGradOp);
REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradMaker);
REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel<paddle::platform::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel<paddle::platform::CPUPlace, float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册