Variational Approximation for Low-dose CT Reconstruction

9 minute read

Published:

In the scenario of low-dose CT image reconstruction, supervised deep learning is widely adopted, which generally demands a dataset containing pairs of normal-dose and correspoding low-dose images and is therefore challenging in clinical situations. This paper proposed an unsupervised deep learning method to tackle this problem.

ref: A Dataset-free Deep Learning Method for Low-Dose CT Image Reconstruction

Formulation

The image reconstruction problem can be formulated as an inverse problem

\[\boldsymbol{y}=\boldsymbol{A}\boldsymbol{x}+n\]

where

  • $\boldsymbol{A}$ denotes the projection matrix of CT imaging
  • $\boldsymbol{y}$ denotes the available measurement
  • $\boldsymbol{x}$ denotes the image to be reconstructed
  • $n$ denotes the measurement noise, often modeled by i.i.d. random variables

Approach

2 approaches to extend unsupervised denoising network to solve inverse problems

  1. Treat inverse problems as a denoising process, which post-process reconstructed images
  2. Use DIP, which typically implies that early stopping can be an effective technique for regularizing a denoising network

In This Paper

Generally speaking, the proposed method is based on Bayesian inference, where the prior distribution of an image is re-parametrized by a DNN with random weights.

In Bayesian inference, there are 2 representative estimators

  1. maximum a posterior (MAP) estimator \(\boldsymbol{x}_{\text{MAP}}=\arg\max\limits_{\boldsymbol{x}} p(\boldsymbol{x}|\boldsymbol{y}).\)
  2. MMSE estimator, a.k.a. conditional mean estimatorr \(\boldsymbol{x}_{\text{CM}}=\mathbb{E}_{(\boldsymbol{x}|\boldsymbol{y})}(\boldsymbol{x}|\boldsymbol{y})=\int \boldsymbol{x}\cdot p(\boldsymbol{x}|\boldsymbol{y})\mathrm{d}\boldsymbol{x}.\)

where

\[p(\boldsymbol{x}|\boldsymbol{y})\]

denotes the posterior distribution of $\boldsymbol{x}$ given the measurement $y$. It’s overt that both estimators require the posterior distribution, which models the data well.

A common practice in Bayesian inference is to repress it by Bayesian rule

\[p(\boldsymbol{x}|\boldsymbol{y})=\frac{p(\boldsymbol{y}|\boldsymbol{x})p(\boldsymbol{x})}{p(\boldsymbol{y})},\]

where the likelihood term writes as

\[p(\boldsymbol{y}|\boldsymbol{x})=\frac{1}{2\sigma^2}||\boldsymbol{y}-\boldsymbol{A}\boldsymbol{x}||^2_2.\]

in the presence of i.i.d. Gaussian noise $n\sim \mathcal{N}(0, \sigma^2I)$. Then, the study of estimators turn to defining a prior distribution $p(x)$ that accurately describes statistical characters for images to be reconstructed. In this work, the variable $x$ is re-expressed by a DNN with random weights

\[\boldsymbol{x} = f(\boldsymbol{x_0};\boldsymbol{\theta}),\]

where $\boldsymbol{x_0}$ is some initial seed and $\boldsymbol{\theta}$ are random variables (netowrk weights). After re-parametrization, the variables for inference now turn to $\boldsymbol{\theta}$; and again, the key to Bayesian inference now is to define an appropriate posterior distribution

\[p(\boldsymbol{\theta}|\boldsymbol{y})\]

for $\boldsymbol{\theta}$. Given that it’s not computationally tractable, variational approximation methods are adopted to approximate the posterior by a set of approximation distributions

\[q(\boldsymbol{\theta}|\boldsymbol{\mu})\]

parametrized by $\boldsymbol{\mu}$. The optimal approximation with distribution parameters $\boldsymbol{\mu}^{\ast}$ is then estimated by minimizing Kullback–Leibler (KL) divergence.

Brief Introduction to Variational Approximation

The goal of variational estimation is to approximate a posterior, say

\[p(\boldsymbol{\beta}|\boldsymbol{Y})\]

with a second distribution, called approximating distribution, $q(\boldsymbol{\beta})$. To make this approximation as accurate as possible, we search over the space of approximating distributions to find the particular distribution with the minimum KL divergence with the actual posterior. Formally, we have to minimize

\[\mathrm{KL}(q(\boldsymbol{\beta}) \| p(\boldsymbol{\beta} \mid \boldsymbol{Y})) \equiv \mathrm{KL}(q \| p)=-\int q(\boldsymbol{\beta}) \log \left\{\frac{p(\boldsymbol{\beta} \mid \boldsymbol{Y})}{q(\boldsymbol{\beta})}\right\} \mathrm{d} \boldsymbol{\beta}.\]

If no assumptions are made about the factorized distribution, then it is minimized when

\[q(\boldsymbol{\beta})=p(\boldsymbol{\beta}|\boldsymbol{Y}),\]

which is not helpful, since the posterior itself is generally intractable. So in paper An Introduction to Bayesian Inference via Variational Approximations, the author utilized factorized approximation, because the independence assumption results in the approximating distribution being divided into a set of factors (or blocks of parameters). $\boldsymbol{\beta}$ is first partitioned into a set of $K$ blocks,

\[\boldsymbol{\beta}=\{\beta_1, \beta_2, \cdots, \beta_K\}.\]

Then the approximating distributions turn into

\[q(\boldsymbol{\beta})=\prod\limits_{k=1}^Kq(\boldsymbol{\beta}_k).\]

After that, the variational algorithm will identify the specific parametric families that constitute each component of the factorized distribution, which I won’t elaborate on here.

Method

A self-supervised method for LDCT reconstruction from noisy measurement, which is built on the DNN-based re-parametrization for Bayesian inference.

Recall that CT reconstruction can be formulated as the following inverse peoblem:

💡 Given an observed image $\boldsymbol{y}\in \mathbb{R}^m$, which is corrupted according to forward model and noise $n$, find the unkown image $\boldsymbol{x}\in \mathbb{R}^n$ that satisfies the observation

\[\boldsymbol{y}=\boldsymbol{A}\boldsymbol{x}+n.\]

Considering a DNN with random weights for re-parametrization:

\[\boldsymbol{x}=f(\boldsymbol{x_0};\boldsymbol{\theta})\]

Then, the inference of $\boldsymbol{x}$ from noisy measurement $y$ turns to inferring the network weights $\boldsymbol{\theta}$ from $\boldsymbol{y}$. And now we need to derive the posterior, which can be approximated by a set of distributions

\[q(\boldsymbol{\theta}|\boldsymbol{\mu}), \boldsymbol{\theta}=\boldsymbol{\mu}\odot \boldsymbol{b},\]

where $\odot$ refers to element-wise multiplication, i.e.,

\[\boldsymbol{\theta}=\boldsymbol{\mu}\odot \boldsymbol{b}: \theta_i=\mu_i\times b_i, 1\leq i\leq N,\]

where $\mu_i$ denotes the distribution paramter of $\theta_i$, and $b_i\sim \mathbf{B}(p_i)$ follows a Bernoulli distribution with probability $p_i$, writes as

\[p(b_i)=p_i^{b_i}(1-p_i)^{1-b_i},\quad b_i\in \{0, 1\}.\]

In other words, the DNN with random weights used in this paper is the widely used network with dropout.

Training

Since we use

\[q(\boldsymbol{\theta}|\boldsymbol{\mu})\]

to approximate

\[p(\boldsymbol{\theta}|\boldsymbol{y}),\]

the optimal estimation is approximated by minimizing the KL divergence between them, i.e.,

\[\begin{aligned} & \min _{\boldsymbol{\mu}} \mathrm{KL}(q(\boldsymbol{\theta} \mid \boldsymbol{\mu})|| p(\boldsymbol{\theta} \mid \boldsymbol{y})) \\ = & \min _{\boldsymbol{\mu}} \mathbb{E}_{\boldsymbol{\theta} \sim q(\boldsymbol{\theta} \mid \boldsymbol{\mu})}[\log q(\boldsymbol{\theta} \mid \boldsymbol{\mu})-\log p(\boldsymbol{\theta} \mid \boldsymbol{y})] \\ \propto & \min _{\boldsymbol{\mu}} \mathbb{E}_{\boldsymbol{\theta} \sim q(\boldsymbol{\theta} \mid \boldsymbol{\mu})}[\log q(\boldsymbol{\theta} \mid \boldsymbol{\mu})-(\log p(\boldsymbol{y} \mid \boldsymbol{\theta})+\log p(\boldsymbol{\theta}))] \\ = & \min _{\boldsymbol{\mu}} \mathrm{KL}(q(\boldsymbol{\theta} \mid \boldsymbol{\mu})|| p(\boldsymbol{\theta}))-\mathbb{E}_{\boldsymbol{\theta} \sim q(\boldsymbol{\theta} \mid \boldsymbol{\mu})} \log p(\boldsymbol{y} \mid \boldsymbol{\theta}) . \end{aligned}\]

For the first term, suppose that $p(\theta)$ is a uniform distribution over a sufficiently large area $\Omega$. For $\theta_i\in {0, 1\times \mu_i}$, we have

\[q(\theta_i|\mu_i)=p_i^{\frac{\theta_i}{\mu_i}}(1-p_i)^{1-\frac{\theta_i}{\mu_i}},\]

and $p(\theta_i)=\frac{1}{s_i}$, where $s_i$ denotes the length of the domain of definition about $\theta_i$. Then the first term tranforms to

\[\begin{aligned} D_{K L}(q(\boldsymbol{\theta} \mid \boldsymbol{\mu}) \| p(\boldsymbol{\theta})) & =\sum_i D_{K L}\left(q\left(\theta_i \mid \mu_i\right) \| p\left(\theta_i\right)\right) \\ & =\sum_i q\left(\theta_i \mid \mu_i\right) \log \frac{q\left(\theta_i \mid \mu_i\right)}{p\left(\theta_i\right)}, \\ & =\sum_i\left(1-p_i\right) \log \left(1-p_i\right)+p_i \log p_i+\log s_i . \end{aligned}\]

It’s apparent from the last derivation that the expression of the first term does not include parameter $\mu$, i.e., it’s a constant.

For the second term, suppose that the measurement noise $n$ is Gaussian so that

\[p(n)\propto e^{\frac{-n^2}{2\widetilde{\sigma}^2}},\]

then we have

\[\log(p(\boldsymbol{y}|\boldsymbol{\theta}))\propto -\frac{1}{2\widetilde{\sigma}^2}||\boldsymbol{A}f(\boldsymbol{x_0};\boldsymbol{\theta})-\boldsymbol{y}||_2^2.\]

Considering that the first term bears no relevance to parameter $\mu$, we conclude that

\[\min _{\boldsymbol{\mu}} D_{K L}(q(\boldsymbol{\theta} \mid \boldsymbol{\mu}) \| p(\boldsymbol{\theta} \mid \boldsymbol{y})) \propto \min _{\boldsymbol{\mu}} \mathbb{E}_{\boldsymbol{\theta} \sim q(\boldsymbol{\theta} \mid \boldsymbol{\mu})}\left\|\boldsymbol{A} f\left(\boldsymbol{x}_0, \boldsymbol{\theta}\right)-\boldsymbol{y}\right\|_2^2\]

It can be seen from the formula above that, the KL divergence only constrains the estimation in the range space1 of the projection matrix $\boldsymbol{A}$. Recall Eq. (11) and definitions of distribution $q$ and $B(p)$, we deduce that

\[\begin{aligned} \min _{\boldsymbol{\mu}} & \mathbb{E}_{\boldsymbol{\theta} \sim q(\boldsymbol{\theta} \mid \boldsymbol{\mu})}\left\|\boldsymbol{A} f\left(\boldsymbol{x}_0, \boldsymbol{\theta}\right)-\boldsymbol{y}\right\|_2^2 \\ & =\min _{\boldsymbol{\mu}} \int\left\|\boldsymbol{A} f\left(\boldsymbol{x}_0, \boldsymbol{\theta}\right)-\boldsymbol{y}\right\|_2^2 q(\boldsymbol{\theta} \mid \boldsymbol{\mu}) \mathrm{d} \boldsymbol{\theta} \\ & \stackrel{\mathrm{d} \boldsymbol{\theta}=\boldsymbol{\mu} \odot \mathrm{d} \boldsymbol{b}}{=} \min _{\boldsymbol{\mu}} \int\left\|\boldsymbol{A} f\left(\boldsymbol{x}_0, \boldsymbol{\mu} \odot \boldsymbol{b}\right)-\boldsymbol{y}\right\|_2^2 \boldsymbol{B}(\boldsymbol{p}) \mathrm{d} \boldsymbol{b} \\ & =\min _{\boldsymbol{\mu}} \mathbb{E}_{\boldsymbol{b} \sim \boldsymbol{B}(\boldsymbol{p})}\left\|\boldsymbol{A} f\left(\boldsymbol{x}_0, \boldsymbol{\mu} \odot \boldsymbol{b}\right)-\boldsymbol{y}\right\|_2^2 . \end{aligned}\]

Integrating total variation (TV) regularization2, the loss function for training is defined as

\[\min _{\boldsymbol{\mu}} \mathbb{E}_{\boldsymbol{b} \sim \boldsymbol{B}(\boldsymbol{p})}\left\|\boldsymbol{A} f\left(\boldsymbol{x}_0, \boldsymbol{\mu} \odot \boldsymbol{b}\right)-\boldsymbol{y}\right\|_2^2+\alpha\left\|\nabla f\left(\boldsymbol{x}_0, \boldsymbol{\mu} \odot \boldsymbol{b}\right)\right\|_1,\]

where $\alpha$ is a hyperparameter. Note that when training, dropout is adopted.

Testing

Once the DNN is trained, we have an approximation to the target posterior, denoted by

\[q(\boldsymbol{\theta}|\boldsymbol{\mu}^{\ast}).\]

In this approach, the image $\boldsymbol{x}$ is estimated using the conditional mean estimator, i.e., given measurement $y$, its conditional mean estimator for $\boldsymbol{x}$ reads

\[\boldsymbol{x}_{\text{CM}}=\int \boldsymbol{x}p(\boldsymbol{x}|\boldsymbol{y})\mathrm{d}\boldsymbol{x}.\]

By re-parametrization $x=f(\boldsymbol{x_0};\theta)$, we have

\[\boldsymbol{x}_{\text{CM}}=\int \boldsymbol{x}p(\boldsymbol{x}|\boldsymbol{y})\mathrm{d}\boldsymbol{x}=\int f(\boldsymbol{x_0};\theta)p(\theta|\boldsymbol{y})\mathrm{d}\theta.\]

By approximating, we have a proximate conditional mean estimator of $\boldsymbol{x}$

\[\boldsymbol{x}_{\mathrm{CM}}^*=\int f\left(\boldsymbol{x}_0 ; \boldsymbol{\theta} \mid \boldsymbol{\mu}^*\right) q\left(\boldsymbol{\theta} \mid \boldsymbol{\mu}^*\right) \mathrm{d} \boldsymbol{\theta}\]

This integration is calculated using MC integration 3, i.e., after the network is trained, we take $K$ random samples of the networks with dropout:

\[f(\boldsymbol{x}_0;\boldsymbol{\theta}_k)=f(\boldsymbol{x}_0;\boldsymbol{\mu}^{\ast}\odot \boldsymbol{b}_k), \quad \boldsymbol{b}_k\sim \mathbf{B(p)}.\]

Then, the estimate is defined by averaging these $K$ samples

\[\boldsymbol{x}^*=\frac{1}{K} \sum_{k=1}^K f\left(\boldsymbol{x}_0 ; \boldsymbol{\theta}_k\right)=\frac{1}{K} \sum_{k=1}^K f\left(\boldsymbol{x}_0 ; \boldsymbol{\mu}^* \odot \boldsymbol{b}_k\right)\]

Summary

Why Self-supervised?

To understand why it is formulated in a self-supervised manner, we may start with the plain supervised learning paradigm.

In order to learn for reconstructing images $\boldsymbol{\hat{x}}$ from noisy observations $\boldsymbol{y}$, we generally need a ground truth $\boldsymbol{x}$ to supervise the learning process, i.e., we need to measure the error between the model’s prediction and ground truth to guide the back propagation. For a large amount of data and their corresponding labels ${y_i, x_i}$

\[\min_{\theta}\sum\limits_{i}\mathcal{L}(\mathcal{F}_{\theta}(y_i), x_i) + \hat{R}(\mathcal{F}_{\theta}).\]

However, in real-world scenarios, such perfectly-labeled dataset is hard to acquire. Typically, for medical image reconstruction, if we want supervised learning methodology, labels themselves need to be reconstructed using one algorithm or another, which not only is time-consuming, but also limits the performance and interpretability of the model. Therefore, how to rely merely on the noisy measurements $\boldsymbol{y}$ to train seems both challenging and meaningful. We can turn the formula above into a self-supervised manner by multiplying the projection matrix $A$.

\[\min_{\theta}\sum\limits_{i}\mathcal{L}(A\mathcal{F}_{\theta}(y_i), y_i) + \hat{R}(\mathcal{F}_{\theta}),\]

which resembles the formula derived in part Training.

  1. Range space of a matrix $M$ refers to the span (set of all possible linear combinations) of its column vectors. 

  2. TV regularization refers to regularization based on L1 norm of the gradients 

  3. Monte Carlo (MC) integration is a technique for numerical integration using random numbers. Typically, it numerically computes a definite integral.