2 回答

TA貢獻(xiàn)1884條經(jīng)驗(yàn) 獲得超4個(gè)贊
您只需要撥打:
model.fit_generator(generator,?steps_per_epoch)
其中steps_per_epoch
是通常ceil(num_samples / batch_size)
并且generator
是一個(gè) python 生成器,它迭代數(shù)據(jù)并批量生成數(shù)據(jù)。每次調(diào)用生成器都應(yīng)該產(chǎn)生batch_size
許多元素。生成器的示例:
def generate_data(directory, batch_size):
? ? """Replaces Keras' native ImageDataGenerator."""
? ? i = 0
? ? file_list = os.listdir(directory)
? ? while True:
? ? ? ? image_batch = []
? ? ? ? for b in range(batch_size):
? ? ? ? ? ? if i == len(file_list):
? ? ? ? ? ? ? ? i = 0
? ? ? ? ? ? ? ? random.shuffle(file_list)
? ? ? ? ? ? sample = file_list[i]
? ? ? ? ? ? i += 1
? ? ? ? ? ? image = cv2.resize(cv2.imread(sample[0]), INPUT_SHAPE)
? ? ? ? ? ? image_batch.append((image.astype(float) - 128) / 128)
? ? ? ? yield np.array(image_batch)
由于這絕對(duì)是特定于問(wèn)題的,因此您必須編寫(xiě)自己的生成器,盡管使用此模板應(yīng)該很簡(jiǎn)單。

TA貢獻(xiàn)1802條經(jīng)驗(yàn) 獲得超4個(gè)贊
這是將訓(xùn)練數(shù)據(jù)分割成小批量的數(shù)據(jù)生成器:
def generate_data(X1,X2,Y,batch_size):
p_input=[]
c_input=[]
target=[]
batch_count=0
for i in range(len(X1)):
p_input.append(X1[i])
c_input.append(X2[i])
target.append(Y[i])
batch_count+=1
if batch_count>batch_size:
prev_X=np.array(p_input,dtype=np.int64)
cur_X=np.array(c_input,dtype=np.int64)
cur_y=np.array(target,dtype=np.int32)
print(len(prev_X),len(cur_X))
yield ([prev_X,cur_X],cur_y )
p_input=[]
c_input=[]
target=[]
batch_count=0
return
這里是 fit_generator 函數(shù)調(diào)用而不是 model.fit 方法:
batch_size=256
epoch_steps=math.ceil(len(previous_sentences)/ batch_size)
hist = model.fit_generator(generate_data(previous_sentences,current_sentences, y_train, batch_size),
steps_per_epoch=epoch_steps,
callbacks = [early_stopping_cb],
validation_data=generate_data(val_prev, val_curr,y_val,batch_size),
validation_steps=val_steps, class_weight=custom_weight_dict,
verbose=1)
添加回答
舉報(bào)