参考 https://github.com/tensorflow/models/tree/master/slim
使用TensorFlow-Slim进行图像分类
准备
安装TensorFlow
参考 https://www.tensorflow.org/install/
如在Ubuntu下安装TensorFlow with GPU support, python 2.7版本
wget https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl pip install tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
下载TF-slim图像模型库
cd $WORKSPACE git clone https://github.com/tensorflow/models/
准备数据
有不少公开数据集,这里以官网提供的Flowers为例。
官网提供了下载和转换数据的代码,为了理解代码并能使用自己的数据,这里参考官方提供的代码进行修改。
cd $WORKSPACE/data wget http://download.tensorflow.org/example_images/flower_photos.tgz tar zxf flower_photos.tgz
数据集文件夹结构如下:
flower_photos ├── daisy │ ├── 100080576_f52e8ee070_n.jpg │ └── ... ├── dandelion ├── LICENSE.txt ├── roses ├── sunflowers └── tulips
由于实际情况中我们自己的数据集并不一定把图片按类别放在不同的文件夹里,故我们生成list.txt来表示图片路径与标签的关系。
Python代码:
import os class_names_to_ids = {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4} data_dir = 'flower_photos/' output_path = 'list.txt' fd = open(output_path, 'w') for class_name in class_names_to_ids.keys(): images_list = os.listdir(data_dir + class_name) for image_name in images_list: fd.write('{}/{} {}\n'.format(class_name, image_name, class_names_to_ids[class_name])) fd.close()
为了方便后期查看label标签,也可以定义labels.txt:
daisy dandelion roses sunflowers tulips
随机生成训练集与验证集:
Python代码:
import random _NUM_VALIDATION = 350 _RANDOM_SEED = 0 list_path = 'list.txt' train_list_path = 'list_train.txt' val_list_path = 'list_val.txt' fd = open(list_path) lines = fd.readlines() fd.close() random.seed(_RANDOM_SEED) random.shuffle(lines) fd = open(train_list_path, 'w') for line in lines[_NUM_VALIDATION:]: fd.write(line) fd.close() fd = open(val_list_path, 'w') for line in lines[:_NUM_VALIDATION]: fd.write(line) fd.close()
生成TFRecord数据:
Python代码:
import sys sys.path.insert(0, '../models/slim/') from datasets import dataset_utils import math import os import tensorflow as tf def convert_dataset(list_path, data_dir, output_dir, _NUM_SHARDS=5): fd = open(list_path) lines = [line.split() for line in fd] fd.close() num_per_shard = int(math.ceil(len(lines) / float(_NUM_SHARDS))) with tf.Graph().as_default(): decode_jpeg_data = tf.placeholder(dtype=tf.string) decode_jpeg = tf.image.decode_jpeg(decode_jpeg_data, channels=3) with tf.Session('') as sess: for shard_id in range(_NUM_SHARDS): output_path = os.path.join(output_dir, 'data_{:05}-of-{:05}.tfrecord'.format(shard_id, _NUM_SHARDS)) tfrecord_writer = tf.python_io.TFRecordWriter(output_path) start_ndx = shard_id * num_per_shard end_ndx = min((shard_id + 1) * num_per_shard, len(lines)) for i in range(start_ndx, end_ndx): sys.stdout.write('\r Converting image {}/{} shard {}'.format( i + 1, len(lines), shard_id)) sys.stdout.flush() image_data = tf.gfile.FastGFile(os.path.join(data_dir, lines[i][0]), 'rb').read() image = sess.run(decode_jpeg, feed_dict={decode_jpeg_data: image_data}) height, width = image.shape[0], image.shape[1] example = dataset_utils.image_to_tfexample( image_data, b'jpg', height, width, int(lines[i][1])) tfrecord_writer.write(example.SerializeToString()) tfrecord_writer.close() sys.stdout.write('\n') sys.stdout.flush() os.system('mkdir -p train') convert_dataset('list_train.txt', 'flower_photos', 'train/') os.system('mkdir -p val') convert_dataset('list_val.txt', 'flower_photos', 'val/')
得到的文件夹结构如下:
data ├── flower_photos ├── labels.txt ├── list_train.txt ├── list.txt ├── list_val.txt ├── train │ ├── data_00000-of-00005.tfrecord │ ├── ... │ └── data_00004-of-00005.tfrecord └── val ├── data_00000-of-00005.tfrecord ├── ... └── data_00004-of-00005.tfrecord
(可选)下载模型
官方提供了不少预训练模型,这里以Inception-ResNet-v2以例。
cd $WORKSPACE/checkpoints wget http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz tar zxf inception_resnet_v2_2016_08_30.tar.gz
训练
读入数据
官方提供了读入Flowers数据集的代码models/slim/datasets/flowers.py,同样这里也是参考并修改成能读入上面定义的通用数据集。
把下面代码写入models/slim/datasets/dataset_classification.py。
import os import tensorflow as tf slim = tf.contrib.slim def get_dataset(dataset_dir, num_samples, num_classes, labels_to_names_path=None, file_pattern='*.tfrecord'): file_pattern = os.path.join(dataset_dir, file_pattern) keys_to_features = { 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 'image/class/label': tf.FixedLenFeature( [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), } items_to_handlers = { 'image': slim.tfexample_decoder.Image(), 'label': slim.tfexample_decoder.Tensor('image/class/label'), } decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers) items_to_descriptions = { 'image': 'A color image of varying size.', 'label': 'A single integer between 0 and ' + str(num_classes - 1), } labels_to_names = None if labels_to_names_path is not None: fd = open(labels_to_names_path) labels_to_names = {i : line.strip() for i, line in enumerate(fd)} fd.close() return slim.dataset.Dataset( data_sources=file_pattern, reader=tf.TFRecordReader, decoder=decoder, num_samples=num_samples, items_to_descriptions=items_to_descriptions, num_classes=num_classes, labels_to_names=labels_to_names)
构建模型
官方提供了许多模型在models/slim/nets/。
如需要自定义模型,则参考官方提供的模型并放在对应的文件夹即可。
开始训练
官方提供了训练脚本,如果使用官方的数据读入和处理,可使用以下方式开始训练。
cd $WORKSPACE/models/slim CUDA_VISIBLE_DEVICES="0" python train_image_classifier.py --train_dir=train_logs --dataset_name=flowers --dataset_split_name=train --dataset_dir=../../data/flowers --model_name=inception_resnet_v2 --checkpoint_path=../../checkpoints/inception_resnet_v2_2016_08_30.ckpt --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits --max_number_of_steps=1000 --batch_size=32 --learning_rate=0.01 --learning_rate_decay_type=fixed --save_interval_secs=60 --save_summaries_secs=60 --log_every_n_steps=10 --optimizer=rmsprop --weight_decay=0.00004
不fine-tune把--checkpoint_path, --checkpoint_exclude_scopes和--trainable_scopes删掉。
fine-tune所有层把--checkpoint_exclude_scopes和--trainable_scopes删掉。
如果只使用CPU则加上--clone_on_cpu=True。
其它参数可删掉用默认值或自行修改。
使用自己的数据则需要修改models/slim/train_image_classifier.py:
把
from datasets import dataset_factory
修改为
from datasets import dataset_classification
把
dataset = dataset_factory.get_dataset( FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
修改为
dataset = dataset_classification.get_dataset( FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes, FLAGS.labels_to_names_path)
在
tf.app.flags.DEFINE_string( 'dataset_dir', None, 'The directory where the dataset files are stored.')
后加入
tf.app.flags.DEFINE_integer( 'num_samples', 3320, 'Number of samples.') tf.app.flags.DEFINE_integer( 'num_classes', 5, 'Number of classes.') tf.app.flags.DEFINE_string( 'labels_to_names_path', None, 'Label names file path.')
训练时执行以下命令即可:
cd $WORKSPACE/models/slim python train_image_classifier.py --train_dir=train_logs --dataset_dir=../../data/train --num_samples=3320 --num_classes=5 --labels_to_names_path=../../data/labels.txt --model_name=inception_resnet_v2 --checkpoint_path=../../checkpoints/inception_resnet_v2_2016_08_30.ckpt --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits
可视化log
可一边训练一边可视化训练的log,可看到Loss趋势。
tensorboard --logdir train_logs/
验证
官方提供了验证脚本。
python eval_image_classifier.py --checkpoint_path=train_logs --eval_dir=eval_logs --dataset_name=flowers --dataset_split_name=validation --dataset_dir=../../data/flowers --model_name=inception_resnet_v2
同样,如果是使用自己的数据集,则需要修改models/slim/eval_image_classifier.py:
把
from datasets import dataset_factory
修改为
from datasets import dataset_classification
把
dataset = dataset_factory.get_dataset( FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
修改为
dataset = dataset_classification.get_dataset( FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes, FLAGS.labels_to_names_path)
在
tf.app.flags.DEFINE_string( 'dataset_dir', None, 'The directory where the dataset files are stored.')
后加入
tf.app.flags.DEFINE_integer( 'num_samples', 350, 'Number of samples.') tf.app.flags.DEFINE_integer( 'num_classes', 5, 'Number of classes.') tf.app.flags.DEFINE_string( 'labels_to_names_path', None, 'Label names file path.')
验证时执行以下命令即可:
python eval_image_classifier.py --checkpoint_path=train_logs --eval_dir=eval_logs --dataset_dir=../../data/val --num_samples=350 --num_classes=5 --model_name=inception_resnet_v2
可以一边训练一边验证,,注意使用其它的GPU或合理分配显存。
同样也可以可视化log,如果已经在可视化训练的log则建议使用其它端口,如:
tensorboard --logdir eval_logs/ --port 6007
测试
参考models/slim/eval_image_classifier.py,可编写读取图片用模型进行推导的脚本models/slim/test_image_classifier.py
from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import math import tensorflow as tf from nets import nets_factory from preprocessing import preprocessing_factory slim = tf.contrib.slim tf.app.flags.DEFINE_string( 'master', '', 'The address of the TensorFlow master to use.') tf.app.flags.DEFINE_string( 'checkpoint_path', '/tmp/tfmodel/', 'The directory where the model was written to or an absolute path to a ' 'checkpoint file.') tf.app.flags.DEFINE_string( 'test_path', '', 'Test image path.') tf.app.flags.DEFINE_integer( 'num_classes', 5, 'Number of classes.') tf.app.flags.DEFINE_integer( 'labels_offset', 0, 'An offset for the labels in the dataset. This flag is primarily used to ' 'evaluate the VGG and ResNet architectures which do not use a background ' 'class for the ImageNet dataset.') tf.app.flags.DEFINE_string( 'model_name', 'inception_v3', 'The name of the architecture to evaluate.') tf.app.flags.DEFINE_string( 'preprocessing_name', None, 'The name of the preprocessing to use. If left ' 'as `None`, then the model_name flag is used.') tf.app.flags.DEFINE_integer( 'test_image_size', None, 'Eval image size') FLAGS = tf.app.flags.FLAGS def main(_): if not FLAGS.test_list: raise ValueError('You must supply the test list with --test_list') tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default(): tf_global_step = slim.get_or_create_global_step() #################### # Select the model # #################### network_fn = nets_factory.get_network_fn( FLAGS.model_name, num_classes=(FLAGS.num_classes - FLAGS.labels_offset), is_training=False) ##################################### # Select the preprocessing function # ##################################### preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name image_preprocessing_fn = preprocessing_factory.get_preprocessing( preprocessing_name, is_training=False) test_image_size = FLAGS.test_image_size or network_fn.default_image_size if tf.gfile.IsDirectory(FLAGS.checkpoint_path): checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path) else: checkpoint_path = FLAGS.checkpoint_path tf.Graph().as_default() with tf.Session() as sess: image = open(FLAGS.test_path, 'rb').read() image = tf.image.decode_jpeg(image, channels=3) processed_image = image_preprocessing_fn(image, test_image_size, test_image_size) processed_images = tf.expand_dims(processed_image, 0) logits, _ = network_fn(processed_images) predictions = tf.argmax(logits, 1) saver = tf.train.Saver() saver.restore(sess, checkpoint_path) np_image, network_input, predictions = sess.run([image, processed_image, predictions]) print('{} {}'.format(FLAGS.test_path, predictions[0])) if __name__ == '__main__': tf.app.run()
测试时执行以下命令即可:
python test_image_classifier.py --checkpoint_path=train_logs/ --test_path=../../data/flower_photos/tulips/6948239566_0ac0a124ee_n.jpg --num_classes=5 --model_name=inception_resnet_v2
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。
免责声明:本站文章均来自网站采集或用户投稿,网站不提供任何软件下载或自行开发的软件! 如有用户或公司发现本站内容信息存在侵权行为,请邮件告知! 858582#qq.com
《魔兽世界》大逃杀!60人新游玩模式《强袭风暴》3月21日上线
暴雪近日发布了《魔兽世界》10.2.6 更新内容,新游玩模式《强袭风暴》即将于3月21 日在亚服上线,届时玩家将前往阿拉希高地展开一场 60 人大逃杀对战。
艾泽拉斯的冒险者已经征服了艾泽拉斯的大地及遥远的彼岸。他们在对抗世界上最致命的敌人时展现出过人的手腕,并且成功阻止终结宇宙等级的威胁。当他们在为即将于《魔兽世界》资料片《地心之战》中来袭的萨拉塔斯势力做战斗准备时,他们还需要在熟悉的阿拉希高地面对一个全新的敌人──那就是彼此。在《巨龙崛起》10.2.6 更新的《强袭风暴》中,玩家将会进入一个全新的海盗主题大逃杀式限时活动,其中包含极高的风险和史诗级的奖励。
《强袭风暴》不是普通的战场,作为一个独立于主游戏之外的活动,玩家可以用大逃杀的风格来体验《魔兽世界》,不分职业、不分装备(除了你在赛局中捡到的),光是技巧和战略的强弱之分就能决定出谁才是能坚持到最后的赢家。本次活动将会开放单人和双人模式,玩家在加入海盗主题的预赛大厅区域前,可以从强袭风暴角色画面新增好友。游玩游戏将可以累计名望轨迹,《巨龙崛起》和《魔兽世界:巫妖王之怒 经典版》的玩家都可以获得奖励。
更新日志
- 明达年度发烧碟MasterSuperiorAudiophile2021[DSF]
- 英文DJ 《致命的温柔》24K德国HD金碟DTS 2CD[WAV+分轨][1.7G]
- 张学友1997《不老的传说》宝丽金首版 [WAV+CUE][971M]
- 张韶涵2024 《不负韶华》开盘母带[低速原抓WAV+CUE][1.1G]
- lol全球总决赛lcs三号种子是谁 S14全球总决赛lcs三号种子队伍介绍
- lol全球总决赛lck三号种子是谁 S14全球总决赛lck三号种子队伍
- 群星.2005-三里屯音乐之男孩女孩的情人节【太合麦田】【WAV+CUE】
- 崔健.2005-给你一点颜色【东西音乐】【WAV+CUE】
- 南台湾小姑娘.1998-心爱,等一下【大旗】【WAV+CUE】
- 【新世纪】群星-美丽人生(CestLaVie)(6CD)[WAV+CUE]
- ProteanQuartet-Tempusomniavincit(2024)[24-WAV]
- SirEdwardElgarconductsElgar[FLAC+CUE]
- 田震《20世纪中华歌坛名人百集珍藏版》[WAV+CUE][1G]
- BEYOND《大地》24K金蝶限量编号[低速原抓WAV+CUE][986M]
- 陈奕迅《准备中 SACD》[日本限量版] [WAV+CUE][1.2G]