# pytorch中的grid_sample
# grid_sample
直译为网格采样,给定一个mask patch
,根据在目标图像上的坐标网格,将mask
变换到目标图像上。
如上图,是将一个2x2
的mask
根据坐标网格grid
变换到6x6
目标图像x0 y0 x1 y1 = 1,1,3,3
的位置上,值得注意的是grid
是经过运算得到的坐标网格,mask
在target image
对应位置的左上角处坐标应该为-1,-1
,右下角处坐标应该为1,1
,目标图像对应位置的像素值由mask
通过插值得到。
知道了grid_sample
的原理,再来看下torch
中的函数。
# grid_sample
函数原型
torch.nn.functional.grid_sample(input,
grid,
mode='bilinear',
padding_mode='zeros', align_corners=None)
input
输入image patch
,支持4d
或5d
输入。为4d
时shape
为grid
坐标网格,当input
为4d
时其shape
为 ,输出的shape
为N,C,H_{out},W_{out}
,对于输出的位置output[n, :, h, w]
,‵grid[n, h, w]是二维向量,指定了其对应的
input上的位置。
output[n, :, h, w]根据‵grid[n, h, w]
指定的对应input
位置上的像素插值得到。grid
指定了在input
输入维度上标准化后的坐标大小,input
左上角对应的应该是-1,-1
,右下角对应的是1,1
mode
插值方式,'bilinear' | 'nearest' | 'bicubic'
padding_mode
,在(-1,1)外的输出图像上的像素值处理方式'zeros' | 'border' | 'reflection'
align_corners
:是否对齐角
# 实例
以将一个100x100
的mask
,网格采样到500x300
的图像上(x,y,w,h)=(100, 100, 100, 200)
为例,看一下grid_sample
是如何使用的。
先计算grid
,
import torch
import numpy as np
import cv2
import torch.nn.functional as F
import matplotlib.pyplot as plt
h, w = 300, 500
x0, y0, x1, y1 = torch.tensor([[100]]), torch.tensor([[100]]), torch.tensor([[200]]), torch.tensor([[300]])
N = 1
x0_int, y0_int = 0, 0
x1_int, y1_int = 500, 300
img_y = torch.arange(y0_int, y1_int, dtype=torch.float32) + 0.5
img_x = torch.arange(x0_int, x1_int, dtype=torch.float32) + 0.5
img_y = (img_y - y0) / (y1 - y0) * 2 - 1
img_x = (img_x - x0) / (x1 - x0) * 2 - 1
gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1))
gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
grid = torch.stack([gx, gy], dim=3)
这里使用的是mask
在目标图像上的大小来对grid
归一化的。
mask = np.zeros((100, 100), dtype=np.uint8)
ct = np.array([[50, 0],[99, 50], [50, 99], [0, 50]], dtype=np.int32)
mask = cv2.drawContours(mask, [ct], -1, 255, cv2.FILLED)
plt.figure(1)
plt.imshow(mask)
mask = torch.from_numpy(mask)
masks = mask[None, None, :]
if not torch.jit.is_scripting():
if not masks.dtype.is_floating_point:
masks = masks.float()
img_masks = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False)
plt.figure(2)
plt.imshow(img_masks.squeeze().numpy().astype(np.uint8))
根据grid
将mask
映射到目标图像上的指定区域指定大小。
grid_sample
的使用,如Mask RCNN
将对象实例分割的mask
映射到原图像尺寸上。
1.https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html (opens new window)