Policy Gradient

class PolicyGradient(model, lr)[源代码]

基类:Algorithm

__init__(model, lr)[源代码]

Policy gradient algorithm

参数:
  • model (parl.Model) – model defining forward network of policy.

  • lr (float) – learning rate.

learn(obs, action, reward)[源代码]

Update model with policy gradient algorithm

参数:
  • obs (paddle tensor) – shape of (batch_size, obs_dim)

  • action (paddle tensor) – shape of (batch_size, 1)

  • reward (paddle tensor) – shape of (batch_size, 1)

返回:

shape of (1)

返回类型:

loss (paddle tensor)

predict(obs)[源代码]

Predict the probability of actions

参数:

obs (paddle tensor) – shape of (obs_dim,)

返回:

shape of (action_dim,)

返回类型:

prob (paddle tensor)