classdef fitCurveLayer < nnet.layer.Layer ...
        & nnet.layer.Acceleratable
    % Example custom SReLU layer.

    properties (Learnable)
    % Layer learnable parameters
    
        a1
        a2
        a3
        a4
    end

    methods
        function layer = fitCurveLayer(args) 
            % layer = sreluLayer creates a SReLU layer.
            %
            % layer = sreluLayer(Name=name) also specifies the
            % layer name.
    
            arguments
                args.Name = "lm_fit";
            end
    
            % Set layer name.
            layer.Name = args.Name;

            % Set layer description.
            layer.Description = "fit curve layer";
        end

        function layer = initialize(layer,~)
            % layer = initialize(layer,layout) initializes the layer
            % learnable parameters using the specified input layout.
            
            if isempty(layer.a1)
                layer.a1 = rand();
            end
            
            if isempty(layer.a2)
                layer.a2 = rand();
            end
            
            if isempty(layer.a3)
                layer.a3 = rand();
            end
            
            if isempty(layer.a4)
                layer.a4 = rand();
            end
        end

        function Y = predict(layer, X)
            % Y = predict(layer, X) forwards the input data X through the
            % layer and outputs the result Y.
           
            
            % Y = layer.a1.*exp(-X./layer.a2) + layer.a3.*X.*exp(-X./layer.a4);
            Y = layer.a1*(X/100) + layer.a2*(X/100).^2 + layer.a3*(X/100).^3 + layer.a4*(X/100).^4;
        end
    end
end