<GLIDE 코드 분석>
Tokenizer : 자연어 처리에서 입력 문장을 일정한 단위로 분할
Vocabulary : 분할된 단위에 고유한 일련번호 구현
OOV(Out OF Vocabulary) : 기계가 모르는 단어로 인해 문제를 푸는 것이 까다로워짐
└ = UNK(Unknown Token)
이를 해결하기 위해 GLIDE 코드에서는 BPE를 사용해 tokenizer 구현.
GLIDE 에서는 tokenizer cache를 lru cache로 진행.
from PIL import Image
from IPython.display import display
import torch as th
from glide_text2im.download import load_checkpoint
from glide_text2im.model_creation import (
create_model_and_diffusion,
model_and_diffusion_defaults,
model_and_diffusion_defaults_upsampler
)
PIL : Python Imaging Library
이미지 분석 및 처리 라이브러리
다양한 이미지 파일 형식을 지원하고, 이미지 프로세싱 기능을 제공.
import display는 이미지를 HTML로 출력하기 위해 Print() 대신 사용.
# Sampling parameters
prompt = "an oil painting of a corgi"
batch_size = 1
guidance_scale = 3.0
# Tune this parameter to control the sharpness of 256x256 images.
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
upsample_temp = 0.997
upsample_temp
sampling을 통해 데이터를 키운다. 아티팩트가 적은 흐린 이미지의 경우 ~0.997로 낮추는 것이 좋음
##############################
# Sample from the base model #
##############################
# Create the text tokens to feed to the model.
tokens = model.tokenizer.encode(prompt)
tokens, mask = model.tokenizer.padded_tokens_and_mask(
tokens, options['text_ctx']
)
# Create the classifier-free guidance tokens (empty)
full_batch_size = batch_size * 2
uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(
[], options['text_ctx']
)
# Pack the tokens together into model kwargs.
model_kwargs = dict(
tokens=th.tensor(
[tokens] * batch_size + [uncond_tokens] * batch_size, device=device
),
mask=th.tensor(
[mask] * batch_size + [uncond_mask] * batch_size,
dtype=th.bool,
device=device,
),
)
# Create a classifier-free guidance sampling function
def model_fn(x_t, ts, **kwargs):
half = x_t[: len(x_t) // 2]
combined = th.cat([half, half], dim=0)
model_out = model(combined, ts, **kwargs)
eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
eps = th.cat([half_eps, half_eps], dim=0)
return th.cat([eps, rest], dim=1)
# Sample from the base model.
model.del_cache()
samples = diffusion.p_sample_loop(
model_fn,
(full_batch_size, 3, options["image_size"], options["image_size"]),
device=device,
clip_denoised=True,
progress=True,
model_kwargs=model_kwargs,
cond_fn=None,
)[:batch_size]
model.del_cache()
# Show the output
show_images(samples)
이 코드를 분석하려면
glide-text2im/glide_text2im/tokenizer/
먼저 들어가준다.. 벌써 지친다 어카냐
Tokenizer : 자연어 처리에서 입력 문장을 일정한 단위로 분할
Vocabulary : 분할된 단위에 고유한 일련번호 구현
OOV(Out OF Vocabulary) : 기계가 모르는 단어로 인해 문제를 푸는 것이 까다로워짐
└ = UNK(Unknown Token)
이를 해결하기 위해 GLIDE 코드에서는 BPE를 사용해 tokenizer 구현.
BPE(Byte Pair Encoding) : 글자 단위에서 점차적으로 단어 집합을 만든다
GLIDE 에서는 tokenizer cache를 lru cache로 진행.
glide-text2im/glide_text2im/tokenizer/simple_tokenizer.py/
"""
Copied from: https://github.com/openai/CLIP/blob/573315e83f07b53a61ff5098757e8fc885f1703e/clip/simple_tokenizer.py
"""
import gzip
import html
import os
from functools import lru_cache
from typing import List, Tuple
import ftfy
import regex as re
@lru_cache()
def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2 ** 8):
if b not in bs:
bs.append(b)
cs.append(2 ** 8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
merges = merges[1 : 49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + "</w>" for v in vocab]
for merge in merges:
vocab.append("".join(merge))
vocab.extend(["<|startoftext|>", "<|endoftext|>"])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {"<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>"}
self.pat = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE,
)
@property
def start_token(self):
return self.encoder["<|startoftext|>"]
@property
def end_token(self):
return self.encoder["<|endoftext|>"]
def padded_tokens_and_len(self, tokens: List[int], text_ctx: int) -> Tuple[List[int], int]:
tokens = [self.start_token] + tokens[: text_ctx - 2] + [self.end_token]
text_len = len(tokens)
padding = text_ctx - len(tokens)
padded_tokens = tokens + [0] * padding
return padded_tokens, text_len
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + "</w>",)
pairs = get_pairs(word)
if not pairs:
return token + "</w>"
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except: # pylint: disable=bare-except
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
def decode(self, tokens):
text = "".join([self.decoder[token] for token in tokens])
text = (
bytearray([self.byte_decoder[c] for c in text])
.decode("utf-8", errors="replace")
.replace("</w>", " ")
)
return text
하이고
import ftfy
ftfy : python에서 결함이 존재한는 문자열을 유니코드 텍스트로 자동변환하는 라이브러리
import regex as re
정규 표현식 모듈.
복잡한 문자열의 검색과 치환을 위해 사용.
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
문자쌍이 bpe_ranks에서 발견되지 않으면 무한대를 반환 (= 무한대는 최소값이 될 수 없기 때문에 버려짐)
더이상 bpe_ranks에 유효한 바이트 쌍이 없으면
루프 종료
##############################
# Upsample the 64x64 samples #
##############################
tokens = model_up.tokenizer.encode(prompt)
tokens, mask = model_up.tokenizer.padded_tokens_and_mask(
tokens, options_up['text_ctx']
)
# Create the model conditioning dict.
model_kwargs = dict(
# Low-res image to upsample.
low_res=((samples+1)*127.5).round()/127.5 - 1,
# Text tokens
tokens=th.tensor(
[tokens] * batch_size, device=device
),
mask=th.tensor(
[mask] * batch_size,
dtype=th.bool,
device=device,
),
)
# Sample from the base model.
model_up.del_cache()
up_shape = (batch_size, 3, options_up["image_size"], options_up["image_size"])
up_samples = diffusion_up.ddim_sample_loop(
model_up,
up_shape,
noise=th.randn(up_shape, device=device) * upsample_temp,
device=device,
clip_denoised=True,
progress=True,
model_kwargs=model_kwargs,
cond_fn=None,
)[:batch_size]
model_up.del_cache()
# Show the output
show_images(up_samples)
model_kwargs = dict(
kwargs를 알려면 먼저 args를 알아야함
*args : 함수에서 여러개의 인자, n개를 받을 때 사용.
정해지지 않은 수의 인자를 받을경우 1개의 함수만 구현해도 된다
**kwargw : 함수에서 여러개의 인자 n개를 받을때, 이를 key-value로 받을 시 사용.
함수에서 딕셔너리로 인자를 받을 수 있음.
나머지는 월요일에!
'혼자서' 카테고리의 다른 글
Textual inversion 하는 과정 (실패후 성공!) (0) | 2022.11.25 |
---|---|
Open AI Glide: Text-to image Generation Explained with code 따라해보기 2 (inpaint.ipynb) (0) | 2022.07.25 |
DALL2-pytorch (구현 실패) (0) | 2022.07.21 |
Open AI Glide: Text-to image Generation Explained with code 따라해보기 (0) | 2022.07.20 |
20220718 (0) | 2022.07.19 |