博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Keras(十六)生成、读取tfrecords文件,并与tf.keras结合使用
阅读量:4203 次
发布时间:2019-05-26

本文共 14471 字,大约阅读时间需要 48 分钟。

本文将介绍:

  • 生成tfrecords文件
  • 读取tfrecords文件
  • 将从tfrecords文件读取的数据与tf.keras结合使用

一,生成tfrecords文件

1,获取csv文件并分类为训练数据集,验证数据集,测试数据集
#!/usr/bin/env python3# -*- coding: utf-8 -*-import matplotlib as mplimport matplotlib.pyplot as pltimport numpy as npimport sklearnimport pandas as pdimport osimport sysimport timeimport tensorflow as tffrom tensorflow import keras# 打印使用的python库的版本信息print(tf.__version__)print(sys.version_info)for module in mpl, np, pd, sklearn, tf, keras:    print(module.__name__, module.__version__)    # 1,获取csv文件并分类为训练数据集,验证数据集,测试数据集source_dir = "./generate_csv/"def get_filenames_by_prefix(source_dir, prefix_name):    all_files = os.listdir(source_dir)    results = []    for filename in all_files:        if filename.startswith(prefix_name):            results.append(os.path.join(source_dir, filename))    return resultstrain_filenames = get_filenames_by_prefix(source_dir, "train")valid_filenames = get_filenames_by_prefix(source_dir, "valid")test_filenames = get_filenames_by_prefix(source_dir, "test")import pprintpprint.pprint(train_filenames)pprint.pprint(valid_filenames)pprint.pprint(test_filenames)
2,将csv文件转为tf.dataset对象
def parse_csv_line(line, n_fields = 9):    defs = [tf.constant(np.nan)] * n_fields    parsed_fields = tf.io.decode_csv(line, record_defaults=defs)    x = tf.stack(parsed_fields[0:-1])    y = tf.stack(parsed_fields[-1:])    return x, ydef csv_reader_dataset(filenames, n_readers=5,                       batch_size=32, n_parse_threads=5,                       shuffle_buffer_size=10000):    dataset = tf.data.Dataset.list_files(filenames)    dataset = dataset.repeat()    dataset = dataset.interleave(        lambda filename: tf.data.TextLineDataset(filename).skip(1),        cycle_length = n_readers    )    dataset.shuffle(shuffle_buffer_size)    dataset = dataset.map(parse_csv_line,num_parallel_calls=n_parse_threads)    dataset = dataset.batch(batch_size)    return datasetbatch_size = 32train_set = csv_reader_dataset(train_filenames,batch_size = batch_size)valid_set = csv_reader_dataset(valid_filenames,batch_size = batch_size)test_set = csv_reader_dataset(test_filenames,batch_size = batch_size)
3,定义将csv文件列表转化为tfrecord文件函数
def serialize_example(x, y):    """Converts x, y to tf.train.Example and serialize"""    input_feautres = tf.train.FloatList(value = x)    label = tf.train.FloatList(value = y)    features = tf.train.Features(        feature = {
"input_features": tf.train.Feature( float_list = input_feautres), "label": tf.train.Feature(float_list = label) } ) example = tf.train.Example(features = features) return example.SerializeToString()def csv_dataset_to_tfrecords(base_filename, dataset, n_shards, steps_per_shard, compression_type = None): options = tf.io.TFRecordOptions( compression_type = compression_type) all_filenames = [] for shard_id in range(n_shards): filename_fullpath = '{}_{:05d}-of-{:05d}'.format( base_filename, shard_id, n_shards) with tf.io.TFRecordWriter(filename_fullpath, options) as writer: for x_batch, y_batch in dataset.skip(shard_id * steps_per_shard).take(steps_per_shard): for x_example, y_example in zip(x_batch, y_batch): writer.write( serialize_example(x_example, y_example)) all_filenames.append(filename_fullpath) return all_filenames
4,将csv文件列表转化为tfrecord文件
n_shards = 20train_steps_per_shard = 11610 // batch_size // n_shardsvalid_steps_per_shard = 3880 // batch_size // n_shardstest_steps_per_shard = 5170 // batch_size // n_shardsoutput_dir = "generate_tfrecords"if not os.path.exists(output_dir):    os.mkdir(output_dir)train_basename = os.path.join(output_dir, "train")valid_basename = os.path.join(output_dir, "valid")test_basename = os.path.join(output_dir, "test")train_tfrecord_filenames = csv_dataset_to_tfrecords(    train_basename, train_set, n_shards, train_steps_per_shard, None)valid_tfrecord_filenames = csv_dataset_to_tfrecords(    valid_basename, valid_set, n_shards, valid_steps_per_shard, None)test_tfrecord_fielnames = csv_dataset_to_tfrecords(    test_basename, test_set, n_shards, test_steps_per_shard, None)
5,将csv文件列表转化为tfrecord的zip压缩文件
n_shards = 20train_steps_per_shard = 11610 // batch_size // n_shardsvalid_steps_per_shard = 3880 // batch_size // n_shardstest_steps_per_shard = 5170 // batch_size // n_shardsoutput_dir = "generate_tfrecords_zip"if not os.path.exists(output_dir):    os.mkdir(output_dir)train_basename = os.path.join(output_dir, "train")valid_basename = os.path.join(output_dir, "valid")test_basename = os.path.join(output_dir, "test")train_tfrecord_filenames = csv_dataset_to_tfrecords(    train_basename, train_set, n_shards, train_steps_per_shard,    compression_type = "GZIP")valid_tfrecord_filenames = csv_dataset_to_tfrecords(    valid_basename, valid_set, n_shards, valid_steps_per_shard,    compression_type = "GZIP")test_tfrecord_fielnames = csv_dataset_to_tfrecords(    test_basename, test_set, n_shards, test_steps_per_shard,    compression_type = "GZIP")pprint.pprint(train_tfrecord_filenames)pprint.pprint(valid_tfrecord_filenames)pprint.pprint(test_tfrecord_fielnames)

二,读取tfrecords文件

1,定义将tfrecord转化为tf.dataset对象方法
# 6,定义将tfrecord转化为tf.dataset对象方法expected_features = {
"input_features": tf.io.FixedLenFeature([8], dtype=tf.float32), "label": tf.io.FixedLenFeature([1], dtype=tf.float32)}def parse_example(serialized_example): example = tf.io.parse_single_example(serialized_example, expected_features) return example["input_features"], example["label"]def tfrecords_reader_dataset(filenames, n_readers=5, batch_size=32, n_parse_threads=5, shuffle_buffer_size=10000): dataset = tf.data.Dataset.list_files(filenames) dataset = dataset.repeat() dataset = dataset.interleave( lambda filename: tf.data.TFRecordDataset( filename, compression_type = "GZIP"), cycle_length = n_readers ) dataset.shuffle(shuffle_buffer_size) dataset = dataset.map(parse_example, num_parallel_calls=n_parse_threads) dataset = dataset.batch(batch_size) return datasettfrecords_train = tfrecords_reader_dataset(train_tfrecord_filenames, batch_size = 3)for x_batch, y_batch in tfrecords_train.take(10): print(x_batch) print(y_batch)
2,将tfrecord转化为tf.dataset对象
# 7,将tfrecord转化为tf.dataset对象batch_size = 32tfrecords_train_set = tfrecords_reader_dataset(    train_tfrecord_filenames, batch_size = batch_size)tfrecords_valid_set = tfrecords_reader_dataset(    valid_tfrecord_filenames, batch_size = batch_size)tfrecords_test_set = tfrecords_reader_dataset(    test_tfrecord_fielnames, batch_size = batch_size)

三,将从tfrecords文件读取的数据与tf.keras结合使用

# 9,得到估计器准确值model.evaluate(tfrecords_test_set, steps = 5160 // batch_size)

四,总结

#!/usr/bin/env python3# -*- coding: utf-8 -*-import matplotlib as mplimport matplotlib.pyplot as pltimport numpy as npimport sklearnimport pandas as pdimport osimport sysimport timeimport tensorflow as tffrom tensorflow import keras# 打印使用的python库的版本信息print(tf.__version__)print(sys.version_info)for module in mpl, np, pd, sklearn, tf, keras:    print(module.__name__, module.__version__)    # 1,获取csv文件并分类为训练数据集,验证数据集,测试数据集source_dir = "./generate_csv/"def get_filenames_by_prefix(source_dir, prefix_name):    all_files = os.listdir(source_dir)    results = []    for filename in all_files:        if filename.startswith(prefix_name):            results.append(os.path.join(source_dir, filename))    return resultstrain_filenames = get_filenames_by_prefix(source_dir, "train")valid_filenames = get_filenames_by_prefix(source_dir, "valid")test_filenames = get_filenames_by_prefix(source_dir, "test")import pprintpprint.pprint(train_filenames)pprint.pprint(valid_filenames)pprint.pprint(test_filenames)# 2,将csv文件转为tf.dataset对象def parse_csv_line(line, n_fields = 9):    defs = [tf.constant(np.nan)] * n_fields    parsed_fields = tf.io.decode_csv(line, record_defaults=defs)    x = tf.stack(parsed_fields[0:-1])    y = tf.stack(parsed_fields[-1:])    return x, ydef csv_reader_dataset(filenames, n_readers=5,                       batch_size=32, n_parse_threads=5,                       shuffle_buffer_size=10000):    dataset = tf.data.Dataset.list_files(filenames)    dataset = dataset.repeat()    dataset = dataset.interleave(        lambda filename: tf.data.TextLineDataset(filename).skip(1),        cycle_length = n_readers    )    dataset.shuffle(shuffle_buffer_size)    dataset = dataset.map(parse_csv_line,num_parallel_calls=n_parse_threads)    dataset = dataset.batch(batch_size)    return datasetbatch_size = 32train_set = csv_reader_dataset(train_filenames,batch_size = batch_size)valid_set = csv_reader_dataset(valid_filenames,batch_size = batch_size)test_set = csv_reader_dataset(test_filenames,batch_size = batch_size)# 3,定义将csv文件列表转化为tfrecord文件函数def serialize_example(x, y):    """Converts x, y to tf.train.Example and serialize"""    input_feautres = tf.train.FloatList(value = x)    label = tf.train.FloatList(value = y)    features = tf.train.Features(        feature = {
"input_features": tf.train.Feature( float_list = input_feautres), "label": tf.train.Feature(float_list = label) } ) example = tf.train.Example(features = features) return example.SerializeToString()def csv_dataset_to_tfrecords(base_filename, dataset, n_shards, steps_per_shard, compression_type = None): options = tf.io.TFRecordOptions( compression_type = compression_type) all_filenames = [] for shard_id in range(n_shards): filename_fullpath = '{}_{:05d}-of-{:05d}'.format( base_filename, shard_id, n_shards) with tf.io.TFRecordWriter(filename_fullpath, options) as writer: for x_batch, y_batch in dataset.skip(shard_id * steps_per_shard).take(steps_per_shard): for x_example, y_example in zip(x_batch, y_batch): writer.write( serialize_example(x_example, y_example)) all_filenames.append(filename_fullpath) return all_filenames# 4,将csv文件列表转化为tfrecord文件n_shards = 20train_steps_per_shard = 11610 // batch_size // n_shardsvalid_steps_per_shard = 3880 // batch_size // n_shardstest_steps_per_shard = 5170 // batch_size // n_shardsoutput_dir = "generate_tfrecords"if not os.path.exists(output_dir): os.mkdir(output_dir)train_basename = os.path.join(output_dir, "train")valid_basename = os.path.join(output_dir, "valid")test_basename = os.path.join(output_dir, "test")train_tfrecord_filenames = csv_dataset_to_tfrecords( train_basename, train_set, n_shards, train_steps_per_shard, None)valid_tfrecord_filenames = csv_dataset_to_tfrecords( valid_basename, valid_set, n_shards, valid_steps_per_shard, None)test_tfrecord_fielnames = csv_dataset_to_tfrecords( test_basename, test_set, n_shards, test_steps_per_shard, None)# 5,将csv文件列表转化为tfrecord的zip压缩文件n_shards = 20train_steps_per_shard = 11610 // batch_size // n_shardsvalid_steps_per_shard = 3880 // batch_size // n_shardstest_steps_per_shard = 5170 // batch_size // n_shardsoutput_dir = "generate_tfrecords_zip"if not os.path.exists(output_dir): os.mkdir(output_dir)train_basename = os.path.join(output_dir, "train")valid_basename = os.path.join(output_dir, "valid")test_basename = os.path.join(output_dir, "test")train_tfrecord_filenames = csv_dataset_to_tfrecords( train_basename, train_set, n_shards, train_steps_per_shard, compression_type = "GZIP")valid_tfrecord_filenames = csv_dataset_to_tfrecords( valid_basename, valid_set, n_shards, valid_steps_per_shard, compression_type = "GZIP")test_tfrecord_fielnames = csv_dataset_to_tfrecords( test_basename, test_set, n_shards, test_steps_per_shard, compression_type = "GZIP")pprint.pprint(train_tfrecord_filenames)pprint.pprint(valid_tfrecord_filenames)pprint.pprint(test_tfrecord_fielnames)################################################################################ 6,定义将tfrecord转化为tf.dataset对象方法expected_features = {
"input_features": tf.io.FixedLenFeature([8], dtype=tf.float32), "label": tf.io.FixedLenFeature([1], dtype=tf.float32)}def parse_example(serialized_example): example = tf.io.parse_single_example(serialized_example, expected_features) return example["input_features"], example["label"]def tfrecords_reader_dataset(filenames, n_readers=5, batch_size=32, n_parse_threads=5, shuffle_buffer_size=10000): dataset = tf.data.Dataset.list_files(filenames) dataset = dataset.repeat() dataset = dataset.interleave( lambda filename: tf.data.TFRecordDataset( filename, compression_type = "GZIP"), cycle_length = n_readers ) dataset.shuffle(shuffle_buffer_size) dataset = dataset.map(parse_example, num_parallel_calls=n_parse_threads) dataset = dataset.batch(batch_size) return datasettfrecords_train = tfrecords_reader_dataset(train_tfrecord_filenames, batch_size = 3)for x_batch, y_batch in tfrecords_train.take(10): print(x_batch) print(y_batch) # 7,将tfrecord转化为tf.dataset对象batch_size = 32tfrecords_train_set = tfrecords_reader_dataset( train_tfrecord_filenames, batch_size = batch_size)tfrecords_valid_set = tfrecords_reader_dataset( valid_tfrecord_filenames, batch_size = batch_size)tfrecords_test_set = tfrecords_reader_dataset( test_tfrecord_fielnames, batch_size = batch_size)# 8,将数据带入tf.keras模型中训练model = keras.models.Sequential([ keras.layers.Dense(30, activation='relu', input_shape=[8]), keras.layers.Dense(1),])model.compile(loss="mean_squared_error", optimizer="sgd")callbacks = [keras.callbacks.EarlyStopping( patience=5, min_delta=1e-2)]history = model.fit(tfrecords_train_set, validation_data = tfrecords_valid_set, steps_per_epoch = 11160 // batch_size, validation_steps = 3870 // batch_size, epochs = 100, callbacks = callbacks)# 9,得到估计器准确值model.evaluate(tfrecords_test_set, steps = 5160 // batch_size)

转载地址:http://ivili.baihongyu.com/

你可能感兴趣的文章
SQL注入漏洞全接触--进阶篇
查看>>
SQL注入漏洞全接触--高级篇
查看>>
SQL注入法攻击一日通
查看>>
菜鸟入门级:SQL注入攻击
查看>>
用vbs来写sql注入等80端口的攻击脚本
查看>>
C# 检查字符串,防SQL注入攻击
查看>>
关于对SQL注入80004005 及其它错误消息分析
查看>>
即时通软件性能测试(与宴宾的对话)
查看>>
应用软件性能测试的艺术(翻译)——序
查看>>
高级性能测试(翻译)
查看>>
Web安全测试解决方案
查看>>
今天开始上班
查看>>
开源测试研究方案泡汤了
查看>>
晒一下我培训的课程——应用系统性能测试规划、实施与分析
查看>>
自动化测试框架之控制界面的关键
查看>>
自动化测试框架指南
查看>>
利用 STAF 实现程序更新包的自动部署测试
查看>>
软件安全性测试转载自小龙虾博客
查看>>
周末参加“北京干部管理职业技术学院”关于高职课程改革的专家讨论会
查看>>
软件测试框架介绍
查看>>