markdown Tensorflow.js图像分类
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了markdown Tensorflow.js图像分类相关的知识,希望对你有一定的参考价值。
Make sure you have local versions of tensorflow.js available. p5 js is also useful for the continuous predictive draw loop
Put the following in your index.html file
```html
<html>
<head>
<meta name="viewport" width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=0>
<style> body {padding: 0; margin: 0;} </style>
<script src="tf.min.js"></script>
<script src="p5.min.js"></script>
<script src="main.js"></script>
</head>
<body>
<canvas id="canvasId" style="display: none"></canvas>
<video id="videoId" autoplay></video>
<img id="imageId"/>
<br>
<button style="margin-top: 15px; margin-left: 15px" onclick="categorize('RED')">Red</button>
<button style="margin-top: 15px; margin-left: 15px" onclick="categorize('GREEN')">Green</button>
<button style="margin-top: 15px; margin-left: 15px" onclick="categorize('BLUE')">Blue</button>
<br>
<button style="margin-top: 15px; margin-left: 15px" onclick="trainModel()">Train Model</button>
<br>
<p style="margin-top: 15px; margin-left: 15px" id="paragraphId"></p>
</body>
</html>
```
Add the following to your main javascript file
```javascript
var categories = {
RED: 0,
GREEN: 1,
BLUE: 2
};
var categoriesInverse = {
0: 'RED',
1: 'GREEN',
2: 'BLUE'
};
var data = [];
var model;
var isTraining = false;
var isTrained = false;
var imgWidth = 100;
var imgHeight = 75;
var retainColor = true;
var predictionFrequency = 30;
function setup() {
loadVideoPreview();
}
function draw() {
if (isTrained && !isTraining){
if (frameCount % predictionFrequency === 0){
predict();
}
}
}
function loadVideoPreview () {
var videoElement = document.querySelector("#videoId");
var canvasElement = document.querySelector("#canvasId");
var imageElement = document.querySelector("#imageId");
videoElement.width = imgWidth;
canvasElement.width = imgWidth;
imageElement.width = imgWidth;
videoElement.height = imgHeight;
canvasElement.height = imgHeight;
imageElement.height = imgHeight;
navigator.getUserMedia = navigator.getUserMedia || navigator.webkitGetUserMedia || navigator.mozGetUserMedia || navigator.msGetUserMedia || navigator.oGetUserMedia;
if (navigator.getUserMedia) {
navigator.getUserMedia({ video: true }, handleVideo, videoError);
}
function handleVideo(stream) {
videoElement.srcObject = stream
videoStream = stream;
}
function videoError(e) {
console.log(e);
}
};
function captureImage() {
var videoElement = document.querySelector("#videoId");
var canvasElement = document.querySelector("#canvasId");
var imageElement = document.querySelector("#imageId");
var canvasContext = canvasElement.getContext('2d');
canvasContext.drawImage(videoElement, 0, 0, canvasElement.width, canvasElement.height);
var image = canvasContext.getImageData(0, 0, canvasElement.width, canvasElement.height);
var pixels = [];
if (retainColor){
pixels = convertToRgb(image.data);
var px = 0;
for (var i = 0; i < image.data.length; i += 4) {
image.data[i+0] = pixels[px];
image.data[i+1] = pixels[px+1];
image.data[i+2] = pixels[px+2];
image.data[i+3] = 255;
px+=3;
}
canvasContext.putImageData(image, 0, 0);
}
else {
pixels = convertToGrayscale(image.data);
var px = 0;
for (var i = 0; i < image.data.length; i += 4) {
image.data[i+0] = pixels[px];
image.data[i+1] = pixels[px];
image.data[i+2] = pixels[px];
image.data[i+3] = 255;
px++;
}
canvasContext.putImageData(image, 0, 0);
}
var photo = canvasElement.toDataURL();
imageElement.src = photo;
return pixels;
}
function categorize (categoryName) {
var pixels = captureImage();
var datum = {
Image: pixels,
Category: categoryName
}
data.push(datum);
}
//Remove Alpha from color
function convertToRgb(input) {
var output = [];
for (var i = 0; i < input.length; i+=4){
output.push(input[i]);
output.push(input[i+1]);
output.push(input[i+2]);
}
return output;
}
function convertToGrayscale(input) {
var output = [];
for (var i = 0; i < input.length; i+=4){
var r = input[i];
var g = input[i+1];
var b = input[i+2];
output.push(floor((r + g + b) / 3));
}
return output;
}
function buildModel() {
let md = tf.sequential();
var colorSpace = retainColor ? 3 : 1;
const hidden = tf.layers.dense({
units: 15,
inputShape: [imgWidth * imgHeight * colorSpace],
activation: 'sigmoid'
});
const output = tf.layers.dense({
units: 3,
activation: 'softmax'
});
md.add(hidden);
md.add(output);
const LEARNING_RATE = 0.25;
const optimizer = tf.train.sgd(LEARNING_RATE);
md.compile({
optimizer: optimizer,
loss: 'categoricalCrossentropy',
metrics: ['accuracy'],
});
return md
}
function trainModel() {
if (data == null || data.length == 0){
console.error('No Training Data Found');
return;
}
let images = [];
let cats = [];
for (let datum of data) {
//Normalize color 0-255 to 0-1
var normalized = datum.Image.map(function (p) { return p / 255 });
images.push(normalized);
cats.push(categories[datum.Category]);
}
var xs = tf.tensor2d(images);
let categoryTensor = tf.tensor1d(cats, 'int32');
var ys = tf.oneHot(categoryTensor, 3).cast('float32');
categoryTensor.dispose();
model = buildModel();
isTraining = true;
model.fit(xs, ys, {
shuffle: true,
validationSplit: 0.1,
epochs: 10,
callbacks: {
onEpochEnd: (epoch, logs) => {
console.log('EPOCH: ' + epoch);
},
onBatchEnd: async (batch, logs) => {
tf.nextFrame();
},
onTrainEnd: () => {
isTraining = false;
isTrained = true;
console.log('finished');
},
},
});
}
function predict() {
if (isTraining){
console.log('Still Training');
return;
}
var pixels = captureImage();
//Normalize color 0-255 to 0-1
var normalized = pixels.map(function (p) { return p / 255 });
var input = tf.tensor2d(normalized, [1, normalized.length]);
let results = model.predict(input);
let argMax = results.argMax(1);
let index = argMax.dataSync()[0];
var prediction = categoriesInverse[index];
var paragraphElement = document.getElementById("paragraphId");
paragraphElement.innerText = 'I predict ' + prediction;
}
```
以上是关于markdown Tensorflow.js图像分类的主要内容,如果未能解决你的问题,请参考以下文章
机器学习Tensorflow.js:我在浏览器中使用机器学习实现了图像分类
如何利用TensorFlow.js部署简单的AI版「你画我猜」图像识别应用
教程 | 如何利用TensorFlow.js部署简单的AI版「你画我猜」图像识别应用