Files
cv_state_ana/reproject/HeartRateMonitor.py
2026-01-27 00:09:11 +08:00

152 lines
5.2 KiB
Python

import numpy as np
import collections
from scipy import signal
class HeartRateMonitor:
def __init__(self, fps=30, window_size=300):
self.fps = fps
self.buffer_size = window_size
# 存储 RGB 三个通道的均值
self.times = np.zeros(window_size)
self.r_buffer = collections.deque(maxlen=window_size)
self.g_buffer = collections.deque(maxlen=window_size)
self.b_buffer = collections.deque(maxlen=window_size)
# 滤波器状态
# 修改: 将最低频率从 0.75(45 BPM) 提高到 0.9(54 BPM) 以过滤低频噪声
self.bp_b, self.bp_a = self._create_bandpass_filter(0.9, 2.5, fps) # 54-150 BPM
# 平滑结果用的
self.bpm_history = collections.deque(maxlen=10)
# 优化: 降频计算
self.frame_counter = 0
self.process_interval = 15 # 每15帧(约0.5s)计算一次,其他时间收集数据
self.last_bpm = None
def _create_bandpass_filter(self, lowcut, highcut, fs, order=5):
"""创建巴特沃斯带通滤波器"""
nyq = 0.5 * fs
low = lowcut / nyq
high = highcut / nyq
b, a = signal.butter(order, [low, high], btype='band')
return b, a
def _pos_algorithm(self, r, g, b):
"""
POS (Plane-Orthogonal-to-Skin) 算法
比单纯的绿色通道法强在一个地方:抗运动干扰
"""
# 1. 归一化 (除以均值)
# 加上 1e-6 防止除零
r_n = r / (np.mean(r) + 1e-6)
g_n = g / (np.mean(g) + 1e-6)
b_n = b / (np.mean(b) + 1e-6)
# 2. 投影到色度平面 (Matplotlib 里的经典公式)
# S1 = G - B
# S2 = G + B - 2R
s1 = g_n - b_n
s2 = g_n + b_n - 2 * r_n
# 3. Alpha 微调 (Alpha Tuning)
# 这一步是为了消除镜面反射带来的运动噪声
alpha = np.std(s1) / (np.std(s2) + 1e-6)
# 4. 融合信号
h = s1 + alpha * s2
return h
def process_frame(self, frame, face_loc):
"""
输入: 原始无损 frame (BGR), 人脸框 (top, right, bottom, left)
输出: BPM 数值 或 None (数据不够时)
"""
top, right, bottom, left = face_loc
# --- 1. ROI 提取与保护 ---
h_img, w_img = frame.shape[:2]
# 缩小 ROI 范围:只取脸中心 50% 区域 (避开背景和边缘)
h_box = bottom - top
w_box = right - left
# 修正 ROI 坐标
roi_top = int(max(0, top + h_box * 0.3))
roi_bottom = int(min(h_img, bottom - h_box * 0.3))
roi_left = int(max(0, left + w_box * 0.3))
roi_right = int(min(w_img, right - w_box * 0.3))
roi = frame[roi_top:roi_bottom, roi_left:roi_right]
if roi.size == 0:
return None
# --- 2. 提取 RGB 均值 ---
# OpenCV 是 BGR
b_mean = np.mean(roi[:, :, 0])
g_mean = np.mean(roi[:, :, 1])
r_mean = np.mean(roi[:, :, 2])
self.r_buffer.append(r_mean)
self.g_buffer.append(g_mean)
self.b_buffer.append(b_mean)
# 数据不够,返回 None
if len(self.r_buffer) < self.buffer_size:
progress = int(len(self.r_buffer) / self.buffer_size * 100)
return None # 或者返回 progress 表示进度
# [优化] 降频计算策略
# 我们每帧都需要收集数据(Buffer append),但不需要每帧都做 FFT
self.frame_counter += 1
if self.frame_counter % self.process_interval != 0:
return self.last_bpm
# --- 3. 信号处理 (核心升级部分) ---
r = np.array(self.r_buffer)
g = np.array(self.g_buffer)
b = np.array(self.b_buffer)
# A. 使用 POS 算法融合三通道 (抗干扰)
pulse_signal = self._pos_algorithm(r, g, b)
# B. 消除直流分量 (Detrending)
# 这一步去掉了光线缓慢变化的干扰
pulse_signal = signal.detrend(pulse_signal)
# C. 带通滤波 (Bandpass Filter)
# 只保留 0.75Hz - 2.5Hz 的信号
pulse_signal = signal.filtfilt(self.bp_b, self.bp_a, pulse_signal)
# --- 4. 频域分析 (FFT) ---
# 加汉宁窗 (减少频谱泄露)
window = np.hanning(len(pulse_signal))
pulse_signal_windowed = pulse_signal * window
# FFT
fft_res = np.fft.rfft(pulse_signal_windowed)
freqs = np.fft.rfftfreq(len(pulse_signal), 1.0/self.fps)
mag = np.abs(fft_res)
# D. 寻找峰值
# 限制频率范围 (54 BPM - 180 BPM)
interest_idx = np.where((freqs >= 0.9) & (freqs <= 3.0))
valid_freqs = freqs[interest_idx]
valid_mags = mag[interest_idx]
if len(valid_mags) == 0:
return self.last_bpm
max_idx = np.argmax(valid_mags)
peak_freq = valid_freqs[max_idx]
bpm = peak_freq * 60.0
# --- 5. 结果平滑 ---
# 防止数字乱跳
self.bpm_history.append(bpm)
avg_bpm = np.mean(self.bpm_history)
self.last_bpm = int(avg_bpm)
return int(avg_bpm)