提交 b368c6ca 编写于 作者: Y Yu Yang

Rename op_proto_name/var_names -> parameter/arguments

上级 3d4ba22c
......@@ -40,8 +40,8 @@ message OpDesc {
};
message Var {
required string op_proto_name = 1;
repeated string var_names = 2;
required string parameter = 1;
repeated string arguments = 2;
};
required string type = 3;
......
......@@ -180,8 +180,8 @@ class OpRegistry {
static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc) {
VarNameMap inputs;
for (auto& input : op_desc.inputs()) {
auto& var_names = inputs[input.op_proto_name()];
auto& var_names_in_proto = input.var_names();
auto& var_names = inputs[input.parameter()];
auto& var_names_in_proto = input.arguments();
var_names.reserve(static_cast<size_t>(var_names_in_proto.size()));
std::copy(var_names_in_proto.begin(), var_names_in_proto.end(),
std::back_inserter(var_names));
......@@ -189,8 +189,8 @@ class OpRegistry {
VarNameMap outputs;
for (auto& output : op_desc.outputs()) {
auto& var_names = outputs[output.op_proto_name()];
auto& var_names_in_proto = output.var_names();
auto& var_names = outputs[output.parameter()];
auto& var_names_in_proto = output.arguments();
var_names.reserve(static_cast<size_t>(var_names_in_proto.size()));
std::copy(var_names_in_proto.begin(), var_names_in_proto.end(),
std::back_inserter(var_names));
......
......@@ -58,12 +58,12 @@ TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim");
auto input = op_desc.add_inputs();
input->set_op_proto_name("input");
*input->mutable_var_names()->Add() = "aa";
input->set_parameter("input");
*input->mutable_arguments()->Add() = "aa";
auto output = op_desc.add_outputs();
output->set_op_proto_name("output");
*output->mutable_var_names()->Add() = "bb";
output->set_parameter("output");
*output->mutable_arguments()->Add() = "bb";
float scale = 3.3;
auto attr = op_desc.mutable_attrs()->Add();
......@@ -84,12 +84,12 @@ TEST(OpRegistry, IllegalAttr) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim");
auto input = op_desc.add_inputs();
input->set_op_proto_name("input");
*input->mutable_var_names()->Add() = "aa";
input->set_parameter("input");
*input->mutable_arguments()->Add() = "aa";
auto output = op_desc.add_outputs();
output->set_op_proto_name("output");
*output->mutable_var_names()->Add() = "bb";
output->set_parameter("output");
*output->mutable_arguments()->Add() = "bb";
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
......@@ -114,12 +114,12 @@ TEST(OpRegistry, DefaultValue) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim");
auto input = op_desc.add_inputs();
input->set_op_proto_name("input");
*input->mutable_var_names()->Add() = "aa";
input->set_parameter("input");
*input->mutable_arguments()->Add() = "aa";
auto output = op_desc.add_outputs();
output->set_op_proto_name("output");
*output->mutable_var_names()->Add() = "bb";
output->set_parameter("output");
*output->mutable_arguments()->Add() = "bb";
ASSERT_TRUE(op_desc.IsInitialized());
......@@ -143,12 +143,12 @@ TEST(OpRegistry, CustomChecker) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("my_test_op");
auto input = op_desc.add_inputs();
input->set_op_proto_name("input");
*input->mutable_var_names()->Add() = "ii";
input->set_parameter("input");
*input->mutable_arguments()->Add() = "ii";
auto output = op_desc.add_outputs();
output->set_op_proto_name("output");
*output->mutable_var_names()->Add() = "oo";
output->set_parameter("output");
*output->mutable_arguments()->Add() = "oo";
SetInputFormat(&op_desc);
// attr 'test_attr' is not set
......
......@@ -61,12 +61,12 @@ TEST(OperatorBase, all) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("test_operator");
auto* ipt = op_desc.mutable_inputs()->Add();
*ipt->mutable_var_names()->Add() = "IN1";
ipt->set_op_proto_name("input");
*ipt->mutable_arguments()->Add() = "IN1";
ipt->set_parameter("input");
auto* output = op_desc.mutable_outputs()->Add();
*output->mutable_var_names()->Add() = "OUT1";
output->set_op_proto_name("output");
*output->mutable_arguments()->Add() = "OUT1";
output->set_parameter("output");
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
......@@ -184,12 +184,12 @@ TEST(OpKernel, all) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("op_with_kernel");
auto* ipt = op_desc.mutable_inputs()->Add();
*ipt->mutable_var_names()->Add() = "IN1";
ipt->set_op_proto_name("input");
*ipt->mutable_arguments()->Add() = "IN1";
ipt->set_parameter("input");
auto* output = op_desc.mutable_outputs()->Add();
*output->mutable_var_names()->Add() = "OUT1";
output->set_op_proto_name("output");
*output->mutable_arguments()->Add() = "OUT1";
output->set_parameter("output");
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
......@@ -217,17 +217,17 @@ TEST(OpKernel, multi_inputs) {
OpDesc op_desc;
op_desc.set_type("op_multi_inputs_with_kernel");
auto x = op_desc.mutable_inputs()->Add();
x->set_op_proto_name("xs");
*x->mutable_var_names()->Add() = "x0";
*x->mutable_var_names()->Add() = "x1";
*x->mutable_var_names()->Add() = "x2";
x->set_parameter("xs");
*x->mutable_arguments()->Add() = "x0";
*x->mutable_arguments()->Add() = "x1";
*x->mutable_arguments()->Add() = "x2";
auto k = op_desc.mutable_inputs()->Add();
k->set_op_proto_name("k");
*k->mutable_var_names()->Add() = "k0";
k->set_parameter("k");
*k->mutable_arguments()->Add() = "k0";
auto y = op_desc.mutable_outputs()->Add();
y->set_op_proto_name("ys");
*y->mutable_var_names()->Add() = "y0";
*y->mutable_var_names()->Add() = "y1";
y->set_parameter("ys");
*y->mutable_arguments()->Add() = "y0";
*y->mutable_arguments()->Add() = "y1";
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册