markdown Tensorflow.js图像分类

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了markdown Tensorflow.js图像分类相关的知识,希望对你有一定的参考价值。

Make sure you have local versions of tensorflow.js available. p5 js is also useful for the continuous predictive draw loop

Put the following in your index.html file

```html
<html>
  <head>
    <meta name="viewport" width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=0>
    <style> body {padding: 0; margin: 0;} </style>
    <script src="tf.min.js"></script>
    <script src="p5.min.js"></script>
    <script src="main.js"></script>
  </head>
  <body>
      <canvas id="canvasId" style="display: none"></canvas>
      <video id="videoId" autoplay></video>
      <img id="imageId"/>
      <br>
      <button style="margin-top: 15px; margin-left: 15px" onclick="categorize('RED')">Red</button>
      <button style="margin-top: 15px; margin-left: 15px" onclick="categorize('GREEN')">Green</button>
      <button style="margin-top: 15px; margin-left: 15px" onclick="categorize('BLUE')">Blue</button>
      <br>
      <button style="margin-top: 15px; margin-left: 15px" onclick="trainModel()">Train Model</button>
      <br>
      <p style="margin-top: 15px; margin-left: 15px" id="paragraphId"></p>
  </body>
</html>
```

Add the following to your main javascript file

```javascript
var categories = {
  RED: 0,
  GREEN: 1,
  BLUE: 2
};

var categoriesInverse = {
  0: 'RED',
  1: 'GREEN',
  2: 'BLUE'
};

var data = [];

var model;
var isTraining = false;
var isTrained = false;

var imgWidth = 100;
var imgHeight = 75;

var retainColor = true;

var predictionFrequency = 30;

function setup() {
  loadVideoPreview();
}

function draw() {
  if (isTrained && !isTraining){
    if (frameCount % predictionFrequency === 0){
      predict();
    }
  }
}

function loadVideoPreview () {
  var videoElement = document.querySelector("#videoId");
  var canvasElement = document.querySelector("#canvasId");
  var imageElement = document.querySelector("#imageId");

  videoElement.width = imgWidth;
  canvasElement.width = imgWidth;
  imageElement.width = imgWidth;
  videoElement.height = imgHeight;
  canvasElement.height = imgHeight;
  imageElement.height = imgHeight;

  navigator.getUserMedia = navigator.getUserMedia || navigator.webkitGetUserMedia || navigator.mozGetUserMedia || navigator.msGetUserMedia || navigator.oGetUserMedia;

  if (navigator.getUserMedia) {
      navigator.getUserMedia({ video: true }, handleVideo, videoError);
  }

  function handleVideo(stream) {
      videoElement.srcObject = stream
      videoStream = stream;
  }

  function videoError(e) {
      console.log(e);
  }
};

function captureImage() {
  var videoElement = document.querySelector("#videoId");
  var canvasElement = document.querySelector("#canvasId");
  var imageElement = document.querySelector("#imageId");

  var canvasContext = canvasElement.getContext('2d');

  canvasContext.drawImage(videoElement, 0, 0, canvasElement.width, canvasElement.height);

  var image = canvasContext.getImageData(0, 0, canvasElement.width, canvasElement.height);

  var pixels = [];
  if (retainColor){
    pixels = convertToRgb(image.data);

    var px = 0;
    for (var i = 0; i < image.data.length; i += 4) {
      image.data[i+0] = pixels[px];
      image.data[i+1] = pixels[px+1];
      image.data[i+2] = pixels[px+2];
      image.data[i+3] = 255;
      px+=3;
    }
    canvasContext.putImageData(image, 0, 0);
  }
  else {
    pixels = convertToGrayscale(image.data);

    var px = 0;
    for (var i = 0; i < image.data.length; i += 4) {
      image.data[i+0] = pixels[px];
      image.data[i+1] = pixels[px];
      image.data[i+2] = pixels[px];
      image.data[i+3] = 255;
      px++;
    }
    canvasContext.putImageData(image, 0, 0);
  }

  var photo = canvasElement.toDataURL();

  imageElement.src = photo;

  return pixels;
}

function categorize (categoryName) {
  var pixels = captureImage();

  var datum = {
    Image: pixels,
    Category: categoryName
  }

  data.push(datum);
}

//Remove Alpha from color
function convertToRgb(input) {
  var output = [];
  for (var i = 0; i < input.length; i+=4){
    output.push(input[i]);
    output.push(input[i+1]);
    output.push(input[i+2]);
  }
  return output;
}

function convertToGrayscale(input) {
  var output = [];
  for (var i = 0; i < input.length; i+=4){
    var r = input[i];
    var g = input[i+1];
    var b = input[i+2];

    output.push(floor((r + g + b) / 3));
  }
  return output;
}

function buildModel() {
  let md = tf.sequential();

  var colorSpace = retainColor ? 3 : 1;

  const hidden = tf.layers.dense({
    units: 15,
    inputShape: [imgWidth * imgHeight * colorSpace],
    activation: 'sigmoid'
  });

  const output = tf.layers.dense({
    units: 3,
    activation: 'softmax'
  });
  md.add(hidden);
  md.add(output);

  const LEARNING_RATE = 0.25;
  const optimizer = tf.train.sgd(LEARNING_RATE);

  md.compile({
    optimizer: optimizer,
    loss: 'categoricalCrossentropy',
    metrics: ['accuracy'],
  });

  return md
}

function trainModel() {
  if (data == null || data.length == 0){
    console.error('No Training Data Found');
    return;
  }

  let images = [];
  let cats = [];
  for (let datum of data) {
    //Normalize color 0-255 to 0-1
    var normalized = datum.Image.map(function (p) { return p / 255 });
    images.push(normalized);
    cats.push(categories[datum.Category]);
  }

  var xs = tf.tensor2d(images);
  let categoryTensor = tf.tensor1d(cats, 'int32');

  var ys = tf.oneHot(categoryTensor, 3).cast('float32');
  categoryTensor.dispose();

  model = buildModel();

  isTraining = true;

  model.fit(xs, ys, {
    shuffle: true,
    validationSplit: 0.1,
    epochs: 10,
    callbacks: {
      onEpochEnd: (epoch, logs) => {
        console.log('EPOCH: ' + epoch);
      },
      onBatchEnd: async (batch, logs) => {
        tf.nextFrame();
      },
      onTrainEnd: () => {
        isTraining = false;
        isTrained = true;
        console.log('finished');
      },
    },
  });
}

function predict() {
  if (isTraining){
    console.log('Still Training');
    return;
  }

  var pixels = captureImage();

  //Normalize color 0-255 to 0-1
  var normalized = pixels.map(function (p) { return p / 255 });

  var input = tf.tensor2d(normalized, [1, normalized.length]);

  let results = model.predict(input);

  let argMax = results.argMax(1);
  let index = argMax.dataSync()[0];

  var prediction = categoriesInverse[index];

  var paragraphElement = document.getElementById("paragraphId");

  paragraphElement.innerText = 'I predict ' + prediction;
}
```

以上是关于markdown Tensorflow.js图像分类的主要内容,如果未能解决你的问题,请参考以下文章

机器学习Tensorflow.js:我在浏览器中使用机器学习实现了图像分类

如何利用TensorFlow.js部署简单的AI版「你画我猜」图像识别应用

教程 | 如何利用TensorFlow.js部署简单的AI版「你画我猜」图像识别应用

如何在 nodejs (tensorflow.js) 中训练模型?

使用 TensorFlow.js 在浏览器中自定义目标检测

使用 TensorFlow.js 在浏览器中自定义目标检测