Skip to article frontmatterSkip to article content
# Prepare Python environment

import scipy.io as sio
from pathlib import Path

# Imports

from pathlib import Path
import pandas as pd
import json
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import glob


# Configurations
data_folder_name = Path("../../../data/04-B1-03-Filtering/images")
    
def get_image(filename):
    # Load image data
    data = nib.load(Path(data_folder_name) / filename)
    data_volume = data.get_fdata() 

    dims = data_volume.shape

    im = np.squeeze(data_volume[:,:])

    xAxis = np.linspace(0,im.shape[0]-1, num=im.shape[0])
    yAxis = np.linspace(0,im.shape[1]-1, num=im.shape[1])
    return im, xAxis, yAxis

im_da_raw1, xAxis_da_raw1, yAxis_da_raw1 = get_image('raw_da_1.nii.gz')
im_da_raw2, xAxis_da_raw2, yAxis_da_raw1 = get_image('raw_da_2.nii.gz')
im_da_b1, xAxis_da_b1, yAxis_da_b1 = get_image('b1_clt_tse.nii.gz')

im_afi_raw1, xAxis_afi_raw1, yAxis_afi_raw1 = get_image('raw_afi_1.nii.gz')
im_afi_raw2, xAxis_afi_raw2, yAxis_afi_raw1 = get_image('raw_afi_2.nii.gz')
im_afi_b1, xAxis_afi_b1, yAxis_afi_b1 = get_image('b1_clt_afi.nii.gz')

im_bs_raw1, xAxis_bs_raw1, yAxis_bs_raw1 = get_image('raw_bs_1.nii.gz')
im_bs_raw2, xAxis_bs_raw2, yAxis_bs_raw1 = get_image('raw_bs_1.nii.gz')
im_bs_b1, xAxis_bs_b1, yAxis_bs_b1 = get_image('b1_clt_gre_bs_cr_fermi.nii.gz')

mask, xAxis_mask, yAxis_mask = get_image('brain_mask_es_2x2x5.nii.gz')

im_da_raw1 = np.flipud(im_da_raw1)
im_da_raw2 = np.flipud(im_da_raw2)
im_da_b1 = np.flipud(im_da_b1)

im_afi_raw1 = np.flipud(im_afi_raw1)
im_afi_raw2 = np.flipud(im_afi_raw2)
im_afi_b1 = np.flipud(im_afi_b1)

im_bs_raw1 = np.flipud(im_bs_raw1)
im_bs_raw2 = np.flipud(im_bs_raw2)
im_bs_b1 = np.flipud(im_bs_b1)

mask = np.flipud(mask)

# Normalize raw

im_da_raw1 = im_da_raw1 / np.max([np.max(im_da_raw1*mask),np.max(im_da_raw2*mask)])*1000
im_da_raw2 = im_da_raw2 / np.max([np.max(im_da_raw1*mask),np.max(im_da_raw2*mask)])*1000

im_afi_raw1 = im_afi_raw1 / np.max([np.max(im_afi_raw1*mask),np.max(im_afi_raw2*mask)])*1000
im_afi_raw2 = im_afi_raw2 / np.max([np.max(im_afi_raw1*mask),np.max(im_afi_raw2*mask)])*1000

im_bs_raw1 = im_bs_raw2 / np.max([np.max(np.abs(im_bs_raw1)*mask),np.max(np.abs(im_bs_raw2)*mask)])*np.pi
im_bs_raw2 = im_bs_raw2 / np.max([np.max(np.abs(im_bs_raw1)*mask),np.max(np.abs(im_bs_raw2)*mask)])*np.pi


## Plot
# PYTHON CODE
# Module imports

import matplotlib.pyplot as plt
import plotly.graph_objs as go
import numpy as np
from plotly import __version__
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
config={'showLink': False, 'displayModeBar': False}

init_notebook_mode(connected=True)

# PYTHON CODE
# Module imports
import matplotlib.pyplot as plt
from PIL import Image
from matplotlib.image import imread
import scipy.io
import plotly.graph_objs as go
import numpy as np
from plotly import __version__
from plotly.offline import init_notebook_mode, iplot, plot
config={'showLink': False, 'displayModeBar': False, 'responsive': True}

init_notebook_mode(connected=True)

import os
import markdown
import random
from scipy.integrate import quad

import warnings
warnings.filterwarnings('ignore')

xAxis_raw = np.linspace(0,128*2-1, num=128*2)
xAxis_b1 = np.linspace(0,128-1, num=128*2)
yAxis = np.linspace(0,88-1, num=88*1)

# DA acqs
da_acqs = np.concatenate((im_da_raw1, im_da_raw2), axis=1)

# AFI acqs
afi_acqs = np.concatenate((im_afi_raw1, im_afi_raw2), axis=1)

# DA acqs
bs_acqs = np.concatenate((im_bs_raw1, im_bs_raw2), axis=1)

# Mask

masks_concat = np.concatenate((mask, mask), axis=1)

trace_da_raw = go.Heatmap(x = xAxis_raw,
                   y = yAxis,
                   z=masks_concat*da_acqs,
                   zmin=0,
                   zmax=1000,
                   colorscale='gray',
                   showscale = False,
                   visible=True)

trace_da_b1 = go.Heatmap(x = xAxis_b1,
                   y = yAxis,
                   z=mask*im_da_b1,
                   zmin=0.7,
                   zmax=1.3,
                   colorscale='RdBu',
                   colorbar={"title": 'B<sub>1</sub>',
                             'titlefont': dict(
                                   family='Times New Roman',
                                   size=26,
                                   )
                            },
                   xaxis='x2',
                   yaxis='y2',
                   visible=True)
                   
trace_afi_raw = go.Heatmap(x = xAxis_raw,
                   y = yAxis,
                   z=masks_concat*afi_acqs,
                   zmin=0,
                   zmax=1000,
                   colorscale='gray',
                   showscale = False,
                   visible=False)

trace_afi_b1 = go.Heatmap(x = xAxis_b1,
                   y = yAxis,
                   z=mask*im_afi_b1,
                   zmin=0.7,
                   zmax=1.3,
                   colorscale='RdBu',
                   colorbar={"title": 'B<sub>1</sub> (ms)',
                             'titlefont': dict(
                                   family='Times New Roman',
                                   size=26,
                                   )
                            },
                   xaxis='x2',
                   yaxis='y2',
                   visible=False)
                   
trace_bs_raw = go.Heatmap(x = xAxis_raw,
                   y = yAxis,
                   z=masks_concat*bs_acqs,
                   zmin=-np.pi,
                   zmax=np.pi,
                   colorscale='gray',
                   showscale = False,
                   visible=False)

trace_bs_b1 = go.Heatmap(x = xAxis_b1,
                   y = yAxis,
                   z=mask*im_bs_b1,
                   zmin=0.7,
                   zmax=1.3,
                   colorscale='RdBu',
                   colorbar={"title": 'B<sub>1</sub> (ms)',
                             'titlefont': dict(
                                   family='Times New Roman',
                                   size=26,
                                   )
                            },
                   xaxis='x2',
                   yaxis='y2',
                   visible=False)

data=[trace_da_raw, trace_da_b1, trace_afi_raw, trace_afi_b1, trace_bs_raw, trace_bs_b1]

updatemenus = list([
    dict(active=0,
         x = 0.3,
         xanchor = 'left',
         y = -0.2,
         yanchor = 'bottom',
         direction = 'up',
         font=dict(
                family='Times New Roman',
                size=16
            ),
         buttons=list([   
            dict(label = 'Double Angle Mapping',
                 method = 'update',
                 args = [{'visible': [True, True, False, False, False, False],
                          'showscale': [False, True, False, False, False, False],},
                         ]),
            dict(label = 'Actual Flip angle Imaging',
                 method = 'update',
                 args = [
                            {
                            'visible': [False, False, True, True, False, False],
                            'showscale': [False, False, False, True, False, False],},
                           ]),
            dict(label = 'Bloch-Siegert shift',
                 method = 'update',
                 args = [{'visible': [False, False, False, False, True, True],
                            'showscale': [False, False, False, False, False, True],},
                           ]),
    ])
    )
])

layout = dict(
    width=750,
    height=400,
    margin = dict(
                t=40,
                r=50,
                b=10,
                l=50),
    annotations=[
        dict(
            x=0.09,
            y=1.13,
            showarrow=False,
            text='ACQ 1',
            font=dict(
                family='Times New Roman',
                size=28
            ),
            xref='paper',
            yref='paper'
        ),
        dict(
            x=0.49,
            y=1.13,
            showarrow=False,
            text='ACQ 2',
            font=dict(
                family='Times New Roman',
                size=28
            ),
            xref='paper',
            yref='paper'
        ),
        dict(
            x=0.9,
            y=1.13,
            showarrow=False,
            text='B<sub>1</sub> (map)',
            font=dict(
                family='Times New Roman',
                size=28
            ),
            xref='paper',
            yref='paper'
        ),
    ],
    xaxis = dict(range = [0,225], autorange = False,
             showgrid = False, zeroline = False, showticklabels = False,
             ticks = '', domain=[0, 0.83]),
    yaxis = dict(range = [0,120], autorange = False,
             showgrid = False, zeroline = False, showticklabels = False,
             ticks = '', domain=[0, 1]),
    xaxis2 = dict(range = [0,44], autorange = False,
             showgrid = False, zeroline = False, showticklabels = False,
             ticks = '', domain=[0.65, 0.98]),
    yaxis2 = dict(range = [0,120], autorange = False,
             showgrid = False, zeroline = False, showticklabels = False,
             ticks = '', domain=[0, 1], anchor='x2'),
    showlegend = False,
    autosize = False,
    updatemenus=updatemenus
)


fig = dict(data=data, layout=layout)

iplot(fig, filename = 'basic-heatmap', config = config)

Loading...