diff options
Diffstat (limited to 'discord/discord.cpp')
-rw-r--r-- | discord/discord.cpp | 82 |
1 files changed, 75 insertions, 7 deletions
diff --git a/discord/discord.cpp b/discord/discord.cpp index 9aa8eed..15a719c 100644 --- a/discord/discord.cpp +++ b/discord/discord.cpp @@ -3,7 +3,11 @@ #include <cassert> DiscordClient::DiscordClient() - : m_http(DiscordAPI) { + : m_http(DiscordAPI) +#ifdef ABADDON_USE_COMPRESSED_SOCKET + , m_decompress_buf(InflateChunkSize) +#endif +{ LoadEventMap(); } @@ -17,7 +21,7 @@ void DiscordClient::Start() { m_client_connected = true; m_websocket.StartConnection(DiscordGateway); - m_websocket.SetJSONCallback(std::bind(&DiscordClient::HandleGatewayMessage, this, std::placeholders::_1)); + m_websocket.SetMessageCallback(std::bind(&DiscordClient::HandleGatewayMessageRaw, this, std::placeholders::_1)); } void DiscordClient::Stop() { @@ -25,7 +29,7 @@ void DiscordClient::Stop() { if (!m_client_connected) return; m_heartbeat_waiter.kill(); - m_heartbeat_thread.join(); + if (m_heartbeat_thread.joinable()) m_heartbeat_thread.join(); m_client_connected = false; m_websocket.Stop(); @@ -51,9 +55,21 @@ std::vector<std::pair<Snowflake, GuildData>> DiscordClient::GetUserSortedGuilds( std::vector<std::pair<Snowflake, GuildData>> sorted_guilds; if (m_user_settings.GuildPositions.size()) { + std::unordered_set<Snowflake> positioned_guilds(m_user_settings.GuildPositions.begin(), m_user_settings.GuildPositions.end()); + // guilds not in the guild_positions object are at the top of the list, descending by guild ID + std::set<Snowflake> unpositioned_guilds; + for (const auto &[id, guild] : m_guilds) { + if (positioned_guilds.find(id) == positioned_guilds.end()) + unpositioned_guilds.insert(id); + } + + // unpositioned_guilds now has unpositioned guilds in ascending order + for (auto it = unpositioned_guilds.rbegin(); it != unpositioned_guilds.rend(); it++) + sorted_guilds.push_back(std::make_pair(*it, m_guilds.at(*it))); + + // now the rest go at the end in the order they are sorted for (const auto &id : m_user_settings.GuildPositions) { - auto &guild = m_guilds.at(id); - sorted_guilds.push_back(std::make_pair(id, guild)); + sorted_guilds.push_back(std::make_pair(id, m_guilds.at(id))); } } else { // default sort is alphabetic for (auto &it : m_guilds) @@ -130,10 +146,62 @@ void DiscordClient::UpdateToken(std::string token) { m_http.SetAuth(token); } -void DiscordClient::HandleGatewayMessage(nlohmann::json j) { +std::string DiscordClient::DecompressGatewayMessage(std::string str) { + return std::string(); +} + +void DiscordClient::HandleGatewayMessageRaw(std::string str) { +#ifdef ABADDON_USE_COMPRESSED_SOCKET // fuck you work + // handles multiple zlib compressed messages, calling HandleGatewayMessage when a full message is received + std::vector<uint8_t> buf(str.begin(), str.end()); + int len = buf.size(); + bool has_suffix = buf[len - 4] == 0x00 && buf[len - 3] == 0x00 && buf[len - 2] == 0xFF && buf[len - 1] == 0xFF; + + m_compressed_buf.insert(m_compressed_buf.begin(), buf.begin(), buf.end()); + + if (!has_suffix) return; + + z_stream z; + std::memset(&z, 0, sizeof(z)); + + assert(inflateInit2(&z, 15) == 0); + + z.next_in = m_compressed_buf.data(); + z.avail_in = m_compressed_buf.size(); + + // loop in case of really big messages (e.g. READY) + while (true) { + z.next_out = m_decompress_buf.data() + z.total_out; + z.avail_out = m_decompress_buf.size() - z.total_out; + + int err = inflate(&z, Z_SYNC_FLUSH); + if ((err == Z_OK || err == Z_BUF_ERROR) && z.avail_in > 0) { + m_decompress_buf.resize(m_decompress_buf.size() + InflateChunkSize); + } else { + if (err != Z_OK) { + fprintf(stderr, "Error decompressing input buffer %d (%d/%d)\n", err, z.avail_in, z.avail_out); + } else { + HandleGatewayMessage(std::string(m_decompress_buf.begin(), m_decompress_buf.begin() + z.total_out)); + if (m_decompress_buf.size() > InflateChunkSize) + m_decompress_buf.resize(InflateChunkSize); + } + + inflateEnd(&z); + + break; + } + } + + m_compressed_buf.clear(); +#else + HandleGatewayMessage(str); +#endif +} + +void DiscordClient::HandleGatewayMessage(std::string str) { GatewayMessage m; try { - m = j; + m = nlohmann::json::parse(str); } catch (std::exception &e) { printf("Error decoding JSON. Discarding message: %s\n", e.what()); return; |