๐ท ์ค์ต2
- Agent๊ฐ Target์ ํฅํด ์ด๋
- position, rigidbody ๊ฐ ๊ด์ธก
- ์ฐ์ : actions.ContinuousActions
๐ถ ๊ธฐ๋ณธ ์ธํ
- ํด๋ ์ ๋ฆฌ
https://assetstore.unity.com/packages/3d/characters/free-mummy-monster-134212
Free Mummy Monster | 3D ์บ๋ฆญํฐ | Unity Asset Store
Elevate your workflow with the Free Mummy Monster asset from amusedART. Find this & other ์บ๋ฆญํฐ options on the Unity Asset Store.
assetstore.unity.com
- ํจํค์ง ๋ค์ด๋ก๋ ํ ์ํฌํธ
๐ถ Agent ์ธํ
- Mummy_Monํ๋ฆฌํน์ Agent๋ก ์ด๋ฆ ๋ณ๊ฒฝ
- ๋จธํฐ๋ฆฌ์ผ ์ฐ๊ฒฐ > Apply
- ๋น๊ฒ์์ค๋ธ์ ํธ(Stage)
- ํ๋ธ(Floor)
- ๋น๊ฒ์์ค๋ธ์ ํธ(DeadZone)
- ํ๋ธ(Wall) ์์ฑ
- Mesh Renderer ๋นํ์ฑํ
- Tag(DEAD_ZONE) ์์ฑ > ์ ์ฉ
- ํ๋ธ(Target) ์์ฑ
- Tag(TARGET) ์์ฑ > ์ ์ฉ
- Agent ์ถ๊ฐ
- Capsule Collider ์ถ๊ฐ > ์กฐ์
- Rigidbody ์ถ๊ฐ > Freeze Rotation
- Animator Controller ์ฐ๊ฒฐ
๐ถ Script ์์ฑ : ML-Agent ๊ธฐ๋ณธ ํจ์
- MummyAgent.cs ์์ฑ
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents; // ML_Agents ๋ค์์คํ์ด์ค
/*
์์ด์ ํธ์ ์ญํ
1. ์ฃผ๋ณํ๊ฒฝ์ ๊ด์ธก(Observations)
2. ์ ์ฑ
์ ์ํด ํ๋(Action)
3. ๋ณด์(Reward)
*/
public class MummyAgent : Agent
{
}
- Agent์ค๋ธ์ ํธ์ MummyAgent.cs ์ฐ๊ฒฐ
- Max Step : ํด๋น ๊ฐ๋งํผ ์์ง์์ผ๋ ๋ณด์์ด ์์ ๋ ์ฒ์๋ถํฐ ๋ค์ ์์ํจ
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents; // ML_Agents ๋ค์์คํ์ด์ค
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
/*
์์ด์ ํธ์ ์ญํ
1. ์ฃผ๋ณํ๊ฒฝ์ ๊ด์ธก(Observations)
2. ์ ์ฑ
์ ์ํด ํ๋(Action)
3. ๋ณด์(Reward)
*/
public class MummyAgent : Agent
{
// ์ด๊ธฐํ ๋ฉ์๋
public override void Initialize()
{
}
// ์ํผ์๋(ํ์ต์ ๋จ์)๊ฐ ์์๋ ๋๋ง๋ค ํธ์ถ๋๋ ๋ฉ์๋
public override void OnEpisodeBegin()
{
}
// ์ฃผ๋ณ ํ๊ฒฝ์ ๊ด์ธกํ๋ ์ฝ๋ฐฑ ๋ฉ์๋
public override void CollectObservations(VectorSensor sensor)
{
}
// ์ ์ฑ
์ผ๋ก๋ถํฐ ์ ๋ฌ๋ฐ์ ๋ฐ์ดํฐ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ๋์ ์คํํ๋ ๋ฉ์๋
public override void OnActionReceived(ActionBuffers actions)
{
}
// ๊ฐ๋ฐ์์ ํ
์คํธ ์ฉ๋ / ๋ชจ๋ฐฉํ์ต
public override void Heuristic(in ActionBuffers actionsOut)
{
}
}
- ๊ธฐ๋ณธ ํจ์๋ค
๐ถ Script ์์ฑ :
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents; // ML_Agents ๋ค์์คํ์ด์ค
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
/*
์์ด์ ํธ์ ์ญํ
1. ์ฃผ๋ณํ๊ฒฝ์ ๊ด์ธก(Observations)
2. ์ ์ฑ
์ ์ํด ํ๋(Action)
3. ๋ณด์(Reward)
*/
public class MummyAgent : Agent
{
private new Rigidbody rigidbody;
private new Transform transform;
private Transform targetTr;
// ์ด๊ธฐํ ๋ฉ์๋
public override void Initialize()
{
rigidbody = GetComponent<Rigidbody>();
transform = GetComponent<Transform>();
targetTr = transform.parent.Find("Target");
}
// ์ํผ์๋(ํ์ต์ ๋จ์)๊ฐ ์์๋ ๋๋ง๋ค ํธ์ถ๋๋ ๋ฉ์๋
public override void OnEpisodeBegin()
{
// ๋ฌผ๋ฆฌ๋ ฅ์ ์ด๊ธฐํ
rigidbody.velocity = Vector3.zero;
rigidbody.angularVelocity = Vector3.zero;
// ์์ด์ ํธ์ ์์น๋ฅผ ๋ถ๊ท์นํ๊ฒ ๋ณ๊ฒฝ
transform.localPosition = new Vector3(Random.Range(-4.0f, 4.0f), 0.05f, Random.Range(-4.0f, 4.0f));
// ํ๊ฒ์ ์์น๋ฅผ ๋ถ๊ท์นํ๊ฒ ๋ณ๊ฒฝ
targetTr.localPosition = new Vector3(Random.Range(-4.0f, 4.0f), 0.55f, Random.Range(-4.0f, 4.0f));
}
// ์ฃผ๋ณ ํ๊ฒฝ์ ๊ด์ธกํ๋ ์ฝ๋ฐฑ ๋ฉ์๋
public override void CollectObservations(VectorSensor sensor)
{
// ์ด 8๊ฐ์ ๋ฐ์ดํฐ๋ฅผ ๊ด์ธก
sensor.AddObservation(targetTr.localPosition); // (x,y,z) 3๊ฐ์ ๋ฐ์ดํฐ
sensor.AddObservation(transform.localPosition); // (x,y,z) 3๊ฐ์ ๋ฐ์ดํฐ
sensor.AddObservation(rigidbody.velocity.x); // (x) 1๊ฐ์ ๋ฐ์ดํฐ
sensor.AddObservation(rigidbody.velocity.z); // (z) 1๊ฐ์ ๋ฐ์ดํฐ
}
// ์ ์ฑ
์ผ๋ก๋ถํฐ ์ ๋ฌ๋ฐ์ ๋ฐ์ดํฐ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ๋์ ์คํํ๋ ๋ฉ์๋
public override void OnActionReceived(ActionBuffers actions)
{
}
// ๊ฐ๋ฐ์์ ํ
์คํธ ์ฉ๋ / ๋ชจ๋ฐฉํ์ต
public override void Heuristic(in ActionBuffers actionsOut)
{
}
}
- ๊ด์ธกํ๋ ๋ฐ์ดํฐ ๊ฐ์๋ฅผ ์ ํํ๊ฒ ์์์ผํจ
- Behavior Name ๋ช ์
- Space Size : ๊ด์ธกํ๋ ๊ฐ ๊ฐ์
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents; // ML_Agents ๋ค์์คํ์ด์ค
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
/*
์์ด์ ํธ์ ์ญํ
1. ์ฃผ๋ณํ๊ฒฝ์ ๊ด์ธก(Observations)
2. ์ ์ฑ
์ ์ํด ํ๋(Action)
3. ๋ณด์(Reward)
*/
public class MummyAgent : Agent
{
private new Rigidbody rigidbody;
private new Transform transform;
private Transform targetTr;
// ์ด๊ธฐํ ๋ฉ์๋
public override void Initialize()
{
rigidbody = GetComponent<Rigidbody>();
transform = GetComponent<Transform>();
targetTr = transform.parent.Find("Target");
}
// ์ํผ์๋(ํ์ต์ ๋จ์)๊ฐ ์์๋ ๋๋ง๋ค ํธ์ถ๋๋ ๋ฉ์๋
public override void OnEpisodeBegin()
{
// ๋ฌผ๋ฆฌ๋ ฅ์ ์ด๊ธฐํ
rigidbody.velocity = Vector3.zero;
rigidbody.angularVelocity = Vector3.zero;
// ์์ด์ ํธ์ ์์น๋ฅผ ๋ถ๊ท์นํ๊ฒ ๋ณ๊ฒฝ
transform.localPosition = new Vector3(Random.Range(-4.0f, 4.0f), 0.05f, Random.Range(-4.0f, 4.0f));
// ํ๊ฒ์ ์์น๋ฅผ ๋ถ๊ท์นํ๊ฒ ๋ณ๊ฒฝ
targetTr.localPosition = new Vector3(Random.Range(-4.0f, 4.0f), 0.55f, Random.Range(-4.0f, 4.0f));
}
// ์ฃผ๋ณ ํ๊ฒฝ์ ๊ด์ธกํ๋ ์ฝ๋ฐฑ ๋ฉ์๋
public override void CollectObservations(VectorSensor sensor)
{
// ์ด 8๊ฐ์ ๋ฐ์ดํฐ๋ฅผ ๊ด์ธก
sensor.AddObservation(targetTr.localPosition); // (x,y,z) 3๊ฐ์ ๋ฐ์ดํฐ
sensor.AddObservation(transform.localPosition); // (x,y,z) 3๊ฐ์ ๋ฐ์ดํฐ
sensor.AddObservation(rigidbody.velocity.x); // (x) 1๊ฐ์ ๋ฐ์ดํฐ
sensor.AddObservation(rigidbody.velocity.z); // (z) 1๊ฐ์ ๋ฐ์ดํฐ
}
// ์ ์ฑ
์ผ๋ก๋ถํฐ ์ ๋ฌ๋ฐ์ ๋ฐ์ดํฐ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ๋์ ์คํํ๋ ๋ฉ์๋
public override void OnActionReceived(ActionBuffers actions)
{
// ์ฐ์ : actions.ContinuousActions
// ์ด์ฐ : actions.DiscreteActions
var action = actions.ContinuousActions;
// [0] Up, Down
// [1] Left, Right
Vector3 dir = (Vector3.forward * action[0]) + (Vector3.right * action[1]);
rigidbody.AddForce(dir.normalized * 50.0f);
// ๊ฐ๋งํ ์์ ๋ ๋ง์ด๋์ค ํ๋ํฐ ๋ถ์ฌ
SetReward(-0.001f);
}
// ๊ฐ๋ฐ์์ ํ
์คํธ ์ฉ๋ / ๋ชจ๋ฐฉํ์ต
public override void Heuristic(in ActionBuffers actionsOut)
{
}
}
- ์ฐ์์ ์ธ ๊ฐ 2๊ฐ : Up,Down / Left, Right
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents; // ML_Agents ๋ค์์คํ์ด์ค
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
/*
์์ด์ ํธ์ ์ญํ
1. ์ฃผ๋ณํ๊ฒฝ์ ๊ด์ธก(Observations)
2. ์ ์ฑ
์ ์ํด ํ๋(Action)
3. ๋ณด์(Reward)
*/
public class MummyAgent : Agent
{
private new Rigidbody rigidbody;
private new Transform transform;
private Transform targetTr;
// ์ด๊ธฐํ ๋ฉ์๋
public override void Initialize()
{
rigidbody = GetComponent<Rigidbody>();
transform = GetComponent<Transform>();
targetTr = transform.parent.Find("Target");
}
// ์ํผ์๋(ํ์ต์ ๋จ์)๊ฐ ์์๋ ๋๋ง๋ค ํธ์ถ๋๋ ๋ฉ์๋
public override void OnEpisodeBegin()
{
// ๋ฌผ๋ฆฌ๋ ฅ์ ์ด๊ธฐํ
rigidbody.velocity = Vector3.zero;
rigidbody.angularVelocity = Vector3.zero;
// ์์ด์ ํธ์ ์์น๋ฅผ ๋ถ๊ท์นํ๊ฒ ๋ณ๊ฒฝ
transform.localPosition = new Vector3(Random.Range(-4.0f, 4.0f), 0.05f, Random.Range(-4.0f, 4.0f));
// ํ๊ฒ์ ์์น๋ฅผ ๋ถ๊ท์นํ๊ฒ ๋ณ๊ฒฝ
targetTr.localPosition = new Vector3(Random.Range(-4.0f, 4.0f), 0.55f, Random.Range(-4.0f, 4.0f));
}
// ์ฃผ๋ณ ํ๊ฒฝ์ ๊ด์ธกํ๋ ์ฝ๋ฐฑ ๋ฉ์๋
public override void CollectObservations(VectorSensor sensor)
{
// ์ด 8๊ฐ์ ๋ฐ์ดํฐ๋ฅผ ๊ด์ธก
sensor.AddObservation(targetTr.localPosition); // (x,y,z) 3๊ฐ์ ๋ฐ์ดํฐ
sensor.AddObservation(transform.localPosition); // (x,y,z) 3๊ฐ์ ๋ฐ์ดํฐ
sensor.AddObservation(rigidbody.velocity.x); // (x) 1๊ฐ์ ๋ฐ์ดํฐ
sensor.AddObservation(rigidbody.velocity.z); // (z) 1๊ฐ์ ๋ฐ์ดํฐ
}
// ์ ์ฑ
์ผ๋ก๋ถํฐ ์ ๋ฌ๋ฐ์ ๋ฐ์ดํฐ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ๋์ ์คํํ๋ ๋ฉ์๋
public override void OnActionReceived(ActionBuffers actions)
{
// ์ฐ์ : actions.ContinuousActions
// ์ด์ฐ : actions.DiscreteActions
var action = actions.ContinuousActions;
// [0] Up, Down
// [1] Left, Right
Vector3 dir = (Vector3.forward * action[0]) + (Vector3.right * action[1]);
rigidbody.AddForce(dir.normalized * 50.0f);
// ๊ฐ๋งํ ์์ ๋ ๋ง์ด๋์ค ํ๋ํฐ ๋ถ์ฌ
SetReward(-0.001f);
}
// ๊ฐ๋ฐ์์ ํ
์คํธ ์ฉ๋ / ๋ชจ๋ฐฉํ์ต
public override void Heuristic(in ActionBuffers actionsOut)
{
var action = actionsOut.ContinuousActions;
action[0] = Input.GetAxis("Vertical");
action[1] = Input.GetAxis("Horizontal");
}
}
- Decision Requester ์ปดํฌ๋ํธ ์ถ๊ฐ
// Decision Requester ์ปดํฌ๋ํธ : ๊ฒฐ์ ์ ์์ฒญํ๋ ์์ฒญ์
// Decision Period : 5๋ฒ์ ํ ๋ฒ ๊ฒฐ์ ์ ์์ฒญํจ
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents; // ML_Agents ๋ค์์คํ์ด์ค
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
/*
์์ด์ ํธ์ ์ญํ
1. ์ฃผ๋ณํ๊ฒฝ์ ๊ด์ธก(Observations)
2. ์ ์ฑ
์ ์ํด ํ๋(Action)
3. ๋ณด์(Reward)
*/
public class MummyAgent : Agent
{
private new Rigidbody rigidbody;
private new Transform transform;
private Transform targetTr;
public Material goodMAT;
public Material badMAT;
private Material originMAT;
private Renderer floorRd;
// ์ด๊ธฐํ ๋ฉ์๋
public override void Initialize()
{
rigidbody = GetComponent<Rigidbody>();
transform = GetComponent<Transform>();
targetTr = transform.parent.Find("Target");
floorRd = transform.parent.Find("Floor").GetComponent<Renderer>();
originMAT = floorRd.material;
}
// ์ํผ์๋(ํ์ต์ ๋จ์)๊ฐ ์์๋ ๋๋ง๋ค ํธ์ถ๋๋ ๋ฉ์๋
public override void OnEpisodeBegin()
{
// ๋ฌผ๋ฆฌ๋ ฅ์ ์ด๊ธฐํ
rigidbody.velocity = Vector3.zero;
rigidbody.angularVelocity = Vector3.zero;
// ์์ด์ ํธ์ ์์น๋ฅผ ๋ถ๊ท์นํ๊ฒ ๋ณ๊ฒฝ
transform.localPosition = new Vector3(Random.Range(-4.0f, 4.0f), 0.05f, Random.Range(-4.0f, 4.0f));
// ํ๊ฒ์ ์์น๋ฅผ ๋ถ๊ท์นํ๊ฒ ๋ณ๊ฒฝ
targetTr.localPosition = new Vector3(Random.Range(-4.0f, 4.0f), 0.55f, Random.Range(-4.0f, 4.0f));
StartCoroutine(ReverMaterial());
}
IEnumerator ReverMaterial()
{
yield return new WaitForSeconds(0.2f);
floorRd.material = originMAT;
}
// ์ฃผ๋ณ ํ๊ฒฝ์ ๊ด์ธกํ๋ ์ฝ๋ฐฑ ๋ฉ์๋
public override void CollectObservations(VectorSensor sensor)
{
// ์ด 8๊ฐ์ ๋ฐ์ดํฐ๋ฅผ ๊ด์ธก
sensor.AddObservation(targetTr.localPosition); // (x,y,z) 3๊ฐ์ ๋ฐ์ดํฐ
sensor.AddObservation(transform.localPosition); // (x,y,z) 3๊ฐ์ ๋ฐ์ดํฐ
sensor.AddObservation(rigidbody.velocity.x); // (x) 1๊ฐ์ ๋ฐ์ดํฐ
sensor.AddObservation(rigidbody.velocity.z); // (z) 1๊ฐ์ ๋ฐ์ดํฐ
}
// ์ ์ฑ
์ผ๋ก๋ถํฐ ์ ๋ฌ๋ฐ์ ๋ฐ์ดํฐ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ๋์ ์คํํ๋ ๋ฉ์๋
public override void OnActionReceived(ActionBuffers actions)
{
// ์ฐ์ : actions.ContinuousActions
// ์ด์ฐ : actions.DiscreteActions
var action = actions.ContinuousActions;
// [0] Up, Down
// [1] Left, Right
Vector3 dir = (Vector3.forward * action[0]) + (Vector3.right * action[1]);
rigidbody.AddForce(dir.normalized * 50.0f);
// ๊ฐ๋งํ ์์ ๋ ๋ง์ด๋์ค ํ๋ํฐ ๋ถ์ฌ
SetReward(-0.001f);
}
// ๊ฐ๋ฐ์์ ํ
์คํธ ์ฉ๋ / ๋ชจ๋ฐฉํ์ต
public override void Heuristic(in ActionBuffers actionsOut)
{
var action = actionsOut.ContinuousActions;
action[0] = Input.GetAxis("Vertical");
action[1] = Input.GetAxis("Horizontal");
}
// ๋ณด์ ์ฒ๋ฆฌ ๋ก์ง
void OnCollisionEnter(Collision coll)
{
if (coll.collider.CompareTag("DEAD_ZONE"))
{
floorRd.material = badMAT;
SetReward(-1.0f);
EndEpisode(); // ํ์ต ์ข
๋ฃ
}
if (coll.collider.CompareTag("TARGET"))
{
floorRd.material = goodMAT;
SetReward(1.0f);
EndEpisode();
}
}
}
- ๋ฆฌ์๋ ์ถ๊ฐ
- ํธ๋ ์ด๋์ด ์ ๋๊ณ ์๋์ง ์๊ฐํ : ์ถฉ๋ ์ floor ์์ ๋ณํ
- Stage ํ๋ฆฌํนํ
- ์คํ ์ด์ง ๋ณต์ฌ
๐ถ ํธ๋ ์ด๋ ์ํค๊ธฐ
- 3DBall.yamlํ์ผ์ ๋ณต์ฌ > ์ด๋ฆ์ Mummy.yaml๋ก
- Mummy.yaml ํ์ผ์ด ์์ฑ๋์
- Mummy.yamlํ์ผ open
behaviors:
Mummy: #Unity Behavior Parameters ์ปดํฌ๋ํธ์ Behavior ์ด๋ฆ๊ณผ ๋์ผํด์ผํจ
trainer_type: ppo
hyperparameters:
batch_size: 64
buffer_size: 12000
learning_rate: 0.0003
beta: 0.001
epsilon: 0.2
lambd: 0.99
num_epoch: 3
learning_rate_schedule: linear
network_settings:
normalize: true
hidden_units: 128
num_layers: 2
vis_encode_type: simple
reward_signals:
extrinsic:
gamma: 0.99
strength: 1.0
keep_checkpoints: 5
max_steps: 300000 #300000๋ฒ ์คํ ํ ํ์ต ์ข
๋ฃ
time_horizon: 1000
summary_freq: 10000 #10000๋ฒ์ ํ๋ฒ์ฉ summary
- Mummy.yaml
- behavior ์ด๋ฆ ์ฃผ์ํ๊ธฐ
- Mummy.yaml ํ์ผ์ด ์๋ ppoํด๋ ์์์ ์์ ์ฝ๋๋ฅผ ์คํ
- ์ ๋ํฐ์์๋ ์คํ
- 30000๋ฒ ๋๋ฉด ํ์ต ์ข ๋ฃ๋จ
๐ถ ํ์ต์ํจ ๋ชจ๋ธ ์ ์ฉ์ํค๊ธฐ
- ํ์ตํ ๋ชจ๋ธ(.onnx)์ ํ๋ก์ ํธ ์ฐฝ์ ๋์ด๋ค ๋๊ธฐ
- Agent์ model ํ๋ผ๋ฏธํฐ์ ํ์ตํ ๋ชจ๋ธ(.onnx) ๋์ด๋ค๋๊ธฐ
- apply to prefab
- playํด ๋ณด๋ฉด ์ ํ์ต๋ ๊ฒฐ๊ณผ๋ฅผ ๋ณผ ์ ์์
'Unity > ML-Agents' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
06. ML-Agents - Soccer (1) (1) | 2021.08.02 |
---|---|
05. ML-Agents - Imitation Learning (0) | 2021.08.02 |
04. ML-Agents - Camera Sensor (0) | 2021.08.01 |
03. ML-Agents - Ray Perception Sensor 3D (0) | 2021.07.30 |
01. ML-Agents - ์ค์น ๋ฐ ๊ฐ๋จํ ์ค์ต (1) | 2021.07.30 |