```python import numpy as np import os from PIL import Image import matplotlib.pyplot as plt import matplotlib as mpl
# 定义重建函数 def restore(sigma, u, v, K): m = len(u) n = len(v[0]) a = np.zeros((m, n)) for k in range(K): uk = u[:, k].reshape(m, 1) vk = v[k].reshape(1, n) a += sigma[k] * np.dot(uk, vk) a[a <0] = 0 a[a > 255] = 255 return np.rint(a).astype('uint8')
if __name__ == "__main__": # 加载图片 img_path = 'son.png' A = Image.open(img_path, 'r') a = np.array(A) output_path = './Pic' if not os.path.exists(output_path): os.mkdir(output_path)
# 重建并保存图像 for k in range(1, K + 1): R = restore(sigma_r, u_r, v_r, k) G = restore(sigma_g, u_g, v_g, k) B = restore(sigma_b, u_b, v_b, k) I = np.stack((R, G, B), axis=2) Image.fromarray(I).save(f'{output_path}/svd_{k}.png') if k <= 12: plt.subplot(3, 4, k) plt.imshow(I) plt.axis('off') plt.title(f'奇异值个数:{k}')
# 定义收敛判断函数 def is_converged(a, b, tol=1e-6): n = len(a) for i in range(n): if abs(a[i] - b[i]) > tol: return False return True
if __name__ == '__main__': # 初始化矩阵 a = np.array([0.65, 0.28, 0.07, 0.15, 0.67, 0.18, 0.12, 0.36, 0.52]).reshape(3, 3) times = 0 prev_diag = None
while (times == 0) or (not is_converged(np.diag(a), prev_diag)): prev_diag = np.diag(a) q, r = np.linalg.qr(a) a = np.dot(r, q) times += 1 print("迭代次数:", times) print("正交阵:\n", q) print("上三角阵:\n", r) print("近似矩阵:\n", a)