/Users/deen/code/yugabyte-db/src/yb/rpc/rpc-test-base.cc
Line | Count | Source (jump to first uncovered line) |
1 | | // |
2 | | // Copyright (c) YugaByte, Inc. |
3 | | // |
4 | | // Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except |
5 | | // in compliance with the License. You may obtain a copy of the License at |
6 | | // |
7 | | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | | // |
9 | | // Unless required by applicable law or agreed to in writing, software distributed under the License |
10 | | // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express |
11 | | // or implied. See the License for the specific language governing permissions and limitations |
12 | | // under the License. |
13 | | // |
14 | | // |
15 | | |
16 | | #include "yb/rpc/rpc-test-base.h" |
17 | | |
18 | | #include <thread> |
19 | | |
20 | | #include "yb/rpc/proxy.h" |
21 | | #include "yb/rpc/rpc_controller.h" |
22 | | #include "yb/rpc/yb_rpc.h" |
23 | | |
24 | | #include "yb/util/debug-util.h" |
25 | | #include "yb/util/flag_tags.h" |
26 | | #include "yb/util/net/net_util.h" |
27 | | #include "yb/util/random_util.h" |
28 | | #include "yb/util/result.h" |
29 | | #include "yb/util/status_log.h" |
30 | | #include "yb/util/test_macros.h" |
31 | | |
32 | | using namespace std::chrono_literals; |
33 | | |
34 | | DEFINE_test_flag(bool, pause_calculator_echo_request, false, |
35 | | "Pause calculator echo request execution until flag is set back to false."); |
36 | | |
37 | | DECLARE_int64(outbound_rpc_block_size); |
38 | | DECLARE_int64(outbound_rpc_memory_limit); |
39 | | |
40 | | namespace yb { namespace rpc { |
41 | | |
42 | | using yb::rpc_test::CalculatorServiceIf; |
43 | | using yb::rpc_test::CalculatorError; |
44 | | |
45 | | using yb::rpc_test::AddRequestPB; |
46 | | using yb::rpc_test::AddResponsePB; |
47 | | using yb::rpc_test::EchoRequestPB; |
48 | | using yb::rpc_test::EchoResponsePB; |
49 | | using yb::rpc_test::ForwardRequestPB; |
50 | | using yb::rpc_test::ForwardResponsePB; |
51 | | using yb::rpc_test::PanicRequestPB; |
52 | | using yb::rpc_test::PanicResponsePB; |
53 | | using yb::rpc_test::SendStringsRequestPB; |
54 | | using yb::rpc_test::SendStringsResponsePB; |
55 | | using yb::rpc_test::SleepRequestPB; |
56 | | using yb::rpc_test::SleepResponsePB; |
57 | | using yb::rpc_test::WhoAmIRequestPB; |
58 | | using yb::rpc_test::WhoAmIResponsePB; |
59 | | using yb::rpc_test::PingRequestPB; |
60 | | using yb::rpc_test::PingResponsePB; |
61 | | using yb::rpc_test::DisconnectRequestPB; |
62 | | using yb::rpc_test::DisconnectResponsePB; |
63 | | |
64 | | using yb::rpc_test_diff_package::ReqDiffPackagePB; |
65 | | using yb::rpc_test_diff_package::RespDiffPackagePB; |
66 | | |
67 | | namespace { |
68 | | |
69 | | constexpr size_t kQueueLength = 1000; |
70 | | |
71 | 25 | Slice GetSidecarPointer(const RpcController& controller, int idx, size_t expected_size) { |
72 | 25 | Slice sidecar = CHECK_RESULT(controller.GetSidecar(idx)); |
73 | 25 | CHECK_EQ(expected_size, sidecar.size()); |
74 | 25 | return sidecar; |
75 | 25 | } |
76 | | |
77 | | MessengerBuilder CreateMessengerBuilder(const std::string& name, |
78 | | const scoped_refptr<MetricEntity>& metric_entity, |
79 | 202 | const MessengerOptions& options) { |
80 | 202 | MessengerBuilder bld(name); |
81 | 202 | bld.set_num_reactors(options.n_reactors); |
82 | 202 | if (options.num_connections_to_server >= 0) { |
83 | 8 | bld.set_num_connections_to_server(options.num_connections_to_server); |
84 | 8 | } |
85 | 202 | static constexpr std::chrono::milliseconds kMinCoarseTimeGranularity(1); |
86 | 202 | static constexpr std::chrono::milliseconds kMaxCoarseTimeGranularity(100); |
87 | 202 | auto coarse_time_granularity = std::max(std::min(options.keep_alive_timeout / 10, |
88 | 202 | kMaxCoarseTimeGranularity), |
89 | 202 | kMinCoarseTimeGranularity); |
90 | 0 | VLOG(1) << "Creating a messenger with connection keep alive time: " |
91 | 0 | << options.keep_alive_timeout.count() << " ms, " |
92 | 0 | << "coarse time granularity: " << coarse_time_granularity.count() << " ms"; |
93 | 202 | bld.set_connection_keepalive_time(options.keep_alive_timeout); |
94 | 202 | bld.set_coarse_timer_granularity(coarse_time_granularity); |
95 | 202 | bld.set_metric_entity(metric_entity); |
96 | 202 | bld.CreateConnectionContextFactory<YBOutboundConnectionContext>( |
97 | 202 | FLAGS_outbound_rpc_memory_limit, |
98 | 202 | MemTracker::FindOrCreateTracker(name)); |
99 | 202 | return bld; |
100 | 202 | } |
101 | | |
102 | | std::unique_ptr<Messenger> CreateMessenger(const std::string& name, |
103 | | const scoped_refptr<MetricEntity>& metric_entity, |
104 | 134 | const MessengerOptions& options) { |
105 | 134 | return EXPECT_RESULT(CreateMessengerBuilder(name, metric_entity, options).Build()); |
106 | 134 | } |
107 | | |
108 | | #ifdef THREAD_SANITIZER |
109 | | constexpr std::chrono::milliseconds kDefaultKeepAlive = 15s; |
110 | | #else |
111 | | constexpr std::chrono::milliseconds kDefaultKeepAlive = 1s; |
112 | | #endif |
113 | | |
114 | | } // namespace |
115 | | |
116 | | const MessengerOptions kDefaultClientMessengerOptions = {1, kDefaultKeepAlive}; |
117 | | const MessengerOptions kDefaultServerMessengerOptions = {3, kDefaultKeepAlive}; |
118 | | |
119 | | void GenericCalculatorService::AddMethodToMap( |
120 | 56 | const RpcServicePtr& service, RpcEndpointMap* map, const char* method_name, Method method) { |
121 | 56 | size_t index = methods_.size(); |
122 | 56 | methods_.emplace_back( |
123 | 56 | RemoteMethod(CalculatorServiceIf::static_service_name(), method_name), method); |
124 | 56 | map->emplace(methods_.back().first.serialized_body(), std::make_pair(service, index)); |
125 | 56 | } |
126 | | |
127 | 14 | void GenericCalculatorService::FillEndpoints(const RpcServicePtr& service, RpcEndpointMap* map) { |
128 | 14 | AddMethodToMap( |
129 | 14 | service, map, CalculatorServiceMethods::kAddMethodName, &GenericCalculatorService::DoAdd); |
130 | 14 | AddMethodToMap( |
131 | 14 | service, map, CalculatorServiceMethods::kSleepMethodName, &GenericCalculatorService::DoSleep); |
132 | 14 | AddMethodToMap( |
133 | 14 | service, map, CalculatorServiceMethods::kEchoMethodName, &GenericCalculatorService::DoEcho); |
134 | 14 | AddMethodToMap( |
135 | 14 | service, map, CalculatorServiceMethods::kSendStringsMethodName, |
136 | 14 | &GenericCalculatorService::DoSendStrings); |
137 | 14 | } |
138 | | |
139 | 225k | void GenericCalculatorService::Handle(InboundCallPtr incoming) { |
140 | 225k | (this->*methods_[incoming->method_index()].second)(incoming.get()); |
141 | 225k | } |
142 | | |
143 | 757 | void GenericCalculatorService::GenericCalculatorService::DoAdd(InboundCall* incoming) { |
144 | 757 | Slice param(incoming->serialized_request()); |
145 | 757 | AddRequestPB req; |
146 | 757 | if (!req.ParseFromArray(param.data(), narrow_cast<int>(param.size()))) { |
147 | 0 | LOG(FATAL) << "couldn't parse: " << param.ToDebugString(); |
148 | 0 | } |
149 | | |
150 | 757 | AddResponsePB resp; |
151 | 757 | resp.set_result(req.x() + req.y()); |
152 | 757 | down_cast<YBInboundCall*>(incoming)->RespondSuccess(AnyMessageConstPtr(&resp)); |
153 | 757 | } |
154 | | |
155 | 3 | void GenericCalculatorService::DoSendStrings(InboundCall* incoming) { |
156 | 3 | Slice param(incoming->serialized_request()); |
157 | 3 | SendStringsRequestPB req; |
158 | 3 | if (!req.ParseFromArray(param.data(), narrow_cast<int>(param.size()))) { |
159 | 0 | LOG(FATAL) << "couldn't parse: " << param.ToDebugString(); |
160 | 0 | } |
161 | | |
162 | 3 | Random r(req.random_seed()); |
163 | 3 | SendStringsResponsePB resp; |
164 | 3 | auto* yb_call = down_cast<YBInboundCall*>(incoming); |
165 | 25 | for (auto size : req.sizes()) { |
166 | 25 | auto sidecar = RefCntBuffer(size); |
167 | 25 | RandomString(sidecar.udata(), size, &r); |
168 | 25 | resp.add_sidecars(narrow_cast<uint32_t>(yb_call->AddRpcSidecar(sidecar.as_slice()))); |
169 | 25 | } |
170 | | |
171 | 3 | down_cast<YBInboundCall*>(incoming)->RespondSuccess(AnyMessageConstPtr(&resp)); |
172 | 3 | } |
173 | | |
174 | 4 | void GenericCalculatorService::DoSleep(InboundCall* incoming) { |
175 | 4 | Slice param(incoming->serialized_request()); |
176 | 4 | SleepRequestPB req; |
177 | 4 | if (!req.ParseFromArray(param.data(), narrow_cast<int>(param.size()))) { |
178 | 0 | incoming->RespondFailure(ErrorStatusPB::ERROR_INVALID_REQUEST, |
179 | 0 | STATUS(InvalidArgument, "Couldn't parse pb", |
180 | 0 | req.InitializationErrorString())); |
181 | 0 | return; |
182 | 0 | } |
183 | | |
184 | 4 | LOG(INFO) << "got call: " << req.ShortDebugString(); |
185 | 4 | SleepFor(MonoDelta::FromMicroseconds(req.sleep_micros())); |
186 | 4 | SleepResponsePB resp; |
187 | 4 | down_cast<YBInboundCall*>(incoming)->RespondSuccess(AnyMessageConstPtr(&resp)); |
188 | 4 | } |
189 | | |
190 | 224k | void GenericCalculatorService::DoEcho(InboundCall* incoming) { |
191 | 224k | Slice param(incoming->serialized_request()); |
192 | 224k | EchoRequestPB req; |
193 | 224k | if (!req.ParseFromArray(param.data(), narrow_cast<int>(param.size()))) { |
194 | 0 | incoming->RespondFailure(ErrorStatusPB::ERROR_INVALID_REQUEST, |
195 | 0 | STATUS(InvalidArgument, "Couldn't parse pb", |
196 | 0 | req.InitializationErrorString())); |
197 | 0 | return; |
198 | 0 | } |
199 | | |
200 | 224k | EchoResponsePB resp; |
201 | 224k | resp.set_data(std::move(*req.mutable_data())); |
202 | 224k | down_cast<YBInboundCall*>(incoming)->RespondSuccess(AnyMessageConstPtr(&resp)); |
203 | 224k | } |
204 | | |
205 | | namespace { |
206 | | |
207 | | class CalculatorService: public CalculatorServiceIf { |
208 | | public: |
209 | | explicit CalculatorService(const scoped_refptr<MetricEntity>& entity, |
210 | | std::string name) |
211 | 69 | : CalculatorServiceIf(entity), name_(std::move(name)) { |
212 | 69 | } |
213 | | |
214 | 69 | void SetMessenger(Messenger* messenger) { |
215 | 69 | messenger_ = messenger; |
216 | 69 | } |
217 | | |
218 | 616 | void Add(const AddRequestPB* req, AddResponsePB* resp, RpcContext context) override { |
219 | 616 | resp->set_result(req->x() + req->y()); |
220 | 616 | context.RespondSuccess(); |
221 | 616 | } |
222 | | |
223 | 39 | void Sleep(const SleepRequestPB* req, SleepResponsePB* resp, RpcContext context) override { |
224 | 39 | if (req->return_app_error()) { |
225 | 1 | CalculatorError my_error; |
226 | 1 | my_error.set_extra_error_data("some application-specific error data"); |
227 | 1 | context.RespondApplicationError(CalculatorError::app_error_ext.number(), |
228 | 1 | "Got some error", my_error); |
229 | 1 | return; |
230 | 1 | } |
231 | | |
232 | | // Respond w/ error if the RPC specifies that the client deadline is set, |
233 | | // but it isn't. |
234 | 38 | if (req->client_timeout_defined()) { |
235 | 3 | auto deadline = context.GetClientDeadline(); |
236 | 3 | if (deadline == CoarseTimePoint::max()) { |
237 | 1 | CalculatorError my_error; |
238 | 1 | my_error.set_extra_error_data("Timeout not set"); |
239 | 1 | context.RespondApplicationError(CalculatorError::app_error_ext.number(), |
240 | 1 | "Missing required timeout", my_error); |
241 | 1 | return; |
242 | 1 | } |
243 | 37 | } |
244 | | |
245 | 37 | if (req->deferred()) { |
246 | | // Spawn a new thread which does the sleep and responds later. |
247 | 2 | std::thread thread([this, req, context = std::move(context)]() mutable { |
248 | 2 | DoSleep(req, std::move(context)); |
249 | 2 | }); |
250 | 2 | thread.detach(); |
251 | 2 | return; |
252 | 2 | } |
253 | 35 | DoSleep(req, std::move(context)); |
254 | 35 | } |
255 | | |
256 | 16.0k | void Echo(const EchoRequestPB* req, EchoResponsePB* resp, RpcContext context) override { |
257 | 16.0k | TEST_PAUSE_IF_FLAG(TEST_pause_calculator_echo_request); |
258 | 16.0k | resp->set_data(req->data()); |
259 | 16.0k | context.RespondSuccess(); |
260 | 16.0k | } |
261 | | |
262 | 4 | void WhoAmI(const WhoAmIRequestPB* req, WhoAmIResponsePB* resp, RpcContext context) override { |
263 | 4 | LOG(INFO) << "Remote address: " << context.remote_address(); |
264 | 4 | resp->set_address(yb::ToString(context.remote_address())); |
265 | 4 | context.RespondSuccess(); |
266 | 4 | } |
267 | | |
268 | | void TestArgumentsInDiffPackage( |
269 | 0 | const ReqDiffPackagePB* req, RespDiffPackagePB* resp, RpcContext context) override { |
270 | 0 | context.RespondSuccess(); |
271 | 0 | } |
272 | | |
273 | 1 | void Panic(const PanicRequestPB* req, PanicResponsePB* resp, RpcContext context) override { |
274 | 1 | TRACE("Got panic request"); |
275 | 1 | PANIC_RPC(&context, "Test method panicking!"); |
276 | 1 | } |
277 | | |
278 | 49.8k | void Ping(const PingRequestPB* req, PingResponsePB* resp, RpcContext context) override { |
279 | 49.8k | auto now = MonoTime::Now(); |
280 | 49.8k | resp->set_time(now.ToUint64()); |
281 | 49.8k | context.RespondSuccess(); |
282 | 49.8k | } |
283 | | |
284 | | void Disconnect( |
285 | 8.87k | const DisconnectRequestPB* peq, DisconnectResponsePB* resp, RpcContext context) override { |
286 | 8.87k | context.CloseConnection(); |
287 | 8.87k | context.RespondSuccess(); |
288 | 8.87k | } |
289 | | |
290 | 12 | void Forward(const ForwardRequestPB* req, ForwardResponsePB* resp, RpcContext context) override { |
291 | 12 | if (!req->has_host() || !req->has_port()) { |
292 | 6 | resp->set_name(name_); |
293 | 6 | context.RespondSuccess(); |
294 | 6 | return; |
295 | 6 | } |
296 | 6 | HostPort hostport(req->host(), req->port()); |
297 | 6 | ProxyCache cache(messenger_); |
298 | 6 | rpc_test::CalculatorServiceProxy proxy(&cache, hostport); |
299 | | |
300 | 6 | ForwardRequestPB forwarded_req; |
301 | 6 | ForwardResponsePB forwarded_resp; |
302 | 6 | RpcController controller; |
303 | 6 | auto status = proxy.Forward(forwarded_req, &forwarded_resp, &controller); |
304 | 6 | if (!status.ok()) { |
305 | 2 | context.RespondFailure(status); |
306 | 4 | } else { |
307 | 4 | resp->set_name(forwarded_resp.name()); |
308 | 4 | context.RespondSuccess(); |
309 | 4 | } |
310 | 6 | } |
311 | | |
312 | | void Lightweight( |
313 | | const rpc_test::LWLightweightRequestPB* const_req, rpc_test::LWLightweightResponsePB* resp, |
314 | 1 | RpcContext context) override { |
315 | 1 | auto* req = const_cast<rpc_test::LWLightweightRequestPB*>(const_req); |
316 | | |
317 | 1 | resp->set_i32(-req->i32()); |
318 | 1 | resp->set_i64(-req->i64()); |
319 | 1 | resp->set_f32(req->u32()); |
320 | 1 | resp->set_f64(req->u64()); |
321 | 1 | resp->set_u32(req->f32()); |
322 | 1 | resp->set_u64(req->f64()); |
323 | 1 | resp->set_r32(-req->r32()); |
324 | 1 | resp->set_r64(-req->r64()); |
325 | 1 | resp->ref_str(req->bytes()); |
326 | 1 | resp->ref_bytes(req->str()); |
327 | 1 | resp->set_en(static_cast<rpc_test::LightweightEnum>(req->en() + 1)); |
328 | 1 | resp->set_sf32(req->si32()); |
329 | 1 | resp->set_sf64(req->si64()); |
330 | 1 | resp->set_si32(req->sf32()); |
331 | 1 | resp->set_si64(req->sf64()); |
332 | 1 | *resp->mutable_ru32() = req->rf32(); |
333 | 1 | *resp->mutable_rf32() = req->ru32(); |
334 | | |
335 | 1 | resp->mutable_rstr()->assign(req->rstr().rbegin(), req->rstr().rend()); |
336 | | |
337 | 1 | auto& resp_msg = *resp->mutable_message(); |
338 | 1 | const auto& req_msg = req->message(); |
339 | 1 | resp_msg.set_sf32(-req_msg.sf32()); |
340 | | |
341 | 1 | resp_msg.mutable_rsi32()->assign(req_msg.rsi32().rbegin(), req_msg.rsi32().rend()); |
342 | | |
343 | 1 | resp_msg.dup_str(">" + req_msg.str().ToBuffer() + "<"); |
344 | | |
345 | 1 | resp_msg.mutable_rbytes()->assign(req_msg.rbytes().rbegin(), req_msg.rbytes().rend()); |
346 | | |
347 | 1 | for (auto it = req->mutable_repeated_messages()->rbegin(); |
348 | 6 | it != req->mutable_repeated_messages()->rend(); ++it) { |
349 | 5 | resp->mutable_repeated_messages()->push_back_ref(&*it); |
350 | 5 | } |
351 | 5 | for (const auto& msg : req->repeated_messages()) { |
352 | 5 | auto temp = CopySharedMessage<rpc_test::LWLightweightSubMessagePB>(msg.ToGoogleProtobuf()); |
353 | 5 | resp->mutable_repeated_messages_copy()->emplace_back(*temp); |
354 | 5 | } |
355 | | |
356 | 1 | resp->mutable_packed_u64()->assign(req->packed_u64().rbegin(), req->packed_u64().rend()); |
357 | | |
358 | 1 | resp->mutable_packed_f32()->assign(req->packed_f32().rbegin(), req->packed_f32().rend()); |
359 | | |
360 | 13 | for (const auto& p : req->pairs()) { |
361 | 13 | auto& pair = *resp->add_pairs(); |
362 | 13 | *pair.mutable_s1() = p.s2(); |
363 | 13 | *pair.mutable_s2() = p.s1(); |
364 | 13 | } |
365 | | |
366 | 1 | resp->ref_ptr_message(req->mutable_ptr_message()); |
367 | | |
368 | | // Should check it before filling map, because map does not preserve order. |
369 | 1 | ASSERT_STR_EQ(AsString(resp->ToGoogleProtobuf()), AsString(*resp)); |
370 | | |
371 | 11 | for (const auto& p : req->map()) { |
372 | 11 | auto& pair = *resp->add_map(); |
373 | 11 | pair.ref_key(p.key()); |
374 | 11 | pair.set_value(p.value()); |
375 | 11 | } |
376 | | |
377 | 1 | req->mutable_map()->clear(); |
378 | 1 | resp->dup_short_debug_string(req->ShortDebugString()); |
379 | | |
380 | 1 | context.RespondSuccess(); |
381 | 1 | } |
382 | | |
383 | | Result<rpc_test::TrivialResponsePB> Trivial( |
384 | 2 | const rpc_test::TrivialRequestPB& req, CoarseTimePoint deadline) override { |
385 | 2 | if (req.value() < 0) { |
386 | 1 | return STATUS_FORMAT(InvalidArgument, "Negative value: $0", req.value()); |
387 | 1 | } |
388 | 1 | rpc_test::TrivialResponsePB resp; |
389 | 1 | resp.set_value(req.value()); |
390 | 1 | return resp; |
391 | 1 | } |
392 | | |
393 | | private: |
394 | 36 | void DoSleep(const SleepRequestPB* req, RpcContext context) { |
395 | 36 | SleepFor(MonoDelta::FromMicroseconds(req->sleep_micros())); |
396 | 36 | context.RespondSuccess(); |
397 | 36 | } |
398 | | |
399 | | std::string name_; |
400 | | Messenger* messenger_ = nullptr; |
401 | | }; |
402 | | |
403 | | std::unique_ptr<CalculatorService> CreateCalculatorService( |
404 | 69 | const scoped_refptr<MetricEntity>& metric_entity, std::string name = std::string()) { |
405 | 69 | return std::make_unique<CalculatorService>(metric_entity, std::move(name)); |
406 | 69 | } |
407 | | |
408 | | class AbacusService: public rpc_test::AbacusServiceIf { |
409 | | public: |
410 | 69 | explicit AbacusService(const scoped_refptr<MetricEntity>& entity) : AbacusServiceIf(entity) {} |
411 | | |
412 | | void Concat( |
413 | | const rpc_test::ConcatRequestPB *req, |
414 | | rpc_test::ConcatResponsePB *resp, |
415 | 1 | RpcContext context) { |
416 | 1 | resp->set_result(req->lhs() + req->rhs()); |
417 | 1 | context.RespondSuccess(); |
418 | 1 | } |
419 | | }; |
420 | | |
421 | | } // namespace |
422 | | |
423 | | TestServer::TestServer(std::unique_ptr<Messenger>&& messenger, |
424 | | const TestServerOptions& options) |
425 | | : messenger_(std::move(messenger)), |
426 | | thread_pool_(std::make_unique<ThreadPool>( |
427 | 82 | "rpc-test", kQueueLength, options.n_worker_threads)) { |
428 | | |
429 | 82 | EXPECT_OK(messenger_->ListenAddress( |
430 | 82 | rpc::CreateConnectionContextFactory<rpc::YBInboundConnectionContext>(), |
431 | 82 | options.endpoint, &bound_endpoint_)); |
432 | 82 | } |
433 | | |
434 | 82 | CHECKED_STATUS TestServer::Start() { |
435 | 82 | return messenger_->StartAcceptor(); |
436 | 82 | } |
437 | | |
438 | 151 | CHECKED_STATUS TestServer::RegisterService(std::unique_ptr<ServiceIf> service) { |
439 | 151 | const std::string& service_name = service->service_name(); |
440 | | |
441 | 151 | auto service_pool = make_scoped_refptr<ServicePool>(kQueueLength, |
442 | 151 | thread_pool_.get(), |
443 | 151 | &messenger_->scheduler(), |
444 | 151 | std::move(service), |
445 | 151 | messenger_->metric_entity()); |
446 | 151 | if (!service_pool_) { |
447 | 82 | service_pool_ = service_pool; |
448 | 82 | } |
449 | | |
450 | 151 | return messenger_->RegisterService(service_name, std::move(service_pool)); |
451 | 151 | } |
452 | | |
453 | 146 | TestServer::~TestServer() { |
454 | 146 | thread_pool_ = nullptr; |
455 | 146 | if (service_pool_) { |
456 | 80 | messenger_->UnregisterAllServices(); |
457 | 80 | service_pool_->Shutdown(); |
458 | 80 | } |
459 | 146 | if (messenger_) { |
460 | 80 | messenger_->Shutdown(); |
461 | 80 | } |
462 | 146 | } |
463 | | |
464 | 2 | void TestServer::Shutdown() { |
465 | 2 | messenger_->UnregisterAllServices(); |
466 | 2 | service_pool_->Shutdown(); |
467 | 2 | messenger_->Shutdown(); |
468 | 2 | } |
469 | | |
470 | | RpcTestBase::RpcTestBase() |
471 | 98 | : metric_entity_(METRIC_ENTITY_server.Instantiate(&metric_registry_, "test.rpc_test")) { |
472 | 98 | } |
473 | | |
474 | 90 | void RpcTestBase::TearDown() { |
475 | 90 | server_.reset(); |
476 | 90 | YBTest::TearDown(); |
477 | 90 | } |
478 | | |
479 | 771 | CHECKED_STATUS RpcTestBase::DoTestSyncCall(Proxy* proxy, const RemoteMethod* method) { |
480 | 771 | AddRequestPB req; |
481 | 771 | req.set_x(RandomUniformInt<uint32_t>()); |
482 | 771 | req.set_y(RandomUniformInt<uint32_t>()); |
483 | 771 | AddResponsePB resp; |
484 | 771 | RpcController controller; |
485 | 771 | controller.set_timeout(MonoDelta::FromMilliseconds(10000)); |
486 | 771 | RETURN_NOT_OK(proxy->SyncRequest(method, /* method_metrics= */ nullptr, req, &resp, &controller)); |
487 | | |
488 | 18.4E | VLOG(1) << "Result: " << resp.ShortDebugString(); |
489 | 755 | CHECK_EQ(req.x() + req.y(), resp.result()); |
490 | 755 | return Status::OK(); |
491 | 771 | } |
492 | | |
493 | | void RpcTestBase::DoTestSidecar(Proxy* proxy, |
494 | | std::vector<size_t> sizes, |
495 | 3 | Status::Code expected_code) { |
496 | 3 | const uint32_t kSeed = 12345; |
497 | | |
498 | 3 | SendStringsRequestPB req; |
499 | 25 | for (auto size : sizes) { |
500 | 25 | req.add_sizes(size); |
501 | 25 | } |
502 | 3 | req.set_random_seed(kSeed); |
503 | | |
504 | 3 | SendStringsResponsePB resp; |
505 | 3 | RpcController controller; |
506 | 3 | controller.set_timeout(MonoDelta::FromMilliseconds(10000)); |
507 | 3 | auto status = proxy->SyncRequest( |
508 | 3 | CalculatorServiceMethods::SendStringsMethod(), /* method_metrics= */ nullptr, req, &resp, |
509 | 3 | &controller); |
510 | | |
511 | 6 | ASSERT_EQ(expected_code, status.code()) << "Invalid status received: " << status.ToString(); |
512 | | |
513 | 3 | if (!status.ok()) { |
514 | 0 | return; |
515 | 0 | } |
516 | | |
517 | 3 | Random rng(kSeed); |
518 | 3 | faststring expected; |
519 | 28 | for (size_t i = 0; i != sizes.size(); ++i) { |
520 | 25 | size_t size = sizes[i]; |
521 | 25 | expected.resize(size); |
522 | 25 | Slice sidecar = GetSidecarPointer(controller, resp.sidecars(narrow_cast<uint32_t>(i)), size); |
523 | 25 | RandomString(expected.data(), size, &rng); |
524 | 50 | ASSERT_EQ(0, sidecar.compare(expected)) << "Invalid sidecar at " << i << " position"; |
525 | 25 | } |
526 | 3 | } |
527 | | |
528 | 28 | void RpcTestBase::DoTestExpectTimeout(Proxy* proxy, const MonoDelta& timeout) { |
529 | 28 | SleepRequestPB req; |
530 | 28 | SleepResponsePB resp; |
531 | 28 | req.set_sleep_micros(500000); // 0.5sec |
532 | | |
533 | 28 | RpcController c; |
534 | 28 | c.set_timeout(timeout); |
535 | 28 | Stopwatch sw; |
536 | 28 | sw.start(); |
537 | 28 | Status s = proxy->SyncRequest( |
538 | 28 | CalculatorServiceMethods::SleepMethod(), /* method_metrics= */ nullptr, req, &resp, &c); |
539 | 28 | ASSERT_FALSE(s.ok()); |
540 | 28 | sw.stop(); |
541 | | |
542 | 28 | auto expected_millis = timeout.ToMilliseconds(); |
543 | 28 | int elapsed_millis = sw.elapsed().wall_millis(); |
544 | | |
545 | | // We shouldn't timeout significantly faster than our configured timeout. |
546 | 28 | EXPECT_GE(elapsed_millis, expected_millis - 10); |
547 | | // And we also shouldn't take the full 0.5sec that we asked for |
548 | 28 | EXPECT_LT(elapsed_millis, 500); |
549 | 28 | EXPECT_TRUE(s.IsTimedOut()); |
550 | 28 | LOG(INFO) << "status: " << s.ToString() << ", seconds elapsed: " << sw.elapsed().wall_seconds(); |
551 | 28 | } |
552 | | |
553 | 13 | void RpcTestBase::StartTestServer(Endpoint* server_endpoint, const TestServerOptions& options) { |
554 | 13 | server_ = std::make_unique<TestServer>( |
555 | 13 | CreateMessenger("TestServer", options.messenger_options), options); |
556 | 13 | EXPECT_OK(server_->RegisterService(std::make_unique<GenericCalculatorService>(metric_entity_))); |
557 | 13 | EXPECT_OK(server_->Start()); |
558 | 13 | *server_endpoint = server_->bound_endpoint(); |
559 | 13 | } |
560 | | |
561 | 12 | void RpcTestBase::StartTestServer(HostPort* server_hostport, const TestServerOptions& options) { |
562 | 12 | Endpoint endpoint; |
563 | 12 | StartTestServer(&endpoint, options); |
564 | 12 | *server_hostport = HostPort::FromBoundEndpoint(endpoint); |
565 | 12 | } |
566 | | |
567 | | TestServer RpcTestBase::StartTestServer( |
568 | | const TestServerOptions& options, const std::string& name, |
569 | 69 | std::unique_ptr<Messenger> messenger) { |
570 | 69 | if (!messenger) { |
571 | 34 | messenger = CreateMessenger("TestServer", options.messenger_options); |
572 | 34 | } |
573 | 69 | TestServer result(std::move(messenger), options); |
574 | 69 | auto service = CreateCalculatorService(metric_entity(), name); |
575 | 69 | service->SetMessenger(result.messenger()); |
576 | 69 | EXPECT_OK(result.RegisterService(std::move(service))); |
577 | 69 | EXPECT_OK(result.RegisterService(std::make_unique<AbacusService>(metric_entity()))); |
578 | 69 | EXPECT_OK(result.Start()); |
579 | 69 | return result; |
580 | 69 | } |
581 | | |
582 | | void RpcTestBase::StartTestServerWithGeneratedCode(HostPort* server_hostport, |
583 | 31 | const TestServerOptions& options) { |
584 | 31 | StartTestServerWithGeneratedCode(nullptr, server_hostport, options); |
585 | 31 | } |
586 | | |
587 | | void RpcTestBase::StartTestServerWithGeneratedCode(std::unique_ptr<Messenger>&& messenger, |
588 | | HostPort* server_hostport, |
589 | 66 | const TestServerOptions& options) { |
590 | 66 | server_ = std::make_unique<TestServer>(StartTestServer( |
591 | 66 | options, std::string(), std::move(messenger))); |
592 | 66 | *server_hostport = HostPort::FromBoundEndpoint(server_->bound_endpoint()); |
593 | 66 | } |
594 | | |
595 | 3 | CHECKED_STATUS RpcTestBase::StartFakeServer(Socket* listen_sock, HostPort* listen_hostport) { |
596 | 3 | RETURN_NOT_OK(listen_sock->Init(0)); |
597 | 3 | RETURN_NOT_OK(listen_sock->BindAndListen(Endpoint(), 1)); |
598 | 3 | Endpoint endpoint; |
599 | 3 | RETURN_NOT_OK(listen_sock->GetSocketAddress(&endpoint)); |
600 | 3 | LOG(INFO) << "Bound to: " << endpoint; |
601 | 3 | *listen_hostport = HostPort::FromBoundEndpoint(endpoint); |
602 | 3 | return Status::OK(); |
603 | 3 | } |
604 | | |
605 | | std::unique_ptr<Messenger> RpcTestBase::CreateMessenger( |
606 | 134 | const string &name, const MessengerOptions& options) { |
607 | 134 | return yb::rpc::CreateMessenger(name, metric_entity_, options); |
608 | 134 | } |
609 | | |
610 | | AutoShutdownMessengerHolder RpcTestBase::CreateAutoShutdownMessengerHolder( |
611 | 58 | const string &name, const MessengerOptions& options) { |
612 | 58 | return rpc::CreateAutoShutdownMessengerHolder(CreateMessenger(name, options)); |
613 | 58 | } |
614 | | |
615 | | MessengerBuilder RpcTestBase::CreateMessengerBuilder(const string &name, |
616 | 68 | const MessengerOptions& options) { |
617 | 68 | return yb::rpc::CreateMessengerBuilder(name, metric_entity_, options); |
618 | 68 | } |
619 | | |
620 | | } // namespace rpc |
621 | | |
622 | | namespace rpc_test { |
623 | | |
624 | 1 | void SetupError(TrivialErrorPB* error, const Status& status) { |
625 | 1 | error->set_code(status.code()); |
626 | 1 | } |
627 | | |
628 | | } |
629 | | |
630 | | } // namespace yb |