用变分自编码器生成图像
目前最流行也是最成功的创造性人工智能应用就是图像生成:学习潜在视觉空间,并从空
间中进行采样来创造全新图片。这些图片是在真实图片中进行插值得到的,可以是想象中的人、
想象中的地方、想象中的猫和狗等。
本节和12.5 节将介绍一些与图像生成有关的概念,还会介绍该领域中的两种主要技术的实
现细节。这两种技术分别是变分自编码器(variational autoencoder,VAE)和生成式对抗网络
(generative adversarial network,GAN)。请注意,这里介绍的技术不仅适用于图像,你还可以使
用VAE 和GAN 探索声音、音乐甚至文本的潜在空间。但在实践中,最有趣的结果都是利用图
片得到的,这也是我们介绍的重点。
从图像潜在空间中采样
图像生成的关键思想就是找到图像表示的低维潜在空间(latent space,与深度学习其他内
容一样,它也是一个向量空间),其中任意一点都可以被映射为一张“有效”图像,即看起来真实的图像。能够实现这种映射的模块,接收一个潜在点作为输入,并输出一张图像(像素网格)。
这种模块叫作生成器(generator,对于GAN 而言)或解码器(decoder,对于VAE 而言)。学到
这种潜在空间之后,我们可以从中对点进行采样,然后将其映射到图像空间,从而生成前所未
见的图像,如图12-13 所示。新图像是训练图像的插值。
要学习图像表示的这种潜在空间,VAE 和GAN 采用了不同的策略,二者各有特点。VAE 非
常适合学习具有良好结构的潜在空间,空间中的特定方向表示数据中有意义的变化轴,如图12-14
所示。GAN 可以生成非常逼真的图像,但它的潜在空间可能没有良好的结构,连续性也不强。
图像编辑的概念向量
第11 章介绍词嵌入时,我们已经暗示了概念向量(concept vector)这一想法。这里用到了
同样的想法:给定一个数据表示的潜在空间或一个嵌入空间,空间中的某些方向可能表示原始
数据中有趣的变化轴。比如在人脸图像的潜在空间中,可能存在一个微笑向量(smile vector)
s:如果潜在点z 是某张人脸的嵌入表示,那么潜在点z + s 就是同一张脸面带微笑的嵌入表
示。一旦找到了这样的向量,就可以这样编辑图像:将图像投影到潜在空间中,然后沿着有意
义的方向移动图像表示,再将其解码到图像空间中。在图像空间中任何独立的变化维度都有概
念向量。对于人脸而言,你可能会发现向人脸添加墨镜的向量、去掉眼镜的向量、将男性面孔
变为女性面孔的向量等。图12-15 是微笑向量的一个例子,它是由新西兰维多利亚大学设计学
院的Tom White 发现的概念向量,他使用的是在名人脸部属性数据集(CelebA 数据集)上训
练的VAE。
变分自编码器
变分自编码器由Diederik P. Kingma 和Max Welling 于2013 年12 月a 以及Danilo Jimenez
Rezende、Shakir Mohamed 和Daan Wierstra 于2014 年1 月b 几乎同时提出。它是一种生成式模型,
特别适用于利用概念向量进行图像编辑。这种现代化的自编码器将深度学习思想与贝叶斯推断
结合在一起。自编码器是一类网络,其目的是将输入编码到低维潜在空间,然后再解码回来。
经典的图像自编码器接收一张图像,通过编码器模块将其映射到潜在向量空间,然后再通
过解码器模块将其解码为与原始图像具有相同尺寸的输出,如图12-16 所示。然后,使用与输
入图像相同的图像作为目标数据来训练这个自编码器,也就是说,自编码器学习对原始输入进
行重构。通过对编码(编码器输出)施加各种限制,我们可以让自编码器学到比较有趣的数据潜在表示。最常见的情况是,将编码限制为低维且稀疏的(大部分元素为0)。在这种情况下,
编码器可以压缩输入数据。
VAE 并没有将输入图像压缩为潜在空间中的固定编码,而是将图像转换为统计分布的参数,
即均值和方差。这本质上意味着我们假设输入图像是由统计过程生成的,在编码和解码的过程
中应该考虑随机性。VAE 使用均值和方差这两个参数从分布中随机采样一个元素,并将这个元
素解码为原始输入,如图12-17 所示。这一过程的随机性提高了其稳健性,并迫使潜在空间的
任何位置都对应有意义的表示,即在潜在空间中采样的每个点都能解码为有效输出。
从技术角度来说,VAE 的工作原理如下。
(1) 编码器模块将输入样本input_img 转换为图像表示潜在空间中的两个参数z_mean 和
z_log_var。
(2) 我们假定潜在正态分布能够生成输入图像,并从这个分布中随机采样一个点z:z =
z_mean + exp(0.5 * z_log_var) * epsilon,其中epsilon 是一个取值很小的
随机张量。
(3) 解码器模块将潜在空间中的这个点映射回原始输入图像。
由于epsilon 是随机的,因此这个过程可以保证,与编码input_img 的潜在位置(z-mean)
接近的每个点都能被解码为与input_img 相似的图像,从而迫使潜在空间能够连续地有意义。
潜在空间中任意两个相邻的点都可以被解码为高度相似的图像。潜在空间的连续性和低维度,
迫使其中的每个方向都表示数据中的一个有意义的变化轴。这使得潜在空间具有非常好的结构,
因此非常适合通过概念向量来进行操作。
VAE 的参数可以通过两个损失函数来训练:一个是重构损失(reconstruction loss),它迫使
解码后的样本匹配原始输入;另一个是正则化损失(regularization loss),它有助于学习良好的
潜在分布,并降低在训练数据上的过拟合。整个过程的原理如下所示。
接下来,我们可以使用重构损失和正则化损失来训练模型。对于正则化损失,我们通常使
用一个表达式(Kullback−Leibler 散度),旨在让编码器输出的分布趋向于以0 为中心的光滑正
态分布。这为编码器提供了一个关于潜在空间结构的合理假设。
下面我们来看一下如何在实践中实现VAE。
用Keras 实现变分自编码器
我们将实现一个能够生成MNIST 数字图像的VAE,它包含以下3 部分。
- 编码器网络:将真实图像转换为潜在空间中的均值和方差。
- 采样层:接收上述均值和方差,并利用它们从潜在空间中随机采样一个点。
- 解码器网络:将潜在空间中的点重新转换为图像。
代码清单12-24 给出了我们要使用的编码器网络,它将图像映射为潜在空间中的概率分布
参数。它是一个简单的卷积神经网络,将输入图像x 映射为z_mean 和z_log_var 两个向量。
这里有一个重要的细节:我们使用步幅对特征图进行下采样,而没有使用最大汇聚。上次我们
这样做是在第9 章的图像分割示例中。回想一下,一般来说,对于关注信息位置(物体在图像
中的位置)的模型来说,步幅比最大汇聚更适合。本模型需要关注信息位置,因为它需要生成
图像编码并将其用于重构有效图像。
代码清单12-24 VAE 编码器网络
代码清单12-25 利用z_mean 和z_log_var 来生成一个潜在空间点z,假设二者是生成
input_img 的统计分布参数。
代码清单12-25 潜在空间采样层
代码清单12-26 给出了解码器网络的实现。我们将向量z 的形状调整为图像尺寸,然后使
用几个卷积层来得到最终的图像输出,其尺寸与原始图像input_img 相同。
代码清单12-26 VAE 解码器网络,将潜在空间点映射为图像
它的架构如下所示。
下面来创建VAE 模型。这是我们的第一个非监督学习模型示例(自编码器是一种自监督
学习,因为它使用输入作为目标)。如果你要做的不是经典的监督学习,那么常见的做法是将
Model 类子类化,并实现自定义的train_ step() 来给出新的训练逻辑,这是第7 章介绍过
的工作流程。我们在这里也会这样做,如代码清单12-27 所示。
代码清单12-27 使用自定义train_step() 的VAE 模型
class VAE(keras.Model): def __init__(self, encoder, decoder, **kwargs): super().__init__(**kwargs) self.encoder = encoder self.decoder = decoder self.sampler = Sampler()
最后,我们将模型实例化并在MNIST 数字上进行训练,如代码清单12-28 所示。由于在自
定义层中给定了损失,因此在编译时无须指定外部损失(loss=None),这又意味着在训练过程
中无须传入目标数据(如你所见,我们在调用fit() 时只向模型传入了x_train)。
代码清单12-28 训练VAE
模型训练完成之后,我们可以使用decoder 网络将任意潜在空间向量转换为图像,如代码
清单12-29 所示。
代码清单12-29 从二维潜在空间中采样图像网格
采样数字的网格(见图12-18)展示了不同数字类别之间完全连续的分布:沿着潜在空间
的一条路径观察,你会发现一个数字逐渐变形为另一个数字。这个空间的特定方向是有意义的,
比如,有些方向表示“逐渐变为5”“逐渐变为1”等。
12.5 节将详细介绍生成人造图像的另一个重要工具:生成式对抗网络(GAN)。
完整代码
importnumpyasnpimportmatplotlib.pyplotaspltimporttensorflowastffromtensorflowimportkerasfromtensorflow.kerasimportlayers# 设置随机种子np.random.seed(42)tf.random.set_seed(42)# 加载并预处理MNIST数据(x_train,_),(x_test,_)=keras.datasets.mnist.load_data()x_train=x_train.astype('float32')/255.0x_test=x_test.astype('float32')/255.0x_train=np.expand_dims(x_train,-1)x_test=np.expand_dims(x_test,-1)print(f"训练数据形状:{x_train.shape}")print(f"测试数据形状:{x_test.shape}")# 定义编码器网络latent_dim=2encoder_inputs=keras.Input(shape=(28,28,1))x=layers.Conv2D(32,3,activation="relu",strides=2,padding="same")(encoder_inputs)x=layers.Conv2D(64,3,activation="relu",strides=2,padding="same")(x)x=layers.Flatten()(x)x=layers.Dense(16,activation="relu")(x)z_mean=layers.Dense(latent_dim,name="z_mean")(x)z_log_var=layers.Dense(latent_dim,name="z_log_var")(x)encoder=keras.Model(encoder_inputs,[z_mean,z_log_var],name="encoder")encoder.summary()# 定义采样层classSampler(layers.Layer):defcall(self,z_mean,z_log_var):batch_size=tf.shape(z_mean)[0]z_size=tf.shape(z_mean)[1]epsilon=tf.random.normal(shape=(batch_size,z_size))returnz_mean+tf.exp(0.5*z_log_var)*epsilon# 定义解码器网络latent_inputs=keras.Input(shape=(latent_dim,))x=layers.Dense(7*7*64,activation="relu")(latent_inputs)x=layers.Reshape((7,7,64))(x)x=layers.Conv2DTranspose(64,3,activation="relu",strides=2,padding="same")(x)x=layers.Conv2DTranspose(32,3,activation="relu",strides=2,padding="same")(x)decoder_outputs=layers.Conv2D(1,3,activation="sigmoid",padding="same")(x)decoder=keras.Model(latent_inputs,decoder_outputs,name="decoder")decoder.summary()# 定义VAE模型classVAE(keras.Model):def__init__(self,encoder,decoder,**kwargs):super().__init__(**kwargs)self.encoder=encoder self.decoder=decoder self.sampler=Sampler()self.total_loss_tracker=keras.metrics.Mean(name="total_loss")self.reconstruction_loss_tracker=keras.metrics.Mean(name="reconstruction_loss")self.kl_loss_tracker=keras.metrics.Mean(name="kl_loss")@propertydefmetrics(self):return[self.total_loss_tracker,self.reconstruction_loss_tracker,self.kl_loss_tracker,]deftrain_step(self,data):withtf.GradientTape()astape:z_mean,z_log_var=self.encoder(data)z=self.sampler(z_mean,z_log_var)reconstruction=self.decoder(z)# 计算重构损失reconstruction_loss=tf.reduce_mean(tf.reduce_sum(keras.losses.binary_crossentropy(data,reconstruction),axis=(1,2)))# 计算KL散度损失kl_loss=-0.5*(1+z_log_var-tf.square(z_mean)-tf.exp(z_log_var))kl_loss=tf.reduce_mean(tf.reduce_sum(kl_loss,axis=1))# 总损失total_loss=reconstruction_loss+kl_loss# 计算梯度并更新权重grads=tape.gradient(total_loss,self.trainable_weights)self.optimizer.apply_gradients(zip(grads,self.trainable_weights))# 更新指标self.total_loss_tracker.update_state(total_loss)self.reconstruction_loss_tracker.update_state(reconstruction_loss)self.kl_loss_tracker.update_state(kl_loss)return{"loss":self.total_loss_tracker.result(),"reconstruction_loss":self.reconstruction_loss_tracker.result(),"kl_loss":self.kl_loss_tracker.result(),}# 创建并编译VAEvae=VAE(encoder,decoder)vae.compile(optimizer=keras.optimizers.Adam())# 训练VAEprint("开始训练VAE...")history=vae.fit(x_train,epochs=30,batch_size=128,validation_data=(x_test,None))# 可视化训练过程defplot_training_history(history):fig,axes=plt.subplots(1,3,figsize=(15,4))axes[0].plot(history.history['loss'],label='训练损失')axes[0].set_title('总损失')axes[0].set_xlabel('Epoch')axes[0].legend()axes[1].plot(history.history['reconstruction_loss'],label='训练重构损失')axes[1].set_title('重构损失')axes[1].set_xlabel('Epoch')axes[1].legend()axes[2].plot(history.history['kl_loss'],label='训练KL损失')axes[2].set_title('KL散度损失')axes[2].set_xlabel('Epoch')axes[2].legend()plt.tight_layout()plt.show()plot_training_history(history)# 从潜在空间生成图像网格defplot_latent_space(decoder,n=30,figsize=15):# 显示数字的n×n网格digit_size=28scale=2.0figure=np.zeros((digit_size*n,digit_size*n))# 在标准正态分布的潜在空间中构建点网格grid_x=np.linspace(-scale,scale,n)grid_y=np.linspace(-scale,scale,n)[::-1]fori,yiinenumerate(grid_y):forj,xiinenumerate(grid_x):z_sample=np.array([[xi,yi]])x_decoded=decoder.predict(z_sample,verbose=0)digit=x_decoded[0].reshape(digit_size,digit_size)figure[i*digit_size:(i+1)*digit_size,j*digit_size:(j+1)*digit_size,]=digit plt.figure(figsize=(figsize,figsize))start_range=digit_size//2end_range=n*digit_size+start_range pixel_range=np.arange(start_range,end_range,digit_size)sample_range_x=np.round(grid_x,1)sample_range_y=np.round(grid_y,1)plt.xticks(pixel_range,sample_range_x)plt.yticks(pixel_range,sample_range_y)plt.xlabel("z[0]")plt.ylabel("z[1]")plt.imshow(figure,cmap="Greys_r")plt.title("从潜在空间解码得到的数字网格")plt.show()print("\n生成潜在空间图像网格...")plot_latent_space(decoder,n=20,figsize=10)# 可视化重构结果defplot_reconstruction(model,x_test,n=10):plt.figure(figsize=(20,4))foriinrange(n):# 显示原始图像ax=plt.subplot(2,n,i+1)plt.imshow(x_test[i].reshape(28,28),cmap='gray')plt.title("原始")plt.axis("off")# 显示重构图像ax=plt.subplot(2,n,i+1+n)# 获取编码并解码z_mean,z_log_var=model.encoder.predict(x_test[i:i+1],verbose=0)z=model.sampler(z_mean,z_log_var)reconstruction=model.decoder.predict(z,verbose=0)plt.imshow(reconstruction[0].reshape(28,28),cmap='gray')plt.title("重构")plt.axis("off")plt.suptitle("原始图像 vs 重构图像")plt.show()print("\n展示重构结果...")plot_reconstruction(vae,x_test)# 在潜在空间中插值defplot_interpolation(decoder,start_point,end_point,n=10):"""在潜在空间中的两个点之间进行插值"""interpolated_points=np.linspace(start_point,end_point,n)plt.figure(figsize=(20,2))fori,pointinenumerate(interpolated_points):ax=plt.subplot(1,n,i+1)img=decoder.predict(point.reshape(1,-1),verbose=0)[0]plt.imshow(img.reshape(28,28),cmap='gray')plt.axis('off')plt.suptitle("潜在空间插值")plt.show()# 选择两个不同的潜在点进行插值print("\n展示潜在空间插值...")start_point=np.array([[-2.0,0.0]])end_point=np.array([[2.0,0.0]])plot_interpolation(decoder,start_point,end_point)