一看就懂的Tensorflow实战(随机森林)

Posted AI异构

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了一看就懂的Tensorflow实战(随机森林)相关的知识,希望对你有一定的参考价值。

随机森林简介

随机森林是一种集成学习方法。训练时每个树分类器从样本集里面随机有放回的抽取一部分进行训练。预测时将要分类的样本带入一个个树分类器,然后以少数服从多数的原则,表决出这个样本的最终分类类型。[4]

设有N个样本,M个变量(维度)个数,该算法具体流程如下:

  1. 确定一个值m,它用来表示每个树分类器选取多少个变量;

  2. 从数据集中有放回的抽取 k 个样本集,用它们创建 k 个树分类器。另外还伴随生成了 k 个袋外数据,用来后面做检测。

  3. 输入待分类样本之后,每个树分类器都会对它进行分类,然后所有分类器按照少数服从多数原则,确定分类结果。

重要参数:

  1. 预选变量个数 (即框架流程中的m);

  2. 随机森林中树的个数。

Tensorflow 随机森林

from __future__ import print_function

import tensorflow as tf
from tensorflow.python.ops import resources
from tensorflow.contrib.tensor_forest.python import tensor_forest

# Ignore all GPUs, tf random forest does not benefit from it.
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

补充:__futrure__[1]
简单来说,Python 的每个新版本都会增加一些新的功能,或者对原来的功能作一些改动。有些改动是不兼容旧版本的。从 Python 2.7 到 Python 3 就有不兼容的一些改动,如果你想在 Python 2.7 中使用 Python 3 的新特性,那么你就需要从__future__模块导入。

  • division
    python2.7 中,不导入__future__,10/3 = 3
    python2.7 中,导入__future__,10/3 = 3.3333333333333335
    很容易看出来,2.7中默认的整数除法是结果向下取整,而导入了 __future__ 之后除法就是真正的除法了。

  • absolute_import
    python2.7 中,默认导入模块是相对导入的(relative import),即以'.'点导入:
    from . import json
    from .json import json_dump
    绝对导入(absolute import)则是指从系统路径sys.path最底层的模块导入:
    import os
    from os import sys

  • print_function
    python2 中 print 不需要括号,而在 python3 中则需要。
    print "Hello world" # python2.7
    print("Hello world") # python3

导入数据集

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./data/", one_hot=False)

Extracting ./data/train-images-idx3-ubyte.gz
Extracting ./data/train-labels-idx1-ubyte.gz
Extracting ./data/t10k-images-idx3-ubyte.gz
Extracting ./data/t10k-labels-idx1-ubyte.gz

设置参数

# Parameters
num_steps = 500 # Total steps to train
batch_size = 1024 # The number of samples per batch
num_classes = 10 # The 10 digits
num_features = 784 # Each image is 28x28 pixels
num_trees = 10
max_nodes = 1000

补充:Estimator API
Estimator 跟 Dataset 都是 Tensorflow 中的高级API。
Estimator 是一种创建 TensorFlow 模型的高级方法,它包括了用于常见机器学习任务的预制模型,当然,你也可以使用它们来创建你的自定义模型。[3]

contrib.tensor_forest 详细的实现了随机森林算法(Random Forests)评估器,并对外提供 high-level API。你只需传入 params 到构造器,params 使用 params.fill() 来填充,而不用传入所有的超参数,Tensor Forest 自己的 RandomForestGraphs 就能使用这些参数来构建整幅图。[2]

# Input and Target data
X = tf.placeholder(tf.float32, shape=[None, num_features])
# For random forest, labels must be integers (the class id)
Y = tf.placeholder(tf.int32, shape=[None])

# Random Forest Parameters
hparams = tensor_forest.ForestHParams(num_classes=num_classes,
                                     num_features=num_features,
                                     num_trees=num_trees,
                                     max_nodes=max_nodes).fill()
# Build the Random Forest
forest_graph = tensor_forest.RandomForestGraphs(hparams)

INFO:tensorflow:Constructing forest with params =
INFO:tensorflow:{'num_trees': 10, 'max_nodes': 1000, 'bagging_fraction': 1.0, 'feature_bagging_fraction': 1.0, 'num_splits_to_consider': 28, 'max_fertile_nodes': 0, 'split_after_samples': 250, 'valid_leaf_threshold': 1, 'dominate_method': 'bootstrap', 'dominate_fraction': 0.99, 'model_name': 'all_dense', 'split_finish_name': 'basic', 'split_pruning_name': 'none', 'collate_examples': False, 'checkpoint_stats': False, 'use_running_stats_method': False, 'initialize_average_splits': False, 'inference_tree_paths': False, 'param_file': None, 'split_name': 'less_or_equal', 'early_finish_check_every_samples': 0, 'prune_every_samples': 0, 'num_classes': 10, 'num_features': 784, 'bagged_num_features': 784, 'bagged_features': None, 'regression': False, 'num_outputs': 1, 'num_output_columns': 11, 'base_random_seed': 0, 'leaf_model_type': 0, 'stats_model_type': 0, 'finish_type': 0, 'pruning_type': 0, 'split_type': 0}

损失函数

# Get training graph and loss
train_op = forest_graph.training_graph(X, Y)
loss_op = forest_graph.training_loss(X, Y)

# Measure the accuracy
infer_op, _, _ = forest_graph.inference_graph(X)
correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))
accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

训练

# Initialize the variables (i.e. assign their default value) and forest resources
init_vars = tf.group(tf.global_variables_initializer(),
   resources.initialize_resources(resources.shared_resources()))

# Start TensorFlow session
sess = tf.train.MonitoredSession()

# Run the initializer
sess.run(init_vars)

# Training
for i in range(1, num_steps + 1):
   # Prepare Data
   # Get the next batch of MNIST data (only images are needed, not labels)
   batch_x, batch_y = mnist.train.next_batch(batch_size)
   _, l = sess.run([train_op, loss_op], feed_dict={X: batch_x, Y: batch_y})
   if i % 50 == 0 or i == 1:
       acc = sess.run(accuracy_op, feed_dict={X: batch_x, Y: batch_y})
       print('Step %i, Loss: %f, Acc: %f' % (i, l, acc))

# Test Model
test_x, test_y = mnist.test.images, mnist.test.labels
print("Test Accuracy:", sess.run(accuracy_op, feed_dict={X: test_x, Y: test_y}))

INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
Step 1, Loss: -1.000000, Acc: 0.411133
Step 50, Loss: -254.800003, Acc: 0.892578
Step 100, Loss: -538.799988, Acc: 0.915039
Step 150, Loss: -826.599976, Acc: 0.922852
Step 200, Loss: -1001.000000, Acc: 0.926758
Step 250, Loss: -1001.000000, Acc: 0.919922
Step 300, Loss: -1001.000000, Acc: 0.933594
Step 350, Loss: -1001.000000, Acc: 0.916992
Step 400, Loss: -1001.000000, Acc: 0.916992
Step 450, Loss: -1001.000000, Acc: 0.927734
Step 500, Loss: -1001.000000, Acc: 0.917969
Test Accuracy: 0.9212

参考

[1] Python __future__ 模块:https://blog.csdn.net/langb2014/article/details/53305246#t1

[2] 【机器学习】在TensorFlow中构建自定义Estimator:深度解析TensorFlow组件Estimator:http://www.doc88.com/p-1834979585771.html

[3] TensorFlow 1.3的Datasets和Estimator知多少?谷歌大神来解答:http://www.sohu.com/a/191717118_390227

[4] 穆晨:随机森林(Random Forest):https://www.cnblogs.com/muchen/p/6883263.html


-长按关注-


以上是关于一看就懂的Tensorflow实战(随机森林)的主要内容,如果未能解决你的问题,请参考以下文章

一看就懂的快速排序

一看就懂的快速排序

一看就懂的ESLint

一看就懂的DDD-(Domain Drive Design领域驱动设计)设计思想

一看就懂的DDD-(Domain Drive Design领域驱动设计)设计思想

新手一看就懂的线程池