ESPHome 2025.6.3
Loading...
Searching...
No Matches
api_frame_helper.cpp
Go to the documentation of this file.
1#include "api_frame_helper.h"
2#ifdef USE_API
4#include "esphome/core/hal.h"
6#include "esphome/core/log.h"
7#include "proto.h"
8#include "api_pb2_size.h"
9#include <cstring>
10#include <cinttypes>
11
12namespace esphome {
13namespace api {
14
15static const char *const TAG = "api.socket";
16
17const char *api_error_to_str(APIError err) {
18 // not using switch to ensure compiler doesn't try to build a big table out of it
19 if (err == APIError::OK) {
20 return "OK";
21 } else if (err == APIError::WOULD_BLOCK) {
22 return "WOULD_BLOCK";
23 } else if (err == APIError::BAD_HANDSHAKE_PACKET_LEN) {
24 return "BAD_HANDSHAKE_PACKET_LEN";
25 } else if (err == APIError::BAD_INDICATOR) {
26 return "BAD_INDICATOR";
27 } else if (err == APIError::BAD_DATA_PACKET) {
28 return "BAD_DATA_PACKET";
29 } else if (err == APIError::TCP_NODELAY_FAILED) {
30 return "TCP_NODELAY_FAILED";
31 } else if (err == APIError::TCP_NONBLOCKING_FAILED) {
32 return "TCP_NONBLOCKING_FAILED";
33 } else if (err == APIError::CLOSE_FAILED) {
34 return "CLOSE_FAILED";
35 } else if (err == APIError::SHUTDOWN_FAILED) {
36 return "SHUTDOWN_FAILED";
37 } else if (err == APIError::BAD_STATE) {
38 return "BAD_STATE";
39 } else if (err == APIError::BAD_ARG) {
40 return "BAD_ARG";
41 } else if (err == APIError::SOCKET_READ_FAILED) {
42 return "SOCKET_READ_FAILED";
43 } else if (err == APIError::SOCKET_WRITE_FAILED) {
44 return "SOCKET_WRITE_FAILED";
45 } else if (err == APIError::HANDSHAKESTATE_READ_FAILED) {
46 return "HANDSHAKESTATE_READ_FAILED";
47 } else if (err == APIError::HANDSHAKESTATE_WRITE_FAILED) {
48 return "HANDSHAKESTATE_WRITE_FAILED";
49 } else if (err == APIError::HANDSHAKESTATE_BAD_STATE) {
50 return "HANDSHAKESTATE_BAD_STATE";
51 } else if (err == APIError::CIPHERSTATE_DECRYPT_FAILED) {
52 return "CIPHERSTATE_DECRYPT_FAILED";
53 } else if (err == APIError::CIPHERSTATE_ENCRYPT_FAILED) {
54 return "CIPHERSTATE_ENCRYPT_FAILED";
55 } else if (err == APIError::OUT_OF_MEMORY) {
56 return "OUT_OF_MEMORY";
57 } else if (err == APIError::HANDSHAKESTATE_SETUP_FAILED) {
58 return "HANDSHAKESTATE_SETUP_FAILED";
59 } else if (err == APIError::HANDSHAKESTATE_SPLIT_FAILED) {
60 return "HANDSHAKESTATE_SPLIT_FAILED";
61 } else if (err == APIError::BAD_HANDSHAKE_ERROR_BYTE) {
62 return "BAD_HANDSHAKE_ERROR_BYTE";
63 } else if (err == APIError::CONNECTION_CLOSED) {
64 return "CONNECTION_CLOSED";
65 }
66 return "UNKNOWN";
67}
68
69// Helper method to buffer data from IOVs
70void APIFrameHelper::buffer_data_from_iov_(const struct iovec *iov, int iovcnt, uint16_t total_write_len) {
71 SendBuffer buffer;
72 buffer.data.reserve(total_write_len);
73 for (int i = 0; i < iovcnt; i++) {
74 const uint8_t *data = reinterpret_cast<uint8_t *>(iov[i].iov_base);
75 buffer.data.insert(buffer.data.end(), data, data + iov[i].iov_len);
76 }
77 this->tx_buf_.push_back(std::move(buffer));
78}
79
80// This method writes data to socket or buffers it
81APIError APIFrameHelper::write_raw_(const struct iovec *iov, int iovcnt) {
82 // Returns APIError::OK if successful (or would block, but data has been buffered)
83 // Returns APIError::SOCKET_WRITE_FAILED if socket write failed, and sets state to FAILED
84
85 if (iovcnt == 0)
86 return APIError::OK; // Nothing to do, success
87
88 uint16_t total_write_len = 0;
89 for (int i = 0; i < iovcnt; i++) {
90#ifdef HELPER_LOG_PACKETS
91 ESP_LOGVV(TAG, "Sending raw: %s",
92 format_hex_pretty(reinterpret_cast<uint8_t *>(iov[i].iov_base), iov[i].iov_len).c_str());
93#endif
94 total_write_len += static_cast<uint16_t>(iov[i].iov_len);
95 }
96
97 // Try to send any existing buffered data first if there is any
98 if (!this->tx_buf_.empty()) {
99 APIError send_result = try_send_tx_buf_();
100 // If real error occurred (not just WOULD_BLOCK), return it
101 if (send_result != APIError::OK && send_result != APIError::WOULD_BLOCK) {
102 return send_result;
103 }
104
105 // If there is still data in the buffer, we can't send, buffer
106 // the new data and return
107 if (!this->tx_buf_.empty()) {
108 this->buffer_data_from_iov_(iov, iovcnt, total_write_len);
109 return APIError::OK; // Success, data buffered
110 }
111 }
112
113 // Try to send directly if no buffered data
114 ssize_t sent = this->socket_->writev(iov, iovcnt);
115
116 if (sent == -1) {
117 if (errno == EWOULDBLOCK || errno == EAGAIN) {
118 // Socket would block, buffer the data
119 this->buffer_data_from_iov_(iov, iovcnt, total_write_len);
120 return APIError::OK; // Success, data buffered
121 }
122 // Socket error
123 ESP_LOGVV(TAG, "%s: Socket write failed with errno %d", this->info_.c_str(), errno);
124 this->state_ = State::FAILED;
125 return APIError::SOCKET_WRITE_FAILED; // Socket write failed
126 } else if (static_cast<uint16_t>(sent) < total_write_len) {
127 // Partially sent, buffer the remaining data
128 SendBuffer buffer;
129 uint16_t to_consume = static_cast<uint16_t>(sent);
130 uint16_t remaining = total_write_len - static_cast<uint16_t>(sent);
131
132 buffer.data.reserve(remaining);
133
134 for (int i = 0; i < iovcnt; i++) {
135 if (to_consume >= iov[i].iov_len) {
136 // This segment was fully sent
137 to_consume -= static_cast<uint16_t>(iov[i].iov_len);
138 } else {
139 // This segment was partially sent or not sent at all
140 const uint8_t *data = reinterpret_cast<uint8_t *>(iov[i].iov_base) + to_consume;
141 uint16_t len = static_cast<uint16_t>(iov[i].iov_len) - to_consume;
142 buffer.data.insert(buffer.data.end(), data, data + len);
143 to_consume = 0;
144 }
145 }
146
147 this->tx_buf_.push_back(std::move(buffer));
148 }
149
150 return APIError::OK; // Success, all data sent or buffered
151}
152
153// Common implementation for trying to send buffered data
154// IMPORTANT: Caller MUST ensure tx_buf_ is not empty before calling this method
156 // Try to send from tx_buf - we assume it's not empty as it's the caller's responsibility to check
157 bool tx_buf_empty = false;
158 while (!tx_buf_empty) {
159 // Get the first buffer in the queue
160 SendBuffer &front_buffer = this->tx_buf_.front();
161
162 // Try to send the remaining data in this buffer
163 ssize_t sent = this->socket_->write(front_buffer.current_data(), front_buffer.remaining());
164
165 if (sent == -1) {
166 if (errno != EWOULDBLOCK && errno != EAGAIN) {
167 // Real socket error (not just would block)
168 ESP_LOGVV(TAG, "%s: Socket write failed with errno %d", this->info_.c_str(), errno);
169 this->state_ = State::FAILED;
170 return APIError::SOCKET_WRITE_FAILED; // Socket write failed
171 }
172 // Socket would block, we'll try again later
174 } else if (sent == 0) {
175 // Nothing sent but not an error
177 } else if (static_cast<uint16_t>(sent) < front_buffer.remaining()) {
178 // Partially sent, update offset
179 // Cast to ensure no overflow issues with uint16_t
180 front_buffer.offset += static_cast<uint16_t>(sent);
181 return APIError::WOULD_BLOCK; // Stop processing more buffers if we couldn't send a complete buffer
182 } else {
183 // Buffer completely sent, remove it from the queue
184 this->tx_buf_.pop_front();
185 // Update empty status for the loop condition
186 tx_buf_empty = this->tx_buf_.empty();
187 // Continue loop to try sending the next buffer
188 }
189 }
190
191 return APIError::OK; // All buffers sent successfully
192}
193
195 if (state_ != State::INITIALIZE || this->socket_ == nullptr) {
196 ESP_LOGVV(TAG, "%s: Bad state for init %d", this->info_.c_str(), (int) state_);
197 return APIError::BAD_STATE;
198 }
199 int err = this->socket_->setblocking(false);
200 if (err != 0) {
202 ESP_LOGVV(TAG, "%s: Setting nonblocking failed with errno %d", this->info_.c_str(), errno);
204 }
205
206 int enable = 1;
207 err = this->socket_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int));
208 if (err != 0) {
210 ESP_LOGVV(TAG, "%s: Setting nodelay failed with errno %d", this->info_.c_str(), errno);
212 }
213 return APIError::OK;
214}
215
216#define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, this->info_.c_str(), ##__VA_ARGS__)
217// uncomment to log raw packets
218//#define HELPER_LOG_PACKETS
219
220#ifdef USE_API_NOISE
221static const char *const PROLOGUE_INIT = "NoiseAPIInit";
222
224std::string noise_err_to_str(int err) {
225 if (err == NOISE_ERROR_NO_MEMORY)
226 return "NO_MEMORY";
227 if (err == NOISE_ERROR_UNKNOWN_ID)
228 return "UNKNOWN_ID";
229 if (err == NOISE_ERROR_UNKNOWN_NAME)
230 return "UNKNOWN_NAME";
231 if (err == NOISE_ERROR_MAC_FAILURE)
232 return "MAC_FAILURE";
233 if (err == NOISE_ERROR_NOT_APPLICABLE)
234 return "NOT_APPLICABLE";
235 if (err == NOISE_ERROR_SYSTEM)
236 return "SYSTEM";
237 if (err == NOISE_ERROR_REMOTE_KEY_REQUIRED)
238 return "REMOTE_KEY_REQUIRED";
239 if (err == NOISE_ERROR_LOCAL_KEY_REQUIRED)
240 return "LOCAL_KEY_REQUIRED";
241 if (err == NOISE_ERROR_PSK_REQUIRED)
242 return "PSK_REQUIRED";
243 if (err == NOISE_ERROR_INVALID_LENGTH)
244 return "INVALID_LENGTH";
245 if (err == NOISE_ERROR_INVALID_PARAM)
246 return "INVALID_PARAM";
247 if (err == NOISE_ERROR_INVALID_STATE)
248 return "INVALID_STATE";
249 if (err == NOISE_ERROR_INVALID_NONCE)
250 return "INVALID_NONCE";
251 if (err == NOISE_ERROR_INVALID_PRIVATE_KEY)
252 return "INVALID_PRIVATE_KEY";
253 if (err == NOISE_ERROR_INVALID_PUBLIC_KEY)
254 return "INVALID_PUBLIC_KEY";
255 if (err == NOISE_ERROR_INVALID_FORMAT)
256 return "INVALID_FORMAT";
257 if (err == NOISE_ERROR_INVALID_SIGNATURE)
258 return "INVALID_SIGNATURE";
259 return to_string(err);
260}
261
264 APIError err = init_common_();
265 if (err != APIError::OK) {
266 return err;
267 }
268
269 // init prologue
270 prologue_.insert(prologue_.end(), PROLOGUE_INIT, PROLOGUE_INIT + strlen(PROLOGUE_INIT));
271
273 return APIError::OK;
274}
277 APIError err = state_action_();
278 if (err != APIError::OK && err != APIError::WOULD_BLOCK) {
279 return err;
280 }
281 if (!this->tx_buf_.empty()) {
282 err = try_send_tx_buf_();
283 if (err != APIError::OK && err != APIError::WOULD_BLOCK) {
284 return err;
285 }
286 }
287 return APIError::OK; // Convert WOULD_BLOCK to OK to avoid connection termination
288}
289
305 if (frame == nullptr) {
306 HELPER_LOG("Bad argument for try_read_frame_");
307 return APIError::BAD_ARG;
308 }
309
310 // read header
311 if (rx_header_buf_len_ < 3) {
312 // no header information yet
313 uint8_t to_read = 3 - rx_header_buf_len_;
314 ssize_t received = this->socket_->read(&rx_header_buf_[rx_header_buf_len_], to_read);
315 if (received == -1) {
316 if (errno == EWOULDBLOCK || errno == EAGAIN) {
318 }
320 HELPER_LOG("Socket read failed with errno %d", errno);
322 } else if (received == 0) {
324 HELPER_LOG("Connection closed");
326 }
327 rx_header_buf_len_ += static_cast<uint8_t>(received);
328 if (static_cast<uint8_t>(received) != to_read) {
329 // not a full read
331 }
332
333 // header reading done
334 }
335
336 // read body
337 uint8_t indicator = rx_header_buf_[0];
338 if (indicator != 0x01) {
340 HELPER_LOG("Bad indicator byte %u", indicator);
342 }
343
344 uint16_t msg_size = (((uint16_t) rx_header_buf_[1]) << 8) | rx_header_buf_[2];
345
346 if (state_ != State::DATA && msg_size > 128) {
347 // for handshake message only permit up to 128 bytes
349 HELPER_LOG("Bad packet len for handshake: %d", msg_size);
351 }
352
353 // reserve space for body
354 if (rx_buf_.size() != msg_size) {
355 rx_buf_.resize(msg_size);
356 }
357
358 if (rx_buf_len_ < msg_size) {
359 // more data to read
360 uint16_t to_read = msg_size - rx_buf_len_;
361 ssize_t received = this->socket_->read(&rx_buf_[rx_buf_len_], to_read);
362 if (received == -1) {
363 if (errno == EWOULDBLOCK || errno == EAGAIN) {
365 }
367 HELPER_LOG("Socket read failed with errno %d", errno);
369 } else if (received == 0) {
371 HELPER_LOG("Connection closed");
373 }
374 rx_buf_len_ += static_cast<uint16_t>(received);
375 if (static_cast<uint16_t>(received) != to_read) {
376 // not all read
378 }
379 }
380
381 // uncomment for even more debugging
382#ifdef HELPER_LOG_PACKETS
383 ESP_LOGVV(TAG, "Received frame: %s", format_hex_pretty(rx_buf_).c_str());
384#endif
385 frame->msg = std::move(rx_buf_);
386 // consume msg
387 rx_buf_ = {};
388 rx_buf_len_ = 0;
390 return APIError::OK;
391}
392
403 int err;
404 APIError aerr;
405 if (state_ == State::INITIALIZE) {
406 HELPER_LOG("Bad state for method: %d", (int) state_);
407 return APIError::BAD_STATE;
408 }
410 // waiting for client hello
411 ParsedFrame frame;
412 aerr = try_read_frame_(&frame);
413 if (aerr == APIError::BAD_INDICATOR) {
414 send_explicit_handshake_reject_("Bad indicator byte");
415 return aerr;
416 }
418 send_explicit_handshake_reject_("Bad handshake packet len");
419 return aerr;
420 }
421 if (aerr != APIError::OK)
422 return aerr;
423 // ignore contents, may be used in future for flags
424 // Reserve space for: existing prologue + 2 size bytes + frame data
425 prologue_.reserve(prologue_.size() + 2 + frame.msg.size());
426 prologue_.push_back((uint8_t) (frame.msg.size() >> 8));
427 prologue_.push_back((uint8_t) frame.msg.size());
428 prologue_.insert(prologue_.end(), frame.msg.begin(), frame.msg.end());
429
431 }
433 // send server hello
434 const std::string &name = App.get_name();
435 const std::string &mac = get_mac_address();
436
437 std::vector<uint8_t> msg;
438 // Reserve space for: 1 byte proto + name + null + mac + null
439 msg.reserve(1 + name.size() + 1 + mac.size() + 1);
440
441 // chosen proto
442 msg.push_back(0x01);
443
444 // node name, terminated by null byte
445 const uint8_t *name_ptr = reinterpret_cast<const uint8_t *>(name.c_str());
446 msg.insert(msg.end(), name_ptr, name_ptr + name.size() + 1);
447 // node mac, terminated by null byte
448 const uint8_t *mac_ptr = reinterpret_cast<const uint8_t *>(mac.c_str());
449 msg.insert(msg.end(), mac_ptr, mac_ptr + mac.size() + 1);
450
451 aerr = write_frame_(msg.data(), msg.size());
452 if (aerr != APIError::OK)
453 return aerr;
454
455 // start handshake
456 aerr = init_handshake_();
457 if (aerr != APIError::OK)
458 return aerr;
459
461 }
462 if (state_ == State::HANDSHAKE) {
463 int action = noise_handshakestate_get_action(handshake_);
464 if (action == NOISE_ACTION_READ_MESSAGE) {
465 // waiting for handshake msg
466 ParsedFrame frame;
467 aerr = try_read_frame_(&frame);
468 if (aerr == APIError::BAD_INDICATOR) {
469 send_explicit_handshake_reject_("Bad indicator byte");
470 return aerr;
471 }
473 send_explicit_handshake_reject_("Bad handshake packet len");
474 return aerr;
475 }
476 if (aerr != APIError::OK)
477 return aerr;
478
479 if (frame.msg.empty()) {
480 send_explicit_handshake_reject_("Empty handshake message");
482 } else if (frame.msg[0] != 0x00) {
483 HELPER_LOG("Bad handshake error byte: %u", frame.msg[0]);
484 send_explicit_handshake_reject_("Bad handshake error byte");
486 }
487
488 NoiseBuffer mbuf;
489 noise_buffer_init(mbuf);
490 noise_buffer_set_input(mbuf, frame.msg.data() + 1, frame.msg.size() - 1);
491 err = noise_handshakestate_read_message(handshake_, &mbuf, nullptr);
492 if (err != 0) {
494 HELPER_LOG("noise_handshakestate_read_message failed: %s", noise_err_to_str(err).c_str());
495 if (err == NOISE_ERROR_MAC_FAILURE) {
496 send_explicit_handshake_reject_("Handshake MAC failure");
497 } else {
498 send_explicit_handshake_reject_("Handshake error");
499 }
501 }
502
504 if (aerr != APIError::OK)
505 return aerr;
506 } else if (action == NOISE_ACTION_WRITE_MESSAGE) {
507 uint8_t buffer[65];
508 NoiseBuffer mbuf;
509 noise_buffer_init(mbuf);
510 noise_buffer_set_output(mbuf, buffer + 1, sizeof(buffer) - 1);
511
512 err = noise_handshakestate_write_message(handshake_, &mbuf, nullptr);
513 if (err != 0) {
515 HELPER_LOG("noise_handshakestate_write_message failed: %s", noise_err_to_str(err).c_str());
517 }
518 buffer[0] = 0x00; // success
519
520 aerr = write_frame_(buffer, mbuf.size + 1);
521 if (aerr != APIError::OK)
522 return aerr;
524 if (aerr != APIError::OK)
525 return aerr;
526 } else {
527 // bad state for action
529 HELPER_LOG("Bad action for handshake: %d", action);
531 }
532 }
534 return APIError::BAD_STATE;
535 }
536 return APIError::OK;
537}
539 std::vector<uint8_t> data;
540 data.resize(reason.length() + 1);
541 data[0] = 0x01; // failure
542
543 // Copy error message in bulk
544 if (!reason.empty()) {
545 std::memcpy(data.data() + 1, reason.c_str(), reason.length());
546 }
547
548 // temporarily remove failed state
549 auto orig_state = state_;
551 write_frame_(data.data(), data.size());
552 state_ = orig_state;
553}
555 int err;
556 APIError aerr;
557 aerr = state_action_();
558 if (aerr != APIError::OK) {
559 return aerr;
560 }
561
562 if (state_ != State::DATA) {
564 }
565
566 ParsedFrame frame;
567 aerr = try_read_frame_(&frame);
568 if (aerr != APIError::OK)
569 return aerr;
570
571 NoiseBuffer mbuf;
572 noise_buffer_init(mbuf);
573 noise_buffer_set_inout(mbuf, frame.msg.data(), frame.msg.size(), frame.msg.size());
574 err = noise_cipherstate_decrypt(recv_cipher_, &mbuf);
575 if (err != 0) {
577 HELPER_LOG("noise_cipherstate_decrypt failed: %s", noise_err_to_str(err).c_str());
579 }
580
581 uint16_t msg_size = mbuf.size;
582 uint8_t *msg_data = frame.msg.data();
583 if (msg_size < 4) {
585 HELPER_LOG("Bad data packet: size %d too short", msg_size);
587 }
588
589 // uint16_t type;
590 // uint16_t data_len;
591 // uint8_t *data;
592 // uint8_t *padding; zero or more bytes to fill up the rest of the packet
593 uint16_t type = (((uint16_t) msg_data[0]) << 8) | msg_data[1];
594 uint16_t data_len = (((uint16_t) msg_data[2]) << 8) | msg_data[3];
595 if (data_len > msg_size - 4) {
597 HELPER_LOG("Bad data packet: data_len %u greater than msg_size %u", data_len, msg_size);
599 }
600
601 buffer->container = std::move(frame.msg);
602 buffer->data_offset = 4;
603 buffer->data_len = data_len;
604 buffer->type = type;
605 return APIError::OK;
606}
608 std::vector<uint8_t> *raw_buffer = buffer.get_buffer();
609 uint16_t payload_len = static_cast<uint16_t>(raw_buffer->size() - frame_header_padding_);
610
611 // Resize to include MAC space (required for Noise encryption)
612 raw_buffer->resize(raw_buffer->size() + frame_footer_size_);
613
614 // Use write_protobuf_packets with a single packet
615 std::vector<PacketInfo> packets;
616 packets.emplace_back(type, 0, payload_len);
617
618 return write_protobuf_packets(buffer, packets);
619}
620
621APIError APINoiseFrameHelper::write_protobuf_packets(ProtoWriteBuffer buffer, const std::vector<PacketInfo> &packets) {
622 APIError aerr = state_action_();
623 if (aerr != APIError::OK) {
624 return aerr;
625 }
626
627 if (state_ != State::DATA) {
629 }
630
631 if (packets.empty()) {
632 return APIError::OK;
633 }
634
635 std::vector<uint8_t> *raw_buffer = buffer.get_buffer();
636 this->reusable_iovs_.clear();
637 this->reusable_iovs_.reserve(packets.size());
638
639 // We need to encrypt each packet in place
640 for (const auto &packet : packets) {
641 uint16_t type = packet.message_type;
642 uint16_t offset = packet.offset;
643 uint16_t payload_len = packet.payload_size;
644 uint16_t msg_len = 4 + payload_len; // type(2) + data_len(2) + payload
645
646 // The buffer already has padding at offset
647 uint8_t *buf_start = raw_buffer->data() + offset;
648
649 // Write noise header
650 buf_start[0] = 0x01; // indicator
651 // buf_start[1], buf_start[2] to be set after encryption
652
653 // Write message header (to be encrypted)
654 const uint8_t msg_offset = 3;
655 buf_start[msg_offset + 0] = (uint8_t) (type >> 8); // type high byte
656 buf_start[msg_offset + 1] = (uint8_t) type; // type low byte
657 buf_start[msg_offset + 2] = (uint8_t) (payload_len >> 8); // data_len high byte
658 buf_start[msg_offset + 3] = (uint8_t) payload_len; // data_len low byte
659 // payload data is already in the buffer starting at offset + 7
660
661 // Make sure we have space for MAC
662 // The buffer should already have been sized appropriately
663
664 // Encrypt the message in place
665 NoiseBuffer mbuf;
666 noise_buffer_init(mbuf);
667 noise_buffer_set_inout(mbuf, buf_start + msg_offset, msg_len, msg_len + frame_footer_size_);
668
669 int err = noise_cipherstate_encrypt(send_cipher_, &mbuf);
670 if (err != 0) {
672 HELPER_LOG("noise_cipherstate_encrypt failed: %s", noise_err_to_str(err).c_str());
674 }
675
676 // Fill in the encrypted size
677 buf_start[1] = (uint8_t) (mbuf.size >> 8);
678 buf_start[2] = (uint8_t) mbuf.size;
679
680 // Add iovec for this encrypted packet
681 struct iovec iov;
682 iov.iov_base = buf_start;
683 iov.iov_len = 3 + mbuf.size; // indicator + size + encrypted data
684 this->reusable_iovs_.push_back(iov);
685 }
686
687 // Send all encrypted packets in one writev call
688 return this->write_raw_(this->reusable_iovs_.data(), this->reusable_iovs_.size());
689}
690
691APIError APINoiseFrameHelper::write_frame_(const uint8_t *data, uint16_t len) {
692 uint8_t header[3];
693 header[0] = 0x01; // indicator
694 header[1] = (uint8_t) (len >> 8);
695 header[2] = (uint8_t) len;
696
697 struct iovec iov[2];
698 iov[0].iov_base = header;
699 iov[0].iov_len = 3;
700 if (len == 0) {
701 return this->write_raw_(iov, 1);
702 }
703 iov[1].iov_base = const_cast<uint8_t *>(data);
704 iov[1].iov_len = len;
705
706 return this->write_raw_(iov, 2);
707}
708
714 int err;
715 memset(&nid_, 0, sizeof(nid_));
716 // const char *proto = "Noise_NNpsk0_25519_ChaChaPoly_SHA256";
717 // err = noise_protocol_name_to_id(&nid_, proto, strlen(proto));
718 nid_.pattern_id = NOISE_PATTERN_NN;
719 nid_.cipher_id = NOISE_CIPHER_CHACHAPOLY;
720 nid_.dh_id = NOISE_DH_CURVE25519;
721 nid_.prefix_id = NOISE_PREFIX_STANDARD;
722 nid_.hybrid_id = NOISE_DH_NONE;
723 nid_.hash_id = NOISE_HASH_SHA256;
724 nid_.modifier_ids[0] = NOISE_MODIFIER_PSK0;
725
726 err = noise_handshakestate_new_by_id(&handshake_, &nid_, NOISE_ROLE_RESPONDER);
727 if (err != 0) {
729 HELPER_LOG("noise_handshakestate_new_by_id failed: %s", noise_err_to_str(err).c_str());
731 }
732
733 const auto &psk = ctx_->get_psk();
734 err = noise_handshakestate_set_pre_shared_key(handshake_, psk.data(), psk.size());
735 if (err != 0) {
737 HELPER_LOG("noise_handshakestate_set_pre_shared_key failed: %s", noise_err_to_str(err).c_str());
739 }
740
741 err = noise_handshakestate_set_prologue(handshake_, prologue_.data(), prologue_.size());
742 if (err != 0) {
744 HELPER_LOG("noise_handshakestate_set_prologue failed: %s", noise_err_to_str(err).c_str());
746 }
747 // set_prologue copies it into handshakestate, so we can get rid of it now
748 prologue_ = {};
749
750 err = noise_handshakestate_start(handshake_);
751 if (err != 0) {
753 HELPER_LOG("noise_handshakestate_start failed: %s", noise_err_to_str(err).c_str());
755 }
756 return APIError::OK;
757}
758
760 assert(state_ == State::HANDSHAKE);
761
762 int action = noise_handshakestate_get_action(handshake_);
763 if (action == NOISE_ACTION_READ_MESSAGE || action == NOISE_ACTION_WRITE_MESSAGE)
764 return APIError::OK;
765 if (action != NOISE_ACTION_SPLIT) {
767 HELPER_LOG("Bad action for handshake: %d", action);
769 }
770 int err = noise_handshakestate_split(handshake_, &send_cipher_, &recv_cipher_);
771 if (err != 0) {
773 HELPER_LOG("noise_handshakestate_split failed: %s", noise_err_to_str(err).c_str());
775 }
776
777 frame_footer_size_ = noise_cipherstate_get_mac_length(send_cipher_);
778
779 HELPER_LOG("Handshake complete!");
780 noise_handshakestate_free(handshake_);
781 handshake_ = nullptr;
783 return APIError::OK;
784}
785
787 if (handshake_ != nullptr) {
788 noise_handshakestate_free(handshake_);
789 handshake_ = nullptr;
790 }
791 if (send_cipher_ != nullptr) {
792 noise_cipherstate_free(send_cipher_);
793 send_cipher_ = nullptr;
794 }
795 if (recv_cipher_ != nullptr) {
796 noise_cipherstate_free(recv_cipher_);
797 recv_cipher_ = nullptr;
798 }
799}
800
801extern "C" {
802// declare how noise generates random bytes (here with a good HWRNG based on the RF system)
803void noise_rand_bytes(void *output, size_t len) {
804 if (!esphome::random_bytes(reinterpret_cast<uint8_t *>(output), len)) {
805 ESP_LOGE(TAG, "Acquiring random bytes failed; rebooting");
806 arch_restart();
807 }
808}
809}
810
811#endif // USE_API_NOISE
812
813#ifdef USE_API_PLAINTEXT
814
817 APIError err = init_common_();
818 if (err != APIError::OK) {
819 return err;
820 }
821
823 return APIError::OK;
824}
827 if (state_ != State::DATA) {
828 return APIError::BAD_STATE;
829 }
830 if (!this->tx_buf_.empty()) {
832 if (err != APIError::OK && err != APIError::WOULD_BLOCK) {
833 return err;
834 }
835 }
836 return APIError::OK; // Convert WOULD_BLOCK to OK to avoid connection termination
837}
838
849 if (frame == nullptr) {
850 HELPER_LOG("Bad argument for try_read_frame_");
851 return APIError::BAD_ARG;
852 }
853
854 // read header
855 while (!rx_header_parsed_) {
856 // Now that we know when the socket is ready, we can read up to 3 bytes
857 // into the rx_header_buf_ before we have to switch back to reading
858 // one byte at a time to ensure we don't read past the message and
859 // into the next one.
860
861 // Read directly into rx_header_buf_ at the current position
862 // Try to get to at least 3 bytes total (indicator + 2 varint bytes), then read one byte at a time
863 ssize_t received =
864 this->socket_->read(&rx_header_buf_[rx_header_buf_pos_], rx_header_buf_pos_ < 3 ? 3 - rx_header_buf_pos_ : 1);
865 if (received == -1) {
866 if (errno == EWOULDBLOCK || errno == EAGAIN) {
868 }
870 HELPER_LOG("Socket read failed with errno %d", errno);
872 } else if (received == 0) {
874 HELPER_LOG("Connection closed");
876 }
877
878 // If this was the first read, validate the indicator byte
879 if (rx_header_buf_pos_ == 0 && received > 0) {
880 if (rx_header_buf_[0] != 0x00) {
882 HELPER_LOG("Bad indicator byte %u", rx_header_buf_[0]);
884 }
885 }
886
887 rx_header_buf_pos_ += received;
888
889 // Check for buffer overflow
890 if (rx_header_buf_pos_ >= sizeof(rx_header_buf_)) {
892 HELPER_LOG("Header buffer overflow");
894 }
895
896 // Need at least 3 bytes total (indicator + 2 varint bytes) before trying to parse
897 if (rx_header_buf_pos_ < 3) {
898 continue;
899 }
900
901 // At this point, we have at least 3 bytes total:
902 // - Validated indicator byte (0x00) stored at position 0
903 // - At least 2 bytes in the buffer for the varints
904 // Buffer layout:
905 // [0]: indicator byte (0x00)
906 // [1-3]: Message size varint (variable length)
907 // - 2 bytes would only allow up to 16383, which is less than noise's UINT16_MAX (65535)
908 // - 3 bytes allows up to 2097151, ensuring we support at least as much as noise
909 // [2-5]: Message type varint (variable length)
910 // We now attempt to parse both varints. If either is incomplete,
911 // we'll continue reading more bytes.
912
913 // Skip indicator byte at position 0
914 uint8_t varint_pos = 1;
915 uint32_t consumed = 0;
916
917 auto msg_size_varint = ProtoVarInt::parse(&rx_header_buf_[varint_pos], rx_header_buf_pos_ - varint_pos, &consumed);
918 if (!msg_size_varint.has_value()) {
919 // not enough data there yet
920 continue;
921 }
922
923 if (msg_size_varint->as_uint32() > std::numeric_limits<uint16_t>::max()) {
925 HELPER_LOG("Bad packet: message size %" PRIu32 " exceeds maximum %u", msg_size_varint->as_uint32(),
926 std::numeric_limits<uint16_t>::max());
928 }
929 rx_header_parsed_len_ = msg_size_varint->as_uint16();
930
931 // Move to next varint position
932 varint_pos += consumed;
933
934 auto msg_type_varint = ProtoVarInt::parse(&rx_header_buf_[varint_pos], rx_header_buf_pos_ - varint_pos, &consumed);
935 if (!msg_type_varint.has_value()) {
936 // not enough data there yet
937 continue;
938 }
939 if (msg_type_varint->as_uint32() > std::numeric_limits<uint16_t>::max()) {
941 HELPER_LOG("Bad packet: message type %" PRIu32 " exceeds maximum %u", msg_type_varint->as_uint32(),
942 std::numeric_limits<uint16_t>::max());
944 }
945 rx_header_parsed_type_ = msg_type_varint->as_uint16();
946 rx_header_parsed_ = true;
947 }
948 // header reading done
949
950 // reserve space for body
951 if (rx_buf_.size() != rx_header_parsed_len_) {
953 }
954
956 // more data to read
957 uint16_t to_read = rx_header_parsed_len_ - rx_buf_len_;
958 ssize_t received = this->socket_->read(&rx_buf_[rx_buf_len_], to_read);
959 if (received == -1) {
960 if (errno == EWOULDBLOCK || errno == EAGAIN) {
962 }
964 HELPER_LOG("Socket read failed with errno %d", errno);
966 } else if (received == 0) {
968 HELPER_LOG("Connection closed");
970 }
971 rx_buf_len_ += static_cast<uint16_t>(received);
972 if (static_cast<uint16_t>(received) != to_read) {
973 // not all read
975 }
976 }
977
978 // uncomment for even more debugging
979#ifdef HELPER_LOG_PACKETS
980 ESP_LOGVV(TAG, "Received frame: %s", format_hex_pretty(rx_buf_).c_str());
981#endif
982 frame->msg = std::move(rx_buf_);
983 // consume msg
984 rx_buf_ = {};
985 rx_buf_len_ = 0;
987 rx_header_parsed_ = false;
988 return APIError::OK;
989}
991 APIError aerr;
992
993 if (state_ != State::DATA) {
995 }
996
997 ParsedFrame frame;
998 aerr = try_read_frame_(&frame);
999 if (aerr != APIError::OK) {
1000 if (aerr == APIError::BAD_INDICATOR) {
1001 // Make sure to tell the remote that we don't
1002 // understand the indicator byte so it knows
1003 // we do not support it.
1004 struct iovec iov[1];
1005 // The \x00 first byte is the marker for plaintext.
1006 //
1007 // The remote will know how to handle the indicator byte,
1008 // but it likely won't understand the rest of the message.
1009 //
1010 // We must send at least 3 bytes to be read, so we add
1011 // a message after the indicator byte to ensures its long
1012 // enough and can aid in debugging.
1013 const char msg[] = "\x00"
1014 "Bad indicator byte";
1015 iov[0].iov_base = (void *) msg;
1016 iov[0].iov_len = 19;
1017 this->write_raw_(iov, 1);
1018 }
1019 return aerr;
1020 }
1021
1022 buffer->container = std::move(frame.msg);
1023 buffer->data_offset = 0;
1025 buffer->type = rx_header_parsed_type_;
1026 return APIError::OK;
1027}
1029 std::vector<uint8_t> *raw_buffer = buffer.get_buffer();
1030 uint16_t payload_len = static_cast<uint16_t>(raw_buffer->size() - frame_header_padding_);
1031
1032 // Use write_protobuf_packets with a single packet
1033 std::vector<PacketInfo> packets;
1034 packets.emplace_back(type, 0, payload_len);
1035
1036 return write_protobuf_packets(buffer, packets);
1037}
1038
1040 const std::vector<PacketInfo> &packets) {
1041 if (state_ != State::DATA) {
1042 return APIError::BAD_STATE;
1043 }
1044
1045 if (packets.empty()) {
1046 return APIError::OK;
1047 }
1048
1049 std::vector<uint8_t> *raw_buffer = buffer.get_buffer();
1050 this->reusable_iovs_.clear();
1051 this->reusable_iovs_.reserve(packets.size());
1052
1053 for (const auto &packet : packets) {
1054 uint16_t type = packet.message_type;
1055 uint16_t offset = packet.offset;
1056 uint16_t payload_len = packet.payload_size;
1057
1058 // Calculate varint sizes for header layout
1059 uint8_t size_varint_len = api::ProtoSize::varint(static_cast<uint32_t>(payload_len));
1060 uint8_t type_varint_len = api::ProtoSize::varint(static_cast<uint32_t>(type));
1061 uint8_t total_header_len = 1 + size_varint_len + type_varint_len;
1062
1063 // Calculate where to start writing the header
1064 // The header starts at the latest possible position to minimize unused padding
1065 //
1066 // Example 1 (small values): total_header_len = 3, header_offset = 6 - 3 = 3
1067 // [0-2] - Unused padding
1068 // [3] - 0x00 indicator byte
1069 // [4] - Payload size varint (1 byte, for sizes 0-127)
1070 // [5] - Message type varint (1 byte, for types 0-127)
1071 // [6...] - Actual payload data
1072 //
1073 // Example 2 (medium values): total_header_len = 4, header_offset = 6 - 4 = 2
1074 // [0-1] - Unused padding
1075 // [2] - 0x00 indicator byte
1076 // [3-4] - Payload size varint (2 bytes, for sizes 128-16383)
1077 // [5] - Message type varint (1 byte, for types 0-127)
1078 // [6...] - Actual payload data
1079 //
1080 // Example 3 (large values): total_header_len = 6, header_offset = 6 - 6 = 0
1081 // [0] - 0x00 indicator byte
1082 // [1-3] - Payload size varint (3 bytes, for sizes 16384-2097151)
1083 // [4-5] - Message type varint (2 bytes, for types 128-32767)
1084 // [6...] - Actual payload data
1085 //
1086 // The message starts at offset + frame_header_padding_
1087 // So we write the header starting at offset + frame_header_padding_ - total_header_len
1088 uint8_t *buf_start = raw_buffer->data() + offset;
1089 uint32_t header_offset = frame_header_padding_ - total_header_len;
1090
1091 // Write the plaintext header
1092 buf_start[header_offset] = 0x00; // indicator
1093
1094 // Encode size varint directly into buffer
1095 ProtoVarInt(payload_len).encode_to_buffer_unchecked(buf_start + header_offset + 1, size_varint_len);
1096
1097 // Encode type varint directly into buffer
1098 ProtoVarInt(type).encode_to_buffer_unchecked(buf_start + header_offset + 1 + size_varint_len, type_varint_len);
1099
1100 // Add iovec for this packet (header + payload)
1101 struct iovec iov;
1102 iov.iov_base = buf_start + header_offset;
1103 iov.iov_len = total_header_len + payload_len;
1104 this->reusable_iovs_.push_back(iov);
1105 }
1106
1107 // Send all packets in one writev call
1108 return write_raw_(this->reusable_iovs_.data(), this->reusable_iovs_.size());
1109}
1110
1111#endif // USE_API_PLAINTEXT
1112
1113} // namespace api
1114} // namespace esphome
1115#endif
const std::string & get_name() const
Get the name of this Application set by pre_setup().
std::vector< uint8_t > rx_buf_
std::vector< struct iovec > reusable_iovs_
void buffer_data_from_iov_(const struct iovec *iov, int iovcnt, uint16_t total_write_len)
APIError write_raw_(const struct iovec *iov, int iovcnt)
std::deque< SendBuffer > tx_buf_
APIError write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) override
APIError read_packet(ReadPacketBuffer *buffer) override
APIError try_read_frame_(ParsedFrame *frame)
Read a packet into the rx_buf_.
APIError write_protobuf_packets(ProtoWriteBuffer buffer, const std::vector< PacketInfo > &packets) override
APIError state_action_()
To be called from read/write methods.
APIError loop() override
Run through handshake messages (if in that phase)
APIError write_frame_(const uint8_t *data, uint16_t len)
std::shared_ptr< APINoiseContext > ctx_
APIError init() override
Initialize the frame helper, returns OK if successful.
void send_explicit_handshake_reject_(const std::string &reason)
APIError init_handshake_()
Initiate the data structures for the handshake.
APIError write_protobuf_packets(ProtoWriteBuffer buffer, const std::vector< PacketInfo > &packets) override
APIError init() override
Initialize the frame helper, returns OK if successful.
APIError loop() override
Not used for plaintext.
APIError read_packet(ReadPacketBuffer *buffer) override
APIError write_protobuf_packet(uint16_t type, ProtoWriteBuffer buffer) override
APIError try_read_frame_(ParsedFrame *frame)
Read a packet into the rx_buf_.
static uint32_t varint(uint32_t value)
ProtoSize class for Protocol Buffer serialization size calculation.
Representation of a VarInt - in ProtoBuf should be 64bit but we only use 32bit.
Definition proto.h:17
void encode_to_buffer_unchecked(uint8_t *buffer, size_t len)
Encode the varint value to a pre-allocated buffer without bounds checking.
Definition proto.h:98
static optional< ProtoVarInt > parse(const uint8_t *buffer, uint32_t len, uint32_t *consumed)
Definition proto.h:22
std::vector< uint8_t > * get_buffer() const
Definition proto.h:321
virtual ssize_t write(const void *buf, size_t len)=0
virtual int setblocking(bool blocking)=0
virtual ssize_t writev(const struct iovec *iov, int iovcnt)=0
virtual ssize_t read(void *buf, size_t len)=0
virtual int setsockopt(int level, int optname, const void *optval, socklen_t optlen)=0
uint8_t type
__int64 ssize_t
Definition httplib.h:175
void noise_rand_bytes(void *output, size_t len)
std::string noise_err_to_str(int err)
Convert a noise error code to a readable error.
const char * api_error_to_str(APIError err)
Providing packet encoding functions for exchanging data with a remote host.
Definition a01nyub.cpp:7
bool random_bytes(uint8_t *data, size_t len)
Generate len number of random bytes.
Definition helpers.cpp:220
std::string size_t len
Definition helpers.h:302
std::string to_string(int value)
Definition helpers.cpp:82
std::string get_mac_address()
Get the device MAC address as a string, in lowercase hex notation.
Definition helpers.cpp:726
void arch_restart()
Definition core.cpp:32
Application App
Global storage of Application pointer - only one Application can exist.
std::string format_hex_pretty(const uint8_t *data, size_t length)
Format the byte array data of length len in pretty-printed, human-readable hex.
Definition helpers.cpp:372
std::vector< uint8_t > container
void * iov_base
Definition headers.h:101
size_t iov_len
Definition headers.h:102