ESPHome 2025.5.0
Loading...
Searching...
No Matches
streaming_model.h
Go to the documentation of this file.
1#pragma once
2
3#ifdef USE_ESP_IDF
4
6
8
9#include <tensorflow/lite/core/c/common.h>
10#include <tensorflow/lite/micro/micro_interpreter.h>
11#include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
12
13namespace esphome {
14namespace micro_wake_word {
15
16static const uint8_t MIN_SLICES_BEFORE_DETECTION = 100;
17static const uint32_t STREAMING_MODEL_VARIABLE_ARENA_SIZE = 1024;
18
20 std::string *wake_word;
22 bool partially_detection; // Set if the most recent probability exceed the threshold, but the sliding window average
23 // hasn't yet
26 bool blocked_by_vad = false;
27};
28
30 public:
31 virtual void log_model_config() = 0;
33
34 // Performs inference on the given features.
35 // - If the model is enabled but not loaded, it will load it
36 // - If the model is disabled but loaded, it will unload it
37 // Returns true if sucessful or false if there is an error
38 bool perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]);
39
42
44 void unload_model();
45
47 virtual void enable() { this->enabled_ = true; }
48
50 virtual void disable() { this->enabled_ = false; }
51
53 bool is_enabled() const { return this->enabled_; }
54
56
57 // Quantized probability cutoffs mapping 0.0 - 1.0 to 0 - 255
59 uint8_t get_probability_cutoff() const { return this->probability_cutoff_; }
60 void set_probability_cutoff(uint8_t probability_cutoff) { this->probability_cutoff_ = probability_cutoff; }
61
62 protected:
65 bool load_model_();
67 bool register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver);
68
69 tflite::MicroMutableOpResolver<20> streaming_op_resolver_;
70
71 bool loaded_{false};
72 bool enabled_{true};
75 int16_t ignore_windows_{-MIN_SLICES_BEFORE_DETECTION};
76
80
81 size_t last_n_index_{0};
83 std::vector<uint8_t> recent_streaming_probabilities_;
84
85 const uint8_t *model_start_;
86 uint8_t *tensor_arena_{nullptr};
87 uint8_t *var_arena_{nullptr};
88 std::unique_ptr<tflite::MicroInterpreter> interpreter_;
89 tflite::MicroResourceVariables *mrv_{nullptr};
90 tflite::MicroAllocator *ma_{nullptr};
91};
92
93class WakeWordModel final : public StreamingModel {
94 public:
105 WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t default_probability_cutoff,
106 size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size,
107 bool default_enabled, bool internal_only);
108
109 void log_model_config() override;
110
115
116 const std::string &get_id() const { return this->id_; }
117 const std::string &get_wake_word() const { return this->wake_word_; }
118
119 void add_trained_language(const std::string &language) { this->trained_languages_.push_back(language); }
120 const std::vector<std::string> &get_trained_languages() const { return this->trained_languages_; }
121
123 void enable() override;
124
126 void disable() override;
127
128 bool get_internal_only() { return this->internal_only_; }
129
130 protected:
131 std::string id_;
132 std::string wake_word_;
133 std::vector<std::string> trained_languages_;
134
136
138};
139
140class VADModel final : public StreamingModel {
141 public:
142 VADModel(const uint8_t *model_start, uint8_t default_probability_cutoff, size_t sliding_window_size,
143 size_t tensor_arena_size);
144
145 void log_model_config() override;
146
151};
152
153} // namespace micro_wake_word
154} // namespace esphome
155
156#endif
virtual void disable()
Disable the model. The next performing_streaming_inference call will unload it.
virtual DetectionEvent determine_detected()=0
bool load_model_()
Allocates tensor and variable arenas and sets up the model interpreter.
virtual void enable()
Enable the model. The next performing_streaming_inference call will load it.
std::unique_ptr< tflite::MicroInterpreter > interpreter_
tflite::MicroMutableOpResolver< 20 > streaming_op_resolver_
bool register_streaming_ops_(tflite::MicroMutableOpResolver< 20 > &op_resolver)
Returns true if successfully registered the streaming model's TensorFlow operations.
void reset_probabilities()
Sets all recent_streaming_probabilities to 0 and resets the ignore window count.
std::vector< uint8_t > recent_streaming_probabilities_
tflite::MicroResourceVariables * mrv_
bool perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE])
void unload_model()
Destroys the TFLite interpreter and frees the tensor and variable arenas' memory.
void set_probability_cutoff(uint8_t probability_cutoff)
bool is_enabled() const
Return true if the model is enabled.
DetectionEvent determine_detected() override
Checks for voice activity by comparing the max probability in the sliding window with the probability...
VADModel(const uint8_t *model_start, uint8_t default_probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size)
void enable() override
Enable the model and save to flash. The next performing_streaming_inference call will load it.
const std::string & get_wake_word() const
DetectionEvent determine_detected() override
Checks for the wake word by comparing the mean probability in the sliding window with the probability...
const std::vector< std::string > & get_trained_languages() const
WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t default_probability_cutoff, size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size, bool default_enabled, bool internal_only)
Constructs a wake word model object.
void add_trained_language(const std::string &language)
void disable() override
Disable the model and save to flash. The next performing_streaming_inference call will unload it.
std::vector< std::string > trained_languages_
Providing packet encoding functions for exchanging data with a remote host.
Definition a01nyub.cpp:7