summaryrefslogtreecommitdiff
path: root/discord/discord.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'discord/discord.cpp')
-rw-r--r--discord/discord.cpp82
1 files changed, 75 insertions, 7 deletions
diff --git a/discord/discord.cpp b/discord/discord.cpp
index 9aa8eed..15a719c 100644
--- a/discord/discord.cpp
+++ b/discord/discord.cpp
@@ -3,7 +3,11 @@
#include <cassert>
DiscordClient::DiscordClient()
- : m_http(DiscordAPI) {
+ : m_http(DiscordAPI)
+#ifdef ABADDON_USE_COMPRESSED_SOCKET
+ , m_decompress_buf(InflateChunkSize)
+#endif
+{
LoadEventMap();
}
@@ -17,7 +21,7 @@ void DiscordClient::Start() {
m_client_connected = true;
m_websocket.StartConnection(DiscordGateway);
- m_websocket.SetJSONCallback(std::bind(&DiscordClient::HandleGatewayMessage, this, std::placeholders::_1));
+ m_websocket.SetMessageCallback(std::bind(&DiscordClient::HandleGatewayMessageRaw, this, std::placeholders::_1));
}
void DiscordClient::Stop() {
@@ -25,7 +29,7 @@ void DiscordClient::Stop() {
if (!m_client_connected) return;
m_heartbeat_waiter.kill();
- m_heartbeat_thread.join();
+ if (m_heartbeat_thread.joinable()) m_heartbeat_thread.join();
m_client_connected = false;
m_websocket.Stop();
@@ -51,9 +55,21 @@ std::vector<std::pair<Snowflake, GuildData>> DiscordClient::GetUserSortedGuilds(
std::vector<std::pair<Snowflake, GuildData>> sorted_guilds;
if (m_user_settings.GuildPositions.size()) {
+ std::unordered_set<Snowflake> positioned_guilds(m_user_settings.GuildPositions.begin(), m_user_settings.GuildPositions.end());
+ // guilds not in the guild_positions object are at the top of the list, descending by guild ID
+ std::set<Snowflake> unpositioned_guilds;
+ for (const auto &[id, guild] : m_guilds) {
+ if (positioned_guilds.find(id) == positioned_guilds.end())
+ unpositioned_guilds.insert(id);
+ }
+
+ // unpositioned_guilds now has unpositioned guilds in ascending order
+ for (auto it = unpositioned_guilds.rbegin(); it != unpositioned_guilds.rend(); it++)
+ sorted_guilds.push_back(std::make_pair(*it, m_guilds.at(*it)));
+
+ // now the rest go at the end in the order they are sorted
for (const auto &id : m_user_settings.GuildPositions) {
- auto &guild = m_guilds.at(id);
- sorted_guilds.push_back(std::make_pair(id, guild));
+ sorted_guilds.push_back(std::make_pair(id, m_guilds.at(id)));
}
} else { // default sort is alphabetic
for (auto &it : m_guilds)
@@ -130,10 +146,62 @@ void DiscordClient::UpdateToken(std::string token) {
m_http.SetAuth(token);
}
-void DiscordClient::HandleGatewayMessage(nlohmann::json j) {
+std::string DiscordClient::DecompressGatewayMessage(std::string str) {
+ return std::string();
+}
+
+void DiscordClient::HandleGatewayMessageRaw(std::string str) {
+#ifdef ABADDON_USE_COMPRESSED_SOCKET // fuck you work
+ // handles multiple zlib compressed messages, calling HandleGatewayMessage when a full message is received
+ std::vector<uint8_t> buf(str.begin(), str.end());
+ int len = buf.size();
+ bool has_suffix = buf[len - 4] == 0x00 && buf[len - 3] == 0x00 && buf[len - 2] == 0xFF && buf[len - 1] == 0xFF;
+
+ m_compressed_buf.insert(m_compressed_buf.begin(), buf.begin(), buf.end());
+
+ if (!has_suffix) return;
+
+ z_stream z;
+ std::memset(&z, 0, sizeof(z));
+
+ assert(inflateInit2(&z, 15) == 0);
+
+ z.next_in = m_compressed_buf.data();
+ z.avail_in = m_compressed_buf.size();
+
+ // loop in case of really big messages (e.g. READY)
+ while (true) {
+ z.next_out = m_decompress_buf.data() + z.total_out;
+ z.avail_out = m_decompress_buf.size() - z.total_out;
+
+ int err = inflate(&z, Z_SYNC_FLUSH);
+ if ((err == Z_OK || err == Z_BUF_ERROR) && z.avail_in > 0) {
+ m_decompress_buf.resize(m_decompress_buf.size() + InflateChunkSize);
+ } else {
+ if (err != Z_OK) {
+ fprintf(stderr, "Error decompressing input buffer %d (%d/%d)\n", err, z.avail_in, z.avail_out);
+ } else {
+ HandleGatewayMessage(std::string(m_decompress_buf.begin(), m_decompress_buf.begin() + z.total_out));
+ if (m_decompress_buf.size() > InflateChunkSize)
+ m_decompress_buf.resize(InflateChunkSize);
+ }
+
+ inflateEnd(&z);
+
+ break;
+ }
+ }
+
+ m_compressed_buf.clear();
+#else
+ HandleGatewayMessage(str);
+#endif
+}
+
+void DiscordClient::HandleGatewayMessage(std::string str) {
GatewayMessage m;
try {
- m = j;
+ m = nlohmann::json::parse(str);
} catch (std::exception &e) {
printf("Error decoding JSON. Discarding message: %s\n", e.what());
return;