Stomach_Cancer_Pytorch/analyze_training_full.py

244 lines
9.2 KiB
Python

import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import glob
# Add project root to path
sys.path.append(os.getcwd())
from model_data_processing.Crop_And_Fill import Crop_Result
def ensure_dir(path):
if not os.path.exists(path):
os.makedirs(path)
def process_and_analyze(input_root, output_root, limit=None):
"""
Process training data:
1. Crop images using Crop_Result.
2. Analyze cropped images for black pixels.
3. Generate statistics and marked images.
"""
# Find all class directories
class_dirs = [d for d in os.listdir(input_root) if os.path.isdir(os.path.join(input_root, d))]
if not class_dirs:
print(f"No class directories found in {input_root}")
return
print(f"Found classes: {class_dirs}")
all_stats = {} # {class_name: {quadrant: count}}
for class_name in class_dirs:
print(f"\nProcessing Class: {class_name}")
class_input_dir = os.path.join(input_root, class_name)
# Prepare Output Directories
class_output_dir = os.path.join(output_root, class_name)
cropped_dir = os.path.join(class_output_dir, "Cropped")
marked_dir = os.path.join(class_output_dir, "Marked")
ensure_dir(cropped_dir)
ensure_dir(marked_dir)
# Get Image Paths
# Supports jpg, png, jpeg
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp']
image_paths = []
for ext in image_extensions:
image_paths.extend(glob.glob(os.path.join(class_input_dir, ext)))
# Also checking lowercase/uppercase if linux, but windows is insensitive.
# glob is case sensitive on some platforms, let's be safe
image_paths.extend(glob.glob(os.path.join(class_input_dir, ext.upper())))
# Remove duplicates if any
image_paths = sorted(list(set(image_paths)))
if not image_paths:
print(f" No images found in {class_name}")
continue
if limit:
print(f" Limiting to {limit} images (found {len(image_paths)})")
image_paths = image_paths[:limit]
else:
print(f" Processing all {len(image_paths)} images")
# 1. Run Crop_Result
# Crop_Result takes a list of paths and an output directory
print(" Running Crop_Result...")
try:
# Check if Crop_Result handles empty list? We checked image_paths is not empty.
Crop_Result(image_paths, cropped_dir)
except Exception as e:
print(f" Error cropping {class_name}: {e}")
continue
# 2. Analyze Cropped Images
print(" Analyzing cropped images...")
class_stats = {
'Top-Left': 0,
'Top-Right': 0,
'Bottom-Left': 0,
'Bottom-Right': 0
}
processed_files = os.listdir(cropped_dir)
for filename in processed_files:
file_path = os.path.join(cropped_dir, filename)
img = cv2.imread(file_path)
if img is None:
continue
# Detect Black Pixels
# Logic: Convert to Gray, Threshold < 30
if len(img.shape) == 3:
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
else:
gray = img
# Create mask: White where pixel < 30 (Black-ish), Black otherwise
_, mask = cv2.threshold(gray, 30, 255, cv2.THRESH_BINARY_INV)
# Define Corner ROIs (e.g., 20% of image size)
h, w = mask.shape
corner_ratio = 0.2
h_roi = int(h * corner_ratio)
w_roi = int(w * corner_ratio)
# Create a mask for valid ROIs only (Top-Left, Top-Right, Bottom-Left, Bottom-Right)
roi_mask = np.zeros_like(mask)
# Top-Left ROI
roi_mask[0:h_roi, 0:w_roi] = 255
# Top-Right ROI
roi_mask[0:h_roi, w-w_roi:w] = 255
# Bottom-Left ROI
roi_mask[h-h_roi:h, 0:w_roi] = 255
# Bottom-Right ROI
roi_mask[h-h_roi:h, w-w_roi:w] = 255
# Apply ROI mask to the black pixel mask
# We only care about black pixels (mask==255) that are INSIDE the ROI (roi_mask==255)
final_mask = cv2.bitwise_and(mask, roi_mask)
# Count pixels in specific corner ROIs
tl = final_mask[0:h_roi, 0:w_roi]
tr = final_mask[0:h_roi, w-w_roi:w]
bl = final_mask[h-h_roi:h, 0:w_roi]
br = final_mask[h-h_roi:h, w-w_roi:w]
class_stats['Top-Left'] += cv2.countNonZero(tl)
class_stats['Top-Right'] += cv2.countNonZero(tr)
class_stats['Bottom-Left'] += cv2.countNonZero(bl)
class_stats['Bottom-Right'] += cv2.countNonZero(br)
# Mark the black pixels on the image for visualization
marked_img = img.copy()
# Set pixels where final_mask is 255 to Red (0, 0, 255)
# This only highlights black pixels within the corner ROIs
marked_img[final_mask == 255] = [0, 0, 255]
# Draw rectangles to show the ROI boundaries
green = (0, 255, 0)
thickness = 2
# Top-Left Rect
cv2.rectangle(marked_img, (0, 0), (w_roi, h_roi), green, thickness)
# Top-Right Rect
cv2.rectangle(marked_img, (w-w_roi, 0), (w, h_roi), green, thickness)
# Bottom-Left Rect
cv2.rectangle(marked_img, (0, h-h_roi), (w_roi, h), green, thickness)
# Bottom-Right Rect
cv2.rectangle(marked_img, (w-w_roi, h-h_roi), (w, h), green, thickness)
# Add labels for quadrants
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.5
color = (0, 255, 0)
text_thickness = 1
cv2.putText(marked_img, 'TL', (10, 20), font, font_scale, color, text_thickness)
cv2.putText(marked_img, 'TR', (w-w_roi+10, 20), font, font_scale, color, text_thickness)
cv2.putText(marked_img, 'BL', (10, h-h_roi+20), font, font_scale, color, text_thickness)
cv2.putText(marked_img, 'BR', (w-w_roi+10, h-h_roi+20), font, font_scale, color, text_thickness)
# Save marked image
cv2.imwrite(os.path.join(marked_dir, filename), marked_img)
all_stats[class_name] = class_stats
print(f" Stats for {class_name}: {class_stats}")
# Aggregate stats by class (Sum of all 4 corners for each class)
class_totals = {}
for class_name, stats in all_stats.items():
total_pixels = sum(stats.values())
class_totals[class_name] = total_pixels
print(f"\nClass Totals: {class_totals}")
# Generate Combined Charts
plot_charts(class_totals, output_root, "Class_Comparison")
print("\nProcessing Complete.")
return all_stats
def plot_charts(counts, output_dir, title_prefix):
labels = list(counts.keys())
values = list(counts.values())
if sum(values) == 0:
print(f" No black pixels detected for {title_prefix}, skipping charts.")
return
# Create a figure with 2 subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
# Bar Chart
colors = ['#FF9999', '#66B2FF', '#99FF99', '#FFCC99']
bars = ax1.bar(labels, values, color=colors)
ax1.set_title(f'{title_prefix} - Black Pixel Count (Threshold < 30)')
ax1.set_ylabel('Count')
for bar in bars:
height = bar.get_height()
ax1.text(bar.get_x() + bar.get_width()/2., height,
f'{height:,}',
ha='center', va='bottom')
# Pie Chart
ax2.pie(values, labels=labels, autopct='%1.1f%%', startangle=90, colors=colors)
ax2.set_title(f'{title_prefix} - Distribution')
plt.tight_layout()
chart_path = os.path.join(output_dir, f'{title_prefix}_analysis.png')
plt.savefig(chart_path)
plt.close()
def main():
# Define Paths
# Adjust these if necessary
TRAIN_ROOT = os.path.join(os.getcwd(), 'Dataset', 'Training')
OUTPUT_ROOT = os.path.join(os.getcwd(), 'Dataset', 'Training_Crop_Analysis')
if not os.path.exists(TRAIN_ROOT):
# Try finding it relative to Pytorch folder if script is run from there
TRAIN_ROOT = os.path.abspath(os.path.join('..', 'Dataset', 'Training'))
if not os.path.exists(TRAIN_ROOT):
# Fallback to hardcoded path seen in previous LS
TRAIN_ROOT = r'd:\Programing\stomach_cancer\Dataset\Training'
# Limit to 20 images per class for demonstration speed
# Set limit=None to process all
process_and_analyze(TRAIN_ROOT, OUTPUT_ROOT)
print(f"Input Root: {TRAIN_ROOT}")
print(f"Output Root: {OUTPUT_ROOT}")
if __name__ == "__main__":
main()