1 回答

TA貢獻(xiàn)1839條經(jīng)驗(yàn) 獲得超15個(gè)贊
class TextLoader():
def __init__(self, data_dir, batch_size, seq_length, encoding='utf-8'):
self.data_dir = data_dir
self.batch_size = batch_size
self.seq_length = seq_length
self.encoding = encoding
#第一次運(yùn)行程序時(shí)只有input.txt一個(gè)文件,剩下兩個(gè)文件是運(yùn)行之后產(chǎn)生的
input_file = os.path.join(data_dir, "input.txt")
vocab_file = os.path.join(data_dir, "vocab.pkl")
tensor_file = os.path.join(data_dir, "data.npy")
#如果是第一次執(zhí)行則調(diào)用preprocess函數(shù),否則調(diào)用load_preprocessed函數(shù)。
if not (os.path.exists(vocab_file) and os.path.exists(tensor_file)):
print("reading text file")
self.preprocess(input_file, vocab_file, tensor_file)
else:
print("loading preprocessed files")
self.load_preprocessed(vocab_file, tensor_file)
self.create_batches()
self.reset_batch_pointer()
def preprocess(self, input_file, vocab_file, tensor_file):
with codecs.open(input_file, "r", encoding=self.encoding) as f:
data = f.read()
#使用Counter函數(shù)對(duì)輸入數(shù)據(jù)進(jìn)行統(tǒng)計(jì)。counter保存data中每個(gè)字符出現(xiàn)的次數(shù)
counter = collections.Counter(data)
#對(duì)counter進(jìn)行排序,出現(xiàn)次數(shù)最多的排在前面
count_pairs = sorted(counter.items(), key=lambda x: -x[1])
#將data中出現(xiàn)的所有字符保存,這里有65個(gè),所以voacb_size=65
self.chars, _ = zip(*count_pairs)
self.vocab_size = len(self.chars)
#按照字符出現(xiàn)次數(shù)多少順序?qū)hars保存,vocab中存儲(chǔ)的是char和順序,這樣方便將data轉(zhuǎn)化為索引
self.vocab = dict(zip(self.chars, range(len(self.chars))))
with open(vocab_file, 'wb') as f:
#保存chars
cPickle.dump(self.chars, f)
#將data中每個(gè)字符轉(zhuǎn)化為索引下標(biāo)。
self.tensor = np.array(list(map(self.vocab.get, data)))
np.save(tensor_file, self.tensor)
def load_preprocessed(self, vocab_file, tensor_file):
#如果是第二次運(yùn)行,則可以直接讀取之前保存的chars和tensor
with open(vocab_file, 'rb') as f:
self.chars = cPickle.load(f)
self.vocab_size = len(self.chars)
self.vocab = dict(zip(self.chars, range(len(self.chars))))
self.tensor = np.load(tensor_file)
self.num_batches = int(self.tensor.size / (self.batch_size *
self.seq_length))
def create_batches(self):
#首先將數(shù)據(jù)按batch_size切割,然后每個(gè)batch_size在按照seq_length進(jìn)行切割
self.num_batches = int(self.tensor.size / (self.batch_size *
self.seq_length))
if self.num_batches == 0:
assert False, "Not enough data. Make seq_length and batch_size small."
self.tensor = self.tensor[:self.num_batches * self.batch_size * self.seq_length]
xdata = self.tensor
#構(gòu)造target,這里使用上一個(gè)詞預(yù)測(cè)下一個(gè)詞,所以直接將x向后一個(gè)字符即可
ydata = np.copy(self.tensor)
ydata[:-1] = xdata[1:]
ydata[-1] = xdata[0]
#將數(shù)據(jù)進(jìn)行切分,這里我們假設(shè)數(shù)據(jù)總長(zhǎng)度為10000,batch_size為100, seq_length為10.
# 所以num_batches=10,所以,xdata在reshape之后變成[100, 100],然后在第二個(gè)維度上切成10份,
# 所以最終得到[100, 10, 10]的數(shù)據(jù)
self.x_batches = np.split(xdata.reshape(self.batch_size, -1),
self.num_batches, 1)
self.y_batches = np.split(ydata.reshape(self.batch_size, -1),
self.num_batches, 1)
def next_batch(self):
x, y = self.x_batches[self.pointer], self.y_batches[self.pointer]
self.pointer += 1
return x, y
def reset_batch_pointer(self):
self.pointer = 0
- 1 回答
- 0 關(guān)注
- 897 瀏覽
添加回答
舉報(bào)