差分
このページの2つのバージョン間の差分を表示します。
| 次のリビジョン | 前のリビジョン | ||
| pytorch:regression [2022/06/02 13:48] – 作成 watalu | pytorch:regression [2022/06/02 13:56] (現在) – watalu | ||
|---|---|---|---|
| 行 61: | 行 61: | ||
| < | < | ||
| device = torch.device(" | device = torch.device(" | ||
| - | # Assume that we are on a CUDA machine, then this should print a CUDA device: | + | |
| print(" | print(" | ||
| class Net(torch.nn.Module): | class Net(torch.nn.Module): | ||
| def __init__(self, | def __init__(self, | ||
| super(Net, self).__init__() | super(Net, self).__init__() | ||
| - | self.hidden = torch.nn.Linear(cols, | + | self.hidden = torch.nn.Linear(cols, |
| - | self.predict = torch.nn.Linear(size_hidden, | + | self.predict = torch.nn.Linear(size_hidden, |
| def forward(self, | def forward(self, | ||
| - | x = F.relu(self.hidden(x)) | + | x = F.relu(self.hidden(x)) |
| - | x = self.predict(x) | + | x = self.predict(x) |
| return x | return x | ||
| model = Net(cols, size_hidden, | model = Net(cols, size_hidden, | ||
| - | optimizer = torch.optim.Adam(net.parameters(), | + | optimizer = torch.optim.Adam(model.parameters(), |
| - | criterion = torch.nn.MSELoss(reduction=' | + | criterion = torch.nn.MSELoss(reduction=' |
| </ | </ | ||
| 行 119: | 行 119: | ||
| ax.plot(loss_train_history, | ax.plot(loss_train_history, | ||
| ax.plot(loss_test_history, | ax.plot(loss_test_history, | ||
| + | plt.show() | ||
| + | fig, ax = plt.subplots() | ||
| + | ax.plot(loss_train_history, | ||
| + | ax.plot(loss_test_history, | ||
| + | plt.ylim(0, 8000) | ||
| + | plt.show() | ||
| + | plt.plot(loss_test_records) | ||
| plt.show() | plt.show() | ||
| 行 127: | 行 134: | ||
| pred=result.data[:, | pred=result.data[:, | ||
| print(len(pred), | print(len(pred), | ||
| - | r2_score(pred, | + | print(r2_score(pred, |
| X = Variable(torch.FloatTensor(X_test)) | X = Variable(torch.FloatTensor(X_test)) | ||
| 行 133: | 行 140: | ||
| pred=result.data[:, | pred=result.data[:, | ||
| print(len(pred), | print(len(pred), | ||
| - | r2_score(pred, | + | print(r2_score(pred, |
| </ | </ | ||