未验证 提交 6a94adbe 编写于 作者: T TeFeng Chen 提交者: GitHub

add check of data type and support mutable_data with compiled infos (#40920)

* support check data type and mutable_data with compiled infos in paddle with cinn

* update cinn_instruction_run_op_test with multi data type
上级 b8236b7b
......@@ -26,7 +26,7 @@ add_definitions(-w)
######################################
include(ExternalProject)
set(CINN_PREFIX_DIR ${THIRD_PARTY_PATH}/CINN)
set(CINN_GIT_TAG 56879b637e2c4db19091eedad03d7cc674e092a2)
set(CINN_GIT_TAG e11c5e672f9961e28cfa403d86f99808beb58817)
set(CINN_OPTIONAL_ARGS -DPY_VERSION=${PY_VERSION}
-DWITH_CUDA=${WITH_GPU}
-DWITH_CUDNN=${WITH_GPU}
......
cc_library(cinn_cache_key SRCS cinn_cache_key.cc DEPS boost graph graph_helper lod_tensor proto_desc)
cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector graph_pattern_detector cinn_compiler errors enforce)
cc_library(transform_desc SRCS transform_desc.cc DEPS proto_desc cinn)
cc_library(transform_type SRCS transform_type.cc DEPS errors enforce cinn)
cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph transform_desc cinn)
cc_library(cinn_compiler SRCS cinn_compiler.cc DEPS framework_proto graph lod_tensor cinn_cache_key cinn_graph_symbolization cinn cinn_launch_context)
......@@ -17,6 +18,9 @@ if (WITH_TESTING)
cc_test(transform_desc_test SRCS transform_desc_test.cc DEPS transform_desc)
set_tests_properties(transform_desc_test PROPERTIES LABELS "RUN_TYPE=CINN")
cc_test(transform_type_test SRCS transform_type_test.cc DEPS transform_type)
set_tests_properties(transform_type_test PROPERTIES LABELS "RUN_TYPE=CINN")
cc_test(cinn_graph_symbolization_test SRCS cinn_graph_symbolization_test.cc DEPS cinn_graph_symbolization)
set_tests_properties(cinn_graph_symbolization_test PROPERTIES LABELS "RUN_TYPE=CINN")
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/paddle2cinn/transform_type.h"
#include "cinn/common/type.h"
#include "cinn/runtime/cinn_runtime.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle::framework::paddle2cinn {
::phi::DataType TransToPaddleDataType(const ::cinn::common::Type& type) {
#define SET_TYPE_CASE_ITEM(cinn_common_type, pd_type) \
if (type == ::cinn::common::cinn_common_type()) { \
return ::phi::DataType::pd_type; \
}
SET_TYPE_CASE_ITEM(Bool, BOOL)
SET_TYPE_CASE_ITEM(I8, INT8)
SET_TYPE_CASE_ITEM(I16, INT16)
SET_TYPE_CASE_ITEM(I32, INT32)
SET_TYPE_CASE_ITEM(I64, INT64)
SET_TYPE_CASE_ITEM(UI8, UINT8)
SET_TYPE_CASE_ITEM(UI16, UINT16)
SET_TYPE_CASE_ITEM(UI32, UINT32)
SET_TYPE_CASE_ITEM(UI64, UINT64)
SET_TYPE_CASE_ITEM(F16, FLOAT16)
SET_TYPE_CASE_ITEM(F32, FLOAT32)
SET_TYPE_CASE_ITEM(F64, FLOAT64)
PADDLE_THROW(
platform::errors::Unimplemented("Type(%s) not supported yet", type));
return ::phi::DataType::UNDEFINED;
#undef SET_TYPE_CASE_ITEM
}
::phi::DataType TransToPaddleDataType(const cinn_type_t& type) {
#define SET_TYPE_CASE_ITEM(cinn_runtime_type, pd_type) \
if (type == cinn_runtime_type()) { \
return ::phi::DataType::pd_type; \
}
SET_TYPE_CASE_ITEM(cinn_bool_t, BOOL)
SET_TYPE_CASE_ITEM(cinn_int8_t, INT8)
SET_TYPE_CASE_ITEM(cinn_int32_t, INT32)
SET_TYPE_CASE_ITEM(cinn_int64_t, INT64)
SET_TYPE_CASE_ITEM(cinn_uint32_t, UINT32)
SET_TYPE_CASE_ITEM(cinn_uint64_t, UINT64)
SET_TYPE_CASE_ITEM(cinn_float32_t, FLOAT32)
SET_TYPE_CASE_ITEM(cinn_float64_t, FLOAT64)
PADDLE_THROW(platform::errors::Unimplemented("Input type not supported yet"));
return ::phi::DataType::UNDEFINED;
#undef SET_TYPE_CASE_ITEM
}
} // namespace paddle::framework::paddle2cinn
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/data_type.h"
// type declaration forward
struct cinn_type_t;
namespace cinn::common {
struct Type;
} // ::cinn::common
namespace paddle::framework::paddle2cinn {
::phi::DataType TransToPaddleDataType(const ::cinn::common::Type& type);
::phi::DataType TransToPaddleDataType(const cinn_type_t& type);
} // namespace paddle::framework::paddle2cinn
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/paddle2cinn/transform_type.h"
#include "cinn/common/type.h"
#include "cinn/runtime/cinn_runtime.h"
#include "gtest/gtest.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle::framework::paddle2cinn {
TEST(TransToPaddleDataType, common_type) {
ASSERT_EQ(::phi::DataType::BOOL,
TransToPaddleDataType(::cinn::common::Bool()));
ASSERT_EQ(::phi::DataType::INT8, TransToPaddleDataType(::cinn::common::I8()));
ASSERT_EQ(::phi::DataType::INT16,
TransToPaddleDataType(::cinn::common::I16()));
ASSERT_EQ(::phi::DataType::INT32,
TransToPaddleDataType(::cinn::common::I32()));
ASSERT_EQ(::phi::DataType::INT64,
TransToPaddleDataType(::cinn::common::I64()));
ASSERT_EQ(::phi::DataType::UINT8,
TransToPaddleDataType(::cinn::common::UI8()));
ASSERT_EQ(::phi::DataType::UINT16,
TransToPaddleDataType(::cinn::common::UI16()));
ASSERT_EQ(::phi::DataType::UINT32,
TransToPaddleDataType(::cinn::common::UI32()));
ASSERT_EQ(::phi::DataType::UINT64,
TransToPaddleDataType(::cinn::common::UI64()));
ASSERT_EQ(::phi::DataType::FLOAT16,
TransToPaddleDataType(::cinn::common::F16()));
ASSERT_EQ(::phi::DataType::FLOAT32,
TransToPaddleDataType(::cinn::common::F32()));
ASSERT_EQ(::phi::DataType::FLOAT64,
TransToPaddleDataType(::cinn::common::F64()));
ASSERT_THROW(TransToPaddleDataType(::cinn::common::Type()),
paddle::platform::EnforceNotMet);
}
TEST(TransToPaddleDataType, runtime_type) {
ASSERT_EQ(::phi::DataType::BOOL, TransToPaddleDataType(cinn_bool_t()));
ASSERT_EQ(::phi::DataType::INT8, TransToPaddleDataType(cinn_int8_t()));
ASSERT_EQ(::phi::DataType::INT32, TransToPaddleDataType(cinn_int32_t()));
ASSERT_EQ(::phi::DataType::INT64, TransToPaddleDataType(cinn_int64_t()));
ASSERT_EQ(::phi::DataType::UINT32, TransToPaddleDataType(cinn_uint32_t()));
ASSERT_EQ(::phi::DataType::UINT64, TransToPaddleDataType(cinn_uint64_t()));
ASSERT_EQ(::phi::DataType::FLOAT32, TransToPaddleDataType(cinn_float32_t()));
ASSERT_EQ(::phi::DataType::FLOAT64, TransToPaddleDataType(cinn_float64_t()));
ASSERT_THROW(TransToPaddleDataType(cinn_type_t()),
paddle::platform::EnforceNotMet);
}
} // namespace paddle::framework::paddle2cinn
include(operators)
cc_library(cinn_op_helper SRCS cinn_op_helper.cc DEPS operator device_context)
cc_library(cinn_launch_context SRCS cinn_launch_context.cc DEPS ddim lod_tensor scope proto_desc graph build_strategy device_context parallel_executor cinn)
cc_library(cinn_launch_context SRCS cinn_launch_context.cc DEPS ddim lod_tensor scope proto_desc graph build_strategy device_context parallel_executor transform_type cinn)
SET(CINN_OP_DEPS parallel_executor string_helper cinn cinn_compiler cinn_op_helper cinn_launch_context)
SET(CINN_OP_DEPS parallel_executor string_helper cinn cinn_compiler cinn_op_helper cinn_launch_context transform_type)
register_operators(DEPS ${CINN_OP_DEPS})
if (WITH_TESTING)
......
......@@ -47,50 +47,17 @@ class CinnInstructionRunOp : public framework::OperatorWithKernel {
}
protected:
/* [Why use single type kernel]:
*
* Whether the kernel data type is int, float or other type,
* which has no effect on its execution logic, so directly
* specified a data type here.
*
*/
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// Why we need override GetExpectedKernelType?
// A cinn-graph may has no inpute var, if we use the base function,
// it will check wheter input tensors is initialized. Here we rewrite
// the function so that we can infer kernel type by output date type.
if (ctx.InputSize(kX)) {
// if the instruction has input, infer kernel type by input date type:
return OperatorWithKernel::GetExpectedKernelType(ctx);
}
// Else infer kernel type by output date type:
// The `OutputVar` will check wheter the kOutputs iff has one output var
const framework::Variable* var = ctx.OutputVar(kOutputs);
PADDLE_ENFORCE_NE(
var, nullptr,
platform::errors::InvalidArgument(
"The cinn_instruction_run Op's Output Variable should not empty."));
const framework::Tensor* tensor = nullptr;
if (var->IsType<framework::Tensor>()) {
tensor = &var->Get<framework::Tensor>();
} else if (var->IsType<framework::LoDTensor>()) {
tensor = &var->Get<framework::LoDTensor>();
} else if (var->IsType<phi::SelectedRows>()) {
tensor = &(var->Get<phi::SelectedRows>().value());
} else if (var->IsType<framework::LoDTensorArray>()) {
auto t_arr = &var->Get<framework::LoDTensorArray>();
PADDLE_ENFORCE_EQ(t_arr->size(), 1UL,
platform::errors::InvalidArgument(
"The cinn_instruction_run Op should just has One "
"Output when Input empty."));
tensor = &(t_arr->front());
}
PADDLE_ENFORCE_NE(
tensor, nullptr,
platform::errors::InvalidArgument(
"The cinn_instruction_run Op's Output Tensor should not empty."));
VLOG(4) << "The tensor [" << ctx.OutputName(kOutputs) << "]'s dtype is "
<< paddle::framework::DataType2String(tensor->dtype());
auto output_type = paddle::framework::TransToProtoVarType(tensor->dtype());
return framework::OpKernelType(output_type, ctx.device_context());
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace());
}
};
......@@ -151,8 +118,4 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
cinn_instruction_run,
ops::CinnInstructionRunOpKernel<CPUDeviceContext, bool>,
ops::CinnInstructionRunOpKernel<CPUDeviceContext, int>,
ops::CinnInstructionRunOpKernel<CPUDeviceContext, int64_t>,
ops::CinnInstructionRunOpKernel<CPUDeviceContext, float>,
ops::CinnInstructionRunOpKernel<CPUDeviceContext, double>);
ops::CinnInstructionRunOpKernel<CPUDeviceContext, float>);
......@@ -17,10 +17,7 @@ limitations under the License. */
namespace ops = paddle::operators;
using CUDADeviceContext = paddle::platform::CUDADeviceContext;
/* see [Why use single type kernel] */
REGISTER_OP_CUDA_KERNEL(
cinn_instruction_run,
ops::CinnInstructionRunOpKernel<CUDADeviceContext, bool>,
ops::CinnInstructionRunOpKernel<CUDADeviceContext, int>,
ops::CinnInstructionRunOpKernel<CUDADeviceContext, int64_t>,
ops::CinnInstructionRunOpKernel<CUDADeviceContext, float>,
ops::CinnInstructionRunOpKernel<CUDADeviceContext, double>);
ops::CinnInstructionRunOpKernel<CUDADeviceContext, float>);
......@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/framework/paddle2cinn/transform_type.h"
#include "paddle/fluid/operators/cinn/cinn_launch_context.h"
#include "paddle/fluid/operators/cinn/cinn_op_helper.h"
......@@ -57,8 +58,9 @@ class CinnInstructionRunOpKernel : public framework::OpKernel<T> {
cinn_buffer_t* buffer = launch_context->GetCinnBufferOfVar(var_name);
framework::Variable* var = ctx.scope().GetVar(var_name);
auto* tensor = var->template GetMutable<framework::LoDTensor>();
buffer->memory =
reinterpret_cast<uint8_t*>(tensor->mutable_data<T>(ctx.GetPlace()));
buffer->memory = reinterpret_cast<uint8_t*>(tensor->mutable_data(
ctx.GetPlace(),
framework::paddle2cinn::TransToPaddleDataType(buffer->type)));
};
std::vector<std::string> in_args = ctx.InputNames(kX);
std::for_each(in_args.begin(), in_args.end(), share_argument_buffer_fn);
......
......@@ -46,20 +46,15 @@ TEST(CinnInstructionOpTest, TestWithElementwiseAdd) {
auto compilation_key = CinnCompiler::GetInstance()->AddGraph(
CreateOnlyElementwiseAddGraph("x", "y", test_op_out_name));
// create a cinn_launch_op and run firstly to launch the compilation
// of the above graph and cache the compiled object in CinnCompiler
auto cinn_launch_op = paddle::framework::OpRegistry::CreateOp(
"cinn_launch", {{"X", {"x", "y"}}}, {{"Out", {test_op_out_name}}},
{{"compilation_key", compilation_key}});
// create cinn_instruction_run_op and elementwise_add op
// create necessary ops
auto cinn_instruction_run_op = paddle::framework::OpRegistry::CreateOp(
"cinn_instruction_run", {{"X", {"x", "y"}}},
{{"Out", {test_op_out_name}}},
{{"cached_index", 0}, {"instruction_index", 0}});
auto elementwise_add_op = paddle::framework::OpRegistry::CreateOp(
"elementwise_add", {{"X", {"x"}}, {"Y", {"y"}}},
{{"Out", {add_op_out_name}}}, {{}});
auto cinn_launch_op = paddle::framework::OpRegistry::CreateOp(
"cinn_launch", {{"X", {"x", "y"}}}, {{"Out", {test_op_out_name}}},
{{"compilation_key", compilation_key}});
// check case: a compiled object not cached before cinn_launch_op run,
// so a cinn_instruction_run_op will throw an error
......@@ -69,18 +64,48 @@ TEST(CinnInstructionOpTest, TestWithElementwiseAdd) {
scope.Var(test_op_out_name)->GetMutable<LoDTensor>();
ASSERT_THROW(cinn_instruction_run_op->Run(scope, place),
paddle::platform::EnforceNotMet);
// run cinn_launch_op firstly to launch the compilation
// of the above graph and cache two compiled results
// of both type float and int
cinn_launch_op->Run(scope, place);
scope.EraseVars({"x", "y", test_op_out_name});
scope.Var(test_op_out_name)->GetMutable<LoDTensor>();
InitVariablesWithRandomValue<int>({"x", "y"}, {30, 40}, place, &scope);
cinn_launch_op->Run(scope, place);
// Run ops and check the computation results
auto run_and_check_fn = [&](const platform::Place& place) {
framework::Scope scope;
InitVariablesWithRandomValue<float>({"x", "y"}, {10, 20}, place, &scope);
scope.Var(test_op_out_name)->GetMutable<LoDTensor>();
scope.Var(add_op_out_name)->GetMutable<LoDTensor>();
auto elementwise_add_op = paddle::framework::OpRegistry::CreateOp(
"elementwise_add", {{"X", {"x"}}, {"Y", {"y"}}},
{{"Out", {add_op_out_name}}}, {{}});
// 1. check on type float
InitVariablesWithRandomValue<float>({"x", "y"}, {10, 20}, place, &scope);
cinn_instruction_run_op->SetAttr("cached_index", 0);
cinn_instruction_run_op->Run(scope, place);
elementwise_add_op->Run(scope, place);
CompareOpResult<float>(scope.GetVar(test_op_out_name),
scope.GetVar(add_op_out_name));
// 2. check on type int to indicate cinn_instruction_run op
// can mutable data according compiled result
scope.EraseVars({"x", "y", test_op_out_name, add_op_out_name});
scope.Var(test_op_out_name)->GetMutable<LoDTensor>();
scope.Var(add_op_out_name)->GetMutable<LoDTensor>();
InitVariablesWithRandomValue<int>({"x", "y"}, {30, 40}, place, &scope);
cinn_instruction_run_op->SetAttr("cached_index", 1);
cinn_instruction_run_op->Run(scope, place);
// need reconstruct elementwise_add_op to choose a new kernel with type int
elementwise_add_op = paddle::framework::OpRegistry::CreateOp(
"elementwise_add", {{"X", {"x"}}, {"Y", {"y"}}},
{{"Out", {add_op_out_name}}}, {{}});
elementwise_add_op->Run(scope, place);
CompareOpResult<int>(scope.GetVar(test_op_out_name),
scope.GetVar(add_op_out_name));
};
// CPU
......
......@@ -22,12 +22,15 @@
#include "cinn/hlir/framework/scope.h"
#include "cinn/hlir/framework/tensor.h"
#include "cinn/runtime/cinn_runtime.h"
#include "cinn/runtime/intrinsic.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/framework/paddle2cinn/transform_type.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/cinn/cinn_op_helper.h"
......@@ -211,10 +214,16 @@ void CinnLaunchContext::CheckTensorEquivalent(
PADDLE_ENFORCE_EQ(paddle_tensor.dims(), cinn_dims,
platform::errors::PreconditionNotMet(
"Tensors' shape in variable(%s) are not equivalent, "
"paddle's shape = [%s], but cinn's shape = [%s].",
"paddle is = [%s], but cinn is = [%s].",
var_name, paddle_tensor.dims(), cinn_dims));
// TODO(CtfGo): check the underlying data type after CINN ready
auto cinn_dtype =
framework::paddle2cinn::TransToPaddleDataType(cinn_tensor->type());
PADDLE_ENFORCE_EQ(paddle_tensor.dtype(), cinn_dtype,
platform::errors::PreconditionNotMet(
"Tensors' dtype in variable(%s) are not equivalent, "
"paddle is = [%s], but cinn is = [%s].",
var_name, paddle_tensor.dtype(), cinn_dtype));
}
void CinnLaunchContext::InitializeArguments() {
......@@ -224,13 +233,15 @@ void CinnLaunchContext::InitializeArguments() {
// assign dimensions with corresponding compiled tensor
cinn_buffer->resize(cinn_tensor->shape().data().data(),
cinn_tensor->shape().data().size());
cinn_buffer->type = cinn::runtime::ToRuntimeType(cinn_tensor->type());
VLOG(4) << string::Sprintf(
"Append an argument:name(%s),dims(%s),argument size:(%lu)", arg,
"Append an argument:name(%s),dims(%s),type(%s)",
framework::DDim(cinn_buffer->dims, cinn_buffer->dimensions).to_str(),
name2argument_.size());
cinn_tensor->type());
name2argument_.emplace(arg, cinn_buffer.get());
hold_buffers_.emplace_back(std::move(cinn_buffer));
}
VLOG(4) << "Total argument size:" << name2argument_.size();
}
void CinnLaunchContext::AssignExternalVariable(const std::string& var_name) {
......@@ -325,9 +336,8 @@ framework::ProgramDesc CinnLaunchContext::BuildCompiledProgram(
}
auto cinn_tensor = GetCinnTensorOfVar(var_name);
// TODO(CtfGo): set the corresponding data type after CINN ready,
// currently set as FP32 in default
var_desc->SetDataType(framework::proto::VarType::FP32);
var_desc->SetDataType(framework::TransToProtoVarType(
framework::paddle2cinn::TransToPaddleDataType(cinn_tensor->type())));
var_desc->SetShape(std::vector<int64_t>(cinn_tensor->shape().data().begin(),
cinn_tensor->shape().data().end()));
}
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <set>
#include <utility>
#include "cinn/common/target.h"
#include "cinn/common/type.h"
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/instruction.h"
#include "cinn/hlir/framework/scope.h"
......@@ -93,15 +94,19 @@ CinnCompiledObject* InitDefaultCompiledObject() {
std::call_once(initialized, [result = compiled_obj.get()]() {
auto& scope = result->scope;
scope = std::make_shared<CinnScope>();
scope->Var<CinnTensor>("cinn_var1");
std::vector<std::string> cinn_vars(
{"cinn_var1", "cinn_var2", "cinn_var3", "cinn_var4", "cinn_var5"});
// initialize variable and set data type
for (const auto& var_name : cinn_vars) {
scope->Var<CinnTensor>(var_name);
scope->GetTensor(var_name)->set_type(::cinn::common::F32());
}
scope->GetTensor("cinn_var1")->Resize(CinnShape({3, 4}));
scope->Var<CinnTensor>("cinn_var2");
scope->GetTensor("cinn_var2")->Resize(CinnShape({6, 7, 8}));
scope->Var<CinnTensor>("cinn_var3");
scope->GetTensor("cinn_var3")->Resize(CinnShape({10, 16}));
scope->Var<CinnTensor>("cinn_var4");
scope->GetTensor("cinn_var4")->Resize(CinnShape({10, 16}));
scope->Var<CinnTensor>("cinn_var5");
scope->GetTensor("cinn_var5")->Resize(CinnShape({10, 16}));
// input variables: var1, var2; output: var5
......@@ -182,13 +187,17 @@ TEST_F(CinnLaunchContextTest, TestConstructResult) {
TEST_F(CinnLaunchContextTest, TestCheckTensorEquivalent) {
platform::CPUPlace place;
framework::Scope scope;
launch_context->UpdateCapturedEnv(scope, place);
auto* tensor1 = scope.Var("var1")->GetMutable<LoDTensor>();
auto* tensor2 = scope.Var("var2")->GetMutable<LoDTensor>();
// CheckTensorEquivalent: tensor dimension not equivalent
// dimension not equivalent
tensor1->mutable_data<float>(phi::make_ddim({3, 5}), place);
ASSERT_THROW(launch_context->CheckTensorEquivalent("var1", *tensor1),
paddle::platform::EnforceNotMet);
// data type not equivalent
tensor2->mutable_data<int>(phi::make_ddim({6, 7, 8}), place);
ASSERT_THROW(launch_context->CheckTensorEquivalent("var2", *tensor2),
paddle::platform::EnforceNotMet);
}
TEST_F(CinnLaunchContextTest, TestBuildCompiledProgram) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册