import torch
import torch.nn as nn
import numpy as np
class PolicyNet(nn.Module):
def __init__(self, input_size, output_size):
super(PolicyNet, self).__init__()
self.fc1 = nn.Linear(input_size, 24)
self.fc2 = nn.Linear(24, 36)
self.fc3 = nn.Linear(36, output_size)
def forward(self, x):
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = torch.sigmoid(self.fc3(x))
return x
p = PolicyNet(24, 8)
i = np.random.rand(24)
prob = p(torch.tensor(i).float())
from torch.distributions import Categorical
m = Categorical(prob)
m.sample()
m.log_prob(torch.tensor(1))