# -*- coding: utf-8 -*-
"""
plot.py
図形出力
"""

import numpy as np
import matplotlib.pyplot as plt

Y0 = 10000

# 年データプロット(共通関数, 複数データを1図に上書きする)
def fyear(figtitle, figsize, idata0, idata1, f, n0data, n1data, n0, n1, xdiv, ymin, ymax, ydiv, label, title):
    plt.figure(figtitle, figsize=figsize, layout='tight')
    ax = plt.subplot()
    assert(f.ndim == 2)

    # 年
    n0 += Y0
    n1 += Y0

    # X軸
    x = np.arange(n1 + 1)

    # データ
    lw = 1  # 線幅
    for idata in range(idata0, idata1 + 1):
        m0 = max(n0, n0data[idata])
        m1 = min(n1, n1data[idata])
        if label is not None:
            ax.plot(x[m0: m1 + 1], f[idata][m0: m1 + 1], label=label[idata], lw=lw)
        else:
            ax.plot(x[m0: m1 + 1], f[idata][m0: m1 + 1], lw=lw)

    # グリッド
    ax.grid(True)

    # 凡例
    if label is not None:
        ax.legend(loc='best')
    
    # 横軸
    div = (n1 - n0) // xdiv
    ax.set_xlim(n0, n1)
    ax.set_xticks(np.linspace(n0, n1, div + 1), labels=np.linspace(n0 - Y0, n1 - Y0, div + 1, dtype=np.int32))

    # 縦軸
    ax.set_ylim(ymin, ymax)
    ax.set_yticks(np.linspace(ymin, ymax, ydiv + 1))

    # タイトル
    if title is not None:
        ax.set_title(title)

    plt.show()

# 散布図
def scatter(figtitle, figsize, idata0, idata1, b, t, n0data, n1data, fmin, fmax, fdiv, marker):
    plt.figure(figtitle, figsize=figsize, layout='tight')

    idata1 = min(idata1, idata0 + 7)   # データ数は最大8個
    ndata = idata1 - idata0 + 1        # データ数
    assert(ndata > 1)
    nfig = (ndata * (ndata - 1)) // 2  # 図数(右上三角)
    mfig = [[1, 1], [1, 3], [2, 3], [3, 4], [3, 5], [4, 6], [4, 7]]  #  図の縦横数
    assert(len(mfig) > ndata - 2)
    my = mfig[ndata - 2][0]  # 縦の図数
    mx = mfig[ndata - 2][1]  # 横の図数

    # 1組データ=1図
    ifig = 1
    for idata in range(idata0, idata1 + 1):
        for jdata in range(idata + 1, idata1 + 1):
            n0 = max(n0data[idata], n0data[jdata])  # 開始年(大きい方)
            n1 = min(n1data[idata], n1data[jdata])  # 終了年(小さい方)

            # 図作成
            ax = plt.subplot(my, mx, ifig)

            # 散布図
            ax.scatter(b[idata, n0: n1 + 1], b[jdata, n0: n1 + 1], s=marker)

            # 横軸
            ax.set_xlim(fmin, fmax)
            ax.set_xticks(np.linspace(fmin, fmax, fdiv + 1))
    
            # 縦軸
            ax.set_ylim(fmin, fmax)
            ax.set_yticks(np.linspace(fmin, fmax, fdiv + 1))

            # アスペクト比=1
            ax.set_aspect('equal')

            # タイトル
            title = '(%d, %d) n=%d t=%.2f' % (idata + 1, jdata + 1, n1 - n0 + 1, t[idata][jdata])
            ax.set_title(title, fontsize=10, pad=0)

            # 図数終了判定
            ifig += 1
            if (ifig > nfig):
                break

    plt.show()

# r/tマップ（全データ1図）
def map(figtitle, figsize, ndata, r, t, rmin, rmax, tmin, tmax):
    fig = plt.figure(figtitle, figsize=figsize, layout='tight')

    # 対角要素は0にする
    for i in range(ndata):
        r[i, i] = t[i, i] = 0

    # r/t平均
    ndmask = (np.identity(ndata) == 0)  # 非対角マスク
    rmean = np.mean(r, where=ndmask)
    tmean = np.mean(t, where=ndmask)

    # r
    ax1 = plt.subplot(1, 2, 1)
    im1 = ax1.imshow(r, vmin=rmin, vmax=rmax, interpolation='auto', cmap='rainbow', origin='upper')
    ax1.tick_params(top=True, labeltop=True)  # 目盛り上
    ax1.tick_params(bottom=False, labelbottom=False)
    ax1.set_title('r data=%d average=%.3f' % (ndata, rmean))
    fig.colorbar(im1, location='bottom')

    # t
    ax2 = plt.subplot(1, 2, 2)
    im2 = ax2.imshow(t, vmin=tmin, vmax=tmax, interpolation='auto', cmap='rainbow', origin='upper')
    ax2.tick_params(top=True, labeltop=True)  # 目盛り上
    ax2.tick_params(bottom=False, labelbottom=False)
    ax2.set_title('t data=%d average=%.3f' % (ndata, tmean))
    fig.colorbar(im2, location='bottom')

    plt.show()

# r/tヒストグラム（全データ1図）
def histogram(figtitle, figsize, ndata, r, t, bins, rmin, rmax, tmin, tmax):
    plt.figure(figtitle, figsize=figsize, layout='tight')

    # 右上三角成分(対角成分は除く)
    rdata = []
    tdata = []
    for i in range(ndata):
        for j in range(i + 1, ndata):
            rdata.append(r[i, j])
            tdata.append(t[i, j])

    # r
    ax1 = plt.subplot(1, 2, 1)
    ax1.hist(rdata, bins=bins, range=(rmin, rmax))
    ax1.set_title('data=%d n=%d max=%.3f' % (ndata, len(rdata), np.max(rdata)))
    ax1.set_xlim(rmin, rmax)
    ax1.set_xlabel('r')

    # t
    ax2 = plt.subplot(1, 2, 2)
    ax2.hist(tdata, bins=bins, range=(tmin, tmax))
    ax2.set_title('data=%d n=%d max=%.3f' % (ndata, len(tdata), np.max(tdata)))
    ax2.set_xlim(tmin, tmax)
    ax2.set_xlabel('t')

    plt.show()

# 照合（1図）
def compare(figtitle, figsize, cmp, rt, cmpfn, ymax):
    assert (cmp is not None) and (len(cmp) > 0) and (cmpfn is not None)
    assert (rt == 'r') or (rt == 't')
    plt.figure(figtitle, figsize=figsize, layout='tight')
    ax = plt.subplot()

    # X,Yデータ
    x = [item[0] for item in cmp]
    y = [item[2 if rt == 'r' else 3] for item in cmp]

    # プロット
    ax.plot(x, y)

    # 横軸
    ax.set_xlim(np.min(x), np.max(x))
    ax.set_xlabel('offset')

    # 縦軸
    if ymax is not None:
        ax.set_ylim(0, ymax)
    else:
        ax.set_ylim(bottom=0)

    # グリッド
    ax.grid(True)

    # タイトル
    if cmpfn is not None:
        title = 'compare %s: %s.txt, %s.txt (center=%d, max=%d, %.3f)' % (rt, cmpfn[0], cmpfn[1], cmpfn[2], np.min(x) + np.argmax(y), np.max(y))
        ax.set_title(title)

    plt.show()

# 標準化データヒストグラム（1図）
def histo2(figtitle, figsize, f, bins, xmin, xmax):
    plt.figure(figtitle, figsize=figsize, layout='tight')
    ax = plt.subplot()

    # 平均、標準偏差
    mean = np.mean(f)
    std = np.std(f)

    # ヒストグラム
    n = ax.hist(f, bins=bins, range=(xmin, xmax))

    # 横軸
    ax.set_xlim(xmin, xmax)
    ax.set_xlabel('X_std')

    # タイトル
    ax.set_title('data=%d mean=%.4f std=%.4f' % (len(f), mean, std))

    # 正規分布グラフ
    ymax = np.max(n[0])
    x = np.linspace(xmin, xmax, 101)
    y = ymax * np.exp(-(x - mean) * (x - mean) / (2 * std * std))
    ax.plot(x, y)

    plt.show()

# データ年代統計（2図）
def datastatis(figtitle, figsize, ndata, n0data, n1data, n0, n1, xdiv):
    plt.figure(figtitle, figsize, layout='tight')

    # (1) データの期間を横バーで表示する
    ax1 = plt.subplot(2, 1, 1)

    # データ
    bmap = np.zeros((ndata, 2 * Y0), dtype=np.float32)  # 0/1
    for idata in range(ndata):
        bmap[idata, n0data[idata]: n1data[idata] + 1] = 1

    # 等高線
    ax1.imshow(bmap[:, n0 + Y0: n1 + 1 + Y0], cmap='gray_r', vmin=0, vmax=1, interpolation='none', origin='lower')

    # アスペクト比: ウィンドウサイズに合わせる
    ax1.set_aspect('auto')

    # グリッド
    ax1.grid(axis='x')

    # 横軸
    div = (n1 - n0) // xdiv
    ax1.set_xlim(0, n1 - n0)
    ax1.set_xticks(np.linspace(0, n1 - n0, div + 1), labels=np.linspace(n0, n1, div + 1, dtype=np.int32))

    # 縦軸
    ax1.set_ylim(-0.5, ndata - 0.5)
    ax1.set_ylabel('data number')

    # タイトル
    title = 'data=%d year=(%d, %d)' % (ndata, np.min(n0data) - Y0, np.max(n1data - Y0))
    ax1.set_title(title)

    # (2) 年度別データ数ヒストグラム
    ax2 = plt.subplot(2, 1, 2)

    # データ数
    num = np.zeros(2 * Y0, dtype=np.int32)
    for idata in range(ndata):
        num[n0data[idata]: n1data[idata] + 1] += 1

    # ヒストグラム
    xmin = np.min(n0data)
    xmax = np.max(n1data)
    x = np.linspace(xmin, xmax, xmax - xmin + 1) - Y0
    ax2.plot(x, num[xmin: xmax + 1])

    # アスペクト比: ウィンドウサイズに合わせる
    ax1.set_aspect('auto')

    # グリッド
    ax2.grid(True)

    # 横軸
    div = (n1 - n0) // xdiv
    ax2.set_xlim(n0, n1)
    ax2.set_xticks(np.linspace(n0, n1, div + 1), labels=np.linspace(n0, n1, div + 1, dtype=np.int32))

    # 縦軸
    ax2.set_ylim(bottom=0)
    ax2.set_ylabel('number of data')

    # タイトル
    title = 'data=%d max=%d total=%d' % (ndata, np.max(num), np.sum(num))
    ax2.set_title(title)

    plt.show()
