YugabyteDB (2.13.0.0-b42, bfc6a6643e7399ac8a0e81d06a3ee6d6571b33ab)

Coverage Report

Created: 2022-03-09 17:30

/Users/deen/code/yugabyte-db/src/yb/rpc/serialization.cc
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
//
18
// The following only applies to changes made to this file as part of YugaByte development.
19
//
20
// Portions Copyright (c) YugaByte, Inc.
21
//
22
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
23
// in compliance with the License.  You may obtain a copy of the License at
24
//
25
// http://www.apache.org/licenses/LICENSE-2.0
26
//
27
// Unless required by applicable law or agreed to in writing, software distributed under the License
28
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
29
// or implied.  See the License for the specific language governing permissions and limitations
30
// under the License.
31
//
32
33
#include "yb/rpc/serialization.h"
34
35
#include <google/protobuf/io/coded_stream.h>
36
#include <google/protobuf/message.h>
37
38
#include "yb/gutil/endian.h"
39
#include "yb/gutil/stringprintf.h"
40
41
#include "yb/rpc/constants.h"
42
#include "yb/rpc/lightweight_message.h"
43
44
#include "yb/rpc/rpc_header.pb.h"
45
46
#include "yb/util/faststring.h"
47
#include "yb/util/ref_cnt_buffer.h"
48
#include "yb/util/result.h"
49
#include "yb/util/slice.h"
50
#include "yb/util/status_format.h"
51
52
DECLARE_uint64(rpc_max_message_size);
53
54
using google::protobuf::MessageLite;
55
using google::protobuf::io::CodedInputStream;
56
using google::protobuf::io::CodedOutputStream;
57
58
namespace yb {
59
namespace rpc {
60
61
79.0M
size_t SerializedMessageSize(size_t body_size, size_t additional_size) {
62
79.0M
  auto full_size = body_size + additional_size;
63
79.0M
  return body_size + CodedOutputStream::VarintSize32(narrow_cast<uint32_t>(full_size));
64
79.0M
}
65
66
CHECKED_STATUS SerializeMessage(
67
    AnyMessageConstPtr msg, size_t body_size, const RefCntBuffer& param_buf,
68
39.5M
    size_t additional_size, size_t offset) {
69
39.5M
  DCHECK_EQ(msg.SerializedSize(), body_size);
70
39.5M
  auto size = SerializedMessageSize(body_size, additional_size);
71
72
39.5M
  auto total_size = size + additional_size;
73
39.5M
  if (total_size > FLAGS_rpc_max_message_size) {
74
0
    return STATUS_FORMAT(InvalidArgument, "Sending too long RPC message ($0 bytes)", total_size);
75
0
  }
76
77
0
  CHECK_EQ(param_buf.size(), offset + size) << "offset = " << offset;
78
39.5M
  uint8_t *dst = param_buf.udata() + offset;
79
39.5M
  dst = CodedOutputStream::WriteVarint32ToArray(
80
39.5M
      narrow_cast<uint32_t>(body_size + additional_size), dst);
81
39.5M
  dst = VERIFY_RESULT(msg.SerializeToArray(dst));
82
39.5M
  CHECK_EQ(dst - param_buf.udata(), param_buf.size());
83
84
39.5M
  return Status::OK();
85
39.5M
}
86
87
Status SerializeHeader(const MessageLite& header,
88
                       size_t param_len,
89
                       RefCntBuffer* header_buf,
90
                       size_t reserve_for_param,
91
19.5M
                       size_t* header_size) {
92
19.5M
  if (PREDICT_FALSE(!header.IsInitialized())) {
93
0
    LOG(DFATAL) << "Uninitialized RPC header";
94
0
    return STATUS(InvalidArgument, "RPC header missing required fields",
95
0
                                  header.InitializationErrorString());
96
0
  }
97
98
  // Compute all the lengths for the packet.
99
19.5M
  size_t header_pb_len = header.ByteSize();
100
19.5M
  size_t header_tot_len = kMsgLengthPrefixLength        // Int prefix for the total length.
101
19.5M
      + CodedOutputStream::VarintSize32(
102
19.5M
            narrow_cast<uint32_t>(header_pb_len))      // Varint delimiter for header PB.
103
19.5M
      + header_pb_len;                                  // Length for the header PB itself.
104
19.5M
  size_t total_size = header_tot_len + param_len;
105
106
19.5M
  *header_buf = RefCntBuffer(header_tot_len + reserve_for_param);
107
19.5M
  if (header_size != nullptr) {
108
19.5M
    *header_size = header_tot_len;
109
19.5M
  }
110
19.5M
  uint8_t* dst = header_buf->udata();
111
112
  // 1. The length for the whole request, not including the 4-byte
113
  // length prefix.
114
19.5M
  NetworkByteOrder::Store32(dst, narrow_cast<uint32_t>(total_size - kMsgLengthPrefixLength));
115
19.5M
  dst += sizeof(uint32_t);
116
117
  // 2. The varint-prefixed RequestHeader PB
118
19.5M
  dst = CodedOutputStream::WriteVarint32ToArray(narrow_cast<uint32_t>(header_pb_len), dst);
119
19.5M
  dst = header.SerializeWithCachedSizesToArray(dst);
120
121
  // We should have used the whole buffer we allocated.
122
19.5M
  CHECK_EQ(dst, header_buf->udata() + header_tot_len);
123
124
19.5M
  return Status::OK();
125
19.5M
}
126
127
Result<RefCntBuffer> SerializeRequest(
128
    size_t body_size, size_t additional_size, const google::protobuf::Message& header,
129
19.5M
    AnyMessageConstPtr body) {
130
19.5M
  auto message_size = SerializedMessageSize(body_size, additional_size);
131
19.5M
  size_t header_size = 0;
132
19.5M
  RefCntBuffer result;
133
19.5M
  RETURN_NOT_OK(SerializeHeader(
134
19.5M
      header, message_size + additional_size, &result, message_size, &header_size));
135
136
19.5M
  RETURN_NOT_OK(SerializeMessage(body, body_size, result, additional_size, header_size));
137
19.5M
  return result;
138
19.5M
}
139
140
0
bool SkipField(uint8_t type, CodedInputStream* in) {
141
0
  switch (type) {
142
0
    case 0: {
143
0
      uint64_t temp;
144
0
      return in->ReadVarint64(&temp);
145
0
    }
146
0
    case 1:
147
0
      return in->Skip(8);
148
0
    case 2: {
149
0
      uint32_t temp;
150
0
      return in->ReadVarint32(&temp) && in->Skip(temp);
151
0
    }
152
0
    case 5:
153
0
      return in->Skip(4);
154
0
    default:
155
0
      return false;
156
0
  }
157
0
}
158
159
25.9M
Result<Slice> ParseString(const Slice& buf, const char* name, CodedInputStream* in) {
160
25.9M
  uint32_t len;
161
25.9M
  if (!in->ReadVarint32(&len) || in->BytesUntilLimit() < implicit_cast<int>(len)) {
162
0
    return STATUS(Corruption, "Unable to decode field", Slice(name));
163
0
  }
164
25.9M
  Slice result(buf.data() + in->CurrentPosition(), len);
165
25.9M
  in->Skip(len);
166
25.9M
  return result;
167
25.9M
}
168
169
CHECKED_STATUS ParseHeader(
170
19.5M
    const Slice& buf, CodedInputStream* in, ParsedRequestHeader* parsed_header) {
171
78.3M
  while (in->BytesUntilLimit() > 0) {
172
58.7M
    auto tag = in->ReadTag();
173
58.7M
    auto field = tag >> 3;
174
58.7M
    switch (field) {
175
19.5M
      case RequestHeader::kCallIdFieldNumber: {
176
19.5M
        uint32_t temp;
177
19.5M
        if (!in->ReadVarint32(&temp)) {
178
0
          return STATUS(Corruption, "Unable to decode call_id field");
179
0
        }
180
19.5M
        parsed_header->call_id = static_cast<int32_t>(temp);
181
19.5M
        } break;
182
19.5M
      case RequestHeader::kRemoteMethodFieldNumber:
183
19.5M
        parsed_header->remote_method = VERIFY_RESULT(ParseString(buf, "remote_method", in));
184
19.5M
        break;
185
19.5M
      case RequestHeader::kTimeoutMillisFieldNumber:
186
19.5M
        if (!in->ReadVarint32(&parsed_header->timeout_ms)) {
187
0
          return STATUS(Corruption, "Unable to decode timeout_ms field");
188
0
        }
189
19.5M
        break;
190
0
      default: {
191
0
        if (!SkipField(tag & 7, in)) {
192
0
          return STATUS_FORMAT(Corruption, "Unable to skip: $0", tag);
193
0
        }
194
0
      }
195
58.7M
    }
196
58.7M
  }
197
198
19.5M
  return Status::OK();
199
19.5M
}
200
201
19.5M
CHECKED_STATUS ParseHeader(const Slice& buf, CodedInputStream* in, MessageLite* parsed_header) {
202
19.5M
  if (PREDICT_FALSE(!parsed_header->ParseFromCodedStream(in))) {
203
0
    return STATUS(Corruption, "Invalid packet: header too short",
204
0
                              buf.ToDebugString());
205
0
  }
206
207
19.5M
  return Status::OK();
208
19.5M
}
209
210
namespace {
211
212
template <class Header>
213
CHECKED_STATUS DoParseYBMessage(const Slice& buf,
214
                                Header* parsed_header,
215
39.1M
                                Slice* parsed_main_message) {
216
39.1M
  CodedInputStream in(buf.data(), narrow_cast<int>(buf.size()));
217
39.1M
  SetupLimit(&in);
218
219
39.1M
  uint32_t header_len;
220
39.1M
  if (PREDICT_FALSE(!in.ReadVarint32(&header_len))) {
221
0
    return STATUS(Corruption, "Invalid packet: missing header delimiter",
222
0
                              buf.ToDebugString());
223
0
  }
224
225
39.1M
  auto l = in.PushLimit(header_len);
226
39.1M
  RETURN_NOT_OK(ParseHeader(buf, &in, parsed_header));
227
39.1M
  in.PopLimit(l);
228
229
39.1M
  uint32_t main_msg_len;
230
39.1M
  if (PREDICT_FALSE(!in.ReadVarint32(&main_msg_len))) {
231
0
    return STATUS(Corruption, "Invalid packet: missing main msg length",
232
0
                              buf.ToDebugString());
233
0
  }
234
235
39.1M
  if (PREDICT_FALSE(!in.Skip(main_msg_len))) {
236
0
    return STATUS(Corruption,
237
0
        StringPrintf("Invalid packet: data too short, expected %d byte main_msg", main_msg_len),
238
0
        buf.ToDebugString());
239
0
  }
240
241
39.1M
  if (PREDICT_FALSE(in.BytesUntilLimit() > 0)) {
242
0
    return STATUS(Corruption,
243
0
      StringPrintf("Invalid packet: %d extra bytes at end of packet", in.BytesUntilLimit()),
244
0
      buf.ToDebugString());
245
0
  }
246
247
39.1M
  *parsed_main_message = Slice(buf.data() + buf.size() - main_msg_len,
248
39.1M
                              main_msg_len);
249
39.1M
  return Status::OK();
250
39.1M
}
serialization.cc:_ZN2yb3rpc12_GLOBAL__N_116DoParseYBMessageINS0_19ParsedRequestHeaderEEENS_6StatusERKNS_5SliceEPT_PS5_
Line
Count
Source
215
19.5M
                                Slice* parsed_main_message) {
216
19.5M
  CodedInputStream in(buf.data(), narrow_cast<int>(buf.size()));
217
19.5M
  SetupLimit(&in);
218
219
19.5M
  uint32_t header_len;
220
19.5M
  if (PREDICT_FALSE(!in.ReadVarint32(&header_len))) {
221
0
    return STATUS(Corruption, "Invalid packet: missing header delimiter",
222
0
                              buf.ToDebugString());
223
0
  }
224
225
19.5M
  auto l = in.PushLimit(header_len);
226
19.5M
  RETURN_NOT_OK(ParseHeader(buf, &in, parsed_header));
227
19.5M
  in.PopLimit(l);
228
229
19.5M
  uint32_t main_msg_len;
230
19.5M
  if (PREDICT_FALSE(!in.ReadVarint32(&main_msg_len))) {
231
0
    return STATUS(Corruption, "Invalid packet: missing main msg length",
232
0
                              buf.ToDebugString());
233
0
  }
234
235
19.5M
  if (PREDICT_FALSE(!in.Skip(main_msg_len))) {
236
0
    return STATUS(Corruption,
237
0
        StringPrintf("Invalid packet: data too short, expected %d byte main_msg", main_msg_len),
238
0
        buf.ToDebugString());
239
0
  }
240
241
19.5M
  if (PREDICT_FALSE(in.BytesUntilLimit() > 0)) {
242
0
    return STATUS(Corruption,
243
0
      StringPrintf("Invalid packet: %d extra bytes at end of packet", in.BytesUntilLimit()),
244
0
      buf.ToDebugString());
245
0
  }
246
247
19.5M
  *parsed_main_message = Slice(buf.data() + buf.size() - main_msg_len,
248
19.5M
                              main_msg_len);
249
19.5M
  return Status::OK();
250
19.5M
}
serialization.cc:_ZN2yb3rpc12_GLOBAL__N_116DoParseYBMessageIN6google8protobuf11MessageLiteEEENS_6StatusERKNS_5SliceEPT_PS7_
Line
Count
Source
215
19.5M
                                Slice* parsed_main_message) {
216
19.5M
  CodedInputStream in(buf.data(), narrow_cast<int>(buf.size()));
217
19.5M
  SetupLimit(&in);
218
219
19.5M
  uint32_t header_len;
220
19.5M
  if (PREDICT_FALSE(!in.ReadVarint32(&header_len))) {
221
0
    return STATUS(Corruption, "Invalid packet: missing header delimiter",
222
0
                              buf.ToDebugString());
223
0
  }
224
225
19.5M
  auto l = in.PushLimit(header_len);
226
19.5M
  RETURN_NOT_OK(ParseHeader(buf, &in, parsed_header));
227
19.5M
  in.PopLimit(l);
228
229
19.5M
  uint32_t main_msg_len;
230
19.5M
  if (PREDICT_FALSE(!in.ReadVarint32(&main_msg_len))) {
231
0
    return STATUS(Corruption, "Invalid packet: missing main msg length",
232
0
                              buf.ToDebugString());
233
0
  }
234
235
19.5M
  if (PREDICT_FALSE(!in.Skip(main_msg_len))) {
236
0
    return STATUS(Corruption,
237
0
        StringPrintf("Invalid packet: data too short, expected %d byte main_msg", main_msg_len),
238
0
        buf.ToDebugString());
239
0
  }
240
241
19.5M
  if (PREDICT_FALSE(in.BytesUntilLimit() > 0)) {
242
0
    return STATUS(Corruption,
243
0
      StringPrintf("Invalid packet: %d extra bytes at end of packet", in.BytesUntilLimit()),
244
0
      buf.ToDebugString());
245
0
  }
246
247
19.5M
  *parsed_main_message = Slice(buf.data() + buf.size() - main_msg_len,
248
19.5M
                              main_msg_len);
249
19.5M
  return Status::OK();
250
19.5M
}
251
252
} // namespace
253
254
Status ParseYBMessage(const Slice& buf,
255
                      ParsedRequestHeader* parsed_header,
256
19.5M
                      Slice* parsed_main_message) {
257
19.5M
  return DoParseYBMessage(buf, parsed_header, parsed_main_message);
258
19.5M
}
259
260
Status ParseYBMessage(const Slice& buf,
261
                      MessageLite* parsed_header,
262
19.5M
                      Slice* parsed_main_message) {
263
19.5M
  return DoParseYBMessage(buf, parsed_header, parsed_main_message);
264
19.5M
}
265
266
3.20M
Result<ParsedRemoteMethod> ParseRemoteMethod(const Slice& buf) {
267
3.20M
  CodedInputStream in(buf.data(), narrow_cast<int>(buf.size()));
268
3.20M
  in.PushLimit(narrow_cast<int>(buf.size()));
269
3.20M
  ParsedRemoteMethod result;
270
9.60M
  while (in.BytesUntilLimit() > 0) {
271
6.40M
    auto tag = in.ReadTag();
272
6.40M
    auto field = tag >> 3;
273
6.40M
    switch (field) {
274
3.20M
      case RemoteMethodPB::kServiceNameFieldNumber:
275
3.20M
        result.service = VERIFY_RESULT(ParseString(buf, "service_name", &in));
276
3.20M
        break;
277
3.20M
      case RemoteMethodPB::kMethodNameFieldNumber:
278
3.20M
        result.method = VERIFY_RESULT(ParseString(buf, "method_name", &in));
279
3.20M
        break;
280
0
      default: {
281
0
        if (!SkipField(tag & 7, &in)) {
282
0
          return STATUS_FORMAT(Corruption, "Unable to skip: $0", tag);
283
0
        }
284
0
      }
285
6.40M
    }
286
6.40M
  }
287
3.20M
  return result;
288
3.20M
}
289
290
2.45k
std::string ParsedRequestHeader::RemoteMethodAsString() const {
291
2.45k
  auto parsed_remote_method = ParseRemoteMethod(remote_method);
292
2.45k
  if (parsed_remote_method.ok()) {
293
2.45k
    return parsed_remote_method->service.ToBuffer() + "." +
294
2.45k
           parsed_remote_method->method.ToBuffer();
295
2
  } else {
296
2
    return parsed_remote_method.status().ToString();
297
2
  }
298
2.45k
}
299
300
1
void ParsedRequestHeader::ToPB(RequestHeader* out) const {
301
1
  out->set_call_id(call_id);
302
1
  if (timeout_ms) {
303
0
    out->set_timeout_millis(timeout_ms);
304
0
  }
305
1
  auto parsed_remote_method = ParseRemoteMethod(remote_method);
306
1
  if (parsed_remote_method.ok()) {
307
1
    out->mutable_remote_method()->set_service_name(parsed_remote_method->service.ToBuffer());
308
1
    out->mutable_remote_method()->set_method_name(parsed_remote_method->method.ToBuffer());
309
1
  }
310
1
}
311
312
}  // namespace rpc
313
}  // namespace yb