YugabyteDB (2.13.1.0-b60, 21121d69985fbf76aa6958d8f04a9bfa936293b5)

Coverage Report

Created: 2022-03-22 16:43

/Users/deen/code/yugabyte-db/src/yb/rpc/serialization.cc
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
//
18
// The following only applies to changes made to this file as part of YugaByte development.
19
//
20
// Portions Copyright (c) YugaByte, Inc.
21
//
22
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
23
// in compliance with the License.  You may obtain a copy of the License at
24
//
25
// http://www.apache.org/licenses/LICENSE-2.0
26
//
27
// Unless required by applicable law or agreed to in writing, software distributed under the License
28
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
29
// or implied.  See the License for the specific language governing permissions and limitations
30
// under the License.
31
//
32
33
#include "yb/rpc/serialization.h"
34
35
#include <google/protobuf/io/coded_stream.h>
36
#include <google/protobuf/message.h>
37
38
#include "yb/gutil/endian.h"
39
#include "yb/gutil/stringprintf.h"
40
41
#include "yb/rpc/constants.h"
42
#include "yb/rpc/lightweight_message.h"
43
44
#include "yb/rpc/rpc_header.pb.h"
45
46
#include "yb/util/faststring.h"
47
#include "yb/util/ref_cnt_buffer.h"
48
#include "yb/util/result.h"
49
#include "yb/util/slice.h"
50
#include "yb/util/status_format.h"
51
52
DECLARE_uint64(rpc_max_message_size);
53
54
using google::protobuf::MessageLite;
55
using google::protobuf::io::CodedInputStream;
56
using google::protobuf::io::CodedOutputStream;
57
58
namespace yb {
59
namespace rpc {
60
61
292M
size_t SerializedMessageSize(size_t body_size, size_t additional_size) {
62
292M
  auto full_size = body_size + additional_size;
63
292M
  return body_size + CodedOutputStream::VarintSize32(narrow_cast<uint32_t>(full_size));
64
292M
}
65
66
CHECKED_STATUS SerializeMessage(
67
    AnyMessageConstPtr msg, size_t body_size, const RefCntBuffer& param_buf,
68
146M
    size_t additional_size, size_t offset) {
69
146M
  DCHECK_EQ(msg.SerializedSize(), body_size);
70
146M
  auto size = SerializedMessageSize(body_size, additional_size);
71
72
146M
  auto total_size = size + additional_size;
73
146M
  if (total_size > FLAGS_rpc_max_message_size) {
74
0
    return STATUS_FORMAT(InvalidArgument, "Sending too long RPC message ($0 bytes)", total_size);
75
0
  }
76
77
146M
  CHECK_EQ
(param_buf.size(), offset + size) << "offset = " << offset0
;
78
146M
  uint8_t *dst = param_buf.udata() + offset;
79
146M
  dst = CodedOutputStream::WriteVarint32ToArray(
80
146M
      narrow_cast<uint32_t>(body_size + additional_size), dst);
81
146M
  dst = 
VERIFY_RESULT146M
(146M
msg.SerializeToArray(dst));
82
0
  CHECK_EQ(dst - param_buf.udata(), param_buf.size());
83
84
146M
  return Status::OK();
85
146M
}
86
87
Status SerializeHeader(const MessageLite& header,
88
                       size_t param_len,
89
                       RefCntBuffer* header_buf,
90
                       size_t reserve_for_param,
91
70.7M
                       size_t* header_size) {
92
70.7M
  if (PREDICT_FALSE(!header.IsInitialized())) {
93
0
    LOG(DFATAL) << "Uninitialized RPC header";
94
0
    return STATUS(InvalidArgument, "RPC header missing required fields",
95
0
                                  header.InitializationErrorString());
96
0
  }
97
98
  // Compute all the lengths for the packet.
99
70.7M
  size_t header_pb_len = header.ByteSize();
100
70.7M
  size_t header_tot_len = kMsgLengthPrefixLength        // Int prefix for the total length.
101
70.7M
      + CodedOutputStream::VarintSize32(
102
70.7M
            narrow_cast<uint32_t>(header_pb_len))      // Varint delimiter for header PB.
103
70.7M
      + header_pb_len;                                  // Length for the header PB itself.
104
70.7M
  size_t total_size = header_tot_len + param_len;
105
106
70.7M
  *header_buf = RefCntBuffer(header_tot_len + reserve_for_param);
107
70.8M
  if (
header_size != nullptr70.7M
) {
108
70.8M
    *header_size = header_tot_len;
109
70.8M
  }
110
70.7M
  uint8_t* dst = header_buf->udata();
111
112
  // 1. The length for the whole request, not including the 4-byte
113
  // length prefix.
114
70.7M
  NetworkByteOrder::Store32(dst, narrow_cast<uint32_t>(total_size - kMsgLengthPrefixLength));
115
70.7M
  dst += sizeof(uint32_t);
116
117
  // 2. The varint-prefixed RequestHeader PB
118
70.7M
  dst = CodedOutputStream::WriteVarint32ToArray(narrow_cast<uint32_t>(header_pb_len), dst);
119
70.7M
  dst = header.SerializeWithCachedSizesToArray(dst);
120
121
  // We should have used the whole buffer we allocated.
122
70.7M
  CHECK_EQ(dst, header_buf->udata() + header_tot_len);
123
124
70.7M
  return Status::OK();
125
70.7M
}
126
127
Result<RefCntBuffer> SerializeRequest(
128
    size_t body_size, size_t additional_size, const google::protobuf::Message& header,
129
70.7M
    AnyMessageConstPtr body) {
130
70.7M
  auto message_size = SerializedMessageSize(body_size, additional_size);
131
70.7M
  size_t header_size = 0;
132
70.7M
  RefCntBuffer result;
133
70.7M
  RETURN_NOT_OK(SerializeHeader(
134
70.7M
      header, message_size + additional_size, &result, message_size, &header_size));
135
136
70.7M
  RETURN_NOT_OK(SerializeMessage(body, body_size, result, additional_size, header_size));
137
70.7M
  return result;
138
70.7M
}
139
140
6
bool SkipField(uint8_t type, CodedInputStream* in) {
141
6
  switch (type) {
142
2
    case 0: {
143
2
      uint64_t temp;
144
2
      return in->ReadVarint64(&temp);
145
0
    }
146
0
    case 1:
147
0
      return in->Skip(8);
148
0
    case 2: {
149
0
      uint32_t temp;
150
0
      return in->ReadVarint32(&temp) && in->Skip(temp);
151
0
    }
152
0
    case 5:
153
0
      return in->Skip(4);
154
4
    default:
155
4
      return false;
156
6
  }
157
6
}
158
159
148M
Result<Slice> ParseString(const Slice& buf, const char* name, CodedInputStream* in) {
160
148M
  uint32_t len;
161
148M
  if (!in->ReadVarint32(&len) || 
in->BytesUntilLimit() < implicit_cast<int>(len)148M
) {
162
0
    return STATUS(Corruption, "Unable to decode field", Slice(name));
163
0
  }
164
148M
  Slice result(buf.data() + in->CurrentPosition(), len);
165
148M
  in->Skip(len);
166
148M
  return result;
167
148M
}
168
169
CHECKED_STATUS ParseHeader(
170
70.8M
    const Slice& buf, CodedInputStream* in, ParsedRequestHeader* parsed_header) {
171
283M
  while (in->BytesUntilLimit() > 0) {
172
212M
    auto tag = in->ReadTag();
173
212M
    auto field = tag >> 3;
174
212M
    switch (field) {
175
70.7M
      case RequestHeader::kCallIdFieldNumber: {
176
70.7M
        uint32_t temp;
177
70.7M
        if (!in->ReadVarint32(&temp)) {
178
0
          return STATUS(Corruption, "Unable to decode call_id field");
179
0
        }
180
70.7M
        parsed_header->call_id = static_cast<int32_t>(temp);
181
70.7M
        } break;
182
70.8M
      case RequestHeader::kRemoteMethodFieldNumber:
183
70.8M
        parsed_header->remote_method = VERIFY_RESULT(ParseString(buf, "remote_method", in));
184
0
        break;
185
70.7M
      case RequestHeader::kTimeoutMillisFieldNumber:
186
70.7M
        if (!in->ReadVarint32(&parsed_header->timeout_ms)) {
187
0
          return STATUS(Corruption, "Unable to decode timeout_ms field");
188
0
        }
189
70.7M
        break;
190
70.7M
      default: {
191
0
        if (!SkipField(tag & 7, in)) {
192
0
          return STATUS_FORMAT(Corruption, "Unable to skip: $0", tag);
193
0
        }
194
0
      }
195
212M
    }
196
212M
  }
197
198
70.7M
  return Status::OK();
199
70.8M
}
200
201
72.4M
CHECKED_STATUS ParseHeader(const Slice& buf, CodedInputStream* in, MessageLite* parsed_header) {
202
72.4M
  if (PREDICT_FALSE(!parsed_header->ParseFromCodedStream(in))) {
203
0
    return STATUS(Corruption, "Invalid packet: header too short",
204
0
                              buf.ToDebugString());
205
0
  }
206
207
72.4M
  return Status::OK();
208
72.4M
}
209
210
namespace {
211
212
template <class Header>
213
CHECKED_STATUS DoParseYBMessage(const Slice& buf,
214
                                Header* parsed_header,
215
143M
                                Slice* parsed_main_message) {
216
143M
  CodedInputStream in(buf.data(), narrow_cast<int>(buf.size()));
217
143M
  SetupLimit(&in);
218
219
143M
  uint32_t header_len;
220
143M
  if (PREDICT_FALSE(!in.ReadVarint32(&header_len))) {
221
0
    return STATUS(Corruption, "Invalid packet: missing header delimiter",
222
0
                              buf.ToDebugString());
223
0
  }
224
225
143M
  auto l = in.PushLimit(header_len);
226
143M
  RETURN_NOT_OK(ParseHeader(buf, &in, parsed_header));
227
143M
  in.PopLimit(l);
228
229
143M
  uint32_t main_msg_len;
230
143M
  if (PREDICT_FALSE(!in.ReadVarint32(&main_msg_len))) {
231
0
    return STATUS(Corruption, "Invalid packet: missing main msg length",
232
0
                              buf.ToDebugString());
233
0
  }
234
235
143M
  if (PREDICT_FALSE(!in.Skip(main_msg_len))) {
236
0
    return STATUS(Corruption,
237
0
        StringPrintf("Invalid packet: data too short, expected %d byte main_msg", main_msg_len),
238
0
        buf.ToDebugString());
239
0
  }
240
241
143M
  if (PREDICT_FALSE(in.BytesUntilLimit() > 0)) {
242
0
    return STATUS(Corruption,
243
0
      StringPrintf("Invalid packet: %d extra bytes at end of packet", in.BytesUntilLimit()),
244
0
      buf.ToDebugString());
245
0
  }
246
247
143M
  *parsed_main_message = Slice(buf.data() + buf.size() - main_msg_len,
248
143M
                              main_msg_len);
249
143M
  return Status::OK();
250
143M
}
serialization.cc:yb::Status yb::rpc::(anonymous namespace)::DoParseYBMessage<yb::rpc::ParsedRequestHeader>(yb::Slice const&, yb::rpc::ParsedRequestHeader*, yb::Slice*)
Line
Count
Source
215
70.8M
                                Slice* parsed_main_message) {
216
70.8M
  CodedInputStream in(buf.data(), narrow_cast<int>(buf.size()));
217
70.8M
  SetupLimit(&in);
218
219
70.8M
  uint32_t header_len;
220
70.8M
  if (PREDICT_FALSE(!in.ReadVarint32(&header_len))) {
221
0
    return STATUS(Corruption, "Invalid packet: missing header delimiter",
222
0
                              buf.ToDebugString());
223
0
  }
224
225
70.8M
  auto l = in.PushLimit(header_len);
226
70.8M
  RETURN_NOT_OK(ParseHeader(buf, &in, parsed_header));
227
70.8M
  in.PopLimit(l);
228
229
70.8M
  uint32_t main_msg_len;
230
70.8M
  if (PREDICT_FALSE(!in.ReadVarint32(&main_msg_len))) {
231
0
    return STATUS(Corruption, "Invalid packet: missing main msg length",
232
0
                              buf.ToDebugString());
233
0
  }
234
235
70.8M
  if (PREDICT_FALSE(!in.Skip(main_msg_len))) {
236
0
    return STATUS(Corruption,
237
0
        StringPrintf("Invalid packet: data too short, expected %d byte main_msg", main_msg_len),
238
0
        buf.ToDebugString());
239
0
  }
240
241
70.8M
  if (PREDICT_FALSE(in.BytesUntilLimit() > 0)) {
242
0
    return STATUS(Corruption,
243
0
      StringPrintf("Invalid packet: %d extra bytes at end of packet", in.BytesUntilLimit()),
244
0
      buf.ToDebugString());
245
0
  }
246
247
70.8M
  *parsed_main_message = Slice(buf.data() + buf.size() - main_msg_len,
248
70.8M
                              main_msg_len);
249
70.8M
  return Status::OK();
250
70.8M
}
serialization.cc:yb::Status yb::rpc::(anonymous namespace)::DoParseYBMessage<google::protobuf::MessageLite>(yb::Slice const&, google::protobuf::MessageLite*, yb::Slice*)
Line
Count
Source
215
72.7M
                                Slice* parsed_main_message) {
216
72.7M
  CodedInputStream in(buf.data(), narrow_cast<int>(buf.size()));
217
72.7M
  SetupLimit(&in);
218
219
72.7M
  uint32_t header_len;
220
72.7M
  if (PREDICT_FALSE(!in.ReadVarint32(&header_len))) {
221
0
    return STATUS(Corruption, "Invalid packet: missing header delimiter",
222
0
                              buf.ToDebugString());
223
0
  }
224
225
72.7M
  auto l = in.PushLimit(header_len);
226
72.7M
  RETURN_NOT_OK(ParseHeader(buf, &in, parsed_header));
227
72.7M
  in.PopLimit(l);
228
229
72.7M
  uint32_t main_msg_len;
230
72.7M
  if (PREDICT_FALSE(!in.ReadVarint32(&main_msg_len))) {
231
0
    return STATUS(Corruption, "Invalid packet: missing main msg length",
232
0
                              buf.ToDebugString());
233
0
  }
234
235
72.7M
  if (PREDICT_FALSE(!in.Skip(main_msg_len))) {
236
0
    return STATUS(Corruption,
237
0
        StringPrintf("Invalid packet: data too short, expected %d byte main_msg", main_msg_len),
238
0
        buf.ToDebugString());
239
0
  }
240
241
72.7M
  if (PREDICT_FALSE(in.BytesUntilLimit() > 0)) {
242
0
    return STATUS(Corruption,
243
0
      StringPrintf("Invalid packet: %d extra bytes at end of packet", in.BytesUntilLimit()),
244
0
      buf.ToDebugString());
245
0
  }
246
247
72.7M
  *parsed_main_message = Slice(buf.data() + buf.size() - main_msg_len,
248
72.7M
                              main_msg_len);
249
72.7M
  return Status::OK();
250
72.7M
}
251
252
} // namespace
253
254
Status ParseYBMessage(const Slice& buf,
255
                      ParsedRequestHeader* parsed_header,
256
70.8M
                      Slice* parsed_main_message) {
257
70.8M
  return DoParseYBMessage(buf, parsed_header, parsed_main_message);
258
70.8M
}
259
260
Status ParseYBMessage(const Slice& buf,
261
                      MessageLite* parsed_header,
262
72.7M
                      Slice* parsed_main_message) {
263
72.7M
  return DoParseYBMessage(buf, parsed_header, parsed_main_message);
264
72.7M
}
265
266
38.9M
Result<ParsedRemoteMethod> ParseRemoteMethod(const Slice& buf) {
267
38.9M
  CodedInputStream in(buf.data(), narrow_cast<int>(buf.size()));
268
38.9M
  in.PushLimit(narrow_cast<int>(buf.size()));
269
38.9M
  ParsedRemoteMethod result;
270
116M
  while (in.BytesUntilLimit() > 0) {
271
77.8M
    auto tag = in.ReadTag();
272
77.8M
    auto field = tag >> 3;
273
77.8M
    switch (field) {
274
38.9M
      case RemoteMethodPB::kServiceNameFieldNumber:
275
38.9M
        result.service = VERIFY_RESULT(ParseString(buf, "service_name", &in));
276
0
        break;
277
38.9M
      case RemoteMethodPB::kMethodNameFieldNumber:
278
38.9M
        result.method = VERIFY_RESULT(ParseString(buf, "method_name", &in));
279
0
        break;
280
6
      default: {
281
6
        if (!SkipField(tag & 7, &in)) {
282
4
          return STATUS_FORMAT(Corruption, "Unable to skip: $0", tag);
283
4
        }
284
6
      }
285
77.8M
    }
286
77.8M
  }
287
38.9M
  return result;
288
38.9M
}
289
290
9.26k
std::string ParsedRequestHeader::RemoteMethodAsString() const {
291
9.26k
  auto parsed_remote_method = ParseRemoteMethod(remote_method);
292
9.26k
  if (parsed_remote_method.ok()) {
293
9.24k
    return parsed_remote_method->service.ToBuffer() + "." +
294
9.24k
           parsed_remote_method->method.ToBuffer();
295
9.24k
  } else {
296
29
    return parsed_remote_method.status().ToString();
297
29
  }
298
9.26k
}
299
300
1
void ParsedRequestHeader::ToPB(RequestHeader* out) const {
301
1
  out->set_call_id(call_id);
302
1
  if (timeout_ms) {
303
0
    out->set_timeout_millis(timeout_ms);
304
0
  }
305
1
  auto parsed_remote_method = ParseRemoteMethod(remote_method);
306
1
  if (parsed_remote_method.ok()) {
307
1
    out->mutable_remote_method()->set_service_name(parsed_remote_method->service.ToBuffer());
308
1
    out->mutable_remote_method()->set_method_name(parsed_remote_method->method.ToBuffer());
309
1
  }
310
1
}
311
312
}  // namespace rpc
313
}  // namespace yb