sherpa_server/main.cpp
2026-06-30 18:09:13 +02:00

579 lines
22 KiB
C++

#include <sherpa-onnx/c-api/c-api.h>
#include <algorithm>
#include <cctype>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <string>
#include <unordered_map>
#include <vector>
#include <fcntl.h>
#include <io.h>
#define WIN32_LEAN_AND_MEAN
#include <windows.h>
#include "tray_icon_win32.hpp"
// ---------------------------------------------------------------------------
// CLI parsing — tiny, just enough for the flags we need
// ---------------------------------------------------------------------------
static const char* FindFlag(int argc, char** argv, const char* name) {
// Matches --name=value or --name value
size_t nameLen = strlen(name);
for (int i = 1; i < argc; i++) {
const char* a = argv[i];
if (strncmp(a, name, nameLen) != 0) continue;
if (a[nameLen] == '=') return a + nameLen + 1;
if (a[nameLen] == '\0' && i + 1 < argc) return argv[i + 1];
}
return nullptr;
}
static const char* FlagOrEmpty(int argc, char** argv, const char* name) {
const char* v = FindFlag(argc, argv, name);
return v ? v : "";
}
static int FlagInt(int argc, char** argv, const char* name, int def) {
const char* v = FindFlag(argc, argv, name);
return v ? atoi(v) : def;
}
static float FlagFloat(int argc, char** argv, const char* name, float def) {
const char* v = FindFlag(argc, argv, name);
return v ? (float)atof(v) : def;
}
// ---------------------------------------------------------------------------
// voice_presets.ini [Model.<key>] loader
// ---------------------------------------------------------------------------
struct ModelDef {
std::string key; // section-name suffix
std::string type; // "kokoro" / "kitten" / "vits" / "matcha" / ...
// Field bag — name -> value. Owned strings stay alive for the
// server's lifetime so SherpaOnnx* configs can keep raw pointers.
std::unordered_map<std::string, std::string> fields;
const char* GetCStr(const char* name) const {
auto it = fields.find(name);
return (it != fields.end() && !it->second.empty()) ? it->second.c_str() : "";
}
float GetFloat(const char* name, float def) const {
auto it = fields.find(name);
if (it == fields.end() || it->second.empty()) return def;
return (float)atof(it->second.c_str());
}
};
// Trim leading/trailing ASCII whitespace + an inline " ;" comment tail
// from a string in-place.
static void TrimAndStripComment(std::string& s) {
// Strip "<sp>;" tail
for (size_t i = 0; i + 1 < s.size(); i++) {
if ((s[i] == ' ' || s[i] == '\t') && s[i + 1] == ';') {
s.resize(i);
break;
}
}
while (!s.empty() && (s.back() == ' ' || s.back() == '\t' ||
s.back() == '\r' || s.back() == '\n')) {
s.pop_back();
}
size_t lead = 0;
while (lead < s.size() && (s[lead] == ' ' || s[lead] == '\t')) lead++;
if (lead > 0) s.erase(0, lead);
}
// Walk a GetPrivateProfileSectionA payload ("k=v\0k=v\0...\0\0") and call
// visit(key, value) for each entry.
template<typename Fn>
static void ForEachIniEntry(const char* buf, Fn&& visit) {
for (const char* p = buf; *p; p += strlen(p) + 1) {
const char* eq = strchr(p, '=');
if (!eq || eq == p) continue;
std::string key(p, eq - p);
std::string val(eq + 1);
TrimAndStripComment(key);
TrimAndStripComment(val);
if (key.empty()) continue;
visit(key, val);
}
}
static ModelDef* FindModelDefCi(std::vector<ModelDef>& defs,
const std::string& key) {
for (auto& d : defs) {
if (_stricmp(d.key.c_str(), key.c_str()) == 0) return &d;
}
return nullptr;
}
static const ModelDef* FindModelDefCi(const std::vector<ModelDef>& defs,
const std::string& key) {
for (const auto& d : defs) {
if (_stricmp(d.key.c_str(), key.c_str()) == 0) return &d;
}
return nullptr;
}
// Discover every [Model.<key>] section in one INI and merge into `out`.
// Last-wins on duplicate keys (case-insensitive) — a later file's
// redefinition replaces the existing entry in place, mirroring the
// plugin-side loader.
static void LoadModelDefsFromFile(const char* iniPath,
std::vector<ModelDef>& out) {
static char nameBuf[8192];
DWORD n = GetPrivateProfileSectionNamesA(nameBuf, sizeof(nameBuf), iniPath);
if (n == 0) return;
if (n >= sizeof(nameBuf) - 2) {
fprintf(stderr, "sherpa_server: warning — section-name list in %s "
"exceeded %zu bytes, some sections may be missed\n",
iniPath, sizeof(nameBuf));
fflush(stderr);
}
static char sectionBuf[16384];
for (const char* sect = nameBuf; *sect; sect += strlen(sect) + 1) {
if (_strnicmp(sect, "Model.", 6) != 0) continue;
const char* keyPart = sect + 6;
if (!*keyPart) {
fprintf(stderr, "sherpa_server: [%s] in %s missing model-key suffix — skipping\n",
sect, iniPath);
fflush(stderr);
continue;
}
DWORD len = GetPrivateProfileSectionA(sect, sectionBuf, sizeof(sectionBuf), iniPath);
if (len == 0) {
fprintf(stderr, "sherpa_server: [%s] in %s is empty — skipping\n",
sect, iniPath);
fflush(stderr);
continue;
}
ModelDef def;
def.key = keyPart;
ForEachIniEntry(sectionBuf, [&](const std::string& k, const std::string& v) {
if (_stricmp(k.c_str(), "type") == 0) def.type = v;
else def.fields[k] = v;
});
if (def.type.empty()) {
fprintf(stderr, "sherpa_server: [%s] in %s missing type= — skipping\n",
sect, iniPath);
fflush(stderr);
continue;
}
if (ModelDef* existing = FindModelDefCi(out, def.key)) {
fprintf(stderr, "sherpa_server: [%s] in %s — overriding previous definition\n",
sect, iniPath);
fflush(stderr);
*existing = std::move(def);
} else {
fprintf(stderr, "sherpa_server: registered model '%s' (type=%s, %zu fields) from %s\n",
def.key.c_str(), def.type.c_str(), def.fields.size(), iniPath);
fflush(stderr);
out.push_back(std::move(def));
}
}
}
// Enumerate every *.ini under `dirPath` (alphabetical, case-insensitive)
// and merge their [Model.*] sections into `out`. Last-wins on duplicates.
static int LoadModelDefsFromDir(const char* dirPath,
std::vector<ModelDef>& out) {
char pattern[MAX_PATH];
int wrote = snprintf(pattern, sizeof(pattern), "%s\\*.ini", dirPath);
if (wrote < 0 || wrote >= (int)sizeof(pattern)) {
fprintf(stderr, "sherpa_server: voice_presets path too long: %s\n", dirPath);
fflush(stderr);
return 0;
}
std::vector<std::string> names;
WIN32_FIND_DATAA fd;
HANDLE h = FindFirstFileA(pattern, &fd);
if (h == INVALID_HANDLE_VALUE) {
fprintf(stderr, "sherpa_server: no *.ini files in %s\n", dirPath);
fflush(stderr);
return 0;
}
do {
if (fd.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) continue;
names.push_back(fd.cFileName);
} while (FindNextFileA(h, &fd));
FindClose(h);
std::sort(names.begin(), names.end(),
[](const std::string& a, const std::string& b) {
return _stricmp(a.c_str(), b.c_str()) < 0;
});
fprintf(stderr, "sherpa_server: loading %zu voice_presets file(s) from %s\n",
names.size(), dirPath);
fflush(stderr);
for (const auto& nm : names) {
char full[MAX_PATH];
snprintf(full, sizeof(full), "%s\\%s", dirPath, nm.c_str());
LoadModelDefsFromFile(full, out);
}
return (int)out.size();
}
// ---------------------------------------------------------------------------
// Engine creation per family
// ---------------------------------------------------------------------------
struct EngineGlobals {
const char* provider;
int numThreads;
int debug;
int maxNumSentences;
float silenceScale;
};
// Preflight the non-CPU provider DLL so we can fall back to CPU without
// onnxruntime's CUDA init aborting the whole process. Returns the
// effective provider string the caller should use.
static const char* PreflightProvider(const char* provider) {
const char* providerDll = nullptr;
if (_stricmp(provider, "cuda") == 0) providerDll = "onnxruntime_providers_cuda.dll";
else if (_stricmp(provider, "tensorrt") == 0) providerDll = "onnxruntime_providers_tensorrt.dll";
if (!providerDll) return provider;
HMODULE h = LoadLibraryA(providerDll);
if (!h) {
DWORD err = GetLastError();
fprintf(stderr,
"sherpa_server: cannot load %s (err=%lu) — provider=%s "
"unavailable on this machine. Falling back to CPU.\n",
providerDll, err, provider);
fflush(stderr);
return "cpu";
}
FreeLibrary(h);
return provider;
}
// Build a SherpaOnnxOfflineTtsConfig populated for one ModelDef and try
// to create the engine. Falls back from non-CPU providers to CPU on
// failure. Returns nullptr on hard failure; writes the loaded engine's
// sample rate / speaker count to *outSampleRate / *outNumSpeakers on
// success.
static const SherpaOnnxOfflineTts*
CreateEngineForDef(const ModelDef& def, const EngineGlobals& g,
int* outSampleRate, int* outNumSpeakers) {
SherpaOnnxOfflineTtsConfig cfg = {};
cfg.model.num_threads = g.numThreads;
cfg.model.debug = g.debug;
cfg.model.provider = PreflightProvider(g.provider);
cfg.max_num_sentences = g.maxNumSentences;
cfg.silence_scale = g.silenceScale;
const char* t = def.type.c_str();
if (_stricmp(t, "kokoro") == 0) {
cfg.model.kokoro.model = def.GetCStr("model");
cfg.model.kokoro.voices = def.GetCStr("voices");
cfg.model.kokoro.tokens = def.GetCStr("tokens");
cfg.model.kokoro.data_dir = def.GetCStr("data_dir");
cfg.model.kokoro.lexicon = def.GetCStr("lexicon");
cfg.model.kokoro.lang = def.GetCStr("lang");
cfg.model.kokoro.length_scale = def.GetFloat("length_scale", 1.0f);
} else if (_stricmp(t, "kitten") == 0) {
cfg.model.kitten.model = def.GetCStr("model");
cfg.model.kitten.voices = def.GetCStr("voices");
cfg.model.kitten.tokens = def.GetCStr("tokens");
cfg.model.kitten.data_dir = def.GetCStr("data_dir");
cfg.model.kitten.length_scale = def.GetFloat("length_scale", 1.0f);
} else if (_stricmp(t, "vits") == 0) {
cfg.model.vits.model = def.GetCStr("model");
cfg.model.vits.lexicon = def.GetCStr("lexicon");
cfg.model.vits.tokens = def.GetCStr("tokens");
cfg.model.vits.data_dir = def.GetCStr("data_dir");
cfg.model.vits.noise_scale = def.GetFloat("noise_scale", 0.667f);
cfg.model.vits.noise_scale_w = def.GetFloat("noise_scale_w", 0.8f);
cfg.model.vits.length_scale = def.GetFloat("length_scale", 1.0f);
} else if (_stricmp(t, "matcha") == 0) {
cfg.model.matcha.acoustic_model = def.GetCStr("acoustic_model");
cfg.model.matcha.vocoder = def.GetCStr("vocoder");
cfg.model.matcha.lexicon = def.GetCStr("lexicon");
cfg.model.matcha.tokens = def.GetCStr("tokens");
cfg.model.matcha.data_dir = def.GetCStr("data_dir");
cfg.model.matcha.noise_scale = def.GetFloat("noise_scale", 1.0f);
cfg.model.matcha.length_scale = def.GetFloat("length_scale", 1.0f);
} else {
fprintf(stderr, "sherpa_server: model '%s' has unsupported type='%s' — "
"supported: kokoro, kitten, vits, matcha\n",
def.key.c_str(), t);
fflush(stderr);
return nullptr;
}
fprintf(stderr, "sherpa_server: loading '%s' (type=%s, provider=%s, threads=%d)\n",
def.key.c_str(), t, cfg.model.provider, cfg.model.num_threads);
fflush(stderr);
const SherpaOnnxOfflineTts* tts = SherpaOnnxCreateOfflineTts(&cfg);
if (!tts && strcmp(cfg.model.provider, "cpu") != 0) {
fprintf(stderr, "sherpa_server: provider=%s failed for '%s' — "
"falling back to CPU\n",
cfg.model.provider, def.key.c_str());
fflush(stderr);
cfg.model.provider = "cpu";
tts = SherpaOnnxCreateOfflineTts(&cfg);
}
if (!tts) {
fprintf(stderr, "sherpa_server: SherpaOnnxCreateOfflineTts failed for '%s'\n",
def.key.c_str());
fflush(stderr);
return nullptr;
}
*outSampleRate = SherpaOnnxOfflineTtsSampleRate(tts);
*outNumSpeakers = SherpaOnnxOfflineTtsNumSpeakers(tts);
fprintf(stderr, "sherpa_server: '%s' ready sr=%d speakers=%d\n",
def.key.c_str(), *outSampleRate, *outNumSpeakers);
fflush(stderr);
return tts;
}
// ---------------------------------------------------------------------------
// WAV building — sherpa returns float [-1, 1] samples; we serialise to a
// mono 16-bit PCM RIFF file and hand raw bytes to the plugin.
// ---------------------------------------------------------------------------
static std::string BuildWav(const float* samples, int32_t n, int32_t sampleRate) {
const uint16_t channels = 1;
const uint16_t bitsPerSample = 16;
const uint16_t blockAlign = channels * (bitsPerSample / 8);
const uint32_t byteRate = (uint32_t)sampleRate * blockAlign;
const uint32_t dataSize = (uint32_t)n * blockAlign;
const uint32_t riffSize = 36 + dataSize;
const uint32_t fmtSize = 16;
const uint16_t audioFormat = 1;
std::string wav;
wav.reserve(44 + dataSize);
auto putBytes = [&](const void* p, size_t len) {
wav.append((const char*)p, len);
};
auto putU32 = [&](uint32_t v) { putBytes(&v, 4); };
auto putU16 = [&](uint16_t v) { putBytes(&v, 2); };
putBytes("RIFF", 4); putU32(riffSize);
putBytes("WAVE", 4);
putBytes("fmt ", 4); putU32(fmtSize);
putU16(audioFormat); putU16(channels);
putU32((uint32_t)sampleRate); putU32(byteRate);
putU16(blockAlign); putU16(bitsPerSample);
putBytes("data", 4); putU32(dataSize);
wav.resize(wav.size() + dataSize);
int16_t* pcm = (int16_t*)(wav.data() + wav.size() - dataSize);
for (int32_t i = 0; i < n; i++) {
float v = samples[i] * 32767.0f;
if (v > 32767.0f) pcm[i] = 32767;
else if (v < -32768.0f) pcm[i] = -32768;
else pcm[i] = (int16_t)v;
}
return wav;
}
static void WriteResponse(const char* wavBytes, int32_t wavLen) {
fwrite(&wavLen, 4, 1, stdout);
if (wavLen > 0 && wavBytes) fwrite(wavBytes, 1, (size_t)wavLen, stdout);
fflush(stdout);
}
static void WriteFailure() {
int32_t zero = 0;
fwrite(&zero, 4, 1, stdout);
fflush(stdout);
}
// ---------------------------------------------------------------------------
// Per-engine state
// ---------------------------------------------------------------------------
struct EngineState {
const SherpaOnnxOfflineTts* tts;
int sampleRate;
int numSpeakers;
bool loadFailed; // sticky: once a model fails to load, don't retry
};
// ---------------------------------------------------------------------------
// main
// ---------------------------------------------------------------------------
int main(int argc, char** argv) {
// Binary mode on stdout so Windows doesn't mangle \n -> \r\n in PCM.
_setmode(_fileno(stdout), _O_BINARY);
const char* presetsPath = FindFlag(argc, argv, "--voice-presets");
if (!presetsPath || !*presetsPath) {
fprintf(stderr, "sherpa_server: --voice-presets <dir> is required\n");
return 2;
}
DWORD presetsAttrs = GetFileAttributesA(presetsPath);
if (presetsAttrs == INVALID_FILE_ATTRIBUTES ||
!(presetsAttrs & FILE_ATTRIBUTE_DIRECTORY)) {
fprintf(stderr, "sherpa_server: --voice-presets must name an existing "
"directory (got %s)\n", presetsPath);
return 2;
}
EngineGlobals globals = {};
{
const char* provider = FindFlag(argc, argv, "--provider");
globals.provider = (provider && provider[0]) ? provider : "cpu";
globals.numThreads = FlagInt(argc, argv, "--num-threads", 2);
globals.debug = FlagInt(argc, argv, "--debug", 0);
// Kokoro ignores max_num_sentences != 1 (it streams the full text
// through a single forward pass). Default 1 avoids a spurious warning.
globals.maxNumSentences = FlagInt(argc, argv, "--max-num-sentences", 1);
globals.silenceScale = FlagFloat(argc, argv, "--silence-scale", 0.2f);
}
float speed = FlagFloat(argc, argv, "--speed", 1.0f);
fprintf(stderr, "sherpa_server: voice_presets=%s provider=%s threads=%d speed=%.2f\n",
presetsPath, globals.provider, globals.numThreads, speed);
fflush(stderr);
std::vector<ModelDef> modelDefs;
int modelCount = LoadModelDefsFromDir(presetsPath, modelDefs);
if (modelCount == 0) {
fprintf(stderr, "sherpa_server: voice_presets/ has no [Model.*] sections — "
"every request will fail until at least one model is declared\n");
fflush(stderr);
}
std::unordered_map<std::string, EngineState> engines;
// Lazy-load an engine on first reference. Subsequent requests for the
// same key reuse the loaded handle. Failed loads are sticky — we
// don't retry on every line, just log once and reply failure.
auto GetEngine = [&](const std::string& key) -> EngineState* {
std::string lower = key;
for (auto& c : lower) c = (char)tolower((unsigned char)c);
auto it = engines.find(lower);
if (it != engines.end()) {
return it->second.loadFailed ? nullptr : &it->second;
}
EngineState& st = engines[lower];
st.tts = nullptr;
st.sampleRate = 0;
st.numSpeakers = 0;
st.loadFailed = false;
const ModelDef* def = FindModelDefCi(modelDefs, key);
if (!def) {
fprintf(stderr, "sherpa_server: model '%s' not declared in any voice_presets/*.ini\n",
key.c_str());
fflush(stderr);
st.loadFailed = true;
return nullptr;
}
st.tts = CreateEngineForDef(*def, globals, &st.sampleRate, &st.numSpeakers);
if (!st.tts) {
st.loadFailed = true;
return nullptr;
}
return &st;
};
// Tray-icon indicator: lets the user see that sherpa_server.exe is
// alive in the background. The exe is WIN32 subsystem (no console
// window), so without this the only sign of life is the process in
// Task Manager.
{
char tip[160];
snprintf(tip, sizeof(tip),
"Sherpa-onnx TTS (Numen) - %d model(s) registered",
modelCount);
sherpa::StartTrayIcon(tip);
}
// Utterance loop. Read one request line, synthesise, emit WAV.
std::string line;
line.reserve(4096);
while (true) {
int ch = fgetc(stdin);
if (ch == EOF) break;
if (ch == '\r') continue;
if (ch != '\n') {
line.push_back((char)ch);
// Guard runaway input; drop the rest of the line on overflow.
if (line.size() > 16 * 1024) {
while ((ch = fgetc(stdin)) != EOF && ch != '\n') {}
fprintf(stderr, "sherpa_server: dropped oversized request line\n");
fflush(stderr);
line.clear();
WriteFailure();
}
continue;
}
// Parse "<modelKey>\t<sid>\t<text>". Two tabs minimum; the text
// is everything past the second tab and may itself contain
// anything (we don't strip).
size_t tab1 = line.find('\t');
size_t tab2 = (tab1 == std::string::npos) ? std::string::npos
: line.find('\t', tab1 + 1);
if (tab1 == std::string::npos || tab2 == std::string::npos) {
fprintf(stderr, "sherpa_server: request missing tab separator(s): \"%.80s\"\n",
line.c_str());
fflush(stderr);
WriteFailure();
line.clear();
continue;
}
std::string modelKey(line.data(), tab1);
std::string sidStr(line.data() + tab1 + 1, tab2 - tab1 - 1);
const char* text = line.c_str() + tab2 + 1;
int32_t sid = atoi(sidStr.c_str());
EngineState* eng = GetEngine(modelKey);
if (!eng) {
WriteFailure();
line.clear();
continue;
}
SherpaOnnxGenerationConfig gcfg = {};
gcfg.sid = sid;
gcfg.speed = speed;
const SherpaOnnxGeneratedAudio* audio =
SherpaOnnxOfflineTtsGenerateWithConfig(eng->tts, text, &gcfg, nullptr, nullptr);
if (!audio || !audio->samples || audio->n <= 0) {
fprintf(stderr, "sherpa_server: synthesis failed (model=%s sid=%d, text=\"%.80s\")\n",
modelKey.c_str(), sid, text);
fflush(stderr);
if (audio) SherpaOnnxDestroyOfflineTtsGeneratedAudio(audio);
WriteFailure();
line.clear();
continue;
}
std::string wav = BuildWav(audio->samples, audio->n, audio->sample_rate);
fprintf(stderr, "sherpa_server: model=%s sid=%d samples=%d sr=%d bytes=%zu\n",
modelKey.c_str(), sid, audio->n, audio->sample_rate, wav.size());
fflush(stderr);
SherpaOnnxDestroyOfflineTtsGeneratedAudio(audio);
WriteResponse(wav.data(), (int32_t)wav.size());
line.clear();
}
fprintf(stderr, "sherpa_server: stdin EOF, shutting down\n");
fflush(stderr);
sherpa::StopTrayIcon();
for (auto& kv : engines) {
if (kv.second.tts) SherpaOnnxDestroyOfflineTts(kv.second.tts);
}
return 0;
}