From b368c6cac4178e20d75b188d07aa69c8907a23b8 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 9 Aug 2017 14:09:31 +0800 Subject: [PATCH] Rename op_proto_name/var_names -> parameter/arguments --- paddle/framework/framework.proto | 4 ++-- paddle/framework/op_registry.h | 8 +++---- paddle/framework/op_registry_test.cc | 32 +++++++++++++------------- paddle/framework/operator_test.cc | 34 ++++++++++++++-------------- 4 files changed, 39 insertions(+), 39 deletions(-) diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index 490d7bd91bf..7077e8aa2c7 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -40,8 +40,8 @@ message OpDesc { }; message Var { - required string op_proto_name = 1; - repeated string var_names = 2; + required string parameter = 1; + repeated string arguments = 2; }; required string type = 3; diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index db23fd7bf93..f11ce8fd377 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -180,8 +180,8 @@ class OpRegistry { static std::shared_ptr CreateOp(const OpDesc& op_desc) { VarNameMap inputs; for (auto& input : op_desc.inputs()) { - auto& var_names = inputs[input.op_proto_name()]; - auto& var_names_in_proto = input.var_names(); + auto& var_names = inputs[input.parameter()]; + auto& var_names_in_proto = input.arguments(); var_names.reserve(static_cast(var_names_in_proto.size())); std::copy(var_names_in_proto.begin(), var_names_in_proto.end(), std::back_inserter(var_names)); @@ -189,8 +189,8 @@ class OpRegistry { VarNameMap outputs; for (auto& output : op_desc.outputs()) { - auto& var_names = outputs[output.op_proto_name()]; - auto& var_names_in_proto = output.var_names(); + auto& var_names = outputs[output.parameter()]; + auto& var_names_in_proto = output.arguments(); var_names.reserve(static_cast(var_names_in_proto.size())); std::copy(var_names_in_proto.begin(), var_names_in_proto.end(), std::back_inserter(var_names)); diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 7eb4de003b4..74dbf4471a0 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -58,12 +58,12 @@ TEST(OpRegistry, CreateOp) { paddle::framework::OpDesc op_desc; op_desc.set_type("cos_sim"); auto input = op_desc.add_inputs(); - input->set_op_proto_name("input"); - *input->mutable_var_names()->Add() = "aa"; + input->set_parameter("input"); + *input->mutable_arguments()->Add() = "aa"; auto output = op_desc.add_outputs(); - output->set_op_proto_name("output"); - *output->mutable_var_names()->Add() = "bb"; + output->set_parameter("output"); + *output->mutable_arguments()->Add() = "bb"; float scale = 3.3; auto attr = op_desc.mutable_attrs()->Add(); @@ -84,12 +84,12 @@ TEST(OpRegistry, IllegalAttr) { paddle::framework::OpDesc op_desc; op_desc.set_type("cos_sim"); auto input = op_desc.add_inputs(); - input->set_op_proto_name("input"); - *input->mutable_var_names()->Add() = "aa"; + input->set_parameter("input"); + *input->mutable_arguments()->Add() = "aa"; auto output = op_desc.add_outputs(); - output->set_op_proto_name("output"); - *output->mutable_var_names()->Add() = "bb"; + output->set_parameter("output"); + *output->mutable_arguments()->Add() = "bb"; auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); @@ -114,12 +114,12 @@ TEST(OpRegistry, DefaultValue) { paddle::framework::OpDesc op_desc; op_desc.set_type("cos_sim"); auto input = op_desc.add_inputs(); - input->set_op_proto_name("input"); - *input->mutable_var_names()->Add() = "aa"; + input->set_parameter("input"); + *input->mutable_arguments()->Add() = "aa"; auto output = op_desc.add_outputs(); - output->set_op_proto_name("output"); - *output->mutable_var_names()->Add() = "bb"; + output->set_parameter("output"); + *output->mutable_arguments()->Add() = "bb"; ASSERT_TRUE(op_desc.IsInitialized()); @@ -143,12 +143,12 @@ TEST(OpRegistry, CustomChecker) { paddle::framework::OpDesc op_desc; op_desc.set_type("my_test_op"); auto input = op_desc.add_inputs(); - input->set_op_proto_name("input"); - *input->mutable_var_names()->Add() = "ii"; + input->set_parameter("input"); + *input->mutable_arguments()->Add() = "ii"; auto output = op_desc.add_outputs(); - output->set_op_proto_name("output"); - *output->mutable_var_names()->Add() = "oo"; + output->set_parameter("output"); + *output->mutable_arguments()->Add() = "oo"; SetInputFormat(&op_desc); // attr 'test_attr' is not set diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index cbfbaa56c13..fa5c14b63b2 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -61,12 +61,12 @@ TEST(OperatorBase, all) { paddle::framework::OpDesc op_desc; op_desc.set_type("test_operator"); auto* ipt = op_desc.mutable_inputs()->Add(); - *ipt->mutable_var_names()->Add() = "IN1"; - ipt->set_op_proto_name("input"); + *ipt->mutable_arguments()->Add() = "IN1"; + ipt->set_parameter("input"); auto* output = op_desc.mutable_outputs()->Add(); - *output->mutable_var_names()->Add() = "OUT1"; - output->set_op_proto_name("output"); + *output->mutable_arguments()->Add() = "OUT1"; + output->set_parameter("output"); auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); attr->set_type(paddle::framework::AttrType::FLOAT); @@ -184,12 +184,12 @@ TEST(OpKernel, all) { paddle::framework::OpDesc op_desc; op_desc.set_type("op_with_kernel"); auto* ipt = op_desc.mutable_inputs()->Add(); - *ipt->mutable_var_names()->Add() = "IN1"; - ipt->set_op_proto_name("input"); + *ipt->mutable_arguments()->Add() = "IN1"; + ipt->set_parameter("input"); auto* output = op_desc.mutable_outputs()->Add(); - *output->mutable_var_names()->Add() = "OUT1"; - output->set_op_proto_name("output"); + *output->mutable_arguments()->Add() = "OUT1"; + output->set_parameter("output"); auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); @@ -217,17 +217,17 @@ TEST(OpKernel, multi_inputs) { OpDesc op_desc; op_desc.set_type("op_multi_inputs_with_kernel"); auto x = op_desc.mutable_inputs()->Add(); - x->set_op_proto_name("xs"); - *x->mutable_var_names()->Add() = "x0"; - *x->mutable_var_names()->Add() = "x1"; - *x->mutable_var_names()->Add() = "x2"; + x->set_parameter("xs"); + *x->mutable_arguments()->Add() = "x0"; + *x->mutable_arguments()->Add() = "x1"; + *x->mutable_arguments()->Add() = "x2"; auto k = op_desc.mutable_inputs()->Add(); - k->set_op_proto_name("k"); - *k->mutable_var_names()->Add() = "k0"; + k->set_parameter("k"); + *k->mutable_arguments()->Add() = "k0"; auto y = op_desc.mutable_outputs()->Add(); - y->set_op_proto_name("ys"); - *y->mutable_var_names()->Add() = "y0"; - *y->mutable_var_names()->Add() = "y1"; + y->set_parameter("ys"); + *y->mutable_arguments()->Add() = "y0"; + *y->mutable_arguments()->Add() = "y1"; auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); -- GitLab