summaryrefslogtreecommitdiff
path: root/discord
diff options
context:
space:
mode:
authorouwou <26526779+ouwou@users.noreply.github.com>2020-08-25 22:10:39 -0400
committerouwou <26526779+ouwou@users.noreply.github.com>2020-08-25 22:10:39 -0400
commit82a21bd08558ba3e067f490525431be30f978b25 (patch)
tree5695ccc11939e0de3a0fa012e73d8071dd97ce3b /discord
parent6b72931ba729bf6ede33cfa6877a5ad21e913c45 (diff)
downloadabaddon-portaudio-82a21bd08558ba3e067f490525431be30f978b25.tar.gz
abaddon-portaudio-82a21bd08558ba3e067f490525431be30f978b25.zip
fix guild order, add copy id guild, add broken zlib, start member list
Diffstat (limited to 'discord')
-rw-r--r--discord/discord.cpp82
-rw-r--r--discord/discord.hpp17
-rw-r--r--discord/websocket.cpp17
-rw-r--r--discord/websocket.hpp6
4 files changed, 99 insertions, 23 deletions
diff --git a/discord/discord.cpp b/discord/discord.cpp
index 9aa8eed..15a719c 100644
--- a/discord/discord.cpp
+++ b/discord/discord.cpp
@@ -3,7 +3,11 @@
#include <cassert>
DiscordClient::DiscordClient()
- : m_http(DiscordAPI) {
+ : m_http(DiscordAPI)
+#ifdef ABADDON_USE_COMPRESSED_SOCKET
+ , m_decompress_buf(InflateChunkSize)
+#endif
+{
LoadEventMap();
}
@@ -17,7 +21,7 @@ void DiscordClient::Start() {
m_client_connected = true;
m_websocket.StartConnection(DiscordGateway);
- m_websocket.SetJSONCallback(std::bind(&DiscordClient::HandleGatewayMessage, this, std::placeholders::_1));
+ m_websocket.SetMessageCallback(std::bind(&DiscordClient::HandleGatewayMessageRaw, this, std::placeholders::_1));
}
void DiscordClient::Stop() {
@@ -25,7 +29,7 @@ void DiscordClient::Stop() {
if (!m_client_connected) return;
m_heartbeat_waiter.kill();
- m_heartbeat_thread.join();
+ if (m_heartbeat_thread.joinable()) m_heartbeat_thread.join();
m_client_connected = false;
m_websocket.Stop();
@@ -51,9 +55,21 @@ std::vector<std::pair<Snowflake, GuildData>> DiscordClient::GetUserSortedGuilds(
std::vector<std::pair<Snowflake, GuildData>> sorted_guilds;
if (m_user_settings.GuildPositions.size()) {
+ std::unordered_set<Snowflake> positioned_guilds(m_user_settings.GuildPositions.begin(), m_user_settings.GuildPositions.end());
+ // guilds not in the guild_positions object are at the top of the list, descending by guild ID
+ std::set<Snowflake> unpositioned_guilds;
+ for (const auto &[id, guild] : m_guilds) {
+ if (positioned_guilds.find(id) == positioned_guilds.end())
+ unpositioned_guilds.insert(id);
+ }
+
+ // unpositioned_guilds now has unpositioned guilds in ascending order
+ for (auto it = unpositioned_guilds.rbegin(); it != unpositioned_guilds.rend(); it++)
+ sorted_guilds.push_back(std::make_pair(*it, m_guilds.at(*it)));
+
+ // now the rest go at the end in the order they are sorted
for (const auto &id : m_user_settings.GuildPositions) {
- auto &guild = m_guilds.at(id);
- sorted_guilds.push_back(std::make_pair(id, guild));
+ sorted_guilds.push_back(std::make_pair(id, m_guilds.at(id)));
}
} else { // default sort is alphabetic
for (auto &it : m_guilds)
@@ -130,10 +146,62 @@ void DiscordClient::UpdateToken(std::string token) {
m_http.SetAuth(token);
}
-void DiscordClient::HandleGatewayMessage(nlohmann::json j) {
+std::string DiscordClient::DecompressGatewayMessage(std::string str) {
+ return std::string();
+}
+
+void DiscordClient::HandleGatewayMessageRaw(std::string str) {
+#ifdef ABADDON_USE_COMPRESSED_SOCKET // fuck you work
+ // handles multiple zlib compressed messages, calling HandleGatewayMessage when a full message is received
+ std::vector<uint8_t> buf(str.begin(), str.end());
+ int len = buf.size();
+ bool has_suffix = buf[len - 4] == 0x00 && buf[len - 3] == 0x00 && buf[len - 2] == 0xFF && buf[len - 1] == 0xFF;
+
+ m_compressed_buf.insert(m_compressed_buf.begin(), buf.begin(), buf.end());
+
+ if (!has_suffix) return;
+
+ z_stream z;
+ std::memset(&z, 0, sizeof(z));
+
+ assert(inflateInit2(&z, 15) == 0);
+
+ z.next_in = m_compressed_buf.data();
+ z.avail_in = m_compressed_buf.size();
+
+ // loop in case of really big messages (e.g. READY)
+ while (true) {
+ z.next_out = m_decompress_buf.data() + z.total_out;
+ z.avail_out = m_decompress_buf.size() - z.total_out;
+
+ int err = inflate(&z, Z_SYNC_FLUSH);
+ if ((err == Z_OK || err == Z_BUF_ERROR) && z.avail_in > 0) {
+ m_decompress_buf.resize(m_decompress_buf.size() + InflateChunkSize);
+ } else {
+ if (err != Z_OK) {
+ fprintf(stderr, "Error decompressing input buffer %d (%d/%d)\n", err, z.avail_in, z.avail_out);
+ } else {
+ HandleGatewayMessage(std::string(m_decompress_buf.begin(), m_decompress_buf.begin() + z.total_out));
+ if (m_decompress_buf.size() > InflateChunkSize)
+ m_decompress_buf.resize(InflateChunkSize);
+ }
+
+ inflateEnd(&z);
+
+ break;
+ }
+ }
+
+ m_compressed_buf.clear();
+#else
+ HandleGatewayMessage(str);
+#endif
+}
+
+void DiscordClient::HandleGatewayMessage(std::string str) {
GatewayMessage m;
try {
- m = j;
+ m = nlohmann::json::parse(str);
} catch (std::exception &e) {
printf("Error decoding JSON. Discarding message: %s\n", e.what());
return;
diff --git a/discord/discord.hpp b/discord/discord.hpp
index ab390dc..3a98b40 100644
--- a/discord/discord.hpp
+++ b/discord/discord.hpp
@@ -4,8 +4,12 @@
#include <nlohmann/json.hpp>
#include <thread>
#include <unordered_map>
+#include <set>
#include <unordered_set>
#include <mutex>
+#ifdef ABADDON_USE_COMPRESSED_SOCKET
+ #include <zlib.h>
+#endif
// bruh
#ifdef GetMessage
@@ -372,7 +376,11 @@ class DiscordClient {
friend class Abaddon;
public:
+#ifdef ABADDON_USE_COMPRESSED_SOCKET
+ static const constexpr char *DiscordGateway = "wss://gateway.discord.gg/?v=6&encoding=json&compress=zlib-stream";
+#else
static const constexpr char *DiscordGateway = "wss://gateway.discord.gg/?v=6&encoding=json";
+#endif
static const constexpr char *DiscordAPI = "https://discord.com/api";
static const constexpr char *GatewayIdentity = "Discord";
@@ -400,7 +408,14 @@ public:
void UpdateToken(std::string token);
private:
- void HandleGatewayMessage(nlohmann::json msg);
+#ifdef ABADDON_USE_COMPRESSED_SOCKET
+ static const constexpr int InflateChunkSize = 0x10000;
+ std::vector<uint8_t> m_compressed_buf;
+ std::vector<uint8_t> m_decompress_buf;
+#endif
+ std::string DecompressGatewayMessage(std::string str);
+ void HandleGatewayMessageRaw(std::string str);
+ void HandleGatewayMessage(std::string str);
void HandleGatewayReady(const GatewayMessage &msg);
void HandleGatewayMessageCreate(const GatewayMessage &msg);
void HeartbeatThread();
diff --git a/discord/websocket.cpp b/discord/websocket.cpp
index 8232ac6..2251a01 100644
--- a/discord/websocket.cpp
+++ b/discord/websocket.cpp
@@ -1,10 +1,10 @@
#include "websocket.hpp"
#include <functional>
-#include <nlohmann/json.hpp>
Websocket::Websocket() {}
void Websocket::StartConnection(std::string url) {
+ m_websocket.disableAutomaticReconnection();
m_websocket.setUrl(url);
m_websocket.setOnMessageCallback(std::bind(&Websocket::OnMessage, this, std::placeholders::_1));
m_websocket.start();
@@ -19,8 +19,8 @@ bool Websocket::IsOpen() const {
return state == ix::ReadyState::Open;
}
-void Websocket::SetJSONCallback(JSONCallback_t func) {
- m_json_callback = func;
+void Websocket::SetMessageCallback(MessageCallback_t func) {
+ m_callback = func;
}
void Websocket::Send(const std::string &str) {
@@ -39,15 +39,8 @@ void Websocket::OnMessage(const ix::WebSocketMessagePtr &msg) {
// printf("%s\n", msg->str.substr(0, 1000).c_str());
//else
// printf("%s\n", msg->str.c_str());
- nlohmann::json obj;
- try {
- obj = nlohmann::json::parse(msg->str);
- } catch (std::exception &e) {
- printf("Error decoding JSON. Discarding message: %s\n", e.what());
- return;
- }
- if (m_json_callback)
- m_json_callback(obj);
+ if (m_callback)
+ m_callback(msg->str);
} break;
}
}
diff --git a/discord/websocket.hpp b/discord/websocket.hpp
index dc8cbec..8e3aa94 100644
--- a/discord/websocket.hpp
+++ b/discord/websocket.hpp
@@ -10,8 +10,8 @@ public:
Websocket();
void StartConnection(std::string url);
- using JSONCallback_t = std::function<void(nlohmann::json)>;
- void SetJSONCallback(JSONCallback_t func);
+ using MessageCallback_t = std::function<void(std::string data)>;
+ void SetMessageCallback(MessageCallback_t func);
void Send(const std::string &str);
void Send(const nlohmann::json &j);
void Stop();
@@ -20,6 +20,6 @@ public:
private:
void OnMessage(const ix::WebSocketMessagePtr &msg);
- JSONCallback_t m_json_callback;
+ MessageCallback_t m_callback;
ix::WebSocket m_websocket;
};