diff options
Diffstat (limited to 'discord/store.hpp')
-rw-r--r-- | discord/store.hpp | 343 |
1 files changed, 234 insertions, 109 deletions
diff --git a/discord/store.hpp b/discord/store.hpp index f84d13e..776b701 100644 --- a/discord/store.hpp +++ b/discord/store.hpp @@ -21,15 +21,13 @@ public: void SetUser(Snowflake id, const UserData &user); void SetChannel(Snowflake id, const ChannelData &chan); void SetGuild(Snowflake id, const GuildData &guild); - void SetRole(Snowflake id, const RoleData &role); + void SetRole(Snowflake guild_id, const RoleData &role); void SetMessage(Snowflake id, const Message &message); void SetGuildMember(Snowflake guild_id, Snowflake user_id, const GuildMember &data); void SetPermissionOverwrite(Snowflake channel_id, Snowflake id, const PermissionOverwrite &perm); void SetEmoji(Snowflake id, const EmojiData &emoji); void SetBan(Snowflake guild_id, Snowflake user_id, const BanData &ban); - // slap const on everything even tho its not *really* const - std::optional<ChannelData> GetChannel(Snowflake id) const; std::optional<EmojiData> GetEmoji(Snowflake id) const; std::optional<GuildData> GetGuild(Snowflake id) const; @@ -42,25 +40,20 @@ public: std::vector<BanData> GetBans(Snowflake guild_id) const; std::vector<Message> GetLastMessages(Snowflake id, size_t num) const; - std::vector<Snowflake> GetChannelMessageIDs(Snowflake id) const; + std::vector<Message> GetMessagesBefore(Snowflake channel_id, Snowflake message_id, size_t limit) const; std::vector<Message> GetPinnedMessages(Snowflake channel_id) const; std::vector<ChannelData> GetActiveThreads(Snowflake channel_id) const; // public + void AddReaction(const MessageReactionAddObject &data, bool byself); + void RemoveReaction(const MessageReactionRemoveObject &data, bool byself); + void ClearGuild(Snowflake id); void ClearChannel(Snowflake id); void ClearBan(Snowflake guild_id, Snowflake user_id); + void ClearRecipient(Snowflake channel_id, Snowflake user_id); - using users_type = std::unordered_map<Snowflake, UserData>; - using channels_type = std::unordered_map<Snowflake, ChannelData>; - using guilds_type = std::unordered_map<Snowflake, GuildData>; - using roles_type = std::unordered_map<Snowflake, RoleData>; - using messages_type = std::unordered_map<Snowflake, Message>; - using members_type = std::unordered_map<Snowflake, std::unordered_map<Snowflake, GuildMember>>; // [guild][user] - using permission_overwrites_type = std::unordered_map<Snowflake, std::unordered_map<Snowflake, PermissionOverwrite>>; // [channel][user/role] - using emojis_type = std::unordered_map<Snowflake, EmojiData>; - - const std::unordered_set<Snowflake> &GetChannels() const; - const std::unordered_set<Snowflake> &GetGuilds() const; + std::unordered_set<Snowflake> GetChannels() const; + std::unordered_set<Snowflake> GetGuilds() const; void ClearAll(); @@ -68,105 +61,237 @@ public: void EndTransaction(); private: - Message GetMessageBound(sqlite3_stmt *stmt) const; + class Statement; + class Database { + public: + Database(const char *path); + ~Database(); - void SetMessageInteractionPair(Snowflake message_id, const MessageInteractionData &interaction); + int Close(); + int StartTransaction(); + int EndTransaction(); + int Execute(const char *command); + int Error() const; + bool OK() const; + const char *ErrStr() const; + int SetError(int err); + sqlite3 *obj(); + + private: + sqlite3 *m_db; + int m_err = SQLITE_OK; + mutable char m_err_scratch[256] { 0 }; + + // stupid shit i dont like to allow closing properly + using type_signal_close = sigc::signal<void>; + type_signal_close m_signal_close; + + public: + type_signal_close signal_close(); + }; + + class Statement { + public: + Statement() = delete; + Statement(const Statement &other) = delete; + Statement(Database &db, const char *command); + ~Statement(); + Statement &operator=(Statement &other) = delete; + + bool OK() const; + + int Bind(int index, int32_t num); + int Bind(int index, uint32_t num); + int Bind(int index, uint64_t num); + int Bind(int index, const char *str, size_t len = -1); + int Bind(int index, const std::string &str); + int Bind(int index, bool val); + int Bind(int index); + + template<typename T> + int Bind(int index, std::optional<T> opt) { + if (opt.has_value()) + return Bind(index, opt.value()); + else + return Bind(index); + } + + template<typename Iter> + int BindIDsAsJSON(int index, Iter start, Iter end) { + std::vector<Snowflake> x; + for (Iter it = start; it != end; it++) { + x.push_back((*it).ID); + } + return Bind(index, nlohmann::json(x).dump()); + } + + template<typename T> + int BindAsJSONArray(int index, const std::optional<T> &obj) { + if (obj.has_value()) + return Bind(index, nlohmann::json(obj.value()).dump()); + else + return Bind(index, std::string("[]")); + } + + template<typename T> + int BindAsJSON(int index, const std::optional<T> &obj) { + if (obj.has_value()) + return Bind(index, nlohmann::json(obj.value()).dump()); + else + return Bind(index); + } + + template<typename T> + int BindAsJSON(int index, const T &obj) { + return Bind(index, nlohmann::json(obj).dump()); + } - std::unordered_set<Snowflake> m_channels; - std::unordered_set<Snowflake> m_guilds; + template<typename T> + inline typename std::enable_if<std::is_enum<T>::value, int>::type + Bind(int index, T val) { + return Bind(index, static_cast<std::underlying_type<T>::type>(val)); + } + + void Get(int index, uint8_t &out) const; + void Get(int index, int32_t &out) const; + void Get(int index, uint64_t &out) const; + void Get(int index, bool &out) const; + void Get(int index, Snowflake &out) const; + void Get(int index, std::string &out) const; + + template<typename T> + void GetJSON(int index, std::optional<T> &out) const { + if (IsNull(index)) + out = std::nullopt; + else { + std::string stuff; + Get(index, stuff); + if (stuff == "") + out = std::nullopt; + else + out = nlohmann::json::parse(stuff).get<T>(); + } + } + + template<typename T> + void GetJSON(int index, T &out) const { + std::string stuff; + Get(index, stuff); + nlohmann::json::parse(stuff).get_to(out); + } + + template<typename T> + void Get(int index, std::optional<T> &out) const { + if (IsNull(index)) + out = std::nullopt; + else { + T tmp; + Get(index, tmp); + out = std::optional<T>(std::move(tmp)); + } + } + + template<typename T> + inline typename std::enable_if<std::is_enum<T>::value, void>::type + Get(int index, T &val) const { + typename std::underlying_type<T>::type tmp; + Get(index, tmp); + val = static_cast<T>(tmp); + } + + template<typename T> + void GetIDOnlyStructs(int index, std::optional<std::vector<T>> &out) const { + out.emplace(); + std::string str; + Get(index, str); + for (const auto &id : nlohmann::json::parse(str)) + out->emplace_back().ID = id.get<Snowflake>(); + } + + template<typename T, typename OutputIt> + void GetArray(int index, OutputIt first) const { + std::string str; + Get(index, str); + for (const auto &id : nlohmann::json::parse(str)) + *first++ = id.get<T>(); + } + + bool IsNull(int index) const; + int Step(); + bool Insert(); + bool FetchOne(); + int Reset(); + + sqlite3_stmt *obj(); + + private: + Database *m_db; + sqlite3_stmt *m_stmt; + }; + + Message GetMessageBound(std::unique_ptr<Statement> &stmt) const; + + void SetMessageInteractionPair(Snowflake message_id, const MessageInteractionData &interaction); bool CreateTables(); bool CreateStatements(); - void Cleanup(); - - template<typename T> - void Bind(sqlite3_stmt *stmt, int index, const std::optional<T> &opt) const; - - template<typename T> - typename std::enable_if<std::is_enum<T>::value, void>::type - Bind(sqlite3_stmt *stmt, int index, T val) const; - - void Bind(sqlite3_stmt *stmt, int index, int num) const; - void Bind(sqlite3_stmt *stmt, int index, uint64_t num) const; - void Bind(sqlite3_stmt *stmt, int index, const std::string &str) const; - void Bind(sqlite3_stmt *stmt, int index, bool val) const; - void Bind(sqlite3_stmt *stmt, int index, std::nullptr_t) const; - bool RunInsert(sqlite3_stmt *stmt); - bool FetchOne(sqlite3_stmt *stmt) const; - - template<typename T> - void Get(sqlite3_stmt *stmt, int index, std::optional<T> &out) const; - - template<typename T> - typename std::enable_if<std::is_enum<T>::value, void>::type - Get(sqlite3_stmt *stmt, int index, T &out) const; - - void Get(sqlite3_stmt *stmt, int index, int &out) const; - void Get(sqlite3_stmt *stmt, int index, uint64_t &out) const; - void Get(sqlite3_stmt *stmt, int index, std::string &out) const; - void Get(sqlite3_stmt *stmt, int index, bool &out) const; - void Get(sqlite3_stmt *stmt, int index, Snowflake &out) const; - bool IsNull(sqlite3_stmt *stmt, int index) const; - void Reset(sqlite3_stmt *stmt) const; + + bool m_ok = true; std::filesystem::path m_db_path; - mutable sqlite3 *m_db; - mutable int m_db_err; - mutable sqlite3_stmt *m_set_user_stmt; - mutable sqlite3_stmt *m_get_user_stmt; - mutable sqlite3_stmt *m_set_perm_stmt; - mutable sqlite3_stmt *m_get_perm_stmt; - mutable sqlite3_stmt *m_set_msg_stmt; - mutable sqlite3_stmt *m_get_msg_stmt; - mutable sqlite3_stmt *m_set_role_stmt; - mutable sqlite3_stmt *m_get_role_stmt; - mutable sqlite3_stmt *m_set_emote_stmt; - mutable sqlite3_stmt *m_get_emote_stmt; - mutable sqlite3_stmt *m_set_member_stmt; - mutable sqlite3_stmt *m_get_member_stmt; - mutable sqlite3_stmt *m_set_guild_stmt; - mutable sqlite3_stmt *m_get_guild_stmt; - mutable sqlite3_stmt *m_set_chan_stmt; - mutable sqlite3_stmt *m_get_chan_stmt; - mutable sqlite3_stmt *m_set_ban_stmt; - mutable sqlite3_stmt *m_get_ban_stmt; - mutable sqlite3_stmt *m_clear_ban_stmt; - mutable sqlite3_stmt *m_get_bans_stmt; - mutable sqlite3_stmt *m_set_msg_interaction_stmt; - mutable sqlite3_stmt *m_get_last_msgs_stmt; - mutable sqlite3_stmt *m_get_msg_ids_stmt; - mutable sqlite3_stmt *m_get_pins_stmt; - mutable sqlite3_stmt *m_get_threads_stmt; - mutable sqlite3_stmt *m_clear_chan_stmt; + Database m_db; +#define STMT(x) mutable std::unique_ptr<Statement> m_stmt_##x + STMT(set_guild); + STMT(get_guild); + STMT(get_guild_ids); + STMT(clr_guild); + STMT(set_chan); + STMT(get_chan); + STMT(get_chan_ids); + STMT(clr_chan); + STMT(set_msg); + STMT(get_msg); + STMT(set_msg_ref); + STMT(get_last_msgs); + STMT(set_user); + STMT(get_user); + STMT(set_member); + STMT(get_member); + STMT(set_role); + STMT(get_role); + STMT(set_emoji); + STMT(get_emoji); + STMT(set_perm); + STMT(get_perm); + STMT(set_ban); + STMT(get_ban); + STMT(get_bans); + STMT(clr_ban); + STMT(set_interaction); + STMT(set_member_roles); + STMT(get_member_roles); + STMT(set_guild_emoji); + STMT(get_guild_emojis); + STMT(clr_guild_emoji); + STMT(set_guild_feature); + STMT(get_guild_features); + STMT(get_guild_chans); + STMT(set_thread); + STMT(get_threads); + STMT(get_active_threads); + STMT(get_messages_before); + STMT(get_pins); + STMT(set_emoji_role); + STMT(get_emoji_roles); + STMT(set_mention); + STMT(get_mentions); + STMT(set_attachment); + STMT(get_attachments); + STMT(set_recipient); + STMT(get_recipients); + STMT(clr_recipient); + STMT(add_reaction); + STMT(sub_reaction); + STMT(get_reactions); +#undef STMT }; - -template<typename T> -inline void Store::Bind(sqlite3_stmt *stmt, int index, const std::optional<T> &opt) const { - if (opt.has_value()) - Bind(stmt, index, *opt); - else - sqlite3_bind_null(stmt, index); -} - -template<typename T> -inline typename std::enable_if<std::is_enum<T>::value, void>::type -Store::Bind(sqlite3_stmt *stmt, int index, T val) const { - Bind(stmt, index, static_cast<typename std::underlying_type<T>::type>(val)); -} - -template<typename T> -inline void Store::Get(sqlite3_stmt *stmt, int index, std::optional<T> &out) const { - if (sqlite3_column_type(stmt, index) == SQLITE_NULL) - out = std::nullopt; - else { - T v; - Get(stmt, index, v); - out = std::optional<T>(v); - } -} - -template<typename T> -inline typename std::enable_if<std::is_enum<T>::value, void>::type -Store::Get(sqlite3_stmt *stmt, int index, T &out) const { - out = static_cast<T>(sqlite3_column_int(stmt, index)); -} |