summaryrefslogtreecommitdiff
path: root/src/discord
diff options
context:
space:
mode:
Diffstat (limited to 'src/discord')
-rw-r--r--src/discord/discord.cpp15
-rw-r--r--src/discord/discord.hpp8
-rw-r--r--src/discord/store.hpp46
3 files changed, 67 insertions, 2 deletions
diff --git a/src/discord/discord.cpp b/src/discord/discord.cpp
index ccc61b4..0618e72 100644
--- a/src/discord/discord.cpp
+++ b/src/discord/discord.cpp
@@ -261,6 +261,21 @@ Snowflake DiscordClient::GetMemberHoistedRole(Snowflake guild_id, Snowflake user
return top_role.has_value() ? top_role->ID : Snowflake::Invalid;
}
+std::optional<RoleData> DiscordClient::GetMemberHoistedRoleCached(const GuildMember &member, const std::unordered_map<Snowflake, RoleData> &roles, bool with_color) const {
+ std::optional<RoleData> top_role;
+ for (const auto id : member.Roles) {
+ if (const auto iter = roles.find(id); iter != roles.end()) {
+ const auto &role = iter->second;
+ if ((with_color && role.Color != 0x000000) || (!with_color && role.IsHoisted)) {
+ if (!top_role.has_value() || top_role->Position < role.Position) {
+ top_role = role;
+ }
+ }
+ }
+ }
+ return top_role;
+}
+
std::optional<RoleData> DiscordClient::GetMemberHighestRole(Snowflake guild_id, Snowflake user_id) const {
const auto data = GetMember(user_id, guild_id);
if (!data.has_value()) return std::nullopt;
diff --git a/src/discord/discord.hpp b/src/discord/discord.hpp
index ebbf5f9..cb14a52 100644
--- a/src/discord/discord.hpp
+++ b/src/discord/discord.hpp
@@ -18,7 +18,7 @@
#include <queue>
#ifdef GetMessage
- #undef GetMessage
+#undef GetMessage
#endif
class Abaddon;
@@ -55,6 +55,7 @@ public:
std::optional<GuildData> GetGuild(Snowflake id) const;
std::optional<GuildMember> GetMember(Snowflake user_id, Snowflake guild_id) const;
Snowflake GetMemberHoistedRole(Snowflake guild_id, Snowflake user_id, bool with_color = false) const;
+ std::optional<RoleData> GetMemberHoistedRoleCached(const GuildMember &member, const std::unordered_map<Snowflake, RoleData> &roles, bool with_color = false) const;
std::optional<RoleData> GetMemberHighestRole(Snowflake guild_id, Snowflake user_id) const;
std::set<Snowflake> GetUsersInGuild(Snowflake id) const;
std::set<Snowflake> GetChannelsInGuild(Snowflake id) const;
@@ -162,6 +163,11 @@ public:
});
}
+ template<typename Iter>
+ std::vector<UserData> GetUsersBulk(Iter begin, Iter end) {
+ return m_store.GetUsersBulk(begin, end);
+ }
+
// FetchGuildBans fetches all bans+reasons via api, this func fetches stored bans (so usually just GUILD_BAN_ADD data)
std::vector<BanData> GetBansInGuild(Snowflake guild_id);
void FetchGuildBan(Snowflake guild_id, Snowflake user_id, const sigc::slot<void(BanData)> &callback);
diff --git a/src/discord/store.hpp b/src/discord/store.hpp
index b6979d0..ccf8d3e 100644
--- a/src/discord/store.hpp
+++ b/src/discord/store.hpp
@@ -11,6 +11,9 @@
#endif
class Store {
+private:
+ class Statement;
+
public:
Store(bool mem_store = false);
~Store();
@@ -51,6 +54,48 @@ public:
std::unordered_set<Snowflake> GetMembersInGuild(Snowflake guild_id) const;
// ^ not the same as GetUsersInGuild since users in a guild may include users who do not have retrieved member data
+ template<typename Iter>
+ std::vector<UserData> GetUsersBulk(Iter begin, Iter end) {
+ const int size = std::distance(begin, end);
+ if (size == 0) return {};
+
+ std::string query = "SELECT * FROM USERS WHERE id IN (";
+ for (int i = 0; i < size; i++) {
+ query += "?, ";
+ }
+ query.resize(query.size() - 2); // chop off extra ", "
+ query += ")";
+
+ Statement s(m_db, query.c_str());
+ if (!s.OK()) {
+ printf("failed to prepare bulk users: %s\n", m_db.ErrStr());
+ return {};
+ }
+
+ for (int i = 0; begin != end; i++, begin++) {
+ s.Bind(i, *begin);
+ }
+
+ std::vector<UserData> r;
+ r.reserve(size);
+ while (s.FetchOne()) {
+ UserData u;
+ s.Get(0, u.ID);
+ s.Get(1, u.Username);
+ s.Get(2, u.Discriminator);
+ s.Get(3, u.Avatar);
+ s.Get(4, u.IsBot);
+ s.Get(5, u.IsSystem);
+ s.Get(6, u.IsMFAEnabled);
+ s.Get(7, u.PremiumType);
+ s.Get(8, u.PublicFlags);
+ s.Get(9, u.GlobalName);
+ r.push_back(u);
+ }
+ printf("fetched %llu\n", r.size());
+ return r;
+ }
+
void AddReaction(const MessageReactionAddObject &data, bool byself);
void RemoveReaction(const MessageReactionRemoveObject &data, bool byself);
@@ -69,7 +114,6 @@ public:
void EndTransaction();
private:
- class Statement;
class Database {
public:
Database(const char *path);