This commit is contained in:
2026-05-21 18:24:46 +08:00
parent e5c843334e
commit 4d347f1747
3 changed files with 245 additions and 247 deletions

View File

@@ -2,10 +2,8 @@ import cv2
import mediapipe as mp import mediapipe as mp
import time import time
import numpy as np import numpy as np
import threading
import queue import queue
import multiprocessing as mp_proc import multiprocessing as mp_proc
from multiprocessing import shared_memory
from collections import deque from collections import deque
from geometry_utils import ( from geometry_utils import (
calculate_ear, calculate_ear,
@@ -65,46 +63,20 @@ class MonitorSystem:
self.current_emotion = "Neutral" self.current_emotion = "Neutral"
self.frame_shape = (720, 1280, 3) self.frame_shape = (720, 1280, 3)
frame_size = int(np.prod(self.frame_shape))
# 必须先解除可能存在的残留 (Windows上有时不需要但保持好习惯) # 使用 spawn 避免 fork 复制 OpenCV/MediaPipe/ONNXRuntime 等 C++ 运行时状态。
# 最好是随机生成一个名字,确保每次运行都是新的 self.mp_ctx = mp_proc.get_context("spawn")
import secrets self.task_queue = self.mp_ctx.Queue(maxsize=2)
auth_key = secrets.token_hex(4) self.result_queue = self.mp_ctx.Queue(maxsize=2)
shm_unique_name = f"monitor_shm_{auth_key}"
try: self.worker_proc = self.mp_ctx.Process(
self.shm = shared_memory.SharedMemory(create=True, size=frame_size, name=shm_unique_name)
except FileExistsError:
# 如果真的点背碰上了,就 connect 这一块
self.shm = shared_memory.SharedMemory(name=shm_unique_name)
print(f"[Main] 共享内存已创建: {self.shm.name} (Size: {frame_size} bytes)")
# 本地 numpy 包装器
self.shared_frame_array = np.ndarray(
self.frame_shape, dtype=np.uint8, buffer=self.shm.buf
)
# 初始化为全黑,避免噪音
self.shared_frame_array.fill(0)
# 跨进程队列
self.task_queue = mp_proc.Queue(maxsize=2)
self.result_queue = mp_proc.Queue(maxsize=2) # 1就够了最新的覆盖
# 3. 启动进程
# Windows下传参只传名字字符串是安全的
self.worker_proc = mp_proc.Process(
target=background_worker_process, target=background_worker_process,
args=( args=(
self.shm.name,
self.frame_shape,
self.task_queue, self.task_queue,
self.result_queue, self.result_queue,
face_db, face_db,
), ),
) )
self.worker_proc.daemon = True
self.worker_proc.start() self.worker_proc.start()
def _get_smoothed_value(self, history, current_val): def _get_smoothed_value(self, history, current_val):
@@ -125,12 +97,6 @@ class MonitorSystem:
frame = cv2.resize(frame, (target_w, target_h)) frame = cv2.resize(frame, (target_w, target_h))
h, w = frame.shape[:2] h, w = frame.shape[:2]
# 现在肯定匹配了,放心写入
try:
self.shared_frame_array[:] = frame[:]
except Exception:
# 极端情况:数组形状不匹配 (比如通道数变了)
pass
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
@@ -257,9 +223,7 @@ class MonitorSystem:
max(0, sface_loc[3] - spad), max(0, sface_loc[3] - spad),
) )
if self.task_queue.full(): self._put_latest_task((0, frame.copy(), sface_loc))
self.task_queue.get()
self.task_queue.put((sface_loc, 0))
self.last_identity_check_time = now self.last_identity_check_time = now
@@ -283,20 +247,14 @@ class MonitorSystem:
y_max = min(h, y_max + pad_y) y_max = min(h, y_max + pad_y)
face_loc = (y_min, x_max, y_max, x_min) face_loc = (y_min, x_max, y_max, x_min)
face_crop = frame[y_min:y_max, x_min:x_max].copy()
if self.task_queue.full(): if face_crop.size > 0:
self.task_queue.get() self._put_latest_task((1, face_crop, None))
self.task_queue.put((face_loc, 1))
self.last_emotion_check_time = now self.last_emotion_check_time = now
while not self.result_queue.empty(): self._drain_results()
type_, data = self.result_queue.get()
if type_ == "identity":
self.current_user = data
elif type_ == "emotion":
self.cached_emotion["label"] = data.get("emotion", "unknown")
self.cached_emotion["va"] = data.get("vaVal", (0.0, 0.0))
analysis_data["identity"] = self.current_user analysis_data["identity"] = self.current_user
analysis_data["emotion_label"] = self.cached_emotion["label"] analysis_data["emotion_label"] = self.cached_emotion["label"]
@@ -304,6 +262,52 @@ class MonitorSystem:
return analysis_data return analysis_data
def _put_latest_task(self, task):
try:
if self.task_queue.full():
self.task_queue.get_nowait()
self.task_queue.put_nowait(task)
except queue.Full:
try:
self.task_queue.get_nowait()
self.task_queue.put_nowait(task)
except (queue.Empty, queue.Full):
pass
except queue.Empty:
pass
def _drain_results(self):
while True:
try:
type_, data = self.result_queue.get_nowait()
except queue.Empty:
break
if type_ == "identity":
self.current_user = data
elif type_ == "emotion":
self.cached_emotion["label"] = data.get("emotion", "unknown")
self.cached_emotion["va"] = data.get("vaVal", (0.0, 0.0))
def close(self):
try:
self._put_latest_task(None)
except Exception:
pass
if self.worker_proc.is_alive():
self.worker_proc.join(timeout=3)
if self.worker_proc.is_alive():
print("[Worker] 未正常退出,强制结束")
self.worker_proc.terminate()
self.worker_proc.join(timeout=2)
try:
self.face_mesh.close()
except Exception:
pass
# def _id_emo_loop(self): # def _id_emo_loop(self):
# while True: # while True:
# try: # try:
@@ -337,16 +341,10 @@ class MonitorSystem:
def background_worker_process( def background_worker_process(
shm_name, # 共享内存的名字
frame_shape, # 图像大小 (h, w, 3)
task_queue, # 任务队列 (主 -> 从) task_queue, # 任务队列 (主 -> 从)
result_queue, # 结果队列 (从 -> 主) result_queue, # 结果队列 (从 -> 主)
face_db_data, # 把人脸库数据传过去初始化 face_db_data, # 把人脸库数据传过去初始化
): ):
existing_shm = shared_memory.SharedMemory(name=shm_name)
# 创建 numpy 数组视图,无需复制数据
shared_frame = np.ndarray(frame_shape, dtype=np.uint8, buffer=existing_shm.buf)
print("[Worker] 正在加载模型...") print("[Worker] 正在加载模型...")
from face_library import FaceLibrary from face_library import FaceLibrary
@@ -362,39 +360,38 @@ def background_worker_process(
while True: while True:
try: try:
# 阻塞等待任务 task = task_queue.get()
# task_info = (task_type, face_loc) if task is None:
face_loc, task_type = task_queue.get() break
# 注意:这里读取的是共享内存里的图,不需要传图! task_type, frame_data, face_loc = task
# 切片操作也是零拷贝
# 为了安全,这里 copy 一份出来处理,避免主进程修改
# 但实际上如果主进程只写新帧,这里读旧帧也问题不大
# 为了绝对安全和解耦,我们假定主进程已经写入了对应的帧
# (实战技巧:通常我们会用一个信号量或多块共享内存来实现乒乓缓存)
# 简化版:我们直接从 shared_frame 读。
# 由于主进程跑得快可能SharedMemory里已经是下一帧了。
# 但对于识别身份来说,差一两帧根本没区别!这才是优化的精髓。
current_frame_view = shared_frame.copy() # .copy() 如果你怕读写冲突
if task_type == 0: # Identity if task_type == 0: # Identity
# RGB转换 rgb = cv2.cvtColor(frame_data, cv2.COLOR_BGR2RGB)
rgb = cv2.cvtColor(current_frame_view, cv2.COLOR_BGR2RGB)
res = face_lib.identify(rgb, face_location=face_loc) res = face_lib.identify(rgb, face_location=face_loc)
if res: if res:
result_queue.put(("identity", res["info"])) _put_latest_result(result_queue, ("identity", res["info"]))
elif task_type == 1 and has_emo: # Emotion elif task_type == 1 and has_emo: # Emotion
# BGR 直接切 if frame_data.size > 0:
roi = current_frame_view[ emo_res = analyze_emotion_with_hsemotion(frame_data)
face_loc[0] : face_loc[2], face_loc[3] : face_loc[1]
]
if roi.size > 0:
emo_res = analyze_emotion_with_hsemotion(roi)
if emo_res: if emo_res:
result_queue.put(("emotion", emo_res[0])) _put_latest_result(result_queue, ("emotion", emo_res[0]))
except Exception as e: except Exception as e:
print(f"[Worker Error] {e}") print(f"[Worker Error] {e}")
def _put_latest_result(result_queue, result):
try:
if result_queue.full():
result_queue.get_nowait()
result_queue.put_nowait(result)
except queue.Full:
try:
result_queue.get_nowait()
result_queue.put_nowait(result)
except (queue.Empty, queue.Full):
pass
except queue.Empty:
pass

View File

@@ -15,8 +15,8 @@ from webrtc_server import WebRTCServer
from HeartRateMonitor import HeartRateMonitor from HeartRateMonitor import HeartRateMonitor
from hook_mocker import HookMocker from hook_mocker import HookMocker
API_URL = "http://10.128.48.48:5000/api/states" API_URL = "http://10.128.48.204:5000/api/states"
CAMERA_ID = 5 CAMERA_ID = 3
BASIC_FACE_DB = { BASIC_FACE_DB = {
"Zhihang": {"name": "Zhihang Deng", "age": 20, "image-path": "zhihang.png"}, "Zhihang": {"name": "Zhihang Deng", "age": 20, "image-path": "zhihang.png"},
@@ -42,6 +42,21 @@ ana_data_queue = queue.Queue(maxsize=10)
stop_event = threading.Event() stop_event = threading.Event()
def put_latest(q, item):
try:
if q.full():
q.get_nowait()
q.put_nowait(item)
except queue.Full:
try:
q.get_nowait()
q.put_nowait(item)
except (queue.Empty, queue.Full):
pass
except queue.Empty:
pass
def capture_thread(): def capture_thread():
""" """
采集线程:优化了分发逻辑,对视频流进行降频处理 采集线程:优化了分发逻辑,对视频流进行降频处理
@@ -97,6 +112,7 @@ def analysis_thread():
2. 队列满时丢弃旧数据,保证数据实时性。 2. 队列满时丢弃旧数据,保证数据实时性。
""" """
monitor = MonitorSystem(BASIC_FACE_DB) monitor = MonitorSystem(BASIC_FACE_DB)
try:
print("[Analysis] 分析系统启动...") print("[Analysis] 分析系统启动...")
freq = 0 freq = 0
gap = 60 gap = 60
@@ -115,12 +131,8 @@ def analysis_thread():
result["eye_close_freq"] = 0 result["eye_close_freq"] = 0
result["heart_rate"] = 0 result["heart_rate"] = 0
if video_queue.full(): put_latest(video_queue, result["frame"])
video_queue.get_nowait() put_latest(ana_video_queue, result["frame"])
video_queue.put(result["frame"])
if ana_video_queue.full():
ana_video_queue.get_nowait()
ana_video_queue.put(result["frame"])
# print(f"[Analysis] {time.strftime('%Y-%m-%d %H:%M:%S')} - Frame processed") # print(f"[Analysis] {time.strftime('%Y-%m-%d %H:%M:%S')} - Frame processed")
payload = { payload = {
@@ -178,18 +190,6 @@ def analysis_thread():
"pose_2": result["pose"][2], "pose_2": result["pose"][2],
} }
) )
# elif result["has_face"]:
# payload.update(
# {
# "name": "Unknown",
# "ear": result["ear"],
# "mar": result["mar"],
# "iris_ratio": result["iris_ratio"],
# "pose": result["pose"],
# "emo_label": result["emotion_label"],
# "emo_va": result["emotion_va"],
# }
# )
if result["has_face"] and result["ear"] < 0.2: if result["has_face"] and result["ear"] < 0.2:
if status == 0: if status == 0:
@@ -212,32 +212,15 @@ def analysis_thread():
result["heart_rate"] = bpm result["heart_rate"] = bpm
payload["heart_rate"] = bpm payload["heart_rate"] = bpm
front_data["heart_rate"] = bpm front_data["heart_rate"] = bpm
if data_queue.full():
try:
_ = data_queue.get_nowait()
except queue.Empty:
pass
if front_data_queue.full():
try:
_ = front_data_queue.get_nowait()
except queue.Empty:
pass
if ana_data_queue.full():
try:
_ = ana_data_queue.get_nowait()
except queue.Empty:
pass
data_queue.put(payload) put_latest(data_queue, payload)
put_latest(ana_data_queue, payload)
ana_data_queue.put(payload) put_latest(front_data_queue, front_data)
put_latest(show_queue, (result["frame"], result))
front_data_queue.put(front_data)
show_queue.put((result["frame"], result))
# draw_debug_info(frame, result) # draw_debug_info(frame, result)
# cv2.imshow("Monitor Client", frame) # cv2.imshow("Monitor Client", frame)
finally:
monitor.close()
print("[Analysis] 分析线程结束") print("[Analysis] 分析线程结束")
@@ -288,6 +271,13 @@ def video_stream_thread(server):
# print("✅ [Video] H.264 软编码启动成功!视频将保存为 MP4。") # print("✅ [Video] H.264 软编码启动成功!视频将保存为 MP4。")
# ----------------------------------------------------------- # -----------------------------------------------------------
out1 = cv2.VideoWriter('output1.mp4', fourcc, 30.0, (1280, 720)) out1 = cv2.VideoWriter('output1.mp4', fourcc, 30.0, (1280, 720))
if not out1.isOpened():
print("[Video] avc1 编码器不可用,回退到 mp4v")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out1 = cv2.VideoWriter('output1.mp4', fourcc, 30.0, (1280, 720))
if not out1.isOpened():
print("[Video] 视频写入初始化失败,跳过本地录像")
out1 = None
# out2 = cv2.VideoWriter('output2.mp4', fourcc, 30.0, (1280, 720)) # out2 = cv2.VideoWriter('output2.mp4', fourcc, 30.0, (1280, 720))
while not stop_event.is_set(): while not stop_event.is_set():
@@ -296,6 +286,7 @@ def video_stream_thread(server):
server.provide_frame(frame) server.provide_frame(frame)
data = front_data_queue.get(timeout=1) data = front_data_queue.get(timeout=1)
server.send_data(json.dumps(data)) server.send_data(json.dumps(data))
if out1 is not None:
out1.write(frame) out1.write(frame)
# out2.write(frame) # out2.write(frame)
except queue.Empty: except queue.Empty:
@@ -347,6 +338,7 @@ def video_stream_thread(server):
# except Exception as e: # except Exception as e:
# print(f"[Video] 重连中... {e}") # print(f"[Video] 重连中... {e}")
# time.sleep(3) # time.sleep(3)
if out1 is not None:
out1.release() out1.release()
# out2.release() # out2.release()
print("[Video] 线程结束") print("[Video] 线程结束")
@@ -541,7 +533,7 @@ def alert_thread(server):
if alert_status: if alert_status:
print(f"警报: {alert_st}") print(f"警报: {alert_st}")
alert = server.alert(int(time.time()), alert_st, info_level) alert = server.alert(int(time.time()), alert_st, info_level)
alert = HookMocker(alert, "http://10.128.48.48:5000/api/osshook") alert = HookMocker(alert, "http://10.128.48.204:5000/api/osshook")
alert.start(width=1280, height=720, fps=30) alert.start(width=1280, height=720, fps=30)
for f in buffered_frame: for f in buffered_frame:
alert.provide_frame(f) alert.provide_frame(f)
@@ -644,7 +636,7 @@ def draw_debug_info(frame, result):
if __name__ == "__main__": if __name__ == "__main__":
server = WebRTCServer(60, 5, "ws://10.128.48.48:5000") server = WebRTCServer(60, CAMERA_ID, "ws://10.128.48.204:5000")
server.start() server.start()
t1 = threading.Thread(target=capture_thread, daemon=True) t1 = threading.Thread(target=capture_thread, daemon=True)
t2 = threading.Thread(target=analysis_thread, daemon=True) t2 = threading.Thread(target=analysis_thread, daemon=True)
@@ -658,6 +650,7 @@ if __name__ == "__main__":
t4.start() t4.start()
t5.start() t5.start()
try:
try: try:
while not stop_event.is_set(): while not stop_event.is_set():
try: try:
@@ -671,11 +664,13 @@ if __name__ == "__main__":
if cv2.waitKey(1) & 0xFF == ord("q"): if cv2.waitKey(1) & 0xFF == ord("q"):
stop_event.set() stop_event.set()
# time.sleep(1) # time.sleep(1)
cv2.destroyAllWindows()
except KeyboardInterrupt: except KeyboardInterrupt:
print("停止程序...") print("停止程序...")
stop_event.set() stop_event.set()
finally:
stop_event.set()
cv2.destroyAllWindows()
server.stop()
t1.join() t1.join()
t2.join() t2.join()

View File

@@ -93,6 +93,7 @@ class WebRTCServer:
self.pcs = set() self.pcs = set()
self.fps = fps self.fps = fps
self.frameContainer = [None] self.frameContainer = [None]
self.frame_lock = threading.Lock()
self.hub = MonitoringHub() self.hub = MonitoringHub()
self.sio = None self.sio = None
@@ -161,7 +162,7 @@ class WebRTCServer:
await pc.setRemoteDescription( await pc.setRemoteDescription(
RTCSessionDescription(offer["sdp"], offer.get("type", "offer")) RTCSessionDescription(offer["sdp"], offer.get("type", "offer"))
) )
pc.addTrack(VideoFrameTrack(self.fps, self.frameContainer)) pc.addTrack(VideoFrameTrack(self.fps, self.frameContainer, self.frame_lock))
dc = pc.createDataChannel("monitoring") dc = pc.createDataChannel("monitoring")
@@ -179,10 +180,12 @@ class WebRTCServer:
asyncio.run_coroutine_threadsafe(self._websocket_start(), self.background_loop) asyncio.run_coroutine_threadsafe(self._websocket_start(), self.background_loop)
def provide_frame(self, frame): def provide_frame(self, frame):
self.frameContainer[0] = frame with self.frame_lock:
self.frameContainer[0] = frame.copy()
def send_data(self, data): def send_data(self, data):
self.hub.send_data(data) if self.background_loop.is_running():
self.background_loop.call_soon_threadsafe(self.hub.send_data, data)
def alert(self, timestamp, summary, level): def alert(self, timestamp, summary, level):
payload = { payload = {
@@ -207,10 +210,11 @@ class WebRTCServer:
class VideoFrameTrack(VideoStreamTrack): class VideoFrameTrack(VideoStreamTrack):
def __init__(self, fps, fc): def __init__(self, fps, fc, lock):
super().__init__() super().__init__()
self.fps = fps self.fps = fps
self.frameContainer = fc self.frameContainer = fc
self.frame_lock = lock
async def next_timestamp(self): async def next_timestamp(self):
""" """
@@ -230,11 +234,13 @@ class VideoFrameTrack(VideoStreamTrack):
async def recv(self): async def recv(self):
pts, time_base = await self.next_timestamp() pts, time_base = await self.next_timestamp()
with self.frame_lock:
frame = self.frameContainer[0] frame = self.frameContainer[0]
if frame is not None:
frame = frame.copy()
if frame is None: if frame is None:
frame = np.zeros((480, 640, 3), dtype=np.uint8) frame = np.zeros((480, 640, 3), dtype=np.uint8)
else: else:
frame = self.frameContainer[0]
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
video_frame = av.VideoFrame.from_ndarray(frame, format="rgb24") video_frame = av.VideoFrame.from_ndarray(frame, format="rgb24")
video_frame.pts = pts video_frame.pts = pts