๐ท ์ค์ต 3
- bad item์ ๋ถ๋ชํ๋ฉด -1, good item์ ๋ถ๋ชํ๋ฉด +1
- Raycast ๋ฐ์ฌํ์ฌ ๊ด์ธก
- good item / bad item
- ์ด์ฐ๊ฐ DiscreteActions
๐ท ํธ๋ ์ด๋ ํ๊ฒฝ ๊ตฌ์ถ
๐ถ ๊ธฐ๋ณธ์ธํ
- ์ฌ ์์ฑ(MummyRay)
- ๋น ๊ฒ์์ค๋ธ์ ํธ (Stage)
- Cube (Floor)
- Scale : 50, 0.1, 50
- ๋น๊ฒ์์ค๋ธ์ ํธ(Walls)
- Cube (Wall)
- Scale : 50, 3, 1
- Agent ์ถ๊ฐ
- Position : 0, 0.05, 0
https://assetstore.unity.com/packages/3d/characters/creatures/rpg-monster-duo-pbr-polyart-157762
- RPG Monster Duo PBR Polyart ๋ค์ด > ์ํฌํธ
- Slime, TurtleShell ์ถ๊ฐ
- Slime, TurtleShell -> GoodItem, BadItem์ผ๋ก ์ด๋ฆ ๋ณ๊ฒฝ
- Sphere Collider ์ถ๊ฐ
- Center : 0, 0.5, 0
- Radius : 0.7
- ํ๊ทธ ์ถ๊ฐ ํ ๊ฐ๊ฐ์ ์ค๋ธ์ ํธ์ ํ๊ทธ ๋ฌ๊ธฐ
- MummyRayAgent.cs ์์ฑ
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
public class MummyRayAgent : Agent
{
public override void Initialize()
{
}
public override void OnEpisodeBegin()
{
}
public override void CollectObservations(VectorSensor sensor)
{
}
public override void OnActionReceived(ActionBuffers actions)
{
}
public override void Heuristic(in ActionBuffers actionsOut)
{
}
}
- ๊ธฐ๋ณธ ํจ์ ์์ฑ
- MummyRayAgent.cs ๋ฃ๊ธฐ
- Behavior Name : MummyRay
- ์ปดํฌ๋ํธ ์ถ๊ฐ : Ray Perception Sensor 3D
- ๋ถ์ฑ๊ผด๋ชจ์์ ๋ ์ด๋ฅผ ์จ
- Ray์ ์์์ ์กฐ์
- Start Vertical Offset : 0.5
- End Vertical Offset : 0.5
- Detectable Tags : ๋ ์ด๋ก ๊ฒ์ถํ๋ ๋์
- Tag 3๊ฐ ์ถ๊ฐ
- Rays per Direction : 4 : ๋ช ๊ฐ์ ๋ ์ด๋ฅผ ์ ๊ฒ์ธ์ง
- Max Ray Degrees : 80 : ๋ ์ด ์ฌ์ด์ ๊ฐ๋
- Ray Length : 30 : ๋ ์ด์ ๊ธธ์ด
๐ถ Agent ์์น, ํ์ ๊ฐ ๋๋ค์ผ๋ก ์์ฑ
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
public class MummyRayAgent : Agent
{
private new Transform transform;
private new Rigidbody rigidbody;
public float moveSpeed = 1.5f;
public float turnSpeed = 200.0f;
public override void Initialize()
{
// ํ ์ํผ์๋(ํ์ต๋จ์) ๋น ์๋ ํ์
MaxStep = 5000;
transform = GetComponent<Transform>();
rigidbody = GetComponent<Rigidbody>();
}
public override void OnEpisodeBegin()
{
// ์คํ
์ด์ง ์ด๊ธฐํ
// ๋ฌผ๋ฆฌ์์ง ์ด๊ธฐํ
rigidbody.velocity = rigidbody.angularVelocity = Vector3.zero;
// ์์ด์ ํธ์ ์์น๋ฅผ ๋ณ๊ฒฝ
transform.localPosition = new Vector3(Random.Range(-22.0f, 22.0f), 0.05f, Random.Range(-22.0f, 22.0f));
// ์์ด์ ํธ์ ํ์ ๊ฐ ๋ณ๊ฒฝ
transform.localRotation = Quaternion.Euler(Vector3.up * Random.Range(0, 360));
}
public override void CollectObservations(VectorSensor sensor)
{
}
public override void OnActionReceived(ActionBuffers actions)
{
}
public override void Heuristic(in ActionBuffers actionsOut)
{
}
}
- Agent ์์น, ํ์ ๊ฐ ๋๋ค์ผ๋ก ์์ฑ
- ์ปดํฌ๋ํธ ์ถ๊ฐ : Rigidbody, Capsule Collider
- ๊ฐ ์ปดํฌ๋ํธ ์กฐ์
๐ถ good item, bad item ์์ฑ
- GoodItem, BadItem ํ๋ฆฌํนํ
- GoodItem, BadItem์ ๊ฐ๊ฐ Animator Controller ์ฒ๋ฆฌ
- Idle๋ง ํ๋๋ก Transition์ ๋์ด์ค
- GoodItem, BadItem ํ์ด์ด๋ผํค ์ฐฝ์์ ์ญ์
- StageManager.cs ์์ฑ
- Stage์ค๋ธ์ ํธ์ StageManager.cs ๋ฃ๊ธฐ
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class StageManager : MonoBehaviour
{
public GameObject goodItem;
public GameObject badItem;
public int goodItemCount = 30;
public int badItemCount = 20;
public List<GameObject> goodList = new List<GameObject>();
public List<GameObject> badList = new List<GameObject>();
void Start()
{
SetStageObject();
}
public void SetStageObject()
{
// Good Item ์์ฑ
for (int i = 0; i < goodItemCount; i++)
{
Vector3 pos = new Vector3(Random.Range(-23.0f, 23.0f), 0.05f, Random.Range(-23.0f, 23.0f));
Quaternion rot = Quaternion.Euler(Vector3.up * Random.Range(0, 360));
goodList.Add(Instantiate(goodItem, transform.position + pos, rot, transform));
}
// Bad Item ์์ฑ
for (int i = 0; i < badItemCount; i++)
{
Vector3 pos = new Vector3(Random.Range(-23.0f, 23.0f), 0.05f, Random.Range(-23.0f, 23.0f));
Quaternion rot = Quaternion.Euler(Vector3.up * Random.Range(0, 360));
badList.Add(Instantiate(badItem, transform.position + pos, rot, transform));
}
}
}
- GoodItem, BadItem ์ด 50๊ฐ ์์ฑ
- GoodItem, BadItem ํ๋ฆฌํน ์ฐ๊ฒฐ
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class StageManager : MonoBehaviour
{
public GameObject goodItem;
public GameObject badItem;
public int goodItemCount = 30;
public int badItemCount = 20;
public List<GameObject> goodList = new List<GameObject>();
public List<GameObject> badList = new List<GameObject>();
void Start()
{
SetStageObject();
}
public void SetStageObject()
{
foreach (var obj in goodList)
{
Destroy(obj);
}
foreach (var obj in badList)
{
Destroy(obj);
}
// List ์ด๊ธฐํ
goodList.Clear();
badList.Clear();
// Good Item ์์ฑ
for (int i = 0; i < goodItemCount; i++)
{
Vector3 pos = new Vector3(Random.Range(-23.0f, 23.0f), 0.05f, Random.Range(-23.0f, 23.0f));
Quaternion rot = Quaternion.Euler(Vector3.up * Random.Range(0, 360));
goodList.Add(Instantiate(goodItem, transform.position + pos, rot, transform));
}
// Bad Item ์์ฑ
for (int i = 0; i < badItemCount; i++)
{
Vector3 pos = new Vector3(Random.Range(-23.0f, 23.0f), 0.05f, Random.Range(-23.0f, 23.0f));
Quaternion rot = Quaternion.Euler(Vector3.up * Random.Range(0, 360));
badList.Add(Instantiate(badItem, transform.position + pos, rot, transform));
}
}
}
- good item, bad item ์์ฑ
๐ถ Agent ์ด๋
- Space Size : 0 : ๊ด์ธกํ๋ ๊ฐ ๊ฐ์ ์์
- Discrete Branches : 2 : ์ด์ฐ๊ฐ (ํด๋น ํ๋ก์ ํธ์์๋ ์ด๋, ํ์ )
- Branch 0 size : 3 : (์ ์ง, ์ ์ง, ํ์ง)
- Branch 1 size : 3 : (์ ์ง, ์ขํ์ , ์ฐํ์ )
- ์ปดํฌ๋ํธ ์ถ๊ฐ : Decision Requester
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
public class MummyRayAgent : Agent
{
private new Transform transform;
private new Rigidbody rigidbody;
public float moveSpeed = 1.5f;
public float turnSpeed = 200.0f;
public override void Initialize()
{
// ํ ์ํผ์๋(ํ์ต๋จ์) ๋น ์๋ ํ์
MaxStep = 5000;
transform = GetComponent<Transform>();
rigidbody = GetComponent<Rigidbody>();
}
public override void OnEpisodeBegin()
{
// ์คํ
์ด์ง ์ด๊ธฐํ
// ๋ฌผ๋ฆฌ์์ง ์ด๊ธฐํ
rigidbody.velocity = rigidbody.angularVelocity = Vector3.zero;
// ์์ด์ ํธ์ ์์น๋ฅผ ๋ณ๊ฒฝ
transform.localPosition = new Vector3(Random.Range(-22.0f, 22.0f), 0.05f, Random.Range(-22.0f, 22.0f));
// ์์ด์ ํธ์ ํ์ ๊ฐ ๋ณ๊ฒฝ
transform.localRotation = Quaternion.Euler(Vector3.up * Random.Range(0, 360));
}
public override void CollectObservations(VectorSensor sensor)
{
}
public override void OnActionReceived(ActionBuffers actions)
{
var action = actions.DiscreteActions;
Debug.Log($"[0] = {action[0]}, [1]: {action[1]}");
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var action = actionsOut.DiscreteActions; // ์ด์ฐ(-1.0, 0.0, +1.0)
actionsOut.Clear();
// Branch 0 - ์ด๋๋ก์ง์ ์ธ ํค ๋งตํ
// Branch 0์ Size : 3
// ์ ์ง / ์ ์ง / ํ์ง
// Non / W / S
// 0 / 1 / 2
if (Input.GetKey(KeyCode.W))
{
action[0] = 1;
}
if (Input.GetKey(KeyCode.S))
{
action[0] = 2;
}
// Branch 1 - ํ์ ๋ก์ง์ ์ธ ํค ๋งตํ
// Branch 1์ Size : 3
// ์ ์ง / ์ขํ์ / ์ฐํ์
// Non / A / D
// 0 / 1 / 2
if (Input.GetKey(KeyCode.A))
{
action[1] = 1;
}
if (Input.GetKey(KeyCode.D))
{
action[1] = 2;
}
}
}
- ์์ดํ ์ด ํฉ๋ฟ๋ ค์ง๊ณ , Agent๊ฐ ์์ง์ด์ง๋ ์๋ ์ํ
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class StageManager : MonoBehaviour
{
public GameObject goodItem;
public GameObject badItem;
public int goodItemCount = 30;
public int badItemCount = 20;
public List<GameObject> goodList = new List<GameObject>();
public List<GameObject> badList = new List<GameObject>();
public void SetStageObject()
{
foreach (var obj in goodList)
{
Destroy(obj);
}
foreach (var obj in badList)
{
Destroy(obj);
}
// List ์ด๊ธฐํ
goodList.Clear();
badList.Clear();
// Good Item ์์ฑ
for (int i = 0; i < goodItemCount; i++)
{
Vector3 pos = new Vector3(Random.Range(-23.0f, 23.0f), 0.05f, Random.Range(-23.0f, 23.0f));
Quaternion rot = Quaternion.Euler(Vector3.up * Random.Range(0, 360));
goodList.Add(Instantiate(goodItem, transform.position + pos, rot, transform));
}
// Bad Item ์์ฑ
for (int i = 0; i < badItemCount; i++)
{
Vector3 pos = new Vector3(Random.Range(-23.0f, 23.0f), 0.05f, Random.Range(-23.0f, 23.0f));
Quaternion rot = Quaternion.Euler(Vector3.up * Random.Range(0, 360));
badList.Add(Instantiate(badItem, transform.position + pos, rot, transform));
}
}
}
- Start() ์ญ์ : SetStageObject() ํจ์๋ฅผ ํธ์ถํ์์
- MummyRayAgent.cs์์ ํธ์ถํ ๊ฒ์ด๊ธฐ ๋๋ฌธ์ ์ญ์
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
public class MummyRayAgent : Agent
{
private new Transform transform;
private new Rigidbody rigidbody;
public float moveSpeed = 1.5f;
public float turnSpeed = 200.0f;
private StageManager stageManager;
public override void Initialize()
{
// ํ ์ํผ์๋(ํ์ต๋จ์) ๋น ์๋ ํ์
MaxStep = 5000;
transform = GetComponent<Transform>();
rigidbody = GetComponent<Rigidbody>();
stageManager = transform.parent.GetComponent<StageManager>();
}
public override void OnEpisodeBegin()
{
// ์คํ
์ด์ง ์ด๊ธฐํ
stageManager.SetStageObject();
// ๋ฌผ๋ฆฌ์์ง ์ด๊ธฐํ
rigidbody.velocity = rigidbody.angularVelocity = Vector3.zero;
// ์์ด์ ํธ์ ์์น๋ฅผ ๋ณ๊ฒฝ
transform.localPosition = new Vector3(Random.Range(-22.0f, 22.0f), 0.05f, Random.Range(-22.0f, 22.0f));
// ์์ด์ ํธ์ ํ์ ๊ฐ ๋ณ๊ฒฝ
transform.localRotation = Quaternion.Euler(Vector3.up * Random.Range(0, 360));
}
public override void CollectObservations(VectorSensor sensor)
{
}
public override void OnActionReceived(ActionBuffers actions)
{
var action = actions.DiscreteActions;
Debug.Log($"[0] = {action[0]}, [1]: {action[1]}");
Vector3 dir = Vector3.zero;
Vector3 rot = Vector3.zero;
// Branch 0 : ์ ์ง / ์ ์ง / ํ์ง
switch (action[0])
{
case 1: dir = transform.forward; break;
case 2: dir = -transform.forward; break;
}
// Branch 1 : ์ ์ง / ์ขํ์ / ์ฐํ์
switch (action[1])
{
case 1: rot = -transform.up; break;
case 2: rot = transform.up; break;
}
transform.Rotate(rot, Time.fixedDeltaTime * turnSpeed);
rigidbody.AddForce(dir * moveSpeed, ForceMode.VelocityChange);
// ๋ง์ด๋์ค ํ๋ํฐ๋ฅผ ์ ์ฉ
AddReward(-1 / (float)MaxStep);
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var action = actionsOut.DiscreteActions; // ์ด์ฐ(-1.0, 0.0, +1.0)
actionsOut.Clear();
// Branch 0 - ์ด๋๋ก์ง์ ์ธ ํค ๋งตํ
// Branch 0์ Size : 3
// ์ ์ง / ์ ์ง / ํ์ง
// Non / W / S
// 0 / 1 / 2
if (Input.GetKey(KeyCode.W))
{
action[0] = 1;
}
if (Input.GetKey(KeyCode.S))
{
action[0] = 2;
}
// Branch 1 - ํ์ ๋ก์ง์ ์ธ ํค ๋งตํ
// Branch 1์ Size : 3
// ์ ์ง / ์ขํ์ / ์ฐํ์
// Non / A / D
// 0 / 1 / 2
if (Input.GetKey(KeyCode.A))
{
action[1] = 1;
}
if (Input.GetKey(KeyCode.D))
{
action[1] = 2;
}
}
}
- ์คํ ์ด์ง ์ด๊ธฐํ : SetStageObject() ํจ์ ํธ์ถ
- ์ด๋ ๋ฐ ํ์ ์ฒ๋ฆฌ
๐ถ ๋ฆฌ์๋ ๋ถ์ฌ
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
public class MummyRayAgent : Agent
{
private new Transform transform;
private new Rigidbody rigidbody;
public float moveSpeed = 1.5f;
public float turnSpeed = 200.0f;
private StageManager stageManager;
public override void Initialize()
{
// ํ ์ํผ์๋(ํ์ต๋จ์) ๋น ์๋ ํ์
MaxStep = 5000;
transform = GetComponent<Transform>();
rigidbody = GetComponent<Rigidbody>();
stageManager = transform.parent.GetComponent<StageManager>();
}
public override void OnEpisodeBegin()
{
// ์คํ
์ด์ง ์ด๊ธฐํ
stageManager.SetStageObject();
// ๋ฌผ๋ฆฌ์์ง ์ด๊ธฐํ
rigidbody.velocity = rigidbody.angularVelocity = Vector3.zero;
// ์์ด์ ํธ์ ์์น๋ฅผ ๋ณ๊ฒฝ
transform.localPosition = new Vector3(Random.Range(-22.0f, 22.0f), 0.05f, Random.Range(-22.0f, 22.0f));
// ์์ด์ ํธ์ ํ์ ๊ฐ ๋ณ๊ฒฝ
transform.localRotation = Quaternion.Euler(Vector3.up * Random.Range(0, 360));
}
public override void CollectObservations(VectorSensor sensor)
{
}
public override void OnActionReceived(ActionBuffers actions)
{
var action = actions.DiscreteActions;
Debug.Log($"[0] = {action[0]}, [1]: {action[1]}");
Vector3 dir = Vector3.zero;
Vector3 rot = Vector3.zero;
// Branch 0 : ์ ์ง / ์ ์ง / ํ์ง
switch (action[0])
{
case 1: dir = transform.forward; break;
case 2: dir = -transform.forward; break;
}
// Branch 1 : ์ ์ง / ์ขํ์ / ์ฐํ์
switch (action[1])
{
case 1: rot = -transform.up; break;
case 2: rot = transform.up; break;
}
transform.Rotate(rot, Time.fixedDeltaTime * turnSpeed);
rigidbody.AddForce(dir * moveSpeed, ForceMode.VelocityChange);
// ๋ง์ด๋์ค ํ๋ํฐ๋ฅผ ์ ์ฉ
AddReward(-1 / (float)MaxStep);
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var action = actionsOut.DiscreteActions; // ์ด์ฐ(-1.0, 0.0, +1.0)
actionsOut.Clear();
// Branch 0 - ์ด๋๋ก์ง์ ์ธ ํค ๋งตํ
// Branch 0์ Size : 3
// ์ ์ง / ์ ์ง / ํ์ง
// Non / W / S
// 0 / 1 / 2
if (Input.GetKey(KeyCode.W))
{
action[0] = 1;
}
if (Input.GetKey(KeyCode.S))
{
action[0] = 2;
}
// Branch 1 - ํ์ ๋ก์ง์ ์ธ ํค ๋งตํ
// Branch 1์ Size : 3
// ์ ์ง / ์ขํ์ / ์ฐํ์
// Non / A / D
// 0 / 1 / 2
if (Input.GetKey(KeyCode.A))
{
action[1] = 1;
}
if (Input.GetKey(KeyCode.D))
{
action[1] = 2;
}
}
void OnCollisionEnter(Collision coll)
{
if (coll.collider.CompareTag("GOOD_ITEM"))
{
// ๊ฐ์๋๊ฐ ๋ถ์ ์ ์๊ธฐ๋๋ฌธ์ ๋ฌผ๋ฆฌ๋ ฅ ์ด๊ธฐํ
rigidbody.velocity = rigidbody.angularVelocity = Vector3.zero;
Destroy(coll.gameObject);
AddReward(1.0f);
}
if (coll.collider.CompareTag("BAD_ITEM"))
{
AddReward(-1.0f);
EndEpisode();
}
if (coll.collider.CompareTag("WALL"))
{
AddReward(-0.1f);
}
}
}
- ์ถฉ๋ ์ฒดํฌ
- ๋ฆฌ์๋
- ๊ฐ๋งํ ์์ ์๋ ์์ผ๋ฏ๋ก ํ๋ํฐ ๋ถ์ฌ
- AddReward(-1 / (float)MaxStep)
- ์ต์ ํ๋ ํ๋ํฐ์
- ๋ฒฝ์ ๋ถ๋ชํ์ง ์๋๋ก ํ๋ํฐ ๋ถ์ฌ
- AddReward(-0.1f)
- ์์๊ฐ์ด๋ฏ๋ก ์กฐ์ ํ์
- ๋ง์ฝ Agent๊ฐ ๋ฒฝ์ ๋ถ๋ชํ๋๊ฒ ๊ฐ์ํ ๋งํ ํจ๋ํฐ๋ก ํ๋จํ์ฌ ๋ฒฝ์ ๋ถ๋ชํ๋ค๋ฉด, ํจ๋ํฐ ๊ฐ์ ์ฌ๋ ค์ผ ํจ.
๐ถ ๋ฆฌ์๋ ๋ฐ์ ๋๋ง๋ค ๋ฐ๋ฅ ์์ ๋ณ๊ฒฝ
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
public class MummyRayAgent : Agent
{
private new Transform transform;
private new Rigidbody rigidbody;
private StageManager stageManager;
public float moveSpeed = 1.5f;
public float turnSpeed = 200.0f;
private Renderer floorRd;
public Material goodMAT, badMAT;
private Material originMAT;
public override void Initialize()
{
// ํ ์ํผ์๋(ํ์ต๋จ์) ๋น ์๋ ํ์
MaxStep = 5000;
transform = GetComponent<Transform>();
rigidbody = GetComponent<Rigidbody>();
stageManager = transform.parent.GetComponent<StageManager>();
floorRd = transform.parent.Find("Floor").GetComponent<Renderer>();
originMAT = floorRd.material;
}
// ๋ฐ๋ฅ์ ์์์ ํ์์ํค๋ ๋ก์ง
IEnumerator RevertMaterial(Material changeMAT)
{
floorRd.material = changeMAT;
yield return new WaitForSeconds(0.2f);
floorRd.material = originMAT;
}
public override void OnEpisodeBegin()
{
// ์คํ
์ด์ง ์ด๊ธฐํ
stageManager.SetStageObject();
// ๋ฌผ๋ฆฌ์์ง ์ด๊ธฐํ
rigidbody.velocity = rigidbody.angularVelocity = Vector3.zero;
// ์์ด์ ํธ์ ์์น๋ฅผ ๋ณ๊ฒฝ
transform.localPosition = new Vector3(Random.Range(-22.0f, 22.0f), 0.05f, Random.Range(-22.0f, 22.0f));
// ์์ด์ ํธ์ ํ์ ๊ฐ ๋ณ๊ฒฝ
transform.localRotation = Quaternion.Euler(Vector3.up * Random.Range(0, 360));
}
public override void CollectObservations(VectorSensor sensor)
{
}
public override void OnActionReceived(ActionBuffers actions)
{
var action = actions.DiscreteActions;
Debug.Log($"[0] = {action[0]}, [1]: {action[1]}");
Vector3 dir = Vector3.zero;
Vector3 rot = Vector3.zero;
// Branch 0 : ์ ์ง / ์ ์ง / ํ์ง
switch (action[0])
{
case 1: dir = transform.forward; break;
case 2: dir = -transform.forward; break;
}
// Branch 1 : ์ ์ง / ์ขํ์ / ์ฐํ์
switch (action[1])
{
case 1: rot = -transform.up; break;
case 2: rot = transform.up; break;
}
transform.Rotate(rot, Time.fixedDeltaTime * turnSpeed);
rigidbody.AddForce(dir * moveSpeed, ForceMode.VelocityChange);
// ๋ง์ด๋์ค ํ๋ํฐ๋ฅผ ์ ์ฉ
AddReward(-1 / (float)MaxStep);
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var action = actionsOut.DiscreteActions; // ์ด์ฐ(-1.0, 0.0, +1.0)
actionsOut.Clear();
// Branch 0 - ์ด๋๋ก์ง์ ์ธ ํค ๋งตํ
// Branch 0์ Size : 3
// ์ ์ง / ์ ์ง / ํ์ง
// Non / W / S
// 0 / 1 / 2
if (Input.GetKey(KeyCode.W))
{
action[0] = 1;
}
if (Input.GetKey(KeyCode.S))
{
action[0] = 2;
}
// Branch 1 - ํ์ ๋ก์ง์ ์ธ ํค ๋งตํ
// Branch 1์ Size : 3
// ์ ์ง / ์ขํ์ / ์ฐํ์
// Non / A / D
// 0 / 1 / 2
if (Input.GetKey(KeyCode.A))
{
action[1] = 1;
}
if (Input.GetKey(KeyCode.D))
{
action[1] = 2;
}
}
void OnCollisionEnter(Collision coll)
{
if (coll.collider.CompareTag("GOOD_ITEM"))
{
// ๊ฐ์๋๊ฐ ๋ถ์ ์ ์๊ธฐ๋๋ฌธ์ ๋ฌผ๋ฆฌ๋ ฅ ์ด๊ธฐํ
rigidbody.velocity = rigidbody.angularVelocity = Vector3.zero;
Destroy(coll.gameObject);
AddReward(1.0f);
StartCoroutine(RevertMaterial(goodMAT));
}
if (coll.collider.CompareTag("BAD_ITEM"))
{
AddReward(-1.0f);
EndEpisode();
StartCoroutine(RevertMaterial(badMAT));
}
if (coll.collider.CompareTag("WALL"))
{
AddReward(-0.1f);
}
}
}
- Materials ์ฐ๊ฒฐ
๐ถ ์คํ ์ด์ง ๊ตฌ์ฑ
- Stage ํ๋ฆฌํนํ
- Agent๊ฐ ์๋ Ray๊ฐ ์ ์คํ ์ด์ง์ ๋ฟ์ง ์๋๋ก ์ถฉ๋ถํ ๊ฑฐ๋ฆฌ๋ฅผ ๋๊ณ ๋ณต์ฌํ๊ธฐ
๐ท ํธ๋ ์ด๋
- FoodCollector.yaml์ด ๊ฐ์ฅ ์ ํฉํด์ ํด๋น ํ์ผ์ ๋ณต์ฌ
- behavior ์ด๋ฆ๊ณผ ๋์ผํ๊ฒ ๋ณ๊ฒฝ
- max_steps : 100๋ง๋ฒ์ผ๋ก ์์
- ํธ๋ ์ด๋ ์์
- MummyRay01ํ์ผ๋ก ํ๋ฒ์ด๋ผ๋ ํธ๋ ์ด๋์์ผฐ์ผ๋ฉด ์๋ฌ๊ฐ ๋ ๊ฒ
- ์ฒ์๋ถํฐ ๋ค์ ์์ : $ mlagents-learn MummyRay.yaml --run-id=MummyRay01 --force
- ๋ฉ์ถ ๊ณณ๋ถํฐ ์ด์ด์ ์์ : $ mlagents-learn MummyRay.yaml --run-id=MummyRay01 --resume
- ์ ๋ํฐ์์ play
๐ถ ํธ๋ ์ด๋ summary
- good Item์ด 30๊ฐ์ด๊ธฐ ๋๋ฌธ์ 30์ ์ด ๋ง์
- item๊ฐ์ ๊ฐ๊ฒฉ์ด ์ข๊ธฐ ๋๋ฌธ์ ์ค์๋ก ์ถฉ๋ -> ์ ์๊ฐ ๋ฎ์
๐ถ ํธ๋ ์ด๋ ๊ฒฐ๊ณผ ํ์ธํด๋ณด๊ธฐ
- ๋ชจ๋ธํ์ผ(.onnx)์ ํ๋ก์ ํธ ์ฐฝ์ผ๋ก ๊ฐ์ ธ์ค๊ธฐ
- Agent์ ๋ชจ๋ธํ์ผ ์ฐ๊ฒฐ
- Apply to Prefab
- ํธ๋ ์ด๋ ๊ฒฐ๊ณผ๊ฐ ๋ง์์ ๋ค์ง ์์ผ๋ฉด
- max_steps๋ฅผ ๋๋ฆฌ๊ณ
- Behavior Parameter์ ์ฐ๊ฒฐ๋์ด์๋ ๋ชจ๋ธ์ ํด์ ์ํจ ํ
- resume์ผ๋ก ๋ค์ ํธ๋ ์ด๋
'Unity > ML-Agents' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
06. ML-Agents - Soccer (1) (1) | 2021.08.02 |
---|---|
05. ML-Agents - Imitation Learning (0) | 2021.08.02 |
04. ML-Agents - Camera Sensor (0) | 2021.08.01 |
02. ML-Agents - position,rigidbody ๊ด์ธก (0) | 2021.07.30 |
01. ML-Agents - ์ค์น ๋ฐ ๊ฐ๋จํ ์ค์ต (1) | 2021.07.30 |