読者です 読者をやめる 読者になる 読者になる

記号微分

R

deriv を利用して、R で記号微分してみます。

正規分布対数尤度関数を記号微分してみよう

正規分布対数尤度関数は、
\[
l(\mu, \s; x)=-\frac{\log(2\pi\s^2)}{2}-\frac{(x - \mu)^2}{2\s^2}
\] ですので、これを偏微分すると、
\[
\begin{align*}
\frac{\partial l(\mu, \s; x)}{\partial \mu} &= \frac{x-\mu}{\s^2} \\
\frac{\partial l(\mu, \s; x)}{\partial \s} &= -\frac{1}{\s}+\frac{(x-\mu)^2}{\s^3}
\end{align*}
\] となる予定です。


deriv を用いた数値微分では、まず被微分関数を expressioncall、または formula の形で定義します。

fx <- expression(
  -log(2*pi*sd^2)/2 - (x - mean)^2/(2*sd^2)
)


あとは、

  • 定義した被微分関数
  • 微分したい変数名
  • 出力される微分された関数に対する引数名
  • 必要に応じてオプション

を指定して実行するだけです。

dfx <- deriv(
  fx,
  c("mean", "sd"),
  function.arg = c("x", "mean", "sd")
)
dfx


結果は以下のようになりました。

> dfx
function (x, mean, sd) 
{
    .expr1 <- 2 * pi
    .expr2 <- sd^2
    .expr3 <- .expr1 * .expr2
    .expr7 <- x - mean
    .expr8 <- .expr7^2
    .expr9 <- 2 * .expr2
    .expr14 <- 2 * sd
    .value <- -log(.expr3)/2 - .expr8/.expr9
    .grad <- array(0, c(length(.value), 2L), list(NULL, c("mean", 
        "sd")))
    .grad[, "mean"] <- 2 * .expr7/.expr9
    .grad[, "sd"] <- -(.expr1 * .expr14/.expr3/2 - .expr8 * (2 * 
        .expr14)/.expr9^2)
    attr(.value, "gradient") <- .grad
    .value
}


一応確認しておくと、

.expr2 <- sd^2
.expr7 <- x - mean
.expr9 <- 2 * .expr2
.grad[, "mean"] <- 2 * .expr7/.expr9

# 2 * .expr7/.expr9 =
#   2 * (x - mean) / 2 * sd^2 =
#   (x - mean) / sd^2

および、

.expr1 <- 2 * pi
.expr2 <- sd^2
.expr3 <- .expr1 * .expr2
.expr7 <- x - mean
.expr8 <- .expr7^2
.expr9 <- 2 * .expr2
.expr14 <- 2 * sd
.grad[, "sd"] <- -(.expr1 * .expr14/.expr3/2 - .expr8 * (2 * 
        .expr14)/.expr9^2)

# -.expr1 * .expr14/.expr3/2 =
#   -2 * pi * 2 * sd / (2 * pi * sd^2) / 2 =
#   -1/sd
# .expr8 * (2 * .expr14)/.expr9^2 =
#   (x - mean)^2 * (2 * 2 * sd) / (2 * sd^2)^2 =
#   (x - mean)^2 / sd^3

となります。
このぐらいなら手計算の方が早いですが、いろいろな場面で利用できるのではないでしょうか。


また、Hessian を自動で出してくれる deriv3 もあります。

注意

deriv のヘルプのタイトルは、

Symbolic and Algorithmic Derivatives of Simple Expressions

となっていて、以下の説明があります。

The internal code knows about the arithmetic operators +, -, *, / and ^, and the single-variable functions exp, log, sin, cos, tan, sinh, cosh, sqrt, pnorm, dnorm, asin, acos, atan, gamma, lgamma, digamma and trigamma, as well as psigamma for one or two arguments (but derivative only with respect to the first). (Note that only the standard normal distribution is considered.)

deriv では基本的な演算子と、いくつかの単変数の関数ぐらいしか使えないため、被微分関数を作るときに注意が必要です。