トップ 最新

かってきままな日々

2019-12-09 (Mo) [長年日記]

_ 初 PyTorch!

昨日くらいから PyTorch で書いてみた。

#!/usr/bin/env python

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(2, 2)

    def forward(self, x):
        x = self.fc1(x)
        return x

    def print(self):
        print(self.fc1.weight)
        print(self.fc1.bias)

net = Net()

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)

inputs = []
labels = []
r0 = torch.rand(1000)
r1 = torch.rand(1000)
for i in range(1000):
    x = 0 + r0[0] - 0.5
    y = 0 + r1[0] - 0.5
    inputs.append([x, y])
    labels.append(0)

r0 = torch.rand(1000)
r1 = torch.rand(1000)
for i in range(1000):
    x = 1 + r0[0] - 0.5
    y = 1 + r1[0] - 0.5
    inputs.append([x, y])
    labels.append(1)

data = [
    torch.FloatTensor(inputs),   # inputs
    torch.LongTensor(labels),    # labels
]

for epoch in range(10000):
    inputs, labels = data

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    # print loss
    # net.print()
    print(epoch, loss.item())

print('Finished Training')

inputs = []
r0 = torch.rand(1)
r1 = torch.rand(1)
x = r0[0]
y = r1[0]
inputs.append([x, y])
inputs = torch.FloatTensor(inputs)   # inputs
outputs = net(inputs)
print('inputs:')
print(inputs)
print('weight:')
print(net.fc1.weight)
print('bias:')
print(net.fc1.bias)
print('output:')
print(outputs)

一応、書けたことは書けた。結果が合ってるから、学習もできてるっぽい。

でも、Tensorflow とどこが違う? と考えるとよくわからん。

Tensorflow は Define and Run で、PyTorch は Define by Run らしい。 なんも違わねーじゃん! という…

https://eetimes.jp/ee/articles/1805/09/news030_4.html

言語処理などRNNで有効な手法だとした。

なるほど、RNN では確かに有効になるかもしれん。書いてみないとわからんけど。 ていうか、RNN 自体はライブラリにあるんだろうけど。 Tensorflow は、系列長の最大長を決めるしかなかったからな。その辺はなんとか なりそうな気がする。

RNN (LSTM でもいいけど) の簡単なサンプルタスクないかなぁ。 何がいいだろう。