mingsDB
Pytorch to CoreML 변환
kmings
2024. 5. 19. 21:38
728x90
LPIENet 논문의 내용을 실제 iPhone에 적용하기 위해 pytorch를 이용하여 학습한 pth파일을 CoreML파일로 변환할 필요가 있어서 해당 방법을 진행하게 되었습니다.
먼저 pth파일과 model의 정보가 필요합니다.
저는 해당 git 에 있는 model과 train 방식을 사용하에 pth 파일을 얻었습니다.
python에서 pytorch파일을 coreml파일로 변환하기 위해 coremltools 가 필요합니다. 따라서
pip install coremltools
사용하여 패키지를 install 해줍시다.
그 후 아래 코드를 통해 변환을 진행했습니다.
import torch
import torch.nn as nn
import coremltools as ct
import coremltools as ct
# model.py 파일의 LPIENet class를 가져옵니다.
from model import LPIENet
model = LPIENet(3, 3, [16, 32, 64], [32, 16])
model.load_state_dict(torch.load('best.pth'))
model.eval()
# 학습 시 batch를 4로 하여 학습했기에 4를 적용했습니다.
example_input = torch.randn(4,3,540,960)
traced_model = torch.jit.trace(model, example_input)
out = model(example_input)
# 이미지를 사용하므로 ct.ImageType을 사용함
# Core ML로 변환
model = ct.convert(
traced_model,
convert_to="mlprogram",
inputs=[ct.TensorType(shape=example_input.shape)]
)
# save model
model.save("newmodel.mlpackage")
위 코드를 실행시키면 newmodel.mlpackage 폴더가 생성되고 해당 폴더 안에는 .mlmodel 파일이 있습니다!!!!
주의!!) windows에서 변환 하니
다음과 같은 에러가 발생해서 인터넷에 있는 해결법을 적용해도 안되길래 Linux에서 실행하니 바로 되네요...ㅠㅠ !!