![article thumbnail image](https://blog.kakaocdn.net/dn/YqnVA/btrJvk186iq/LBB78qVd4kVv5JS7KKu3lk/img.png)
[workstation]
- 환경 : RTX 3090
- 데이터 : 박대리의 현실고증 직장생활 얼굴 10장
- 사전학습 모델 : 위와 동일
*tmux session 시작
1) few-shot 학습용 도커 컨테이너 생성 및 실행 (nvidia pytorch 9월 이미지 사용)
NV_GPU=1 nvidia-docker run --name bernice-few-shot -it -v $(pwd):/workspace -v $(readlink -f disk1):/disk1 nvcr.io/nvidia/pytorch:21.09-py3 /bin/bash
2) 이미지 전처리
- 박대리의 현실고증 직장생활 짤 10개 선정
- 수기로 얼굴 크롭
- 아래와 같이 투명 제거
import glob, os
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
from pathlib import Path
def square_and_fill(im):
x, y = im.size
size = max(x, y)
white = Image.new("RGBA", (size, size), "WHITE")
white.paste(im, (int((size - x) / 2), int((size - y) / 2)), mask=im)
return white
for p in glob.glob('/content/*.png'):
path = Path(p)
im = Image.open(p)
white = square_and_fill(im)
white = white.convert('RGB')
white.save(f'{path.stem}.jpg')
3) prepare_data.py
pip install visdom
python prepare_data.py --out ./processed_data/parkdaeri/ --size 256 ./raw_data/
4) train.py
python train.py --ckpt /disk1/zzalgun_psp/pretrained_models/ffhq_256_rosinality.pt --data_path ./processed_data/parkdaeri/ --exp /disk1/zzalgun_psp/few-shot_results
---
few-shot 모델에 projection 하기
1)projector.py
python projector.py --ckpt /disk1/zzalgun/pretrained_models/ffhq_256_rosinality.pt --size 256 /disk1/zzalgun/my_face/my_face03.jpg
2)load latent codes from stylegan2(rosinality)
import torch
from model import Generator
from torchvision import utils
# latent & noise from *.pt
latent = torch.load('my_face03.pt')['/disk1/zzalgun/my_face/my_face03.jpg']
noises = latent['noise']
latent_n = latent['latent']
latent_n.size() # torch.Size([512])
latent_n = latent_n.reshape(1, 512)
# model
# g_ema = Generator(256, 512, 8)
# 해당 레포지토리 코드로 few-shot 모델 load가 안 되므로 few-shot 코드에 있는 방식 사용
# g_ema.load_state_dict(torch.load('/disk1/zzalgun/few-shot_results/parkdaeri/002000.pt')['g_ema'], strict=False)
# g_ema.eval()
# g_ema = g_ema.to('cuda')
# generate & save
# img_gen, _ = g_ema([latent_n], input_is_latent=True, inject_index=1, noise=noises)
# utils.save_image(img_gen, 'ttest.png', nrow=1, normalize=True, range=(-1, 1))
3)latent code에서 이미지 generate하기 (few-shot)
g_list = []
g_source = Generator(256, 512, 8, channel_multiplier=2).to('cuda')
checkpoint = torch.load('/disk1/zzalgun/pretrained_models/ffhq_256_rosinality.pt')
g_source.load_state_dict(checkpoint['g_ema'], strict=False)
g_list.append(g_source)
g_target = Generator(256, 512, 8, channel_multiplier=2).to('cuda')
g_target = nn.parallel.DataParallel(g_target)
checkpoint = torch.load('/disk1/zzalgun/few-shot_results/parkdaeri/002000.pt')
g_target.load_state_dict(checkpoint['g_ema'], strict=False)
g_list.append(g_target)
with torch.no_grad():
for i in range(len(g_list)):
g_test = g_list[i]
g_test.eval()
sample, _ = g_test([latent_n],
truncation=1,
truncation_latent=None,
input_is_latent=True,
randomize_noise=True)
if i == 0:
tot_img = sample
else:
tot_img = torch.cat([tot_img, sample], dim = 0)
utils.save_image(
tot_img,
f'test_sample/sample.png',
nrow=1,
normalize=True,
range=(-1, 1),
)
import torch
from model import Generator
from torchvision import utils
# latent & noise from *.pt
latent = torch.load('../stylegan2-pytorch/myface01.pt')['/disk1/zzalgun/my_face/myface01.jpg']
noises = latent['noise']
latent_n = latent['latent']
latent_n.size() # torch.Size([512])
latent_n = latent_n.reshape(1, 512)
g_list = []
g_source = Generator(256, 512, 8, channel_multiplier=2).to('cuda')
checkpoint = torch.load('/disk1/zzalgun/pretrained_models/ffhq_256_rosinality.pt')
g_source.load_state_dict(checkpoint['g_ema'], strict=False)
g_list.append(g_source)
g_target = Generator(256, 512, 8, channel_multiplier=2).to('cuda')
g_target = torch.nn.parallel.DataParallel(g_target)
checkpoint = torch.load('/disk1/zzalgun/few-shot_results/parkdaeri/002000.pt')
g_target.load_state_dict(checkpoint['g_ema'], strict=False)
g_list.append(g_target)
with torch.no_grad():
for i in range(len(g_list)):
g_test = g_list[i]
g_test.eval()
sample, _ = g_test([latent_n],
truncation=1,
truncation_latent=None,
input_is_latent=True,
randomize_noise=True)
if i == 0:
tot_img = sample
else:
tot_img = torch.cat([tot_img, sample], dim = 0)
utils.save_image(
tot_img,
f'test_sample/sample_01.png',
nrow=1,
normalize=True,
range=(-1, 1),
)
---
트러블 슈팅
few-shot으로 학습한 모델(noseless)로 학습 재개하려고 보니까 아래와 같은 오류
generator.load_state_dict(ckpt["g"], strict=False)
KeyError: 'g'
모델을 불러와보니 g_ema 밖에 없음
![](https://blog.kakaocdn.net/dn/4UZUj/btrJvkA4Pnu/mcR2sL1F3P3WuLkt7jpTK0/img.png)
![](https://blog.kakaocdn.net/dn/qISov/btrJwZQwF9l/OWfvFpt0hkMOz7c3KzZdqk/img.png)
train.py 을 살펴보면 g_ema 만 저장하도록 하고 있다 (이미지 생성용)
주석 처리를 해제하고 코 제거 모델을 재학습했으나,
이번에는 discriminator를 load 할 때 에러
train.py 에서 discriminator 불러오는 부분에 strict = False 추가하면 오류는 안 나지만 스냅샷 이미지가 다음과 같이 잘못 나온다.
![](https://blog.kakaocdn.net/dn/c6LZHp/btrJwZptbTN/HgLzGfadYmxsev93K8KSO1/img.png)
에러를 잘 살펴보면 state_dict의 키 형식이 조금씩 다르다
"convs.0.0.weight" - "module.convs.0.0.weight"
앞에 module이 붙어 인식하지 못하는 듯. 직접 key를 바꿔주기로 했다.
from collections import OrderedDict
from model import Patch_Discriminator as Discriminator # , Projection_head
discriminator = Discriminator(
256, channel_multiplier=2
).to('cuda')
ckpt = torch.load('/disk1/zzalgun/few-shot_results/nose/ffhq_002000.pt')
new = OrderedDict()
for key in ckpt['d'].copy().keys():
val = ckpt['d'][key] # pop 하면 삭제됨
key = key.replace('module.', '')
new[key] = val
discriminator.load_state_dict(new)
'인공지능 > computer vision' 카테고리의 다른 글
Audio Reactive styleGAN (pending) (0) | 2022.08.12 |
---|---|
GAN실험 - Inversion of Input Images (0) | 2022.08.12 |
GAN실험 - AnimeGANv2 학습 (0) | 2022.08.12 |
GAN실험 - FreezeG (0) | 2022.08.11 |
GAN실험 - GANspace (0) | 2022.08.10 |