轻识Logo
目录

    【深度学习】常见优化器的PyTorch实现


    这里主要讲不同常见优化器代码的实现,以及在一个小数据集上做一个简单的比较。

    备注:pytorch需要升级到最新版本

    其中,SGD和SGDM,还有Adam是pytorch自带的优化器,而RAdam是最近提出的一个说是Adam更强的优化器,但是一般情况下真正的大佬还在用SGDM来做优化器

    导入必要库:

    import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimimport matplotlib.pyplot as pltimport torch.utils.data as Datafrom torch.optim.optimizer import Optimizerimport math

    主程序部分:

    LR = 0.01BATCH_SIZE = 32EPOCH = 12
    # fake datasetx = torch.unsqueeze(torch.linspace(-1, 1, 300), dim=1)y = x.pow(2) + 0.1 * torch.normal(torch.zeros(*x.size()))
    torch_dataset = Data.TensorDataset(x, y)loader = Data.DataLoader( dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

    class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.hidden = nn.Linear(1, 20) self.prediction = nn.Linear(20, 1)
    def forward(self, x): x = F.relu(self.hidden(x)) x = self.prediction(x) return x

    def main(): net_SGD = Net() net_Momentum = Net() net_Adam = Net() net_RAdam = Net() nets = [net_SGD, net_Momentum, net_Adam, net_RAdam] opt_SGD = optim.SGD(net_SGD.parameters(), lr=LR) opt_Momentum = optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.9) opt_Adam = optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99)) opt_RAdam = RAdam(net_RAdam.parameters(),lr=LR,weight_decay=0) optimizers = [opt_SGD, opt_Momentum, opt_Adam, opt_RAdam] loss_func = nn.MSELoss() losses_his = [[], [], [], []] # training for epoch in range(EPOCH): print('EPOCH:', epoch) for step, (batch_x, batch_y) in enumerate(loader): b_x = batch_x b_y = batch_y for net, opt, l_his in zip(nets, optimizers, losses_his): out = net(b_x) loss = loss_func(out, b_y) opt.zero_grad() loss.backward() opt.step() l_his.append(loss.item()) labels = ['SGD', 'Momentum', 'Adam','RAdam'] for i, l_his in enumerate(losses_his): plt.plot(l_his, label=labels[i]) plt.legend(loc='best') plt.xlabel('Steps') plt.ylabel('Loss') plt.ylim((0, 0.2)) plt.show()

    if __name__ == '__main__': main()

    下图是优化器的对比:

    ba391ec0b462752f35f8dabf0dc8f2fa.webp

    可以看出来,Adam的效果可以说是非常好的。然后SGDM其次,SGDM是大佬们经常会使用的,所以在这里虽然看起来SGDM效果不如Adam,但是依然推荐在项目中,尝试一下SGDM的效果。


    往期精彩回顾





    • 适合初学者入门人工智能的路线及资料下载

    • 机器学习及深度学习笔记等资料打印

    • 机器学习在线手册

    • 深度学习笔记专辑

    • 《统计学习方法》的代码复现专辑

    • AI基础下载

    • 机器学习的数学基础专辑

    获取一折本站知识星球优惠券,复制链接直接打开:

    https://t.zsxq.com/yFQV7am

    本站qq群1003271085。

    加入微信群请扫码进群:

    浏览 29
    点赞
    评论
    收藏
    分享

    手机扫一扫分享

    举报
    DeepSpeed基于 PyTorch 的深度学习优化库
    DeepSpeed是一个深度学习优化库,它可以使分布式训练变得容易、高效和有效。10x更大的模型5x更快地训练最小的代码更改DeepSpeed可以在当前一代的GPU集群上训练具有超过千亿个参数的DL模
    DeepSpeed基于 PyTorch 的深度学习优化库
    0
    深度学习中的优化算法与实现
    GiantPandaCV
    0
    【深度学习】PyTorch训练一个CNN分类器
    机器学习算法与Python实战
    0
    pytorch优化器与学习率设置详解
    程序员大白
    0
    pytorch优化器与学习率设置详解
    视学算法
    0
    PyTorch深度学习实战
    PyTorch深度学习实战
    0
    pytorch优化器与学习率设置详解
    极市平台
    0
    PyTorch深度学习实战
    1.PyTorch核心开发者教你使用 PyTorch 创建神经网络和深度学习系统的实用指南。
    PyTorch深度学习实战
    0
    点赞
    评论
    收藏
    分享

    手机扫一扫分享

    举报

    PHP网站源码深圳百姓网标王光明设计网站坑梓网站改版沙井建网站坂田如何制作网站荷坳建站永湖SEO按天扣费石岩外贸网站制作福田百度竞价平湖网站定制坂田网站关键词优化深圳百姓网标王推广民治建网站观澜seo优化南山英文网站建设深圳百度seo大鹏建站平湖网站推广坑梓网站排名优化西乡SEO按天扣费坑梓网页制作双龙百度关键词包年推广大芬百度seo大鹏网站优化按天收费东莞模板网站建设南山SEO按效果付费观澜seo光明百度关键词包年推广大运百度标王福田网站关键词优化歼20紧急升空逼退外机英媒称团队夜以继日筹划王妃复出草木蔓发 春山在望成都发生巨响 当地回应60岁老人炒菠菜未焯水致肾病恶化男子涉嫌走私被判11年却一天牢没坐劳斯莱斯右转逼停直行车网传落水者说“没让你救”系谣言广东通报13岁男孩性侵女童不予立案贵州小伙回应在美国卖三蹦子火了淀粉肠小王子日销售额涨超10倍有个姐真把千机伞做出来了近3万元金手镯仅含足金十克呼北高速交通事故已致14人死亡杨洋拄拐现身医院国产伟哥去年销售近13亿男子给前妻转账 现任妻子起诉要回新基金只募集到26元还是员工自购男孩疑遭霸凌 家长讨说法被踢出群充个话费竟沦为间接洗钱工具新的一天从800个哈欠开始单亲妈妈陷入热恋 14岁儿子报警#春分立蛋大挑战#中国投资客涌入日本东京买房两大学生合买彩票中奖一人不认账新加坡主帅:唯一目标击败中国队月嫂回应掌掴婴儿是在赶虫子19岁小伙救下5人后溺亡 多方发声清明节放假3天调休1天张家界的山上“长”满了韩国人?开封王婆为何火了主播靠辱骂母亲走红被批捕封号代拍被何赛飞拿着魔杖追着打阿根廷将发行1万与2万面值的纸币库克现身上海为江西彩礼“减负”的“试婚人”因自嘲式简历走红的教授更新简介殡仪馆花卉高于市场价3倍还重复用网友称在豆瓣酱里吃出老鼠头315晚会后胖东来又人满为患了网友建议重庆地铁不准乘客携带菜筐特朗普谈“凯特王妃P图照”罗斯否认插足凯特王妃婚姻青海通报栏杆断裂小学生跌落住进ICU恒大被罚41.75亿到底怎么缴湖南一县政协主席疑涉刑案被控制茶百道就改标签日期致歉王树国3次鞠躬告别西交大师生张立群任西安交通大学校长杨倩无缘巴黎奥运

    PHP网站源码 XML地图 TXT地图 虚拟主机 SEO 网站制作 网站优化