tfgan折腾笔记(二):核心函数详述——gan_model族

摘要:
预先定义好的判别器函数的输入参数有两个:第一个是“真实数据(图像)”/“机器生成的图像”;第二个是生成器的输入,即此函数的第四个参数(在普通的gan当中,判别器只需要第一个参数。即使不需要第二个参数,也必须显式地定义出第二个参数,只不过定义了之后在判别器函数中可以不使用)。判别器的返回值必须在负无穷到正无穷之间。一般传入真实图像batch化后的引用。对于vallinagan,是tensor类型的噪声。

定义model的函数有:

1.gan_model

函数原型:

defgan_model(
    #Lambdas defining models.
generator_fn,
    discriminator_fn,
    #Real data and conditioning.
real_data,
    generator_inputs,
    #Optional scopes.
    generator_scope='Generator',
    discriminator_scope='Discriminator',
    #Options.
    check_shapes=True)

参数:

generator_fn:预先定义好的生成器网络的函数名称。预先定义好的生成器函数的输入参数应该是接下来要说明的第四个参数generator_inputs,生成网络的返回值是网络的输出(因为是GAN,所以生成器的输出一般是一幅机器生成的图像)。

discriminator_fn:预先定义好的判别器网络的函数名称。预先定义好的判别器函数的输入参数有两个:第一个是“真实数据(图像)”/“机器生成的图像(generator_fn的返回值)”;第二个是生成器的输入,即此函数的第四个参数(在普通的gan当中,判别器只需要第一个参数。即使不需要第二个参数,也必须显式地定义出第二个参数,只不过定义了之后在判别器函数中可以不使用)。判别器的返回值必须在负无穷到正无穷之间([-inf, +inf])。

real_data:真实图像。一般传入真实图像batch化后的引用。

generator_inputs:生成器的输入。对于vallina gan,是tensor类型的噪声。除此之外,如果是c-gan,还可以传入一个list或tuple作为参数(在下方的“其他说明“里详细说明c-gan(conditional-gan)的情况)。

generator_scope:传入这个参数可以定义生成器内参数的变量命名空间(variable_scope)。默认为"Generator"。

discriminator_scope:传入这个参数可以定义判别器内参数的变量命名空间(variable_scope)。默认为"Discriminator"。

check_shapes:如果为真,将检查生成器生成的数据与真实数据是否有相同的shape。如果为假,则跳过检查。

返回值:

返回一个“GANModel 命名管道”。实际上就是一个由生成器函数、判别器函数、生成的数据、变量空间等东西组成的一个List。这个返回值不需要我们写程序的时候用,就不过多解释了(具体用法见本系列上一篇文档:传送门)。

函数内部实现:

generator_fn和discriminator_fn在gan_model函数里这样调用:

#由机器生成数据
generated_data =generator_fn(generator_inputs)

#判别器判断机器生成图片的真实性
discriminator_gen_outputs =discriminator_fn(generated_data, generator_inputs)

#判别器判断真实图片的真实性
discriminator_real_outputs = discriminator_fn(real_data, generator_inputs)

其他说明:

  • gan_model支持conditional-gan。若需要训练c-gan,要通过generator_inputs额外传入标签信息。如:generator_inputs=(noise, one_hot_label)。同时,判别器网络与生成器网络应该按照c-gan论文中的模型重新定义。
  • real_data一般为一个next_batch。如:next_batch = tf.compat.v1.data.make_one_shot_iterator(image_ds).get_next()

2.infogan_model

函数原型:

definfogan_model(
    #Lambdas defining models.
generator_fn,
    discriminator_fn,
    #Real data and conditioning.
real_data,
    unstructured_generator_inputs,
    structured_generator_inputs,
    #Optional scopes.
    generator_scope='Generator',
    discriminator_scope='Discriminator')

参数:

generator_fn:预先定义好的生成器网络的函数名称。预先定义好的生成器函数的输入参数应该是接下来要说明的unstructrued_generator_inputs与structured_generator_inputs共同组成的列表,列表中的每一项是一个Tensor,生成网络的返回值是生成器的输出。

discriminator_fn:预先定义好的判别器网络的函数名称。预先定义好的判别器函数的输入参数应该有两个:第一个是“真实数据(图像)”/“机器生成的图像(generator_fn的返回值)”;第二个是生成器的输入,即(unstructrued_generator_inputs与structured_generator_inputs共同组成的列表)。预先定义好的判别器函数的输出应是一个二维Tuple。Tuple的第一维是生成器网络输出层的logits,范围在[-inf, +inf]。Tuple的第二维是分布的列表:此分布的第i个列表元素代表的是第i个structure noise 的分布。

real_data:真实图像。一般传入真实图像batch化后的引用。

unstructured_generator_inputs:Tensor的列表。表示非结构化的noise或条件。

structured_generator_inputs:Tensor的列表。这些Tensor必须与识别器具有较高的相互信息。

generator_scope:传入这个参数可以定义生成器内参数的变量命名空间(variable_scope)。默认为"Generator"。

discriminator_scope:传入这个参数可以定义判别器内参数的变量命名空间(variable_scope)。默认为"Discriminator"。

返回值:

返回一个“InfoGANModel 命名管道”。同“GANModel 命名管道”一样,我们无需关心管道中的具体内容。

函数内部实现:

生成器的输入这样定义:

generator_inputs = (unstructured_generator_inputs + structured_generator_inputs)

生成器和判别器这样调用:

#由机器生成数据
generated_data =generator_fn(generator_inputs)

#判别器判断机器生成图片的真实性
dis_gen_outputs, predicted_distributions = discriminator_fn(generated_data, generator_inputs)

#判别器判断真实图片的真实性
dis_real_outputs, _ = discriminator_fn(real_data, generator_inputs)

其他说明:

  • 关于生成器和判别器网络模型的搭建,请参照Info-GAN的论文。
  • real_data一般为一个next_batch。如:next_batch = tf.compat.v1.data.make_one_shot_iterator(image_ds).get_next()

3.acgan_model:

函数原型:

defacgan_model(
    #Lambdas defining models.
generator_fn,
    discriminator_fn,
    #Real data and conditioning.
real_data,
    generator_inputs,
    one_hot_labels,
    #Optional scopes.
    generator_scope='Generator',
    discriminator_scope='Discriminator',
    #Options.
    check_shapes=True)

参数:

与gan_model中的参数基本一致,除了:

discriminator_fn:预定义的判别器函数应当返回一个二维Tuple。第一维是网络输出层的real或者fake的logits;第二维是分类器的logits。他们两个的范围都应该是[-inf, +inf]。

one_hot_labels:对应于一个batch图像的one_hot_label。

返回值:

返回“AcGANModel 命名管道”。同“GANModel 命名管道”一样,我们无需关心管道中的具体内容。

函数内部实现:

生成器和判别器这样调用:

#由机器生成数据
generated_data =generator_fn(generator_inputs)

#判别器判断机器生成图片的真实性
(discriminator_gen_outputs, discriminator_gen_classification_logits) =_validate_acgan_discriminator_outputs(discriminator_fn(generated_data, generator_inputs))

#判别器判断真实图片的真实性
(discriminator_real_outputs, discriminator_real_classification_logits) = _validate_acgan_discriminator_outputs(discriminator_fn(real_data, generator_inputs))

其他说明:

  • one_hot_labels在此函数内部没有被使用,而是直接通过命名管道(返回值)传递给gan_loss函数(下一篇详细说明)。
  • one_hot_labels与real_data均为batch。

4.cyclegan_model:

函数原型:

defcyclegan_model(
    #Lambdas defining models.
generator_fn,
    discriminator_fn,
    #data X and Y.
data_x,
    data_y,
    #Optional scopes.
    generator_scope='Generator',
    discriminator_scope='Discriminator',
    model_x2y_scope='ModelX2Y',
    model_y2x_scope='ModelY2X',
    #Options.
    check_shapes=True)

参数:

generator_fn:预先定义好的生成器函数。此生成器的输入有一个参数,与gan_model的generator_fn一样。返回值为生成器网络的输出。

discriminator_fn:预先定义好的判别器函数。与gan_model的discriminator_fn定义一样。

data_x:x域的真实数据。

data_y:y域的真实数据。

generator_scope:与gan_model的generator_scope意义一样。

discriminator_scope:与gan_model的discriminator_scope意义一样。

model_x2y_scope:x->y转换过程的variable_scope。

model_y2x_scope:y->x转换过程的variable_scope。

check_shapes:如果为真,将检查生成器生成的数据与真实数据是否有相同的shape。如果为假,则跳过检查。

返回值:

返回“CycleGANModel 命名空间”。

函数内部实现:

此函数实际上调用了gan_model函数,如下所示:

#Create models.
  def_define_partial_model(input_data, output_data):    # 内部函数定义
    returngan_model(
        generator_fn=generator_fn,
        discriminator_fn=discriminator_fn,
        real_data=output_data,
        generator_inputs=input_data,
        generator_scope=generator_scope,
        discriminator_scope=discriminator_scope,
        check_shapes=check_shapes)

  with tf.compat.v1.variable_scope(model_x2y_scope):
    model_x2y =_define_partial_model(data_x, data_y)
  with tf.compat.v1.variable_scope(model_y2x_scope):
    model_y2x =_define_partial_model(data_y, data_x)

  with tf.compat.v1.variable_scope(model_y2x.generator_scope, reuse=True):
    reconstructed_x =model_y2x.generator_fn(model_x2y.generated_data)
  with tf.compat.v1.variable_scope(model_x2y.generator_scope, reuse=True):
    reconstructed_y =model_x2y.generator_fn(model_y2x.generated_data)

  returnnamedtuples.CycleGANModel(model_x2y, model_y2x, reconstructed_x,
                                   reconstructed_y)

其他说明:

5.stargan_model

函数原型:

defstargan_model(generator_fn,
                  discriminator_fn,
                  input_data,
                  input_data_domain_label,
                  generator_scope='Generator',
                  discriminator_scope='Discriminator')

参数:

generator_fn:预先定义好的函数的函数名称。函数的输入有两个,应分别为:input、target,返回值是根据inputs和targets由机器生成的图像。inputs的形状应该是(batch, height, width, channel),targets的形状是(batch, num_domain)。返回值有和inputs相同的形状。

discriminator_fn:预先定义好的函数的函数名称。此函数的输入有两个,分别为input和num_domain。返回值是一个Tuple:(`source_prediction`, `domain_prediction`)。`source_prediction`表示预测的图像(真实或生成的)真实度,“ domain_prediction”代表判别器对域分类的预测(真实度)。 `source_prediction`的形状是(batch), `domain_prediction`具有形状(batch,num_domains)。

input_data:Tensor或Tensor组成的列表。代表真实输入的图片。形状是(batch, height, width, channel)。

input_data_domain_label:Tensor或Tensor组成的列表。形状为(batch, num_domains)。表示真实数据相对应的代表域的标签。

generator_scope:与gan_model的此参数意义相同。

discriminator_scope:与gan_model的此参数意义相同。

返回值:

返回“StarGANModel 命名空间”。

函数内部实现:

函数内部重要代码如下:

  #Transform input_data to random target domains.
with tf.compat.v1.variable_scope(generator_scope) as generator_scope:
    generated_data_domain_target =generate_stargan_random_domain_target(
        batch_size, num_domains)
    generated_data =generator_fn(input_data, generated_data_domain_target)

  #Transform generated_data back to the original input_data domain.
  with tf.compat.v1.variable_scope(generator_scope, reuse=True):
    reconstructed_data =generator_fn(generated_data, input_data_domain_label)

  #Predict source and domain for the generated_data using the discriminator.
with tf.compat.v1.variable_scope(discriminator_scope) as discriminator_scope:
    disc_gen_data_source_pred, disc_gen_data_domain_pred =discriminator_fn(
        generated_data, num_domains)

  #Predict source and domain for the input_data using the discriminator.
  with tf.compat.v1.variable_scope(discriminator_scope, reuse=True):
    disc_input_data_source_pred, disc_input_data_domain_pred =discriminator_fn(
        input_data, num_domains)

其他说明:

免责声明:文章转载自《tfgan折腾笔记(二):核心函数详述——gan_model族》仅用于学习参考。如对内容有疑问,请及时联系本站处理。

上篇贪念Openwrt自定义CGI实现下篇

宿迁高防,2C2G15M,22元/月;香港BGP,2C5G5M,25元/月 雨云优惠码:MjYwNzM=

相关文章

python-输入

1. python2版本中 咱们在银行ATM机器前取钱时,肯定需要输入密码,对不? 那么怎样才能让程序知道咱们刚刚输入的是什么呢?? 大家应该知道了,如果要完成ATM机取钱这件事情,需要先从键盘中输入一个数据,然后用一个变量来保存,是不是很好理解啊 1.1 raw_input() 在Python中,获取键盘输入的数据的方法是采用 raw_input...

jquery表单插件 jquery.form(异步提交)(学习总结)

不通过jquery.form实现异步提交 通过一个iframe 将form 的target定位到iframe中的name <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional...

mysql增删改和学生管理sql

importpymysql #2.建连 conn = pymysql.connect("localhost","root",'root','李森') print(conn) #3.获取游标 cur =conn.cursor() #4.增 sql="insert into student_1 values(default,%s,%s,%s,%s)"cur.e...

OpenCV——常用函数查询

1、cvLoadImage:将图像文件加载至内存; 2、cvNamedWindow:在屏幕上创建一个窗口; 3、cvShowImage:在一个已创建好的窗口中显示图像; 4、cvWaitKey:使程序暂停,等待用户触发一个按键操作; 5、cvReleaseImage:释放图像文件所分配的内存; 6、cvDestroyWindow:销毁显示图像文件的窗口;...

web读取本地文档

FileReader 对象允许Web应用程序异步读取存储在用户计算机上的文件(或原始数据缓冲区)的内容,使用 File 或 Blob 对象指定要读取的文件或数据。 其中File对象可以是来自用户在一个<input>元素上选择文件后返回的FileList对象,也可以来自拖放操作生成的 DataTransfer对象,还可以是来自在一个HTMLCan...

函数的返回值为结构体类型

可见,函数的返回值为结构体类型,其返回值既不是“值传递”也不是通过“寄存器”回传。编译器在编译此类函数时,为其附加了一个指针参数(指向的地址在caller的堆栈上),且作为函数的第一个参数(函数本身的参数依次后移),函数语义上的返回值通过该附加的指针参数回传,而函数真正的返回值就是该指针。 ————————————————版权声明:本文为CSDN博主「st...