instruction.cc 16.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/hlir/framework/instruction.h"

#include <fstream>
#include <sstream>

#include "paddle/cinn/common/test_helper.h"
#include "paddle/cinn/hlir/framework/accuracy_checker.h"
#include "paddle/cinn/runtime/flags.h"
#include "paddle/cinn/utils/profiler.h"

DECLARE_bool(cinn_sync_run);
DECLARE_string(cinn_self_check_accuracy);

namespace cinn {
namespace hlir {
namespace framework {

namespace details {
class ResultsPrint {
 public:
  static ResultsPrint* GetInstance() {
    static ResultsPrint print;
    return &print;
  }

  void write(const std::string& result_str) {
    if (of_.is_open()) {
      of_ << result_str << std::endl;
    } else if (!FLAGS_cinn_self_check_accuracy.empty()) {
      std::cerr << result_str << std::endl;
    } else {
      VLOG(2) << result_str << std::endl;
    }
  }

 private:
  ResultsPrint() {
52 53 54 55
    bool print_to_file =
        !FLAGS_cinn_self_check_accuracy.empty() &&
        !cinn::runtime::CheckStringFlagTrue(FLAGS_cinn_self_check_accuracy) &&
        !cinn::runtime::CheckStringFlagFalse(FLAGS_cinn_self_check_accuracy);
56 57 58 59

    if (print_to_file) {
      of_.open(FLAGS_cinn_self_check_accuracy, std::ios_base::out);
      if (of_.is_open()) {
60 61
        LOG(INFO) << "The CINN compute results will writing into file: \""
                  << FLAGS_cinn_self_check_accuracy << "\"";
62
      } else if (!FLAGS_cinn_self_check_accuracy.empty()) {
63 64
        LOG(WARNING) << "Failed to open file: \""
                     << FLAGS_cinn_self_check_accuracy
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
                     << "\", the CINN compute result will print.";
      }
    }
  }

  ~ResultsPrint() {
    if (of_.is_open()) {
      of_.close();
    }
  }

  std::ofstream of_;
};
}  // namespace details

80 81
void Instruction::UpdateArgsCache(
    const std::map<std::string, cinn_pod_value_t>* name2podargs) {
82 83 84 85 86 87
  int cache_size = size();
  args_cached_.resize(cache_size);

  for (int i = 0; i < cache_size; ++i) {
    common::ArgsBuilder builder;
    std::vector<std::string> all_args = in_args_[i];
88 89
    all_args.insert(
        std::end(all_args), out_args_[i].begin(), out_args_[i].end());
90 91 92

    if (name2podargs != nullptr) {
      for (const auto& arg : all_args) {
93 94 95 96
        CHECK_NE(name2podargs->count(arg), 0)
            << "Argument [" << arg << "] not found in the name2podargs";
        VLOG(5) << "Get a argument, name=" << arg
                << ",type_code=" << name2podargs->at(arg).type_code();
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
        builder.Add(name2podargs->at(arg));
      }
    } else {
      for (const auto& arg : all_args) {
        auto* var = scope_->FindVar(arg);
        CHECK(var) << "Argument [" << arg << "] not found in the scope";

        // TODO(Superjomn) Support other types.
        auto& tensor = absl::get<Tensor>(*var);
        VLOG(5) << "Get a argument, name=" << arg;
        builder.Add(tensor->buffer());
      }
    }

    args_cached_[i] = builder.Build();
  }
}

void Instruction::Finalize() {
  if (fn_ptrs_.size() > 1 && fn_ptrs_.size() != in_args_.size()) {
    out_args_.back()[0] = out_args_.front()[0];
    out_args_.erase(out_args_.begin());
    in_args_.erase(in_args_.begin());
  }

  finalized_flag_ = true;
}

125 126 127 128 129 130 131
void Instruction::Run(
    const std::map<std::string, cinn_pod_value_t>* name2podargs,
    bool dryrun,
    void* stream,
    bool use_cache) {
  utils::RecordEvent record_run(function_name_,
                                cinn::utils::EventType::kInstruction);
132 133 134 135 136 137 138 139 140
  CHECK(finalized_flag_) << "Instruction must be finalized before run";
  if (function_name_ == "no_run") {
    VLOG(2) << "skip instruction";
    return;
  }

  VLOG(2) << "Run function " << function_name_;

  {
141 142
    utils::RecordEvent record_args("UpdateArgsCache",
                                   cinn::utils::EventType::kInstruction);
143 144 145 146 147
    if (!use_cache || args_cached_.size() != size()) {
      UpdateArgsCache(name2podargs);
    }
  }

148 149
  utils::RecordEvent record_args("Instruction::Run",
                                 cinn::utils::EventType::kInstruction);
150 151 152 153
#if defined(CINN_WITH_CUDA) && !defined(CINN_WITH_CUDNN)
  if (function_name_ == "cublas_gemm" && target_.arch == Target::Arch::NVGPU) {
    auto& pod_args = args_cached_[0];
    VLOG(3) << "The pod_args size of cublas_gemm: " << pod_args.size();
154 155 156 157 158 159 160 161
    runtime::cuda::cinn_gpu_cublas_gemm(attrs,
                                        pod_args[0],
                                        pod_args[1],
                                        pod_args[2],
                                        pod_args[3],
                                        static_cast<cudaStream_t>(stream));
  } else if (function_name_ == "cublas_matmul" &&
             target_.arch == Target::Arch::NVGPU) {
162 163
    auto& pod_args = args_cached_[0];
    VLOG(3) << "The pod_args size of cublas_matmul: " << pod_args.size();
164 165 166 167 168 169
    runtime::cuda::cinn_gpu_cublas_gemm(attrs,
                                        pod_args[0],
                                        pod_args[1],
                                        nullptr,
                                        pod_args[2],
                                        static_cast<cudaStream_t>(stream));
170 171 172 173 174
  } else {
    VLOG(3) << "Runing extern function " << function_name_;
    for (int idx = 0; idx < fn_ptrs_.size(); ++idx) {
      VLOG(3) << "Runing func name: " << fn_names_[idx];
      auto& pod_args = args_cached_[idx];
175 176
      CHECK(fn_ptrs_[idx]) << "The LoweredFunc address should be set first by "
                              "calling SetLoweredFunc method";
177 178
      if (!dryrun) {
        if (target_ == common::DefaultNVGPUTarget()) {
179 180
          ((lower_func_ptr_g)fn_ptrs_[idx])(
              static_cast<void*>(pod_args.data()), pod_args.size(), stream);
181
        } else {
182 183
          ((lower_func_ptr_t)fn_ptrs_[idx])(static_cast<void*>(pod_args.data()),
                                            pod_args.size());
184 185 186 187 188 189 190
        }
      }
    }
    VLOG(3) << "Done Runing extern function " << function_name_;
  }
#elif defined(CINN_WITH_CUDNN)
  auto& pod_args = args_cached_[0];
191 192 193 194
  // Here conv2d and depthwise_conv2d are implemented by one cudnn api
  // cudnnConvolutionForward
  if ((function_name_ == "conv2d" || function_name_ == "depthwise_conv2d") &&
      target_.arch == Target::Arch::NVGPU) {
195 196 197
    if (str_attrs[0] == "forward") {
      if (str_attrs.size() > 1 && str_attrs[1] == "NHWC") {
        absl::flat_hash_map<std::string, int> attrs_map = {
198 199 200 201 202 203 204 205 206 207
            {"input_n", attrs[0]},     {"input_h", attrs[1]},
            {"input_w", attrs[2]},     {"input_c", attrs[3]},
            {"weights_n", attrs[4]},   {"weights_c", attrs[5]},
            {"weights_h", attrs[6]},   {"weights_w", attrs[7]},
            {"pad_h", attrs[8]},       {"pad_w", attrs[9]},
            {"stride_h", attrs[10]},   {"stride_w", attrs[11]},
            {"dilation_h", attrs[12]}, {"dilation_w", attrs[13]},
            {"groups", attrs[14]},     {"output_n", attrs[15]},
            {"output_h", attrs[16]},   {"output_w", attrs[17]},
            {"output_c", attrs[18]},
208
        };
209 210 211 212 213 214
        runtime::cuda::cinn_gpu_cudnn_conv2d(attrs_map,
                                             pod_args[0],
                                             pod_args[1],
                                             pod_args[2],
                                             static_cast<cudaStream_t>(stream),
                                             common::Layout::kNHWC);
215 216 217

      } else {
        absl::flat_hash_map<std::string, int> attrs_map = {
218 219 220 221 222 223 224 225 226 227
            {"input_n", attrs[0]},     {"input_c", attrs[1]},
            {"input_h", attrs[2]},     {"input_w", attrs[3]},
            {"weights_n", attrs[4]},   {"weights_c", attrs[5]},
            {"weights_h", attrs[6]},   {"weights_w", attrs[7]},
            {"pad_h", attrs[8]},       {"pad_w", attrs[9]},
            {"stride_h", attrs[10]},   {"stride_w", attrs[11]},
            {"dilation_h", attrs[12]}, {"dilation_w", attrs[13]},
            {"groups", attrs[14]},     {"output_n", attrs[15]},
            {"output_c", attrs[16]},   {"output_h", attrs[17]},
            {"output_w", attrs[18]},
228
        };
229 230 231 232 233 234
        runtime::cuda::cinn_gpu_cudnn_conv2d(attrs_map,
                                             pod_args[0],
                                             pod_args[1],
                                             pod_args[2],
                                             static_cast<cudaStream_t>(stream),
                                             common::Layout::kNCHW);
235 236 237 238
      }
    } else if (str_attrs[0] == "backward_data") {
      // w, dy, dx
      absl::flat_hash_map<std::string, int> attrs_map = {
239 240 241 242 243 244 245 246 247 248
          {"input_n", attrs[15]},    {"input_c", attrs[16]},
          {"input_h", attrs[17]},    {"input_w", attrs[18]},
          {"weights_n", attrs[0]},   {"weights_c", attrs[1]},
          {"weights_h", attrs[2]},   {"weights_w", attrs[3]},
          {"pad_h", attrs[8]},       {"pad_w", attrs[9]},
          {"stride_h", attrs[10]},   {"stride_w", attrs[11]},
          {"dilation_h", attrs[12]}, {"dilation_w", attrs[13]},
          {"groups", attrs[14]},     {"output_n", attrs[4]},
          {"output_c", attrs[5]},    {"output_h", attrs[6]},
          {"output_w", attrs[7]},
249 250 251
      };
      // w, dy, dx
      runtime::cuda::cinn_gpu_cudnn_conv2d_backward_data(
252 253 254 255 256
          attrs_map,
          pod_args[0],
          pod_args[1],
          pod_args[2],
          static_cast<cudaStream_t>(stream));
257 258 259
    } else {
      // x, dy, w
      absl::flat_hash_map<std::string, int> attrs_map = {
260 261 262 263 264 265 266 267 268 269
          {"input_n", attrs[0]},     {"input_c", attrs[1]},
          {"input_h", attrs[2]},     {"input_w", attrs[3]},
          {"weights_n", attrs[15]},  {"weights_c", attrs[16]},
          {"weights_h", attrs[17]},  {"weights_w", attrs[18]},
          {"pad_h", attrs[8]},       {"pad_w", attrs[9]},
          {"stride_h", attrs[10]},   {"stride_w", attrs[11]},
          {"dilation_h", attrs[12]}, {"dilation_w", attrs[13]},
          {"groups", attrs[14]},     {"output_n", attrs[4]},
          {"output_c", attrs[5]},    {"output_h", attrs[6]},
          {"output_w", attrs[7]},
270 271 272
      };
      // x, dy, w
      runtime::cuda::cinn_gpu_cudnn_conv2d_backward_filter(
273 274 275 276 277
          attrs_map,
          pod_args[0],
          pod_args[1],
          pod_args[2],
          static_cast<cudaStream_t>(stream));
278
    }
279 280 281 282 283 284 285 286 287
  } else if (function_name_ == "pool2d" &&
             target_.arch == Target::Arch::NVGPU) {
    runtime::cuda::cinn_gpu_cudnn_pool2d(attrs,
                                         str_attrs,
                                         pod_args[0],
                                         pod_args[1],
                                         static_cast<cudaStream_t>(stream));
  } else if (function_name_ == "softmax" &&
             target_.arch == Target::Arch::NVGPU) {
288
    CHECK_EQ(pod_args.size(), 3);
289 290
    runtime::cuda::cinn_gpu_cudnn_softmax(
        attrs, pod_args[0], pod_args[1], static_cast<cudaStream_t>(stream));
291 292
  } else if (function_name_ == "mul" && target_.arch == Target::Arch::NVGPU) {
    CHECK_EQ(pod_args.size(), 4);
293 294 295 296 297 298 299
    runtime::cuda::cinn_gpu_cublas_mul(attrs,
                                       pod_args[0],
                                       pod_args[1],
                                       pod_args[2],
                                       static_cast<cudaStream_t>(stream));
  } else if (function_name_ == "cublas_gemm" &&
             target_.arch == Target::Arch::NVGPU) {
300
    VLOG(3) << "The pod_args size of cublas_gemm: " << pod_args.size();
301 302 303 304 305 306 307 308
    runtime::cuda::cinn_gpu_cublas_gemm(attrs,
                                        pod_args[0],
                                        pod_args[1],
                                        pod_args[2],
                                        pod_args[3],
                                        static_cast<cudaStream_t>(stream));
  } else if (function_name_ == "cublas_matmul" &&
             target_.arch == Target::Arch::NVGPU) {
309 310
    auto& pod_args = args_cached_[0];
    VLOG(3) << "The pod_args size of cublas_matmul: " << pod_args.size();
311 312 313 314 315 316
    runtime::cuda::cinn_gpu_cublas_gemm(attrs,
                                        pod_args[0],
                                        pod_args[1],
                                        nullptr,
                                        pod_args[2],
                                        static_cast<cudaStream_t>(stream));
317 318 319 320 321
  } else {
    VLOG(3) << "Runing extern function " << function_name_;
    for (int idx = 0; idx < fn_ptrs_.size(); ++idx) {
      VLOG(3) << "Runing func name: " << fn_names_[idx];
      auto& pod_args = args_cached_[idx];
322 323
      CHECK(fn_ptrs_[idx]) << "The LoweredFunc address should be set first by "
                              "calling SetLoweredFunc method";
324 325
      if (!dryrun) {
        if (target_ == common::DefaultNVGPUTarget()) {
326 327
          ((lower_func_ptr_g)fn_ptrs_[idx])(
              static_cast<void*>(pod_args.data()), pod_args.size(), stream);
328
        } else {
329 330
          ((lower_func_ptr_t)fn_ptrs_[idx])(static_cast<void*>(pod_args.data()),
                                            pod_args.size());
331 332 333 334 335 336 337 338 339 340
        }
      }
    }
    VLOG(3) << "Done Runing extern function " << function_name_;
  }
#else
  VLOG(3) << "Runing extern function " << function_name_;
  for (int idx = 0; idx < fn_ptrs_.size(); ++idx) {
    VLOG(3) << "Runing func name: " << fn_names_[idx];
    auto& pod_args = args_cached_[idx];
341 342
    CHECK(fn_ptrs_[idx]) << "The LoweredFunc address should be set first by "
                            "calling SetLoweredFunc method";
343 344
    if (!dryrun) {
      if (target_ == common::DefaultNVGPUTarget()) {
345 346
        ((lower_func_ptr_g)fn_ptrs_[idx])(
            static_cast<void*>(pod_args.data()), pod_args.size(), stream);
347
      } else {
348 349
        ((lower_func_ptr_t)fn_ptrs_[idx])(static_cast<void*>(pod_args.data()),
                                          pod_args.size());
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367
      }
    }
  }
  VLOG(3) << "Done Runing extern function " << function_name_;
#endif

  if (!cinn::runtime::CheckStringFlagFalse(FLAGS_cinn_self_check_accuracy)) {
    CheckResults(name2podargs, stream);
  }
  // TODO(thisjiang): revert while flags correct
  //   if (FLAGS_cinn_sync_run) {
  // #ifdef CINN_WITH_CUDA
  //     utils::RecordEvent record_sync("FLAGS_cinn_sync_run");
  //     CUDA_CALL(cudaStreamSynchronize(static_cast<cudaStream_t>(stream)));
  // #endif
  //   }
}

368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
std::string Instruction::DumpInstruction() {
  std::stringstream ss;
  ss << "Instruction {" << std::endl;
  for (size_t i = 0; i < fn_names_.size(); ++i) {
    ss << "  Function " << fn_names_[i] << ":" << std::endl;
    ss << "    function ptr: " << fn_ptrs_[i] << std::endl;

    auto in_arg = in_args_[i];
    std::sort(in_arg.begin(), in_arg.end());
    for (auto& in_name : in_arg) {
      ss << "    input: " << in_name << std::endl;
    }

    auto out_arg = out_args_[i];
    std::sort(out_arg.begin(), out_arg.end());
    for (auto& out_name : out_arg) {
      ss << "    output: " << out_name << std::endl;
    }
  }
  ss << "}" << std::endl;
  return ss.str();
}

391 392
void Instruction::CheckResults(
    const std::map<std::string, cinn_pod_value_t>* name2podargs, void* stream) {
393 394 395 396 397
#ifdef CINN_WITH_CUDA
  cudaStreamSynchronize(static_cast<cudaStream_t>(stream));
#endif

  if (fn_names_.size() == 1) {
398 399
    std::unordered_set<std::string> skipped_instr_set = {
        "malloc_buffer_instruction", "free_buffer_instruction"};
400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446
    for (auto& name : skipped_instr_set) {
      if (fn_names_[0].find(name) != std::string::npos) {
        // Skip the malloc & free buffer instructions.
        return;
      }
    }
  }

  AccuracyChecker checker(target_, scope_);

  std::stringstream ss;
  ss << "Instruction {" << std::endl;
  for (size_t i = 0; i < fn_names_.size(); ++i) {
    ss << "  Function " << fn_names_[i] << ":" << std::endl;

    auto in_arg = in_args_[i];
    std::sort(in_arg.begin(), in_arg.end());
    for (auto& in_name : in_arg) {
      std::string result_str;
      if (name2podargs) {
        result_str = checker(name2podargs, in_name);
      } else {
        result_str = checker(in_name);
      }
      ss << "    input: " << result_str << std::endl;
    }

    auto out_arg = out_args_[i];
    std::sort(out_arg.begin(), out_arg.end());
    for (auto& out_name : out_arg) {
      std::string result_str;
      if (name2podargs) {
        result_str = checker(name2podargs, out_name);
      } else {
        result_str = checker(out_name);
      }
      ss << "    output: " << result_str << std::endl;
    }
  }
  ss << "}" << std::endl;

  details::ResultsPrint::GetInstance()->write(ss.str());
}

}  // namespace framework
}  // namespace hlir
}  // namespace cinn