๐ท ์ค์ต 5 - ๋ชจ๋ฐฉํ์ต
- ๊ฐํํ์ต๋ง์ ์ฌ์ฉํ๋ฉด ํจ์จ์ด ์๋์ค๋ ๊ฒฝ์ฐ
- ์ฌ๋์ด ์๋ฒ๋ณด์ด๋ ๊ฒ์ ๋ ์ฝ๋ฉ์ ๋ ์, ๋ นํํ์ผ์ ์ค
- ์ปดํฌ๋ํธ : Demonstration Recorder
- Hint์ ์์๊ณผ ๊ฐ์ ์์์ ํ๋ธ๋ก ์ด๋
๐ท ํธ๋ ์ด๋ ํ๊ฒฝ ๊ตฌ์ถ
- ์ฌ ์์ฑ : MummyIL
- ๋น๊ฒ์์ค๋ธ์ ํธ : Stage
- ํ๋ธ : Floor
- Scale : 10, 0.1, 10
- ๋น๊ฒ์์ค๋ธ์ ํธ : Walls
- ํ๋ธ : Wall
- Scale : 10, 1, 1
- ๋ฒฝ์ธ์ฐ๊ธฐ
- Tag ๋ถ์ฌ : WALL
- agent ์ถ๊ฐ
- position : 0, 0.05, 0
- cube ์์ฑ : Hint
- position : 0, 0.55, 0
- cube : Black, Red, Green, Blue
- Scale : 2, 1, 2
- ๋ชจ์๋ฆฌ์ ์์น
- ํ๊ทธ ์ถ๊ฐ : BLACK, RED, BLUE, GREEN
- ๊ฐ ํ๋ธ์ ํด๋นํ๋ ํ๊ทธ ๋ถ์ฌ
- ์คํฌ๋ฆฝํธ ์์ฑ : StageManagerIL.cs
- Stage์ StageManagerIL ์คํฌ๋ฆฝํธ ์ถ๊ฐ
๐ถ StageManagerIL.cs
- Hint์ ์์์ ๋๋ค์ผ๋ก ๋ณ๊ฒฝ์ํด
- ์ด์ ์ ๋์๋ ์์์ ์๋์ค๋๋ก ์ธํ
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class StageManagerIL : MonoBehaviour
{
public enum HINT_COLOR
{
BLACK, BLUE, GREEN, RED
}
// ํํธ์ ์์
public HINT_COLOR hintColor = HINT_COLOR.BLACK;
public Material[] hintMt;
private new Renderer renderer;
void Start()
{
renderer = transform.Find("Hint").GetComponent<Renderer>();
}
public void InitStage()
{
int idx = Random.Range(0, hintMt.Length);
renderer.material = hintMt[idx];
// ๋ชฉํํ๊ฒ์ ์์์ ์ง์
hintColor = (HINT_COLOR)idx;
}
}
- Hint์ ์์ ๋ถ์ฌ
- ์์ ๋์ผํ๊ฒ ๋จธํฐ๋ฆฌ์ผ ๋ฃ์ด์ฃผ๊ธฐ
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class StageManagerIL : MonoBehaviour
{
public enum HINT_COLOR
{
BLACK, BLUE, GREEN, RED
}
// ํํธ์ ์์
public HINT_COLOR hintColor = HINT_COLOR.BLACK;
public Material[] hintMt;
private new Renderer renderer;
// ๋ฐ๋ก ์ ์ ๋์๋ ์์์ ์ ์ฅํ ๋ณ์
private int prevTag = -1;
void Start()
{
renderer = transform.Find("Hint").GetComponent<Renderer>();
}
public void InitStage()
{
int idx = 0;
do
{
idx = Random.Range(0, hintMt.Length);
} while (idx == prevTag);
prevTag = idx;
renderer.material = hintMt[idx];
// ๋ชฉํํ๊ฒ์ ์์์ ์ง์
hintColor = (HINT_COLOR)idx;
}
private void Update()
{
if (Input.GetMouseButtonDown(0))
{
InitStage();
}
}
}
- ๊ฐ์ ์์์ผ๋ก ์ค๋ณต๋์ง ์๋๋ก ์ธํ
- Tag ์ถ๊ฐ : HINT_BLACK, HINT_RED, HINT_BLUE, HINT_GREEN
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class StageManagerIL : MonoBehaviour
{
public enum HINT_COLOR
{
BLACK, BLUE, GREEN, RED
}
// ํํธ์ ์์
public HINT_COLOR hintColor = HINT_COLOR.BLACK;
public Material[] hintMt;
public string[] hintTag;
private new Renderer renderer;
// ๋ฐ๋ก ์ ์ ๋์๋ ์์์ ์ ์ฅํ ๋ณ์
private int prevTag = -1;
void Start()
{
renderer = transform.Find("Hint").GetComponent<Renderer>();
}
public void InitStage()
{
int idx = 0;
do
{
idx = Random.Range(0, hintMt.Length);
} while (idx == prevTag);
prevTag = idx;
renderer.material = hintMt[idx];
// ๋ชฉํํ๊ฒ์ ์์์ ์ง์
hintColor = (HINT_COLOR)idx;
}
private void Update()
{
if (Input.GetMouseButtonDown(0))
{
InitStage();
}
}
}
- public string[] hintTag; ์ถ๊ฐ
- ์์ ๋จธํฐ๋ฆฌ์ผ๊ณผ ๋์ผํ ์์๋ก Tag๊ฐ ์ ๋ ฅ
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class StageManagerIL : MonoBehaviour
{
public enum HINT_COLOR
{
BLACK, BLUE, GREEN, RED
}
// ํํธ์ ์์
public HINT_COLOR hintColor = HINT_COLOR.BLACK;
public Material[] hintMt;
public string[] hintTag;
private new Renderer renderer;
// ๋ฐ๋ก ์ ์ ๋์๋ ์์์ ์ ์ฅํ ๋ณ์
private int prevTag = -1;
void Start()
{
renderer = transform.Find("Hint").GetComponent<Renderer>();
}
public void InitStage()
{
int idx = 0;
do
{
idx = Random.Range(0, hintMt.Length);
} while (idx == prevTag);
prevTag = idx;
// Hint์ ๋จธํฐ๋ฆฌ์ผ ๊ต์ฒด
renderer.material = hintMt[idx];
// Hint์ ํ๊ทธ๋ฅผ ์ง์
renderer.gameObject.tag = hintTag[idx];
// ๋ชฉํํ๊ฒ์ ์์์ ์ง์
hintColor = (HINT_COLOR)idx;
}
private void Update()
{
if (Input.GetMouseButtonDown(0))
{
InitStage();
}
}
}
- Hint์ ์์๊ณผ ๋์ผํ Tag ๋ถ์ฌ
๐ถ Agent ์ธํ
- Capsule Collider ์ถ๊ฐ
- Center : 0, 0.5, 0
- Radius : 0.3
- Rigidbody ์ถ๊ฐ
- Collision Detection : Continuous
- Freeze Rotation
- animator controller ์ฐ๊ฒฐ
- MummyILAgent.cs ์์ฑ
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class MummyILAgent : Agent
{
private StageManagerIL stageManager;
public override void Initialize()
{
stageManager = transform.parent.GetComponent<StageManagerIL>();
}
public override void OnEpisodeBegin()
{
stageManager.InitStage();
}
public override void CollectObservations(VectorSensor sensor)
{
}
public override void OnActionReceived(ActionBuffers actions)
{
}
public override void Heuristic(in ActionBuffers actionsOut)
{
}
}
- ๊ธฐ๋ณธ ํจ์ ์์ฑ
- ์ํผ์๋๊ฐ ์์ํ ๋๋ง๋ค InitStage() ํธ์ถ
- Agent์ MummyILAgent ์คํฌ๋ฆฝํธ ์ถ๊ฐ
- Max step : 10 : hint์ ์์๊ณผ ํ๊ทธ๊ฐ ๋ณํ๋์ง ํ์ธ
- ๋ค์ 0์ผ๋ก ๋๋๋ ค๋๊ธฐ
๐ถ MummyILAgent.cs : ์ด๋๋ก์ง
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class MummyILAgent : Agent
{
private new Transform transform;
private new Rigidbody rigidbody;
public float moveSpeed = 1.5f;
public float turnSpeed = 200.0f;
private StageManagerIL stageManager;
private Renderer floorRd;
private Material originMt;
public Material goodMt, badMt;
public override void Initialize()
{
MaxStep = 2000;
transform = GetComponent<Transform>();
rigidbody = GetComponent<Rigidbody>();
floorRd = transform.parent.Find("Floor").GetComponent<Renderer>();
originMt = floorRd.material;
stageManager = transform.parent.GetComponent<StageManagerIL>();
}
public override void OnEpisodeBegin()
{
stageManager.InitStage();
// ๋ฌผ๋ฆฌ๋ ฅ ์ด๊ธฐํ
rigidbody.velocity = rigidbody.angularVelocity = Vector3.zero;
// Agent์ ์์น๋ฅผ ์ด๊ธฐํ
transform.localPosition = new Vector3(0, 0.0f, -2.5f);
transform.localRotation = Quaternion.identity;
}
public override void CollectObservations(VectorSensor sensor)
{
}
public override void OnActionReceived(ActionBuffers actions)
{
var action = actions.DiscreteActions;
Debug.Log($"[0] = {action[0]}, [1]: {action[1]}");
Vector3 dir = Vector3.zero;
Vector3 rot = Vector3.zero;
// Branch 0 : ์ ์ง / ์ ์ง / ํ์ง
switch (action[0])
{
case 1: dir = transform.forward; break;
case 2: dir = -transform.forward; break;
}
// Branch 1 : ์ ์ง / ์ขํ์ / ์ฐํ์
switch (action[1])
{
case 1: rot = -transform.up; break;
case 2: rot = transform.up; break;
}
transform.Rotate(rot, Time.fixedDeltaTime * turnSpeed);
rigidbody.AddForce(dir * moveSpeed, ForceMode.VelocityChange);
// ๋ง์ด๋์ค ํ๋ํฐ๋ฅผ ์ ์ฉ
// ๋ชจ๋ฐฉํ์ต์ ํ ๊ฒ์ด๊ธฐ ๋๋ฌธ์ ๊ตณ์ด ํจ๋ํฐ ์์ค๋ ๋จ
AddReward(-1 / (float)MaxStep);
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var action = actionsOut.DiscreteActions;
actionsOut.Clear();
if (Input.GetKey(KeyCode.W))
{
action[0] = 1;
}
if (Input.GetKey(KeyCode.S))
{
action[0] = 2;
}
if (Input.GetKey(KeyCode.A))
{
action[1] = 1;
}
if (Input.GetKey(KeyCode.D))
{
action[1] = 2;
}
}
}
- ์ด๋ ์ฒ๋ฆฌ
- Space Size : 0
- Discrete Branches : 2
- Branch 0 : 3
- Branch 1 : 3
- ์ปดํฌ๋ํธ ์ถ๊ฐ : Decision Requester
๐ถ ๋ฆฌ์๋ ๋ถ์ฌ ๋ฐ floor ์์ ๋ณ๊ฒฝ
- ๋จธํฐ๋ฆฌ์ผ ์ฐ๊ฒฐ
- ์ปดํฌ๋ํธ ์ถ๊ฐ : Ray Perception Sensor 3D
- Ray๋ก ์ธ์งํด์ผํ ๋์์ ์ ๋ ฅ
- Hint์ 4๊ฐ์ ๋ธ๋ญ์ ์ธ์งํ ์ ์๋๋ก ์กฐ์
- Max Ray Degrees : 90
- Start Vertical Offset : 0.5
- End Vertical Offset : 0.5
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class MummyILAgent : Agent
{
private new Transform transform;
private new Rigidbody rigidbody;
public float moveSpeed = 1.5f;
public float turnSpeed = 200.0f;
private StageManagerIL stageManager;
private Renderer floorRd;
private Material originMt;
public Material goodMt, badMt;
public override void Initialize()
{
MaxStep = 2000;
transform = GetComponent<Transform>();
rigidbody = GetComponent<Rigidbody>();
floorRd = transform.parent.Find("Floor").GetComponent<Renderer>();
originMt = floorRd.material;
stageManager = transform.parent.GetComponent<StageManagerIL>();
}
public override void OnEpisodeBegin()
{
stageManager.InitStage();
// ๋ฌผ๋ฆฌ๋ ฅ ์ด๊ธฐํ
rigidbody.velocity = rigidbody.angularVelocity = Vector3.zero;
// Agent์ ์์น๋ฅผ ์ด๊ธฐํ
transform.localPosition = new Vector3(0, 0.0f, -3.5f);
transform.localRotation = Quaternion.identity;
}
public override void CollectObservations(VectorSensor sensor)
{
}
public override void OnActionReceived(ActionBuffers actions)
{
var action = actions.DiscreteActions;
Debug.Log($"[0] = {action[0]}, [1]: {action[1]}");
Vector3 dir = Vector3.zero;
Vector3 rot = Vector3.zero;
// Branch 0 : ์ ์ง / ์ ์ง / ํ์ง
switch (action[0])
{
case 1: dir = transform.forward; break;
case 2: dir = -transform.forward; break;
}
// Branch 1 : ์ ์ง / ์ขํ์ / ์ฐํ์
switch (action[1])
{
case 1: rot = -transform.up; break;
case 2: rot = transform.up; break;
}
transform.Rotate(rot, Time.fixedDeltaTime * turnSpeed);
rigidbody.AddForce(dir * moveSpeed, ForceMode.VelocityChange);
// ๋ง์ด๋์ค ํ๋ํฐ๋ฅผ ์ ์ฉ
// ๋ชจ๋ฐฉํ์ต์ ํ ๊ฒ์ด๊ธฐ ๋๋ฌธ์ ๊ตณ์ด ํจ๋ํฐ๋ฅผ ๋ถ์ฌํ ํ์๋ ์์
AddReward(-1 / (float)MaxStep);
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var action = actionsOut.DiscreteActions;
actionsOut.Clear();
if (Input.GetKey(KeyCode.W))
{
action[0] = 1;
}
if (Input.GetKey(KeyCode.S))
{
action[0] = 2;
}
if (Input.GetKey(KeyCode.A))
{
action[1] = 1;
}
if (Input.GetKey(KeyCode.D))
{
action[1] = 2;
}
}
private void OnCollisionEnter(Collision coll)
{
if (coll.collider.tag == stageManager.hintColor.ToString())
{
SetReward(+1.0f);
EndEpisode();
StartCoroutine(ReverMaterial(goodMt));
}
else
{
if (coll.collider.CompareTag("WALL") || coll.gameObject.name == "Hint")
{
AddReward(-0.05f);
}
else
{
SetReward(-1.0f);
EndEpisode();
StartCoroutine(ReverMaterial(badMt));
}
}
}
IEnumerator ReverMaterial(Material changeMt)
{
floorRd.material = changeMt;
yield return new WaitForSeconds(0.2f);
floorRd.material = originMt;
}
}
- ๋ฆฌ์๋ ๋ถ์ฌ, floor ์์ ๋ณ๊ฒฝ
๐ท ๋ ์ฝ๋ฉ
๐ถ ์ฝ๋ ์์
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class MummyILAgent : Agent
{
private new Transform transform;
private new Rigidbody rigidbody;
public float moveSpeed = 1.5f;
public float turnSpeed = 200.0f;
private StageManagerIL stageManager;
private Renderer floorRd;
private Material originMt;
public Material goodMt, badMt;
public override void Initialize()
{
MaxStep = 2000;
transform = GetComponent<Transform>();
rigidbody = GetComponent<Rigidbody>();
floorRd = transform.parent.Find("Floor").GetComponent<Renderer>();
originMt = floorRd.material;
stageManager = transform.parent.GetComponent<StageManagerIL>();
}
public override void OnEpisodeBegin()
{
stageManager.InitStage();
// ๋ฌผ๋ฆฌ๋ ฅ ์ด๊ธฐํ
rigidbody.velocity = rigidbody.angularVelocity = Vector3.zero;
// Agent์ ์์น๋ฅผ ์ด๊ธฐํ
transform.localPosition = new Vector3(0, 0.0f, -3.5f);
transform.localRotation = Quaternion.identity;
}
public override void CollectObservations(VectorSensor sensor)
{
}
public override void OnActionReceived(ActionBuffers actions)
{
var action = actions.DiscreteActions;
Debug.Log($"[0] = {action[0]}, [1]: {action[1]}");
Vector3 dir = Vector3.zero;
Vector3 rot = Vector3.zero;
// Branch 0 : ์ ์ง / ์ ์ง / ํ์ง
switch (action[0])
{
case 1: dir = transform.forward; break;
case 2: dir = -transform.forward; break;
}
// Branch 1 : ์ ์ง / ์ขํ์ / ์ฐํ์
switch (action[1])
{
case 1: rot = -transform.up; break;
case 2: rot = transform.up; break;
}
transform.Rotate(rot, Time.fixedDeltaTime * turnSpeed);
rigidbody.AddForce(dir * moveSpeed, ForceMode.VelocityChange);
// ๋ง์ด๋์ค ํ๋ํฐ๋ฅผ ์ ์ฉ
// ๋ชจ๋ฐฉํ์ต์ ํ ๊ฒ์ด๊ธฐ ๋๋ฌธ์ ๊ตณ์ด ํจ๋ํฐ๋ฅผ ๋ถ์ฌํ ํ์๋ ์์
// AddReward(-1 / (float)MaxStep);
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var action = actionsOut.DiscreteActions;
actionsOut.Clear();
if (Input.GetKey(KeyCode.W))
{
action[0] = 1;
}
if (Input.GetKey(KeyCode.S))
{
action[0] = 2;
}
if (Input.GetKey(KeyCode.A))
{
action[1] = 1;
}
if (Input.GetKey(KeyCode.D))
{
action[1] = 2;
}
}
private void OnCollisionEnter(Collision coll)
{
if (coll.gameObject.name == "Floor") return;
if (coll.collider.tag == stageManager.hintColor.ToString())
{
SetReward(+1.0f);
EndEpisode();
StartCoroutine(ReverMaterial(goodMt));
}
else
{
if (coll.collider.CompareTag("WALL") || coll.gameObject.name == "Hint")
{
SetReward(-0.05f);
}
else
{
SetReward(-1.0f);
EndEpisode();
StartCoroutine(ReverMaterial(badMt));
}
}
}
IEnumerator ReverMaterial(Material changeMt)
{
floorRd.material = changeMt;
yield return new WaitForSeconds(0.2f);
floorRd.material = originMt;
}
}
- AddReward(-1 / (float)MaxStep) : ์ฃผ์์ฒ๋ฆฌ
- ๋ฒฝ๊ณผ Hint์ ๋ถ๋ชํ์ ๋ : AddReward(-0.05f) -> SetReward(-0.05f) : ๋ณ๊ฒฝ
๐ถ ๋ ์ฝ๋ฉ ์ธํ
- ํด๋ ์์ฑ : MummyIL
- Stage ํ๋ฆฌํนํ
- ์ปดํฌ๋ํธ ์ถ๊ฐ : Demonstration Recorder
- ํด๋ ์ถ๊ฐ : Demo : ๋ ์ฝ๋ฉํ ํ์ผ์ ์ ์ฅํ ํด๋
- Record : ์ฒดํฌ
- Num Steps To Record : 0 : ๋ช ์คํ ๊น์ง ๋ ์ฝ๋ฉํ ์ง : 0์ด๋ฉด ๋ฌด์ ํ
- Name : MummyIL : ์ ์ฅํ ํ์ผ๋ช
- Directory : Assets/Demo : ํ์ผ์ ์ ์ฅํ ์์น
- ํ๋ ์ด
- ๋ ์ ์์ผ๋ฉด ์ค์ํ๋ฉด ์๋จ
- ๋ ์ฝ๋ฉ ํ Record ์ธ์ฒดํฌ
- Apply All
๐ท ํธ๋ ์ด๋
- ์คํ ์ด์ง ๋ณต์ฌ
- PushBlock.yaml ๋ณต์ฌ
- behavior ์ด๋ฆ ๋ณ๊ฒฝ
- demo_path : MummyIL.demo
- max_steps : 30๋ง๋ฒ
- recordingํ ํ์ผ์ imitation ํด๋๋ก ๊ฐ์ ธ์ด
- ํธ๋ ์ด๋ ์์
- ์ ๋ํฐ์์ play
๐ถ ํธ๋ ์ด๋ summary
๐ถ ํธ๋ ์ด๋ ๊ฒฐ๊ณผ ํ์ธํด๋ณด๊ธฐ
- ํ์ต๋ ๋ชจ๋ธ(.onnx)์ ํ๋ก์ ํธ์ฐฝ์ ๋์ด๋ค๋๊ธฐ
- ๋ชจ๋ธ ์ฐ๊ฒฐํ๊ธฐ
- Apply to Prefab
- Play
'Unity > ML-Agents' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
07. ML-Agents - Soccer (2) (0) | 2021.08.03 |
---|---|
06. ML-Agents - Soccer (1) (1) | 2021.08.02 |
04. ML-Agents - Camera Sensor (0) | 2021.08.01 |
03. ML-Agents - Ray Perception Sensor 3D (0) | 2021.07.30 |
02. ML-Agents - position,rigidbody ๊ด์ธก (0) | 2021.07.30 |