使用 Pillow and/or NumPy 进行快速像素操作

Quick pixel manipulation with Pillow and/or NumPy

我正在尝试提高我的图像处理速度,因为它对于实际使用来说太慢了。

我需要做的是对图像上每个像素的颜色应用复杂的变换。操作基本上是应用矢量变换,如 T(r, g, b, a) => (r * x, g * x, b * y, a) 或通俗地说,它是红色和绿色值乘以一个常数,蓝色和保持 Alpha 的不同乘法。但如果 R​​GB 颜色属于某些特定颜色,我还需要以不同方式操作它,在这些情况下,它们必须遵循 dictionary/transformation table,其中 RGB => newRGB 再次保持 alpha。

算法为:

for each pixel in image:
  if pixel[r, g, b] in special:
    return special[pixel[r, g, b]] + pixel[a]
  else:
    return T(pixel)

这很简单,但速度不是最理想的。我相信有一些方法可以使用 numpy 向量,但我找不到方法。

关于实施的重要细节:

buffer是从一个wxPython Bitmap中得到的,special(RG|B)_pal是变换table,最终结果也是一个wxPython Bitmap。它们是这样获得的:

# buffer
bitmap = wx.Bitmap # it's valid wxBitmap here, this is just to let you know it exists
buff = bytearray(bitmap.GetWidth() * bitmap.GetHeight() * 4)
bitmap.CopyToBuffer(buff, wx.BitmapBufferFormat_RGBA)

self.RG_mult= 0.75
self.B_mult = 0.83

self.RG_pal = []
self.B_pal = []

for i in range(0, 256):
    self.RG_pal.append(int(i * self.RG_mult))
    self.B_pal.append(int(i * self.B_mult))

self.special = {
    # RGB: new_RGB
    # Implementation specific for the fastest access
    # with buffer keys are 24bit numbers, with PIL keys are tuples
}

我尝试的实现包括直接缓冲区操作:

for x in range(0, bitmap.GetWidth() * bitmap.GetHeight()):
    index = x * 4
    r = buf[index]
    g = buf[index + 1]
    b = buf[index + 2]
    rgb = buf[index:index + 3]
    if rgb in self.special:
        special = self.special[rgb]
        buf[index] = special[0]
        buf[index + 1] = special[1]
        buf[index + 2] = special[2]
    else:
        buf[index] = self.RG_pal[r]
        buf[index + 1] = self.RG_pal[g]
        buf[index + 2] = self.B_pal[b]

将 Pillow 与 getdata() 一起使用:

pil = Image.frombuffer("RGBA", (bitmap.GetWidth(), bitmap.GetHeight()), buf)
pil_buf = []

for colour in pil.getdata():
    colour_idx = colour[0:3]

    if (colour_idx in self.special):
        special = self.special[colour_idx]
        pil_buf.append((
            special[0],
            special[1],
            special[2],
            colour[3],
        ))
    else:
        pil_buf.append((
            self.RG_pal[colour[0]],
            self.RG_pal[colour[1]],
            self.B_pal[colour[2]],
            colour[3],
        ))

pil.putdata(pil_buf)
buf = pil.tobytes()

point()getdata() 的枕头(我达到的最快,比其他人快两倍多)

pil = Image.frombuffer("RGBA", (bitmap.GetWidth(), bitmap.GetHeight()), buf)

r, g, b, a = pil.split()
r = r.point(lambda r: r * self.RG_mult)
g = g.point(lambda g: g * self.RG_mult)
b = b.point(lambda b: b * self.B_mult)
pil = Image.merge("RGBA", (r, g, b, a))

i = 0
for colour in pil.getdata():
    colour_idx = colour[0:3]

    if (colour_idx in self.special):
        special = self.special[colour_idx]
        pil.putpixel(
            (i % bitmap.GetWidth(), i // bitmap.GetWidth()),
            (
                special[0],
                special[1],
                special[2],
                colour[3],
            )
        )
    i += 1

buf = pil.tobytes()

我也尝试过使用 numpy.where 但后来我无法让它工作。使用 numpy.apply_along_axis 它可以工作,但性能很糟糕。其他使用 numpy 的尝试我无法一起访问 RGB,只能作为分离的波段。

纯 Numpy 版本

第一个优化依赖于这样一个事实,即像素的特殊颜色可能少得多。我使用 numpy 来完成所有的内部循环。这适用于最多 1MP 个图像。如果您有多个图像,我建议您使用并行方法。

让我们定义一个测试用例:

import requests
from io import BytesIO
from PIL import Image
import numpy as np

# Load some image, so we have the same
response = requests.get("https://upload.wikimedia.org/wikipedia/commons/4/41/Rick_Astley_Dallas.jpg")
# Make areas of known color
img = Image.open(BytesIO(response.content)).rotate(10, expand=True).rotate(-10,expand=True, fillcolor=(255,255,255)).convert('RGBA')

print("height: %d, width: %d (%.2f MP)"%(img.height, img.width, img.width*img.height/10e6))

height: 5034, width: 5792 (2.92 MP)

定义我们的特殊颜色

specials = {
    (4,1,6):(255,255,255), 
    (0, 0, 0):(255, 0, 255), 
    (255, 255, 255):(0, 255, 0)
}

算法

def transform_map(img, specials, R_factor, G_factor, B_factor):
    # Your transform
    def transform(x, a):
        a *= x
        return a.clip(0, 255).astype(np.uint8)

    # Convert to array
    img_array = np.asarray(img)
    # Extract channels
    R = img_array.T[0]
    G = img_array.T[1]
    B = img_array.T[2]
    A = img_array.T[3]

    # Find Special colors
    # First, calculate a uniqe hash
    color_hashes = (R + 2**8 * G + 2**16 * B)


    # Find inidices of special colors
    special_idxs = []
    for k, v in specials.items():
        key_arr = np.array(list(k))
        val_arr = np.array(list(v))

        spec_hash = key_arr[0] + 2**8 * key_arr[1] + 2**16 * key_arr[2]
        special_idxs.append(
            {
                'mask': np.where(np.isin(color_hashes, spec_hash)),
                'value': val_arr
            }
        )

    # Apply transform to whole image
    R = transform(R, R_factor)
    G = transform(G, G_factor)
    B = transform(B, B_factor)


    # Replace values where special colors were found
    for idx in special_idxs:
        R[idx['mask']] = idx['value'][0]
        G[idx['mask']] = idx['value'][1]
        B[idx['mask']] = idx['value'][2]

    return Image.fromarray(np.array([R,G,B,A]).T, mode='RGBA')

最后 Intel Core i5-6300U @ 2.40GHz

上的一些基准测试
import time
times = []
for i in range(10):
    t0 = time.time()
    # Test
    transform_map(img, specials, 1.2, .9, 1.2)
    #
    t1 = time.time()
    times.append(t1-t0)
np.round(times, 2)

print('average run time: %.2f +/-%.2f'%(np.mean(times), np.std(times)))

average run time: 9.72 +/-0.91

编辑并行化

使用与上述相同的设置,我们可以将处理大图像的速度提高 2 倍。 (没有numba的小的更快)

from numba import njit, prange
from numba.core import types
from numba.typed import Dict

# Map dict of special colors or transform over array of pixel values
@njit(parallel=True, locals={'px_hash': types.uint32})
def check_and_transform(img_array, d, T):
    #Save Shape for later
    shape = img_array.shape
    # Flatten image for 1-d iteration
    img_array_flat = img_array.reshape(-1,3).copy()
    N = img_array_flat.shape[0]
    # Replace or map
    for i in prange(N):
        px_hash = np.uint32(0)
        px_hash += img_array_flat[i,0]
        px_hash += types.uint32(2**8) * img_array_flat[i,1] 
        px_hash += types.uint32(2**16) * img_array_flat[i,2]
        
        try:
            img_array_flat[i] = d[px_hash]
        except Exception:
            img_array_flat[i] =  (img_array_flat[i] * T).astype(np.uint8)
    # return image
    return img_array_flat.reshape(shape) 

# Wrapper for function above
def map_or_transform_jit(image: Image, specials: dict, T: np.ndarray):
    # assemble numba typed dict
    d = Dict.empty(
        key_type=types.uint32,
        value_type=types.uint8[:],
    )
    for k, v in specials.items():
        k = types.uint32(k[0] + 2**8 * k[1] + 2**16 * k[2])
        v = np.array(v, dtype=np.uint8)
        d[k] = v
        
    # get rgb channels
    img_arr = np.array(img)
    rgb = img_arr[:,:,:3].copy()
    img_shape = img_arr.shape
    # apply map
    rgb = check_and_transform(rgb, d, T)
    # set color channels
    img_arr[:,:,:3] = rgb
    
    return Image.fromarray(img_arr, mode='RGBA')

# Benchmark
import time
times = []
for i in range(10):
    t0 = time.time()
    # Test
    test_img = map_or_transform_jit(img, specials, np.array([1, .5, .5]))
    #
    t1 = time.time()
    times.append(t1-t0)
np.round(times, 2)

print('average run time: %.2f +/- %.2f'%(np.mean(times), np.std(times)))
test_img

average run time: 3.76 +/- 0.08