본문 바로가기

mingsDB

Pytorch to CoreML 변환

728x90

LPIENet 논문의 내용을 실제 iPhone에 적용하기 위해 pytorch를 이용하여 학습한 pth파일을 CoreML파일로 변환할 필요가 있어서 해당 방법을 진행하게 되었습니다. 


먼저 pth파일과 model의 정보가 필요합니다.

저는 해당 git 에 있는 model과 train 방식을 사용하에 pth 파일을 얻었습니다.

 

python에서 pytorch파일을 coreml파일로 변환하기 위해 coremltools 가 필요합니다. 따라서 

pip install coremltools

사용하여 패키지를 install 해줍시다.

 

그 후 아래 코드를 통해 변환을 진행했습니다.

 

import torch
import torch.nn as nn
import coremltools as ct
import coremltools as ct

# model.py 파일의 LPIENet class를 가져옵니다.
from model import LPIENet

model = LPIENet(3, 3, [16, 32, 64], [32, 16])

model.load_state_dict(torch.load('best.pth'))
model.eval()

# 학습 시 batch를 4로 하여 학습했기에 4를 적용했습니다.
example_input = torch.randn(4,3,540,960)
traced_model = torch.jit.trace(model, example_input)
out = model(example_input)


# 이미지를 사용하므로 ct.ImageType을 사용함

# Core ML로 변환
model = ct.convert(
    traced_model,
    convert_to="mlprogram",
    inputs=[ct.TensorType(shape=example_input.shape)]
 )
 
 # save model
model.save("newmodel.mlpackage")

 

위 코드를 실행시키면 newmodel.mlpackage 폴더가 생성되고 해당 폴더 안에는 .mlmodel 파일이 있습니다!!!!

 


주의!!) windows에서 변환 하니 

다음과 같은 에러가 발생해서 인터넷에 있는 해결법을 적용해도 안되길래  Linux에서 실행하니 바로 되네요...ㅠㅠ !!

 


참고 사이트