線形回帰の最小二乗法をベクトル偏微分で解く

||y-Ab||^2 の最小化を偏微分で求めるが、ベクトル演算とベクトル微分が「これは知ってて当然でしょ」という感じでさっくり飛ばされることが多いのでしつこいくらいにひとつずつやる。

下準備
微分されるベクトル変数はx, 係数のベクトルはa とする。基本的にx またはa とかくと、列ベクトルつまり縦に長いベクトルである。つまり、

x=\begin{pmatrix}x_1\\x_2\\\vdots\\x_n\end{pmatrix}, a=\begin{pmatrix}a_1\\a_2\\\vdots\\a_n\end{pmatrix} である。縦に書くと長いので、x もしくは転置T して

x=(x_1,x_2,\dots,x_n)^T, a=(a_1,a_2,\dots,a_n)^T と書く。

ベクトルの演算として,

a^Tx=x^Ta

である。

転置は(XY)^T=Y^TX^T である。

偏微分\frac{\partial}{\partial x}a^Tx=a=\frac{\partial}{\partial x}x^Ta である。

二次形式という表現x^TAx というものがある。これは、ただ単純にノルムを考えるとa^Ta (普通の二乗和)になるが,ベクトル表現のときに正方行列X を間にかませると、x^2, y^2, xy, x, y の組み合わせを得ることができる。

二次形式の微分は、A のひだりとみぎでそれぞれ微分するから、
\frac{\partial}{\partial x} x^TAx=\frac{\partial}{\partial x} x^T\cdot Ax+x^TA\frac{\partial}{\partial x} x

Ax=M, x^TA=N とおけば、\frac{\partial}{\partial x} x^TAx=\frac{\partial}{\partial x}x^TM+\frac{\partial}{\partial x}Nx=M+N^T

M+N^T=Ax+(x^TA)^T=Ax+A^Tx=(A+A^T)x となる。


というわけで、E=||y-Ab||^2 を展開して、偏微分で0 になるときのb を求めにかかるが

E=(y-Ab)^T(y-Ab)
普通に二乗の形にするが、ベクトル(というか行列表示)ではこうする
E=(y^T-(Ab)^T)(y-Ab)
転置を中にいれた
E=(y^T-b^TA^T)(y-Ab)
転置の公式
E=||y||^2-y^TAb-b^TA^Ty+b^TA^TAb
展開する
E=||y||^2-(Ab)^Ty-b^TA^Ty+b^TA^TAb
a^Tx=x^Ta を思い出すと、ここで、y^TAb をかたまりとみて転置する。
E=||y||^2-b^TA^Ty-b^TA^Ty+b^TA^TAb
転置の公式を使って(Ab)^T=b^TA^T とした
E=||y||^2-2b^TA^Ty+b^TA^TAb
まとめた
 
さてこれをb偏微分する。||y||^2b に関係ないので0 になるので
\frac{\partial}{\partial b}b^TA^Ty=A^Ty
b^T(A^Ty) と思えば、\frac{\partial}{\partial x}x^Ta=a である
\frac{\partial}{\partial b}b^TA^TAb=(A^TA+(A^TA)^T)b
A^TA=B とみなせば\frac{\partial}{\partial b}b^TBb=(B+B^T)b である
転置の公式(A^TA)^T=A^TA であり、
\frac{\partial}{\partial b}b^TA^TAb=(A^TA+(A^TA)^T)b=2A^TAb

というわけで、
\frac{\partial}{\partial b}E=-2A^Ty+2A^TAb

偏微分で0 となるときに求めるb だから

b=(A^TA)^{-1}A^Ty

が答え。


nr <- 20
nc <- 7
y <- runif(nr)
X <- matrix(runif(nr*nc), nr, nc)

lm(y ~ X - 1)
Call:
lm(formula = y ~ X - 1)

Coefficients:
      X1        X2        X3        X4        X5        X6        X7  
 0.01803   0.43939   0.28782   0.26241  -0.13173   0.02898   0.35624  
solve(t(X)%*%X)%*%t(X)%*%y
            [,1]
[1,]  0.01803270
[2,]  0.43938587
[3,]  0.28782293
[4,]  0.26241033
[5,] -0.13172939
[6,]  0.02897629
[7,]  0.35623832