1 回答

TA貢獻1906條經(jīng)驗 獲得超3個贊
所以我最終要做的是手動進行所有預(yù)處理并為每個包含預(yù)處理序列的庫存保存一個 .npy 文件,然后使用手動創(chuàng)建的生成器進行批量處理:
class seq_generator():
def __init__(self, list_of_filepaths):
self.usedDict = dict()
for path in list_of_filepaths:
self.usedDict[path] = []
def generate(self):
while True:
path = np.random.choice(list(self.usedDict.keys()))
stock_array = np.load(path)
random_sequence = np.random.randint(stock_array.shape[0])
if random_sequence not in self.usedDict[path]:
self.usedDict[path].append(random_sequence)
yield stock_array[random_sequence, :, :]
train_generator = seq_generator(list_of_filepaths)
train_dataset = tf.data.Dataset.from_generator(seq_generator.generate(),
output_types=(tf.float32, tf.float32),
output_shapes=(n_timesteps, n_features))
train_dataset = train_dataset.batch(batch_size)
Wherelist_of_filepaths只是預(yù)處理 .npy 數(shù)據(jù)的路徑列表。
這將:
加載隨機股票的預(yù)處理 .npy 數(shù)據(jù)
隨機選擇一個序列
檢查序列的索引是否已經(jīng)被使用
usedDict
如果不:
附加該序列的索引
usedDict
以跟蹤不向模型提供兩次相同的數(shù)據(jù)產(chǎn)生序列
這意味著生成器將在每次“調(diào)用”時從隨機股票中提供單個唯一序列,使我能夠使用來自 Tensorflows數(shù)據(jù)集類型的.from_generator()
和方法。.batch()
添加回答
舉報