skka3134

skka3134

email
telegram

機器學習和量化投資:3.pytorch創建數據集

  1. 安裝 pytorch,pytorch 是一個以 Python 為主的深度學習框架。使用 pytorch 可以自動組合因子成策略。GPU 訓練的話只有 N 卡支援,這裡選擇 cpu 模式就行。https://pytorch.org/ ,torchvision 用來處理圖像,torchaudio 處理音頻用不到所以不安裝。
sudo /home/skka3134/folder/bot/bin/python -m pip install torch

圖片
2. 設置資料集,從 Dateset 繼承類,形成 TimeSeriesDataset

from torch.utils.data import Dataset

class TimeSeriesDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return self.X[i], self.y[i]
    
train_dataset = TimeSeriesDataset(X_train, y_train)
test_dataset = TimeSeriesDataset(X_test, y_test)
  1. 加載資料集
from torch.utils.data import DataLoader
batch_size = 16    #每批讀取資料16個,如果用的是GPU訓練,可以調大一點,128?
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) #shuffle=True代表打亂資料
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) #shuffle=True代表不打亂資料
  1. 可視化處理
for _, batch in enumerate(train_loader):
    x_batch, y_batch = batch[0].to(device), batch[1].to(device)
    print(x_batch.shape, y_batch.shape)
    break
載入中......
此文章數據所有權由區塊鏈加密技術和智能合約保障僅歸創作者所有。