Electroneum
websocket.h
Go to the documentation of this file.
1 #pragma once
2 #include "socket_adaptors.h"
3 #include "http_request.h"
4 #include "TinySHA1.hpp"
5 
6 namespace crow
7 {
8  namespace websocket
9  {
10  enum class WebSocketReadState
11  {
12  MiniHeader,
13  Len16,
14  Len64,
15  Mask,
16  Payload,
17  };
18 
19  struct connection
20  {
21  virtual void send_binary(const std::string& msg) = 0;
22  virtual void send_text(const std::string& msg) = 0;
23  virtual void close(const std::string& msg = "quit") = 0;
24  virtual ~connection(){}
25  };
26 
27  template <typename Adaptor>
28  class Connection : public connection
29  {
30  public:
31  Connection(const crow::request& req, Adaptor&& adaptor,
32  std::function<void(crow::websocket::connection&)> open_handler,
33  std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler,
34  std::function<void(crow::websocket::connection&, const std::string&)> close_handler,
35  std::function<void(crow::websocket::connection&)> error_handler)
36  : adaptor_(std::move(adaptor)), open_handler_(std::move(open_handler)), message_handler_(std::move(message_handler)), close_handler_(std::move(close_handler)), error_handler_(std::move(error_handler))
37  {
38  if (req.get_header_value("upgrade") != "websocket")
39  {
40  adaptor.close();
41  delete this;
42  return;
43  }
44  // Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
45  // Sec-WebSocket-Version: 13
46  std::string magic = req.get_header_value("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
47  sha1::SHA1 s;
48  s.processBytes(magic.data(), magic.size());
49  uint8_t digest[20];
50  s.getDigestBytes(digest);
51  start(crow::utility::base64encode((char*)digest, 20));
52  }
53 
54  template<typename CompletionHandler>
55  void dispatch(CompletionHandler handler)
56  {
57  adaptor_.get_io_service().dispatch(handler);
58  }
59 
60  template<typename CompletionHandler>
61  void post(CompletionHandler handler)
62  {
63  adaptor_.get_io_service().post(handler);
64  }
65 
66  void send_pong(const std::string& msg)
67  {
68  dispatch([this, msg]{
69  char buf[3] = "\x8A\x00";
70  buf[1] += msg.size();
71  write_buffers_.emplace_back(buf, buf+2);
72  write_buffers_.emplace_back(msg);
73  do_write();
74  });
75  }
76 
77  void send_binary(const std::string& msg) override
78  {
79  dispatch([this, msg]{
80  auto header = build_header(2, msg.size());
81  write_buffers_.emplace_back(std::move(header));
82  write_buffers_.emplace_back(msg);
83  do_write();
84  });
85  }
86 
87  void send_text(const std::string& msg) override
88  {
89  dispatch([this, msg]{
90  auto header = build_header(1, msg.size());
91  write_buffers_.emplace_back(std::move(header));
92  write_buffers_.emplace_back(msg);
93  do_write();
94  });
95  }
96 
97  void close(const std::string& msg) override
98  {
99  dispatch([this, msg]{
100  has_sent_close_ = true;
102  {
104  if (close_handler_)
105  close_handler_(*this, msg);
106  }
107  auto header = build_header(0x8, msg.size());
108  write_buffers_.emplace_back(std::move(header));
109  write_buffers_.emplace_back(msg);
110  do_write();
111  });
112  }
113 
114  protected:
115 
116  std::string build_header(int opcode, size_t size)
117  {
118  char buf[2+8] = "\x80\x00";
119  buf[0] += opcode;
120  if (size < 126)
121  {
122  buf[1] += size;
123  return {buf, buf+2};
124  }
125  else if (size < 0x10000)
126  {
127  buf[1] += 126;
128  *(uint16_t*)(buf+2) = (uint16_t)size;
129  return {buf, buf+4};
130  }
131  else
132  {
133  buf[1] += 127;
134  *(uint64_t*)(buf+2) = (uint64_t)size;
135  return {buf, buf+10};
136  }
137  }
138 
139  void start(std::string&& hello)
140  {
141  static std::string header = "HTTP/1.1 101 Switching Protocols\r\n"
142  "Upgrade: websocket\r\n"
143  "Connection: Upgrade\r\n"
144  "Sec-WebSocket-Accept: ";
145  static std::string crlf = "\r\n";
146  write_buffers_.emplace_back(header);
147  write_buffers_.emplace_back(std::move(hello));
148  write_buffers_.emplace_back(crlf);
149  write_buffers_.emplace_back(crlf);
150  do_write();
151  if (open_handler_)
152  open_handler_(*this);
153  do_read();
154  }
155 
156  void do_read()
157  {
158  is_reading = true;
159  switch(state_)
160  {
162  {
163  //boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&mini_header_, 1),
164  adaptor_.socket().async_read_some(boost::asio::buffer(&mini_header_, 2),
165  [this](const boost::system::error_code& ec, std::size_t bytes_transferred)
166  {
167  is_reading = false;
168  mini_header_ = htons(mini_header_);
169 #ifdef CROW_ENABLE_DEBUG
170 
171  if (!ec && bytes_transferred != 2)
172  {
173  throw std::runtime_error("WebSocket:MiniHeader:async_read fail:asio bug?");
174  }
175 #endif
176 
177  if (!ec && ((mini_header_ & 0x80) == 0x80))
178  {
179  if ((mini_header_ & 0x7f) == 127)
180  {
182  }
183  else if ((mini_header_ & 0x7f) == 126)
184  {
186  }
187  else
188  {
191  }
192  do_read();
193  }
194  else
195  {
196  close_connection_ = true;
197  adaptor_.close();
198  if (error_handler_)
199  error_handler_(*this);
200  check_destroy();
201  }
202  });
203  }
204  break;
206  {
207  remaining_length_ = 0;
208  boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length_, 2),
209  [this](const boost::system::error_code& ec, std::size_t bytes_transferred)
210  {
211  is_reading = false;
212  remaining_length_ = ntohs(*(uint16_t*)&remaining_length_);
213 #ifdef CROW_ENABLE_DEBUG
214  if (!ec && bytes_transferred != 2)
215  {
216  throw std::runtime_error("WebSocket:Len16:async_read fail:asio bug?");
217  }
218 #endif
219 
220  if (!ec)
221  {
223  do_read();
224  }
225  else
226  {
227  close_connection_ = true;
228  adaptor_.close();
229  if (error_handler_)
230  error_handler_(*this);
231  check_destroy();
232  }
233  });
234  }
235  break;
237  {
238  boost::asio::async_read(adaptor_.socket(), boost::asio::buffer(&remaining_length_, 8),
239  [this](const boost::system::error_code& ec, std::size_t bytes_transferred)
240  {
241  is_reading = false;
242  remaining_length_ = ((1==ntohl(1)) ? (remaining_length_) : ((uint64_t)ntohl((remaining_length_) & 0xFFFFFFFF) << 32) | ntohl((remaining_length_) >> 32));
243 #ifdef CROW_ENABLE_DEBUG
244  if (!ec && bytes_transferred != 8)
245  {
246  throw std::runtime_error("WebSocket:Len16:async_read fail:asio bug?");
247  }
248 #endif
249 
250  if (!ec)
251  {
253  do_read();
254  }
255  else
256  {
257  close_connection_ = true;
258  adaptor_.close();
259  if (error_handler_)
260  error_handler_(*this);
261  check_destroy();
262  }
263  });
264  }
265  break;
267  boost::asio::async_read(adaptor_.socket(), boost::asio::buffer((char*)&mask_, 4),
268  [this](const boost::system::error_code& ec, std::size_t bytes_transferred)
269  {
270  is_reading = false;
271 #ifdef CROW_ENABLE_DEBUG
272  if (!ec && bytes_transferred != 4)
273  {
274  throw std::runtime_error("WebSocket:Mask:async_read fail:asio bug?");
275  }
276 #endif
277 
278  if (!ec)
279  {
281  do_read();
282  }
283  else
284  {
285  close_connection_ = true;
286  if (error_handler_)
287  error_handler_(*this);
288  adaptor_.close();
289  }
290  });
291  break;
293  {
294  size_t to_read = buffer_.size();
295  if (remaining_length_ < to_read)
296  to_read = remaining_length_;
297  adaptor_.socket().async_read_some( boost::asio::buffer(buffer_, to_read),
298  [this](const boost::system::error_code& ec, std::size_t bytes_transferred)
299  {
300  is_reading = false;
301 
302  if (!ec)
303  {
304  fragment_.insert(fragment_.end(), buffer_.begin(), buffer_.begin() + bytes_transferred);
305  remaining_length_ -= bytes_transferred;
306  if (remaining_length_ == 0)
307  {
308  handle_fragment();
310  do_read();
311  }
312  }
313  else
314  {
315  close_connection_ = true;
316  if (error_handler_)
317  error_handler_(*this);
318  adaptor_.close();
319  }
320  });
321  }
322  break;
323  }
324  }
325 
326  bool is_FIN()
327  {
328  return mini_header_ & 0x8000;
329  }
330 
331  int opcode()
332  {
333  return (mini_header_ & 0x0f00) >> 8;
334  }
335 
337  {
338  for(decltype(fragment_.length()) i = 0; i < fragment_.length(); i ++)
339  {
340  fragment_[i] ^= ((char*)&mask_)[i%4];
341  }
342  switch(opcode())
343  {
344  case 0: // Continuation
345  {
346  message_ += fragment_;
347  if (is_FIN())
348  {
349  if (message_handler_)
351  message_.clear();
352  }
353  }
354  case 1: // Text
355  {
356  is_binary_ = false;
357  message_ += fragment_;
358  if (is_FIN())
359  {
360  if (message_handler_)
362  message_.clear();
363  }
364  }
365  break;
366  case 2: // Binary
367  {
368  is_binary_ = true;
369  message_ += fragment_;
370  if (is_FIN())
371  {
372  if (message_handler_)
374  message_.clear();
375  }
376  }
377  break;
378  case 0x8: // Close
379  {
380  has_recv_close_ = true;
381  if (!has_sent_close_)
382  {
383  close(fragment_);
384  }
385  else
386  {
387  adaptor_.close();
388  close_connection_ = true;
390  {
391  if (close_handler_)
392  close_handler_(*this, fragment_);
394  }
395  check_destroy();
396  }
397  }
398  break;
399  case 0x9: // Ping
400  {
402  }
403  break;
404  case 0xA: // Pong
405  {
406  pong_received_ = true;
407  }
408  break;
409  }
410 
411  fragment_.clear();
412  }
413 
414  void do_write()
415  {
416  if (sending_buffers_.empty())
417  {
419  std::vector<boost::asio::const_buffer> buffers;
420  buffers.reserve(sending_buffers_.size());
421  for(auto& s:sending_buffers_)
422  {
423  buffers.emplace_back(boost::asio::buffer(s));
424  }
425  boost::asio::async_write(adaptor_.socket(), buffers,
426  [&](const boost::system::error_code& ec, std::size_t /*bytes_transferred*/)
427  {
428  sending_buffers_.clear();
429  if (!ec && !close_connection_)
430  {
431  if (!write_buffers_.empty())
432  do_write();
433  if (has_sent_close_)
434  close_connection_ = true;
435  }
436  else
437  {
438  close_connection_ = true;
439  check_destroy();
440  }
441  });
442  }
443  }
444 
446  {
447  //if (has_sent_close_ && has_recv_close_)
449  if (close_handler_)
450  close_handler_(*this, "uncleanly");
451  if (sending_buffers_.empty() && !is_reading)
452  delete this;
453  }
454  private:
455  Adaptor adaptor_;
456 
457  std::vector<std::string> sending_buffers_;
458  std::vector<std::string> write_buffers_;
459 
460  boost::array<char, 4096> buffer_;
462  std::string message_;
463  std::string fragment_;
465  uint64_t remaining_length_{0};
466  bool close_connection_{false};
467  bool is_reading{false};
468  uint32_t mask_;
469  uint16_t mini_header_;
470  bool has_sent_close_{false};
471  bool has_recv_close_{false};
472  bool error_occured_{false};
473  bool pong_received_{false};
475 
476  std::function<void(crow::websocket::connection&)> open_handler_;
477  std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler_;
478  std::function<void(crow::websocket::connection&, const std::string&)> close_handler_;
479  std::function<void(crow::websocket::connection&)> error_handler_;
480  };
481  }
482 }
Definition: TinySHA1.hpp:30
bool has_recv_close_
Definition: websocket.h:471
std::function< void(crow::websocket::connection &, const std::string &)> close_handler_
Definition: websocket.h:478
boost::array< char, 4096 > buffer_
Definition: websocket.h:460
std::function< void(crow::websocket::connection &)> error_handler_
Definition: websocket.h:479
Definition: http_request.h:23
bool pong_received_
Definition: websocket.h:473
std::string base64encode(const char *data, size_t size, const char *key="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/")
Definition: utility.h:503
void do_write()
Definition: websocket.h:414
void post(CompletionHandler handler)
Definition: websocket.h:61
uint16_t mini_header_
Definition: websocket.h:469
WebSocketReadState state_
Definition: websocket.h:464
std::string fragment_
Definition: websocket.h:463
std::function< void(crow::websocket::connection &)> open_handler_
Definition: websocket.h:476
Definition: block_queue.cpp:41
virtual ~connection()
Definition: websocket.h:24
void dispatch(CompletionHandler handler)
Definition: websocket.h:55
void close(const std::string &msg) override
Definition: websocket.h:97
Definition: websocket.h:28
bool is_binary_
Definition: websocket.h:461
virtual void send_binary(const std::string &msg)=0
uint64_t remaining_length_
Definition: websocket.h:465
const std::string & get_header_value(const std::string &key) const
Definition: http_request.h:50
void check_destroy()
Definition: websocket.h:445
virtual void send_text(const std::string &msg)=0
bool is_reading
Definition: websocket.h:467
bool is_FIN()
Definition: websocket.h:326
std::vector< std::string > write_buffers_
Definition: websocket.h:458
int opcode()
Definition: websocket.h:331
uint32_t mask_
Definition: websocket.h:468
WebSocketReadState
Definition: websocket.h:10
Definition: websocket.h:19
void send_text(const std::string &msg) override
Definition: websocket.h:87
std::vector< std::string > sending_buffers_
Definition: websocket.h:457
bool has_sent_close_
Definition: websocket.h:470
Adaptor adaptor_
Definition: websocket.h:455
std::string message_
Definition: websocket.h:462
void handle_fragment()
Definition: websocket.h:336
Definition: ci_map.h:7
void send_pong(const std::string &msg)
Definition: websocket.h:66
std::function< void(crow::websocket::connection &, const std::string &, bool)> message_handler_
Definition: websocket.h:477
Connection(const crow::request &req, Adaptor &&adaptor, std::function< void(crow::websocket::connection &)> open_handler, std::function< void(crow::websocket::connection &, const std::string &, bool)> message_handler, std::function< void(crow::websocket::connection &, const std::string &)> close_handler, std::function< void(crow::websocket::connection &)> error_handler)
Definition: websocket.h:31
bool close_connection_
Definition: websocket.h:466
bool error_occured_
Definition: websocket.h:472
void send_binary(const std::string &msg) override
Definition: websocket.h:77
void start(std::string &&hello)
Definition: websocket.h:139
std::string build_header(int opcode, size_t size)
Definition: websocket.h:116
virtual void close(const std::string &msg="quit")=0
#define s(x, c)
Definition: aesb.c:46
void do_read()
Definition: websocket.h:156
bool is_close_handler_called_
Definition: websocket.h:474