Pytorch风格迁移代码

发布时间:2022-06-20 发布网站:脚本宝典
脚本宝典收集整理的这篇文章主要介绍了Pytorch风格迁移代码脚本宝典觉得挺不错的,现在分享给大家,也给大家做个参考。

最近研究了一下风格迁移,主要是想应用于某些主题节日时动态融合背景,生成一些抽象的艺术图片,这里给大家分享一个现成的代码,我本地把环境搭建好后跑了试试,有兴趣的可以直接拿去运行:

  1 import torch
  2 import torch.nn as nn
  3 import torch.nn.functional as F
  4 import torch.optim as optim
  5 
  6 from PIL import Image
  7 import matplotlib.pyplot as plt
  8 
  9 import torchvision.transforms as transforms
 10 import torchvision.models as models
 11 import datetime
 12 
 13 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 14 
 15 
 16 num_steps = 10000  # cpu跑的话,低于300吧,不然耗时很长
 17 save_path = "data/drew/img/end_%s.jpg" % datetime.datetime.now().strftime("%Y%m%d%H%M%S")
 18 content_img_path = "data/drew/img/dancing.jpg"
 19 style_img_path = "data/drew/img/picasso.jpg"
 20 
 21 
 22 def get_img_size(img_name):
 23     im = Image.open(img_name).convert('RGB')  # 这里要转成RGB
 24     return im, im.height, im.width
 25 
 26 
 27 def image_loader(img, im_h, im_w):
 28     loader = transforms.Compose([transforms.Resize([im_h, im_w]), transforms.ToTensor()])    # 如果跑不动,这里的Resize设置小一点,我这用的是适配融入内容的尺寸
 29     im_l = loader(img).unsqueeze(0)
 30     return im_l.to(device, torch.float)
 31 
 32 
 33 c_image, c_im_h, c_im_w = get_img_size(content_img_path)
 34 s_image, s_im_h, s_im_w = get_img_size(style_img_path)
 35 content_img = image_loader(c_image, c_im_h, c_im_w)
 36 style_img = image_loader(s_image, c_im_h, c_im_w)
 37 
 38 
 39 assert style_img.size() == content_img.size(), "we need to import style and content images of the same size"
 40 unloader = transforms.ToPILImage()
 41 
 42 plt.ion()
 43 
 44 
 45 def imshow(tensor, title=None):
 46     image = tensor.cpu().clone()  # we clone the tensor to not do changes on it
 47     image = image.squeeze(0)      # remove the fake batch dimension
 48     image = unloader(image)
 49     plt.imshow(image)
 50     if title is not None:
 51         plt.title(title)
 52     plt.pause(0.001) # pause a bit so that plots are updated
 53 
 54 
 55 # plt.figure()
 56 # imshow(style_img, title='Style Image')
 57 #
 58 # plt.figure()
 59 # imshow(content_img, title='Content Image')
 60 
 61 
 62 class ContentLoss(nn.Module):
 63 
 64     def __init__(self, target,):
 65         super(ContentLoss, self).__init__()
 66         self.target = target.detach()
 67 
 68     def forward(self, input):
 69         self.loss = F.mse_loss(input, self.target)
 70         return input
 71 
 72 
 73 def gram_matrix(input):
 74     a, b, c, d = input.size()  # a=batch size(=1)
 75 
 76     features = input.view(a * b, c * d)  # resise F_XL into hat F_XL
 77 
 78     G = torch.mm(features, features.t())  # compute the gram product
 79 
 80     return G.div(a * b * c * d)
 81 
 82 
 83 class StyleLoss(nn.Module):
 84 
 85     def __init__(self, target_feature):
 86         super(StyleLoss, self).__init__()
 87         self.target = gram_matrix(target_feature).detach()
 88 
 89     def forward(self, input):
 90         G = gram_matrix(input)
 91         self.loss = F.mse_loss(G, self.target)
 92         return input
 93 
 94 
 95 cnn = models.vgg19(pretrained=True).features.to(device).eval()
 96 
 97 
 98 cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
 99 cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
100 
101 
102 class Normalization(nn.Module):
103     def __init__(self, mean, std):
104         super(Normalization, self).__init__()
105         self.mean = mean.clone().detach().view(-1, 1, 1)
106         self.std = std.clone().detach().view(-1, 1, 1)
107 
108     def forward(self, img):
109         # normalize img
110         return (img - self.mean) / self.std
111 
112 
113 content_layers_default = ['conv_4']
114 style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
115 
116 
117 def get_style_model_and_losses(cnn, normalization_mean, normalization_std, style_img, content_img,
118                                content_layers=content_layers_default, style_layers=style_layers_default):
119     normalization = Normalization(normalization_mean, normalization_std).to(device)
120 
121     content_losses = []
122     style_losses = []
123 
124     model = nn.Sequential(normalization)
125 
126     i = 0  # increment every time we see a conv
127     for layer in cnn.children():
128         if isinstance(layer, nn.Conv2d):
129             i += 1
130             name = 'conv_{}'.format(i)
131         elif isinstance(layer, nn.ReLU):
132             name = 'relu_{}'.format(i)
133             layer = nn.ReLU(inplace=False)
134         elif isinstance(layer, nn.MaxPool2d):
135             name = 'pool_{}'.format(i)
136         elif isinstance(layer, nn.BatchNorm2d):
137             name = 'bn_{}'.format(i)
138         else:
139             raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
140 
141         model.add_module(name, layer)
142 
143         if name in content_layers:
144             # add content loss:
145             target = model(content_img).detach()
146             content_loss = ContentLoss(target)
147             model.add_module("content_loss_{}".format(i), content_loss)
148             content_losses.append(content_loss)
149 
150         if name in style_layers:
151             # add style loss:
152             target_feature = model(style_img).detach()
153             style_loss = StyleLoss(target_feature)
154             model.add_module("style_loss_{}".format(i), style_loss)
155             style_losses.append(style_loss)
156 
157     # now we trim off the layers after the last content and style losses
158     for i in range(len(model) - 1, -1, -1):
159         if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
160             break
161 
162     model = model[:(i + 1)]
163 
164     return model, style_losses, content_losses
165 
166 
167 input_img = content_img.clone()
168 
169 # plt.figure()
170 # imshow(input_img, title='Input Image')
171 
172 
173 def get_input_optimizer(input_img):
174     optimizer = optim.LBFGS([input_img])
175     return optimizer
176 
177 
178 def run_style_transfer(cnn, normalization_mean, normalization_std,
179                        content_img, style_img, input_img, num_steps=num_steps,
180                        style_weight=1000000, content_weight=1):
181     """Run the style transfer."""
182     print('Building the style transfer model..')
183     model, style_losses, content_losses = get_style_model_and_losses(cnn,
184         normalization_mean, normalization_std, style_img, content_img)
185 
186     # We want to optimize the input and not the model parameters so we
187     # update all the requires_grad fields accordingly
188     input_img.requires_grad_(True)
189     model.requires_grad_(False)
190 
191     optimizer = get_input_optimizer(input_img)
192 
193     print('Optimizing..')
194     run = [0]
195     while run[0] <= num_steps:
196 
197         def closure():
198             # correct the values of updated input image
199             with torch.no_grad():
200                 input_img.clamp_(0, 1)
201 
202             optimizer.zero_grad()
203             model(input_img)
204             style_score = 0
205             content_score = 0
206 
207             for sl in style_losses:
208                 style_score += sl.loss
209             for cl in content_losses:
210                 content_score += cl.loss
211 
212             style_score *= style_weight
213             content_score *= content_weight
214 
215             loss = style_score + content_score
216             loss.backward()
217 
218             run[0] += 1
219             if run[0] % 50 == 0:
220                 print("run {}:".format(run))
221                 print('Style Loss : {:4f} Content Loss: {:4f}'.format(
222                     style_score.item(), content_score.item()))
223                 print()
224 
225             return style_score + content_score
226 
227         optimizer.step(closure)
228 
229     # a last correction...
230     with torch.no_grad():
231         input_img.clamp_(0, 1)
232 
233     return input_img
234 
235 
236 begin_time = datetime.datetime.now()
237 print("******************开始时间*****************", begin_time)
238 output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std,
239                             content_img, style_img, input_img)
240 try:
241     plt.figure()
242     imshow(output, title='Output Image')
243 
244     # sphinx_gallery_thumbnail_number = 4
245     plt.ioff()
246     plt.savefig(save_path)
247 except Exception as e:
248     print(e)
249 print("******************结束时间*****************", datetime.datetime.now())
250 print("******************耗时*****************", datetime.datetime.now()-begin_time)
251 # plt.show()

dancing.jpg

Pytorch风格迁移代码

picasso.jpg

Pytorch风格迁移代码

我这迁移后的图像,还是不错的。

Pytorch风格迁移代码

 

 

 

 风格:

 

Pytorch风格迁移代码

 内容:

 

Pytorch风格迁移代码

 

 迁移融合后:

Pytorch风格迁移代码

 风格;

Pytorch风格迁移代码

 

融入:

Pytorch风格迁移代码

 

 迁移后:

Pytorch风格迁移代码

 

 还可以吧,哈哈~

 

有兴趣的可以去研究一下原文:

原文地址:

https://pytorch.org/tutorials/advanced/neural_style_tutorial.html

原GitHub代码地址:

https://github.com/pytorch/tutorials/blob/master/advanced_source/neural_style_tutorial.py

 

需要准备:

有显卡并且支持pytorch训练的服务器,只是cpu的话就算了,GPU服务器跑几分钟,cpu服务器跑跑一小时,cpu还100%!

 

脚本宝典总结

以上是脚本宝典为你收集整理的Pytorch风格迁移代码全部内容,希望文章能够帮你解决Pytorch风格迁移代码所遇到的问题。

如果觉得脚本宝典网站内容还不错,欢迎将脚本宝典推荐好友。

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
如您有任何意见或建议可联系处理。小编QQ:384754419,请注明来意。
标签: