未验证 提交 3cb6f65e 编写于 作者: J jiangcheng 提交者: GitHub

Add transformer of paddle desc and cinn desc (#36100)

* add transformer of paddle desc and cinn desc

* change LOG(FATAL) to PADDLE_THROW for ci

* full error imformation for ci

* fix some problem as review advice

* fix some bug

* move vat type utils to tansform_desc header file

* add if NOT WITH_CINN control whether compile

* build_strategy check whether open WITH_CINN

* add control WITH_CINN in cmake
上级 ab732884
...@@ -3,6 +3,11 @@ cc_library(cinn_compiled_object SRCS cinn_compiled_object.cc DEPS feed_fetch_met ...@@ -3,6 +3,11 @@ cc_library(cinn_compiled_object SRCS cinn_compiled_object.cc DEPS feed_fetch_met
cc_library(cinn_runner SRCS cinn_runner.cc DEPS cinn_cache_key cinn_compiled_object feed_fetch_method graph lod_tensor scope) cc_library(cinn_runner SRCS cinn_runner.cc DEPS cinn_cache_key cinn_compiled_object feed_fetch_method graph lod_tensor scope)
cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector) cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector)
if (WITH_CINN)
cc_library(transform_desc SRCS transform_desc.cc DEPS proto_desc cinn)
cc_test(test_transform_desc SRCS transform_desc_test.cc DEPS transform_desc)
endif()
cc_test(cinn_cache_key_test SRCS cinn_cache_key_test.cc DEPS cinn_cache_key) cc_test(cinn_cache_key_test SRCS cinn_cache_key_test.cc DEPS cinn_cache_key)
cc_test(cinn_runner_test SRCS cinn_runner_test.cc DEPS cinn_runner proto_desc) cc_test(cinn_runner_test SRCS cinn_runner_test.cc DEPS cinn_runner proto_desc)
cc_test(cinn_compiled_object_test SRCS cinn_compiled_object_test.cc DEPS cinn_compiled_object) cc_test(cinn_compiled_object_test SRCS cinn_compiled_object_test.cc DEPS cinn_compiled_object)
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/paddle2cinn/transform_desc.h"
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace paddle2cinn {
using PbVarType = framework::proto::VarType;
namespace cpp = ::cinn::frontend::paddle::cpp;
::cinn::frontend::paddle::cpp::VarDescAPI::Type TransformVarTypeToCinn(
const ::paddle::framework::proto::VarType::Type &type) {
#define SET_TYPE_CASE_ITEM(type__) \
case ::paddle::framework::proto::VarType::type__: \
return ::cinn::frontend::paddle::cpp::VarDescAPI::Type::type__; \
break;
switch (type) {
SET_TYPE_CASE_ITEM(LOD_TENSOR);
SET_TYPE_CASE_ITEM(LOD_TENSOR_ARRAY);
SET_TYPE_CASE_ITEM(LOD_RANK_TABLE);
SET_TYPE_CASE_ITEM(SELECTED_ROWS);
SET_TYPE_CASE_ITEM(FEED_MINIBATCH);
SET_TYPE_CASE_ITEM(FETCH_LIST);
SET_TYPE_CASE_ITEM(STEP_SCOPES);
SET_TYPE_CASE_ITEM(PLACE_LIST);
SET_TYPE_CASE_ITEM(READER);
default:
PADDLE_THROW(platform::errors::NotFound("Cannot found var type"));
}
#undef SET_TYPE_CASE_ITEM
}
::paddle::framework::proto::VarType::Type TransformVarTypeFromCinn(
const ::cinn::frontend::paddle::cpp::VarDescAPI::Type &type) {
#define SET_TYPE_CASE_ITEM(type__) \
case ::cinn::frontend::paddle::cpp::VarDescAPI::Type::type__: \
return ::paddle::framework::proto::VarType::type__; \
break;
switch (type) {
SET_TYPE_CASE_ITEM(LOD_TENSOR);
SET_TYPE_CASE_ITEM(LOD_TENSOR_ARRAY);
SET_TYPE_CASE_ITEM(LOD_RANK_TABLE);
SET_TYPE_CASE_ITEM(SELECTED_ROWS);
SET_TYPE_CASE_ITEM(FEED_MINIBATCH);
SET_TYPE_CASE_ITEM(FETCH_LIST);
SET_TYPE_CASE_ITEM(STEP_SCOPES);
SET_TYPE_CASE_ITEM(PLACE_LIST);
SET_TYPE_CASE_ITEM(READER);
default:
PADDLE_THROW(platform::errors::NotFound("Cannot found var type"));
}
#undef SET_TYPE_CASE_ITEM
}
::cinn::frontend::paddle::cpp::VarDescAPI::Type TransformVarDataTypeToCinn(
const ::paddle::framework::proto::VarType::Type &type) {
#define SET_DATA_TYPE_CASE_ITEM(type__) \
case ::paddle::framework::proto::VarType::type__: \
return ::cinn::frontend::paddle::cpp::VarDescAPI::Type::type__; \
break;
switch (type) {
SET_DATA_TYPE_CASE_ITEM(BOOL);
SET_DATA_TYPE_CASE_ITEM(SIZE_T);
SET_DATA_TYPE_CASE_ITEM(UINT8);
SET_DATA_TYPE_CASE_ITEM(INT8);
SET_DATA_TYPE_CASE_ITEM(INT16);
SET_DATA_TYPE_CASE_ITEM(INT32);
SET_DATA_TYPE_CASE_ITEM(INT64);
SET_DATA_TYPE_CASE_ITEM(FP16);
SET_DATA_TYPE_CASE_ITEM(FP32);
SET_DATA_TYPE_CASE_ITEM(FP64);
default:
PADDLE_THROW(platform::errors::NotFound("Cannot found var data type"));
}
#undef SET_DATA_TYPE_CASE_ITEM
}
::paddle::framework::proto::VarType::Type TransformVarDataTypeFromCpp(
const ::cinn::frontend::paddle::cpp::VarDescAPI::Type &type) {
#define SET_DATA_TYPE_CASE_ITEM(type__) \
case ::cinn::frontend::paddle::cpp::VarDescAPI::Type::type__: \
return ::paddle::framework::proto::VarType::type__; \
break;
switch (type) {
SET_DATA_TYPE_CASE_ITEM(BOOL);
SET_DATA_TYPE_CASE_ITEM(SIZE_T);
SET_DATA_TYPE_CASE_ITEM(UINT8);
SET_DATA_TYPE_CASE_ITEM(INT8);
SET_DATA_TYPE_CASE_ITEM(INT16);
SET_DATA_TYPE_CASE_ITEM(INT32);
SET_DATA_TYPE_CASE_ITEM(INT64);
SET_DATA_TYPE_CASE_ITEM(FP16);
SET_DATA_TYPE_CASE_ITEM(FP32);
SET_DATA_TYPE_CASE_ITEM(FP64);
default:
PADDLE_THROW(platform::errors::NotFound("Cannot found var data type"));
}
#undef SET_DATA_TYPE_CASE_ITEM
}
void TransformVarDescToCinn(framework::VarDesc *pb_desc,
cpp::VarDesc *cpp_desc) {
cpp_desc->SetName(pb_desc->Name());
cpp_desc->SetType(TransformVarTypeToCinn(pb_desc->GetType()));
cpp_desc->SetPersistable(pb_desc->Persistable());
if (pb_desc->Name() != "feed" && pb_desc->Name() != "fetch") {
cpp_desc->SetDataType(TransformVarDataTypeToCinn(pb_desc->GetDataType()));
cpp_desc->SetShape(pb_desc->GetShape());
}
}
void TransformVarDescFromCinn(const cpp::VarDesc &cpp_desc,
framework::VarDesc *pb_desc) {
pb_desc->Proto()->Clear();
pb_desc->SetName(cpp_desc.Name());
pb_desc->SetType(TransformVarTypeFromCinn(cpp_desc.GetType()));
pb_desc->SetPersistable(cpp_desc.Persistable());
if (cpp_desc.Name() != "feed" && cpp_desc.Name() != "fetch") {
pb_desc->SetShape(cpp_desc.GetShape());
pb_desc->SetDataType(TransformVarDataTypeFromCpp(cpp_desc.GetDataType()));
}
}
/// For OpDesc transform
void OpInputsToCinn(framework::OpDesc *pb_desc, cpp::OpDesc *cpp_desc) {
for (const std::string &param : pb_desc->InputNames()) {
cpp_desc->SetInput(param, pb_desc->Input(param));
}
}
void OpInputsFromCinn(const cpp::OpDesc &cpp_desc, framework::OpDesc *pb_desc) {
pb_desc->MutableInputs()->clear();
for (const std::string &param : cpp_desc.InputArgumentNames()) {
pb_desc->SetInput(param, cpp_desc.Input(param));
}
}
void OpOutputsToCinn(framework::OpDesc *pb_desc, cpp::OpDesc *cpp_desc) {
for (const std::string &param : pb_desc->OutputNames()) {
cpp_desc->SetOutput(param, pb_desc->Output(param));
}
}
void OpOutputsFromCinn(const cpp::OpDesc &cpp_desc,
framework::OpDesc *pb_desc) {
pb_desc->MutableOutputs()->clear();
for (const std::string &param : cpp_desc.OutputArgumentNames()) {
pb_desc->SetOutput(param, cpp_desc.Output(param));
}
}
void OpAttrsToCinn(framework::OpDesc *pb_desc, cpp::OpDesc *cpp_desc) {
using AttrType = framework::proto::AttrType;
auto set_attr = [&](const std::string &name, AttrType type) {
switch (type) {
#define IMPL_ONE(type__, T) \
case AttrType::type__: \
cpp_desc->SetAttr<T>(name, pb_desc->GetAttrIfExists<T>(name)); \
break;
IMPL_ONE(INT, int32_t);
IMPL_ONE(FLOAT, float);
IMPL_ONE(STRING, std::string);
IMPL_ONE(STRINGS, std::vector<std::string>);
IMPL_ONE(FLOATS, std::vector<float>);
IMPL_ONE(INTS, std::vector<int>);
IMPL_ONE(BOOLEAN, bool);
IMPL_ONE(LONG, int64_t);
IMPL_ONE(LONGS, std::vector<int64_t>);
case AttrType::BLOCK: {
auto i = pb_desc->GetAttrIfExists<int16_t>(name);
cpp_desc->SetAttr<int32_t>(name, i);
break;
}
default:
PADDLE_THROW(platform::errors::NotFound(
"Unsupported attr type %d found ", static_cast<int>(type)));
}
};
#undef IMPL_ONE
for (const auto &attr_name : pb_desc->AttrNames()) {
auto type = pb_desc->GetAttrType(attr_name);
set_attr(attr_name, type);
}
}
void OpAttrsFromCinn(const cpp::OpDesc &cpp_desc, framework::OpDesc *pb_desc) {
pb_desc->MutableAttrMap()->clear();
using AttrType = cpp::OpDescAPI::AttrType;
auto set_attr = [&](const std::string &name, AttrType type) {
switch (type) {
#define IMPL_ONE(type__, T) \
case AttrType::type__: \
pb_desc->SetAttr(name, cpp_desc.GetAttr<T>(name)); \
break;
IMPL_ONE(INT, int32_t);
IMPL_ONE(FLOAT, float);
IMPL_ONE(STRING, std::string);
IMPL_ONE(STRINGS, std::vector<std::string>);
IMPL_ONE(FLOATS, std::vector<float>);
IMPL_ONE(INTS, std::vector<int>);
IMPL_ONE(BOOLEAN, bool);
IMPL_ONE(LONG, int64_t);
IMPL_ONE(LONGS, std::vector<int64_t>);
default:
PADDLE_THROW(platform::errors::NotFound(
"Unsupported attr type %d found ", static_cast<int>(type)));
}
};
#undef IMPL_ONE
for (const auto &attr_name : cpp_desc.AttrNames()) {
auto type = cpp_desc.GetAttrType(attr_name);
set_attr(attr_name, type);
}
}
void TransformOpDescToCinn(framework::OpDesc *pb_desc, cpp::OpDesc *cpp_desc) {
cpp_desc->SetType(pb_desc->Type());
OpInputsToCinn(pb_desc, cpp_desc);
OpOutputsToCinn(pb_desc, cpp_desc);
OpAttrsToCinn(pb_desc, cpp_desc);
}
void TransformOpDescFromCinn(const cpp::OpDesc &cpp_desc,
framework::OpDesc *pb_desc) {
pb_desc->Proto()->Clear();
pb_desc->SetType(cpp_desc.Type());
OpInputsFromCinn(cpp_desc, pb_desc);
OpOutputsFromCinn(cpp_desc, pb_desc);
OpAttrsFromCinn(cpp_desc, pb_desc);
}
/// For BlockDesc transform
void TransformBlockDescToCinn(framework::BlockDesc *pb_desc,
cpp::BlockDesc *cpp_desc) {
cpp_desc->SetIdx(pb_desc->ID());
cpp_desc->SetParentIdx(pb_desc->Parent());
cpp_desc->SetForwardBlockIdx(pb_desc->ForwardBlockID());
cpp_desc->ClearOps();
const auto &all_ops = pb_desc->AllOps();
for (const auto &op : all_ops) {
auto *cpp_op_desc = cpp_desc->AddOp<cpp::OpDesc>();
TransformOpDescToCinn(op, cpp_op_desc);
}
cpp_desc->ClearVars();
const auto &all_vars = pb_desc->AllVars();
for (const auto &var : all_vars) {
auto *cpp_var_desc = cpp_desc->AddVar<cpp::VarDesc>();
TransformVarDescToCinn(var, cpp_var_desc);
}
}
void TransformBlockDescFromCinn(const cpp::BlockDesc &cpp_desc,
framework::BlockDesc *pb_desc) {
pb_desc->Proto()->Clear();
pb_desc->Proto()->set_idx(cpp_desc.Idx());
pb_desc->Proto()->set_parent_idx(cpp_desc.ParentIdx());
pb_desc->Proto()->set_forward_block_idx(cpp_desc.ForwardBlockIdx());
for (size_t i = 0; i < cpp_desc.OpsSize(); ++i) {
const auto &cpp_op_desc =
cpp_desc.template GetConstOp<cpp::OpDesc>(static_cast<int32_t>(i));
auto *pb_op_desc = pb_desc->AppendOp();
TransformOpDescFromCinn(cpp_op_desc, pb_op_desc);
}
for (size_t i = 0; i < cpp_desc.VarsSize(); ++i) {
const auto &cpp_var_desc =
cpp_desc.template GetConstVar<cpp::VarDesc>(static_cast<int32_t>(i));
auto *pb_var_desc = pb_desc->Var(cpp_var_desc.Name());
TransformVarDescFromCinn(cpp_var_desc, pb_var_desc);
}
}
/// For ProgramDesc transform
void TransformProgramDescToCinn(framework::ProgramDesc *pb_desc,
cpp::ProgramDesc *cpp_desc) {
if (pb_desc->Proto()->version().has_version()) {
cpp_desc->SetVersion(pb_desc->Version());
}
cpp_desc->ClearBlocks();
for (size_t i = 0; i < pb_desc->Size(); ++i) {
auto *pb_block_desc = pb_desc->MutableBlock(i);
auto *cpp_block_desc = cpp_desc->AddBlock<cpp::BlockDesc>();
TransformBlockDescToCinn(pb_block_desc, cpp_block_desc);
}
}
void TransformProgramDescFromCinn(const cpp::ProgramDesc &cpp_desc,
framework::ProgramDesc *pb_desc) {
pb_desc->Proto()->Clear();
if (cpp_desc.HasVersion()) {
pb_desc->SetVersion(cpp_desc.Version());
}
// For paddle proto program, the only way to add block is invoke
// AppendBlock(),
// the AppendBlock need one necessary parameter: const BlockDesc &parent,
// but the only function of parent is set the block's parent_idx value.
// Meanwhile a program has at least one block, so we set block0 to all
// sub-block's parent in initial and cannot remove.
// Don't worry, it will be change in "TransformBlockDescFromCinn".
auto *block0 = pb_desc->MutableBlock(0);
for (size_t i = 0; i < cpp_desc.BlocksSize(); ++i) {
const auto &cpp_block_desc = cpp_desc.GetConstBlock<cpp::BlockDesc>(i);
framework::BlockDesc *pb_block_desc = nullptr;
if (i < pb_desc->Size()) {
pb_block_desc = pb_desc->MutableBlock(i);
} else {
pb_block_desc = pb_desc->AppendBlock(*block0);
}
TransformBlockDescFromCinn(cpp_block_desc, pb_block_desc);
}
}
} // namespace paddle2cinn
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "cinn/frontend/paddle/cpp/block_desc.h"
#include "cinn/frontend/paddle/cpp/desc_api.h"
#include "cinn/frontend/paddle/cpp/op_desc.h"
#include "cinn/frontend/paddle/cpp/program_desc.h"
#include "cinn/frontend/paddle/cpp/var_desc.h"
namespace paddle {
namespace framework {
namespace paddle2cinn {
::cinn::frontend::paddle::cpp::VarDescAPI::Type TransformVarTypeToCinn(
const ::paddle::framework::proto::VarType::Type& type);
::paddle::framework::proto::VarType::Type TransformVarTypeFromCinn(
const ::cinn::frontend::paddle::cpp::VarDescAPI::Type& type);
::cinn::frontend::paddle::cpp::VarDescAPI::Type TransformVarDataTypeToCinn(
const ::paddle::framework::proto::VarType::Type& type);
::paddle::framework::proto::VarType::Type TransformVarDataTypeFromCpp(
const ::cinn::frontend::paddle::cpp::VarDescAPI::Type& type);
// Why use framework::VarDesc* rather than const framework::VarDesc& here?
// framework::VarDesc lack of many API like clear(), etc. On the other hand,
// the paddle node return framework::Desc* even if the node is const
void TransformVarDescToCinn(framework::VarDesc* pb_desc,
::cinn::frontend::paddle::cpp::VarDesc* cpp_desc);
void TransformVarDescFromCinn(
const ::cinn::frontend::paddle::cpp::VarDesc& cpp_desc,
framework::VarDesc* pb_desc);
void TransformOpDescToCinn(framework::OpDesc* pb_desc,
::cinn::frontend::paddle::cpp::OpDesc* cpp_desc);
void TransformOpDescFromCinn(
const ::cinn::frontend::paddle::cpp::OpDesc& cpp_desc,
framework::OpDesc* pb_desc);
void TransformBlockDescToCinn(
framework::BlockDesc* pb_desc,
::cinn::frontend::paddle::cpp::BlockDesc* cpp_desc);
void TransformBlockDescFromCinn(
const ::cinn::frontend::paddle::cpp::BlockDesc& cpp_desc,
framework::BlockDesc* pb_desc);
void TransformProgramDescToCinn(
framework::ProgramDesc* pb_desc,
::cinn::frontend::paddle::cpp::ProgramDesc* cpp_desc);
void TransformProgramDescFromCinn(
const ::cinn::frontend::paddle::cpp::ProgramDesc& cpp_desc,
framework::ProgramDesc* pb_desc);
} // namespace paddle2cinn
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/paddle2cinn/transform_desc.h"
namespace paddle {
namespace framework {
namespace paddle2cinn {
using PbVarType = framework::proto::VarType;
namespace cpp = ::cinn::frontend::paddle::cpp;
// check VarDesc
cpp::VarDesc CreateCppVarDesc() {
cpp::VarDesc var("test");
var.SetType(cpp::VarDescAPI::Type::LOD_TENSOR);
var.SetPersistable(true);
var.SetDataType(cpp::VarDescAPI::Type::FP32);
var.SetShape({100, 200, 300});
return var;
}
framework::VarDesc CreatePbVarDesc() {
framework::VarDesc var("test");
var.SetType(PbVarType::LOD_TENSOR);
var.SetPersistable(true);
var.SetDataType(PbVarType::FP32);
var.SetShape({100, 200, 300});
return var;
}
TEST(TransformVarDesc, cpp2pb) {
auto cpp_var = CreateCppVarDesc();
framework::VarDesc pb_var("init");
TransformVarDescFromCinn(cpp_var, &pb_var);
auto correct_var = CreatePbVarDesc();
ASSERT_EQ(pb_var.Name(), correct_var.Name());
ASSERT_EQ(pb_var.GetType(), correct_var.GetType());
ASSERT_EQ(pb_var.Persistable(), correct_var.Persistable());
ASSERT_EQ(pb_var.GetDataType(), correct_var.GetDataType());
ASSERT_EQ(pb_var.GetShape(), correct_var.GetShape());
}
TEST(TransformVarDesc, pb2cpp) {
auto pb_var = CreatePbVarDesc();
cpp::VarDesc cpp_var;
TransformVarDescToCinn(&pb_var, &cpp_var);
auto correct_var = CreateCppVarDesc();
ASSERT_EQ(cpp_var.Name(), correct_var.Name());
ASSERT_EQ(cpp_var.GetType(), correct_var.GetType());
ASSERT_EQ(cpp_var.Persistable(), correct_var.Persistable());
ASSERT_EQ(cpp_var.GetDataType(), correct_var.GetDataType());
ASSERT_EQ(cpp_var.GetShape(), correct_var.GetShape());
}
// check OpDesc
cpp::OpDesc CreateCppOpDesc() {
cpp::OpDesc op;
op.SetType("test");
op.SetInput("X", {"x1"});
op.SetInput("Y", {"y1", "y2"});
op.SetOutput("Out", {"out1"});
op.SetAttr<float>("attr_f", 0.1f);
op.SetAttr<std::string>("attr_str", "test_attr");
return op;
}
framework::OpDesc CreatePbOpDesc() {
framework::OpDesc op;
op.SetType("test");
op.SetInput("X", {"x1"});
op.SetInput("Y", {"y1", "y2"});
op.SetOutput("Out", {"out1"});
op.SetAttr("attr_f", 0.1f);
op.SetAttr("attr_str", std::string("test_attr"));
return op;
}
TEST(TransformOpDesc, cpp2pb) {
auto cpp_op = CreateCppOpDesc();
framework::OpDesc pb_op;
TransformOpDescFromCinn(cpp_op, &pb_op);
auto correct_op = CreatePbOpDesc();
ASSERT_EQ(pb_op.Type(), correct_op.Type());
ASSERT_EQ(pb_op.Inputs(), correct_op.Inputs());
ASSERT_EQ(pb_op.Outputs(), correct_op.Outputs());
ASSERT_EQ(pb_op.AttrNames(), correct_op.AttrNames());
for (const auto &attr_name : pb_op.AttrNames()) {
ASSERT_EQ(pb_op.GetAttrType(attr_name), correct_op.GetAttrType(attr_name));
}
ASSERT_EQ(pb_op.GetAttrIfExists<float>("attr_f"),
correct_op.GetAttrIfExists<float>("attr_f"));
ASSERT_EQ(pb_op.GetAttrIfExists<std::string>("attr_str"),
correct_op.GetAttrIfExists<std::string>("attr_str"));
}
TEST(TransformOpDesc, pb2cpp) {
auto pb_op = CreatePbOpDesc();
cpp::OpDesc cpp_op;
TransformOpDescToCinn(&pb_op, &cpp_op);
auto correct_op = CreateCppOpDesc();
ASSERT_EQ(cpp_op.Type(), correct_op.Type());
ASSERT_EQ(cpp_op.inputs(), correct_op.inputs());
ASSERT_EQ(cpp_op.outputs(), correct_op.outputs());
ASSERT_EQ(cpp_op.AttrNames(), correct_op.AttrNames());
ASSERT_EQ(cpp_op.attr_types(), correct_op.attr_types());
ASSERT_EQ(cpp_op.GetAttr<float>("attr_f"),
correct_op.GetAttr<float>("attr_f"));
ASSERT_EQ(cpp_op.GetAttr<std::string>("attr_str"),
correct_op.GetAttr<std::string>("attr_str"));
}
// check BlockDesc
// framework::BlockDesc is DISABLE_COPY_AND_ASSIGN, so can not return
void CreateCppBlockDesc(cpp::BlockDesc *block) {
block->SetIdx(42);
block->SetParentIdx(4);
block->SetForwardBlockIdx(32);
auto *op = block->AddOp<cpp::OpDesc>();
*op = CreateCppOpDesc();
auto *var = block->AddVar<cpp::VarDesc>();
*var = CreateCppVarDesc();
}
void CreatePbBlockDesc(framework::BlockDesc *block) {
block->Proto()->set_idx(42);
block->Proto()->set_parent_idx(4);
block->Proto()->set_forward_block_idx(32);
auto *op = block->AppendOp();
*op = CreatePbOpDesc();
auto *var = block->Var("init");
*var = CreatePbVarDesc();
}
TEST(TransformBlockDesc, cpp2pb) {
cpp::BlockDesc cpp_block;
CreateCppBlockDesc(&cpp_block);
framework::ProgramDesc pb_prog;
auto *pb_block = pb_prog.MutableBlock(0);
TransformBlockDescFromCinn(cpp_block, pb_block);
framework::ProgramDesc correct_prog;
auto *correct_block = correct_prog.MutableBlock(0);
CreatePbBlockDesc(correct_block);
ASSERT_EQ(pb_block->ID(), correct_block->ID());
ASSERT_EQ(pb_block->Parent(), correct_block->Parent());
ASSERT_EQ(pb_block->ForwardBlockID(), correct_block->ForwardBlockID());
ASSERT_EQ(pb_block->OpSize(), correct_block->OpSize());
ASSERT_EQ(pb_block->AllVars().size(), correct_block->AllVars().size());
}
TEST(TransformBlockDesc, pb2cpp) {
framework::ProgramDesc pb_prog;
auto *pb_block = pb_prog.MutableBlock(0);
CreatePbBlockDesc(pb_block);
cpp::BlockDesc cpp_block;
TransformBlockDescToCinn(pb_block, &cpp_block);
cpp::BlockDesc correct_block;
CreateCppBlockDesc(&correct_block);
ASSERT_EQ(cpp_block.Idx(), correct_block.Idx());
ASSERT_EQ(cpp_block.ParentIdx(), correct_block.ParentIdx());
ASSERT_EQ(cpp_block.ForwardBlockIdx(), correct_block.ForwardBlockIdx());
ASSERT_EQ(cpp_block.OpsSize(), correct_block.OpsSize());
ASSERT_EQ(cpp_block.VarsSize(), correct_block.VarsSize());
}
// check ProgramDesc
cpp::ProgramDesc CreateCppProgramDesc() {
cpp::ProgramDesc prog;
prog.SetVersion(22);
auto *block = prog.AddBlock<cpp::BlockDesc>();
CreateCppBlockDesc(block);
return prog;
}
framework::ProgramDesc CreatePbProgramDesc() {
framework::ProgramDesc prog;
prog.SetVersion(22);
auto *block = prog.MutableBlock(0);
CreatePbBlockDesc(block);
return prog;
}
TEST(TransformProgramDesc, cpp2pb) {
auto cpp_prog = CreateCppProgramDesc();
framework::ProgramDesc pb_prog;
TransformProgramDescFromCinn(cpp_prog, &pb_prog);
auto correct_prog = CreatePbProgramDesc();
ASSERT_EQ(pb_prog.Version(), correct_prog.Version());
ASSERT_EQ(pb_prog.Size(), correct_prog.Size());
}
TEST(TransformProgramDesc, pb2cpp) {
auto pb_prog = CreatePbProgramDesc();
cpp::ProgramDesc cpp_prog;
TransformProgramDescToCinn(&pb_prog, &cpp_prog);
auto correct_prog = CreateCppProgramDesc();
ASSERT_EQ(cpp_prog.Version(), correct_prog.Version());
ASSERT_EQ(cpp_prog.BlocksSize(), correct_prog.BlocksSize());
}
} // namespace paddle2cinn
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册