入门客AI创业平台(我带你入门,你带我飞行)
博文笔记

使用TensorFlow-Slim进行图像分类

创建时间:2017-08-31 投稿人: 浏览次数:8731

参考 https://github.com/tensorflow/models/tree/master/slim

使用TensorFlow-Slim进行图像分类

准备

  1. 安装TensorFlow

    参考 https://www.tensorflow.org/install/

    如在Ubuntu下安装TensorFlow with GPU support, python 2.7版本

    wget https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
    pip install tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
  2. 下载TF-slim图像模型库

    cd $WORKSPACE
    git clone https://github.com/tensorflow/models/
  3. 准备数据

    有不少公开数据集,这里以官网提供的Flowers为例。

    官网提供了下载和转换数据的代码,为了理解代码并能使用自己的数据,这里参考官方提供的代码进行修改。

    cd $WORKSPACE/data
    wget http://download.tensorflow.org/example_images/flower_photos.tgz
    tar zxf flower_photos.tgz

    数据集文件夹结构如下:

    flower_photos
    ├── daisy
    │   ├── 100080576_f52e8ee070_n.jpg
    │   └── ...
    ├── dandelion
    ├── LICENSE.txt
    ├── roses
    ├── sunflowers
    └── tulips

    由于实际情况中我们自己的数据集并不一定把图片按类别放在不同的文件夹里,故我们生成list.txt来表示图片路径与标签的关系。

    Python代码:

    import os
    
    class_names_to_ids = {"daisy": 0, "dandelion": 1, "roses": 2, "sunflowers": 3, "tulips": 4}
    data_dir = "flower_photos/"
    output_path = "list.txt"
    
    fd = open(output_path, "w")
    for class_name in class_names_to_ids.keys():
        images_list = os.listdir(data_dir + class_name)
        for image_name in images_list:
            fd.write("{}/{} {}
    ".format(class_name, image_name, class_names_to_ids[class_name]))
    
    fd.close()

    为了方便后期查看label标签,也可以定义labels.txt

    daisy
    dandelion
    roses
    sunflowers
    tulips

    随机生成训练集与验证集:

    Python代码:

    import random
    
    _NUM_VALIDATION = 350
    _RANDOM_SEED = 0
    list_path = "list.txt"
    train_list_path = "list_train.txt"
    val_list_path = "list_val.txt"
    
    fd = open(list_path)
    lines = fd.readlines()
    fd.close()
    random.seed(_RANDOM_SEED)
    random.shuffle(lines)
    
    fd = open(train_list_path, "w")
    for line in lines[_NUM_VALIDATION:]:
        fd.write(line)
    
    fd.close()
    fd = open(val_list_path, "w")
    for line in lines[:_NUM_VALIDATION]:
        fd.write(line)
    
    fd.close()

    生成TFRecord数据:

    Python代码:

    import sys
    sys.path.insert(0, "../models/slim/")
    from datasets import dataset_utils
    import math
    import os
    import tensorflow as tf
    
    def convert_dataset(list_path, data_dir, output_dir, _NUM_SHARDS=5):
        fd = open(list_path)
        lines = [line.split() for line in fd]
        fd.close()
        num_per_shard = int(math.ceil(len(lines) / float(_NUM_SHARDS)))
        with tf.Graph().as_default():
            decode_jpeg_data = tf.placeholder(dtype=tf.string)
            decode_jpeg = tf.image.decode_jpeg(decode_jpeg_data, channels=3)
            with tf.Session("") as sess:
                for shard_id in range(_NUM_SHARDS):
                    output_path = os.path.join(output_dir,
                        "data_{:05}-of-{:05}.tfrecord".format(shard_id, _NUM_SHARDS))
                    tfrecord_writer = tf.python_io.TFRecordWriter(output_path)
                    start_ndx = shard_id * num_per_shard
                    end_ndx = min((shard_id + 1) * num_per_shard, len(lines))
                    for i in range(start_ndx, end_ndx):
                        sys.stdout.write("
    >> Converting image {}/{} shard {}".format(
                            i + 1, len(lines), shard_id))
                        sys.stdout.flush()
                        image_data = tf.gfile.FastGFile(os.path.join(data_dir, lines[i][0]), "rb").read()
                        image = sess.run(decode_jpeg, feed_dict={decode_jpeg_data: image_data})
                        height, width = image.shape[0], image.shape[1]
                        example = dataset_utils.image_to_tfexample(
                            image_data, b"jpg", height, width, int(lines[i][1]))
                        tfrecord_writer.write(example.SerializeToString())
                    tfrecord_writer.close()
        sys.stdout.write("
    ")
        sys.stdout.flush()
    
    os.system("mkdir -p train")
    convert_dataset("list_train.txt", "flower_photos", "train/")
    os.system("mkdir -p val")
    convert_dataset("list_val.txt", "flower_photos", "val/")

    得到的文件夹结构如下:

    data
    ├── flower_photos
    ├── labels.txt
    ├── list_train.txt
    ├── list.txt
    ├── list_val.txt
    ├── train
    │   ├── data_00000-of-00005.tfrecord
    │   ├── ...
    │   └── data_00004-of-00005.tfrecord
    └── val
        ├── data_00000-of-00005.tfrecord
        ├── ...
        └── data_00004-of-00005.tfrecord
  4. (可选)下载模型

    官方提供了不少预训练模型,这里以Inception-ResNet-v2以例。

    cd $WORKSPACE/checkpoints
    wget http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
    tar zxf inception_resnet_v2_2016_08_30.tar.gz

训练

  1. 读入数据

    官方提供了读入Flowers数据集的代码models/slim/datasets/flowers.py,同样这里也是参考并修改成能读入上面定义的通用数据集。

    把下面代码写入models/slim/datasets/dataset_classification.py

    import os
    import tensorflow as tf
    slim = tf.contrib.slim
    
    def get_dataset(dataset_dir, num_samples, num_classes, labels_to_names_path=None, file_pattern="*.tfrecord"):
        file_pattern = os.path.join(dataset_dir, file_pattern)
        keys_to_features = {
            "image/encoded": tf.FixedLenFeature((), tf.string, default_value=""),
            "image/format": tf.FixedLenFeature((), tf.string, default_value="png"),
            "image/class/label": tf.FixedLenFeature(
                [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
        }
        items_to_handlers = {
            "image": slim.tfexample_decoder.Image(),
            "label": slim.tfexample_decoder.Tensor("image/class/label"),
        }
        decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
        items_to_descriptions = {
            "image": "A color image of varying size.",
            "label": "A single integer between 0 and " + str(num_classes - 1),
        }
        labels_to_names = None
        if labels_to_names_path is not None:
            fd = open(labels_to_names_path)
            labels_to_names = {i : line.strip() for i, line in enumerate(fd)}
            fd.close()
        return slim.dataset.Dataset(
                data_sources=file_pattern,
                reader=tf.TFRecordReader,
                decoder=decoder,
                num_samples=num_samples,
                items_to_descriptions=items_to_descriptions,
                num_classes=num_classes,
                labels_to_names=labels_to_names)
  2. 构建模型

    官方提供了许多模型在models/slim/nets/

    如需要自定义模型,则参考官方提供的模型并放在对应的文件夹即可。

  3. 开始训练

    官方提供了训练脚本,如果使用官方的数据读入和处理,可使用以下方式开始训练。

    cd $WORKSPACE/models/slim
    CUDA_VISIBLE_DEVICES="0" python train_image_classifier.py 
        --train_dir=train_logs 
        --dataset_name=flowers 
        --dataset_split_name=train 
        --dataset_dir=../../data/flowers 
        --model_name=inception_resnet_v2 
        --checkpoint_path=../../checkpoints/inception_resnet_v2_2016_08_30.ckpt 
        --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits 
        --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits 
        --max_number_of_steps=1000 
        --batch_size=32 
        --learning_rate=0.01 
        --learning_rate_decay_type=fixed 
        --save_interval_secs=60 
        --save_summaries_secs=60 
        --log_every_n_steps=10 
        --optimizer=rmsprop 
        --weight_decay=0.00004

    不fine-tune把--checkpoint_path, --checkpoint_exclude_scopes--trainable_scopes删掉。

    fine-tune所有层把--checkpoint_exclude_scopes--trainable_scopes删掉。

    如果只使用CPU则加上--clone_on_cpu=True

    其它参数可删掉用默认值或自行修改。

    使用自己的数据则需要修改models/slim/train_image_classifier.py

    from datasets import dataset_factory

    修改为

    from datasets import dataset_classification

    dataset = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    修改为

    dataset = dataset_classification.get_dataset(
        FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes, FLAGS.labels_to_names_path)

    tf.app.flags.DEFINE_string(
        "dataset_dir", None, "The directory where the dataset files are stored.")

    后加入

    tf.app.flags.DEFINE_integer(
        "num_samples", 3320, "Number of samples.")
    
    tf.app.flags.DEFINE_integer(
        "num_classes", 5, "Number of classes.")
    
    tf.app.flags.DEFINE_string(
        "labels_to_names_path", None, "Label names file path.")

    训练时执行以下命令即可:

    cd $WORKSPACE/models/slim
    python train_image_classifier.py 
        --train_dir=train_logs 
        --dataset_dir=../../data/train 
        --num_samples=3320 
        --num_classes=5 
        --labels_to_names_path=../../data/labels.txt 
        --model_name=inception_resnet_v2 
        --checkpoint_path=../../checkpoints/inception_resnet_v2_2016_08_30.ckpt 
        --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits 
        --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits
  4. 可视化log

    可一边训练一边可视化训练的log,可看到Loss趋势。

    tensorboard --logdir train_logs/

验证

官方提供了验证脚本。

python eval_image_classifier.py 
    --checkpoint_path=train_logs 
    --eval_dir=eval_logs 
    --dataset_name=flowers 
    --dataset_split_name=validation 
    --dataset_dir=../../data/flowers 
    --model_name=inception_resnet_v2

同样,如果是使用自己的数据集,则需要修改models/slim/eval_image_classifier.py

from datasets import dataset_factory

修改为

from datasets import dataset_classification

dataset = dataset_factory.get_dataset(
    FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

修改为

dataset = dataset_classification.get_dataset(
    FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes, FLAGS.labels_to_names_path)

tf.app.flags.DEFINE_string(
    "dataset_dir", None, "The directory where the dataset files are stored.")

后加入

tf.app.flags.DEFINE_integer(
    "num_samples", 350, "Number of samples.")

tf.app.flags.DEFINE_integer(
    "num_classes", 5, "Number of classes.")

tf.app.flags.DEFINE_string(
    "labels_to_names_path", None, "Label names file path.")

验证时执行以下命令即可:

python eval_image_classifier.py 
    --checkpoint_path=train_logs 
    --eval_dir=eval_logs 
    --dataset_dir=../../data/val 
    --num_samples=350 
    --num_classes=5 
    --model_name=inception_resnet_v2

可以一边训练一边验证,,注意使用其它的GPU或合理分配显存。

同样也可以可视化log,如果已经在可视化训练的log则建议使用其它端口,如:

tensorboard --logdir eval_logs/ --port 6007

测试

参考models/slim/eval_image_classifier.py,可编写读取图片用模型进行推导的脚本models/slim/test_image_classifier.py

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import math
import tensorflow as tf

from nets import nets_factory
from preprocessing import preprocessing_factory

slim = tf.contrib.slim

tf.app.flags.DEFINE_string(
    "master", "", "The address of the TensorFlow master to use.")

tf.app.flags.DEFINE_string(
    "checkpoint_path", "/tmp/tfmodel/",
    "The directory where the model was written to or an absolute path to a "
    "checkpoint file.")

tf.app.flags.DEFINE_string(
    "test_path", "", "Test image path.")

tf.app.flags.DEFINE_integer(
    "num_classes", 5, "Number of classes.")

tf.app.flags.DEFINE_integer(
    "labels_offset", 0,
    "An offset for the labels in the dataset. This flag is primarily used to "
    "evaluate the VGG and ResNet architectures which do not use a background "
    "class for the ImageNet dataset.")

tf.app.flags.DEFINE_string(
    "model_name", "inception_v3", "The name of the architecture to evaluate.")

tf.app.flags.DEFINE_string(
    "preprocessing_name", None, "The name of the preprocessing to use. If left "
    "as `None`, then the model_name flag is used.")

tf.app.flags.DEFINE_integer(
    "test_image_size", None, "Eval image size")

FLAGS = tf.app.flags.FLAGS


def main(_):
    if not FLAGS.test_list:
        raise ValueError("You must supply the test list with --test_list")

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        tf_global_step = slim.get_or_create_global_step()

        ####################
        # Select the model #
        ####################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(FLAGS.num_classes - FLAGS.labels_offset),
            is_training=False)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name,
            is_training=False)

        test_image_size = FLAGS.test_image_size or network_fn.default_image_size

        if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
        else:
            checkpoint_path = FLAGS.checkpoint_path

        tf.Graph().as_default()
        with tf.Session() as sess:
            image = open(FLAGS.test_path, "rb").read()
            image = tf.image.decode_jpeg(image, channels=3)
            processed_image = image_preprocessing_fn(image, test_image_size, test_image_size)
            processed_images = tf.expand_dims(processed_image, 0)
            logits, _ = network_fn(processed_images)
            predictions = tf.argmax(logits, 1)
            saver = tf.train.Saver()
            saver.restore(sess, checkpoint_path)
            np_image, network_input, predictions = sess.run([image, processed_image, predictions])
            print("{} {}".format(FLAGS.test_path, predictions[0]))

if __name__ == "__main__":
    tf.app.run()

测试时执行以下命令即可:

python test_image_classifier.py 
    --checkpoint_path=train_logs/ 
    --test_path=../../data/flower_photos/tulips/6948239566_0ac0a124ee_n.jpg 
    --num_classes=5 
    --model_name=inception_resnet_v2
阅读更多
声明:该文观点仅代表作者本人,入门客AI创业平台信息发布平台仅提供信息存储空间服务,如有疑问请联系rumenke@qq.com。
  • 上一篇:没有了
  • 下一篇:没有了
未上传头像