[workstation]

  • 환경 : RTX 3090
  • 데이터 : 박대리의 현실고증 직장생활 얼굴 10장
  • 사전학습 모델 : 위와 동일

 

*tmux session 시작

1) few-shot 학습용 도커 컨테이너 생성 및 실행 (nvidia pytorch 9월 이미지 사용)

NV_GPU=1 nvidia-docker run --name bernice-few-shot -it -v $(pwd):/workspace -v $(readlink -f disk1):/disk1 nvcr.io/nvidia/pytorch:21.09-py3 /bin/bash

 

2) 이미지 전처리

  • 박대리의 현실고증 직장생활 짤 10개 선정
  • 수기로 얼굴 크롭
  • 아래와 같이 투명 제거
import glob, os
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
from pathlib import Path

def square_and_fill(im):
  x, y = im.size
  size = max(x, y)
  white = Image.new("RGBA", (size, size), "WHITE") 
  white.paste(im, (int((size - x) / 2), int((size - y) / 2)), mask=im)
  return white

for p in glob.glob('/content/*.png'):
  path = Path(p)
  im = Image.open(p)
  white = square_and_fill(im)
  white = white.convert('RGB')
  white.save(f'{path.stem}.jpg')

 

 

3) prepare_data.py

pip install visdom
python prepare_data.py --out ./processed_data/parkdaeri/ --size 256 ./raw_data/

 

4) train.py

python train.py --ckpt /disk1/zzalgun_psp/pretrained_models/ffhq_256_rosinality.pt --data_path ./processed_data/parkdaeri/ --exp /disk1/zzalgun_psp/few-shot_results

 

 

 

---

 

few-shot 모델에 projection 하기

 

1)projector.py

 
python projector.py --ckpt /disk1/zzalgun/pretrained_models/ffhq_256_rosinality.pt --size 256 /disk1/zzalgun/my_face/my_face03.jpg

 

2)load latent codes from stylegan2(rosinality)

import torch
from model import Generator
from torchvision import utils

# latent & noise from *.pt
latent = torch.load('my_face03.pt')['/disk1/zzalgun/my_face/my_face03.jpg']
noises = latent['noise']
latent_n = latent['latent']
latent_n.size() # torch.Size([512])
latent_n = latent_n.reshape(1, 512)

# model
# g_ema = Generator(256, 512, 8)
# 해당 레포지토리 코드로 few-shot 모델 load가 안 되므로 few-shot 코드에 있는 방식 사용
# g_ema.load_state_dict(torch.load('/disk1/zzalgun/few-shot_results/parkdaeri/002000.pt')['g_ema'], strict=False)
# g_ema.eval()
# g_ema = g_ema.to('cuda')

# generate & save
# img_gen, _ = g_ema([latent_n], input_is_latent=True, inject_index=1, noise=noises)
# utils.save_image(img_gen, 'ttest.png', nrow=1, normalize=True, range=(-1, 1))

 

3)latent code에서 이미지 generate하기 (few-shot)

g_list = []

g_source = Generator(256, 512, 8, channel_multiplier=2).to('cuda')
checkpoint = torch.load('/disk1/zzalgun/pretrained_models/ffhq_256_rosinality.pt')
g_source.load_state_dict(checkpoint['g_ema'], strict=False)
g_list.append(g_source)

g_target = Generator(256, 512, 8, channel_multiplier=2).to('cuda')
g_target = nn.parallel.DataParallel(g_target)
checkpoint = torch.load('/disk1/zzalgun/few-shot_results/parkdaeri/002000.pt')
g_target.load_state_dict(checkpoint['g_ema'], strict=False)
g_list.append(g_target)

with torch.no_grad():
    for i in range(len(g_list)):
        g_test = g_list[i]
        g_test.eval()
        sample, _ = g_test([latent_n], 
                           truncation=1, 
                           truncation_latent=None, 
                           input_is_latent=True, 
                           randomize_noise=True)
        if i == 0:
            tot_img = sample
        else:
            tot_img = torch.cat([tot_img, sample], dim = 0)

    utils.save_image(
      tot_img,
      f'test_sample/sample.png',
      nrow=1,
      normalize=True,
      range=(-1, 1),
      )

 

import torch
from model import Generator
from torchvision import utils

# latent & noise from *.pt
latent = torch.load('../stylegan2-pytorch/myface01.pt')['/disk1/zzalgun/my_face/myface01.jpg']
noises = latent['noise']
latent_n = latent['latent']
latent_n.size() # torch.Size([512])
latent_n = latent_n.reshape(1, 512)

g_list = []

g_source = Generator(256, 512, 8, channel_multiplier=2).to('cuda')
checkpoint = torch.load('/disk1/zzalgun/pretrained_models/ffhq_256_rosinality.pt')
g_source.load_state_dict(checkpoint['g_ema'], strict=False)
g_list.append(g_source)

g_target = Generator(256, 512, 8, channel_multiplier=2).to('cuda')
g_target = torch.nn.parallel.DataParallel(g_target)
checkpoint = torch.load('/disk1/zzalgun/few-shot_results/parkdaeri/002000.pt')
g_target.load_state_dict(checkpoint['g_ema'], strict=False)
g_list.append(g_target)

with torch.no_grad():
    for i in range(len(g_list)):
        g_test = g_list[i]
        g_test.eval()
        sample, _ = g_test([latent_n], 
                           truncation=1, 
                           truncation_latent=None, 
                           input_is_latent=True, 
                           randomize_noise=True)
        if i == 0:
            tot_img = sample
        else:
            tot_img = torch.cat([tot_img, sample], dim = 0)

    utils.save_image(
      tot_img,
      f'test_sample/sample_01.png',
      nrow=1,
      normalize=True,
      range=(-1, 1),
      )

 

---

트러블 슈팅

few-shot으로 학습한 모델(noseless)로 학습 재개하려고 보니까 아래와 같은 오류

 
    generator.load_state_dict(ckpt["g"], strict=False)
KeyError: 'g'

 

모델을 불러와보니 g_ema 밖에 없음

원래 사전학습 모델과 비교했을 때

 

train.py 을 살펴보면 g_ema 만 저장하도록 하고 있다 (이미지 생성용)

 

 

주석 처리를 해제하고 코 제거 모델을 재학습했으나,

이번에는 discriminator를 load 할 때 에러

 

 

train.py 에서 discriminator 불러오는 부분에 strict = False 추가하면 오류는 안 나지만 스냅샷 이미지가 다음과 같이 잘못 나온다.

 

 

에러를 잘 살펴보면 state_dict의 키 형식이 조금씩 다르다

"convs.0.0.weight" - "module.convs.0.0.weight"

 

앞에 module이 붙어 인식하지 못하는 듯. 직접 key를 바꿔주기로 했다.

 

from collections import OrderedDict
from model import Patch_Discriminator as Discriminator  # , Projection_head

discriminator = Discriminator(
        256, channel_multiplier=2
    ).to('cuda')
ckpt = torch.load('/disk1/zzalgun/few-shot_results/nose/ffhq_002000.pt')

new = OrderedDict()

for key in ckpt['d'].copy().keys():
   val = ckpt['d'][key] # pop 하면 삭제됨
   key = key.replace('module.', '')
   new[key] = val
   
discriminator.load_state_dict(new)

 

 

'인공지능 > computer vision' 카테고리의 다른 글

Audio Reactive styleGAN (pending)  (0) 2022.08.12
GAN실험 - Inversion of Input Images  (0) 2022.08.12
GAN실험 - AnimeGANv2 학습  (0) 2022.08.12
GAN실험 - FreezeG  (0) 2022.08.11
GAN실험 - GANspace  (0) 2022.08.10
복사했습니다!