/Users/deen/code/yugabyte-db/src/yb/util/net/tunnel.cc
Line | Count | Source (jump to first uncovered line) |
1 | | // Copyright (c) YugaByte, Inc. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except |
4 | | // in compliance with the License. You may obtain a copy of the License at |
5 | | // |
6 | | // http://www.apache.org/licenses/LICENSE-2.0 |
7 | | // |
8 | | // Unless required by applicable law or agreed to in writing, software distributed under the License |
9 | | // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express |
10 | | // or implied. See the License for the specific language governing permissions and limitations |
11 | | // under the License. |
12 | | // |
13 | | |
14 | | #include "yb/util/net/tunnel.h" |
15 | | |
16 | | #include <boost/asio/ip/tcp.hpp> |
17 | | #include <boost/asio/strand.hpp> |
18 | | #include <boost/asio/write.hpp> |
19 | | #include <boost/optional.hpp> |
20 | | |
21 | | #include "yb/util/logging.h" |
22 | | #include "yb/util/size_literals.h" |
23 | | #include "yb/util/status.h" |
24 | | #include "yb/util/status_format.h" |
25 | | |
26 | | using namespace std::placeholders; |
27 | | |
28 | | namespace yb { |
29 | | |
30 | | class TunnelConnection; |
31 | | |
32 | | typedef std::shared_ptr<class TunnelConnection> TunnelConnectionPtr; |
33 | | |
34 | | struct SemiTunnel { |
35 | | boost::asio::ip::tcp::socket* input; |
36 | | boost::asio::ip::tcp::socket* output; |
37 | | std::vector<char>* buffer; |
38 | | TunnelConnectionPtr self; |
39 | | }; |
40 | | |
41 | | class TunnelConnection : public std::enable_shared_from_this<TunnelConnection> { |
42 | | public: |
43 | | explicit TunnelConnection(IoService* io_service, boost::asio::ip::tcp::socket* socket) |
44 | 0 | : inbound_socket_(std::move(*socket)), outbound_socket_(*io_service), strand_(*io_service) { |
45 | 0 | } |
46 | | |
47 | 1.15k | void Start(const Endpoint& dest) { |
48 | 1.15k | boost::system::error_code ec; |
49 | 1.15k | auto remote = inbound_socket_.remote_endpoint(ec); |
50 | 1.15k | auto inbound = inbound_socket_.local_endpoint(ec); |
51 | 1.15k | log_prefix_ = Format("$0 => $1 => $2: ", remote, inbound, dest); |
52 | 1.15k | outbound_socket_.async_connect( |
53 | 1.15k | dest, |
54 | 1.15k | strand_.wrap(std::bind(&TunnelConnection::HandleConnect, this, _1, shared_from_this()))); |
55 | 1.15k | } |
56 | | |
57 | 70 | void Shutdown() { |
58 | 70 | strand_.dispatch([this, shared_self = shared_from_this()] { |
59 | 70 | boost::system::error_code ec; |
60 | 70 | inbound_socket_.shutdown(boost::asio::ip::tcp::socket::shutdown_type::shutdown_both, ec); |
61 | 70 | LOG_IF_WITH_PREFIX(INFO, ec) << "Shutdown failed: " << ec.message(); |
62 | 70 | outbound_socket_.shutdown(boost::asio::ip::tcp::socket::shutdown_type::shutdown_both, ec); |
63 | 0 | LOG_IF_WITH_PREFIX(INFO, ec) << "Shutdown failed: " << ec.message(); |
64 | 70 | }); |
65 | 70 | } |
66 | | |
67 | | private: |
68 | 1.15k | void HandleConnect(const boost::system::error_code& ec, const TunnelConnectionPtr& self) { |
69 | 1.15k | if (ec) { |
70 | 0 | LOG_WITH_PREFIX(WARNING) << "Connect failed: " << ec.message(); |
71 | 0 | return; |
72 | 0 | } |
73 | | |
74 | 1.15k | if (VLOG_IS_ON(2)) { |
75 | 0 | boost::system::error_code endpoint_ec; |
76 | 0 | VLOG_WITH_PREFIX(2) << "Connected: " << outbound_socket_.local_endpoint(endpoint_ec); |
77 | 0 | } |
78 | | |
79 | 1.15k | in2out_buffer_.resize(4_KB); |
80 | 1.15k | out2in_buffer_.resize(4_KB); |
81 | 1.15k | StartRead({&inbound_socket_, &outbound_socket_, &in2out_buffer_, self}); |
82 | 1.15k | StartRead({&outbound_socket_, &inbound_socket_, &out2in_buffer_, self}); |
83 | 1.15k | } |
84 | | |
85 | 4.69k | void StartRead(const SemiTunnel& semi_tunnel) { |
86 | 4.69k | semi_tunnel.input->async_read_some( |
87 | 4.69k | boost::asio::buffer(*semi_tunnel.buffer), |
88 | 4.69k | strand_.wrap(std::bind(&TunnelConnection::HandleRead, this, _1, _2, semi_tunnel))); |
89 | 4.69k | } |
90 | | |
91 | | void HandleRead(const boost::system::error_code& ec, size_t transferred, |
92 | 3.58k | const SemiTunnel& semi_tunnel) { |
93 | 3.58k | if (ec) { |
94 | 18.4E | VLOG_WITH_PREFIX(1) << "Read failed: " << ec.message(); |
95 | 1.21k | return; |
96 | 1.21k | } |
97 | | |
98 | 2.37k | async_write( |
99 | 2.37k | *semi_tunnel.output, boost::asio::buffer(semi_tunnel.buffer->data(), transferred), |
100 | 2.37k | strand_.wrap(std::bind(&TunnelConnection::HandleWrite, this, _1, _2, semi_tunnel))); |
101 | 2.37k | } |
102 | | |
103 | | void HandleWrite(const boost::system::error_code& ec, size_t transferred, |
104 | 2.37k | const SemiTunnel& semi_tunnel) { |
105 | 2.37k | if (ec) { |
106 | 0 | VLOG_WITH_PREFIX(1) << "Write failed: " << ec.message(); |
107 | 0 | return; |
108 | 0 | } |
109 | | |
110 | 2.37k | StartRead(semi_tunnel); |
111 | 2.37k | } |
112 | | |
113 | 70 | const std::string& LogPrefix() const { |
114 | 70 | return log_prefix_; |
115 | 70 | } |
116 | | |
117 | | boost::asio::ip::tcp::socket inbound_socket_; |
118 | | boost::asio::ip::tcp::socket outbound_socket_; |
119 | | boost::asio::io_context::strand strand_; |
120 | | std::vector<char> in2out_buffer_; |
121 | | std::vector<char> out2in_buffer_; |
122 | | std::string log_prefix_; |
123 | | }; |
124 | | |
125 | | class Tunnel::Impl { |
126 | | public: |
127 | | explicit Impl(boost::asio::io_context* io_context) |
128 | 906 | : io_context_(*io_context), strand_(*io_context) {} |
129 | | |
130 | 188 | ~Impl() { |
131 | 30 | LOG_IF(DFATAL, !closing_.load(std::memory_order_acquire)) |
132 | 30 | << "Tunnel shutdown has not been started"; |
133 | 188 | } |
134 | | |
135 | | CHECKED_STATUS Start(const Endpoint& local, const Endpoint& remote, |
136 | 906 | AddressChecker address_checker) { |
137 | 906 | auto acceptor = std::make_shared<boost::asio::ip::tcp::acceptor>(io_context_); |
138 | 906 | boost::system::error_code ec; |
139 | | |
140 | 906 | LOG(INFO) << "Starting tunnel: " << local << " => " << remote; |
141 | | |
142 | 906 | acceptor->open(local.protocol(), ec); |
143 | 906 | if (ec) { |
144 | 0 | return STATUS_FORMAT(NetworkError, "Open failed: $0", ec.message()); |
145 | 0 | } |
146 | 906 | acceptor->set_option(boost::asio::socket_base::reuse_address(true), ec); |
147 | 906 | if (ec) { |
148 | 0 | return STATUS_FORMAT(NetworkError, "Reuse address failed: $0", ec.message()); |
149 | 0 | } |
150 | 906 | acceptor->bind(local, ec); |
151 | 906 | if (ec) { |
152 | 0 | return STATUS_FORMAT(NetworkError, "Bind failed: $0", ec.message()); |
153 | 0 | } |
154 | 906 | acceptor->listen(boost::asio::ip::tcp::socket::max_listen_connections, ec); |
155 | 906 | if (ec) { |
156 | 0 | return STATUS_FORMAT(NetworkError, "Listen failed: $0", ec.message()); |
157 | 0 | } |
158 | 906 | strand_.dispatch([ |
159 | 906 | this, acceptor, local, remote, address_checker]() { |
160 | 906 | local_ = local; |
161 | 906 | remote_ = remote; |
162 | 906 | address_checker_ = address_checker; |
163 | 906 | acceptor_.emplace(std::move(*acceptor)); |
164 | 906 | StartAccept(); |
165 | 906 | }); |
166 | 906 | return Status::OK(); |
167 | 906 | } |
168 | | |
169 | 160 | void Shutdown() { |
170 | 160 | closing_.store(true, std::memory_order_release); |
171 | 160 | strand_.dispatch([this] { |
172 | 160 | LOG(INFO) << "Shutdown tunnel: " << local_ << " => " << remote_; |
173 | 160 | if (acceptor_) { |
174 | 160 | boost::system::error_code ec; |
175 | 160 | acceptor_->cancel(ec); |
176 | 0 | LOG_IF(WARNING, ec) << "Cancel failed: " << ec.message(); |
177 | 160 | acceptor_->close(ec); |
178 | 0 | LOG_IF(WARNING, ec) << "Close failed: " << ec.message(); |
179 | 160 | } |
180 | | |
181 | 70 | for (auto& connection : connections_) { |
182 | 70 | auto shared_connection = connection.lock(); |
183 | 70 | if (shared_connection) { |
184 | 70 | shared_connection->Shutdown(); |
185 | 70 | } |
186 | 70 | } |
187 | 160 | connections_.clear(); |
188 | 160 | }); |
189 | 160 | } |
190 | | |
191 | | private: |
192 | 2.06k | void StartAccept() { |
193 | 2.06k | socket_.emplace(io_context_); |
194 | 2.06k | acceptor_->async_accept(*socket_, strand_.wrap(std::bind(&Impl::HandleAccept, this, _1))); |
195 | 2.06k | } |
196 | | |
197 | 1.31k | void HandleAccept(const boost::system::error_code& ec) { |
198 | 1.31k | if (ec) { |
199 | 1 | LOG_IF(WARNING, ec != boost::asio::error::operation_aborted) |
200 | 1 | << "Accept failed: " << ec.message(); |
201 | 161 | return; |
202 | 161 | } |
203 | | |
204 | 1.15k | if (!CheckAddress()) { |
205 | 0 | boost::system::error_code ec; |
206 | 0 | socket_->close(ec); |
207 | 0 | LOG_IF(WARNING, ec) << "Close failed: " << ec.message(); |
208 | 0 | StartAccept(); |
209 | 0 | return; |
210 | 0 | } |
211 | | |
212 | 1.15k | auto connection = std::make_shared<TunnelConnection>(&io_context_, socket_.get_ptr()); |
213 | 1.15k | connection->Start(remote_); |
214 | 1.15k | bool found = false; |
215 | 1.33k | for (auto& weak_connection : connections_) { |
216 | 1.33k | auto shared_connection = weak_connection.lock(); |
217 | 1.33k | if (!shared_connection) { |
218 | 0 | found = true; |
219 | 0 | weak_connection = connection; |
220 | 0 | break; |
221 | 0 | } |
222 | 1.33k | } |
223 | 1.15k | if (!found) { |
224 | 1.15k | connections_.push_back(connection); |
225 | 1.15k | } |
226 | 1.15k | StartAccept(); |
227 | 1.15k | } |
228 | | |
229 | 1.15k | bool CheckAddress() { |
230 | 1.15k | if (!address_checker_) { |
231 | 1.15k | return true; |
232 | 1.15k | } |
233 | | |
234 | 0 | boost::system::error_code ec; |
235 | 0 | auto endpoint = socket_->remote_endpoint(ec); |
236 | |
|
237 | 0 | if (ec) { |
238 | 0 | LOG(WARNING) << "Cannot get remote endpoint: " << ec.message(); |
239 | 0 | return true; |
240 | 0 | } |
241 | | |
242 | 0 | return address_checker_(endpoint.address()); |
243 | 0 | } |
244 | | |
245 | | boost::asio::io_context& io_context_; |
246 | | boost::asio::io_context::strand strand_; |
247 | | AddressChecker address_checker_; |
248 | | Endpoint local_; |
249 | | Endpoint remote_; |
250 | | boost::optional<boost::asio::ip::tcp::acceptor> acceptor_; |
251 | | boost::optional<boost::asio::ip::tcp::socket> socket_; |
252 | | std::vector<std::weak_ptr<TunnelConnection>> connections_; |
253 | | std::atomic<bool> closing_{false}; |
254 | | }; |
255 | | |
256 | 906 | Tunnel::Tunnel(boost::asio::io_context* io_context) : impl_(new Impl(io_context)) { |
257 | 906 | } |
258 | | |
259 | 188 | Tunnel::~Tunnel() { |
260 | 188 | } |
261 | | |
262 | | Status Tunnel::Start(const Endpoint& local, const Endpoint& remote, |
263 | 906 | AddressChecker address_checker) { |
264 | 906 | return impl_->Start(local, remote, std::move(address_checker)); |
265 | 906 | } |
266 | | |
267 | 160 | void Tunnel::Shutdown() { |
268 | 160 | impl_->Shutdown(); |
269 | 160 | } |
270 | | |
271 | | } // namespace yb |