Skip to article frontmatterSkip to article content
# Prepare Python environment

import scipy.io as sio
from pathlib import Path

# Imports

from pathlib import Path
import pandas as pd
import json
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import glob

from contextlib import contextmanager
import sys, os
from pathlib import Path

@contextmanager
def suppress_stdout():
    with open(os.devnull, "w") as devnull:
        old_stdout = sys.stdout
        sys.stdout = devnull
        try:  
            yield
        finally:
            sys.stdout = old_stdout

import os
from pathlib import Path

def find_myst_yml_directories(start_dir=None):
    """
    Recursively search for directories containing myst.yml file.
    
    Args:
        start_dir (str or Path): Starting directory (defaults to current directory)
    
    Returns:
        list: List of full paths to directories containing myst.yml
    """
    if start_dir is None:
        start_dir = Path.cwd()
    else:
        start_dir = Path(start_dir)
    
    myst_dirs = []
    
    def _search_directory(current_dir):
        # Check if myst.yml exists in current directory
        myst_file = current_dir / "myst.yml"
        if myst_file.exists():
            myst_dirs.append(str(current_dir.resolve()))
            # Don't search subdirectories if we found myst.yml here
            return
        
        # Recursively search all subdirectories
        for item in current_dir.iterdir():
            if item.is_dir():
                try:
                    _search_directory(item)
                except (PermissionError, OSError):
                    # Skip directories we can't access
                    continue
    
    _search_directory(start_dir)
    return myst_dirs

def find_myst_yml_directories_upwards(start_dir=None):
    """
    Search for myst.yml in current directory, if not found go to parent and repeat.
    
    Args:
        start_dir (str or Path): Starting directory (defaults to current directory)
    
    Returns:
        str or None: Full path of directory containing myst.yml, or None if not found
    """
    if start_dir is None:
        current_dir = Path.cwd()
    else:
        current_dir = Path(start_dir)
    
    # Keep going up until we reach the filesystem root
    while current_dir != current_dir.parent:  # Stop at root
        myst_file = current_dir / "myst.yml"
        if myst_file.exists():
            return str(current_dir.resolve())
        
        # Move to parent directory
        current_dir = current_dir.parent
    
    return None

with suppress_stdout():
    repo_path = Path(find_myst_yml_directories_upwards())
    print(repo_path)
    data_req_path = repo_path / "binder" / "data_requirement.json"
    data_path = repo_path / "data"
    dataset_path = data_path / "qmrlab-mooc"

# Configurations
data_folder_name = dataset_path / "04-B1-03-Filtering" / "04-B1-03-Filtering" / "images"
    
def get_image(filename):
    # Load image data
    data = nib.load(Path(data_folder_name) / filename)
    data_volume = data.get_fdata() 

    dims = data_volume.shape

    im = np.squeeze(data_volume[:,:])

    xAxis = np.linspace(0,im.shape[0]-1, num=im.shape[0])
    yAxis = np.linspace(0,im.shape[1]-1, num=im.shape[1])
    return im, xAxis, yAxis

im_da_raw1, xAxis_da_raw1, yAxis_da_raw1 = get_image('raw_da_1.nii.gz')
im_da_raw2, xAxis_da_raw2, yAxis_da_raw1 = get_image('raw_da_2.nii.gz')
im_da_b1, xAxis_da_b1, yAxis_da_b1 = get_image('b1_clt_tse.nii.gz')

im_afi_raw1, xAxis_afi_raw1, yAxis_afi_raw1 = get_image('raw_afi_1.nii.gz')
im_afi_raw2, xAxis_afi_raw2, yAxis_afi_raw1 = get_image('raw_afi_2.nii.gz')
im_afi_b1, xAxis_afi_b1, yAxis_afi_b1 = get_image('b1_clt_afi.nii.gz')

im_bs_raw1, xAxis_bs_raw1, yAxis_bs_raw1 = get_image('raw_bs_1.nii.gz')
im_bs_raw2, xAxis_bs_raw2, yAxis_bs_raw1 = get_image('raw_bs_1.nii.gz')
im_bs_b1, xAxis_bs_b1, yAxis_bs_b1 = get_image('b1_clt_gre_bs_cr_fermi.nii.gz')

mask, xAxis_mask, yAxis_mask = get_image('brain_mask_es_2x2x5.nii.gz')

im_da_raw1 = np.flipud(im_da_raw1)
im_da_raw2 = np.flipud(im_da_raw2)
im_da_b1 = np.flipud(im_da_b1)

im_afi_raw1 = np.flipud(im_afi_raw1)
im_afi_raw2 = np.flipud(im_afi_raw2)
im_afi_b1 = np.flipud(im_afi_b1)

im_bs_raw1 = np.flipud(im_bs_raw1)
im_bs_raw2 = np.flipud(im_bs_raw2)
im_bs_b1 = np.flipud(im_bs_b1)

mask = np.flipud(mask)

# Normalize raw

im_da_raw1 = im_da_raw1 / np.max([np.max(im_da_raw1*mask),np.max(im_da_raw2*mask)])*1000
im_da_raw2 = im_da_raw2 / np.max([np.max(im_da_raw1*mask),np.max(im_da_raw2*mask)])*1000

im_afi_raw1 = im_afi_raw1 / np.max([np.max(im_afi_raw1*mask),np.max(im_afi_raw2*mask)])*1000
im_afi_raw2 = im_afi_raw2 / np.max([np.max(im_afi_raw1*mask),np.max(im_afi_raw2*mask)])*1000

im_bs_raw1 = im_bs_raw2 / np.max([np.max(np.abs(im_bs_raw1)*mask),np.max(np.abs(im_bs_raw2)*mask)])*np.pi
im_bs_raw2 = im_bs_raw2 / np.max([np.max(np.abs(im_bs_raw1)*mask),np.max(np.abs(im_bs_raw2)*mask)])*np.pi


## Plot
# PYTHON CODE
# Module imports

import matplotlib.pyplot as plt
import plotly.graph_objs as go
import numpy as np
from plotly import __version__
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
config={'showLink': False, 'displayModeBar': False}

init_notebook_mode(connected=True)

# PYTHON CODE
# Module imports
import matplotlib.pyplot as plt
from PIL import Image
from matplotlib.image import imread
import scipy.io
import plotly.graph_objs as go
import numpy as np
from plotly import __version__
from plotly.offline import init_notebook_mode, iplot, plot
config={'showLink': False, 'displayModeBar': False, 'responsive': True}

init_notebook_mode(connected=True)

import os
import markdown
import random
from scipy.integrate import quad

import warnings
warnings.filterwarnings('ignore')

xAxis_raw = np.linspace(0,128*2-1, num=128*2)
xAxis_b1 = np.linspace(0,128-1, num=128*2)
yAxis = np.linspace(0,88-1, num=88*1)

# DA acqs
da_acqs = np.concatenate((im_da_raw1, im_da_raw2), axis=1)

# AFI acqs
afi_acqs = np.concatenate((im_afi_raw1, im_afi_raw2), axis=1)

# DA acqs
bs_acqs = np.concatenate((im_bs_raw1, im_bs_raw2), axis=1)

# Mask

masks_concat = np.concatenate((mask, mask), axis=1)

trace_da_raw = go.Heatmap(x = xAxis_raw,
                   y = yAxis,
                   z=masks_concat*da_acqs,
                   zmin=0,
                   zmax=1000,
                   colorscale='gray',
                   showscale = False,
                   visible=True)

trace_da_b1 = go.Heatmap(x = xAxis_b1,
                   y = yAxis,
                   z=mask*im_da_b1,
                   zmin=0.7,
                   zmax=1.3,
                   colorscale='RdBu',
                   colorbar={"title": 'B<sub>1</sub>',
                             'titlefont': dict(
                                   family='Times New Roman',
                                   size=26,
                                   )
                            },
                   xaxis='x2',
                   yaxis='y2',
                   visible=True)
                   
trace_afi_raw = go.Heatmap(x = xAxis_raw,
                   y = yAxis,
                   z=masks_concat*afi_acqs,
                   zmin=0,
                   zmax=1000,
                   colorscale='gray',
                   showscale = False,
                   visible=False)

trace_afi_b1 = go.Heatmap(x = xAxis_b1,
                   y = yAxis,
                   z=mask*im_afi_b1,
                   zmin=0.7,
                   zmax=1.3,
                   colorscale='RdBu',
                   colorbar={"title": 'B<sub>1</sub> (ms)',
                             'titlefont': dict(
                                   family='Times New Roman',
                                   size=26,
                                   )
                            },
                   xaxis='x2',
                   yaxis='y2',
                   visible=False)
                   
trace_bs_raw = go.Heatmap(x = xAxis_raw,
                   y = yAxis,
                   z=masks_concat*bs_acqs,
                   zmin=-np.pi,
                   zmax=np.pi,
                   colorscale='gray',
                   showscale = False,
                   visible=False)

trace_bs_b1 = go.Heatmap(x = xAxis_b1,
                   y = yAxis,
                   z=mask*im_bs_b1,
                   zmin=0.7,
                   zmax=1.3,
                   colorscale='RdBu',
                   colorbar={"title": 'B<sub>1</sub> (ms)',
                             'titlefont': dict(
                                   family='Times New Roman',
                                   size=26,
                                   )
                            },
                   xaxis='x2',
                   yaxis='y2',
                   visible=False)

data=[trace_da_raw, trace_da_b1, trace_afi_raw, trace_afi_b1, trace_bs_raw, trace_bs_b1]

updatemenus = list([
    dict(active=0,
         x = 0.3,
         xanchor = 'left',
         y = -0.2,
         yanchor = 'bottom',
         direction = 'up',
         font=dict(
                family='Times New Roman',
                size=16
            ),
         buttons=list([   
            dict(label = 'Double Angle Mapping',
                 method = 'update',
                 args = [{'visible': [True, True, False, False, False, False],
                          'showscale': [False, True, False, False, False, False],},
                         ]),
            dict(label = 'Actual Flip angle Imaging',
                 method = 'update',
                 args = [
                            {
                            'visible': [False, False, True, True, False, False],
                            'showscale': [False, False, False, True, False, False],},
                           ]),
            dict(label = 'Bloch-Siegert shift',
                 method = 'update',
                 args = [{'visible': [False, False, False, False, True, True],
                            'showscale': [False, False, False, False, False, True],},
                           ]),
    ])
    )
])

layout = dict(
    width=700,
    height=400,
    margin = dict(
                t=40,
                r=50,
                b=10,
                l=50),
    annotations=[
        dict(
            x=0.09,
            y=1.13,
            showarrow=False,
            text='ACQ 1',
            font=dict(
                family='Times New Roman',
                size=28
            ),
            xref='paper',
            yref='paper'
        ),
        dict(
            x=0.49,
            y=1.13,
            showarrow=False,
            text='ACQ 2',
            font=dict(
                family='Times New Roman',
                size=28
            ),
            xref='paper',
            yref='paper'
        ),
        dict(
            x=0.9,
            y=1.13,
            showarrow=False,
            text='B<sub>1</sub> (map)',
            font=dict(
                family='Times New Roman',
                size=28
            ),
            xref='paper',
            yref='paper'
        ),
    ],
    xaxis = dict(range = [0,225], autorange = False,
             showgrid = False, zeroline = False, showticklabels = False,
             ticks = '', domain=[0, 0.83]),
    yaxis = dict(range = [0,120], autorange = False,
             showgrid = False, zeroline = False, showticklabels = False,
             ticks = '', domain=[0, 1]),
    xaxis2 = dict(range = [0,44], autorange = False,
             showgrid = False, zeroline = False, showticklabels = False,
             ticks = '', domain=[0.65, 0.98]),
    yaxis2 = dict(range = [0,120], autorange = False,
             showgrid = False, zeroline = False, showticklabels = False,
             ticks = '', domain=[0, 1], anchor='x2'),
    showlegend = False,
    autosize = False,
    updatemenus=updatemenus
)


fig = dict(data=data, layout=layout)

iplot(fig, filename = 'basic-heatmap', config = config)

Loading...