Files
2024-12-07 02:00:39 +08:00

92 lines
3.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from Load_process.file_processing import Process_File
from keras.models import Model
from matplotlib import pyplot as plt
import cv2
import numpy as np
from keras import backend as K
from keras.preprocessing import image
import tensorflow as tf
import datetime
class Grad_CAM:
def __init__(self, Label, One_Hot, Experiment_Name, Layer_Name) -> None:
self.experiment_name = Experiment_Name
self.Layer_Name = Layer_Name
self.Label = Label
self.One_Hot_Label = One_Hot
self.Save_File_Name = self.Convert_One_Hot_To_int()
pass
def process_main(self, model, index, images):
for i in range(len(images)):
array = np.expand_dims(images[i], axis=0) # 替圖片增加一個維度,代表他的數量
heatmap = self.gradcam(array, model)
self.plot_heatmap(heatmap, images[i], self.Save_File_Name[i], index, i)
pass
def Convert_One_Hot_To_int(self):
return [np.argmax(Label)for Label in self.One_Hot_Label]
def gradcam(self, Image, model, pred_index = None):
# 首先,我們創建了一個模型,將輸入圖像映射為最後一個卷積層的激活值以及輸出的預測結果。
grad_model = Model(
[model.inputs], [model.get_layer(self.Layer_Name).output, model.output]
)
# 然後,我們計算了對於輸入圖像而言預測類別的最高梯度,並相對於最後一個卷積層的激活值進行了計算。
with tf.GradientTape() as tape: # 創建一個梯度紀錄器,並做前向傳播
last_conv_layer_output, preds = grad_model(Image)
if pred_index is None:
pred_index = tf.argmax(preds[0])
class_channel = preds[:, pred_index]
# 這是輸出神經元(預測的頂部或所選的)對於最後一個卷積層的輸出特徵圖的梯度。
grads = tape.gradient(class_channel, last_conv_layer_output)
# 這是一個向量,其中每個項目是在特定特徵圖通道上梯度的平均強度。
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
# 我們將特徵圖陣列中的每個通道乘以「這個通道相對於頂部預測類別的重要性」然後將所有通道加總起來以獲得熱圖類別激活。n
last_conv_layer_output = last_conv_layer_output[0]
heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
heatmap = tf.squeeze(heatmap)
# For visualization purpose, we will also normalize the heatmap between 0 & 1
heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
return heatmap.numpy()
def plot_heatmap(self, heatmap, img, Label, index, Title):
File = Process_File()
# ReLU
heatmap = np.maximum(heatmap, 0)
# 正規化
heatmap /= np.max(heatmap)
# 讀取影像
# img = cv2.imread(img)
fig, ax = plt.subplots()
# im = cv2.resize(cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB), (img.shape[1], img.shape[0]))
# 拉伸 heatmap
img_path = cv2.resize(img, (512, 512))
heatmap = cv2.resize(heatmap, (512, 512))
heatmap = np.uint8(255 * heatmap)
img_path = cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB)
# 以 0.6 透明度繪製原始影像
ax.imshow(img_path, alpha=1)
# 以 0.4 透明度繪製熱力圖
ax.imshow(heatmap, cmap='jet', alpha=0.3)
save_root = '../Result/CNN_result_of_reading('+ str(datetime.date.today()) + " )/" + str(Label)
File.JudgeRoot_MakeDir(save_root)
save_root = File.Make_Save_Root(self.experiment_name + "-" + str(index) + "-" + str(Title) + ".png", save_root)
# 存plt檔
plt.savefig(save_root)
plt.close("all") # 關閉圖表
pass