import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

#  读取mumax3 输出的table,放在pandas数据框架里
def read_mumax3_table(filename):
    """Puts the mumax3 output table in a pandas dataframe"""

    from pandas import read_table
    
    table = read_table(filename)
    table.columns = ' '.join(table.columns).split()[1::2]
    
    return table

#读取ovf文件
def read_mumax3_ovffiles(outputdir):
    """Load all ovffiles in outputdir into a dictionary of numpy arrays 
    with the ovffilename (without extension) as key"""
    
    from subprocess import run, PIPE, STDOUT
    from glob import glob
    from os import path
    from numpy import load

    # convert all ovf files in the output directory to numpy files
    p = run(["mumax3-convert","-numpy",outputdir+"/*.ovf"], stdout=PIPE, stderr=STDOUT)
    if p.returncode != 0:
        print(p.stdout.decode('UTF-8'))

    # read the numpy files (the converted ovf files)
    fields = {}
    for npyfile in glob(outputdir+"/*.npy"):
        key = path.splitext(path.basename(npyfile))[0]
        fields[key] = load(npyfile)
    
    return fields
#执行mumax脚本，并读取数据
def run_mumax3(script, name, verbose=False):
    """ Executes a mumax3 script and convert ovf files to numpy files
    
    Parameters
    ----------
      script:  string containing the mumax3 input script
      name:    name of the simulation (this will be the name of the script and output dir)
      verbose: print stdout of mumax3 when it is finished
    """
    
    from subprocess import run, PIPE, STDOUT
    from os import path

    scriptfile = name + ".txt" 
    outputdir  = name + ".out"

    # write the input script in scriptfile
    with open(scriptfile, 'w' ) as f:
        f.write(script)
    
    # call mumax3 to execute this script
    p = run(["mumax3","-f",scriptfile], stdout=PIPE, stderr=STDOUT)
    if verbose or p.returncode != 0:
        print(p.stdout.decode('UTF-8'))
        
    if path.exists(outputdir + "/table.txt"):
        table = read_mumax3_table(outputdir + "/table.txt")
    else:
        table = None
        
    fields = read_mumax3_ovffiles(outputdir)
    
    return table, fields

fmax =  6e9  #最大频率
T  = 30e-9   #脚本运行时间
dt = 10e-12   #采样频率
dx = 2e-9   #单元格大小
nx = 321   #单元格数量

Bz = 0.001  #z方向偏置场
A = 13e-12    #交换常数
Ms   = 800e3   #饱和磁化常数
alpha = 0.02   #阻尼系数
gamma  = 1.76e11   #旋磁比


script=f"""
sizeX := 642e-9
sizeY := 160e-9
sizeZ := 2e-9
Nx := 321
Ny := 80
setgridsize({nx}, Ny, 1)
setcellsize({dx}, sizeY/Ny,2e-9)
//setGeom(ellipse(500e-9, 160e-9))
enabledemag = false
//temp.setregion(2,30)
defregion(1,xrange(-2e-9,0))
defregion(2,xrange(0,600e-9))
defregion(3,xrange(600e-9,640e-9))

Msat = {Ms}
Aex = {A}
alpha= {alpha}
Ku1 = 1.3e3
anisU= vector(0,0,1)

alpha.setregion(2,{alpha})
alpha.setregion(3,1)

m=uniform(0.1, 0, 1)
relax()
f := 6e9 // 1GHz
A := 1// 10mT
B_ext.setregion(1,vector(A*sin(2*pi*{fmax}*t),0, {Bz}))

autosave(m, {dt})
tableAutosave({dt})
run({T})
"""
 
table, fields = run_mumax3(script,"spinwaves")






print(table)

plt.figure()

nanosecond = 1e-9
plt.plot( table["t"]/nanosecond, table["mx"])
plt.plot( table["t"]/nanosecond, table["my"])
plt.plot( table["t"]/nanosecond, table["mz"])

plt.xlabel("Time (ns)")
plt.ylabel("Magnetization")

plt.show()

#print(fields.keys())






# Stack all snapshots of the magnetization on top of each other
m = np.stack([fields[key] for key in sorted(fields.keys())])

# Select the x component
mx = m[:,0,0,0,:]

# Apply the two dimensional FFT
mx_fft = np.fft.fft2(mx)
mx_fft = np.fft.fftshift(mx_fft)

plt.figure(figsize=(10,6))

# Show the intensity plot of the 2D FFT
extent = [ -(2*np.pi)/(2*dx), (2*np.pi)/(2*dx), -1/(2*dt), 1/(2*dt)] # extent of k values and frequencies
plt.imshow(np.abs(mx_fft)**2, extent=extent, aspect='auto', origin='lower', cmap="inferno")

# Plot the analytical derived dispersion relation 
k = np.linspace(-6e8,6e8,2000)
freq_theory = A*gamma*k**2 /(np.pi*Ms) + gamma*Bz /(2*np.pi)
plt.plot(k,freq_theory,'r--',lw=1)
plt.axhline(gamma*Bz/(2*np.pi),c='g',ls='--',lw=1)

plt.xlim([-6e8,6e8])
plt.ylim([-fmax,fmax])
plt.ylabel("$f$ (Hz)")
plt.xlabel("$k$ (1/m)")

plt.show()