YugabyteDB (2.13.1.0-b60, 21121d69985fbf76aa6958d8f04a9bfa936293b5)

Coverage Report

Created: 2022-03-22 16:43

/Users/deen/code/yugabyte-db/src/yb/rpc/secure_stream.cc
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) YugaByte, Inc.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
4
// in compliance with the License.  You may obtain a copy of the License at
5
//
6
// http://www.apache.org/licenses/LICENSE-2.0
7
//
8
// Unless required by applicable law or agreed to in writing, software distributed under the License
9
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
10
// or implied.  See the License for the specific language governing permissions and limitations
11
// under the License.
12
//
13
14
#include "yb/rpc/secure_stream.h"
15
16
#include <openssl/err.h>
17
#include <openssl/ssl.h>
18
#include <openssl/x509v3.h>
19
20
#include <boost/tokenizer.hpp>
21
22
#include "yb/encryption/encryption_util.h"
23
24
#include "yb/gutil/casts.h"
25
26
#include "yb/rpc/outbound_data.h"
27
#include "yb/rpc/refined_stream.h"
28
29
#include "yb/util/enums.h"
30
#include "yb/util/errno.h"
31
#include "yb/util/logging.h"
32
#include "yb/util/scope_exit.h"
33
#include "yb/util/status_format.h"
34
35
using namespace std::literals;
36
37
DEFINE_bool(allow_insecure_connections, true, "Whether we should allow insecure connections.");
38
DEFINE_bool(dump_certificate_entries, false, "Whether we should dump certificate entries.");
39
DEFINE_bool(verify_client_endpoint, false, "Whether client endpoint should be verified.");
40
DEFINE_bool(verify_server_endpoint, true, "Whether server endpoint should be verified.");
41
DEFINE_string(ssl_protocols, "",
42
              "List of allowed SSL protocols (ssl2, ssl3, tls10, tls11, tls12). "
43
                  "Empty to allow TLS only.");
44
45
DEFINE_string(cipher_list, "",
46
              "Define the list of available ciphers (TLSv1.2 and below).");
47
48
DEFINE_string(ciphersuites, "",
49
              "Define the available TLSv1.3 ciphersuites.");
50
51
#define YB_RPC_SSL_TYPE_DEFINE(name) \
52
5.30k
  void BOOST_PP_CAT(name, Free)::operator()(name* value) const { \
53
5.30k
    BOOST_PP_CAT(name, _free)(value); \
54
5.30k
  } \
yb::rpc::detail::EVP_PKEYFree::operator()(evp_pkey_st*) const
Line
Count
Source
52
33
  void BOOST_PP_CAT(name, Free)::operator()(name* value) const { \
53
33
    BOOST_PP_CAT(name, _free)(value); \
54
33
  } \
yb::rpc::detail::SSLFree::operator()(ssl_st*) const
Line
Count
Source
52
2.38k
  void BOOST_PP_CAT(name, Free)::operator()(name* value) const { \
53
2.38k
    BOOST_PP_CAT(name, _free)(value); \
54
2.38k
  } \
yb::rpc::detail::SSL_CTXFree::operator()(ssl_ctx_st*) const
Line
Count
Source
52
35
  void BOOST_PP_CAT(name, Free)::operator()(name* value) const { \
53
35
    BOOST_PP_CAT(name, _free)(value); \
54
35
  } \
yb::rpc::detail::X509Free::operator()(x509_st*) const
Line
Count
Source
52
33
  void BOOST_PP_CAT(name, Free)::operator()(name* value) const { \
53
33
    BOOST_PP_CAT(name, _free)(value); \
54
33
  } \
Unexecuted instantiation: secure_stream.cc:yb::rpc::(anonymous namespace)::RSAFree::operator()(rsa_st*) const
secure_stream.cc:yb::rpc::(anonymous namespace)::X509_NAMEFree::operator()(X509_name_st*) const
Line
Count
Source
52
26
  void BOOST_PP_CAT(name, Free)::operator()(name* value) const { \
53
26
    BOOST_PP_CAT(name, _free)(value); \
54
26
  } \
secure_stream.cc:yb::rpc::(anonymous namespace)::ASN1_INTEGERFree::operator()(asn1_string_st*) const
Line
Count
Source
52
26
  void BOOST_PP_CAT(name, Free)::operator()(name* value) const { \
53
26
    BOOST_PP_CAT(name, _free)(value); \
54
26
  } \
secure_stream.cc:yb::rpc::(anonymous namespace)::BIOFree::operator()(bio_st*) const
Line
Count
Source
52
2.76k
  void BOOST_PP_CAT(name, Free)::operator()(name* value) const { \
53
2.76k
    BOOST_PP_CAT(name, _free)(value); \
54
2.76k
  } \
55
56
namespace yb {
57
namespace rpc {
58
59
namespace {
60
61
YB_RPC_SSL_TYPE_DECLARE(BIO);
62
YB_RPC_SSL_TYPE_DEFINE(BIO)
63
YB_RPC_SSL_TYPE_DECLARE(SSL);
64
65
const unsigned char kContextId[] = { 'Y', 'u', 'g', 'a', 'B', 'y', 't', 'e' };
66
67
361
std::string SSLErrorMessage(uint64_t error) {
68
361
  auto message = ERR_reason_error_string(error);
69
361
  return message ? message : 
"no error"0
;
70
361
}
71
72
#define YB_RPC_SSL_TYPE(name) YB_RPC_SSL_TYPE_DECLARE(name) YB_RPC_SSL_TYPE_DEFINE(name)
73
74
0
#define SSL_STATUS(type, format) STATUS_FORMAT(type, format, SSLErrorMessage(ERR_get_error()))
75
76
366
Result<BIOPtr> BIOFromSlice(const Slice& data) {
77
366
  BIOPtr bio(BIO_new_mem_buf(data.data(), narrow_cast<int>(data.size())));
78
366
  if (!bio) {
79
0
    return SSL_STATUS(IOError, "Create BIO failed: $0");
80
0
  }
81
366
  return std::move(bio);
82
366
}
83
84
183
Result<detail::X509Ptr> X509FromSlice(const Slice& data) {
85
183
  ERR_clear_error();
86
87
183
  auto bio = VERIFY_RESULT(BIOFromSlice(data));
88
89
0
  detail::X509Ptr cert(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
90
183
  if (!cert) {
91
0
    return SSL_STATUS(IOError, "Read cert failed: $0");
92
0
  }
93
94
183
  return std::move(cert);
95
183
}
96
97
YB_RPC_SSL_TYPE(ASN1_INTEGER);
98
YB_RPC_SSL_TYPE(RSA);
99
YB_RPC_SSL_TYPE(X509_NAME);
100
101
26
Result<detail::EVP_PKEYPtr> GeneratePrivateKey(int bits) {
102
26
  RSAPtr rsa(RSA_generate_key(bits, 65537, nullptr, nullptr));
103
26
  if (!rsa) {
104
0
    return SSL_STATUS(InvalidArgument, "Failed to generate private key: $0");
105
0
  }
106
107
26
  detail::EVP_PKEYPtr pkey(EVP_PKEY_new());
108
26
  auto res = EVP_PKEY_assign_RSA(pkey.get(), rsa.release());
109
26
  if (res != 1) {
110
0
    return SSL_STATUS(InvalidArgument, "Failed to assign private key: $0");
111
0
  }
112
113
26
  return std::move(pkey);
114
26
}
115
116
class ExtensionConfigurator {
117
 public:
118
13
  explicit ExtensionConfigurator(X509 *cert) : cert_(cert) {
119
    // No configuration database
120
13
    X509V3_set_ctx_nodb(&ctx_);
121
    // Both issuer and subject certs
122
13
    X509V3_set_ctx(&ctx_, cert, cert, nullptr, nullptr, 0);
123
13
  }
124
125
13
  CHECKED_STATUS Add(int nid, const char* value) {
126
13
    X509_EXTENSION *ex = X509V3_EXT_conf_nid(nullptr, &ctx_, nid, value);
127
13
    if (!ex) {
128
0
      return SSL_STATUS(InvalidArgument, "Failed to create extension: $0");
129
0
    }
130
131
13
    X509_add_ext(cert_, ex, -1);
132
13
    X509_EXTENSION_free(ex);
133
134
13
    return Status::OK();
135
13
  }
136
137
 private:
138
  X509V3_CTX ctx_;
139
  X509* cert_;
140
};
141
142
Result<detail::X509Ptr> CreateCertificate(
143
26
    EVP_PKEY* key, const std::string& common_name, EVP_PKEY* ca_pkey, X509* ca_cert) {
144
26
  detail::X509Ptr cert(X509_new());
145
26
  if (!cert) {
146
0
    return SSL_STATUS(IOError, "Failed to create new certificate: $0");
147
0
  }
148
149
26
  if (X509_set_version(cert.get(), 2) != 1) {
150
0
    return SSL_STATUS(IOError, "Failed to set certificate version: $0");
151
0
  }
152
153
26
  ASN1_INTEGERPtr aserial(ASN1_INTEGER_new());
154
26
  ASN1_INTEGER_set(aserial.get(), 0);
155
26
  if (!X509_set_serialNumber(cert.get(), aserial.get())) {
156
0
    return SSL_STATUS(IOError, "Failed to set serial number: $0");
157
0
  }
158
159
26
  X509_NAMEPtr name(X509_NAME_new());
160
26
  auto bytes = pointer_cast<const unsigned char*>(common_name.c_str());
161
26
  if (!X509_NAME_add_entry_by_txt(
162
26
      name.get(), "CN", MBSTRING_ASC, bytes, narrow_cast<int>(common_name.length()), -1, 0)) {
163
0
    return SSL_STATUS(IOError, "Failed to create subject: $0");
164
0
  }
165
166
26
  if (X509_set_subject_name(cert.get(), name.get()) != 1) {
167
0
    return SSL_STATUS(IOError, "Failed to set subject: $0");
168
0
  }
169
170
26
  X509_NAME* issuer = name.get();
171
26
  if (ca_cert) {
172
13
    issuer = X509_get_subject_name(ca_cert);
173
13
    if (!issuer) {
174
0
      return SSL_STATUS(IOError, "Failed to get CA subject name: $0");
175
0
    }
176
13
  } else {
177
13
    ExtensionConfigurator configurator(cert.get());
178
13
    RETURN_NOT_OK(configurator.Add(NID_basic_constraints, "critical,CA:TRUE"));
179
13
  }
180
181
26
  if (X509_set_issuer_name(cert.get(), issuer) != 1) {
182
0
    return SSL_STATUS(IOError, "Failed to set issuer: $0");
183
0
  }
184
185
26
  if (X509_set_pubkey(cert.get(), key) != 1) {
186
0
    return SSL_STATUS(IOError, "Failed to set public key: $0");
187
0
  }
188
189
26
  if (!X509_gmtime_adj(X509_get_notBefore(cert.get()), 0)) {
190
0
    return SSL_STATUS(IOError, "Failed to set not before: $0");
191
0
  }
192
193
26
  const auto k1Year = 365 * 24h;
194
26
  auto seconds = std::chrono::duration_cast<std::chrono::seconds>(k1Year).count();
195
26
  if (!X509_gmtime_adj(X509_get_notAfter(cert.get()), seconds)) {
196
0
    return SSL_STATUS(IOError, "Failed to set not after: $0");
197
0
  }
198
199
26
  if (ca_cert) {
200
13
    X509V3_CTX ctx;
201
13
    X509V3_set_ctx(&ctx, ca_cert, cert.get(), nullptr, nullptr, 0);
202
13
  }
203
204
26
  if (!X509_sign(cert.get(), ca_pkey, EVP_sha256())) {
205
0
    return SSL_STATUS(IOError, "Sign failed: $0");
206
0
  }
207
208
26
  return std::move(cert);
209
26
}
210
211
9
const std::unordered_map<std::string, int64_t>& SSLProtocolMap() {
212
9
  static const std::unordered_map<std::string, int64_t> result = {
213
9
      {"ssl2", SSL_OP_NO_SSLv2},
214
9
      {"ssl3", SSL_OP_NO_SSLv3},
215
9
      {"tls10", SSL_OP_NO_TLSv1},
216
9
      {"tls11", SSL_OP_NO_TLSv1_1},
217
9
      {"tls12", SSL_OP_NO_TLSv1_2},
218
9
      {"tls13", SSL_OP_NO_TLSv1_3},
219
9
  };
220
9
  return result;
221
9
}
222
223
211
int64_t ProtocolsOption() {
224
211
  constexpr int64_t kDefaultProtocols = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
225
226
211
  const std::string& ssl_protocols = FLAGS_ssl_protocols;
227
211
  if (ssl_protocols.empty()) {
228
202
    return kDefaultProtocols;
229
202
  }
230
231
9
  const auto& protocol_map = SSLProtocolMap();
232
9
  int64_t result = SSL_OP_NO_SSL_MASK;
233
9
  boost::tokenizer<> tokenizer(ssl_protocols);
234
9
  for (const auto& protocol : tokenizer) {
235
9
    auto it = protocol_map.find(protocol);
236
9
    if (it == protocol_map.end()) {
237
0
      LOG(DFATAL) << "Unknown SSL protocol: " << protocol;
238
0
      return kDefaultProtocols;
239
0
    }
240
9
    result &= ~it->second;
241
9
  }
242
243
9
  return result;
244
9
}
245
246
} // namespace
247
248
namespace detail {
249
250
YB_RPC_SSL_TYPE_DEFINE(EVP_PKEY)
251
YB_RPC_SSL_TYPE_DEFINE(SSL)
252
YB_RPC_SSL_TYPE_DEFINE(SSL_CTX)
253
YB_RPC_SSL_TYPE_DEFINE(X509)
254
255
}
256
257
211
SecureContext::SecureContext() {
258
211
  encryption::InitOpenSSL();
259
260
211
  context_.reset(SSL_CTX_new(SSLv23_method()));
261
211
  DCHECK(context_);
262
263
211
  int64_t protocols = ProtocolsOption();
264
211
  VLOG
(1) << "Protocols option: " << protocols0
;
265
211
  SSL_CTX_set_options(context_.get(), protocols | SSL_OP_NO_COMPRESSION);
266
267
211
  auto cipher_list = FLAGS_cipher_list;
268
211
  if (!cipher_list.empty()) {
269
3
    LOG(INFO) << "Use cipher list: " << cipher_list;
270
3
    auto res = SSL_CTX_set_cipher_list(context_.get(), cipher_list.c_str());
271
3
    LOG_IF(DFATAL, res != 1) << "Failed to set cipher list: "
272
0
                             << SSLErrorMessage(ERR_get_error());
273
3
  }
274
275
211
  auto ciphersuites = FLAGS_ciphersuites;
276
211
  if (!ciphersuites.empty()) {
277
3
    LOG(INFO) << "Use cipher suites: " << ciphersuites;
278
3
    auto res = SSL_CTX_set_ciphersuites(context_.get(), ciphersuites.c_str());
279
3
    LOG_IF(DFATAL, res != 1) << "Failed to set ciphersuites: "
280
0
                           << SSLErrorMessage(ERR_get_error());
281
3
  }
282
283
211
  auto res = SSL_CTX_set_session_id_context(context_.get(), kContextId, sizeof(kContextId));
284
211
  LOG_IF(DFATAL, res != 1) << "Failed to set session id for SSL context: "
285
0
                           << SSLErrorMessage(ERR_get_error());
286
211
}
287
288
5.49k
detail::SSLPtr SecureContext::Create() const {
289
5.49k
  return detail::SSLPtr(SSL_new(context_.get()));
290
5.49k
}
291
292
198
Status SecureContext::AddCertificateAuthorityFile(const std::string& file) {
293
198
  X509_STORE* store = SSL_CTX_get_cert_store(context_.get());
294
198
  if (!store) {
295
0
    return SSL_STATUS(IllegalState, "Failed to get store: $0");
296
0
  }
297
298
198
  auto bytes = pointer_cast<const char*>(file.c_str());
299
198
  auto res = X509_STORE_load_locations(store, bytes, nullptr);
300
198
  if (res != 1) {
301
0
    return SSL_STATUS(InvalidArgument, "Failed to add certificate file: $0");
302
0
  }
303
304
198
  return Status::OK();
305
198
}
306
307
0
Status SecureContext::AddCertificateAuthority(const Slice& data) {
308
0
  return AddCertificateAuthority(VERIFY_RESULT(X509FromSlice(data)).get());
309
0
}
310
311
13
Status SecureContext::AddCertificateAuthority(X509* cert) {
312
13
  X509_STORE* store = SSL_CTX_get_cert_store(context_.get());
313
13
  if (!store) {
314
0
    return SSL_STATUS(IllegalState, "Failed to get store: $0");
315
0
  }
316
317
13
  auto res = X509_STORE_add_cert(store, cert);
318
13
  if (res != 1) {
319
0
    return SSL_STATUS(InvalidArgument, "Failed to add certificate: $0");
320
0
  }
321
322
13
  return Status::OK();
323
13
}
324
325
13
Status SecureContext::TEST_GenerateKeys(int bits, const std::string& common_name) {
326
13
  auto ca_key = VERIFY_RESULT(GeneratePrivateKey(bits));
327
13
  auto ca_cert = VERIFY_RESULT(CreateCertificate(ca_key.get(), "YugaByte", ca_key.get(), nullptr));
328
13
  auto key = VERIFY_RESULT(GeneratePrivateKey(bits));
329
13
  auto cert = VERIFY_RESULT(CreateCertificate(key.get(), common_name, ca_key.get(), ca_cert.get()));
330
331
13
  RETURN_NOT_OK(AddCertificateAuthority(ca_cert.get()));
332
13
  pkey_ = std::move(key);
333
13
  certificate_ = std::move(cert);
334
335
13
  return Status::OK();
336
13
}
337
338
183
Status SecureContext::UsePrivateKey(const Slice& slice) {
339
183
  ERR_clear_error();
340
341
183
  auto bio = VERIFY_RESULT(BIOFromSlice(slice));
342
343
0
  auto pkey = PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr);
344
183
  if (!pkey) {
345
0
    return SSL_STATUS(IOError, "Failed to read private key: $0");
346
0
  }
347
348
183
  pkey_.reset(pkey);
349
183
  return Status::OK();
350
183
}
351
352
183
Status SecureContext::UseCertificate(const Slice& data) {
353
183
  ERR_clear_error();
354
355
183
  certificate_ = VERIFY_RESULT(X509FromSlice(data));
356
357
0
  return Status::OK();
358
183
}
359
360
namespace {
361
362
class SecureRefiner : public StreamRefiner {
363
 public:
364
  SecureRefiner(const SecureContext& context, const StreamCreateData& data)
365
7.51k
    : secure_context_(context), remote_hostname_(data.remote_hostname) {
366
7.51k
  }
367
368
 private:
369
7.54k
  void Start(RefinedStream* stream) override {
370
7.54k
    stream_ = stream;
371
7.54k
  }
372
373
  CHECKED_STATUS Handshake() override;
374
  CHECKED_STATUS Init();
375
376
  CHECKED_STATUS Send(OutboundDataPtr data) override;
377
  CHECKED_STATUS ProcessHeader() override;
378
  Result<ReadBufferFull> Read(StreamReadBuffer* out) override;
379
380
506
  std::string ToString() const override {
381
506
    return "SECURE";
382
506
  }
383
384
0
  const Protocol* GetProtocol() override {
385
0
    return SecureStreamProtocol();
386
0
  }
387
388
  static int VerifyCallback(int preverified, X509_STORE_CTX* store_context);
389
  CHECKED_STATUS Verify(bool preverified, X509_STORE_CTX* store_context);
390
  bool MatchEndpoint(X509* cert, GENERAL_NAMES* gens);
391
  bool MatchUid(X509* cert, GENERAL_NAMES* gens);
392
  bool MatchUidEntry(const Slice& value, const char* name);
393
  Result<bool> WriteEncrypted(OutboundDataPtr data);
394
  void DecryptReceived();
395
396
5.05k
  CHECKED_STATUS Established(RefinedStreamState state) {
397
5.05k
    
VLOG_WITH_PREFIX0
(4) << "Established with state: " << state << ", used cipher: "
398
0
                        << SSL_get_cipher_name(ssl_.get());
399
400
5.05k
    return stream_->Established(state);
401
5.05k
  }
402
403
62
  const std::string& LogPrefix() const {
404
62
    return stream_->LogPrefix();
405
62
  }
406
407
  const SecureContext& secure_context_;
408
  const std::string remote_hostname_;
409
  RefinedStream* stream_ = nullptr;
410
  std::vector<std::string> certificate_entries_;
411
412
  BIOPtr bio_;
413
  detail::SSLPtr ssl_;
414
  Status verification_status_;
415
};
416
417
140k
Status SecureRefiner::Send(OutboundDataPtr data) {
418
140k
  boost::container::small_vector<RefCntBuffer, 10> queue;
419
140k
  data->Serialize(&queue);
420
143k
  for (const auto& buf : queue) {
421
143k
    Slice slice(buf.data(), buf.size());
422
181k
    for (;;) {
423
181k
      int slice_size = narrow_cast<int>(slice.size());
424
181k
      auto len = SSL_write(ssl_.get(), slice.data(), slice_size);
425
181k
      if (len == slice_size) {
426
143k
        break;
427
143k
      }
428
38.7k
      auto error = len <= 0 ? 
SSL_get_error(ssl_.get(), len)6
: SSL_ERROR_NONE;
429
18.4E
      VLOG_WITH_PREFIX(4) << "SSL_write was not full: " << slice.size() << ", written: " << len
430
18.4E
                          << ", error: " << error;
431
38.7k
      if (error != SSL_ERROR_NONE) {
432
6
        if (error != SSL_ERROR_WANT_WRITE || !VERIFY_RESULT(WriteEncrypted(nullptr))) {
433
0
          return STATUS_FORMAT(
434
0
              NetworkError, "SSL write failed: $0 ($1)", SSLErrorMessage(error), error);
435
0
        }
436
38.7k
      } else {
437
38.7k
        RETURN_NOT_OK(WriteEncrypted(nullptr));
438
38.7k
      }
439
38.9k
      
if (38.7k
len > 038.7k
) {
440
38.9k
        slice.remove_prefix(len);
441
38.9k
      }
442
38.7k
    }
443
143k
  }
444
140k
  return ResultToStatus(WriteEncrypted(std::move(data)));
445
140k
}
446
447
178k
Result<bool> SecureRefiner::WriteEncrypted(OutboundDataPtr data) {
448
178k
  auto pending = BIO_ctrl_pending(bio_.get());
449
178k
  if (pending == 0) {
450
0
    return data ? STATUS(NetworkError, "No pending data during write") : Result<bool>(false);
451
0
  }
452
178k
  RefCntBuffer buf(pending);
453
178k
  int buf_size = narrow_cast<int>(buf.size());
454
178k
  auto len = BIO_read(bio_.get(), buf.data(), buf_size);
455
18.4E
  LOG_IF_WITH_PREFIX(DFATAL, len != buf_size)
456
18.4E
      << "BIO_read was not full: " << buf.size() << ", read: " << len;
457
18.4E
  VLOG_WITH_PREFIX(4) << "Write encrypted: " << len << ", " << AsString(data);
458
178k
  RETURN_NOT_OK(stream_->SendToLower(std::make_shared<SingleBufferOutboundData>(
459
178k
      buf, std::move(data))));
460
178k
  return true;
461
178k
}
462
463
3.19k
Status SecureRefiner::ProcessHeader() {
464
3.19k
  auto data = stream_->ReadBuffer().AppendedVecs();
465
3.19k
  if (data.empty() || data[0].iov_len < 2) {
466
0
    return Status::OK();
467
0
  }
468
469
3.19k
  const auto* bytes = static_cast<const uint8_t*>(data[0].iov_base);
470
3.19k
  if (bytes[0] == 0x16 && 
bytes[1] == 0x033.01k
) { // TLS handshake header
471
3.01k
    RETURN_NOT_OK(Init());
472
3.01k
    return stream_->StartHandshake();
473
3.01k
  }
474
475
186
  if (!FLAGS_allow_insecure_connections) {
476
168
    return STATUS_FORMAT(NetworkError, "Insecure connection header: $0",
477
168
                         Slice(bytes, 2).ToDebugHexString());
478
168
  }
479
480
18
  return Established(RefinedStreamState::kDisabled);
481
186
}
482
483
// Tries to do SSL_read up to num bytes from buf. Possible results:
484
// > 0 - number of bytes actually read.
485
// = 0 - in case of SSL_ERROR_WANT_READ.
486
// Status with network error - in case of other errors.
487
178k
Result<ReadBufferFull> SecureRefiner::Read(StreamReadBuffer* out) {
488
178k
  DecryptReceived();
489
178k
  auto total = 0;
490
178k
  auto iovecs = VERIFY_RESULT(out->PrepareAppend());
491
0
  auto iov_it = iovecs.begin();
492
361k
  for (;;) {
493
361k
    auto len = SSL_read(ssl_.get(), iov_it->iov_base, narrow_cast<int>(iov_it->iov_len));
494
495
361k
    if (len <= 0) {
496
178k
      auto error = SSL_get_error(ssl_.get(), len);
497
178k
      if (error == SSL_ERROR_WANT_READ) {
498
18.4E
        VLOG_WITH_PREFIX(4) << "Read decrypted: SSL_ERROR_WANT_READ";
499
178k
        break;
500
178k
      }
501
471
      auto status = STATUS_FORMAT(
502
471
          NetworkError, "SSL read failed: $0 ($1)", SSLErrorMessage(error), error);
503
471
      LOG_WITH_PREFIX(INFO) << status;
504
471
      return status;
505
178k
    }
506
507
182k
    
VLOG_WITH_PREFIX85
(4) << "Read decrypted: " << len85
;
508
182k
    total += len;
509
182k
    IoVecRemovePrefix(len, &*iov_it);
510
182k
    if (iov_it->iov_len == 0) {
511
730
      if (++iov_it == iovecs.end()) {
512
100
        break;
513
100
      }
514
730
    }
515
182k
  }
516
178k
  out->DataAppended(total);
517
178k
  return ReadBufferFull(out->Full());
518
178k
}
519
520
189k
void SecureRefiner::DecryptReceived() {
521
189k
  auto& inp = stream_->ReadBuffer();
522
189k
  if (inp.Empty()) {
523
8.06k
    return;
524
8.06k
  }
525
181k
  size_t total = 0;
526
181k
  for (const auto& iov : inp.AppendedVecs()) {
527
181k
    auto res = BIO_write(bio_.get(), iov.iov_base, narrow_cast<int>(iov.iov_len));
528
181k
    
VLOG_WITH_PREFIX327
(4) << "Decrypted: " << res << " of " << iov.iov_len327
;
529
181k
    if (res <= 0) {
530
0
      break;
531
0
    }
532
181k
    total += res;
533
181k
    if (implicit_cast<size_t>(res) < iov.iov_len) {
534
35.8k
      break;
535
35.8k
    }
536
181k
  }
537
181k
  inp.Consume(total, {});
538
181k
}
539
540
11.1k
Status SecureRefiner::Handshake() {
541
11.1k
  RETURN_NOT_OK(Init());
542
543
11.1k
  DecryptReceived();
544
545
21.6k
  for (;;) {
546
21.6k
    if (stream_->IsConnected()) {
547
5.00k
      return Status::OK();
548
5.00k
    }
549
550
16.6k
    auto pending_before = BIO_ctrl_pending(bio_.get());
551
16.6k
    ERR_clear_error();
552
16.6k
    int result = stream_->local_side() == LocalSide::kClient
553
16.6k
        ? 
SSL_connect(ssl_.get())7.46k
:
SSL_accept(ssl_.get())9.15k
;
554
16.6k
    int ssl_error = SSL_get_error(ssl_.get(), result);
555
16.6k
    int sys_error = static_cast<int>(ERR_get_error());
556
16.6k
    auto pending_after = BIO_ctrl_pending(bio_.get());
557
558
16.6k
    if (ssl_error == SSL_ERROR_SSL || 
ssl_error == SSL_ERROR_SYSCALL16.1k
) {
559
444
      std::string message = verification_status_.ok()
560
444
          ? 
(300
ssl_error == SSL_ERROR_SSL300
?
SSLErrorMessage(sys_error)300
:
ErrnoToString(sys_error)0
)
561
444
          : 
verification_status_.ToString()144
;
562
444
      std::string message_suffix;
563
444
      if (FLAGS_dump_certificate_entries) {
564
0
        message_suffix = Format(", certificate entries: $0", certificate_entries_);
565
0
      }
566
444
      return STATUS_FORMAT(NetworkError, "Handshake failed: $0, address: $1, hostname: $2$3",
567
444
                           message, stream_->Remote().address(), remote_hostname_, message_suffix);
568
444
    }
569
570
16.1k
    
if (16.1k
ssl_error == SSL_ERROR_WANT_WRITE16.1k
|| pending_after > pending_before) {
571
      // SSL expects that we would write to underlying transport.
572
10.5k
      RefCntBuffer buffer(pending_after);
573
10.5k
      int len = BIO_read(bio_.get(), buffer.data(), narrow_cast<int>(buffer.size()));
574
10.5k
      DCHECK_EQ(len, pending_after);
575
10.5k
      RETURN_NOT_OK(stream_->SendToLower(
576
10.5k
          std::make_shared<SingleBufferOutboundData>(buffer, nullptr)));
577
      // If SSL_connect/SSL_accept returned positive result it means that TLS connection
578
      // was succesfully established. We just have to send last portion of data.
579
10.5k
      if (result > 0) {
580
5.00k
        RETURN_NOT_OK(Established(RefinedStreamState::kEnabled));
581
5.00k
      }
582
10.5k
    } else 
if (5.66k
ssl_error == SSL_ERROR_WANT_READ5.66k
) {
583
      // SSL expects that we would read from underlying transport.
584
5.64k
      return Status::OK();
585
5.64k
    } else 
if (20
SSL_get_shutdown(ssl_.get()) & SSL_RECEIVED_SHUTDOWN20
) {
586
0
      return STATUS(Aborted, "Handshake aborted");
587
20
    } else {
588
20
      return Established(RefinedStreamState::kEnabled);
589
20
    }
590
16.1k
  }
591
11.1k
}
592
593
14.1k
Status SecureRefiner::Init() {
594
14.1k
  if (ssl_) {
595
8.65k
    return Status::OK();
596
8.65k
  }
597
598
5.48k
  ssl_ = secure_context_.Create();
599
5.48k
  SSL_set_mode(ssl_.get(), SSL_MODE_ENABLE_PARTIAL_WRITE);
600
5.48k
  SSL_set_mode(ssl_.get(), SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
601
5.48k
  SSL_set_mode(ssl_.get(), SSL_MODE_RELEASE_BUFFERS);
602
5.48k
  SSL_set_app_data(ssl_.get(), this);
603
604
5.48k
  if (stream_->local_side() == LocalSide::kServer || 
secure_context_.use_client_certificate()2.45k
) {
605
3.50k
    auto res = SSL_use_PrivateKey(ssl_.get(), secure_context_.private_key());
606
3.50k
    if (res != 1) {
607
0
      return SSL_STATUS(InvalidArgument, "Failed to use private key: $0");
608
0
    }
609
3.50k
    res = SSL_use_certificate(ssl_.get(), secure_context_.certificate());
610
3.50k
    if (res != 1) {
611
0
      return SSL_STATUS(InvalidArgument, "Failed to use certificate: $0");
612
0
    }
613
3.50k
  }
614
615
5.48k
  BIO* int_bio = nullptr;
616
5.48k
  BIO* temp_bio = nullptr;
617
5.48k
  BIO_new_bio_pair(&int_bio, 0, &temp_bio, 0);
618
5.48k
  SSL_set_bio(ssl_.get(), int_bio, int_bio);
619
5.48k
  bio_.reset(temp_bio);
620
621
5.48k
  int verify_mode = SSL_VERIFY_PEER;
622
5.48k
  if (secure_context_.require_client_certificate()) {
623
1.31k
    verify_mode |= SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
624
1.31k
  }
625
5.48k
  SSL_set_verify(ssl_.get(), verify_mode, &VerifyCallback);
626
627
5.48k
  return Status::OK();
628
5.48k
}
629
630
6.29k
int SecureRefiner::VerifyCallback(int preverified, X509_STORE_CTX* store_context) {
631
6.29k
  if (!store_context) {
632
0
    return preverified;
633
0
  }
634
635
6.29k
  auto ssl = static_cast<SSL*>(X509_STORE_CTX_get_ex_data(
636
6.29k
      store_context, SSL_get_ex_data_X509_STORE_CTX_idx()));
637
6.29k
  if (!ssl) {
638
0
    return preverified;
639
0
  }
640
641
6.29k
  auto refiner = static_cast<SecureRefiner*>(SSL_get_app_data(ssl));
642
643
6.29k
  if (!refiner) {
644
0
    return preverified;
645
0
  }
646
647
6.29k
  auto status = refiner->Verify(preverified != 0, store_context);
648
6.29k
  if (status.ok()) {
649
6.12k
    return 1;
650
6.12k
  }
651
652
172
  VLOG
(4) << refiner->LogPrefix() << status28
;
653
172
  refiner->verification_status_ = status;
654
172
  return 0;
655
6.29k
}
656
657
namespace {
658
659
// Matches pattern from RFC 2818:
660
// Names may contain the wildcard character * which is considered to match any single domain name
661
// component or component fragment. E.g., *.a.com matches foo.a.com but not bar.foo.a.com.
662
// f*.com matches foo.com but not bar.com.
663
349
bool MatchPattern(Slice pattern, Slice host) {
664
349
  const char* p = pattern.cdata();
665
349
  const char* p_end = pattern.cend();
666
349
  const char* h = host.cdata();
667
349
  const char* h_end = host.cend();
668
669
790
  while (p != p_end && 
h != h_end769
) {
670
481
    if (*p == '*') {
671
40
      ++p;
672
80
      while (h != h_end && *h != '.') {
673
40
        if (MatchPattern(Slice(p, p_end), Slice(h, h_end))) {
674
0
          return true;
675
0
        }
676
40
        ++h;
677
40
      }
678
441
    } else if (std::tolower(*p) == std::tolower(*h)) {
679
401
      ++p;
680
401
      ++h;
681
401
    } else {
682
40
      return false;
683
40
    }
684
481
  }
685
686
309
  return p == p_end && 
h == h_end21
;
687
349
}
688
689
232
Slice GetEntryByNid(X509* cert, int nid) {
690
232
  X509_NAME* name = X509_get_subject_name(cert);
691
232
  int last_i = -1;
692
464
  for (int i = -1; (i = X509_NAME_get_index_by_NID(name, nid, i)) >= 0; ) {
693
232
    last_i = i;
694
232
  }
695
232
  if (last_i == -1) {
696
0
    return Slice();
697
0
  }
698
232
  auto* name_entry = X509_NAME_get_entry(name, last_i);
699
232
  if (!name_entry) {
700
0
    LOG(DFATAL) << "No name entry in certificate at index: " << last_i;
701
0
    return Slice();
702
0
  }
703
232
  auto* common_name = X509_NAME_ENTRY_get_data(name_entry);
704
705
232
  if (common_name && common_name->data && common_name->length) {
706
232
    return Slice(common_name->data, common_name->length);
707
232
  }
708
709
0
  return Slice();
710
232
}
711
712
232
Slice GetCommonName(X509* cert) {
713
232
  return GetEntryByNid(cert, NID_commonName);
714
232
}
715
716
} // namespace
717
718
2.84k
bool SecureRefiner::MatchEndpoint(X509* cert, GENERAL_NAMES* gens) {
719
2.84k
  auto address = stream_->Remote().address();
720
721
3.15k
  for (int i = 0; i < sk_GENERAL_NAME_num(gens); 
++i317
) {
722
2.93k
    GENERAL_NAME* gen = sk_GENERAL_NAME_value(gens, i);
723
2.93k
    if (gen->type == GEN_DNS) {
724
165
      ASN1_IA5STRING* domain = gen->d.dNSName;
725
165
      if (domain->type == V_ASN1_IA5STRING && domain->data && domain->length) {
726
165
        Slice domain_slice(domain->data, domain->length);
727
165
        
VLOG_WITH_PREFIX0
(4) << "Domain: " << domain_slice.ToBuffer() << " vs " << remote_hostname_0
;
728
165
        if (FLAGS_dump_certificate_entries) {
729
0
          certificate_entries_.push_back(Format("DNS:$0", domain_slice.ToBuffer()));
730
0
        }
731
165
        if (MatchPattern(domain_slice, remote_hostname_)) {
732
21
          return true;
733
21
        }
734
165
      }
735
2.76k
    } else 
if (2.76k
gen->type == GEN_IPADD2.76k
) {
736
2.76k
      ASN1_OCTET_STRING* ip_address = gen->d.iPAddress;
737
2.77k
      if (
ip_address->type == V_ASN1_OCTET_STRING2.76k
&&
ip_address->data2.77k
) {
738
2.77k
        if (ip_address->length == 4) {
739
2.76k
          boost::asio::ip::address_v4::bytes_type bytes;
740
2.76k
          memcpy(&bytes, ip_address->data, bytes.size());
741
2.76k
          auto allowed_address = boost::asio::ip::address_v4(bytes);
742
2.76k
          
VLOG_WITH_PREFIX10
(4) << "IPv4: " << allowed_address.to_string() << " vs " << address10
;
743
2.76k
          if (FLAGS_dump_certificate_entries) {
744
0
            certificate_entries_.push_back(Format("IP Address:$0", allowed_address));
745
0
          }
746
2.76k
          if (address == allowed_address) {
747
2.59k
            return true;
748
2.59k
          }
749
2.76k
        } else 
if (3
ip_address->length == 163
) {
750
0
          boost::asio::ip::address_v6::bytes_type bytes;
751
0
          memcpy(&bytes, ip_address->data, bytes.size());
752
0
          auto allowed_address = boost::asio::ip::address_v6(bytes);
753
0
          VLOG_WITH_PREFIX(4) << "IPv6: " << allowed_address.to_string() << " vs " << address;
754
0
          if (FLAGS_dump_certificate_entries) {
755
0
            certificate_entries_.push_back(Format("IP Address:$0", allowed_address));
756
0
          }
757
0
          if (address == allowed_address) {
758
0
            return true;
759
0
          }
760
0
        }
761
2.77k
      }
762
2.76k
    }
763
2.93k
  }
764
765
  // No match in the alternate names, so try the common names. We should only
766
  // use the "most specific" common name, which is the last one in the list.
767
226
  Slice common_name = GetCommonName(cert);
768
226
  if (!common_name.empty()) {
769
206
    
VLOG_WITH_PREFIX0
(4) << "Common name: " << common_name.ToBuffer() << " vs "
770
0
                        << stream_->Remote().address() << "/" << remote_hostname_;
771
206
    if (common_name == stream_->Remote().address().to_string() ||
772
206
        
MatchPattern(common_name, remote_hostname_)144
) {
773
62
      return true;
774
62
    }
775
206
  }
776
777
164
  
VLOG_WITH_PREFIX20
(4) << "Nothing suitable for " << stream_->Remote().address() << "/"
778
20
                      << remote_hostname_;
779
780
164
  return false;
781
226
}
782
783
26
bool SecureRefiner::MatchUidEntry(const Slice& value, const char* name) {
784
26
  if (value == secure_context_.required_uid()) {
785
26
    
VLOG_WITH_PREFIX0
(4) << "Accepted " << name << ": " << value.ToBuffer()0
;
786
26
    return true;
787
26
  } else 
if (0
!value.empty()0
) {
788
0
    VLOG_WITH_PREFIX(4) << "Rejected " << name << ": " << value.ToBuffer() << ", while "
789
0
                        << secure_context_.required_uid() << " required";
790
0
  }
791
0
  return false;
792
26
}
793
794
0
bool IsStringType(int type) {
795
0
  switch (type) {
796
0
    case V_ASN1_UTF8STRING: FALLTHROUGH_INTENDED;
797
0
    case V_ASN1_IA5STRING: FALLTHROUGH_INTENDED;
798
0
    case V_ASN1_UNIVERSALSTRING: FALLTHROUGH_INTENDED;
799
0
    case V_ASN1_BMPSTRING: FALLTHROUGH_INTENDED;
800
0
    case V_ASN1_VISIBLESTRING: FALLTHROUGH_INTENDED;
801
0
    case V_ASN1_PRINTABLESTRING: FALLTHROUGH_INTENDED;
802
0
    case V_ASN1_TELETEXSTRING: FALLTHROUGH_INTENDED;
803
0
    case V_ASN1_GENERALSTRING: FALLTHROUGH_INTENDED;
804
0
    case V_ASN1_NUMERICSTRING:
805
0
      return true;
806
0
  }
807
0
  return false;
808
0
}
809
810
26
bool SecureRefiner::MatchUid(X509* cert, GENERAL_NAMES* gens) {
811
26
  if (MatchUidEntry(GetCommonName(cert), "common name")) {
812
26
    return true;
813
26
  }
814
815
0
  auto uid = GetEntryByNid(cert, NID_userId);
816
0
  if (!uid.empty()) {
817
0
    if (FLAGS_dump_certificate_entries) {
818
0
      certificate_entries_.push_back(Format("UID:$0", uid.ToBuffer()));
819
0
    }
820
0
    if (MatchUidEntry(uid, "uid")) {
821
0
      return true;
822
0
    }
823
0
  }
824
825
0
  for (int i = 0; i < sk_GENERAL_NAME_num(gens); ++i) {
826
0
    GENERAL_NAME* gen = sk_GENERAL_NAME_value(gens, i);
827
0
    if (gen->type == GEN_OTHERNAME) {
828
0
      auto value = gen->d.otherName->value;
829
0
      if (IsStringType(value->type)) {
830
0
        Slice other_name(value->value.asn1_string->data, value->value.asn1_string->length);
831
0
        if (!other_name.empty()) {
832
0
          if (FLAGS_dump_certificate_entries) {
833
0
            certificate_entries_.push_back(Format("ON:$0", other_name.ToBuffer()));
834
0
          }
835
0
          if (MatchUidEntry(other_name, "other name")) {
836
0
            return true;
837
0
          }
838
0
        }
839
0
      }
840
0
    }
841
0
  }
842
0
  VLOG_WITH_PREFIX(4) << "Not found entry for UID " << secure_context_.required_uid();
843
844
0
  return false;
845
0
}
846
847
// Verify according to RFC 2818.
848
6.29k
Status SecureRefiner::Verify(bool preverified, X509_STORE_CTX* store_context) {
849
  // Don't bother looking at certificates that have failed pre-verification.
850
6.29k
  if (!preverified) {
851
0
    auto err = X509_STORE_CTX_get_error(store_context);
852
0
    return STATUS_FORMAT(
853
0
        NetworkError, "Unverified certificate: $0", X509_verify_cert_error_string(err));
854
0
  }
855
856
  // We're only interested in checking the certificate at the end of the chain.
857
6.29k
  int depth = X509_STORE_CTX_get_error_depth(store_context);
858
6.29k
  if (depth > 0) {
859
3.14k
    
VLOG_WITH_PREFIX0
(4) << "Intermediate certificate"0
;
860
3.14k
    return Status::OK();
861
3.14k
  }
862
863
3.15k
  X509* cert = X509_STORE_CTX_get_current_cert(store_context);
864
3.15k
  auto gens = static_cast<GENERAL_NAMES*>(X509_get_ext_d2i(
865
3.15k
      cert, NID_subject_alt_name, nullptr, nullptr));
866
3.15k
  auto se = ScopeExit([gens] {
867
3.14k
    GENERAL_NAMES_free(gens);
868
3.14k
  });
869
870
3.15k
  if (FLAGS_dump_certificate_entries) {
871
0
    certificate_entries_.push_back(Format("CN:$0", GetCommonName(cert).ToBuffer()));
872
0
  }
873
874
3.15k
  if (!secure_context_.required_uid().empty()) {
875
26
    if (!MatchUid(cert, gens)) {
876
0
      return STATUS_FORMAT(
877
0
          NetworkError, "Uid does not match: $0", secure_context_.required_uid());
878
0
    }
879
3.12k
  } else {
880
3.12k
    
VLOG_WITH_PREFIX10
(4) << "Skip UID verification"10
;
881
3.12k
  }
882
883
3.15k
  bool verify_endpoint = stream_->local_side() == LocalSide::kClient ? 
FLAGS_verify_server_endpoint2.46k
884
3.15k
                                                                     : 
FLAGS_verify_client_endpoint685
;
885
3.15k
  if (verify_endpoint) {
886
2.83k
    if (!MatchEndpoint(cert, gens)) {
887
144
      return STATUS(NetworkError, "Endpoint does not match");
888
144
    }
889
2.83k
  } else {
890
312
    
VLOG_WITH_PREFIX11
(4) << "Skip endpoint verification"11
;
891
312
  }
892
893
3.00k
  return Status::OK();
894
3.15k
}
895
896
} // namespace
897
898
420
const Protocol* SecureStreamProtocol() {
899
420
  static Protocol result("tcps");
900
420
  return &result;
901
420
}
902
903
StreamFactoryPtr SecureStreamFactory(
904
    StreamFactoryPtr lower_layer_factory, const MemTrackerPtr& buffer_tracker,
905
224
    const SecureContext* context) {
906
224
  return std::make_shared<RefinedStreamFactory>(
907
7.45k
      std::move(lower_layer_factory), buffer_tracker, [context](const StreamCreateData& data) {
908
7.45k
    return std::make_unique<SecureRefiner>(*context, data);
909
7.45k
  });
910
224
}
911
912
} // namespace rpc
913
} // namespace yb