summaryrefslogtreecommitdiff
path: root/src/discord
diff options
context:
space:
mode:
authorouwou <26526779+ouwou@users.noreply.github.com>2023-10-03 02:10:20 +0000
committerGitHub <noreply@github.com>2023-10-03 02:10:20 +0000
commit483b547a6447f370da20f68c59cfe6c5e0059f32 (patch)
tree8fba364dbbadc6ad1d0c884226abfe9c9732eb89 /src/discord
parent5288694563f84e269802df14733ae8c7a7fc6901 (diff)
parent20cfbfad272126c9dce607e4594c5e05194b7629 (diff)
downloadabaddon-portaudio-483b547a6447f370da20f68c59cfe6c5e0059f32.tar.gz
abaddon-portaudio-483b547a6447f370da20f68c59cfe6c5e0059f32.zip
Merge pull request #220 from uowuo/member-list
Rewrite member list (for real this time)
Diffstat (limited to 'src/discord')
-rw-r--r--src/discord/discord.cpp15
-rw-r--r--src/discord/discord.hpp8
-rw-r--r--src/discord/store.cpp64
-rw-r--r--src/discord/store.hpp43
4 files changed, 86 insertions, 44 deletions
diff --git a/src/discord/discord.cpp b/src/discord/discord.cpp
index ccc61b4..0618e72 100644
--- a/src/discord/discord.cpp
+++ b/src/discord/discord.cpp
@@ -261,6 +261,21 @@ Snowflake DiscordClient::GetMemberHoistedRole(Snowflake guild_id, Snowflake user
return top_role.has_value() ? top_role->ID : Snowflake::Invalid;
}
+std::optional<RoleData> DiscordClient::GetMemberHoistedRoleCached(const GuildMember &member, const std::unordered_map<Snowflake, RoleData> &roles, bool with_color) const {
+ std::optional<RoleData> top_role;
+ for (const auto id : member.Roles) {
+ if (const auto iter = roles.find(id); iter != roles.end()) {
+ const auto &role = iter->second;
+ if ((with_color && role.Color != 0x000000) || (!with_color && role.IsHoisted)) {
+ if (!top_role.has_value() || top_role->Position < role.Position) {
+ top_role = role;
+ }
+ }
+ }
+ }
+ return top_role;
+}
+
std::optional<RoleData> DiscordClient::GetMemberHighestRole(Snowflake guild_id, Snowflake user_id) const {
const auto data = GetMember(user_id, guild_id);
if (!data.has_value()) return std::nullopt;
diff --git a/src/discord/discord.hpp b/src/discord/discord.hpp
index ebbf5f9..cb14a52 100644
--- a/src/discord/discord.hpp
+++ b/src/discord/discord.hpp
@@ -18,7 +18,7 @@
#include <queue>
#ifdef GetMessage
- #undef GetMessage
+#undef GetMessage
#endif
class Abaddon;
@@ -55,6 +55,7 @@ public:
std::optional<GuildData> GetGuild(Snowflake id) const;
std::optional<GuildMember> GetMember(Snowflake user_id, Snowflake guild_id) const;
Snowflake GetMemberHoistedRole(Snowflake guild_id, Snowflake user_id, bool with_color = false) const;
+ std::optional<RoleData> GetMemberHoistedRoleCached(const GuildMember &member, const std::unordered_map<Snowflake, RoleData> &roles, bool with_color = false) const;
std::optional<RoleData> GetMemberHighestRole(Snowflake guild_id, Snowflake user_id) const;
std::set<Snowflake> GetUsersInGuild(Snowflake id) const;
std::set<Snowflake> GetChannelsInGuild(Snowflake id) const;
@@ -162,6 +163,11 @@ public:
});
}
+ template<typename Iter>
+ std::vector<UserData> GetUsersBulk(Iter begin, Iter end) {
+ return m_store.GetUsersBulk(begin, end);
+ }
+
// FetchGuildBans fetches all bans+reasons via api, this func fetches stored bans (so usually just GUILD_BAN_ADD data)
std::vector<BanData> GetBansInGuild(Snowflake guild_id);
void FetchGuildBan(Snowflake guild_id, Snowflake user_id, const sigc::slot<void(BanData)> &callback);
diff --git a/src/discord/store.cpp b/src/discord/store.cpp
index d8994c4..dfeb7d1 100644
--- a/src/discord/store.cpp
+++ b/src/discord/store.cpp
@@ -27,18 +27,7 @@ Store::Store(bool mem_store)
m_ok &= CreateStatements();
}
-Store::~Store() {
- m_db.Close();
- if (!m_db.OK()) {
- fprintf(stderr, "error closing database: %s\n", m_db.ErrStr());
- return;
- }
-
- if (m_db_path != ":memory:") {
- std::error_code ec;
- std::filesystem::remove(m_db_path, ec);
- }
-}
+Store::~Store() {}
bool Store::IsValid() const {
return m_db.OK() && m_ok;
@@ -519,7 +508,6 @@ std::optional<WebhookMessageData> Store::GetWebhookMessage(Snowflake message_id)
return data;
}
-
Snowflake Store::GetGuildOwner(Snowflake guild_id) const {
auto &s = m_stmt_get_guild_owner;
@@ -961,6 +949,21 @@ std::optional<Message> Store::GetMessage(Snowflake id) const {
return top;
}
+UserData Store::GetUserBound(Statement *stmt) const {
+ UserData u;
+ stmt->Get(0, u.ID);
+ stmt->Get(1, u.Username);
+ stmt->Get(2, u.Discriminator);
+ stmt->Get(3, u.Avatar);
+ stmt->Get(4, u.IsBot);
+ stmt->Get(5, u.IsSystem);
+ stmt->Get(6, u.IsMFAEnabled);
+ stmt->Get(7, u.PremiumType);
+ stmt->Get(8, u.PublicFlags);
+ stmt->Get(9, u.GlobalName);
+ return u;
+}
+
Message Store::GetMessageBound(std::unique_ptr<Statement> &s) const {
Message r;
@@ -1137,18 +1140,7 @@ std::optional<UserData> Store::GetUser(Snowflake id) const {
return {};
}
- UserData r;
-
- r.ID = id;
- s->Get(1, r.Username);
- s->Get(2, r.Discriminator);
- s->Get(3, r.Avatar);
- s->Get(4, r.IsBot);
- s->Get(5, r.IsSystem);
- s->Get(6, r.IsMFAEnabled);
- s->Get(7, r.PremiumType);
- s->Get(8, r.PublicFlags);
- s->Get(9, r.GlobalName);
+ auto r = GetUserBound(s.get());
s->Reset();
@@ -2360,7 +2352,8 @@ bool Store::CreateStatements() {
return true;
}
-Store::Database::Database(const char *path) {
+Store::Database::Database(const char *path)
+ : m_db_path(path) {
if (path != ":memory:"s) {
std::error_code ec;
if (std::filesystem::exists(path, ec) && !std::filesystem::remove(path, ec)) {
@@ -2377,9 +2370,18 @@ Store::Database::~Database() {
int Store::Database::Close() {
if (m_db == nullptr) return m_err;
- m_signal_close.emit();
m_err = sqlite3_close(m_db);
m_db = nullptr;
+
+ if (!OK()) {
+ fprintf(stderr, "error closing database: %s\n", ErrStr());
+ } else {
+ if (m_db_path != ":memory:") {
+ std::error_code ec;
+ std::filesystem::remove(m_db_path, ec);
+ }
+ }
+
return m_err;
}
@@ -2420,17 +2422,9 @@ sqlite3 *Store::Database::obj() {
return m_db;
}
-Store::Database::type_signal_close Store::Database::signal_close() {
- return m_signal_close;
-}
-
Store::Statement::Statement(Database &db, const char *command)
: m_db(&db) {
if (m_db->SetError(sqlite3_prepare_v2(m_db->obj(), command, -1, &m_stmt, nullptr)) != SQLITE_OK) return;
- m_db->signal_close().connect([this] {
- sqlite3_finalize(m_stmt);
- m_stmt = nullptr;
- });
}
Store::Statement::~Statement() {
diff --git a/src/discord/store.hpp b/src/discord/store.hpp
index b6979d0..6157f09 100644
--- a/src/discord/store.hpp
+++ b/src/discord/store.hpp
@@ -11,6 +11,9 @@
#endif
class Store {
+private:
+ class Statement;
+
public:
Store(bool mem_store = false);
~Store();
@@ -51,6 +54,36 @@ public:
std::unordered_set<Snowflake> GetMembersInGuild(Snowflake guild_id) const;
// ^ not the same as GetUsersInGuild since users in a guild may include users who do not have retrieved member data
+ template<typename Iter>
+ std::vector<UserData> GetUsersBulk(Iter begin, Iter end) {
+ const int size = std::distance(begin, end);
+ if (size == 0) return {};
+
+ std::string query = "SELECT * FROM USERS WHERE id IN (";
+ for (int i = 0; i < size; i++) {
+ query += "?, ";
+ }
+ query.resize(query.size() - 2); // chop off extra ", "
+ query += ")";
+
+ Statement s(m_db, query.c_str());
+ if (!s.OK()) {
+ printf("failed to prepare bulk users: %s\n", m_db.ErrStr());
+ return {};
+ }
+
+ for (int i = 0; begin != end; i++, begin++) {
+ s.Bind(i, *begin);
+ }
+
+ std::vector<UserData> r;
+ r.reserve(size);
+ while (s.FetchOne()) {
+ r.push_back(GetUserBound(&s));
+ }
+ return r;
+ }
+
void AddReaction(const MessageReactionAddObject &data, bool byself);
void RemoveReaction(const MessageReactionRemoveObject &data, bool byself);
@@ -69,7 +102,6 @@ public:
void EndTransaction();
private:
- class Statement;
class Database {
public:
Database(const char *path);
@@ -89,13 +121,7 @@ private:
sqlite3 *m_db;
int m_err = SQLITE_OK;
mutable char m_err_scratch[256] { 0 };
-
- // stupid shit i dont like to allow closing properly
- using type_signal_close = sigc::signal<void>;
- type_signal_close m_signal_close;
-
- public:
- type_signal_close signal_close();
+ std::filesystem::path m_db_path;
};
class Statement {
@@ -242,6 +268,7 @@ private:
sqlite3_stmt *m_stmt;
};
+ UserData GetUserBound(Statement *stmt) const;
Message GetMessageBound(std::unique_ptr<Statement> &stmt) const;
static RoleData GetRoleBound(std::unique_ptr<Statement> &stmt);