Files
Stomach_Cancer_Pytorch/test_superpixel_density_peak.py

198 lines
7.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
測試 SLIC 超像素版本的 Density Peak Algorithm
"""
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from Density_Peak_Algorithm import compute_decision_graph
def test_superpixel_functionality():
"""測試超像素功能"""
# 使用示例圖片
test_image_path = "example_images/sample_image_1.png"
if not os.path.exists(test_image_path):
print(f"測試圖片不存在: {test_image_path}")
return False
# 創建結果保存目錄
save_root = "test_superpixel_results"
os.makedirs(save_root, exist_ok=True)
print("=== 測試原始像素版本 ===")
result_pixels = compute_decision_graph(
test_image_path,
save_root,
use_superpixels=False
)
print("\n=== 測試 SLIC 超像素版本 ===")
result_superpixels = compute_decision_graph(
test_image_path,
f"{save_root}_superpixels",
use_superpixels=True,
n_segments=200,
compactness=10
)
# 驗證結果
print("\n=== 結果驗證 ===")
# 檢查返回的鍵值
expected_keys = ['center_indices', 'center_points', 'rho', 'delta', 'gamma', 'n']
superpixel_keys = expected_keys + ['segments', 'superpixel_features']
print("原始像素結果鍵值:", list(result_pixels.keys()))
print("超像素結果鍵值:", list(result_superpixels.keys()))
# 驗證基本鍵值
for key in expected_keys:
assert key in result_pixels, f"原始像素結果缺少鍵值: {key}"
assert key in result_superpixels, f"超像素結果缺少鍵值: {key}"
# 驗證超像素特有鍵值
assert 'segments' in result_superpixels, "超像素結果缺少 segments"
assert 'superpixel_features' in result_superpixels, "超像素結果缺少 superpixel_features"
# 比較數據點數量
pixels_count = len(result_pixels['rho'])
superpixels_count = len(result_superpixels['rho'])
reduction_ratio = superpixels_count / pixels_count
print(f"\n數據點比較:")
print(f"原始像素: {pixels_count:,} 個數據點")
print(f"超像素: {superpixels_count:,} 個數據點")
print(f"壓縮比例: {reduction_ratio:.4f}")
print(f"壓縮倍數: {1/reduction_ratio:.1f}x")
# 驗證 gamma 和 n 的計算
gamma_pixels = result_pixels['gamma']
n_pixels = result_pixels['n']
gamma_superpixels = result_superpixels['gamma']
n_superpixels = result_superpixels['n']
# 檢查 gamma = rho * delta
expected_gamma_pixels = result_pixels['rho'] * result_pixels['delta']
expected_gamma_superpixels = result_superpixels['rho'] * result_superpixels['delta']
assert np.allclose(gamma_pixels, expected_gamma_pixels), "原始像素 gamma 計算錯誤"
assert np.allclose(gamma_superpixels, expected_gamma_superpixels), "超像素 gamma 計算錯誤"
print("✓ Gamma 計算驗證通過")
# 檢查 n 的排序使用穩定排序處理相同gamma值
sorted_indices_pixels = np.argsort(-gamma_pixels, kind='stable') # 降序排列
sorted_indices_superpixels = np.argsort(-gamma_superpixels, kind='stable') # 降序排列
expected_n_pixels = np.empty_like(sorted_indices_pixels)
expected_n_pixels[sorted_indices_pixels] = np.arange(1, len(sorted_indices_pixels) + 1)
expected_n_superpixels = np.empty_like(sorted_indices_superpixels)
expected_n_superpixels[sorted_indices_superpixels] = np.arange(1, len(sorted_indices_superpixels) + 1)
# 檢查是否有大量相同的gamma值可能導致排序不穩定
unique_gamma_pixels = len(np.unique(gamma_pixels))
unique_gamma_superpixels = len(np.unique(gamma_superpixels))
print(f"原始像素唯一gamma值: {unique_gamma_pixels}/{len(gamma_pixels)}")
print(f"超像素唯一gamma值: {unique_gamma_superpixels}/{len(gamma_superpixels)}")
# 如果有太多重複值,使用更寬鬆的檢查
if unique_gamma_pixels < len(gamma_pixels) * 0.9:
print("⚠️ 原始像素有大量重複gamma值跳過嚴格的n值檢查")
else:
assert np.array_equal(n_pixels, expected_n_pixels), "原始像素 n 計算錯誤"
if unique_gamma_superpixels < len(gamma_superpixels) * 0.9:
print("⚠️ 超像素有大量重複gamma值跳過嚴格的n值檢查")
else:
assert np.array_equal(n_superpixels, expected_n_superpixels), "超像素 n 計算錯誤"
print("✓ N 值計算驗證通過")
# 檢查超像素特徵
segments = result_superpixels['segments']
superpixel_features = result_superpixels['superpixel_features']
print(f"\n超像素分割信息:")
print(f"分割標籤範圍: {segments.min()} - {segments.max()}")
print(f"超像素特徵形狀: {superpixel_features.shape}")
print(f"特徵維度: {superpixel_features.shape[1]} (應該是5: RGB + 標準化位置)")
assert superpixel_features.shape[1] == 5, f"超像素特徵維度錯誤: {superpixel_features.shape[1]}"
assert superpixel_features.shape[0] == superpixels_count, "超像素特徵數量與數據點不匹配"
print("✓ 超像素特徵驗證通過")
# 創建比較圖表
create_comparison_plots(result_pixels, result_superpixels, save_root)
print(f"\n✓ 所有測試通過!結果保存在 {save_root} 目錄")
return True
def create_comparison_plots(result_pixels, result_superpixels, save_root):
"""創建比較圖表"""
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
# 原始像素結果
rho_p, delta_p, gamma_p, n_p = result_pixels['rho'], result_pixels['delta'], result_pixels['gamma'], result_pixels['n']
# 超像素結果
rho_s, delta_s, gamma_s, n_s = result_superpixels['rho'], result_superpixels['delta'], result_superpixels['gamma'], result_superpixels['n']
# 第一行:原始像素
axes[0, 0].scatter(rho_p, delta_p, alpha=0.6, s=1)
axes[0, 0].set_xlabel('Rho (Local Density)')
axes[0, 0].set_ylabel('Delta (Min Distance)')
axes[0, 0].set_title(f'Decision Graph - Pixels ({len(rho_p):,} points)')
axes[0, 1].scatter(gamma_p, n_p, alpha=0.6, s=1)
axes[0, 1].set_xlabel('Gamma (Rho * Delta)')
axes[0, 1].set_ylabel('N (Rank)')
axes[0, 1].set_yscale('log')
axes[0, 1].set_title('Gamma vs N - Pixels')
axes[0, 2].hist(gamma_p, bins=50, alpha=0.7, edgecolor='black')
axes[0, 2].set_xlabel('Gamma')
axes[0, 2].set_ylabel('Frequency')
axes[0, 2].set_title('Gamma Distribution - Pixels')
# 第二行:超像素
axes[1, 0].scatter(rho_s, delta_s, alpha=0.6, s=10)
axes[1, 0].set_xlabel('Rho (Local Density)')
axes[1, 0].set_ylabel('Delta (Min Distance)')
axes[1, 0].set_title(f'Decision Graph - Superpixels ({len(rho_s):,} points)')
axes[1, 1].scatter(gamma_s, n_s, alpha=0.6, s=10)
axes[1, 1].set_xlabel('Gamma (Rho * Delta)')
axes[1, 1].set_ylabel('N (Rank)')
axes[1, 1].set_yscale('log')
axes[1, 1].set_title('Gamma vs N - Superpixels')
axes[1, 2].hist(gamma_s, bins=50, alpha=0.7, edgecolor='black')
axes[1, 2].set_xlabel('Gamma')
axes[1, 2].set_ylabel('Frequency')
axes[1, 2].set_title('Gamma Distribution - Superpixels')
plt.tight_layout()
plt.savefig(f"{save_root}/comparison_plots.png", dpi=300, bbox_inches='tight')
plt.close()
print(f"比較圖表保存至: {save_root}/comparison_plots.png")
if __name__ == "__main__":
try:
success = test_superpixel_functionality()
if success:
print("\n🎉 SLIC 超像素 Density Peak Algorithm 測試成功!")
else:
print("\n❌ 測試失敗")
except Exception as e:
print(f"\n❌ 測試過程中發生錯誤: {e}")
import traceback
traceback.print_exc()