YugabyteDB (2.13.0.0-b42, bfc6a6643e7399ac8a0e81d06a3ee6d6571b33ab)

Coverage Report

Created: 2022-03-09 17:30

/Users/deen/code/yugabyte-db/src/yb/rpc/secure_stream.cc
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) YugaByte, Inc.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
4
// in compliance with the License.  You may obtain a copy of the License at
5
//
6
// http://www.apache.org/licenses/LICENSE-2.0
7
//
8
// Unless required by applicable law or agreed to in writing, software distributed under the License
9
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
10
// or implied.  See the License for the specific language governing permissions and limitations
11
// under the License.
12
//
13
14
#include "yb/rpc/secure_stream.h"
15
16
#include <openssl/err.h>
17
#include <openssl/ssl.h>
18
#include <openssl/x509v3.h>
19
20
#include <boost/tokenizer.hpp>
21
22
#include "yb/encryption/encryption_util.h"
23
24
#include "yb/gutil/casts.h"
25
26
#include "yb/rpc/outbound_data.h"
27
#include "yb/rpc/refined_stream.h"
28
29
#include "yb/util/enums.h"
30
#include "yb/util/errno.h"
31
#include "yb/util/logging.h"
32
#include "yb/util/scope_exit.h"
33
#include "yb/util/status_format.h"
34
35
using namespace std::literals;
36
37
DEFINE_bool(allow_insecure_connections, true, "Whether we should allow insecure connections.");
38
DEFINE_bool(dump_certificate_entries, false, "Whether we should dump certificate entries.");
39
DEFINE_bool(verify_client_endpoint, false, "Whether client endpoint should be verified.");
40
DEFINE_bool(verify_server_endpoint, true, "Whether server endpoint should be verified.");
41
DEFINE_string(ssl_protocols, "",
42
              "List of allowed SSL protocols (ssl2, ssl3, tls10, tls11, tls12). "
43
                  "Empty to allow TLS only.");
44
45
DEFINE_string(cipher_list, "",
46
              "Define the list of available ciphers (TLSv1.2 and below).");
47
48
DEFINE_string(ciphersuites, "",
49
              "Define the available TLSv1.3 ciphersuites.");
50
51
#define YB_RPC_SSL_TYPE_DEFINE(name) \
52
2.05k
  void BOOST_PP_CAT(name, Free)::operator()(name* value) const { \
53
2.05k
    BOOST_PP_CAT(name, _free)(value); \
54
2.05k
  } \
_ZNK2yb3rpc6detail12EVP_PKEYFreeclEP11evp_pkey_st
Line
Count
Source
52
26
  void BOOST_PP_CAT(name, Free)::operator()(name* value) const { \
53
26
    BOOST_PP_CAT(name, _free)(value); \
54
26
  } \
_ZNK2yb3rpc6detail7SSLFreeclEP6ssl_st
Line
Count
Source
52
892
  void BOOST_PP_CAT(name, Free)::operator()(name* value) const { \
53
892
    BOOST_PP_CAT(name, _free)(value); \
54
892
  } \
_ZNK2yb3rpc6detail11SSL_CTXFreeclEP10ssl_ctx_st
Line
Count
Source
52
13
  void BOOST_PP_CAT(name, Free)::operator()(name* value) const { \
53
13
    BOOST_PP_CAT(name, _free)(value); \
54
13
  } \
_ZNK2yb3rpc6detail8X509FreeclEP7x509_st
Line
Count
Source
52
26
  void BOOST_PP_CAT(name, Free)::operator()(name* value) const { \
53
26
    BOOST_PP_CAT(name, _free)(value); \
54
26
  } \
Unexecuted instantiation: secure_stream.cc:_ZNK2yb3rpc12_GLOBAL__N_17RSAFreeclEP6rsa_st
secure_stream.cc:_ZNK2yb3rpc12_GLOBAL__N_113X509_NAMEFreeclEP12X509_name_st
Line
Count
Source
52
26
  void BOOST_PP_CAT(name, Free)::operator()(name* value) const { \
53
26
    BOOST_PP_CAT(name, _free)(value); \
54
26
  } \
secure_stream.cc:_ZNK2yb3rpc12_GLOBAL__N_116ASN1_INTEGERFreeclEP14asn1_string_st
Line
Count
Source
52
26
  void BOOST_PP_CAT(name, Free)::operator()(name* value) const { \
53
26
    BOOST_PP_CAT(name, _free)(value); \
54
26
  } \
secure_stream.cc:_ZNK2yb3rpc12_GLOBAL__N_17BIOFreeclEP6bio_st
Line
Count
Source
52
1.04k
  void BOOST_PP_CAT(name, Free)::operator()(name* value) const { \
53
1.04k
    BOOST_PP_CAT(name, _free)(value); \
54
1.04k
  } \
55
56
namespace yb {
57
namespace rpc {
58
59
namespace {
60
61
YB_RPC_SSL_TYPE_DECLARE(BIO);
62
YB_RPC_SSL_TYPE_DEFINE(BIO)
63
YB_RPC_SSL_TYPE_DECLARE(SSL);
64
65
const unsigned char kContextId[] = { 'Y', 'u', 'g', 'a', 'B', 'y', 't', 'e' };
66
67
142
std::string SSLErrorMessage(uint64_t error) {
68
142
  auto message = ERR_reason_error_string(error);
69
142
  return message ? message : "no error";
70
142
}
71
72
#define YB_RPC_SSL_TYPE(name) YB_RPC_SSL_TYPE_DECLARE(name) YB_RPC_SSL_TYPE_DEFINE(name)
73
74
0
#define SSL_STATUS(type, format) STATUS_FORMAT(type, format, SSLErrorMessage(ERR_get_error()))
75
76
150
Result<BIOPtr> BIOFromSlice(const Slice& data) {
77
150
  BIOPtr bio(BIO_new_mem_buf(data.data(), narrow_cast<int>(data.size())));
78
150
  if (!bio) {
79
0
    return SSL_STATUS(IOError, "Create BIO failed: $0");
80
0
  }
81
150
  return std::move(bio);
82
150
}
83
84
75
Result<detail::X509Ptr> X509FromSlice(const Slice& data) {
85
75
  ERR_clear_error();
86
87
75
  auto bio = VERIFY_RESULT(BIOFromSlice(data));
88
89
75
  detail::X509Ptr cert(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
90
75
  if (!cert) {
91
0
    return SSL_STATUS(IOError, "Read cert failed: $0");
92
0
  }
93
94
75
  return std::move(cert);
95
75
}
96
97
YB_RPC_SSL_TYPE(ASN1_INTEGER);
98
YB_RPC_SSL_TYPE(RSA);
99
YB_RPC_SSL_TYPE(X509_NAME);
100
101
26
Result<detail::EVP_PKEYPtr> GeneratePrivateKey(int bits) {
102
26
  RSAPtr rsa(RSA_generate_key(bits, 65537, nullptr, nullptr));
103
26
  if (!rsa) {
104
0
    return SSL_STATUS(InvalidArgument, "Failed to generate private key: $0");
105
0
  }
106
107
26
  detail::EVP_PKEYPtr pkey(EVP_PKEY_new());
108
26
  auto res = EVP_PKEY_assign_RSA(pkey.get(), rsa.release());
109
26
  if (res != 1) {
110
0
    return SSL_STATUS(InvalidArgument, "Failed to assign private key: $0");
111
0
  }
112
113
26
  return std::move(pkey);
114
26
}
115
116
class ExtensionConfigurator {
117
 public:
118
13
  explicit ExtensionConfigurator(X509 *cert) : cert_(cert) {
119
    // No configuration database
120
13
    X509V3_set_ctx_nodb(&ctx_);
121
    // Both issuer and subject certs
122
13
    X509V3_set_ctx(&ctx_, cert, cert, nullptr, nullptr, 0);
123
13
  }
124
125
13
  CHECKED_STATUS Add(int nid, const char* value) {
126
13
    X509_EXTENSION *ex = X509V3_EXT_conf_nid(nullptr, &ctx_, nid, value);
127
13
    if (!ex) {
128
0
      return SSL_STATUS(InvalidArgument, "Failed to create extension: $0");
129
0
    }
130
131
13
    X509_add_ext(cert_, ex, -1);
132
13
    X509_EXTENSION_free(ex);
133
134
13
    return Status::OK();
135
13
  }
136
137
 private:
138
  X509V3_CTX ctx_;
139
  X509* cert_;
140
};
141
142
Result<detail::X509Ptr> CreateCertificate(
143
26
    EVP_PKEY* key, const std::string& common_name, EVP_PKEY* ca_pkey, X509* ca_cert) {
144
26
  detail::X509Ptr cert(X509_new());
145
26
  if (!cert) {
146
0
    return SSL_STATUS(IOError, "Failed to create new certificate: $0");
147
0
  }
148
149
26
  if (X509_set_version(cert.get(), 2) != 1) {
150
0
    return SSL_STATUS(IOError, "Failed to set certificate version: $0");
151
0
  }
152
153
26
  ASN1_INTEGERPtr aserial(ASN1_INTEGER_new());
154
26
  ASN1_INTEGER_set(aserial.get(), 0);
155
26
  if (!X509_set_serialNumber(cert.get(), aserial.get())) {
156
0
    return SSL_STATUS(IOError, "Failed to set serial number: $0");
157
0
  }
158
159
26
  X509_NAMEPtr name(X509_NAME_new());
160
26
  auto bytes = pointer_cast<const unsigned char*>(common_name.c_str());
161
26
  if (!X509_NAME_add_entry_by_txt(
162
0
      name.get(), "CN", MBSTRING_ASC, bytes, narrow_cast<int>(common_name.length()), -1, 0)) {
163
0
    return SSL_STATUS(IOError, "Failed to create subject: $0");
164
0
  }
165
166
26
  if (X509_set_subject_name(cert.get(), name.get()) != 1) {
167
0
    return SSL_STATUS(IOError, "Failed to set subject: $0");
168
0
  }
169
170
26
  X509_NAME* issuer = name.get();
171
26
  if (ca_cert) {
172
13
    issuer = X509_get_subject_name(ca_cert);
173
13
    if (!issuer) {
174
0
      return SSL_STATUS(IOError, "Failed to get CA subject name: $0");
175
0
    }
176
13
  } else {
177
13
    ExtensionConfigurator configurator(cert.get());
178
13
    RETURN_NOT_OK(configurator.Add(NID_basic_constraints, "critical,CA:TRUE"));
179
13
  }
180
181
26
  if (X509_set_issuer_name(cert.get(), issuer) != 1) {
182
0
    return SSL_STATUS(IOError, "Failed to set issuer: $0");
183
0
  }
184
185
26
  if (X509_set_pubkey(cert.get(), key) != 1) {
186
0
    return SSL_STATUS(IOError, "Failed to set public key: $0");
187
0
  }
188
189
26
  if (!X509_gmtime_adj(X509_get_notBefore(cert.get()), 0)) {
190
0
    return SSL_STATUS(IOError, "Failed to set not before: $0");
191
0
  }
192
193
26
  const auto k1Year = 365 * 24h;
194
26
  auto seconds = std::chrono::duration_cast<std::chrono::seconds>(k1Year).count();
195
26
  if (!X509_gmtime_adj(X509_get_notAfter(cert.get()), seconds)) {
196
0
    return SSL_STATUS(IOError, "Failed to set not after: $0");
197
0
  }
198
199
26
  if (ca_cert) {
200
13
    X509V3_CTX ctx;
201
13
    X509V3_set_ctx(&ctx, ca_cert, cert.get(), nullptr, nullptr, 0);
202
13
  }
203
204
26
  if (!X509_sign(cert.get(), ca_pkey, EVP_sha256())) {
205
0
    return SSL_STATUS(IOError, "Sign failed: $0");
206
0
  }
207
208
26
  return std::move(cert);
209
26
}
210
211
0
const std::unordered_map<std::string, int64_t>& SSLProtocolMap() {
212
0
  static const std::unordered_map<std::string, int64_t> result = {
213
0
      {"ssl2", SSL_OP_NO_SSLv2},
214
0
      {"ssl3", SSL_OP_NO_SSLv3},
215
0
      {"tls10", SSL_OP_NO_TLSv1},
216
0
      {"tls11", SSL_OP_NO_TLSv1_1},
217
0
      {"tls12", SSL_OP_NO_TLSv1_2},
218
0
      {"tls13", SSL_OP_NO_TLSv1_3},
219
0
  };
220
0
  return result;
221
0
}
222
223
88
int64_t ProtocolsOption() {
224
88
  constexpr int64_t kDefaultProtocols = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
225
226
88
  const std::string& ssl_protocols = FLAGS_ssl_protocols;
227
88
  if (ssl_protocols.empty()) {
228
88
    return kDefaultProtocols;
229
88
  }
230
231
0
  const auto& protocol_map = SSLProtocolMap();
232
0
  int64_t result = SSL_OP_NO_SSL_MASK;
233
0
  boost::tokenizer<> tokenizer(ssl_protocols);
234
0
  for (const auto& protocol : tokenizer) {
235
0
    auto it = protocol_map.find(protocol);
236
0
    if (it == protocol_map.end()) {
237
0
      LOG(DFATAL) << "Unknown SSL protocol: " << protocol;
238
0
      return kDefaultProtocols;
239
0
    }
240
0
    result &= ~it->second;
241
0
  }
242
243
0
  return result;
244
0
}
245
246
} // namespace
247
248
namespace detail {
249
250
YB_RPC_SSL_TYPE_DEFINE(EVP_PKEY)
251
YB_RPC_SSL_TYPE_DEFINE(SSL)
252
YB_RPC_SSL_TYPE_DEFINE(SSL_CTX)
253
YB_RPC_SSL_TYPE_DEFINE(X509)
254
255
}
256
257
88
SecureContext::SecureContext() {
258
88
  encryption::InitOpenSSL();
259
260
88
  context_.reset(SSL_CTX_new(SSLv23_method()));
261
88
  DCHECK(context_);
262
263
88
  int64_t protocols = ProtocolsOption();
264
0
  VLOG(1) << "Protocols option: " << protocols;
265
88
  SSL_CTX_set_options(context_.get(), protocols | SSL_OP_NO_COMPRESSION);
266
267
88
  auto cipher_list = FLAGS_cipher_list;
268
88
  if (!cipher_list.empty()) {
269
0
    LOG(INFO) << "Use cipher list: " << cipher_list;
270
0
    auto res = SSL_CTX_set_cipher_list(context_.get(), cipher_list.c_str());
271
0
    LOG_IF(DFATAL, res != 1) << "Failed to set cipher list: "
272
0
                             << SSLErrorMessage(ERR_get_error());
273
0
  }
274
275
88
  auto ciphersuites = FLAGS_ciphersuites;
276
88
  if (!ciphersuites.empty()) {
277
0
    LOG(INFO) << "Use cipher suites: " << ciphersuites;
278
0
    auto res = SSL_CTX_set_ciphersuites(context_.get(), ciphersuites.c_str());
279
0
    LOG_IF(DFATAL, res != 1) << "Failed to set ciphersuites: "
280
0
                           << SSLErrorMessage(ERR_get_error());
281
0
  }
282
283
88
  auto res = SSL_CTX_set_session_id_context(context_.get(), kContextId, sizeof(kContextId));
284
0
  LOG_IF(DFATAL, res != 1) << "Failed to set session id for SSL context: "
285
0
                           << SSLErrorMessage(ERR_get_error());
286
88
}
287
288
1.89k
detail::SSLPtr SecureContext::Create() const {
289
1.89k
  return detail::SSLPtr(SSL_new(context_.get()));
290
1.89k
}
291
292
75
Status SecureContext::AddCertificateAuthorityFile(const std::string& file) {
293
75
  X509_STORE* store = SSL_CTX_get_cert_store(context_.get());
294
75
  if (!store) {
295
0
    return SSL_STATUS(IllegalState, "Failed to get store: $0");
296
0
  }
297
298
75
  auto bytes = pointer_cast<const char*>(file.c_str());
299
75
  auto res = X509_STORE_load_locations(store, bytes, nullptr);
300
75
  if (res != 1) {
301
0
    return SSL_STATUS(InvalidArgument, "Failed to add certificate file: $0");
302
0
  }
303
304
75
  return Status::OK();
305
75
}
306
307
0
Status SecureContext::AddCertificateAuthority(const Slice& data) {
308
0
  return AddCertificateAuthority(VERIFY_RESULT(X509FromSlice(data)).get());
309
0
}
310
311
13
Status SecureContext::AddCertificateAuthority(X509* cert) {
312
13
  X509_STORE* store = SSL_CTX_get_cert_store(context_.get());
313
13
  if (!store) {
314
0
    return SSL_STATUS(IllegalState, "Failed to get store: $0");
315
0
  }
316
317
13
  auto res = X509_STORE_add_cert(store, cert);
318
13
  if (res != 1) {
319
0
    return SSL_STATUS(InvalidArgument, "Failed to add certificate: $0");
320
0
  }
321
322
13
  return Status::OK();
323
13
}
324
325
13
Status SecureContext::TEST_GenerateKeys(int bits, const std::string& common_name) {
326
13
  auto ca_key = VERIFY_RESULT(GeneratePrivateKey(bits));
327
13
  auto ca_cert = VERIFY_RESULT(CreateCertificate(ca_key.get(), "YugaByte", ca_key.get(), nullptr));
328
13
  auto key = VERIFY_RESULT(GeneratePrivateKey(bits));
329
13
  auto cert = VERIFY_RESULT(CreateCertificate(key.get(), common_name, ca_key.get(), ca_cert.get()));
330
331
13
  RETURN_NOT_OK(AddCertificateAuthority(ca_cert.get()));
332
13
  pkey_ = std::move(key);
333
13
  certificate_ = std::move(cert);
334
335
13
  return Status::OK();
336
13
}
337
338
75
Status SecureContext::UsePrivateKey(const Slice& slice) {
339
75
  ERR_clear_error();
340
341
75
  auto bio = VERIFY_RESULT(BIOFromSlice(slice));
342
343
75
  auto pkey = PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr);
344
75
  if (!pkey) {
345
0
    return SSL_STATUS(IOError, "Failed to read private key: $0");
346
0
  }
347
348
75
  pkey_.reset(pkey);
349
75
  return Status::OK();
350
75
}
351
352
75
Status SecureContext::UseCertificate(const Slice& data) {
353
75
  ERR_clear_error();
354
355
75
  certificate_ = VERIFY_RESULT(X509FromSlice(data));
356
357
75
  return Status::OK();
358
75
}
359
360
namespace {
361
362
class SecureRefiner : public StreamRefiner {
363
 public:
364
  SecureRefiner(const SecureContext& context, const StreamCreateData& data)
365
2.97k
    : secure_context_(context), remote_hostname_(data.remote_hostname) {
366
2.97k
  }
367
368
 private:
369
2.99k
  void Start(RefinedStream* stream) override {
370
2.99k
    stream_ = stream;
371
2.99k
  }
372
373
  CHECKED_STATUS Handshake() override;
374
  CHECKED_STATUS Init();
375
376
  CHECKED_STATUS Send(OutboundDataPtr data) override;
377
  CHECKED_STATUS ProcessHeader() override;
378
  Result<ReadBufferFull> Read(StreamReadBuffer* out) override;
379
380
188
  std::string ToString() const override {
381
188
    return "SECURE";
382
188
  }
383
384
0
  const Protocol* GetProtocol() override {
385
0
    return SecureStreamProtocol();
386
0
  }
387
388
  static int VerifyCallback(int preverified, X509_STORE_CTX* store_context);
389
  CHECKED_STATUS Verify(bool preverified, X509_STORE_CTX* store_context);
390
  bool MatchEndpoint(X509* cert, GENERAL_NAMES* gens);
391
  bool MatchUid(X509* cert, GENERAL_NAMES* gens);
392
  bool MatchUidEntry(const Slice& value, const char* name);
393
  Result<bool> WriteEncrypted(OutboundDataPtr data);
394
  void DecryptReceived();
395
396
1.74k
  CHECKED_STATUS Established(RefinedStreamState state) {
397
18.4E
    VLOG_WITH_PREFIX(4) << "Established with state: " << state << ", used cipher: "
398
18.4E
                        << SSL_get_cipher_name(ssl_.get());
399
400
1.74k
    return stream_->Established(state);
401
1.74k
  }
402
403
44
  const std::string& LogPrefix() const {
404
44
    return stream_->LogPrefix();
405
44
  }
406
407
  const SecureContext& secure_context_;
408
  const std::string remote_hostname_;
409
  RefinedStream* stream_ = nullptr;
410
  std::vector<std::string> certificate_entries_;
411
412
  BIOPtr bio_;
413
  detail::SSLPtr ssl_;
414
  Status verification_status_;
415
};
416
417
40.8k
Status SecureRefiner::Send(OutboundDataPtr data) {
418
40.8k
  boost::container::small_vector<RefCntBuffer, 10> queue;
419
40.8k
  data->Serialize(&queue);
420
40.7k
  for (const auto& buf : queue) {
421
40.7k
    Slice slice(buf.data(), buf.size());
422
74.5k
    for (;;) {
423
74.5k
      int slice_size = narrow_cast<int>(slice.size());
424
74.5k
      auto len = SSL_write(ssl_.get(), slice.data(), slice_size);
425
74.5k
      if (len == slice_size) {
426
40.7k
        break;
427
40.7k
      }
428
33.8k
      auto error = len <= 0 ? SSL_get_error(ssl_.get(), len) : SSL_ERROR_NONE;
429
24
      VLOG_WITH_PREFIX(4) << "SSL_write was not full: " << slice.size() << ", written: " << len
430
24
                          << ", error: " << error;
431
33.8k
      if (error != SSL_ERROR_NONE) {
432
0
        if (error != SSL_ERROR_WANT_WRITE || !VERIFY_RESULT(WriteEncrypted(nullptr))) {
433
0
          return STATUS_FORMAT(
434
0
              NetworkError, "SSL write failed: $0 ($1)", SSLErrorMessage(error), error);
435
0
        }
436
33.8k
      } else {
437
33.8k
        RETURN_NOT_OK(WriteEncrypted(nullptr));
438
33.8k
      }
439
33.8k
      if (len > 0) {
440
33.7k
        slice.remove_prefix(len);
441
33.7k
      }
442
33.8k
    }
443
40.7k
  }
444
40.8k
  return ResultToStatus(WriteEncrypted(std::move(data)));
445
40.8k
}
446
447
74.3k
Result<bool> SecureRefiner::WriteEncrypted(OutboundDataPtr data) {
448
74.3k
  auto pending = BIO_ctrl_pending(bio_.get());
449
74.3k
  if (pending == 0) {
450
0
    return data ? STATUS(NetworkError, "No pending data during write") : Result<bool>(false);
451
0
  }
452
74.3k
  RefCntBuffer buf(pending);
453
74.3k
  int buf_size = narrow_cast<int>(buf.size());
454
74.3k
  auto len = BIO_read(bio_.get(), buf.data(), buf_size);
455
18.4E
  LOG_IF_WITH_PREFIX(DFATAL, len != buf_size)
456
18.4E
      << "BIO_read was not full: " << buf.size() << ", read: " << len;
457
18.4E
  VLOG_WITH_PREFIX(4) << "Write encrypted: " << len << ", " << AsString(data);
458
74.3k
  RETURN_NOT_OK(stream_->SendToLower(std::make_shared<SingleBufferOutboundData>(
459
74.3k
      buf, std::move(data))));
460
74.3k
  return true;
461
74.3k
}
462
463
1.10k
Status SecureRefiner::ProcessHeader() {
464
1.10k
  auto data = stream_->ReadBuffer().AppendedVecs();
465
1.10k
  if (data.empty() || data[0].iov_len < 2) {
466
0
    return Status::OK();
467
0
  }
468
469
1.10k
  const auto* bytes = static_cast<const uint8_t*>(data[0].iov_base);
470
1.10k
  if (bytes[0] == 0x16 && bytes[1] == 0x03) { // TLS handshake header
471
1.04k
    RETURN_NOT_OK(Init());
472
1.04k
    return stream_->StartHandshake();
473
57
  }
474
475
57
  if (!FLAGS_allow_insecure_connections) {
476
57
    return STATUS_FORMAT(NetworkError, "Insecure connection header: $0",
477
57
                         Slice(bytes, 2).ToDebugHexString());
478
57
  }
479
480
0
  return Established(RefinedStreamState::kDisabled);
481
0
}
482
483
// Tries to do SSL_read up to num bytes from buf. Possible results:
484
// > 0 - number of bytes actually read.
485
// = 0 - in case of SSL_ERROR_WANT_READ.
486
// Status with network error - in case of other errors.
487
73.9k
Result<ReadBufferFull> SecureRefiner::Read(StreamReadBuffer* out) {
488
73.9k
  DecryptReceived();
489
73.9k
  auto total = 0;
490
73.9k
  auto iovecs = VERIFY_RESULT(out->PrepareAppend());
491
73.9k
  auto iov_it = iovecs.begin();
492
148k
  for (;;) {
493
148k
    auto len = SSL_read(ssl_.get(), iov_it->iov_base, narrow_cast<int>(iov_it->iov_len));
494
495
148k
    if (len <= 0) {
496
73.8k
      auto error = SSL_get_error(ssl_.get(), len);
497
73.8k
      if (error == SSL_ERROR_WANT_READ) {
498
1
        VLOG_WITH_PREFIX(4) << "Read decrypted: SSL_ERROR_WANT_READ";
499
73.7k
        break;
500
73.7k
      }
501
94
      auto status = STATUS_FORMAT(
502
94
          NetworkError, "SSL read failed: $0 ($1)", SSLErrorMessage(error), error);
503
94
      LOG_WITH_PREFIX(INFO) << status;
504
94
      return status;
505
94
    }
506
507
18.4E
    VLOG_WITH_PREFIX(4) << "Read decrypted: " << len;
508
74.8k
    total += len;
509
74.8k
    IoVecRemovePrefix(len, &*iov_it);
510
74.8k
    if (iov_it->iov_len == 0) {
511
112
      if (++iov_it == iovecs.end()) {
512
100
        break;
513
100
      }
514
112
    }
515
74.8k
  }
516
73.8k
  out->DataAppended(total);
517
73.8k
  return ReadBufferFull(out->Full());
518
73.9k
}
519
520
77.7k
void SecureRefiner::DecryptReceived() {
521
77.7k
  auto& inp = stream_->ReadBuffer();
522
77.7k
  if (inp.Empty()) {
523
3.14k
    return;
524
3.14k
  }
525
74.6k
  size_t total = 0;
526
74.6k
  for (const auto& iov : inp.AppendedVecs()) {
527
74.6k
    auto res = BIO_write(bio_.get(), iov.iov_base, narrow_cast<int>(iov.iov_len));
528
15
    VLOG_WITH_PREFIX(4) << "Decrypted: " << res << " of " << iov.iov_len;
529
74.6k
    if (res <= 0) {
530
0
      break;
531
0
    }
532
74.6k
    total += res;
533
74.6k
    if (implicit_cast<size_t>(res) < iov.iov_len) {
534
31.1k
      break;
535
31.1k
    }
536
74.6k
  }
537
74.6k
  inp.Consume(total, {});
538
74.6k
}
539
540
3.89k
Status SecureRefiner::Handshake() {
541
3.89k
  RETURN_NOT_OK(Init());
542
543
3.89k
  DecryptReceived();
544
545
7.49k
  for (;;) {
546
7.49k
    if (stream_->IsConnected()) {
547
1.74k
      return Status::OK();
548
1.74k
    }
549
550
5.75k
    auto pending_before = BIO_ctrl_pending(bio_.get());
551
5.75k
    ERR_clear_error();
552
5.75k
    int result = stream_->local_side() == LocalSide::kClient
553
3.23k
        ? SSL_connect(ssl_.get()) : SSL_accept(ssl_.get());
554
5.75k
    int ssl_error = SSL_get_error(ssl_.get(), result);
555
5.75k
    int sys_error = static_cast<int>(ERR_get_error());
556
5.75k
    auto pending_after = BIO_ctrl_pending(bio_.get());
557
558
5.75k
    if (ssl_error == SSL_ERROR_SSL || ssl_error == SSL_ERROR_SYSCALL) {
559
144
      std::string message = verification_status_.ok()
560
99
          ? (ssl_error == SSL_ERROR_SSL ? SSLErrorMessage(sys_error) : ErrnoToString(sys_error))
561
45
          : verification_status_.ToString();
562
144
      std::string message_suffix;
563
144
      if (FLAGS_dump_certificate_entries) {
564
0
        message_suffix = Format(", certificate entries: $0", certificate_entries_);
565
0
      }
566
144
      return STATUS_FORMAT(NetworkError, "Handshake failed: $0, address: $1, hostname: $2$3",
567
144
                           message, stream_->Remote().address(), remote_hostname_, message_suffix);
568
144
    }
569
570
5.62k
    if (ssl_error == SSL_ERROR_WANT_WRITE || pending_after > pending_before) {
571
      // SSL expects that we would write to underlying transport.
572
3.63k
      RefCntBuffer buffer(pending_after);
573
3.63k
      int len = BIO_read(bio_.get(), buffer.data(), narrow_cast<int>(buffer.size()));
574
3.63k
      DCHECK_EQ(len, pending_after);
575
3.63k
      RETURN_NOT_OK(stream_->SendToLower(
576
3.63k
          std::make_shared<SingleBufferOutboundData>(buffer, nullptr)));
577
      // If SSL_connect/SSL_accept returned positive result it means that TLS connection
578
      // was succesfully established. We just have to send last portion of data.
579
3.63k
      if (result > 0) {
580
1.74k
        RETURN_NOT_OK(Established(RefinedStreamState::kEnabled));
581
1.74k
      }
582
1.98k
    } else if (ssl_error == SSL_ERROR_WANT_READ) {
583
      // SSL expects that we would read from underlying transport.
584
1.98k
      return Status::OK();
585
18.4E
    } else if (SSL_get_shutdown(ssl_.get()) & SSL_RECEIVED_SHUTDOWN) {
586
0
      return STATUS(Aborted, "Handshake aborted");
587
18.4E
    } else {
588
18.4E
      return Established(RefinedStreamState::kEnabled);
589
18.4E
    }
590
5.60k
  }
591
3.89k
}
592
593
4.93k
Status SecureRefiner::Init() {
594
4.93k
  if (ssl_) {
595
3.04k
    return Status::OK();
596
3.04k
  }
597
598
1.89k
  ssl_ = secure_context_.Create();
599
1.89k
  SSL_set_mode(ssl_.get(), SSL_MODE_ENABLE_PARTIAL_WRITE);
600
1.89k
  SSL_set_mode(ssl_.get(), SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
601
1.89k
  SSL_set_mode(ssl_.get(), SSL_MODE_RELEASE_BUFFERS);
602
1.89k
  SSL_set_app_data(ssl_.get(), this);
603
604
1.89k
  if (stream_->local_side() == LocalSide::kServer || secure_context_.use_client_certificate()) {
605
1.43k
    auto res = SSL_use_PrivateKey(ssl_.get(), secure_context_.private_key());
606
1.43k
    if (res != 1) {
607
0
      return SSL_STATUS(InvalidArgument, "Failed to use private key: $0");
608
0
    }
609
1.43k
    res = SSL_use_certificate(ssl_.get(), secure_context_.certificate());
610
1.43k
    if (res != 1) {
611
0
      return SSL_STATUS(InvalidArgument, "Failed to use certificate: $0");
612
0
    }
613
1.89k
  }
614
615
1.89k
  BIO* int_bio = nullptr;
616
1.89k
  BIO* temp_bio = nullptr;
617
1.89k
  BIO_new_bio_pair(&int_bio, 0, &temp_bio, 0);
618
1.89k
  SSL_set_bio(ssl_.get(), int_bio, int_bio);
619
1.89k
  bio_.reset(temp_bio);
620
621
1.89k
  int verify_mode = SSL_VERIFY_PEER;
622
1.89k
  if (secure_context_.require_client_certificate()) {
623
901
    verify_mode |= SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
624
901
  }
625
1.89k
  SSL_set_verify(ssl_.get(), verify_mode, &VerifyCallback);
626
627
1.89k
  return Status::OK();
628
1.89k
}
629
630
2.61k
int SecureRefiner::VerifyCallback(int preverified, X509_STORE_CTX* store_context) {
631
2.61k
  if (!store_context) {
632
0
    return preverified;
633
0
  }
634
635
2.61k
  auto ssl = static_cast<SSL*>(X509_STORE_CTX_get_ex_data(
636
2.61k
      store_context, SSL_get_ex_data_X509_STORE_CTX_idx()));
637
2.61k
  if (!ssl) {
638
0
    return preverified;
639
0
  }
640
641
2.61k
  auto refiner = static_cast<SecureRefiner*>(SSL_get_app_data(ssl));
642
643
2.61k
  if (!refiner) {
644
0
    return preverified;
645
0
  }
646
647
2.61k
  auto status = refiner->Verify(preverified != 0, store_context);
648
2.61k
  if (status.ok()) {
649
2.55k
    return 1;
650
2.55k
  }
651
652
17
  VLOG(4) << refiner->LogPrefix() << status;
653
62
  refiner->verification_status_ = status;
654
62
  return 0;
655
62
}
656
657
namespace {
658
659
// Matches pattern from RFC 2818:
660
// Names may contain the wildcard character * which is considered to match any single domain name
661
// component or component fragment. E.g., *.a.com matches foo.a.com but not bar.foo.a.com.
662
// f*.com matches foo.com but not bar.com.
663
90
bool MatchPattern(Slice pattern, Slice host) {
664
90
  const char* p = pattern.cdata();
665
90
  const char* p_end = pattern.cend();
666
90
  const char* h = host.cdata();
667
90
  const char* h_end = host.cend();
668
669
90
  while (p != p_end && h != h_end) {
670
0
    if (*p == '*') {
671
0
      ++p;
672
0
      while (h != h_end && *h != '.') {
673
0
        if (MatchPattern(Slice(p, p_end), Slice(h, h_end))) {
674
0
          return true;
675
0
        }
676
0
        ++h;
677
0
      }
678
0
    } else if (std::tolower(*p) == std::tolower(*h)) {
679
0
      ++p;
680
0
      ++h;
681
0
    } else {
682
0
      return false;
683
0
    }
684
0
  }
685
686
90
  return p == p_end && h == h_end;
687
90
}
688
689
107
Slice GetEntryByNid(X509* cert, int nid) {
690
107
  X509_NAME* name = X509_get_subject_name(cert);
691
107
  int last_i = -1;
692
214
  for (int i = -1; (i = X509_NAME_get_index_by_NID(name, nid, i)) >= 0; ) {
693
107
    last_i = i;
694
107
  }
695
107
  if (last_i == -1) {
696
0
    return Slice();
697
0
  }
698
107
  auto* name_entry = X509_NAME_get_entry(name, last_i);
699
107
  if (!name_entry) {
700
0
    LOG(DFATAL) << "No name entry in certificate at index: " << last_i;
701
0
    return Slice();
702
0
  }
703
107
  auto* common_name = X509_NAME_ENTRY_get_data(name_entry);
704
705
107
  if (common_name && common_name->data && common_name->length) {
706
107
    return Slice(common_name->data, common_name->length);
707
107
  }
708
709
0
  return Slice();
710
0
}
711
712
107
Slice GetCommonName(X509* cert) {
713
107
  return GetEntryByNid(cert, NID_commonName);
714
107
}
715
716
} // namespace
717
718
1.09k
bool SecureRefiner::MatchEndpoint(X509* cert, GENERAL_NAMES* gens) {
719
1.09k
  auto address = stream_->Remote().address();
720
721
1.18k
  for (int i = 0; i < sk_GENERAL_NAME_num(gens); ++i) {
722
1.07k
    GENERAL_NAME* gen = sk_GENERAL_NAME_value(gens, i);
723
1.07k
    if (gen->type == GEN_DNS) {
724
45
      ASN1_IA5STRING* domain = gen->d.dNSName;
725
45
      if (domain->type == V_ASN1_IA5STRING && domain->data && domain->length) {
726
45
        Slice domain_slice(domain->data, domain->length);
727
0
        VLOG_WITH_PREFIX(4) << "Domain: " << domain_slice.ToBuffer() << " vs " << remote_hostname_;
728
45
        if (FLAGS_dump_certificate_entries) {
729
0
          certificate_entries_.push_back(Format("DNS:$0", domain_slice.ToBuffer()));
730
0
        }
731
45
        if (MatchPattern(domain_slice, remote_hostname_)) {
732
0
          return true;
733
0
        }
734
1.02k
      }
735
1.03k
    } else if (gen->type == GEN_IPADD) {
736
1.03k
      ASN1_OCTET_STRING* ip_address = gen->d.iPAddress;
737
1.03k
      if (ip_address->type == V_ASN1_OCTET_STRING && ip_address->data) {
738
1.03k
        if (ip_address->length == 4) {
739
1.03k
          boost::asio::ip::address_v4::bytes_type bytes;
740
1.03k
          memcpy(&bytes, ip_address->data, bytes.size());
741
1.03k
          auto allowed_address = boost::asio::ip::address_v4(bytes);
742
7
          VLOG_WITH_PREFIX(4) << "IPv4: " << allowed_address.to_string() << " vs " << address;
743
1.03k
          if (FLAGS_dump_certificate_entries) {
744
0
            certificate_entries_.push_back(Format("IP Address:$0", allowed_address));
745
0
          }
746
1.03k
          if (address == allowed_address) {
747
980
            return true;
748
980
          }
749
18.4E
        } else if (ip_address->length == 16) {
750
0
          boost::asio::ip::address_v6::bytes_type bytes;
751
0
          memcpy(&bytes, ip_address->data, bytes.size());
752
0
          auto allowed_address = boost::asio::ip::address_v6(bytes);
753
0
          VLOG_WITH_PREFIX(4) << "IPv6: " << allowed_address.to_string() << " vs " << address;
754
0
          if (FLAGS_dump_certificate_entries) {
755
0
            certificate_entries_.push_back(Format("IP Address:$0", allowed_address));
756
0
          }
757
0
          if (address == allowed_address) {
758
0
            return true;
759
0
          }
760
0
        }
761
1.02k
      }
762
1.03k
    }
763
1.07k
  }
764
765
  // No match in the alternate names, so try the common names. We should only
766
  // use the "most specific" common name, which is the last one in the list.
767
115
  Slice common_name = GetCommonName(cert);
768
115
  if (!common_name.empty()) {
769
0
    VLOG_WITH_PREFIX(4) << "Common name: " << common_name.ToBuffer() << " vs "
770
0
                        << stream_->Remote().address() << "/" << remote_hostname_;
771
107
    if (common_name == stream_->Remote().address().to_string() ||
772
62
        MatchPattern(common_name, remote_hostname_)) {
773
62
      return true;
774
62
    }
775
53
  }
776
777
53
  VLOG_WITH_PREFIX(4) << "Nothing suitable for " << stream_->Remote().address() << "/"
778
8
                      << remote_hostname_;
779
780
53
  return false;
781
53
}
782
783
0
bool SecureRefiner::MatchUidEntry(const Slice& value, const char* name) {
784
0
  if (value == secure_context_.required_uid()) {
785
0
    VLOG_WITH_PREFIX(4) << "Accepted " << name << ": " << value.ToBuffer();
786
0
    return true;
787
0
  } else if (!value.empty()) {
788
0
    VLOG_WITH_PREFIX(4) << "Rejected " << name << ": " << value.ToBuffer() << ", while "
789
0
                        << secure_context_.required_uid() << " required";
790
0
  }
791
0
  return false;
792
0
}
793
794
0
bool IsStringType(int type) {
795
0
  switch (type) {
796
0
    case V_ASN1_UTF8STRING: FALLTHROUGH_INTENDED;
797
0
    case V_ASN1_IA5STRING: FALLTHROUGH_INTENDED;
798
0
    case V_ASN1_UNIVERSALSTRING: FALLTHROUGH_INTENDED;
799
0
    case V_ASN1_BMPSTRING: FALLTHROUGH_INTENDED;
800
0
    case V_ASN1_VISIBLESTRING: FALLTHROUGH_INTENDED;
801
0
    case V_ASN1_PRINTABLESTRING: FALLTHROUGH_INTENDED;
802
0
    case V_ASN1_TELETEXSTRING: FALLTHROUGH_INTENDED;
803
0
    case V_ASN1_GENERALSTRING: FALLTHROUGH_INTENDED;
804
0
    case V_ASN1_NUMERICSTRING:
805
0
      return true;
806
0
  }
807
0
  return false;
808
0
}
809
810
0
bool SecureRefiner::MatchUid(X509* cert, GENERAL_NAMES* gens) {
811
0
  if (MatchUidEntry(GetCommonName(cert), "common name")) {
812
0
    return true;
813
0
  }
814
815
0
  auto uid = GetEntryByNid(cert, NID_userId);
816
0
  if (!uid.empty()) {
817
0
    if (FLAGS_dump_certificate_entries) {
818
0
      certificate_entries_.push_back(Format("UID:$0", uid.ToBuffer()));
819
0
    }
820
0
    if (MatchUidEntry(uid, "uid")) {
821
0
      return true;
822
0
    }
823
0
  }
824
825
0
  for (int i = 0; i < sk_GENERAL_NAME_num(gens); ++i) {
826
0
    GENERAL_NAME* gen = sk_GENERAL_NAME_value(gens, i);
827
0
    if (gen->type == GEN_OTHERNAME) {
828
0
      auto value = gen->d.otherName->value;
829
0
      if (IsStringType(value->type)) {
830
0
        Slice other_name(value->value.asn1_string->data, value->value.asn1_string->length);
831
0
        if (!other_name.empty()) {
832
0
          if (FLAGS_dump_certificate_entries) {
833
0
            certificate_entries_.push_back(Format("ON:$0", other_name.ToBuffer()));
834
0
          }
835
0
          if (MatchUidEntry(other_name, "other name")) {
836
0
            return true;
837
0
          }
838
0
        }
839
0
      }
840
0
    }
841
0
  }
842
0
  VLOG_WITH_PREFIX(4) << "Not found entry for UID " << secure_context_.required_uid();
843
844
0
  return false;
845
0
}
846
847
// Verify according to RFC 2818.
848
2.61k
Status SecureRefiner::Verify(bool preverified, X509_STORE_CTX* store_context) {
849
  // Don't bother looking at certificates that have failed pre-verification.
850
2.61k
  if (!preverified) {
851
0
    auto err = X509_STORE_CTX_get_error(store_context);
852
0
    return STATUS_FORMAT(
853
0
        NetworkError, "Unverified certificate: $0", X509_verify_cert_error_string(err));
854
0
  }
855
856
  // We're only interested in checking the certificate at the end of the chain.
857
2.61k
  int depth = X509_STORE_CTX_get_error_depth(store_context);
858
2.61k
  if (depth > 0) {
859
0
    VLOG_WITH_PREFIX(4) << "Intermediate certificate";
860
1.30k
    return Status::OK();
861
1.30k
  }
862
863
1.30k
  X509* cert = X509_STORE_CTX_get_current_cert(store_context);
864
1.30k
  auto gens = static_cast<GENERAL_NAMES*>(X509_get_ext_d2i(
865
1.30k
      cert, NID_subject_alt_name, nullptr, nullptr));
866
1.30k
  auto se = ScopeExit([gens] {
867
1.30k
    GENERAL_NAMES_free(gens);
868
1.30k
  });
869
870
1.30k
  if (FLAGS_dump_certificate_entries) {
871
0
    certificate_entries_.push_back(Format("CN:$0", GetCommonName(cert).ToBuffer()));
872
0
  }
873
874
1.30k
  if (!secure_context_.required_uid().empty()) {
875
0
    if (!MatchUid(cert, gens)) {
876
0
      return STATUS_FORMAT(
877
0
          NetworkError, "Uid does not match: $0", secure_context_.required_uid());
878
0
    }
879
1.30k
  } else {
880
3
    VLOG_WITH_PREFIX(4) << "Skip UID verification";
881
1.30k
  }
882
883
1.30k
  bool verify_endpoint = stream_->local_side() == LocalSide::kClient ? FLAGS_verify_server_endpoint
884
468
                                                                     : FLAGS_verify_client_endpoint;
885
1.30k
  if (verify_endpoint) {
886
1.09k
    if (!MatchEndpoint(cert, gens)) {
887
45
      return STATUS(NetworkError, "Endpoint does not match");
888
45
    }
889
214
  } else {
890
3
    VLOG_WITH_PREFIX(4) << "Skip endpoint verification";
891
214
  }
892
893
1.26k
  return Status::OK();
894
1.30k
}
895
896
} // namespace
897
898
174
const Protocol* SecureStreamProtocol() {
899
174
  static Protocol result("tcps");
900
174
  return &result;
901
174
}
902
903
StreamFactoryPtr SecureStreamFactory(
904
    StreamFactoryPtr lower_layer_factory, const MemTrackerPtr& buffer_tracker,
905
101
    const SecureContext* context) {
906
101
  return std::make_shared<RefinedStreamFactory>(
907
2.95k
      std::move(lower_layer_factory), buffer_tracker, [context](const StreamCreateData& data) {
908
2.95k
    return std::make_unique<SecureRefiner>(*context, data);
909
2.95k
  });
910
101
}
911
912
} // namespace rpc
913
} // namespace yb