/Users/deen/code/yugabyte-db/src/yb/rpc/serialization.cc
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | // |
18 | | // The following only applies to changes made to this file as part of YugaByte development. |
19 | | // |
20 | | // Portions Copyright (c) YugaByte, Inc. |
21 | | // |
22 | | // Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except |
23 | | // in compliance with the License. You may obtain a copy of the License at |
24 | | // |
25 | | // http://www.apache.org/licenses/LICENSE-2.0 |
26 | | // |
27 | | // Unless required by applicable law or agreed to in writing, software distributed under the License |
28 | | // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express |
29 | | // or implied. See the License for the specific language governing permissions and limitations |
30 | | // under the License. |
31 | | // |
32 | | |
33 | | #include "yb/rpc/serialization.h" |
34 | | |
35 | | #include <google/protobuf/io/coded_stream.h> |
36 | | #include <google/protobuf/message.h> |
37 | | |
38 | | #include "yb/gutil/endian.h" |
39 | | #include "yb/gutil/stringprintf.h" |
40 | | |
41 | | #include "yb/rpc/constants.h" |
42 | | #include "yb/rpc/lightweight_message.h" |
43 | | |
44 | | #include "yb/rpc/rpc_header.pb.h" |
45 | | |
46 | | #include "yb/util/faststring.h" |
47 | | #include "yb/util/ref_cnt_buffer.h" |
48 | | #include "yb/util/result.h" |
49 | | #include "yb/util/slice.h" |
50 | | #include "yb/util/status_format.h" |
51 | | |
52 | | DECLARE_uint64(rpc_max_message_size); |
53 | | |
54 | | using google::protobuf::MessageLite; |
55 | | using google::protobuf::io::CodedInputStream; |
56 | | using google::protobuf::io::CodedOutputStream; |
57 | | |
58 | | namespace yb { |
59 | | namespace rpc { |
60 | | |
61 | 292M | size_t SerializedMessageSize(size_t body_size, size_t additional_size) { |
62 | 292M | auto full_size = body_size + additional_size; |
63 | 292M | return body_size + CodedOutputStream::VarintSize32(narrow_cast<uint32_t>(full_size)); |
64 | 292M | } |
65 | | |
66 | | CHECKED_STATUS SerializeMessage( |
67 | | AnyMessageConstPtr msg, size_t body_size, const RefCntBuffer& param_buf, |
68 | 146M | size_t additional_size, size_t offset) { |
69 | 146M | DCHECK_EQ(msg.SerializedSize(), body_size); |
70 | 146M | auto size = SerializedMessageSize(body_size, additional_size); |
71 | | |
72 | 146M | auto total_size = size + additional_size; |
73 | 146M | if (total_size > FLAGS_rpc_max_message_size) { |
74 | 0 | return STATUS_FORMAT(InvalidArgument, "Sending too long RPC message ($0 bytes)", total_size); |
75 | 0 | } |
76 | | |
77 | 146M | CHECK_EQ(param_buf.size(), offset + size) << "offset = " << offset0 ; |
78 | 146M | uint8_t *dst = param_buf.udata() + offset; |
79 | 146M | dst = CodedOutputStream::WriteVarint32ToArray( |
80 | 146M | narrow_cast<uint32_t>(body_size + additional_size), dst); |
81 | 146M | dst = VERIFY_RESULT146M (146M msg.SerializeToArray(dst)); |
82 | 0 | CHECK_EQ(dst - param_buf.udata(), param_buf.size()); |
83 | | |
84 | 146M | return Status::OK(); |
85 | 146M | } |
86 | | |
87 | | Status SerializeHeader(const MessageLite& header, |
88 | | size_t param_len, |
89 | | RefCntBuffer* header_buf, |
90 | | size_t reserve_for_param, |
91 | 70.7M | size_t* header_size) { |
92 | 70.7M | if (PREDICT_FALSE(!header.IsInitialized())) { |
93 | 0 | LOG(DFATAL) << "Uninitialized RPC header"; |
94 | 0 | return STATUS(InvalidArgument, "RPC header missing required fields", |
95 | 0 | header.InitializationErrorString()); |
96 | 0 | } |
97 | | |
98 | | // Compute all the lengths for the packet. |
99 | 70.7M | size_t header_pb_len = header.ByteSize(); |
100 | 70.7M | size_t header_tot_len = kMsgLengthPrefixLength // Int prefix for the total length. |
101 | 70.7M | + CodedOutputStream::VarintSize32( |
102 | 70.7M | narrow_cast<uint32_t>(header_pb_len)) // Varint delimiter for header PB. |
103 | 70.7M | + header_pb_len; // Length for the header PB itself. |
104 | 70.7M | size_t total_size = header_tot_len + param_len; |
105 | | |
106 | 70.7M | *header_buf = RefCntBuffer(header_tot_len + reserve_for_param); |
107 | 70.8M | if (header_size != nullptr70.7M ) { |
108 | 70.8M | *header_size = header_tot_len; |
109 | 70.8M | } |
110 | 70.7M | uint8_t* dst = header_buf->udata(); |
111 | | |
112 | | // 1. The length for the whole request, not including the 4-byte |
113 | | // length prefix. |
114 | 70.7M | NetworkByteOrder::Store32(dst, narrow_cast<uint32_t>(total_size - kMsgLengthPrefixLength)); |
115 | 70.7M | dst += sizeof(uint32_t); |
116 | | |
117 | | // 2. The varint-prefixed RequestHeader PB |
118 | 70.7M | dst = CodedOutputStream::WriteVarint32ToArray(narrow_cast<uint32_t>(header_pb_len), dst); |
119 | 70.7M | dst = header.SerializeWithCachedSizesToArray(dst); |
120 | | |
121 | | // We should have used the whole buffer we allocated. |
122 | 70.7M | CHECK_EQ(dst, header_buf->udata() + header_tot_len); |
123 | | |
124 | 70.7M | return Status::OK(); |
125 | 70.7M | } |
126 | | |
127 | | Result<RefCntBuffer> SerializeRequest( |
128 | | size_t body_size, size_t additional_size, const google::protobuf::Message& header, |
129 | 70.7M | AnyMessageConstPtr body) { |
130 | 70.7M | auto message_size = SerializedMessageSize(body_size, additional_size); |
131 | 70.7M | size_t header_size = 0; |
132 | 70.7M | RefCntBuffer result; |
133 | 70.7M | RETURN_NOT_OK(SerializeHeader( |
134 | 70.7M | header, message_size + additional_size, &result, message_size, &header_size)); |
135 | | |
136 | 70.7M | RETURN_NOT_OK(SerializeMessage(body, body_size, result, additional_size, header_size)); |
137 | 70.7M | return result; |
138 | 70.7M | } |
139 | | |
140 | 6 | bool SkipField(uint8_t type, CodedInputStream* in) { |
141 | 6 | switch (type) { |
142 | 2 | case 0: { |
143 | 2 | uint64_t temp; |
144 | 2 | return in->ReadVarint64(&temp); |
145 | 0 | } |
146 | 0 | case 1: |
147 | 0 | return in->Skip(8); |
148 | 0 | case 2: { |
149 | 0 | uint32_t temp; |
150 | 0 | return in->ReadVarint32(&temp) && in->Skip(temp); |
151 | 0 | } |
152 | 0 | case 5: |
153 | 0 | return in->Skip(4); |
154 | 4 | default: |
155 | 4 | return false; |
156 | 6 | } |
157 | 6 | } |
158 | | |
159 | 148M | Result<Slice> ParseString(const Slice& buf, const char* name, CodedInputStream* in) { |
160 | 148M | uint32_t len; |
161 | 148M | if (!in->ReadVarint32(&len) || in->BytesUntilLimit() < implicit_cast<int>(len)148M ) { |
162 | 0 | return STATUS(Corruption, "Unable to decode field", Slice(name)); |
163 | 0 | } |
164 | 148M | Slice result(buf.data() + in->CurrentPosition(), len); |
165 | 148M | in->Skip(len); |
166 | 148M | return result; |
167 | 148M | } |
168 | | |
169 | | CHECKED_STATUS ParseHeader( |
170 | 70.8M | const Slice& buf, CodedInputStream* in, ParsedRequestHeader* parsed_header) { |
171 | 283M | while (in->BytesUntilLimit() > 0) { |
172 | 212M | auto tag = in->ReadTag(); |
173 | 212M | auto field = tag >> 3; |
174 | 212M | switch (field) { |
175 | 70.7M | case RequestHeader::kCallIdFieldNumber: { |
176 | 70.7M | uint32_t temp; |
177 | 70.7M | if (!in->ReadVarint32(&temp)) { |
178 | 0 | return STATUS(Corruption, "Unable to decode call_id field"); |
179 | 0 | } |
180 | 70.7M | parsed_header->call_id = static_cast<int32_t>(temp); |
181 | 70.7M | } break; |
182 | 70.8M | case RequestHeader::kRemoteMethodFieldNumber: |
183 | 70.8M | parsed_header->remote_method = VERIFY_RESULT(ParseString(buf, "remote_method", in)); |
184 | 0 | break; |
185 | 70.7M | case RequestHeader::kTimeoutMillisFieldNumber: |
186 | 70.7M | if (!in->ReadVarint32(&parsed_header->timeout_ms)) { |
187 | 0 | return STATUS(Corruption, "Unable to decode timeout_ms field"); |
188 | 0 | } |
189 | 70.7M | break; |
190 | 70.7M | default: { |
191 | 0 | if (!SkipField(tag & 7, in)) { |
192 | 0 | return STATUS_FORMAT(Corruption, "Unable to skip: $0", tag); |
193 | 0 | } |
194 | 0 | } |
195 | 212M | } |
196 | 212M | } |
197 | | |
198 | 70.7M | return Status::OK(); |
199 | 70.8M | } |
200 | | |
201 | 72.4M | CHECKED_STATUS ParseHeader(const Slice& buf, CodedInputStream* in, MessageLite* parsed_header) { |
202 | 72.4M | if (PREDICT_FALSE(!parsed_header->ParseFromCodedStream(in))) { |
203 | 0 | return STATUS(Corruption, "Invalid packet: header too short", |
204 | 0 | buf.ToDebugString()); |
205 | 0 | } |
206 | | |
207 | 72.4M | return Status::OK(); |
208 | 72.4M | } |
209 | | |
210 | | namespace { |
211 | | |
212 | | template <class Header> |
213 | | CHECKED_STATUS DoParseYBMessage(const Slice& buf, |
214 | | Header* parsed_header, |
215 | 143M | Slice* parsed_main_message) { |
216 | 143M | CodedInputStream in(buf.data(), narrow_cast<int>(buf.size())); |
217 | 143M | SetupLimit(&in); |
218 | | |
219 | 143M | uint32_t header_len; |
220 | 143M | if (PREDICT_FALSE(!in.ReadVarint32(&header_len))) { |
221 | 0 | return STATUS(Corruption, "Invalid packet: missing header delimiter", |
222 | 0 | buf.ToDebugString()); |
223 | 0 | } |
224 | | |
225 | 143M | auto l = in.PushLimit(header_len); |
226 | 143M | RETURN_NOT_OK(ParseHeader(buf, &in, parsed_header)); |
227 | 143M | in.PopLimit(l); |
228 | | |
229 | 143M | uint32_t main_msg_len; |
230 | 143M | if (PREDICT_FALSE(!in.ReadVarint32(&main_msg_len))) { |
231 | 0 | return STATUS(Corruption, "Invalid packet: missing main msg length", |
232 | 0 | buf.ToDebugString()); |
233 | 0 | } |
234 | | |
235 | 143M | if (PREDICT_FALSE(!in.Skip(main_msg_len))) { |
236 | 0 | return STATUS(Corruption, |
237 | 0 | StringPrintf("Invalid packet: data too short, expected %d byte main_msg", main_msg_len), |
238 | 0 | buf.ToDebugString()); |
239 | 0 | } |
240 | | |
241 | 143M | if (PREDICT_FALSE(in.BytesUntilLimit() > 0)) { |
242 | 0 | return STATUS(Corruption, |
243 | 0 | StringPrintf("Invalid packet: %d extra bytes at end of packet", in.BytesUntilLimit()), |
244 | 0 | buf.ToDebugString()); |
245 | 0 | } |
246 | | |
247 | 143M | *parsed_main_message = Slice(buf.data() + buf.size() - main_msg_len, |
248 | 143M | main_msg_len); |
249 | 143M | return Status::OK(); |
250 | 143M | } serialization.cc:yb::Status yb::rpc::(anonymous namespace)::DoParseYBMessage<yb::rpc::ParsedRequestHeader>(yb::Slice const&, yb::rpc::ParsedRequestHeader*, yb::Slice*) Line | Count | Source | 215 | 70.8M | Slice* parsed_main_message) { | 216 | 70.8M | CodedInputStream in(buf.data(), narrow_cast<int>(buf.size())); | 217 | 70.8M | SetupLimit(&in); | 218 | | | 219 | 70.8M | uint32_t header_len; | 220 | 70.8M | if (PREDICT_FALSE(!in.ReadVarint32(&header_len))) { | 221 | 0 | return STATUS(Corruption, "Invalid packet: missing header delimiter", | 222 | 0 | buf.ToDebugString()); | 223 | 0 | } | 224 | | | 225 | 70.8M | auto l = in.PushLimit(header_len); | 226 | 70.8M | RETURN_NOT_OK(ParseHeader(buf, &in, parsed_header)); | 227 | 70.8M | in.PopLimit(l); | 228 | | | 229 | 70.8M | uint32_t main_msg_len; | 230 | 70.8M | if (PREDICT_FALSE(!in.ReadVarint32(&main_msg_len))) { | 231 | 0 | return STATUS(Corruption, "Invalid packet: missing main msg length", | 232 | 0 | buf.ToDebugString()); | 233 | 0 | } | 234 | | | 235 | 70.8M | if (PREDICT_FALSE(!in.Skip(main_msg_len))) { | 236 | 0 | return STATUS(Corruption, | 237 | 0 | StringPrintf("Invalid packet: data too short, expected %d byte main_msg", main_msg_len), | 238 | 0 | buf.ToDebugString()); | 239 | 0 | } | 240 | | | 241 | 70.8M | if (PREDICT_FALSE(in.BytesUntilLimit() > 0)) { | 242 | 0 | return STATUS(Corruption, | 243 | 0 | StringPrintf("Invalid packet: %d extra bytes at end of packet", in.BytesUntilLimit()), | 244 | 0 | buf.ToDebugString()); | 245 | 0 | } | 246 | | | 247 | 70.8M | *parsed_main_message = Slice(buf.data() + buf.size() - main_msg_len, | 248 | 70.8M | main_msg_len); | 249 | 70.8M | return Status::OK(); | 250 | 70.8M | } |
serialization.cc:yb::Status yb::rpc::(anonymous namespace)::DoParseYBMessage<google::protobuf::MessageLite>(yb::Slice const&, google::protobuf::MessageLite*, yb::Slice*) Line | Count | Source | 215 | 72.7M | Slice* parsed_main_message) { | 216 | 72.7M | CodedInputStream in(buf.data(), narrow_cast<int>(buf.size())); | 217 | 72.7M | SetupLimit(&in); | 218 | | | 219 | 72.7M | uint32_t header_len; | 220 | 72.7M | if (PREDICT_FALSE(!in.ReadVarint32(&header_len))) { | 221 | 0 | return STATUS(Corruption, "Invalid packet: missing header delimiter", | 222 | 0 | buf.ToDebugString()); | 223 | 0 | } | 224 | | | 225 | 72.7M | auto l = in.PushLimit(header_len); | 226 | 72.7M | RETURN_NOT_OK(ParseHeader(buf, &in, parsed_header)); | 227 | 72.7M | in.PopLimit(l); | 228 | | | 229 | 72.7M | uint32_t main_msg_len; | 230 | 72.7M | if (PREDICT_FALSE(!in.ReadVarint32(&main_msg_len))) { | 231 | 0 | return STATUS(Corruption, "Invalid packet: missing main msg length", | 232 | 0 | buf.ToDebugString()); | 233 | 0 | } | 234 | | | 235 | 72.7M | if (PREDICT_FALSE(!in.Skip(main_msg_len))) { | 236 | 0 | return STATUS(Corruption, | 237 | 0 | StringPrintf("Invalid packet: data too short, expected %d byte main_msg", main_msg_len), | 238 | 0 | buf.ToDebugString()); | 239 | 0 | } | 240 | | | 241 | 72.7M | if (PREDICT_FALSE(in.BytesUntilLimit() > 0)) { | 242 | 0 | return STATUS(Corruption, | 243 | 0 | StringPrintf("Invalid packet: %d extra bytes at end of packet", in.BytesUntilLimit()), | 244 | 0 | buf.ToDebugString()); | 245 | 0 | } | 246 | | | 247 | 72.7M | *parsed_main_message = Slice(buf.data() + buf.size() - main_msg_len, | 248 | 72.7M | main_msg_len); | 249 | 72.7M | return Status::OK(); | 250 | 72.7M | } |
|
251 | | |
252 | | } // namespace |
253 | | |
254 | | Status ParseYBMessage(const Slice& buf, |
255 | | ParsedRequestHeader* parsed_header, |
256 | 70.8M | Slice* parsed_main_message) { |
257 | 70.8M | return DoParseYBMessage(buf, parsed_header, parsed_main_message); |
258 | 70.8M | } |
259 | | |
260 | | Status ParseYBMessage(const Slice& buf, |
261 | | MessageLite* parsed_header, |
262 | 72.7M | Slice* parsed_main_message) { |
263 | 72.7M | return DoParseYBMessage(buf, parsed_header, parsed_main_message); |
264 | 72.7M | } |
265 | | |
266 | 38.9M | Result<ParsedRemoteMethod> ParseRemoteMethod(const Slice& buf) { |
267 | 38.9M | CodedInputStream in(buf.data(), narrow_cast<int>(buf.size())); |
268 | 38.9M | in.PushLimit(narrow_cast<int>(buf.size())); |
269 | 38.9M | ParsedRemoteMethod result; |
270 | 116M | while (in.BytesUntilLimit() > 0) { |
271 | 77.8M | auto tag = in.ReadTag(); |
272 | 77.8M | auto field = tag >> 3; |
273 | 77.8M | switch (field) { |
274 | 38.9M | case RemoteMethodPB::kServiceNameFieldNumber: |
275 | 38.9M | result.service = VERIFY_RESULT(ParseString(buf, "service_name", &in)); |
276 | 0 | break; |
277 | 38.9M | case RemoteMethodPB::kMethodNameFieldNumber: |
278 | 38.9M | result.method = VERIFY_RESULT(ParseString(buf, "method_name", &in)); |
279 | 0 | break; |
280 | 6 | default: { |
281 | 6 | if (!SkipField(tag & 7, &in)) { |
282 | 4 | return STATUS_FORMAT(Corruption, "Unable to skip: $0", tag); |
283 | 4 | } |
284 | 6 | } |
285 | 77.8M | } |
286 | 77.8M | } |
287 | 38.9M | return result; |
288 | 38.9M | } |
289 | | |
290 | 9.26k | std::string ParsedRequestHeader::RemoteMethodAsString() const { |
291 | 9.26k | auto parsed_remote_method = ParseRemoteMethod(remote_method); |
292 | 9.26k | if (parsed_remote_method.ok()) { |
293 | 9.24k | return parsed_remote_method->service.ToBuffer() + "." + |
294 | 9.24k | parsed_remote_method->method.ToBuffer(); |
295 | 9.24k | } else { |
296 | 29 | return parsed_remote_method.status().ToString(); |
297 | 29 | } |
298 | 9.26k | } |
299 | | |
300 | 1 | void ParsedRequestHeader::ToPB(RequestHeader* out) const { |
301 | 1 | out->set_call_id(call_id); |
302 | 1 | if (timeout_ms) { |
303 | 0 | out->set_timeout_millis(timeout_ms); |
304 | 0 | } |
305 | 1 | auto parsed_remote_method = ParseRemoteMethod(remote_method); |
306 | 1 | if (parsed_remote_method.ok()) { |
307 | 1 | out->mutable_remote_method()->set_service_name(parsed_remote_method->service.ToBuffer()); |
308 | 1 | out->mutable_remote_method()->set_method_name(parsed_remote_method->method.ToBuffer()); |
309 | 1 | } |
310 | 1 | } |
311 | | |
312 | | } // namespace rpc |
313 | | } // namespace yb |