Unity/ML-Agents

03. ML-Agents - Ray Perception Sensor 3D

๐Ÿ”ท ์‹ค์Šต 3

- bad item์— ๋ถ€๋”ชํžˆ๋ฉด -1, good item์— ๋ถ€๋”ชํžˆ๋ฉด +1

- Raycast ๋ฐœ์‚ฌํ•˜์—ฌ ๊ด€์ธก

- good item / bad item

- ์ด์‚ฐ๊ฐ’ DiscreteActions

 

 

 

๐Ÿ”ท ํŠธ๋ ˆ์ด๋‹ ํ™˜๊ฒฝ ๊ตฌ์ถ•

๐Ÿ”ถ ๊ธฐ๋ณธ์„ธํŒ…

- ์”ฌ ์ƒ์„ฑ(MummyRay)

 

 

 

- ๋นˆ ๊ฒŒ์ž„์˜ค๋ธŒ์ ํŠธ (Stage)

- Cube (Floor)

    - Scale : 50, 0.1, 50

 

 

 

- ๋นˆ๊ฒŒ์ž„์˜ค๋ธŒ์ ํŠธ(Walls)

- Cube (Wall)

    - Scale : 50, 3, 1

 

 

 

 

- Agent ์ถ”๊ฐ€

    - Position : 0, 0.05, 0

 

 

 

https://assetstore.unity.com/packages/3d/characters/creatures/rpg-monster-duo-pbr-polyart-157762

- RPG Monster Duo PBR Polyart ๋‹ค์šด > ์ž„ํฌํŠธ

 

 

 

- Slime, TurtleShell ์ถ”๊ฐ€

 

 

 

- Slime, TurtleShell -> GoodItem, BadItem์œผ๋กœ ์ด๋ฆ„ ๋ณ€๊ฒฝ

 

 

 

- Sphere Collider ์ถ”๊ฐ€

    - Center : 0, 0.5, 0

    - Radius : 0.7

 

 

 

- ํƒœ๊ทธ ์ถ”๊ฐ€ ํ›„ ๊ฐ๊ฐ์˜ ์˜ค๋ธŒ์ ํŠธ์— ํƒœ๊ทธ ๋‹ฌ๊ธฐ

 

 

 

- MummyRayAgent.cs ์ƒ์„ฑ

 

 

 

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;

public class MummyRayAgent : Agent
{
    public override void Initialize()
    {

    }

    public override void OnEpisodeBegin()
    {

    }

    public override void CollectObservations(VectorSensor sensor)
    {

    }

    public override void OnActionReceived(ActionBuffers actions)
    {

    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {

    }
}

- ๊ธฐ๋ณธ ํ•จ์ˆ˜ ์ž‘์„ฑ

 

 

 

- MummyRayAgent.cs ๋„ฃ๊ธฐ

- Behavior Name : MummyRay

 

 

 

- ์ปดํฌ๋„ŒํŠธ ์ถ”๊ฐ€ : Ray Perception Sensor 3D

- ๋ถ€์ฑ„๊ผด๋ชจ์–‘์˜ ๋ ˆ์ด๋ฅผ ์จ

 

 

 

- Ray์˜ ์‹œ์ž‘์  ์กฐ์ ˆ

    - Start Vertical Offset : 0.5

    - End Vertical Offset : 0.5

 

 

 

- Detectable Tags : ๋ ˆ์ด๋กœ ๊ฒ€์ถœํ•˜๋Š” ๋Œ€์ƒ

    - Tag 3๊ฐœ ์ถ”๊ฐ€

 

 

 

- Rays per Direction : 4 : ๋ช‡ ๊ฐœ์˜ ๋ ˆ์ด๋ฅผ ์  ๊ฒƒ์ธ์ง€

- Max Ray Degrees : 80 : ๋ ˆ์ด ์‚ฌ์ด์˜ ๊ฐ๋„

- Ray Length : 30 : ๋ ˆ์ด์˜ ๊ธธ์ด

 

 

 

๐Ÿ”ถ Agent ์œ„์น˜, ํšŒ์ „๊ฐ’ ๋žœ๋ค์œผ๋กœ ์ƒ์„ฑ

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;

public class MummyRayAgent : Agent
{
    private new Transform transform;
    private new Rigidbody rigidbody;

    public float moveSpeed = 1.5f;
    public float turnSpeed = 200.0f;



    public override void Initialize()
    {
        // ํ•œ ์—ํ”ผ์†Œ๋“œ(ํ•™์Šต๋‹จ์œ„) ๋‹น ์‹œ๋„ ํšŸ์ˆ˜
        MaxStep = 5000;

        transform = GetComponent<Transform>();
        rigidbody = GetComponent<Rigidbody>();
    }

    public override void OnEpisodeBegin()
    {
        // ์Šคํ…Œ์ด์ง€ ์ดˆ๊ธฐํ™”

        // ๋ฌผ๋ฆฌ์—”์ง„ ์ดˆ๊ธฐํ™”
        rigidbody.velocity = rigidbody.angularVelocity = Vector3.zero;

        // ์—์ด์ „ํŠธ์˜ ์œ„์น˜๋ฅผ ๋ณ€๊ฒฝ
        transform.localPosition = new Vector3(Random.Range(-22.0f, 22.0f), 0.05f, Random.Range(-22.0f, 22.0f));

        // ์—์ด์ „ํŠธ์˜ ํšŒ์ „๊ฐ’ ๋ณ€๊ฒฝ
        transform.localRotation = Quaternion.Euler(Vector3.up * Random.Range(0, 360));

    }

    public override void CollectObservations(VectorSensor sensor)
    {

    }

    public override void OnActionReceived(ActionBuffers actions)
    {

    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {

    }
}

- Agent ์œ„์น˜, ํšŒ์ „๊ฐ’ ๋žœ๋ค์œผ๋กœ ์ƒ์„ฑ

 

 

 

- ์ปดํฌ๋„ŒํŠธ ์ถ”๊ฐ€ : Rigidbody, Capsule Collider

- ๊ฐ ์ปดํฌ๋„ŒํŠธ ์กฐ์ ˆ

 

 

 

๐Ÿ”ถ good item, bad item ์ƒ์„ฑ

- GoodItem, BadItem ํ”„๋ฆฌํŒนํ™”

 

 

 

- GoodItem, BadItem์˜ ๊ฐ๊ฐ Animator Controller ์ฒ˜๋ฆฌ

- Idle๋งŒ ํ•˜๋„๋ก Transition์„ ๋Š์–ด์คŒ

 

 

 

- GoodItem, BadItem ํ•˜์ด์–ด๋ผํ‚ค ์ฐฝ์—์„œ ์‚ญ์ œ

 

 

 

- StageManager.cs ์ƒ์„ฑ

 

 

 

- Stage์˜ค๋ธŒ์ ํŠธ์— StageManager.cs ๋„ฃ๊ธฐ

 

 

 

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class StageManager : MonoBehaviour
{
    public GameObject goodItem;
    public GameObject badItem;

    public int goodItemCount = 30;
    public int badItemCount = 20;

    public List<GameObject> goodList = new List<GameObject>();
    public List<GameObject> badList = new List<GameObject>();

    void Start()
    {
        SetStageObject();
    }

    public void SetStageObject()
    {
        // Good Item ์ƒ์„ฑ
        for (int i = 0; i < goodItemCount; i++)
        {
            Vector3 pos = new Vector3(Random.Range(-23.0f, 23.0f), 0.05f, Random.Range(-23.0f, 23.0f));
            Quaternion rot = Quaternion.Euler(Vector3.up * Random.Range(0, 360));

            goodList.Add(Instantiate(goodItem, transform.position + pos, rot, transform));
        }

        // Bad Item ์ƒ์„ฑ
        for (int i = 0; i < badItemCount; i++)
        {
            Vector3 pos = new Vector3(Random.Range(-23.0f, 23.0f), 0.05f, Random.Range(-23.0f, 23.0f));
            Quaternion rot = Quaternion.Euler(Vector3.up * Random.Range(0, 360));

            badList.Add(Instantiate(badItem, transform.position + pos, rot, transform));
        }
    }
}

 

- GoodItem, BadItem  ์ด 50๊ฐœ ์ƒ์„ฑ

 

 

 

- GoodItem, BadItem ํ”„๋ฆฌํŒน ์—ฐ๊ฒฐ

 

 

 

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class StageManager : MonoBehaviour
{
    public GameObject goodItem;
    public GameObject badItem;

    public int goodItemCount = 30;
    public int badItemCount = 20;

    public List<GameObject> goodList = new List<GameObject>();
    public List<GameObject> badList = new List<GameObject>();

    void Start()
    {
        SetStageObject();
    }

    public void SetStageObject()
    {
        foreach (var obj in goodList)
        {
            Destroy(obj);
        }
        foreach (var obj in badList)
        {
            Destroy(obj);
        }

        // List ์ดˆ๊ธฐํ™”
        goodList.Clear();
        badList.Clear();

        // Good Item ์ƒ์„ฑ
        for (int i = 0; i < goodItemCount; i++)
        {
            Vector3 pos = new Vector3(Random.Range(-23.0f, 23.0f), 0.05f, Random.Range(-23.0f, 23.0f));
            Quaternion rot = Quaternion.Euler(Vector3.up * Random.Range(0, 360));

            goodList.Add(Instantiate(goodItem, transform.position + pos, rot, transform));
        }

        // Bad Item ์ƒ์„ฑ
        for (int i = 0; i < badItemCount; i++)
        {
            Vector3 pos = new Vector3(Random.Range(-23.0f, 23.0f), 0.05f, Random.Range(-23.0f, 23.0f));
            Quaternion rot = Quaternion.Euler(Vector3.up * Random.Range(0, 360));

            badList.Add(Instantiate(badItem, transform.position + pos, rot, transform));
        }
    }
}

- good item, bad item ์ƒ์„ฑ

 

 

 

๐Ÿ”ถ Agent ์ด๋™

- Space Size : 0 : ๊ด€์ธกํ•˜๋Š” ๊ฐ’ ๊ฐœ์ˆ˜ ์—†์Œ

- Discrete Branches : 2 : ์ด์‚ฐ๊ฐ’ (ํ•ด๋‹น ํ”„๋กœ์ ํŠธ์—์„œ๋Š” ์ด๋™, ํšŒ์ „)

    - Branch 0 size : 3 : (์ •์ง€, ์ „์ง„, ํ›„์ง„)

    - Branch 1 size : 3 : (์ •์ง€, ์ขŒํšŒ์ „, ์šฐํšŒ์ „)

 

 

 

- ์ปดํฌ๋„ŒํŠธ ์ถ”๊ฐ€ : Decision Requester

 

 

 

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;

public class MummyRayAgent : Agent
{
    private new Transform transform;
    private new Rigidbody rigidbody;

    public float moveSpeed = 1.5f;
    public float turnSpeed = 200.0f;



    public override void Initialize()
    {
        // ํ•œ ์—ํ”ผ์†Œ๋“œ(ํ•™์Šต๋‹จ์œ„) ๋‹น ์‹œ๋„ ํšŸ์ˆ˜
        MaxStep = 5000;

        transform = GetComponent<Transform>();
        rigidbody = GetComponent<Rigidbody>();
    }

    public override void OnEpisodeBegin()
    {
        // ์Šคํ…Œ์ด์ง€ ์ดˆ๊ธฐํ™”

        // ๋ฌผ๋ฆฌ์—”์ง„ ์ดˆ๊ธฐํ™”
        rigidbody.velocity = rigidbody.angularVelocity = Vector3.zero;

        // ์—์ด์ „ํŠธ์˜ ์œ„์น˜๋ฅผ ๋ณ€๊ฒฝ
        transform.localPosition = new Vector3(Random.Range(-22.0f, 22.0f), 0.05f, Random.Range(-22.0f, 22.0f));

        // ์—์ด์ „ํŠธ์˜ ํšŒ์ „๊ฐ’ ๋ณ€๊ฒฝ
        transform.localRotation = Quaternion.Euler(Vector3.up * Random.Range(0, 360));

    }

    public override void CollectObservations(VectorSensor sensor)
    {

    }

    public override void OnActionReceived(ActionBuffers actions)
    {
        var action = actions.DiscreteActions;
        Debug.Log($"[0] = {action[0]}, [1]: {action[1]}");
    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        var action = actionsOut.DiscreteActions;    // ์ด์‚ฐ(-1.0, 0.0, +1.0)

        actionsOut.Clear();

        // Branch 0 - ์ด๋™๋กœ์ง์— ์“ธ ํ‚ค ๋งตํ•‘
        // Branch 0์˜ Size : 3
        //  ์ •์ง€    /   ์ „์ง„    /   ํ›„์ง„
        //  Non     /   W       /   S
        //  0       /   1       /   2
        if (Input.GetKey(KeyCode.W))
        {
            action[0] = 1;
        }
        if (Input.GetKey(KeyCode.S))
        {
            action[0] = 2;
        }

        // Branch 1 - ํšŒ์ „๋กœ์ง์— ์“ธ ํ‚ค ๋งตํ•‘
        // Branch 1์˜ Size : 3
        //  ์ •์ง€    /   ์ขŒํšŒ์ „  /   ์šฐํšŒ์ „
        //  Non     /   A       /   D
        //  0       /   1       /   2
        if (Input.GetKey(KeyCode.A))
        {
            action[1] = 1;
        }
        if (Input.GetKey(KeyCode.D))
        {
            action[1] = 2;
        }

    }
}

- ์•„์ดํ…œ์ด ํฉ๋ฟŒ๋ ค์ง€๊ณ , Agent๊ฐ€ ์›€์ง์ด์ง€๋Š” ์•Š๋Š” ์ƒํƒœ


 

 

 

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class StageManager : MonoBehaviour
{
    public GameObject goodItem;
    public GameObject badItem;

    public int goodItemCount = 30;
    public int badItemCount = 20;

    public List<GameObject> goodList = new List<GameObject>();
    public List<GameObject> badList = new List<GameObject>();

    public void SetStageObject()
    {
        foreach (var obj in goodList)
        {
            Destroy(obj);
        }
        foreach (var obj in badList)
        {
            Destroy(obj);
        }

        // List ์ดˆ๊ธฐํ™”
        goodList.Clear();
        badList.Clear();

        // Good Item ์ƒ์„ฑ
        for (int i = 0; i < goodItemCount; i++)
        {
            Vector3 pos = new Vector3(Random.Range(-23.0f, 23.0f), 0.05f, Random.Range(-23.0f, 23.0f));
            Quaternion rot = Quaternion.Euler(Vector3.up * Random.Range(0, 360));

            goodList.Add(Instantiate(goodItem, transform.position + pos, rot, transform));
        }

        // Bad Item ์ƒ์„ฑ
        for (int i = 0; i < badItemCount; i++)
        {
            Vector3 pos = new Vector3(Random.Range(-23.0f, 23.0f), 0.05f, Random.Range(-23.0f, 23.0f));
            Quaternion rot = Quaternion.Euler(Vector3.up * Random.Range(0, 360));

            badList.Add(Instantiate(badItem, transform.position + pos, rot, transform));
        }
    }
}

- Start() ์‚ญ์ œ : SetStageObject() ํ•จ์ˆ˜๋ฅผ ํ˜ธ์ถœํ–ˆ์—ˆ์Œ

    - MummyRayAgent.cs์—์„œ ํ˜ธ์ถœํ•  ๊ฒƒ์ด๊ธฐ ๋•Œ๋ฌธ์— ์‚ญ์ œ

 

 

 

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;

public class MummyRayAgent : Agent
{
    private new Transform transform;
    private new Rigidbody rigidbody;

    public float moveSpeed = 1.5f;
    public float turnSpeed = 200.0f;
    private StageManager stageManager;



    public override void Initialize()
    {
        // ํ•œ ์—ํ”ผ์†Œ๋“œ(ํ•™์Šต๋‹จ์œ„) ๋‹น ์‹œ๋„ ํšŸ์ˆ˜
        MaxStep = 5000;

        transform = GetComponent<Transform>();
        rigidbody = GetComponent<Rigidbody>();
        stageManager = transform.parent.GetComponent<StageManager>();
    }

    public override void OnEpisodeBegin()
    {
        // ์Šคํ…Œ์ด์ง€ ์ดˆ๊ธฐํ™”
        stageManager.SetStageObject();

        // ๋ฌผ๋ฆฌ์—”์ง„ ์ดˆ๊ธฐํ™”
        rigidbody.velocity = rigidbody.angularVelocity = Vector3.zero;

        // ์—์ด์ „ํŠธ์˜ ์œ„์น˜๋ฅผ ๋ณ€๊ฒฝ
        transform.localPosition = new Vector3(Random.Range(-22.0f, 22.0f), 0.05f, Random.Range(-22.0f, 22.0f));

        // ์—์ด์ „ํŠธ์˜ ํšŒ์ „๊ฐ’ ๋ณ€๊ฒฝ
        transform.localRotation = Quaternion.Euler(Vector3.up * Random.Range(0, 360));

    }

    public override void CollectObservations(VectorSensor sensor)
    {

    }

    public override void OnActionReceived(ActionBuffers actions)
    {
        var action = actions.DiscreteActions;
        Debug.Log($"[0] = {action[0]}, [1]: {action[1]}");

        Vector3 dir = Vector3.zero;
        Vector3 rot = Vector3.zero;

        // Branch 0 : ์ •์ง€ / ์ „์ง„ / ํ›„์ง„
        switch (action[0])
        {
            case 1: dir = transform.forward; break;
            case 2: dir = -transform.forward; break;
        }

        // Branch 1 : ์ •์ง€ / ์ขŒํšŒ์ „ / ์šฐํšŒ์ „
        switch (action[1])
        {
            case 1: rot = -transform.up; break;
            case 2: rot = transform.up; break;
        }

        transform.Rotate(rot, Time.fixedDeltaTime * turnSpeed);
        rigidbody.AddForce(dir * moveSpeed, ForceMode.VelocityChange);

        // ๋งˆ์ด๋„ˆ์Šค ํŽ˜๋„ํ‹ฐ๋ฅผ ์ ์šฉ
        AddReward(-1 / (float)MaxStep);
    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        var action = actionsOut.DiscreteActions;    // ์ด์‚ฐ(-1.0, 0.0, +1.0)

        actionsOut.Clear();

        // Branch 0 - ์ด๋™๋กœ์ง์— ์“ธ ํ‚ค ๋งตํ•‘
        // Branch 0์˜ Size : 3
        //  ์ •์ง€    /   ์ „์ง„    /   ํ›„์ง„
        //  Non     /   W       /   S
        //  0       /   1       /   2
        if (Input.GetKey(KeyCode.W))
        {
            action[0] = 1;
        }
        if (Input.GetKey(KeyCode.S))
        {
            action[0] = 2;
        }

        // Branch 1 - ํšŒ์ „๋กœ์ง์— ์“ธ ํ‚ค ๋งตํ•‘
        // Branch 1์˜ Size : 3
        //  ์ •์ง€    /   ์ขŒํšŒ์ „  /   ์šฐํšŒ์ „
        //  Non     /   A       /   D
        //  0       /   1       /   2
        if (Input.GetKey(KeyCode.A))
        {
            action[1] = 1;
        }
        if (Input.GetKey(KeyCode.D))
        {
            action[1] = 2;
        }

    }
}

- ์Šคํ…Œ์ด์ง€ ์ดˆ๊ธฐํ™” : SetStageObject() ํ•จ์ˆ˜ ํ˜ธ์ถœ

- ์ด๋™ ๋ฐ ํšŒ์ „ ์ฒ˜๋ฆฌ

 

 

 

๐Ÿ”ถ ๋ฆฌ์›Œ๋“œ ๋ถ€์—ฌ

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;

public class MummyRayAgent : Agent
{
    private new Transform transform;
    private new Rigidbody rigidbody;

    public float moveSpeed = 1.5f;
    public float turnSpeed = 200.0f;
    private StageManager stageManager;



    public override void Initialize()
    {
        // ํ•œ ์—ํ”ผ์†Œ๋“œ(ํ•™์Šต๋‹จ์œ„) ๋‹น ์‹œ๋„ ํšŸ์ˆ˜
        MaxStep = 5000;

        transform = GetComponent<Transform>();
        rigidbody = GetComponent<Rigidbody>();
        stageManager = transform.parent.GetComponent<StageManager>();
    }

    public override void OnEpisodeBegin()
    {
        // ์Šคํ…Œ์ด์ง€ ์ดˆ๊ธฐํ™”
        stageManager.SetStageObject();

        // ๋ฌผ๋ฆฌ์—”์ง„ ์ดˆ๊ธฐํ™”
        rigidbody.velocity = rigidbody.angularVelocity = Vector3.zero;

        // ์—์ด์ „ํŠธ์˜ ์œ„์น˜๋ฅผ ๋ณ€๊ฒฝ
        transform.localPosition = new Vector3(Random.Range(-22.0f, 22.0f), 0.05f, Random.Range(-22.0f, 22.0f));

        // ์—์ด์ „ํŠธ์˜ ํšŒ์ „๊ฐ’ ๋ณ€๊ฒฝ
        transform.localRotation = Quaternion.Euler(Vector3.up * Random.Range(0, 360));

    }

    public override void CollectObservations(VectorSensor sensor)
    {

    }

    public override void OnActionReceived(ActionBuffers actions)
    {
        var action = actions.DiscreteActions;
        Debug.Log($"[0] = {action[0]}, [1]: {action[1]}");

        Vector3 dir = Vector3.zero;
        Vector3 rot = Vector3.zero;

        // Branch 0 : ์ •์ง€ / ์ „์ง„ / ํ›„์ง„
        switch (action[0])
        {
            case 1: dir = transform.forward; break;
            case 2: dir = -transform.forward; break;
        }

        // Branch 1 : ์ •์ง€ / ์ขŒํšŒ์ „ / ์šฐํšŒ์ „
        switch (action[1])
        {
            case 1: rot = -transform.up; break;
            case 2: rot = transform.up; break;
        }

        transform.Rotate(rot, Time.fixedDeltaTime * turnSpeed);
        rigidbody.AddForce(dir * moveSpeed, ForceMode.VelocityChange);

        // ๋งˆ์ด๋„ˆ์Šค ํŽ˜๋„ํ‹ฐ๋ฅผ ์ ์šฉ
        AddReward(-1 / (float)MaxStep);
    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        var action = actionsOut.DiscreteActions;    // ์ด์‚ฐ(-1.0, 0.0, +1.0)

        actionsOut.Clear();

        // Branch 0 - ์ด๋™๋กœ์ง์— ์“ธ ํ‚ค ๋งตํ•‘
        // Branch 0์˜ Size : 3
        //  ์ •์ง€    /   ์ „์ง„    /   ํ›„์ง„
        //  Non     /   W       /   S
        //  0       /   1       /   2
        if (Input.GetKey(KeyCode.W))
        {
            action[0] = 1;
        }
        if (Input.GetKey(KeyCode.S))
        {
            action[0] = 2;
        }

        // Branch 1 - ํšŒ์ „๋กœ์ง์— ์“ธ ํ‚ค ๋งตํ•‘
        // Branch 1์˜ Size : 3
        //  ์ •์ง€    /   ์ขŒํšŒ์ „  /   ์šฐํšŒ์ „
        //  Non     /   A       /   D
        //  0       /   1       /   2
        if (Input.GetKey(KeyCode.A))
        {
            action[1] = 1;
        }
        if (Input.GetKey(KeyCode.D))
        {
            action[1] = 2;
        }
    }

    void OnCollisionEnter(Collision coll)
    {
        if (coll.collider.CompareTag("GOOD_ITEM"))
        {
            // ๊ฐ€์†๋„๊ฐ€ ๋ถ™์„ ์ˆ˜ ์žˆ๊ธฐ๋•Œ๋ฌธ์— ๋ฌผ๋ฆฌ๋ ฅ ์ดˆ๊ธฐํ™”
            rigidbody.velocity = rigidbody.angularVelocity = Vector3.zero;
            Destroy(coll.gameObject);
            AddReward(1.0f);
        }

        if (coll.collider.CompareTag("BAD_ITEM"))
        {
            AddReward(-1.0f);
            EndEpisode();
        }

        if (coll.collider.CompareTag("WALL"))
        {
            AddReward(-0.1f);
        }

    }

}

- ์ถฉ๋Œ ์ฒดํฌ

- ๋ฆฌ์›Œ๋“œ

  • ๊ฐ€๋งŒํžˆ ์žˆ์„ ์ˆ˜๋„ ์žˆ์œผ๋ฏ€๋กœ ํŽ˜๋„ํ‹ฐ ๋ถ€์—ฌ
    • AddReward(-1 / (float)MaxStep)
    • ์ตœ์ ํ™”๋œ ํŽ˜๋„ํ‹ฐ์ž„
  • ๋ฒฝ์— ๋ถ€๋”ชํžˆ์ง€ ์•Š๋„๋ก ํŽ˜๋„ํ‹ฐ ๋ถ€์—ฌ
    • AddReward(-0.1f)
    • ์ž„์‹œ๊ฐ’์ด๋ฏ€๋กœ ์กฐ์ ˆ ํ•„์š”
    • ๋งŒ์•ฝ Agent๊ฐ€ ๋ฒฝ์— ๋ถ€๋”ชํžˆ๋Š”๊ฒŒ ๊ฐ์ˆ˜ํ• ๋งŒํ•œ ํŒจ๋„ํ‹ฐ๋กœ ํŒ๋‹จํ•˜์—ฌ ๋ฒฝ์— ๋ถ€๋”ชํžŒ๋‹ค๋ฉด, ํŒจ๋„ํ‹ฐ ๊ฐ’์„ ์˜ฌ๋ ค์•ผ ํ•จ.

 

 

 

๐Ÿ”ถ ๋ฆฌ์›Œ๋“œ ๋ฐ›์„ ๋•Œ๋งˆ๋‹ค ๋ฐ”๋‹ฅ ์ƒ‰์ƒ ๋ณ€๊ฒฝ

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;

public class MummyRayAgent : Agent
{
    private new Transform transform;
    private new Rigidbody rigidbody;
    private StageManager stageManager;

    public float moveSpeed = 1.5f;
    public float turnSpeed = 200.0f;

    private Renderer floorRd;
    public Material goodMAT, badMAT;
    private Material originMAT;



    public override void Initialize()
    {
        // ํ•œ ์—ํ”ผ์†Œ๋“œ(ํ•™์Šต๋‹จ์œ„) ๋‹น ์‹œ๋„ ํšŸ์ˆ˜
        MaxStep = 5000;

        transform = GetComponent<Transform>();
        rigidbody = GetComponent<Rigidbody>();
        stageManager = transform.parent.GetComponent<StageManager>();
        floorRd = transform.parent.Find("Floor").GetComponent<Renderer>();
        originMAT = floorRd.material;
    }
    // ๋ฐ”๋‹ฅ์˜ ์ƒ‰์ƒ์„ ํ™˜์›์‹œํ‚ค๋Š” ๋กœ์ง
    IEnumerator RevertMaterial(Material changeMAT)
    {
        floorRd.material = changeMAT;
        yield return new WaitForSeconds(0.2f);
        floorRd.material = originMAT;
    }

    public override void OnEpisodeBegin()
    {
        // ์Šคํ…Œ์ด์ง€ ์ดˆ๊ธฐํ™”
        stageManager.SetStageObject();

        // ๋ฌผ๋ฆฌ์—”์ง„ ์ดˆ๊ธฐํ™”
        rigidbody.velocity = rigidbody.angularVelocity = Vector3.zero;

        // ์—์ด์ „ํŠธ์˜ ์œ„์น˜๋ฅผ ๋ณ€๊ฒฝ
        transform.localPosition = new Vector3(Random.Range(-22.0f, 22.0f), 0.05f, Random.Range(-22.0f, 22.0f));

        // ์—์ด์ „ํŠธ์˜ ํšŒ์ „๊ฐ’ ๋ณ€๊ฒฝ
        transform.localRotation = Quaternion.Euler(Vector3.up * Random.Range(0, 360));

    }

    public override void CollectObservations(VectorSensor sensor)
    {

    }

    public override void OnActionReceived(ActionBuffers actions)
    {
        var action = actions.DiscreteActions;
        Debug.Log($"[0] = {action[0]}, [1]: {action[1]}");

        Vector3 dir = Vector3.zero;
        Vector3 rot = Vector3.zero;

        // Branch 0 : ์ •์ง€ / ์ „์ง„ / ํ›„์ง„
        switch (action[0])
        {
            case 1: dir = transform.forward; break;
            case 2: dir = -transform.forward; break;
        }

        // Branch 1 : ์ •์ง€ / ์ขŒํšŒ์ „ / ์šฐํšŒ์ „
        switch (action[1])
        {
            case 1: rot = -transform.up; break;
            case 2: rot = transform.up; break;
        }

        transform.Rotate(rot, Time.fixedDeltaTime * turnSpeed);
        rigidbody.AddForce(dir * moveSpeed, ForceMode.VelocityChange);

        // ๋งˆ์ด๋„ˆ์Šค ํŽ˜๋„ํ‹ฐ๋ฅผ ์ ์šฉ
        AddReward(-1 / (float)MaxStep);
    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        var action = actionsOut.DiscreteActions;    // ์ด์‚ฐ(-1.0, 0.0, +1.0)

        actionsOut.Clear();

        // Branch 0 - ์ด๋™๋กœ์ง์— ์“ธ ํ‚ค ๋งตํ•‘
        // Branch 0์˜ Size : 3
        //  ์ •์ง€    /   ์ „์ง„    /   ํ›„์ง„
        //  Non     /   W       /   S
        //  0       /   1       /   2
        if (Input.GetKey(KeyCode.W))
        {
            action[0] = 1;
        }
        if (Input.GetKey(KeyCode.S))
        {
            action[0] = 2;
        }

        // Branch 1 - ํšŒ์ „๋กœ์ง์— ์“ธ ํ‚ค ๋งตํ•‘
        // Branch 1์˜ Size : 3
        //  ์ •์ง€    /   ์ขŒํšŒ์ „  /   ์šฐํšŒ์ „
        //  Non     /   A       /   D
        //  0       /   1       /   2
        if (Input.GetKey(KeyCode.A))
        {
            action[1] = 1;
        }
        if (Input.GetKey(KeyCode.D))
        {
            action[1] = 2;
        }
    }

    void OnCollisionEnter(Collision coll)
    {
        if (coll.collider.CompareTag("GOOD_ITEM"))
        {
            // ๊ฐ€์†๋„๊ฐ€ ๋ถ™์„ ์ˆ˜ ์žˆ๊ธฐ๋•Œ๋ฌธ์— ๋ฌผ๋ฆฌ๋ ฅ ์ดˆ๊ธฐํ™”
            rigidbody.velocity = rigidbody.angularVelocity = Vector3.zero;
            Destroy(coll.gameObject);
            AddReward(1.0f);

            StartCoroutine(RevertMaterial(goodMAT));
        }

        if (coll.collider.CompareTag("BAD_ITEM"))
        {
            AddReward(-1.0f);
            EndEpisode();

            StartCoroutine(RevertMaterial(badMAT));
        }

        if (coll.collider.CompareTag("WALL"))
        {
            AddReward(-0.1f);
        }

    }

}

 

 

 

- Materials ์—ฐ๊ฒฐ

 

 

 

๐Ÿ”ถ ์Šคํ…Œ์ด์ง€ ๊ตฌ์„ฑ

- Stage ํ”„๋ฆฌํŒนํ™”

 

 

 

- Agent๊ฐ€ ์˜๋Š” Ray๊ฐ€ ์˜† ์Šคํ…Œ์ด์ง€์— ๋‹ฟ์ง€ ์•Š๋„๋ก ์ถฉ๋ถ„ํ•œ ๊ฑฐ๋ฆฌ๋ฅผ ๋‘๊ณ  ๋ณต์‚ฌํ•˜๊ธฐ

 

 

 

๐Ÿ”ท ํŠธ๋ ˆ์ด๋‹

- FoodCollector.yaml์ด ๊ฐ€์žฅ ์ ํ•ฉํ•ด์„œ ํ•ด๋‹น ํŒŒ์ผ์„ ๋ณต์‚ฌ

 

 

 

- behavior ์ด๋ฆ„๊ณผ ๋™์ผํ•˜๊ฒŒ ๋ณ€๊ฒฝ

- max_steps : 100๋งŒ๋ฒˆ์œผ๋กœ ์ˆ˜์ •

 

 

 

- ํŠธ๋ ˆ์ด๋‹ ์‹œ์ž‘

- MummyRay01ํŒŒ์ผ๋กœ ํ•œ๋ฒˆ์ด๋ผ๋„ ํŠธ๋ ˆ์ด๋‹์‹œ์ผฐ์œผ๋ฉด ์—๋Ÿฌ๊ฐ€ ๋‚  ๊ฒƒ

    - ์ฒ˜์Œ๋ถ€ํ„ฐ ๋‹ค์‹œ ์‹œ์ž‘ : $ mlagents-learn MummyRay.yaml --run-id=MummyRay01 --force

    - ๋ฉˆ์ถ˜ ๊ณณ๋ถ€ํ„ฐ ์ด์–ด์„œ ์‹œ์ž‘ : $ mlagents-learn MummyRay.yaml --run-id=MummyRay01 --resume

 

 

 

- ์œ ๋‹ˆํ‹ฐ์—์„œ play

 

 

 

๐Ÿ”ถ ํŠธ๋ ˆ์ด๋‹ summary

- good Item์ด 30๊ฐœ์ด๊ธฐ ๋•Œ๋ฌธ์— 30์ ์ด ๋งŒ์ 

- item๊ฐ„์˜ ๊ฐ„๊ฒฉ์ด ์ข๊ธฐ ๋•Œ๋ฌธ์— ์‹ค์ˆ˜๋กœ ์ถฉ๋Œ -> ์ ์ˆ˜๊ฐ€ ๋‚ฎ์Œ

 

 

 

๐Ÿ”ถ ํŠธ๋ ˆ์ด๋‹ ๊ฒฐ๊ณผ ํ™•์ธํ•ด๋ณด๊ธฐ

- ๋ชจ๋ธํŒŒ์ผ(.onnx)์„ ํ”„๋กœ์ ํŠธ ์ฐฝ์œผ๋กœ ๊ฐ€์ ธ์˜ค๊ธฐ

 

 

 

- Agent์— ๋ชจ๋ธํŒŒ์ผ ์—ฐ๊ฒฐ

- Apply to Prefab

 

 

 

- ํŠธ๋ ˆ์ด๋‹ ๊ฒฐ๊ณผ๊ฐ€ ๋งˆ์Œ์— ๋“ค์ง€ ์•Š์œผ๋ฉด

    - max_steps๋ฅผ ๋Š˜๋ฆฌ๊ณ 

    - Behavior Parameter์— ์—ฐ๊ฒฐ๋˜์–ด์žˆ๋Š” ๋ชจ๋ธ์„ ํ•ด์ œ์‹œํ‚จ ํ›„ 

    - resume์œผ๋กœ ๋‹ค์‹œ ํŠธ๋ ˆ์ด๋‹

'Unity > ML-Agents' ์นดํ…Œ๊ณ ๋ฆฌ์˜ ๋‹ค๋ฅธ ๊ธ€

06. ML-Agents - Soccer (1)  (1) 2021.08.02
05. ML-Agents - Imitation Learning  (0) 2021.08.02
04. ML-Agents - Camera Sensor  (0) 2021.08.01
02. ML-Agents - position,rigidbody ๊ด€์ธก  (0) 2021.07.30
01. ML-Agents - ์„ค์น˜ ๋ฐ ๊ฐ„๋‹จํ•œ ์‹ค์Šต  (1) 2021.07.30