本文参考了知乎的文章 机器学习 | 近端梯度下降法 (proximal gradient descent), 写的非常棒,但感觉有些微的赘余, 因此以这篇博客,希望更精简地介绍 近端梯度下降法 这种略显陌生的 算法。
对于传统的可微的目标函数, 直接使用梯度下降法即可。 而对于不可微的情况下, 就是 近端梯度法 表现的时机了。 简而言之, 目标函数必可写成如下的形式:
f(x)=g(x)+h(x)f(x)=g(x)+h(x) f(x)=g(x)+h(x)
其中 g(x)g(x)g(x) 可微, 而 h(x)h(x)h(x) 不可微 。 这时候, 近端梯度法的迭代公式为:
xk=proxth(⋅)(xk−1−t∇g(xk−1))(1)\boldsymbol{x}^{k}=\operatorname{prox}_{t h(\cdot)}\left(\boldsymbol{x}^{k-1}-t \nabla g\left(\boldsymbol{x}^{k-1}\right)\right)\tag{1} xk=proxth(⋅)(xk−1−t∇g(xk−1))(1)
其中, proxth(⋅)\operatorname{prox}_{t h(\cdot)}proxth(⋅) 为 近端投影算子, 由下式给出:
proxth(⋅)(x)=argminz12t∥x−z∥22+h(x)(2)\operatorname{prox}_{t h(\cdot)}(\boldsymbol{x})=\arg \min _{\boldsymbol{z}} \frac{1}{2 t}\|\boldsymbol{x}-\boldsymbol{z}\|_{2}^{2} + h(x)\tag{2} proxth(⋅)(x)=argzmin2t1∥x−z∥22+h(x)(2)
两式中, ttt 均代表步长。
好了,这就是近端梯度法的算法步骤了, 似乎非常简洁明了: 先根据g(x)g(x)g(x)做一个梯度下降, 再根据h(x)h(x)h(x)做一个近端投影。 那么问题来了, 为什么可以这样做, 意义又是什么?直接看下面的公式:
xk=proxth(⋅)(xk−1−t∇g(xk−1))=argminzh(z)+12t∥z−(xk−1−t∇g(xk−1))∥22=argminzh(z)+t2∥∇g(xk−1)∥22+∇g(xk−1)⊤(z−xk−1)+12t∥z−xk−1∥22=argminzh(z)+g(xk−1)+∇g(xk−1)⊤(z−xk−1)+12t∥z−xk−1∥22≈argminzh(z)+g(z)\begin{aligned} \boldsymbol{x}^{k} &=\operatorname{prox}_{t h(\cdot)}\left(\boldsymbol{x}^{k-1}-t \nabla g\left(\boldsymbol{x}^{k-1}\right)\right) \\ &=\arg \min _{\boldsymbol{z}} h(\boldsymbol{z})+\frac{1}{2 t}\left\|\boldsymbol{z}-\left(\boldsymbol{x}^{k-1}-t \nabla g\left(\boldsymbol{x}^{k-1}\right)\right)\right\|_{2}^{2} \\ &=\arg \min _{\boldsymbol{z}} h(\boldsymbol{z})+\frac{t}{2}\left\|\nabla g\left(\boldsymbol{x}^{k-1}\right)\right\|_{2}^{2}+\nabla g\left(\boldsymbol{x}^{k-1}\right)^{\top}\left(\boldsymbol{z}-\boldsymbol{x}^{k-1}\right)+\frac{1}{2 t}\left\|\boldsymbol{z}-\boldsymbol{x}^{k-1}\right\|_{2}^{2} \\ &=\arg \min _{\boldsymbol{z}} h(\boldsymbol{z})+g\left(\boldsymbol{x}^{k-1}\right)+\nabla g\left(\boldsymbol{x}^{k-1}\right)^{\top}\left(\boldsymbol{z}-\boldsymbol{x}^{k-1}\right)+\frac{1}{2 t}\left\|\boldsymbol{z}-\boldsymbol{x}^{k-1}\right\|_{2}^{2} \\ & \approx \arg \min _{\boldsymbol{z}} h(\boldsymbol{z})+g(\boldsymbol{z}) \end{aligned} xk=proxth(⋅)(xk−1−t∇g(xk−1))=argzminh(z)+2t1∥∥z−(xk−1−t∇g(xk−1))∥∥22=argzminh(z)+2t∥∥∇g(xk−1)∥∥22+∇g(xk−1)⊤(z−xk−1)+2t1∥∥z−xk−1∥∥22=argzminh(z)+g(xk−1)+∇g(xk−1)⊤(z−xk−1)+2t1∥∥z−xk−1∥∥22≈argzminh(z)+g(z)
第三个等式来自于把 12t∥z−(xk−1−t∇g(xk−1))∥22\frac{1}{2 t}\left\|\boldsymbol{z}-\left(\boldsymbol{x}^{k-1}-t \nabla g\left(\boldsymbol{x}^{k-1}\right)\right)\right\|_{2}^{2}2t1∥∥z−(xk−1−t∇g(xk−1))∥∥22 一项拆开, 而第四个等式 则是去掉了与 zzz 无关的项 t2∥∇g(xk−1)∥22\frac{t}{2}\left\|\nabla g\left(\boldsymbol{x}^{k-1}\right)\right\|_{2}^{2}2t∥∥∇g(xk−1)∥∥22, 增加了 g(xk−1)g\left(\boldsymbol{x}^{k-1}\right)g(xk−1) 一项, 第五步的不等式则是来自于泰勒展开的 二阶展开。 最后综合看结论就是:
xk≈argminzf(z)\boldsymbol{x}^{k} \approx \arg \min _{\boldsymbol{z}} f(z)xk≈argzminf(z)
因此, 近端梯度法实质上就是求取了目标函数的最小值。 而随着迭代的进行, 越逼近最优解, 泰勒展开也越精确。