ชุดข้อมูลขนาดใหญ่มากอาจไม่พอดีกับหน่วยความจำที่จัดสรรให้กับกระบวนการของคุณ ใน
ขั้นตอนก่อนหน้านี้ เราได้สร้างไปป์ไลน์ที่สามารถนำชุดข้อมูลทั้งหมด
ลงในหน่วยความจำ เตรียมข้อมูล และส่งชุดการทำงานไปยังการฝึก
แต่ Keras มีฟังก์ชันการฝึกทางเลือก
(fit_generator)
ที่ดึงข้อมูลเป็นกลุ่ม ซึ่งทำให้เราสามารถ ประยุกต์ใช้การเปลี่ยนรูปแบบ
ไปป์ไลน์ข้อมูลไปยังส่วนเล็กๆ (หลายส่วนของ batch_size) ของข้อมูล
ระหว่างการทดสอบ เราใช้การจัดกลุ่ม (โค้ดใน GitHub) สำหรับชุดข้อมูล เช่น
DBPedia, รีวิวจาก Amazon, ข่าวสาร Ag และรีวิว Yelp
โค้ดต่อไปนี้แสดงวิธีสร้างกลุ่มข้อมูลและป้อนให้กับกลุ่มข้อมูล
fit_generator
def _data_generator(x, y, num_features, batch_size):
"""Generates batches of vectorized texts for training/validation.
# Arguments
x: np.matrix, feature matrix.
y: np.ndarray, labels.
num_features: int, number of features.
batch_size: int, number of samples per batch.
# Returns
Yields feature and label data in batches.
"""
num_samples = x.shape[0]
num_batches = num_samples // batch_size
if num_samples % batch_size:
num_batches += 1
while 1:
for i in range(num_batches):
start_idx = i * batch_size
end_idx = (i + 1) * batch_size
if end_idx > num_samples:
end_idx = num_samples
x_batch = x[start_idx:end_idx]
y_batch = y[start_idx:end_idx]
yield x_batch, y_batch
# Create training and validation generators.
training_generator = _data_generator(
x_train, train_labels, num_features, batch_size)
validation_generator = _data_generator(
x_val, val_labels, num_features, batch_size)
# Get number of training steps. This indicated the number of steps it takes
# to cover all samples in one epoch.
steps_per_epoch = x_train.shape[0] // batch_size
if x_train.shape[0] % batch_size:
steps_per_epoch += 1
# Get number of validation steps.
validation_steps = x_val.shape[0] // batch_size
if x_val.shape[0] % batch_size:
validation_steps += 1
# Train and validate model.
history = model.fit_generator(
generator=training_generator,
steps_per_epoch=steps_per_epoch,
validation_data=validation_generator,
validation_steps=validation_steps,
callbacks=callbacks,
epochs=epochs,
verbose=2) # Logs once per epoch.