#@title Figure 그리기
def box_with_points(ax, values, title, ylabel,
specials=None, jitter=0.05, ms=3):
vals = np.asarray(values, dtype=float)
x0 = 1.0
ax.boxplot([vals], positions=[x0], widths=[0.35], vert=True,
showmeans=False,
medianprops=dict(color='none'))
if specials:
color_map = {"parity": "red", "magnitude": "navy"}
offsets = [0, 0]
for k, item in enumerate(specials):
lab, v = item[0], float(item[1])
ax.plot([x0 + offsets[k % len(offsets)]], [v],
marker='o', linestyle='none',
markersize=ms, color=color_map.get(lab, "black"),
label=f"{lab} = {v:.3f}", zorder=4)
ax.legend(loc="best", fontsize=8, frameon=False)
ax.set_xlim(0.6, 1.4)
ax.set_ylim(0.0, 1.0)
ax.set_xticks([x0]); ax.set_xticklabels(["all balanced dichotomies"])
ax.set_title(title); ax.set_ylabel(ylabel)
fig, axes = plt.subplots(1, 3, figsize=(13, 4), dpi=140)
box_with_points(axes[0], sd_vals,
title="SD (balanced dichotomies)",
ylabel="Accuracy")
box_with_points(axes[1], ccgp_all,
title="CCGP (balanced dichotomies)",
ylabel="CCGP",
specials=[("parity", ccgp_par, "--"),
("magnitude", ccgp_mag, ":")])
box_with_points(axes[2], ps_all,
title="PS (balanced dichotomies)",
ylabel="PS (mean pairwise cosine)",
specials=[("parity", ps_par, "--"),
("magnitude", ps_mag, ":")])
plt.tight_layout()
plt.show()
@torch.no_grad()
def collect_for_mds(loader, per_digit=60):
model.eval()
buckets = {d:0 for d in range(1,9)}
X_pix, X_h1, X_h2, Y = [], [], [], []
for imgs, labels in loader:
imgs = imgs.to(device)
_, _, pix, h1, h2 = model(imgs, return_reprs=True)
pix, h1, h2 = pix.cpu().numpy(), h1.cpu().numpy(), h2.cpu().numpy()
labels = labels.numpy()
for i in range(len(labels)):
d = int(labels[i])
if 1 <= d <= 8 and buckets[d] < per_digit:
X_pix.append(pix[i]); X_h1.append(h1[i]); X_h2.append(h2[i]); Y.append(d)
buckets[d] += 1
if all(v>=per_digit for v in buckets.values()):
break
return np.stack(X_pix), np.stack(X_h1), np.stack(X_h2), np.array(Y)
X_E, X_F, X_G, y_d = collect_for_mds(train_loader, per_digit=60)
def run_mds(X, random_state=0, max_iter=300):
return MDS(n_components=2, metric=True, random_state=random_state,
n_init=1, max_iter=max_iter, dissimilarity='euclidean').fit_transform(X)
Z_E, Z_F, Z_G = run_mds(X_E), run_mds(X_F), run_mds(X_G)
def _mds_color_alpha_for_digit(d):
color = "#1f77b4" if (d % 2 == 0) else "#ff7f0e" # even blue / odd orange
alpha = 0.4 if (d >= 5) else 1.0 # big lighter / small solid
return color, alpha
def panel_with_numbers(ax, Z, y_digits, title):
x, y = Z[:,0], Z[:,1]
ax.scatter(x, y, s=0) # autoscale trick
for i in range(len(y_digits)):
d = int(y_digits[i])
c, a = _mds_color_alpha_for_digit(d)
ax.text(x[i], y[i], str(d), ha="center", va="center", fontsize=8, color=c, alpha=a)
pad_x = 0.05*(x.max()-x.min()+1e-9); pad_y = 0.05*(y.max()-y.min()+1e-9)
ax.set_xlim(x.min()-pad_x, x.max()+pad_x); ax.set_ylim(y.min()-pad_y, y.max()+pad_y)
ax.set_aspect("equal", adjustable="datalim")
ax.set_title(title); ax.set_xticks([]); ax.set_yticks([])
fig, axes = plt.subplots(1,3, figsize=(14,4), dpi=140)
panel_with_numbers(axes[0], Z_E, y_d, "E) Input (pixels)")
panel_with_numbers(axes[1], Z_F, y_d, "F) Hidden layer 1")
panel_with_numbers(axes[2], Z_G, y_d, "G) Hidden layer 2")
plt.tight_layout()
plt.show()