diff --git a/config/debugger_config.json b/config/debugger_config.json new file mode 100644 index 0000000000000000000000000000000000000000..ac0766d2122ce433b1f6d6c390432776c1061561 --- /dev/null +++ b/config/debugger_config.json @@ -0,0 +1,14 @@ +{ + "DebuggerSettings": { + "ssl_certificate": true, + "certificate_path": "/home/maning/sslcertificates/client.pfx", + "certificate_passphrase": "12345678" + }, + + "DebuggerSettingsSpec": { + "ssl_certificate": "true, secure_mode enable, default is true", + "certificate_path": "path to the certificate file", + "certificate_passphrase": "passphrase of the certificate" + }, + "other": {} +} \ No newline at end of file diff --git a/mindspore/ccsrc/debug/debugger/debugger.cc b/mindspore/ccsrc/debug/debugger/debugger.cc index 01c46d1393d55768c0074e01e1eaabe700d78c0c..4da4755301cb4b4256aa6c31260b96ab6b4bdf11 100644 --- a/mindspore/ccsrc/debug/debugger/debugger.cc +++ b/mindspore/ccsrc/debug/debugger/debugger.cc @@ -30,6 +30,11 @@ #include "backend/session/anf_runtime_algorithm.h" #include "runtime/device/kernel_runtime_manager.h" +#include "utils/system/file_system.h" +#include "utils/system/env.h" +#include +using json = nlohmann::json; + using debugger::EventReply; using debugger::GraphProto; using debugger::ModelProto; @@ -60,7 +65,10 @@ Debugger::Debugger() is_dataset_graph_(false), partial_memory_(false), last_overflow_bin_(0), - overflow_bin_path_("") {} + overflow_bin_path_(""), + ssl_certificate(true), + certificate_dir(""), + certificate_passphrase("") {} void Debugger::Init(const uint32_t device_id, const std::string device_target) { // access lock for public method @@ -175,7 +183,8 @@ void Debugger::EnableDebugger() { // initialize grpc client if (debugger_enabled_) { - grpc_client_ = std::make_unique(host, port); + SetDebuggerConfFromJsonFile(); + grpc_client_ = std::make_unique(host, port, ssl_certificate, certificate_dir, certificate_passphrase); } debug_services_ = std::make_unique(); @@ -787,4 +796,88 @@ std::vector Debugger::CheckOpOverflow() { return op_names; } +bool Debugger::SetDebuggerConfFromJsonFile() { + const char *config_path_str = std::getenv("MINDSPORE_CONFIG_PATH"); + if (config_path_str != nullptr) { + MS_LOG(INFO) << "Getenv MINDSPORE_CONFIG_PATH :" << config_path_str; + } else { + MS_LOG(INFO) << "No need debugger config path. please export MINDSPORE_CONFIG_PATH eg: MINDSPORE_CONFIG_PATH=/etc"; + ssl_certificate = false; + return false; + } + char real_path[4096] = {0}; + if (nullptr == realpath(config_path_str, real_path)) { + MS_LOG(ERROR) << "Env debugger config path error, " << config_path_str; + ssl_certificate = false; + return false; + } + std::string debugger_config_file = std::string(real_path) + "/debugger_config.json"; + std::shared_ptr fs = system::Env::GetFileSystem(); + MS_EXCEPTION_IF_NULL(fs); + if (!fs->FileExist(debugger_config_file)) { + MS_LOG(ERROR) << debugger_config_file << " not exist."; + ssl_certificate = false; + return false; + } + + return ParseDebuggerConfig(debugger_config_file); +} + +bool Debugger::ParseDebuggerConfig(const std::string &debugger_config_file) { + std::ifstream jsonFile(debugger_config_file); + if (!jsonFile.is_open()) { + MS_LOG(ERROR) << debugger_config_file << " open failed."; + ssl_certificate = false; + return false; + } + json j; + jsonFile >> j; + if (j.find("DebuggerSettings") == j.end()) { + MS_LOG(ERROR) << "DebuggerSettings is not exist."; + ssl_certificate = false; + return false; + } else { + json debuggerSettings = j.at("DebuggerSettings"); + // convert json to string + std::stringstream ss; + ss << debuggerSettings; + std::string cfg = ss.str(); + MS_LOG(INFO) << "Debugger Settings Json: " << cfg; + if (!IsConfigExist(debuggerSettings)) { + return false; + } + if (!IsConfigValid(debuggerSettings)) { + return false; + } + } + return true; +} + +bool Debugger::IsConfigExist(const nlohmann::json &debuggerSettings) { + if (debuggerSettings.find("ssl_certificate") == debuggerSettings.end() || + debuggerSettings.find("certificate_dir") == debuggerSettings.end() || + debuggerSettings.find("certificate_passphrase") == debuggerSettings.end()) { + MS_LOG(ERROR) << "DebuggerSettings keys is not exist."; + ssl_certificate = false; + return false; + } + return true; +} + +bool Debugger::IsConfigValid(const nlohmann::json &debuggerSettings) { + auto enable_secure = debuggerSettings.at("ssl_certificate"); + auto certificate_dir_ = debuggerSettings.at("certificate_dir"); + auto certificate_passphrase_ = debuggerSettings.at("certificate_passphrase"); + if (!(enable_secure.is_boolean() && certificate_dir_.is_string() && certificate_passphrase_.is_string())) { + MS_LOG(ERROR) << "Element's type in Debugger config json is invalid."; + ssl_certificate = false; + return false; + } + + ssl_certificate = enable_secure; + certificate_dir = certificate_dir_; + certificate_passphrase = certificate_passphrase_; + return true; +} + } // namespace mindspore diff --git a/mindspore/ccsrc/debug/debugger/debugger.h b/mindspore/ccsrc/debug/debugger/debugger.h index bfccd6aabad8b95cf65be313212691bbda73ce11..09059851847f26c6ae56390ee888eb91f077f39b 100644 --- a/mindspore/ccsrc/debug/debugger/debugger.h +++ b/mindspore/ccsrc/debug/debugger/debugger.h @@ -25,6 +25,7 @@ #include "backend/session/kernel_graph.h" #include "debug/debugger/grpc_client.h" #include "debug/debug_services.h" +#include using debugger::Chunk; using debugger::DataType; @@ -165,6 +166,14 @@ class Debugger : public std::enable_shared_from_this { // singleton static std::mutex instance_lock_; static std::shared_ptr debugger_; + + bool ssl_certificate; + std::string certificate_dir; + std::string certificate_passphrase; + bool SetDebuggerConfFromJsonFile(); + bool ParseDebuggerConfig(const std::string &dump_config_file); + bool IsConfigExist(const nlohmann::json &dumpSettings); + bool IsConfigValid(const nlohmann::json &dumpSettings); }; using DebuggerPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/debug/debugger/grpc_client.cc b/mindspore/ccsrc/debug/debugger/grpc_client.cc index 8677e9051fc6059dacde3a68d0e825250c1d9a5c..a80db02cc76af8880c948a98795c0c1496ec27a2 100644 --- a/mindspore/ccsrc/debug/debugger/grpc_client.cc +++ b/mindspore/ccsrc/debug/debugger/grpc_client.cc @@ -13,10 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "debug/debugger/grpc_client.h" + +#include +#include +#include +#include +#include +#include +#include #include #include -#include "debug/debugger/grpc_client.h" #include "utils/log_adapter.h" using debugger::Chunk; @@ -31,13 +39,96 @@ using debugger::WatchpointHit; #define CHUNK_SIZE 1024 * 1024 * 3 namespace mindspore { -GrpcClient::GrpcClient(const std::string &host, const std::string &port) : stub_(nullptr) { Init(host, port); } +GrpcClient::GrpcClient(const std::string &host, const std::string &port, const bool &ssl_certificate, + const std::string &certificate_dir, const std::string &certificate_passphrase) + : stub_(nullptr) { + Init(host, port, ssl_certificate, certificate_dir, certificate_passphrase); +} -void GrpcClient::Init(const std::string &host, const std::string &port) { +void GrpcClient::Init(const std::string &host, const std::string &port, const bool &ssl_certificate, + const std::string &certificate_dir, const std::string &certificate_passphrase) { std::string target_str = host + ":" + port; MS_LOG(INFO) << "GrpcClient connecting to: " << target_str; - std::shared_ptr channel = grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials()); + std::shared_ptr channel; + if (ssl_certificate) { + FILE *fp; + EVP_PKEY *pkey = NULL; + X509 *cert = NULL; + STACK_OF(X509) *ca = NULL; + PKCS12 *p12 = NULL; + + if ((fp = fopen(certificate_dir.c_str(), "rb")) == NULL) { + MS_LOG(ERROR) << "Error opening file: " << certificate_dir; + exit(EXIT_FAILURE); + } + p12 = d2i_PKCS12_fp(fp, NULL); + fclose(fp); + if (p12 == NULL) { + MS_LOG(ERROR) << "Error reading PKCS#12 file"; + X509_free(cert); + EVP_PKEY_free(pkey); + sk_X509_pop_free(ca, X509_free); + exit(EXIT_FAILURE); + } + if (!PKCS12_parse(p12, certificate_passphrase.c_str(), &pkey, &cert, &ca)) { + MS_LOG(ERROR) << "Error parsing PKCS#12 file"; + X509_free(cert); + EVP_PKEY_free(pkey); + sk_X509_pop_free(ca, X509_free); + exit(EXIT_FAILURE); + } + std::string strca; + std::string strcert; + std::string strkey; + + if (pkey == NULL || cert == NULL || ca == NULL) { + MS_LOG(ERROR) << "Error private key or cert or CA certificate."; + X509_free(cert); + EVP_PKEY_free(pkey); + sk_X509_pop_free(ca, X509_free); + exit(EXIT_FAILURE); + } else { + ASN1_TIME *validtime = X509_getm_notAfter(cert); + if (X509_cmp_current_time(validtime) < 0) { + MS_LOG(ERROR) << "This certificate is over its valid time, please use a new certificate."; + X509_free(cert); + EVP_PKEY_free(pkey); + sk_X509_pop_free(ca, X509_free); + exit(EXIT_FAILURE); + } + int nid = X509_get_signature_nid(cert); + int keybit = EVP_PKEY_bits(pkey); + if (nid == NID_sha1) { + MS_LOG(WARNING) << "Signature algrithm is sha1, which maybe not secure enough."; + } else if (keybit < 2048) { + MS_LOG(WARNING) << "The private key bits is: " << keybit << ", which maybe not secure enough."; + } + int dwPriKeyLen = i2d_PrivateKey(pkey, NULL); // get the length of private key + unsigned char *pribuf = (unsigned char *)malloc(sizeof(unsigned char) * dwPriKeyLen); + i2d_PrivateKey(pkey, &pribuf); // PrivateKey DER code + strkey = std::string(reinterpret_cast(pribuf), dwPriKeyLen); + + int dwcertLen = i2d_X509(cert, NULL); // get the length of private key + unsigned char *certbuf = (unsigned char *)malloc(sizeof(unsigned char) * dwcertLen); + i2d_X509(cert, &certbuf); // PrivateKey DER code + strcert = std::string(reinterpret_cast(certbuf), dwcertLen); + + int dwcaLen = i2d_X509(sk_X509_value(ca, 0), NULL); // get the length of private key + unsigned char *cabuf = (unsigned char *)malloc(sizeof(unsigned char) * dwcaLen); + i2d_X509(sk_X509_value(ca, 0), &cabuf); // PrivateKey DER code + strcat = std::string(reinterpret_cast(cabuf), dwcaLen); + + free(pribuf); + free(certbuf); + free(cabuf); + } + + grpc::SslCredentialsOptions opts = {strca, strkey, strcert}; + channel = grpc::CreateChannel(target_str, grpc::SslCredentials(opts)); + } else { + channel = grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials()); + } stub_ = EventListener::NewStub(channel); } diff --git a/mindspore/ccsrc/debug/debugger/grpc_client.h b/mindspore/ccsrc/debug/debugger/grpc_client.h index 0b5359e4447419df83688b5cc8ef7f13df8a9cef..fb209b99098b1a9be04ad439fd4c2357ce1e1f0b 100644 --- a/mindspore/ccsrc/debug/debugger/grpc_client.h +++ b/mindspore/ccsrc/debug/debugger/grpc_client.h @@ -17,9 +17,17 @@ #define MINDSPORE_CCSRC_DEBUG_DEBUGGER_GRPC_CLIENT_H_ #include + +#include +#include +#include +#include +#include + #include #include #include + #include "proto/debug_grpc.grpc.pb.h" using debugger::EventListener; @@ -33,13 +41,15 @@ namespace mindspore { class GrpcClient { public: // constructor - GrpcClient(const std::string &host, const std::string &port); + GrpcClient(const std::string &host, const std::string &port, const bool &ssl_certificate, + const std::string &certificate_dir, const std::string &certificate_passphrase); // deconstructor ~GrpcClient() = default; // init - void Init(const std::string &host, const std::string &port); + void Init(const std::string &host, const std::string &port, const bool &ssl_certificate, + const std::string &certificate_dir, const std::string &certificate_passphrase); // reset void Reset();