Model, Algorithm, Agent¶
In the previous tutorial, we quickly demonstrate three basic blocks of PARL: Model
, Algorithm
, Agent
, and use these basic blocks to construct final Cartpole Agent
interacting with the environment. Now, in this tutorial, we will introduce in detail the specific positioning of each module, as well as the usage specifications.
Model¶
- Definition:
Model
defines Forward Network, it is usually a Policy Network or a Value Function Network. The current environment status (State) is the input to the Network. - ⚠️Warning: customized
Model
has to inheritparl.Model
. - Methods that must be implemented:
forward
: define computation of Forward Network according to components in__init__
.
- Remarks: Implementation of Target Network is easy in PARL using
copy.deepcopy
. - Examples:
import paddle
import paddle.nn as nn
import parl
import copy
class CartpoleModel(parl.Model):
def __init__(self, obs_dim, act_dim):
super(CartpoleModel, self).__init__()
hid1_size = act_dim * 10
self.fc1 = nn.Linear(obs_dim, hid1_size)
self.fc2 = nn.Linear(hid1_size, act_dim)
self.tanh = nn.Tanh()
self.softmax = nn.Softmax()
def forward(self, x):
out = self.tanh(self.fc1(x))
prob = self.softmax(self.fc2(out))
return prob
if __name__ == '__main__:
model = CartpoleModel()
target_model = copy.deepcopy(model)
Algorithm¶
- Definition:
Algorithm
will update the parameters of theModel
passed to it. In general, we define the loss function inAlgorithm
. EachAlgorithm
contains at least oneModel
. - ⚠️Warning: it is recommended to directly import PARL’s implementations of
Algorithms
. - Methods that must be implemented:
learn
: using training data (observations, rewards, actions, etc.) to update parameters inModel
.predict
: using current observation to predict the current action distribution or action value function.
- Examples:
model = CartpoleModel(act_dim=2)
algorithm = parl.algorithms.PolicyGradient(model, lr=1e-3)
Agent¶
- Definition:
Agent
is used to interact with the environment to generate training data. The training data is then passed toAlgorithm
to update the parameters ofModel
. It also handles data preprocessing. - ⚠️Warning: customized
Agent
has to inheritparl.Agent
and call parent class’s__init__
method inside its constructor. - Methods that must be implemented:
learn
: using training data (observations, rewards, actions, etc.) to update parameters inModel
.predict
: return a predicted action based on current observation, this function is often used for evaluation and deployment of theAgent
.sample
: return a sampled action based on current observation, this function is often used in training to help theAgent
explore the observation space.
- Examples:
class CartpoleAgent(parl.Agent):
def __init__(self, algorithm):
super(CartpoleAgent, self).__init__(algorithm)
def sample(self, obs):
obs = paddle.to_tensor(obs, dtype='float32')
prob = self.alg.predict(obs)
prob = prob.numpy()
act = np.random.choice(len(prob), 1, p=prob)[0]
return act
def predict(self, obs):
obs = paddle.to_tensor(obs, dtype='float32')
prob = self.alg.predict(obs)
act = int(prob.argmax())
return act
def learn(self, obs, act, reward):
act = np.expand_dims(act, axis=-1)
reward = np.expand_dims(reward, axis=-1)
obs = paddle.to_tensor(obs, dtype='float32')
act = paddle.to_tensor(act, dtype='int32')
reward = paddle.to_tensor(reward, dtype='float32')
loss = self.alg.learn(obs, act, reward)
return float(loss)
if __name__ == '__main__':
model = CartpoleModel()
alg = parl.algorithms.PolicyGradient(model, lr=1e-3)
agent = CartpoleAgent(alg, obs_dim=4, act_dim=2)