二値分類におけるBinary Cross Entropyの勾配計算とモデル更新

作成日 2025年7月18日金曜日

更新日 2025年7月27日日曜日

二値分類タスクを機械学習を用いて解く際には、ロジスティック回帰や(最終層の活性化関数を(標準)シグモイド関数とした)深層学習などが利用される。

これらのモデルを勾配降下法で学習するためには、損失関数である Binary Cross Entropy の勾配を計算する必要があるので、その計算をする。

対数の底は情報理論の文脈では 22 を用いるが、勾配計算の微分の文脈では便宜上 ee を用いる(最終的な結果が定数倍変わってしまうが、学習率に吸収されるため勾配計算では問題ない)。

定義

入力と目的変数

入力を x\bm{x} 、目的変数を yy とすると以下のように表される。

x=(x1x2xn)Rny{0,1}\begin{aligned} \bm{x} &= \begin{pmatrix} x_1 & x_2 & \dots & x_n \end{pmatrix} \in \mathbb{R}^n\\ y &\in \left\{0, 1\right\}\\ \end{aligned}

ここで、入力とはロジスティック回帰の場合は特徴量のベクトル、深層学習の場合は隠れ層の最終層(出力層の直前の層)の出力のベクトルを指す。

重みとバイアス

線形結合の重みを w\bm{w} 、バイアスを bb とすると以下のように表される。

w=(w1w2wn)RnbR\begin{aligned} \bm{w} &= \begin{pmatrix} w_1 & w_2 & \dots & w_n \end{pmatrix} \in \mathbb{R}^n\\ b &\in \mathbb{R}\\ \end{aligned}

予測モデル

(標準)シグモイド関数を σ:R(0,1)\sigma:\mathbb{R} \to (0,1) とすると、ロジット zz と予測確率 tt は以下のように表される。

ここで、(0,1)(0,1) は開区間 {xR0<x<1}\{x\in\mathbb{R}|0<x<1\} を表す。

z=xw+b=x1w1+x2w2++xnwn+bt=σ(z)=11+exp(z)=11+ez\begin{aligned} z &= \bm{x} \bm{w} + b\\ &= x_1 w_1 + x_2 w_2 + \dots + x_n w_n + b\\ t &= \sigma(z) = \frac{1}{1 + \exp(-z)} = \frac{1}{1 + e^{-z}} \end{aligned}

交差エントロピーの意味

目的変数 yy の確率変数を YY 、予測値 y^\hat{y} の確率変数を Y^\hat{Y} とすると、 YY の分布 ppY^\hat{Y} の分布 qq は 、b{0,1}b \in \{0, 1\} を用いて以下のように表される。

p(b)=p(Y=bx)={yb=11yb=0=by+(1b)(1y)q(b)=q(Y^=bx)={tb=11tb=0=bt+(1b)(1t)\begin{aligned} p(b) &= p(Y=b|\bm{x})\\ &= \begin{cases} y & b=1\\1-y & b=0 \end{cases}\\ &= by + (1 - b)(1 - y)\\ q(b) &= q(\hat{Y}=b|\bm{x})\\ &= \begin{cases} t & b=1\\1-t & b=0 \end{cases}\\ &= bt + (1 - b)(1 - t) \end{aligned}

ppqq の交差エントロピー H(p,q)H(p, q) は以下のように表される。

H(p,q)=Ebp[Iq(b)]=Ebp[log(q(b))]=b{0,1}p(b)log(q(b))=(p(1)log(q(1))+p(0)log(q(0)))=(ylog(t)+(1y)log(1t))\begin{aligned} H(p, q) &= E_{b\sim p}\left[I_{q}(b)\right]\\ &= E_{b\sim p}\left[-\log\left(q(b)\right)\right]\\ &= -\sum_{b\in\{0,1\}} p(b) \log\left(q(b)\right)\\ &= -\left(p(1) \log(q(1)) + p(0) \log(q(0))\right)\\ &= -\left(y \log(t) + (1 - y) \log(1 - t)\right) \end{aligned}

ここで、 Iq(b)I_{q}(b) は分布 qq による確率変数 bb の自己情報量を表し、 Ebp[Iq(b)]E_{b\sim p}\left[I_{q}(b)\right] は確率変数 bb が分布 pp に従うときの Iq(b)I_{q}(b) の期待値を表す。

交差エントロピー H(p,q)H(p, q) は、以下のように変形できる。

H(p,q)=Ebp[Iq(b)]=Ebp[Iq(b)]Ebp[Ip(b)]+Ebp[Ip(b)]=Ebp[Iq(b)Ip(b)]+Ebp[Ip(b)]=D(pq)+Hbp(b)\begin{aligned} H(p, q) &= E_{b\sim p}\left[I_{q}(b)\right]\\ &= E_{b\sim p}\left[I_{q}(b)\right] - E_{b\sim p}\left[I_{p}(b)\right] + E_{b\sim p}\left[I_{p}(b)\right]\\ &= E_{b\sim p}\left[I_{q}(b) - I_{p}(b)\right] + E_{b\sim p}\left[I_{p}(b)\right]\\ &= D(p||q) + H_{b\sim p}(b)\\ \end{aligned}

ここで、 Hbp(b)H_{b\sim p}(b) は確率変数 bb が分布 pp に従うときの bb のエントロピーを表し、 D(pq)D(p||q) は分布 pp と分布 qq の KL ダイバージェンスを表す。

KL ダイバージェンス D(pq)D(p||q) は分布 pp と分布 qq の距離(遠さ)やズレを表す指標とされる。

エントロピー Hbp(b)H_{b\sim p}(b) は分布 qq によらず一定であるため、交差エントロピー H(p,q)H(p, q) を最小化することは、KL ダイバージェンス D(pq)D(p||q) を最小化することと同じであり、分布 pp に近い分布 qq を求めることができる。

損失関数

交差エントロピー H(p,q)=(ylog(t)+(1y)log(1t))H(p, q) = -\left(y \log(t) + (1 - y) \log(1 - t)\right)yytt の関数として表したものを損失関数とする。

損失関数 L:{0,1},(0,1)RL:\{0,1\},(0,1) \to \mathbb{R} は以下のように表される。

L(y,t)=(ylog(t)+(1y)log(1t))\begin{aligned} L(y, t) &= -\left(y \log(t) + (1 - y) \log(1 - t)\right)\\ \end{aligned}

勾配計算

重み w\bm{w} の勾配 Lw(y,t)\frac{\partial L}{\partial \bm{w}}(y, t) とバイアス bb の勾配 Lb(y,t)\frac{\partial L}{\partial b}(y, t) を計算する。

それぞれ、以下のように変形できる。

Lw(y,t)=zwtzLt(y,t)Lb(y,t)=zbtzLt(y,t)\begin{aligned} \frac{\partial L}{\partial \bm{w}}(y, t) &= \frac{\partial z}{\partial \bm{w}} \frac{\partial t}{\partial z} \frac{\partial L}{\partial t}(y, t)\\ \frac{\partial L}{\partial b}(y, t) &= \frac{\partial z}{\partial b} \frac{\partial t}{\partial z} \frac{\partial L}{\partial t}(y, t)\\ \end{aligned}

それぞれの値を求める。

Lt(y,t)=t((ylog(t)+(1y)log(1t)))=(yt1y1t)=(y(1t)(1y)tt(1t))=(yytt+ytt(1t))=(ytt(1t))=tyt(1t)tz=zσ(z)=z(11+ez)=z(1+ez)1=(1+ez)2z(1+ez)=(1+ez)2ezz(z)=(1+ez)2ez(1)=ez(1+ez)2=11+ezez1+ez=11+ezez+111+ez=11+ez(1+ez)11+ez=11+ez(111+ez)=t(1t)zw=w(xw+b)=xzb=b(xw+b)=1\begin{aligned} \frac{\partial L}{\partial t}(y, t) &= \frac{\partial}{\partial t}\left(-\left(y \log(t) + (1 - y) \log(1 - t)\right)\right)\\ &= -\left(\frac{y}{t} - \frac{1 - y}{1 - t}\right)\\ &= -\left(\frac{y(1 - t) - (1 - y)t}{t(1 - t)}\right)\\ &= -\left(\frac{y - yt - t + yt}{t(1 - t)}\right)\\ &= -\left(\frac{y - t}{t(1 - t)}\right)\\ &= \frac{t - y}{t(1 - t)}\\ \frac{\partial t}{\partial z} &= \frac{\partial}{\partial z}\sigma(z)\\ &= \frac{\partial}{\partial z}\left(\frac{1}{1 + e^{-z}}\right)\\ &= \frac{\partial}{\partial z}\left(1 + e^{-z}\right)^{-1}\\ &= -\left(1 + e^{-z}\right)^{-2} \cdot \frac{\partial}{\partial z}\left(1 + e^{-z}\right)\\ &= -\left(1 + e^{-z}\right)^{-2} \cdot e^{-z} \cdot \frac{\partial}{\partial z}(-z)\\ &= -\left(1 + e^{-z}\right)^{-2} \cdot e^{-z} \cdot (-1)\\ &= \frac{e^{-z}}{\left(1 + e^{-z}\right)^2}\\ &= \frac{1}{1 + e^{-z}} \cdot \frac{e^{-z}}{1 + e^{-z}}\\ &= \frac{1}{1 + e^{-z}} \cdot \frac{e^{-z} + 1 - 1}{1 + e^{-z}}\\ &= \frac{1}{1 + e^{-z}} \cdot \frac{\left(1 + e^{-z}\right) - 1}{1 + e^{-z}}\\ &= \frac{1}{1 + e^{-z}} \cdot \left(1 - \frac{1}{1 + e^{-z}}\right)\\ &= t(1 - t)\\ \frac{\partial z}{\partial \bm{w}} &= \frac{\partial}{\partial \bm{w}}\left(\bm{x} \bm{w} + b\right)\\ &= \bm{x}\\ \frac{\partial z}{\partial b} &= \frac{\partial}{\partial b}\left(\bm{x} \bm{w} + b\right)\\ &= 1 \end{aligned}

これらを代入して、重み w\bm{w} の勾配とバイアス bb の勾配を求める。

Lw(y,t)=xt(1t)tyt(1t)=(ty)xLb(y,t)=1t(1t)tyt(1t)=ty\begin{aligned} \frac{\partial L}{\partial \bm{w}}(y, t) &= \bm{x} \cdot t(1 - t) \cdot \frac{t - y}{t(1 - t)}\\ &= (t - y)\bm{x}\\ \frac{\partial L}{\partial b}(y, t) &= 1 \cdot t(1 - t) \cdot \frac{t - y}{t(1 - t)}\\ &= t - y \end{aligned}

勾配降下法

勾配降下法で学習を行う場合を考える。

学習率を η\eta とすると、重み w\bm{w} とバイアス bb の更新は、以下のように行う。

wwηLw(y,t)=wη(ty)xbbηLb(y,t)=bη(ty)\begin{aligned} \bm{w} &\leftarrow \bm{w} - \eta \frac{\partial L}{\partial \bm{w}}(y, t)\\ &= \bm{w} - \eta (t - y)\bm{x}\\ b &\leftarrow b - \eta \frac{\partial L}{\partial b}(y, t)\\ &= b - \eta (t - y) \end{aligned}

重み付け

正例に対する重みを α\alpha とすると、重み付き損失関数 Lα(y,t):{0,1},(0,1)RL_\alpha(y, t):\{0,1\},(0,1) \to \mathbb{R} は以下のように表される。

Lα(y,t)=(αylog(t)+(1y)log(1t))\begin{aligned} L_\alpha(y, t) &= -\left(\alpha y \log(t) + (1 - y) \log(1 - t)\right)\\ \end{aligned}

ここで、 α>1\alpha > 1 のときは正例に対して重みを大きくし、再現率を高め、 α<1\alpha < 1 のときは負例に対して重みを大きくし、適合率を高める。 α\alpha の値はデータセットに含まれる正例の割合を p1p_1 とすると、 α=1p1p1=1p11\alpha = \frac{1 - p_1}{p_1} = \frac{1}{p_1} - 1 とすることが多い。

重み w\bm{w} の勾配 Lαw(y,t)\frac{\partial L_\alpha}{\partial \bm{w}}(y, t) とバイアス bb の勾配 Lαb(y,t)\frac{\partial L_\alpha}{\partial b}(y, t) を計算する。

L(y,t)L(y, t) の際と同様に、それぞれ、以下のように変形できる。

Lαw(y,t)=zwtzLαt(y,t)Lαb(y,t)=zbtzLαt(y,t)\begin{aligned} \frac{\partial L_\alpha}{\partial \bm{w}}(y, t) &= \frac{\partial z}{\partial \bm{w}} \frac{\partial t}{\partial z} \frac{\partial L_\alpha}{\partial t}(y, t)\\ \frac{\partial L_\alpha}{\partial b}(y, t) &= \frac{\partial z}{\partial b} \frac{\partial t}{\partial z} \frac{\partial L_\alpha}{\partial t}(y, t)\\ \end{aligned}

Lαt(y,t)\frac{\partial L_\alpha}{\partial t}(y, t) を求める。

Lαt(y,t)=t((αylog(t)+(1y)log(1t)))=(αyt1y1t)=(αy(1t)(1y)tt(1t))=(1y)tαy(1t)t(1t)\begin{aligned} \frac{\partial L_\alpha}{\partial t}(y, t) &= \frac{\partial}{\partial t}\left(-\left(\alpha y \log(t) + (1 - y) \log(1 - t)\right)\right)\\ &= -\left(\frac{\alpha y}{t} - \frac{1 - y}{1 - t}\right)\\ &= -\left(\frac{\alpha y(1 - t) - (1 - y)t}{t(1 - t)}\right)\\ &= \frac{(1 - y)t - \alpha y(1 - t)}{t(1 - t)} \end{aligned}

これと先に求めた値を代入して、重み w\bm{w} の勾配とバイアス bb の勾配を求める。

Lαw(y,t)=xt(1t)(1y)tαy(1t)t(1t)=((1y)tαy(1t))xLαb(y,t)=1t(1t)(1y)tαy(1t)t(1t)=(1y)tαy(1t)\begin{aligned} \frac{\partial L_\alpha}{\partial \bm{w}}(y, t) &= \bm{x} \cdot t(1 - t) \cdot \frac{(1 - y)t - \alpha y(1 - t)}{t(1 - t)}\\ &= \left((1 - y)t - \alpha y(1 - t)\right)\bm{x}\\ \frac{\partial L_\alpha}{\partial b}(y, t) &= 1 \cdot t(1 - t) \cdot \frac{(1 - y)t - \alpha y(1 - t)}{t(1 - t)}\\ &= (1 - y)t - \alpha y(1 - t) \end{aligned}