From 4a9cc19b23554f6e959ddb26d2697aa4b8aed569 Mon Sep 17 00:00:00 2001 From: dijunkun Date: Wed, 11 Oct 2023 16:03:22 +0800 Subject: [PATCH] Add feature password --- .../remote_desk/remote_desk_gui/main.cpp | 4 +- application/signal_server/signal_server.cpp | 37 ++++++-- .../signal_server/transmission_manager.cpp | 54 +++++++++++ .../signal_server/transmission_manager.h | 7 ++ src/interface/x.h | 6 +- src/pc/peer_connection.cpp | 91 +++++++++++-------- src/pc/peer_connection.h | 10 +- src/rtc/x_inner.cpp | 17 ++-- 8 files changed, 168 insertions(+), 58 deletions(-) diff --git a/application/remote_desk/remote_desk_gui/main.cpp b/application/remote_desk/remote_desk_gui/main.cpp index 60862f7..e40c1b5 100644 --- a/application/remote_desk/remote_desk_gui/main.cpp +++ b/application/remote_desk/remote_desk_gui/main.cpp @@ -496,7 +496,7 @@ int main() { std::string user_id = "S-" + std::string(GetMac(mac_addr)); if (strcmp(online_label, "Online") == 0) { - CreateConnection(peer_server, mac_addr); + CreateConnection(peer_server, mac_addr, server_password); nv12_buffer_ = new char[NV12_BUFFER_SIZE]; #ifdef _WIN32 @@ -572,7 +572,7 @@ int main() { if (ImGui::Button(connect_label)) { if (strcmp(connect_label, "Connect") == 0 && !joined) { std::string user_id = "C-" + std::string(GetMac(mac_addr)); - JoinConnection(peer_client, remote_id); + JoinConnection(peer_client, remote_id, client_password); joined = true; } else if (strcmp(connect_label, "Disconnect") == 0 && joined) { LeaveConnection(peer_client); diff --git a/application/signal_server/signal_server.cpp b/application/signal_server/signal_server.cpp index 1a1cd33..d26308d 100644 --- a/application/signal_server/signal_server.cpp +++ b/application/signal_server/signal_server.cpp @@ -130,6 +130,7 @@ void SignalServer::on_message(websocketpp::connection_hdl hdl, switch (HASH_STRING_PIECE(type.c_str())) { case "create_transmission"_H: { std::string transmission_id = j["transmission_id"].get(); + std::string password = j["password"].get(); std::string user_id = j["user_id"].get(); LOG_INFO("Receive user id [{}] create transmission request with id [{}]", user_id, transmission_id); @@ -151,6 +152,8 @@ void SignalServer::on_message(websocketpp::connection_hdl hdl, transmission_manager_.BindUserIdToTransmission(user_id, transmission_id); transmission_manager_.BindUserIdToWsHandle(user_id, hdl); + transmission_manager_.BindPasswordToTransmission(password, + transmission_id); LOG_INFO("Create transmission id [{}]", transmission_id); json message = {{"type", "transmission_id"}, @@ -190,16 +193,36 @@ void SignalServer::on_message(websocketpp::connection_hdl hdl, } case "query_user_id_list"_H: { std::string transmission_id = j["transmission_id"].get(); - std::vector user_id_list = - transmission_manager_.GetAllUserIdOfTransmission(transmission_id); + std::string password = j["password"].get(); - json message = {{"type", "user_id_list"}, - {"transmission_id", transmission_id}, - {"user_id_list", user_id_list}, - {"status", "success"}}; + if (transmission_manager_.CheckPassword(password, transmission_id)) { + std::vector user_id_list = + transmission_manager_.GetAllUserIdOfTransmission(transmission_id); + + json message = {{"type", "user_id_list"}, + {"transmission_id", transmission_id}, + {"user_id_list", user_id_list}, + {"status", "success"}}; + + send_msg(hdl, message); + } else { + std::vector user_id_list; + json message = {{"type", "user_id_list"}, + {"transmission_id", transmission_id}, + {"user_id_list", user_id_list}, + {"status", "failed"}, + {"reason", "Incorrect password"}}; + // LOG_INFO( + // "Incorrect password [{}] for transmission [{}] with password is " + // "[{}]", + // password, transmission_id, + // transmission_manager_.GetPassword(transmission_id)); + + send_msg(hdl, message); + } // LOG_INFO("Send member_list: [{}]", message.dump()); - send_msg(hdl, message); + break; } case "offer"_H: { diff --git a/application/signal_server/transmission_manager.cpp b/application/signal_server/transmission_manager.cpp index 97c561c..deb8580 100644 --- a/application/signal_server/transmission_manager.cpp +++ b/application/signal_server/transmission_manager.cpp @@ -40,6 +40,25 @@ bool TransmissionManager::BindUserIdToTransmission( return true; } +bool TransmissionManager::BindPasswordToTransmission( + const std::string& password, const std::string& transmission_id) { + if (transmission_password_list_.find(transmission_id) == + transmission_password_list_.end()) { + transmission_password_list_[transmission_id] = password; + // LOG_INFO("Bind password [{}] to transmission [{}]", password, + // transmission_id); + return true; + } else { + auto old_password = transmission_password_list_[transmission_id]; + transmission_password_list_[transmission_id] = password; + // LOG_WARN("Update password [{}] to [{}] for transmission [{}]", + // old_password, password, transmission_id); + return true; + } + + return false; +} + bool TransmissionManager::BindUserIdToWsHandle( const std::string& user_id, websocketpp::connection_hdl hdl) { if (user_id_ws_hdl_list_.find(user_id) != user_id_ws_hdl_list_.end()) { @@ -94,6 +113,19 @@ bool TransmissionManager::ReleaseAllUserIdFromTransmission( return true; } +bool TransmissionManager::ReleasePasswordFromTransmission( + const std::string& transmission_id) { + if (transmission_password_list_.end() == + transmission_password_list_.find(transmission_id)) { + LOG_ERROR("No transmission with id [{}]", transmission_id); + return false; + } + + transmission_password_list_.erase(transmission_id); + + return true; +} + websocketpp::connection_hdl TransmissionManager::GetWsHandle( const std::string& user_id) { if (user_id_ws_hdl_list_.find(user_id) != user_id_ws_hdl_list_.end()) { @@ -111,4 +143,26 @@ std::string TransmissionManager::GetUserId(websocketpp::connection_hdl hdl) { if (it->second.lock().get() == hdl.lock().get()) return it->first; } return ""; +} + +bool TransmissionManager::CheckPassword(const std::string& password, + const std::string& transmission_id) { + if (transmission_password_list_.find(transmission_id) == + transmission_password_list_.end()) { + LOG_ERROR("No transmission with id [{}]", transmission_id); + return false; + } + + return transmission_password_list_[transmission_id] == password; +} + +std::string TransmissionManager::GetPassword( + const std::string& transmission_id) { + if (transmission_password_list_.find(transmission_id) == + transmission_password_list_.end()) { + LOG_ERROR("No transmission with id [{}]", transmission_id); + return ""; + } + + return transmission_password_list_[transmission_id]; } \ No newline at end of file diff --git a/application/signal_server/transmission_manager.h b/application/signal_server/transmission_manager.h index f262748..499e4ff 100644 --- a/application/signal_server/transmission_manager.h +++ b/application/signal_server/transmission_manager.h @@ -17,17 +17,24 @@ class TransmissionManager { public: bool BindUserIdToTransmission(const std::string& user_id, const std::string& transmission_id); + bool BindPasswordToTransmission(const std::string& password, + const std::string& transmission_id); bool BindUserIdToWsHandle(const std::string& user_id, websocketpp::connection_hdl hdl); std::string ReleaseUserIdFromTransmission(websocketpp::connection_hdl hdl); bool ReleaseAllUserIdFromTransmission(const std::string& transmission_id); + bool ReleasePasswordFromTransmission(const std::string& transmission_id); websocketpp::connection_hdl GetWsHandle(const std::string& user_id); std::string GetUserId(websocketpp::connection_hdl hdl); + bool CheckPassword(const std::string& password, + const std::string& transmission_id); + std::string GetPassword(const std::string& transmission_id); private: std::map> transmission_user_id_list_; + std::map transmission_password_list_; std::map user_id_ws_hdl_list_; }; diff --git a/src/interface/x.h b/src/interface/x.h index 5c45f98..349067c 100644 --- a/src/interface/x.h +++ b/src/interface/x.h @@ -29,9 +29,11 @@ PeerPtr* CreatePeer(const Params* params); int Init(PeerPtr* peer_ptr, const char* user_id); -int CreateConnection(PeerPtr* peer_ptr, const char* transmission_id); +int CreateConnection(PeerPtr* peer_ptr, const char* transmission_id, + const char* password); -int JoinConnection(PeerPtr* peer_ptr, const char* transmission_id); +int JoinConnection(PeerPtr* peer_ptr, const char* transmission_id, + const char* password); int LeaveConnection(PeerPtr* peer_ptr); diff --git a/src/pc/peer_connection.cpp b/src/pc/peer_connection.cpp index 084387e..a4db350 100644 --- a/src/pc/peer_connection.cpp +++ b/src/pc/peer_connection.cpp @@ -180,12 +180,17 @@ int PeerConnection::CreateVideoCodec(bool hardware_acceleration) { } int PeerConnection::Create(PeerConnectionParams params, - const std::string &transmission_id) { + const std::string &transmission_id, + const std::string &password) { int ret = 0; + password_ = password; + json message = {{"type", "create_transmission"}, {"user_id", user_id_}, - {"transmission_id", transmission_id}}; + {"transmission_id", transmission_id}, + {"password", password}}; + if (ws_transport_) { ws_transport_->Send(message.dump()); LOG_INFO("Send create transmission request, transmission_id [{}]", @@ -195,11 +200,14 @@ int PeerConnection::Create(PeerConnectionParams params, } int PeerConnection::Join(PeerConnectionParams params, - const std::string &transmission_id) { + const std::string &transmission_id, + const std::string &password) { int ret = 0; + password_ = password; + transmission_id_ = transmission_id; - ret = RequestTransmissionMemberList(transmission_id_); + ret = RequestTransmissionMemberList(transmission_id_, password); return ret; } @@ -246,41 +254,47 @@ void PeerConnection::ProcessSignal(const std::string &signal) { } case "user_id_list"_H: { user_id_list_ = j["user_id_list"]; - std::string transmission_id = j["transmission_id"]; - if (user_id_list_.empty()) { - LOG_WARN("Wait for host create transmission [{}]", transmission_id); - RequestTransmissionMemberList(transmission_id); - break; - } - - LOG_INFO("Transmission [{}] members: [", transmission_id); - for (auto user_id : user_id_list_) { - LOG_INFO("{}", user_id); - } - LOG_INFO("]"); - - for (auto &remote_user_id : user_id_list_) { - if (remote_user_id == user_id_) { - continue; + std::string transmission_id = j["transmission_id"].get(); + std::string status = j["status"].get(); + if (status == "failed") { + std::string reason = j["reason"].get(); + LOG_ERROR("{}", reason); + } else { + if (user_id_list_.empty()) { + LOG_WARN("Wait for host create transmission [{}]", transmission_id); + RequestTransmissionMemberList(transmission_id, password_); + break; } - ice_transmission_list_[remote_user_id] = - std::make_unique(true, transmission_id, user_id_, - remote_user_id, ws_transport_, - on_ice_status_change_); - ice_transmission_list_[remote_user_id]->SetOnReceiveVideoFunc( - on_receive_video_); - ice_transmission_list_[remote_user_id]->SetOnReceiveAudioFunc( - on_receive_audio_); - ice_transmission_list_[remote_user_id]->SetOnReceiveDataFunc( - on_receive_data_); + LOG_INFO("Transmission [{}] members: [", transmission_id); + for (auto user_id : user_id_list_) { + LOG_INFO("{}", user_id); + } + LOG_INFO("]"); - ice_transmission_list_[remote_user_id]->InitIceTransmission( - cfg_stun_server_ip_, stun_server_port_, cfg_turn_server_ip_, - turn_server_port_, cfg_turn_server_username_, - cfg_turn_server_password_); - ice_transmission_list_[remote_user_id]->JoinTransmission(); + for (auto &remote_user_id : user_id_list_) { + if (remote_user_id == user_id_) { + continue; + } + ice_transmission_list_[remote_user_id] = + std::make_unique(true, transmission_id, user_id_, + remote_user_id, ws_transport_, + on_ice_status_change_); + + ice_transmission_list_[remote_user_id]->SetOnReceiveVideoFunc( + on_receive_video_); + ice_transmission_list_[remote_user_id]->SetOnReceiveAudioFunc( + on_receive_audio_); + ice_transmission_list_[remote_user_id]->SetOnReceiveDataFunc( + on_receive_data_); + + ice_transmission_list_[remote_user_id]->InitIceTransmission( + cfg_stun_server_ip_, stun_server_port_, cfg_turn_server_ip_, + turn_server_port_, cfg_turn_server_username_, + cfg_turn_server_password_); + ice_transmission_list_[remote_user_id]->JoinTransmission(); + } } break; @@ -298,7 +312,7 @@ void PeerConnection::ProcessSignal(const std::string &signal) { if (std::string::npos != user_id.find("S-")) { LOG_INFO("Server leaves, try to rejoin transmission"); - RequestTransmissionMemberList(transmission_id_); + RequestTransmissionMemberList(transmission_id_, password_); } } break; @@ -365,11 +379,12 @@ void PeerConnection::ProcessSignal(const std::string &signal) { } int PeerConnection::RequestTransmissionMemberList( - const std::string &transmission_id) { + const std::string &transmission_id, const std::string &password) { LOG_INFO("Request member list"); json message = {{"type", "query_user_id_list"}, - {"transmission_id", transmission_id_}}; + {"transmission_id", transmission_id_}, + {"password", password}}; if (ws_transport_) { ws_transport_->Send(message.dump()); diff --git a/src/pc/peer_connection.h b/src/pc/peer_connection.h index a3663a3..c64e46b 100644 --- a/src/pc/peer_connection.h +++ b/src/pc/peer_connection.h @@ -34,9 +34,11 @@ class PeerConnection { int Init(PeerConnectionParams params, const std::string &user_id); int Create(PeerConnectionParams params, - const std::string &transmission_id = ""); + const std::string &transmission_id = "", + const std::string &password = ""); - int Join(PeerConnectionParams params, const std::string &transmission_id); + int Join(PeerConnectionParams params, const std::string &transmission_id, + const std::string &password = ""); int Leave(); @@ -53,7 +55,8 @@ class PeerConnection { void ProcessSignal(const std::string &signal); - int RequestTransmissionMemberList(const std::string &transmission_id); + int RequestTransmissionMemberList(const std::string &transmission_id, + const std::string &password); private: std::string uri_ = ""; @@ -98,6 +101,7 @@ class PeerConnection { OnReceiveBuffer on_receive_data_buffer_; char *nv12_data_ = nullptr; bool inited_ = false; + std::string password_; private: std::unique_ptr video_encoder_ = nullptr; diff --git a/src/rtc/x_inner.cpp b/src/rtc/x_inner.cpp index c42c54d..b9603d1 100644 --- a/src/rtc/x_inner.cpp +++ b/src/rtc/x_inner.cpp @@ -29,15 +29,20 @@ int Init(PeerPtr *peer_ptr, const char *user_id) { return 0; } -int CreateConnection(PeerPtr *peer_ptr, const char *transmission_id) { - peer_ptr->peer_connection->Create(peer_ptr->pc_params, transmission_id); - LOG_INFO("CreateConnection"); +int CreateConnection(PeerPtr *peer_ptr, const char *transmission_id, + const char *password) { + peer_ptr->peer_connection->Create(peer_ptr->pc_params, transmission_id, + password); + LOG_INFO("CreateConnection [{}] with password [{}]", transmission_id, + password); return 0; } -int JoinConnection(PeerPtr *peer_ptr, const char *transmission_id) { - peer_ptr->peer_connection->Join(peer_ptr->pc_params, transmission_id); - LOG_INFO("JoinConnection"); +int JoinConnection(PeerPtr *peer_ptr, const char *transmission_id, + const char *password) { + peer_ptr->peer_connection->Join(peer_ptr->pc_params, transmission_id, + password); + LOG_INFO("JoinConnection[{}] with password [{}]", transmission_id, password); return 0; }