差分
このページの2つのバージョン間の差分を表示します。
| 両方とも前のリビジョン前のリビジョン次のリビジョン | 前のリビジョン | ||
| pytorch:regression [2022/06/02 13:49] – watalu | pytorch:regression [2022/06/02 13:56] (現在) – watalu | ||
|---|---|---|---|
| 行 75: | 行 75: | ||
| model = Net(cols, size_hidden, | model = Net(cols, size_hidden, | ||
| - | optimizer = torch.optim.Adam(net.parameters(), | + | optimizer = torch.optim.Adam(model.parameters(), |
| criterion = torch.nn.MSELoss(reduction=' | criterion = torch.nn.MSELoss(reduction=' | ||
| </ | </ | ||
| 行 119: | 行 119: | ||
| ax.plot(loss_train_history, | ax.plot(loss_train_history, | ||
| ax.plot(loss_test_history, | ax.plot(loss_test_history, | ||
| + | plt.show() | ||
| + | fig, ax = plt.subplots() | ||
| + | ax.plot(loss_train_history, | ||
| + | ax.plot(loss_test_history, | ||
| + | plt.ylim(0, 8000) | ||
| + | plt.show() | ||
| + | plt.plot(loss_test_records) | ||
| plt.show() | plt.show() | ||
| 行 127: | 行 134: | ||
| pred=result.data[:, | pred=result.data[:, | ||
| print(len(pred), | print(len(pred), | ||
| - | r2_score(pred, | + | print(r2_score(pred, |
| X = Variable(torch.FloatTensor(X_test)) | X = Variable(torch.FloatTensor(X_test)) | ||
| 行 133: | 行 140: | ||
| pred=result.data[:, | pred=result.data[:, | ||
| print(len(pred), | print(len(pred), | ||
| - | r2_score(pred, | + | print(r2_score(pred, |
| </ | </ | ||