From 6a94adbe5d454a7e273912b65e9186bbb56bb71e Mon Sep 17 00:00:00 2001 From: TeFeng Chen Date: Sun, 27 Mar 2022 11:54:27 +0800 Subject: [PATCH] add check of data type and support mutable_data with compiled infos (#40920) * support check data type and mutable_data with compiled infos in paddle with cinn * update cinn_instruction_run_op_test with multi data type --- cmake/external/cinn.cmake | 2 +- .../framework/paddle2cinn/CMakeLists.txt | 4 ++ .../framework/paddle2cinn/transform_type.cc | 68 +++++++++++++++++++ .../framework/paddle2cinn/transform_type.h | 30 ++++++++ .../paddle2cinn/transform_type_test.cc | 64 +++++++++++++++++ paddle/fluid/operators/cinn/CMakeLists.txt | 4 +- .../operators/cinn/cinn_instruction_run_op.cc | 57 +++------------- .../cinn/cinn_instruction_run_op.cu.cc | 7 +- .../operators/cinn/cinn_instruction_run_op.h | 6 +- .../cinn/cinn_instruction_run_op_test.cc | 47 ++++++++++--- .../operators/cinn/cinn_launch_context.cc | 24 +++++-- .../cinn/cinn_launch_context_test.cc | 23 +++++-- 12 files changed, 254 insertions(+), 82 deletions(-) create mode 100644 paddle/fluid/framework/paddle2cinn/transform_type.cc create mode 100644 paddle/fluid/framework/paddle2cinn/transform_type.h create mode 100644 paddle/fluid/framework/paddle2cinn/transform_type_test.cc diff --git a/cmake/external/cinn.cmake b/cmake/external/cinn.cmake index d3f330ba9d..75df827a43 100644 --- a/cmake/external/cinn.cmake +++ b/cmake/external/cinn.cmake @@ -26,7 +26,7 @@ add_definitions(-w) ###################################### include(ExternalProject) set(CINN_PREFIX_DIR ${THIRD_PARTY_PATH}/CINN) -set(CINN_GIT_TAG 56879b637e2c4db19091eedad03d7cc674e092a2) +set(CINN_GIT_TAG e11c5e672f9961e28cfa403d86f99808beb58817) set(CINN_OPTIONAL_ARGS -DPY_VERSION=${PY_VERSION} -DWITH_CUDA=${WITH_GPU} -DWITH_CUDNN=${WITH_GPU} diff --git a/paddle/fluid/framework/paddle2cinn/CMakeLists.txt b/paddle/fluid/framework/paddle2cinn/CMakeLists.txt index a1d4eb20ff..75e258d147 100644 --- a/paddle/fluid/framework/paddle2cinn/CMakeLists.txt +++ b/paddle/fluid/framework/paddle2cinn/CMakeLists.txt @@ -1,6 +1,7 @@ cc_library(cinn_cache_key SRCS cinn_cache_key.cc DEPS boost graph graph_helper lod_tensor proto_desc) cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector graph_pattern_detector cinn_compiler errors enforce) cc_library(transform_desc SRCS transform_desc.cc DEPS proto_desc cinn) +cc_library(transform_type SRCS transform_type.cc DEPS errors enforce cinn) cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph transform_desc cinn) cc_library(cinn_compiler SRCS cinn_compiler.cc DEPS framework_proto graph lod_tensor cinn_cache_key cinn_graph_symbolization cinn cinn_launch_context) @@ -17,6 +18,9 @@ if (WITH_TESTING) cc_test(transform_desc_test SRCS transform_desc_test.cc DEPS transform_desc) set_tests_properties(transform_desc_test PROPERTIES LABELS "RUN_TYPE=CINN") + cc_test(transform_type_test SRCS transform_type_test.cc DEPS transform_type) + set_tests_properties(transform_type_test PROPERTIES LABELS "RUN_TYPE=CINN") + cc_test(cinn_graph_symbolization_test SRCS cinn_graph_symbolization_test.cc DEPS cinn_graph_symbolization) set_tests_properties(cinn_graph_symbolization_test PROPERTIES LABELS "RUN_TYPE=CINN") diff --git a/paddle/fluid/framework/paddle2cinn/transform_type.cc b/paddle/fluid/framework/paddle2cinn/transform_type.cc new file mode 100644 index 0000000000..0e348084d2 --- /dev/null +++ b/paddle/fluid/framework/paddle2cinn/transform_type.cc @@ -0,0 +1,68 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/paddle2cinn/transform_type.h" +#include "cinn/common/type.h" +#include "cinn/runtime/cinn_runtime.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/errors.h" + +namespace paddle::framework::paddle2cinn { + +::phi::DataType TransToPaddleDataType(const ::cinn::common::Type& type) { +#define SET_TYPE_CASE_ITEM(cinn_common_type, pd_type) \ + if (type == ::cinn::common::cinn_common_type()) { \ + return ::phi::DataType::pd_type; \ + } + + SET_TYPE_CASE_ITEM(Bool, BOOL) + SET_TYPE_CASE_ITEM(I8, INT8) + SET_TYPE_CASE_ITEM(I16, INT16) + SET_TYPE_CASE_ITEM(I32, INT32) + SET_TYPE_CASE_ITEM(I64, INT64) + SET_TYPE_CASE_ITEM(UI8, UINT8) + SET_TYPE_CASE_ITEM(UI16, UINT16) + SET_TYPE_CASE_ITEM(UI32, UINT32) + SET_TYPE_CASE_ITEM(UI64, UINT64) + SET_TYPE_CASE_ITEM(F16, FLOAT16) + SET_TYPE_CASE_ITEM(F32, FLOAT32) + SET_TYPE_CASE_ITEM(F64, FLOAT64) + + PADDLE_THROW( + platform::errors::Unimplemented("Type(%s) not supported yet", type)); + return ::phi::DataType::UNDEFINED; +#undef SET_TYPE_CASE_ITEM +} + +::phi::DataType TransToPaddleDataType(const cinn_type_t& type) { +#define SET_TYPE_CASE_ITEM(cinn_runtime_type, pd_type) \ + if (type == cinn_runtime_type()) { \ + return ::phi::DataType::pd_type; \ + } + + SET_TYPE_CASE_ITEM(cinn_bool_t, BOOL) + SET_TYPE_CASE_ITEM(cinn_int8_t, INT8) + SET_TYPE_CASE_ITEM(cinn_int32_t, INT32) + SET_TYPE_CASE_ITEM(cinn_int64_t, INT64) + SET_TYPE_CASE_ITEM(cinn_uint32_t, UINT32) + SET_TYPE_CASE_ITEM(cinn_uint64_t, UINT64) + SET_TYPE_CASE_ITEM(cinn_float32_t, FLOAT32) + SET_TYPE_CASE_ITEM(cinn_float64_t, FLOAT64) + + PADDLE_THROW(platform::errors::Unimplemented("Input type not supported yet")); + return ::phi::DataType::UNDEFINED; +#undef SET_TYPE_CASE_ITEM +} + +} // namespace paddle::framework::paddle2cinn diff --git a/paddle/fluid/framework/paddle2cinn/transform_type.h b/paddle/fluid/framework/paddle2cinn/transform_type.h new file mode 100644 index 0000000000..e44960abbd --- /dev/null +++ b/paddle/fluid/framework/paddle2cinn/transform_type.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/common/data_type.h" + +// type declaration forward +struct cinn_type_t; +namespace cinn::common { +struct Type; +} // ::cinn::common + +namespace paddle::framework::paddle2cinn { + +::phi::DataType TransToPaddleDataType(const ::cinn::common::Type& type); + +::phi::DataType TransToPaddleDataType(const cinn_type_t& type); + +} // namespace paddle::framework::paddle2cinn diff --git a/paddle/fluid/framework/paddle2cinn/transform_type_test.cc b/paddle/fluid/framework/paddle2cinn/transform_type_test.cc new file mode 100644 index 0000000000..6c5d360d34 --- /dev/null +++ b/paddle/fluid/framework/paddle2cinn/transform_type_test.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/paddle2cinn/transform_type.h" +#include "cinn/common/type.h" +#include "cinn/runtime/cinn_runtime.h" +#include "gtest/gtest.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle::framework::paddle2cinn { + +TEST(TransToPaddleDataType, common_type) { + ASSERT_EQ(::phi::DataType::BOOL, + TransToPaddleDataType(::cinn::common::Bool())); + ASSERT_EQ(::phi::DataType::INT8, TransToPaddleDataType(::cinn::common::I8())); + ASSERT_EQ(::phi::DataType::INT16, + TransToPaddleDataType(::cinn::common::I16())); + ASSERT_EQ(::phi::DataType::INT32, + TransToPaddleDataType(::cinn::common::I32())); + ASSERT_EQ(::phi::DataType::INT64, + TransToPaddleDataType(::cinn::common::I64())); + ASSERT_EQ(::phi::DataType::UINT8, + TransToPaddleDataType(::cinn::common::UI8())); + ASSERT_EQ(::phi::DataType::UINT16, + TransToPaddleDataType(::cinn::common::UI16())); + ASSERT_EQ(::phi::DataType::UINT32, + TransToPaddleDataType(::cinn::common::UI32())); + ASSERT_EQ(::phi::DataType::UINT64, + TransToPaddleDataType(::cinn::common::UI64())); + ASSERT_EQ(::phi::DataType::FLOAT16, + TransToPaddleDataType(::cinn::common::F16())); + ASSERT_EQ(::phi::DataType::FLOAT32, + TransToPaddleDataType(::cinn::common::F32())); + ASSERT_EQ(::phi::DataType::FLOAT64, + TransToPaddleDataType(::cinn::common::F64())); + ASSERT_THROW(TransToPaddleDataType(::cinn::common::Type()), + paddle::platform::EnforceNotMet); +} + +TEST(TransToPaddleDataType, runtime_type) { + ASSERT_EQ(::phi::DataType::BOOL, TransToPaddleDataType(cinn_bool_t())); + ASSERT_EQ(::phi::DataType::INT8, TransToPaddleDataType(cinn_int8_t())); + ASSERT_EQ(::phi::DataType::INT32, TransToPaddleDataType(cinn_int32_t())); + ASSERT_EQ(::phi::DataType::INT64, TransToPaddleDataType(cinn_int64_t())); + ASSERT_EQ(::phi::DataType::UINT32, TransToPaddleDataType(cinn_uint32_t())); + ASSERT_EQ(::phi::DataType::UINT64, TransToPaddleDataType(cinn_uint64_t())); + ASSERT_EQ(::phi::DataType::FLOAT32, TransToPaddleDataType(cinn_float32_t())); + ASSERT_EQ(::phi::DataType::FLOAT64, TransToPaddleDataType(cinn_float64_t())); + ASSERT_THROW(TransToPaddleDataType(cinn_type_t()), + paddle::platform::EnforceNotMet); +} + +} // namespace paddle::framework::paddle2cinn diff --git a/paddle/fluid/operators/cinn/CMakeLists.txt b/paddle/fluid/operators/cinn/CMakeLists.txt index 2092f65212..2406445e6c 100644 --- a/paddle/fluid/operators/cinn/CMakeLists.txt +++ b/paddle/fluid/operators/cinn/CMakeLists.txt @@ -1,9 +1,9 @@ include(operators) cc_library(cinn_op_helper SRCS cinn_op_helper.cc DEPS operator device_context) -cc_library(cinn_launch_context SRCS cinn_launch_context.cc DEPS ddim lod_tensor scope proto_desc graph build_strategy device_context parallel_executor cinn) +cc_library(cinn_launch_context SRCS cinn_launch_context.cc DEPS ddim lod_tensor scope proto_desc graph build_strategy device_context parallel_executor transform_type cinn) -SET(CINN_OP_DEPS parallel_executor string_helper cinn cinn_compiler cinn_op_helper cinn_launch_context) +SET(CINN_OP_DEPS parallel_executor string_helper cinn cinn_compiler cinn_op_helper cinn_launch_context transform_type) register_operators(DEPS ${CINN_OP_DEPS}) if (WITH_TESTING) diff --git a/paddle/fluid/operators/cinn/cinn_instruction_run_op.cc b/paddle/fluid/operators/cinn/cinn_instruction_run_op.cc index 8139530b80..0903c53e5e 100644 --- a/paddle/fluid/operators/cinn/cinn_instruction_run_op.cc +++ b/paddle/fluid/operators/cinn/cinn_instruction_run_op.cc @@ -47,50 +47,17 @@ class CinnInstructionRunOp : public framework::OperatorWithKernel { } protected: + /* [Why use single type kernel]: + * + * Whether the kernel data type is int, float or other type, + * which has no effect on its execution logic, so directly + * specified a data type here. + * + */ framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - // Why we need override GetExpectedKernelType? - // A cinn-graph may has no inpute var, if we use the base function, - // it will check wheter input tensors is initialized. Here we rewrite - // the function so that we can infer kernel type by output date type. - if (ctx.InputSize(kX)) { - // if the instruction has input, infer kernel type by input date type: - return OperatorWithKernel::GetExpectedKernelType(ctx); - } - - // Else infer kernel type by output date type: - // The `OutputVar` will check wheter the kOutputs iff has one output var - const framework::Variable* var = ctx.OutputVar(kOutputs); - PADDLE_ENFORCE_NE( - var, nullptr, - platform::errors::InvalidArgument( - "The cinn_instruction_run Op's Output Variable should not empty.")); - - const framework::Tensor* tensor = nullptr; - if (var->IsType()) { - tensor = &var->Get(); - } else if (var->IsType()) { - tensor = &var->Get(); - } else if (var->IsType()) { - tensor = &(var->Get().value()); - } else if (var->IsType()) { - auto t_arr = &var->Get(); - PADDLE_ENFORCE_EQ(t_arr->size(), 1UL, - platform::errors::InvalidArgument( - "The cinn_instruction_run Op should just has One " - "Output when Input empty.")); - tensor = &(t_arr->front()); - } - - PADDLE_ENFORCE_NE( - tensor, nullptr, - platform::errors::InvalidArgument( - "The cinn_instruction_run Op's Output Tensor should not empty.")); - - VLOG(4) << "The tensor [" << ctx.OutputName(kOutputs) << "]'s dtype is " - << paddle::framework::DataType2String(tensor->dtype()); - auto output_type = paddle::framework::TransToProtoVarType(tensor->dtype()); - return framework::OpKernelType(output_type, ctx.device_context()); + return framework::OpKernelType(framework::proto::VarType::FP32, + ctx.GetPlace()); } }; @@ -151,8 +118,4 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( cinn_instruction_run, - ops::CinnInstructionRunOpKernel, - ops::CinnInstructionRunOpKernel, - ops::CinnInstructionRunOpKernel, - ops::CinnInstructionRunOpKernel, - ops::CinnInstructionRunOpKernel); + ops::CinnInstructionRunOpKernel); diff --git a/paddle/fluid/operators/cinn/cinn_instruction_run_op.cu.cc b/paddle/fluid/operators/cinn/cinn_instruction_run_op.cu.cc index a1b00a1820..ea72f6c537 100644 --- a/paddle/fluid/operators/cinn/cinn_instruction_run_op.cu.cc +++ b/paddle/fluid/operators/cinn/cinn_instruction_run_op.cu.cc @@ -17,10 +17,7 @@ limitations under the License. */ namespace ops = paddle::operators; using CUDADeviceContext = paddle::platform::CUDADeviceContext; +/* see [Why use single type kernel] */ REGISTER_OP_CUDA_KERNEL( cinn_instruction_run, - ops::CinnInstructionRunOpKernel, - ops::CinnInstructionRunOpKernel, - ops::CinnInstructionRunOpKernel, - ops::CinnInstructionRunOpKernel, - ops::CinnInstructionRunOpKernel); + ops::CinnInstructionRunOpKernel); diff --git a/paddle/fluid/operators/cinn/cinn_instruction_run_op.h b/paddle/fluid/operators/cinn/cinn_instruction_run_op.h index 8847faa944..81c2d23d3f 100644 --- a/paddle/fluid/operators/cinn/cinn_instruction_run_op.h +++ b/paddle/fluid/operators/cinn/cinn_instruction_run_op.h @@ -23,6 +23,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" +#include "paddle/fluid/framework/paddle2cinn/transform_type.h" #include "paddle/fluid/operators/cinn/cinn_launch_context.h" #include "paddle/fluid/operators/cinn/cinn_op_helper.h" @@ -57,8 +58,9 @@ class CinnInstructionRunOpKernel : public framework::OpKernel { cinn_buffer_t* buffer = launch_context->GetCinnBufferOfVar(var_name); framework::Variable* var = ctx.scope().GetVar(var_name); auto* tensor = var->template GetMutable(); - buffer->memory = - reinterpret_cast(tensor->mutable_data(ctx.GetPlace())); + buffer->memory = reinterpret_cast(tensor->mutable_data( + ctx.GetPlace(), + framework::paddle2cinn::TransToPaddleDataType(buffer->type))); }; std::vector in_args = ctx.InputNames(kX); std::for_each(in_args.begin(), in_args.end(), share_argument_buffer_fn); diff --git a/paddle/fluid/operators/cinn/cinn_instruction_run_op_test.cc b/paddle/fluid/operators/cinn/cinn_instruction_run_op_test.cc index 0edbee534c..358d0fc6d0 100644 --- a/paddle/fluid/operators/cinn/cinn_instruction_run_op_test.cc +++ b/paddle/fluid/operators/cinn/cinn_instruction_run_op_test.cc @@ -46,20 +46,15 @@ TEST(CinnInstructionOpTest, TestWithElementwiseAdd) { auto compilation_key = CinnCompiler::GetInstance()->AddGraph( CreateOnlyElementwiseAddGraph("x", "y", test_op_out_name)); - // create a cinn_launch_op and run firstly to launch the compilation - // of the above graph and cache the compiled object in CinnCompiler - auto cinn_launch_op = paddle::framework::OpRegistry::CreateOp( - "cinn_launch", {{"X", {"x", "y"}}}, {{"Out", {test_op_out_name}}}, - {{"compilation_key", compilation_key}}); - - // create cinn_instruction_run_op and elementwise_add op + // create necessary ops auto cinn_instruction_run_op = paddle::framework::OpRegistry::CreateOp( "cinn_instruction_run", {{"X", {"x", "y"}}}, {{"Out", {test_op_out_name}}}, {{"cached_index", 0}, {"instruction_index", 0}}); - auto elementwise_add_op = paddle::framework::OpRegistry::CreateOp( - "elementwise_add", {{"X", {"x"}}, {"Y", {"y"}}}, - {{"Out", {add_op_out_name}}}, {{}}); + + auto cinn_launch_op = paddle::framework::OpRegistry::CreateOp( + "cinn_launch", {{"X", {"x", "y"}}}, {{"Out", {test_op_out_name}}}, + {{"compilation_key", compilation_key}}); // check case: a compiled object not cached before cinn_launch_op run, // so a cinn_instruction_run_op will throw an error @@ -69,18 +64,48 @@ TEST(CinnInstructionOpTest, TestWithElementwiseAdd) { scope.Var(test_op_out_name)->GetMutable(); ASSERT_THROW(cinn_instruction_run_op->Run(scope, place), paddle::platform::EnforceNotMet); + // run cinn_launch_op firstly to launch the compilation + // of the above graph and cache two compiled results + // of both type float and int + cinn_launch_op->Run(scope, place); + scope.EraseVars({"x", "y", test_op_out_name}); + scope.Var(test_op_out_name)->GetMutable(); + InitVariablesWithRandomValue({"x", "y"}, {30, 40}, place, &scope); cinn_launch_op->Run(scope, place); // Run ops and check the computation results auto run_and_check_fn = [&](const platform::Place& place) { framework::Scope scope; - InitVariablesWithRandomValue({"x", "y"}, {10, 20}, place, &scope); scope.Var(test_op_out_name)->GetMutable(); scope.Var(add_op_out_name)->GetMutable(); + auto elementwise_add_op = paddle::framework::OpRegistry::CreateOp( + "elementwise_add", {{"X", {"x"}}, {"Y", {"y"}}}, + {{"Out", {add_op_out_name}}}, {{}}); + + // 1. check on type float + InitVariablesWithRandomValue({"x", "y"}, {10, 20}, place, &scope); + cinn_instruction_run_op->SetAttr("cached_index", 0); cinn_instruction_run_op->Run(scope, place); elementwise_add_op->Run(scope, place); CompareOpResult(scope.GetVar(test_op_out_name), scope.GetVar(add_op_out_name)); + + // 2. check on type int to indicate cinn_instruction_run op + // can mutable data according compiled result + scope.EraseVars({"x", "y", test_op_out_name, add_op_out_name}); + scope.Var(test_op_out_name)->GetMutable(); + scope.Var(add_op_out_name)->GetMutable(); + + InitVariablesWithRandomValue({"x", "y"}, {30, 40}, place, &scope); + cinn_instruction_run_op->SetAttr("cached_index", 1); + cinn_instruction_run_op->Run(scope, place); + // need reconstruct elementwise_add_op to choose a new kernel with type int + elementwise_add_op = paddle::framework::OpRegistry::CreateOp( + "elementwise_add", {{"X", {"x"}}, {"Y", {"y"}}}, + {{"Out", {add_op_out_name}}}, {{}}); + elementwise_add_op->Run(scope, place); + CompareOpResult(scope.GetVar(test_op_out_name), + scope.GetVar(add_op_out_name)); }; // CPU diff --git a/paddle/fluid/operators/cinn/cinn_launch_context.cc b/paddle/fluid/operators/cinn/cinn_launch_context.cc index b76dd60409..b445527322 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_context.cc +++ b/paddle/fluid/operators/cinn/cinn_launch_context.cc @@ -22,12 +22,15 @@ #include "cinn/hlir/framework/scope.h" #include "cinn/hlir/framework/tensor.h" #include "cinn/runtime/cinn_runtime.h" +#include "cinn/runtime/intrinsic.h" +#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/details/build_strategy.h" #include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h" #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" +#include "paddle/fluid/framework/paddle2cinn/transform_type.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/operators/cinn/cinn_op_helper.h" @@ -211,10 +214,16 @@ void CinnLaunchContext::CheckTensorEquivalent( PADDLE_ENFORCE_EQ(paddle_tensor.dims(), cinn_dims, platform::errors::PreconditionNotMet( "Tensors' shape in variable(%s) are not equivalent, " - "paddle's shape = [%s], but cinn's shape = [%s].", + "paddle is = [%s], but cinn is = [%s].", var_name, paddle_tensor.dims(), cinn_dims)); - // TODO(CtfGo): check the underlying data type after CINN ready + auto cinn_dtype = + framework::paddle2cinn::TransToPaddleDataType(cinn_tensor->type()); + PADDLE_ENFORCE_EQ(paddle_tensor.dtype(), cinn_dtype, + platform::errors::PreconditionNotMet( + "Tensors' dtype in variable(%s) are not equivalent, " + "paddle is = [%s], but cinn is = [%s].", + var_name, paddle_tensor.dtype(), cinn_dtype)); } void CinnLaunchContext::InitializeArguments() { @@ -224,13 +233,15 @@ void CinnLaunchContext::InitializeArguments() { // assign dimensions with corresponding compiled tensor cinn_buffer->resize(cinn_tensor->shape().data().data(), cinn_tensor->shape().data().size()); + cinn_buffer->type = cinn::runtime::ToRuntimeType(cinn_tensor->type()); VLOG(4) << string::Sprintf( - "Append an argument:name(%s),dims(%s),argument size:(%lu)", arg, + "Append an argument:name(%s),dims(%s),type(%s)", framework::DDim(cinn_buffer->dims, cinn_buffer->dimensions).to_str(), - name2argument_.size()); + cinn_tensor->type()); name2argument_.emplace(arg, cinn_buffer.get()); hold_buffers_.emplace_back(std::move(cinn_buffer)); } + VLOG(4) << "Total argument size:" << name2argument_.size(); } void CinnLaunchContext::AssignExternalVariable(const std::string& var_name) { @@ -325,9 +336,8 @@ framework::ProgramDesc CinnLaunchContext::BuildCompiledProgram( } auto cinn_tensor = GetCinnTensorOfVar(var_name); - // TODO(CtfGo): set the corresponding data type after CINN ready, - // currently set as FP32 in default - var_desc->SetDataType(framework::proto::VarType::FP32); + var_desc->SetDataType(framework::TransToProtoVarType( + framework::paddle2cinn::TransToPaddleDataType(cinn_tensor->type()))); var_desc->SetShape(std::vector(cinn_tensor->shape().data().begin(), cinn_tensor->shape().data().end())); } diff --git a/paddle/fluid/operators/cinn/cinn_launch_context_test.cc b/paddle/fluid/operators/cinn/cinn_launch_context_test.cc index 4976a59d1d..15ea9a6926 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_context_test.cc +++ b/paddle/fluid/operators/cinn/cinn_launch_context_test.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include "cinn/common/target.h" +#include "cinn/common/type.h" #include "cinn/hlir/framework/graph_compiler.h" #include "cinn/hlir/framework/instruction.h" #include "cinn/hlir/framework/scope.h" @@ -93,15 +94,19 @@ CinnCompiledObject* InitDefaultCompiledObject() { std::call_once(initialized, [result = compiled_obj.get()]() { auto& scope = result->scope; scope = std::make_shared(); - scope->Var("cinn_var1"); + std::vector cinn_vars( + {"cinn_var1", "cinn_var2", "cinn_var3", "cinn_var4", "cinn_var5"}); + + // initialize variable and set data type + for (const auto& var_name : cinn_vars) { + scope->Var(var_name); + scope->GetTensor(var_name)->set_type(::cinn::common::F32()); + } + scope->GetTensor("cinn_var1")->Resize(CinnShape({3, 4})); - scope->Var("cinn_var2"); scope->GetTensor("cinn_var2")->Resize(CinnShape({6, 7, 8})); - scope->Var("cinn_var3"); scope->GetTensor("cinn_var3")->Resize(CinnShape({10, 16})); - scope->Var("cinn_var4"); scope->GetTensor("cinn_var4")->Resize(CinnShape({10, 16})); - scope->Var("cinn_var5"); scope->GetTensor("cinn_var5")->Resize(CinnShape({10, 16})); // input variables: var1, var2; output: var5 @@ -182,13 +187,17 @@ TEST_F(CinnLaunchContextTest, TestConstructResult) { TEST_F(CinnLaunchContextTest, TestCheckTensorEquivalent) { platform::CPUPlace place; framework::Scope scope; - launch_context->UpdateCapturedEnv(scope, place); auto* tensor1 = scope.Var("var1")->GetMutable(); + auto* tensor2 = scope.Var("var2")->GetMutable(); - // CheckTensorEquivalent: tensor dimension not equivalent + // dimension not equivalent tensor1->mutable_data(phi::make_ddim({3, 5}), place); ASSERT_THROW(launch_context->CheckTensorEquivalent("var1", *tensor1), paddle::platform::EnforceNotMet); + // data type not equivalent + tensor2->mutable_data(phi::make_ddim({6, 7, 8}), place); + ASSERT_THROW(launch_context->CheckTensorEquivalent("var2", *tensor2), + paddle::platform::EnforceNotMet); } TEST_F(CinnLaunchContextTest, TestBuildCompiledProgram) { -- GitLab