以下三种方式
# -*- coding:utf-8 -*- import numpy as np import os from sklearn import datasets def data_generator(input_filename, batch_size): """ :param input_filename: :param batch_size: :return: """ feature_size = 3 labels = np.zeros(batch_size) rets = np.empty(shape=[batch_size, feature_size]) i = 0 for line in open(input_filename, "r"): data = line.split(" ") label = int(float(data[0])) ids = [] values = [] for fea in data[1:]: id, value = fea.split(":") if int(id) > feature_size - 1: break ids.append(int(id)) values.append(float(value)) ret = np.zeros([1, feature_size]) for (index, d) in zip(ids, values): ret[0][index] = d labels[i] = int(label) rets[i] = ret i += 1 if i > batch_size - 1: i = 0 yield labels, rets[0:, 0:3] def get_data(input_filename, batch_size): oneline = 16294 # 一行多少个字节 feature_size = 1947 batch = 0 while True: data = datasets.load_svmlight_file(input_filename, offset=oneline * batch_size * batch, length=oneline * batch_size, n_features=feature_size) features = data[0] labels = data[1] if features.shape[0] > 0: # 保证返回和数据的有效性 batch += 1 yield labels, features[0:, 0:3] else: raise StopIteration def get_data_all(input_filename, batch_size): data = datasets.load_svmlight_file(input_filename) features = data[0] labels = data[1] batch = 0 while True: start_index = batch * batch_size end_index = (batch + 1) * batch_size if features.shape[0] > end_index: yield labels[start_index:end_index], features[start_index:end_index, 0:3] batch += 1 else: raise StopIteration if __name__ == "__main__": print("====", os.getcwd()) filename = "/home/part-00000" generator = data_generator(filename, 10) labels, features = generator.next() print([labels]) print(features) generator = get_data_all(filename, 1000) while True: labels, features = generator.next() print ‘data‘, len(labels), features.shape