728x90
Flask와 pytorch로 웹페이지에서 손글씨를 받아 인식하고 label로 업데이터까지 해보자
mnist의 데이터셋이 28*28 인것 처럼 웹페이지에서 입력받을 손글씨의 픽셀을 28*28로 하자.
우선 28*28 사이즈의 table에 마우스 드래그로 숫자를 그릴수 있도록 html을 짜주자.
<!DOCTYPE html>
<html lang="en">
<head>
<title>Number Recognition</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<script src="https://code.jquery.com/jquery-3.6.4.min.js"></script>
<style>
table {
border-collapse: collapse;
width: 160 px;
}
table,
td,
th {
border: 1px solid black;
}
td {
padding: 5px
}
td.drawn {
background-color: black;
}
</style>
</head>
<body>
<h2>Number Recognition</h2>
<h3>Draw one from 0 to 9</h3>
<input type = "text" id = "text" >
<button type="button" onclick="updatemodel()">Update</button>
<button onclick="predictHandwriting()">Predict</button>
<br><span id="output"></span><br>
<script>
const numRows = 28;
const numCols = 28;
let drawing = false;
let binaryNumber = '';
document.write('<table>');
for (let i = 0; i < numRows; i++) {
document.write('<tr>');
for (let j = 0; j < numCols; j++) {
document.write('<td onmousedown="startDrawing()" onmouseup="stopDrawing()" onmousemove="cellHovered(this)"></td>');
}
document.write('</tr>');
}
document.write('</table>');
function startDrawing() {
drawing = true;
}
function stopDrawing() {
drawing = false;
updateNumberInput();
}
function cellHovered(cell) {
if (drawing) {
cell.classList.add('drawn');
}
}
function updateNumberInput() {
const numRows = 28;
const numCols = 28;
const tableCells = document.querySelectorAll('td');
binaryNumber = '';
tableCells.forEach((cell, index) => {
if (cell.classList.contains('drawn')) {
binaryNumber += '1';
} else {
binaryNumber += '0';
}
// 현재 인덱스가 행의 끝에 도달하면 ',' 추가
if ((index + 1) % numCols !== 0) {
binaryNumber += ',';
} else {
binaryNumber += '\n'; // 행의 끝에 도달하면 줄 바꿈 추가
}});}
function predictHandwriting() {
$.ajax({
url: '/predict',
type: 'POST',
contentType: 'application/json', // JSON 형식으로 데이터 보내기
data: JSON.stringify({ inputdata: binaryNumber }),
success: function(result) {
$('#output').text('Prediction Result: ' + result);
}});}
function updatemodel() {
var label = document.getElementById("text").value;
$.ajax({
url: '/update',
type: 'POST',
contentType: 'application/json',
data: JSON.stringify({ label: label, inputdata: binaryNumber }),
success: function(result) {
console.log(result);
},
error: function(error) {
console.error('Error:', error);
}});}
</script>
</body>
</html>
서버에 28*28 table을 검정이라면 0, 흰색이라면 1로 치환해, binaryNumber를 json 형식으로 보내자.
업데이트때도 필요한 label 데이터 역시 json 형식으로 전송한다.
from flask import Flask, render_template, request, jsonify
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
app = Flask(__name__)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(-1, 28 * 28)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
model = CNN()
model.load_state_dict(torch.load('mnist_model.pth', map_location=torch.device('cpu')))
model.eval()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 루트 URL에 대한 라우트 추가
@app.route('/')
def index():
return render_template('index.html')
# '/predict' URL에 대한 라우트 추가
@app.route('/predict', methods=['POST'])
def predict():
input_data = request.json['inputdata']
print(input_data)
image_matrix = process_tensor(input_data)
image_matrix = image_matrix
# 예측 수행
with torch.no_grad():
prediction = model(image_matrix)
predicted_digit = torch.argmax(prediction).item()
# 예측된 숫자를 문자열로 반환
return str(predicted_digit)
def process_tensor(number_data):
rows = number_data.strip().split('\n')
tensor_data = []
for row in rows:
values = [float(val) for val in row.strip().split(',')]
tensor_data.append(values)
return torch.tensor(tensor_data, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
@app.route('/update', methods=['post'])
def update_model():
input_data = request.json['inputdata']
label = request.json['label']
print(label)
image_matrix = process_tensor(input_data)
label = torch.tensor([int(label)], dtype = torch.long)
output = model(image_matrix)
loss = criterion(output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.save(model.state_dict(), 'mnist_model.pth')
return str("model updated successfully")
# 앱 실행
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=80)
json형식의 label과 inputdata를 각각 /update, /predict 엔트포인트에서 전달받는다
process_tensor에서 해당 데이터를 숫자형식으로 바꿔 28,28 형태의 텐서로 바꾼다
원래는 CNN을 통해 해당 데이터셋을 학습시키고자 했지만 cnn을 처리하는 과정에서
자꾸 차원 오류가나 그냥 ReLU의 단일 층 신경망을 이용해다
새로 추가될, 기존의 데이터 셋을 따로 관리하지 않고 바로바로 신경망 업데이트하고 휘발시키기에
관련 과정을 구현해 보는것도 좋을 듯 하다.
728x90
'Dev > DL' 카테고리의 다른 글
Termux Ubuntu에서 Pytroch 설치 (0) | 2024.08.09 |
---|---|
강화학습 자동매매 proj - 라즈베리파이 Pytorch 설치 & MNIST 라즈베리파이 colab 비교 (1) | 2024.01.28 |
강화학습 자동매매 proj - 라즈비안 (Raspbian) OS + ssh 접속 (0) | 2024.01.27 |
안드로이드 termux 리눅스 - pytorch 설치 (0) | 2024.01.27 |