YugabyteDB (2.13.0.0-b42, bfc6a6643e7399ac8a0e81d06a3ee6d6571b33ab)

Coverage Report

Created: 2022-03-09 17:30

/Users/deen/code/yugabyte-db/src/yb/util/net/tunnel.cc
Line
Count
Source (jump to first uncovered line)
1
// Copyright (c) YugaByte, Inc.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
4
// in compliance with the License.  You may obtain a copy of the License at
5
//
6
// http://www.apache.org/licenses/LICENSE-2.0
7
//
8
// Unless required by applicable law or agreed to in writing, software distributed under the License
9
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
10
// or implied.  See the License for the specific language governing permissions and limitations
11
// under the License.
12
//
13
14
#include "yb/util/net/tunnel.h"
15
16
#include <boost/asio/ip/tcp.hpp>
17
#include <boost/asio/strand.hpp>
18
#include <boost/asio/write.hpp>
19
#include <boost/optional.hpp>
20
21
#include "yb/util/logging.h"
22
#include "yb/util/size_literals.h"
23
#include "yb/util/status.h"
24
#include "yb/util/status_format.h"
25
26
using namespace std::placeholders;
27
28
namespace yb {
29
30
class TunnelConnection;
31
32
typedef std::shared_ptr<class TunnelConnection> TunnelConnectionPtr;
33
34
struct SemiTunnel {
35
  boost::asio::ip::tcp::socket* input;
36
  boost::asio::ip::tcp::socket* output;
37
  std::vector<char>* buffer;
38
  TunnelConnectionPtr self;
39
};
40
41
class TunnelConnection : public std::enable_shared_from_this<TunnelConnection> {
42
 public:
43
  explicit TunnelConnection(IoService* io_service, boost::asio::ip::tcp::socket* socket)
44
0
      : inbound_socket_(std::move(*socket)), outbound_socket_(*io_service), strand_(*io_service) {
45
0
  }
46
47
1.15k
  void Start(const Endpoint& dest) {
48
1.15k
    boost::system::error_code ec;
49
1.15k
    auto remote = inbound_socket_.remote_endpoint(ec);
50
1.15k
    auto inbound = inbound_socket_.local_endpoint(ec);
51
1.15k
    log_prefix_ = Format("$0 => $1 => $2: ", remote, inbound, dest);
52
1.15k
    outbound_socket_.async_connect(
53
1.15k
        dest,
54
1.15k
        strand_.wrap(std::bind(&TunnelConnection::HandleConnect, this, _1, shared_from_this())));
55
1.15k
  }
56
57
70
  void Shutdown() {
58
70
    strand_.dispatch([this, shared_self = shared_from_this()] {
59
70
      boost::system::error_code ec;
60
70
      inbound_socket_.shutdown(boost::asio::ip::tcp::socket::shutdown_type::shutdown_both, ec);
61
70
      LOG_IF_WITH_PREFIX(INFO, ec) << "Shutdown failed: " << ec.message();
62
70
      outbound_socket_.shutdown(boost::asio::ip::tcp::socket::shutdown_type::shutdown_both, ec);
63
0
      LOG_IF_WITH_PREFIX(INFO, ec) << "Shutdown failed: " << ec.message();
64
70
    });
65
70
  }
66
67
 private:
68
1.15k
  void HandleConnect(const boost::system::error_code& ec, const TunnelConnectionPtr& self) {
69
1.15k
    if (ec) {
70
0
      LOG_WITH_PREFIX(WARNING) << "Connect failed: " << ec.message();
71
0
      return;
72
0
    }
73
74
1.15k
    if (VLOG_IS_ON(2)) {
75
0
      boost::system::error_code endpoint_ec;
76
0
      VLOG_WITH_PREFIX(2) << "Connected: " << outbound_socket_.local_endpoint(endpoint_ec);
77
0
    }
78
79
1.15k
    in2out_buffer_.resize(4_KB);
80
1.15k
    out2in_buffer_.resize(4_KB);
81
1.15k
    StartRead({&inbound_socket_, &outbound_socket_, &in2out_buffer_, self});
82
1.15k
    StartRead({&outbound_socket_, &inbound_socket_, &out2in_buffer_, self});
83
1.15k
  }
84
85
4.69k
  void StartRead(const SemiTunnel& semi_tunnel) {
86
4.69k
    semi_tunnel.input->async_read_some(
87
4.69k
        boost::asio::buffer(*semi_tunnel.buffer),
88
4.69k
        strand_.wrap(std::bind(&TunnelConnection::HandleRead, this, _1, _2, semi_tunnel)));
89
4.69k
  }
90
91
  void HandleRead(const boost::system::error_code& ec, size_t transferred,
92
3.58k
                  const SemiTunnel& semi_tunnel) {
93
3.58k
    if (ec) {
94
18.4E
      VLOG_WITH_PREFIX(1) << "Read failed: " << ec.message();
95
1.21k
      return;
96
1.21k
    }
97
98
2.37k
    async_write(
99
2.37k
        *semi_tunnel.output, boost::asio::buffer(semi_tunnel.buffer->data(), transferred),
100
2.37k
        strand_.wrap(std::bind(&TunnelConnection::HandleWrite, this, _1, _2, semi_tunnel)));
101
2.37k
  }
102
103
  void HandleWrite(const boost::system::error_code& ec, size_t transferred,
104
2.37k
                   const SemiTunnel& semi_tunnel) {
105
2.37k
    if (ec) {
106
0
      VLOG_WITH_PREFIX(1) << "Write failed: " << ec.message();
107
0
      return;
108
0
    }
109
110
2.37k
    StartRead(semi_tunnel);
111
2.37k
  }
112
113
70
  const std::string& LogPrefix() const {
114
70
    return log_prefix_;
115
70
  }
116
117
  boost::asio::ip::tcp::socket inbound_socket_;
118
  boost::asio::ip::tcp::socket outbound_socket_;
119
  boost::asio::io_context::strand strand_;
120
  std::vector<char> in2out_buffer_;
121
  std::vector<char> out2in_buffer_;
122
  std::string log_prefix_;
123
};
124
125
class Tunnel::Impl {
126
 public:
127
  explicit Impl(boost::asio::io_context* io_context)
128
906
      : io_context_(*io_context), strand_(*io_context) {}
129
130
188
  ~Impl() {
131
30
    LOG_IF(DFATAL, !closing_.load(std::memory_order_acquire))
132
30
        << "Tunnel shutdown has not been started";
133
188
  }
134
135
  CHECKED_STATUS Start(const Endpoint& local, const Endpoint& remote,
136
906
                       AddressChecker address_checker) {
137
906
    auto acceptor = std::make_shared<boost::asio::ip::tcp::acceptor>(io_context_);
138
906
    boost::system::error_code ec;
139
140
906
    LOG(INFO) << "Starting tunnel: " << local << " => " << remote;
141
142
906
    acceptor->open(local.protocol(), ec);
143
906
    if (ec) {
144
0
      return STATUS_FORMAT(NetworkError, "Open failed: $0", ec.message());
145
0
    }
146
906
    acceptor->set_option(boost::asio::socket_base::reuse_address(true), ec);
147
906
    if (ec) {
148
0
      return STATUS_FORMAT(NetworkError, "Reuse address failed: $0", ec.message());
149
0
    }
150
906
    acceptor->bind(local, ec);
151
906
    if (ec) {
152
0
      return STATUS_FORMAT(NetworkError, "Bind failed: $0", ec.message());
153
0
    }
154
906
    acceptor->listen(boost::asio::ip::tcp::socket::max_listen_connections, ec);
155
906
    if (ec) {
156
0
      return STATUS_FORMAT(NetworkError, "Listen failed: $0", ec.message());
157
0
    }
158
906
    strand_.dispatch([
159
906
        this, acceptor, local, remote, address_checker]() {
160
906
      local_ = local;
161
906
      remote_ = remote;
162
906
      address_checker_ = address_checker;
163
906
      acceptor_.emplace(std::move(*acceptor));
164
906
      StartAccept();
165
906
    });
166
906
    return Status::OK();
167
906
  }
168
169
160
  void Shutdown() {
170
160
    closing_.store(true, std::memory_order_release);
171
160
    strand_.dispatch([this] {
172
160
      LOG(INFO) << "Shutdown tunnel: " << local_ << " => " << remote_;
173
160
      if (acceptor_) {
174
160
        boost::system::error_code ec;
175
160
        acceptor_->cancel(ec);
176
0
        LOG_IF(WARNING, ec) << "Cancel failed: " << ec.message();
177
160
        acceptor_->close(ec);
178
0
        LOG_IF(WARNING, ec) << "Close failed: " << ec.message();
179
160
      }
180
181
70
      for (auto& connection : connections_) {
182
70
        auto shared_connection = connection.lock();
183
70
        if (shared_connection) {
184
70
          shared_connection->Shutdown();
185
70
        }
186
70
      }
187
160
      connections_.clear();
188
160
    });
189
160
  }
190
191
 private:
192
2.06k
  void StartAccept() {
193
2.06k
    socket_.emplace(io_context_);
194
2.06k
    acceptor_->async_accept(*socket_, strand_.wrap(std::bind(&Impl::HandleAccept, this, _1)));
195
2.06k
  }
196
197
1.31k
  void HandleAccept(const boost::system::error_code& ec) {
198
1.31k
    if (ec) {
199
1
      LOG_IF(WARNING, ec != boost::asio::error::operation_aborted)
200
1
          << "Accept failed: " << ec.message();
201
161
      return;
202
161
    }
203
204
1.15k
    if (!CheckAddress()) {
205
0
      boost::system::error_code ec;
206
0
      socket_->close(ec);
207
0
      LOG_IF(WARNING, ec) << "Close failed: " << ec.message();
208
0
      StartAccept();
209
0
      return;
210
0
    }
211
212
1.15k
    auto connection = std::make_shared<TunnelConnection>(&io_context_, socket_.get_ptr());
213
1.15k
    connection->Start(remote_);
214
1.15k
    bool found = false;
215
1.33k
    for (auto& weak_connection : connections_) {
216
1.33k
      auto shared_connection = weak_connection.lock();
217
1.33k
      if (!shared_connection) {
218
0
        found = true;
219
0
        weak_connection = connection;
220
0
        break;
221
0
      }
222
1.33k
    }
223
1.15k
    if (!found) {
224
1.15k
      connections_.push_back(connection);
225
1.15k
    }
226
1.15k
    StartAccept();
227
1.15k
  }
228
229
1.15k
  bool CheckAddress() {
230
1.15k
    if (!address_checker_) {
231
1.15k
      return true;
232
1.15k
    }
233
234
0
    boost::system::error_code ec;
235
0
    auto endpoint = socket_->remote_endpoint(ec);
236
237
0
    if (ec) {
238
0
      LOG(WARNING) << "Cannot get remote endpoint: " << ec.message();
239
0
      return true;
240
0
    }
241
242
0
    return address_checker_(endpoint.address());
243
0
  }
244
245
  boost::asio::io_context& io_context_;
246
  boost::asio::io_context::strand strand_;
247
  AddressChecker address_checker_;
248
  Endpoint local_;
249
  Endpoint remote_;
250
  boost::optional<boost::asio::ip::tcp::acceptor> acceptor_;
251
  boost::optional<boost::asio::ip::tcp::socket> socket_;
252
  std::vector<std::weak_ptr<TunnelConnection>> connections_;
253
  std::atomic<bool> closing_{false};
254
};
255
256
906
Tunnel::Tunnel(boost::asio::io_context* io_context) : impl_(new Impl(io_context)) {
257
906
}
258
259
188
Tunnel::~Tunnel() {
260
188
}
261
262
Status Tunnel::Start(const Endpoint& local, const Endpoint& remote,
263
906
                     AddressChecker address_checker) {
264
906
  return impl_->Start(local, remote, std::move(address_checker));
265
906
}
266
267
160
void Tunnel::Shutdown() {
268
160
  impl_->Shutdown();
269
160
}
270
271
} // namespace yb