diff options
Diffstat (limited to 'discord/store.cpp')
-rw-r--r-- | discord/store.cpp | 262 |
1 files changed, 160 insertions, 102 deletions
diff --git a/discord/store.cpp b/discord/store.cpp index e7699ba..7aa74c0 100644 --- a/discord/store.cpp +++ b/discord/store.cpp @@ -315,6 +315,82 @@ void Store::SetUser(Snowflake id, const UserData &user) { } } +Message Store::GetMessageBound(sqlite3_stmt *stmt) const { + Message ret; + Get(stmt, 0, ret.ID); + Get(stmt, 1, ret.ChannelID); + Get(stmt, 2, ret.GuildID); + Get(stmt, 3, ret.Author.ID); // yike + Get(stmt, 4, ret.Content); + Get(stmt, 5, ret.Timestamp); + Get(stmt, 6, ret.EditedTimestamp); + Get(stmt, 7, ret.IsTTS); + Get(stmt, 8, ret.DoesMentionEveryone); + std::string tmps; + Get(stmt, 9, tmps); + nlohmann::json::parse(tmps).get_to(ret.Mentions); + Get(stmt, 10, tmps); + nlohmann::json::parse(tmps).get_to(ret.Attachments); + Get(stmt, 11, tmps); + nlohmann::json::parse(tmps).get_to(ret.Embeds); + Get(stmt, 12, ret.IsPinned); + Get(stmt, 13, ret.WebhookID); + uint64_t tmpi; + Get(stmt, 14, tmpi); + ret.Type = static_cast<MessageType>(tmpi); + + Get(stmt, 15, tmps); + if (tmps != "") + ret.Application = nlohmann::json::parse(tmps).get<MessageApplicationData>(); + + Get(stmt, 16, tmps); + if (tmps != "") + ret.MessageReference = nlohmann::json::parse(tmps).get<MessageReferenceData>(); + + Get(stmt, 17, tmpi); + ret.Flags = static_cast<MessageFlags>(tmpi); + + Get(stmt, 18, tmps); + if (tmps != "") + ret.Stickers = nlohmann::json::parse(tmps).get<std::vector<StickerData>>(); + + Get(stmt, 19, tmps); + if (tmps != "") + ret.Reactions = nlohmann::json::parse(tmps).get<std::vector<ReactionData>>(); + + bool tmpb = false; + Get(stmt, 20, tmpb); + if (tmpb) ret.SetDeleted(); + + Get(stmt, 21, tmpb); + if (tmpb) ret.SetEdited(); + + Get(stmt, 22, ret.IsPending); + Get(stmt, 23, ret.Nonce); + + // interaction data from join + + if (!IsNull(stmt, 24)) { + auto &interaction = ret.Interaction.emplace(); + Get(stmt, 24, interaction.ID); + Get(stmt, 25, interaction.Name); + Get(stmt, 26, interaction.Type); + Get(stmt, 27, interaction.User.ID); + } + + Reset(stmt); + + if (ret.MessageReference.has_value() && ret.MessageReference->MessageID.has_value()) { + auto ref = GetMessage(*ret.MessageReference->MessageID); + if (ref.has_value()) + ret.ReferencedMessage = std::make_unique<Message>(std::move(*ref)); + else + ret.ReferencedMessage = nullptr; + } + + return ret; +} + void Store::SetMessageInteractionPair(Snowflake message_id, const MessageInteractionData &interaction) { Bind(m_set_msg_interaction_stmt, 1, message_id); Bind(m_set_msg_interaction_stmt, 2, interaction.ID); @@ -362,6 +438,31 @@ std::vector<BanData> Store::GetBans(Snowflake guild_id) const { return ret; } +std::vector<Message> Store::GetLastMessages(Snowflake id, size_t num) const { + auto ids = GetChannelMessageIDs(id); + std::vector<Message> ret; + for (auto it = ids.cend() - std::min(ids.size(), num); it != ids.cend(); it++) + ret.push_back(*GetMessage(*it)); + return ret; +} + +std::vector<Snowflake> Store::GetChannelMessageIDs(Snowflake id) const { + std::vector<Snowflake> ret; + Bind(m_get_msg_ids_stmt, 1, id); + + while (FetchOne(m_get_msg_ids_stmt)) { + Snowflake x; + Get(m_get_msg_ids_stmt, 0, x); + ret.push_back(x); + } + + Reset(m_get_msg_ids_stmt); + + if (m_db_err != SQLITE_DONE) + fprintf(stderr, "error while fetching ids: %s\n", sqlite3_errstr(m_db_err)); + return ret; +} + std::optional<ChannelData> Store::GetChannel(Snowflake id) const { Bind(m_get_chan_stmt, 1, id); if (!FetchOne(m_get_chan_stmt)) { @@ -537,77 +638,7 @@ std::optional<Message> Store::GetMessage(Snowflake id) const { return std::nullopt; } - Message ret; - ret.ID = id; - Get(m_get_msg_stmt, 1, ret.ChannelID); - Get(m_get_msg_stmt, 2, ret.GuildID); - Get(m_get_msg_stmt, 3, ret.Author.ID); // yike - Get(m_get_msg_stmt, 4, ret.Content); - Get(m_get_msg_stmt, 5, ret.Timestamp); - Get(m_get_msg_stmt, 6, ret.EditedTimestamp); - Get(m_get_msg_stmt, 7, ret.IsTTS); - Get(m_get_msg_stmt, 8, ret.DoesMentionEveryone); - std::string tmps; - Get(m_get_msg_stmt, 9, tmps); - nlohmann::json::parse(tmps).get_to(ret.Mentions); - Get(m_get_msg_stmt, 10, tmps); - nlohmann::json::parse(tmps).get_to(ret.Attachments); - Get(m_get_msg_stmt, 11, tmps); - nlohmann::json::parse(tmps).get_to(ret.Embeds); - Get(m_get_msg_stmt, 12, ret.IsPinned); - Get(m_get_msg_stmt, 13, ret.WebhookID); - uint64_t tmpi; - Get(m_get_msg_stmt, 14, tmpi); - ret.Type = static_cast<MessageType>(tmpi); - - Get(m_get_msg_stmt, 15, tmps); - if (tmps != "") - ret.Application = nlohmann::json::parse(tmps).get<MessageApplicationData>(); - - Get(m_get_msg_stmt, 16, tmps); - if (tmps != "") - ret.MessageReference = nlohmann::json::parse(tmps).get<MessageReferenceData>(); - - Get(m_get_msg_stmt, 17, tmpi); - ret.Flags = static_cast<MessageFlags>(tmpi); - - Get(m_get_msg_stmt, 18, tmps); - if (tmps != "") - ret.Stickers = nlohmann::json::parse(tmps).get<std::vector<StickerData>>(); - - Get(m_get_msg_stmt, 19, tmps); - if (tmps != "") - ret.Reactions = nlohmann::json::parse(tmps).get<std::vector<ReactionData>>(); - - bool tmpb = false; - Get(m_get_msg_stmt, 20, tmpb); - if (tmpb) ret.SetDeleted(); - - Get(m_get_msg_stmt, 21, tmpb); - if (tmpb) ret.SetEdited(); - - Get(m_get_msg_stmt, 22, ret.IsPending); - Get(m_get_msg_stmt, 23, ret.Nonce); - - // interaction data from join - - if (!IsNull(m_get_msg_stmt, 24)) { - auto &interaction = ret.Interaction.emplace(); - Get(m_get_msg_stmt, 24, interaction.ID); - Get(m_get_msg_stmt, 25, interaction.Name); - Get(m_get_msg_stmt, 26, interaction.Type); - Get(m_get_msg_stmt, 27, interaction.User.ID); - } - - Reset(m_get_msg_stmt); - - if (ret.MessageReference.has_value() && ret.MessageReference->MessageID.has_value()) { - auto ref = GetMessage(*ret.MessageReference->MessageID); - if (ref.has_value()) - ret.ReferencedMessage = std::make_unique<Message>(std::move(*ref)); - else - ret.ReferencedMessage = nullptr; - } + auto ret = GetMessageBound(m_get_msg_stmt); return std::optional<Message>(std::move(ret)); } @@ -731,7 +762,7 @@ void Store::EndTransaction() { } bool Store::CreateTables() { - constexpr const char *create_users = R"( + const char *create_users = R"( CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY, username TEXT NOT NULL, @@ -749,7 +780,7 @@ bool Store::CreateTables() { ) )"; - constexpr const char *create_permissions = R"( + const char *create_permissions = R"( CREATE TABLE IF NOT EXISTS permissions ( id INTEGER NOT NULL, channel_id INTEGER NOT NULL, @@ -760,7 +791,7 @@ bool Store::CreateTables() { ) )"; - constexpr const char *create_messages = R"( + const char *create_messages = R"( CREATE TABLE IF NOT EXISTS messages ( id INTEGER PRIMARY KEY, channel_id INTEGER NOT NULL, @@ -789,7 +820,7 @@ bool Store::CreateTables() { ) )"; - constexpr const char *create_roles = R"( + const char *create_roles = R"( CREATE TABLE IF NOT EXISTS roles ( id INTEGER PRIMARY KEY, name TEXT NOT NULL, @@ -802,7 +833,7 @@ bool Store::CreateTables() { ) )"; - constexpr const char *create_emojis = R"( + const char *create_emojis = R"( CREATE TABLE IF NOT EXISTS emojis ( id INTEGER PRIMARY KEY, /*though nullable, only custom emojis (with non-null ids) are stored*/ name TEXT NOT NULL, /*same as id*/ @@ -815,7 +846,7 @@ bool Store::CreateTables() { ) )"; - constexpr const char *create_members = R"( + const char *create_members = R"( CREATE TABLE IF NOT EXISTS members ( user_id INTEGER NOT NULL, guild_id INTEGER NOT NULL, @@ -830,7 +861,7 @@ bool Store::CreateTables() { ) )"; - constexpr const char *create_guilds = R"( + const char *create_guilds = R"( CREATE TABLE IF NOT EXISTS guilds ( id INTEGER PRIMARY KEY, name TEXT NOT NULL, @@ -874,7 +905,7 @@ bool Store::CreateTables() { ) )"; - constexpr const char *create_channels = R"( + const char *create_channels = R"( CREATE TABLE IF NOT EXISTS channels ( id INTEGER PRIMARY KEY, type INTEGER NOT NULL, @@ -897,7 +928,7 @@ bool Store::CreateTables() { ) )"; - constexpr const char *create_bans = R"( + const char *create_bans = R"( CREATE TABLE IF NOT EXISTS bans ( guild_id INTEGER NOT NULL, user_id INTEGER NOT NULL, @@ -906,7 +937,7 @@ bool Store::CreateTables() { ) )"; - constexpr const char *create_interactions = R"( + const char *create_interactions = R"( CREATE TABLE IF NOT EXISTS message_interactions ( message_id INTEGER NOT NULL, interaction_id INTEGER NOT NULL, @@ -981,33 +1012,33 @@ bool Store::CreateTables() { } bool Store::CreateStatements() { - constexpr const char *set_user = R"( + const char *set_user = R"( REPLACE INTO users VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ) )"; - constexpr const char *get_user = R"( + const char *get_user = R"( SELECT * FROM users WHERE id = ? )"; - constexpr const char *set_perm = R"( + const char *set_perm = R"( REPLACE INTO permissions VALUES ( ?, ?, ?, ?, ? ) )"; - constexpr const char *get_perm = R"( + const char *get_perm = R"( SELECT * FROM permissions WHERE id = ? AND channel_id = ? )"; - constexpr const char *set_msg = R"( + const char *set_msg = R"( REPLACE INTO messages VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ) )"; - constexpr const char *get_msg = R"( + const char *get_msg = R"( SELECT messages.*, message_interactions.interaction_id as interaction_id, message_interactions.name as interaction_name, @@ -1020,80 +1051,93 @@ bool Store::CreateStatements() { WHERE id = ? )"; - constexpr const char *set_role = R"( + const char *set_role = R"( REPLACE INTO roles VALUES ( ?, ?, ?, ?, ?, ?, ?, ? ) )"; - constexpr const char *get_role = R"( + const char *get_role = R"( SELECT * FROM roles WHERE id = ? )"; - constexpr const char *set_emoji = R"( + const char *set_emoji = R"( REPLACE INTO emojis VALUES ( ?, ?, ?, ?, ?, ?, ?, ? ) )"; - constexpr const char *get_emoji = R"( + const char *get_emoji = R"( SELECT * FROM emojis WHERE id = ? )"; - constexpr const char *set_member = R"( + const char *set_member = R"( REPLACE INTO members VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ? ) )"; - constexpr const char *get_member = R"( + const char *get_member = R"( SELECT * FROM members WHERE user_id = ? AND guild_id = ? )"; - constexpr const char *set_guild = R"( + const char *set_guild = R"( REPLACE INTO guilds VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ) )"; - constexpr const char *get_guild = R"( + const char *get_guild = R"( SELECT * FROM guilds WHERE id = ? )"; - constexpr const char *set_chan = R"( + const char *set_chan = R"( REPLACE INTO channels VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ) )"; - constexpr const char *get_chan = R"( + const char *get_chan = R"( SELECT * FROM channels WHERE id = ? )"; - constexpr const char *set_ban = R"( + const char *set_ban = R"( REPLACE INTO bans VALUES ( ?, ?, ? ) )"; - constexpr const char *get_ban = R"( + const char *get_ban = R"( SELECT * FROM bans WHERE guild_id = ? AND user_id = ? )"; - constexpr const char *clear_ban = R"( + const char *clear_ban = R"( DELETE FROM bans WHERE guild_id = ? AND user_id = ? )"; - constexpr const char *get_bans = R"( + const char *get_bans = R"( SELECT * FROM bans WHERE guild_id = ? )"; - constexpr const char *set_interaction = R"( + const char *set_interaction = R"( REPLACE INTO message_interactions VALUES ( ?, ?, ?, ?, ? ) )"; + const char *get_last_msgs = R"( + SELECT * FROM ( + SELECT * FROM messages + WHERE channel_id = ? + ORDER BY id DESC + LIMIT ? + ) T1 ORDER BY id ASC + )"; + + const char *get_msg_ids = R"( + SELECT id FROM messages WHERE channel_id = ? ORDER BY id ASC + )"; + m_db_err = sqlite3_prepare_v2(m_db, set_user, -1, &m_set_user_stmt, nullptr); if (m_db_err != SQLITE_OK) { fprintf(stderr, "failed to prepare set user statement: %s\n", sqlite3_errstr(m_db_err)); @@ -1220,6 +1264,18 @@ bool Store::CreateStatements() { return false; } + m_db_err = sqlite3_prepare_v2(m_db, get_last_msgs, -1, &m_get_last_msgs_stmt, nullptr); + if (m_db_err != SQLITE_OK) { + fprintf(stderr, "failed to prepare get last messages statement: %s\n", sqlite3_errstr(m_db_err)); + return false; + } + + m_db_err = sqlite3_prepare_v2(m_db, get_msg_ids, -1, &m_get_msg_ids_stmt, nullptr); + if (m_db_err != SQLITE_OK) { + fprintf(stderr, "failed to prepare get msg ids statement: %s\n", sqlite3_errstr(m_db_err)); + return false; + } + return true; } @@ -1245,6 +1301,8 @@ void Store::Cleanup() { sqlite3_finalize(m_clear_ban_stmt); sqlite3_finalize(m_get_bans_stmt); sqlite3_finalize(m_set_msg_interaction_stmt); + sqlite3_finalize(m_get_last_msgs_stmt); + sqlite3_finalize(m_get_msg_ids_stmt); } void Store::Bind(sqlite3_stmt *stmt, int index, int num) const { |