diff options
-rw-r--r-- | src/discord/discord.cpp | 16 | ||||
-rw-r--r-- | src/discord/store.cpp | 42 | ||||
-rw-r--r-- | src/discord/store.hpp | 4 |
3 files changed, 53 insertions, 9 deletions
diff --git a/src/discord/discord.cpp b/src/discord/discord.cpp index 2808e17..d7cee4c 100644 --- a/src/discord/discord.cpp +++ b/src/discord/discord.cpp @@ -341,12 +341,10 @@ bool DiscordClient::HasChannelPermission(Snowflake user_id, Snowflake channel_id } Permission DiscordClient::ComputePermissions(Snowflake member_id, Snowflake guild_id) const { - const auto member = GetMember(member_id, guild_id); - const auto guild = GetGuild(guild_id); - if (!member.has_value() || !guild.has_value()) - return Permission::NONE; + const auto member_roles = m_store.GetMemberRoles(guild_id, member_id); + const auto guild_owner = m_store.GetGuildOwner(guild_id); - if (guild->OwnerID == member_id) + if (guild_owner == member_id) return Permission::ALL; const auto everyone = GetRole(guild_id); @@ -354,7 +352,7 @@ Permission DiscordClient::ComputePermissions(Snowflake member_id, Snowflake guil return Permission::NONE; Permission perms = everyone->Permissions; - for (const auto role_id : member->Roles) { + for (const auto role_id : member_roles) { const auto role = GetRole(role_id); if (role.has_value()) perms |= role->Permissions; @@ -371,8 +369,8 @@ Permission DiscordClient::ComputeOverwrites(Permission base, Snowflake member_id return Permission::ALL; const auto channel = GetChannel(channel_id); - const auto member = GetMember(member_id, *channel->GuildID); - if (!member.has_value() || !channel.has_value()) + const auto member_roles = m_store.GetMemberRoles(*channel->GuildID, member_id); + if (!channel.has_value()) return Permission::NONE; Permission perms = base; @@ -384,7 +382,7 @@ Permission DiscordClient::ComputeOverwrites(Permission base, Snowflake member_id Permission allow = Permission::NONE; Permission deny = Permission::NONE; - for (const auto role_id : member->Roles) { + for (const auto role_id : member_roles) { const auto overwrite = GetPermissionOverwrite(channel_id, role_id); if (overwrite.has_value()) { allow |= overwrite->Allow; diff --git a/src/discord/store.cpp b/src/discord/store.cpp index 892f4aa..7f674c4 100644 --- a/src/discord/store.cpp +++ b/src/discord/store.cpp @@ -473,6 +473,40 @@ std::vector<BanData> Store::GetBans(Snowflake guild_id) const { return ret; } +Snowflake Store::GetGuildOwner(Snowflake guild_id) const { + auto &s = m_stmt_get_guild_owner; + + s->Bind(1, guild_id); + if (s->FetchOne()) { + Snowflake ret; + s->Get(0, ret); + s->Reset(); + return ret; + } + + s->Reset(); + + return Snowflake::Invalid; +} + +std::vector<Snowflake> Store::GetMemberRoles(Snowflake guild_id, Snowflake user_id) const { + std::vector<Snowflake> ret; + + auto &s = m_stmt_get_member_roles; + + s->Bind(1, user_id); + s->Bind(2, guild_id); + + while (s->FetchOne()) { + auto &f = ret.emplace_back(); + s->Get(0, f); + } + + s->Reset(); + + return ret; +} + std::vector<Message> Store::GetLastMessages(Snowflake id, size_t num) const { auto &s = m_stmt_get_last_msgs; std::vector<Message> msgs; @@ -2198,6 +2232,14 @@ bool Store::CreateStatements() { return false; } + m_stmt_get_guild_owner = std::make_unique<Statement>(m_db, R"( + SELECT owner_id FROM guilds WHERE id = ? + )"); + if (!m_stmt_get_guild_owner->OK()) { + fprintf(stderr, "failed to prepare get guild owner statement: %s\n", m_db.ErrStr()); + return false; + } + return true; } diff --git a/src/discord/store.hpp b/src/discord/store.hpp index da97dd5..f1e2f05 100644 --- a/src/discord/store.hpp +++ b/src/discord/store.hpp @@ -39,6 +39,9 @@ public: std::optional<BanData> GetBan(Snowflake guild_id, Snowflake user_id) const; std::vector<BanData> GetBans(Snowflake guild_id) const; + Snowflake GetGuildOwner(Snowflake guild_id) const; + std::vector<Snowflake> GetMemberRoles(Snowflake guild_id, Snowflake user_id) const; + std::vector<Message> GetLastMessages(Snowflake id, size_t num) const; std::vector<Message> GetMessagesBefore(Snowflake channel_id, Snowflake message_id, size_t limit) const; std::vector<Message> GetPinnedMessages(Snowflake channel_id) const; @@ -308,5 +311,6 @@ private: STMT(get_chan_ids_parent); STMT(get_guild_member_ids); STMT(clr_role); + STMT(get_guild_owner); #undef STMT }; |