Gaze in self-supervised learning
In the last work Medical image supervised learning, we use gaze to help computer aided diagnosis. However, CAM are usually used in supervised learning (with annotation).
In many cases, there are a lot of images that we have corresponding gaze but can not get the annotation. We proposed FocusContrast:
import torch
from PIL import Image
import numpy as np
import cv2 as cv
from torchvision import transforms
from typing import Callable, List, Tuple
CONFIG_ZOOMIN = 2
CONFIG_DEGREE = 45
CONFIG_CUTOUT = 48
def decouple_to_numpy(pil_img: Image) -> np.array:
""" This function returns the img and the saliency which are all
SINGLE channel grayscale
Args:
pil_img ([type]): [special made for openselfsup]
Returns:
np.array: [img, 0-1 float], np.array: [img, 0-1 np.float32]
"""
np_input = np.array(pil_img, dtype=np.uint8)
np_img = np_input[:, :, 0]/255
np_saliency = np_input[:, :, 1]/255
return np_img.astype(np.float32), np_saliency.astype(np.float32)
def prob_wrapper(transform_function: Callable, p: float = 1.0) -> Callable:
def prob_transform_func(img: np.array, saliency: np.array, *args, **kwargs):
rand_01 = np.random.random()
if rand_01 < p:
img, saliency = transform_function(img, saliency, *args, **kwargs)
return img, saliency
return prob_transform_func
def compose(list_of_transforms: list) -> Callable:
"""
returns composed transform function which have a special input image of
three channel: [grayscale_org_img, saliency, zeros]
Args:
list_of_transforms (list): [list of transform in Operator and FocusContrastOperator]
Returns:
function: [composed transform, input is a pil image]
"""
def composed_transforms(img, saliency):
for transform_func in list_of_transforms:
img, saliency = transform_func(img, saliency)
return img, saliency
return composed_transforms
def pipeline(transforms_and_probs: List[Tuple[Callable, float]]) -> Callable:
""" Returns the entire process pipeline function
that have a input of pil image and have a output of tensor
Args:
transforms_and_probs (List[Callable, float]): [()]
Returns:
Callable: [description]
"""
prob_transform_func = []
for transform_func, p in transforms_and_probs:
prob_transform_func.append(prob_wrapper(transform_func, p))
composed_func = compose(prob_transform_func)
def inner(pil_img: Image) -> torch.tensor:
img, saliency = decouple_to_numpy(pil_img)
img, saliency = composed_func(img, saliency)
return torch.from_numpy(img.astype(np.float32))
return inner
def torch_function_for_numpy(torch_func: Callable) -> Callable:
""" This function turns a
torch.tensor->torch.tensor function to np.array->np.array function
Args:
torch_func (function): [torch.tensor->torch.tensor]
Raises:
TypeError: [only support 0.0-1.0 image]
Returns:
function: [np.array->np.array]
"""
''' This function turns a
torch.tensor->torch.tensor function to np.array->array function
'''
def numpy_func(np_img: np.array):
assert (np_img.dtype == np.float32 and np.max(np_img) <=
1 and np.min(np_img) >= 0), "image has to be 0-1 and np.float32"
torch_img = torch.from_numpy(np_img)
torch_result = torch_func(torch_img)
np_result = torch_result.numpy()
return np_result
return numpy_func
def same_random_transform_on_both(img: np.array, saliency: np.array, np_random_transform: Callable) -> np.array:
""" Perform transform function (np.array->np.array) on two np image (both 0-1 and 0-255 are supported)
Args:
img (np.array): [0-1 np.array]
saliency (np.array): [0-1 np.array]
np_random_transform (function): [np.array->np.array]
Returns:
np.array: [transformed img and saliency]
"""
w, h = img.shape
coupled_input = np.zeros(shape=(3, w, h), dtype=np.float32)
coupled_input[0, :, :] = img
coupled_input[1, :, :] = saliency
result = np_random_transform(coupled_input)
result_img = result[0]
result_saliency = result[1]
return result_img, result_saliency
class Operator:
""" Augmentation operators that are not work with saliency or any other information
"""
# config = {'degree':}
@ staticmethod
def color_distort(img: np.array, saliency: np.array) -> np.array:
torch_color_distort = transforms.ColorJitter(
brightness=0.2, contrast=0.8)
numpy_color_distort = torch_function_for_numpy(torch_color_distort)
# we need a three channel channel first
img = np.array([img, img, img])
result = numpy_color_distort(img)
return result[0], saliency
@ staticmethod
def random_flip(img, saliency) -> np.array:
torch_random_flip = transforms.RandomHorizontalFlip()
numpy_random_flip = torch_function_for_numpy(torch_random_flip)
result_img, result_saliency = same_random_transform_on_both(
img, saliency, numpy_random_flip)
return result_img, result_saliency
@ staticmethod
def random_rotate(img, saliency) -> np.array:
torch_random_rotate = transforms.RandomRotation(degrees=CONFIG_DEGREE)
numpy_random_rotate = torch_function_for_numpy(torch_random_rotate)
result_img, result_saliency = same_random_transform_on_both(
img, saliency, numpy_random_rotate)
return result_img, result_saliency
@ staticmethod
def random_crop(img, saliency) -> np.array:
zoom_in_ratio = (CONFIG_ZOOMIN-1.2) * np.random.rand(1) + 1.2
w, h = img.shape
new_w, new_h = int(w*zoom_in_ratio), int(h*zoom_in_ratio)
new_img = cv.resize(img, dsize=(new_w, new_h))
new_saliency = cv.resize(saliency, dsize=(new_w, new_h))
p = (np.random.randint(low=0, high=new_w-w),
np.random.randint(low=0, high=new_h-h))
cropped_img = new_img[p[0]:p[0]+w, p[1]:p[1]+h]
cropped_saliency = new_saliency[p[0]:p[0]+w, p[1]:p[1]+h]
return cropped_img, cropped_saliency
@ staticmethod
def random_cutout(img, saliency, minimal_size=32) -> np.array:
minimal_size = CONFIG_CUTOUT
w, h = img.shape
canvas = np.zeros(img.shape)
pt1 = (np.random.randint(low=0, high=w-minimal_size),
np.random.randint(low=0, high=h-minimal_size))
# pt2 = (pt1[0]+minimal_size, pt1[1]+minimal_size)
pt2 = (np.random.randint(low=pt1[0]+minimal_size, high=w),
np.random.randint(low=pt1[1]+minimal_size, high=h))
mask = cv.rectangle(canvas, pt1, pt2, color=1.0,
thickness=cv.FILLED)
result = np.multiply(1-mask, img)
return result, saliency
@ staticmethod
def reshape(img, saliency, target_shape=(224,224)) -> np.array:
new_img = cv.resize(img, dsize=target_shape)
new_saliency = cv.resize(saliency, dsize=target_shape)
return new_img, new_saliency
@ staticmethod
def to_RGB(img, saliency) -> np.array:
if len(img.shape)==2:
img = np.stack([img,img,img])
if len(saliency.shape)==2:
saliency = np.stack([saliency,saliency,saliency])
return img, saliency
class FocusContrastOperator:
""" Augmentation operator with focus/saliency
"""
@ staticmethod
def focus_drop(img, saliency, drop_to=0.1) -> np.array:
# make this 0-1
saliency = cv.GaussianBlur(saliency, (199, 199), 0)
saliency = saliency/(0.01+np.max(saliency))
# add non-saliece-value to make the image not entirely black
saliency += drop_to
saliency = np.clip(saliency, 0, 1)
result = np.multiply(img, saliency)
return result, saliency
@ staticmethod
def focus_crop(img, saliency, threshold=0.9, zoom_in_ratio=CONFIG_ZOOMIN) -> np.array:
w, h = img.shape
zoom_in_ratio = (zoom_in_ratio-1) * np.random.rand(1) + 1 # 1-zoom_in_ratio
new_w, new_h = int(w*zoom_in_ratio), int(h*zoom_in_ratio)
new_img = cv.resize(img, dsize=(new_w, new_h))
new_saliency = cv.resize(saliency, dsize=(
new_w, new_h), interpolation=cv.INTER_NEAREST)
def get_new_pt1():
pt1 = (np.random.randint(low=0, high=new_w-w),
np.random.randint(low=0, high=new_h-h))
return pt1
rand_pt1 = get_new_pt1()
saliency_crop = new_saliency[rand_pt1[0]:rand_pt1[0]+w,
rand_pt1[1]:rand_pt1[1]+h]
saliency_ratio = np.sum(saliency_crop) / \
(np.sum(saliency)+1e-4)/(zoom_in_ratio**2)
# if the generated crop have too little overlap with the saliency, desprecate it
# and generate a new mask until it passes.
counter = 0
while(saliency_ratio < threshold and counter < 200):
rand_pt1 = get_new_pt1()
saliency_crop = new_saliency[rand_pt1[0]:rand_pt1[0]+w,
rand_pt1[1]:rand_pt1[1]+h]
saliency_ratio = np.sum(saliency_crop) / \
(np.sum(saliency)+1e-4)/(zoom_in_ratio**2)
# print(saliency_ratio)
counter += 1
cropped_img = new_img[rand_pt1[0]:rand_pt1[0]+w,
rand_pt1[1]:rand_pt1[1]+h]
cropped_saliency = new_saliency[rand_pt1[0]:rand_pt1[0]+w,
rand_pt1[1]:rand_pt1[1]+h]
return cropped_img, cropped_saliency
@ staticmethod
def focus_cutout(img, saliency, threshold=400, minimal_size=CONFIG_CUTOUT) -> np.array:
w, h = img.shape
def get_new_cutoutmask():
canvas = np.zeros(img.shape)
pt1 = (np.random.randint(low=0, high=w-minimal_size),
np.random.randint(low=0, high=h-minimal_size))
# pt2 = (pt1[0]+minimal_size, pt1[1]+minimal_size)
pt2 = (np.random.randint(low=pt1[0]+minimal_size, high=w),
np.random.randint(low=pt1[1]+minimal_size, high=h))
mask = cv.rectangle(canvas, pt1, pt2, color=1.0,
thickness=cv.FILLED)
return mask
rand_mask = get_new_cutoutmask()
# if the generated mask have too much overlap with the saliency, desprecate it
# and generate a new mask until it passes.
overlap = np.sum(rand_mask*saliency)
# print("overlap is:", overlap)
counter = 0
while(overlap > threshold and counter < 100):
rand_mask = get_new_cutoutmask()
overlap = np.sum(rand_mask*saliency)
counter += 1
result = np.multiply(1-rand_mask, img)
return result, saliency