提交 6ea1b21f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4822 Add secure mechanism to Mindspore Debugger

Merge pull request !4822 from maning202007/rebase_master
{
"DebuggerSettings": {
"ssl_certificate": true,
"certificate_path": "/home/maning/sslcertificates/client.pfx",
"certificate_passphrase": "12345678"
},
"DebuggerSettingsSpec": {
"ssl_certificate": "true, secure_mode enable, default is true",
"certificate_path": "path to the certificate file",
"certificate_passphrase": "passphrase of the certificate"
},
"other": {}
}
\ No newline at end of file
...@@ -30,6 +30,11 @@ ...@@ -30,6 +30,11 @@
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "runtime/device/kernel_runtime_manager.h" #include "runtime/device/kernel_runtime_manager.h"
#include "utils/system/file_system.h"
#include "utils/system/env.h"
#include <nlohmann/json.hpp>
using json = nlohmann::json;
using debugger::EventReply; using debugger::EventReply;
using debugger::GraphProto; using debugger::GraphProto;
using debugger::ModelProto; using debugger::ModelProto;
...@@ -60,7 +65,10 @@ Debugger::Debugger() ...@@ -60,7 +65,10 @@ Debugger::Debugger()
is_dataset_graph_(false), is_dataset_graph_(false),
partial_memory_(false), partial_memory_(false),
last_overflow_bin_(0), last_overflow_bin_(0),
overflow_bin_path_("") {} overflow_bin_path_(""),
ssl_certificate(true),
certificate_dir(""),
certificate_passphrase("") {}
void Debugger::Init(const uint32_t device_id, const std::string device_target) { void Debugger::Init(const uint32_t device_id, const std::string device_target) {
// access lock for public method // access lock for public method
...@@ -175,7 +183,8 @@ void Debugger::EnableDebugger() { ...@@ -175,7 +183,8 @@ void Debugger::EnableDebugger() {
// initialize grpc client // initialize grpc client
if (debugger_enabled_) { if (debugger_enabled_) {
grpc_client_ = std::make_unique<GrpcClient>(host, port); SetDebuggerConfFromJsonFile();
grpc_client_ = std::make_unique<GrpcClient>(host, port, ssl_certificate, certificate_dir, certificate_passphrase);
} }
debug_services_ = std::make_unique<DebugServices>(); debug_services_ = std::make_unique<DebugServices>();
...@@ -787,4 +796,88 @@ std::vector<std::string> Debugger::CheckOpOverflow() { ...@@ -787,4 +796,88 @@ std::vector<std::string> Debugger::CheckOpOverflow() {
return op_names; return op_names;
} }
bool Debugger::SetDebuggerConfFromJsonFile() {
const char *config_path_str = std::getenv("MINDSPORE_CONFIG_PATH");
if (config_path_str != nullptr) {
MS_LOG(INFO) << "Getenv MINDSPORE_CONFIG_PATH :" << config_path_str;
} else {
MS_LOG(INFO) << "No need debugger config path. please export MINDSPORE_CONFIG_PATH eg: MINDSPORE_CONFIG_PATH=/etc";
ssl_certificate = false;
return false;
}
char real_path[4096] = {0};
if (nullptr == realpath(config_path_str, real_path)) {
MS_LOG(ERROR) << "Env debugger config path error, " << config_path_str;
ssl_certificate = false;
return false;
}
std::string debugger_config_file = std::string(real_path) + "/debugger_config.json";
std::shared_ptr<system::FileSystem> fs = system::Env::GetFileSystem();
MS_EXCEPTION_IF_NULL(fs);
if (!fs->FileExist(debugger_config_file)) {
MS_LOG(ERROR) << debugger_config_file << " not exist.";
ssl_certificate = false;
return false;
}
return ParseDebuggerConfig(debugger_config_file);
}
bool Debugger::ParseDebuggerConfig(const std::string &debugger_config_file) {
std::ifstream jsonFile(debugger_config_file);
if (!jsonFile.is_open()) {
MS_LOG(ERROR) << debugger_config_file << " open failed.";
ssl_certificate = false;
return false;
}
json j;
jsonFile >> j;
if (j.find("DebuggerSettings") == j.end()) {
MS_LOG(ERROR) << "DebuggerSettings is not exist.";
ssl_certificate = false;
return false;
} else {
json debuggerSettings = j.at("DebuggerSettings");
// convert json to string
std::stringstream ss;
ss << debuggerSettings;
std::string cfg = ss.str();
MS_LOG(INFO) << "Debugger Settings Json: " << cfg;
if (!IsConfigExist(debuggerSettings)) {
return false;
}
if (!IsConfigValid(debuggerSettings)) {
return false;
}
}
return true;
}
bool Debugger::IsConfigExist(const nlohmann::json &debuggerSettings) {
if (debuggerSettings.find("ssl_certificate") == debuggerSettings.end() ||
debuggerSettings.find("certificate_dir") == debuggerSettings.end() ||
debuggerSettings.find("certificate_passphrase") == debuggerSettings.end()) {
MS_LOG(ERROR) << "DebuggerSettings keys is not exist.";
ssl_certificate = false;
return false;
}
return true;
}
bool Debugger::IsConfigValid(const nlohmann::json &debuggerSettings) {
auto enable_secure = debuggerSettings.at("ssl_certificate");
auto certificate_dir_ = debuggerSettings.at("certificate_dir");
auto certificate_passphrase_ = debuggerSettings.at("certificate_passphrase");
if (!(enable_secure.is_boolean() && certificate_dir_.is_string() && certificate_passphrase_.is_string())) {
MS_LOG(ERROR) << "Element's type in Debugger config json is invalid.";
ssl_certificate = false;
return false;
}
ssl_certificate = enable_secure;
certificate_dir = certificate_dir_;
certificate_passphrase = certificate_passphrase_;
return true;
}
} // namespace mindspore } // namespace mindspore
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "backend/session/kernel_graph.h" #include "backend/session/kernel_graph.h"
#include "debug/debugger/grpc_client.h" #include "debug/debugger/grpc_client.h"
#include "debug/debug_services.h" #include "debug/debug_services.h"
#include <nlohmann/json.hpp>
using debugger::Chunk; using debugger::Chunk;
using debugger::DataType; using debugger::DataType;
...@@ -165,6 +166,14 @@ class Debugger : public std::enable_shared_from_this<Debugger> { ...@@ -165,6 +166,14 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
// singleton // singleton
static std::mutex instance_lock_; static std::mutex instance_lock_;
static std::shared_ptr<Debugger> debugger_; static std::shared_ptr<Debugger> debugger_;
bool ssl_certificate;
std::string certificate_dir;
std::string certificate_passphrase;
bool SetDebuggerConfFromJsonFile();
bool ParseDebuggerConfig(const std::string &dump_config_file);
bool IsConfigExist(const nlohmann::json &dumpSettings);
bool IsConfigValid(const nlohmann::json &dumpSettings);
}; };
using DebuggerPtr = std::shared_ptr<Debugger>; using DebuggerPtr = std::shared_ptr<Debugger>;
......
...@@ -13,10 +13,18 @@ ...@@ -13,10 +13,18 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "debug/debugger/grpc_client.h"
#include <stdio.h>
#include <stdlib.h>
#include <openssl/pem.h>
#include <openssl/err.h>
#include <openssl/pkcs12.h>
#include <openssl/x509.h>
#include <openssl/evp.h>
#include <thread> #include <thread>
#include <vector> #include <vector>
#include "debug/debugger/grpc_client.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
using debugger::Chunk; using debugger::Chunk;
...@@ -31,13 +39,96 @@ using debugger::WatchpointHit; ...@@ -31,13 +39,96 @@ using debugger::WatchpointHit;
#define CHUNK_SIZE 1024 * 1024 * 3 #define CHUNK_SIZE 1024 * 1024 * 3
namespace mindspore { namespace mindspore {
GrpcClient::GrpcClient(const std::string &host, const std::string &port) : stub_(nullptr) { Init(host, port); } GrpcClient::GrpcClient(const std::string &host, const std::string &port, const bool &ssl_certificate,
const std::string &certificate_dir, const std::string &certificate_passphrase)
: stub_(nullptr) {
Init(host, port, ssl_certificate, certificate_dir, certificate_passphrase);
}
void GrpcClient::Init(const std::string &host, const std::string &port) { void GrpcClient::Init(const std::string &host, const std::string &port, const bool &ssl_certificate,
const std::string &certificate_dir, const std::string &certificate_passphrase) {
std::string target_str = host + ":" + port; std::string target_str = host + ":" + port;
MS_LOG(INFO) << "GrpcClient connecting to: " << target_str; MS_LOG(INFO) << "GrpcClient connecting to: " << target_str;
std::shared_ptr<grpc::Channel> channel = grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials()); std::shared_ptr<grpc::Channel> channel;
if (ssl_certificate) {
FILE *fp;
EVP_PKEY *pkey = NULL;
X509 *cert = NULL;
STACK_OF(X509) *ca = NULL;
PKCS12 *p12 = NULL;
if ((fp = fopen(certificate_dir.c_str(), "rb")) == NULL) {
MS_LOG(ERROR) << "Error opening file: " << certificate_dir;
exit(EXIT_FAILURE);
}
p12 = d2i_PKCS12_fp(fp, NULL);
fclose(fp);
if (p12 == NULL) {
MS_LOG(ERROR) << "Error reading PKCS#12 file";
X509_free(cert);
EVP_PKEY_free(pkey);
sk_X509_pop_free(ca, X509_free);
exit(EXIT_FAILURE);
}
if (!PKCS12_parse(p12, certificate_passphrase.c_str(), &pkey, &cert, &ca)) {
MS_LOG(ERROR) << "Error parsing PKCS#12 file";
X509_free(cert);
EVP_PKEY_free(pkey);
sk_X509_pop_free(ca, X509_free);
exit(EXIT_FAILURE);
}
std::string strca;
std::string strcert;
std::string strkey;
if (pkey == NULL || cert == NULL || ca == NULL) {
MS_LOG(ERROR) << "Error private key or cert or CA certificate.";
X509_free(cert);
EVP_PKEY_free(pkey);
sk_X509_pop_free(ca, X509_free);
exit(EXIT_FAILURE);
} else {
ASN1_TIME *validtime = X509_getm_notAfter(cert);
if (X509_cmp_current_time(validtime) < 0) {
MS_LOG(ERROR) << "This certificate is over its valid time, please use a new certificate.";
X509_free(cert);
EVP_PKEY_free(pkey);
sk_X509_pop_free(ca, X509_free);
exit(EXIT_FAILURE);
}
int nid = X509_get_signature_nid(cert);
int keybit = EVP_PKEY_bits(pkey);
if (nid == NID_sha1) {
MS_LOG(WARNING) << "Signature algrithm is sha1, which maybe not secure enough.";
} else if (keybit < 2048) {
MS_LOG(WARNING) << "The private key bits is: " << keybit << ", which maybe not secure enough.";
}
int dwPriKeyLen = i2d_PrivateKey(pkey, NULL); // get the length of private key
unsigned char *pribuf = (unsigned char *)malloc(sizeof(unsigned char) * dwPriKeyLen);
i2d_PrivateKey(pkey, &pribuf); // PrivateKey DER code
strkey = std::string(reinterpret_cast<char const *>(pribuf), dwPriKeyLen);
int dwcertLen = i2d_X509(cert, NULL); // get the length of private key
unsigned char *certbuf = (unsigned char *)malloc(sizeof(unsigned char) * dwcertLen);
i2d_X509(cert, &certbuf); // PrivateKey DER code
strcert = std::string(reinterpret_cast<char const *>(certbuf), dwcertLen);
int dwcaLen = i2d_X509(sk_X509_value(ca, 0), NULL); // get the length of private key
unsigned char *cabuf = (unsigned char *)malloc(sizeof(unsigned char) * dwcaLen);
i2d_X509(sk_X509_value(ca, 0), &cabuf); // PrivateKey DER code
strcat = std::string(reinterpret_cast<char const *>(cabuf), dwcaLen);
free(pribuf);
free(certbuf);
free(cabuf);
}
grpc::SslCredentialsOptions opts = {strca, strkey, strcert};
channel = grpc::CreateChannel(target_str, grpc::SslCredentials(opts));
} else {
channel = grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials());
}
stub_ = EventListener::NewStub(channel); stub_ = EventListener::NewStub(channel);
} }
......
...@@ -17,9 +17,17 @@ ...@@ -17,9 +17,17 @@
#define MINDSPORE_CCSRC_DEBUG_DEBUGGER_GRPC_CLIENT_H_ #define MINDSPORE_CCSRC_DEBUG_DEBUGGER_GRPC_CLIENT_H_
#include <grpcpp/grpcpp.h> #include <grpcpp/grpcpp.h>
#include <stdio.h>
#include <stdlib.h>
#include <openssl/pem.h>
#include <openssl/err.h>
#include <openssl/pkcs12.h>
#include <string> #include <string>
#include <list> #include <list>
#include <memory> #include <memory>
#include "proto/debug_grpc.grpc.pb.h" #include "proto/debug_grpc.grpc.pb.h"
using debugger::EventListener; using debugger::EventListener;
...@@ -33,13 +41,15 @@ namespace mindspore { ...@@ -33,13 +41,15 @@ namespace mindspore {
class GrpcClient { class GrpcClient {
public: public:
// constructor // constructor
GrpcClient(const std::string &host, const std::string &port); GrpcClient(const std::string &host, const std::string &port, const bool &ssl_certificate,
const std::string &certificate_dir, const std::string &certificate_passphrase);
// deconstructor // deconstructor
~GrpcClient() = default; ~GrpcClient() = default;
// init // init
void Init(const std::string &host, const std::string &port); void Init(const std::string &host, const std::string &port, const bool &ssl_certificate,
const std::string &certificate_dir, const std::string &certificate_passphrase);
// reset // reset
void Reset(); void Reset();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册