aes_cipher.cc 11.1 KB
Newer Older
Y
Yanghello 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/io/crypto/aes_cipher.h"

#include <cryptopp/aes.h>
#include <cryptopp/ccm.h>
#include <cryptopp/cryptlib.h>
#include <cryptopp/filters.h>
#include <cryptopp/gcm.h>
#include <cryptopp/modes.h>
#include <cryptopp/smartptr.h>

#include <set>
#include <string>

#include "paddle/fluid/framework/io/crypto/cipher_utils.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace framework {

47 48
void AESCipher::Init(const std::string& cipher_name,
                     const int& iv_size,
Y
Yanghello 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
                     const int& tag_size) {
  aes_cipher_name_ = cipher_name;
  iv_size_ = iv_size;
  tag_size_ = tag_size;
  std::set<std::string> authented_cipher_set{"AES_GCM_NoPadding"};
  if (authented_cipher_set.find(cipher_name) != authented_cipher_set.end()) {
    is_authenticated_cipher_ = true;
  }
}

std::string AESCipher::EncryptInternal(const std::string& plaintext,
                                       const std::string& key) {
  CryptoPP::member_ptr<CryptoPP::SymmetricCipher> m_cipher;
  CryptoPP::member_ptr<CryptoPP::StreamTransformationFilter> m_filter;
  bool need_iv = false;
  const unsigned char* key_char =
      reinterpret_cast<const unsigned char*>(&(key.at(0)));
  BuildCipher(true, &need_iv, &m_cipher, &m_filter);
  if (need_iv) {
    iv_ = CipherUtils::GenKey(iv_size_);
    m_cipher.get()->SetKeyWithIV(
70 71 72 73
        key_char,
        key.size(),
        reinterpret_cast<const unsigned char*>(&(iv_.at(0))),
        iv_.size());
Y
Yanghello 已提交
74 75 76 77 78 79 80 81
  } else {
    m_cipher.get()->SetKey(key_char, key.size());
  }

  std::string ciphertext;
  m_filter->Attach(new CryptoPP::StringSink(ciphertext));
  CryptoPP::StringSource(plaintext, true, new CryptoPP::Redirector(*m_filter));
  if (need_iv) {
82
    return iv_ + ciphertext;
Y
Yanghello 已提交
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
  }

  return ciphertext;
}

std::string AESCipher::DecryptInternal(const std::string& ciphertext,
                                       const std::string& key) {
  CryptoPP::member_ptr<CryptoPP::SymmetricCipher> m_cipher;
  CryptoPP::member_ptr<CryptoPP::StreamTransformationFilter> m_filter;
  bool need_iv = false;
  const unsigned char* key_char =
      reinterpret_cast<const unsigned char*>(&(key.at(0)));
  BuildCipher(false, &need_iv, &m_cipher, &m_filter);
  int ciphertext_beg = 0;
  if (need_iv) {
    iv_ = ciphertext.substr(0, iv_size_ / 8);
    ciphertext_beg = iv_size_ / 8;
    m_cipher.get()->SetKeyWithIV(
101 102 103 104
        key_char,
        key.size(),
        reinterpret_cast<const unsigned char*>(&(iv_.at(0))),
        iv_.size());
Y
Yanghello 已提交
105 106 107 108 109
  } else {
    m_cipher.get()->SetKey(key_char, key.size());
  }
  std::string plaintext;
  m_filter->Attach(new CryptoPP::StringSink(plaintext));
110 111
  CryptoPP::StringSource(ciphertext.substr(ciphertext_beg),
                         true,
Y
Yanghello 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
                         new CryptoPP::Redirector(*m_filter));

  return plaintext;
}

std::string AESCipher::AuthenticatedEncryptInternal(
    const std::string& plaintext, const std::string& key) {
  CryptoPP::member_ptr<CryptoPP::AuthenticatedSymmetricCipher> m_cipher;
  CryptoPP::member_ptr<CryptoPP::AuthenticatedEncryptionFilter> m_filter;
  bool need_iv = false;
  const unsigned char* key_char =
      reinterpret_cast<const unsigned char*>(&(key.at(0)));
  BuildAuthEncCipher(&need_iv, &m_cipher, &m_filter);
  if (need_iv) {
    iv_ = CipherUtils::GenKey(iv_size_);
    m_cipher.get()->SetKeyWithIV(
128 129 130 131
        key_char,
        key.size(),
        reinterpret_cast<const unsigned char*>(&(iv_.at(0))),
        iv_.size());
Y
Yanghello 已提交
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
  } else {
    m_cipher.get()->SetKey(key_char, key.size());
  }

  std::string ciphertext;
  m_filter->Attach(new CryptoPP::StringSink(ciphertext));
  CryptoPP::StringSource(plaintext, true, new CryptoPP::Redirector(*m_filter));
  if (need_iv) {
    ciphertext = iv_.append(ciphertext);
  }

  return ciphertext;
}

std::string AESCipher::AuthenticatedDecryptInternal(
    const std::string& ciphertext, const std::string& key) {
  CryptoPP::member_ptr<CryptoPP::AuthenticatedSymmetricCipher> m_cipher;
  CryptoPP::member_ptr<CryptoPP::AuthenticatedDecryptionFilter> m_filter;
  bool need_iv = false;
  const unsigned char* key_char =
      reinterpret_cast<const unsigned char*>(&(key.at(0)));
  BuildAuthDecCipher(&need_iv, &m_cipher, &m_filter);
  int ciphertext_beg = 0;
  if (need_iv) {
    iv_ = ciphertext.substr(0, iv_size_ / 8);
    ciphertext_beg = iv_size_ / 8;
    m_cipher.get()->SetKeyWithIV(
159 160 161 162
        key_char,
        key.size(),
        reinterpret_cast<const unsigned char*>(&(iv_.at(0))),
        iv_.size());
Y
Yanghello 已提交
163 164 165 166 167
  } else {
    m_cipher.get()->SetKey(key_char, key.size());
  }
  std::string plaintext;
  m_filter->Attach(new CryptoPP::StringSink(plaintext));
168 169
  CryptoPP::StringSource(ciphertext.substr(ciphertext_beg),
                         true,
Y
Yanghello 已提交
170 171
                         new CryptoPP::Redirector(*m_filter));
  PADDLE_ENFORCE_EQ(
172 173
      m_filter->GetLastResult(),
      true,
Y
Yanghello 已提交
174 175 176 177 178 179
      paddle::platform::errors::InvalidArgument("Integrity check failed. "
                                                "Invalid ciphertext input."));
  return plaintext;
}

void AESCipher::BuildCipher(
180 181
    bool for_encrypt,
    bool* need_iv,
Y
Yanghello 已提交
182 183 184 185 186
    CryptoPP::member_ptr<CryptoPP::SymmetricCipher>* m_cipher,
    CryptoPP::member_ptr<CryptoPP::StreamTransformationFilter>* m_filter) {
  if (aes_cipher_name_ == "AES_ECB_PKCSPadding" && for_encrypt) {
    m_cipher->reset(new CryptoPP::ECB_Mode<CryptoPP::AES>::Encryption);
    m_filter->reset(new CryptoPP::StreamTransformationFilter(
187 188
        *(*m_cipher).get(),
        NULL,
Y
Yanghello 已提交
189 190 191 192
        CryptoPP::BlockPaddingSchemeDef::PKCS_PADDING));
  } else if (aes_cipher_name_ == "AES_ECB_PKCSPadding" && !for_encrypt) {
    m_cipher->reset(new CryptoPP::ECB_Mode<CryptoPP::AES>::Decryption);
    m_filter->reset(new CryptoPP::StreamTransformationFilter(
193 194
        *(*m_cipher).get(),
        NULL,
Y
Yanghello 已提交
195 196 197 198 199
        CryptoPP::BlockPaddingSchemeDef::PKCS_PADDING));
  } else if (aes_cipher_name_ == "AES_CBC_PKCSPadding" && for_encrypt) {
    m_cipher->reset(new CryptoPP::CBC_Mode<CryptoPP::AES>::Encryption);
    *need_iv = true;
    m_filter->reset(new CryptoPP::StreamTransformationFilter(
200 201
        *(*m_cipher).get(),
        NULL,
Y
Yanghello 已提交
202 203 204 205 206
        CryptoPP::BlockPaddingSchemeDef::PKCS_PADDING));
  } else if (aes_cipher_name_ == "AES_CBC_PKCSPadding" && !for_encrypt) {
    m_cipher->reset(new CryptoPP::CBC_Mode<CryptoPP::AES>::Decryption);
    *need_iv = true;
    m_filter->reset(new CryptoPP::StreamTransformationFilter(
207 208
        *(*m_cipher).get(),
        NULL,
Y
Yanghello 已提交
209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
        CryptoPP::BlockPaddingSchemeDef::PKCS_PADDING));
  } else if (aes_cipher_name_ == "AES_CTR_NoPadding" && for_encrypt) {
    m_cipher->reset(new CryptoPP::CTR_Mode<CryptoPP::AES>::Encryption);
    *need_iv = true;
    m_filter->reset(new CryptoPP::StreamTransformationFilter(
        *(*m_cipher).get(), NULL, CryptoPP::BlockPaddingSchemeDef::NO_PADDING));
  } else if (aes_cipher_name_ == "AES_CTR_NoPadding" && !for_encrypt) {
    m_cipher->reset(new CryptoPP::CTR_Mode<CryptoPP::AES>::Decryption);
    *need_iv = true;
    m_filter->reset(new CryptoPP::StreamTransformationFilter(
        *(*m_cipher).get(), NULL, CryptoPP::BlockPaddingSchemeDef::NO_PADDING));
  } else {
    PADDLE_THROW(paddle::platform::errors::Unimplemented(
        "Create cipher error. "
        "Cipher name %s is error, or has not been implemented.",
        aes_cipher_name_));
  }
}

void AESCipher::BuildAuthEncCipher(
    bool* need_iv,
    CryptoPP::member_ptr<CryptoPP::AuthenticatedSymmetricCipher>* m_cipher,
    CryptoPP::member_ptr<CryptoPP::AuthenticatedEncryptionFilter>* m_filter) {
  if (aes_cipher_name_ == "AES_GCM_NoPadding") {
    m_cipher->reset(new CryptoPP::GCM<CryptoPP::AES>::Encryption);
    *need_iv = true;
    m_filter->reset(new CryptoPP::AuthenticatedEncryptionFilter(
236 237 238 239
        *(*m_cipher).get(),
        NULL,
        false,
        tag_size_ / 8,
Y
Yanghello 已提交
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
        CryptoPP::DEFAULT_CHANNEL,
        CryptoPP::BlockPaddingSchemeDef::NO_PADDING));
  } else {
    PADDLE_THROW(paddle::platform::errors::Unimplemented(
        "Create cipher error. "
        "Cipher name %s is error, or has not been implemented.",
        aes_cipher_name_));
  }
}

void AESCipher::BuildAuthDecCipher(
    bool* need_iv,
    CryptoPP::member_ptr<CryptoPP::AuthenticatedSymmetricCipher>* m_cipher,
    CryptoPP::member_ptr<CryptoPP::AuthenticatedDecryptionFilter>* m_filter) {
  if (aes_cipher_name_ == "AES_GCM_NoPadding") {
    m_cipher->reset(new CryptoPP::GCM<CryptoPP::AES>::Decryption);
    *need_iv = true;
    m_filter->reset(new CryptoPP::AuthenticatedDecryptionFilter(
258 259 260 261
        *(*m_cipher).get(),
        NULL,
        CryptoPP::AuthenticatedDecryptionFilter::DEFAULT_FLAGS,
        tag_size_ / 8,
Y
Yanghello 已提交
262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
        CryptoPP::BlockPaddingSchemeDef::NO_PADDING));
  } else {
    PADDLE_THROW(paddle::platform::errors::Unimplemented(
        "Create cipher error. "
        "Cipher name %s is error, or has not been implemented.",
        aes_cipher_name_));
  }
}

std::string AESCipher::Encrypt(const std::string& plaintext,
                               const std::string& key) {
  return is_authenticated_cipher_ ? AuthenticatedEncryptInternal(plaintext, key)
                                  : EncryptInternal(plaintext, key);
}

std::string AESCipher::Decrypt(const std::string& ciphertext,
                               const std::string& key) {
  return is_authenticated_cipher_
             ? AuthenticatedDecryptInternal(ciphertext, key)
             : DecryptInternal(ciphertext, key);
}

void AESCipher::EncryptToFile(const std::string& plaintext,
                              const std::string& key,
                              const std::string& filename) {
287 288 289
  std::ofstream fout(filename, std::ios::binary);
  std::string ciphertext = this->Encrypt(plaintext, key);
  fout.write(ciphertext.data(), ciphertext.size());
Y
Yanghello 已提交
290 291 292 293 294
  fout.close();
}

std::string AESCipher::DecryptFromFile(const std::string& key,
                                       const std::string& filename) {
295
  std::ifstream fin(filename, std::ios::binary);
Y
Yanghello 已提交
296 297 298 299 300 301 302 303
  std::string ciphertext{std::istreambuf_iterator<char>(fin),
                         std::istreambuf_iterator<char>()};
  fin.close();
  return Decrypt(ciphertext, key);
}

}  // namespace framework
}  // namespace paddle