/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include #include #ifdef PADDLE_WITH_CUDA #include #endif #include "cinn/cinn.h" #include "cinn/common/target.h" #include "cinn/frontend/net_builder.h" #include "cinn/frontend/syntax.h" #include "cinn/hlir/framework/graph.h" #include "cinn/hlir/framework/graph_compiler.h" #include "cinn/hlir/framework/pass.h" #include "cinn/hlir/framework/tensor.h" #include "cinn/hlir/op/use_ops.h" #include "cinn/hlir/pass/use_pass.h" namespace cinn { namespace frontend { Program CreateAddProgram() { constexpr int M = 32; constexpr int N = 24; NetBuilder builder("net_builder"); auto a = builder.CreateInput(Float(32), {M, N}); auto b = builder.CreateInput(Float(32), {M, N}); auto c = builder.add(a, b); auto d = builder.add(a, c); auto program = builder.Build(); return program; } void SetRandData(hlir::framework::Tensor tensor, Target target) { auto* data = tensor->mutable_data(target); std::random_device seed; std::default_random_engine engine(seed()); std::uniform_real_distribution dist(0.f, 1.f); size_t num_ele = tensor->shape().numel(); std::vector random_data(num_ele); for (size_t i = 0; i < num_ele; i++) { random_data[i] = dist(engine); // All random data } #ifdef PADDLE_WITH_CUDA cudaMemcpy(data, random_data.data(), num_ele * sizeof(float), cudaMemcpyHostToDevice); #else std::copy(random_data.begin(), random_data.end(), data); #endif } TEST(net_build, basic) { auto program = CreateAddProgram(); // output program for (size_t i = 0; i < program.size(); i++) { LOG(INFO) << "instruction: " << program[i]; } } TEST(net_build, program_execute_multi_elementwise_add) { auto program = CreateAddProgram(); #ifdef PADDLE_WITH_CUDA Target target = common::DefaultNVGPUTarget(); #else Target target = common::DefaultHostTarget(); #endif auto graph = std::make_shared(program, target); std::cout << "graph:\n" << graph->Visualize() << std::endl; auto scope = BuildScope(target, graph); hlir::framework::GraphCompiler gc(target, scope, graph); auto runtime_program = gc.Build(); scope->Var("A"); scope->Var("B"); auto A = scope->GetTensor("A"); auto B = scope->GetTensor("B"); SetRandData(A, target); SetRandData(B, target); runtime_program->Execute(); } TEST(net_build, program_execute_fc) { constexpr int B = 10; // batch size constexpr int M = 32; constexpr int K = 18; constexpr int N = 24; NetBuilder builder("net_builder"); auto a = builder.CreateInput(Float(32), {B, M, K}, "A"); auto w = builder.CreateInput(Float(32), {N, K}, "W"); // weight auto b = builder.CreateInput(Float(32), {N}, "B"); // bias auto mul_out = builder.mul(a, w, 2, 1); auto add_out = builder.add(mul_out, b); auto program = builder.Build(); #ifdef PADDLE_WITH_CUDA Target target = common::DefaultNVGPUTarget(); #else Target target = common::DefaultHostTarget(); #endif auto graph = std::make_shared(program, target); auto scope = BuildScope(target, graph); hlir::framework::GraphCompiler gc(target, scope, graph); auto runtime_program = gc.Build(); scope->Var(std::string(a.id())); scope->Var(std::string(w.id())); scope->Var(std::string(b.id())); scope->Var(std::string(mul_out->id)); auto a_ten = scope->GetTensor(std::string(a.id())); auto w_ten = scope->GetTensor(std::string(w.id())); auto b_ten = scope->GetTensor(std::string(b.id())); auto fake_out_ten = scope->GetTensor(std::string(mul_out->id)); auto add_out_ten = scope->GetTensor(std::string(add_out->id)); SetRandData(a_ten, target); SetRandData(w_ten, target); SetRandData(b_ten, target); runtime_program->Execute(); } } // namespace frontend } // namespace cinn