import cv2 import numpy as np import matplotlib.pyplot as plt import os import sys import glob # Add project root to path sys.path.append(os.getcwd()) from model_data_processing.Crop_And_Fill import Crop_Result def ensure_dir(path): if not os.path.exists(path): os.makedirs(path) def process_and_analyze(input_root, output_root, limit=None): """ Process training data: 1. Crop images using Crop_Result. 2. Analyze cropped images for black pixels. 3. Generate statistics and marked images. """ # Find all class directories class_dirs = [d for d in os.listdir(input_root) if os.path.isdir(os.path.join(input_root, d))] if not class_dirs: print(f"No class directories found in {input_root}") return print(f"Found classes: {class_dirs}") all_stats = {} # {class_name: {quadrant: count}} for class_name in class_dirs: print(f"\nProcessing Class: {class_name}") class_input_dir = os.path.join(input_root, class_name) # Prepare Output Directories class_output_dir = os.path.join(output_root, class_name) cropped_dir = os.path.join(class_output_dir, "Cropped") marked_dir = os.path.join(class_output_dir, "Marked") ensure_dir(cropped_dir) ensure_dir(marked_dir) # Get Image Paths # Supports jpg, png, jpeg image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp'] image_paths = [] for ext in image_extensions: image_paths.extend(glob.glob(os.path.join(class_input_dir, ext))) # Also checking lowercase/uppercase if linux, but windows is insensitive. # glob is case sensitive on some platforms, let's be safe image_paths.extend(glob.glob(os.path.join(class_input_dir, ext.upper()))) # Remove duplicates if any image_paths = sorted(list(set(image_paths))) if not image_paths: print(f" No images found in {class_name}") continue if limit: print(f" Limiting to {limit} images (found {len(image_paths)})") image_paths = image_paths[:limit] else: print(f" Processing all {len(image_paths)} images") # 1. Run Crop_Result # Crop_Result takes a list of paths and an output directory print(" Running Crop_Result...") try: # Check if Crop_Result handles empty list? We checked image_paths is not empty. Crop_Result(image_paths, cropped_dir) except Exception as e: print(f" Error cropping {class_name}: {e}") continue # 2. Analyze Cropped Images print(" Analyzing cropped images...") class_stats = { 'Top-Left': 0, 'Top-Right': 0, 'Bottom-Left': 0, 'Bottom-Right': 0 } processed_files = os.listdir(cropped_dir) for filename in processed_files: file_path = os.path.join(cropped_dir, filename) img = cv2.imread(file_path) if img is None: continue # Detect Black Pixels # Logic: Convert to Gray, Threshold < 30 if len(img.shape) == 3: gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) else: gray = img # Create mask: White where pixel < 30 (Black-ish), Black otherwise _, mask = cv2.threshold(gray, 30, 255, cv2.THRESH_BINARY_INV) # Define Corner ROIs (e.g., 20% of image size) h, w = mask.shape corner_ratio = 0.2 h_roi = int(h * corner_ratio) w_roi = int(w * corner_ratio) # Create a mask for valid ROIs only (Top-Left, Top-Right, Bottom-Left, Bottom-Right) roi_mask = np.zeros_like(mask) # Top-Left ROI roi_mask[0:h_roi, 0:w_roi] = 255 # Top-Right ROI roi_mask[0:h_roi, w-w_roi:w] = 255 # Bottom-Left ROI roi_mask[h-h_roi:h, 0:w_roi] = 255 # Bottom-Right ROI roi_mask[h-h_roi:h, w-w_roi:w] = 255 # Apply ROI mask to the black pixel mask # We only care about black pixels (mask==255) that are INSIDE the ROI (roi_mask==255) final_mask = cv2.bitwise_and(mask, roi_mask) # Count pixels in specific corner ROIs tl = final_mask[0:h_roi, 0:w_roi] tr = final_mask[0:h_roi, w-w_roi:w] bl = final_mask[h-h_roi:h, 0:w_roi] br = final_mask[h-h_roi:h, w-w_roi:w] class_stats['Top-Left'] += cv2.countNonZero(tl) class_stats['Top-Right'] += cv2.countNonZero(tr) class_stats['Bottom-Left'] += cv2.countNonZero(bl) class_stats['Bottom-Right'] += cv2.countNonZero(br) # Mark the black pixels on the image for visualization marked_img = img.copy() # Set pixels where final_mask is 255 to Red (0, 0, 255) # This only highlights black pixels within the corner ROIs marked_img[final_mask == 255] = [0, 0, 255] # Draw rectangles to show the ROI boundaries green = (0, 255, 0) thickness = 2 # Top-Left Rect cv2.rectangle(marked_img, (0, 0), (w_roi, h_roi), green, thickness) # Top-Right Rect cv2.rectangle(marked_img, (w-w_roi, 0), (w, h_roi), green, thickness) # Bottom-Left Rect cv2.rectangle(marked_img, (0, h-h_roi), (w_roi, h), green, thickness) # Bottom-Right Rect cv2.rectangle(marked_img, (w-w_roi, h-h_roi), (w, h), green, thickness) # Add labels for quadrants font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.5 color = (0, 255, 0) text_thickness = 1 cv2.putText(marked_img, 'TL', (10, 20), font, font_scale, color, text_thickness) cv2.putText(marked_img, 'TR', (w-w_roi+10, 20), font, font_scale, color, text_thickness) cv2.putText(marked_img, 'BL', (10, h-h_roi+20), font, font_scale, color, text_thickness) cv2.putText(marked_img, 'BR', (w-w_roi+10, h-h_roi+20), font, font_scale, color, text_thickness) # Save marked image cv2.imwrite(os.path.join(marked_dir, filename), marked_img) all_stats[class_name] = class_stats print(f" Stats for {class_name}: {class_stats}") # Aggregate stats by class (Sum of all 4 corners for each class) class_totals = {} for class_name, stats in all_stats.items(): total_pixels = sum(stats.values()) class_totals[class_name] = total_pixels print(f"\nClass Totals: {class_totals}") # Generate Combined Charts plot_charts(class_totals, output_root, "Class_Comparison") print("\nProcessing Complete.") return all_stats def plot_charts(counts, output_dir, title_prefix): labels = list(counts.keys()) values = list(counts.values()) if sum(values) == 0: print(f" No black pixels detected for {title_prefix}, skipping charts.") return # Create a figure with 2 subplots fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6)) # Bar Chart colors = ['#FF9999', '#66B2FF', '#99FF99', '#FFCC99'] bars = ax1.bar(labels, values, color=colors) ax1.set_title(f'{title_prefix} - Black Pixel Count (Threshold < 30)') ax1.set_ylabel('Count') for bar in bars: height = bar.get_height() ax1.text(bar.get_x() + bar.get_width()/2., height, f'{height:,}', ha='center', va='bottom') # Pie Chart ax2.pie(values, labels=labels, autopct='%1.1f%%', startangle=90, colors=colors) ax2.set_title(f'{title_prefix} - Distribution') plt.tight_layout() chart_path = os.path.join(output_dir, f'{title_prefix}_analysis.png') plt.savefig(chart_path) plt.close() def main(): # Define Paths # Adjust these if necessary TRAIN_ROOT = os.path.join(os.getcwd(), 'Dataset', 'Training') OUTPUT_ROOT = os.path.join(os.getcwd(), 'Dataset', 'Training_Crop_Analysis') if not os.path.exists(TRAIN_ROOT): # Try finding it relative to Pytorch folder if script is run from there TRAIN_ROOT = os.path.abspath(os.path.join('..', 'Dataset', 'Training')) if not os.path.exists(TRAIN_ROOT): # Fallback to hardcoded path seen in previous LS TRAIN_ROOT = r'd:\Programing\stomach_cancer\Dataset\Training' # Limit to 20 images per class for demonstration speed # Set limit=None to process all process_and_analyze(TRAIN_ROOT, OUTPUT_ROOT) print(f"Input Root: {TRAIN_ROOT}") print(f"Output Root: {OUTPUT_ROOT}") if __name__ == "__main__": main()