AxCaliberSMT

Axon diamter mapping based on diffusion MRI (dMRI).

gpuAxCaliberSMT

AxCaliberSMT with askAdam solver.

Usage

obj = gpuAxCaliberSMT(b, delta, Delta, D0, Da, DeL, Dcsf, varargin)
out = obj.estimate(s, mask, extraData, fitting);

Model parameters

% a     : Axon diameter[um],
% f     : neurite fraction (f=fa/(fa+fe)),
% fcsf  : CSF fraction
% DeR   : hindered diffusion diffusivity [um2/ms]
model_params    = {'a';                   'f';'fcsf';                'DeR'};
ub              = [ 20;                     1;     1;                    3];
lb              = [0.1;                     0;     0;                 0.01];
startpoint      = [1.5925;  0.777777777777778;   0.1;    0.482105263157895];

I/O overview

obj = gpuAxCaliberSMT(b, delta, Delta, D0, Da, DeL, Dcsf);

Input

Description

b

1xNshell b-values vector [ms/um2]

delta

1xNshell diffusion gradient pulse width vector, aka little delta, same size as ‘bval’ [ms]

Delta

1xNshell diffusion time, aka big delta, same size as ‘bval’ [ms]

D0

intra-cellular intrinsic diffusivity [um2/ms]

Da

intra-cellular axial diffusivity [um2/ms]

DeL

extra-cellular axial diffusivity [um2/ms]

Dcsf

CSF diffusivity [um2/ms]

out = obj.estimate(dwi, mask, extraData, fitting);

Input

Description

dwi

4D dMRI data, can be either full acquisition or SMT signal [x,y,z,diffusion]

mask

3D mask, [x,y,z]

extradata

Structure array with additional data (Optional)

extradata.bval

1D b-values [1xdiffusion], same order as ‘dwi’ [ms/um2] (Optional, only if ‘dwi’ is full acquisition)

extradata.bvec

2D b-vector [3xdiffusion], same order as ‘dwi’ (Optional, only if ‘dwi’ is full acquisition)

extradata.ldelta

1D gradient duration [1xdiffusion], same order as ‘dwi’ [ms] (Optional, only if ‘dwi’ is full acquisition)

extradata.BDELTA

1D diffusion time [1xdiffusion], same order as ‘dwi’ [ms] (Optional, only if ‘dwi’ is full acquisition)

fitting

Structure array for model parameter estimation

fitting.optimiser

Algorithm for parameter update, ‘adam’ (default) | ‘sgdm’ | ‘rmsprop’

fitting.isdisplay

boolean, display optimisation process in graphic plot

fitting.convergenceValue

tolerance in loss gradient to stop the optimisation

fitting.convergenceWindow

# of elements in which ‘convergenceValue’ is computed

fitting.iteration

maximum # of optimisation iterations

fitting.initialLearnRate

initial learn rate of Adam optimiser

fitting.tol

tolerance in loss

fitting.lambda

regularisation parameter(s)

fitting.regmap

model parameter(s) in which regularisation is applied

fitting.TVmode

Mode for total variation (TV) regularisation, ‘2D’ | ‘3D’

fitting.lossFunction

loss function, ‘L1’ | ‘L2’ | ‘huber’ | ‘mse’

fitting.isPrior

Starting point estimated based on likelihood method instead of fix/random location

Example

Example script for noise propagation:

addpath(genpath('../../gacelle/'));
clear;
%% Simulate data

% for reproducibility
seed        = 8715; rng(seed); gpurng(seed);
Nsample     = 1e3;  % #voxel
SNR         = 100;   

% fixed parameters
D0          = 1.7;
Da_fixed    = 1.7;
DeL_fixed   = 1.7;
Dcsf        = 3;

% get current DWI protocol for simulation
bval_sorted     = [0.05, 0.35, 0.80, 1.5, 2.401, 3.45, 4.75, 6, 0.2, 0.95, 2.3, 4.25, 6.75, 9.85, 13.5, 17.8];
ldelta_sorted   = ones(size(bval_sorted))* 6; % ms
BDELTA_sorted   = [13,13,13,13,13,13,13,13,30,30,30,30,30,30,30,30]; %ms

% Parameter raneg for forward simulation
axonDia_range   = [0.5 6];
f_range         = [0.3, 1];
fscf_range      = [0 0.3];
DeR_range       = [0.5 1.5];

% generate ground truth
axonDia_GT   = single(rand(1,Nsample) * diff(axonDia_range)  + min(axonDia_range) );
fcsf_GT      = single(rand(1,Nsample) * diff(fscf_range)     + min(fscf_range));
f_GT         = single(rand(1,Nsample) * diff(f_range)        + min(f_range));
DeR_GT       = single(rand(1,Nsample) * diff(DeR_range)      + min(DeR_range));

% Forward signal simulation
model       = 'VanGelderen';
pars        = [];
pars.a      = single(axonDia_GT);
pars.f      = single(f_GT);
pars.fcsf   = single(fcsf_GT);
pars.DeR    = single(DeR_GT);
objGPU      = gpuAxCaliberSMT(bval_sorted, ldelta_sorted, BDELTA_sorted, D0, Da_fixed, DeL_fixed, Dcsf);
s           = objGPU.FWD(pars, model);

% Let assume Gaussian noise for simplicity
noiseLv = 1/SNR;
s       = s + randn(size(s)) .* noiseLv;
s       = permute(s,[2 3 4 1]);
mask    = ones(size(s,1:3))>0;  % create mask

%% askAdam estimation
fitting                     = [];
fitting.iteration           = 4000;
fitting.initialLearnRate    = 0.001;
fitting.convergenceValue    = 1e-8;
fitting.lossFunction        = 'l1';
fitting.tol                 = 1e-4;
fitting.isDisplay           = false;
fitting.lambda              = 0;
fitting.start               = 'likelihood';
fitting.patience            = 5;   
extraData                   = [];

out   = objGPU.estimate(s, mask, extraData, fitting);

%% plot result
figure;
field = fieldnames(pars);
tiledlayout(1,numel(field));
for k = 1:numel(field)
    nexttile;
    scatter(pars.(field{k}),out.final.(field{k}),5,'filled','MarkerFaceAlpha',.4);
    h = refline(1);
    h.Color = 'k';
    title(field{k});
    xlabel('GT');ylabel('Fitted');
end

Example script for real data:

%% demo_gpuAxCaliberSMTmcmc_RealData.m
%
% This demo provides several examples on the ulitisation of gpuAxCaliberSMTmcmc.m 
% for parameter estimation with in vivo data
% 
% Kwok-Shing Chan 
% kchan2@mgh.harvard.edu
%
% Date created: 25 March 2024 
% Date modified: 15 August 2024
% Date modified: 24 September 2024
%
%% add paths
addpath(genpath('../../gacelle')); % this is the path to 'gacelle' package
clear;

%% I/O: Load data
data_dir = '/path/to/your/bids/derivatives/processed_dwi/sub-<label>';

% Nb: # of unique b-value per little delta per big delta
% Nd: # of unique little delta
% ND: # of unique big delta

dwi     = niftiread(fullfile(data_dir, 'sub-<label>_dwi.nii.gz'));                      % full DWI data 
mask    = niftiread(fullfile(data_dir, 'sub-<label>_brain_mask.nii.gz'))>0;             % signal mask
bval    = readmatrix(fullfile(data_dir,'sub-<label>.bval'),         'FileType','text'); % 1x(Nb*Nd*ND) b-values, same length as the 4th dimension dwi
bvec    = readmatrix(fullfile(data_dir,'sub-<label>.bvec'),         'FileType','text'); % 3x(Nb*Nd*ND) gradient directions, 2nd dimension has the same length as the 4th dimension dwi
ldelta  = readmatrix(fullfile(data_dir,'sub-<label>.pulseWidth'),   'FileType','text'); % 1x(Nb*Nd*ND) little delta, same length as the 4th dimension dwi
BDELTA  = readmatrix(fullfile(data_dir,'sub-<label>.diffusionTime'),'FileType','text'); % 1x(Nb*Nd*ND) big delta, same length as the 4th dimension dwi

%% Algorithm parameters
bval        = bval/1e3; % s/mm2 to ms/um2
% fix parameters
D0          = 1.7;
Da_fixed    = 1.7;
DeL_fixed   = 1.7;
Dcsf        = 3;

% get unique b-values for each little delta and big delta
[bval_sorted,ldelta_sorted,BDELTA_sorted] = DWIutility.unique_shell(bval,ldelta,BDELTA);

obj     = DWIutility;
lmax    = 0;
dwi_smt = obj.get_Sl_all(dwi,bval,bvec,ldelta,BDELTA,lmax);

%% Usage #1: Basic default setting 
fitting                     = [];
fitting.itearation          = 4000;
fitting.initialLearnRate    = 0.001;
fitting.convergenceValue    = 1e-8;
fitting.tol                 = 1e-3;
fitting.isdisplay           = false;
fitting.lambda              = 0;
fitting.isPrior             = 1;

extractdata.bval    = bval;
extractdata.bvec    = bvec;
extractdata.ldelta  = ldelta;
extractdata.BDELTA  = BDELTA;

% reproducibility
seed = 892396; rng(seed); gpurng(seed);
% intiate askadam object
smt_gpu     = gpuAxCaliberSMT(bval_sorted, ldelta_sorted, BDELTA_sorted, D0, Da_fixed, DeL_fixed, Dcsf);
% askadam estimation
[out]       = smt_gpu.estimate(dwi, mask, extractdata, fitting);

%% Usage #2: Applying spatial regularisation
fitting                     = [];
fitting.itearation          = 4000;
fitting.initialLearnRate    = 0.001;
fitting.convergenceValue    = 1e-8;
fitting.tol                 = 1e-3;
fitting.isdisplay           = false;
fitting.regmap              = {'a','f'};        % apply TV regularisation on 2 maps
fitting.lambda              = {0.0001, 0.0001};
fitting.TVmode              = '3D';
fitting.voxelSize           = [2,2,2];
fitting.isPrior             = 1;

extractdata.bval    = bval;
extractdata.bvec    = bvec;
extractdata.ldelta  = ldelta;
extractdata.BDELTA  = BDELTA;

% reproducibility
seed = 892396; rng(seed); gpurng(seed);

% intiate askadam object
smt_gpu     = gpuAxCaliberSMT(bval_sorted, ldelta_sorted, BDELTA_sorted, D0, Da_fixed, DeL_fixed, Dcsf);
% askadam estimation
[out]       = smt_gpu.estimate(dwi, mask, extractdata, fitting);

gpuAxCaliberSMTmcmc

AxCaliberSMT with MCMC solver.

Usage

obj = gpuAxCaliberSMTmcmc(b, delta, Delta, D0, Da, DeL, Dcsf, varargin)
out = obj.estimate(s, mask, extraData, fitting);

Model parameters

% a     : Axon diameter[um],
% f     : neurite fraction (f=fa/(fa+fe)),
% fcsf  : CSF fraction
% DeR   : hindered diffusion diffusivity [um2/ms]
% noise : noise level
model_params    = {'a';                   'f';'fcsf';                'DeR';'noise'};
ub              = [ 20;                     1;     1;                    3;    0.1];
lb              = [0.1;                     0;     0;                 0.01;   0.01];
step            = [0.24875;              0.05;  0.05;   0.0393421052631579;  0.005];
startpoint      = [1.5925;  0.777777777777778;   0.1;    0.482105263157895;   0.05];

I/O overview

obj = gpuAxCaliberSMTmcmc(b, delta, Delta, D0, Da, DeL, Dcsf);

Input

Description

bval

1xNshell unique b-values vector [ms/um2]

BDELTA

1xNshell diffusion time, same size as ‘bval’ [ms]

out = obj.estimate(dwi, mask, extradata, fitting);

Input

Description

dwi

4D dMRI data, can be either full acquisition or SMT signal [x,y,z,diffusion]

mask

3D mask, [x,y,z]

extradata

Structure array with additional data (Optional)

extradata.bval

1D b-values [1xdiffusion], same order as ‘dwi’ [ms/um2] (Optional, only if ‘dwi’ is full acquisition)

extradata.bvec

2D b-vector [3xdiffusion], same order as ‘dwi’ (Optional, only if ‘dwi’ is full acquisition)

extradata.ldelta

1D gradient duration [1xdiffusion], same order as ‘dwi’ [ms] (Optional, only if ‘dwi’ is full acquisition)

extradata.BDELTA

1D diffusion time [1xdiffusion], same order as ‘dwi’ [ms] (Optional, only if ‘dwi’ is full acquisition)

fitting

Structure array for model parameter estimation

fitting.algorithm

MCMC algorithm, ‘MH’ (Metropolis-Hastings)|’GW’ (Affline-invariant ensemble)

fitting.iteration

# MCMC iterations

fitting.repetition

# repetition of MCMC proposal

fitting.thinning

sampling interval between iterations

fitting.burnin

iterations to be discarded at the beginning, if >1, the exact number will be used; else iteration*burnin

fitting.xStepSize

step size of model parameter in MCMC proposal, same size and order as ‘model_params’ (‘MH’ only)

fitting.StepSize

step size for ‘GW’ in MCMC proposal (‘GW’ only)

fitting.Nwalker

# random walkers (‘GW’ only)

fitting.metric

cell variable, metric(s) derived from posterior distribution, ‘mean’|’std’|’median’|’iqr’ (can be multiple)

fitting.start

Starting point methods, ‘likelihood’ | ‘default | 1xM parameters array

Output

Description

out

structure contains optimisation result

out.posterior

structure contains MCMC posterior samples

out.posterior.(model_params{k})

Model parameter MCMC posterior samples, masked and unshaped for memory preservation

out.{metric}.(model_params{k})

Posterior statistics chosen in fitting.metric

Example

Example script for noise propagation:

addpath(genpath('../../gacelle/'));
clear;
%% Simulate data

% for reproducibility
seed        = 8715; rng(seed); gpurng(seed);
Nsample     = 1e3;  % #voxel
SNR         = 50;   

% fixed parameters
D0          = 1.7;
Da_fixed    = 1.7;
DeL_fixed   = 1.7;
Dcsf        = 3;

% get current DWI protocol for simulation
bval_sorted     = [0.05, 0.35, 0.80, 1.5, 2.401, 3.45, 4.75, 6, 0.2, 0.95, 2.3, 4.25, 6.75, 9.85, 13.5, 17.8];
ldelta_sorted   = ones(size(bval_sorted))* 6; % ms
BDELTA_sorted   = [13,13,13,13,13,13,13,13,30,30,30,30,30,30,30,30]; %ms

% Parameter raneg for forward simulation
axonDia_range   = [0.1 6];
f_range         = [0.3, 1];
fscf_range      = [0 0.3];
DeR_range       = [0.5 1.5];

% generate ground truth
axonDia_GT   = single(rand(1,Nsample) * diff(axonDia_range)  + min(axonDia_range) );
fcsf_GT      = single(rand(1,Nsample) * diff(fscf_range)     + min(fscf_range));
f_GT         = single(rand(1,Nsample) * diff(f_range)        + min(f_range));
DeR_GT       = single(rand(1,Nsample) * diff(DeR_range)      + min(DeR_range));

% Forward signal simulation
model       = 'VanGelderen';
pars        = [];
pars.a      = axonDia_GT;
pars.f      = f_GT;
pars.fcsf   = fcsf_GT;
pars.DeR    = DeR_GT;
objGPU      = gpuAxCaliberSMTmcmc(bval_sorted, ldelta_sorted, BDELTA_sorted, D0, Da_fixed, DeL_fixed, Dcsf);
s           = objGPU.FWD(pars, model);

% Let assume Gaussian noise for simplicity
noiseLv = 1/SNR;
s       = s + randn(size(s)) .* noiseLv;
s       = gather(permute(s,[2 3 4 1]));
mask    = ones(size(s,1:3))>0;    % create mask

%% MCMC estimation
fitting             = [];
fitting.algorithm   = 'GW';
fitting.Nwalker     = 50;
fitting.StepSize    = 2;
fitting.iteration   = 1e4;
fitting.thinning    = 10;        % Sample every 10 iteration
fitting.metric      = {'median','iqr'};
fitting.burnin      = 0.1;       % 10% burn-in
extraData           = [];

out   = objGPU.estimate(s, mask, extraData, fitting);

%% plot result
figure;
field = fieldnames(pars);
tiledlayout(1,numel(field));
for k = 1:numel(field)
    nexttile;
    scatter(pars.(field{k}),out.median.(field{k}),5,'filled','MarkerFaceAlpha',.4);
    h = refline(1);
    h.Color = 'k';
    title(field{k});
    xlabel('GT');ylabel('Estimate');
end

Example script for real data:

%% demo_gpuAxCaliberSMTmcmc_RealData.m
%
% This demo provides several examples on the ulitisation of gpuAxCaliberSMTmcmc.m 
% for parameter estimation with in vivo data
% 
% Kwok-Shing Chan 
% kchan2@mgh.harvard.edu
%
% Date created: 25 March 2024 
% Date modified: 14 June 2024
%
%% add paths
addpath(genpath('/autofs/space/linen_001/users/kwokshing/tools/gacelle')); % this path should be accessible to the group
clear;

%% I/O: Load data
data_dir = '/path/to/your/bids/derivatives/processed_dwi/sub-<label>';

% Nb: # of unique b-value per little delta per big delta
% Nd: # of unique little delta
% ND: # of unique big delta

dwi     = niftiread(fullfile(data_dir, 'sub-<label>_dwi.nii.gz'));                      % full DWI data 
mask    = niftiread(fullfile(data_dir, 'sub-<label>_brain_mask.nii.gz'))>0;             % signal mask
bval    = readmatrix(fullfile(data_dir,'sub-<label>.bval'),         'FileType','text'); % 1x(Nb*Nd*ND) b-values, same length as the 4th dimension dwi
bvec    = readmatrix(fullfile(data_dir,'sub-<label>.bvec'),         'FileType','text'); % 3x(Nb*Nd*ND) gradient directions, 2nd dimension has the same length as the 4th dimension dwi
ldelta  = readmatrix(fullfile(data_dir,'sub-<label>.pulseWidth'),   'FileType','text'); % 1x(Nb*Nd*ND) little delta, same length as the 4th dimension dwi
BDELTA  = readmatrix(fullfile(data_dir,'sub-<label>.diffusionTime'),'FileType','text'); % 1x(Nb*Nd*ND) big delta, same length as the 4th dimension dwi

%% Algorithm parameters
bval        = bval/1e3; % s/mm2 to ms/um2
% fix parameters
D0          = 1.7;
Da_fixed    = 1.7;
DeL_fixed   = 1.7;
Dcsf        = 3;

% get unique b-values for each little delta and big delta
[bval_sorted,ldelta_sorted,BDELTA_sorted] = DWIutility.unique_shell(bval,ldelta,BDELTA);

extraData           = [];
extraData.bval      = bval;
extraData.bvec      = bvec;
extraData.ldelta    = ldelta;
extraData.BDELTA    = BDELTA;

%% Usage #1: Basic default setting (same as Hong-Hsi's original implementation)
fitting             = [];
fitting.algorithm   = 'MH';     % Metropolis-Hastings
fitting.iteration   = 2e4;      % 2e4 for demo purpose. Original implementation used 2e5.
fitting.thinning    = 100;
fitting.metric      = {'median','mean'};
fitting.start       = 'default';

% get the GPU device
g = gpuDevice;
% intiate MCMC object
smt_gpu                     = gpuAxCaliberSMTmcmc(bval_sorted, ldelta_sorted, BDELTA_sorted, D0, Da_fixed, DeL_fixed, Dcsf);
% MCMC estimation
out = smt_gpu.estimate(dwi, mask, extraData, fitting);
% reset GPU memory
reset(g)

% Here is an example to convert posterior distribution into image space
a_dist = mcmc.distribution2image(out.posterior.a,mask);

%% Usgae #2: Affine invariant ensemble MCMC
fitting             = [];
fitting.algorithm   = 'GW';
fitting.Nwalker     = 50;
fitting.StepSize    = 2;
fitting.iteration   = 2e3;     
fitting.thinning    = 20;       % Samples every 20 iterations
fitting.method      = 'median';
fitting.burnin      = 0.2;      % 10% burn-in

% get the GPU device
g = gpuDevice;
% intiate MCMC object
smt_gpu                     = gpuAxCaliberSMTmcmc(bval_sorted, ldelta_sorted, BDELTA_sorted, D0, Da_fixed, DeL_fixed, Dcsf);
% MCMC estimation
out = smt_gpu.estimate(dwi, mask, extraData, fitting);
% reset GPU memory
reset(g)

%% Usgae #2: You may use the internal save option to export the posterior distributions into disk
%  This may be useful when the #MCMC iterations is large to avoid memory issue
fitting                 = [];
fitting.algorithm       = 'MH';     % Metropolis-Hastings
fitting.iteration       = 2e4;
fitting.sampling        = 100;
fitting.method          = 'median';
fitting.start           = 'default';
fitting.outputFilename = '/path/to/output_dir/filename.mat';   % provide your output info here

g = gpuDevice;
smt_gpu                 = gpuAxCaliberSMTmcmc(bval_sorted, ldelta_sorted, BDELTA_sorted, D0, Da_fixed, DeL_fixed, Dcsf);
out                     = smt_gpu.estimate(dwi, mask, extraData, fitting);
reset(g)

% to load the distribution back to Workspace
load(fitting.outputFilename);
a_dist = mcmc.distribution2image(out.posterior.a,mask);

%% Usgae #3: If you perfer to use the spherical mean signal as the input it is also possible
% This might be useful if you already have the SMT signal computed and don't want to laod the entire DWI into memory
% You may use whatever method here, just make sure the DWI order in the 4-th dimension  must match the bval/ldelta/BDELTA order used in Line #77)
obj     = DWIutility;
lmax    = 0;
dwi_smt = obj.get_Sl_all(dwi,bval,bvec,ldelta,BDELTA,lmax);

fitting             = [];
fitting.algorithm   = 'GW';
fitting.Nwalker     = 50;
fitting.StepSize    = 2;
fitting.iteration   = 1e4;     
fitting.thinning    = 20;       % Samples every 20 iterations
fitting.method      = 'median';
fitting.burnin      = 0.1;      % 10% burn-in

g = gpuDevice;
smt_gpu                 = gpuAxCaliberSMTmcmc(bval_sorted, ldelta_sorted, BDELTA_sorted, D0, Da_fixed, DeL_fixed, Dcsf);
out                     = smt_gpu.estimate(dwi_smt, mask, [], fitting);
reset(g)

%% Usgae #4: You might also define the starting points yourself
fitting                 = [];
fitting.algorithm       = 'GW';
fitting.Nwalker         = 50;
fitting.StepSize        = 2;
fitting.iteration       = 1e4;     
fitting.thinning        = 20;       % Samples every 20 iterations
fitting.method          = 'median';
fitting.burnin          = 0.1;      % 10% burn-in
fitting.start           = [2,0.5,0.20,0.7,0.02];

g = gpuDevice;
smt_gpu     = gpuAxCaliberSMTmcmc(bval_sorted, ldelta_sorted, BDELTA_sorted, D0, Da_fixed, DeL_fixed, Dcsf);
out         = smt_gpu.estimate(dwi, mask, extraData, fitting);
reset(g)

%% Below list a few methods to reduce the number of outliers in the MH MCMC results
% The three options are not mutual exclusive so you may use all or some of the options together

%% Option 1: Use maximum likelihood to initiate estimation starting points
fitting             = [];
fitting.algorithm   = 'MH';     % Metropolis-Hastings
fitting.iteration   = 2e4;
fitting.sampling    = 100;
fitting.method      = 'median';
fitting.start       = 'likelihood';

g = gpuDevice;
smt_gpu     = gpuAxCaliberSMTmcmc(bval_sorted, ldelta_sorted, BDELTA_sorted, D0, Da_fixed, DeL_fixed, Dcsf);
out         = smt_gpu.estimate(dwi, mask, extraData, fitting);
reset(g)

%% Option 2: Use narrower parameter range 
fitting             = [];
fitting.algorithm   = 'MH';     % Metropolis-Hastings
fitting.iteration   = 2e4;
fitting.sampling    = 100;
fitting.method      = 'median';
fitting.start       = 'likelihood';
fitting.boundary    = [ 0.1 10        ;   % radius, um
                          0 1         ;   % intra-cellular volume fraction
                          0 1         ;   % isotropic volume fraction
                       0.2 DeL_fixed  ;   % extra-cellular RD, um2/ms
                       0.01 0.1      ];

g = gpuDevice;
smt_gpu         = gpuAxCaliberSMTmcmc(bval_sorted, ldelta_sorted, BDELTA_sorted, D0, Da_fixed, DeL_fixed, Dcsf);
out             = smt_gpu.estimate(dwi, mask, extraData, fitting);
reset(g)

%% Option 3: Run MCMC with multiple proposal at the same starting position, the final results will be the median/mean across all repetitions
t = tic;

fitting             = [];
fitting.algorithm   = 'MH';     % Metropolis-Hastings
fitting.repetition  = 5;        % same starting position, different proposal
fitting.iteration   = 2e4;
fitting.sampling    = 100;
fitting.method      = 'median';
fitting.start       = 'default';

g = gpuDevice;
smt_gpu     = gpuAxCaliberSMTmcmc(bval_sorted, ldelta_sorted, BDELTA_sorted, D0, Da_fixed, DeL_fixed, Dcsf);
out         = smt_gpu.estimate(dwi, mask, extraData, fitting);
reset(g)

toc(t)

% Here is an example to convert posterior distribution into image space
a_dist = mcmc.distribution2image(out.posterior.a,mask);

%% Advanced Option 4: Run MCMC with multiple times, with different starting points, using multiple GPUs
% which may or may not speed up the process depending on the setup
% if you want to save the posterior distributions it would be better to run in serial (i.e. the option above)
% Note that the display messages will be scrumbled
t = tic;

random_bound = 0.2; % randomly pick +/-20% starting points

% check available GPU
availableGPUs = gpuDeviceCount("available");
if availableGPUs > 1 && isempty(gcp('nocreate'))
pool = parpool("Processes",availableGPUs);
end

repetition = 5;
% out     = repmat(struct(), repetition, 1 );
a       = zeros([size(mask) repetition]);
f       = zeros([size(mask) repetition]);
fcsf    = zeros([size(mask) repetition]);
DeR     = zeros([size(mask) repetition]);
noise   = zeros([size(mask) repetition]);

g = gpuDevice;

if availableGPUs > 1

parfor k = 1:repetition
fitting                 = [];
fitting.algorithm       = 'MH';     % Metropolis-Hastings
fitting.repetition      = 1;
fitting.iteration       = 2e4;
fitting.sampling        = 100;
fitting.method          = 'median';
fitting.start           = [2,0.5,0.20,0.7,0.02] .* (1-random_bound)+(random_bound*2)*rand(1,5);
% if you want the posterior distribution of each iteration, you need to save it in the disk space instead
% fitting.outputFilename = strcat('/path/to/output_dir/filename_',num2str(k),'.mat');   
fitting.outputFilename = [];

smt_gpu                     = gpuAxCaliberSMTmcmc(bval_sorted, ldelta_sorted, BDELTA_sorted, D0, Da_fixed, DeL_fixed, Dcsf);
[~,a(:,:,:,k),f(:,:,:,k),fcsf(:,:,:,k),DeR(:,:,:,k),noise(:,:,:,k)]    = smt_gpu.estimate(dwi, mask, extraData, fitting);

end
toc(t)

reset(g)
delete(pool)

a_final = median(a,4);

end