diff --git a/oneflow/core/common/gdb.cpp b/oneflow/core/common/gdb.cpp index ed76e80d8874efd3b6c816124b39291af7b40296..c26b70674000e8427ccee413cb41cf193e9c362c 100644 --- a/oneflow/core/common/gdb.cpp +++ b/oneflow/core/common/gdb.cpp @@ -1,21 +1,54 @@ #include "oneflow/core/register/blob.h" #include "oneflow/core/kernel/kernel_util.h" +#include "oneflow/core/common/util.h" +#include "oneflow/core/common/protobuf.h" namespace oneflow { // used by gdb only namespace gdb { +namespace { + +static char* MallocThenCpyD2H(const char* gpu_src, size_t size) { + char* cpu_dst = reinterpret_cast(malloc(size)); + cudaMemcpy(cpu_dst, gpu_src, size, cudaMemcpyDeviceToHost); + return cpu_dst; +} + +static void CpyH2DThenFree(char* gpu_dst, char* cpu_src, size_t size) { + cudaMemcpy(gpu_dst, cpu_src, size, cudaMemcpyHostToDevice); + free(cpu_src); +} + +template +void LoadFromStrFile(T* buf, const std::string& file_name) { + std::ifstream file(file_name); + CHECK(file.is_open()); + std::string line; + for (int64_t i = 0; std::getline(file, line); ++i) { buf[i] = oneflow_cast(line); } + file.close(); +} + +} // namespace + // used by passing std::string param static std::string param0; -static const Blob* CpuBlobCopiedFromGpuBlobPtr(uint64_t gpu_blob_ptr) { - Blob* gpu_blob = reinterpret_cast(gpu_blob_ptr); - char* cpu_body_ptr = reinterpret_cast(malloc(gpu_blob->ByteSizeOfDataContentField())); - cudaMemcpy(cpu_body_ptr, gpu_blob->dptr(), gpu_blob->ByteSizeOfDataContentField(), - cudaMemcpyDeviceToHost); - return new Blob(const_cast(gpu_blob->regst()), gpu_blob->blob_desc_ptr(), - reinterpret_cast(gpu_blob->mut_header_ptr()), cpu_body_ptr); +static void CudaMemCpyH2DThenFreeCpuPtr(uint64_t gpu_dst, uint64_t cpu_src, size_t size) { + CpyH2DThenFree(reinterpret_cast(gpu_dst), reinterpret_cast(cpu_src), size); +} + +static void* MallocCpuBufThenCudaMemCpyD2H(uint64_t gpu_src, size_t size) { + return MallocThenCpyD2H(reinterpret_cast(gpu_src), size); +} + +static void FloatBufLoadFromStrFile(uint64_t ptr, const char* file_name) { + LoadFromStrFile(reinterpret_cast(ptr), std::string(file_name)); +} + +static void Int32BufLoadFromStrFile(uint64_t ptr, const char* file_name) { + LoadFromStrFile(reinterpret_cast(ptr), std::string(file_name)); } static Blob* Blob4BnInOp(const std::function* BnInOp2Blob, @@ -23,6 +56,41 @@ static Blob* Blob4BnInOp(const std::function* BnInOp2 return (*BnInOp2Blob)(std::string(bn_in_op)); } +static HashMap> GetAllBlobNames( + const OpAttribute& op_attribute) { + std::list attrs{ + "input_bns", "input_diff_bns", "output_bns", "output_diff_bns", "data_tmp_bns", + "fw_buf_bns", "bw_buf_bns", "model_bns", "model_diff_bns", "const_model_bns", + "forward_model_bns", "const_buf_bns", "pb_input_bns", "pb_output_bns", + }; + HashMap> ret; + for (const auto& attr : attrs) { + const auto& repeated_field = GetPbRpfFromPbMessage(op_attribute, attr); + if (repeated_field.empty() == false) { ret.insert({attr, PbRpf2StdVec(repeated_field)}); } + } + return ret; +} + +void ForwardEnterBreakPoint(const OpAttribute& op_attribute, + const std::function& BnInOp2Blob) { + // do nothing +} + +void ForwardLeaveBreakPoint(const OpAttribute& op_attribute, + const std::function& BnInOp2Blob) { + // do nothing +} + +void BackwardEnterBreakPoint(const OpAttribute& op_attribute, + const std::function& BnInOp2Blob) { + // do nothing +} + +void BackwardLeaveBreakPoint(const OpAttribute& op_attribute, + const std::function& BnInOp2Blob) { + // do nothing +} + } // namespace gdb } // namespace oneflow diff --git a/oneflow/core/common/gdb.h b/oneflow/core/common/gdb.h new file mode 100644 index 0000000000000000000000000000000000000000..62093380e0bd6e71f11e3f0c9d4e73a06919b4b2 --- /dev/null +++ b/oneflow/core/common/gdb.h @@ -0,0 +1,24 @@ +#ifndef ONEFLOW_CORE_COMMON_GDB_H_ +#define ONEFLOW_CORE_COMMON_GDB_H_ + +namespace oneflow { + +namespace gdb { + +void ForwardEnterBreakPoint(const OpAttribute& op_attribute, + const std::function& BnInOp2Blob); + +void ForwardLeaveBreakPoint(const OpAttribute& op_attribute, + const std::function& BnInOp2Blob); + +void BackwardEnterBreakPoint(const OpAttribute& op_attribute, + const std::function& BnInOp2Blob); + +void BackwardLeaveBreakPoint(const OpAttribute& op_attribute, + const std::function& BnInOp2Blob); + +} // namespace gdb + +} // namespace oneflow + +#endif // ONEFLOW_CORE_COMMON_GDB_H_ diff --git a/oneflow/core/kernel/kernel.cpp b/oneflow/core/kernel/kernel.cpp index f50695ab616a05adb2bb98b09995b81a791f557b..dc81e7c99ccf81e4f27794b2663ab4fb2e14e48c 100644 --- a/oneflow/core/kernel/kernel.cpp +++ b/oneflow/core/kernel/kernel.cpp @@ -1,4 +1,5 @@ #include "oneflow/core/kernel/kernel.h" +#include "oneflow/core/common/gdb.h" namespace oneflow { @@ -30,9 +31,13 @@ void Kernel::InitModelAndConstBuf(const KernelCtx& ctx, const ParallelContext* p void Kernel::Launch(const KernelCtx& ctx, std::function BnInOp2Blob) const { if (kernel_conf_.is_forward()) { + gdb::ForwardEnterBreakPoint(op_attribute(), BnInOp2Blob); Forward(ctx, BnInOp2Blob); + gdb::ForwardLeaveBreakPoint(op_attribute(), BnInOp2Blob); } else { + gdb::BackwardEnterBreakPoint(op_attribute(), BnInOp2Blob); Backward(ctx, BnInOp2Blob); + gdb::BackwardLeaveBreakPoint(op_attribute(), BnInOp2Blob); } }