diff --git a/common/pb/CMakeLists.txt b/common/pb/CMakeLists.txt index adf1e1b8..e698c95f 100644 --- a/common/pb/CMakeLists.txt +++ b/common/pb/CMakeLists.txt @@ -123,6 +123,7 @@ SET(PROTO_FILES response_join_room.proto response_list_users.proto response_login.proto + response_register.proto response_replay_download.proto response_replay_list.proto response.proto diff --git a/common/pb/response.proto b/common/pb/response.proto index 0f13c415..302b1801 100644 --- a/common/pb/response.proto +++ b/common/pb/response.proto @@ -24,6 +24,14 @@ message Response { RespAccessDenied = 20; RespUsernameInvalid = 21; RespRegistrationRequired = 22; + RespRegistrationAccepted = 23; // Server agrees to process client's registration request + RespUserAlreadyExists = 24; // Client attempted to register a name which is already registered + RespEmailRequiredToRegister = 25; // Server requires email to register accounts but client did not provide one + RespServerDoesNotUseAuth = 26; // Client attempted to register but server does not use authentication + RespTooManyRequests = 27; // Server refused to complete command because client has sent too many too quickly + RespAccountNotActivated = 28; // Client attempted to log into a registered username but the account hasn't been activated + RespRegistrationDisabled = 29; // Server does not allow clients to register + RespRegistrationFailed = 30; // Server accepted a reg request but failed to perform the registration } enum ResponseType { JOIN_ROOM = 1000; @@ -35,6 +43,7 @@ message Response { DECK_LIST = 1006; DECK_DOWNLOAD = 1007; DECK_UPLOAD = 1008; + REGISTER = 1009; REPLAY_LIST = 1100; REPLAY_DOWNLOAD = 1101; } diff --git a/common/pb/response_register.proto b/common/pb/response_register.proto new file mode 100644 index 00000000..9c6998ef --- /dev/null +++ b/common/pb/response_register.proto @@ -0,0 +1,9 @@ +import "response.proto"; + +message Response_Register { + extend Response { + optional Response_Register ext = 1009; + } + optional string denied_reason_str = 1; + optional uint64 denied_end_time = 2; +} \ No newline at end of file diff --git a/common/pb/session_commands.proto b/common/pb/session_commands.proto index bbb5e81d..be5cc1db 100644 --- a/common/pb/session_commands.proto +++ b/common/pb/session_commands.proto @@ -1,3 +1,5 @@ +import "serverinfo_user.proto"; + message SessionCommand { enum SessionCommandType { PING = 1000; @@ -16,6 +18,7 @@ message SessionCommand { DECK_UPLOAD = 1013; LIST_ROOMS = 1014; JOIN_ROOM = 1015; + REGISTER = 1016; REPLAY_LIST = 1100; REPLAY_DOWNLOAD = 1101; REPLAY_MODIFY_MATCH = 1102; @@ -94,3 +97,21 @@ message Command_JoinRoom { } optional uint32 room_id = 1; } + +// User wants to register a new account +message Command_Register { + extend SessionCommand { + optional Command_Register ext = 1016; + } + // User name client wants to register + required string user_name = 1; + // Hashed password to be inserted into database + required string password = 2; + // Email address of the client for user validation + optional string email = 3; + // Gender of the user + optional ServerInfo_User.Gender gender = 4; + // Country code of the user. 2 letter ISO format + optional string country = 5; + optional string real_name = 6; +} diff --git a/common/server.cpp b/common/server.cpp index 4b65739b..7fd083e2 100644 --- a/common/server.cpp +++ b/common/server.cpp @@ -175,6 +175,45 @@ AuthenticationResult Server::loginUser(Server_ProtocolHandler *session, QString return authState; } +RegistrationResult Server::registerUserAccount(const QString &ipAddress, const Command_Register &cmd, QString &banReason, int &banSecondsRemaining) +{ + // TODO + + if (!registrationEnabled) + return RegistrationDisabled; + + QString emailAddress = QString::fromStdString(cmd.email()); + if (requireEmailForRegistration && emailAddress.isEmpty()) + return EmailRequired; + + Server_DatabaseInterface *databaseInterface = getDatabaseInterface(); + + // TODO: Move this method outside of the db interface + QString userName = QString::fromStdString(cmd.user_name()); + if (!databaseInterface->usernameIsValid(userName)) + return InvalidUsername; + + if (databaseInterface->checkUserIsBanned(ipAddress, userName, banReason, banSecondsRemaining)) + return ClientIsBanned; + + if (tooManyRegistrationAttempts(ipAddress)) + return TooManyRequests; + + QString realName = QString::fromStdString(cmd.real_name()); + ServerInfo_User_Gender gender = cmd.gender(); + QString country = QString::fromStdString(cmd.country()); + QString passwordSha512 = QString::fromStdString(cmd.password()); + bool regSucceeded = databaseInterface->registerUser(userName, realName, gender, passwordSha512, emailAddress, country, false); + + return regSucceeded ? Accepted : Failed; +} + +bool Server::tooManyRegistrationAttempts(const QString &ipAddress) +{ + // TODO: implement + return false; +} + void Server::addPersistentPlayer(const QString &userName, int roomId, int gameId, int playerId) { QWriteLocker locker(&persistentPlayersLock); diff --git a/common/server.h b/common/server.h index 896f6ef8..ca039350 100644 --- a/common/server.h +++ b/common/server.h @@ -7,6 +7,7 @@ #include #include #include +#include "pb/commands.pb.h" #include "pb/serverinfo_user.pb.h" #include "server_player_reference.h" @@ -28,6 +29,7 @@ class CommandContainer; class Command_JoinGame; enum AuthenticationResult { NotLoggedIn = 0, PasswordRight = 1, UnknownUser = 2, WouldOverwriteOldSession = 3, UserIsBanned = 4, UsernameInvalid = 5, RegistrationRequired = 6 }; +enum RegistrationResult { Accepted = 0, UserAlreadyExists = 1, EmailRequired = 2, UnauthenticatedServer = 3, TooManyRequests = 4, InvalidUsername = 5, ClientIsBanned = 6, RegistrationDisabled = 7, Failed = 8}; class Server : public QObject { @@ -44,6 +46,19 @@ public: ~Server(); void setThreaded(bool _threaded) { threaded = _threaded; } AuthenticationResult loginUser(Server_ProtocolHandler *session, QString &name, const QString &password, QString &reason, int &secondsLeft); + + /** + * Registers a user account. + * @param ipAddress The address of the connection from the user + * @param userName The username to attempt to register + * @param emailAddress The email address to associate with the new account (and to use for activation) + * @param banReason If the client is banned, the reason for the ban will be included in this string. + * @param banSecondsRemaining If the client is banned, the time left will be included in this. 0 if the ban is permanent. + * @return RegistrationResult member indicating whether it succeeded or failed. + */ + RegistrationResult registerUserAccount(const QString &ipAddress, const Command_Register &cmd, QString &banReason, int &banSecondsRemaining); + + bool tooManyRegistrationAttempts(const QString &ipAddress); const QMap &getRooms() { return rooms; } Server_AbstractUserInterface *findUser(const QString &userName) const; @@ -115,6 +130,9 @@ protected: int getUsersCount() const; int getGamesCount() const; void addRoom(Server_Room *newRoom); + + bool registrationEnabled; + bool requireEmailForRegistration; }; #endif diff --git a/common/server_database_interface.h b/common/server_database_interface.h index 518bd5e6..2568f195 100644 --- a/common/server_database_interface.h +++ b/common/server_database_interface.h @@ -13,6 +13,7 @@ public: : QObject(parent) { } virtual AuthenticationResult checkUserPassword(Server_ProtocolHandler *handler, const QString &user, const QString &password, QString &reasonStr, int &secondsLeft) = 0; + virtual bool checkUserIsBanned(const QString &ipAddress, const QString &userName, QString &banReason, int &banSecondsRemaining) { return false; } virtual bool userExists(const QString & /* user */) { return false; } virtual QMap getBuddyList(const QString & /* name */) { return QMap(); } virtual QMap getIgnoreList(const QString & /* name */) { return QMap(); } @@ -23,6 +24,7 @@ public: virtual DeckList *getDeckFromDatabase(int /* deckId */, int /* userId */) { return 0; } virtual qint64 startSession(const QString & /* userName */, const QString & /* address */) { return 0; } + virtual bool usernameIsValid(const QString &userName) { return true; }; public slots: virtual void endSession(qint64 /* sessionId */ ) { } public: @@ -35,9 +37,11 @@ public: virtual bool userSessionExists(const QString & /* userName */) { return false; } virtual bool getRequireRegistration() { return false; } + virtual bool registerUser(const QString &userName, const QString &realName, ServerInfo_User_Gender const &gender, const QString &passwordSha512, const QString &emailAddress, const QString &country, bool active = false) { return false; } enum LogMessage_TargetType { MessageTargetRoom, MessageTargetGame, MessageTargetChat, MessageTargetIslRoom }; virtual void logMessage(const int /* senderId */, const QString & /* senderName */, const QString & /* senderIp */, const QString & /* logMessage */, LogMessage_TargetType /* targetType */, const int /* targetId */, const QString & /* targetName */) { }; + bool checkUserIsBanned(Server_ProtocolHandler *session, QString &banReason, int &banSecondsRemaining); }; #endif diff --git a/common/server_protocolhandler.cpp b/common/server_protocolhandler.cpp index 46d79139..cc5dd7f9 100644 --- a/common/server_protocolhandler.cpp +++ b/common/server_protocolhandler.cpp @@ -9,6 +9,7 @@ #include "pb/commands.pb.h" #include "pb/response.pb.h" #include "pb/response_login.pb.h" +#include "pb/response_register.pb.h" #include "pb/response_list_users.pb.h" #include "pb/response_get_games_of_user.pb.h" #include "pb/response_get_user_info.pb.h" @@ -134,12 +135,17 @@ Response::ResponseCode Server_ProtocolHandler::processSessionCommandContainer(co SessionCommand debugSc(sc); debugSc.MutableExtension(Command_Login::ext)->clear_password(); logDebugMessage(QString::fromStdString(debugSc.ShortDebugString())); + } else if (num == SessionCommand::REGISTER) { + SessionCommand logSc(sc); + logSc.MutableExtension(Command_Register::ext)->clear_password(); + logDebugMessage(QString::fromStdString(logSc.ShortDebugString())); } else logDebugMessage(QString::fromStdString(sc.ShortDebugString())); } switch ((SessionCommand::SessionCommandType) num) { case SessionCommand::PING: resp = cmdPing(sc.GetExtension(Command_Ping::ext), rc); break; case SessionCommand::LOGIN: resp = cmdLogin(sc.GetExtension(Command_Login::ext), rc); break; + case SessionCommand::REGISTER: resp = cmdRegisterAccount(sc.GetExtension(Command_Register::ext), rc); break; case SessionCommand::MESSAGE: resp = cmdMessage(sc.GetExtension(Command_Message::ext), rc); break; case SessionCommand::GET_GAMES_OF_USER: resp = cmdGetGamesOfUser(sc.GetExtension(Command_GetGamesOfUser::ext), rc); break; case SessionCommand::GET_USER_INFO: resp = cmdGetUserInfo(sc.GetExtension(Command_GetUserInfo::ext), rc); break; @@ -413,6 +419,49 @@ Response::ResponseCode Server_ProtocolHandler::cmdLogin(const Command_Login &cmd return Response::RespOk; } +Response::ResponseCode Server_ProtocolHandler::cmdRegisterAccount(const Command_Register &cmd, ResponseContainer &rc) +{ + qDebug() << "Got register command: " << QString::fromStdString(cmd.user_name()); + + QString banReason; + int banSecondsRemaining; + RegistrationResult result = + server->registerUserAccount( + this->getAddress(), + cmd, + banReason, + banSecondsRemaining); + qDebug() << "Register command result:" << result; + + switch (result) { + case RegistrationDisabled: + return Response::RespRegistrationDisabled; + case Accepted: + return Response::RespRegistrationAccepted; + case UserAlreadyExists: + return Response::RespUserAlreadyExists; + case EmailRequired: + return Response::RespEmailRequiredToRegister; + case UnauthenticatedServer: + return Response::RespServerDoesNotUseAuth; + case TooManyRequests: + return Response::RespTooManyRequests; + case InvalidUsername: + return Response::RespUsernameInvalid; + case Failed: + return Response::RespRegistrationFailed; + case ClientIsBanned: + Response_Register *re = new Response_Register; + re->set_denied_reason_str(banReason.toStdString()); + if (banSecondsRemaining != 0) + re->set_denied_end_time(QDateTime::currentDateTime().addSecs(banSecondsRemaining).toTime_t()); + rc.setResponseExtension(re); + return Response::RespUserIsBanned; + } + + return Response::RespInvalidCommand; +} + Response::ResponseCode Server_ProtocolHandler::cmdMessage(const Command_Message &cmd, ResponseContainer &rc) { if (authState == NotLoggedIn) diff --git a/common/server_protocolhandler.h b/common/server_protocolhandler.h index 5389b724..7ea34cf8 100644 --- a/common/server_protocolhandler.h +++ b/common/server_protocolhandler.h @@ -28,6 +28,7 @@ class AdminCommand; class Command_Ping; class Command_Login; +class Command_Register; class Command_Message; class Command_ListUsers; class Command_GetGamesOfUser; @@ -59,6 +60,7 @@ private: Response::ResponseCode cmdPing(const Command_Ping &cmd, ResponseContainer &rc); Response::ResponseCode cmdLogin(const Command_Login &cmd, ResponseContainer &rc); + Response::ResponseCode cmdRegisterAccount(const Command_Register &cmd, ResponseContainer &rc); Response::ResponseCode cmdMessage(const Command_Message &cmd, ResponseContainer &rc); Response::ResponseCode cmdGetGamesOfUser(const Command_GetGamesOfUser &cmd, ResponseContainer &rc); Response::ResponseCode cmdGetUserInfo(const Command_GetUserInfo &cmd, ResponseContainer &rc); diff --git a/servatrice/scripts/.gitignore b/servatrice/scripts/.gitignore new file mode 100644 index 00000000..eb47a082 --- /dev/null +++ b/servatrice/scripts/.gitignore @@ -0,0 +1 @@ +pypb/ diff --git a/servatrice/scripts/mk_pypb.sh b/servatrice/scripts/mk_pypb.sh new file mode 100755 index 00000000..e11d237d --- /dev/null +++ b/servatrice/scripts/mk_pypb.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +SRC_DIR=../../common/pb/ +DST_DIR=./pypb + +rm -rf "$DST_DIR" +mkdir -p "$DST_DIR" +protoc -I=$SRC_DIR --python_out=$DST_DIR $SRC_DIR/*.proto +touch "$DST_DIR/__init__.py" + diff --git a/servatrice/scripts/register.py b/servatrice/scripts/register.py new file mode 100755 index 00000000..427c9e72 --- /dev/null +++ b/servatrice/scripts/register.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python + +import socket, sys, struct, time + +from pypb.server_message_pb2 import ServerMessage +from pypb.session_commands_pb2 import Command_Register as Reg +from pypb.commands_pb2 import CommandContainer as Cmd +from pypb.event_server_identification_pb2 import Event_ServerIdentification as ServerId +from pypb.response_pb2 import Response + +HOST = "localhost" +PORT = 4747 + +CMD_ID = 1 + +def build_reg(): + global CMD_ID + cmd = Cmd() + sc = cmd.session_command.add() + + reg = sc.Extensions[Reg.ext] + reg.user_name = "testUser" + reg.email = "test@example.com" + reg.password = "password" + + cmd.cmd_id = CMD_ID + CMD_ID += 1 + return cmd + +def send(msg): + packed = struct.pack('>I', len(msg)) + sock.sendall(packed) + sock.sendall(msg) + +def print_resp(resp): + print "<<<" + print repr(resp) + m = ServerMessage() + m.ParseFromString(bytes(resp)) + print m + +def recv(sock): + print "< header" + header = sock.recv(4) + msg_size = struct.unpack('>I', header)[0] + print "< ", msg_size + raw_msg = sock.recv(msg_size) + print_resp(raw_msg) + +if __name__ == "__main__": + address = (HOST, PORT) + sock = socket.socket() + + print "Connecting to server ", address + sock.connect(address) + + # hack for old xml clients - server expects this and discards first message + print ">>> xml hack" + xmlClientHack = Cmd().SerializeToString() + send(xmlClientHack) + print sock.recv(60) + + recv(sock) + + print ">>> register" + r = build_reg() + print r + msg = r.SerializeToString() + send(msg) + recv(sock) + + print "Done" + diff --git a/servatrice/servatrice.ini.example b/servatrice/servatrice.ini.example index bd1cad55..f75a8958 100644 --- a/servatrice/servatrice.ini.example +++ b/servatrice/servatrice.ini.example @@ -54,6 +54,13 @@ password=123456 ; Accept only registered users? default is 0 (accept unregistered users) regonly=0 +[registration] + +; Servatrice can process registration requests to add new users on the fly. +; Enable this feature? Default false. +;enabled=false +; Require users to provide an email address in order to register. Default true. +;requireemail=true [database] diff --git a/servatrice/src/servatrice.cpp b/servatrice/src/servatrice.cpp index 66a9c977..e4430912 100644 --- a/servatrice/src/servatrice.cpp +++ b/servatrice/src/servatrice.cpp @@ -160,6 +160,13 @@ bool Servatrice::initServer() authenticationMethod = AuthenticationNone; } + registrationEnabled = settingsCache->value("registration/enabled", false).toBool(); + requireEmailForRegistration = settingsCache->value("registration/requireemail", true).toBool(); + + qDebug() << "Registration enabled: " << registrationEnabled; + if (registrationEnabled) + qDebug() << "Require email address to register: " << requireEmailForRegistration; + QString dbTypeStr = settingsCache->value("database/type").toString(); if (dbTypeStr == "mysql") databaseType = DatabaseMySql; @@ -172,12 +179,17 @@ bool Servatrice::initServer() if (databaseType != DatabaseNone) { settingsCache->beginGroup("database"); dbPrefix = settingsCache->value("prefix").toString(); - servatriceDatabaseInterface->initDatabase("QMYSQL", - settingsCache->value("hostname").toString(), - settingsCache->value("database").toString(), - settingsCache->value("user").toString(), - settingsCache->value("password").toString()); + bool dbOpened = + servatriceDatabaseInterface->initDatabase("QMYSQL", + settingsCache->value("hostname").toString(), + settingsCache->value("database").toString(), + settingsCache->value("user").toString(), + settingsCache->value("password").toString()); settingsCache->endGroup(); + if (!dbOpened) { + qDebug() << "Failed to open database"; + return false; + } updateServerList(); @@ -342,7 +354,7 @@ bool Servatrice::initServer() if (gameServer->listen(QHostAddress::Any, gamePort)) qDebug() << "Server listening."; else { - qDebug() << "gameServer->listen(): Error."; + qDebug() << "gameServer->listen(): Error:" << gameServer->errorString(); return false; } return true; diff --git a/servatrice/src/servatrice_database_interface.cpp b/servatrice/src/servatrice_database_interface.cpp index 77f31156..aabf2a2a 100644 --- a/servatrice/src/servatrice_database_interface.cpp +++ b/servatrice/src/servatrice_database_interface.cpp @@ -9,6 +9,7 @@ #include #include #include +#include Servatrice_DatabaseInterface::Servatrice_DatabaseInterface(int _instanceId, Servatrice *_server) : instanceId(_instanceId), @@ -34,7 +35,9 @@ void Servatrice_DatabaseInterface::initDatabase(const QSqlDatabase &_sqlDatabase } } -void Servatrice_DatabaseInterface::initDatabase(const QString &type, const QString &hostName, const QString &databaseName, const QString &userName, const QString &password) +bool Servatrice_DatabaseInterface::initDatabase(const QString &type, const QString &hostName, + const QString &databaseName, const QString &userName, + const QString &password) { sqlDatabase = QSqlDatabase::addDatabase(type, "main"); sqlDatabase.setHostName(hostName); @@ -42,7 +45,7 @@ void Servatrice_DatabaseInterface::initDatabase(const QString &type, const QStri sqlDatabase.setUserName(userName); sqlDatabase.setPassword(password); - openDatabase(); + return openDatabase(); } bool Servatrice_DatabaseInterface::openDatabase() @@ -102,11 +105,52 @@ bool Servatrice_DatabaseInterface::usernameIsValid(const QString &user) return re.exactMatch(user); } +// TODO move this to Server bool Servatrice_DatabaseInterface::getRequireRegistration() { return settingsCache->value("authentication/regonly", 0).toBool(); } +bool Servatrice_DatabaseInterface::registerUser(const QString &userName, const QString &realName, ServerInfo_User_Gender const &gender, const QString &passwordSha512, const QString &emailAddress, const QString &country, bool active) +{ + if (!checkSql()) + return false; + + QSqlQuery *query = prepareQuery("insert into {prefix}_users " + "(name, realname, gender, password_sha512, email, country, registrationDate, active) " + "values " + "(:userName, :realName, :gender, :password_sha512, :email, :country, UTC_TIMESTAMP(), :active)"); + query->bindValue(":userName", userName); + query->bindValue(":realName", realName); + query->bindValue(":gender", getGenderChar(gender)); + query->bindValue(":password_sha512", passwordSha512); + query->bindValue(":email", emailAddress); + query->bindValue(":country", country); + query->bindValue(":active", active ? 1 : 0); + + if (!execSqlQuery(query)) { + qDebug() << "Failed to insert user: " << query->lastError() << " sql: " << query->lastQuery(); + // TODO handle duplicate insert error + return false; + } + + return true; +} + +QChar Servatrice_DatabaseInterface::getGenderChar(ServerInfo_User_Gender const &gender) +{ + switch (gender) { + case ServerInfo_User_Gender_GenderUnknown: + return QChar('u'); + case ServerInfo_User_Gender_Male: + return QChar('m'); + case ServerInfo_User_Gender_Female: + return QChar('f'); + default: + return QChar('u'); + } +} + AuthenticationResult Servatrice_DatabaseInterface::checkUserPassword(Server_ProtocolHandler *handler, const QString &user, const QString &password, QString &reasonStr, int &banSecondsLeft) { switch (server->getAuthenticationMethod()) { @@ -125,43 +169,8 @@ AuthenticationResult Servatrice_DatabaseInterface::checkUserPassword(Server_Prot if (!usernameIsValid(user)) return UsernameInvalid; - QSqlQuery *ipBanQuery = prepareQuery("select time_to_sec(timediff(now(), date_add(b.time_from, interval b.minutes minute))), b.minutes <=> 0, b.visible_reason from {prefix}_bans b where b.time_from = (select max(c.time_from) from {prefix}_bans c where c.ip_address = :address) and b.ip_address = :address2"); - ipBanQuery->bindValue(":address", static_cast(handler)->getPeerAddress().toString()); - ipBanQuery->bindValue(":address2", static_cast(handler)->getPeerAddress().toString()); - if (!execSqlQuery(ipBanQuery)) { - qDebug("Login denied: SQL error"); - return NotLoggedIn; - } - - if (ipBanQuery->next()) { - const int secondsLeft = -ipBanQuery->value(0).toInt(); - const bool permanentBan = ipBanQuery->value(1).toInt(); - if ((secondsLeft > 0) || permanentBan) { - reasonStr = ipBanQuery->value(2).toString(); - banSecondsLeft = permanentBan ? 0 : secondsLeft; - qDebug("Login denied: banned by address"); - return UserIsBanned; - } - } - - QSqlQuery *nameBanQuery = prepareQuery("select time_to_sec(timediff(now(), date_add(b.time_from, interval b.minutes minute))), b.minutes <=> 0, b.visible_reason from {prefix}_bans b where b.time_from = (select max(c.time_from) from {prefix}_bans c where c.user_name = :name2) and b.user_name = :name1"); - nameBanQuery->bindValue(":name1", user); - nameBanQuery->bindValue(":name2", user); - if (!execSqlQuery(nameBanQuery)) { - qDebug("Login denied: SQL error"); - return NotLoggedIn; - } - - if (nameBanQuery->next()) { - const int secondsLeft = -nameBanQuery->value(0).toInt(); - const bool permanentBan = nameBanQuery->value(1).toInt(); - if ((secondsLeft > 0) || permanentBan) { - reasonStr = nameBanQuery->value(2).toString(); - banSecondsLeft = permanentBan ? 0 : secondsLeft; - qDebug("Login denied: banned by name"); - return UserIsBanned; - } - } + if (checkUserIsBanned(handler->getAddress(), user, reasonStr, banSecondsLeft)) + return UserIsBanned; QSqlQuery *passwordQuery = prepareQuery("select password_sha512 from {prefix}_users where name = :name and active = 1"); passwordQuery->bindValue(":name", user); @@ -188,6 +197,79 @@ AuthenticationResult Servatrice_DatabaseInterface::checkUserPassword(Server_Prot return UnknownUser; } +bool Servatrice_DatabaseInterface::checkUserIsBanned(const QString &ipAddress, const QString &userName, QString &banReason, int &banSecondsRemaining) +{ + if (server->getAuthenticationMethod() != Servatrice::AuthenticationSql) + return false; + + if (!checkSql()) { + qDebug("Failed to check if user is banned. Database invalid."); + return false; + } + + return + checkUserIsIpBanned(ipAddress, banReason, banSecondsRemaining) + || checkUserIsNameBanned(userName, banReason, banSecondsRemaining); + +} + +bool Servatrice_DatabaseInterface::checkUserIsNameBanned(const QString &userName, QString &banReason, int &banSecondsRemaining) +{ + QSqlQuery *nameBanQuery = prepareQuery("select time_to_sec(timediff(now(), date_add(b.time_from, interval b.minutes minute))), b.minutes <=> 0, b.visible_reason from {prefix}_bans b where b.time_from = (select max(c.time_from) from {prefix}_bans c where c.user_name = :name2) and b.user_name = :name1"); + nameBanQuery->bindValue(":name1", userName); + nameBanQuery->bindValue(":name2", userName); + if (!execSqlQuery(nameBanQuery)) { + qDebug() << "Name ban check failed: SQL error" << nameBanQuery->lastError(); + return false; + } + + if (nameBanQuery->next()) { + const int secondsLeft = -nameBanQuery->value(0).toInt(); + const bool permanentBan = nameBanQuery->value(1).toInt(); + if ((secondsLeft > 0) || permanentBan) { + banReason = nameBanQuery->value(2).toString(); + banSecondsRemaining = permanentBan ? 0 : secondsLeft; + qDebug() << "Username" << userName << "is banned by name"; + return true; + } + } + return false; +} + +bool Servatrice_DatabaseInterface::checkUserIsIpBanned(const QString &ipAddress, QString &banReason, int &banSecondsRemaining) +{ + QSqlQuery *ipBanQuery = prepareQuery( + "select" + " time_to_sec(timediff(now(), date_add(b.time_from, interval b.minutes minute)))," + " b.minutes <=> 0," + " b.visible_reason" + " from {prefix}_bans b" + " where" + " b.time_from = (select max(c.time_from)" + " from {prefix}_bans c" + " where c.ip_address = :address)" + " and b.ip_address = :address2"); + + ipBanQuery->bindValue(":address", ipAddress); + ipBanQuery->bindValue(":address2", ipAddress); + if (!execSqlQuery(ipBanQuery)) { + qDebug() << "IP ban check failed: SQL error." << ipBanQuery->lastError(); + return false; + } + + if (ipBanQuery->next()) { + const int secondsLeft = -ipBanQuery->value(0).toInt(); + const bool permanentBan = ipBanQuery->value(1).toInt(); + if ((secondsLeft > 0) || permanentBan) { + banReason = ipBanQuery->value(2).toString(); + banSecondsRemaining = permanentBan ? 0 : secondsLeft; + qDebug() << "User is banned by address" << ipAddress; + return true; + } + } + return false; +} + bool Servatrice_DatabaseInterface::userExists(const QString &user) { if (server->getAuthenticationMethod() == Servatrice::AuthenticationSql) { @@ -579,4 +661,4 @@ void Servatrice_DatabaseInterface::logMessage(const int senderId, const QString query->bindValue(":target_id", (targetType == MessageTargetChat && targetId < 1) ? QVariant() : targetId); query->bindValue(":target_name", targetName); execSqlQuery(query); -} +} \ No newline at end of file diff --git a/servatrice/src/servatrice_database_interface.h b/servatrice/src/servatrice_database_interface.h index 8850d7be..d352ba7a 100644 --- a/servatrice/src/servatrice_database_interface.h +++ b/servatrice/src/servatrice_database_interface.h @@ -4,6 +4,7 @@ #include #include #include +#include #include "server.h" #include "server_database_interface.h" @@ -18,15 +19,22 @@ private: QHash preparedStatements; Servatrice *server; ServerInfo_User evalUserQueryResult(const QSqlQuery *query, bool complete, bool withId = false); - bool usernameIsValid(const QString &user); + /** Must be called after checkSql and server is known to be in auth mode. */ + bool checkUserIsIpBanned(const QString &ipAddress, QString &banReason, int &banSecondsRemaining); + /** Must be called after checkSql and server is known to be in auth mode. */ + bool checkUserIsNameBanned(QString const &userName, QString &banReason, int &banSecondsRemaining); + QChar getGenderChar(ServerInfo_User_Gender const &gender); protected: + bool usernameIsValid(const QString &user); AuthenticationResult checkUserPassword(Server_ProtocolHandler *handler, const QString &user, const QString &password, QString &reasonStr, int &secondsLeft); + bool checkUserIsBanned(const QString &ipAddress, const QString &userName, QString &banReason, int &banSecondsRemaining); public slots: void initDatabase(const QSqlDatabase &_sqlDatabase); public: Servatrice_DatabaseInterface(int _instanceId, Servatrice *_server); ~Servatrice_DatabaseInterface(); - void initDatabase(const QString &type, const QString &hostName, const QString &databaseName, const QString &userName, const QString &password); + bool initDatabase(const QString &type, const QString &hostName, const QString &databaseName, + const QString &userName, const QString &password); bool openDatabase(); bool checkSql(); QSqlQuery * prepareQuery(const QString &queryText); @@ -55,6 +63,7 @@ public: bool userSessionExists(const QString &userName); bool getRequireRegistration(); + bool registerUser(const QString &userName, const QString &realName, ServerInfo_User_Gender const &gender, const QString &passwordSha512, const QString &emailAddress, const QString &country, bool active = false); void logMessage(const int senderId, const QString &senderName, const QString &senderIp, const QString &logMessage, LogMessage_TargetType targetType, const int targetId, const QString &targetName); };