#@title Figure 그리기
def box_with_points(ax, values, title, ylabel,
                    specials=None, jitter=0.05, ms=3):
    vals = np.asarray(values, dtype=float)
    x0 = 1.0
    ax.boxplot([vals], positions=[x0], widths=[0.35], vert=True,
               showmeans=False,
               medianprops=dict(color='none'))
    if specials:
        color_map = {"parity": "red", "magnitude": "navy"}
        offsets = [0, 0]
        for k, item in enumerate(specials):
            lab, v = item[0], float(item[1])
            ax.plot([x0 + offsets[k % len(offsets)]], [v],
                    marker='o', linestyle='none',
                    markersize=ms, color=color_map.get(lab, "black"),
                    label=f"{lab} = {v:.3f}", zorder=4)
        ax.legend(loc="best", fontsize=8, frameon=False)
    ax.set_xlim(0.6, 1.4)
    ax.set_ylim(0.0, 1.0)
    ax.set_xticks([x0]); ax.set_xticklabels(["all balanced dichotomies"])
    ax.set_title(title); ax.set_ylabel(ylabel)
fig, axes = plt.subplots(1, 3, figsize=(13, 4), dpi=140)
box_with_points(axes[0], sd_vals,
                title="SD (balanced dichotomies)",
                ylabel="Accuracy")
box_with_points(axes[1], ccgp_all,
                title="CCGP (balanced dichotomies)",
                ylabel="CCGP",
                specials=[("parity", ccgp_par, "--"),
                          ("magnitude", ccgp_mag, ":")])
box_with_points(axes[2], ps_all,
                title="PS (balanced dichotomies)",
                ylabel="PS (mean pairwise cosine)",
                specials=[("parity", ps_par, "--"),
                          ("magnitude", ps_mag, ":")])
plt.tight_layout()
plt.show()
@torch.no_grad()
def collect_for_mds(loader, per_digit=60):
    model.eval()
    buckets = {d:0 for d in range(1,9)}
    X_pix, X_h1, X_h2, Y = [], [], [], []
    for imgs, labels in loader:
        imgs = imgs.to(device)
        _, _, pix, h1, h2 = model(imgs, return_reprs=True)
        pix, h1, h2 = pix.cpu().numpy(), h1.cpu().numpy(), h2.cpu().numpy()
        labels = labels.numpy()
        for i in range(len(labels)):
            d = int(labels[i])
            if 1 <= d <= 8 and buckets[d] < per_digit:
                X_pix.append(pix[i]); X_h1.append(h1[i]); X_h2.append(h2[i]); Y.append(d)
                buckets[d] += 1
        if all(v>=per_digit for v in buckets.values()):
            break
    return np.stack(X_pix), np.stack(X_h1), np.stack(X_h2), np.array(Y)
X_E, X_F, X_G, y_d = collect_for_mds(train_loader, per_digit=60)
def run_mds(X, random_state=0, max_iter=300):
    return MDS(n_components=2, metric=True, random_state=random_state,
               n_init=1, max_iter=max_iter, dissimilarity='euclidean').fit_transform(X)
Z_E, Z_F, Z_G = run_mds(X_E), run_mds(X_F), run_mds(X_G)
def _mds_color_alpha_for_digit(d):
    color = "#1f77b4" if (d % 2 == 0) else "#ff7f0e"  # even blue / odd orange
    alpha = 0.4 if (d >= 5) else 1.0                  # big lighter / small solid
    return color, alpha
def panel_with_numbers(ax, Z, y_digits, title):
    x, y = Z[:,0], Z[:,1]
    ax.scatter(x, y, s=0)  # autoscale trick
    for i in range(len(y_digits)):
        d = int(y_digits[i])
        c, a = _mds_color_alpha_for_digit(d)
        ax.text(x[i], y[i], str(d), ha="center", va="center", fontsize=8, color=c, alpha=a)
    pad_x = 0.05*(x.max()-x.min()+1e-9); pad_y = 0.05*(y.max()-y.min()+1e-9)
    ax.set_xlim(x.min()-pad_x, x.max()+pad_x); ax.set_ylim(y.min()-pad_y, y.max()+pad_y)
    ax.set_aspect("equal", adjustable="datalim")
    ax.set_title(title); ax.set_xticks([]); ax.set_yticks([])
fig, axes = plt.subplots(1,3, figsize=(14,4), dpi=140)
panel_with_numbers(axes[0], Z_E, y_d, "E) Input (pixels)")
panel_with_numbers(axes[1], Z_F, y_d, "F) Hidden layer 1")
panel_with_numbers(axes[2], Z_G, y_d, "G) Hidden layer 2")
plt.tight_layout()
plt.show()