提交 5f6e5ed0 编写于 作者: Y Yi Wang 提交者: GitHub

Merge pull request #7 from qingqing01/grad_op_builder

Update grad_op_builder after refactoring framework proto.
...@@ -18,59 +18,32 @@ permissions and limitations under the License. */ ...@@ -18,59 +18,32 @@ permissions and limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
/**
class OpRegistry; class OpRegistry;
using VarIndexMap = std::unordered_map<std::string, int>; using VarIndexMap = std::unordered_map<std::string, int>;
enum class OpArgType { IN, OUT }; enum class OpArgType { IN, OUT };
static std::vector<int>* GetOpFormat(OperatorBase* op, const OpArgType& type) {
std::string key = type == OpArgType::IN ? "input_format" : "output_format";
return op->attrs_.count(key)
? &boost::get<std::vector<int>>(op->attrs_.at(key))
: nullptr;
}
static const std::vector<int>* GetOpFormat(const OperatorBase* op,
const OpArgType& type) {
std::string key = type == OpArgType::IN ? "input_format" : "output_format";
return op->attrs_.count(key)
? &boost::get<std::vector<int>>(op->attrs_.at(key))
: nullptr;
}
static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
const OpArgType& src_type, const OpArgType& dst_type, const OpArgType& src_type, const OpArgType& dst_type,
int& idx, bool is_grad) { bool is_grad) {
const std::vector<std::string>& src_inout = const auto& src_inout =
src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_; src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_;
const std::vector<int>* src_format = GetOpFormat(src_op, src_type);
std::vector<std::string>& dst_inout = auto& dst_inout =
dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_; dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_;
std::vector<int>* dst_format = GetOpFormat(dst_op, dst_type);
const OpProto& proto = OpRegistry::protos().at(src_op->type_); const OpProto& proto = OpRegistry::protos().at(src_op->type_);
const auto& src_arg_list = const auto& src_arg_list =
src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); src_type == OpArgType::IN ? proto.inputs() : proto.outputs();
for (const auto& arg : src_arg_list) { for (const auto& arg : src_arg_list) {
std::string src_name = arg.name(); std::string src_name = arg.name();
std::string dst_name = is_grad ? src_name + kGradVarSuffix : src_name; std::string dst_name = is_grad ? GradVarName(src_name) : src_name;
(*dst_op->in_out_idxs_)[dst_name] = idx++; for (auto& var_name : src_inout.at(src_name)) {
int src_arg_idx = src_op->in_out_idxs_->at(src_name); std::string s = is_grad ? GradVarName(var_name)
int src_begin = : (arg.no_gradient() ? kEmptyVarName : var_name);
src_format == nullptr ? src_arg_idx : src_format->at(src_arg_idx); dst_inout[dst_name].emplace_back(s);
int src_end = src_format == nullptr ? src_arg_idx + 1
: src_format->at(src_arg_idx + 1);
for (int i = src_begin; i < src_end; ++i) {
std::string s =
is_grad ? src_inout[i] + kGradVarSuffix
: (arg.ignore_gradient() ? kEmptyVarName : src_inout[i]);
dst_inout.emplace_back(s);
}
if (dst_format != nullptr) {
dst_format->push_back(dst_inout.size());
} }
} }
} }
...@@ -80,25 +53,12 @@ OperatorBase* BuildGradOp(const OperatorBase* op) { ...@@ -80,25 +53,12 @@ OperatorBase* BuildGradOp(const OperatorBase* op) {
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)(); OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
grad_op->type_ = grad_op_type; grad_op->type_ = grad_op_type;
grad_op->attrs_ = op->attrs_; grad_op->attrs_ = op->attrs_;
grad_op->attrs_.erase("input_format"); TransOpArg(op, grad_op, OpArgType::IN, OpArgType::IN, false); // I
grad_op->attrs_.erase("output_format"); TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, false); // O
if (GetOpFormat(op, OpArgType::IN) != nullptr) { TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, true); // OG
grad_op->attrs_["output_format"] = std::vector<int>({0}); TransOpArg(op, grad_op, OpArgType::IN, OpArgType::OUT, true); // IG
}
if (GetOpFormat(op, OpArgType::IN) != nullptr ||
GetOpFormat(op, OpArgType::OUT) != nullptr) {
grad_op->attrs_["input_format"] = std::vector<int>({0});
}
grad_op->in_out_idxs_.reset(new VarIndexMap());
int in_idx = 0;
int out_idx = 0;
TransOpArg(op, grad_op, OpArgType::IN, OpArgType::IN, in_idx, false); // I
TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, in_idx, false); // G
TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, in_idx, true); // OG
TransOpArg(op, grad_op, OpArgType::IN, OpArgType::OUT, out_idx, true); // IG
return grad_op; return grad_op;
} }
**/
OperatorBase* BuildGradOp(const OperatorBase* op) { return nullptr; }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -51,14 +51,14 @@ TEST(GradOpBuilder, AddTwo) { ...@@ -51,14 +51,14 @@ TEST(GradOpBuilder, AddTwo) {
"add_two", {{"X", {"x"}}, {"Y", {"y"}}}, {{"Out", {"out"}}}, {})); "add_two", {{"X", {"x"}}, {"Y", {"y"}}}, {{"Out", {"out"}}}, {}));
std::shared_ptr<f::OperatorBase> grad_add_op = std::shared_ptr<f::OperatorBase> grad_add_op =
f::OpRegistry::CreateGradOp(*add_op); f::OpRegistry::CreateGradOp(*add_op);
EXPECT_EQ(static_cast<int>(grad_add_op->inputs_.size()), 4); EXPECT_EQ(grad_add_op->inputs_.size(), 4UL);
EXPECT_EQ(static_cast<int>(grad_add_op->outputs_.size()), 2); EXPECT_EQ(grad_add_op->outputs_.size(), 2UL);
EXPECT_EQ(grad_add_op->Input("X"), "x"); EXPECT_EQ(grad_add_op->Input("X"), "x");
EXPECT_EQ(grad_add_op->Input("Y"), "y"); EXPECT_EQ(grad_add_op->Input("Y"), "y");
EXPECT_EQ(grad_add_op->Input("Out"), "out"); EXPECT_EQ(grad_add_op->Input("Out"), "out");
EXPECT_EQ(grad_add_op->Input("Out@GRAD"), "out@GRAD"); EXPECT_EQ(grad_add_op->Input(f::GradVarName("Out")), f::GradVarName("out"));
EXPECT_EQ(grad_add_op->Output("X@GRAD"), "x@GRAD"); EXPECT_EQ(grad_add_op->Output(f::GradVarName("X")), f::GradVarName("x"));
EXPECT_EQ(grad_add_op->Output("Y@GRAD"), "y@GRAD"); EXPECT_EQ(grad_add_op->Output(f::GradVarName("Y")), f::GradVarName("y"));
} }
REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker); REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker);
...@@ -67,17 +67,16 @@ REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker); ...@@ -67,17 +67,16 @@ REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker);
REGISTER_GRADIENT_OP(io_ignored, io_ignored_grad, f::NOP); REGISTER_GRADIENT_OP(io_ignored, io_ignored_grad, f::NOP);
TEST(GradOpBuilder, MutiInOut) { TEST(GradOpBuilder, MutiInOut) {
f::AttributeMap attrs{{"input_format", std::vector<int>{0, 1, 4, 5}},
{"output_format", std::vector<int>{0, 1, 3}}};
std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp( std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp(
"mult_io", {{"In1", {"in1"}}, "mult_io",
{"In2_mult", {"in2_1", "in2_2", "in2_3"}}, {{"In1", {"in1"}},
{"In3", {"in3"}}}, {"In2_mult", {"in2_1", "in2_2", "in2_3"}},
{{"Out1", {"Out2_mult"}}, {"Out2", {"out2_1", "out2_2"}}}, attrs)); {"In3", {"in3"}}},
{{"Out1", {"out1"}}, {"Out2_mult", {"out2_1", "out2_2"}}}, {}));
std::shared_ptr<f::OperatorBase> grad_test_op = std::shared_ptr<f::OperatorBase> grad_test_op =
f::OpRegistry::CreateGradOp(*test_op); f::OpRegistry::CreateGradOp(*test_op);
ASSERT_EQ(grad_test_op->inputs_.size(), 5UL + 3UL + 3UL); ASSERT_EQ(grad_test_op->inputs_.size(), 3UL + 2UL + 2UL);
EXPECT_EQ(grad_test_op->Input("In1"), "in1"); EXPECT_EQ(grad_test_op->Input("In1"), "in1");
EXPECT_EQ(grad_test_op->Inputs("In2_mult"), EXPECT_EQ(grad_test_op->Inputs("In2_mult"),
std::vector<std::string>({"in2_1", "in2_2", "in2_3"})); std::vector<std::string>({"in2_1", "in2_2", "in2_3"}));
...@@ -85,36 +84,33 @@ TEST(GradOpBuilder, MutiInOut) { ...@@ -85,36 +84,33 @@ TEST(GradOpBuilder, MutiInOut) {
EXPECT_EQ(grad_test_op->Input("Out1"), "out1"); EXPECT_EQ(grad_test_op->Input("Out1"), "out1");
EXPECT_EQ(grad_test_op->Inputs("Out2_mult"), EXPECT_EQ(grad_test_op->Inputs("Out2_mult"),
std::vector<std::string>({"out2_1", "out2_2"})); std::vector<std::string>({"out2_1", "out2_2"}));
EXPECT_EQ(grad_test_op->Input("Out1" + f::kGradVarSuffix), EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out1")),
"out1" + f::kGradVarSuffix); f::GradVarName("out1"));
EXPECT_EQ(grad_test_op->Inputs("Out2_mult" + f::kGradVarSuffix), EXPECT_EQ(grad_test_op->Inputs(f::GradVarName("Out2_mult")),
std::vector<std::string>( std::vector<std::string>(
{"out2_1" + f::kGradVarSuffix, "out2_2" + f::kGradVarSuffix})); {f::GradVarName("out2_1"), f::GradVarName("out2_2")}));
ASSERT_EQ(grad_test_op->outputs_.size(), 5UL); ASSERT_EQ(grad_test_op->outputs_.size(), 3UL);
EXPECT_EQ(grad_test_op->Output("In1" + f::kGradVarSuffix), EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1"));
"in1" + f::kGradVarSuffix); EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")),
EXPECT_EQ(grad_test_op->Outputs("In2_mult" + f::kGradVarSuffix), std::vector<std::string>({f::GradVarName("in2_1"),
std::vector<std::string>({"in2_1" + f::kGradVarSuffix, f::GradVarName("in2_2"),
"in2_2" + f::kGradVarSuffix, f::GradVarName("in2_3")}));
"in2_3" + f::kGradVarSuffix})); EXPECT_EQ(grad_test_op->Output(f::GradVarName("In3")), f::GradVarName("in3"));
EXPECT_EQ(grad_test_op->Output("In3" + f::kGradVarSuffix),
"in3" + f::kGradVarSuffix);
} }
TEST(GradOpBuilder, IOIgnoredInGradient) { TEST(GradOpBuilder, IOIgnoredInGradient) {
f::AttributeMap attrs{{"input_format", std::vector<int>{0, 1, 3, 5}},
{"output_format", std::vector<int>{0, 2, 3}}};
std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp( std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp(
"io_ignored", {{"In1", {"in1"}}, "io_ignored",
{"In2_mult", {"in2_1", "in2_2"}}, {{"In1", {"in1"}},
{"In3_mult", {"in3_1", "in3_2"}}}, {"In2_mult", {"in2_1", "in2_2"}},
{{"Out1_mult", {"out1_1", "out1_2"}}, {"Out2", {"out2"}}}, attrs)); {"In3_mult", {"in3_1", "in3_2"}}},
{{"Out1_mult", {"out1_1", "out1_2"}}, {"Out2", {"out2"}}}, {}));
std::shared_ptr<f::OperatorBase> grad_test_op = std::shared_ptr<f::OperatorBase> grad_test_op =
f::OpRegistry::CreateGradOp(*test_op); f::OpRegistry::CreateGradOp(*test_op);
// 'In2' and 'Out2' are ignored in gradient calculating // 'In2' and 'Out2' are ignored in gradient calculating
ASSERT_EQ(grad_test_op->inputs_.size(), 5UL + 3UL + 3UL); ASSERT_EQ(grad_test_op->inputs_.size(), 3UL + 2UL + 2UL);
EXPECT_EQ(grad_test_op->Input("In1"), "in1"); EXPECT_EQ(grad_test_op->Input("In1"), "in1");
EXPECT_EQ(grad_test_op->Inputs("In2_mult"), EXPECT_EQ(grad_test_op->Inputs("In2_mult"),
std::vector<std::string>({f::kEmptyVarName, f::kEmptyVarName})); std::vector<std::string>({f::kEmptyVarName, f::kEmptyVarName}));
...@@ -123,19 +119,18 @@ TEST(GradOpBuilder, IOIgnoredInGradient) { ...@@ -123,19 +119,18 @@ TEST(GradOpBuilder, IOIgnoredInGradient) {
EXPECT_EQ(grad_test_op->Inputs("Out1_mult"), EXPECT_EQ(grad_test_op->Inputs("Out1_mult"),
std::vector<std::string>({"out1_1", "out1_2"})); std::vector<std::string>({"out1_1", "out1_2"}));
EXPECT_EQ(grad_test_op->Input("Out2"), f::kEmptyVarName); EXPECT_EQ(grad_test_op->Input("Out2"), f::kEmptyVarName);
EXPECT_EQ(grad_test_op->Inputs("Out1_mult" + f::kGradVarSuffix), EXPECT_EQ(grad_test_op->Inputs(f::GradVarName("Out1_mult")),
std::vector<std::string>( std::vector<std::string>(
{"out1_1" + f::kGradVarSuffix, "out1_2" + f::kGradVarSuffix})); {f::GradVarName("out1_1"), f::GradVarName("out1_2")}));
EXPECT_EQ(grad_test_op->Input("Out2" + f::kGradVarSuffix), EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out2")),
"out2" + f::kGradVarSuffix); f::GradVarName("out2"));
ASSERT_EQ(grad_test_op->outputs_.size(), 5UL); ASSERT_EQ(grad_test_op->outputs_.size(), 3UL);
EXPECT_EQ(grad_test_op->Output("In1" + f::kGradVarSuffix), EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1"));
"in1" + f::kGradVarSuffix); EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")),
EXPECT_EQ(grad_test_op->Outputs("In2_mult" + f::kGradVarSuffix),
std::vector<std::string>( std::vector<std::string>(
{"in2_1" + f::kGradVarSuffix, "in2_2" + f::kGradVarSuffix})); {f::GradVarName("in2_1"), f::GradVarName("in2_2")}));
EXPECT_EQ(grad_test_op->Outputs("In3_mult" + f::kGradVarSuffix), EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In3_mult")),
std::vector<std::string>( std::vector<std::string>(
{"in3_1" + f::kGradVarSuffix, "in3_2" + f::kGradVarSuffix})); {f::GradVarName("in3_1"), f::GradVarName("in3_2")}));
} }
...@@ -131,14 +131,6 @@ TEST(OpRegistry, DefaultValue) { ...@@ -131,14 +131,6 @@ TEST(OpRegistry, DefaultValue) {
ASSERT_EQ(op->GetAttr<float>("scale"), 1.0); ASSERT_EQ(op->GetAttr<float>("scale"), 1.0);
} }
static void SetInputFormat(paddle::framework::OpDesc* desc) {
auto attr = desc->add_attrs();
attr->set_name("input_format");
attr->set_type(paddle::framework::INTS);
attr->mutable_ints()->Add(0);
attr->mutable_ints()->Add(1);
}
TEST(OpRegistry, CustomChecker) { TEST(OpRegistry, CustomChecker) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("my_test_op"); op_desc.set_type("my_test_op");
...@@ -149,7 +141,6 @@ TEST(OpRegistry, CustomChecker) { ...@@ -149,7 +141,6 @@ TEST(OpRegistry, CustomChecker) {
auto output = op_desc.add_outputs(); auto output = op_desc.add_outputs();
output->set_op_proto_name("output"); output->set_op_proto_name("output");
*output->mutable_var_names()->Add() = "oo"; *output->mutable_var_names()->Add() = "oo";
SetInputFormat(&op_desc);
// attr 'test_attr' is not set // attr 'test_attr' is not set
bool caught = false; bool caught = false;
...@@ -189,7 +180,6 @@ TEST(OpRegistry, CustomChecker) { ...@@ -189,7 +180,6 @@ TEST(OpRegistry, CustomChecker) {
attr->set_name("test_attr"); attr->set_name("test_attr");
attr->set_type(paddle::framework::AttrType::INT); attr->set_type(paddle::framework::AttrType::INT);
attr->set_i(4); attr->set_i(4);
SetInputFormat(&op_desc);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
paddle::framework::Scope scope; paddle::framework::Scope scope;
......
...@@ -185,11 +185,11 @@ TEST(OpKernel, all) { ...@@ -185,11 +185,11 @@ TEST(OpKernel, all) {
op_desc.set_type("op_with_kernel"); op_desc.set_type("op_with_kernel");
auto* ipt = op_desc.mutable_inputs()->Add(); auto* ipt = op_desc.mutable_inputs()->Add();
*ipt->mutable_var_names()->Add() = "IN1"; *ipt->mutable_var_names()->Add() = "IN1";
ipt->set_op_proto_name("input"); ipt->set_op_proto_name("x");
auto* output = op_desc.mutable_outputs()->Add(); auto* output = op_desc.mutable_outputs()->Add();
*output->mutable_var_names()->Add() = "OUT1"; *output->mutable_var_names()->Add() = "OUT1";
output->set_op_proto_name("output"); output->set_op_proto_name("y");
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
...@@ -234,21 +234,6 @@ TEST(OpKernel, multi_inputs) { ...@@ -234,21 +234,6 @@ TEST(OpKernel, multi_inputs) {
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(3.14); attr->set_f(3.14);
auto attr0 = op_desc.mutable_attrs()->Add();
attr0->set_name("input_format");
attr0->set_type(paddle::framework::AttrType::INTS);
auto input_format = attr0->mutable_ints();
input_format->Add(0); // x0
input_format->Add(3); // k
input_format->Add(4); // end
auto attr1 = op_desc.mutable_attrs()->Add();
attr1->set_name("output_format");
attr1->set_type(paddle::framework::AttrType::INTS);
auto output_format = attr1->mutable_ints();
output_format->Add(0); // y0
output_format->Add(2); // y1
paddle::platform::CPUDeviceContext cpu_device_context; paddle::platform::CPUDeviceContext cpu_device_context;
paddle::framework::Scope scope; paddle::framework::Scope scope;
scope.NewVar("x0")->GetMutable<Tensor>(); scope.NewVar("x0")->GetMutable<Tensor>();
......
...@@ -74,6 +74,7 @@ class TestOpDescCreationMethod(unittest.TestCase): ...@@ -74,6 +74,7 @@ class TestOpDescCreationMethod(unittest.TestCase):
expected1.inputs.extend(['x', 'w', 'b']) expected1.inputs.extend(['x', 'w', 'b'])
expected1.outputs.extend(['y']) expected1.outputs.extend(['y'])
expected1.type = 'fc' expected1.type = 'fc'
# the input_format can be removed after testing
attr = expected1.attrs.add() attr = expected1.attrs.add()
attr.name = 'input_format' attr.name = 'input_format'
attr.type = attribute_pb2.INTS attr.type = attribute_pb2.INTS
...@@ -86,6 +87,7 @@ class TestOpDescCreationMethod(unittest.TestCase): ...@@ -86,6 +87,7 @@ class TestOpDescCreationMethod(unittest.TestCase):
expected2.inputs.extend(['x1', 'x2', 'x3', 'w1', 'w2', 'w3', 'b']) expected2.inputs.extend(['x1', 'x2', 'x3', 'w1', 'w2', 'w3', 'b'])
expected2.outputs.extend(['y']) expected2.outputs.extend(['y'])
expected2.type = 'fc' expected2.type = 'fc'
# the input_format can be removed after testing
attr = expected2.attrs.add() attr = expected2.attrs.add()
attr.name = 'input_format' attr.name = 'input_format'
attr.type = attribute_pb2.INTS attr.type = attribute_pb2.INTS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册