DCGAN代码解析(二)

本文涉及的产品
云解析 DNS,旗舰版 1个月
全局流量管理 GTM,标准版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
简介: DCGAN代码解析(二)

4 损失函数


使用 二项交叉熵(Binary Cross Entropy, BCE)Loss

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-o3tCo2KR-1664249499192)(figures/BCE-loss.png)]

# Loss function
adversarial_loss = torch.nn.BCELoss()


5 Cuda加速


cuda = True if torch.cuda.is_available() else False
print("cuda_is_available =", cuda)
if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
cuda_is_available = True


6 优化器


使用Adam优化器

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
print("learning_rate =", opt.lr)
learning_rate = 0.0002


7 创建输入


分别从数据集和随机向量中获取输入

for i, (imgs, labels) in list(enumerate(dataloader))[:1]:
    # Configure input
    real_imgs = Variable(imgs.type(Tensor))
    # Sample noise as generator input
    z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
    print("i =", i, '\n')
    print("shape of z =", z.shape, '\n')
    print("shape of real_imgs =", real_imgs.shape, '\n')
    print("z =", z, '\n')
    print("real_imgs =")
    for img in real_imgs[:3]:
        show_img(img)
i = 0 
shape of z = torch.Size([64, 100]) 
shape of real_imgs = torch.Size([64, 1, 32, 32]) 
z = tensor([[ 3.1224e-01, -1.1344e-01, -1.0401e+00,  ...,  1.8232e-01,
         -1.2940e+00,  1.3365e+00],
        [ 7.3029e-01,  4.0669e-01, -1.3267e-01,  ..., -4.9197e-01,
         -7.5093e-01, -1.1240e+00],
        [ 1.2938e+00,  7.8608e-01,  1.8455e-01,  ..., -5.0269e-01,
          7.9739e-01, -5.3891e-02],
        ...,
        [-7.9207e-01, -4.8256e-02,  4.5883e-01,  ...,  1.2142e+00,
          6.2461e-01, -1.5289e+00],
        [-1.4916e-03,  4.8395e-01, -3.0754e-01,  ..., -1.8773e-01,
         -5.0988e-01, -1.2065e+00],
        [ 1.2712e+00, -5.0849e-01,  6.2769e-01,  ...,  1.0904e+00,
          2.1514e-01, -4.0929e-01]], device='cuda:0') 
real_imgs =


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-xKXVVKIB-1664249499192)(test_files/test_21_1.png)]


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SxHmPOLV-1664249499193)(test_files/test_21_2.png)]


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3X41G5yf-1664249499194)(test_files/test_21_3.png)]


8 计算loss,反向传播


分别对生成器和判别器计算loss,使用反向传播更新模型参数

# Adversarial ground truths
    batch_size = imgs.shape[0]
    valid = Variable(Tensor(batch_size, 1).fill_(1.0), requires_grad=False) # 为1时判定为真
    fake = Variable(Tensor(batch_size, 1).fill_(0.0), requires_grad=False) # 为0时判定为假
    # ---------------------
    #  Train Generator
    # ---------------------
    optimizer_G.zero_grad()
    # Sample noise as generator input
    z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
    # Generate a batch of images
    gen_imgs = generator(z)
    # Loss measures generator's ability to fool the discriminator
    g_loss = adversarial_loss(discriminator(gen_imgs), valid)
    g_loss.backward()
    optimizer_G.step()
    # ---------------------
    #  Train Discriminator
    # ---------------------
    optimizer_D.zero_grad()
    # Measure discriminator's ability to classify real from generated samples
    real_loss = adversarial_loss(discriminator(real_imgs), valid)
    fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
    d_loss = (real_loss + fake_loss) / 2
    print("real_loss =", real_loss, '\n')
    print("fake_loss =", fake_loss, '\n')
    print("d_loss =", d_loss, '\n')    
    d_loss.backward()
    optimizer_D.step()
real_loss = tensor(0.7088, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>) 
fake_loss = tensor(0.6778, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>) 
d_loss = tensor(0.6933, device='cuda:0', grad_fn=<DivBackward0>)


9 保存生成图像和模型文件


from torchvision.utils import save_image
    def sample_image(n_row, batches_done):
        """Saves a grid of generated digits ranging from 0 to n_classes"""
        # Sample noise
        z = Variable(Tensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
        # Get labels ranging from 0 to n_classes for n rows
        gen_imgs = generator(z)
        save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)
    epoch = 0 # temporary
    batches_done = epoch * len(dataloader) + i
    if batches_done % opt.sample_interval == 0:
        os.makedirs("images", exist_ok=True)
        sample_image(n_row=10, batches_done=batches_done)
        os.makedirs("model", exist_ok=True) # 保存模型
        torch.save(generator, 'model/generator.pkl') 
        torch.save(discriminator, 'model/discriminator.pkl')
        print("gen images saved!\n")
        print("model saved!")
gen images saved!
model saved!


rue)

epoch = 0 # temporary
batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
    os.makedirs("images", exist_ok=True)
    sample_image(n_row=10, batches_done=batches_done)
    os.makedirs("model", exist_ok=True) # 保存模型
    torch.save(generator, 'model/generator.pkl') 
    torch.save(discriminator, 'model/discriminator.pkl')
    print("gen images saved!\n")
    print("model saved!")
gen images saved!
    model saved!
目录
相关文章
|
9天前
|
PHP 开发者 容器
PHP命名空间深度解析:避免命名冲突与提升代码组织####
本文深入探讨了PHP中命名空间的概念、用途及最佳实践,揭示其在解决全局命名冲突、提高代码可维护性方面的重要性。通过生动实例和详尽分析,本文将帮助开发者有效利用命名空间来优化大型项目结构,确保代码的清晰与高效。 ####
13 1
|
1月前
|
存储 安全 Java
系统安全架构的深度解析与实践:Java代码实现
【11月更文挑战第1天】系统安全架构是保护信息系统免受各种威胁和攻击的关键。作为系统架构师,设计一套完善的系统安全架构不仅需要对各种安全威胁有深入理解,还需要熟练掌握各种安全技术和工具。
109 10
|
1月前
|
前端开发 JavaScript 开发者
揭秘前端高手的秘密武器:深度解析递归组件与动态组件的奥妙,让你代码效率翻倍!
【10月更文挑战第23天】在Web开发中,组件化已成为主流。本文深入探讨了递归组件与动态组件的概念、应用及实现方式。递归组件通过在组件内部调用自身,适用于处理层级结构数据,如菜单和树形控件。动态组件则根据数据变化动态切换组件显示,适用于不同业务逻辑下的组件展示。通过示例,展示了这两种组件的实现方法及其在实际开发中的应用价值。
36 1
|
2月前
|
机器学习/深度学习 人工智能 算法
揭开深度学习与传统机器学习的神秘面纱:从理论差异到实战代码详解两者间的选择与应用策略全面解析
【10月更文挑战第10天】本文探讨了深度学习与传统机器学习的区别,通过图像识别和语音处理等领域的应用案例,展示了深度学习在自动特征学习和处理大规模数据方面的优势。文中还提供了一个Python代码示例,使用TensorFlow构建多层感知器(MLP)并与Scikit-learn中的逻辑回归模型进行对比,进一步说明了两者的不同特点。
80 2
|
2月前
|
存储 搜索推荐 数据库
运用LangChain赋能企业规章制度制定:深入解析Retrieval-Augmented Generation(RAG)技术如何革新内部管理文件起草流程,实现高效合规与个性化定制的完美结合——实战指南与代码示例全面呈现
【10月更文挑战第3天】构建公司规章制度时,需融合业务实际与管理理论,制定合规且促发展的规则体系。尤其在数字化转型背景下,利用LangChain框架中的RAG技术,可提升规章制定效率与质量。通过Chroma向量数据库存储规章制度文本,并使用OpenAI Embeddings处理文本向量化,将现有文档转换后插入数据库。基于此,构建RAG生成器,根据输入问题检索信息并生成规章制度草案,加快更新速度并确保内容准确,灵活应对法律与业务变化,提高管理效率。此方法结合了先进的人工智能技术,展现了未来规章制度制定的新方向。
39 3
|
2月前
|
SQL 监控 关系型数据库
SQL错误代码1303解析与处理方法
在SQL编程和数据库管理中,遇到错误代码是常有的事,其中错误代码1303在不同数据库系统中可能代表不同的含义
|
2月前
|
SQL 安全 关系型数据库
SQL错误代码1303解析与解决方案:深入理解并应对权限问题
在数据库管理和开发过程中,遇到错误代码是常见的事情,每个错误代码都代表着一种特定的问题
|
3月前
|
敏捷开发 安全 测试技术
软件测试的艺术:从代码到用户体验的全方位解析
本文将深入探讨软件测试的重要性和实施策略,通过分析不同类型的测试方法和工具,展示如何有效地提升软件质量和用户满意度。我们将从单元测试、集成测试到性能测试等多个角度出发,详细解释每种测试方法的实施步骤和最佳实践。此外,文章还将讨论如何通过持续集成和自动化测试来优化测试流程,以及如何建立有效的测试团队来应对快速变化的市场需求。通过实际案例的分析,本文旨在为读者提供一套系统而实用的软件测试策略,帮助读者在软件开发过程中做出更明智的决策。
|
3月前
|
SQL 人工智能 机器人
遇到的代码部份解析
/ 模拟后端返回的数据
19 0
|
3月前
|
设计模式 存储 算法
PHP中的设计模式:策略模式的深入解析与应用在软件开发的浩瀚海洋中,PHP以其独特的魅力和强大的功能吸引了无数开发者。作为一门历史悠久且广泛应用的编程语言,PHP不仅拥有丰富的内置函数和扩展库,还支持面向对象编程(OOP),为开发者提供了灵活而强大的工具集。在PHP的众多特性中,设计模式的应用尤为引人注目,它们如同精雕细琢的宝石,镶嵌在代码的肌理之中,让程序更加优雅、高效且易于维护。今天,我们就来深入探讨PHP中使用频率颇高的一种设计模式——策略模式。
本文旨在深入探讨PHP中的策略模式,从定义到实现,再到应用场景,全面剖析其在PHP编程中的应用价值。策略模式作为一种行为型设计模式,允许在运行时根据不同情况选择不同的算法或行为,极大地提高了代码的灵活性和可维护性。通过实例分析,本文将展示如何在PHP项目中有效利用策略模式来解决实际问题,并提升代码质量。

推荐镜像

更多