Skip to article frontmatterSkip to article content
# Prepare Python environment

import scipy.io as sio
from pathlib import Path

data_dir = Path("../../../data/04-B1-03-Filtering")
data_file = "b1filt_fig6.mat"

#Load either archived or generated plot variables
mat_contents = sio.loadmat(data_dir / data_file)

da_data = mat_contents["da_data"]
afi_data = mat_contents["afi_data"]
bs_data = mat_contents["bs_data"]

mask = mat_contents["mask"]

median_smoothing_factors = mat_contents["median_smoothing_factors"][0]

gaussian_smoothing_factors = mat_contents["gaussian_smoothing_factors"][0]

spline_smoothing_factors = mat_contents["spline_smoothing_factors"][0]

from os import path
from pathlib import Path
import os


# Imports

from pathlib import Path
import pandas as pd
import json
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import glob

xAxis = list(range(1,np.shape(mask)[1]))
yAxis = list(range(1,np.shape(mask)[0]))

# DA acqs
da_low = np.concatenate(
        (
        np.rot90(da_data['Raw'][0][0], -1), 
        np.rot90(da_data['median_low'][0][0], -1), 
        np.rot90(da_data['gaussian_low'][0][0], -1), 
        np.rot90(da_data['spline_low'][0][0], -1)
        ), 
        axis=1
)
da_medium = np.concatenate(
        (
        np.rot90(da_data['Raw'][0][0], -1), 
        np.rot90(da_data['median_medium'][0][0], -1), 
        np.rot90(da_data['gaussian_medium'][0][0], -1), 
        np.rot90(da_data['spline_medium'][0][0], -1)
        ), 
        axis=1
)
da_high = np.concatenate(
        (
        np.rot90(da_data['Raw'][0][0], -1), 
        np.rot90(da_data['median_high'][0][0], -1), 
        np.rot90(da_data['gaussian_high'][0][0], -1), 
        np.rot90(da_data['spline_high'][0][0], -1)
        ), 
        axis=1
)

# AFI acqs
afi_low = np.concatenate(
        (
        np.rot90(afi_data['Raw'][0][0], -1), 
        np.rot90(afi_data['median_low'][0][0], -1), 
        np.rot90(afi_data['gaussian_low'][0][0], -1), 
        np.rot90(afi_data['spline_low'][0][0], -1)
        ), 
        axis=1
)
afi_medium = np.concatenate(
        (
        np.rot90(afi_data['Raw'][0][0], -1), 
        np.rot90(afi_data['median_medium'][0][0], -1), 
        np.rot90(afi_data['gaussian_medium'][0][0], -1), 
        np.rot90(afi_data['spline_medium'][0][0], -1)
        ), 
        axis=1
)
afi_high = np.concatenate(
        (
        np.rot90(afi_data['Raw'][0][0], -1), 
        np.rot90(afi_data['median_high'][0][0], -1), 
        np.rot90(afi_data['gaussian_high'][0][0], -1), 
        np.rot90(afi_data['spline_high'][0][0], -1)
        ), 
        axis=1
)

# BS acqs
bs_low = np.concatenate(
        (
        np.rot90(bs_data['Raw'][0][0], -1), 
        np.rot90(bs_data['median_low'][0][0], -1), 
        np.rot90(bs_data['gaussian_low'][0][0], -1), 
        np.rot90(bs_data['spline_low'][0][0], -1)
        ), 
        axis=1
)
bs_medium = np.concatenate(
        (
        np.rot90(bs_data['Raw'][0][0], -1), 
        np.rot90(bs_data['median_medium'][0][0], -1), 
        np.rot90(bs_data['gaussian_medium'][0][0], -1), 
        np.rot90(bs_data['spline_medium'][0][0], -1)
        ), 
        axis=1
)
bs_high = np.concatenate(
        (
        np.rot90(bs_data['Raw'][0][0], -1), 
        np.rot90(bs_data['median_high'][0][0], -1), 
        np.rot90(bs_data['gaussian_high'][0][0], -1), 
        np.rot90(bs_data['spline_high'][0][0], -1)
        ), 
        axis=1
)

# Concate methods

data_low = np.concatenate(
        (da_low, afi_low, bs_low), 
        axis=0
)

data_medium = np.concatenate(
        (da_medium, afi_medium, bs_medium), 
        axis=0
)

data_high = np.concatenate(
        (da_high, afi_high, bs_high), 
        axis=0
)

# PYTHON CODE
# Module imports
import matplotlib.pyplot as plt
from PIL import Image
from matplotlib.image import imread
import scipy.io
import plotly.graph_objs as go
import numpy as np
from plotly import __version__
from plotly.offline import init_notebook_mode, iplot, plot
config={'showLink': False, 'displayModeBar': False, 'responsive': True}

init_notebook_mode(connected=True)

import os
import markdown
import random
from scipy.integrate import quad

import warnings
warnings.filterwarnings('ignore')

# DA acqs
da_low = np.concatenate(
        (
        np.rot90(da_data['Raw'][0][0], -1), 
        np.rot90(da_data['median_low'][0][0], -1), 
        np.rot90(da_data['gaussian_low'][0][0], -1), 
        np.rot90(da_data['spline_low'][0][0], -1)
        ), 
        axis=1
)
da_medium = np.concatenate(
        (
        np.rot90(da_data['Raw'][0][0], -1), 
        np.rot90(da_data['median_medium'][0][0], -1), 
        np.rot90(da_data['gaussian_medium'][0][0], -1), 
        np.rot90(da_data['spline_medium'][0][0], -1)
        ), 
        axis=1
)
da_high = np.concatenate(
        (
        np.rot90(da_data['Raw'][0][0], -1), 
        np.rot90(da_data['median_high'][0][0], -1), 
        np.rot90(da_data['gaussian_high'][0][0], -1), 
        np.rot90(da_data['spline_high'][0][0], -1)
        ), 
        axis=1
)

# AFI acqs
afi_low = np.concatenate(
        (
        np.rot90(afi_data['Raw'][0][0], -1), 
        np.rot90(afi_data['median_low'][0][0], -1), 
        np.rot90(afi_data['gaussian_low'][0][0], -1), 
        np.rot90(afi_data['spline_low'][0][0], -1)
        ), 
        axis=1
)
afi_medium = np.concatenate(
        (
        np.rot90(afi_data['Raw'][0][0], -1), 
        np.rot90(afi_data['median_medium'][0][0], -1), 
        np.rot90(afi_data['gaussian_medium'][0][0], -1), 
        np.rot90(afi_data['spline_medium'][0][0], -1)
        ), 
        axis=1
)
afi_high = np.concatenate(
        (
        np.rot90(afi_data['Raw'][0][0], -1), 
        np.rot90(afi_data['median_high'][0][0], -1), 
        np.rot90(afi_data['gaussian_high'][0][0], -1), 
        np.rot90(afi_data['spline_high'][0][0], -1)
        ), 
        axis=1
)

# BS acqs
bs_low = np.concatenate(
        (
        np.rot90(bs_data['Raw'][0][0], -1), 
        np.rot90(bs_data['median_low'][0][0], -1), 
        np.rot90(bs_data['gaussian_low'][0][0], -1), 
        np.rot90(bs_data['spline_low'][0][0], -1)
        ), 
        axis=1
)
bs_medium = np.concatenate(
        (
        np.rot90(bs_data['Raw'][0][0], -1), 
        np.rot90(bs_data['median_medium'][0][0], -1), 
        np.rot90(bs_data['gaussian_medium'][0][0], -1), 
        np.rot90(bs_data['spline_medium'][0][0], -1)
        ), 
        axis=1
)
bs_high = np.concatenate(
        (
        np.rot90(bs_data['Raw'][0][0], -1), 
        np.rot90(bs_data['median_high'][0][0], -1), 
        np.rot90(bs_data['gaussian_high'][0][0], -1), 
        np.rot90(bs_data['spline_high'][0][0], -1)
        ), 
        axis=1
)

# Concate methods

data_low = np.concatenate(
        (da_low, afi_low, bs_low), 
        axis=0
)

data_medium = np.concatenate(
        (da_medium, afi_medium, bs_medium), 
        axis=0
)

data_high = np.concatenate(
        (da_high, afi_high, bs_high), 
        axis=0
)


# Mask

masks_concat = np.concatenate((np.rot90(mask,-1), np.rot90(mask,-1), np.rot90(mask,-1), np.rot90(mask,-1)), axis=1)
mask_grid = np.concatenate((masks_concat, masks_concat, masks_concat), axis=0)

xAxis = np.linspace(0,88*4-1, num=88*4)
yAxis = np.linspace(0,128*3-1, num=128*3)


trace_low = go.Heatmap(x = xAxis,
                   y = yAxis,
                   z=np.flipud(mask_grid*data_low),
                   zmin=0.7,
                   zmax=1.3,
                   colorscale='RdBu',
                   colorbar={"title": 'B<sub>1</sub>',
                             'titlefont': dict(
                                   family='Times New Roman',
                                   size=26,
                                   )
                            },
                   visible=False)
                   
trace_medium = go.Heatmap(x = xAxis,
                   y = yAxis,
                   z=np.flipud(mask_grid*data_medium),
                   zmin=0.7,
                   zmax=1.3,
                   colorscale='RdBu',
                   colorbar={"title": 'B<sub>1</sub>',
                             'titlefont': dict(
                                   family='Times New Roman',
                                   size=26,
                                   )
                            },
                   visible=True)
                   
trace_high = go.Heatmap(x = xAxis,
                   y = yAxis,
                   z=np.flipud(mask_grid*data_high),
                   zmin=0.7,
                   zmax=1.3,
                   colorscale='RdBu',
                   colorbar={"title": 'B<sub>1</sub>',
                             'titlefont': dict(
                                   family='Times New Roman',
                                   size=26,
                                   )
                            },
                   visible=False)

data=[trace_low, trace_medium, trace_high]

updatemenus = list([
    dict(active=1,
         x = 0.4,
         xanchor = 'left',
         y = -0.08,
         yanchor = 'bottom',
         direction = 'up',
         font=dict(
                family='Times New Roman',
                size=16
            ),
         buttons=list([   
            dict(label = 'Weak filter',
                 method = 'update',
                 args = [{'visible': [True, False, False]},
                         ]),
            dict(label = 'Medium filter',
                 method = 'update',
                 args = [
                            {
                            'visible': [False, True, False]},
                           ]),
            dict(label = 'Strong filter',
                 method = 'update',
                 args = [{'visible': [False, False, True]},
                           ]),
    ])
    )
])

layout = dict(
    width=750,
    height=750,
    margin = dict(
                t=40,
                r=50,
                b=10,
                l=50),
    annotations=[
        dict(
            x=0.03,
            y=1.05,
            showarrow=False,
            text='Unfiltered',
            font=dict(
                family='Times New Roman',
                size=30
            ),
            xref='paper',
            yref='paper'
        ),
        dict(
            x=0.305,
            y=1.05,
            showarrow=False,
            text='Median',
            font=dict(
                family='Times New Roman',
                size=30
            ),
            xref='paper',
            yref='paper'
        ),
        dict(
            x=0.625,
            y=1.05,
            showarrow=False,
            text='Gaussian',
            font=dict(
                family='Times New Roman',
                size=30
            ),
            xref='paper',
            yref='paper'
        ),
        dict(
            x=0.94,
            y=1.05,
            showarrow=False,
            text='Spline',
            font=dict(
                family='Times New Roman',
                size=30
            ),
            xref='paper',
            yref='paper'
        ),
        dict(
            x=-0.06,
            y=0.95,
            showarrow=False,
            text='Double Angle',
            font=dict(
                family='Times New Roman',
                size=30
            ),
            textangle=-90,
            xref='paper',
            yref='paper'
        ),
        dict(
            x=-0.06,
            y=0.5,
            showarrow=False,
            text='AFI',
            font=dict(
                family='Times New Roman',
                size=30
            ),
            textangle=-90,
            xref='paper',
            yref='paper'
        ),
        dict(
            x=-0.06,
            y=0.05,
            showarrow=False,
            text='Bloch-Siegert',
            font=dict(
                family='Times New Roman',
                size=30
            ),
            textangle=-90,
            xref='paper',
            yref='paper'
        ),
    ],
    xaxis = dict(range = [0,88*4-1], autorange = False,
             showgrid = False, zeroline = False, showticklabels = False,
             ticks = '', domain=[0, 1]),
    yaxis = dict(range = [0,128*3-1], autorange = False,
             showgrid = False, zeroline = False, showticklabels = False,
             ticks = '', domain=[0, 1]),

    showlegend = False,
    autosize = False,
    updatemenus=updatemenus
)


fig = dict(data=data, layout=layout)

iplot(fig, filename = 'basic-heatmap', config = config)
Loading...