博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
【pytorch】回归拟合
阅读量:2225 次
发布时间:2019-05-09

本文共 1826 字,大约阅读时间需要 6 分钟。

import torchfrom torch.autograd import Variableimport torch.nn.functional as Fimport matplotlib.pyplot as pltx = torch.unsqueeze(torch.linspace(-1,1,100),dim = 1)  #压缩为2维,因为torch 中 只会处理二维的数据# print(x)y = x.pow(2) + 0.2 * torch.rand(x.size())# print(y)x,y = Variable(x),Variable(y)# 神经网络中只用Variable的方法# plt.scatter(x.data.numpy(),y.data.numpy())# plt.show()  # 散点图class Net(torch.nn.Module):  # 继承 torch 的 Module    def __init__(self, n_feature, n_hidden, n_output):        super(Net, self).__init__()     # 继承 __init__ 功能        # 定义每层用什么样的形式        self.hidden = torch.nn.Linear(n_feature, n_hidden)   # 隐藏层线性输出        self.predict = torch.nn.Linear(n_hidden, n_output)   # 输出层线性输出    def forward(self, x):   # 这同时也是 Module 中的 forward 功能        # 正向传播输入值, 神经网络分析出输出值        x = F.relu(self.hidden(x))      # 激励函数(隐藏层的线性值)        x = self.predict(x)             # 输出值        return xnet = Net(n_feature=1, n_hidden=10, n_output=1)# print(net)  # net 的结构"""Net (  (hidden): Linear (1 -> 10)  (predict): Linear (10 -> 1))"""plt.ion()   # 画图plt.show()optimizer = torch.optim.SGD(net.parameters(), lr=0.2)  # 传入 net 的所有参数, 学习率loss_func = torch.nn.MSELoss()      # 预测值和真实值的误差计算公式 (均方差)for t in range(100):    prediction = net(x)     # 喂给 net 训练数据 x, 输出预测值    loss = loss_func(prediction, y)     # 计算两者的误差    optimizer.zero_grad()   # 清空上一步的残余更新参数值    loss.backward()         # 误差反向传播, 计算参数更新值    optimizer.step()        # 将参数更新值施加到 net 的 parameters 上    # 接着上面来    if t % 5 == 0:        # plot and show learning process        plt.cla()        plt.scatter(x.data.numpy(), y.data.numpy())        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)        plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={
'size': 20, 'color': 'red'}) plt.pause(0.1) # 误差为0.1 的时候退出 #但是模拟的结果参数是多少?plt.ioff()plt.show()

转载地址:http://bbmfb.baihongyu.com/

你可能感兴趣的文章
(四)alin’s mysql学习笔记----索引简介
查看>>
分布式系统中的幂等性的理解
查看>>
spring的注解开发中的常用注解(一)------@bean @Configuration @ComponentScan @Import @Scope @Lazy
查看>>
(五)alin’s mysql学习笔记----索引性能分析
查看>>
Spring中使用@Transactional注解进行事务管理的时候只有应用到 public 方法才有效
查看>>
springboot整合rabbitmq及rabbitmq的简单入门
查看>>
mysql事务和隔离级别笔记
查看>>
事务的传播属性(有坑点)自调用失效学习笔记
查看>>
REDIS缓存穿透,缓存击穿,缓存雪崩原因+解决方案
查看>>
动态代理实现AOP
查看>>
23种常见的java设计模式
查看>>
关于被final修饰的基本数据类型一些注意事项
查看>>
java Thread中,run方法和start方法的区别
查看>>
在 XML 中有 5 个预定义的实体引用
查看>>
XML 元素是可扩展的
查看>>
避免 XML 属性?针对元数据的 XML 属性
查看>>
XML DOM nodeType 属性值代表的意思
查看>>
JSP相关知识
查看>>
JDBC的基本知识
查看>>
《Head first设计模式》学习笔记 - 适配器模式
查看>>