defshiftMask(src, target, mask, offset=(0,0)): off_h, off_w = offset height = src.shape[0] width = src.shape[1] new_src = np.zeros(target.shape) new_mask = np.zeros(target.shape) new_mask = new_mask[:,:,0] for h inrange(0, height): for w inrange(0, width): x = h + off_h y = w + off_w if x < target.shape[0] and y < target.shape[1] and x >= 0and y >= 0: new_src[x,y,:] = src[h,w,:] new_mask[x,y] = mask[h,w] return new_src, new_mask
defin_omega(x, y, mask): if mask[x, y] != 0: returnTrue else: returnFalse
defon_border(x, y, mask): if in_omega(x, y, mask) == False: returnFalse for i, j in get_neighbor(x, y): if in_omega(i, j, mask) == False: returnTrue returnFalse
defpixel_belong(x, y, mask): if in_omega(x, y, mask) == False: return OUTSIDE elif on_border(x, y, mask): return BORDER else: return OMEGA
defget_fpoints(mask): f_points = [] for i inrange(0, mask.shape[0]): for j inrange(0, mask.shape[1]): if in_omega(i, j, mask): f_points.append((i, j)) return f_points
defpoisson_matrix(f_points): ''' Construct a poisson matrix with non zero points in mask (fpoints). Considered the sparsity of the matrix (at most 5 non zero element per row), use lil_matrix in scipy.sparse rather than np.array ''' N = len(f_points) A = lil_matrix((N,N)) for index inrange(0, N): A[index,index] = 4 x, y = f_points[index] for pt in get_neighbor(x, y): if pt notin f_points: continue else: A[index, f_points.index(pt)] = -1 return A
defsolve_equation(src, target, mask): ''' This function is used to solve Ax = b equation. Use linalq in scipy.sparse to get the inverse of a sparse matrix. ''' f_points = get_fpoints(mask) N = len(f_points) b = np.zeros(N) A = poisson_matrix(f_points) for index inrange(0, N): x, y = f_points[index] b[index] = laplace_pixel(src, x, y) # ensure the border condition that the gradiant of f and f* is the same if pixel_belong(x, y, mask) == BORDER: for i, j in get_neighbor(x, y): if in_omega(i, j, mask) == False: b[index] += target[i, j] A = A.asformat("csr") x_ans = linalg.spsolve(A, b) result = np.copy(target).astype(int) for index inrange(0, N): x, y = f_points[index] result[x, y] = x_ans[index] print(x_ans) return result